1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the SampleContextTracker used by CSSPGO.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/IPO/SampleContextTracker.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/IR/DebugInfoMetadata.h"
17 #include "llvm/IR/InstrTypes.h"
18 #include "llvm/IR/Instruction.h"
19 #include "llvm/ProfileData/SampleProf.h"
25 using namespace sampleprof
;
27 #define DEBUG_TYPE "sample-context-tracker"
31 ContextTrieNode
*ContextTrieNode::getChildContext(const LineLocation
&CallSite
,
32 StringRef CalleeName
) {
33 if (CalleeName
.empty())
34 return getHottestChildContext(CallSite
);
36 uint64_t Hash
= FunctionSamples::getCallSiteHash(CalleeName
, CallSite
);
37 auto It
= AllChildContext
.find(Hash
);
38 if (It
!= AllChildContext
.end())
44 ContextTrieNode::getHottestChildContext(const LineLocation
&CallSite
) {
45 // CSFDO-TODO: This could be slow, change AllChildContext so we can
46 // do point look up for child node by call site alone.
47 // Retrieve the child node with max count for indirect call
48 ContextTrieNode
*ChildNodeRet
= nullptr;
49 uint64_t MaxCalleeSamples
= 0;
50 for (auto &It
: AllChildContext
) {
51 ContextTrieNode
&ChildNode
= It
.second
;
52 if (ChildNode
.CallSiteLoc
!= CallSite
)
54 FunctionSamples
*Samples
= ChildNode
.getFunctionSamples();
57 if (Samples
->getTotalSamples() > MaxCalleeSamples
) {
58 ChildNodeRet
= &ChildNode
;
59 MaxCalleeSamples
= Samples
->getTotalSamples();
67 SampleContextTracker::moveContextSamples(ContextTrieNode
&ToNodeParent
,
68 const LineLocation
&CallSite
,
69 ContextTrieNode
&&NodeToMove
) {
71 FunctionSamples::getCallSiteHash(NodeToMove
.getFuncName(), CallSite
);
72 std::map
<uint64_t, ContextTrieNode
> &AllChildContext
=
73 ToNodeParent
.getAllChildContext();
74 assert(!AllChildContext
.count(Hash
) && "Node to remove must exist");
75 AllChildContext
[Hash
] = NodeToMove
;
76 ContextTrieNode
&NewNode
= AllChildContext
[Hash
];
77 NewNode
.setCallSiteLoc(CallSite
);
79 // Walk through nodes in the moved the subtree, and update
80 // FunctionSamples' context as for the context promotion.
81 // We also need to set new parant link for all children.
82 std::queue
<ContextTrieNode
*> NodeToUpdate
;
83 NewNode
.setParentContext(&ToNodeParent
);
84 NodeToUpdate
.push(&NewNode
);
86 while (!NodeToUpdate
.empty()) {
87 ContextTrieNode
*Node
= NodeToUpdate
.front();
89 FunctionSamples
*FSamples
= Node
->getFunctionSamples();
92 setContextNode(FSamples
, Node
);
93 FSamples
->getContext().setState(SyntheticContext
);
96 for (auto &It
: Node
->getAllChildContext()) {
97 ContextTrieNode
*ChildNode
= &It
.second
;
98 ChildNode
->setParentContext(Node
);
99 NodeToUpdate
.push(ChildNode
);
106 void ContextTrieNode::removeChildContext(const LineLocation
&CallSite
,
107 StringRef CalleeName
) {
108 uint64_t Hash
= FunctionSamples::getCallSiteHash(CalleeName
, CallSite
);
109 // Note this essentially calls dtor and destroys that child context
110 AllChildContext
.erase(Hash
);
113 std::map
<uint64_t, ContextTrieNode
> &ContextTrieNode::getAllChildContext() {
114 return AllChildContext
;
117 StringRef
ContextTrieNode::getFuncName() const { return FuncName
; }
119 FunctionSamples
*ContextTrieNode::getFunctionSamples() const {
123 void ContextTrieNode::setFunctionSamples(FunctionSamples
*FSamples
) {
124 FuncSamples
= FSamples
;
127 std::optional
<uint32_t> ContextTrieNode::getFunctionSize() const {
131 void ContextTrieNode::addFunctionSize(uint32_t FSize
) {
135 FuncSize
= *FuncSize
+ FSize
;
138 LineLocation
ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc
; }
140 ContextTrieNode
*ContextTrieNode::getParentContext() const {
141 return ParentContext
;
144 void ContextTrieNode::setParentContext(ContextTrieNode
*Parent
) {
145 ParentContext
= Parent
;
148 void ContextTrieNode::setCallSiteLoc(const LineLocation
&Loc
) {
152 void ContextTrieNode::dumpNode() {
153 dbgs() << "Node: " << FuncName
<< "\n"
154 << " Callsite: " << CallSiteLoc
<< "\n"
155 << " Size: " << FuncSize
<< "\n"
158 for (auto &It
: AllChildContext
) {
159 dbgs() << " Node: " << It
.second
.getFuncName() << "\n";
163 void ContextTrieNode::dumpTree() {
164 dbgs() << "Context Profile Tree:\n";
165 std::queue
<ContextTrieNode
*> NodeQueue
;
166 NodeQueue
.push(this);
168 while (!NodeQueue
.empty()) {
169 ContextTrieNode
*Node
= NodeQueue
.front();
173 for (auto &It
: Node
->getAllChildContext()) {
174 ContextTrieNode
*ChildNode
= &It
.second
;
175 NodeQueue
.push(ChildNode
);
180 ContextTrieNode
*ContextTrieNode::getOrCreateChildContext(
181 const LineLocation
&CallSite
, StringRef CalleeName
, bool AllowCreate
) {
182 uint64_t Hash
= FunctionSamples::getCallSiteHash(CalleeName
, CallSite
);
183 auto It
= AllChildContext
.find(Hash
);
184 if (It
!= AllChildContext
.end()) {
185 assert(It
->second
.getFuncName() == CalleeName
&&
186 "Hash collision for child context node");
193 AllChildContext
[Hash
] = ContextTrieNode(this, CalleeName
, nullptr, CallSite
);
194 return &AllChildContext
[Hash
];
197 // Profiler tracker than manages profiles and its associated context
198 SampleContextTracker::SampleContextTracker(
199 SampleProfileMap
&Profiles
,
200 const DenseMap
<uint64_t, StringRef
> *GUIDToFuncNameMap
)
201 : GUIDToFuncNameMap(GUIDToFuncNameMap
) {
202 for (auto &FuncSample
: Profiles
) {
203 FunctionSamples
*FSamples
= &FuncSample
.second
;
204 SampleContext Context
= FuncSample
.second
.getContext();
205 LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context
.toString()
207 ContextTrieNode
*NewNode
= getOrCreateContextPath(Context
, true);
208 assert(!NewNode
->getFunctionSamples() &&
209 "New node can't have sample profile");
210 NewNode
->setFunctionSamples(FSamples
);
212 populateFuncToCtxtMap();
215 void SampleContextTracker::populateFuncToCtxtMap() {
216 for (auto *Node
: *this) {
217 FunctionSamples
*FSamples
= Node
->getFunctionSamples();
219 FSamples
->getContext().setState(RawContext
);
220 setContextNode(FSamples
, Node
);
221 FuncToCtxtProfiles
[Node
->getFuncName()].push_back(FSamples
);
227 SampleContextTracker::getCalleeContextSamplesFor(const CallBase
&Inst
,
228 StringRef CalleeName
) {
229 LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst
<< "\n");
230 DILocation
*DIL
= Inst
.getDebugLoc();
234 CalleeName
= FunctionSamples::getCanonicalFnName(CalleeName
);
235 // Convert real function names to MD5 names, if the input profile is
238 CalleeName
= getRepInFormat(CalleeName
, FunctionSamples::UseMD5
, FGUID
);
240 // For indirect call, CalleeName will be empty, in which case the context
241 // profile for callee with largest total samples will be returned.
242 ContextTrieNode
*CalleeContext
= getCalleeContextFor(DIL
, CalleeName
);
244 FunctionSamples
*FSamples
= CalleeContext
->getFunctionSamples();
245 LLVM_DEBUG(if (FSamples
) {
246 dbgs() << " Callee context found: " << getContextString(CalleeContext
)
255 std::vector
<const FunctionSamples
*>
256 SampleContextTracker::getIndirectCalleeContextSamplesFor(
257 const DILocation
*DIL
) {
258 std::vector
<const FunctionSamples
*> R
;
262 ContextTrieNode
*CallerNode
= getContextFor(DIL
);
263 LineLocation CallSite
= FunctionSamples::getCallSiteIdentifier(DIL
);
264 for (auto &It
: CallerNode
->getAllChildContext()) {
265 ContextTrieNode
&ChildNode
= It
.second
;
266 if (ChildNode
.getCallSiteLoc() != CallSite
)
268 if (FunctionSamples
*CalleeSamples
= ChildNode
.getFunctionSamples())
269 R
.push_back(CalleeSamples
);
276 SampleContextTracker::getContextSamplesFor(const DILocation
*DIL
) {
277 assert(DIL
&& "Expect non-null location");
279 ContextTrieNode
*ContextNode
= getContextFor(DIL
);
283 // We may have inlined callees during pre-LTO compilation, in which case
284 // we need to rely on the inline stack from !dbg to mark context profile
285 // as inlined, instead of `MarkContextSamplesInlined` during inlining.
286 // Sample profile loader walks through all instructions to get profile,
287 // which calls this function. So once that is done, all previously inlined
288 // context profile should be marked properly.
289 FunctionSamples
*Samples
= ContextNode
->getFunctionSamples();
290 if (Samples
&& ContextNode
->getParentContext() != &RootContext
)
291 Samples
->getContext().setState(InlinedContext
);
297 SampleContextTracker::getContextSamplesFor(const SampleContext
&Context
) {
298 ContextTrieNode
*Node
= getContextFor(Context
);
302 return Node
->getFunctionSamples();
305 SampleContextTracker::ContextSamplesTy
&
306 SampleContextTracker::getAllContextSamplesFor(const Function
&Func
) {
307 StringRef CanonName
= FunctionSamples::getCanonicalFnName(Func
);
308 return FuncToCtxtProfiles
[CanonName
];
311 SampleContextTracker::ContextSamplesTy
&
312 SampleContextTracker::getAllContextSamplesFor(StringRef Name
) {
313 return FuncToCtxtProfiles
[Name
];
316 FunctionSamples
*SampleContextTracker::getBaseSamplesFor(const Function
&Func
,
318 StringRef CanonName
= FunctionSamples::getCanonicalFnName(Func
);
319 return getBaseSamplesFor(CanonName
, MergeContext
);
322 FunctionSamples
*SampleContextTracker::getBaseSamplesFor(StringRef Name
,
324 LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name
<< "\n");
325 // Convert real function names to MD5 names, if the input profile is
328 Name
= getRepInFormat(Name
, FunctionSamples::UseMD5
, FGUID
);
330 // Base profile is top-level node (child of root node), so try to retrieve
331 // existing top-level node for given function first. If it exists, it could be
332 // that we've merged base profile before, or there's actually context-less
333 // profile from the input (e.g. due to unreliable stack walking).
334 ContextTrieNode
*Node
= getTopLevelContextNode(Name
);
336 LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name
339 // We have profile for function under different contexts,
340 // create synthetic base profile and merge context profiles
341 // into base profile.
342 for (auto *CSamples
: FuncToCtxtProfiles
[Name
]) {
343 SampleContext
&Context
= CSamples
->getContext();
344 // Skip inlined context profile and also don't re-merge any context
345 if (Context
.hasState(InlinedContext
) || Context
.hasState(MergedContext
))
348 ContextTrieNode
*FromNode
= getContextNodeForProfile(CSamples
);
349 if (FromNode
== Node
)
352 ContextTrieNode
&ToNode
= promoteMergeContextSamplesTree(*FromNode
);
353 assert((!Node
|| Node
== &ToNode
) && "Expect only one base profile");
358 // Still no profile even after merge/promotion (if allowed)
362 return Node
->getFunctionSamples();
365 void SampleContextTracker::markContextSamplesInlined(
366 const FunctionSamples
*InlinedSamples
) {
367 assert(InlinedSamples
&& "Expect non-null inlined samples");
368 LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
369 << getContextString(*InlinedSamples
) << "\n");
370 InlinedSamples
->getContext().setState(InlinedContext
);
373 ContextTrieNode
&SampleContextTracker::getRootContext() { return RootContext
; }
375 void SampleContextTracker::promoteMergeContextSamplesTree(
376 const Instruction
&Inst
, StringRef CalleeName
) {
377 LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
379 // Get the caller context for the call instruction, we don't use callee
380 // name from call because there can be context from indirect calls too.
381 DILocation
*DIL
= Inst
.getDebugLoc();
382 ContextTrieNode
*CallerNode
= getContextFor(DIL
);
386 // Get the context that needs to be promoted
387 LineLocation CallSite
= FunctionSamples::getCallSiteIdentifier(DIL
);
388 // For indirect call, CalleeName will be empty, in which case we need to
389 // promote all non-inlined child context profiles.
390 if (CalleeName
.empty()) {
391 for (auto &It
: CallerNode
->getAllChildContext()) {
392 ContextTrieNode
*NodeToPromo
= &It
.second
;
393 if (CallSite
!= NodeToPromo
->getCallSiteLoc())
395 FunctionSamples
*FromSamples
= NodeToPromo
->getFunctionSamples();
396 if (FromSamples
&& FromSamples
->getContext().hasState(InlinedContext
))
398 promoteMergeContextSamplesTree(*NodeToPromo
);
403 // Get the context for the given callee that needs to be promoted
404 ContextTrieNode
*NodeToPromo
=
405 CallerNode
->getChildContext(CallSite
, CalleeName
);
409 promoteMergeContextSamplesTree(*NodeToPromo
);
412 ContextTrieNode
&SampleContextTracker::promoteMergeContextSamplesTree(
413 ContextTrieNode
&NodeToPromo
) {
414 // Promote the input node to be directly under root. This can happen
415 // when we decided to not inline a function under context represented
416 // by the input node. The promote and merge is then needed to reflect
417 // the context profile in the base (context-less) profile.
418 FunctionSamples
*FromSamples
= NodeToPromo
.getFunctionSamples();
419 assert(FromSamples
&& "Shouldn't promote a context without profile");
420 (void)FromSamples
; // Unused in release build.
422 LLVM_DEBUG(dbgs() << " Found context tree root to promote: "
423 << getContextString(&NodeToPromo
) << "\n");
425 assert(!FromSamples
->getContext().hasState(InlinedContext
) &&
426 "Shouldn't promote inlined context profile");
427 return promoteMergeContextSamplesTree(NodeToPromo
, RootContext
);
432 SampleContextTracker::getContextString(const FunctionSamples
&FSamples
) const {
433 return getContextString(getContextNodeForProfile(&FSamples
));
437 SampleContextTracker::getContextString(ContextTrieNode
*Node
) const {
438 SampleContextFrameVector Res
;
439 if (Node
== &RootContext
)
440 return std::string();
441 Res
.emplace_back(Node
->getFuncName(), LineLocation(0, 0));
443 ContextTrieNode
*PreNode
= Node
;
444 Node
= Node
->getParentContext();
445 while (Node
&& Node
!= &RootContext
) {
446 Res
.emplace_back(Node
->getFuncName(), PreNode
->getCallSiteLoc());
448 Node
= Node
->getParentContext();
451 std::reverse(Res
.begin(), Res
.end());
453 return SampleContext::getContextString(Res
);
457 void SampleContextTracker::dump() { RootContext
.dumpTree(); }
459 StringRef
SampleContextTracker::getFuncNameFor(ContextTrieNode
*Node
) const {
460 if (!FunctionSamples::UseMD5
)
461 return Node
->getFuncName();
462 assert(GUIDToFuncNameMap
&& "GUIDToFuncNameMap needs to be populated first");
463 return GUIDToFuncNameMap
->lookup(std::stoull(Node
->getFuncName().data()));
467 SampleContextTracker::getContextFor(const SampleContext
&Context
) {
468 return getOrCreateContextPath(Context
, false);
472 SampleContextTracker::getCalleeContextFor(const DILocation
*DIL
,
473 StringRef CalleeName
) {
474 assert(DIL
&& "Expect non-null location");
476 ContextTrieNode
*CallContext
= getContextFor(DIL
);
480 // When CalleeName is empty, the child context profile with max
481 // total samples will be returned.
482 return CallContext
->getChildContext(
483 FunctionSamples::getCallSiteIdentifier(DIL
), CalleeName
);
486 ContextTrieNode
*SampleContextTracker::getContextFor(const DILocation
*DIL
) {
487 assert(DIL
&& "Expect non-null location");
488 SmallVector
<std::pair
<LineLocation
, StringRef
>, 10> S
;
490 // Use C++ linkage name if possible.
491 const DILocation
*PrevDIL
= DIL
;
492 for (DIL
= DIL
->getInlinedAt(); DIL
; DIL
= DIL
->getInlinedAt()) {
493 StringRef Name
= PrevDIL
->getScope()->getSubprogram()->getLinkageName();
495 Name
= PrevDIL
->getScope()->getSubprogram()->getName();
497 std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL
), Name
));
501 // Push root node, note that root node like main may only
502 // a name, but not linkage name.
503 StringRef RootName
= PrevDIL
->getScope()->getSubprogram()->getLinkageName();
504 if (RootName
.empty())
505 RootName
= PrevDIL
->getScope()->getSubprogram()->getName();
506 S
.push_back(std::make_pair(LineLocation(0, 0), RootName
));
508 // Convert real function names to MD5 names, if the input profile is
510 std::list
<std::string
> MD5Names
;
511 if (FunctionSamples::UseMD5
) {
512 for (auto &Location
: S
) {
513 MD5Names
.emplace_back();
514 getRepInFormat(Location
.second
, FunctionSamples::UseMD5
, MD5Names
.back());
515 Location
.second
= MD5Names
.back();
519 ContextTrieNode
*ContextNode
= &RootContext
;
521 while (--I
>= 0 && ContextNode
) {
522 LineLocation
&CallSite
= S
[I
].first
;
523 StringRef CalleeName
= S
[I
].second
;
524 ContextNode
= ContextNode
->getChildContext(CallSite
, CalleeName
);
534 SampleContextTracker::getOrCreateContextPath(const SampleContext
&Context
,
536 ContextTrieNode
*ContextNode
= &RootContext
;
537 LineLocation
CallSiteLoc(0, 0);
539 for (const auto &Callsite
: Context
.getContextFrames()) {
540 // Create child node at parent line/disc location
543 ContextNode
->getOrCreateChildContext(CallSiteLoc
, Callsite
.FuncName
);
546 ContextNode
->getChildContext(CallSiteLoc
, Callsite
.FuncName
);
548 CallSiteLoc
= Callsite
.Location
;
551 assert((!AllowCreate
|| ContextNode
) &&
552 "Node must exist if creation is allowed");
556 ContextTrieNode
*SampleContextTracker::getTopLevelContextNode(StringRef FName
) {
557 assert(!FName
.empty() && "Top level node query must provide valid name");
558 return RootContext
.getChildContext(LineLocation(0, 0), FName
);
561 ContextTrieNode
&SampleContextTracker::addTopLevelContextNode(StringRef FName
) {
562 assert(!getTopLevelContextNode(FName
) && "Node to add must not exist");
563 return *RootContext
.getOrCreateChildContext(LineLocation(0, 0), FName
);
566 void SampleContextTracker::mergeContextNode(ContextTrieNode
&FromNode
,
567 ContextTrieNode
&ToNode
) {
568 FunctionSamples
*FromSamples
= FromNode
.getFunctionSamples();
569 FunctionSamples
*ToSamples
= ToNode
.getFunctionSamples();
570 if (FromSamples
&& ToSamples
) {
571 // Merge/duplicate FromSamples into ToSamples
572 ToSamples
->merge(*FromSamples
);
573 ToSamples
->getContext().setState(SyntheticContext
);
574 FromSamples
->getContext().setState(MergedContext
);
575 if (FromSamples
->getContext().hasAttribute(ContextShouldBeInlined
))
576 ToSamples
->getContext().setAttribute(ContextShouldBeInlined
);
577 } else if (FromSamples
) {
578 // Transfer FromSamples from FromNode to ToNode
579 ToNode
.setFunctionSamples(FromSamples
);
580 setContextNode(FromSamples
, &ToNode
);
581 FromSamples
->getContext().setState(SyntheticContext
);
585 ContextTrieNode
&SampleContextTracker::promoteMergeContextSamplesTree(
586 ContextTrieNode
&FromNode
, ContextTrieNode
&ToNodeParent
) {
588 // Ignore call site location if destination is top level under root
589 LineLocation NewCallSiteLoc
= LineLocation(0, 0);
590 LineLocation OldCallSiteLoc
= FromNode
.getCallSiteLoc();
591 ContextTrieNode
&FromNodeParent
= *FromNode
.getParentContext();
592 ContextTrieNode
*ToNode
= nullptr;
593 bool MoveToRoot
= (&ToNodeParent
== &RootContext
);
595 NewCallSiteLoc
= OldCallSiteLoc
;
598 // Locate destination node, create/move if not existing
599 ToNode
= ToNodeParent
.getChildContext(NewCallSiteLoc
, FromNode
.getFuncName());
601 // Do not delete node to move from its parent here because
602 // caller is iterating over children of that parent node.
604 &moveContextSamples(ToNodeParent
, NewCallSiteLoc
, std::move(FromNode
));
606 dbgs() << " Context promoted and merged to: " << getContextString(ToNode
)
610 // Destination node exists, merge samples for the context tree
611 mergeContextNode(FromNode
, *ToNode
);
613 if (ToNode
->getFunctionSamples())
614 dbgs() << " Context promoted and merged to: "
615 << getContextString(ToNode
) << "\n";
618 // Recursively promote and merge children
619 for (auto &It
: FromNode
.getAllChildContext()) {
620 ContextTrieNode
&FromChildNode
= It
.second
;
621 promoteMergeContextSamplesTree(FromChildNode
, *ToNode
);
624 // Remove children once they're all merged
625 FromNode
.getAllChildContext().clear();
628 // For root of subtree, remove itself from old parent too
630 FromNodeParent
.removeChildContext(OldCallSiteLoc
, ToNode
->getFuncName());
635 void SampleContextTracker::createContextLessProfileMap(
636 SampleProfileMap
&ContextLessProfiles
) {
637 for (auto *Node
: *this) {
638 FunctionSamples
*FProfile
= Node
->getFunctionSamples();
639 // Profile's context can be empty, use ContextNode's func name.
641 ContextLessProfiles
.Create(Node
->getFuncName()).merge(*FProfile
);