1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===---------------------------------------------------------------------===//
9 // This pass performs various peephole optimisations that fold masks into vector
10 // pseudo instructions after instruction selection.
12 // Currently it converts
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
15 // PseudoVMV_V_V %false, %true, %vl, %sew
17 //===---------------------------------------------------------------------===//
20 #include "RISCVSubtarget.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetInstrInfo.h"
24 #include "llvm/CodeGen/TargetRegisterInfo.h"
28 #define DEBUG_TYPE "riscv-fold-masks"
32 class RISCVFoldMasks
: public MachineFunctionPass
{
35 const TargetInstrInfo
*TII
;
36 MachineRegisterInfo
*MRI
;
37 const TargetRegisterInfo
*TRI
;
38 RISCVFoldMasks() : MachineFunctionPass(ID
) {
39 initializeRISCVFoldMasksPass(*PassRegistry::getPassRegistry());
42 bool runOnMachineFunction(MachineFunction
&MF
) override
;
43 MachineFunctionProperties
getRequiredProperties() const override
{
44 return MachineFunctionProperties().set(
45 MachineFunctionProperties::Property::IsSSA
);
48 StringRef
getPassName() const override
{ return "RISC-V Fold Masks"; }
51 bool convertVMergeToVMv(MachineInstr
&MI
, MachineInstr
*MaskDef
);
53 bool isAllOnesMask(MachineInstr
*MaskCopy
);
58 char RISCVFoldMasks::ID
= 0;
60 INITIALIZE_PASS(RISCVFoldMasks
, DEBUG_TYPE
, "RISC-V Fold Masks", false, false)
62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr
*MaskCopy
) {
65 assert(MaskCopy
->isCopy() && MaskCopy
->getOperand(0).getReg() == RISCV::V0
);
67 TRI
->lookThruCopyLike(MaskCopy
->getOperand(1).getReg(), MRI
);
68 if (!SrcReg
.isVirtual())
70 MachineInstr
*SrcDef
= MRI
->getVRegDef(SrcReg
);
74 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
75 // undefined behaviour if it's the wrong bitwidth, so we could choose to
76 // assume that it's all-ones? Same applies to its VL.
77 switch (SrcDef
->getOpcode()) {
78 case RISCV::PseudoVMSET_M_B1
:
79 case RISCV::PseudoVMSET_M_B2
:
80 case RISCV::PseudoVMSET_M_B4
:
81 case RISCV::PseudoVMSET_M_B8
:
82 case RISCV::PseudoVMSET_M_B16
:
83 case RISCV::PseudoVMSET_M_B32
:
84 case RISCV::PseudoVMSET_M_B64
:
91 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
92 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
93 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr
&MI
, MachineInstr
*V0Def
) {
94 #define CASE_VMERGE_TO_VMV(lmul) \
95 case RISCV::PseudoVMERGE_VVM_##lmul: \
96 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
99 switch (MI
.getOpcode()) {
101 llvm_unreachable("Expected VMERGE_VVM_<LMUL> instruction.");
102 CASE_VMERGE_TO_VMV(MF8
)
103 CASE_VMERGE_TO_VMV(MF4
)
104 CASE_VMERGE_TO_VMV(MF2
)
105 CASE_VMERGE_TO_VMV(M1
)
106 CASE_VMERGE_TO_VMV(M2
)
107 CASE_VMERGE_TO_VMV(M4
)
108 CASE_VMERGE_TO_VMV(M8
)
111 Register MergeReg
= MI
.getOperand(1).getReg();
112 Register FalseReg
= MI
.getOperand(2).getReg();
113 // Check merge == false (or merge == undef)
114 if (MergeReg
!= RISCV::NoRegister
&& TRI
->lookThruCopyLike(MergeReg
, MRI
) !=
115 TRI
->lookThruCopyLike(FalseReg
, MRI
))
118 assert(MI
.getOperand(4).isReg() && MI
.getOperand(4).getReg() == RISCV::V0
);
119 if (!isAllOnesMask(V0Def
))
122 MI
.setDesc(TII
->get(NewOpc
));
123 MI
.removeOperand(1); // Merge operand
124 MI
.tieOperands(0, 1); // Tie false to dest
125 MI
.removeOperand(3); // Mask operand
127 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED
));
129 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
130 // register class for the destination and merge operands e.g. VRNoV0 -> VR
131 MRI
->recomputeRegClass(MI
.getOperand(0).getReg());
132 MRI
->recomputeRegClass(MI
.getOperand(1).getReg());
136 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction
&MF
) {
137 if (skipFunction(MF
.getFunction()))
140 // Skip if the vector extension is not enabled.
141 const RISCVSubtarget
&ST
= MF
.getSubtarget
<RISCVSubtarget
>();
142 if (!ST
.hasVInstructions())
145 TII
= ST
.getInstrInfo();
146 MRI
= &MF
.getRegInfo();
147 TRI
= MRI
->getTargetRegisterInfo();
149 bool Changed
= false;
151 // Masked pseudos coming out of isel will have their mask operand in the form:
153 // $v0:vr = COPY %mask:vr
154 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
156 // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
158 MachineInstr
*CurrentV0Def
;
159 for (MachineBasicBlock
&MBB
: MF
) {
160 CurrentV0Def
= nullptr;
161 for (MachineInstr
&MI
: MBB
) {
162 unsigned BaseOpc
= RISCV::getRVVMCOpcode(MI
.getOpcode());
163 if (BaseOpc
== RISCV::VMERGE_VVM
)
164 Changed
|= convertVMergeToVMv(MI
, CurrentV0Def
);
166 if (MI
.definesRegister(RISCV::V0
, TRI
))
174 FunctionPass
*llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }