1 //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a Union-find algorithm to compute Minimum Spanning Tree
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
15 #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/BranchProbabilityInfo.h"
21 #include "llvm/Analysis/CFG.h"
22 #include "llvm/Support/BranchProbability.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #define DEBUG_TYPE "cfgmst"
35 /// An union-find based Minimum Spanning Tree for CFG
37 /// Implements a Union-find algorithm to compute Minimum Spanning Tree
39 template <class Edge
, class BBInfo
> class CFGMST
{
43 // Store all the edges in CFG. It may contain some stale edges
44 // when Removed is set.
45 std::vector
<std::unique_ptr
<Edge
>> AllEdges
;
47 // This map records the auxiliary information for each BB.
48 DenseMap
<const BasicBlock
*, std::unique_ptr
<BBInfo
>> BBInfos
;
50 // Whehter the function has an exit block with no successors.
51 // (For function with an infinite loop, this block may be absent)
52 bool ExitBlockFound
= false;
54 // Find the root group of the G and compress the path from G to the root.
55 BBInfo
*findAndCompressGroup(BBInfo
*G
) {
57 G
->Group
= findAndCompressGroup(static_cast<BBInfo
*>(G
->Group
));
58 return static_cast<BBInfo
*>(G
->Group
);
61 // Union BB1 and BB2 into the same group and return true.
62 // Returns false if BB1 and BB2 are already in the same group.
63 bool unionGroups(const BasicBlock
*BB1
, const BasicBlock
*BB2
) {
64 BBInfo
*BB1G
= findAndCompressGroup(&getBBInfo(BB1
));
65 BBInfo
*BB2G
= findAndCompressGroup(&getBBInfo(BB2
));
70 // Make the smaller rank tree a direct child or the root of high rank tree.
71 if (BB1G
->Rank
< BB2G
->Rank
)
75 // If the ranks are the same, increment root of one tree by one.
76 if (BB1G
->Rank
== BB2G
->Rank
)
82 // Give BB, return the auxiliary information.
83 BBInfo
&getBBInfo(const BasicBlock
*BB
) const {
84 auto It
= BBInfos
.find(BB
);
85 assert(It
->second
.get() != nullptr);
86 return *It
->second
.get();
89 // Give BB, return the auxiliary information if it's available.
90 BBInfo
*findBBInfo(const BasicBlock
*BB
) const {
91 auto It
= BBInfos
.find(BB
);
92 if (It
== BBInfos
.end())
94 return It
->second
.get();
97 // Traverse the CFG using a stack. Find all the edges and assign the weight.
98 // Edges with large weight will be put into MST first so they are less likely
99 // to be instrumented.
101 LLVM_DEBUG(dbgs() << "Build Edge on " << F
.getName() << "\n");
103 const BasicBlock
*Entry
= &(F
.getEntryBlock());
104 uint64_t EntryWeight
= (BFI
!= nullptr ? BFI
->getEntryFreq() : 2);
105 // If we want to instrument the entry count, lower the weight to 0.
106 if (InstrumentFuncEntry
)
108 Edge
*EntryIncoming
= nullptr, *EntryOutgoing
= nullptr,
109 *ExitOutgoing
= nullptr, *ExitIncoming
= nullptr;
110 uint64_t MaxEntryOutWeight
= 0, MaxExitOutWeight
= 0, MaxExitInWeight
= 0;
112 // Add a fake edge to the entry.
113 EntryIncoming
= &addEdge(nullptr, Entry
, EntryWeight
);
114 LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry
->getName()
115 << " w = " << EntryWeight
<< "\n");
117 // Special handling for single BB functions.
118 if (succ_empty(Entry
)) {
119 addEdge(Entry
, nullptr, EntryWeight
);
123 static const uint32_t CriticalEdgeMultiplier
= 1000;
125 for (BasicBlock
&BB
: F
) {
126 Instruction
*TI
= BB
.getTerminator();
128 (BFI
!= nullptr ? BFI
->getBlockFreq(&BB
).getFrequency() : 2);
130 if (int successors
= TI
->getNumSuccessors()) {
131 for (int i
= 0; i
!= successors
; ++i
) {
132 BasicBlock
*TargetBB
= TI
->getSuccessor(i
);
133 bool Critical
= isCriticalEdge(TI
, i
);
134 uint64_t scaleFactor
= BBWeight
;
136 if (scaleFactor
< UINT64_MAX
/ CriticalEdgeMultiplier
)
137 scaleFactor
*= CriticalEdgeMultiplier
;
139 scaleFactor
= UINT64_MAX
;
142 Weight
= BPI
->getEdgeProbability(&BB
, TargetBB
).scale(scaleFactor
);
145 auto *E
= &addEdge(&BB
, TargetBB
, Weight
);
146 E
->IsCritical
= Critical
;
147 LLVM_DEBUG(dbgs() << " Edge: from " << BB
.getName() << " to "
148 << TargetBB
->getName() << " w=" << Weight
<< "\n");
150 // Keep track of entry/exit edges:
152 if (Weight
> MaxEntryOutWeight
) {
153 MaxEntryOutWeight
= Weight
;
158 auto *TargetTI
= TargetBB
->getTerminator();
159 if (TargetTI
&& !TargetTI
->getNumSuccessors()) {
160 if (Weight
> MaxExitInWeight
) {
161 MaxExitInWeight
= Weight
;
167 ExitBlockFound
= true;
168 Edge
*ExitO
= &addEdge(&BB
, nullptr, BBWeight
);
169 if (BBWeight
> MaxExitOutWeight
) {
170 MaxExitOutWeight
= BBWeight
;
171 ExitOutgoing
= ExitO
;
173 LLVM_DEBUG(dbgs() << " Edge: from " << BB
.getName() << " to fake exit"
174 << " w = " << BBWeight
<< "\n");
178 // Entry/exit edge adjustment heurisitic:
179 // prefer instrumenting entry edge over exit edge
180 // if possible. Those exit edges may never have a chance to be
181 // executed (for instance the program is an event handling loop)
182 // before the profile is asynchronously dumped.
184 // If EntryIncoming and ExitOutgoing has similar weight, make sure
185 // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
186 // and ExitIncoming has similar weight, make sure ExitIncoming becomes
188 uint64_t EntryInWeight
= EntryWeight
;
190 if (EntryInWeight
>= MaxExitOutWeight
&&
191 EntryInWeight
* 2 < MaxExitOutWeight
* 3) {
192 EntryIncoming
->Weight
= MaxExitOutWeight
;
193 ExitOutgoing
->Weight
= EntryInWeight
+ 1;
196 if (MaxEntryOutWeight
>= MaxExitInWeight
&&
197 MaxEntryOutWeight
* 2 < MaxExitInWeight
* 3) {
198 EntryOutgoing
->Weight
= MaxExitInWeight
;
199 ExitIncoming
->Weight
= MaxEntryOutWeight
+ 1;
203 // Sort CFG edges based on its weight.
204 void sortEdgesByWeight() {
205 llvm::stable_sort(AllEdges
, [](const std::unique_ptr
<Edge
> &Edge1
,
206 const std::unique_ptr
<Edge
> &Edge2
) {
207 return Edge1
->Weight
> Edge2
->Weight
;
211 // Traverse all the edges and compute the Minimum Weight Spanning Tree
212 // using union-find algorithm.
213 void computeMinimumSpanningTree() {
214 // First, put all the critical edge with landing-pad as the Dest to MST.
215 // This works around the insufficient support of critical edges split
216 // when destination BB is a landing pad.
217 for (auto &Ei
: AllEdges
) {
220 if (Ei
->IsCritical
) {
221 if (Ei
->DestBB
&& Ei
->DestBB
->isLandingPad()) {
222 if (unionGroups(Ei
->SrcBB
, Ei
->DestBB
))
228 for (auto &Ei
: AllEdges
) {
231 // If we detect infinite loops, force
232 // instrumenting the entry edge:
233 if (!ExitBlockFound
&& Ei
->SrcBB
== nullptr)
235 if (unionGroups(Ei
->SrcBB
, Ei
->DestBB
))
240 // Dump the Debug information about the instrumentation.
241 void dumpEdges(raw_ostream
&OS
, const Twine
&Message
) const {
242 if (!Message
.str().empty())
243 OS
<< Message
<< "\n";
244 OS
<< " Number of Basic Blocks: " << BBInfos
.size() << "\n";
245 for (auto &BI
: BBInfos
) {
246 const BasicBlock
*BB
= BI
.first
;
247 OS
<< " BB: " << (BB
== nullptr ? "FakeNode" : BB
->getName()) << " "
248 << BI
.second
->infoString() << "\n";
251 OS
<< " Number of Edges: " << AllEdges
.size()
252 << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
254 for (auto &EI
: AllEdges
)
255 OS
<< " Edge " << Count
++ << ": " << getBBInfo(EI
->SrcBB
).Index
<< "-->"
256 << getBBInfo(EI
->DestBB
).Index
<< EI
->infoString() << "\n";
259 // Add an edge to AllEdges with weight W.
260 Edge
&addEdge(const BasicBlock
*Src
, const BasicBlock
*Dest
, uint64_t W
) {
261 uint32_t Index
= BBInfos
.size();
262 auto Iter
= BBInfos
.end();
264 std::tie(Iter
, Inserted
) = BBInfos
.insert(std::make_pair(Src
, nullptr));
266 // Newly inserted, update the real info.
267 Iter
->second
= std::move(std::make_unique
<BBInfo
>(Index
));
270 std::tie(Iter
, Inserted
) = BBInfos
.insert(std::make_pair(Dest
, nullptr));
272 // Newly inserted, update the real info.
273 Iter
->second
= std::move(std::make_unique
<BBInfo
>(Index
));
274 AllEdges
.emplace_back(new Edge(Src
, Dest
, W
));
275 return *AllEdges
.back();
278 BranchProbabilityInfo
*BPI
;
279 BlockFrequencyInfo
*BFI
;
281 // If function entry will be always instrumented.
282 bool InstrumentFuncEntry
;
285 CFGMST(Function
&Func
, bool InstrumentFuncEntry_
,
286 BranchProbabilityInfo
*BPI_
= nullptr,
287 BlockFrequencyInfo
*BFI_
= nullptr)
288 : F(Func
), BPI(BPI_
), BFI(BFI_
),
289 InstrumentFuncEntry(InstrumentFuncEntry_
) {
292 computeMinimumSpanningTree();
293 if (AllEdges
.size() > 1 && InstrumentFuncEntry
)
294 std::iter_swap(std::move(AllEdges
.begin()),
295 std::move(AllEdges
.begin() + AllEdges
.size() - 1));
299 } // end namespace llvm
301 #undef DEBUG_TYPE // "cfgmst"
303 #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H