Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / lib / Target / RISCV / RISCVFoldMasks.cpp
blobd1c77a6cc7756dfb7bad4dcff3152bfc1a9abcf0
1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass performs various peephole optimisations that fold masks into vector
10 // pseudo instructions after instruction selection.
12 // Currently it converts
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
17 //===---------------------------------------------------------------------===//
19 #include "RISCV.h"
20 #include "RISCVSubtarget.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetInstrInfo.h"
24 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 using namespace llvm;
28 #define DEBUG_TYPE "riscv-fold-masks"
30 namespace {
32 class RISCVFoldMasks : public MachineFunctionPass {
33 public:
34 static char ID;
35 const TargetInstrInfo *TII;
36 MachineRegisterInfo *MRI;
37 const TargetRegisterInfo *TRI;
38 RISCVFoldMasks() : MachineFunctionPass(ID) {
39 initializeRISCVFoldMasksPass(*PassRegistry::getPassRegistry());
42 bool runOnMachineFunction(MachineFunction &MF) override;
43 MachineFunctionProperties getRequiredProperties() const override {
44 return MachineFunctionProperties().set(
45 MachineFunctionProperties::Property::IsSSA);
48 StringRef getPassName() const override { return "RISC-V Fold Masks"; }
50 private:
51 bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef);
53 bool isAllOnesMask(MachineInstr *MaskCopy);
56 } // namespace
58 char RISCVFoldMasks::ID = 0;
60 INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskCopy) {
63 if (!MaskCopy)
64 return false;
65 assert(MaskCopy->isCopy() && MaskCopy->getOperand(0).getReg() == RISCV::V0);
66 Register SrcReg =
67 TRI->lookThruCopyLike(MaskCopy->getOperand(1).getReg(), MRI);
68 if (!SrcReg.isVirtual())
69 return false;
70 MachineInstr *SrcDef = MRI->getVRegDef(SrcReg);
71 if (!SrcDef)
72 return false;
74 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
75 // undefined behaviour if it's the wrong bitwidth, so we could choose to
76 // assume that it's all-ones? Same applies to its VL.
77 switch (SrcDef->getOpcode()) {
78 case RISCV::PseudoVMSET_M_B1:
79 case RISCV::PseudoVMSET_M_B2:
80 case RISCV::PseudoVMSET_M_B4:
81 case RISCV::PseudoVMSET_M_B8:
82 case RISCV::PseudoVMSET_M_B16:
83 case RISCV::PseudoVMSET_M_B32:
84 case RISCV::PseudoVMSET_M_B64:
85 return true;
86 default:
87 return false;
91 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
92 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
93 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
94 #define CASE_VMERGE_TO_VMV(lmul) \
95 case RISCV::PseudoVMERGE_VVM_##lmul: \
96 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
97 break;
98 unsigned NewOpc;
99 switch (MI.getOpcode()) {
100 default:
101 llvm_unreachable("Expected VMERGE_VVM_<LMUL> instruction.");
102 CASE_VMERGE_TO_VMV(MF8)
103 CASE_VMERGE_TO_VMV(MF4)
104 CASE_VMERGE_TO_VMV(MF2)
105 CASE_VMERGE_TO_VMV(M1)
106 CASE_VMERGE_TO_VMV(M2)
107 CASE_VMERGE_TO_VMV(M4)
108 CASE_VMERGE_TO_VMV(M8)
111 Register MergeReg = MI.getOperand(1).getReg();
112 Register FalseReg = MI.getOperand(2).getReg();
113 // Check merge == false (or merge == undef)
114 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
115 TRI->lookThruCopyLike(FalseReg, MRI))
116 return false;
118 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
119 if (!isAllOnesMask(V0Def))
120 return false;
122 MI.setDesc(TII->get(NewOpc));
123 MI.removeOperand(1); // Merge operand
124 MI.tieOperands(0, 1); // Tie false to dest
125 MI.removeOperand(3); // Mask operand
126 MI.addOperand(
127 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
129 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
130 // register class for the destination and merge operands e.g. VRNoV0 -> VR
131 MRI->recomputeRegClass(MI.getOperand(0).getReg());
132 MRI->recomputeRegClass(MI.getOperand(1).getReg());
133 return true;
136 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
137 if (skipFunction(MF.getFunction()))
138 return false;
140 // Skip if the vector extension is not enabled.
141 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
142 if (!ST.hasVInstructions())
143 return false;
145 TII = ST.getInstrInfo();
146 MRI = &MF.getRegInfo();
147 TRI = MRI->getTargetRegisterInfo();
149 bool Changed = false;
151 // Masked pseudos coming out of isel will have their mask operand in the form:
153 // $v0:vr = COPY %mask:vr
154 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
156 // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
157 // on each pseudo.
158 MachineInstr *CurrentV0Def;
159 for (MachineBasicBlock &MBB : MF) {
160 CurrentV0Def = nullptr;
161 for (MachineInstr &MI : MBB) {
162 unsigned BaseOpc = RISCV::getRVVMCOpcode(MI.getOpcode());
163 if (BaseOpc == RISCV::VMERGE_VVM)
164 Changed |= convertVMergeToVMv(MI, CurrentV0Def);
166 if (MI.definesRegister(RISCV::V0, TRI))
167 CurrentV0Def = &MI;
171 return Changed;
174 FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }