1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===---------------------------------------------------------------------===//
9 // This pass performs various peephole optimisations that fold masks into vector
10 // pseudo instructions after instruction selection.
12 // Currently it converts
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
15 // PseudoVMV_V_V %false, %true, %vl, %sew
17 //===---------------------------------------------------------------------===//
20 #include "RISCVISelDAGToDAG.h"
21 #include "RISCVSubtarget.h"
22 #include "llvm/CodeGen/MachineFunctionPass.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetInstrInfo.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
29 #define DEBUG_TYPE "riscv-fold-masks"
33 class RISCVFoldMasks
: public MachineFunctionPass
{
36 const TargetInstrInfo
*TII
;
37 MachineRegisterInfo
*MRI
;
38 const TargetRegisterInfo
*TRI
;
39 RISCVFoldMasks() : MachineFunctionPass(ID
) {}
41 bool runOnMachineFunction(MachineFunction
&MF
) override
;
42 MachineFunctionProperties
getRequiredProperties() const override
{
43 return MachineFunctionProperties().set(
44 MachineFunctionProperties::Property::IsSSA
);
47 StringRef
getPassName() const override
{ return "RISC-V Fold Masks"; }
50 bool convertToUnmasked(MachineInstr
&MI
, MachineInstr
*MaskDef
);
51 bool convertVMergeToVMv(MachineInstr
&MI
, MachineInstr
*MaskDef
);
53 bool isAllOnesMask(MachineInstr
*MaskDef
);
58 char RISCVFoldMasks::ID
= 0;
60 INITIALIZE_PASS(RISCVFoldMasks
, DEBUG_TYPE
, "RISC-V Fold Masks", false, false)
62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr
*MaskDef
) {
65 assert(MaskDef
->isCopy() && MaskDef
->getOperand(0).getReg() == RISCV::V0
);
66 Register SrcReg
= TRI
->lookThruCopyLike(MaskDef
->getOperand(1).getReg(), MRI
);
67 if (!SrcReg
.isVirtual())
69 MaskDef
= MRI
->getVRegDef(SrcReg
);
73 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
74 // undefined behaviour if it's the wrong bitwidth, so we could choose to
75 // assume that it's all-ones? Same applies to its VL.
76 switch (MaskDef
->getOpcode()) {
77 case RISCV::PseudoVMSET_M_B1
:
78 case RISCV::PseudoVMSET_M_B2
:
79 case RISCV::PseudoVMSET_M_B4
:
80 case RISCV::PseudoVMSET_M_B8
:
81 case RISCV::PseudoVMSET_M_B16
:
82 case RISCV::PseudoVMSET_M_B32
:
83 case RISCV::PseudoVMSET_M_B64
:
90 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
91 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
92 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr
&MI
, MachineInstr
*V0Def
) {
93 #define CASE_VMERGE_TO_VMV(lmul) \
94 case RISCV::PseudoVMERGE_VVM_##lmul: \
95 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
98 switch (MI
.getOpcode()) {
101 CASE_VMERGE_TO_VMV(MF8
)
102 CASE_VMERGE_TO_VMV(MF4
)
103 CASE_VMERGE_TO_VMV(MF2
)
104 CASE_VMERGE_TO_VMV(M1
)
105 CASE_VMERGE_TO_VMV(M2
)
106 CASE_VMERGE_TO_VMV(M4
)
107 CASE_VMERGE_TO_VMV(M8
)
110 Register MergeReg
= MI
.getOperand(1).getReg();
111 Register FalseReg
= MI
.getOperand(2).getReg();
112 // Check merge == false (or merge == undef)
113 if (MergeReg
!= RISCV::NoRegister
&& TRI
->lookThruCopyLike(MergeReg
, MRI
) !=
114 TRI
->lookThruCopyLike(FalseReg
, MRI
))
117 assert(MI
.getOperand(4).isReg() && MI
.getOperand(4).getReg() == RISCV::V0
);
118 if (!isAllOnesMask(V0Def
))
121 MI
.setDesc(TII
->get(NewOpc
));
122 MI
.removeOperand(1); // Merge operand
123 MI
.tieOperands(0, 1); // Tie false to dest
124 MI
.removeOperand(3); // Mask operand
126 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED
));
128 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
129 // register class for the destination and merge operands e.g. VRNoV0 -> VR
130 MRI
->recomputeRegClass(MI
.getOperand(0).getReg());
131 MRI
->recomputeRegClass(MI
.getOperand(1).getReg());
135 bool RISCVFoldMasks::convertToUnmasked(MachineInstr
&MI
,
136 MachineInstr
*MaskDef
) {
137 const RISCV::RISCVMaskedPseudoInfo
*I
=
138 RISCV::getMaskedPseudoInfo(MI
.getOpcode());
142 if (!isAllOnesMask(MaskDef
))
145 // There are two classes of pseudos in the table - compares and
146 // everything else. See the comment on RISCVMaskedPseudo for details.
147 const unsigned Opc
= I
->UnmaskedPseudo
;
148 const MCInstrDesc
&MCID
= TII
->get(Opc
);
149 const bool HasPolicyOp
= RISCVII::hasVecPolicyOp(MCID
.TSFlags
);
150 const bool HasPassthru
= RISCVII::isFirstDefTiedToFirstUse(MCID
);
152 const MCInstrDesc
&MaskedMCID
= TII
->get(MI
.getOpcode());
153 assert(RISCVII::hasVecPolicyOp(MaskedMCID
.TSFlags
) ==
154 RISCVII::hasVecPolicyOp(MCID
.TSFlags
) &&
155 "Masked and unmasked pseudos are inconsistent");
156 assert(HasPolicyOp
== HasPassthru
&& "Unexpected pseudo structure");
162 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
163 unsigned MaskOpIdx
= I
->MaskOpIdx
+ MI
.getNumExplicitDefs();
164 MI
.removeOperand(MaskOpIdx
);
166 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
167 // so try and relax it to vr.
168 MRI
->recomputeRegClass(MI
.getOperand(0).getReg());
169 unsigned PassthruOpIdx
= MI
.getNumExplicitDefs();
171 if (MI
.getOperand(PassthruOpIdx
).getReg() != RISCV::NoRegister
)
172 MRI
->recomputeRegClass(MI
.getOperand(PassthruOpIdx
).getReg());
174 MI
.removeOperand(PassthruOpIdx
);
179 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction
&MF
) {
180 if (skipFunction(MF
.getFunction()))
183 // Skip if the vector extension is not enabled.
184 const RISCVSubtarget
&ST
= MF
.getSubtarget
<RISCVSubtarget
>();
185 if (!ST
.hasVInstructions())
188 TII
= ST
.getInstrInfo();
189 MRI
= &MF
.getRegInfo();
190 TRI
= MRI
->getTargetRegisterInfo();
192 bool Changed
= false;
194 // Masked pseudos coming out of isel will have their mask operand in the form:
196 // $v0:vr = COPY %mask:vr
197 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
199 // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
201 MachineInstr
*CurrentV0Def
;
202 for (MachineBasicBlock
&MBB
: MF
) {
203 CurrentV0Def
= nullptr;
204 for (MachineInstr
&MI
: MBB
) {
205 Changed
|= convertToUnmasked(MI
, CurrentV0Def
);
206 Changed
|= convertVMergeToVMv(MI
, CurrentV0Def
);
208 if (MI
.definesRegister(RISCV::V0
, TRI
))
216 FunctionPass
*llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }