[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVRVVInitUndef.cpp
blob735fc1350c0091559877f5a2aa3bfadee53c8ea3
1 //===- RISCVRVVInitUndef.cpp - Initialize undef vector value to pseudo ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a function pass that initializes undef vector value to
10 // temporary pseudo instruction and remove it in expandpseudo pass to prevent
11 // register allocation resulting in a constraint violated result for vector
12 // instruction. It also rewrites the NoReg tied operand back to an
13 // IMPLICIT_DEF.
15 // RISC-V vector instruction has register overlapping constraint for certain
16 // instructions, and will cause illegal instruction trap if violated, we use
17 // early clobber to model this constraint, but it can't prevent register
18 // allocator allocated same or overlapped if the input register is undef value,
19 // so convert IMPLICIT_DEF to temporary pseudo instruction and remove it later
20 // could prevent that happen, it's not best way to resolve this, and it might
21 // change the order of program or increase the register pressure, so ideally we
22 // should model the constraint right, but before we model the constraint right,
23 // it's the only way to prevent that happen.
25 // When we enable the subregister liveness option, it will also trigger same
26 // issue due to the partial of register is undef. If we pseudoinit the whole
27 // register, then it will generate redundant COPY instruction. Currently, it
28 // will generate INSERT_SUBREG to make sure the whole register is occupied
29 // when program encounter operation that has early-clobber constraint.
32 // See also: https://github.com/llvm/llvm-project/issues/50157
34 // Additionally, this pass rewrites tied operands of vector instructions
35 // from NoReg to IMPLICIT_DEF. (Not that this is a non-overlapping set of
36 // operands to the above.) We use NoReg to side step a MachineCSE
37 // optimization quality problem but need to convert back before
38 // TwoAddressInstruction. See pr64282 for context.
40 //===----------------------------------------------------------------------===//
42 #include "RISCV.h"
43 #include "RISCVSubtarget.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/CodeGen/DetectDeadLanes.h"
47 #include "llvm/CodeGen/MachineFunctionPass.h"
48 using namespace llvm;
50 #define DEBUG_TYPE "riscv-init-undef"
51 #define RISCV_INIT_UNDEF_NAME "RISC-V init undef pass"
53 namespace {
55 class RISCVInitUndef : public MachineFunctionPass {
56 const TargetInstrInfo *TII;
57 MachineRegisterInfo *MRI;
58 const RISCVSubtarget *ST;
59 const TargetRegisterInfo *TRI;
61 // Newly added vregs, assumed to be fully rewritten
62 SmallSet<Register, 8> NewRegs;
63 SmallVector<MachineInstr *, 8> DeadInsts;
65 public:
66 static char ID;
68 RISCVInitUndef() : MachineFunctionPass(ID) {}
69 bool runOnMachineFunction(MachineFunction &MF) override;
71 void getAnalysisUsage(AnalysisUsage &AU) const override {
72 AU.setPreservesCFG();
73 MachineFunctionPass::getAnalysisUsage(AU);
76 StringRef getPassName() const override { return RISCV_INIT_UNDEF_NAME; }
78 private:
79 bool processBasicBlock(MachineFunction &MF, MachineBasicBlock &MBB,
80 const DeadLaneDetector &DLD);
81 bool isVectorRegClass(const Register R);
82 const TargetRegisterClass *
83 getVRLargestSuperClass(const TargetRegisterClass *RC) const;
84 bool handleSubReg(MachineFunction &MF, MachineInstr &MI,
85 const DeadLaneDetector &DLD);
86 bool fixupIllOperand(MachineInstr *MI, MachineOperand &MO);
87 bool handleReg(MachineInstr *MI);
90 } // end anonymous namespace
92 char RISCVInitUndef::ID = 0;
93 INITIALIZE_PASS(RISCVInitUndef, DEBUG_TYPE, RISCV_INIT_UNDEF_NAME, false, false)
94 char &llvm::RISCVInitUndefID = RISCVInitUndef::ID;
96 const TargetRegisterClass *
97 RISCVInitUndef::getVRLargestSuperClass(const TargetRegisterClass *RC) const {
98 if (RISCV::VRM8RegClass.hasSubClassEq(RC))
99 return &RISCV::VRM8RegClass;
100 if (RISCV::VRM4RegClass.hasSubClassEq(RC))
101 return &RISCV::VRM4RegClass;
102 if (RISCV::VRM2RegClass.hasSubClassEq(RC))
103 return &RISCV::VRM2RegClass;
104 if (RISCV::VRRegClass.hasSubClassEq(RC))
105 return &RISCV::VRRegClass;
106 return RC;
109 bool RISCVInitUndef::isVectorRegClass(const Register R) {
110 const TargetRegisterClass *RC = MRI->getRegClass(R);
111 return RISCV::VRRegClass.hasSubClassEq(RC) ||
112 RISCV::VRM2RegClass.hasSubClassEq(RC) ||
113 RISCV::VRM4RegClass.hasSubClassEq(RC) ||
114 RISCV::VRM8RegClass.hasSubClassEq(RC);
117 static unsigned getUndefInitOpcode(unsigned RegClassID) {
118 switch (RegClassID) {
119 case RISCV::VRRegClassID:
120 return RISCV::PseudoRVVInitUndefM1;
121 case RISCV::VRM2RegClassID:
122 return RISCV::PseudoRVVInitUndefM2;
123 case RISCV::VRM4RegClassID:
124 return RISCV::PseudoRVVInitUndefM4;
125 case RISCV::VRM8RegClassID:
126 return RISCV::PseudoRVVInitUndefM8;
127 default:
128 llvm_unreachable("Unexpected register class.");
132 static bool isEarlyClobberMI(MachineInstr &MI) {
133 return llvm::any_of(MI.defs(), [](const MachineOperand &DefMO) {
134 return DefMO.isReg() && DefMO.isEarlyClobber();
138 static bool findImplictDefMIFromReg(Register Reg, MachineRegisterInfo *MRI) {
139 for (auto &DefMI : MRI->def_instructions(Reg)) {
140 if (DefMI.getOpcode() == TargetOpcode::IMPLICIT_DEF)
141 return true;
143 return false;
146 bool RISCVInitUndef::handleReg(MachineInstr *MI) {
147 bool Changed = false;
148 for (auto &UseMO : MI->uses()) {
149 if (!UseMO.isReg())
150 continue;
151 if (UseMO.isTied())
152 continue;
153 if (!UseMO.getReg().isVirtual())
154 continue;
155 if (!isVectorRegClass(UseMO.getReg()))
156 continue;
158 if (UseMO.isUndef() || findImplictDefMIFromReg(UseMO.getReg(), MRI))
159 Changed |= fixupIllOperand(MI, UseMO);
161 return Changed;
164 bool RISCVInitUndef::handleSubReg(MachineFunction &MF, MachineInstr &MI,
165 const DeadLaneDetector &DLD) {
166 bool Changed = false;
168 for (MachineOperand &UseMO : MI.uses()) {
169 if (!UseMO.isReg())
170 continue;
171 if (!UseMO.getReg().isVirtual())
172 continue;
173 if (UseMO.isTied())
174 continue;
176 Register Reg = UseMO.getReg();
177 if (NewRegs.count(Reg))
178 continue;
179 DeadLaneDetector::VRegInfo Info =
180 DLD.getVRegInfo(Register::virtReg2Index(Reg));
182 if (Info.UsedLanes == Info.DefinedLanes)
183 continue;
185 const TargetRegisterClass *TargetRegClass =
186 getVRLargestSuperClass(MRI->getRegClass(Reg));
188 LaneBitmask NeedDef = Info.UsedLanes & ~Info.DefinedLanes;
190 LLVM_DEBUG({
191 dbgs() << "Instruction has undef subregister.\n";
192 dbgs() << printReg(Reg, nullptr)
193 << " Used: " << PrintLaneMask(Info.UsedLanes)
194 << " Def: " << PrintLaneMask(Info.DefinedLanes)
195 << " Need Def: " << PrintLaneMask(NeedDef) << "\n";
198 SmallVector<unsigned> SubRegIndexNeedInsert;
199 TRI->getCoveringSubRegIndexes(*MRI, TargetRegClass, NeedDef,
200 SubRegIndexNeedInsert);
202 Register LatestReg = Reg;
203 for (auto ind : SubRegIndexNeedInsert) {
204 Changed = true;
205 const TargetRegisterClass *SubRegClass =
206 getVRLargestSuperClass(TRI->getSubRegisterClass(TargetRegClass, ind));
207 Register TmpInitSubReg = MRI->createVirtualRegister(SubRegClass);
208 BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(),
209 TII->get(getUndefInitOpcode(SubRegClass->getID())),
210 TmpInitSubReg);
211 Register NewReg = MRI->createVirtualRegister(TargetRegClass);
212 BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(),
213 TII->get(TargetOpcode::INSERT_SUBREG), NewReg)
214 .addReg(LatestReg)
215 .addReg(TmpInitSubReg)
216 .addImm(ind);
217 LatestReg = NewReg;
220 UseMO.setReg(LatestReg);
223 return Changed;
226 bool RISCVInitUndef::fixupIllOperand(MachineInstr *MI, MachineOperand &MO) {
228 LLVM_DEBUG(
229 dbgs() << "Emitting PseudoRVVInitUndef for implicit vector register "
230 << MO.getReg() << '\n');
232 const TargetRegisterClass *TargetRegClass =
233 getVRLargestSuperClass(MRI->getRegClass(MO.getReg()));
234 unsigned Opcode = getUndefInitOpcode(TargetRegClass->getID());
235 Register NewReg = MRI->createVirtualRegister(TargetRegClass);
236 BuildMI(*MI->getParent(), MI, MI->getDebugLoc(), TII->get(Opcode), NewReg);
237 MO.setReg(NewReg);
238 if (MO.isUndef())
239 MO.setIsUndef(false);
240 return true;
243 bool RISCVInitUndef::processBasicBlock(MachineFunction &MF,
244 MachineBasicBlock &MBB,
245 const DeadLaneDetector &DLD) {
246 bool Changed = false;
247 for (MachineBasicBlock::iterator I = MBB.begin(); I != MBB.end(); ++I) {
248 MachineInstr &MI = *I;
250 // If we used NoReg to represent the passthru, switch this back to being
251 // an IMPLICIT_DEF before TwoAddressInstructions.
252 unsigned UseOpIdx;
253 if (MI.getNumDefs() != 0 && MI.isRegTiedToUseOperand(0, &UseOpIdx)) {
254 MachineOperand &UseMO = MI.getOperand(UseOpIdx);
255 if (UseMO.getReg() == RISCV::NoRegister) {
256 const TargetRegisterClass *RC =
257 TII->getRegClass(MI.getDesc(), UseOpIdx, TRI, MF);
258 Register NewDest = MRI->createVirtualRegister(RC);
259 // We don't have a way to update dead lanes, so keep track of the
260 // new register so that we avoid querying it later.
261 NewRegs.insert(NewDest);
262 BuildMI(MBB, I, I->getDebugLoc(),
263 TII->get(TargetOpcode::IMPLICIT_DEF), NewDest);
264 UseMO.setReg(NewDest);
265 Changed = true;
269 if (isEarlyClobberMI(MI)) {
270 if (ST->enableSubRegLiveness())
271 Changed |= handleSubReg(MF, MI, DLD);
272 Changed |= handleReg(&MI);
275 return Changed;
278 bool RISCVInitUndef::runOnMachineFunction(MachineFunction &MF) {
279 ST = &MF.getSubtarget<RISCVSubtarget>();
280 if (!ST->hasVInstructions())
281 return false;
283 MRI = &MF.getRegInfo();
284 TII = ST->getInstrInfo();
285 TRI = MRI->getTargetRegisterInfo();
287 bool Changed = false;
288 DeadLaneDetector DLD(MRI, TRI);
289 DLD.computeSubRegisterLaneBitInfo();
291 for (MachineBasicBlock &BB : MF)
292 Changed |= processBasicBlock(MF, BB, DLD);
294 for (auto *DeadMI : DeadInsts)
295 DeadMI->eraseFromParent();
296 DeadInsts.clear();
298 return Changed;
301 FunctionPass *llvm::createRISCVInitUndefPass() { return new RISCVInitUndef(); }