1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
21 bool llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::hasDivergentDefs(
22 const MachineInstr
&I
) const {
23 for (auto &op
: I
.all_defs()) {
24 if (isDivergent(op
.getReg()))
31 bool llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::markDefsDivergent(
32 const MachineInstr
&Instr
) {
33 bool insertedDivergent
= false;
34 const auto &MRI
= F
.getRegInfo();
35 const auto &RBI
= *F
.getSubtarget().getRegBankInfo();
36 const auto &TRI
= *MRI
.getTargetRegisterInfo();
37 for (auto &op
: Instr
.all_defs()) {
38 if (!op
.getReg().isVirtual())
40 assert(!op
.getSubReg());
41 if (TRI
.isUniformReg(MRI
, RBI
, op
.getReg()))
43 insertedDivergent
|= markDivergent(op
.getReg());
45 return insertedDivergent
;
49 void llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::initialize() {
50 const auto &InstrInfo
= *F
.getSubtarget().getInstrInfo();
52 for (const MachineBasicBlock
&block
: F
) {
53 for (const MachineInstr
&instr
: block
) {
54 auto uniformity
= InstrInfo
.getInstructionUniformity(instr
);
55 if (uniformity
== InstructionUniformity::AlwaysUniform
) {
56 addUniformOverride(instr
);
60 if (uniformity
== InstructionUniformity::NeverUniform
) {
68 void llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::pushUsers(
70 assert(isDivergent(Reg
));
71 const auto &RegInfo
= F
.getRegInfo();
72 for (MachineInstr
&UserInstr
: RegInfo
.use_instructions(Reg
)) {
73 markDivergent(UserInstr
);
78 void llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::pushUsers(
79 const MachineInstr
&Instr
) {
80 assert(!isAlwaysUniform(Instr
));
81 if (Instr
.isTerminator())
83 for (const MachineOperand
&op
: Instr
.all_defs()) {
84 auto Reg
= op
.getReg();
91 bool llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::usesValueFromCycle(
92 const MachineInstr
&I
, const MachineCycle
&DefCycle
) const {
93 assert(!isAlwaysUniform(I
));
94 for (auto &Op
: I
.operands()) {
95 if (!Op
.isReg() || !Op
.readsReg())
97 auto Reg
= Op
.getReg();
99 // FIXME: Physical registers need to be properly checked instead of always
101 if (Reg
.isPhysical())
104 auto *Def
= F
.getRegInfo().getVRegDef(Reg
);
105 if (DefCycle
.contains(Def
->getParent()))
112 void llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::
113 propagateTemporalDivergence(const MachineInstr
&I
,
114 const MachineCycle
&DefCycle
) {
115 const auto &RegInfo
= F
.getRegInfo();
116 for (auto &Op
: I
.all_defs()) {
117 if (!Op
.getReg().isVirtual())
119 auto Reg
= Op
.getReg();
120 if (isDivergent(Reg
))
122 for (MachineInstr
&UserInstr
: RegInfo
.use_instructions(Reg
)) {
123 if (DefCycle
.contains(UserInstr
.getParent()))
125 markDivergent(UserInstr
);
131 bool llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>::isDivergentUse(
132 const MachineOperand
&U
) const {
136 auto Reg
= U
.getReg();
137 if (isDivergent(Reg
))
140 const auto &RegInfo
= F
.getRegInfo();
141 auto *Def
= RegInfo
.getOneDef(Reg
);
145 auto *DefInstr
= Def
->getParent();
146 auto *UseInstr
= U
.getParent();
147 return isTemporalDivergent(*UseInstr
->getParent(), *DefInstr
);
150 // This ensures explicit instantiation of
151 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
152 template class llvm::GenericUniformityInfo
<MachineSSAContext
>;
153 template struct llvm::GenericUniformityAnalysisImplDeleter
<
154 llvm::GenericUniformityAnalysisImpl
<MachineSSAContext
>>;
156 MachineUniformityInfo
llvm::computeMachineUniformityInfo(
157 MachineFunction
&F
, const MachineCycleInfo
&cycleInfo
,
158 const MachineDomTree
&domTree
, bool HasBranchDivergence
) {
159 assert(F
.getRegInfo().isSSA() && "Expected to be run on SSA form!");
160 MachineUniformityInfo
UI(domTree
, cycleInfo
);
161 if (HasBranchDivergence
)
168 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
169 class MachineUniformityAnalysisPass
: public MachineFunctionPass
{
170 MachineUniformityInfo UI
;
175 MachineUniformityAnalysisPass();
177 MachineUniformityInfo
&getUniformityInfo() { return UI
; }
178 const MachineUniformityInfo
&getUniformityInfo() const { return UI
; }
180 bool runOnMachineFunction(MachineFunction
&F
) override
;
181 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
182 void print(raw_ostream
&OS
, const Module
*M
= nullptr) const override
;
184 // TODO: verify analysis
187 class MachineUniformityInfoPrinterPass
: public MachineFunctionPass
{
191 MachineUniformityInfoPrinterPass();
193 bool runOnMachineFunction(MachineFunction
&F
) override
;
194 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
199 char MachineUniformityAnalysisPass::ID
= 0;
201 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
202 : MachineFunctionPass(ID
) {
203 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
206 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass
, "machine-uniformity",
207 "Machine Uniformity Info Analysis", true, true)
208 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass
)
209 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree
)
210 INITIALIZE_PASS_END(MachineUniformityAnalysisPass
, "machine-uniformity",
211 "Machine Uniformity Info Analysis", true, true)
213 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
214 AU
.setPreservesAll();
215 AU
.addRequired
<MachineCycleInfoWrapperPass
>();
216 AU
.addRequired
<MachineDominatorTree
>();
217 MachineFunctionPass::getAnalysisUsage(AU
);
220 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction
&MF
) {
221 auto &DomTree
= getAnalysis
<MachineDominatorTree
>().getBase();
222 auto &CI
= getAnalysis
<MachineCycleInfoWrapperPass
>().getCycleInfo();
223 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
225 UI
= computeMachineUniformityInfo(MF
, CI
, DomTree
, true);
229 void MachineUniformityAnalysisPass::print(raw_ostream
&OS
,
230 const Module
*) const {
231 OS
<< "MachineUniformityInfo for function: " << UI
.getFunction().getName()
236 char MachineUniformityInfoPrinterPass::ID
= 0;
238 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
239 : MachineFunctionPass(ID
) {
240 initializeMachineUniformityInfoPrinterPassPass(
241 *PassRegistry::getPassRegistry());
244 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass
,
245 "print-machine-uniformity",
246 "Print Machine Uniformity Info Analysis", true, true)
247 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass
)
248 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass
,
249 "print-machine-uniformity",
250 "Print Machine Uniformity Info Analysis", true, true)
252 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
253 AnalysisUsage
&AU
) const {
254 AU
.setPreservesAll();
255 AU
.addRequired
<MachineUniformityAnalysisPass
>();
256 MachineFunctionPass::getAnalysisUsage(AU
);
259 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
260 MachineFunction
&F
) {
261 auto &UI
= getAnalysis
<MachineUniformityAnalysisPass
>();