[clang] Add test for CWG190 "Layout-compatible POD-struct types" (#121668)
[llvm-project.git] / llvm / lib / Analysis / UniformityAnalysis.cpp
blob592de1067e191a16b53e0b1176586882d6381463
1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/InitializePasses.h"
18 using namespace llvm;
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
22 const Instruction &I) const {
23 return isDivergent((const Value *)&I);
26 template <>
27 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
28 const Instruction &Instr) {
29 return markDivergent(cast<Value>(&Instr));
32 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
33 for (auto &I : instructions(F)) {
34 if (TTI->isSourceOfDivergence(&I))
35 markDivergent(I);
36 else if (TTI->isAlwaysUniform(&I))
37 addUniformOverride(I);
39 for (auto &Arg : F.args()) {
40 if (TTI->isSourceOfDivergence(&Arg)) {
41 markDivergent(&Arg);
46 template <>
47 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
48 const Value *V) {
49 for (const auto *User : V->users()) {
50 if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
51 markDivergent(*UserInstr);
56 template <>
57 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
58 const Instruction &Instr) {
59 assert(!isAlwaysUniform(Instr));
60 if (Instr.isTerminator())
61 return;
62 pushUsers(cast<Value>(&Instr));
65 template <>
66 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
67 const Instruction &I, const Cycle &DefCycle) const {
68 assert(!isAlwaysUniform(I));
69 for (const Use &U : I.operands()) {
70 if (auto *I = dyn_cast<Instruction>(&U)) {
71 if (DefCycle.contains(I->getParent()))
72 return true;
75 return false;
78 template <>
79 void llvm::GenericUniformityAnalysisImpl<
80 SSAContext>::propagateTemporalDivergence(const Instruction &I,
81 const Cycle &DefCycle) {
82 if (isDivergent(I))
83 return;
84 for (auto *User : I.users()) {
85 auto *UserInstr = cast<Instruction>(User);
86 if (DefCycle.contains(UserInstr->getParent()))
87 continue;
88 markDivergent(*UserInstr);
92 template <>
93 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
94 const Use &U) const {
95 const auto *V = U.get();
96 if (isDivergent(V))
97 return true;
98 if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
99 const auto *UseInstr = cast<Instruction>(U.getUser());
100 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
102 return false;
105 // This ensures explicit instantiation of
106 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
107 template class llvm::GenericUniformityInfo<SSAContext>;
108 template struct llvm::GenericUniformityAnalysisImplDeleter<
109 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
111 //===----------------------------------------------------------------------===//
112 // UniformityInfoAnalysis and related pass implementations
113 //===----------------------------------------------------------------------===//
115 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
116 FunctionAnalysisManager &FAM) {
117 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
118 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
119 auto &CI = FAM.getResult<CycleAnalysis>(F);
120 UniformityInfo UI{DT, CI, &TTI};
121 // Skip computation if we can assume everything is uniform.
122 if (TTI.hasBranchDivergence(&F))
123 UI.compute();
125 return UI;
128 AnalysisKey UniformityInfoAnalysis::Key;
130 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
131 : OS(OS) {}
133 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
134 FunctionAnalysisManager &AM) {
135 OS << "UniformityInfo for function '" << F.getName() << "':\n";
136 AM.getResult<UniformityInfoAnalysis>(F).print(OS);
138 return PreservedAnalyses::all();
141 //===----------------------------------------------------------------------===//
142 // UniformityInfoWrapperPass Implementation
143 //===----------------------------------------------------------------------===//
145 char UniformityInfoWrapperPass::ID = 0;
147 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
148 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
151 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
152 "Uniformity Analysis", true, true)
153 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
154 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
155 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
156 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
157 "Uniformity Analysis", true, true)
159 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
160 AU.setPreservesAll();
161 AU.addRequired<DominatorTreeWrapperPass>();
162 AU.addRequiredTransitive<CycleInfoWrapperPass>();
163 AU.addRequired<TargetTransformInfoWrapperPass>();
166 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
167 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
168 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
169 auto &targetTransformInfo =
170 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
172 m_function = &F;
173 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
175 // Skip computation if we can assume everything is uniform.
176 if (targetTransformInfo.hasBranchDivergence(m_function))
177 m_uniformityInfo.compute();
179 return false;
182 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
183 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
186 void UniformityInfoWrapperPass::releaseMemory() {
187 m_uniformityInfo = UniformityInfo{};
188 m_function = nullptr;