[lit] Improve lit.Run class
[llvm-complete.git] / lib / Analysis / BranchProbabilityInfo.cpp
bloba06ee096d54c28f9b09f2b579d5c170b5b1df049
1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/CFG.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/LLVMContext.h"
30 #include "llvm/IR/Metadata.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/BranchProbability.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <cassert>
40 #include <cstdint>
41 #include <iterator>
42 #include <utility>
44 using namespace llvm;
46 #define DEBUG_TYPE "branch-prob"
48 static cl::opt<bool> PrintBranchProb(
49 "print-bpi", cl::init(false), cl::Hidden,
50 cl::desc("Print the branch probability info."));
52 cl::opt<std::string> PrintBranchProbFuncName(
53 "print-bpi-func-name", cl::Hidden,
54 cl::desc("The option to specify the name of the function "
55 "whose branch probability info is printed."));
57 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
58 "Branch Probability Analysis", false, true)
59 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
60 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
61 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
62 "Branch Probability Analysis", false, true)
64 char BranchProbabilityInfoWrapperPass::ID = 0;
66 // Weights are for internal use only. They are used by heuristics to help to
67 // estimate edges' probability. Example:
69 // Using "Loop Branch Heuristics" we predict weights of edges for the
70 // block BB2.
71 // ...
72 // |
73 // V
74 // BB1<-+
75 // | |
76 // | | (Weight = 124)
77 // V |
78 // BB2--+
79 // |
80 // | (Weight = 4)
81 // V
82 // BB3
84 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
85 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
86 static const uint32_t LBH_TAKEN_WEIGHT = 124;
87 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
88 // Unlikely edges within a loop are half as likely as other edges
89 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
91 /// Unreachable-terminating branch taken probability.
92 ///
93 /// This is the probability for a branch being taken to a block that terminates
94 /// (eventually) in unreachable. These are predicted as unlikely as possible.
95 /// All reachable probability will equally share the remaining part.
96 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
98 /// Weight for a branch taken going into a cold block.
99 ///
100 /// This is the weight for a branch taken toward a block marked
101 /// cold. A block is marked cold if it's postdominated by a
102 /// block containing a call to a cold function. Cold functions
103 /// are those marked with attribute 'cold'.
104 static const uint32_t CC_TAKEN_WEIGHT = 4;
106 /// Weight for a branch not-taken into a cold block.
108 /// This is the weight for a branch not taken toward a block marked
109 /// cold.
110 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
112 static const uint32_t PH_TAKEN_WEIGHT = 20;
113 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
115 static const uint32_t ZH_TAKEN_WEIGHT = 20;
116 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
118 static const uint32_t FPH_TAKEN_WEIGHT = 20;
119 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
121 /// This is the probability for an ordered floating point comparison.
122 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
123 /// This is the probability for an unordered floating point comparison, it means
124 /// one or two of the operands are NaN. Usually it is used to test for an
125 /// exceptional case, so the result is unlikely.
126 static const uint32_t FPH_UNO_WEIGHT = 1;
128 /// Invoke-terminating normal branch taken weight
130 /// This is the weight for branching to the normal destination of an invoke
131 /// instruction. We expect this to happen most of the time. Set the weight to an
132 /// absurdly high value so that nested loops subsume it.
133 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
135 /// Invoke-terminating normal branch not-taken weight.
137 /// This is the weight for branching to the unwind destination of an invoke
138 /// instruction. This is essentially never taken.
139 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
141 /// Add \p BB to PostDominatedByUnreachable set if applicable.
142 void
143 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
144 const Instruction *TI = BB->getTerminator();
145 if (TI->getNumSuccessors() == 0) {
146 if (isa<UnreachableInst>(TI) ||
147 // If this block is terminated by a call to
148 // @llvm.experimental.deoptimize then treat it like an unreachable since
149 // the @llvm.experimental.deoptimize call is expected to practically
150 // never execute.
151 BB->getTerminatingDeoptimizeCall())
152 PostDominatedByUnreachable.insert(BB);
153 return;
156 // If the terminator is an InvokeInst, check only the normal destination block
157 // as the unwind edge of InvokeInst is also very unlikely taken.
158 if (auto *II = dyn_cast<InvokeInst>(TI)) {
159 if (PostDominatedByUnreachable.count(II->getNormalDest()))
160 PostDominatedByUnreachable.insert(BB);
161 return;
164 for (auto *I : successors(BB))
165 // If any of successor is not post dominated then BB is also not.
166 if (!PostDominatedByUnreachable.count(I))
167 return;
169 PostDominatedByUnreachable.insert(BB);
172 /// Add \p BB to PostDominatedByColdCall set if applicable.
173 void
174 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
175 assert(!PostDominatedByColdCall.count(BB));
176 const Instruction *TI = BB->getTerminator();
177 if (TI->getNumSuccessors() == 0)
178 return;
180 // If all of successor are post dominated then BB is also done.
181 if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
182 return PostDominatedByColdCall.count(SuccBB);
183 })) {
184 PostDominatedByColdCall.insert(BB);
185 return;
188 // If the terminator is an InvokeInst, check only the normal destination
189 // block as the unwind edge of InvokeInst is also very unlikely taken.
190 if (auto *II = dyn_cast<InvokeInst>(TI))
191 if (PostDominatedByColdCall.count(II->getNormalDest())) {
192 PostDominatedByColdCall.insert(BB);
193 return;
196 // Otherwise, if the block itself contains a cold function, add it to the
197 // set of blocks post-dominated by a cold call.
198 for (auto &I : *BB)
199 if (const CallInst *CI = dyn_cast<CallInst>(&I))
200 if (CI->hasFnAttr(Attribute::Cold)) {
201 PostDominatedByColdCall.insert(BB);
202 return;
206 /// Calculate edge weights for successors lead to unreachable.
208 /// Predict that a successor which leads necessarily to an
209 /// unreachable-terminated block as extremely unlikely.
210 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
211 const Instruction *TI = BB->getTerminator();
212 (void) TI;
213 assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
214 assert(!isa<InvokeInst>(TI) &&
215 "Invokes should have already been handled by calcInvokeHeuristics");
217 SmallVector<unsigned, 4> UnreachableEdges;
218 SmallVector<unsigned, 4> ReachableEdges;
220 for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
221 if (PostDominatedByUnreachable.count(*I))
222 UnreachableEdges.push_back(I.getSuccessorIndex());
223 else
224 ReachableEdges.push_back(I.getSuccessorIndex());
226 // Skip probabilities if all were reachable.
227 if (UnreachableEdges.empty())
228 return false;
230 if (ReachableEdges.empty()) {
231 BranchProbability Prob(1, UnreachableEdges.size());
232 for (unsigned SuccIdx : UnreachableEdges)
233 setEdgeProbability(BB, SuccIdx, Prob);
234 return true;
237 auto UnreachableProb = UR_TAKEN_PROB;
238 auto ReachableProb =
239 (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
240 ReachableEdges.size();
242 for (unsigned SuccIdx : UnreachableEdges)
243 setEdgeProbability(BB, SuccIdx, UnreachableProb);
244 for (unsigned SuccIdx : ReachableEdges)
245 setEdgeProbability(BB, SuccIdx, ReachableProb);
247 return true;
250 // Propagate existing explicit probabilities from either profile data or
251 // 'expect' intrinsic processing. Examine metadata against unreachable
252 // heuristic. The probability of the edge coming to unreachable block is
253 // set to min of metadata and unreachable heuristic.
254 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
255 const Instruction *TI = BB->getTerminator();
256 assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
257 if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
258 return false;
260 MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
261 if (!WeightsNode)
262 return false;
264 // Check that the number of successors is manageable.
265 assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
267 // Ensure there are weights for all of the successors. Note that the first
268 // operand to the metadata node is a name, not a weight.
269 if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
270 return false;
272 // Build up the final weights that will be used in a temporary buffer.
273 // Compute the sum of all weights to later decide whether they need to
274 // be scaled to fit in 32 bits.
275 uint64_t WeightSum = 0;
276 SmallVector<uint32_t, 2> Weights;
277 SmallVector<unsigned, 2> UnreachableIdxs;
278 SmallVector<unsigned, 2> ReachableIdxs;
279 Weights.reserve(TI->getNumSuccessors());
280 for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
281 ConstantInt *Weight =
282 mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
283 if (!Weight)
284 return false;
285 assert(Weight->getValue().getActiveBits() <= 32 &&
286 "Too many bits for uint32_t");
287 Weights.push_back(Weight->getZExtValue());
288 WeightSum += Weights.back();
289 if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
290 UnreachableIdxs.push_back(i - 1);
291 else
292 ReachableIdxs.push_back(i - 1);
294 assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
296 // If the sum of weights does not fit in 32 bits, scale every weight down
297 // accordingly.
298 uint64_t ScalingFactor =
299 (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
301 if (ScalingFactor > 1) {
302 WeightSum = 0;
303 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
304 Weights[i] /= ScalingFactor;
305 WeightSum += Weights[i];
308 assert(WeightSum <= UINT32_MAX &&
309 "Expected weights to scale down to 32 bits");
311 if (WeightSum == 0 || ReachableIdxs.size() == 0) {
312 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
313 Weights[i] = 1;
314 WeightSum = TI->getNumSuccessors();
317 // Set the probability.
318 SmallVector<BranchProbability, 2> BP;
319 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
320 BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
322 // Examine the metadata against unreachable heuristic.
323 // If the unreachable heuristic is more strong then we use it for this edge.
324 if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
325 auto ToDistribute = BranchProbability::getZero();
326 auto UnreachableProb = UR_TAKEN_PROB;
327 for (auto i : UnreachableIdxs)
328 if (UnreachableProb < BP[i]) {
329 ToDistribute += BP[i] - UnreachableProb;
330 BP[i] = UnreachableProb;
333 // If we modified the probability of some edges then we must distribute
334 // the difference between reachable blocks.
335 if (ToDistribute > BranchProbability::getZero()) {
336 BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
337 for (auto i : ReachableIdxs)
338 BP[i] += PerEdge;
342 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
343 setEdgeProbability(BB, i, BP[i]);
345 return true;
348 /// Calculate edge weights for edges leading to cold blocks.
350 /// A cold block is one post-dominated by a block with a call to a
351 /// cold function. Those edges are unlikely to be taken, so we give
352 /// them relatively low weight.
354 /// Return true if we could compute the weights for cold edges.
355 /// Return false, otherwise.
356 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
357 const Instruction *TI = BB->getTerminator();
358 (void) TI;
359 assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
360 assert(!isa<InvokeInst>(TI) &&
361 "Invokes should have already been handled by calcInvokeHeuristics");
363 // Determine which successors are post-dominated by a cold block.
364 SmallVector<unsigned, 4> ColdEdges;
365 SmallVector<unsigned, 4> NormalEdges;
366 for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
367 if (PostDominatedByColdCall.count(*I))
368 ColdEdges.push_back(I.getSuccessorIndex());
369 else
370 NormalEdges.push_back(I.getSuccessorIndex());
372 // Skip probabilities if no cold edges.
373 if (ColdEdges.empty())
374 return false;
376 if (NormalEdges.empty()) {
377 BranchProbability Prob(1, ColdEdges.size());
378 for (unsigned SuccIdx : ColdEdges)
379 setEdgeProbability(BB, SuccIdx, Prob);
380 return true;
383 auto ColdProb = BranchProbability::getBranchProbability(
384 CC_TAKEN_WEIGHT,
385 (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
386 auto NormalProb = BranchProbability::getBranchProbability(
387 CC_NONTAKEN_WEIGHT,
388 (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
390 for (unsigned SuccIdx : ColdEdges)
391 setEdgeProbability(BB, SuccIdx, ColdProb);
392 for (unsigned SuccIdx : NormalEdges)
393 setEdgeProbability(BB, SuccIdx, NormalProb);
395 return true;
398 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
399 // between two pointer or pointer and NULL will fail.
400 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
401 const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
402 if (!BI || !BI->isConditional())
403 return false;
405 Value *Cond = BI->getCondition();
406 ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
407 if (!CI || !CI->isEquality())
408 return false;
410 Value *LHS = CI->getOperand(0);
412 if (!LHS->getType()->isPointerTy())
413 return false;
415 assert(CI->getOperand(1)->getType()->isPointerTy());
417 // p != 0 -> isProb = true
418 // p == 0 -> isProb = false
419 // p != q -> isProb = true
420 // p == q -> isProb = false;
421 unsigned TakenIdx = 0, NonTakenIdx = 1;
422 bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
423 if (!isProb)
424 std::swap(TakenIdx, NonTakenIdx);
426 BranchProbability TakenProb(PH_TAKEN_WEIGHT,
427 PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
428 setEdgeProbability(BB, TakenIdx, TakenProb);
429 setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
430 return true;
433 static int getSCCNum(const BasicBlock *BB,
434 const BranchProbabilityInfo::SccInfo &SccI) {
435 auto SccIt = SccI.SccNums.find(BB);
436 if (SccIt == SccI.SccNums.end())
437 return -1;
438 return SccIt->second;
441 // Consider any block that is an entry point to the SCC as a header.
442 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
443 BranchProbabilityInfo::SccInfo &SccI) {
444 assert(getSCCNum(BB, SccI) == SccNum);
446 // Lazily compute the set of headers for a given SCC and cache the results
447 // in the SccHeaderMap.
448 if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
449 SccI.SccHeaders.resize(SccNum + 1);
450 auto &HeaderMap = SccI.SccHeaders[SccNum];
451 bool Inserted;
452 BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
453 std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
454 if (Inserted) {
455 bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
456 [&](const BasicBlock *Pred) {
457 return getSCCNum(Pred, SccI) != SccNum;
459 HeaderMapIt->second = IsHeader;
460 return IsHeader;
461 } else
462 return HeaderMapIt->second;
465 // Compute the unlikely successors to the block BB in the loop L, specifically
466 // those that are unlikely because this is a loop, and add them to the
467 // UnlikelyBlocks set.
468 static void
469 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
470 SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
471 // Sometimes in a loop we have a branch whose condition is made false by
472 // taking it. This is typically something like
473 // int n = 0;
474 // while (...) {
475 // if (++n >= MAX) {
476 // n = 0;
477 // }
478 // }
479 // In this sort of situation taking the branch means that at the very least it
480 // won't be taken again in the next iteration of the loop, so we should
481 // consider it less likely than a typical branch.
483 // We detect this by looking back through the graph of PHI nodes that sets the
484 // value that the condition depends on, and seeing if we can reach a successor
485 // block which can be determined to make the condition false.
487 // FIXME: We currently consider unlikely blocks to be half as likely as other
488 // blocks, but if we consider the example above the likelyhood is actually
489 // 1/MAX. We could therefore be more precise in how unlikely we consider
490 // blocks to be, but it would require more careful examination of the form
491 // of the comparison expression.
492 const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
493 if (!BI || !BI->isConditional())
494 return;
496 // Check if the branch is based on an instruction compared with a constant
497 CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
498 if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
499 !isa<Constant>(CI->getOperand(1)))
500 return;
502 // Either the instruction must be a PHI, or a chain of operations involving
503 // constants that ends in a PHI which we can then collapse into a single value
504 // if the PHI value is known.
505 Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
506 PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
507 Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
508 // Collect the instructions until we hit a PHI
509 SmallVector<BinaryOperator *, 1> InstChain;
510 while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
511 isa<Constant>(CmpLHS->getOperand(1))) {
512 // Stop if the chain extends outside of the loop
513 if (!L->contains(CmpLHS))
514 return;
515 InstChain.push_back(cast<BinaryOperator>(CmpLHS));
516 CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
517 if (CmpLHS)
518 CmpPHI = dyn_cast<PHINode>(CmpLHS);
520 if (!CmpPHI || !L->contains(CmpPHI))
521 return;
523 // Trace the phi node to find all values that come from successors of BB
524 SmallPtrSet<PHINode*, 8> VisitedInsts;
525 SmallVector<PHINode*, 8> WorkList;
526 WorkList.push_back(CmpPHI);
527 VisitedInsts.insert(CmpPHI);
528 while (!WorkList.empty()) {
529 PHINode *P = WorkList.back();
530 WorkList.pop_back();
531 for (BasicBlock *B : P->blocks()) {
532 // Skip blocks that aren't part of the loop
533 if (!L->contains(B))
534 continue;
535 Value *V = P->getIncomingValueForBlock(B);
536 // If the source is a PHI add it to the work list if we haven't
537 // already visited it.
538 if (PHINode *PN = dyn_cast<PHINode>(V)) {
539 if (VisitedInsts.insert(PN).second)
540 WorkList.push_back(PN);
541 continue;
543 // If this incoming value is a constant and B is a successor of BB, then
544 // we can constant-evaluate the compare to see if it makes the branch be
545 // taken or not.
546 Constant *CmpLHSConst = dyn_cast<Constant>(V);
547 if (!CmpLHSConst ||
548 std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
549 continue;
550 // First collapse InstChain
551 for (Instruction *I : llvm::reverse(InstChain)) {
552 CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
553 cast<Constant>(I->getOperand(1)), true);
554 if (!CmpLHSConst)
555 break;
557 if (!CmpLHSConst)
558 continue;
559 // Now constant-evaluate the compare
560 Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
561 CmpLHSConst, CmpConst, true);
562 // If the result means we don't branch to the block then that block is
563 // unlikely.
564 if (Result &&
565 ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
566 (Result->isOneValue() && B == BI->getSuccessor(1))))
567 UnlikelyBlocks.insert(B);
572 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
573 // as taken, exiting edges as not-taken.
574 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
575 const LoopInfo &LI,
576 SccInfo &SccI) {
577 int SccNum;
578 Loop *L = LI.getLoopFor(BB);
579 if (!L) {
580 SccNum = getSCCNum(BB, SccI);
581 if (SccNum < 0)
582 return false;
585 SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
586 if (L)
587 computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
589 SmallVector<unsigned, 8> BackEdges;
590 SmallVector<unsigned, 8> ExitingEdges;
591 SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
592 SmallVector<unsigned, 8> UnlikelyEdges;
594 for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
595 // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
596 // irreducible loops.
597 if (L) {
598 if (UnlikelyBlocks.count(*I) != 0)
599 UnlikelyEdges.push_back(I.getSuccessorIndex());
600 else if (!L->contains(*I))
601 ExitingEdges.push_back(I.getSuccessorIndex());
602 else if (L->getHeader() == *I)
603 BackEdges.push_back(I.getSuccessorIndex());
604 else
605 InEdges.push_back(I.getSuccessorIndex());
606 } else {
607 if (getSCCNum(*I, SccI) != SccNum)
608 ExitingEdges.push_back(I.getSuccessorIndex());
609 else if (isSCCHeader(*I, SccNum, SccI))
610 BackEdges.push_back(I.getSuccessorIndex());
611 else
612 InEdges.push_back(I.getSuccessorIndex());
616 if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
617 return false;
619 // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
620 // normalize them so that they sum up to one.
621 unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
622 (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
623 (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
624 (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
626 if (uint32_t numBackEdges = BackEdges.size()) {
627 BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
628 auto Prob = TakenProb / numBackEdges;
629 for (unsigned SuccIdx : BackEdges)
630 setEdgeProbability(BB, SuccIdx, Prob);
633 if (uint32_t numInEdges = InEdges.size()) {
634 BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
635 auto Prob = TakenProb / numInEdges;
636 for (unsigned SuccIdx : InEdges)
637 setEdgeProbability(BB, SuccIdx, Prob);
640 if (uint32_t numExitingEdges = ExitingEdges.size()) {
641 BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
642 Denom);
643 auto Prob = NotTakenProb / numExitingEdges;
644 for (unsigned SuccIdx : ExitingEdges)
645 setEdgeProbability(BB, SuccIdx, Prob);
648 if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
649 BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
650 Denom);
651 auto Prob = UnlikelyProb / numUnlikelyEdges;
652 for (unsigned SuccIdx : UnlikelyEdges)
653 setEdgeProbability(BB, SuccIdx, Prob);
656 return true;
659 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
660 const TargetLibraryInfo *TLI) {
661 const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
662 if (!BI || !BI->isConditional())
663 return false;
665 Value *Cond = BI->getCondition();
666 ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
667 if (!CI)
668 return false;
670 auto GetConstantInt = [](Value *V) {
671 if (auto *I = dyn_cast<BitCastInst>(V))
672 return dyn_cast<ConstantInt>(I->getOperand(0));
673 return dyn_cast<ConstantInt>(V);
676 Value *RHS = CI->getOperand(1);
677 ConstantInt *CV = GetConstantInt(RHS);
678 if (!CV)
679 return false;
681 // If the LHS is the result of AND'ing a value with a single bit bitmask,
682 // we don't have information about probabilities.
683 if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
684 if (LHS->getOpcode() == Instruction::And)
685 if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
686 if (AndRHS->getValue().isPowerOf2())
687 return false;
689 // Check if the LHS is the return value of a library function
690 LibFunc Func = NumLibFuncs;
691 if (TLI)
692 if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
693 if (Function *CalledFn = Call->getCalledFunction())
694 TLI->getLibFunc(*CalledFn, Func);
696 bool isProb;
697 if (Func == LibFunc_strcasecmp ||
698 Func == LibFunc_strcmp ||
699 Func == LibFunc_strncasecmp ||
700 Func == LibFunc_strncmp ||
701 Func == LibFunc_memcmp) {
702 // strcmp and similar functions return zero, negative, or positive, if the
703 // first string is equal, less, or greater than the second. We consider it
704 // likely that the strings are not equal, so a comparison with zero is
705 // probably false, but also a comparison with any other number is also
706 // probably false given that what exactly is returned for nonzero values is
707 // not specified. Any kind of comparison other than equality we know
708 // nothing about.
709 switch (CI->getPredicate()) {
710 case CmpInst::ICMP_EQ:
711 isProb = false;
712 break;
713 case CmpInst::ICMP_NE:
714 isProb = true;
715 break;
716 default:
717 return false;
719 } else if (CV->isZero()) {
720 switch (CI->getPredicate()) {
721 case CmpInst::ICMP_EQ:
722 // X == 0 -> Unlikely
723 isProb = false;
724 break;
725 case CmpInst::ICMP_NE:
726 // X != 0 -> Likely
727 isProb = true;
728 break;
729 case CmpInst::ICMP_SLT:
730 // X < 0 -> Unlikely
731 isProb = false;
732 break;
733 case CmpInst::ICMP_SGT:
734 // X > 0 -> Likely
735 isProb = true;
736 break;
737 default:
738 return false;
740 } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
741 // InstCombine canonicalizes X <= 0 into X < 1.
742 // X <= 0 -> Unlikely
743 isProb = false;
744 } else if (CV->isMinusOne()) {
745 switch (CI->getPredicate()) {
746 case CmpInst::ICMP_EQ:
747 // X == -1 -> Unlikely
748 isProb = false;
749 break;
750 case CmpInst::ICMP_NE:
751 // X != -1 -> Likely
752 isProb = true;
753 break;
754 case CmpInst::ICMP_SGT:
755 // InstCombine canonicalizes X >= 0 into X > -1.
756 // X >= 0 -> Likely
757 isProb = true;
758 break;
759 default:
760 return false;
762 } else {
763 return false;
766 unsigned TakenIdx = 0, NonTakenIdx = 1;
768 if (!isProb)
769 std::swap(TakenIdx, NonTakenIdx);
771 BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
772 ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
773 setEdgeProbability(BB, TakenIdx, TakenProb);
774 setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
775 return true;
778 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
779 const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
780 if (!BI || !BI->isConditional())
781 return false;
783 Value *Cond = BI->getCondition();
784 FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
785 if (!FCmp)
786 return false;
788 uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
789 uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
790 bool isProb;
791 if (FCmp->isEquality()) {
792 // f1 == f2 -> Unlikely
793 // f1 != f2 -> Likely
794 isProb = !FCmp->isTrueWhenEqual();
795 } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
796 // !isnan -> Likely
797 isProb = true;
798 TakenWeight = FPH_ORD_WEIGHT;
799 NontakenWeight = FPH_UNO_WEIGHT;
800 } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
801 // isnan -> Unlikely
802 isProb = false;
803 TakenWeight = FPH_ORD_WEIGHT;
804 NontakenWeight = FPH_UNO_WEIGHT;
805 } else {
806 return false;
809 unsigned TakenIdx = 0, NonTakenIdx = 1;
811 if (!isProb)
812 std::swap(TakenIdx, NonTakenIdx);
814 BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
815 setEdgeProbability(BB, TakenIdx, TakenProb);
816 setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
817 return true;
820 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
821 const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
822 if (!II)
823 return false;
825 BranchProbability TakenProb(IH_TAKEN_WEIGHT,
826 IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
827 setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
828 setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
829 return true;
832 void BranchProbabilityInfo::releaseMemory() {
833 Probs.clear();
836 void BranchProbabilityInfo::print(raw_ostream &OS) const {
837 OS << "---- Branch Probabilities ----\n";
838 // We print the probabilities from the last function the analysis ran over,
839 // or the function it is currently running over.
840 assert(LastF && "Cannot print prior to running over a function");
841 for (const auto &BI : *LastF) {
842 for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
843 ++SI) {
844 printEdgeProbability(OS << " ", &BI, *SI);
849 bool BranchProbabilityInfo::
850 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
851 // Hot probability is at least 4/5 = 80%
852 // FIXME: Compare against a static "hot" BranchProbability.
853 return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
856 const BasicBlock *
857 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
858 auto MaxProb = BranchProbability::getZero();
859 const BasicBlock *MaxSucc = nullptr;
861 for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
862 const BasicBlock *Succ = *I;
863 auto Prob = getEdgeProbability(BB, Succ);
864 if (Prob > MaxProb) {
865 MaxProb = Prob;
866 MaxSucc = Succ;
870 // Hot probability is at least 4/5 = 80%
871 if (MaxProb > BranchProbability(4, 5))
872 return MaxSucc;
874 return nullptr;
877 /// Get the raw edge probability for the edge. If can't find it, return a
878 /// default probability 1/N where N is the number of successors. Here an edge is
879 /// specified using PredBlock and an
880 /// index to the successors.
881 BranchProbability
882 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
883 unsigned IndexInSuccessors) const {
884 auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
886 if (I != Probs.end())
887 return I->second;
889 return {1, static_cast<uint32_t>(succ_size(Src))};
892 BranchProbability
893 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
894 succ_const_iterator Dst) const {
895 return getEdgeProbability(Src, Dst.getSuccessorIndex());
898 /// Get the raw edge probability calculated for the block pair. This returns the
899 /// sum of all raw edge probabilities from Src to Dst.
900 BranchProbability
901 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
902 const BasicBlock *Dst) const {
903 auto Prob = BranchProbability::getZero();
904 bool FoundProb = false;
905 for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
906 if (*I == Dst) {
907 auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
908 if (MapI != Probs.end()) {
909 FoundProb = true;
910 Prob += MapI->second;
913 uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
914 return FoundProb ? Prob : BranchProbability(1, succ_num);
917 /// Set the edge probability for a given edge specified by PredBlock and an
918 /// index to the successors.
919 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
920 unsigned IndexInSuccessors,
921 BranchProbability Prob) {
922 Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
923 Handles.insert(BasicBlockCallbackVH(Src, this));
924 LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
925 << IndexInSuccessors << " successor probability to " << Prob
926 << "\n");
929 raw_ostream &
930 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
931 const BasicBlock *Src,
932 const BasicBlock *Dst) const {
933 const BranchProbability Prob = getEdgeProbability(Src, Dst);
934 OS << "edge " << Src->getName() << " -> " << Dst->getName()
935 << " probability is " << Prob
936 << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
938 return OS;
941 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
942 for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
943 auto Key = I->first;
944 if (Key.first == BB)
945 Probs.erase(Key);
949 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
950 const TargetLibraryInfo *TLI) {
951 LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
952 << " ----\n\n");
953 LastF = &F; // Store the last function we ran on for printing.
954 assert(PostDominatedByUnreachable.empty());
955 assert(PostDominatedByColdCall.empty());
957 // Record SCC numbers of blocks in the CFG to identify irreducible loops.
958 // FIXME: We could only calculate this if the CFG is known to be irreducible
959 // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
960 int SccNum = 0;
961 SccInfo SccI;
962 for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
963 ++It, ++SccNum) {
964 // Ignore single-block SCCs since they either aren't loops or LoopInfo will
965 // catch them.
966 const std::vector<const BasicBlock *> &Scc = *It;
967 if (Scc.size() == 1)
968 continue;
970 LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
971 for (auto *BB : Scc) {
972 LLVM_DEBUG(dbgs() << " " << BB->getName());
973 SccI.SccNums[BB] = SccNum;
975 LLVM_DEBUG(dbgs() << "\n");
978 // Walk the basic blocks in post-order so that we can build up state about
979 // the successors of a block iteratively.
980 for (auto BB : post_order(&F.getEntryBlock())) {
981 LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
982 << "\n");
983 updatePostDominatedByUnreachable(BB);
984 updatePostDominatedByColdCall(BB);
985 // If there is no at least two successors, no sense to set probability.
986 if (BB->getTerminator()->getNumSuccessors() < 2)
987 continue;
988 if (calcMetadataWeights(BB))
989 continue;
990 if (calcInvokeHeuristics(BB))
991 continue;
992 if (calcUnreachableHeuristics(BB))
993 continue;
994 if (calcColdCallHeuristics(BB))
995 continue;
996 if (calcLoopBranchHeuristics(BB, LI, SccI))
997 continue;
998 if (calcPointerHeuristics(BB))
999 continue;
1000 if (calcZeroHeuristics(BB, TLI))
1001 continue;
1002 if (calcFloatingPointHeuristics(BB))
1003 continue;
1006 PostDominatedByUnreachable.clear();
1007 PostDominatedByColdCall.clear();
1009 if (PrintBranchProb &&
1010 (PrintBranchProbFuncName.empty() ||
1011 F.getName().equals(PrintBranchProbFuncName))) {
1012 print(dbgs());
1016 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1017 AnalysisUsage &AU) const {
1018 // We require DT so it's available when LI is available. The LI updating code
1019 // asserts that DT is also present so if we don't make sure that we have DT
1020 // here, that assert will trigger.
1021 AU.addRequired<DominatorTreeWrapperPass>();
1022 AU.addRequired<LoopInfoWrapperPass>();
1023 AU.addRequired<TargetLibraryInfoWrapperPass>();
1024 AU.setPreservesAll();
1027 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1028 const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1029 const TargetLibraryInfo &TLI =
1030 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1031 BPI.calculate(F, LI, &TLI);
1032 return false;
1035 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1037 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1038 const Module *) const {
1039 BPI.print(OS);
1042 AnalysisKey BranchProbabilityAnalysis::Key;
1043 BranchProbabilityInfo
1044 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1045 BranchProbabilityInfo BPI;
1046 BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1047 return BPI;
1050 PreservedAnalyses
1051 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1052 OS << "Printing analysis results of BPI for function "
1053 << "'" << F.getName() << "':"
1054 << "\n";
1055 AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1056 return PreservedAnalyses::all();