Revert " [LoongArch][ISel] Check the number of sign bits in `PatGprGpr_32` (#107432)"
[llvm-project.git] / llvm / lib / Target / AMDGPU / AMDGPUGlobalISelDivergenceLowering.cpp
blobfb258547e8fb90dab9892b2b2cc40eb008402761
1 //===-- AMDGPUGlobalISelDivergenceLowering.cpp ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// GlobalISel pass that selects divergent i1 phis as lane mask phis.
11 /// Lane mask merging uses same algorithm as SDAG in SILowerI1Copies.
12 /// Handles all cases of temporal divergence.
13 /// For divergent non-phi i1 and uniform i1 uses outside of the cycle this pass
14 /// currently depends on LCSSA to insert phis with one incoming.
16 //===----------------------------------------------------------------------===//
18 #include "AMDGPU.h"
19 #include "SILowerI1Copies.h"
20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
23 #include "llvm/InitializePasses.h"
25 #define DEBUG_TYPE "amdgpu-global-isel-divergence-lowering"
27 using namespace llvm;
29 namespace {
31 class AMDGPUGlobalISelDivergenceLowering : public MachineFunctionPass {
32 public:
33 static char ID;
35 public:
36 AMDGPUGlobalISelDivergenceLowering() : MachineFunctionPass(ID) {
37 initializeAMDGPUGlobalISelDivergenceLoweringPass(
38 *PassRegistry::getPassRegistry());
41 bool runOnMachineFunction(MachineFunction &MF) override;
43 StringRef getPassName() const override {
44 return "AMDGPU GlobalISel divergence lowering";
47 void getAnalysisUsage(AnalysisUsage &AU) const override {
48 AU.setPreservesCFG();
49 AU.addRequired<MachineDominatorTreeWrapperPass>();
50 AU.addRequired<MachinePostDominatorTreeWrapperPass>();
51 AU.addRequired<MachineUniformityAnalysisPass>();
52 MachineFunctionPass::getAnalysisUsage(AU);
56 class DivergenceLoweringHelper : public PhiLoweringHelper {
57 public:
58 DivergenceLoweringHelper(MachineFunction *MF, MachineDominatorTree *DT,
59 MachinePostDominatorTree *PDT,
60 MachineUniformityInfo *MUI);
62 private:
63 MachineUniformityInfo *MUI = nullptr;
64 MachineIRBuilder B;
65 Register buildRegCopyToLaneMask(Register Reg);
67 public:
68 void markAsLaneMask(Register DstReg) const override;
69 void getCandidatesForLowering(
70 SmallVectorImpl<MachineInstr *> &Vreg1Phis) const override;
71 void collectIncomingValuesFromPhi(
72 const MachineInstr *MI,
73 SmallVectorImpl<Incoming> &Incomings) const override;
74 void replaceDstReg(Register NewReg, Register OldReg,
75 MachineBasicBlock *MBB) override;
76 void buildMergeLaneMasks(MachineBasicBlock &MBB,
77 MachineBasicBlock::iterator I, const DebugLoc &DL,
78 Register DstReg, Register PrevReg,
79 Register CurReg) override;
80 void constrainAsLaneMask(Incoming &In) override;
83 DivergenceLoweringHelper::DivergenceLoweringHelper(
84 MachineFunction *MF, MachineDominatorTree *DT,
85 MachinePostDominatorTree *PDT, MachineUniformityInfo *MUI)
86 : PhiLoweringHelper(MF, DT, PDT), MUI(MUI), B(*MF) {}
88 // _(s1) -> SReg_32/64(s1)
89 void DivergenceLoweringHelper::markAsLaneMask(Register DstReg) const {
90 assert(MRI->getType(DstReg) == LLT::scalar(1));
92 if (MRI->getRegClassOrNull(DstReg)) {
93 if (MRI->constrainRegClass(DstReg, ST->getBoolRC()))
94 return;
95 llvm_unreachable("Failed to constrain register class");
98 MRI->setRegClass(DstReg, ST->getBoolRC());
101 void DivergenceLoweringHelper::getCandidatesForLowering(
102 SmallVectorImpl<MachineInstr *> &Vreg1Phis) const {
103 LLT S1 = LLT::scalar(1);
105 // Add divergent i1 phis to the list
106 for (MachineBasicBlock &MBB : *MF) {
107 for (MachineInstr &MI : MBB.phis()) {
108 Register Dst = MI.getOperand(0).getReg();
109 if (MRI->getType(Dst) == S1 && MUI->isDivergent(Dst))
110 Vreg1Phis.push_back(&MI);
115 void DivergenceLoweringHelper::collectIncomingValuesFromPhi(
116 const MachineInstr *MI, SmallVectorImpl<Incoming> &Incomings) const {
117 for (unsigned i = 1; i < MI->getNumOperands(); i += 2) {
118 Incomings.emplace_back(MI->getOperand(i).getReg(),
119 MI->getOperand(i + 1).getMBB(), Register());
123 void DivergenceLoweringHelper::replaceDstReg(Register NewReg, Register OldReg,
124 MachineBasicBlock *MBB) {
125 BuildMI(*MBB, MBB->getFirstNonPHI(), {}, TII->get(AMDGPU::COPY), OldReg)
126 .addReg(NewReg);
129 // Copy Reg to new lane mask register, insert a copy after instruction that
130 // defines Reg while skipping phis if needed.
131 Register DivergenceLoweringHelper::buildRegCopyToLaneMask(Register Reg) {
132 Register LaneMask = createLaneMaskReg(MRI, LaneMaskRegAttrs);
133 MachineInstr *Instr = MRI->getVRegDef(Reg);
134 MachineBasicBlock *MBB = Instr->getParent();
135 B.setInsertPt(*MBB, MBB->SkipPHIsAndLabels(std::next(Instr->getIterator())));
136 B.buildCopy(LaneMask, Reg);
137 return LaneMask;
140 // bb.previous
141 // %PrevReg = ...
143 // bb.current
144 // %CurReg = ...
146 // %DstReg - not defined
148 // -> (wave32 example, new registers have sreg_32 reg class and S1 LLT)
150 // bb.previous
151 // %PrevReg = ...
152 // %PrevRegCopy:sreg_32(s1) = COPY %PrevReg
154 // bb.current
155 // %CurReg = ...
156 // %CurRegCopy:sreg_32(s1) = COPY %CurReg
157 // ...
158 // %PrevMaskedReg:sreg_32(s1) = ANDN2 %PrevRegCopy, ExecReg - active lanes 0
159 // %CurMaskedReg:sreg_32(s1) = AND %ExecReg, CurRegCopy - inactive lanes to 0
160 // %DstReg:sreg_32(s1) = OR %PrevMaskedReg, CurMaskedReg
162 // DstReg = for active lanes rewrite bit in PrevReg with bit from CurReg
163 void DivergenceLoweringHelper::buildMergeLaneMasks(
164 MachineBasicBlock &MBB, MachineBasicBlock::iterator I, const DebugLoc &DL,
165 Register DstReg, Register PrevReg, Register CurReg) {
166 // DstReg = (PrevReg & !EXEC) | (CurReg & EXEC)
167 // TODO: check if inputs are constants or results of a compare.
169 Register PrevRegCopy = buildRegCopyToLaneMask(PrevReg);
170 Register CurRegCopy = buildRegCopyToLaneMask(CurReg);
171 Register PrevMaskedReg = createLaneMaskReg(MRI, LaneMaskRegAttrs);
172 Register CurMaskedReg = createLaneMaskReg(MRI, LaneMaskRegAttrs);
174 B.setInsertPt(MBB, I);
175 B.buildInstr(AndN2Op, {PrevMaskedReg}, {PrevRegCopy, ExecReg});
176 B.buildInstr(AndOp, {CurMaskedReg}, {ExecReg, CurRegCopy});
177 B.buildInstr(OrOp, {DstReg}, {PrevMaskedReg, CurMaskedReg});
180 // GlobalISel has to constrain S1 incoming taken as-is with lane mask register
181 // class. Insert a copy of Incoming.Reg to new lane mask inside Incoming.Block,
182 // Incoming.Reg becomes that new lane mask.
183 void DivergenceLoweringHelper::constrainAsLaneMask(Incoming &In) {
184 B.setInsertPt(*In.Block, In.Block->getFirstTerminator());
186 auto Copy = B.buildCopy(LLT::scalar(1), In.Reg);
187 MRI->setRegClass(Copy.getReg(0), ST->getBoolRC());
188 In.Reg = Copy.getReg(0);
191 } // End anonymous namespace.
193 INITIALIZE_PASS_BEGIN(AMDGPUGlobalISelDivergenceLowering, DEBUG_TYPE,
194 "AMDGPU GlobalISel divergence lowering", false, false)
195 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
197 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
198 INITIALIZE_PASS_END(AMDGPUGlobalISelDivergenceLowering, DEBUG_TYPE,
199 "AMDGPU GlobalISel divergence lowering", false, false)
201 char AMDGPUGlobalISelDivergenceLowering::ID = 0;
203 char &llvm::AMDGPUGlobalISelDivergenceLoweringID =
204 AMDGPUGlobalISelDivergenceLowering::ID;
206 FunctionPass *llvm::createAMDGPUGlobalISelDivergenceLoweringPass() {
207 return new AMDGPUGlobalISelDivergenceLowering();
210 bool AMDGPUGlobalISelDivergenceLowering::runOnMachineFunction(
211 MachineFunction &MF) {
212 MachineDominatorTree &DT =
213 getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
214 MachinePostDominatorTree &PDT =
215 getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
216 MachineUniformityInfo &MUI =
217 getAnalysis<MachineUniformityAnalysisPass>().getUniformityInfo();
219 DivergenceLoweringHelper Helper(&MF, &DT, &PDT, &MUI);
221 return Helper.lowerPhis();