1 //===- DXContainer.cpp - DXContainer object file implementation -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Object/DXContainer.h"
10 #include "llvm/BinaryFormat/DXContainer.h"
11 #include "llvm/Object/Error.h"
12 #include "llvm/Support/Alignment.h"
13 #include "llvm/Support/FormatVariadic.h"
16 using namespace llvm::object
;
18 static Error
parseFailed(const Twine
&Msg
) {
19 return make_error
<GenericBinaryError
>(Msg
.str(), object_error::parse_failed
);
23 static Error
readStruct(StringRef Buffer
, const char *Src
, T
&Struct
) {
24 // Don't read before the beginning or past the end of the file
25 if (Src
< Buffer
.begin() || Src
+ sizeof(T
) > Buffer
.end())
26 return parseFailed("Reading structure out of file bounds");
28 memcpy(&Struct
, Src
, sizeof(T
));
29 // DXContainer is always little endian
30 if (sys::IsBigEndianHost
)
32 return Error::success();
36 static Error
readInteger(StringRef Buffer
, const char *Src
, T
&Val
,
37 Twine Str
= "structure") {
38 static_assert(std::is_integral_v
<T
>,
39 "Cannot call readInteger on non-integral type.");
40 // Don't read before the beginning or past the end of the file
41 if (Src
< Buffer
.begin() || Src
+ sizeof(T
) > Buffer
.end())
42 return parseFailed(Twine("Reading ") + Str
+ " out of file bounds");
44 // The DXContainer offset table is comprised of uint32_t values but not padded
45 // to a 64-bit boundary. So Parts may start unaligned if there is an odd
46 // number of parts and part data itself is not required to be padded.
47 if (reinterpret_cast<uintptr_t>(Src
) % alignof(T
) != 0)
48 memcpy(reinterpret_cast<char *>(&Val
), Src
, sizeof(T
));
50 Val
= *reinterpret_cast<const T
*>(Src
);
51 // DXContainer is always little endian
52 if (sys::IsBigEndianHost
)
53 sys::swapByteOrder(Val
);
54 return Error::success();
57 DXContainer::DXContainer(MemoryBufferRef O
) : Data(O
) {}
59 Error
DXContainer::parseHeader() {
60 return readStruct(Data
.getBuffer(), Data
.getBuffer().data(), Header
);
63 Error
DXContainer::parseDXILHeader(StringRef Part
) {
65 return parseFailed("More than one DXIL part is present in the file");
66 const char *Current
= Part
.begin();
67 dxbc::ProgramHeader Header
;
68 if (Error Err
= readStruct(Part
, Current
, Header
))
70 Current
+= offsetof(dxbc::ProgramHeader
, Bitcode
) + Header
.Bitcode
.Offset
;
71 DXIL
.emplace(std::make_pair(Header
, Current
));
72 return Error::success();
75 Error
DXContainer::parseShaderFlags(StringRef Part
) {
77 return parseFailed("More than one SFI0 part is present in the file");
78 uint64_t FlagValue
= 0;
79 if (Error Err
= readInteger(Part
, Part
.begin(), FlagValue
))
81 ShaderFlags
= FlagValue
;
82 return Error::success();
85 Error
DXContainer::parseHash(StringRef Part
) {
87 return parseFailed("More than one HASH part is present in the file");
88 dxbc::ShaderHash ReadHash
;
89 if (Error Err
= readStruct(Part
, Part
.begin(), ReadHash
))
92 return Error::success();
95 Error
DXContainer::parsePSVInfo(StringRef Part
) {
97 return parseFailed("More than one PSV0 part is present in the file");
98 PSVInfo
= DirectX::PSVRuntimeInfo(Part
);
99 // Parsing the PSVRuntime info occurs late because we need to read data from
100 // other parts first.
101 return Error::success();
104 Error
DXContainer::parsePartOffsets() {
105 uint32_t LastOffset
=
106 sizeof(dxbc::Header
) + (Header
.PartCount
* sizeof(uint32_t));
107 const char *Current
= Data
.getBuffer().data() + sizeof(dxbc::Header
);
108 for (uint32_t Part
= 0; Part
< Header
.PartCount
; ++Part
) {
110 if (Error Err
= readInteger(Data
.getBuffer(), Current
, PartOffset
))
112 if (PartOffset
< LastOffset
)
115 "Part offset for part {0} begins before the previous part ends",
118 Current
+= sizeof(uint32_t);
119 if (PartOffset
>= Data
.getBufferSize())
120 return parseFailed("Part offset points beyond boundary of the file");
121 // To prevent overflow when reading the part name, we subtract the part name
122 // size from the buffer size, rather than adding to the offset. Since the
123 // file header is larger than the part header we can't reach this code
124 // unless the buffer is at least as large as a part header, so this
125 // subtraction can't underflow.
126 if (PartOffset
>= Data
.getBufferSize() - sizeof(dxbc::PartHeader::Name
))
127 return parseFailed("File not large enough to read part name");
128 PartOffsets
.push_back(PartOffset
);
131 dxbc::parsePartType(Data
.getBuffer().substr(PartOffset
, 4));
132 uint32_t PartDataStart
= PartOffset
+ sizeof(dxbc::PartHeader
);
134 if (Error Err
= readInteger(Data
.getBuffer(),
135 Data
.getBufferStart() + PartOffset
+ 4,
136 PartSize
, "part size"))
138 StringRef PartData
= Data
.getBuffer().substr(PartDataStart
, PartSize
);
139 LastOffset
= PartOffset
+ PartSize
;
141 case dxbc::PartType::DXIL
:
142 if (Error Err
= parseDXILHeader(PartData
))
145 case dxbc::PartType::SFI0
:
146 if (Error Err
= parseShaderFlags(PartData
))
149 case dxbc::PartType::HASH
:
150 if (Error Err
= parseHash(PartData
))
153 case dxbc::PartType::PSV0
:
154 if (Error Err
= parsePSVInfo(PartData
))
157 case dxbc::PartType::Unknown
:
162 // Fully parsing the PSVInfo requires knowing the shader kind which we read
163 // out of the program header in the DXIL part.
166 return parseFailed("Cannot fully parse pipeline state validation "
167 "information without DXIL part.");
168 if (Error Err
= PSVInfo
->parse(DXIL
->first
.ShaderKind
))
171 return Error::success();
174 Expected
<DXContainer
> DXContainer::create(MemoryBufferRef Object
) {
175 DXContainer
Container(Object
);
176 if (Error Err
= Container
.parseHeader())
177 return std::move(Err
);
178 if (Error Err
= Container
.parsePartOffsets())
179 return std::move(Err
);
183 void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset
) {
184 StringRef Buffer
= Container
.Data
.getBuffer();
185 const char *Current
= Buffer
.data() + Offset
;
186 // Offsets are validated during parsing, so all offsets in the container are
187 // valid and contain enough readable data to read a header.
188 cantFail(readStruct(Buffer
, Current
, IteratorState
.Part
));
190 StringRef(Current
+ sizeof(dxbc::PartHeader
), IteratorState
.Part
.Size
);
191 IteratorState
.Offset
= Offset
;
194 Error
DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind
) {
195 Triple::EnvironmentType ShaderStage
= dxbc::getShaderStage(ShaderKind
);
197 const char *Current
= Data
.begin();
198 if (Error Err
= readInteger(Data
, Current
, Size
))
200 Current
+= sizeof(uint32_t);
202 StringRef PSVInfoData
= Data
.substr(sizeof(uint32_t), Size
);
204 if (PSVInfoData
.size() < Size
)
206 "Pipeline state data extends beyond the bounds of the part");
208 using namespace dxbc::PSV
;
210 const uint32_t PSVVersion
= getVersion();
212 // Detect the PSVVersion by looking at the size field.
213 if (PSVVersion
== 2) {
214 v2::RuntimeInfo Info
;
215 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
217 if (sys::IsBigEndianHost
)
218 Info
.swapBytes(ShaderStage
);
220 } else if (PSVVersion
== 1) {
221 v1::RuntimeInfo Info
;
222 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
224 if (sys::IsBigEndianHost
)
225 Info
.swapBytes(ShaderStage
);
227 } else if (PSVVersion
== 0) {
228 v0::RuntimeInfo Info
;
229 if (Error Err
= readStruct(PSVInfoData
, Current
, Info
))
231 if (sys::IsBigEndianHost
)
232 Info
.swapBytes(ShaderStage
);
236 "Cannot read PSV Runtime Info, unsupported PSV version.");
240 uint32_t ResourceCount
= 0;
241 if (Error Err
= readInteger(Data
, Current
, ResourceCount
))
243 Current
+= sizeof(uint32_t);
245 if (ResourceCount
> 0) {
246 if (Error Err
= readInteger(Data
, Current
, Resources
.Stride
))
248 Current
+= sizeof(uint32_t);
250 size_t BindingDataSize
= Resources
.Stride
* ResourceCount
;
251 Resources
.Data
= Data
.substr(Current
- Data
.begin(), BindingDataSize
);
253 if (Resources
.Data
.size() < BindingDataSize
)
255 "Resource binding data extends beyond the bounds of the part");
257 Current
+= BindingDataSize
;
259 Resources
.Stride
= sizeof(v2::ResourceBindInfo
);
261 // PSV version 0 ends after the resource bindings.
263 return Error::success();
265 // String table starts at a 4-byte offset.
266 Current
= reinterpret_cast<const char *>(
267 alignTo
<4>(reinterpret_cast<const uintptr_t>(Current
)));
269 uint32_t StringTableSize
= 0;
270 if (Error Err
= readInteger(Data
, Current
, StringTableSize
))
272 if (StringTableSize
% 4 != 0)
273 return parseFailed("String table misaligned");
274 Current
+= sizeof(uint32_t);
275 StringTable
= StringRef(Current
, StringTableSize
);
277 Current
+= StringTableSize
;
279 uint32_t SemanticIndexTableSize
= 0;
280 if (Error Err
= readInteger(Data
, Current
, SemanticIndexTableSize
))
282 Current
+= sizeof(uint32_t);
284 SemanticIndexTable
.reserve(SemanticIndexTableSize
);
285 for (uint32_t I
= 0; I
< SemanticIndexTableSize
; ++I
) {
287 if (Error Err
= readInteger(Data
, Current
, Index
))
289 Current
+= sizeof(uint32_t);
290 SemanticIndexTable
.push_back(Index
);
293 uint8_t InputCount
= getSigInputCount();
294 uint8_t OutputCount
= getSigOutputCount();
295 uint8_t PatchOrPrimCount
= getSigPatchOrPrimCount();
297 uint32_t ElementCount
= InputCount
+ OutputCount
+ PatchOrPrimCount
;
299 if (ElementCount
> 0) {
300 if (Error Err
= readInteger(Data
, Current
, SigInputElements
.Stride
))
302 Current
+= sizeof(uint32_t);
303 // Assign the stride to all the arrays.
304 SigOutputElements
.Stride
= SigPatchOrPrimElements
.Stride
=
305 SigInputElements
.Stride
;
307 if (Data
.end() - Current
< ElementCount
* SigInputElements
.Stride
)
309 "Signature elements extend beyond the size of the part");
311 size_t InputSize
= SigInputElements
.Stride
* InputCount
;
312 SigInputElements
.Data
= Data
.substr(Current
- Data
.begin(), InputSize
);
313 Current
+= InputSize
;
315 size_t OutputSize
= SigOutputElements
.Stride
* OutputCount
;
316 SigOutputElements
.Data
= Data
.substr(Current
- Data
.begin(), OutputSize
);
317 Current
+= OutputSize
;
319 size_t PSize
= SigPatchOrPrimElements
.Stride
* PatchOrPrimCount
;
320 SigPatchOrPrimElements
.Data
= Data
.substr(Current
- Data
.begin(), PSize
);
324 ArrayRef
<uint8_t> OutputVectorCounts
= getOutputVectorCounts();
325 uint8_t PatchConstOrPrimVectorCount
= getPatchConstOrPrimVectorCount();
326 uint8_t InputVectorCount
= getInputVectorCount();
328 auto maskDwordSize
= [](uint8_t Vector
) {
329 return (static_cast<uint32_t>(Vector
) + 7) >> 3;
332 auto mapTableSize
= [maskDwordSize
](uint8_t X
, uint8_t Y
) {
333 return maskDwordSize(Y
) * X
* 4;
337 for (uint32_t I
= 0; I
< OutputVectorCounts
.size(); ++I
) {
338 // The vector mask is one bit per component and 4 components per vector.
339 // We can compute the number of dwords required by rounding up to the next
342 maskDwordSize(static_cast<uint32_t>(OutputVectorCounts
[I
]));
343 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
344 OutputVectorMasks
[I
].Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
348 if (ShaderStage
== Triple::Hull
&& PatchConstOrPrimVectorCount
> 0) {
349 uint32_t NumDwords
= maskDwordSize(PatchConstOrPrimVectorCount
);
350 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
351 PatchOrPrimMasks
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
356 // Input/Output mapping table
357 for (uint32_t I
= 0; I
< OutputVectorCounts
.size(); ++I
) {
358 if (InputVectorCount
== 0 || OutputVectorCounts
[I
] == 0)
360 uint32_t NumDwords
= mapTableSize(InputVectorCount
, OutputVectorCounts
[I
]);
361 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
362 InputOutputMap
[I
].Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
366 // Hull shader: Input/Patch mapping table
367 if (ShaderStage
== Triple::Hull
&& PatchConstOrPrimVectorCount
> 0 &&
368 InputVectorCount
> 0) {
370 mapTableSize(InputVectorCount
, PatchConstOrPrimVectorCount
);
371 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
372 InputPatchMap
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
376 // Domain Shader: Patch/Output mapping table
377 if (ShaderStage
== Triple::Domain
&& PatchConstOrPrimVectorCount
> 0 &&
378 OutputVectorCounts
[0] > 0) {
380 mapTableSize(PatchConstOrPrimVectorCount
, OutputVectorCounts
[0]);
381 size_t NumBytes
= NumDwords
* sizeof(uint32_t);
382 PatchOutputMap
.Data
= Data
.substr(Current
- Data
.begin(), NumBytes
);
386 return Error::success();
389 uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {
390 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
391 return P
->SigInputElements
;
392 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
393 return P
->SigInputElements
;
397 uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {
398 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
399 return P
->SigOutputElements
;
400 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
401 return P
->SigOutputElements
;
405 uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const {
406 if (const auto *P
= std::get_if
<dxbc::PSV::v2::RuntimeInfo
>(&BasicInfo
))
407 return P
->SigPatchOrPrimElements
;
408 if (const auto *P
= std::get_if
<dxbc::PSV::v1::RuntimeInfo
>(&BasicInfo
))
409 return P
->SigPatchOrPrimElements
;