1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This is a variant of the UnifyFunctionExitNodes pass. Rather than ensuring
10 // there is at most one ret and one unreachable instruction, it ensures there is
11 // at most one divergent exiting block.
13 // StructurizeCFG can't deal with multi-exit regions formed by branches to
14 // multiple return nodes. It is not desirable to structurize regions with
15 // uniform branches, so unifying those to the same return block as divergent
16 // branches inhibits use of scalar branching. It still can't deal with the case
17 // where one branch goes to return, and one unreachable. Replace unreachable in
18 // this case with a return.
20 //===----------------------------------------------------------------------===//
23 #include "SIDefines.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Analysis/DomTreeUpdater.h"
29 #include "llvm/Analysis/LegacyDivergenceAnalysis.h"
30 #include "llvm/Analysis/PostDominators.h"
31 #include "llvm/Analysis/TargetTransformInfo.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/CFG.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Dominators.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/InstrTypes.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/Intrinsics.h"
41 #include "llvm/IR/IntrinsicsAMDGPU.h"
42 #include "llvm/IR/Type.h"
43 #include "llvm/InitializePasses.h"
44 #include "llvm/Pass.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Transforms/Scalar.h"
47 #include "llvm/Transforms/Utils.h"
48 #include "llvm/Transforms/Utils/Local.h"
52 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
56 class AMDGPUUnifyDivergentExitNodes
: public FunctionPass
{
58 const TargetTransformInfo
*TTI
= nullptr;
61 static char ID
; // Pass identification, replacement for typeid
63 AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID
) {
64 initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry());
67 // We can preserve non-critical-edgeness when we unify function exit nodes
68 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
69 BasicBlock
*unifyReturnBlockSet(Function
&F
, DomTreeUpdater
&DTU
,
70 ArrayRef
<BasicBlock
*> ReturningBlocks
,
72 bool runOnFunction(Function
&F
) override
;
75 } // end anonymous namespace
77 char AMDGPUUnifyDivergentExitNodes::ID
= 0;
79 char &llvm::AMDGPUUnifyDivergentExitNodesID
= AMDGPUUnifyDivergentExitNodes::ID
;
81 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
82 "Unify divergent function exit nodes", false, false)
83 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
84 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass
)
85 INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis
)
86 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
, DEBUG_TYPE
,
87 "Unify divergent function exit nodes", false, false)
89 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage
&AU
) const{
90 if (RequireAndPreserveDomTree
)
91 AU
.addRequired
<DominatorTreeWrapperPass
>();
93 AU
.addRequired
<PostDominatorTreeWrapperPass
>();
95 AU
.addRequired
<LegacyDivergenceAnalysis
>();
97 if (RequireAndPreserveDomTree
) {
98 AU
.addPreserved
<DominatorTreeWrapperPass
>();
99 // FIXME: preserve PostDominatorTreeWrapperPass
102 // No divergent values are changed, only blocks and branch edges.
103 AU
.addPreserved
<LegacyDivergenceAnalysis
>();
105 // We preserve the non-critical-edgeness property
106 AU
.addPreservedID(BreakCriticalEdgesID
);
108 // This is a cluster of orthogonal Transforms
109 AU
.addPreservedID(LowerSwitchID
);
110 FunctionPass::getAnalysisUsage(AU
);
112 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
115 /// \returns true if \p BB is reachable through only uniform branches.
116 /// XXX - Is there a more efficient way to find this?
117 static bool isUniformlyReached(const LegacyDivergenceAnalysis
&DA
,
119 SmallVector
<BasicBlock
*, 8> Stack(predecessors(&BB
));
120 SmallPtrSet
<BasicBlock
*, 8> Visited
;
122 while (!Stack
.empty()) {
123 BasicBlock
*Top
= Stack
.pop_back_val();
124 if (!DA
.isUniform(Top
->getTerminator()))
127 for (BasicBlock
*Pred
: predecessors(Top
)) {
128 if (Visited
.insert(Pred
).second
)
129 Stack
.push_back(Pred
);
136 BasicBlock
*AMDGPUUnifyDivergentExitNodes::unifyReturnBlockSet(
137 Function
&F
, DomTreeUpdater
&DTU
, ArrayRef
<BasicBlock
*> ReturningBlocks
,
139 // Otherwise, we need to insert a new basic block into the function, add a PHI
140 // nodes (if the function returns values), and convert all of the return
141 // instructions into unconditional branches.
142 BasicBlock
*NewRetBlock
= BasicBlock::Create(F
.getContext(), Name
, &F
);
143 IRBuilder
<> B(NewRetBlock
);
145 PHINode
*PN
= nullptr;
146 if (F
.getReturnType()->isVoidTy()) {
149 // If the function doesn't return void... add a PHI node to the block...
150 PN
= B
.CreatePHI(F
.getReturnType(), ReturningBlocks
.size(),
155 // Loop over all of the blocks, replacing the return instruction with an
156 // unconditional branch.
157 std::vector
<DominatorTree::UpdateType
> Updates
;
158 Updates
.reserve(ReturningBlocks
.size());
159 for (BasicBlock
*BB
: ReturningBlocks
) {
160 // Add an incoming element to the PHI node for every return instruction that
161 // is merging into this new block...
163 PN
->addIncoming(BB
->getTerminator()->getOperand(0), BB
);
165 // Remove and delete the return inst.
166 BB
->getTerminator()->eraseFromParent();
167 BranchInst::Create(NewRetBlock
, BB
);
168 Updates
.push_back({DominatorTree::Insert
, BB
, NewRetBlock
});
171 if (RequireAndPreserveDomTree
)
172 DTU
.applyUpdates(Updates
);
175 for (BasicBlock
*BB
: ReturningBlocks
) {
176 // Cleanup possible branch to unconditional branch to the return.
177 simplifyCFG(BB
, *TTI
, RequireAndPreserveDomTree
? &DTU
: nullptr,
178 SimplifyCFGOptions().bonusInstThreshold(2));
184 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function
&F
) {
185 DominatorTree
*DT
= nullptr;
186 if (RequireAndPreserveDomTree
)
187 DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
189 auto &PDT
= getAnalysis
<PostDominatorTreeWrapperPass
>().getPostDomTree();
191 // If there's only one exit, we don't need to do anything.
192 if (PDT
.root_size() <= 1)
195 LegacyDivergenceAnalysis
&DA
= getAnalysis
<LegacyDivergenceAnalysis
>();
196 TTI
= &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
198 // Loop over all of the blocks in a function, tracking all of the blocks that
200 SmallVector
<BasicBlock
*, 4> ReturningBlocks
;
201 SmallVector
<BasicBlock
*, 4> UnreachableBlocks
;
203 // Dummy return block for infinite loop.
204 BasicBlock
*DummyReturnBB
= nullptr;
206 bool Changed
= false;
207 std::vector
<DominatorTree::UpdateType
> Updates
;
209 for (BasicBlock
*BB
: PDT
.roots()) {
210 if (isa
<ReturnInst
>(BB
->getTerminator())) {
211 if (!isUniformlyReached(DA
, *BB
))
212 ReturningBlocks
.push_back(BB
);
213 } else if (isa
<UnreachableInst
>(BB
->getTerminator())) {
214 if (!isUniformlyReached(DA
, *BB
))
215 UnreachableBlocks
.push_back(BB
);
216 } else if (BranchInst
*BI
= dyn_cast
<BranchInst
>(BB
->getTerminator())) {
218 ConstantInt
*BoolTrue
= ConstantInt::getTrue(F
.getContext());
219 if (DummyReturnBB
== nullptr) {
220 DummyReturnBB
= BasicBlock::Create(F
.getContext(),
221 "DummyReturnBlock", &F
);
222 Type
*RetTy
= F
.getReturnType();
223 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
224 ReturnInst::Create(F
.getContext(), RetVal
, DummyReturnBB
);
225 ReturningBlocks
.push_back(DummyReturnBB
);
228 if (BI
->isUnconditional()) {
229 BasicBlock
*LoopHeaderBB
= BI
->getSuccessor(0);
230 BI
->eraseFromParent(); // Delete the unconditional branch.
231 // Add a new conditional branch with a dummy edge to the return block.
232 BranchInst::Create(LoopHeaderBB
, DummyReturnBB
, BoolTrue
, BB
);
233 Updates
.push_back({DominatorTree::Insert
, BB
, DummyReturnBB
});
234 } else { // Conditional branch.
235 SmallVector
<BasicBlock
*, 2> Successors(succ_begin(BB
), succ_end(BB
));
237 // Create a new transition block to hold the conditional branch.
238 BasicBlock
*TransitionBB
= BB
->splitBasicBlock(BI
, "TransitionBlock");
240 Updates
.reserve(Updates
.size() + 2 * Successors
.size() + 2);
242 // 'Successors' become successors of TransitionBB instead of BB,
243 // and TransitionBB becomes a single successor of BB.
244 Updates
.push_back({DominatorTree::Insert
, BB
, TransitionBB
});
245 for (BasicBlock
*Successor
: Successors
) {
246 Updates
.push_back({DominatorTree::Insert
, TransitionBB
, Successor
});
247 Updates
.push_back({DominatorTree::Delete
, BB
, Successor
});
250 // Create a branch that will always branch to the transition block and
251 // references DummyReturnBB.
252 BB
->getTerminator()->eraseFromParent();
253 BranchInst::Create(TransitionBB
, DummyReturnBB
, BoolTrue
, BB
);
254 Updates
.push_back({DominatorTree::Insert
, BB
, DummyReturnBB
});
260 if (!UnreachableBlocks
.empty()) {
261 BasicBlock
*UnreachableBlock
= nullptr;
263 if (UnreachableBlocks
.size() == 1) {
264 UnreachableBlock
= UnreachableBlocks
.front();
266 UnreachableBlock
= BasicBlock::Create(F
.getContext(),
267 "UnifiedUnreachableBlock", &F
);
268 new UnreachableInst(F
.getContext(), UnreachableBlock
);
270 Updates
.reserve(Updates
.size() + UnreachableBlocks
.size());
271 for (BasicBlock
*BB
: UnreachableBlocks
) {
272 // Remove and delete the unreachable inst.
273 BB
->getTerminator()->eraseFromParent();
274 BranchInst::Create(UnreachableBlock
, BB
);
275 Updates
.push_back({DominatorTree::Insert
, BB
, UnreachableBlock
});
280 if (!ReturningBlocks
.empty()) {
281 // Don't create a new unreachable inst if we have a return. The
282 // structurizer/annotator can't handle the multiple exits
284 Type
*RetTy
= F
.getReturnType();
285 Value
*RetVal
= RetTy
->isVoidTy() ? nullptr : UndefValue::get(RetTy
);
286 // Remove and delete the unreachable inst.
287 UnreachableBlock
->getTerminator()->eraseFromParent();
289 Function
*UnreachableIntrin
=
290 Intrinsic::getDeclaration(F
.getParent(), Intrinsic::amdgcn_unreachable
);
292 // Insert a call to an intrinsic tracking that this is an unreachable
293 // point, in case we want to kill the active lanes or something later.
294 CallInst::Create(UnreachableIntrin
, {}, "", UnreachableBlock
);
296 // Don't create a scalar trap. We would only want to trap if this code was
297 // really reached, but a scalar trap would happen even if no lanes
298 // actually reached here.
299 ReturnInst::Create(F
.getContext(), RetVal
, UnreachableBlock
);
300 ReturningBlocks
.push_back(UnreachableBlock
);
305 // FIXME: add PDT here once simplifycfg is ready.
306 DomTreeUpdater
DTU(DT
, DomTreeUpdater::UpdateStrategy::Eager
);
307 if (RequireAndPreserveDomTree
)
308 DTU
.applyUpdates(Updates
);
311 // Now handle return blocks.
312 if (ReturningBlocks
.empty())
313 return Changed
; // No blocks return
315 if (ReturningBlocks
.size() == 1)
316 return Changed
; // Already has a single return block
318 unifyReturnBlockSet(F
, DTU
, ReturningBlocks
, "UnifiedReturnBlock");