1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements utilities for working with Profiling Metadata.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/IR/ProfDataUtils.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/LLVMContext.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/IR/ProfDataUtils.h"
27 // MD_prof nodes have the following layout
30 // { String name, Array of i32 }
33 // { MDString, [i32, i32, ...]}
35 // Concretely for Branch Weights
36 // { "branch_weights", [i32 1, i32 10000]}
38 // We maintain some constants here to ensure that we access the branch weights
39 // correctly, and can change the behavior in the future if the layout changes
41 // the minimum number of operands for MD_prof nodes with branch weights
42 constexpr unsigned MinBWOps
= 3;
44 // the minimum number of operands for MD_prof nodes with value profiles
45 constexpr unsigned MinVPOps
= 5;
47 // We may want to add support for other MD_prof types, so provide an abstraction
48 // for checking the metadata type.
49 bool isTargetMD(const MDNode
*ProfData
, const char *Name
, unsigned MinOps
) {
50 // TODO: This routine may be simplified if MD_prof used an enum instead of a
51 // string to differentiate the types of MD_prof nodes.
52 if (!ProfData
|| !Name
|| MinOps
< 2)
55 unsigned NOps
= ProfData
->getNumOperands();
59 auto *ProfDataName
= dyn_cast
<MDString
>(ProfData
->getOperand(0));
63 return ProfDataName
->getString() == Name
;
67 typename
= typename
std::enable_if
<std::is_arithmetic_v
<T
>>>
68 static void extractFromBranchWeightMD(const MDNode
*ProfileData
,
69 SmallVectorImpl
<T
> &Weights
) {
70 assert(isBranchWeightMD(ProfileData
) && "wrong metadata");
72 unsigned NOps
= ProfileData
->getNumOperands();
73 unsigned WeightsIdx
= getBranchWeightOffset(ProfileData
);
74 assert(WeightsIdx
< NOps
&& "Weights Index must be less than NOps.");
75 Weights
.resize(NOps
- WeightsIdx
);
77 for (unsigned Idx
= WeightsIdx
, E
= NOps
; Idx
!= E
; ++Idx
) {
79 mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(Idx
));
80 assert(Weight
&& "Malformed branch_weight in MD_prof node");
81 assert(Weight
->getValue().getActiveBits() <= (sizeof(T
) * 8) &&
82 "Too many bits for MD_prof branch_weight");
83 Weights
[Idx
- WeightsIdx
] = Weight
->getZExtValue();
91 bool hasProfMD(const Instruction
&I
) {
92 return I
.hasMetadata(LLVMContext::MD_prof
);
95 bool isBranchWeightMD(const MDNode
*ProfileData
) {
96 return isTargetMD(ProfileData
, "branch_weights", MinBWOps
);
99 bool isValueProfileMD(const MDNode
*ProfileData
) {
100 return isTargetMD(ProfileData
, "VP", MinVPOps
);
103 bool hasBranchWeightMD(const Instruction
&I
) {
104 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
105 return isBranchWeightMD(ProfileData
);
108 bool hasCountTypeMD(const Instruction
&I
) {
109 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
110 // Value profiles record count-type information.
111 if (isValueProfileMD(ProfileData
))
113 // Conservatively assume non CallBase instruction only get taken/not-taken
114 // branch probability, so not interpret them as count.
115 return isa
<CallBase
>(I
) && !isBranchWeightMD(ProfileData
);
118 bool hasValidBranchWeightMD(const Instruction
&I
) {
119 return getValidBranchWeightMDNode(I
);
122 bool hasBranchWeightOrigin(const Instruction
&I
) {
123 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
124 return hasBranchWeightOrigin(ProfileData
);
127 bool hasBranchWeightOrigin(const MDNode
*ProfileData
) {
128 if (!isBranchWeightMD(ProfileData
))
130 auto *ProfDataName
= dyn_cast
<MDString
>(ProfileData
->getOperand(1));
131 // NOTE: if we ever have more types of branch weight provenance,
132 // we need to check the string value is "expected". For now, we
133 // supply a more generic API, and avoid the spurious comparisons.
134 assert(ProfDataName
== nullptr || ProfDataName
->getString() == "expected");
135 return ProfDataName
!= nullptr;
138 unsigned getBranchWeightOffset(const MDNode
*ProfileData
) {
139 return hasBranchWeightOrigin(ProfileData
) ? 2 : 1;
142 unsigned getNumBranchWeights(const MDNode
&ProfileData
) {
143 return ProfileData
.getNumOperands() - getBranchWeightOffset(&ProfileData
);
146 MDNode
*getBranchWeightMDNode(const Instruction
&I
) {
147 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
148 if (!isBranchWeightMD(ProfileData
))
153 MDNode
*getValidBranchWeightMDNode(const Instruction
&I
) {
154 auto *ProfileData
= getBranchWeightMDNode(I
);
155 if (ProfileData
&& getNumBranchWeights(*ProfileData
) == I
.getNumSuccessors())
160 void extractFromBranchWeightMD32(const MDNode
*ProfileData
,
161 SmallVectorImpl
<uint32_t> &Weights
) {
162 extractFromBranchWeightMD(ProfileData
, Weights
);
165 void extractFromBranchWeightMD64(const MDNode
*ProfileData
,
166 SmallVectorImpl
<uint64_t> &Weights
) {
167 extractFromBranchWeightMD(ProfileData
, Weights
);
170 bool extractBranchWeights(const MDNode
*ProfileData
,
171 SmallVectorImpl
<uint32_t> &Weights
) {
172 if (!isBranchWeightMD(ProfileData
))
174 extractFromBranchWeightMD(ProfileData
, Weights
);
178 bool extractBranchWeights(const Instruction
&I
,
179 SmallVectorImpl
<uint32_t> &Weights
) {
180 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
181 return extractBranchWeights(ProfileData
, Weights
);
184 bool extractBranchWeights(const Instruction
&I
, uint64_t &TrueVal
,
185 uint64_t &FalseVal
) {
186 assert((I
.getOpcode() == Instruction::Br
||
187 I
.getOpcode() == Instruction::Select
) &&
188 "Looking for branch weights on something besides branch, select, or "
191 SmallVector
<uint32_t, 2> Weights
;
192 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
193 if (!extractBranchWeights(ProfileData
, Weights
))
196 if (Weights
.size() > 2)
199 TrueVal
= Weights
[0];
200 FalseVal
= Weights
[1];
204 bool extractProfTotalWeight(const MDNode
*ProfileData
, uint64_t &TotalVal
) {
209 auto *ProfDataName
= dyn_cast
<MDString
>(ProfileData
->getOperand(0));
213 if (ProfDataName
->getString() == "branch_weights") {
214 unsigned Offset
= getBranchWeightOffset(ProfileData
);
215 for (unsigned Idx
= Offset
; Idx
< ProfileData
->getNumOperands(); ++Idx
) {
216 auto *V
= mdconst::extract
<ConstantInt
>(ProfileData
->getOperand(Idx
));
217 TotalVal
+= V
->getValue().getZExtValue();
222 if (ProfDataName
->getString() == "VP" && ProfileData
->getNumOperands() > 3) {
223 TotalVal
= mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(2))
231 bool extractProfTotalWeight(const Instruction
&I
, uint64_t &TotalVal
) {
232 return extractProfTotalWeight(I
.getMetadata(LLVMContext::MD_prof
), TotalVal
);
235 void setBranchWeights(Instruction
&I
, ArrayRef
<uint32_t> Weights
,
237 MDBuilder
MDB(I
.getContext());
238 MDNode
*BranchWeights
= MDB
.createBranchWeights(Weights
, IsExpected
);
239 I
.setMetadata(LLVMContext::MD_prof
, BranchWeights
);
242 void scaleProfData(Instruction
&I
, uint64_t S
, uint64_t T
) {
243 assert(T
!= 0 && "Caller should guarantee");
244 auto *ProfileData
= I
.getMetadata(LLVMContext::MD_prof
);
245 if (ProfileData
== nullptr)
248 auto *ProfDataName
= dyn_cast
<MDString
>(ProfileData
->getOperand(0));
249 if (!ProfDataName
|| (ProfDataName
->getString() != "branch_weights" &&
250 ProfDataName
->getString() != "VP"))
253 if (!hasCountTypeMD(I
))
256 LLVMContext
&C
= I
.getContext();
259 SmallVector
<Metadata
*, 3> Vals
;
260 Vals
.push_back(ProfileData
->getOperand(0));
261 APInt
APS(128, S
), APT(128, T
);
262 if (ProfDataName
->getString() == "branch_weights" &&
263 ProfileData
->getNumOperands() > 0) {
264 // Using APInt::div may be expensive, but most cases should fit 64 bits.
266 mdconst::dyn_extract
<ConstantInt
>(
267 ProfileData
->getOperand(getBranchWeightOffset(ProfileData
)))
271 Vals
.push_back(MDB
.createConstant(ConstantInt::get(
272 Type::getInt32Ty(C
), Val
.udiv(APT
).getLimitedValue(UINT32_MAX
))));
273 } else if (ProfDataName
->getString() == "VP")
274 for (unsigned i
= 1; i
< ProfileData
->getNumOperands(); i
+= 2) {
275 // The first value is the key of the value profile, which will not change.
276 Vals
.push_back(ProfileData
->getOperand(i
));
278 mdconst::dyn_extract
<ConstantInt
>(ProfileData
->getOperand(i
+ 1))
281 // Don't scale the magic number.
282 if (Count
== NOMORE_ICP_MAGICNUM
) {
283 Vals
.push_back(ProfileData
->getOperand(i
+ 1));
286 // Using APInt::div may be expensive, but most cases should fit 64 bits.
287 APInt
Val(128, Count
);
289 Vals
.push_back(MDB
.createConstant(ConstantInt::get(
290 Type::getInt64Ty(C
), Val
.udiv(APT
).getLimitedValue())));
292 I
.setMetadata(LLVMContext::MD_prof
, MDNode::get(C
, Vals
));