[AMDGPU][AsmParser][NFC] Translate parsed MIMG instructions to MCInsts automatically.
[llvm-project.git] / llvm / lib / CodeGen / MIRSampleProfile.cpp
blob96f8589e682d55b6a5f8ec367652cf19dd79cb45
1 //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides the implementation of the MIRSampleProfile loader, mainly
10 // for flow sensitive SampleFDO.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MIRSampleProfile.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20 #include "llvm/CodeGen/MachineDominators.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineLoopInfo.h"
23 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24 #include "llvm/CodeGen/MachinePostDominators.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/PseudoProbe.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/VirtualFileSystem.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34 #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35 #include <optional>
37 using namespace llvm;
38 using namespace sampleprof;
39 using namespace llvm::sampleprofutil;
40 using ProfileCount = Function::ProfileCount;
42 #define DEBUG_TYPE "fs-profile-loader"
44 static cl::opt<bool> ShowFSBranchProb(
45 "show-fs-branchprob", cl::Hidden, cl::init(false),
46 cl::desc("Print setting flow sensitive branch probabilities"));
47 static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
48 "fs-profile-debug-prob-diff-threshold", cl::init(10),
49 cl::desc("Only show debug message if the branch probility is greater than "
50 "this value (in percentage)."));
52 static cl::opt<unsigned> FSProfileDebugBWThreshold(
53 "fs-profile-debug-bw-threshold", cl::init(10000),
54 cl::desc("Only show debug message if the source branch weight is greater "
55 " than this value."));
57 static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
58 cl::init(false),
59 cl::desc("View BFI before MIR loader"));
60 static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
61 cl::init(false),
62 cl::desc("View BFI after MIR loader"));
64 extern cl::opt<bool> ImprovedFSDiscriminator;
65 char MIRProfileLoaderPass::ID = 0;
67 INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
68 "Load MIR Sample Profile",
69 /* cfg = */ false, /* is_analysis = */ false)
70 INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
71 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
72 INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
73 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
74 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
75 INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
76 /* cfg = */ false, /* is_analysis = */ false)
78 char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
80 FunctionPass *
81 llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
82 FSDiscriminatorPass P,
83 IntrusiveRefCntPtr<vfs::FileSystem> FS) {
84 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
87 namespace llvm {
89 // Internal option used to control BFI display only after MBP pass.
90 // Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
91 // -view-block-layout-with-bfi={none | fraction | integer | count}
92 extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
94 // Command line option to specify the name of the function for CFG dump
95 // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
96 extern cl::opt<std::string> ViewBlockFreqFuncName;
98 std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
99 if (MI.isPseudoProbe()) {
100 PseudoProbe Probe;
101 Probe.Id = MI.getOperand(1).getImm();
102 Probe.Type = MI.getOperand(2).getImm();
103 Probe.Attr = MI.getOperand(3).getImm();
104 Probe.Factor = 1;
105 DILocation *DebugLoc = MI.getDebugLoc();
106 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
107 return Probe;
110 // Ignore callsite probes since they do not have FS discriminators.
111 return std::nullopt;
114 namespace afdo_detail {
115 template <> struct IRTraits<MachineBasicBlock> {
116 using InstructionT = MachineInstr;
117 using BasicBlockT = MachineBasicBlock;
118 using FunctionT = MachineFunction;
119 using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
120 using LoopT = MachineLoop;
121 using LoopInfoPtrT = MachineLoopInfo *;
122 using DominatorTreePtrT = MachineDominatorTree *;
123 using PostDominatorTreePtrT = MachinePostDominatorTree *;
124 using PostDominatorTreeT = MachinePostDominatorTree;
125 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
126 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
127 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
128 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
129 static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
130 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
131 return GraphTraits<const MachineFunction *>::getEntryNode(F);
133 static PredRangeT getPredecessors(MachineBasicBlock *BB) {
134 return BB->predecessors();
136 static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
137 return BB->successors();
140 } // namespace afdo_detail
142 class MIRProfileLoader final
143 : public SampleProfileLoaderBaseImpl<MachineFunction> {
144 public:
145 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
146 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
147 MachineOptimizationRemarkEmitter *MORE) {
148 DT = MDT;
149 PDT = MPDT;
150 LI = MLI;
151 BFI = MBFI;
152 ORE = MORE;
154 void setFSPass(FSDiscriminatorPass Pass) {
155 P = Pass;
156 LowBit = getFSPassBitBegin(P);
157 HighBit = getFSPassBitEnd(P);
158 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
161 MIRProfileLoader(StringRef Name, StringRef RemapName,
162 IntrusiveRefCntPtr<vfs::FileSystem> FS)
163 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
164 std::move(FS)) {}
166 void setBranchProbs(MachineFunction &F);
167 bool runOnFunction(MachineFunction &F);
168 bool doInitialization(Module &M);
169 bool isValid() const { return ProfileIsValid; }
171 protected:
172 friend class SampleCoverageTracker;
174 /// Hold the information of the basic block frequency.
175 MachineBlockFrequencyInfo *BFI;
177 /// PassNum is the sequence number this pass is called, start from 1.
178 FSDiscriminatorPass P;
180 // LowBit in the FS discriminator used by this instance. Note the number is
181 // 0-based. Base discrimnator use bit 0 to bit 11.
182 unsigned LowBit;
183 // HighwBit in the FS discriminator used by this instance. Note the number
184 // is 0-based.
185 unsigned HighBit;
187 bool ProfileIsValid = true;
188 ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
189 if (FunctionSamples::ProfileIsProbeBased)
190 return getProbeWeight(MI);
191 if (ImprovedFSDiscriminator && MI.isMetaInstruction())
192 return std::error_code();
193 return getInstWeightImpl(MI);
197 template <>
198 void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo(
199 MachineFunction &F) {}
201 void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
202 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
203 for (auto &BI : F) {
204 MachineBasicBlock *BB = &BI;
205 if (BB->succ_size() < 2)
206 continue;
207 const MachineBasicBlock *EC = EquivalenceClass[BB];
208 uint64_t BBWeight = BlockWeights[EC];
209 uint64_t SumEdgeWeight = 0;
210 for (MachineBasicBlock *Succ : BB->successors()) {
211 Edge E = std::make_pair(BB, Succ);
212 SumEdgeWeight += EdgeWeights[E];
215 if (BBWeight != SumEdgeWeight) {
216 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
217 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
218 << "\n");
219 BBWeight = SumEdgeWeight;
221 if (BBWeight == 0) {
222 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
223 continue;
226 #ifndef NDEBUG
227 uint64_t BBWeightOrig = BBWeight;
228 #endif
229 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
230 uint32_t Factor = 1;
231 if (BBWeight > MaxWeight) {
232 Factor = BBWeight / MaxWeight + 1;
233 BBWeight /= Factor;
234 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
237 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
238 SE = BB->succ_end();
239 SI != SE; ++SI) {
240 MachineBasicBlock *Succ = *SI;
241 Edge E = std::make_pair(BB, Succ);
242 uint64_t EdgeWeight = EdgeWeights[E];
243 EdgeWeight /= Factor;
245 assert(BBWeight >= EdgeWeight &&
246 "BBweight is larger than EdgeWeight -- should not happen.\n");
248 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
249 BranchProbability NewProb(EdgeWeight, BBWeight);
250 if (OldProb == NewProb)
251 continue;
252 BB->setSuccProbability(SI, NewProb);
253 #ifndef NDEBUG
254 if (!ShowFSBranchProb)
255 continue;
256 bool Show = false;
257 BranchProbability Diff;
258 if (OldProb > NewProb)
259 Diff = OldProb - NewProb;
260 else
261 Diff = NewProb - OldProb;
262 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
263 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
265 auto DIL = BB->findBranchDebugLoc();
266 auto SuccDIL = Succ->findBranchDebugLoc();
267 if (Show) {
268 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
269 << Succ->getNumber() << "): ";
270 if (DIL)
271 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
272 << DIL->getColumn();
273 if (SuccDIL)
274 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
275 << ":" << SuccDIL->getColumn();
276 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
277 << "\n";
279 #endif
284 bool MIRProfileLoader::doInitialization(Module &M) {
285 auto &Ctx = M.getContext();
287 auto ReaderOrErr = sampleprof::SampleProfileReader::create(
288 Filename, Ctx, *FS, P, RemappingFilename);
289 if (std::error_code EC = ReaderOrErr.getError()) {
290 std::string Msg = "Could not open profile: " + EC.message();
291 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
292 return false;
295 Reader = std::move(ReaderOrErr.get());
296 Reader->setModule(&M);
297 ProfileIsValid = (Reader->read() == sampleprof_error::success);
299 // Load pseudo probe descriptors for probe-based function samples.
300 if (Reader->profileIsProbeBased()) {
301 ProbeManager = std::make_unique<PseudoProbeManager>(M);
302 if (!ProbeManager->moduleIsProbed(M)) {
303 return false;
307 return true;
310 bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
311 // Do not load non-FS profiles. A line or probe can get a zero-valued
312 // discriminator at certain pass which could result in accidentally loading
313 // the corresponding base counter in the non-FS profile, while a non-zero
314 // discriminator would end up getting zero samples. This could in turn undo
315 // the sample distribution effort done by previous BFI maintenance and the
316 // probe distribution factor work for pseudo probes.
317 if (!Reader->profileIsFS())
318 return false;
320 Function &Func = MF.getFunction();
321 clearFunctionData(false);
322 Samples = Reader->getSamplesFor(Func);
323 if (!Samples || Samples->empty())
324 return false;
326 if (FunctionSamples::ProfileIsProbeBased) {
327 if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
328 return false;
329 } else {
330 if (getFunctionLoc(MF) == 0)
331 return false;
334 DenseSet<GlobalValue::GUID> InlinedGUIDs;
335 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
337 // Set the new BPI, BFI.
338 setBranchProbs(MF);
340 return Changed;
343 } // namespace llvm
345 MIRProfileLoaderPass::MIRProfileLoaderPass(
346 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
347 IntrusiveRefCntPtr<vfs::FileSystem> FS)
348 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
349 LowBit = getFSPassBitBegin(P);
350 HighBit = getFSPassBitEnd(P);
352 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
353 MIRSampleLoader = std::make_unique<MIRProfileLoader>(
354 FileName, RemappingFileName, std::move(VFS));
355 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
358 bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
359 if (!MIRSampleLoader->isValid())
360 return false;
362 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
363 << MF.getFunction().getName() << "\n");
364 MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
365 MIRSampleLoader->setInitVals(
366 &getAnalysis<MachineDominatorTree>(),
367 &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
368 MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
370 MF.RenumberBlocks();
371 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
372 (ViewBlockFreqFuncName.empty() ||
373 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
374 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
377 bool Changed = MIRSampleLoader->runOnFunction(MF);
378 if (Changed)
379 MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());
381 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
382 (ViewBlockFreqFuncName.empty() ||
383 MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
384 MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
387 return Changed;
390 bool MIRProfileLoaderPass::doInitialization(Module &M) {
391 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
392 << "\n");
394 MIRSampleLoader->setFSPass(P);
395 return MIRSampleLoader->doInitialization(M);
398 void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
399 AU.setPreservesAll();
400 AU.addRequired<MachineBlockFrequencyInfo>();
401 AU.addRequired<MachineDominatorTree>();
402 AU.addRequired<MachinePostDominatorTree>();
403 AU.addRequiredTransitive<MachineLoopInfo>();
404 AU.addRequired<MachineOptimizationRemarkEmitterPass>();
405 MachineFunctionPass::getAnalysisUsage(AU);