1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/InitializePasses.h"
21 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::hasDivergentDefs(
22 const Instruction
&I
) const {
23 return isDivergent((const Value
*)&I
);
27 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::markDefsDivergent(
28 const Instruction
&Instr
) {
29 return markDivergent(cast
<Value
>(&Instr
));
32 template <> void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::initialize() {
33 for (auto &I
: instructions(F
)) {
34 if (TTI
->isSourceOfDivergence(&I
))
36 else if (TTI
->isAlwaysUniform(&I
))
37 addUniformOverride(I
);
39 for (auto &Arg
: F
.args()) {
40 if (TTI
->isSourceOfDivergence(&Arg
)) {
47 void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::pushUsers(
49 for (const auto *User
: V
->users()) {
50 if (const auto *UserInstr
= dyn_cast
<const Instruction
>(User
)) {
51 markDivergent(*UserInstr
);
57 void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::pushUsers(
58 const Instruction
&Instr
) {
59 assert(!isAlwaysUniform(Instr
));
60 if (Instr
.isTerminator())
62 pushUsers(cast
<Value
>(&Instr
));
66 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::usesValueFromCycle(
67 const Instruction
&I
, const Cycle
&DefCycle
) const {
68 assert(!isAlwaysUniform(I
));
69 for (const Use
&U
: I
.operands()) {
70 if (auto *I
= dyn_cast
<Instruction
>(&U
)) {
71 if (DefCycle
.contains(I
->getParent()))
79 void llvm::GenericUniformityAnalysisImpl
<
80 SSAContext
>::propagateTemporalDivergence(const Instruction
&I
,
81 const Cycle
&DefCycle
) {
84 for (auto *User
: I
.users()) {
85 auto *UserInstr
= cast
<Instruction
>(User
);
86 if (DefCycle
.contains(UserInstr
->getParent()))
88 markDivergent(*UserInstr
);
93 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::isDivergentUse(
95 const auto *V
= U
.get();
98 if (const auto *DefInstr
= dyn_cast
<Instruction
>(V
)) {
99 const auto *UseInstr
= cast
<Instruction
>(U
.getUser());
100 return isTemporalDivergent(*UseInstr
->getParent(), *DefInstr
);
105 // This ensures explicit instantiation of
106 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
107 template class llvm::GenericUniformityInfo
<SSAContext
>;
108 template struct llvm::GenericUniformityAnalysisImplDeleter
<
109 llvm::GenericUniformityAnalysisImpl
<SSAContext
>>;
111 //===----------------------------------------------------------------------===//
112 // UniformityInfoAnalysis and related pass implementations
113 //===----------------------------------------------------------------------===//
115 llvm::UniformityInfo
UniformityInfoAnalysis::run(Function
&F
,
116 FunctionAnalysisManager
&FAM
) {
117 auto &DT
= FAM
.getResult
<DominatorTreeAnalysis
>(F
);
118 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(F
);
119 auto &CI
= FAM
.getResult
<CycleAnalysis
>(F
);
120 UniformityInfo UI
{DT
, CI
, &TTI
};
121 // Skip computation if we can assume everything is uniform.
122 if (TTI
.hasBranchDivergence(&F
))
128 AnalysisKey
UniformityInfoAnalysis::Key
;
130 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream
&OS
)
133 PreservedAnalyses
UniformityInfoPrinterPass::run(Function
&F
,
134 FunctionAnalysisManager
&AM
) {
135 OS
<< "UniformityInfo for function '" << F
.getName() << "':\n";
136 AM
.getResult
<UniformityInfoAnalysis
>(F
).print(OS
);
138 return PreservedAnalyses::all();
141 //===----------------------------------------------------------------------===//
142 // UniformityInfoWrapperPass Implementation
143 //===----------------------------------------------------------------------===//
145 char UniformityInfoWrapperPass::ID
= 0;
147 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID
) {
148 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
151 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass
, "uniformity",
152 "Uniformity Analysis", true, true)
153 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
154 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass
)
155 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
156 INITIALIZE_PASS_END(UniformityInfoWrapperPass
, "uniformity",
157 "Uniformity Analysis", true, true)
159 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
160 AU
.setPreservesAll();
161 AU
.addRequired
<DominatorTreeWrapperPass
>();
162 AU
.addRequiredTransitive
<CycleInfoWrapperPass
>();
163 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
166 bool UniformityInfoWrapperPass::runOnFunction(Function
&F
) {
167 auto &cycleInfo
= getAnalysis
<CycleInfoWrapperPass
>().getResult();
168 auto &domTree
= getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
169 auto &targetTransformInfo
=
170 getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
173 m_uniformityInfo
= UniformityInfo
{domTree
, cycleInfo
, &targetTransformInfo
};
175 // Skip computation if we can assume everything is uniform.
176 if (targetTransformInfo
.hasBranchDivergence(m_function
))
177 m_uniformityInfo
.compute();
182 void UniformityInfoWrapperPass::print(raw_ostream
&OS
, const Module
*) const {
183 OS
<< "UniformityInfo for function '" << m_function
->getName() << "':\n";
186 void UniformityInfoWrapperPass::releaseMemory() {
187 m_uniformityInfo
= UniformityInfo
{};
188 m_function
= nullptr;