1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements utilities for working with Profiling Metadata.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/IR/ProfDataUtils.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/LLVMContext.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/Support/BranchProbability.h"
22 #include "llvm/Support/CommandLine.h"
28 // MD_prof nodes have the following layout
31 // { String name, Array of i32 }
34 // { MDString, [i32, i32, ...]}
36 // Concretely for Branch Weights
37 // { "branch_weights", [i32 1, i32 10000]}
39 // We maintain some constants here to ensure that we access the branch weights
40 // correctly, and can change the behavior in the future if the layout changes
42 // The index at which the weights vector starts
43 constexpr unsigned WeightsIdx
= 1;
45 // the minimum number of operands for MD_prof nodes with branch weights
46 constexpr unsigned MinBWOps
= 3;
48 // We may want to add support for other MD_prof types, so provide an abstraction
49 // for checking the metadata type.
50 bool isTargetMD(const MDNode
*ProfData
, const char *Name
, unsigned MinOps
) {
51 // TODO: This routine may be simplified if MD_prof used an enum instead of a
52 // string to differentiate the types of MD_prof nodes.
53 if (!ProfData
|| !Name
|| MinOps
< 2)
56 unsigned NOps
= ProfData
->getNumOperands();
60 auto *ProfDataName
= dyn_cast
<MDString
>(ProfData
->getOperand(0));
64 return ProfDataName
->getString().equals(Name
);
71 bool hasProfMD(const Instruction
&I
) {
72 return nullptr != I
.getMetadata(LLVMContext::MD_prof
);
75 bool isBranchWeightMD(const MDNode
*ProfileData
) {
76 return isTargetMD(ProfileData
, "branch_weights", MinBWOps
);
79 bool hasBranchWeightMD(const Instruction
&I
) {
80 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
81 return isBranchWeightMD(ProfileData
);
84 bool hasValidBranchWeightMD(const Instruction
&I
) {
85 return getValidBranchWeightMDNode(I
);
88 MDNode
*getBranchWeightMDNode(const Instruction
&I
) {
89 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
90 if (!isBranchWeightMD(ProfileData
))
95 MDNode
*getValidBranchWeightMDNode(const Instruction
&I
) {
96 auto *ProfileData
= getBranchWeightMDNode(I
);
97 if (ProfileData
&& ProfileData
->getNumOperands() == 1 + I
.getNumSuccessors())
102 void extractFromBranchWeightMD(const MDNode
*ProfileData
,
103 SmallVectorImpl
<uint32_t> &Weights
) {
104 assert(isBranchWeightMD(ProfileData
) && "wrong metadata");
106 unsigned NOps
= ProfileData
->getNumOperands();
107 assert(WeightsIdx
< NOps
&& "Weights Index must be less than NOps.");
108 Weights
.resize(NOps
- WeightsIdx
);
110 for (unsigned Idx
= WeightsIdx
, E
= NOps
; Idx
!= E
; ++Idx
) {
111 ConstantInt
*Weight
=
112 mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(Idx
));
113 assert(Weight
&& "Malformed branch_weight in MD_prof node");
114 assert(Weight
->getValue().getActiveBits() <= 32 &&
115 "Too many bits for uint32_t");
116 Weights
[Idx
- WeightsIdx
] = Weight
->getZExtValue();
120 bool extractBranchWeights(const MDNode
*ProfileData
,
121 SmallVectorImpl
<uint32_t> &Weights
) {
122 if (!isBranchWeightMD(ProfileData
))
124 extractFromBranchWeightMD(ProfileData
, Weights
);
128 bool extractBranchWeights(const Instruction
&I
,
129 SmallVectorImpl
<uint32_t> &Weights
) {
130 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
131 return extractBranchWeights(ProfileData
, Weights
);
134 bool extractBranchWeights(const Instruction
&I
, uint64_t &TrueVal
,
135 uint64_t &FalseVal
) {
136 assert((I
.getOpcode() == Instruction::Br
||
137 I
.getOpcode() == Instruction::Select
) &&
138 "Looking for branch weights on something besides branch, select, or "
141 SmallVector
<uint32_t, 2> Weights
;
142 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
143 if (!extractBranchWeights(ProfileData
, Weights
))
146 if (Weights
.size() > 2)
149 TrueVal
= Weights
[0];
150 FalseVal
= Weights
[1];
154 bool extractProfTotalWeight(const MDNode
*ProfileData
, uint64_t &TotalVal
) {
159 auto *ProfDataName
= dyn_cast
<MDString
>(ProfileData
->getOperand(0));
163 if (ProfDataName
->getString().equals("branch_weights")) {
164 for (unsigned Idx
= 1; Idx
< ProfileData
->getNumOperands(); Idx
++) {
165 auto *V
= mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(Idx
));
166 assert(V
&& "Malformed branch_weight in MD_prof node");
167 TotalVal
+= V
->getValue().getZExtValue();
172 if (ProfDataName
->getString().equals("VP") &&
173 ProfileData
->getNumOperands() > 3) {
174 TotalVal
= mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(2))
182 bool extractProfTotalWeight(const Instruction
&I
, uint64_t &TotalVal
) {
183 return extractProfTotalWeight(I
.getMetadata(LLVMContext::MD_prof
), TotalVal
);