Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / lib / Analysis / UniformityAnalysis.cpp
blob2d617db431c58880c10ca82a01b1444fa9836cfe
1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
19 using namespace llvm;
21 template <>
22 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23 const Instruction &I) const {
24 return isDivergent((const Value *)&I);
27 template <>
28 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29 const Instruction &Instr) {
30 return markDivergent(cast<Value>(&Instr));
33 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34 for (auto &I : instructions(F)) {
35 if (TTI->isSourceOfDivergence(&I))
36 markDivergent(I);
37 else if (TTI->isAlwaysUniform(&I))
38 addUniformOverride(I);
40 for (auto &Arg : F.args()) {
41 if (TTI->isSourceOfDivergence(&Arg)) {
42 markDivergent(&Arg);
47 template <>
48 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
49 const Value *V) {
50 for (const auto *User : V->users()) {
51 if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
52 markDivergent(*UserInstr);
57 template <>
58 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
59 const Instruction &Instr) {
60 assert(!isAlwaysUniform(Instr));
61 if (Instr.isTerminator())
62 return;
63 pushUsers(cast<Value>(&Instr));
66 template <>
67 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
68 const Instruction &I, const Cycle &DefCycle) const {
69 assert(!isAlwaysUniform(I));
70 for (const Use &U : I.operands()) {
71 if (auto *I = dyn_cast<Instruction>(&U)) {
72 if (DefCycle.contains(I->getParent()))
73 return true;
76 return false;
79 template <>
80 void llvm::GenericUniformityAnalysisImpl<
81 SSAContext>::propagateTemporalDivergence(const Instruction &I,
82 const Cycle &DefCycle) {
83 if (isDivergent(I))
84 return;
85 for (auto *User : I.users()) {
86 auto *UserInstr = cast<Instruction>(User);
87 if (DefCycle.contains(UserInstr->getParent()))
88 continue;
89 markDivergent(*UserInstr);
93 template <>
94 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
95 const Use &U) const {
96 const auto *V = U.get();
97 if (isDivergent(V))
98 return true;
99 if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
100 const auto *UseInstr = cast<Instruction>(U.getUser());
101 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
103 return false;
106 // This ensures explicit instantiation of
107 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
108 template class llvm::GenericUniformityInfo<SSAContext>;
109 template struct llvm::GenericUniformityAnalysisImplDeleter<
110 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
112 //===----------------------------------------------------------------------===//
113 // UniformityInfoAnalysis and related pass implementations
114 //===----------------------------------------------------------------------===//
116 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
117 FunctionAnalysisManager &FAM) {
118 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
119 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
120 auto &CI = FAM.getResult<CycleAnalysis>(F);
121 UniformityInfo UI{DT, CI, &TTI};
122 // Skip computation if we can assume everything is uniform.
123 if (TTI.hasBranchDivergence(&F))
124 UI.compute();
126 return UI;
129 AnalysisKey UniformityInfoAnalysis::Key;
131 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
132 : OS(OS) {}
134 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
135 FunctionAnalysisManager &AM) {
136 OS << "UniformityInfo for function '" << F.getName() << "':\n";
137 AM.getResult<UniformityInfoAnalysis>(F).print(OS);
139 return PreservedAnalyses::all();
142 //===----------------------------------------------------------------------===//
143 // UniformityInfoWrapperPass Implementation
144 //===----------------------------------------------------------------------===//
146 char UniformityInfoWrapperPass::ID = 0;
148 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
149 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
152 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
153 "Uniformity Analysis", true, true)
154 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
155 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
157 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
158 "Uniformity Analysis", true, true)
160 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
161 AU.setPreservesAll();
162 AU.addRequired<DominatorTreeWrapperPass>();
163 AU.addRequiredTransitive<CycleInfoWrapperPass>();
164 AU.addRequired<TargetTransformInfoWrapperPass>();
167 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
168 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
169 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170 auto &targetTransformInfo =
171 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
173 m_function = &F;
174 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
176 // Skip computation if we can assume everything is uniform.
177 if (targetTransformInfo.hasBranchDivergence(m_function))
178 m_uniformityInfo.compute();
180 return false;
183 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
184 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
187 void UniformityInfoWrapperPass::releaseMemory() {
188 m_uniformityInfo = UniformityInfo{};
189 m_function = nullptr;