[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / clang / lib / CodeGen / CodeGenPGO.cpp
blob7d6c69f22d0e56282cf7251567e83e1da332f7f1
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
11 //===----------------------------------------------------------------------===//
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
26 static llvm::cl::opt<bool>
27 EnableValueProfiling("enable-value-profiling",
28 llvm::cl::desc("Enable value profiling"),
29 llvm::cl::Hidden, llvm::cl::init(false));
31 using namespace clang;
32 using namespace CodeGen;
34 void CodeGenPGO::setFuncName(StringRef Name,
35 llvm::GlobalValue::LinkageTypes Linkage) {
36 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37 FuncName = llvm::getPGOFuncName(
38 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
41 // If we're generating a profile, create a variable for the name.
42 if (CGM.getCodeGenOpts().hasProfileClangInstr())
43 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
46 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47 setFuncName(Fn->getName(), Fn->getLinkage());
48 // Create PGOFuncName meta data.
49 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion : unsigned {
54 PGO_HASH_V1,
55 PGO_HASH_V2,
56 PGO_HASH_V3,
58 // Keep this set to the latest hash version.
59 PGO_HASH_LATEST = PGO_HASH_V3
62 namespace {
63 /// Stable hasher for PGO region counters.
64 ///
65 /// PGOHash produces a stable hash of a given function's control flow.
66 ///
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
69 ///
70 /// \note When this hash does eventually change (years?), we still need to
71 /// support old hashes. We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
73 class PGOHash {
74 uint64_t Working;
75 unsigned Count;
76 PGOHashVersion HashVersion;
77 llvm::MD5 MD5;
79 static const int NumBitsPerType = 6;
80 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81 static const unsigned TooBig = 1u << NumBitsPerType;
83 public:
84 /// Hash values for AST nodes.
85 ///
86 /// Distinct values for AST nodes that have region counters attached.
87 ///
88 /// These values must be stable. All new members must be added at the end,
89 /// and no members should be removed. Changing the enumeration value for an
90 /// AST node will affect the hash of every function that contains that node.
91 enum HashType : unsigned char {
92 None = 0,
93 LabelStmt = 1,
94 WhileStmt,
95 DoStmt,
96 ForStmt,
97 CXXForRangeStmt,
98 ObjCForCollectionStmt,
99 SwitchStmt,
100 CaseStmt,
101 DefaultStmt,
102 IfStmt,
103 CXXTryStmt,
104 CXXCatchStmt,
105 ConditionalOperator,
106 BinaryOperatorLAnd,
107 BinaryOperatorLOr,
108 BinaryConditionalOperator,
109 // The preceding values are available with PGO_HASH_V1.
111 EndOfScope,
112 IfThenBranch,
113 IfElseBranch,
114 GotoStmt,
115 IndirectGotoStmt,
116 BreakStmt,
117 ContinueStmt,
118 ReturnStmt,
119 ThrowExpr,
120 UnaryOperatorLNot,
121 BinaryOperatorLT,
122 BinaryOperatorGT,
123 BinaryOperatorLE,
124 BinaryOperatorGE,
125 BinaryOperatorEQ,
126 BinaryOperatorNE,
127 // The preceding values are available since PGO_HASH_V2.
129 // Keep this last. It's for the static assert that follows.
130 LastHashType
132 static_assert(LastHashType <= TooBig, "Too many types in HashType");
134 PGOHash(PGOHashVersion HashVersion)
135 : Working(0), Count(0), HashVersion(HashVersion) {}
136 void combine(HashType Type);
137 uint64_t finalize();
138 PGOHashVersion getHashVersion() const { return HashVersion; }
140 const int PGOHash::NumBitsPerType;
141 const unsigned PGOHash::NumTypesPerWord;
142 const unsigned PGOHash::TooBig;
144 /// Get the PGO hash version used in the given indexed profile.
145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146 CodeGenModule &CGM) {
147 if (PGOReader->getVersion() <= 4)
148 return PGO_HASH_V1;
149 if (PGOReader->getVersion() <= 5)
150 return PGO_HASH_V2;
151 return PGO_HASH_V3;
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156 using Base = RecursiveASTVisitor<MapRegionCounters>;
158 /// The next counter value to assign.
159 unsigned NextCounter;
160 /// The function hash.
161 PGOHash Hash;
162 /// The map of statements to counters.
163 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164 /// The profile version.
165 uint64_t ProfileVersion;
167 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170 ProfileVersion(ProfileVersion) {}
172 // Blocks and lambdas are handled as separate functions, so we need not
173 // traverse them in the parent context.
174 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
175 bool TraverseLambdaExpr(LambdaExpr *LE) {
176 // Traverse the captures, but not the body.
177 for (auto C : zip(LE->captures(), LE->capture_inits()))
178 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179 return true;
181 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
183 bool VisitDecl(const Decl *D) {
184 switch (D->getKind()) {
185 default:
186 break;
187 case Decl::Function:
188 case Decl::CXXMethod:
189 case Decl::CXXConstructor:
190 case Decl::CXXDestructor:
191 case Decl::CXXConversion:
192 case Decl::ObjCMethod:
193 case Decl::Block:
194 case Decl::Captured:
195 CounterMap[D->getBody()] = NextCounter++;
196 break;
198 return true;
201 /// If \p S gets a fresh counter, update the counter mappings. Return the
202 /// V1 hash of \p S.
203 PGOHash::HashType updateCounterMappings(Stmt *S) {
204 auto Type = getHashType(PGO_HASH_V1, S);
205 if (Type != PGOHash::None)
206 CounterMap[S] = NextCounter++;
207 return Type;
210 /// The RHS of all logical operators gets a fresh counter in order to count
211 /// how many times the RHS evaluates to true or false, depending on the
212 /// semantics of the operator. This is only valid for ">= v7" of the profile
213 /// version so that we facilitate backward compatibility.
214 bool VisitBinaryOperator(BinaryOperator *S) {
215 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216 if (S->isLogicalOp() &&
217 CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218 CounterMap[S->getRHS()] = NextCounter++;
219 return Base::VisitBinaryOperator(S);
222 /// Include \p S in the function hash.
223 bool VisitStmt(Stmt *S) {
224 auto Type = updateCounterMappings(S);
225 if (Hash.getHashVersion() != PGO_HASH_V1)
226 Type = getHashType(Hash.getHashVersion(), S);
227 if (Type != PGOHash::None)
228 Hash.combine(Type);
229 return true;
232 bool TraverseIfStmt(IfStmt *If) {
233 // If we used the V1 hash, use the default traversal.
234 if (Hash.getHashVersion() == PGO_HASH_V1)
235 return Base::TraverseIfStmt(If);
237 // Otherwise, keep track of which branch we're in while traversing.
238 VisitStmt(If);
239 for (Stmt *CS : If->children()) {
240 if (!CS)
241 continue;
242 if (CS == If->getThen())
243 Hash.combine(PGOHash::IfThenBranch);
244 else if (CS == If->getElse())
245 Hash.combine(PGOHash::IfElseBranch);
246 TraverseStmt(CS);
248 Hash.combine(PGOHash::EndOfScope);
249 return true;
252 // If the statement type \p N is nestable, and its nesting impacts profile
253 // stability, define a custom traversal which tracks the end of the statement
254 // in the hash (provided we're not using the V1 hash).
255 #define DEFINE_NESTABLE_TRAVERSAL(N) \
256 bool Traverse##N(N *S) { \
257 Base::Traverse##N(S); \
258 if (Hash.getHashVersion() != PGO_HASH_V1) \
259 Hash.combine(PGOHash::EndOfScope); \
260 return true; \
263 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
264 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
271 /// Get version \p HashVersion of the PGO hash for \p S.
272 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273 switch (S->getStmtClass()) {
274 default:
275 break;
276 case Stmt::LabelStmtClass:
277 return PGOHash::LabelStmt;
278 case Stmt::WhileStmtClass:
279 return PGOHash::WhileStmt;
280 case Stmt::DoStmtClass:
281 return PGOHash::DoStmt;
282 case Stmt::ForStmtClass:
283 return PGOHash::ForStmt;
284 case Stmt::CXXForRangeStmtClass:
285 return PGOHash::CXXForRangeStmt;
286 case Stmt::ObjCForCollectionStmtClass:
287 return PGOHash::ObjCForCollectionStmt;
288 case Stmt::SwitchStmtClass:
289 return PGOHash::SwitchStmt;
290 case Stmt::CaseStmtClass:
291 return PGOHash::CaseStmt;
292 case Stmt::DefaultStmtClass:
293 return PGOHash::DefaultStmt;
294 case Stmt::IfStmtClass:
295 return PGOHash::IfStmt;
296 case Stmt::CXXTryStmtClass:
297 return PGOHash::CXXTryStmt;
298 case Stmt::CXXCatchStmtClass:
299 return PGOHash::CXXCatchStmt;
300 case Stmt::ConditionalOperatorClass:
301 return PGOHash::ConditionalOperator;
302 case Stmt::BinaryConditionalOperatorClass:
303 return PGOHash::BinaryConditionalOperator;
304 case Stmt::BinaryOperatorClass: {
305 const BinaryOperator *BO = cast<BinaryOperator>(S);
306 if (BO->getOpcode() == BO_LAnd)
307 return PGOHash::BinaryOperatorLAnd;
308 if (BO->getOpcode() == BO_LOr)
309 return PGOHash::BinaryOperatorLOr;
310 if (HashVersion >= PGO_HASH_V2) {
311 switch (BO->getOpcode()) {
312 default:
313 break;
314 case BO_LT:
315 return PGOHash::BinaryOperatorLT;
316 case BO_GT:
317 return PGOHash::BinaryOperatorGT;
318 case BO_LE:
319 return PGOHash::BinaryOperatorLE;
320 case BO_GE:
321 return PGOHash::BinaryOperatorGE;
322 case BO_EQ:
323 return PGOHash::BinaryOperatorEQ;
324 case BO_NE:
325 return PGOHash::BinaryOperatorNE;
328 break;
332 if (HashVersion >= PGO_HASH_V2) {
333 switch (S->getStmtClass()) {
334 default:
335 break;
336 case Stmt::GotoStmtClass:
337 return PGOHash::GotoStmt;
338 case Stmt::IndirectGotoStmtClass:
339 return PGOHash::IndirectGotoStmt;
340 case Stmt::BreakStmtClass:
341 return PGOHash::BreakStmt;
342 case Stmt::ContinueStmtClass:
343 return PGOHash::ContinueStmt;
344 case Stmt::ReturnStmtClass:
345 return PGOHash::ReturnStmt;
346 case Stmt::CXXThrowExprClass:
347 return PGOHash::ThrowExpr;
348 case Stmt::UnaryOperatorClass: {
349 const UnaryOperator *UO = cast<UnaryOperator>(S);
350 if (UO->getOpcode() == UO_LNot)
351 return PGOHash::UnaryOperatorLNot;
352 break;
357 return PGOHash::None;
361 /// A StmtVisitor that propagates the raw counts through the AST and
362 /// records the count at statements where the value may change.
363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364 /// PGO state.
365 CodeGenPGO &PGO;
367 /// A flag that is set when the current count should be recorded on the
368 /// next statement, such as at the exit of a loop.
369 bool RecordNextStmtCount;
371 /// The count at the current location in the traversal.
372 uint64_t CurrentCount;
374 /// The map of statements to count values.
375 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
377 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378 struct BreakContinue {
379 uint64_t BreakCount = 0;
380 uint64_t ContinueCount = 0;
381 BreakContinue() = default;
383 SmallVector<BreakContinue, 8> BreakContinueStack;
385 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386 CodeGenPGO &PGO)
387 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
389 void RecordStmtCount(const Stmt *S) {
390 if (RecordNextStmtCount) {
391 CountMap[S] = CurrentCount;
392 RecordNextStmtCount = false;
396 /// Set and return the current count.
397 uint64_t setCount(uint64_t Count) {
398 CurrentCount = Count;
399 return Count;
402 void VisitStmt(const Stmt *S) {
403 RecordStmtCount(S);
404 for (const Stmt *Child : S->children())
405 if (Child)
406 this->Visit(Child);
409 void VisitFunctionDecl(const FunctionDecl *D) {
410 // Counter tracks entry to the function body.
411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412 CountMap[D->getBody()] = BodyCount;
413 Visit(D->getBody());
416 // Skip lambda expressions. We visit these as FunctionDecls when we're
417 // generating them and aren't interested in the body when generating a
418 // parent context.
419 void VisitLambdaExpr(const LambdaExpr *LE) {}
421 void VisitCapturedDecl(const CapturedDecl *D) {
422 // Counter tracks entry to the capture body.
423 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424 CountMap[D->getBody()] = BodyCount;
425 Visit(D->getBody());
428 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429 // Counter tracks entry to the method body.
430 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431 CountMap[D->getBody()] = BodyCount;
432 Visit(D->getBody());
435 void VisitBlockDecl(const BlockDecl *D) {
436 // Counter tracks entry to the block body.
437 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438 CountMap[D->getBody()] = BodyCount;
439 Visit(D->getBody());
442 void VisitReturnStmt(const ReturnStmt *S) {
443 RecordStmtCount(S);
444 if (S->getRetValue())
445 Visit(S->getRetValue());
446 CurrentCount = 0;
447 RecordNextStmtCount = true;
450 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451 RecordStmtCount(E);
452 if (E->getSubExpr())
453 Visit(E->getSubExpr());
454 CurrentCount = 0;
455 RecordNextStmtCount = true;
458 void VisitGotoStmt(const GotoStmt *S) {
459 RecordStmtCount(S);
460 CurrentCount = 0;
461 RecordNextStmtCount = true;
464 void VisitLabelStmt(const LabelStmt *S) {
465 RecordNextStmtCount = false;
466 // Counter tracks the block following the label.
467 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468 CountMap[S] = BlockCount;
469 Visit(S->getSubStmt());
472 void VisitBreakStmt(const BreakStmt *S) {
473 RecordStmtCount(S);
474 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475 BreakContinueStack.back().BreakCount += CurrentCount;
476 CurrentCount = 0;
477 RecordNextStmtCount = true;
480 void VisitContinueStmt(const ContinueStmt *S) {
481 RecordStmtCount(S);
482 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483 BreakContinueStack.back().ContinueCount += CurrentCount;
484 CurrentCount = 0;
485 RecordNextStmtCount = true;
488 void VisitWhileStmt(const WhileStmt *S) {
489 RecordStmtCount(S);
490 uint64_t ParentCount = CurrentCount;
492 BreakContinueStack.push_back(BreakContinue());
493 // Visit the body region first so the break/continue adjustments can be
494 // included when visiting the condition.
495 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496 CountMap[S->getBody()] = CurrentCount;
497 Visit(S->getBody());
498 uint64_t BackedgeCount = CurrentCount;
500 // ...then go back and propagate counts through the condition. The count
501 // at the start of the condition is the sum of the incoming edges,
502 // the backedge from the end of the loop body, and the edges from
503 // continue statements.
504 BreakContinue BC = BreakContinueStack.pop_back_val();
505 uint64_t CondCount =
506 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507 CountMap[S->getCond()] = CondCount;
508 Visit(S->getCond());
509 setCount(BC.BreakCount + CondCount - BodyCount);
510 RecordNextStmtCount = true;
513 void VisitDoStmt(const DoStmt *S) {
514 RecordStmtCount(S);
515 uint64_t LoopCount = PGO.getRegionCount(S);
517 BreakContinueStack.push_back(BreakContinue());
518 // The count doesn't include the fallthrough from the parent scope. Add it.
519 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520 CountMap[S->getBody()] = BodyCount;
521 Visit(S->getBody());
522 uint64_t BackedgeCount = CurrentCount;
524 BreakContinue BC = BreakContinueStack.pop_back_val();
525 // The count at the start of the condition is equal to the count at the
526 // end of the body, plus any continues.
527 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528 CountMap[S->getCond()] = CondCount;
529 Visit(S->getCond());
530 setCount(BC.BreakCount + CondCount - LoopCount);
531 RecordNextStmtCount = true;
534 void VisitForStmt(const ForStmt *S) {
535 RecordStmtCount(S);
536 if (S->getInit())
537 Visit(S->getInit());
539 uint64_t ParentCount = CurrentCount;
541 BreakContinueStack.push_back(BreakContinue());
542 // Visit the body region first. (This is basically the same as a while
543 // loop; see further comments in VisitWhileStmt.)
544 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545 CountMap[S->getBody()] = BodyCount;
546 Visit(S->getBody());
547 uint64_t BackedgeCount = CurrentCount;
548 BreakContinue BC = BreakContinueStack.pop_back_val();
550 // The increment is essentially part of the body but it needs to include
551 // the count for all the continue statements.
552 if (S->getInc()) {
553 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554 CountMap[S->getInc()] = IncCount;
555 Visit(S->getInc());
558 // ...then go back and propagate counts through the condition.
559 uint64_t CondCount =
560 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561 if (S->getCond()) {
562 CountMap[S->getCond()] = CondCount;
563 Visit(S->getCond());
565 setCount(BC.BreakCount + CondCount - BodyCount);
566 RecordNextStmtCount = true;
569 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570 RecordStmtCount(S);
571 if (S->getInit())
572 Visit(S->getInit());
573 Visit(S->getLoopVarStmt());
574 Visit(S->getRangeStmt());
575 Visit(S->getBeginStmt());
576 Visit(S->getEndStmt());
578 uint64_t ParentCount = CurrentCount;
579 BreakContinueStack.push_back(BreakContinue());
580 // Visit the body region first. (This is basically the same as a while
581 // loop; see further comments in VisitWhileStmt.)
582 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583 CountMap[S->getBody()] = BodyCount;
584 Visit(S->getBody());
585 uint64_t BackedgeCount = CurrentCount;
586 BreakContinue BC = BreakContinueStack.pop_back_val();
588 // The increment is essentially part of the body but it needs to include
589 // the count for all the continue statements.
590 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591 CountMap[S->getInc()] = IncCount;
592 Visit(S->getInc());
594 // ...then go back and propagate counts through the condition.
595 uint64_t CondCount =
596 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597 CountMap[S->getCond()] = CondCount;
598 Visit(S->getCond());
599 setCount(BC.BreakCount + CondCount - BodyCount);
600 RecordNextStmtCount = true;
603 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604 RecordStmtCount(S);
605 Visit(S->getElement());
606 uint64_t ParentCount = CurrentCount;
607 BreakContinueStack.push_back(BreakContinue());
608 // Counter tracks the body of the loop.
609 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610 CountMap[S->getBody()] = BodyCount;
611 Visit(S->getBody());
612 uint64_t BackedgeCount = CurrentCount;
613 BreakContinue BC = BreakContinueStack.pop_back_val();
615 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616 BodyCount);
617 RecordNextStmtCount = true;
620 void VisitSwitchStmt(const SwitchStmt *S) {
621 RecordStmtCount(S);
622 if (S->getInit())
623 Visit(S->getInit());
624 Visit(S->getCond());
625 CurrentCount = 0;
626 BreakContinueStack.push_back(BreakContinue());
627 Visit(S->getBody());
628 // If the switch is inside a loop, add the continue counts.
629 BreakContinue BC = BreakContinueStack.pop_back_val();
630 if (!BreakContinueStack.empty())
631 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632 // Counter tracks the exit block of the switch.
633 setCount(PGO.getRegionCount(S));
634 RecordNextStmtCount = true;
637 void VisitSwitchCase(const SwitchCase *S) {
638 RecordNextStmtCount = false;
639 // Counter for this particular case. This counts only jumps from the
640 // switch header and does not include fallthrough from the case before
641 // this one.
642 uint64_t CaseCount = PGO.getRegionCount(S);
643 setCount(CurrentCount + CaseCount);
644 // We need the count without fallthrough in the mapping, so it's more useful
645 // for branch probabilities.
646 CountMap[S] = CaseCount;
647 RecordNextStmtCount = true;
648 Visit(S->getSubStmt());
651 void VisitIfStmt(const IfStmt *S) {
652 RecordStmtCount(S);
654 if (S->isConsteval()) {
655 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656 if (Stm)
657 Visit(Stm);
658 return;
661 uint64_t ParentCount = CurrentCount;
662 if (S->getInit())
663 Visit(S->getInit());
664 Visit(S->getCond());
666 // Counter tracks the "then" part of an if statement. The count for
667 // the "else" part, if it exists, will be calculated from this counter.
668 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669 CountMap[S->getThen()] = ThenCount;
670 Visit(S->getThen());
671 uint64_t OutCount = CurrentCount;
673 uint64_t ElseCount = ParentCount - ThenCount;
674 if (S->getElse()) {
675 setCount(ElseCount);
676 CountMap[S->getElse()] = ElseCount;
677 Visit(S->getElse());
678 OutCount += CurrentCount;
679 } else
680 OutCount += ElseCount;
681 setCount(OutCount);
682 RecordNextStmtCount = true;
685 void VisitCXXTryStmt(const CXXTryStmt *S) {
686 RecordStmtCount(S);
687 Visit(S->getTryBlock());
688 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689 Visit(S->getHandler(I));
690 // Counter tracks the continuation block of the try statement.
691 setCount(PGO.getRegionCount(S));
692 RecordNextStmtCount = true;
695 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696 RecordNextStmtCount = false;
697 // Counter tracks the catch statement's handler block.
698 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699 CountMap[S] = CatchCount;
700 Visit(S->getHandlerBlock());
703 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704 RecordStmtCount(E);
705 uint64_t ParentCount = CurrentCount;
706 Visit(E->getCond());
708 // Counter tracks the "true" part of a conditional operator. The
709 // count in the "false" part will be calculated from this counter.
710 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711 CountMap[E->getTrueExpr()] = TrueCount;
712 Visit(E->getTrueExpr());
713 uint64_t OutCount = CurrentCount;
715 uint64_t FalseCount = setCount(ParentCount - TrueCount);
716 CountMap[E->getFalseExpr()] = FalseCount;
717 Visit(E->getFalseExpr());
718 OutCount += CurrentCount;
720 setCount(OutCount);
721 RecordNextStmtCount = true;
724 void VisitBinLAnd(const BinaryOperator *E) {
725 RecordStmtCount(E);
726 uint64_t ParentCount = CurrentCount;
727 Visit(E->getLHS());
728 // Counter tracks the right hand side of a logical and operator.
729 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730 CountMap[E->getRHS()] = RHSCount;
731 Visit(E->getRHS());
732 setCount(ParentCount + RHSCount - CurrentCount);
733 RecordNextStmtCount = true;
736 void VisitBinLOr(const BinaryOperator *E) {
737 RecordStmtCount(E);
738 uint64_t ParentCount = CurrentCount;
739 Visit(E->getLHS());
740 // Counter tracks the right hand side of a logical or operator.
741 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742 CountMap[E->getRHS()] = RHSCount;
743 Visit(E->getRHS());
744 setCount(ParentCount + RHSCount - CurrentCount);
745 RecordNextStmtCount = true;
748 } // end anonymous namespace
750 void PGOHash::combine(HashType Type) {
751 // Check that we never combine 0 and only have six bits.
752 assert(Type && "Hash is invalid: unexpected type 0");
753 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
755 // Pass through MD5 if enough work has built up.
756 if (Count && Count % NumTypesPerWord == 0) {
757 using namespace llvm::support;
758 uint64_t Swapped =
759 endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
760 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
761 Working = 0;
764 // Accumulate the current type.
765 ++Count;
766 Working = Working << NumBitsPerType | Type;
769 uint64_t PGOHash::finalize() {
770 // Use Working as the hash directly if we never used MD5.
771 if (Count <= NumTypesPerWord)
772 // No need to byte swap here, since none of the math was endian-dependent.
773 // This number will be byte-swapped as required on endianness transitions,
774 // so we will see the same value on the other side.
775 return Working;
777 // Check for remaining work in Working.
778 if (Working) {
779 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
780 // is buggy because it converts a uint64_t into an array of uint8_t.
781 if (HashVersion < PGO_HASH_V3) {
782 MD5.update({(uint8_t)Working});
783 } else {
784 using namespace llvm::support;
785 uint64_t Swapped =
786 endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
787 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
791 // Finalize the MD5 and return the hash.
792 llvm::MD5::MD5Result Result;
793 MD5.final(Result);
794 return Result.low();
797 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
798 const Decl *D = GD.getDecl();
799 if (!D->hasBody())
800 return;
802 // Skip CUDA/HIP kernel launch stub functions.
803 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
804 D->hasAttr<CUDAGlobalAttr>())
805 return;
807 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
808 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
809 if (!InstrumentRegions && !PGOReader)
810 return;
811 if (D->isImplicit())
812 return;
813 // Constructors and destructors may be represented by several functions in IR.
814 // If so, instrument only base variant, others are implemented by delegation
815 // to the base one, it would be counted twice otherwise.
816 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
817 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
818 if (GD.getCtorType() != Ctor_Base &&
819 CodeGenFunction::IsConstructorDelegationValid(CCD))
820 return;
822 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
823 return;
825 CGM.ClearUnusedCoverageMapping(D);
826 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
827 return;
828 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
829 return;
831 setFuncName(Fn);
833 mapRegionCounters(D);
834 if (CGM.getCodeGenOpts().CoverageMapping)
835 emitCounterRegionMapping(D);
836 if (PGOReader) {
837 SourceManager &SM = CGM.getContext().getSourceManager();
838 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
839 computeRegionCounts(D);
840 applyFunctionAttributes(PGOReader, Fn);
844 void CodeGenPGO::mapRegionCounters(const Decl *D) {
845 // Use the latest hash version when inserting instrumentation, but use the
846 // version in the indexed profile if we're reading PGO data.
847 PGOHashVersion HashVersion = PGO_HASH_LATEST;
848 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
849 if (auto *PGOReader = CGM.getPGOReader()) {
850 HashVersion = getPGOHashVersion(PGOReader, CGM);
851 ProfileVersion = PGOReader->getVersion();
854 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
855 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
856 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
857 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
858 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
859 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
860 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
861 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
862 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
863 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
864 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
865 NumRegionCounters = Walker.NextCounter;
866 FunctionHash = Walker.Hash.finalize();
869 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
870 if (!D->getBody())
871 return true;
873 // Skip host-only functions in the CUDA device compilation and device-only
874 // functions in the host compilation. Just roughly filter them out based on
875 // the function attributes. If there are effectively host-only or device-only
876 // ones, their coverage mapping may still be generated.
877 if (CGM.getLangOpts().CUDA &&
878 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
879 !D->hasAttr<CUDAGlobalAttr>()) ||
880 (!CGM.getLangOpts().CUDAIsDevice &&
881 (D->hasAttr<CUDAGlobalAttr>() ||
882 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
883 return true;
885 // Don't map the functions in system headers.
886 const auto &SM = CGM.getContext().getSourceManager();
887 auto Loc = D->getBody()->getBeginLoc();
888 return SM.isInSystemHeader(Loc);
891 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
892 if (skipRegionMappingForDecl(D))
893 return;
895 std::string CoverageMapping;
896 llvm::raw_string_ostream OS(CoverageMapping);
897 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
898 CGM.getContext().getSourceManager(),
899 CGM.getLangOpts(), RegionCounterMap.get());
900 MappingGen.emitCounterMapping(D, OS);
901 OS.flush();
903 if (CoverageMapping.empty())
904 return;
906 CGM.getCoverageMapping()->addFunctionMappingRecord(
907 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
910 void
911 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
912 llvm::GlobalValue::LinkageTypes Linkage) {
913 if (skipRegionMappingForDecl(D))
914 return;
916 std::string CoverageMapping;
917 llvm::raw_string_ostream OS(CoverageMapping);
918 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
919 CGM.getContext().getSourceManager(),
920 CGM.getLangOpts());
921 MappingGen.emitEmptyMapping(D, OS);
922 OS.flush();
924 if (CoverageMapping.empty())
925 return;
927 setFuncName(Name, Linkage);
928 CGM.getCoverageMapping()->addFunctionMappingRecord(
929 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
932 void CodeGenPGO::computeRegionCounts(const Decl *D) {
933 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
934 ComputeRegionCounts Walker(*StmtCountMap, *this);
935 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
936 Walker.VisitFunctionDecl(FD);
937 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
938 Walker.VisitObjCMethodDecl(MD);
939 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
940 Walker.VisitBlockDecl(BD);
941 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
942 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
945 void
946 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
947 llvm::Function *Fn) {
948 if (!haveRegionCounts())
949 return;
951 uint64_t FunctionCount = getRegionCount(nullptr);
952 Fn->setEntryCount(FunctionCount);
955 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
956 llvm::Value *StepV) {
957 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
958 return;
959 if (!Builder.GetInsertBlock())
960 return;
962 unsigned Counter = (*RegionCounterMap)[S];
964 llvm::Value *Args[] = {FuncNameVar,
965 Builder.getInt64(FunctionHash),
966 Builder.getInt32(NumRegionCounters),
967 Builder.getInt32(Counter), StepV};
968 if (!StepV)
969 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
970 ArrayRef(Args, 4));
971 else
972 Builder.CreateCall(
973 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
974 ArrayRef(Args));
977 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
978 if (CGM.getCodeGenOpts().hasProfileClangInstr())
979 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
980 uint32_t(EnableValueProfiling));
983 // This method either inserts a call to the profile run-time during
984 // instrumentation or puts profile data into metadata for PGO use.
985 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
986 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
988 if (!EnableValueProfiling)
989 return;
991 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
992 return;
994 if (isa<llvm::Constant>(ValuePtr))
995 return;
997 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
998 if (InstrumentValueSites && RegionCounterMap) {
999 auto BuilderInsertPoint = Builder.saveIP();
1000 Builder.SetInsertPoint(ValueSite);
1001 llvm::Value *Args[5] = {
1002 FuncNameVar,
1003 Builder.getInt64(FunctionHash),
1004 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1005 Builder.getInt32(ValueKind),
1006 Builder.getInt32(NumValueSites[ValueKind]++)
1008 Builder.CreateCall(
1009 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1010 Builder.restoreIP(BuilderInsertPoint);
1011 return;
1014 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1015 if (PGOReader && haveRegionCounts()) {
1016 // We record the top most called three functions at each call site.
1017 // Profile metadata contains "VP" string identifying this metadata
1018 // as value profiling data, then a uint32_t value for the value profiling
1019 // kind, a uint64_t value for the total number of times the call is
1020 // executed, followed by the function hash and execution count (uint64_t)
1021 // pairs for each function.
1022 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1023 return;
1025 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1026 (llvm::InstrProfValueKind)ValueKind,
1027 NumValueSites[ValueKind]);
1029 NumValueSites[ValueKind]++;
1033 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1034 bool IsInMainFile) {
1035 CGM.getPGOStats().addVisited(IsInMainFile);
1036 RegionCounts.clear();
1037 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1038 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1039 if (auto E = RecordExpected.takeError()) {
1040 auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1041 if (IPE == llvm::instrprof_error::unknown_function)
1042 CGM.getPGOStats().addMissing(IsInMainFile);
1043 else if (IPE == llvm::instrprof_error::hash_mismatch)
1044 CGM.getPGOStats().addMismatched(IsInMainFile);
1045 else if (IPE == llvm::instrprof_error::malformed)
1046 // TODO: Consider a more specific warning for this case.
1047 CGM.getPGOStats().addMismatched(IsInMainFile);
1048 return;
1050 ProfRecord =
1051 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1052 RegionCounts = ProfRecord->Counts;
1055 /// Calculate what to divide by to scale weights.
1057 /// Given the maximum weight, calculate a divisor that will scale all the
1058 /// weights to strictly less than UINT32_MAX.
1059 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1060 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1063 /// Scale an individual branch weight (and add 1).
1065 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1067 /// According to Laplace's Rule of Succession, it is better to compute the
1068 /// weight based on the count plus 1, so universally add 1 to the value.
1070 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1071 /// greater than \c Weight.
1072 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1073 assert(Scale && "scale by 0?");
1074 uint64_t Scaled = Weight / Scale + 1;
1075 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1076 return Scaled;
1079 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1080 uint64_t FalseCount) const {
1081 // Check for empty weights.
1082 if (!TrueCount && !FalseCount)
1083 return nullptr;
1085 // Calculate how to scale down to 32-bits.
1086 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1088 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1089 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1090 scaleBranchWeight(FalseCount, Scale));
1093 llvm::MDNode *
1094 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1095 // We need at least two elements to create meaningful weights.
1096 if (Weights.size() < 2)
1097 return nullptr;
1099 // Check for empty weights.
1100 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1101 if (MaxWeight == 0)
1102 return nullptr;
1104 // Calculate how to scale down to 32-bits.
1105 uint64_t Scale = calculateWeightScale(MaxWeight);
1107 SmallVector<uint32_t, 16> ScaledWeights;
1108 ScaledWeights.reserve(Weights.size());
1109 for (uint64_t W : Weights)
1110 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1112 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1113 return MDHelper.createBranchWeights(ScaledWeights);
1116 llvm::MDNode *
1117 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1118 uint64_t LoopCount) const {
1119 if (!PGO.haveRegionCounts())
1120 return nullptr;
1121 std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1122 if (!CondCount || *CondCount == 0)
1123 return nullptr;
1124 return createProfileWeights(LoopCount,
1125 std::max(*CondCount, LoopCount) - LoopCount);