1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Instrumentation-based profile-guided optimization
11 //===----------------------------------------------------------------------===//
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
26 static llvm::cl::opt
<bool>
27 EnableValueProfiling("enable-value-profiling",
28 llvm::cl::desc("Enable value profiling"),
29 llvm::cl::Hidden
, llvm::cl::init(false));
31 using namespace clang
;
32 using namespace CodeGen
;
34 void CodeGenPGO::setFuncName(StringRef Name
,
35 llvm::GlobalValue::LinkageTypes Linkage
) {
36 llvm::IndexedInstrProfReader
*PGOReader
= CGM
.getPGOReader();
37 FuncName
= llvm::getPGOFuncName(
38 Name
, Linkage
, CGM
.getCodeGenOpts().MainFileName
,
39 PGOReader
? PGOReader
->getVersion() : llvm::IndexedInstrProf::Version
);
41 // If we're generating a profile, create a variable for the name.
42 if (CGM
.getCodeGenOpts().hasProfileClangInstr())
43 FuncNameVar
= llvm::createPGOFuncNameVar(CGM
.getModule(), Linkage
, FuncName
);
46 void CodeGenPGO::setFuncName(llvm::Function
*Fn
) {
47 setFuncName(Fn
->getName(), Fn
->getLinkage());
48 // Create PGOFuncName meta data.
49 llvm::createPGOFuncNameMetadata(*Fn
, FuncName
);
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion
: unsigned {
58 // Keep this set to the latest hash version.
59 PGO_HASH_LATEST
= PGO_HASH_V3
63 /// Stable hasher for PGO region counters.
65 /// PGOHash produces a stable hash of a given function's control flow.
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
70 /// \note When this hash does eventually change (years?), we still need to
71 /// support old hashes. We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
76 PGOHashVersion HashVersion
;
79 static const int NumBitsPerType
= 6;
80 static const unsigned NumTypesPerWord
= sizeof(uint64_t) * 8 / NumBitsPerType
;
81 static const unsigned TooBig
= 1u << NumBitsPerType
;
84 /// Hash values for AST nodes.
86 /// Distinct values for AST nodes that have region counters attached.
88 /// These values must be stable. All new members must be added at the end,
89 /// and no members should be removed. Changing the enumeration value for an
90 /// AST node will affect the hash of every function that contains that node.
91 enum HashType
: unsigned char {
98 ObjCForCollectionStmt
,
108 BinaryConditionalOperator
,
109 // The preceding values are available with PGO_HASH_V1.
127 // The preceding values are available since PGO_HASH_V2.
129 // Keep this last. It's for the static assert that follows.
132 static_assert(LastHashType
<= TooBig
, "Too many types in HashType");
134 PGOHash(PGOHashVersion HashVersion
)
135 : Working(0), Count(0), HashVersion(HashVersion
) {}
136 void combine(HashType Type
);
138 PGOHashVersion
getHashVersion() const { return HashVersion
; }
140 const int PGOHash::NumBitsPerType
;
141 const unsigned PGOHash::NumTypesPerWord
;
142 const unsigned PGOHash::TooBig
;
144 /// Get the PGO hash version used in the given indexed profile.
145 static PGOHashVersion
getPGOHashVersion(llvm::IndexedInstrProfReader
*PGOReader
,
146 CodeGenModule
&CGM
) {
147 if (PGOReader
->getVersion() <= 4)
149 if (PGOReader
->getVersion() <= 5)
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters
: public RecursiveASTVisitor
<MapRegionCounters
> {
156 using Base
= RecursiveASTVisitor
<MapRegionCounters
>;
158 /// The next counter value to assign.
159 unsigned NextCounter
;
160 /// The function hash.
162 /// The map of statements to counters.
163 llvm::DenseMap
<const Stmt
*, unsigned> &CounterMap
;
164 /// The profile version.
165 uint64_t ProfileVersion
;
167 MapRegionCounters(PGOHashVersion HashVersion
, uint64_t ProfileVersion
,
168 llvm::DenseMap
<const Stmt
*, unsigned> &CounterMap
)
169 : NextCounter(0), Hash(HashVersion
), CounterMap(CounterMap
),
170 ProfileVersion(ProfileVersion
) {}
172 // Blocks and lambdas are handled as separate functions, so we need not
173 // traverse them in the parent context.
174 bool TraverseBlockExpr(BlockExpr
*BE
) { return true; }
175 bool TraverseLambdaExpr(LambdaExpr
*LE
) {
176 // Traverse the captures, but not the body.
177 for (auto C
: zip(LE
->captures(), LE
->capture_inits()))
178 TraverseLambdaCapture(LE
, &std::get
<0>(C
), std::get
<1>(C
));
181 bool TraverseCapturedStmt(CapturedStmt
*CS
) { return true; }
183 bool VisitDecl(const Decl
*D
) {
184 switch (D
->getKind()) {
188 case Decl::CXXMethod
:
189 case Decl::CXXConstructor
:
190 case Decl::CXXDestructor
:
191 case Decl::CXXConversion
:
192 case Decl::ObjCMethod
:
195 CounterMap
[D
->getBody()] = NextCounter
++;
201 /// If \p S gets a fresh counter, update the counter mappings. Return the
203 PGOHash::HashType
updateCounterMappings(Stmt
*S
) {
204 auto Type
= getHashType(PGO_HASH_V1
, S
);
205 if (Type
!= PGOHash::None
)
206 CounterMap
[S
] = NextCounter
++;
210 /// The RHS of all logical operators gets a fresh counter in order to count
211 /// how many times the RHS evaluates to true or false, depending on the
212 /// semantics of the operator. This is only valid for ">= v7" of the profile
213 /// version so that we facilitate backward compatibility.
214 bool VisitBinaryOperator(BinaryOperator
*S
) {
215 if (ProfileVersion
>= llvm::IndexedInstrProf::Version7
)
216 if (S
->isLogicalOp() &&
217 CodeGenFunction::isInstrumentedCondition(S
->getRHS()))
218 CounterMap
[S
->getRHS()] = NextCounter
++;
219 return Base::VisitBinaryOperator(S
);
222 /// Include \p S in the function hash.
223 bool VisitStmt(Stmt
*S
) {
224 auto Type
= updateCounterMappings(S
);
225 if (Hash
.getHashVersion() != PGO_HASH_V1
)
226 Type
= getHashType(Hash
.getHashVersion(), S
);
227 if (Type
!= PGOHash::None
)
232 bool TraverseIfStmt(IfStmt
*If
) {
233 // If we used the V1 hash, use the default traversal.
234 if (Hash
.getHashVersion() == PGO_HASH_V1
)
235 return Base::TraverseIfStmt(If
);
237 // Otherwise, keep track of which branch we're in while traversing.
239 for (Stmt
*CS
: If
->children()) {
242 if (CS
== If
->getThen())
243 Hash
.combine(PGOHash::IfThenBranch
);
244 else if (CS
== If
->getElse())
245 Hash
.combine(PGOHash::IfElseBranch
);
248 Hash
.combine(PGOHash::EndOfScope
);
252 // If the statement type \p N is nestable, and its nesting impacts profile
253 // stability, define a custom traversal which tracks the end of the statement
254 // in the hash (provided we're not using the V1 hash).
255 #define DEFINE_NESTABLE_TRAVERSAL(N) \
256 bool Traverse##N(N *S) { \
257 Base::Traverse##N(S); \
258 if (Hash.getHashVersion() != PGO_HASH_V1) \
259 Hash.combine(PGOHash::EndOfScope); \
263 DEFINE_NESTABLE_TRAVERSAL(WhileStmt
)
264 DEFINE_NESTABLE_TRAVERSAL(DoStmt
)
265 DEFINE_NESTABLE_TRAVERSAL(ForStmt
)
266 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt
)
267 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt
)
268 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt
)
269 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt
)
271 /// Get version \p HashVersion of the PGO hash for \p S.
272 PGOHash::HashType
getHashType(PGOHashVersion HashVersion
, const Stmt
*S
) {
273 switch (S
->getStmtClass()) {
276 case Stmt::LabelStmtClass
:
277 return PGOHash::LabelStmt
;
278 case Stmt::WhileStmtClass
:
279 return PGOHash::WhileStmt
;
280 case Stmt::DoStmtClass
:
281 return PGOHash::DoStmt
;
282 case Stmt::ForStmtClass
:
283 return PGOHash::ForStmt
;
284 case Stmt::CXXForRangeStmtClass
:
285 return PGOHash::CXXForRangeStmt
;
286 case Stmt::ObjCForCollectionStmtClass
:
287 return PGOHash::ObjCForCollectionStmt
;
288 case Stmt::SwitchStmtClass
:
289 return PGOHash::SwitchStmt
;
290 case Stmt::CaseStmtClass
:
291 return PGOHash::CaseStmt
;
292 case Stmt::DefaultStmtClass
:
293 return PGOHash::DefaultStmt
;
294 case Stmt::IfStmtClass
:
295 return PGOHash::IfStmt
;
296 case Stmt::CXXTryStmtClass
:
297 return PGOHash::CXXTryStmt
;
298 case Stmt::CXXCatchStmtClass
:
299 return PGOHash::CXXCatchStmt
;
300 case Stmt::ConditionalOperatorClass
:
301 return PGOHash::ConditionalOperator
;
302 case Stmt::BinaryConditionalOperatorClass
:
303 return PGOHash::BinaryConditionalOperator
;
304 case Stmt::BinaryOperatorClass
: {
305 const BinaryOperator
*BO
= cast
<BinaryOperator
>(S
);
306 if (BO
->getOpcode() == BO_LAnd
)
307 return PGOHash::BinaryOperatorLAnd
;
308 if (BO
->getOpcode() == BO_LOr
)
309 return PGOHash::BinaryOperatorLOr
;
310 if (HashVersion
>= PGO_HASH_V2
) {
311 switch (BO
->getOpcode()) {
315 return PGOHash::BinaryOperatorLT
;
317 return PGOHash::BinaryOperatorGT
;
319 return PGOHash::BinaryOperatorLE
;
321 return PGOHash::BinaryOperatorGE
;
323 return PGOHash::BinaryOperatorEQ
;
325 return PGOHash::BinaryOperatorNE
;
332 if (HashVersion
>= PGO_HASH_V2
) {
333 switch (S
->getStmtClass()) {
336 case Stmt::GotoStmtClass
:
337 return PGOHash::GotoStmt
;
338 case Stmt::IndirectGotoStmtClass
:
339 return PGOHash::IndirectGotoStmt
;
340 case Stmt::BreakStmtClass
:
341 return PGOHash::BreakStmt
;
342 case Stmt::ContinueStmtClass
:
343 return PGOHash::ContinueStmt
;
344 case Stmt::ReturnStmtClass
:
345 return PGOHash::ReturnStmt
;
346 case Stmt::CXXThrowExprClass
:
347 return PGOHash::ThrowExpr
;
348 case Stmt::UnaryOperatorClass
: {
349 const UnaryOperator
*UO
= cast
<UnaryOperator
>(S
);
350 if (UO
->getOpcode() == UO_LNot
)
351 return PGOHash::UnaryOperatorLNot
;
357 return PGOHash::None
;
361 /// A StmtVisitor that propagates the raw counts through the AST and
362 /// records the count at statements where the value may change.
363 struct ComputeRegionCounts
: public ConstStmtVisitor
<ComputeRegionCounts
> {
367 /// A flag that is set when the current count should be recorded on the
368 /// next statement, such as at the exit of a loop.
369 bool RecordNextStmtCount
;
371 /// The count at the current location in the traversal.
372 uint64_t CurrentCount
;
374 /// The map of statements to count values.
375 llvm::DenseMap
<const Stmt
*, uint64_t> &CountMap
;
377 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378 struct BreakContinue
{
380 uint64_t ContinueCount
;
381 BreakContinue() : BreakCount(0), ContinueCount(0) {}
383 SmallVector
<BreakContinue
, 8> BreakContinueStack
;
385 ComputeRegionCounts(llvm::DenseMap
<const Stmt
*, uint64_t> &CountMap
,
387 : PGO(PGO
), RecordNextStmtCount(false), CountMap(CountMap
) {}
389 void RecordStmtCount(const Stmt
*S
) {
390 if (RecordNextStmtCount
) {
391 CountMap
[S
] = CurrentCount
;
392 RecordNextStmtCount
= false;
396 /// Set and return the current count.
397 uint64_t setCount(uint64_t Count
) {
398 CurrentCount
= Count
;
402 void VisitStmt(const Stmt
*S
) {
404 for (const Stmt
*Child
: S
->children())
409 void VisitFunctionDecl(const FunctionDecl
*D
) {
410 // Counter tracks entry to the function body.
411 uint64_t BodyCount
= setCount(PGO
.getRegionCount(D
->getBody()));
412 CountMap
[D
->getBody()] = BodyCount
;
416 // Skip lambda expressions. We visit these as FunctionDecls when we're
417 // generating them and aren't interested in the body when generating a
419 void VisitLambdaExpr(const LambdaExpr
*LE
) {}
421 void VisitCapturedDecl(const CapturedDecl
*D
) {
422 // Counter tracks entry to the capture body.
423 uint64_t BodyCount
= setCount(PGO
.getRegionCount(D
->getBody()));
424 CountMap
[D
->getBody()] = BodyCount
;
428 void VisitObjCMethodDecl(const ObjCMethodDecl
*D
) {
429 // Counter tracks entry to the method body.
430 uint64_t BodyCount
= setCount(PGO
.getRegionCount(D
->getBody()));
431 CountMap
[D
->getBody()] = BodyCount
;
435 void VisitBlockDecl(const BlockDecl
*D
) {
436 // Counter tracks entry to the block body.
437 uint64_t BodyCount
= setCount(PGO
.getRegionCount(D
->getBody()));
438 CountMap
[D
->getBody()] = BodyCount
;
442 void VisitReturnStmt(const ReturnStmt
*S
) {
444 if (S
->getRetValue())
445 Visit(S
->getRetValue());
447 RecordNextStmtCount
= true;
450 void VisitCXXThrowExpr(const CXXThrowExpr
*E
) {
453 Visit(E
->getSubExpr());
455 RecordNextStmtCount
= true;
458 void VisitGotoStmt(const GotoStmt
*S
) {
461 RecordNextStmtCount
= true;
464 void VisitLabelStmt(const LabelStmt
*S
) {
465 RecordNextStmtCount
= false;
466 // Counter tracks the block following the label.
467 uint64_t BlockCount
= setCount(PGO
.getRegionCount(S
));
468 CountMap
[S
] = BlockCount
;
469 Visit(S
->getSubStmt());
472 void VisitBreakStmt(const BreakStmt
*S
) {
474 assert(!BreakContinueStack
.empty() && "break not in a loop or switch!");
475 BreakContinueStack
.back().BreakCount
+= CurrentCount
;
477 RecordNextStmtCount
= true;
480 void VisitContinueStmt(const ContinueStmt
*S
) {
482 assert(!BreakContinueStack
.empty() && "continue stmt not in a loop!");
483 BreakContinueStack
.back().ContinueCount
+= CurrentCount
;
485 RecordNextStmtCount
= true;
488 void VisitWhileStmt(const WhileStmt
*S
) {
490 uint64_t ParentCount
= CurrentCount
;
492 BreakContinueStack
.push_back(BreakContinue());
493 // Visit the body region first so the break/continue adjustments can be
494 // included when visiting the condition.
495 uint64_t BodyCount
= setCount(PGO
.getRegionCount(S
));
496 CountMap
[S
->getBody()] = CurrentCount
;
498 uint64_t BackedgeCount
= CurrentCount
;
500 // ...then go back and propagate counts through the condition. The count
501 // at the start of the condition is the sum of the incoming edges,
502 // the backedge from the end of the loop body, and the edges from
503 // continue statements.
504 BreakContinue BC
= BreakContinueStack
.pop_back_val();
506 setCount(ParentCount
+ BackedgeCount
+ BC
.ContinueCount
);
507 CountMap
[S
->getCond()] = CondCount
;
509 setCount(BC
.BreakCount
+ CondCount
- BodyCount
);
510 RecordNextStmtCount
= true;
513 void VisitDoStmt(const DoStmt
*S
) {
515 uint64_t LoopCount
= PGO
.getRegionCount(S
);
517 BreakContinueStack
.push_back(BreakContinue());
518 // The count doesn't include the fallthrough from the parent scope. Add it.
519 uint64_t BodyCount
= setCount(LoopCount
+ CurrentCount
);
520 CountMap
[S
->getBody()] = BodyCount
;
522 uint64_t BackedgeCount
= CurrentCount
;
524 BreakContinue BC
= BreakContinueStack
.pop_back_val();
525 // The count at the start of the condition is equal to the count at the
526 // end of the body, plus any continues.
527 uint64_t CondCount
= setCount(BackedgeCount
+ BC
.ContinueCount
);
528 CountMap
[S
->getCond()] = CondCount
;
530 setCount(BC
.BreakCount
+ CondCount
- LoopCount
);
531 RecordNextStmtCount
= true;
534 void VisitForStmt(const ForStmt
*S
) {
539 uint64_t ParentCount
= CurrentCount
;
541 BreakContinueStack
.push_back(BreakContinue());
542 // Visit the body region first. (This is basically the same as a while
543 // loop; see further comments in VisitWhileStmt.)
544 uint64_t BodyCount
= setCount(PGO
.getRegionCount(S
));
545 CountMap
[S
->getBody()] = BodyCount
;
547 uint64_t BackedgeCount
= CurrentCount
;
548 BreakContinue BC
= BreakContinueStack
.pop_back_val();
550 // The increment is essentially part of the body but it needs to include
551 // the count for all the continue statements.
553 uint64_t IncCount
= setCount(BackedgeCount
+ BC
.ContinueCount
);
554 CountMap
[S
->getInc()] = IncCount
;
558 // ...then go back and propagate counts through the condition.
560 setCount(ParentCount
+ BackedgeCount
+ BC
.ContinueCount
);
562 CountMap
[S
->getCond()] = CondCount
;
565 setCount(BC
.BreakCount
+ CondCount
- BodyCount
);
566 RecordNextStmtCount
= true;
569 void VisitCXXForRangeStmt(const CXXForRangeStmt
*S
) {
573 Visit(S
->getLoopVarStmt());
574 Visit(S
->getRangeStmt());
575 Visit(S
->getBeginStmt());
576 Visit(S
->getEndStmt());
578 uint64_t ParentCount
= CurrentCount
;
579 BreakContinueStack
.push_back(BreakContinue());
580 // Visit the body region first. (This is basically the same as a while
581 // loop; see further comments in VisitWhileStmt.)
582 uint64_t BodyCount
= setCount(PGO
.getRegionCount(S
));
583 CountMap
[S
->getBody()] = BodyCount
;
585 uint64_t BackedgeCount
= CurrentCount
;
586 BreakContinue BC
= BreakContinueStack
.pop_back_val();
588 // The increment is essentially part of the body but it needs to include
589 // the count for all the continue statements.
590 uint64_t IncCount
= setCount(BackedgeCount
+ BC
.ContinueCount
);
591 CountMap
[S
->getInc()] = IncCount
;
594 // ...then go back and propagate counts through the condition.
596 setCount(ParentCount
+ BackedgeCount
+ BC
.ContinueCount
);
597 CountMap
[S
->getCond()] = CondCount
;
599 setCount(BC
.BreakCount
+ CondCount
- BodyCount
);
600 RecordNextStmtCount
= true;
603 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt
*S
) {
605 Visit(S
->getElement());
606 uint64_t ParentCount
= CurrentCount
;
607 BreakContinueStack
.push_back(BreakContinue());
608 // Counter tracks the body of the loop.
609 uint64_t BodyCount
= setCount(PGO
.getRegionCount(S
));
610 CountMap
[S
->getBody()] = BodyCount
;
612 uint64_t BackedgeCount
= CurrentCount
;
613 BreakContinue BC
= BreakContinueStack
.pop_back_val();
615 setCount(BC
.BreakCount
+ ParentCount
+ BackedgeCount
+ BC
.ContinueCount
-
617 RecordNextStmtCount
= true;
620 void VisitSwitchStmt(const SwitchStmt
*S
) {
626 BreakContinueStack
.push_back(BreakContinue());
628 // If the switch is inside a loop, add the continue counts.
629 BreakContinue BC
= BreakContinueStack
.pop_back_val();
630 if (!BreakContinueStack
.empty())
631 BreakContinueStack
.back().ContinueCount
+= BC
.ContinueCount
;
632 // Counter tracks the exit block of the switch.
633 setCount(PGO
.getRegionCount(S
));
634 RecordNextStmtCount
= true;
637 void VisitSwitchCase(const SwitchCase
*S
) {
638 RecordNextStmtCount
= false;
639 // Counter for this particular case. This counts only jumps from the
640 // switch header and does not include fallthrough from the case before
642 uint64_t CaseCount
= PGO
.getRegionCount(S
);
643 setCount(CurrentCount
+ CaseCount
);
644 // We need the count without fallthrough in the mapping, so it's more useful
645 // for branch probabilities.
646 CountMap
[S
] = CaseCount
;
647 RecordNextStmtCount
= true;
648 Visit(S
->getSubStmt());
651 void VisitIfStmt(const IfStmt
*S
) {
654 if (S
->isConsteval()) {
655 const Stmt
*Stm
= S
->isNegatedConsteval() ? S
->getThen() : S
->getElse();
661 uint64_t ParentCount
= CurrentCount
;
666 // Counter tracks the "then" part of an if statement. The count for
667 // the "else" part, if it exists, will be calculated from this counter.
668 uint64_t ThenCount
= setCount(PGO
.getRegionCount(S
));
669 CountMap
[S
->getThen()] = ThenCount
;
671 uint64_t OutCount
= CurrentCount
;
673 uint64_t ElseCount
= ParentCount
- ThenCount
;
676 CountMap
[S
->getElse()] = ElseCount
;
678 OutCount
+= CurrentCount
;
680 OutCount
+= ElseCount
;
682 RecordNextStmtCount
= true;
685 void VisitCXXTryStmt(const CXXTryStmt
*S
) {
687 Visit(S
->getTryBlock());
688 for (unsigned I
= 0, E
= S
->getNumHandlers(); I
< E
; ++I
)
689 Visit(S
->getHandler(I
));
690 // Counter tracks the continuation block of the try statement.
691 setCount(PGO
.getRegionCount(S
));
692 RecordNextStmtCount
= true;
695 void VisitCXXCatchStmt(const CXXCatchStmt
*S
) {
696 RecordNextStmtCount
= false;
697 // Counter tracks the catch statement's handler block.
698 uint64_t CatchCount
= setCount(PGO
.getRegionCount(S
));
699 CountMap
[S
] = CatchCount
;
700 Visit(S
->getHandlerBlock());
703 void VisitAbstractConditionalOperator(const AbstractConditionalOperator
*E
) {
705 uint64_t ParentCount
= CurrentCount
;
708 // Counter tracks the "true" part of a conditional operator. The
709 // count in the "false" part will be calculated from this counter.
710 uint64_t TrueCount
= setCount(PGO
.getRegionCount(E
));
711 CountMap
[E
->getTrueExpr()] = TrueCount
;
712 Visit(E
->getTrueExpr());
713 uint64_t OutCount
= CurrentCount
;
715 uint64_t FalseCount
= setCount(ParentCount
- TrueCount
);
716 CountMap
[E
->getFalseExpr()] = FalseCount
;
717 Visit(E
->getFalseExpr());
718 OutCount
+= CurrentCount
;
721 RecordNextStmtCount
= true;
724 void VisitBinLAnd(const BinaryOperator
*E
) {
726 uint64_t ParentCount
= CurrentCount
;
728 // Counter tracks the right hand side of a logical and operator.
729 uint64_t RHSCount
= setCount(PGO
.getRegionCount(E
));
730 CountMap
[E
->getRHS()] = RHSCount
;
732 setCount(ParentCount
+ RHSCount
- CurrentCount
);
733 RecordNextStmtCount
= true;
736 void VisitBinLOr(const BinaryOperator
*E
) {
738 uint64_t ParentCount
= CurrentCount
;
740 // Counter tracks the right hand side of a logical or operator.
741 uint64_t RHSCount
= setCount(PGO
.getRegionCount(E
));
742 CountMap
[E
->getRHS()] = RHSCount
;
744 setCount(ParentCount
+ RHSCount
- CurrentCount
);
745 RecordNextStmtCount
= true;
748 } // end anonymous namespace
750 void PGOHash::combine(HashType Type
) {
751 // Check that we never combine 0 and only have six bits.
752 assert(Type
&& "Hash is invalid: unexpected type 0");
753 assert(unsigned(Type
) < TooBig
&& "Hash is invalid: too many types");
755 // Pass through MD5 if enough work has built up.
756 if (Count
&& Count
% NumTypesPerWord
== 0) {
757 using namespace llvm::support
;
758 uint64_t Swapped
= endian::byte_swap
<uint64_t, little
>(Working
);
759 MD5
.update(llvm::ArrayRef((uint8_t *)&Swapped
, sizeof(Swapped
)));
763 // Accumulate the current type.
765 Working
= Working
<< NumBitsPerType
| Type
;
768 uint64_t PGOHash::finalize() {
769 // Use Working as the hash directly if we never used MD5.
770 if (Count
<= NumTypesPerWord
)
771 // No need to byte swap here, since none of the math was endian-dependent.
772 // This number will be byte-swapped as required on endianness transitions,
773 // so we will see the same value on the other side.
776 // Check for remaining work in Working.
778 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
779 // is buggy because it converts a uint64_t into an array of uint8_t.
780 if (HashVersion
< PGO_HASH_V3
) {
781 MD5
.update({(uint8_t)Working
});
783 using namespace llvm::support
;
784 uint64_t Swapped
= endian::byte_swap
<uint64_t, little
>(Working
);
785 MD5
.update(llvm::ArrayRef((uint8_t *)&Swapped
, sizeof(Swapped
)));
789 // Finalize the MD5 and return the hash.
790 llvm::MD5::MD5Result Result
;
795 void CodeGenPGO::assignRegionCounters(GlobalDecl GD
, llvm::Function
*Fn
) {
796 const Decl
*D
= GD
.getDecl();
800 // Skip CUDA/HIP kernel launch stub functions.
801 if (CGM
.getLangOpts().CUDA
&& !CGM
.getLangOpts().CUDAIsDevice
&&
802 D
->hasAttr
<CUDAGlobalAttr
>())
805 bool InstrumentRegions
= CGM
.getCodeGenOpts().hasProfileClangInstr();
806 llvm::IndexedInstrProfReader
*PGOReader
= CGM
.getPGOReader();
807 if (!InstrumentRegions
&& !PGOReader
)
811 // Constructors and destructors may be represented by several functions in IR.
812 // If so, instrument only base variant, others are implemented by delegation
813 // to the base one, it would be counted twice otherwise.
814 if (CGM
.getTarget().getCXXABI().hasConstructorVariants()) {
815 if (const auto *CCD
= dyn_cast
<CXXConstructorDecl
>(D
))
816 if (GD
.getCtorType() != Ctor_Base
&&
817 CodeGenFunction::IsConstructorDelegationValid(CCD
))
820 if (isa
<CXXDestructorDecl
>(D
) && GD
.getDtorType() != Dtor_Base
)
823 CGM
.ClearUnusedCoverageMapping(D
);
824 if (Fn
->hasFnAttribute(llvm::Attribute::NoProfile
))
826 if (Fn
->hasFnAttribute(llvm::Attribute::SkipProfile
))
831 mapRegionCounters(D
);
832 if (CGM
.getCodeGenOpts().CoverageMapping
)
833 emitCounterRegionMapping(D
);
835 SourceManager
&SM
= CGM
.getContext().getSourceManager();
836 loadRegionCounts(PGOReader
, SM
.isInMainFile(D
->getLocation()));
837 computeRegionCounts(D
);
838 applyFunctionAttributes(PGOReader
, Fn
);
842 void CodeGenPGO::mapRegionCounters(const Decl
*D
) {
843 // Use the latest hash version when inserting instrumentation, but use the
844 // version in the indexed profile if we're reading PGO data.
845 PGOHashVersion HashVersion
= PGO_HASH_LATEST
;
846 uint64_t ProfileVersion
= llvm::IndexedInstrProf::Version
;
847 if (auto *PGOReader
= CGM
.getPGOReader()) {
848 HashVersion
= getPGOHashVersion(PGOReader
, CGM
);
849 ProfileVersion
= PGOReader
->getVersion();
852 RegionCounterMap
.reset(new llvm::DenseMap
<const Stmt
*, unsigned>);
853 MapRegionCounters
Walker(HashVersion
, ProfileVersion
, *RegionCounterMap
);
854 if (const FunctionDecl
*FD
= dyn_cast_or_null
<FunctionDecl
>(D
))
855 Walker
.TraverseDecl(const_cast<FunctionDecl
*>(FD
));
856 else if (const ObjCMethodDecl
*MD
= dyn_cast_or_null
<ObjCMethodDecl
>(D
))
857 Walker
.TraverseDecl(const_cast<ObjCMethodDecl
*>(MD
));
858 else if (const BlockDecl
*BD
= dyn_cast_or_null
<BlockDecl
>(D
))
859 Walker
.TraverseDecl(const_cast<BlockDecl
*>(BD
));
860 else if (const CapturedDecl
*CD
= dyn_cast_or_null
<CapturedDecl
>(D
))
861 Walker
.TraverseDecl(const_cast<CapturedDecl
*>(CD
));
862 assert(Walker
.NextCounter
> 0 && "no entry counter mapped for decl");
863 NumRegionCounters
= Walker
.NextCounter
;
864 FunctionHash
= Walker
.Hash
.finalize();
867 bool CodeGenPGO::skipRegionMappingForDecl(const Decl
*D
) {
871 // Skip host-only functions in the CUDA device compilation and device-only
872 // functions in the host compilation. Just roughly filter them out based on
873 // the function attributes. If there are effectively host-only or device-only
874 // ones, their coverage mapping may still be generated.
875 if (CGM
.getLangOpts().CUDA
&&
876 ((CGM
.getLangOpts().CUDAIsDevice
&& !D
->hasAttr
<CUDADeviceAttr
>() &&
877 !D
->hasAttr
<CUDAGlobalAttr
>()) ||
878 (!CGM
.getLangOpts().CUDAIsDevice
&&
879 (D
->hasAttr
<CUDAGlobalAttr
>() ||
880 (!D
->hasAttr
<CUDAHostAttr
>() && D
->hasAttr
<CUDADeviceAttr
>())))))
883 // Don't map the functions in system headers.
884 const auto &SM
= CGM
.getContext().getSourceManager();
885 auto Loc
= D
->getBody()->getBeginLoc();
886 return SM
.isInSystemHeader(Loc
);
889 void CodeGenPGO::emitCounterRegionMapping(const Decl
*D
) {
890 if (skipRegionMappingForDecl(D
))
893 std::string CoverageMapping
;
894 llvm::raw_string_ostream
OS(CoverageMapping
);
895 CoverageMappingGen
MappingGen(*CGM
.getCoverageMapping(),
896 CGM
.getContext().getSourceManager(),
897 CGM
.getLangOpts(), RegionCounterMap
.get());
898 MappingGen
.emitCounterMapping(D
, OS
);
901 if (CoverageMapping
.empty())
904 CGM
.getCoverageMapping()->addFunctionMappingRecord(
905 FuncNameVar
, FuncName
, FunctionHash
, CoverageMapping
);
909 CodeGenPGO::emitEmptyCounterMapping(const Decl
*D
, StringRef Name
,
910 llvm::GlobalValue::LinkageTypes Linkage
) {
911 if (skipRegionMappingForDecl(D
))
914 std::string CoverageMapping
;
915 llvm::raw_string_ostream
OS(CoverageMapping
);
916 CoverageMappingGen
MappingGen(*CGM
.getCoverageMapping(),
917 CGM
.getContext().getSourceManager(),
919 MappingGen
.emitEmptyMapping(D
, OS
);
922 if (CoverageMapping
.empty())
925 setFuncName(Name
, Linkage
);
926 CGM
.getCoverageMapping()->addFunctionMappingRecord(
927 FuncNameVar
, FuncName
, FunctionHash
, CoverageMapping
, false);
930 void CodeGenPGO::computeRegionCounts(const Decl
*D
) {
931 StmtCountMap
.reset(new llvm::DenseMap
<const Stmt
*, uint64_t>);
932 ComputeRegionCounts
Walker(*StmtCountMap
, *this);
933 if (const FunctionDecl
*FD
= dyn_cast_or_null
<FunctionDecl
>(D
))
934 Walker
.VisitFunctionDecl(FD
);
935 else if (const ObjCMethodDecl
*MD
= dyn_cast_or_null
<ObjCMethodDecl
>(D
))
936 Walker
.VisitObjCMethodDecl(MD
);
937 else if (const BlockDecl
*BD
= dyn_cast_or_null
<BlockDecl
>(D
))
938 Walker
.VisitBlockDecl(BD
);
939 else if (const CapturedDecl
*CD
= dyn_cast_or_null
<CapturedDecl
>(D
))
940 Walker
.VisitCapturedDecl(const_cast<CapturedDecl
*>(CD
));
944 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader
*PGOReader
,
945 llvm::Function
*Fn
) {
946 if (!haveRegionCounts())
949 uint64_t FunctionCount
= getRegionCount(nullptr);
950 Fn
->setEntryCount(FunctionCount
);
953 void CodeGenPGO::emitCounterIncrement(CGBuilderTy
&Builder
, const Stmt
*S
,
954 llvm::Value
*StepV
) {
955 if (!CGM
.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap
)
957 if (!Builder
.GetInsertBlock())
960 unsigned Counter
= (*RegionCounterMap
)[S
];
961 auto *I8PtrTy
= llvm::Type::getInt8PtrTy(CGM
.getLLVMContext());
963 llvm::Value
*Args
[] = {llvm::ConstantExpr::getBitCast(FuncNameVar
, I8PtrTy
),
964 Builder
.getInt64(FunctionHash
),
965 Builder
.getInt32(NumRegionCounters
),
966 Builder
.getInt32(Counter
), StepV
};
968 Builder
.CreateCall(CGM
.getIntrinsic(llvm::Intrinsic::instrprof_increment
),
972 CGM
.getIntrinsic(llvm::Intrinsic::instrprof_increment_step
),
976 void CodeGenPGO::setValueProfilingFlag(llvm::Module
&M
) {
977 if (CGM
.getCodeGenOpts().hasProfileClangInstr())
978 M
.addModuleFlag(llvm::Module::Warning
, "EnableValueProfiling",
979 uint32_t(EnableValueProfiling
));
982 // This method either inserts a call to the profile run-time during
983 // instrumentation or puts profile data into metadata for PGO use.
984 void CodeGenPGO::valueProfile(CGBuilderTy
&Builder
, uint32_t ValueKind
,
985 llvm::Instruction
*ValueSite
, llvm::Value
*ValuePtr
) {
987 if (!EnableValueProfiling
)
990 if (!ValuePtr
|| !ValueSite
|| !Builder
.GetInsertBlock())
993 if (isa
<llvm::Constant
>(ValuePtr
))
996 bool InstrumentValueSites
= CGM
.getCodeGenOpts().hasProfileClangInstr();
997 if (InstrumentValueSites
&& RegionCounterMap
) {
998 auto BuilderInsertPoint
= Builder
.saveIP();
999 Builder
.SetInsertPoint(ValueSite
);
1000 llvm::Value
*Args
[5] = {
1001 llvm::ConstantExpr::getBitCast(FuncNameVar
, Builder
.getInt8PtrTy()),
1002 Builder
.getInt64(FunctionHash
),
1003 Builder
.CreatePtrToInt(ValuePtr
, Builder
.getInt64Ty()),
1004 Builder
.getInt32(ValueKind
),
1005 Builder
.getInt32(NumValueSites
[ValueKind
]++)
1008 CGM
.getIntrinsic(llvm::Intrinsic::instrprof_value_profile
), Args
);
1009 Builder
.restoreIP(BuilderInsertPoint
);
1013 llvm::IndexedInstrProfReader
*PGOReader
= CGM
.getPGOReader();
1014 if (PGOReader
&& haveRegionCounts()) {
1015 // We record the top most called three functions at each call site.
1016 // Profile metadata contains "VP" string identifying this metadata
1017 // as value profiling data, then a uint32_t value for the value profiling
1018 // kind, a uint64_t value for the total number of times the call is
1019 // executed, followed by the function hash and execution count (uint64_t)
1020 // pairs for each function.
1021 if (NumValueSites
[ValueKind
] >= ProfRecord
->getNumValueSites(ValueKind
))
1024 llvm::annotateValueSite(CGM
.getModule(), *ValueSite
, *ProfRecord
,
1025 (llvm::InstrProfValueKind
)ValueKind
,
1026 NumValueSites
[ValueKind
]);
1028 NumValueSites
[ValueKind
]++;
1032 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader
*PGOReader
,
1033 bool IsInMainFile
) {
1034 CGM
.getPGOStats().addVisited(IsInMainFile
);
1035 RegionCounts
.clear();
1036 llvm::Expected
<llvm::InstrProfRecord
> RecordExpected
=
1037 PGOReader
->getInstrProfRecord(FuncName
, FunctionHash
);
1038 if (auto E
= RecordExpected
.takeError()) {
1039 auto IPE
= std::get
<0>(llvm::InstrProfError::take(std::move(E
)));
1040 if (IPE
== llvm::instrprof_error::unknown_function
)
1041 CGM
.getPGOStats().addMissing(IsInMainFile
);
1042 else if (IPE
== llvm::instrprof_error::hash_mismatch
)
1043 CGM
.getPGOStats().addMismatched(IsInMainFile
);
1044 else if (IPE
== llvm::instrprof_error::malformed
)
1045 // TODO: Consider a more specific warning for this case.
1046 CGM
.getPGOStats().addMismatched(IsInMainFile
);
1050 std::make_unique
<llvm::InstrProfRecord
>(std::move(RecordExpected
.get()));
1051 RegionCounts
= ProfRecord
->Counts
;
1054 /// Calculate what to divide by to scale weights.
1056 /// Given the maximum weight, calculate a divisor that will scale all the
1057 /// weights to strictly less than UINT32_MAX.
1058 static uint64_t calculateWeightScale(uint64_t MaxWeight
) {
1059 return MaxWeight
< UINT32_MAX
? 1 : MaxWeight
/ UINT32_MAX
+ 1;
1062 /// Scale an individual branch weight (and add 1).
1064 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1066 /// According to Laplace's Rule of Succession, it is better to compute the
1067 /// weight based on the count plus 1, so universally add 1 to the value.
1069 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1070 /// greater than \c Weight.
1071 static uint32_t scaleBranchWeight(uint64_t Weight
, uint64_t Scale
) {
1072 assert(Scale
&& "scale by 0?");
1073 uint64_t Scaled
= Weight
/ Scale
+ 1;
1074 assert(Scaled
<= UINT32_MAX
&& "overflow 32-bits");
1078 llvm::MDNode
*CodeGenFunction::createProfileWeights(uint64_t TrueCount
,
1079 uint64_t FalseCount
) const {
1080 // Check for empty weights.
1081 if (!TrueCount
&& !FalseCount
)
1084 // Calculate how to scale down to 32-bits.
1085 uint64_t Scale
= calculateWeightScale(std::max(TrueCount
, FalseCount
));
1087 llvm::MDBuilder
MDHelper(CGM
.getLLVMContext());
1088 return MDHelper
.createBranchWeights(scaleBranchWeight(TrueCount
, Scale
),
1089 scaleBranchWeight(FalseCount
, Scale
));
1093 CodeGenFunction::createProfileWeights(ArrayRef
<uint64_t> Weights
) const {
1094 // We need at least two elements to create meaningful weights.
1095 if (Weights
.size() < 2)
1098 // Check for empty weights.
1099 uint64_t MaxWeight
= *std::max_element(Weights
.begin(), Weights
.end());
1103 // Calculate how to scale down to 32-bits.
1104 uint64_t Scale
= calculateWeightScale(MaxWeight
);
1106 SmallVector
<uint32_t, 16> ScaledWeights
;
1107 ScaledWeights
.reserve(Weights
.size());
1108 for (uint64_t W
: Weights
)
1109 ScaledWeights
.push_back(scaleBranchWeight(W
, Scale
));
1111 llvm::MDBuilder
MDHelper(CGM
.getLLVMContext());
1112 return MDHelper
.createBranchWeights(ScaledWeights
);
1116 CodeGenFunction::createProfileWeightsForLoop(const Stmt
*Cond
,
1117 uint64_t LoopCount
) const {
1118 if (!PGO
.haveRegionCounts())
1120 std::optional
<uint64_t> CondCount
= PGO
.getStmtCount(Cond
);
1121 if (!CondCount
|| *CondCount
== 0)
1123 return createProfileWeights(LoopCount
,
1124 std::max(*CondCount
, LoopCount
) - LoopCount
);