1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an algorithm that returns for a divergent branch
10 // the set of basic blocks whose phi nodes become divergent due to divergent
11 // control. These are the blocks that are reachable by two disjoint paths from
12 // the branch or loop exits that have a reaching path that is disjoint from a
13 // path to the loop latch.
15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16 // control-induced divergence in phi nodes.
19 // The SyncDependenceAnalysis lazily computes sync dependences [3].
20 // The analysis evaluates the disjoint path criterion [2] by a reduction
21 // to SSA construction. The SSA construction algorithm is implemented as
22 // a simple data-flow analysis [1].
24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
25 // [2] "Efficiently Computing Static Single Assignment Form
26 // and the Control Dependence Graph", TOPLAS '91,
27 // Cytron, Ferrante, Rosen, Wegman and Zadeck
28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
31 // -- Sync dependence --
32 // Sync dependence [4] characterizes the control flow aspect of the
33 // propagation of branch divergence. For example,
35 // %cond = icmp slt i32 %tid, 10
36 // br i1 %cond, label %then, label %else
42 // %a = phi i32 [ 0, %then ], [ 1, %else ]
44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
45 // because %tid is not on its use-def chains, %a is sync dependent on %tid
46 // because the branch "br i1 %cond" depends on %tid and affects which value %a
49 // -- Reduction to SSA construction --
50 // There are two disjoint paths from A to X, if a certain variant of SSA
51 // construction places a phi node in X under the following set-up scheme [2].
53 // This variant of SSA construction ignores incoming undef values.
54 // That is paths from the entry without a definition do not result in
66 // Assume that A contains a divergent branch. We are interested
67 // in the set of all blocks where each block is reachable from A
68 // via two disjoint paths. This would be the set {D, F} in this
70 // To generally reduce this query to SSA construction we introduce
71 // a virtual variable x and assign to x different values in each
72 // successor block of A.
82 // Our flavor of SSA construction for x will construct the following
92 // The blocks D and F contain phi nodes and are thus each reachable
93 // by two disjoins paths from A.
96 // In case of loop exits we need to check the disjoint path criterion for loops
97 // [2]. To this end, we check whether the definition of x differs between the
98 // loop exit and the loop header (_after_ SSA construction).
100 //===----------------------------------------------------------------------===//
101 #include "llvm/Analysis/SyncDependenceAnalysis.h"
102 #include "llvm/ADT/PostOrderIterator.h"
103 #include "llvm/ADT/SmallPtrSet.h"
104 #include "llvm/Analysis/PostDominators.h"
105 #include "llvm/IR/BasicBlock.h"
106 #include "llvm/IR/CFG.h"
107 #include "llvm/IR/Dominators.h"
108 #include "llvm/IR/Function.h"
110 #include <functional>
112 #include <unordered_set>
114 #define DEBUG_TYPE "sync-dependence"
116 // The SDA algorithm operates on a modified CFG - we modify the edges leaving
117 // loop headers as follows:
119 // * We remove all edges leaving all loop headers.
120 // * We add additional edges from the loop headers to their exit blocks.
122 // The modification is virtual, that is whenever we visit a loop header we
123 // pretend it had different successors.
125 using namespace llvm
;
127 // Custom Post-Order Traveral
129 // We cannot use the vanilla (R)PO computation of LLVM because:
130 // * We (virtually) modify the CFG.
131 // * We want a loop-compact block enumeration, that is the numbers assigned by
132 // the traveral to the blocks of a loop are an interval.
133 using POCB
= std::function
<void(const BasicBlock
&)>;
134 using VisitedSet
= std::set
<const BasicBlock
*>;
135 using BlockStack
= std::vector
<const BasicBlock
*>;
138 static void computeLoopPO(const LoopInfo
&LI
, Loop
&Loop
, POCB CallBack
,
139 VisitedSet
&Finalized
);
141 // for a nested region (top-level loop or nested loop)
142 static void computeStackPO(BlockStack
&Stack
, const LoopInfo
&LI
, Loop
*Loop
,
143 POCB CallBack
, VisitedSet
&Finalized
) {
144 const auto *LoopHeader
= Loop
? Loop
->getHeader() : nullptr;
145 while (!Stack
.empty()) {
146 const auto *NextBB
= Stack
.back();
148 auto *NestedLoop
= LI
.getLoopFor(NextBB
);
149 bool IsNestedLoop
= NestedLoop
!= Loop
;
151 // Treat the loop as a node
153 SmallVector
<BasicBlock
*, 3> NestedExits
;
154 NestedLoop
->getUniqueExitBlocks(NestedExits
);
155 bool PushedNodes
= false;
156 for (const auto *NestedExitBB
: NestedExits
) {
157 if (NestedExitBB
== LoopHeader
)
159 if (Loop
&& !Loop
->contains(NestedExitBB
))
161 if (Finalized
.count(NestedExitBB
))
164 Stack
.push_back(NestedExitBB
);
167 // All loop exits finalized -> finish this node
169 computeLoopPO(LI
, *NestedLoop
, CallBack
, Finalized
);
175 bool PushedNodes
= false;
176 for (const auto *SuccBB
: successors(NextBB
)) {
177 if (SuccBB
== LoopHeader
)
179 if (Loop
&& !Loop
->contains(SuccBB
))
181 if (Finalized
.count(SuccBB
))
184 Stack
.push_back(SuccBB
);
187 // Never push nodes twice
189 if (!Finalized
.insert(NextBB
).second
)
196 static void computeTopLevelPO(Function
&F
, const LoopInfo
&LI
, POCB CallBack
) {
197 VisitedSet Finalized
;
199 Stack
.reserve(24); // FIXME made-up number
200 Stack
.push_back(&F
.getEntryBlock());
201 computeStackPO(Stack
, LI
, nullptr, CallBack
, Finalized
);
204 static void computeLoopPO(const LoopInfo
&LI
, Loop
&Loop
, POCB CallBack
,
205 VisitedSet
&Finalized
) {
206 /// Call CallBack on all loop blocks.
207 std::vector
<const BasicBlock
*> Stack
;
208 const auto *LoopHeader
= Loop
.getHeader();
210 // Visit the header last
211 Finalized
.insert(LoopHeader
);
212 CallBack(*LoopHeader
);
214 // Initialize with immediate successors
215 for (const auto *BB
: successors(LoopHeader
)) {
216 if (!Loop
.contains(BB
))
218 if (BB
== LoopHeader
)
223 // Compute PO inside region
224 computeStackPO(Stack
, LI
, &Loop
, CallBack
, Finalized
);
231 ControlDivergenceDesc
SyncDependenceAnalysis::EmptyDivergenceDesc
;
233 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree
&DT
,
234 const PostDominatorTree
&PDT
,
236 : DT(DT
), PDT(PDT
), LI(LI
) {
237 computeTopLevelPO(*DT
.getRoot()->getParent(), LI
,
238 [&](const BasicBlock
&BB
) { LoopPO
.appendBlock(BB
); });
241 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
243 // divergence propagator for reducible CFGs
244 struct DivergencePropagator
{
245 const ModifiedPO
&LoopPOT
;
246 const DominatorTree
&DT
;
247 const PostDominatorTree
&PDT
;
249 const BasicBlock
&DivTermBlock
;
251 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
253 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
254 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
255 // from X or B is an immediate successor of X (initial value).
256 using BlockLabelVec
= std::vector
<const BasicBlock
*>;
257 BlockLabelVec BlockLabels
;
258 // divergent join and loop exit descriptor.
259 std::unique_ptr
<ControlDivergenceDesc
> DivDesc
;
261 DivergencePropagator(const ModifiedPO
&LoopPOT
, const DominatorTree
&DT
,
262 const PostDominatorTree
&PDT
, const LoopInfo
&LI
,
263 const BasicBlock
&DivTermBlock
)
264 : LoopPOT(LoopPOT
), DT(DT
), PDT(PDT
), LI(LI
), DivTermBlock(DivTermBlock
),
265 BlockLabels(LoopPOT
.size(), nullptr),
266 DivDesc(new ControlDivergenceDesc
) {}
268 void printDefs(raw_ostream
&Out
) {
269 Out
<< "Propagator::BlockLabels {\n";
270 for (int BlockIdx
= (int)BlockLabels
.size() - 1; BlockIdx
> 0; --BlockIdx
) {
271 const auto *Label
= BlockLabels
[BlockIdx
];
272 Out
<< LoopPOT
.getBlockAt(BlockIdx
)->getName().str() << "(" << BlockIdx
277 Out
<< Label
->getName() << "\n";
283 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
284 // causes a divergent join.
285 bool computeJoin(const BasicBlock
&SuccBlock
, const BasicBlock
&PushedLabel
) {
286 auto SuccIdx
= LoopPOT
.getIndexOf(SuccBlock
);
288 // unset or same reaching label
289 const auto *OldLabel
= BlockLabels
[SuccIdx
];
290 if (!OldLabel
|| (OldLabel
== &PushedLabel
)) {
291 BlockLabels
[SuccIdx
] = &PushedLabel
;
295 // Update the definition
296 BlockLabels
[SuccIdx
] = &SuccBlock
;
300 // visiting a virtual loop exit edge from the loop header --> temporal
301 // divergence on join
302 bool visitLoopExitEdge(const BasicBlock
&ExitBlock
,
303 const BasicBlock
&DefBlock
, bool FromParentLoop
) {
304 // Pushing from a non-parent loop cannot cause temporal divergence.
306 return visitEdge(ExitBlock
, DefBlock
);
308 if (!computeJoin(ExitBlock
, DefBlock
))
311 // Identified a divergent loop exit
312 DivDesc
->LoopDivBlocks
.insert(&ExitBlock
);
313 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock
.getName()
318 // process \p SuccBlock with reaching definition \p DefBlock
319 bool visitEdge(const BasicBlock
&SuccBlock
, const BasicBlock
&DefBlock
) {
320 if (!computeJoin(SuccBlock
, DefBlock
))
323 // Divergent, disjoint paths join.
324 DivDesc
->JoinDivBlocks
.insert(&SuccBlock
);
325 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock
.getName());
329 std::unique_ptr
<ControlDivergenceDesc
> computeJoinPoints() {
332 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock
.getName()
335 const auto *DivBlockLoop
= LI
.getLoopFor(&DivTermBlock
);
337 // Early stopping criterion
338 int FloorIdx
= LoopPOT
.size() - 1;
339 const BasicBlock
*FloorLabel
= nullptr;
341 // bootstrap with branch targets
344 for (const auto *SuccBlock
: successors(&DivTermBlock
)) {
345 auto SuccIdx
= LoopPOT
.getIndexOf(*SuccBlock
);
346 BlockLabels
[SuccIdx
] = SuccBlock
;
348 // Find the successor with the highest index to start with
349 BlockIdx
= std::max
<int>(BlockIdx
, SuccIdx
);
350 FloorIdx
= std::min
<int>(FloorIdx
, SuccIdx
);
352 // Identify immediate divergent loop exits
356 const auto *BlockLoop
= LI
.getLoopFor(SuccBlock
);
357 if (BlockLoop
&& DivBlockLoop
->contains(BlockLoop
))
359 DivDesc
->LoopDivBlocks
.insert(SuccBlock
);
360 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
361 << SuccBlock
->getName() << "\n");
364 // propagate definitions at the immediate successors of the node in RPO
365 for (; BlockIdx
>= FloorIdx
; --BlockIdx
) {
366 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
368 // Any label available here
369 const auto *Label
= BlockLabels
[BlockIdx
];
374 const auto *Block
= LoopPOT
.getBlockAt(BlockIdx
);
375 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block
->getName() << "\n");
377 auto *BlockLoop
= LI
.getLoopFor(Block
);
378 bool IsLoopHeader
= BlockLoop
&& BlockLoop
->getHeader() == Block
;
379 bool CausedJoin
= false;
380 int LoweredFloorIdx
= FloorIdx
;
382 // Disconnect from immediate successors and propagate directly to loop
384 SmallVector
<BasicBlock
*, 4> BlockLoopExits
;
385 BlockLoop
->getExitBlocks(BlockLoopExits
);
387 bool IsParentLoop
= BlockLoop
->contains(&DivTermBlock
);
388 for (const auto *BlockLoopExit
: BlockLoopExits
) {
389 CausedJoin
|= visitLoopExitEdge(*BlockLoopExit
, *Label
, IsParentLoop
);
390 LoweredFloorIdx
= std::min
<int>(LoweredFloorIdx
,
391 LoopPOT
.getIndexOf(*BlockLoopExit
));
394 // Acyclic successor case
395 for (const auto *SuccBlock
: successors(Block
)) {
396 CausedJoin
|= visitEdge(*SuccBlock
, *Label
);
398 std::min
<int>(LoweredFloorIdx
, LoopPOT
.getIndexOf(*SuccBlock
));
404 // 1. Different labels pushed to successors
405 FloorIdx
= LoweredFloorIdx
;
406 } else if (FloorLabel
!= Label
) {
407 // 2. No join caused BUT we pushed a label that is different than the
409 FloorIdx
= LoweredFloorIdx
;
414 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
416 return std::move(DivDesc
);
421 static void printBlockSet(ConstBlockSet
&Blocks
, raw_ostream
&Out
) {
424 for (const auto *BB
: Blocks
)
425 Out
<< LS
<< BB
->getName();
430 const ControlDivergenceDesc
&
431 SyncDependenceAnalysis::getJoinBlocks(const Instruction
&Term
) {
433 if (Term
.getNumSuccessors() <= 1) {
434 return EmptyDivergenceDesc
;
437 // already available in cache?
438 auto ItCached
= CachedControlDivDescs
.find(&Term
);
439 if (ItCached
!= CachedControlDivDescs
.end())
440 return *ItCached
->second
;
442 // compute all join points
443 // Special handling of divergent loop exits is not needed for LCSSA
444 const auto &TermBlock
= *Term
.getParent();
445 DivergencePropagator
Propagator(LoopPO
, DT
, PDT
, LI
, TermBlock
);
446 auto DivDesc
= Propagator
.computeJoinPoints();
448 LLVM_DEBUG(dbgs() << "Result (" << Term
.getParent()->getName() << "):\n";
449 dbgs() << "JoinDivBlocks: ";
450 printBlockSet(DivDesc
->JoinDivBlocks
, dbgs());
451 dbgs() << "\nLoopDivBlocks: ";
452 printBlockSet(DivDesc
->LoopDivBlocks
, dbgs()); dbgs() << "\n";);
454 auto ItInserted
= CachedControlDivDescs
.emplace(&Term
, std::move(DivDesc
));
455 assert(ItInserted
.second
);
456 return *ItInserted
.first
->second
;