[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / DirectX / DXILResource.cpp
blob8e5b9867e6661bf6f04dd27e58b8821ec15eabef
1 //===- DXILResource.cpp - DXIL Resource helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL Resources.
10 ///
11 //===----------------------------------------------------------------------===//
13 #include "DXILResource.h"
14 #include "CBufferDataLayout.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/Metadata.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/Format.h"
22 using namespace llvm;
23 using namespace llvm::dxil;
25 template <typename T> void ResourceTable<T>::collect(Module &M) {
26 NamedMDNode *Entry = M.getNamedMetadata(MDName);
27 if (!Entry || Entry->getNumOperands() == 0)
28 return;
30 uint32_t Counter = 0;
31 for (auto *Res : Entry->operands()) {
32 Data.push_back(T(Counter++, hlsl::FrontendResource(cast<MDNode>(Res))));
36 template <> void ResourceTable<ConstantBuffer>::collect(Module &M) {
37 NamedMDNode *Entry = M.getNamedMetadata(MDName);
38 if (!Entry || Entry->getNumOperands() == 0)
39 return;
41 uint32_t Counter = 0;
42 for (auto *Res : Entry->operands()) {
43 Data.push_back(
44 ConstantBuffer(Counter++, hlsl::FrontendResource(cast<MDNode>(Res))));
46 // FIXME: share CBufferDataLayout with CBuffer load lowering.
47 // See https://github.com/llvm/llvm-project/issues/58381
48 CBufferDataLayout CBDL(M.getDataLayout(), /*IsLegacy*/ true);
49 for (auto &CB : Data)
50 CB.setSize(CBDL);
53 void Resources::collect(Module &M) {
54 UAVs.collect(M);
55 CBuffers.collect(M);
58 ResourceBase::ResourceBase(uint32_t I, hlsl::FrontendResource R)
59 : ID(I), GV(R.getGlobalVariable()), Name(""), Space(R.getSpace()),
60 LowerBound(R.getResourceIndex()), RangeSize(1) {
61 if (auto *ArrTy = dyn_cast<ArrayType>(GV->getValueType()))
62 RangeSize = ArrTy->getNumElements();
65 StringRef ResourceBase::getElementTypeName(ElementType ElTy) {
66 switch (ElTy) {
67 case ElementType::Invalid:
68 return "invalid";
69 case ElementType::I1:
70 return "i1";
71 case ElementType::I16:
72 return "i16";
73 case ElementType::U16:
74 return "u16";
75 case ElementType::I32:
76 return "i32";
77 case ElementType::U32:
78 return "u32";
79 case ElementType::I64:
80 return "i64";
81 case ElementType::U64:
82 return "u64";
83 case ElementType::F16:
84 return "f16";
85 case ElementType::F32:
86 return "f32";
87 case ElementType::F64:
88 return "f64";
89 case ElementType::SNormF16:
90 return "snorm_f16";
91 case ElementType::UNormF16:
92 return "unorm_f16";
93 case ElementType::SNormF32:
94 return "snorm_f32";
95 case ElementType::UNormF32:
96 return "unorm_f32";
97 case ElementType::SNormF64:
98 return "snorm_f64";
99 case ElementType::UNormF64:
100 return "unorm_f64";
101 case ElementType::PackedS8x32:
102 return "p32i8";
103 case ElementType::PackedU8x32:
104 return "p32u8";
106 llvm_unreachable("All ElementType enums are handled in switch");
109 void ResourceBase::printElementType(ResourceKind Kind, ElementType ElTy,
110 unsigned Alignment, raw_ostream &OS) {
111 switch (Kind) {
112 default:
113 // TODO: add vector size.
114 OS << right_justify(getElementTypeName(ElTy), Alignment);
115 break;
116 case ResourceKind::RawBuffer:
117 OS << right_justify("byte", Alignment);
118 break;
119 case ResourceKind::StructuredBuffer:
120 OS << right_justify("struct", Alignment);
121 break;
122 case ResourceKind::CBuffer:
123 case ResourceKind::Sampler:
124 OS << right_justify("NA", Alignment);
125 break;
126 case ResourceKind::Invalid:
127 case ResourceKind::NumEntries:
128 break;
132 StringRef ResourceBase::getKindName(ResourceKind Kind) {
133 switch (Kind) {
134 case ResourceKind::NumEntries:
135 case ResourceKind::Invalid:
136 return "invalid";
137 case ResourceKind::Texture1D:
138 return "1d";
139 case ResourceKind::Texture2D:
140 return "2d";
141 case ResourceKind::Texture2DMS:
142 return "2dMS";
143 case ResourceKind::Texture3D:
144 return "3d";
145 case ResourceKind::TextureCube:
146 return "cube";
147 case ResourceKind::Texture1DArray:
148 return "1darray";
149 case ResourceKind::Texture2DArray:
150 return "2darray";
151 case ResourceKind::Texture2DMSArray:
152 return "2darrayMS";
153 case ResourceKind::TextureCubeArray:
154 return "cubearray";
155 case ResourceKind::TypedBuffer:
156 return "buf";
157 case ResourceKind::RawBuffer:
158 return "rawbuf";
159 case ResourceKind::StructuredBuffer:
160 return "structbuf";
161 case ResourceKind::CBuffer:
162 return "cbuffer";
163 case ResourceKind::Sampler:
164 return "sampler";
165 case ResourceKind::TBuffer:
166 return "tbuffer";
167 case ResourceKind::RTAccelerationStructure:
168 return "ras";
169 case ResourceKind::FeedbackTexture2D:
170 return "fbtex2d";
171 case ResourceKind::FeedbackTexture2DArray:
172 return "fbtex2darray";
174 llvm_unreachable("All ResourceKind enums are handled in switch");
177 void ResourceBase::printKind(ResourceKind Kind, unsigned Alignment,
178 raw_ostream &OS, bool SRV, bool HasCounter,
179 uint32_t SampleCount) {
180 switch (Kind) {
181 default:
182 OS << right_justify(getKindName(Kind), Alignment);
183 break;
185 case ResourceKind::RawBuffer:
186 case ResourceKind::StructuredBuffer:
187 if (SRV)
188 OS << right_justify("r/o", Alignment);
189 else {
190 if (!HasCounter)
191 OS << right_justify("r/w", Alignment);
192 else
193 OS << right_justify("r/w+cnt", Alignment);
195 break;
196 case ResourceKind::TypedBuffer:
197 OS << right_justify("buf", Alignment);
198 break;
199 case ResourceKind::Texture2DMS:
200 case ResourceKind::Texture2DMSArray: {
201 std::string DimName = getKindName(Kind).str();
202 if (SampleCount)
203 DimName += std::to_string(SampleCount);
204 OS << right_justify(DimName, Alignment);
205 } break;
206 case ResourceKind::CBuffer:
207 case ResourceKind::Sampler:
208 OS << right_justify("NA", Alignment);
209 break;
210 case ResourceKind::Invalid:
211 case ResourceKind::NumEntries:
212 break;
216 void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
217 StringRef BindingPrefix) const {
218 std::string ResID = IDPrefix.str();
219 ResID += std::to_string(ID);
220 OS << right_justify(ResID, 8);
222 std::string Bind = BindingPrefix.str();
223 Bind += std::to_string(LowerBound);
224 if (Space)
225 Bind += ",space" + std::to_string(Space);
227 OS << right_justify(Bind, 15);
228 if (RangeSize != UINT_MAX)
229 OS << right_justify(std::to_string(RangeSize), 6) << "\n";
230 else
231 OS << right_justify("unbounded", 6) << "\n";
234 void UAVResource::print(raw_ostream &OS) const {
235 OS << "; " << left_justify(Name, 31);
237 OS << right_justify("UAV", 10);
239 printElementType(Shape, ExtProps.ElementType.value_or(ElementType::Invalid),
240 8, OS);
242 // FIXME: support SampleCount.
243 // See https://github.com/llvm/llvm-project/issues/58175
244 printKind(Shape, 12, OS, /*SRV*/ false, HasCounter);
245 // Print the binding part.
246 ResourceBase::print(OS, "U", "u");
249 ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
250 : ResourceBase(I, R) {}
252 void ConstantBuffer::setSize(CBufferDataLayout &DL) {
253 CBufferSizeInBytes = DL.getTypeAllocSizeInBytes(GV->getValueType());
256 void ConstantBuffer::print(raw_ostream &OS) const {
257 OS << "; " << left_justify(Name, 31);
259 OS << right_justify("cbuffer", 10);
261 printElementType(ResourceKind::CBuffer, ElementType::Invalid, 8, OS);
263 printKind(ResourceKind::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
264 // Print the binding part.
265 ResourceBase::print(OS, "CB", "cb");
268 template <typename T> void ResourceTable<T>::print(raw_ostream &OS) const {
269 for (auto &Res : Data)
270 Res.print(OS);
273 MDNode *ResourceBase::ExtendedProperties::write(LLVMContext &Ctx) const {
274 IRBuilder<> B(Ctx);
275 SmallVector<Metadata *> Entries;
276 if (ElementType) {
277 Entries.emplace_back(
278 ConstantAsMetadata::get(B.getInt32(TypedBufferElementType)));
279 Entries.emplace_back(ConstantAsMetadata::get(
280 B.getInt32(static_cast<uint32_t>(*ElementType))));
282 if (Entries.empty())
283 return nullptr;
284 return MDNode::get(Ctx, Entries);
287 void ResourceBase::write(LLVMContext &Ctx,
288 MutableArrayRef<Metadata *> Entries) const {
289 IRBuilder<> B(Ctx);
290 Entries[0] = ConstantAsMetadata::get(B.getInt32(ID));
291 Entries[1] = ConstantAsMetadata::get(GV);
292 Entries[2] = MDString::get(Ctx, Name);
293 Entries[3] = ConstantAsMetadata::get(B.getInt32(Space));
294 Entries[4] = ConstantAsMetadata::get(B.getInt32(LowerBound));
295 Entries[5] = ConstantAsMetadata::get(B.getInt32(RangeSize));
298 MDNode *UAVResource::write() const {
299 auto &Ctx = GV->getContext();
300 IRBuilder<> B(Ctx);
301 Metadata *Entries[11];
302 ResourceBase::write(Ctx, Entries);
303 Entries[6] =
304 ConstantAsMetadata::get(B.getInt32(static_cast<uint32_t>(Shape)));
305 Entries[7] = ConstantAsMetadata::get(B.getInt1(GloballyCoherent));
306 Entries[8] = ConstantAsMetadata::get(B.getInt1(HasCounter));
307 Entries[9] = ConstantAsMetadata::get(B.getInt1(IsROV));
308 Entries[10] = ExtProps.write(Ctx);
309 return MDNode::get(Ctx, Entries);
312 MDNode *ConstantBuffer::write() const {
313 auto &Ctx = GV->getContext();
314 IRBuilder<> B(Ctx);
315 Metadata *Entries[7];
316 ResourceBase::write(Ctx, Entries);
318 Entries[6] = ConstantAsMetadata::get(B.getInt32(CBufferSizeInBytes));
319 return MDNode::get(Ctx, Entries);
322 template <typename T> MDNode *ResourceTable<T>::write(Module &M) const {
323 if (Data.empty())
324 return nullptr;
325 SmallVector<Metadata *> MDs;
326 for (auto &Res : Data)
327 MDs.emplace_back(Res.write());
329 NamedMDNode *Entry = M.getNamedMetadata(MDName);
330 if (Entry)
331 Entry->eraseFromParent();
333 return MDNode::get(M.getContext(), MDs);
336 void Resources::write(Module &M) const {
337 Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr};
339 ResourceMDs[1] = UAVs.write(M);
341 ResourceMDs[2] = CBuffers.write(M);
343 bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr ||
344 ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr;
346 if (HasResource) {
347 NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources");
348 DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs));
351 NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs");
352 if (Entry)
353 Entry->eraseFromParent();
356 void Resources::print(raw_ostream &O) const {
357 O << ";\n"
358 << "; Resource Bindings:\n"
359 << ";\n"
360 << "; Name Type Format Dim "
361 "ID HLSL Bind Count\n"
362 << "; ------------------------------ ---------- ------- ----------- "
363 "------- -------------- ------\n";
365 CBuffers.print(O);
366 UAVs.print(O);
369 void Resources::dump() const { print(dbgs()); }