1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
22 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::hasDivergentDefs(
23 const Instruction
&I
) const {
24 return isDivergent((const Value
*)&I
);
28 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::markDefsDivergent(
29 const Instruction
&Instr
) {
30 return markDivergent(cast
<Value
>(&Instr
));
33 template <> void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::initialize() {
34 for (auto &I
: instructions(F
)) {
35 if (TTI
->isSourceOfDivergence(&I
))
37 else if (TTI
->isAlwaysUniform(&I
))
38 addUniformOverride(I
);
40 for (auto &Arg
: F
.args()) {
41 if (TTI
->isSourceOfDivergence(&Arg
)) {
48 void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::pushUsers(
50 for (const auto *User
: V
->users()) {
51 if (const auto *UserInstr
= dyn_cast
<const Instruction
>(User
)) {
52 markDivergent(*UserInstr
);
58 void llvm::GenericUniformityAnalysisImpl
<SSAContext
>::pushUsers(
59 const Instruction
&Instr
) {
60 assert(!isAlwaysUniform(Instr
));
61 if (Instr
.isTerminator())
63 pushUsers(cast
<Value
>(&Instr
));
67 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::usesValueFromCycle(
68 const Instruction
&I
, const Cycle
&DefCycle
) const {
69 assert(!isAlwaysUniform(I
));
70 for (const Use
&U
: I
.operands()) {
71 if (auto *I
= dyn_cast
<Instruction
>(&U
)) {
72 if (DefCycle
.contains(I
->getParent()))
80 void llvm::GenericUniformityAnalysisImpl
<
81 SSAContext
>::propagateTemporalDivergence(const Instruction
&I
,
82 const Cycle
&DefCycle
) {
85 for (auto *User
: I
.users()) {
86 auto *UserInstr
= cast
<Instruction
>(User
);
87 if (DefCycle
.contains(UserInstr
->getParent()))
89 markDivergent(*UserInstr
);
94 bool llvm::GenericUniformityAnalysisImpl
<SSAContext
>::isDivergentUse(
96 const auto *V
= U
.get();
99 if (const auto *DefInstr
= dyn_cast
<Instruction
>(V
)) {
100 const auto *UseInstr
= cast
<Instruction
>(U
.getUser());
101 return isTemporalDivergent(*UseInstr
->getParent(), *DefInstr
);
106 // This ensures explicit instantiation of
107 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
108 template class llvm::GenericUniformityInfo
<SSAContext
>;
109 template struct llvm::GenericUniformityAnalysisImplDeleter
<
110 llvm::GenericUniformityAnalysisImpl
<SSAContext
>>;
112 //===----------------------------------------------------------------------===//
113 // UniformityInfoAnalysis and related pass implementations
114 //===----------------------------------------------------------------------===//
116 llvm::UniformityInfo
UniformityInfoAnalysis::run(Function
&F
,
117 FunctionAnalysisManager
&FAM
) {
118 auto &DT
= FAM
.getResult
<DominatorTreeAnalysis
>(F
);
119 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(F
);
120 auto &CI
= FAM
.getResult
<CycleAnalysis
>(F
);
121 UniformityInfo UI
{DT
, CI
, &TTI
};
122 // Skip computation if we can assume everything is uniform.
123 if (TTI
.hasBranchDivergence(&F
))
129 AnalysisKey
UniformityInfoAnalysis::Key
;
131 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream
&OS
)
134 PreservedAnalyses
UniformityInfoPrinterPass::run(Function
&F
,
135 FunctionAnalysisManager
&AM
) {
136 OS
<< "UniformityInfo for function '" << F
.getName() << "':\n";
137 AM
.getResult
<UniformityInfoAnalysis
>(F
).print(OS
);
139 return PreservedAnalyses::all();
142 //===----------------------------------------------------------------------===//
143 // UniformityInfoWrapperPass Implementation
144 //===----------------------------------------------------------------------===//
146 char UniformityInfoWrapperPass::ID
= 0;
148 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID
) {
149 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
152 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass
, "uniformity",
153 "Uniformity Analysis", true, true)
154 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
155 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass
)
156 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
157 INITIALIZE_PASS_END(UniformityInfoWrapperPass
, "uniformity",
158 "Uniformity Analysis", true, true)
160 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
161 AU
.setPreservesAll();
162 AU
.addRequired
<DominatorTreeWrapperPass
>();
163 AU
.addRequiredTransitive
<CycleInfoWrapperPass
>();
164 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
167 bool UniformityInfoWrapperPass::runOnFunction(Function
&F
) {
168 auto &cycleInfo
= getAnalysis
<CycleInfoWrapperPass
>().getResult();
169 auto &domTree
= getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
170 auto &targetTransformInfo
=
171 getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
174 m_uniformityInfo
= UniformityInfo
{domTree
, cycleInfo
, &targetTransformInfo
};
176 // Skip computation if we can assume everything is uniform.
177 if (targetTransformInfo
.hasBranchDivergence(m_function
))
178 m_uniformityInfo
.compute();
183 void UniformityInfoWrapperPass::print(raw_ostream
&OS
, const Module
*) const {
184 OS
<< "UniformityInfo for function '" << m_function
->getName() << "':\n";
187 void UniformityInfoWrapperPass::releaseMemory() {
188 m_uniformityInfo
= UniformityInfo
{};
189 m_function
= nullptr;