1 //===- DXContainer.cpp - DXContainer object file implementation -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Object/DXContainer.h"
10 #include "llvm/BinaryFormat/DXContainer.h"
11 #include "llvm/Object/Error.h"
12 #include "llvm/Support/Alignment.h"
13 #include "llvm/Support/FormatVariadic.h"
16 using namespace llvm::object
;
18 static Error
parseFailed(const Twine
&Msg
) {
19 return make_error
<GenericBinaryError
>(Msg
.str(), object_error::parse_failed
);
23 static Error
readStruct(StringRef Buffer
, const char *Src
, T
&Struct
) {
24 // Don't read before the beginning or past the end of the file
25 if (Src
< Buffer
.begin() || Src
+ sizeof(T
) > Buffer
.end())
26 return parseFailed("Reading structure out of file bounds");
28 memcpy(&Struct
, Src
, sizeof(T
));
29 // DXContainer is always little endian
30 if (sys::IsBigEndianHost
)
32 return Error::success();
36 static Error
readInteger(StringRef Buffer
, const char *Src
, T
&Val
,
37 Twine Str
= "structure") {
38 static_assert(std::is_integral_v
<T
>,
39 "Cannot call readInteger on non-integral type.");
40 // Don't read before the beginning or past the end of the file
41 if (Src
< Buffer
.begin() || Src
+ sizeof(T
) > Buffer
.end())
42 return parseFailed(Twine("Reading ") + Str
+ " out of file bounds");
44 // The DXContainer offset table is comprised of uint32_t values but not padded
45 // to a 64-bit boundary. So Parts may start unaligned if there is an odd
46 // number of parts and part data itself is not required to be padded.
47 if (reinterpret_cast<uintptr_t>(Src
) % alignof(T
) != 0)
48 memcpy(reinterpret_cast<char *>(&Val
), Src
, sizeof(T
));
50 Val
= *reinterpret_cast<const T
*>(Src
);
51 // DXContainer is always little endian
52 if (sys::IsBigEndianHost
)
53 sys::swapByteOrder(Val
);
54 return Error::success();
57 DXContainer::DXContainer(MemoryBufferRef O
) : Data(O
) {}
59 Error
DXContainer::parseHeader() {
60 return readStruct(Data
.getBuffer(), Data
.getBuffer().data(), Header
);
63 Error
DXContainer::parseDXILHeader(StringRef Part
) {
65 return parseFailed("More than one DXIL part is present in the file");
66 const char *Current
= Part
.begin();
67 dxbc::ProgramHeader Header
;
68 if (Error Err
= readStruct(Part
, Current
, Header
))
70 Current
+= offsetof(dxbc::ProgramHeader
, Bitcode
) + Header
.Bitcode
.Offset
;
71 DXIL
.emplace(std::make_pair(Header
, Current
));
72 return Error::success();
75 Error
DXContainer::parseShaderFlags(StringRef Part
) {
77 return parseFailed("More than one SFI0 part is present in the file");
78 uint64_t FlagValue
= 0;
79 if (Error Err
= readInteger(Part
, Part
.begin(), FlagValue
))
81 ShaderFlags
= FlagValue
;
82 return Error::success();
85 Error
DXContainer::parseHash(StringRef Part
) {
87 return parseFailed("More than one HASH part is present in the file");
88 dxbc::ShaderHash ReadHash
;
89 if (Error Err
= readStruct(Part
, Part
.begin(), ReadHash
))
92 return Error::success();
95 Error
DXContainer::parsePSVInfo(StringRef Part
) {
97 return parseFailed("More than one PSV0 part is present in the file");
98 PSVInfo
= DirectX::PSVRuntimeInfo(Part
);
99 // Parsing the PSVRuntime info occurs late because we need to read data from
100 // other parts first.
101 return Error::success();
104 Error
DirectX::Signature::initialize(StringRef Part
) {
105 dxbc::ProgramSignatureHeader SigHeader
;
106 if (Error Err
= readStruct(Part
, Part
.begin(), SigHeader
))
108 size_t Size
= sizeof(dxbc::ProgramSignatureElement
) * SigHeader
.ParamCount
;
110 if (Part
.size() < Size
+ SigHeader
.FirstParamOffset
)
111 return parseFailed("Signature parameters extend beyond the part boundary");
113 Parameters
.Data
= Part
.substr(SigHeader
.FirstParamOffset
, Size
);
115 StringTableOffset
= SigHeader
.FirstParamOffset
+ static_cast<uint32_t>(Size
);
116 StringTable
= Part
.substr(SigHeader
.FirstParamOffset
+ Size
);
118 for (const auto &Param
: Parameters
) {
119 if (Param
.NameOffset
< StringTableOffset
)
120 return parseFailed("Invalid parameter name offset: name starts before "
121 "the first name offset");
122 if (Param
.NameOffset
- StringTableOffset
> StringTable
.size())
123 return parseFailed("Invalid parameter name offset: name starts after the "
124 "end of the part data");
126 return Error::success();
129 Error
DXContainer::parsePartOffsets() {
130 uint32_t LastOffset
=
131 sizeof(dxbc::Header
) + (Header
.PartCount
* sizeof(uint32_t));
132 const char *Current
= Data
.getBuffer().data() + sizeof(dxbc::Header
);
133 for (uint32_t Part
= 0; Part
< Header
.PartCount
; ++Part
) {
135 if (Error Err
= readInteger(Data
.getBuffer(), Current
, PartOffset
))
137 if (PartOffset
< LastOffset
)
140 "Part offset for part {0} begins before the previous part ends",
143 Current
+= sizeof(uint32_t);
144 if (PartOffset
>= Data
.getBufferSize())
145 return parseFailed("Part offset points beyond boundary of the file");
146 // To prevent overflow when reading the part name, we subtract the part name
147 // size from the buffer size, rather than adding to the offset. Since the
148 // file header is larger than the part header we can't reach this code
149 // unless the buffer is at least as large as a part header, so this
150 // subtraction can't underflow.
151 if (PartOffset
>= Data
.getBufferSize() - sizeof(dxbc::PartHeader::Name
))
152 return parseFailed("File not large enough to read part name");
153 PartOffsets
.push_back(PartOffset
);
156 dxbc::parsePartType(Data
.getBuffer().substr(PartOffset
, 4));
157 uint32_t PartDataStart
= PartOffset
+ sizeof(dxbc::PartHeader
);
159 if (Error Err
= readInteger(Data
.getBuffer(),
160 Data
.getBufferStart() + PartOffset
+ 4,
161 PartSize
, "part size"))
163 StringRef PartData
= Data
.getBuffer().substr(PartDataStart
, PartSize
);
164 LastOffset
= PartOffset
+ PartSize
;
166 case dxbc::PartType::DXIL
:
167 if (Error Err
= parseDXILHeader(PartData
))
170 case dxbc::PartType::SFI0
:
171 if (Error Err
= parseShaderFlags(PartData
))
174 case dxbc::PartType::HASH
:
175 if (Error Err
= parseHash(PartData
))
178 case dxbc::PartType::PSV0
:
179 if (Error Err
= parsePSVInfo(PartData
))
182 case dxbc::PartType::ISG1
:
183 if (Error Err
= InputSignature
.initialize(PartData
))
186 case dxbc::PartType::OSG1
:
187 if (Error Err
= OutputSignature
.initialize(PartData
))
190 case dxbc::PartType::PSG1
:
191 if (Error Err
= PatchConstantSignature
.initialize(PartData
))
194 case dxbc::PartType::Unknown
:
199 // Fully parsing the PSVInfo requires knowing the shader kind which we read
200 // out of the program header in the DXIL part.
203 return parseFailed("Cannot fully parse pipeline state validation "
204 "information without DXIL part.");
205 if (Error Err
= PSVInfo
->parse(DXIL
->first
.ShaderKind
))
208 return Error::success();
211 Expected
<DXContainer
> DXContainer::create(MemoryBufferRef Object
) {
212 DXContainer
Container(Object
);
213 if (Error Err
= Container
.parseHeader())
214 return std::move(Err
);
215 if (Error Err
= Container
.parsePartOffsets())
216 return std::move(Err
);
220 void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset
) {
221 StringRef Buffer
= Container
.Data
.getBuffer();
222 const char *Current
= Buffer
.data() + Offset
;
223 // Offsets are validated during parsing, so all offsets in the container are
224 // valid and contain enough readable data to read a header.
225 cantFail(readStruct(Buffer
, Current
, IteratorState
.Part
));
227 StringRef(Current
+ sizeof(dxbc::PartHeader
), IteratorState
.Part
.Size
);
228 IteratorState
.Offset
= Offset
;
231 Error
DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind
) {
232 Triple::EnvironmentType ShaderStage
= dxbc::getShaderStage(ShaderKind
);
234 const char *Current
= Data
.begin();
235 if (Error Err
= readInteger(Data
, Current
, Size
))
237 Current
+= sizeof(uint32_t);
239 StringRef PSVInfoData
= Data
.substr(sizeof(uint32_t), Size
);
241 if (PSVInfoData
.size() < Size
)
243 "Pipeline state data extends beyond the bounds of the part");
245 using namespace dxbc::PSV
;
247 const uint32_t PSVVersion
= getVersion();
249 // Detect the PSVVersion by looking at the size field.
250 if (PSVVersion
== 2) {
251 v2::RuntimeInfo Info
;
252 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
254 if (sys::IsBigEndianHost
)
255 Info
.swapBytes(ShaderStage
);
257 } else if (PSVVersion
== 1) {
258 v1::RuntimeInfo Info
;
259 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
261 if (sys::IsBigEndianHost
)
262 Info
.swapBytes(ShaderStage
);
264 } else if (PSVVersion
== 0) {
265 v0::RuntimeInfo Info
;
266 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
268 if (sys::IsBigEndianHost
)
269 Info
.swapBytes(ShaderStage
);
273 "Cannot read PSV Runtime Info, unsupported PSV version.");
277 uint32_t ResourceCount
= 0;
278 if (Error Err
= readInteger(Data
, Current
, ResourceCount
))
280 Current
+= sizeof(uint32_t);
282 if (ResourceCount
> 0) {
283 if (Error Err
= readInteger(Data
, Current
, Resources
.Stride
))
285 Current
+= sizeof(uint32_t);
287 size_t BindingDataSize
= Resources
.Stride
* ResourceCount
;
288 Resources
.Data
= Data
.substr(Current
- Data
.begin(), BindingDataSize
);
290 if (Resources
.Data
.size() < BindingDataSize
)
292 "Resource binding data extends beyond the bounds of the part");
294 Current
+= BindingDataSize
;
296 Resources
.Stride
= sizeof(v2::ResourceBindInfo
);
298 // PSV version 0 ends after the resource bindings.
300 return Error::success();
302 // String table starts at a 4-byte offset.
303 Current
= reinterpret_cast<const char *>(
304 alignTo
<4>(reinterpret_cast<const uintptr_t>(Current
)));
306 uint32_t StringTableSize
= 0;
307 if (Error Err
= readInteger(Data
, Current
, StringTableSize
))
309 if (StringTableSize
% 4 != 0)
310 return parseFailed("String table misaligned");
311 Current
+= sizeof(uint32_t);
312 StringTable
= StringRef(Current
, StringTableSize
);
314 Current
+= StringTableSize
;
316 uint32_t SemanticIndexTableSize
= 0;
317 if (Error Err
= readInteger(Data
, Current
, SemanticIndexTableSize
))
319 Current
+= sizeof(uint32_t);
321 SemanticIndexTable
.reserve(SemanticIndexTableSize
);
322 for (uint32_t I
= 0; I
< SemanticIndexTableSize
; ++I
) {
324 if (Error Err
= readInteger(Data
, Current
, Index
))
326 Current
+= sizeof(uint32_t);
327 SemanticIndexTable
.push_back(Index
);
330 uint8_t InputCount
= getSigInputCount();
331 uint8_t OutputCount
= getSigOutputCount();
332 uint8_t PatchOrPrimCount
= getSigPatchOrPrimCount();
334 uint32_t ElementCount
= InputCount
+ OutputCount
+ PatchOrPrimCount
;
336 if (ElementCount
> 0) {
337 if (Error Err
= readInteger(Data
, Current
, SigInputElements
.Stride
))
339 Current
+= sizeof(uint32_t);
340 // Assign the stride to all the arrays.
341 SigOutputElements
.Stride
= SigPatchOrPrimElements
.Stride
=
342 SigInputElements
.Stride
;
344 if (Data
.end() - Current
< ElementCount
* SigInputElements
.Stride
)
346 "Signature elements extend beyond the size of the part");
348 size_t InputSize
= SigInputElements
.Stride
* InputCount
;
349 SigInputElements
.Data
= Data
.substr(Current
- Data
.begin(), InputSize
);
350 Current
+= InputSize
;
352 size_t OutputSize
= SigOutputElements
.Stride
* OutputCount
;
353 SigOutputElements
.Data
= Data
.substr(Current
- Data
.begin(), OutputSize
);
354 Current
+= OutputSize
;
356 size_t PSize
= SigPatchOrPrimElements
.Stride
* PatchOrPrimCount
;
357 SigPatchOrPrimElements
.Data
= Data
.substr(Current
- Data
.begin(), PSize
);
361 ArrayRef
<uint8_t> OutputVectorCounts
= getOutputVectorCounts();
362 uint8_t PatchConstOrPrimVectorCount
= getPatchConstOrPrimVectorCount();
363 uint8_t InputVectorCount
= getInputVectorCount();
365 auto maskDwordSize
= [](uint8_t Vector
) {
366 return (static_cast<uint32_t>(Vector
) + 7) >> 3;
369 auto mapTableSize
= [maskDwordSize
](uint8_t X
, uint8_t Y
) {
370 return maskDwordSize(Y
) * X
* 4;
374 for (uint32_t I
= 0; I
< OutputVectorCounts
.size(); ++I
) {
375 // The vector mask is one bit per component and 4 components per vector.
376 // We can compute the number of dwords required by rounding up to the next
379 maskDwordSize(static_cast<uint32_t>(OutputVectorCounts
[I
]));
380 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
381 OutputVectorMasks
[I
].Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
385 if (ShaderStage
== Triple::Hull
&& PatchConstOrPrimVectorCount
> 0) {
386 uint32_t NumDwords
= maskDwordSize(PatchConstOrPrimVectorCount
);
387 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
388 PatchOrPrimMasks
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
393 // Input/Output mapping table
394 for (uint32_t I
= 0; I
< OutputVectorCounts
.size(); ++I
) {
395 if (InputVectorCount
== 0 || OutputVectorCounts
[I
] == 0)
397 uint32_t NumDwords
= mapTableSize(InputVectorCount
, OutputVectorCounts
[I
]);
398 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
399 InputOutputMap
[I
].Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
403 // Hull shader: Input/Patch mapping table
404 if (ShaderStage
== Triple::Hull
&& PatchConstOrPrimVectorCount
> 0 &&
405 InputVectorCount
> 0) {
407 mapTableSize(InputVectorCount
, PatchConstOrPrimVectorCount
);
408 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
409 InputPatchMap
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
413 // Domain Shader: Patch/Output mapping table
414 if (ShaderStage
== Triple::Domain
&& PatchConstOrPrimVectorCount
> 0 &&
415 OutputVectorCounts
[0] > 0) {
417 mapTableSize(PatchConstOrPrimVectorCount
, OutputVectorCounts
[0]);
418 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
419 PatchOutputMap
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
423 return Error::success();
426 uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {
427 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
428 return P
->SigInputElements
;
429 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
430 return P
->SigInputElements
;
434 uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {
435 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
436 return P
->SigOutputElements
;
437 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
438 return P
->SigOutputElements
;
442 uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const {
443 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
444 return P
->SigPatchOrPrimElements
;
445 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
446 return P
->SigPatchOrPrimElements
;