1 //===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). That implementation in turn is based on the approach
30 // Improving Performance of OpenCL on CPUs
31 // Ralf Karrenberg and Sebastian Hack
34 // This DivergenceAnalysis implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
47 // %cond = icmp slt i32 %tid, 10
48 // br i1 %cond, label %then, label %else
54 // %a = phi i32 [ 0, %then ], [ 1, %else ]
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
64 // The current DivergenceAnalysis implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 // non-kernel-entry function and the return value of a function call as
68 // 2. memory as black box. It conservatively considers values loaded from
69 // generic or local address as divergent. This can be improved by leveraging
70 // pointer analysis and/or by modelling non-escaping memory objects in SSA
73 //===----------------------------------------------------------------------===//
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/Analysis/LoopInfo.h"
77 #include "llvm/Analysis/Passes.h"
78 #include "llvm/Analysis/PostDominators.h"
79 #include "llvm/Analysis/TargetTransformInfo.h"
80 #include "llvm/IR/Dominators.h"
81 #include "llvm/IR/InstIterator.h"
82 #include "llvm/IR/Instructions.h"
83 #include "llvm/IR/IntrinsicInst.h"
84 #include "llvm/IR/Value.h"
85 #include "llvm/Support/Debug.h"
86 #include "llvm/Support/raw_ostream.h"
91 #define DEBUG_TYPE "divergence-analysis"
93 // class DivergenceAnalysis
94 DivergenceAnalysis::DivergenceAnalysis(
95 const Function
&F
, const Loop
*RegionLoop
, const DominatorTree
&DT
,
96 const LoopInfo
&LI
, SyncDependenceAnalysis
&SDA
, bool IsLCSSAForm
)
97 : F(F
), RegionLoop(RegionLoop
), DT(DT
), LI(LI
), SDA(SDA
),
98 IsLCSSAForm(IsLCSSAForm
) {}
100 void DivergenceAnalysis::markDivergent(const Value
&DivVal
) {
101 assert(isa
<Instruction
>(DivVal
) || isa
<Argument
>(DivVal
));
102 assert(!isAlwaysUniform(DivVal
) && "cannot be a divergent");
103 DivergentValues
.insert(&DivVal
);
106 void DivergenceAnalysis::addUniformOverride(const Value
&UniVal
) {
107 UniformOverrides
.insert(&UniVal
);
110 bool DivergenceAnalysis::updateTerminator(const Instruction
&Term
) const {
111 if (Term
.getNumSuccessors() <= 1)
113 if (auto *BranchTerm
= dyn_cast
<BranchInst
>(&Term
)) {
114 assert(BranchTerm
->isConditional());
115 return isDivergent(*BranchTerm
->getCondition());
117 if (auto *SwitchTerm
= dyn_cast
<SwitchInst
>(&Term
)) {
118 return isDivergent(*SwitchTerm
->getCondition());
120 if (isa
<InvokeInst
>(Term
)) {
121 return false; // ignore abnormal executions through landingpad
124 llvm_unreachable("unexpected terminator");
127 bool DivergenceAnalysis::updateNormalInstruction(const Instruction
&I
) const {
128 // TODO function calls with side effects, etc
129 for (const auto &Op
: I
.operands()) {
130 if (isDivergent(*Op
))
136 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock
&ObservingBlock
,
137 const Value
&Val
) const {
138 const auto *Inst
= dyn_cast
<const Instruction
>(&Val
);
141 // check whether any divergent loop carrying Val terminates before control
142 // proceeds to ObservingBlock
143 for (const auto *Loop
= LI
.getLoopFor(Inst
->getParent());
144 Loop
!= RegionLoop
&& !Loop
->contains(&ObservingBlock
);
145 Loop
= Loop
->getParentLoop()) {
146 if (DivergentLoops
.find(Loop
) != DivergentLoops
.end())
153 bool DivergenceAnalysis::updatePHINode(const PHINode
&Phi
) const {
154 // joining divergent disjoint path in Phi parent block
155 if (!Phi
.hasConstantOrUndefValue() && isJoinDivergent(*Phi
.getParent())) {
159 // An incoming value could be divergent by itself.
160 // Otherwise, an incoming value could be uniform within the loop
161 // that carries its definition but it may appear divergent
162 // from outside the loop. This happens when divergent loop exits
163 // drop definitions of that uniform value in different iterations.
165 // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop
166 // if (i % thread_id == 0) break; // divergent loop exit
168 // int divI = i; // divI is divergent
169 for (size_t i
= 0; i
< Phi
.getNumIncomingValues(); ++i
) {
170 const auto *InVal
= Phi
.getIncomingValue(i
);
171 if (isDivergent(*Phi
.getIncomingValue(i
)) ||
172 isTemporalDivergent(*Phi
.getParent(), *InVal
)) {
179 bool DivergenceAnalysis::inRegion(const Instruction
&I
) const {
180 return I
.getParent() && inRegion(*I
.getParent());
183 bool DivergenceAnalysis::inRegion(const BasicBlock
&BB
) const {
184 return (!RegionLoop
&& BB
.getParent() == &F
) || RegionLoop
->contains(&BB
);
187 // marks all users of loop-carried values of the loop headed by LoopHeader as
189 void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock
&LoopHeader
) {
190 auto *DivLoop
= LI
.getLoopFor(&LoopHeader
);
191 assert(DivLoop
&& "loopHeader is not actually part of a loop");
193 SmallVector
<BasicBlock
*, 8> TaintStack
;
194 DivLoop
->getExitBlocks(TaintStack
);
196 // Otherwise potential users of loop-carried values could be anywhere in the
197 // dominance region of DivLoop (including its fringes for phi nodes)
198 DenseSet
<const BasicBlock
*> Visited
;
199 for (auto *Block
: TaintStack
) {
200 Visited
.insert(Block
);
202 Visited
.insert(&LoopHeader
);
204 while (!TaintStack
.empty()) {
205 auto *UserBlock
= TaintStack
.back();
206 TaintStack
.pop_back();
208 // don't spread divergence beyond the region
209 if (!inRegion(*UserBlock
))
212 assert(!DivLoop
->contains(UserBlock
) &&
213 "irreducible control flow detected");
215 // phi nodes at the fringes of the dominance region
216 if (!DT
.dominates(&LoopHeader
, UserBlock
)) {
217 // all PHI nodes of UserBlock become divergent
218 for (auto &Phi
: UserBlock
->phis()) {
219 Worklist
.push_back(&Phi
);
224 // taint outside users of values carried by DivLoop
225 for (auto &I
: *UserBlock
) {
226 if (isAlwaysUniform(I
))
231 for (auto &Op
: I
.operands()) {
232 auto *OpInst
= dyn_cast
<Instruction
>(&Op
);
235 if (DivLoop
->contains(OpInst
->getParent())) {
243 // visit all blocks in the dominance region
244 for (auto *SuccBlock
: successors(UserBlock
)) {
245 if (!Visited
.insert(SuccBlock
).second
) {
248 TaintStack
.push_back(SuccBlock
);
253 void DivergenceAnalysis::pushPHINodes(const BasicBlock
&Block
) {
254 for (const auto &Phi
: Block
.phis()) {
255 if (isDivergent(Phi
))
257 Worklist
.push_back(&Phi
);
261 void DivergenceAnalysis::pushUsers(const Value
&V
) {
262 for (const auto *User
: V
.users()) {
263 const auto *UserInst
= dyn_cast
<const Instruction
>(User
);
267 if (isDivergent(*UserInst
))
270 // only compute divergent inside loop
271 if (!inRegion(*UserInst
))
273 Worklist
.push_back(UserInst
);
277 bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock
&JoinBlock
,
278 const Loop
*BranchLoop
) {
279 LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock
.getName() << "\n");
281 // ignore divergence outside the region
282 if (!inRegion(JoinBlock
)) {
286 // push non-divergent phi nodes in JoinBlock to the worklist
287 pushPHINodes(JoinBlock
);
289 // JoinBlock is a divergent loop exit
290 if (BranchLoop
&& !BranchLoop
->contains(&JoinBlock
)) {
294 // disjoint-paths divergent at JoinBlock
295 markBlockJoinDivergent(JoinBlock
);
299 void DivergenceAnalysis::propagateBranchDivergence(const Instruction
&Term
) {
300 LLVM_DEBUG(dbgs() << "propBranchDiv " << Term
.getParent()->getName() << "\n");
304 const auto *BranchLoop
= LI
.getLoopFor(Term
.getParent());
306 // whether there is a divergent loop exit from BranchLoop (if any)
307 bool IsBranchLoopDivergent
= false;
309 // iterate over all blocks reachable by disjoint from Term within the loop
310 // also iterates over loop exits that become divergent due to Term.
311 for (const auto *JoinBlock
: SDA
.join_blocks(Term
)) {
312 IsBranchLoopDivergent
|= propagateJoinDivergence(*JoinBlock
, BranchLoop
);
315 // Branch loop is a divergent loop due to the divergent branch in Term
316 if (IsBranchLoopDivergent
) {
318 if (!DivergentLoops
.insert(BranchLoop
).second
) {
321 propagateLoopDivergence(*BranchLoop
);
325 void DivergenceAnalysis::propagateLoopDivergence(const Loop
&ExitingLoop
) {
326 LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop
.getName() << "\n");
328 // don't propagate beyond region
329 if (!inRegion(*ExitingLoop
.getHeader()))
332 const auto *BranchLoop
= ExitingLoop
.getParentLoop();
334 // Uses of loop-carried values could occur anywhere
335 // within the dominance region of the definition. All loop-carried
336 // definitions are dominated by the loop header (reducible control).
337 // Thus all users have to be in the dominance region of the loop header,
338 // except PHI nodes that can also live at the fringes of the dom region
339 // (incoming defining value).
341 taintLoopLiveOuts(*ExitingLoop
.getHeader());
343 // whether there is a divergent loop exit from BranchLoop (if any)
344 bool IsBranchLoopDivergent
= false;
346 // iterate over all blocks reachable by disjoint paths from exits of
347 // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn
349 for (const auto *JoinBlock
: SDA
.join_blocks(ExitingLoop
)) {
350 IsBranchLoopDivergent
|= propagateJoinDivergence(*JoinBlock
, BranchLoop
);
353 // Branch loop is a divergent due to divergent loop exit in ExitingLoop
354 if (IsBranchLoopDivergent
) {
356 if (!DivergentLoops
.insert(BranchLoop
).second
) {
359 propagateLoopDivergence(*BranchLoop
);
363 void DivergenceAnalysis::compute() {
364 for (auto *DivVal
: DivergentValues
) {
368 // propagate divergence
369 while (!Worklist
.empty()) {
370 const Instruction
&I
= *Worklist
.back();
373 // maintain uniformity of overrides
374 if (isAlwaysUniform(I
))
377 bool WasDivergent
= isDivergent(I
);
381 // propagate divergence caused by terminator
382 if (I
.isTerminator()) {
383 if (updateTerminator(I
)) {
384 // propagate control divergence to affected instructions
385 propagateBranchDivergence(I
);
390 // update divergence of I due to divergent operands
391 bool DivergentUpd
= false;
392 const auto *Phi
= dyn_cast
<const PHINode
>(&I
);
394 DivergentUpd
= updatePHINode(*Phi
);
396 DivergentUpd
= updateNormalInstruction(I
);
399 // propagate value divergence to users
407 bool DivergenceAnalysis::isAlwaysUniform(const Value
&V
) const {
408 return UniformOverrides
.find(&V
) != UniformOverrides
.end();
411 bool DivergenceAnalysis::isDivergent(const Value
&V
) const {
412 return DivergentValues
.find(&V
) != DivergentValues
.end();
415 bool DivergenceAnalysis::isDivergentUse(const Use
&U
) const {
417 Instruction
&I
= *cast
<Instruction
>(U
.getUser());
418 return isDivergent(V
) || isTemporalDivergent(*I
.getParent(), V
);
421 void DivergenceAnalysis::print(raw_ostream
&OS
, const Module
*) const {
422 if (DivergentValues
.empty())
424 // iterate instructions using instructions() to ensure a deterministic order.
425 for (auto &I
: instructions(F
)) {
427 OS
<< "DIVERGENT:" << I
<< '\n';
431 // class GPUDivergenceAnalysis
432 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function
&F
,
433 const DominatorTree
&DT
,
434 const PostDominatorTree
&PDT
,
436 const TargetTransformInfo
&TTI
)
437 : SDA(DT
, PDT
, LI
), DA(F
, nullptr, DT
, LI
, SDA
, false) {
438 for (auto &I
: instructions(F
)) {
439 if (TTI
.isSourceOfDivergence(&I
)) {
441 } else if (TTI
.isAlwaysUniform(&I
)) {
442 DA
.addUniformOverride(I
);
445 for (auto &Arg
: F
.args()) {
446 if (TTI
.isSourceOfDivergence(&Arg
)) {
447 DA
.markDivergent(Arg
);
454 bool GPUDivergenceAnalysis::isDivergent(const Value
&val
) const {
455 return DA
.isDivergent(val
);
458 bool GPUDivergenceAnalysis::isDivergentUse(const Use
&use
) const {
459 return DA
.isDivergentUse(use
);
462 void GPUDivergenceAnalysis::print(raw_ostream
&OS
, const Module
*mod
) const {
463 OS
<< "Divergence of kernel " << DA
.getFunction().getName() << " {\n";