1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineLoopInfo.h"
22 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
23 #include "llvm/CodeGen/MachinePostDominators.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
31 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
34 using namespace sampleprof
;
35 using namespace llvm::sampleprofutil
;
36 using ProfileCount
= Function::ProfileCount
;
38 #define DEBUG_TYPE "fs-profile-loader"
40 static cl::opt
<bool> ShowFSBranchProb(
41 "show-fs-branchprob", cl::Hidden
, cl::init(false),
42 cl::desc("Print setting flow sensitive branch probabilities"));
43 static cl::opt
<unsigned> FSProfileDebugProbDiffThreshold(
44 "fs-profile-debug-prob-diff-threshold", cl::init(10),
45 cl::desc("Only show debug message if the branch probility is greater than "
46 "this value (in percentage)."));
48 static cl::opt
<unsigned> FSProfileDebugBWThreshold(
49 "fs-profile-debug-bw-threshold", cl::init(10000),
50 cl::desc("Only show debug message if the source branch weight is greater "
51 " than this value."));
53 static cl::opt
<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden
,
55 cl::desc("View BFI before MIR loader"));
56 static cl::opt
<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden
,
58 cl::desc("View BFI after MIR loader"));
60 char MIRProfileLoaderPass::ID
= 0;
62 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass
, DEBUG_TYPE
,
63 "Load MIR Sample Profile",
64 /* cfg = */ false, /* is_analysis = */ false)
65 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo
)
66 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree
)
67 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree
)
68 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
69 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass
)
70 INITIALIZE_PASS_END(MIRProfileLoaderPass
, DEBUG_TYPE
, "Load MIR Sample Profile",
71 /* cfg = */ false, /* is_analysis = */ false)
73 char &llvm::MIRProfileLoaderPassID
= MIRProfileLoaderPass::ID
;
75 FunctionPass
*llvm::createMIRProfileLoaderPass(std::string File
,
76 std::string RemappingFile
,
77 FSDiscriminatorPass P
) {
78 return new MIRProfileLoaderPass(File
, RemappingFile
, P
);
83 // Internal option used to control BFI display only after MBP pass.
84 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
85 // -view-block-layout-with-bfi={none | fraction | integer | count}
86 extern cl::opt
<GVDAGType
> ViewBlockLayoutWithBFI
;
88 // Command line option to specify the name of the function for CFG dump
89 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
90 extern cl::opt
<std::string
> ViewBlockFreqFuncName
;
92 namespace afdo_detail
{
93 template <> struct IRTraits
<MachineBasicBlock
> {
94 using InstructionT
= MachineInstr
;
95 using BasicBlockT
= MachineBasicBlock
;
96 using FunctionT
= MachineFunction
;
97 using BlockFrequencyInfoT
= MachineBlockFrequencyInfo
;
98 using LoopT
= MachineLoop
;
99 using LoopInfoPtrT
= MachineLoopInfo
*;
100 using DominatorTreePtrT
= MachineDominatorTree
*;
101 using PostDominatorTreePtrT
= MachinePostDominatorTree
*;
102 using PostDominatorTreeT
= MachinePostDominatorTree
;
103 using OptRemarkEmitterT
= MachineOptimizationRemarkEmitter
;
104 using OptRemarkAnalysisT
= MachineOptimizationRemarkAnalysis
;
105 using PredRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
106 using SuccRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
107 static Function
&getFunction(MachineFunction
&F
) { return F
.getFunction(); }
108 static const MachineBasicBlock
*getEntryBB(const MachineFunction
*F
) {
109 return GraphTraits
<const MachineFunction
*>::getEntryNode(F
);
111 static PredRangeT
getPredecessors(MachineBasicBlock
*BB
) {
112 return BB
->predecessors();
114 static SuccRangeT
getSuccessors(MachineBasicBlock
*BB
) {
115 return BB
->successors();
118 } // namespace afdo_detail
120 class MIRProfileLoader final
121 : public SampleProfileLoaderBaseImpl
<MachineBasicBlock
> {
123 void setInitVals(MachineDominatorTree
*MDT
, MachinePostDominatorTree
*MPDT
,
124 MachineLoopInfo
*MLI
, MachineBlockFrequencyInfo
*MBFI
,
125 MachineOptimizationRemarkEmitter
*MORE
) {
132 void setFSPass(FSDiscriminatorPass Pass
) {
134 LowBit
= getFSPassBitBegin(P
);
135 HighBit
= getFSPassBitEnd(P
);
136 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
139 MIRProfileLoader(StringRef Name
, StringRef RemapName
)
140 : SampleProfileLoaderBaseImpl(std::string(Name
), std::string(RemapName
)) {
143 void setBranchProbs(MachineFunction
&F
);
144 bool runOnFunction(MachineFunction
&F
);
145 bool doInitialization(Module
&M
);
146 bool isValid() const { return ProfileIsValid
; }
149 friend class SampleCoverageTracker
;
151 /// Hold the information of the basic block frequency.
152 MachineBlockFrequencyInfo
*BFI
;
154 /// PassNum is the sequence number this pass is called, start from 1.
155 FSDiscriminatorPass P
;
157 // LowBit in the FS discriminator used by this instance. Note the number is
158 // 0-based. Base discrimnator use bit 0 to bit 11.
160 // HighwBit in the FS discriminator used by this instance. Note the number
164 bool ProfileIsValid
= true;
168 void SampleProfileLoaderBaseImpl
<
169 MachineBasicBlock
>::computeDominanceAndLoopInfo(MachineFunction
&F
) {}
171 void MIRProfileLoader::setBranchProbs(MachineFunction
&F
) {
172 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
174 MachineBasicBlock
*BB
= &BI
;
175 if (BB
->succ_size() < 2)
177 const MachineBasicBlock
*EC
= EquivalenceClass
[BB
];
178 uint64_t BBWeight
= BlockWeights
[EC
];
179 uint64_t SumEdgeWeight
= 0;
180 for (MachineBasicBlock
*Succ
: BB
->successors()) {
181 Edge E
= std::make_pair(BB
, Succ
);
182 SumEdgeWeight
+= EdgeWeights
[E
];
185 if (BBWeight
!= SumEdgeWeight
) {
186 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
187 << BBWeight
<< " SumEdgeWeight= " << SumEdgeWeight
189 BBWeight
= SumEdgeWeight
;
192 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
197 uint64_t BBWeightOrig
= BBWeight
;
199 uint32_t MaxWeight
= std::numeric_limits
<uint32_t>::max();
201 if (BBWeight
> MaxWeight
) {
202 Factor
= BBWeight
/ MaxWeight
+ 1;
204 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor
<< "\n");
207 for (MachineBasicBlock::succ_iterator SI
= BB
->succ_begin(),
210 MachineBasicBlock
*Succ
= *SI
;
211 Edge E
= std::make_pair(BB
, Succ
);
212 uint64_t EdgeWeight
= EdgeWeights
[E
];
213 EdgeWeight
/= Factor
;
215 assert(BBWeight
>= EdgeWeight
&&
216 "BBweight is larger than EdgeWeight -- should not happen.\n");
218 BranchProbability OldProb
= BFI
->getMBPI()->getEdgeProbability(BB
, SI
);
219 BranchProbability
NewProb(EdgeWeight
, BBWeight
);
220 if (OldProb
== NewProb
)
222 BB
->setSuccProbability(SI
, NewProb
);
224 if (!ShowFSBranchProb
)
227 BranchProbability Diff
;
228 if (OldProb
> NewProb
)
229 Diff
= OldProb
- NewProb
;
231 Diff
= NewProb
- OldProb
;
232 Show
= (Diff
>= BranchProbability(FSProfileDebugProbDiffThreshold
, 100));
233 Show
&= (BBWeightOrig
>= FSProfileDebugBWThreshold
);
235 auto DIL
= BB
->findBranchDebugLoc();
236 auto SuccDIL
= Succ
->findBranchDebugLoc();
238 dbgs() << "Set branch fs prob: MBB (" << BB
->getNumber() << " -> "
239 << Succ
->getNumber() << "): ";
241 dbgs() << DIL
->getFilename() << ":" << DIL
->getLine() << ":"
244 dbgs() << "-->" << SuccDIL
->getFilename() << ":" << SuccDIL
->getLine()
245 << ":" << SuccDIL
->getColumn();
246 dbgs() << " W=" << BBWeightOrig
<< " " << OldProb
<< " --> " << NewProb
254 bool MIRProfileLoader::doInitialization(Module
&M
) {
255 auto &Ctx
= M
.getContext();
257 auto ReaderOrErr
= sampleprof::SampleProfileReader::create(Filename
, Ctx
, P
,
259 if (std::error_code EC
= ReaderOrErr
.getError()) {
260 std::string Msg
= "Could not open profile: " + EC
.message();
261 Ctx
.diagnose(DiagnosticInfoSampleProfile(Filename
, Msg
));
265 Reader
= std::move(ReaderOrErr
.get());
266 Reader
->setModule(&M
);
267 ProfileIsValid
= (Reader
->read() == sampleprof_error::success
);
268 Reader
->getSummary();
273 bool MIRProfileLoader::runOnFunction(MachineFunction
&MF
) {
274 Function
&Func
= MF
.getFunction();
275 clearFunctionData(false);
276 Samples
= Reader
->getSamplesFor(Func
);
277 if (!Samples
|| Samples
->empty())
280 if (getFunctionLoc(MF
) == 0)
283 DenseSet
<GlobalValue::GUID
> InlinedGUIDs
;
284 bool Changed
= computeAndPropagateWeights(MF
, InlinedGUIDs
);
286 // Set the new BPI, BFI.
294 MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName
,
295 std::string RemappingFileName
,
296 FSDiscriminatorPass P
)
297 : MachineFunctionPass(ID
), ProfileFileName(FileName
), P(P
),
299 std::make_unique
<MIRProfileLoader
>(FileName
, RemappingFileName
)) {
300 LowBit
= getFSPassBitBegin(P
);
301 HighBit
= getFSPassBitEnd(P
);
302 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
305 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction
&MF
) {
306 if (!MIRSampleLoader
->isValid())
309 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
310 << MF
.getFunction().getName() << "\n");
311 MBFI
= &getAnalysis
<MachineBlockFrequencyInfo
>();
312 MIRSampleLoader
->setInitVals(
313 &getAnalysis
<MachineDominatorTree
>(),
314 &getAnalysis
<MachinePostDominatorTree
>(), &getAnalysis
<MachineLoopInfo
>(),
315 MBFI
, &getAnalysis
<MachineOptimizationRemarkEmitterPass
>().getORE());
318 if (ViewBFIBefore
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
319 (ViewBlockFreqFuncName
.empty() ||
320 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
321 MBFI
->view("MIR_Prof_loader_b." + MF
.getName(), false);
324 bool Changed
= MIRSampleLoader
->runOnFunction(MF
);
326 MBFI
->calculate(MF
, *MBFI
->getMBPI(), *&getAnalysis
<MachineLoopInfo
>());
328 if (ViewBFIAfter
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
329 (ViewBlockFreqFuncName
.empty() ||
330 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
331 MBFI
->view("MIR_prof_loader_a." + MF
.getName(), false);
337 bool MIRProfileLoaderPass::doInitialization(Module
&M
) {
338 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M
.getName()
341 MIRSampleLoader
->setFSPass(P
);
342 return MIRSampleLoader
->doInitialization(M
);
345 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
346 AU
.setPreservesAll();
347 AU
.addRequired
<MachineBlockFrequencyInfo
>();
348 AU
.addRequired
<MachineDominatorTree
>();
349 AU
.addRequired
<MachinePostDominatorTree
>();
350 AU
.addRequiredTransitive
<MachineLoopInfo
>();
351 AU
.addRequired
<MachineOptimizationRemarkEmitterPass
>();
352 MachineFunctionPass::getAnalysisUsage(AU
);