1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). That implementation in turn is based on the approach
30 // Improving Performance of OpenCL on CPUs
31 // Ralf Karrenberg and Sebastian Hack
34 // This implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (DivergenceAnalysis) for functions, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
47 // %cond = icmp slt i32 %tid, 10
48 // br i1 %cond, label %then, label %else
54 // %a = phi i32 [ 0, %then ], [ 1, %else ]
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
64 // The current implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 // non-kernel-entry function and the return value of a function call as
68 // 2. memory as black box. It conservatively considers values loaded from
69 // generic or local address as divergent. This can be improved by leveraging
70 // pointer analysis and/or by modelling non-escaping memory objects in SSA
73 //===----------------------------------------------------------------------===//
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/Analysis/CFG.h"
77 #include "llvm/Analysis/LoopInfo.h"
78 #include "llvm/Analysis/Passes.h"
79 #include "llvm/Analysis/PostDominators.h"
80 #include "llvm/Analysis/TargetTransformInfo.h"
81 #include "llvm/IR/Dominators.h"
82 #include "llvm/IR/InstIterator.h"
83 #include "llvm/IR/Instructions.h"
84 #include "llvm/IR/IntrinsicInst.h"
85 #include "llvm/IR/Value.h"
86 #include "llvm/Support/Debug.h"
87 #include "llvm/Support/raw_ostream.h"
91 #define DEBUG_TYPE "divergence"
93 DivergenceAnalysisImpl::DivergenceAnalysisImpl(
94 const Function
&F
, const Loop
*RegionLoop
, const DominatorTree
&DT
,
95 const LoopInfo
&LI
, SyncDependenceAnalysis
&SDA
, bool IsLCSSAForm
)
96 : F(F
), RegionLoop(RegionLoop
), DT(DT
), LI(LI
), SDA(SDA
),
97 IsLCSSAForm(IsLCSSAForm
) {}
99 bool DivergenceAnalysisImpl::markDivergent(const Value
&DivVal
) {
100 if (isAlwaysUniform(DivVal
))
102 assert(isa
<Instruction
>(DivVal
) || isa
<Argument
>(DivVal
));
103 assert(!isAlwaysUniform(DivVal
) && "cannot be a divergent");
104 return DivergentValues
.insert(&DivVal
).second
;
107 void DivergenceAnalysisImpl::addUniformOverride(const Value
&UniVal
) {
108 UniformOverrides
.insert(&UniVal
);
111 bool DivergenceAnalysisImpl::isTemporalDivergent(
112 const BasicBlock
&ObservingBlock
, const Value
&Val
) const {
113 const auto *Inst
= dyn_cast
<const Instruction
>(&Val
);
116 // check whether any divergent loop carrying Val terminates before control
117 // proceeds to ObservingBlock
118 for (const auto *Loop
= LI
.getLoopFor(Inst
->getParent());
119 Loop
!= RegionLoop
&& !Loop
->contains(&ObservingBlock
);
120 Loop
= Loop
->getParentLoop()) {
121 if (DivergentLoops
.contains(Loop
))
128 bool DivergenceAnalysisImpl::inRegion(const Instruction
&I
) const {
129 return I
.getParent() && inRegion(*I
.getParent());
132 bool DivergenceAnalysisImpl::inRegion(const BasicBlock
&BB
) const {
133 return (!RegionLoop
&& BB
.getParent() == &F
) || RegionLoop
->contains(&BB
);
136 void DivergenceAnalysisImpl::pushUsers(const Value
&V
) {
137 const auto *I
= dyn_cast
<const Instruction
>(&V
);
139 if (I
&& I
->isTerminator()) {
140 analyzeControlDivergence(*I
);
144 for (const auto *User
: V
.users()) {
145 const auto *UserInst
= dyn_cast
<const Instruction
>(User
);
149 // only compute divergent inside loop
150 if (!inRegion(*UserInst
))
153 // All users of divergent values are immediate divergent
154 if (markDivergent(*UserInst
))
155 Worklist
.push_back(UserInst
);
159 static const Instruction
*getIfCarriedInstruction(const Use
&U
,
160 const Loop
&DivLoop
) {
161 const auto *I
= dyn_cast
<const Instruction
>(&U
);
164 if (!DivLoop
.contains(I
))
169 void DivergenceAnalysisImpl::analyzeTemporalDivergence(
170 const Instruction
&I
, const Loop
&OuterDivLoop
) {
171 if (isAlwaysUniform(I
))
176 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I
.getName() << "\n");
177 assert((isa
<PHINode
>(I
) || !IsLCSSAForm
) &&
178 "In LCSSA form all users of loop-exiting defs are Phi nodes.");
179 for (const Use
&Op
: I
.operands()) {
180 const auto *OpInst
= getIfCarriedInstruction(Op
, OuterDivLoop
);
183 if (markDivergent(I
))
189 // marks all users of loop-carried values of the loop headed by LoopHeader as
191 void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
192 const BasicBlock
&DivExit
, const Loop
&OuterDivLoop
) {
193 // All users are in immediate exit blocks
195 for (const auto &Phi
: DivExit
.phis()) {
196 analyzeTemporalDivergence(Phi
, OuterDivLoop
);
201 // For non-LCSSA we have to follow all live out edges wherever they may lead.
202 const BasicBlock
&LoopHeader
= *OuterDivLoop
.getHeader();
203 SmallVector
<const BasicBlock
*, 8> TaintStack
;
204 TaintStack
.push_back(&DivExit
);
206 // Otherwise potential users of loop-carried values could be anywhere in the
207 // dominance region of DivLoop (including its fringes for phi nodes)
208 DenseSet
<const BasicBlock
*> Visited
;
209 Visited
.insert(&DivExit
);
212 auto *UserBlock
= TaintStack
.pop_back_val();
214 // don't spread divergence beyond the region
215 if (!inRegion(*UserBlock
))
218 assert(!OuterDivLoop
.contains(UserBlock
) &&
219 "irreducible control flow detected");
221 // phi nodes at the fringes of the dominance region
222 if (!DT
.dominates(&LoopHeader
, UserBlock
)) {
223 // all PHI nodes of UserBlock become divergent
224 for (auto &Phi
: UserBlock
->phis()) {
225 analyzeTemporalDivergence(Phi
, OuterDivLoop
);
230 // Taint outside users of values carried by OuterDivLoop.
231 for (auto &I
: *UserBlock
) {
232 analyzeTemporalDivergence(I
, OuterDivLoop
);
235 // visit all blocks in the dominance region
236 for (auto *SuccBlock
: successors(UserBlock
)) {
237 if (!Visited
.insert(SuccBlock
).second
) {
240 TaintStack
.push_back(SuccBlock
);
242 } while (!TaintStack
.empty());
245 void DivergenceAnalysisImpl::propagateLoopExitDivergence(
246 const BasicBlock
&DivExit
, const Loop
&InnerDivLoop
) {
247 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit
.getName() << "\n");
249 // Find outer-most loop that does not contain \p DivExit
250 const Loop
*DivLoop
= &InnerDivLoop
;
251 const Loop
*OuterDivLoop
= DivLoop
;
252 const Loop
*ExitLevelLoop
= LI
.getLoopFor(&DivExit
);
253 const unsigned LoopExitDepth
=
254 ExitLevelLoop
? ExitLevelLoop
->getLoopDepth() : 0;
255 while (DivLoop
&& DivLoop
->getLoopDepth() > LoopExitDepth
) {
256 DivergentLoops
.insert(DivLoop
); // all crossed loops are divergent
257 OuterDivLoop
= DivLoop
;
258 DivLoop
= DivLoop
->getParentLoop();
260 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop
->getName()
263 analyzeLoopExitDivergence(DivExit
, *OuterDivLoop
);
266 // this is a divergent join point - mark all phi nodes as divergent and push
267 // them onto the stack.
268 void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock
&JoinBlock
) {
269 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock
.getName()
272 // ignore divergence outside the region
273 if (!inRegion(JoinBlock
)) {
277 // push non-divergent phi nodes in JoinBlock to the worklist
278 for (const auto &Phi
: JoinBlock
.phis()) {
279 if (isDivergent(Phi
))
281 // FIXME Theoretically ,the 'undef' value could be replaced by any other
282 // value causing spurious divergence.
283 if (Phi
.hasConstantOrUndefValue())
285 if (markDivergent(Phi
))
286 Worklist
.push_back(&Phi
);
290 void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction
&Term
) {
291 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term
.getParent()->getName()
294 // Don't propagate divergence from unreachable blocks.
295 if (!DT
.isReachableFromEntry(Term
.getParent()))
298 const auto *BranchLoop
= LI
.getLoopFor(Term
.getParent());
300 const auto &DivDesc
= SDA
.getJoinBlocks(Term
);
302 // Iterate over all blocks now reachable by a disjoint path join
303 for (const auto *JoinBlock
: DivDesc
.JoinDivBlocks
) {
304 taintAndPushPhiNodes(*JoinBlock
);
307 assert(DivDesc
.LoopDivBlocks
.empty() || BranchLoop
);
308 for (const auto *DivExitBlock
: DivDesc
.LoopDivBlocks
) {
309 propagateLoopExitDivergence(*DivExitBlock
, *BranchLoop
);
313 void DivergenceAnalysisImpl::compute() {
314 // Initialize worklist.
315 auto DivValuesCopy
= DivergentValues
;
316 for (const auto *DivVal
: DivValuesCopy
) {
317 assert(isDivergent(*DivVal
) && "Worklist invariant violated!");
321 // All values on the Worklist are divergent.
322 // Their users may not have been updated yed.
323 while (!Worklist
.empty()) {
324 const Instruction
&I
= *Worklist
.back();
327 // propagate value divergence to users
328 assert(isDivergent(I
) && "Worklist invariant violated!");
333 bool DivergenceAnalysisImpl::isAlwaysUniform(const Value
&V
) const {
334 return UniformOverrides
.contains(&V
);
337 bool DivergenceAnalysisImpl::isDivergent(const Value
&V
) const {
338 return DivergentValues
.contains(&V
);
341 bool DivergenceAnalysisImpl::isDivergentUse(const Use
&U
) const {
343 Instruction
&I
= *cast
<Instruction
>(U
.getUser());
344 return isDivergent(V
) || isTemporalDivergent(*I
.getParent(), V
);
347 DivergenceInfo::DivergenceInfo(Function
&F
, const DominatorTree
&DT
,
348 const PostDominatorTree
&PDT
, const LoopInfo
&LI
,
349 const TargetTransformInfo
&TTI
,
351 : F(F
), ContainsIrreducible(false) {
352 if (!KnownReducible
) {
353 using RPOTraversal
= ReversePostOrderTraversal
<const Function
*>;
354 RPOTraversal
FuncRPOT(&F
);
355 if (containsIrreducibleCFG
<const BasicBlock
*, const RPOTraversal
,
356 const LoopInfo
>(FuncRPOT
, LI
)) {
357 ContainsIrreducible
= true;
361 SDA
= std::make_unique
<SyncDependenceAnalysis
>(DT
, PDT
, LI
);
362 DA
= std::make_unique
<DivergenceAnalysisImpl
>(F
, nullptr, DT
, LI
, *SDA
,
364 for (auto &I
: instructions(F
)) {
365 if (TTI
.isSourceOfDivergence(&I
)) {
366 DA
->markDivergent(I
);
367 } else if (TTI
.isAlwaysUniform(&I
)) {
368 DA
->addUniformOverride(I
);
371 for (auto &Arg
: F
.args()) {
372 if (TTI
.isSourceOfDivergence(&Arg
)) {
373 DA
->markDivergent(Arg
);
380 AnalysisKey
DivergenceAnalysis::Key
;
382 DivergenceAnalysis::Result
383 DivergenceAnalysis::run(Function
&F
, FunctionAnalysisManager
&AM
) {
384 auto &DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
385 auto &PDT
= AM
.getResult
<PostDominatorTreeAnalysis
>(F
);
386 auto &LI
= AM
.getResult
<LoopAnalysis
>(F
);
387 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
389 return DivergenceInfo(F
, DT
, PDT
, LI
, TTI
, /* KnownReducible = */ false);
393 DivergenceAnalysisPrinterPass::run(Function
&F
, FunctionAnalysisManager
&FAM
) {
394 auto &DI
= FAM
.getResult
<DivergenceAnalysis
>(F
);
395 OS
<< "'Divergence Analysis' for function '" << F
.getName() << "':\n";
396 if (DI
.hasDivergence()) {
397 for (auto &Arg
: F
.args()) {
398 OS
<< (DI
.isDivergent(Arg
) ? "DIVERGENT: " : " ");
401 for (const BasicBlock
&BB
: F
) {
402 OS
<< "\n " << BB
.getName() << ":\n";
403 for (auto &I
: BB
.instructionsWithoutDebug()) {
404 OS
<< (DI
.isDivergent(I
) ? "DIVERGENT: " : " ");
409 return PreservedAnalyses::all();