1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // The divergence analysis determines which instructions and branches are
11 // divergent given a set of divergent source instructions.
13 //===----------------------------------------------------------------------===//
15 #ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
16 #define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/Analysis/SyncDependenceAnalysis.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/Pass.h"
30 class TargetTransformInfo
;
32 /// \brief Generic divergence analysis for reducible CFGs.
34 /// This analysis propagates divergence in a data-parallel context from sources
35 /// of divergence to all users. It requires reducible CFGs. All assignments
36 /// should be in SSA form.
37 class DivergenceAnalysis
{
39 /// \brief This instance will analyze the whole function \p F or the loop \p
42 /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
43 /// Otherwise the whole function is analyzed.
44 /// \param IsLCSSAForm whether the analysis may assume that the IR in the
45 /// region in in LCSSA form.
46 DivergenceAnalysis(const Function
&F
, const Loop
*RegionLoop
,
47 const DominatorTree
&DT
, const LoopInfo
&LI
,
48 SyncDependenceAnalysis
&SDA
, bool IsLCSSAForm
);
50 /// \brief The loop that defines the analyzed region (if any).
51 const Loop
*getRegionLoop() const { return RegionLoop
; }
52 const Function
&getFunction() const { return F
; }
54 /// \brief Whether \p BB is part of the region.
55 bool inRegion(const BasicBlock
&BB
) const;
56 /// \brief Whether \p I is part of the region.
57 bool inRegion(const Instruction
&I
) const;
59 /// \brief Mark \p UniVal as a value that is always uniform.
60 void addUniformOverride(const Value
&UniVal
);
62 /// \brief Mark \p DivVal as a value that is always divergent.
63 void markDivergent(const Value
&DivVal
);
65 /// \brief Propagate divergence to all instructions in the region.
66 /// Divergence is seeded by calls to \p markDivergent.
69 /// \brief Whether any value was marked or analyzed to be divergent.
70 bool hasDetectedDivergence() const { return !DivergentValues
.empty(); }
72 /// \brief Whether \p Val will always return a uniform value regardless of its
74 bool isAlwaysUniform(const Value
&Val
) const;
76 /// \brief Whether \p Val is divergent at its definition.
77 bool isDivergent(const Value
&Val
) const;
79 /// \brief Whether \p U is divergent. Uses of a uniform value can be divergent.
80 bool isDivergentUse(const Use
&U
) const;
82 void print(raw_ostream
&OS
, const Module
*) const;
85 bool updateTerminator(const Instruction
&Term
) const;
86 bool updatePHINode(const PHINode
&Phi
) const;
88 /// \brief Computes whether \p Inst is divergent based on the
89 /// divergence of its operands.
91 /// \returns Whether \p Inst is divergent.
93 /// This should only be called for non-phi, non-terminator instructions.
94 bool updateNormalInstruction(const Instruction
&Inst
) const;
96 /// \brief Mark users of live-out users as divergent.
98 /// \param LoopHeader the header of the divergent loop.
100 /// Marks all users of live-out values of the loop headed by \p LoopHeader
101 /// as divergent and puts them on the worklist.
102 void taintLoopLiveOuts(const BasicBlock
&LoopHeader
);
104 /// \brief Push all users of \p Val (in the region) to the worklist
105 void pushUsers(const Value
&I
);
107 /// \brief Push all phi nodes in @block to the worklist
108 void pushPHINodes(const BasicBlock
&Block
);
110 /// \brief Mark \p Block as join divergent
112 /// A block is join divergent if two threads may reach it from different
113 /// incoming blocks at the same time.
114 void markBlockJoinDivergent(const BasicBlock
&Block
) {
115 DivergentJoinBlocks
.insert(&Block
);
118 /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
119 bool isTemporalDivergent(const BasicBlock
&ObservingBlock
,
120 const Value
&Val
) const;
122 /// \brief Whether \p Block is join divergent
124 /// (see markBlockJoinDivergent).
125 bool isJoinDivergent(const BasicBlock
&Block
) const {
126 return DivergentJoinBlocks
.find(&Block
) != DivergentJoinBlocks
.end();
129 /// \brief Propagate control-induced divergence to users (phi nodes and
132 // \param JoinBlock is a divergent loop exit or join point of two disjoint
134 // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
135 bool propagateJoinDivergence(const BasicBlock
&JoinBlock
,
136 const Loop
*TermLoop
);
138 /// \brief Propagate induced value divergence due to control divergence in \p
140 void propagateBranchDivergence(const Instruction
&Term
);
142 /// \brief Propagate divergent caused by a divergent loop exit.
144 /// \param ExitingLoop is a divergent loop.
145 void propagateLoopDivergence(const Loop
&ExitingLoop
);
149 // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
150 // Otw, analyze the whole function
151 const Loop
*RegionLoop
;
153 const DominatorTree
&DT
;
156 // Recognized divergent loops
157 DenseSet
<const Loop
*> DivergentLoops
;
159 // The SDA links divergent branches to divergent control-flow joins.
160 SyncDependenceAnalysis
&SDA
;
162 // Use simplified code path for LCSSA form.
165 // Set of known-uniform values.
166 DenseSet
<const Value
*> UniformOverrides
;
168 // Blocks with joining divergent control from different predecessors.
169 DenseSet
<const BasicBlock
*> DivergentJoinBlocks
;
171 // Detected/marked divergent values.
172 DenseSet
<const Value
*> DivergentValues
;
174 // Internal worklist for divergence propagation.
175 std::vector
<const Instruction
*> Worklist
;
178 /// \brief Divergence analysis frontend for GPU kernels.
179 class GPUDivergenceAnalysis
{
180 SyncDependenceAnalysis SDA
;
181 DivergenceAnalysis DA
;
184 /// Runs the divergence analysis on @F, a GPU kernel
185 GPUDivergenceAnalysis(Function
&F
, const DominatorTree
&DT
,
186 const PostDominatorTree
&PDT
, const LoopInfo
&LI
,
187 const TargetTransformInfo
&TTI
);
189 /// Whether any divergence was detected.
190 bool hasDivergence() const { return DA
.hasDetectedDivergence(); }
192 /// The GPU kernel this analysis result is for
193 const Function
&getFunction() const { return DA
.getFunction(); }
195 /// Whether \p V is divergent at its definition.
196 bool isDivergent(const Value
&V
) const;
198 /// Whether \p U is divergent. Uses of a uniform value can be divergent.
199 bool isDivergentUse(const Use
&U
) const;
201 /// Whether \p V is uniform/non-divergent.
202 bool isUniform(const Value
&V
) const { return !isDivergent(V
); }
204 /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be
206 bool isUniformUse(const Use
&U
) const { return !isDivergentUse(U
); }
208 /// Print all divergent values in the kernel.
209 void print(raw_ostream
&OS
, const Module
*) const;
214 #endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H