1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
38 using namespace sampleprof
;
39 using namespace llvm::sampleprofutil
;
40 using ProfileCount
= Function::ProfileCount
;
42 #define DEBUG_TYPE "fs-profile-loader"
44 static cl::opt
<bool> ShowFSBranchProb(
45 "show-fs-branchprob", cl::Hidden
, cl::init(false),
46 cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt
<unsigned> FSProfileDebugProbDiffThreshold(
48 "fs-profile-debug-prob-diff-threshold", cl::init(10),
49 cl::desc("Only show debug message if the branch probility is greater than "
50 "this value (in percentage)."));
52 static cl::opt
<unsigned> FSProfileDebugBWThreshold(
53 "fs-profile-debug-bw-threshold", cl::init(10000),
54 cl::desc("Only show debug message if the source branch weight is greater "
55 " than this value."));
57 static cl::opt
<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden
,
59 cl::desc("View BFI before MIR loader"));
60 static cl::opt
<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden
,
62 cl::desc("View BFI after MIR loader"));
64 extern cl::opt
<bool> ImprovedFSDiscriminator
;
65 char MIRProfileLoaderPass::ID
= 0;
67 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass
, DEBUG_TYPE
,
68 "Load MIR Sample Profile",
69 /* cfg = */ false, /* is_analysis = */ false)
70 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo
)
71 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree
)
72 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree
)
73 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo
)
74 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass
)
75 INITIALIZE_PASS_END(MIRProfileLoaderPass
, DEBUG_TYPE
, "Load MIR Sample Profile",
76 /* cfg = */ false, /* is_analysis = */ false)
78 char &llvm::MIRProfileLoaderPassID
= MIRProfileLoaderPass::ID
;
81 llvm::createMIRProfileLoaderPass(std::string File
, std::string RemappingFile
,
82 FSDiscriminatorPass P
,
83 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
) {
84 return new MIRProfileLoaderPass(File
, RemappingFile
, P
, std::move(FS
));
89 // Internal option used to control BFI display only after MBP pass.
90 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
91 // -view-block-layout-with-bfi={none | fraction | integer | count}
92 extern cl::opt
<GVDAGType
> ViewBlockLayoutWithBFI
;
94 // Command line option to specify the name of the function for CFG dump
95 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
96 extern cl::opt
<std::string
> ViewBlockFreqFuncName
;
98 std::optional
<PseudoProbe
> extractProbe(const MachineInstr
&MI
) {
99 if (MI
.isPseudoProbe()) {
101 Probe
.Id
= MI
.getOperand(1).getImm();
102 Probe
.Type
= MI
.getOperand(2).getImm();
103 Probe
.Attr
= MI
.getOperand(3).getImm();
105 DILocation
*DebugLoc
= MI
.getDebugLoc();
106 Probe
.Discriminator
= DebugLoc
? DebugLoc
->getDiscriminator() : 0;
110 // Ignore callsite probes since they do not have FS discriminators.
114 namespace afdo_detail
{
115 template <> struct IRTraits
<MachineBasicBlock
> {
116 using InstructionT
= MachineInstr
;
117 using BasicBlockT
= MachineBasicBlock
;
118 using FunctionT
= MachineFunction
;
119 using BlockFrequencyInfoT
= MachineBlockFrequencyInfo
;
120 using LoopT
= MachineLoop
;
121 using LoopInfoPtrT
= MachineLoopInfo
*;
122 using DominatorTreePtrT
= MachineDominatorTree
*;
123 using PostDominatorTreePtrT
= MachinePostDominatorTree
*;
124 using PostDominatorTreeT
= MachinePostDominatorTree
;
125 using OptRemarkEmitterT
= MachineOptimizationRemarkEmitter
;
126 using OptRemarkAnalysisT
= MachineOptimizationRemarkAnalysis
;
127 using PredRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
128 using SuccRangeT
= iterator_range
<std::vector
<MachineBasicBlock
*>::iterator
>;
129 static Function
&getFunction(MachineFunction
&F
) { return F
.getFunction(); }
130 static const MachineBasicBlock
*getEntryBB(const MachineFunction
*F
) {
131 return GraphTraits
<const MachineFunction
*>::getEntryNode(F
);
133 static PredRangeT
getPredecessors(MachineBasicBlock
*BB
) {
134 return BB
->predecessors();
136 static SuccRangeT
getSuccessors(MachineBasicBlock
*BB
) {
137 return BB
->successors();
140 } // namespace afdo_detail
142 class MIRProfileLoader final
143 : public SampleProfileLoaderBaseImpl
<MachineFunction
> {
145 void setInitVals(MachineDominatorTree
*MDT
, MachinePostDominatorTree
*MPDT
,
146 MachineLoopInfo
*MLI
, MachineBlockFrequencyInfo
*MBFI
,
147 MachineOptimizationRemarkEmitter
*MORE
) {
154 void setFSPass(FSDiscriminatorPass Pass
) {
156 LowBit
= getFSPassBitBegin(P
);
157 HighBit
= getFSPassBitEnd(P
);
158 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
161 MIRProfileLoader(StringRef Name
, StringRef RemapName
,
162 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
)
163 : SampleProfileLoaderBaseImpl(std::string(Name
), std::string(RemapName
),
166 void setBranchProbs(MachineFunction
&F
);
167 bool runOnFunction(MachineFunction
&F
);
168 bool doInitialization(Module
&M
);
169 bool isValid() const { return ProfileIsValid
; }
172 friend class SampleCoverageTracker
;
174 /// Hold the information of the basic block frequency.
175 MachineBlockFrequencyInfo
*BFI
;
177 /// PassNum is the sequence number this pass is called, start from 1.
178 FSDiscriminatorPass P
;
180 // LowBit in the FS discriminator used by this instance. Note the number is
181 // 0-based. Base discrimnator use bit 0 to bit 11.
183 // HighwBit in the FS discriminator used by this instance. Note the number
187 bool ProfileIsValid
= true;
188 ErrorOr
<uint64_t> getInstWeight(const MachineInstr
&MI
) override
{
189 if (FunctionSamples::ProfileIsProbeBased
)
190 return getProbeWeight(MI
);
191 if (ImprovedFSDiscriminator
&& MI
.isMetaInstruction())
192 return std::error_code();
193 return getInstWeightImpl(MI
);
198 void SampleProfileLoaderBaseImpl
<MachineFunction
>::computeDominanceAndLoopInfo(
199 MachineFunction
&F
) {}
201 void MIRProfileLoader::setBranchProbs(MachineFunction
&F
) {
202 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
204 MachineBasicBlock
*BB
= &BI
;
205 if (BB
->succ_size() < 2)
207 const MachineBasicBlock
*EC
= EquivalenceClass
[BB
];
208 uint64_t BBWeight
= BlockWeights
[EC
];
209 uint64_t SumEdgeWeight
= 0;
210 for (MachineBasicBlock
*Succ
: BB
->successors()) {
211 Edge E
= std::make_pair(BB
, Succ
);
212 SumEdgeWeight
+= EdgeWeights
[E
];
215 if (BBWeight
!= SumEdgeWeight
) {
216 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
217 << BBWeight
<< " SumEdgeWeight= " << SumEdgeWeight
219 BBWeight
= SumEdgeWeight
;
222 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
227 uint64_t BBWeightOrig
= BBWeight
;
229 uint32_t MaxWeight
= std::numeric_limits
<uint32_t>::max();
231 if (BBWeight
> MaxWeight
) {
232 Factor
= BBWeight
/ MaxWeight
+ 1;
234 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor
<< "\n");
237 for (MachineBasicBlock::succ_iterator SI
= BB
->succ_begin(),
240 MachineBasicBlock
*Succ
= *SI
;
241 Edge E
= std::make_pair(BB
, Succ
);
242 uint64_t EdgeWeight
= EdgeWeights
[E
];
243 EdgeWeight
/= Factor
;
245 assert(BBWeight
>= EdgeWeight
&&
246 "BBweight is larger than EdgeWeight -- should not happen.\n");
248 BranchProbability OldProb
= BFI
->getMBPI()->getEdgeProbability(BB
, SI
);
249 BranchProbability
NewProb(EdgeWeight
, BBWeight
);
250 if (OldProb
== NewProb
)
252 BB
->setSuccProbability(SI
, NewProb
);
254 if (!ShowFSBranchProb
)
257 BranchProbability Diff
;
258 if (OldProb
> NewProb
)
259 Diff
= OldProb
- NewProb
;
261 Diff
= NewProb
- OldProb
;
262 Show
= (Diff
>= BranchProbability(FSProfileDebugProbDiffThreshold
, 100));
263 Show
&= (BBWeightOrig
>= FSProfileDebugBWThreshold
);
265 auto DIL
= BB
->findBranchDebugLoc();
266 auto SuccDIL
= Succ
->findBranchDebugLoc();
268 dbgs() << "Set branch fs prob: MBB (" << BB
->getNumber() << " -> "
269 << Succ
->getNumber() << "): ";
271 dbgs() << DIL
->getFilename() << ":" << DIL
->getLine() << ":"
274 dbgs() << "-->" << SuccDIL
->getFilename() << ":" << SuccDIL
->getLine()
275 << ":" << SuccDIL
->getColumn();
276 dbgs() << " W=" << BBWeightOrig
<< " " << OldProb
<< " --> " << NewProb
284 bool MIRProfileLoader::doInitialization(Module
&M
) {
285 auto &Ctx
= M
.getContext();
287 auto ReaderOrErr
= sampleprof::SampleProfileReader::create(
288 Filename
, Ctx
, *FS
, P
, RemappingFilename
);
289 if (std::error_code EC
= ReaderOrErr
.getError()) {
290 std::string Msg
= "Could not open profile: " + EC
.message();
291 Ctx
.diagnose(DiagnosticInfoSampleProfile(Filename
, Msg
));
295 Reader
= std::move(ReaderOrErr
.get());
296 Reader
->setModule(&M
);
297 ProfileIsValid
= (Reader
->read() == sampleprof_error::success
);
299 // Load pseudo probe descriptors for probe-based function samples.
300 if (Reader
->profileIsProbeBased()) {
301 ProbeManager
= std::make_unique
<PseudoProbeManager
>(M
);
302 if (!ProbeManager
->moduleIsProbed(M
)) {
310 bool MIRProfileLoader::runOnFunction(MachineFunction
&MF
) {
311 // Do not load non-FS profiles. A line or probe can get a zero-valued
312 // discriminator at certain pass which could result in accidentally loading
313 // the corresponding base counter in the non-FS profile, while a non-zero
314 // discriminator would end up getting zero samples. This could in turn undo
315 // the sample distribution effort done by previous BFI maintenance and the
316 // probe distribution factor work for pseudo probes.
317 if (!Reader
->profileIsFS())
320 Function
&Func
= MF
.getFunction();
321 clearFunctionData(false);
322 Samples
= Reader
->getSamplesFor(Func
);
323 if (!Samples
|| Samples
->empty())
326 if (FunctionSamples::ProfileIsProbeBased
) {
327 if (!ProbeManager
->profileIsValid(MF
.getFunction(), *Samples
))
330 if (getFunctionLoc(MF
) == 0)
334 DenseSet
<GlobalValue::GUID
> InlinedGUIDs
;
335 bool Changed
= computeAndPropagateWeights(MF
, InlinedGUIDs
);
337 // Set the new BPI, BFI.
345 MIRProfileLoaderPass::MIRProfileLoaderPass(
346 std::string FileName
, std::string RemappingFileName
, FSDiscriminatorPass P
,
347 IntrusiveRefCntPtr
<vfs::FileSystem
> FS
)
348 : MachineFunctionPass(ID
), ProfileFileName(FileName
), P(P
) {
349 LowBit
= getFSPassBitBegin(P
);
350 HighBit
= getFSPassBitEnd(P
);
352 auto VFS
= FS
? std::move(FS
) : vfs::getRealFileSystem();
353 MIRSampleLoader
= std::make_unique
<MIRProfileLoader
>(
354 FileName
, RemappingFileName
, std::move(VFS
));
355 assert(LowBit
< HighBit
&& "HighBit needs to be greater than Lowbit");
358 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction
&MF
) {
359 if (!MIRSampleLoader
->isValid())
362 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
363 << MF
.getFunction().getName() << "\n");
364 MBFI
= &getAnalysis
<MachineBlockFrequencyInfo
>();
365 MIRSampleLoader
->setInitVals(
366 &getAnalysis
<MachineDominatorTree
>(),
367 &getAnalysis
<MachinePostDominatorTree
>(), &getAnalysis
<MachineLoopInfo
>(),
368 MBFI
, &getAnalysis
<MachineOptimizationRemarkEmitterPass
>().getORE());
371 if (ViewBFIBefore
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
372 (ViewBlockFreqFuncName
.empty() ||
373 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
374 MBFI
->view("MIR_Prof_loader_b." + MF
.getName(), false);
377 bool Changed
= MIRSampleLoader
->runOnFunction(MF
);
379 MBFI
->calculate(MF
, *MBFI
->getMBPI(), *&getAnalysis
<MachineLoopInfo
>());
381 if (ViewBFIAfter
&& ViewBlockLayoutWithBFI
!= GVDT_None
&&
382 (ViewBlockFreqFuncName
.empty() ||
383 MF
.getFunction().getName().equals(ViewBlockFreqFuncName
))) {
384 MBFI
->view("MIR_prof_loader_a." + MF
.getName(), false);
390 bool MIRProfileLoaderPass::doInitialization(Module
&M
) {
391 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M
.getName()
394 MIRSampleLoader
->setFSPass(P
);
395 return MIRSampleLoader
->doInitialization(M
);
398 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage
&AU
) const {
399 AU
.setPreservesAll();
400 AU
.addRequired
<MachineBlockFrequencyInfo
>();
401 AU
.addRequired
<MachineDominatorTree
>();
402 AU
.addRequired
<MachinePostDominatorTree
>();
403 AU
.addRequiredTransitive
<MachineLoopInfo
>();
404 AU
.addRequired
<MachineOptimizationRemarkEmitterPass
>();
405 MachineFunctionPass::getAnalysisUsage(AU
);