1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
38 using namespace sampleprof
;
39 using namespace llvm::sampleprofutil
;
40 using ProfileCount
= Function::ProfileCount
;
42 #define DEBUG_TYPE "fs-profile-loader"
44 static cl::opt
<bool> ShowFSBranchProb(
45 "show-fs-branchprob", cl::Hidden
, cl::init(false),
46 cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt
<unsigned> FSProfileDebugProbDiffThreshold(
48 "fs-profile-debug-prob-diff-threshold", cl::init(10),
49 cl::desc("Only show debug message if the branch probility is greater than "
50 "this value (in percentage)."));
52 static cl::opt
<unsigned> FSProfileDebugBWThreshold(
53 "fs-profile-debug-bw-threshold", cl::init(10000),
54 cl::desc("Only show debug message if the source branch weight is greater "
55 " than this value."));
57 static cl::opt
<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden
,
59 cl::desc("View BFI before MIR loader"));
60 static cl::opt
<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden
,
62 cl::desc("View BFI after MIR loader"));
65 extern cl::opt
<bool> ImprovedFSDiscriminator
;
67 char MIRProfileLoaderPass::ID
= 0;
69 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass
, DEBUG_TYPE
,
70 "Load MIR Sample Profile",
71 /* cfg = */ false, /* is_analysis = */ false)
72 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo
)
73 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree
)
74 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree
)
75 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
76 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass
)
77 INITIALIZE_PASS_END(MIRProfileLoaderPass
, DEBUG_TYPE
, "Load MIR Sample Profile",
78 /* cfg = */ false, /* is_analysis = */ false)
80 char &llvm::MIRProfileLoaderPassID
= MIRProfileLoaderPass::ID
;
83 llvm::createMIRProfileLoaderPass(std::string File
, std::string RemappingFile
,
84 FSDiscriminatorPass P
,
85 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
) {
86 return new MIRProfileLoaderPass(File
, RemappingFile
, P
, std::move(FS
));
91 // Internal option used to control BFI display only after MBP pass.
92 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
93 // -view-block-layout-with-bfi={none | fraction | integer | count}
94 extern cl::opt
<GVDAGType
> ViewBlockLayoutWithBFI
;
96 // Command line option to specify the name of the function for CFG dump
97 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
98 extern cl::opt
<std::string
> ViewBlockFreqFuncName
;
100 std::optional
<PseudoProbe
> extractProbe(const MachineInstr
&MI
) {
101 if (MI
.isPseudoProbe()) {
103 Probe
.Id
= MI
.getOperand(1).getImm();
104 Probe
.Type
= MI
.getOperand(2).getImm();
105 Probe
.Attr
= MI
.getOperand(3).getImm();
107 DILocation
*DebugLoc
= MI
.getDebugLoc();
108 Probe
.Discriminator
= DebugLoc
? DebugLoc
->getDiscriminator() : 0;
112 // Ignore callsite probes since they do not have FS discriminators.
116 namespace afdo_detail
{
117 template <> struct IRTraits
<MachineBasicBlock
> {
118 using InstructionT
= MachineInstr
;
119 using BasicBlockT
= MachineBasicBlock
;
120 using FunctionT
= MachineFunction
;
121 using BlockFrequencyInfoT
= MachineBlockFrequencyInfo
;
122 using LoopT
= MachineLoop
;
123 using LoopInfoPtrT
= MachineLoopInfo
*;
124 using DominatorTreePtrT
= MachineDominatorTree
*;
125 using PostDominatorTreePtrT
= MachinePostDominatorTree
*;
126 using PostDominatorTreeT
= MachinePostDominatorTree
;
127 using OptRemarkEmitterT
= MachineOptimizationRemarkEmitter
;
128 using OptRemarkAnalysisT
= MachineOptimizationRemarkAnalysis
;
129 using PredRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
130 using SuccRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
131 static Function
&getFunction(MachineFunction
&F
) { return F
.getFunction(); }
132 static const MachineBasicBlock
*getEntryBB(const MachineFunction
*F
) {
133 return GraphTraits
<const MachineFunction
*>::getEntryNode(F
);
135 static PredRangeT
getPredecessors(MachineBasicBlock
*BB
) {
136 return BB
->predecessors();
138 static SuccRangeT
getSuccessors(MachineBasicBlock
*BB
) {
139 return BB
->successors();
142 } // namespace afdo_detail
144 class MIRProfileLoader final
145 : public SampleProfileLoaderBaseImpl
<MachineFunction
> {
147 void setInitVals(MachineDominatorTree
*MDT
, MachinePostDominatorTree
*MPDT
,
148 MachineLoopInfo
*MLI
, MachineBlockFrequencyInfo
*MBFI
,
149 MachineOptimizationRemarkEmitter
*MORE
) {
156 void setFSPass(FSDiscriminatorPass Pass
) {
158 LowBit
= getFSPassBitBegin(P
);
159 HighBit
= getFSPassBitEnd(P
);
160 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
163 MIRProfileLoader(StringRef Name
, StringRef RemapName
,
164 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
)
165 : SampleProfileLoaderBaseImpl(std::string(Name
), std::string(RemapName
),
168 void setBranchProbs(MachineFunction
&F
);
169 bool runOnFunction(MachineFunction
&F
);
170 bool doInitialization(Module
&M
);
171 bool isValid() const { return ProfileIsValid
; }
174 friend class SampleCoverageTracker
;
176 /// Hold the information of the basic block frequency.
177 MachineBlockFrequencyInfo
*BFI
;
179 /// PassNum is the sequence number this pass is called, start from 1.
180 FSDiscriminatorPass P
;
182 // LowBit in the FS discriminator used by this instance. Note the number is
183 // 0-based. Base discrimnator use bit 0 to bit 11.
185 // HighwBit in the FS discriminator used by this instance. Note the number
189 bool ProfileIsValid
= true;
190 ErrorOr
<uint64_t> getInstWeight(const MachineInstr
&MI
) override
{
191 if (FunctionSamples::ProfileIsProbeBased
)
192 return getProbeWeight(MI
);
193 if (ImprovedFSDiscriminator
&& MI
.isMetaInstruction())
194 return std::error_code();
195 return getInstWeightImpl(MI
);
200 void SampleProfileLoaderBaseImpl
<MachineFunction
>::computeDominanceAndLoopInfo(
201 MachineFunction
&F
) {}
203 void MIRProfileLoader::setBranchProbs(MachineFunction
&F
) {
204 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
206 MachineBasicBlock
*BB
= &BI
;
207 if (BB
->succ_size() < 2)
209 const MachineBasicBlock
*EC
= EquivalenceClass
[BB
];
210 uint64_t BBWeight
= BlockWeights
[EC
];
211 uint64_t SumEdgeWeight
= 0;
212 for (MachineBasicBlock
*Succ
: BB
->successors()) {
213 Edge E
= std::make_pair(BB
, Succ
);
214 SumEdgeWeight
+= EdgeWeights
[E
];
217 if (BBWeight
!= SumEdgeWeight
) {
218 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
219 << BBWeight
<< " SumEdgeWeight= " << SumEdgeWeight
221 BBWeight
= SumEdgeWeight
;
224 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
229 uint64_t BBWeightOrig
= BBWeight
;
231 uint32_t MaxWeight
= std::numeric_limits
<uint32_t>::max();
233 if (BBWeight
> MaxWeight
) {
234 Factor
= BBWeight
/ MaxWeight
+ 1;
236 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor
<< "\n");
239 for (MachineBasicBlock::succ_iterator SI
= BB
->succ_begin(),
242 MachineBasicBlock
*Succ
= *SI
;
243 Edge E
= std::make_pair(BB
, Succ
);
244 uint64_t EdgeWeight
= EdgeWeights
[E
];
245 EdgeWeight
/= Factor
;
247 assert(BBWeight
>= EdgeWeight
&&
248 "BBweight is larger than EdgeWeight -- should not happen.\n");
250 BranchProbability OldProb
= BFI
->getMBPI()->getEdgeProbability(BB
, SI
);
251 BranchProbability
NewProb(EdgeWeight
, BBWeight
);
252 if (OldProb
== NewProb
)
254 BB
->setSuccProbability(SI
, NewProb
);
256 if (!ShowFSBranchProb
)
259 BranchProbability Diff
;
260 if (OldProb
> NewProb
)
261 Diff
= OldProb
- NewProb
;
263 Diff
= NewProb
- OldProb
;
264 Show
= (Diff
>= BranchProbability(FSProfileDebugProbDiffThreshold
, 100));
265 Show
&= (BBWeightOrig
>= FSProfileDebugBWThreshold
);
267 auto DIL
= BB
->findBranchDebugLoc();
268 auto SuccDIL
= Succ
->findBranchDebugLoc();
270 dbgs() << "Set branch fs prob: MBB (" << BB
->getNumber() << " -> "
271 << Succ
->getNumber() << "): ";
273 dbgs() << DIL
->getFilename() << ":" << DIL
->getLine() << ":"
276 dbgs() << "-->" << SuccDIL
->getFilename() << ":" << SuccDIL
->getLine()
277 << ":" << SuccDIL
->getColumn();
278 dbgs() << " W=" << BBWeightOrig
<< " " << OldProb
<< " --> " << NewProb
286 bool MIRProfileLoader::doInitialization(Module
&M
) {
287 auto &Ctx
= M
.getContext();
289 auto ReaderOrErr
= sampleprof::SampleProfileReader::create(
290 Filename
, Ctx
, *FS
, P
, RemappingFilename
);
291 if (std::error_code EC
= ReaderOrErr
.getError()) {
292 std::string Msg
= "Could not open profile: " + EC
.message();
293 Ctx
.diagnose(DiagnosticInfoSampleProfile(Filename
, Msg
));
297 Reader
= std::move(ReaderOrErr
.get());
298 Reader
->setModule(&M
);
299 ProfileIsValid
= (Reader
->read() == sampleprof_error::success
);
301 // Load pseudo probe descriptors for probe-based function samples.
302 if (Reader
->profileIsProbeBased()) {
303 ProbeManager
= std::make_unique
<PseudoProbeManager
>(M
);
304 if (!ProbeManager
->moduleIsProbed(M
)) {
312 bool MIRProfileLoader::runOnFunction(MachineFunction
&MF
) {
313 // Do not load non-FS profiles. A line or probe can get a zero-valued
314 // discriminator at certain pass which could result in accidentally loading
315 // the corresponding base counter in the non-FS profile, while a non-zero
316 // discriminator would end up getting zero samples. This could in turn undo
317 // the sample distribution effort done by previous BFI maintenance and the
318 // probe distribution factor work for pseudo probes.
319 if (!Reader
->profileIsFS())
322 Function
&Func
= MF
.getFunction();
323 clearFunctionData(false);
324 Samples
= Reader
->getSamplesFor(Func
);
325 if (!Samples
|| Samples
->empty())
328 if (FunctionSamples::ProfileIsProbeBased
) {
329 if (!ProbeManager
->profileIsValid(MF
.getFunction(), *Samples
))
332 if (getFunctionLoc(MF
) == 0)
336 DenseSet
<GlobalValue::GUID
> InlinedGUIDs
;
337 bool Changed
= computeAndPropagateWeights(MF
, InlinedGUIDs
);
339 // Set the new BPI, BFI.
347 MIRProfileLoaderPass::MIRProfileLoaderPass(
348 std::string FileName
, std::string RemappingFileName
, FSDiscriminatorPass P
,
349 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
)
350 : MachineFunctionPass(ID
), ProfileFileName(FileName
), P(P
) {
351 LowBit
= getFSPassBitBegin(P
);
352 HighBit
= getFSPassBitEnd(P
);
354 auto VFS
= FS
? std::move(FS
) : vfs::getRealFileSystem();
355 MIRSampleLoader
= std::make_unique
<MIRProfileLoader
>(
356 FileName
, RemappingFileName
, std::move(VFS
));
357 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
360 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction
&MF
) {
361 if (!MIRSampleLoader
->isValid())
364 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
365 << MF
.getFunction().getName() << "\n");
366 MBFI
= &getAnalysis
<MachineBlockFrequencyInfo
>();
367 MIRSampleLoader
->setInitVals(
368 &getAnalysis
<MachineDominatorTree
>(),
369 &getAnalysis
<MachinePostDominatorTree
>(), &getAnalysis
<MachineLoopInfo
>(),
370 MBFI
, &getAnalysis
<MachineOptimizationRemarkEmitterPass
>().getORE());
373 if (ViewBFIBefore
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
374 (ViewBlockFreqFuncName
.empty() ||
375 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
376 MBFI
->view("MIR_Prof_loader_b." + MF
.getName(), false);
379 bool Changed
= MIRSampleLoader
->runOnFunction(MF
);
381 MBFI
->calculate(MF
, *MBFI
->getMBPI(), *&getAnalysis
<MachineLoopInfo
>());
383 if (ViewBFIAfter
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
384 (ViewBlockFreqFuncName
.empty() ||
385 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
386 MBFI
->view("MIR_prof_loader_a." + MF
.getName(), false);
392 bool MIRProfileLoaderPass::doInitialization(Module
&M
) {
393 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M
.getName()
396 MIRSampleLoader
->setFSPass(P
);
397 return MIRSampleLoader
->doInitialization(M
);
400 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
401 AU
.setPreservesAll();
402 AU
.addRequired
<MachineBlockFrequencyInfo
>();
403 AU
.addRequired
<MachineDominatorTree
>();
404 AU
.addRequired
<MachinePostDominatorTree
>();
405 AU
.addRequiredTransitive
<MachineLoopInfo
>();
406 AU
.addRequired
<MachineOptimizationRemarkEmitterPass
>();
407 MachineFunctionPass::getAnalysisUsage(AU
);