1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "DXILTranslateMetadata.h"
10 #include "DXILResource.h"
11 #include "DXILResourceAnalysis.h"
12 #include "DXILShaderFlags.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/DXILMetadataAnalysis.h"
17 #include "llvm/Analysis/DXILResource.h"
18 #include "llvm/IR/BasicBlock.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DiagnosticInfo.h"
21 #include "llvm/IR/DiagnosticPrinter.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/MDBuilder.h"
26 #include "llvm/IR/Metadata.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/VersionTuple.h"
32 #include "llvm/TargetParser/Triple.h"
36 using namespace llvm::dxil
;
39 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
40 /// for TranslateMetadata pass
41 class DiagnosticInfoTranslateMD
: public DiagnosticInfo
{
47 /// \p M is the module for which the diagnostic is being emitted. \p Msg is
48 /// the message to show. Note that this class does not copy this message, so
49 /// this reference must be valid for the whole life time of the diagnostic.
50 DiagnosticInfoTranslateMD(const Module
&M
, const Twine
&Msg
,
51 DiagnosticSeverity Severity
= DS_Error
)
52 : DiagnosticInfo(DK_Unsupported
, Severity
), Msg(Msg
), Mod(M
) {}
54 void print(DiagnosticPrinter
&DP
) const override
{
55 DP
<< Mod
.getName() << ": " << Msg
<< '\n';
59 enum class EntryPropsTag
{
77 static NamedMDNode
*emitResourceMetadata(Module
&M
, DXILBindingMap
&DBM
,
78 DXILResourceTypeMap
&DRTM
,
79 const dxil::Resources
&MDResources
) {
80 LLVMContext
&Context
= M
.getContext();
82 for (ResourceBindingInfo
&RI
: DBM
)
84 RI
.createSymbol(M
, DRTM
[RI
.getHandleTy()].createElementStruct());
86 SmallVector
<Metadata
*> SRVs
, UAVs
, CBufs
, Smps
;
87 for (const ResourceBindingInfo
&RI
: DBM
.srvs())
88 SRVs
.push_back(RI
.getAsMetadata(M
, DRTM
[RI
.getHandleTy()]));
89 for (const ResourceBindingInfo
&RI
: DBM
.uavs())
90 UAVs
.push_back(RI
.getAsMetadata(M
, DRTM
[RI
.getHandleTy()]));
91 for (const ResourceBindingInfo
&RI
: DBM
.cbuffers())
92 CBufs
.push_back(RI
.getAsMetadata(M
, DRTM
[RI
.getHandleTy()]));
93 for (const ResourceBindingInfo
&RI
: DBM
.samplers())
94 Smps
.push_back(RI
.getAsMetadata(M
, DRTM
[RI
.getHandleTy()]));
96 Metadata
*SRVMD
= SRVs
.empty() ? nullptr : MDNode::get(Context
, SRVs
);
97 Metadata
*UAVMD
= UAVs
.empty() ? nullptr : MDNode::get(Context
, UAVs
);
98 Metadata
*CBufMD
= CBufs
.empty() ? nullptr : MDNode::get(Context
, CBufs
);
99 Metadata
*SmpMD
= Smps
.empty() ? nullptr : MDNode::get(Context
, Smps
);
100 bool HasResources
= !DBM
.empty();
102 if (MDResources
.hasUAVs()) {
103 assert(!UAVMD
&& "Old and new UAV representations can't coexist");
104 UAVMD
= MDResources
.writeUAVs(M
);
108 if (MDResources
.hasCBuffers()) {
109 assert(!CBufMD
&& "Old and new cbuffer representations can't coexist");
110 CBufMD
= MDResources
.writeCBuffers(M
);
117 NamedMDNode
*ResourceMD
= M
.getOrInsertNamedMetadata("dx.resources");
118 ResourceMD
->addOperand(
119 MDNode::get(M
.getContext(), {SRVMD
, UAVMD
, CBufMD
, SmpMD
}));
124 static StringRef
getShortShaderStage(Triple::EnvironmentType Env
) {
130 case Triple::Geometry
:
136 case Triple::Compute
:
138 case Triple::Library
:
142 case Triple::Amplification
:
147 llvm_unreachable("Unsupported environment for DXIL generation.");
150 static uint32_t getShaderStage(Triple::EnvironmentType Env
) {
151 return (uint32_t)Env
- (uint32_t)llvm::Triple::Pixel
;
154 static SmallVector
<Metadata
*>
155 getTagValueAsMetadata(EntryPropsTag Tag
, uint64_t Value
, LLVMContext
&Ctx
) {
156 SmallVector
<Metadata
*> MDVals
;
157 MDVals
.emplace_back(ConstantAsMetadata::get(
158 ConstantInt::get(Type::getInt32Ty(Ctx
), static_cast<int>(Tag
))));
160 case EntryPropsTag::ShaderFlags
:
161 MDVals
.emplace_back(ConstantAsMetadata::get(
162 ConstantInt::get(Type::getInt64Ty(Ctx
), Value
)));
164 case EntryPropsTag::ShaderKind
:
165 MDVals
.emplace_back(ConstantAsMetadata::get(
166 ConstantInt::get(Type::getInt32Ty(Ctx
), Value
)));
168 case EntryPropsTag::GSState
:
169 case EntryPropsTag::DSState
:
170 case EntryPropsTag::HSState
:
171 case EntryPropsTag::NumThreads
:
172 case EntryPropsTag::AutoBindingSpace
:
173 case EntryPropsTag::RayPayloadSize
:
174 case EntryPropsTag::RayAttribSize
:
175 case EntryPropsTag::MSState
:
176 case EntryPropsTag::ASStateTag
:
177 case EntryPropsTag::WaveSize
:
178 case EntryPropsTag::EntryRootSig
:
179 llvm_unreachable("NYI: Unhandled entry property tag");
185 getEntryPropAsMetadata(const EntryProperties
&EP
, uint64_t EntryShaderFlags
,
186 const Triple::EnvironmentType ShaderProfile
) {
187 SmallVector
<Metadata
*> MDVals
;
188 LLVMContext
&Ctx
= EP
.Entry
->getContext();
189 if (EntryShaderFlags
!= 0)
190 MDVals
.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags
,
191 EntryShaderFlags
, Ctx
));
193 if (EP
.Entry
!= nullptr) {
194 // FIXME: support more props.
195 // See https://github.com/llvm/llvm-project/issues/57948.
196 // Add shader kind for lib entries.
197 if (ShaderProfile
== Triple::EnvironmentType::Library
&&
198 EP
.ShaderStage
!= Triple::EnvironmentType::Library
)
199 MDVals
.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind
,
200 getShaderStage(EP
.ShaderStage
), Ctx
));
202 if (EP
.ShaderStage
== Triple::EnvironmentType::Compute
) {
203 MDVals
.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
204 Type::getInt32Ty(Ctx
), static_cast<int>(EntryPropsTag::NumThreads
))));
205 Metadata
*NumThreadVals
[] = {ConstantAsMetadata::get(ConstantInt::get(
206 Type::getInt32Ty(Ctx
), EP
.NumThreadsX
)),
207 ConstantAsMetadata::get(ConstantInt::get(
208 Type::getInt32Ty(Ctx
), EP
.NumThreadsY
)),
209 ConstantAsMetadata::get(ConstantInt::get(
210 Type::getInt32Ty(Ctx
), EP
.NumThreadsZ
))};
211 MDVals
.emplace_back(MDNode::get(Ctx
, NumThreadVals
));
216 return MDNode::get(Ctx
, MDVals
);
219 MDTuple
*constructEntryMetadata(const Function
*EntryFn
, MDTuple
*Signatures
,
220 MDNode
*Resources
, MDTuple
*Properties
,
222 // Each entry point metadata record specifies:
223 // * reference to the entry point function global symbol
225 // * list of signatures
226 // * list of resources
227 // * list of tag-value pairs of shader capabilities and other properties
230 EntryFn
? ValueAsMetadata::get(const_cast<Function
*>(EntryFn
)) : nullptr;
231 MDVals
[1] = MDString::get(Ctx
, EntryFn
? EntryFn
->getName() : "");
232 MDVals
[2] = Signatures
;
233 MDVals
[3] = Resources
;
234 MDVals
[4] = Properties
;
235 return MDNode::get(Ctx
, MDVals
);
238 static MDTuple
*emitEntryMD(const EntryProperties
&EP
, MDTuple
*Signatures
,
240 const uint64_t EntryShaderFlags
,
241 const Triple::EnvironmentType ShaderProfile
) {
242 MDTuple
*Properties
=
243 getEntryPropAsMetadata(EP
, EntryShaderFlags
, ShaderProfile
);
244 return constructEntryMetadata(EP
.Entry
, Signatures
, MDResources
, Properties
,
245 EP
.Entry
->getContext());
248 static void emitValidatorVersionMD(Module
&M
, const ModuleMetadataInfo
&MMDI
) {
249 if (MMDI
.ValidatorVersion
.empty())
252 LLVMContext
&Ctx
= M
.getContext();
253 IRBuilder
<> IRB(Ctx
);
256 ConstantAsMetadata::get(IRB
.getInt32(MMDI
.ValidatorVersion
.getMajor()));
257 MDVals
[1] = ConstantAsMetadata::get(
258 IRB
.getInt32(MMDI
.ValidatorVersion
.getMinor().value_or(0)));
259 NamedMDNode
*ValVerNode
= M
.getOrInsertNamedMetadata("dx.valver");
260 // Set validator version obtained from DXIL Metadata Analysis pass
261 ValVerNode
->clearOperands();
262 ValVerNode
->addOperand(MDNode::get(Ctx
, MDVals
));
265 static void emitShaderModelVersionMD(Module
&M
,
266 const ModuleMetadataInfo
&MMDI
) {
267 LLVMContext
&Ctx
= M
.getContext();
268 IRBuilder
<> IRB(Ctx
);
270 VersionTuple SM
= MMDI
.ShaderModelVersion
;
271 SMVals
[0] = MDString::get(Ctx
, getShortShaderStage(MMDI
.ShaderProfile
));
272 SMVals
[1] = ConstantAsMetadata::get(IRB
.getInt32(SM
.getMajor()));
273 SMVals
[2] = ConstantAsMetadata::get(IRB
.getInt32(SM
.getMinor().value_or(0)));
274 NamedMDNode
*SMMDNode
= M
.getOrInsertNamedMetadata("dx.shaderModel");
275 SMMDNode
->addOperand(MDNode::get(Ctx
, SMVals
));
278 static void emitDXILVersionTupleMD(Module
&M
, const ModuleMetadataInfo
&MMDI
) {
279 LLVMContext
&Ctx
= M
.getContext();
280 IRBuilder
<> IRB(Ctx
);
281 VersionTuple DXILVer
= MMDI
.DXILVersion
;
282 Metadata
*DXILVals
[2];
283 DXILVals
[0] = ConstantAsMetadata::get(IRB
.getInt32(DXILVer
.getMajor()));
285 ConstantAsMetadata::get(IRB
.getInt32(DXILVer
.getMinor().value_or(0)));
286 NamedMDNode
*DXILVerMDNode
= M
.getOrInsertNamedMetadata("dx.version");
287 DXILVerMDNode
->addOperand(MDNode::get(Ctx
, DXILVals
));
290 static MDTuple
*emitTopLevelLibraryNode(Module
&M
, MDNode
*RMD
,
291 uint64_t ShaderFlags
) {
292 LLVMContext
&Ctx
= M
.getContext();
293 MDTuple
*Properties
= nullptr;
294 if (ShaderFlags
!= 0) {
295 SmallVector
<Metadata
*> MDVals
;
297 getTagValueAsMetadata(EntryPropsTag::ShaderFlags
, ShaderFlags
, Ctx
));
298 Properties
= MDNode::get(Ctx
, MDVals
);
300 // Library has an entry metadata with resource table metadata and all other
302 return constructEntryMetadata(nullptr, nullptr, RMD
, Properties
, Ctx
);
305 // TODO: We might need to refactor this to be more generic,
306 // in case we need more metadata to be replaced.
307 static void translateBranchMetadata(Module
&M
) {
308 for (Function
&F
: M
) {
309 for (BasicBlock
&BB
: F
) {
310 Instruction
*BBTerminatorInst
= BB
.getTerminator();
312 MDNode
*HlslControlFlowMD
=
313 BBTerminatorInst
->getMetadata("hlsl.controlflow.hint");
315 if (!HlslControlFlowMD
)
318 assert(HlslControlFlowMD
->getNumOperands() == 2 &&
319 "invalid operands for hlsl.controlflow.hint");
321 MDBuilder
MDHelper(M
.getContext());
323 mdconst::extract
<ConstantInt
>(HlslControlFlowMD
->getOperand(1));
325 SmallVector
<llvm::Metadata
*, 2> Vals(
326 ArrayRef
<Metadata
*>{MDHelper
.createString("dx.controlflow.hints"),
327 MDHelper
.createConstant(Op1
)});
329 MDNode
*MDNode
= llvm::MDNode::get(M
.getContext(), Vals
);
331 BBTerminatorInst
->setMetadata("dx.controlflow.hints", MDNode
);
332 BBTerminatorInst
->setMetadata("hlsl.controlflow.hint", nullptr);
337 static void translateMetadata(Module
&M
, DXILBindingMap
&DBM
,
338 DXILResourceTypeMap
&DRTM
,
339 const Resources
&MDResources
,
340 const ModuleShaderFlags
&ShaderFlags
,
341 const ModuleMetadataInfo
&MMDI
) {
342 LLVMContext
&Ctx
= M
.getContext();
343 IRBuilder
<> IRB(Ctx
);
344 SmallVector
<MDNode
*> EntryFnMDNodes
;
346 emitValidatorVersionMD(M
, MMDI
);
347 emitShaderModelVersionMD(M
, MMDI
);
348 emitDXILVersionTupleMD(M
, MMDI
);
349 NamedMDNode
*NamedResourceMD
=
350 emitResourceMetadata(M
, DBM
, DRTM
, MDResources
);
352 (NamedResourceMD
!= nullptr) ? NamedResourceMD
->getOperand(0) : nullptr;
353 // FIXME: Add support to construct Signatures
354 // See https://github.com/llvm/llvm-project/issues/57928
355 MDTuple
*Signatures
= nullptr;
357 if (MMDI
.ShaderProfile
== Triple::EnvironmentType::Library
) {
358 // Get the combined shader flag mask of all functions in the library to be
359 // used as shader flags mask value associated with top-level library entry
361 uint64_t CombinedMask
= ShaderFlags
.getCombinedFlags();
362 EntryFnMDNodes
.emplace_back(
363 emitTopLevelLibraryNode(M
, ResourceMD
, CombinedMask
));
364 } else if (MMDI
.EntryPropertyVec
.size() > 1) {
365 M
.getContext().diagnose(DiagnosticInfoTranslateMD(
366 M
, "Non-library shader: One and only one entry expected"));
369 for (const EntryProperties
&EntryProp
: MMDI
.EntryPropertyVec
) {
370 const ComputedShaderFlags
&EntrySFMask
=
371 ShaderFlags
.getFunctionFlags(EntryProp
.Entry
);
373 // If ShaderProfile is Library, mask is already consolidated in the
374 // top-level library node. Hence it is not emitted.
375 uint64_t EntryShaderFlags
= 0;
376 if (MMDI
.ShaderProfile
!= Triple::EnvironmentType::Library
) {
377 EntryShaderFlags
= EntrySFMask
;
378 if (EntryProp
.ShaderStage
!= MMDI
.ShaderProfile
) {
379 M
.getContext().diagnose(DiagnosticInfoTranslateMD(
382 Twine(getShortShaderStage(EntryProp
.ShaderStage
) +
383 "' for entry '" + Twine(EntryProp
.Entry
->getName()) +
384 "' different from specified target profile '" +
385 Twine(Triple::getEnvironmentTypeName(MMDI
.ShaderProfile
) +
389 EntryFnMDNodes
.emplace_back(emitEntryMD(EntryProp
, Signatures
, ResourceMD
,
391 MMDI
.ShaderProfile
));
394 NamedMDNode
*EntryPointsNamedMD
=
395 M
.getOrInsertNamedMetadata("dx.entryPoints");
396 for (auto *Entry
: EntryFnMDNodes
)
397 EntryPointsNamedMD
->addOperand(Entry
);
400 PreservedAnalyses
DXILTranslateMetadata::run(Module
&M
,
401 ModuleAnalysisManager
&MAM
) {
402 DXILBindingMap
&DBM
= MAM
.getResult
<DXILResourceBindingAnalysis
>(M
);
403 DXILResourceTypeMap
&DRTM
= MAM
.getResult
<DXILResourceTypeAnalysis
>(M
);
404 const dxil::Resources
&MDResources
= MAM
.getResult
<DXILResourceMDAnalysis
>(M
);
405 const ModuleShaderFlags
&ShaderFlags
= MAM
.getResult
<ShaderFlagsAnalysis
>(M
);
406 const dxil::ModuleMetadataInfo MMDI
= MAM
.getResult
<DXILMetadataAnalysis
>(M
);
408 translateMetadata(M
, DBM
, DRTM
, MDResources
, ShaderFlags
, MMDI
);
409 translateBranchMetadata(M
);
411 return PreservedAnalyses::all();
415 class DXILTranslateMetadataLegacy
: public ModulePass
{
417 static char ID
; // Pass identification, replacement for typeid
418 explicit DXILTranslateMetadataLegacy() : ModulePass(ID
) {}
420 StringRef
getPassName() const override
{ return "DXIL Translate Metadata"; }
422 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
423 AU
.addRequired
<DXILResourceTypeWrapperPass
>();
424 AU
.addRequired
<DXILResourceBindingWrapperPass
>();
425 AU
.addRequired
<DXILResourceMDWrapper
>();
426 AU
.addRequired
<ShaderFlagsAnalysisWrapper
>();
427 AU
.addRequired
<DXILMetadataAnalysisWrapperPass
>();
428 AU
.addPreserved
<DXILResourceBindingWrapperPass
>();
429 AU
.addPreserved
<DXILResourceMDWrapper
>();
430 AU
.addPreserved
<DXILMetadataAnalysisWrapperPass
>();
431 AU
.addPreserved
<ShaderFlagsAnalysisWrapper
>();
434 bool runOnModule(Module
&M
) override
{
435 DXILBindingMap
&DBM
=
436 getAnalysis
<DXILResourceBindingWrapperPass
>().getBindingMap();
437 DXILResourceTypeMap
&DRTM
=
438 getAnalysis
<DXILResourceTypeWrapperPass
>().getResourceTypeMap();
439 const dxil::Resources
&MDResources
=
440 getAnalysis
<DXILResourceMDWrapper
>().getDXILResource();
441 const ModuleShaderFlags
&ShaderFlags
=
442 getAnalysis
<ShaderFlagsAnalysisWrapper
>().getShaderFlags();
443 dxil::ModuleMetadataInfo MMDI
=
444 getAnalysis
<DXILMetadataAnalysisWrapperPass
>().getModuleMetadata();
446 translateMetadata(M
, DBM
, DRTM
, MDResources
, ShaderFlags
, MMDI
);
447 translateBranchMetadata(M
);
454 char DXILTranslateMetadataLegacy::ID
= 0;
456 ModulePass
*llvm::createDXILTranslateMetadataLegacyPass() {
457 return new DXILTranslateMetadataLegacy();
460 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy
, "dxil-translate-metadata",
461 "DXIL Translate Metadata", false, false)
462 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass
)
463 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper
)
464 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper
)
465 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass
)
466 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy
, "dxil-translate-metadata",
467 "DXIL Translate Metadata", false, false)