[clang] Add test for CWG190 "Layout-compatible POD-struct types" (#121668)
[llvm-project.git] / llvm / lib / IR / ProfDataUtils.cpp
blob5441228b3291ee609307cfb754baf4e2bfb87c4f
1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for working with Profiling Metadata.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/IR/ProfDataUtils.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/LLVMContext.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/IR/ProfDataUtils.h"
23 using namespace llvm;
25 namespace {
27 // MD_prof nodes have the following layout
29 // In general:
30 // { String name, Array of i32 }
32 // In terms of Types:
33 // { MDString, [i32, i32, ...]}
35 // Concretely for Branch Weights
36 // { "branch_weights", [i32 1, i32 10000]}
38 // We maintain some constants here to ensure that we access the branch weights
39 // correctly, and can change the behavior in the future if the layout changes
41 // the minimum number of operands for MD_prof nodes with branch weights
42 constexpr unsigned MinBWOps = 3;
44 // the minimum number of operands for MD_prof nodes with value profiles
45 constexpr unsigned MinVPOps = 5;
47 // We may want to add support for other MD_prof types, so provide an abstraction
48 // for checking the metadata type.
49 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
50 // TODO: This routine may be simplified if MD_prof used an enum instead of a
51 // string to differentiate the types of MD_prof nodes.
52 if (!ProfData || !Name || MinOps < 2)
53 return false;
55 unsigned NOps = ProfData->getNumOperands();
56 if (NOps < MinOps)
57 return false;
59 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
60 if (!ProfDataName)
61 return false;
63 return ProfDataName->getString() == Name;
66 template <typename T,
67 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
68 static void extractFromBranchWeightMD(const MDNode *ProfileData,
69 SmallVectorImpl<T> &Weights) {
70 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
72 unsigned NOps = ProfileData->getNumOperands();
73 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
74 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
75 Weights.resize(NOps - WeightsIdx);
77 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
78 ConstantInt *Weight =
79 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
80 assert(Weight && "Malformed branch_weight in MD_prof node");
81 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
82 "Too many bits for MD_prof branch_weight");
83 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
87 } // namespace
89 namespace llvm {
91 bool hasProfMD(const Instruction &I) {
92 return I.hasMetadata(LLVMContext::MD_prof);
95 bool isBranchWeightMD(const MDNode *ProfileData) {
96 return isTargetMD(ProfileData, "branch_weights", MinBWOps);
99 bool isValueProfileMD(const MDNode *ProfileData) {
100 return isTargetMD(ProfileData, "VP", MinVPOps);
103 bool hasBranchWeightMD(const Instruction &I) {
104 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
105 return isBranchWeightMD(ProfileData);
108 bool hasCountTypeMD(const Instruction &I) {
109 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
110 // Value profiles record count-type information.
111 if (isValueProfileMD(ProfileData))
112 return true;
113 // Conservatively assume non CallBase instruction only get taken/not-taken
114 // branch probability, so not interpret them as count.
115 return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
118 bool hasValidBranchWeightMD(const Instruction &I) {
119 return getValidBranchWeightMDNode(I);
122 bool hasBranchWeightOrigin(const Instruction &I) {
123 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
124 return hasBranchWeightOrigin(ProfileData);
127 bool hasBranchWeightOrigin(const MDNode *ProfileData) {
128 if (!isBranchWeightMD(ProfileData))
129 return false;
130 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
131 // NOTE: if we ever have more types of branch weight provenance,
132 // we need to check the string value is "expected". For now, we
133 // supply a more generic API, and avoid the spurious comparisons.
134 assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
135 return ProfDataName != nullptr;
138 unsigned getBranchWeightOffset(const MDNode *ProfileData) {
139 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
142 unsigned getNumBranchWeights(const MDNode &ProfileData) {
143 return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
146 MDNode *getBranchWeightMDNode(const Instruction &I) {
147 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
148 if (!isBranchWeightMD(ProfileData))
149 return nullptr;
150 return ProfileData;
153 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
154 auto *ProfileData = getBranchWeightMDNode(I);
155 if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
156 return ProfileData;
157 return nullptr;
160 void extractFromBranchWeightMD32(const MDNode *ProfileData,
161 SmallVectorImpl<uint32_t> &Weights) {
162 extractFromBranchWeightMD(ProfileData, Weights);
165 void extractFromBranchWeightMD64(const MDNode *ProfileData,
166 SmallVectorImpl<uint64_t> &Weights) {
167 extractFromBranchWeightMD(ProfileData, Weights);
170 bool extractBranchWeights(const MDNode *ProfileData,
171 SmallVectorImpl<uint32_t> &Weights) {
172 if (!isBranchWeightMD(ProfileData))
173 return false;
174 extractFromBranchWeightMD(ProfileData, Weights);
175 return true;
178 bool extractBranchWeights(const Instruction &I,
179 SmallVectorImpl<uint32_t> &Weights) {
180 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
181 return extractBranchWeights(ProfileData, Weights);
184 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
185 uint64_t &FalseVal) {
186 assert((I.getOpcode() == Instruction::Br ||
187 I.getOpcode() == Instruction::Select) &&
188 "Looking for branch weights on something besides branch, select, or "
189 "switch");
191 SmallVector<uint32_t, 2> Weights;
192 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
193 if (!extractBranchWeights(ProfileData, Weights))
194 return false;
196 if (Weights.size() > 2)
197 return false;
199 TrueVal = Weights[0];
200 FalseVal = Weights[1];
201 return true;
204 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
205 TotalVal = 0;
206 if (!ProfileData)
207 return false;
209 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
210 if (!ProfDataName)
211 return false;
213 if (ProfDataName->getString() == "branch_weights") {
214 unsigned Offset = getBranchWeightOffset(ProfileData);
215 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
216 auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx));
217 TotalVal += V->getValue().getZExtValue();
219 return true;
222 if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
223 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
224 ->getValue()
225 .getZExtValue();
226 return true;
228 return false;
231 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
232 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
235 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
236 bool IsExpected) {
237 MDBuilder MDB(I.getContext());
238 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
239 I.setMetadata(LLVMContext::MD_prof, BranchWeights);
242 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
243 assert(T != 0 && "Caller should guarantee");
244 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
245 if (ProfileData == nullptr)
246 return;
248 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
249 if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
250 ProfDataName->getString() != "VP"))
251 return;
253 if (!hasCountTypeMD(I))
254 return;
256 LLVMContext &C = I.getContext();
258 MDBuilder MDB(C);
259 SmallVector<Metadata *, 3> Vals;
260 Vals.push_back(ProfileData->getOperand(0));
261 APInt APS(128, S), APT(128, T);
262 if (ProfDataName->getString() == "branch_weights" &&
263 ProfileData->getNumOperands() > 0) {
264 // Using APInt::div may be expensive, but most cases should fit 64 bits.
265 APInt Val(128,
266 mdconst::dyn_extract<ConstantInt>(
267 ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
268 ->getValue()
269 .getZExtValue());
270 Val *= APS;
271 Vals.push_back(MDB.createConstant(ConstantInt::get(
272 Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
273 } else if (ProfDataName->getString() == "VP")
274 for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
275 // The first value is the key of the value profile, which will not change.
276 Vals.push_back(ProfileData->getOperand(i));
277 uint64_t Count =
278 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
279 ->getValue()
280 .getZExtValue();
281 // Don't scale the magic number.
282 if (Count == NOMORE_ICP_MAGICNUM) {
283 Vals.push_back(ProfileData->getOperand(i + 1));
284 continue;
286 // Using APInt::div may be expensive, but most cases should fit 64 bits.
287 APInt Val(128, Count);
288 Val *= APS;
289 Vals.push_back(MDB.createConstant(ConstantInt::get(
290 Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
292 I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
295 } // namespace llvm