1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
10 // there is at most one ret and one unreachable instruction, it ensures there is
11 // at most one divergent exiting block.
13 // StructurizeCFG can't deal with multi-exit regions formed by branches to
14 // multiple return nodes. It is not desirable to structurize regions with
15 // uniform branches, so unifying those to the same return block as divergent
16 // branches inhibits use of scalar branching. It still can't deal with the case
17 // where one branch goes to return, and one unreachable. Replace unreachable in
18 // this case with a return.
20 //===----------------------------------------------------------------------===//
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Analysis/LegacyDivergenceAnalysis.h"
28 #include "llvm/Analysis/PostDominators.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/CFG.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/InstrTypes.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Intrinsics.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils.h"
46 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
50 class AMDGPUUnifyDivergentExitNodes
: public FunctionPass
{
52 static char ID
; // Pass identification, replacement for typeid
54 AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID
) {
55 initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry());
58 // We can preserve non-critical-edgeness when we unify function exit nodes
59 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
60 bool runOnFunction(Function
&F
) override
;
63 } // end anonymous namespace
65 char AMDGPUUnifyDivergentExitNodes::ID
= 0;
67 char &llvm::AMDGPUUnifyDivergentExitNodesID
= AMDGPUUnifyDivergentExitNodes::ID
;
69 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
70 "Unify divergent function exit nodes", false, false)
71 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass
)
72 INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis
)
73 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
74 "Unify divergent function exit nodes", false, false)
76 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage
&AU
) const{
77 // TODO: Preserve dominator tree.
78 AU
.addRequired
<PostDominatorTreeWrapperPass
>();
80 AU
.addRequired
<LegacyDivergenceAnalysis
>();
82 // No divergent values are changed, only blocks and branch edges.
83 AU
.addPreserved
<LegacyDivergenceAnalysis
>();
85 // We preserve the non-critical-edgeness property
86 AU
.addPreservedID(BreakCriticalEdgesID
);
88 // This is a cluster of orthogonal Transforms
89 AU
.addPreservedID(LowerSwitchID
);
90 FunctionPass::getAnalysisUsage(AU
);
92 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
95 /// \returns true if \p BB is reachable through only uniform branches.
96 /// XXX - Is there a more efficient way to find this?
97 static bool isUniformlyReached(const LegacyDivergenceAnalysis
&DA
,
99 SmallVector
<BasicBlock
*, 8> Stack
;
100 SmallPtrSet
<BasicBlock
*, 8> Visited
;
102 for (BasicBlock
*Pred
: predecessors(&BB
))
103 Stack
.push_back(Pred
);
105 while (!Stack
.empty()) {
106 BasicBlock
*Top
= Stack
.pop_back_val();
107 if (!DA
.isUniform(Top
->getTerminator()))
110 for (BasicBlock
*Pred
: predecessors(Top
)) {
111 if (Visited
.insert(Pred
).second
)
112 Stack
.push_back(Pred
);
119 static BasicBlock
*unifyReturnBlockSet(Function
&F
,
120 ArrayRef
<BasicBlock
*> ReturningBlocks
,
121 const TargetTransformInfo
&TTI
,
123 // Otherwise, we need to insert a new basic block into the function, add a PHI
124 // nodes (if the function returns values), and convert all of the return
125 // instructions into unconditional branches.
126 BasicBlock
*NewRetBlock
= BasicBlock::Create(F
.getContext(), Name
, &F
);
128 PHINode
*PN
= nullptr;
129 if (F
.getReturnType()->isVoidTy()) {
130 ReturnInst::Create(F
.getContext(), nullptr, NewRetBlock
);
132 // If the function doesn't return void... add a PHI node to the block...
133 PN
= PHINode::Create(F
.getReturnType(), ReturningBlocks
.size(),
135 NewRetBlock
->getInstList().push_back(PN
);
136 ReturnInst::Create(F
.getContext(), PN
, NewRetBlock
);
139 // Loop over all of the blocks, replacing the return instruction with an
140 // unconditional branch.
141 for (BasicBlock
*BB
: ReturningBlocks
) {
142 // Add an incoming element to the PHI node for every return instruction that
143 // is merging into this new block...
145 PN
->addIncoming(BB
->getTerminator()->getOperand(0), BB
);
147 // Remove and delete the return inst.
148 BB
->getTerminator()->eraseFromParent();
149 BranchInst::Create(NewRetBlock
, BB
);
152 for (BasicBlock
*BB
: ReturningBlocks
) {
153 // Cleanup possible branch to unconditional branch to the return.
154 simplifyCFG(BB
, TTI
, {2});
160 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function
&F
) {
161 auto &PDT
= getAnalysis
<PostDominatorTreeWrapperPass
>().getPostDomTree();
162 if (PDT
.getRoots().size() <= 1)
165 LegacyDivergenceAnalysis
&DA
= getAnalysis
<LegacyDivergenceAnalysis
>();
167 // Loop over all of the blocks in a function, tracking all of the blocks that
169 SmallVector
<BasicBlock
*, 4> ReturningBlocks
;
170 SmallVector
<BasicBlock
*, 4> UnreachableBlocks
;
172 // Dummy return block for infinite loop.
173 BasicBlock
*DummyReturnBB
= nullptr;
175 for (BasicBlock
*BB
: PDT
.getRoots()) {
176 if (isa
<ReturnInst
>(BB
->getTerminator())) {
177 if (!isUniformlyReached(DA
, *BB
))
178 ReturningBlocks
.push_back(BB
);
179 } else if (isa
<UnreachableInst
>(BB
->getTerminator())) {
180 if (!isUniformlyReached(DA
, *BB
))
181 UnreachableBlocks
.push_back(BB
);
182 } else if (BranchInst
*BI
= dyn_cast
<BranchInst
>(BB
->getTerminator())) {
184 ConstantInt
*BoolTrue
= ConstantInt::getTrue(F
.getContext());
185 if (DummyReturnBB
== nullptr) {
186 DummyReturnBB
= BasicBlock::Create(F
.getContext(),
187 "DummyReturnBlock", &F
);
188 Type
*RetTy
= F
.getReturnType();
189 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
190 ReturnInst::Create(F
.getContext(), RetVal
, DummyReturnBB
);
191 ReturningBlocks
.push_back(DummyReturnBB
);
194 if (BI
->isUnconditional()) {
195 BasicBlock
*LoopHeaderBB
= BI
->getSuccessor(0);
196 BI
->eraseFromParent(); // Delete the unconditional branch.
197 // Add a new conditional branch with a dummy edge to the return block.
198 BranchInst::Create(LoopHeaderBB
, DummyReturnBB
, BoolTrue
, BB
);
199 } else { // Conditional branch.
200 // Create a new transition block to hold the conditional branch.
201 BasicBlock
*TransitionBB
= BasicBlock::Create(F
.getContext(),
202 "TransitionBlock", &F
);
204 // Move BI from BB to the new transition block.
205 BI
->removeFromParent();
206 TransitionBB
->getInstList().push_back(BI
);
208 // Create a branch that will always branch to the transition block.
209 BranchInst::Create(TransitionBB
, DummyReturnBB
, BoolTrue
, BB
);
214 if (!UnreachableBlocks
.empty()) {
215 BasicBlock
*UnreachableBlock
= nullptr;
217 if (UnreachableBlocks
.size() == 1) {
218 UnreachableBlock
= UnreachableBlocks
.front();
220 UnreachableBlock
= BasicBlock::Create(F
.getContext(),
221 "UnifiedUnreachableBlock", &F
);
222 new UnreachableInst(F
.getContext(), UnreachableBlock
);
224 for (BasicBlock
*BB
: UnreachableBlocks
) {
225 // Remove and delete the unreachable inst.
226 BB
->getTerminator()->eraseFromParent();
227 BranchInst::Create(UnreachableBlock
, BB
);
231 if (!ReturningBlocks
.empty()) {
232 // Don't create a new unreachable inst if we have a return. The
233 // structurizer/annotator can't handle the multiple exits
235 Type
*RetTy
= F
.getReturnType();
236 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
237 // Remove and delete the unreachable inst.
238 UnreachableBlock
->getTerminator()->eraseFromParent();
240 Function
*UnreachableIntrin
=
241 Intrinsic::getDeclaration(F
.getParent(), Intrinsic::amdgcn_unreachable
);
243 // Insert a call to an intrinsic tracking that this is an unreachable
244 // point, in case we want to kill the active lanes or something later.
245 CallInst::Create(UnreachableIntrin
, {}, "", UnreachableBlock
);
247 // Don't create a scalar trap. We would only want to trap if this code was
248 // really reached, but a scalar trap would happen even if no lanes
249 // actually reached here.
250 ReturnInst::Create(F
.getContext(), RetVal
, UnreachableBlock
);
251 ReturningBlocks
.push_back(UnreachableBlock
);
255 // Now handle return blocks.
256 if (ReturningBlocks
.empty())
257 return false; // No blocks return
259 if (ReturningBlocks
.size() == 1)
260 return false; // Already has a single return block
262 const TargetTransformInfo
&TTI
263 = getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
265 unifyReturnBlockSet(F
, ReturningBlocks
, TTI
, "UnifiedReturnBlock");