[ARM] More MVE compare vector splat combines for ANDs
[llvm-complete.git] / lib / Analysis / DivergenceAnalysis.cpp
blob0ccd59ef2bfd46134d79a3c855125862847383d3
1 //===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
19 // they re-converge.
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). That implementation in turn is based on the approach
28 // described in
30 // Improving Performance of OpenCL on CPUs
31 // Ralf Karrenberg and Sebastian Hack
32 // CC '12
34 // This DivergenceAnalysis implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
47 // %cond = icmp slt i32 %tid, 10
48 // br i1 %cond, label %then, label %else
49 // then:
50 // br label %merge
51 // else:
52 // br label %merge
53 // merge:
54 // %a = phi i32 [ 0, %then ], [ 1, %else ]
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
59 // is assigned to.
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
64 // The current DivergenceAnalysis implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 // non-kernel-entry function and the return value of a function call as
67 // divergent.
68 // 2. memory as black box. It conservatively considers values loaded from
69 // generic or local address as divergent. This can be improved by leveraging
70 // pointer analysis and/or by modelling non-escaping memory objects in SSA
71 // as done in RV.
73 //===----------------------------------------------------------------------===//
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/Analysis/LoopInfo.h"
77 #include "llvm/Analysis/Passes.h"
78 #include "llvm/Analysis/PostDominators.h"
79 #include "llvm/Analysis/TargetTransformInfo.h"
80 #include "llvm/IR/Dominators.h"
81 #include "llvm/IR/InstIterator.h"
82 #include "llvm/IR/Instructions.h"
83 #include "llvm/IR/IntrinsicInst.h"
84 #include "llvm/IR/Value.h"
85 #include "llvm/Support/Debug.h"
86 #include "llvm/Support/raw_ostream.h"
87 #include <vector>
89 using namespace llvm;
91 #define DEBUG_TYPE "divergence-analysis"
93 // class DivergenceAnalysis
94 DivergenceAnalysis::DivergenceAnalysis(
95 const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
96 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
97 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
98 IsLCSSAForm(IsLCSSAForm) {}
100 void DivergenceAnalysis::markDivergent(const Value &DivVal) {
101 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
102 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
103 DivergentValues.insert(&DivVal);
106 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) {
107 UniformOverrides.insert(&UniVal);
110 bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const {
111 if (Term.getNumSuccessors() <= 1)
112 return false;
113 if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) {
114 assert(BranchTerm->isConditional());
115 return isDivergent(*BranchTerm->getCondition());
117 if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) {
118 return isDivergent(*SwitchTerm->getCondition());
120 if (isa<InvokeInst>(Term)) {
121 return false; // ignore abnormal executions through landingpad
124 llvm_unreachable("unexpected terminator");
127 bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const {
128 // TODO function calls with side effects, etc
129 for (const auto &Op : I.operands()) {
130 if (isDivergent(*Op))
131 return true;
133 return false;
136 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
137 const Value &Val) const {
138 const auto *Inst = dyn_cast<const Instruction>(&Val);
139 if (!Inst)
140 return false;
141 // check whether any divergent loop carrying Val terminates before control
142 // proceeds to ObservingBlock
143 for (const auto *Loop = LI.getLoopFor(Inst->getParent());
144 Loop != RegionLoop && !Loop->contains(&ObservingBlock);
145 Loop = Loop->getParentLoop()) {
146 if (DivergentLoops.find(Loop) != DivergentLoops.end())
147 return true;
150 return false;
153 bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const {
154 // joining divergent disjoint path in Phi parent block
155 if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) {
156 return true;
159 // An incoming value could be divergent by itself.
160 // Otherwise, an incoming value could be uniform within the loop
161 // that carries its definition but it may appear divergent
162 // from outside the loop. This happens when divergent loop exits
163 // drop definitions of that uniform value in different iterations.
165 // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop
166 // if (i % thread_id == 0) break; // divergent loop exit
167 // }
168 // int divI = i; // divI is divergent
169 for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) {
170 const auto *InVal = Phi.getIncomingValue(i);
171 if (isDivergent(*Phi.getIncomingValue(i)) ||
172 isTemporalDivergent(*Phi.getParent(), *InVal)) {
173 return true;
176 return false;
179 bool DivergenceAnalysis::inRegion(const Instruction &I) const {
180 return I.getParent() && inRegion(*I.getParent());
183 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
184 return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
187 // marks all users of loop-carried values of the loop headed by LoopHeader as
188 // divergent
189 void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
190 auto *DivLoop = LI.getLoopFor(&LoopHeader);
191 assert(DivLoop && "loopHeader is not actually part of a loop");
193 SmallVector<BasicBlock *, 8> TaintStack;
194 DivLoop->getExitBlocks(TaintStack);
196 // Otherwise potential users of loop-carried values could be anywhere in the
197 // dominance region of DivLoop (including its fringes for phi nodes)
198 DenseSet<const BasicBlock *> Visited;
199 for (auto *Block : TaintStack) {
200 Visited.insert(Block);
202 Visited.insert(&LoopHeader);
204 while (!TaintStack.empty()) {
205 auto *UserBlock = TaintStack.back();
206 TaintStack.pop_back();
208 // don't spread divergence beyond the region
209 if (!inRegion(*UserBlock))
210 continue;
212 assert(!DivLoop->contains(UserBlock) &&
213 "irreducible control flow detected");
215 // phi nodes at the fringes of the dominance region
216 if (!DT.dominates(&LoopHeader, UserBlock)) {
217 // all PHI nodes of UserBlock become divergent
218 for (auto &Phi : UserBlock->phis()) {
219 Worklist.push_back(&Phi);
221 continue;
224 // taint outside users of values carried by DivLoop
225 for (auto &I : *UserBlock) {
226 if (isAlwaysUniform(I))
227 continue;
228 if (isDivergent(I))
229 continue;
231 for (auto &Op : I.operands()) {
232 auto *OpInst = dyn_cast<Instruction>(&Op);
233 if (!OpInst)
234 continue;
235 if (DivLoop->contains(OpInst->getParent())) {
236 markDivergent(I);
237 pushUsers(I);
238 break;
243 // visit all blocks in the dominance region
244 for (auto *SuccBlock : successors(UserBlock)) {
245 if (!Visited.insert(SuccBlock).second) {
246 continue;
248 TaintStack.push_back(SuccBlock);
253 void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) {
254 for (const auto &Phi : Block.phis()) {
255 if (isDivergent(Phi))
256 continue;
257 Worklist.push_back(&Phi);
261 void DivergenceAnalysis::pushUsers(const Value &V) {
262 for (const auto *User : V.users()) {
263 const auto *UserInst = dyn_cast<const Instruction>(User);
264 if (!UserInst)
265 continue;
267 if (isDivergent(*UserInst))
268 continue;
270 // only compute divergent inside loop
271 if (!inRegion(*UserInst))
272 continue;
273 Worklist.push_back(UserInst);
277 bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock,
278 const Loop *BranchLoop) {
279 LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n");
281 // ignore divergence outside the region
282 if (!inRegion(JoinBlock)) {
283 return false;
286 // push non-divergent phi nodes in JoinBlock to the worklist
287 pushPHINodes(JoinBlock);
289 // JoinBlock is a divergent loop exit
290 if (BranchLoop && !BranchLoop->contains(&JoinBlock)) {
291 return true;
294 // disjoint-paths divergent at JoinBlock
295 markBlockJoinDivergent(JoinBlock);
296 return false;
299 void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
300 LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n");
302 markDivergent(Term);
304 const auto *BranchLoop = LI.getLoopFor(Term.getParent());
306 // whether there is a divergent loop exit from BranchLoop (if any)
307 bool IsBranchLoopDivergent = false;
309 // iterate over all blocks reachable by disjoint from Term within the loop
310 // also iterates over loop exits that become divergent due to Term.
311 for (const auto *JoinBlock : SDA.join_blocks(Term)) {
312 IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
315 // Branch loop is a divergent loop due to the divergent branch in Term
316 if (IsBranchLoopDivergent) {
317 assert(BranchLoop);
318 if (!DivergentLoops.insert(BranchLoop).second) {
319 return;
321 propagateLoopDivergence(*BranchLoop);
325 void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) {
326 LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n");
328 // don't propagate beyond region
329 if (!inRegion(*ExitingLoop.getHeader()))
330 return;
332 const auto *BranchLoop = ExitingLoop.getParentLoop();
334 // Uses of loop-carried values could occur anywhere
335 // within the dominance region of the definition. All loop-carried
336 // definitions are dominated by the loop header (reducible control).
337 // Thus all users have to be in the dominance region of the loop header,
338 // except PHI nodes that can also live at the fringes of the dom region
339 // (incoming defining value).
340 if (!IsLCSSAForm)
341 taintLoopLiveOuts(*ExitingLoop.getHeader());
343 // whether there is a divergent loop exit from BranchLoop (if any)
344 bool IsBranchLoopDivergent = false;
346 // iterate over all blocks reachable by disjoint paths from exits of
347 // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn
348 // become divergent.
349 for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) {
350 IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
353 // Branch loop is a divergent due to divergent loop exit in ExitingLoop
354 if (IsBranchLoopDivergent) {
355 assert(BranchLoop);
356 if (!DivergentLoops.insert(BranchLoop).second) {
357 return;
359 propagateLoopDivergence(*BranchLoop);
363 void DivergenceAnalysis::compute() {
364 for (auto *DivVal : DivergentValues) {
365 pushUsers(*DivVal);
368 // propagate divergence
369 while (!Worklist.empty()) {
370 const Instruction &I = *Worklist.back();
371 Worklist.pop_back();
373 // maintain uniformity of overrides
374 if (isAlwaysUniform(I))
375 continue;
377 bool WasDivergent = isDivergent(I);
378 if (WasDivergent)
379 continue;
381 // propagate divergence caused by terminator
382 if (I.isTerminator()) {
383 if (updateTerminator(I)) {
384 // propagate control divergence to affected instructions
385 propagateBranchDivergence(I);
386 continue;
390 // update divergence of I due to divergent operands
391 bool DivergentUpd = false;
392 const auto *Phi = dyn_cast<const PHINode>(&I);
393 if (Phi) {
394 DivergentUpd = updatePHINode(*Phi);
395 } else {
396 DivergentUpd = updateNormalInstruction(I);
399 // propagate value divergence to users
400 if (DivergentUpd) {
401 markDivergent(I);
402 pushUsers(I);
407 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const {
408 return UniformOverrides.find(&V) != UniformOverrides.end();
411 bool DivergenceAnalysis::isDivergent(const Value &V) const {
412 return DivergentValues.find(&V) != DivergentValues.end();
415 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
416 if (DivergentValues.empty())
417 return;
418 // iterate instructions using instructions() to ensure a deterministic order.
419 for (auto &I : instructions(F)) {
420 if (isDivergent(I))
421 OS << "DIVERGENT:" << I << '\n';
425 // class GPUDivergenceAnalysis
426 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F,
427 const DominatorTree &DT,
428 const PostDominatorTree &PDT,
429 const LoopInfo &LI,
430 const TargetTransformInfo &TTI)
431 : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) {
432 for (auto &I : instructions(F)) {
433 if (TTI.isSourceOfDivergence(&I)) {
434 DA.markDivergent(I);
435 } else if (TTI.isAlwaysUniform(&I)) {
436 DA.addUniformOverride(I);
439 for (auto &Arg : F.args()) {
440 if (TTI.isSourceOfDivergence(&Arg)) {
441 DA.markDivergent(Arg);
445 DA.compute();
448 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const {
449 return DA.isDivergent(val);
452 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const {
453 OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n";
454 DA.print(OS, mod);
455 OS << "}\n";