1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/IR/Function.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
23 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
26 using namespace sampleprof
;
27 using namespace llvm::sampleprofutil
;
28 using ProfileCount
= Function::ProfileCount
;
30 #define DEBUG_TYPE "fs-profile-loader"
32 static cl::opt
<bool> ShowFSBranchProb(
33 "show-fs-branchprob", cl::Hidden
, cl::init(false),
34 cl::desc("Print setting flow sensitive branch probabilities"));
35 static cl::opt
<unsigned> FSProfileDebugProbDiffThreshold(
36 "fs-profile-debug-prob-diff-threshold", cl::init(10),
37 cl::desc("Only show debug message if the branch probility is greater than "
38 "this value (in percentage)."));
40 static cl::opt
<unsigned> FSProfileDebugBWThreshold(
41 "fs-profile-debug-bw-threshold", cl::init(10000),
42 cl::desc("Only show debug message if the source branch weight is greater "
43 " than this value."));
45 static cl::opt
<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden
,
47 cl::desc("View BFI before MIR loader"));
48 static cl::opt
<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden
,
50 cl::desc("View BFI after MIR loader"));
52 char MIRProfileLoaderPass::ID
= 0;
54 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass
, DEBUG_TYPE
,
55 "Load MIR Sample Profile",
56 /* cfg = */ false, /* is_analysis = */ false)
57 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo
)
58 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree
)
59 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree
)
60 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
61 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass
)
62 INITIALIZE_PASS_END(MIRProfileLoaderPass
, DEBUG_TYPE
, "Load MIR Sample Profile",
63 /* cfg = */ false, /* is_analysis = */ false)
65 char &llvm::MIRProfileLoaderPassID
= MIRProfileLoaderPass::ID
;
67 FunctionPass
*llvm::createMIRProfileLoaderPass(std::string File
,
68 std::string RemappingFile
,
69 FSDiscriminatorPass P
) {
70 return new MIRProfileLoaderPass(File
, RemappingFile
, P
);
75 // Internal option used to control BFI display only after MBP pass.
76 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
77 // -view-block-layout-with-bfi={none | fraction | integer | count}
78 extern cl::opt
<GVDAGType
> ViewBlockLayoutWithBFI
;
80 // Command line option to specify the name of the function for CFG dump
81 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
82 extern cl::opt
<std::string
> ViewBlockFreqFuncName
;
84 namespace afdo_detail
{
85 template <> struct IRTraits
<MachineBasicBlock
> {
86 using InstructionT
= MachineInstr
;
87 using BasicBlockT
= MachineBasicBlock
;
88 using FunctionT
= MachineFunction
;
89 using BlockFrequencyInfoT
= MachineBlockFrequencyInfo
;
90 using LoopT
= MachineLoop
;
91 using LoopInfoPtrT
= MachineLoopInfo
*;
92 using DominatorTreePtrT
= MachineDominatorTree
*;
93 using PostDominatorTreePtrT
= MachinePostDominatorTree
*;
94 using PostDominatorTreeT
= MachinePostDominatorTree
;
95 using OptRemarkEmitterT
= MachineOptimizationRemarkEmitter
;
96 using OptRemarkAnalysisT
= MachineOptimizationRemarkAnalysis
;
97 using PredRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
98 using SuccRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
99 static Function
&getFunction(MachineFunction
&F
) { return F
.getFunction(); }
100 static const MachineBasicBlock
*getEntryBB(const MachineFunction
*F
) {
101 return GraphTraits
<const MachineFunction
*>::getEntryNode(F
);
103 static PredRangeT
getPredecessors(MachineBasicBlock
*BB
) {
104 return BB
->predecessors();
106 static SuccRangeT
getSuccessors(MachineBasicBlock
*BB
) {
107 return BB
->successors();
110 } // namespace afdo_detail
112 class MIRProfileLoader final
113 : public SampleProfileLoaderBaseImpl
<MachineBasicBlock
> {
115 void setInitVals(MachineDominatorTree
*MDT
, MachinePostDominatorTree
*MPDT
,
116 MachineLoopInfo
*MLI
, MachineBlockFrequencyInfo
*MBFI
,
117 MachineOptimizationRemarkEmitter
*MORE
) {
124 void setFSPass(FSDiscriminatorPass Pass
) {
126 LowBit
= getFSPassBitBegin(P
);
127 HighBit
= getFSPassBitEnd(P
);
128 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
131 MIRProfileLoader(StringRef Name
, StringRef RemapName
)
132 : SampleProfileLoaderBaseImpl(std::string(Name
), std::string(RemapName
)) {
135 void setBranchProbs(MachineFunction
&F
);
136 bool runOnFunction(MachineFunction
&F
);
137 bool doInitialization(Module
&M
);
138 bool isValid() const { return ProfileIsValid
; }
141 friend class SampleCoverageTracker
;
143 /// Hold the information of the basic block frequency.
144 MachineBlockFrequencyInfo
*BFI
;
146 /// PassNum is the sequence number this pass is called, start from 1.
147 FSDiscriminatorPass P
;
149 // LowBit in the FS discriminator used by this instance. Note the number is
150 // 0-based. Base discrimnator use bit 0 to bit 11.
152 // HighwBit in the FS discriminator used by this instance. Note the number
156 bool ProfileIsValid
= true;
160 void SampleProfileLoaderBaseImpl
<
161 MachineBasicBlock
>::computeDominanceAndLoopInfo(MachineFunction
&F
) {}
163 void MIRProfileLoader::setBranchProbs(MachineFunction
&F
) {
164 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
166 MachineBasicBlock
*BB
= &BI
;
167 if (BB
->succ_size() < 2)
169 const MachineBasicBlock
*EC
= EquivalenceClass
[BB
];
170 uint64_t BBWeight
= BlockWeights
[EC
];
171 uint64_t SumEdgeWeight
= 0;
172 for (MachineBasicBlock::succ_iterator SI
= BB
->succ_begin(),
175 MachineBasicBlock
*Succ
= *SI
;
176 Edge E
= std::make_pair(BB
, Succ
);
177 SumEdgeWeight
+= EdgeWeights
[E
];
180 if (BBWeight
!= SumEdgeWeight
) {
181 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
182 << BBWeight
<< " SumEdgeWeight= " << SumEdgeWeight
184 BBWeight
= SumEdgeWeight
;
187 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
192 uint64_t BBWeightOrig
= BBWeight
;
194 uint32_t MaxWeight
= std::numeric_limits
<uint32_t>::max();
196 if (BBWeight
> MaxWeight
) {
197 Factor
= BBWeight
/ MaxWeight
+ 1;
199 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor
<< "\n");
202 for (MachineBasicBlock::succ_iterator SI
= BB
->succ_begin(),
205 MachineBasicBlock
*Succ
= *SI
;
206 Edge E
= std::make_pair(BB
, Succ
);
207 uint64_t EdgeWeight
= EdgeWeights
[E
];
208 EdgeWeight
/= Factor
;
210 assert(BBWeight
>= EdgeWeight
&&
211 "BBweight is larger than EdgeWeight -- should not happen.\n");
213 BranchProbability OldProb
= BFI
->getMBPI()->getEdgeProbability(BB
, SI
);
214 BranchProbability
NewProb(EdgeWeight
, BBWeight
);
215 if (OldProb
== NewProb
)
217 BB
->setSuccProbability(SI
, NewProb
);
219 if (!ShowFSBranchProb
)
222 BranchProbability Diff
;
223 if (OldProb
> NewProb
)
224 Diff
= OldProb
- NewProb
;
226 Diff
= NewProb
- OldProb
;
227 Show
= (Diff
>= BranchProbability(FSProfileDebugProbDiffThreshold
, 100));
228 Show
&= (BBWeightOrig
>= FSProfileDebugBWThreshold
);
230 auto DIL
= BB
->findBranchDebugLoc();
231 auto SuccDIL
= Succ
->findBranchDebugLoc();
233 dbgs() << "Set branch fs prob: MBB (" << BB
->getNumber() << " -> "
234 << Succ
->getNumber() << "): ";
236 dbgs() << DIL
->getFilename() << ":" << DIL
->getLine() << ":"
239 dbgs() << "-->" << SuccDIL
->getFilename() << ":" << SuccDIL
->getLine()
240 << ":" << SuccDIL
->getColumn();
241 dbgs() << " W=" << BBWeightOrig
<< " " << OldProb
<< " --> " << NewProb
249 bool MIRProfileLoader::doInitialization(Module
&M
) {
250 auto &Ctx
= M
.getContext();
252 auto ReaderOrErr
= sampleprof::SampleProfileReader::create(Filename
, Ctx
, P
,
254 if (std::error_code EC
= ReaderOrErr
.getError()) {
255 std::string Msg
= "Could not open profile: " + EC
.message();
256 Ctx
.diagnose(DiagnosticInfoSampleProfile(Filename
, Msg
));
260 Reader
= std::move(ReaderOrErr
.get());
261 Reader
->setModule(&M
);
262 ProfileIsValid
= (Reader
->read() == sampleprof_error::success
);
263 Reader
->getSummary();
268 bool MIRProfileLoader::runOnFunction(MachineFunction
&MF
) {
269 Function
&Func
= MF
.getFunction();
270 clearFunctionData(false);
271 Samples
= Reader
->getSamplesFor(Func
);
272 if (!Samples
|| Samples
->empty())
275 if (getFunctionLoc(MF
) == 0)
278 DenseSet
<GlobalValue::GUID
> InlinedGUIDs
;
279 bool Changed
= computeAndPropagateWeights(MF
, InlinedGUIDs
);
281 // Set the new BPI, BFI.
289 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName
,
290 std::string RemappingFileName
,
291 FSDiscriminatorPass P
)
292 : MachineFunctionPass(ID
), ProfileFileName(FileName
), P(P
),
294 std::make_unique
<MIRProfileLoader
>(FileName
, RemappingFileName
)) {
295 LowBit
= getFSPassBitBegin(P
);
296 HighBit
= getFSPassBitEnd(P
);
297 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
300 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction
&MF
) {
301 if (!MIRSampleLoader
->isValid())
304 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
305 << MF
.getFunction().getName() << "\n");
306 MBFI
= &getAnalysis
<MachineBlockFrequencyInfo
>();
307 MIRSampleLoader
->setInitVals(
308 &getAnalysis
<MachineDominatorTree
>(),
309 &getAnalysis
<MachinePostDominatorTree
>(), &getAnalysis
<MachineLoopInfo
>(),
310 MBFI
, &getAnalysis
<MachineOptimizationRemarkEmitterPass
>().getORE());
313 if (ViewBFIBefore
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
314 (ViewBlockFreqFuncName
.empty() ||
315 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
316 MBFI
->view("MIR_Prof_loader_b." + MF
.getName(), false);
319 bool Changed
= MIRSampleLoader
->runOnFunction(MF
);
321 if (ViewBFIAfter
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
322 (ViewBlockFreqFuncName
.empty() ||
323 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
324 MBFI
->view("MIR_prof_loader_a." + MF
.getName(), false);
330 bool MIRProfileLoaderPass::doInitialization(Module
&M
) {
331 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M
.getName()
334 MIRSampleLoader
->setFSPass(P
);
335 return MIRSampleLoader
->doInitialization(M
);
338 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
339 AU
.setPreservesAll();
340 AU
.addRequired
<MachineBlockFrequencyInfo
>();
341 AU
.addRequired
<MachineDominatorTree
>();
342 AU
.addRequired
<MachinePostDominatorTree
>();
343 AU
.addRequiredTransitive
<MachineLoopInfo
>();
344 AU
.addRequired
<MachineOptimizationRemarkEmitterPass
>();
345 MachineFunctionPass::getAnalysisUsage(AU
);