1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // DXContainerGlobalsPass implementation.
11 //===----------------------------------------------------------------------===//
13 #include "DXILShaderFlags.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/BinaryFormat/DXContainer.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/InitializePasses.h"
23 #include "llvm/MC/DXContainerPSVInfo.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/MD5.h"
26 #include "llvm/Transforms/Utils/ModuleUtils.h"
29 using namespace llvm::dxil
;
30 using namespace llvm::mcdxbc
;
33 class DXContainerGlobals
: public llvm::ModulePass
{
35 GlobalVariable
*buildContainerGlobal(Module
&M
, Constant
*Content
,
36 StringRef Name
, StringRef SectionName
);
37 GlobalVariable
*getFeatureFlags(Module
&M
);
38 GlobalVariable
*computeShaderHash(Module
&M
);
39 GlobalVariable
*buildSignature(Module
&M
, Signature
&Sig
, StringRef Name
,
40 StringRef SectionName
);
41 void addSignature(Module
&M
, SmallVector
<GlobalValue
*> &Globals
);
42 void addPipelineStateValidationInfo(Module
&M
,
43 SmallVector
<GlobalValue
*> &Globals
);
46 static char ID
; // Pass identification, replacement for typeid
47 DXContainerGlobals() : ModulePass(ID
) {
48 initializeDXContainerGlobalsPass(*PassRegistry::getPassRegistry());
51 StringRef
getPassName() const override
{
52 return "DXContainer Global Emitter";
55 bool runOnModule(Module
&M
) override
;
57 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
59 AU
.addRequired
<ShaderFlagsAnalysisWrapper
>();
65 bool DXContainerGlobals::runOnModule(Module
&M
) {
66 llvm::SmallVector
<GlobalValue
*> Globals
;
67 Globals
.push_back(getFeatureFlags(M
));
68 Globals
.push_back(computeShaderHash(M
));
69 addSignature(M
, Globals
);
70 addPipelineStateValidationInfo(M
, Globals
);
71 appendToCompilerUsed(M
, Globals
);
75 GlobalVariable
*DXContainerGlobals::getFeatureFlags(Module
&M
) {
76 const uint64_t FeatureFlags
=
77 static_cast<uint64_t>(getAnalysis
<ShaderFlagsAnalysisWrapper
>()
81 Constant
*FeatureFlagsConstant
=
82 ConstantInt::get(M
.getContext(), APInt(64, FeatureFlags
));
83 return buildContainerGlobal(M
, FeatureFlagsConstant
, "dx.sfi0", "SFI0");
86 GlobalVariable
*DXContainerGlobals::computeShaderHash(Module
&M
) {
88 cast
<ConstantDataArray
>(M
.getNamedGlobal("dx.dxil")->getInitializer());
90 Digest
.update(DXILConstant
->getRawDataValues());
91 MD5::MD5Result Result
= Digest
.final();
93 dxbc::ShaderHash HashData
= {0, {0}};
94 // The Hash's IncludesSource flag gets set whenever the hashed shader includes
96 if (M
.debug_compile_units_begin() != M
.debug_compile_units_end())
97 HashData
.Flags
= static_cast<uint32_t>(dxbc::HashFlags::IncludesSource
);
99 memcpy(reinterpret_cast<void *>(&HashData
.Digest
), Result
.data(), 16);
100 if (sys::IsBigEndianHost
)
101 HashData
.swapBytes();
102 StringRef
Data(reinterpret_cast<char *>(&HashData
), sizeof(dxbc::ShaderHash
));
104 Constant
*ModuleConstant
=
105 ConstantDataArray::get(M
.getContext(), arrayRefFromStringRef(Data
));
106 return buildContainerGlobal(M
, ModuleConstant
, "dx.hash", "HASH");
109 GlobalVariable
*DXContainerGlobals::buildContainerGlobal(
110 Module
&M
, Constant
*Content
, StringRef Name
, StringRef SectionName
) {
111 auto *GV
= new llvm::GlobalVariable(
112 M
, Content
->getType(), true, GlobalValue::PrivateLinkage
, Content
, Name
);
113 GV
->setSection(SectionName
);
114 GV
->setAlignment(Align(4));
118 GlobalVariable
*DXContainerGlobals::buildSignature(Module
&M
, Signature
&Sig
,
120 StringRef SectionName
) {
121 SmallString
<256> Data
;
122 raw_svector_ostream
OS(Data
);
125 ConstantDataArray::getString(M
.getContext(), Data
, /*AddNull*/ false);
126 return buildContainerGlobal(M
, Constant
, Name
, SectionName
);
129 void DXContainerGlobals::addSignature(Module
&M
,
130 SmallVector
<GlobalValue
*> &Globals
) {
131 // FIXME: support graphics shader.
132 // see issue https://github.com/llvm/llvm-project/issues/90504.
135 Globals
.emplace_back(buildSignature(M
, InputSig
, "dx.isg1", "ISG1"));
138 Globals
.emplace_back(buildSignature(M
, OutputSig
, "dx.osg1", "OSG1"));
141 void DXContainerGlobals::addPipelineStateValidationInfo(
142 Module
&M
, SmallVector
<GlobalValue
*> &Globals
) {
143 SmallString
<256> Data
;
144 raw_svector_ostream
OS(Data
);
146 Triple
TT(M
.getTargetTriple());
147 PSV
.BaseData
.MinimumWaveLaneCount
= 0;
148 PSV
.BaseData
.MaximumWaveLaneCount
= std::numeric_limits
<uint32_t>::max();
149 PSV
.BaseData
.ShaderStage
=
150 static_cast<uint8_t>(TT
.getEnvironment() - Triple::Pixel
);
152 // Hardcoded values here to unblock loading the shader into D3D.
154 // TODO: Lots more stuff to do here!
156 // See issue https://github.com/llvm/llvm-project/issues/96674.
157 PSV
.BaseData
.NumThreadsX
= 1;
158 PSV
.BaseData
.NumThreadsY
= 1;
159 PSV
.BaseData
.NumThreadsZ
= 1;
160 PSV
.EntryName
= "main";
162 PSV
.finalize(TT
.getEnvironment());
165 ConstantDataArray::getString(M
.getContext(), Data
, /*AddNull*/ false);
166 Globals
.emplace_back(buildContainerGlobal(M
, Constant
, "dx.psv0", "PSV0"));
169 char DXContainerGlobals::ID
= 0;
170 INITIALIZE_PASS_BEGIN(DXContainerGlobals
, "dxil-globals",
171 "DXContainer Global Emitter", false, true)
172 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper
)
173 INITIALIZE_PASS_END(DXContainerGlobals
, "dxil-globals",
174 "DXContainer Global Emitter", false, true)
176 ModulePass
*llvm::createDXContainerGlobalsPass() {
177 return new DXContainerGlobals();