[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVFoldMasks.cpp
blob6ee006525df560f20cb0ed031f0b209ab8709330
1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass performs various peephole optimisations that fold masks into vector
10 // pseudo instructions after instruction selection.
12 // Currently it converts
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
17 //===---------------------------------------------------------------------===//
19 #include "RISCV.h"
20 #include "RISCVISelDAGToDAG.h"
21 #include "RISCVSubtarget.h"
22 #include "llvm/CodeGen/MachineFunctionPass.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetInstrInfo.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
27 using namespace llvm;
29 #define DEBUG_TYPE "riscv-fold-masks"
31 namespace {
33 class RISCVFoldMasks : public MachineFunctionPass {
34 public:
35 static char ID;
36 const TargetInstrInfo *TII;
37 MachineRegisterInfo *MRI;
38 const TargetRegisterInfo *TRI;
39 RISCVFoldMasks() : MachineFunctionPass(ID) {}
41 bool runOnMachineFunction(MachineFunction &MF) override;
42 MachineFunctionProperties getRequiredProperties() const override {
43 return MachineFunctionProperties().set(
44 MachineFunctionProperties::Property::IsSSA);
47 StringRef getPassName() const override { return "RISC-V Fold Masks"; }
49 private:
50 bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef);
51 bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef);
53 bool isAllOnesMask(MachineInstr *MaskDef);
56 } // namespace
58 char RISCVFoldMasks::ID = 0;
60 INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) {
63 if (!MaskDef)
64 return false;
65 assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
66 Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
67 if (!SrcReg.isVirtual())
68 return false;
69 MaskDef = MRI->getVRegDef(SrcReg);
70 if (!MaskDef)
71 return false;
73 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
74 // undefined behaviour if it's the wrong bitwidth, so we could choose to
75 // assume that it's all-ones? Same applies to its VL.
76 switch (MaskDef->getOpcode()) {
77 case RISCV::PseudoVMSET_M_B1:
78 case RISCV::PseudoVMSET_M_B2:
79 case RISCV::PseudoVMSET_M_B4:
80 case RISCV::PseudoVMSET_M_B8:
81 case RISCV::PseudoVMSET_M_B16:
82 case RISCV::PseudoVMSET_M_B32:
83 case RISCV::PseudoVMSET_M_B64:
84 return true;
85 default:
86 return false;
90 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
91 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
92 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
93 #define CASE_VMERGE_TO_VMV(lmul) \
94 case RISCV::PseudoVMERGE_VVM_##lmul: \
95 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
96 break;
97 unsigned NewOpc;
98 switch (MI.getOpcode()) {
99 default:
100 return false;
101 CASE_VMERGE_TO_VMV(MF8)
102 CASE_VMERGE_TO_VMV(MF4)
103 CASE_VMERGE_TO_VMV(MF2)
104 CASE_VMERGE_TO_VMV(M1)
105 CASE_VMERGE_TO_VMV(M2)
106 CASE_VMERGE_TO_VMV(M4)
107 CASE_VMERGE_TO_VMV(M8)
110 Register MergeReg = MI.getOperand(1).getReg();
111 Register FalseReg = MI.getOperand(2).getReg();
112 // Check merge == false (or merge == undef)
113 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
114 TRI->lookThruCopyLike(FalseReg, MRI))
115 return false;
117 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
118 if (!isAllOnesMask(V0Def))
119 return false;
121 MI.setDesc(TII->get(NewOpc));
122 MI.removeOperand(1); // Merge operand
123 MI.tieOperands(0, 1); // Tie false to dest
124 MI.removeOperand(3); // Mask operand
125 MI.addOperand(
126 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
128 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
129 // register class for the destination and merge operands e.g. VRNoV0 -> VR
130 MRI->recomputeRegClass(MI.getOperand(0).getReg());
131 MRI->recomputeRegClass(MI.getOperand(1).getReg());
132 return true;
135 bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
136 MachineInstr *MaskDef) {
137 const RISCV::RISCVMaskedPseudoInfo *I =
138 RISCV::getMaskedPseudoInfo(MI.getOpcode());
139 if (!I)
140 return false;
142 if (!isAllOnesMask(MaskDef))
143 return false;
145 // There are two classes of pseudos in the table - compares and
146 // everything else. See the comment on RISCVMaskedPseudo for details.
147 const unsigned Opc = I->UnmaskedPseudo;
148 const MCInstrDesc &MCID = TII->get(Opc);
149 const bool HasPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags);
150 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
151 #ifndef NDEBUG
152 const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
153 assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
154 RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
155 "Masked and unmasked pseudos are inconsistent");
156 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
157 #endif
158 (void)HasPolicyOp;
160 MI.setDesc(MCID);
162 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
163 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
164 MI.removeOperand(MaskOpIdx);
166 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
167 // so try and relax it to vr.
168 MRI->recomputeRegClass(MI.getOperand(0).getReg());
169 unsigned PassthruOpIdx = MI.getNumExplicitDefs();
170 if (HasPassthru) {
171 if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
172 MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
173 } else
174 MI.removeOperand(PassthruOpIdx);
176 return true;
179 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
180 if (skipFunction(MF.getFunction()))
181 return false;
183 // Skip if the vector extension is not enabled.
184 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
185 if (!ST.hasVInstructions())
186 return false;
188 TII = ST.getInstrInfo();
189 MRI = &MF.getRegInfo();
190 TRI = MRI->getTargetRegisterInfo();
192 bool Changed = false;
194 // Masked pseudos coming out of isel will have their mask operand in the form:
196 // $v0:vr = COPY %mask:vr
197 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
199 // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
200 // on each pseudo.
201 MachineInstr *CurrentV0Def;
202 for (MachineBasicBlock &MBB : MF) {
203 CurrentV0Def = nullptr;
204 for (MachineInstr &MI : MBB) {
205 Changed |= convertToUnmasked(MI, CurrentV0Def);
206 Changed |= convertVMergeToVMv(MI, CurrentV0Def);
208 if (MI.definesRegister(RISCV::V0, TRI))
209 CurrentV0Def = &MI;
213 return Changed;
216 FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }