1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
11 // there is at most one ret and one unreachable instruction, it ensures there is
12 // at most one divergent exiting block.
14 // StructurizeCFG can't deal with multi-exit regions formed by branches to
15 // multiple return nodes. It is not desirable to structurize regions with
16 // uniform branches, so unifying those to the same return block as divergent
17 // branches inhibits use of scalar branching. It still can't deal with the case
18 // where one branch goes to return, and one unreachable. Replace unreachable in
19 // this case with a return.
21 //===----------------------------------------------------------------------===//
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Analysis/LegacyDivergenceAnalysis.h"
29 #include "llvm/Analysis/PostDominators.h"
30 #include "llvm/Analysis/TargetTransformInfo.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/CFG.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/InstrTypes.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Intrinsics.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Scalar.h"
43 #include "llvm/Transforms/Utils.h"
47 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
51 class AMDGPUUnifyDivergentExitNodes
: public FunctionPass
{
53 static char ID
; // Pass identification, replacement for typeid
55 AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID
) {
56 initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry());
59 // We can preserve non-critical-edgeness when we unify function exit nodes
60 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
61 bool runOnFunction(Function
&F
) override
;
64 } // end anonymous namespace
66 char AMDGPUUnifyDivergentExitNodes::ID
= 0;
68 char &llvm::AMDGPUUnifyDivergentExitNodesID
= AMDGPUUnifyDivergentExitNodes::ID
;
70 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
71 "Unify divergent function exit nodes", false, false)
72 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass
)
73 INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis
)
74 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
75 "Unify divergent function exit nodes", false, false)
77 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage
&AU
) const{
78 // TODO: Preserve dominator tree.
79 AU
.addRequired
<PostDominatorTreeWrapperPass
>();
81 AU
.addRequired
<LegacyDivergenceAnalysis
>();
83 // No divergent values are changed, only blocks and branch edges.
84 AU
.addPreserved
<LegacyDivergenceAnalysis
>();
86 // We preserve the non-critical-edgeness property
87 AU
.addPreservedID(BreakCriticalEdgesID
);
89 // This is a cluster of orthogonal Transforms
90 AU
.addPreservedID(LowerSwitchID
);
91 FunctionPass::getAnalysisUsage(AU
);
93 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
96 /// \returns true if \p BB is reachable through only uniform branches.
97 /// XXX - Is there a more efficient way to find this?
98 static bool isUniformlyReached(const LegacyDivergenceAnalysis
&DA
,
100 SmallVector
<BasicBlock
*, 8> Stack
;
101 SmallPtrSet
<BasicBlock
*, 8> Visited
;
103 for (BasicBlock
*Pred
: predecessors(&BB
))
104 Stack
.push_back(Pred
);
106 while (!Stack
.empty()) {
107 BasicBlock
*Top
= Stack
.pop_back_val();
108 if (!DA
.isUniform(Top
->getTerminator()))
111 for (BasicBlock
*Pred
: predecessors(Top
)) {
112 if (Visited
.insert(Pred
).second
)
113 Stack
.push_back(Pred
);
120 static BasicBlock
*unifyReturnBlockSet(Function
&F
,
121 ArrayRef
<BasicBlock
*> ReturningBlocks
,
122 const TargetTransformInfo
&TTI
,
124 // Otherwise, we need to insert a new basic block into the function, add a PHI
125 // nodes (if the function returns values), and convert all of the return
126 // instructions into unconditional branches.
127 BasicBlock
*NewRetBlock
= BasicBlock::Create(F
.getContext(), Name
, &F
);
129 PHINode
*PN
= nullptr;
130 if (F
.getReturnType()->isVoidTy()) {
131 ReturnInst::Create(F
.getContext(), nullptr, NewRetBlock
);
133 // If the function doesn't return void... add a PHI node to the block...
134 PN
= PHINode::Create(F
.getReturnType(), ReturningBlocks
.size(),
136 NewRetBlock
->getInstList().push_back(PN
);
137 ReturnInst::Create(F
.getContext(), PN
, NewRetBlock
);
140 // Loop over all of the blocks, replacing the return instruction with an
141 // unconditional branch.
142 for (BasicBlock
*BB
: ReturningBlocks
) {
143 // Add an incoming element to the PHI node for every return instruction that
144 // is merging into this new block...
146 PN
->addIncoming(BB
->getTerminator()->getOperand(0), BB
);
148 // Remove and delete the return inst.
149 BB
->getTerminator()->eraseFromParent();
150 BranchInst::Create(NewRetBlock
, BB
);
153 for (BasicBlock
*BB
: ReturningBlocks
) {
154 // Cleanup possible branch to unconditional branch to the return.
155 simplifyCFG(BB
, TTI
, {2});
161 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function
&F
) {
162 auto &PDT
= getAnalysis
<PostDominatorTreeWrapperPass
>().getPostDomTree();
163 if (PDT
.getRoots().size() <= 1)
166 LegacyDivergenceAnalysis
&DA
= getAnalysis
<LegacyDivergenceAnalysis
>();
168 // Loop over all of the blocks in a function, tracking all of the blocks that
170 SmallVector
<BasicBlock
*, 4> ReturningBlocks
;
171 SmallVector
<BasicBlock
*, 4> UnreachableBlocks
;
173 // Dummy return block for infinite loop.
174 BasicBlock
*DummyReturnBB
= nullptr;
176 for (BasicBlock
*BB
: PDT
.getRoots()) {
177 if (isa
<ReturnInst
>(BB
->getTerminator())) {
178 if (!isUniformlyReached(DA
, *BB
))
179 ReturningBlocks
.push_back(BB
);
180 } else if (isa
<UnreachableInst
>(BB
->getTerminator())) {
181 if (!isUniformlyReached(DA
, *BB
))
182 UnreachableBlocks
.push_back(BB
);
183 } else if (BranchInst
*BI
= dyn_cast
<BranchInst
>(BB
->getTerminator())) {
185 ConstantInt
*BoolTrue
= ConstantInt::getTrue(F
.getContext());
186 if (DummyReturnBB
== nullptr) {
187 DummyReturnBB
= BasicBlock::Create(F
.getContext(),
188 "DummyReturnBlock", &F
);
189 Type
*RetTy
= F
.getReturnType();
190 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
191 ReturnInst::Create(F
.getContext(), RetVal
, DummyReturnBB
);
192 ReturningBlocks
.push_back(DummyReturnBB
);
195 if (BI
->isUnconditional()) {
196 BasicBlock
*LoopHeaderBB
= BI
->getSuccessor(0);
197 BI
->eraseFromParent(); // Delete the unconditional branch.
198 // Add a new conditional branch with a dummy edge to the return block.
199 BranchInst::Create(LoopHeaderBB
, DummyReturnBB
, BoolTrue
, BB
);
200 } else { // Conditional branch.
201 // Create a new transition block to hold the conditional branch.
202 BasicBlock
*TransitionBB
= BasicBlock::Create(F
.getContext(),
203 "TransitionBlock", &F
);
205 // Move BI from BB to the new transition block.
206 BI
->removeFromParent();
207 TransitionBB
->getInstList().push_back(BI
);
209 // Create a branch that will always branch to the transition block.
210 BranchInst::Create(TransitionBB
, DummyReturnBB
, BoolTrue
, BB
);
215 if (!UnreachableBlocks
.empty()) {
216 BasicBlock
*UnreachableBlock
= nullptr;
218 if (UnreachableBlocks
.size() == 1) {
219 UnreachableBlock
= UnreachableBlocks
.front();
221 UnreachableBlock
= BasicBlock::Create(F
.getContext(),
222 "UnifiedUnreachableBlock", &F
);
223 new UnreachableInst(F
.getContext(), UnreachableBlock
);
225 for (BasicBlock
*BB
: UnreachableBlocks
) {
226 // Remove and delete the unreachable inst.
227 BB
->getTerminator()->eraseFromParent();
228 BranchInst::Create(UnreachableBlock
, BB
);
232 if (!ReturningBlocks
.empty()) {
233 // Don't create a new unreachable inst if we have a return. The
234 // structurizer/annotator can't handle the multiple exits
236 Type
*RetTy
= F
.getReturnType();
237 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
238 // Remove and delete the unreachable inst.
239 UnreachableBlock
->getTerminator()->eraseFromParent();
241 Function
*UnreachableIntrin
=
242 Intrinsic::getDeclaration(F
.getParent(), Intrinsic::amdgcn_unreachable
);
244 // Insert a call to an intrinsic tracking that this is an unreachable
245 // point, in case we want to kill the active lanes or something later.
246 CallInst::Create(UnreachableIntrin
, {}, "", UnreachableBlock
);
248 // Don't create a scalar trap. We would only want to trap if this code was
249 // really reached, but a scalar trap would happen even if no lanes
250 // actually reached here.
251 ReturnInst::Create(F
.getContext(), RetVal
, UnreachableBlock
);
252 ReturningBlocks
.push_back(UnreachableBlock
);
256 // Now handle return blocks.
257 if (ReturningBlocks
.empty())
258 return false; // No blocks return
260 if (ReturningBlocks
.size() == 1)
261 return false; // Already has a single return block
263 const TargetTransformInfo
&TTI
264 = getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
266 unifyReturnBlockSet(F
, ReturningBlocks
, TTI
, "UnifiedReturnBlock");