1 //===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements an algorithm that returns for a divergent branch
11 // the set of basic blocks whose phi nodes become divergent due to divergent
12 // control. These are the blocks that are reachable by two disjoint paths from
13 // the branch or loop exits that have a reaching path that is disjoint from a
14 // path to the loop latch.
16 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
17 // control-induced divergence in phi nodes.
20 // The SyncDependenceAnalysis lazily computes sync dependences [3].
21 // The analysis evaluates the disjoint path criterion [2] by a reduction
22 // to SSA construction. The SSA construction algorithm is implemented as
23 // a simple data-flow analysis [1].
25 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
26 // [2] "Efficiently Computing Static Single Assignment Form
27 // and the Control Dependence Graph", TOPLAS '91,
28 // Cytron, Ferrante, Rosen, Wegman and Zadeck
29 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
30 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
32 // -- Sync dependence --
33 // Sync dependence [4] characterizes the control flow aspect of the
34 // propagation of branch divergence. For example,
36 // %cond = icmp slt i32 %tid, 10
37 // br i1 %cond, label %then, label %else
43 // %a = phi i32 [ 0, %then ], [ 1, %else ]
45 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
46 // because %tid is not on its use-def chains, %a is sync dependent on %tid
47 // because the branch "br i1 %cond" depends on %tid and affects which value %a
50 // -- Reduction to SSA construction --
51 // There are two disjoint paths from A to X, if a certain variant of SSA
52 // construction places a phi node in X under the following set-up scheme [2].
54 // This variant of SSA construction ignores incoming undef values.
55 // That is paths from the entry without a definition do not result in
67 // Assume that A contains a divergent branch. We are interested
68 // in the set of all blocks where each block is reachable from A
69 // via two disjoint paths. This would be the set {D, F} in this
71 // To generally reduce this query to SSA construction we introduce
72 // a virtual variable x and assign to x different values in each
73 // successor block of A.
83 // Our flavor of SSA construction for x will construct the following
93 // The blocks D and F contain phi nodes and are thus each reachable
94 // by two disjoins paths from A.
97 // In case of loop exits we need to check the disjoint path criterion for loops
98 // [2]. To this end, we check whether the definition of x differs between the
99 // loop exit and the loop header (_after_ SSA construction).
101 //===----------------------------------------------------------------------===//
102 #include "llvm/ADT/PostOrderIterator.h"
103 #include "llvm/ADT/SmallPtrSet.h"
104 #include "llvm/Analysis/PostDominators.h"
105 #include "llvm/Analysis/SyncDependenceAnalysis.h"
106 #include "llvm/IR/BasicBlock.h"
107 #include "llvm/IR/CFG.h"
108 #include "llvm/IR/Dominators.h"
109 #include "llvm/IR/Function.h"
112 #include <unordered_set>
114 #define DEBUG_TYPE "sync-dependence"
118 ConstBlockSet
SyncDependenceAnalysis::EmptyBlockSet
;
120 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree
&DT
,
121 const PostDominatorTree
&PDT
,
123 : FuncRPOT(DT
.getRoot()->getParent()), DT(DT
), PDT(PDT
), LI(LI
) {}
125 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
127 using FunctionRPOT
= ReversePostOrderTraversal
<const Function
*>;
129 // divergence propagator for reducible CFGs
130 struct DivergencePropagator
{
131 const FunctionRPOT
&FuncRPOT
;
132 const DominatorTree
&DT
;
133 const PostDominatorTree
&PDT
;
136 // identified join points
137 std::unique_ptr
<ConstBlockSet
> JoinBlocks
;
139 // reached loop exits (by a path disjoint to a path to the loop header)
140 SmallPtrSet
<const BasicBlock
*, 4> ReachedLoopExits
;
142 // if DefMap[B] == C then C is the dominating definition at block B
143 // if DefMap[B] ~ undef then we haven't seen B yet
144 // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
145 // an immediate successor of X (initial value).
146 using DefiningBlockMap
= std::map
<const BasicBlock
*, const BasicBlock
*>;
147 DefiningBlockMap DefMap
;
149 // all blocks with pending visits
150 std::unordered_set
<const BasicBlock
*> PendingUpdates
;
152 DivergencePropagator(const FunctionRPOT
&FuncRPOT
, const DominatorTree
&DT
,
153 const PostDominatorTree
&PDT
, const LoopInfo
&LI
)
154 : FuncRPOT(FuncRPOT
), DT(DT
), PDT(PDT
), LI(LI
),
155 JoinBlocks(new ConstBlockSet
) {}
157 // set the definition at @block and mark @block as pending for a visit
158 void addPending(const BasicBlock
&Block
, const BasicBlock
&DefBlock
) {
159 bool WasAdded
= DefMap
.emplace(&Block
, &DefBlock
).second
;
161 PendingUpdates
.insert(&Block
);
164 void printDefs(raw_ostream
&Out
) {
165 Out
<< "Propagator::DefMap {\n";
166 for (const auto *Block
: FuncRPOT
) {
167 auto It
= DefMap
.find(Block
);
168 Out
<< Block
->getName() << " : ";
169 if (It
== DefMap
.end()) {
172 const auto *DefBlock
= It
->second
;
173 Out
<< (DefBlock
? DefBlock
->getName() : "<null>") << "\n";
179 // process @succBlock with reaching definition @defBlock
180 // the original divergent branch was in @parentLoop (if any)
181 void visitSuccessor(const BasicBlock
&SuccBlock
, const Loop
*ParentLoop
,
182 const BasicBlock
&DefBlock
) {
184 // @succBlock is a loop exit
185 if (ParentLoop
&& !ParentLoop
->contains(&SuccBlock
)) {
186 DefMap
.emplace(&SuccBlock
, &DefBlock
);
187 ReachedLoopExits
.insert(&SuccBlock
);
191 // first reaching def?
192 auto ItLastDef
= DefMap
.find(&SuccBlock
);
193 if (ItLastDef
== DefMap
.end()) {
194 addPending(SuccBlock
, DefBlock
);
198 // a join of at least two definitions
199 if (ItLastDef
->second
!= &DefBlock
) {
200 // do we know this join already?
201 if (!JoinBlocks
->insert(&SuccBlock
).second
)
204 // update the definition
205 addPending(SuccBlock
, SuccBlock
);
209 // find all blocks reachable by two disjoint paths from @rootTerm.
210 // This method works for both divergent terminators and loops with
212 // @rootBlock is either the block containing the branch or the header of the
214 // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
215 // headed by @rootBlock.
216 // @parentLoop is the parent loop of the Loop or the loop that contains the
218 template <typename SuccessorIterable
>
219 std::unique_ptr
<ConstBlockSet
>
220 computeJoinPoints(const BasicBlock
&RootBlock
,
221 SuccessorIterable NodeSuccessors
, const Loop
*ParentLoop
) {
224 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop
? ParentLoop
->getName() : "<null>") << "\n" );
226 // bootstrap with branch targets
227 for (const auto *SuccBlock
: NodeSuccessors
) {
228 DefMap
.emplace(SuccBlock
, SuccBlock
);
230 if (ParentLoop
&& !ParentLoop
->contains(SuccBlock
)) {
231 // immediate loop exit from node.
232 ReachedLoopExits
.insert(SuccBlock
);
235 PendingUpdates
.insert(SuccBlock
);
240 dbgs() << "SDA: rpo order:\n";
241 for (const auto * RpoBlock
: FuncRPOT
) {
242 dbgs() << "- " << RpoBlock
->getName() << "\n";
246 auto ItBeginRPO
= FuncRPOT
.begin();
248 // skip until term (TODO RPOT won't let us start at @term directly)
249 for (; *ItBeginRPO
!= &RootBlock
; ++ItBeginRPO
) {}
251 auto ItEndRPO
= FuncRPOT
.end();
252 assert(ItBeginRPO
!= ItEndRPO
);
254 // propagate definitions at the immediate successors of the node in RPO
255 auto ItBlockRPO
= ItBeginRPO
;
256 while ((++ItBlockRPO
!= ItEndRPO
) &&
257 !PendingUpdates
.empty()) {
258 const auto *Block
= *ItBlockRPO
;
259 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block
->getName() << "\n");
261 // skip Block if not pending update
262 auto ItPending
= PendingUpdates
.find(Block
);
263 if (ItPending
== PendingUpdates
.end())
265 PendingUpdates
.erase(ItPending
);
267 // propagate definition at Block to its successors
268 auto ItDef
= DefMap
.find(Block
);
269 const auto *DefBlock
= ItDef
->second
;
272 auto *BlockLoop
= LI
.getLoopFor(Block
);
274 (ParentLoop
!= BlockLoop
&& ParentLoop
->contains(BlockLoop
))) {
275 // if the successor is the header of a nested loop pretend its a
276 // single node with the loop's exits as successors
277 SmallVector
<BasicBlock
*, 4> BlockLoopExits
;
278 BlockLoop
->getExitBlocks(BlockLoopExits
);
279 for (const auto *BlockLoopExit
: BlockLoopExits
) {
280 visitSuccessor(*BlockLoopExit
, ParentLoop
, *DefBlock
);
284 // the successors are either on the same loop level or loop exits
285 for (const auto *SuccBlock
: successors(Block
)) {
286 visitSuccessor(*SuccBlock
, ParentLoop
, *DefBlock
);
291 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
293 // We need to know the definition at the parent loop header to decide
294 // whether the definition at the header is different from the definition at
295 // the loop exits, which would indicate a divergent loop exits.
299 // B // nested loop header
301 // C -> X (exit from B loop) -..-> (A latch)
303 // D -> back to B (B latch)
305 // proper exit from both loops
307 // analyze reached loop exits
308 if (!ReachedLoopExits
.empty()) {
309 const BasicBlock
*ParentLoopHeader
=
310 ParentLoop
? ParentLoop
->getHeader() : nullptr;
313 auto ItHeaderDef
= DefMap
.find(ParentLoopHeader
);
314 const auto *HeaderDefBlock
= (ItHeaderDef
== DefMap
.end()) ? nullptr : ItHeaderDef
->second
;
316 LLVM_DEBUG(printDefs(dbgs()));
317 assert(HeaderDefBlock
&& "no definition at header of carrying loop");
319 for (const auto *ExitBlock
: ReachedLoopExits
) {
320 auto ItExitDef
= DefMap
.find(ExitBlock
);
321 assert((ItExitDef
!= DefMap
.end()) &&
322 "no reaching def at reachable loop exit");
323 if (ItExitDef
->second
!= HeaderDefBlock
) {
324 JoinBlocks
->insert(ExitBlock
);
329 return std::move(JoinBlocks
);
333 const ConstBlockSet
&SyncDependenceAnalysis::join_blocks(const Loop
&Loop
) {
334 using LoopExitVec
= SmallVector
<BasicBlock
*, 4>;
335 LoopExitVec LoopExits
;
336 Loop
.getExitBlocks(LoopExits
);
337 if (LoopExits
.size() < 1) {
338 return EmptyBlockSet
;
341 // already available in cache?
342 auto ItCached
= CachedLoopExitJoins
.find(&Loop
);
343 if (ItCached
!= CachedLoopExitJoins
.end()) {
344 return *ItCached
->second
;
347 // compute all join points
348 DivergencePropagator Propagator
{FuncRPOT
, DT
, PDT
, LI
};
349 auto JoinBlocks
= Propagator
.computeJoinPoints
<const LoopExitVec
&>(
350 *Loop
.getHeader(), LoopExits
, Loop
.getParentLoop());
352 auto ItInserted
= CachedLoopExitJoins
.emplace(&Loop
, std::move(JoinBlocks
));
353 assert(ItInserted
.second
);
354 return *ItInserted
.first
->second
;
357 const ConstBlockSet
&
358 SyncDependenceAnalysis::join_blocks(const Instruction
&Term
) {
360 if (Term
.getNumSuccessors() < 1) {
361 return EmptyBlockSet
;
364 // already available in cache?
365 auto ItCached
= CachedBranchJoins
.find(&Term
);
366 if (ItCached
!= CachedBranchJoins
.end())
367 return *ItCached
->second
;
369 // compute all join points
370 DivergencePropagator Propagator
{FuncRPOT
, DT
, PDT
, LI
};
371 const auto &TermBlock
= *Term
.getParent();
372 auto JoinBlocks
= Propagator
.computeJoinPoints
<succ_const_range
>(
373 TermBlock
, successors(Term
.getParent()), LI
.getLoopFor(&TermBlock
));
375 auto ItInserted
= CachedBranchJoins
.emplace(&Term
, std::move(JoinBlocks
));
376 assert(ItInserted
.second
);
377 return *ItInserted
.first
->second
;