[clang] NFC, add a "continue" bailout in the for-loop of
[llvm-project.git] / llvm / lib / Target / DirectX / DXILTranslateMetadata.cpp
blob5fd5c226eef8947a7969acf9f90c27971f03b1c1
1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "DXILTranslateMetadata.h"
10 #include "DXILResource.h"
11 #include "DXILResourceAnalysis.h"
12 #include "DXILShaderFlags.h"
13 #include "DirectX.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/DXILMetadataAnalysis.h"
17 #include "llvm/Analysis/DXILResource.h"
18 #include "llvm/IR/BasicBlock.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DiagnosticInfo.h"
21 #include "llvm/IR/DiagnosticPrinter.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/MDBuilder.h"
26 #include "llvm/IR/Metadata.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/VersionTuple.h"
32 #include "llvm/TargetParser/Triple.h"
33 #include <cstdint>
35 using namespace llvm;
36 using namespace llvm::dxil;
38 namespace {
39 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
40 /// for TranslateMetadata pass
41 class DiagnosticInfoTranslateMD : public DiagnosticInfo {
42 private:
43 const Twine &Msg;
44 const Module &Mod;
46 public:
47 /// \p M is the module for which the diagnostic is being emitted. \p Msg is
48 /// the message to show. Note that this class does not copy this message, so
49 /// this reference must be valid for the whole life time of the diagnostic.
50 DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg,
51 DiagnosticSeverity Severity = DS_Error)
52 : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
54 void print(DiagnosticPrinter &DP) const override {
55 DP << Mod.getName() << ": " << Msg << '\n';
59 enum class EntryPropsTag {
60 ShaderFlags = 0,
61 GSState,
62 DSState,
63 HSState,
64 NumThreads,
65 AutoBindingSpace,
66 RayPayloadSize,
67 RayAttribSize,
68 ShaderKind,
69 MSState,
70 ASStateTag,
71 WaveSize,
72 EntryRootSig,
75 } // namespace
77 static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM,
78 DXILResourceTypeMap &DRTM,
79 const dxil::Resources &MDResources) {
80 LLVMContext &Context = M.getContext();
82 for (ResourceBindingInfo &RI : DBM)
83 if (!RI.hasSymbol())
84 RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct());
86 SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
87 for (const ResourceBindingInfo &RI : DBM.srvs())
88 SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
89 for (const ResourceBindingInfo &RI : DBM.uavs())
90 UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
91 for (const ResourceBindingInfo &RI : DBM.cbuffers())
92 CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
93 for (const ResourceBindingInfo &RI : DBM.samplers())
94 Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()]));
96 Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
97 Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
98 Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs);
99 Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps);
100 bool HasResources = !DBM.empty();
102 if (MDResources.hasUAVs()) {
103 assert(!UAVMD && "Old and new UAV representations can't coexist");
104 UAVMD = MDResources.writeUAVs(M);
105 HasResources = true;
108 if (MDResources.hasCBuffers()) {
109 assert(!CBufMD && "Old and new cbuffer representations can't coexist");
110 CBufMD = MDResources.writeCBuffers(M);
111 HasResources = true;
114 if (!HasResources)
115 return nullptr;
117 NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
118 ResourceMD->addOperand(
119 MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
121 return ResourceMD;
124 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
125 switch (Env) {
126 case Triple::Pixel:
127 return "ps";
128 case Triple::Vertex:
129 return "vs";
130 case Triple::Geometry:
131 return "gs";
132 case Triple::Hull:
133 return "hs";
134 case Triple::Domain:
135 return "ds";
136 case Triple::Compute:
137 return "cs";
138 case Triple::Library:
139 return "lib";
140 case Triple::Mesh:
141 return "ms";
142 case Triple::Amplification:
143 return "as";
144 default:
145 break;
147 llvm_unreachable("Unsupported environment for DXIL generation.");
150 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
151 return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
154 static SmallVector<Metadata *>
155 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
156 SmallVector<Metadata *> MDVals;
157 MDVals.emplace_back(ConstantAsMetadata::get(
158 ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
159 switch (Tag) {
160 case EntryPropsTag::ShaderFlags:
161 MDVals.emplace_back(ConstantAsMetadata::get(
162 ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
163 break;
164 case EntryPropsTag::ShaderKind:
165 MDVals.emplace_back(ConstantAsMetadata::get(
166 ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
167 break;
168 case EntryPropsTag::GSState:
169 case EntryPropsTag::DSState:
170 case EntryPropsTag::HSState:
171 case EntryPropsTag::NumThreads:
172 case EntryPropsTag::AutoBindingSpace:
173 case EntryPropsTag::RayPayloadSize:
174 case EntryPropsTag::RayAttribSize:
175 case EntryPropsTag::MSState:
176 case EntryPropsTag::ASStateTag:
177 case EntryPropsTag::WaveSize:
178 case EntryPropsTag::EntryRootSig:
179 llvm_unreachable("NYI: Unhandled entry property tag");
181 return MDVals;
184 static MDTuple *
185 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
186 const Triple::EnvironmentType ShaderProfile) {
187 SmallVector<Metadata *> MDVals;
188 LLVMContext &Ctx = EP.Entry->getContext();
189 if (EntryShaderFlags != 0)
190 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
191 EntryShaderFlags, Ctx));
193 if (EP.Entry != nullptr) {
194 // FIXME: support more props.
195 // See https://github.com/llvm/llvm-project/issues/57948.
196 // Add shader kind for lib entries.
197 if (ShaderProfile == Triple::EnvironmentType::Library &&
198 EP.ShaderStage != Triple::EnvironmentType::Library)
199 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
200 getShaderStage(EP.ShaderStage), Ctx));
202 if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
203 MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
204 Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
205 Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
206 Type::getInt32Ty(Ctx), EP.NumThreadsX)),
207 ConstantAsMetadata::get(ConstantInt::get(
208 Type::getInt32Ty(Ctx), EP.NumThreadsY)),
209 ConstantAsMetadata::get(ConstantInt::get(
210 Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
211 MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
214 if (MDVals.empty())
215 return nullptr;
216 return MDNode::get(Ctx, MDVals);
219 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
220 MDNode *Resources, MDTuple *Properties,
221 LLVMContext &Ctx) {
222 // Each entry point metadata record specifies:
223 // * reference to the entry point function global symbol
224 // * unmangled name
225 // * list of signatures
226 // * list of resources
227 // * list of tag-value pairs of shader capabilities and other properties
228 Metadata *MDVals[5];
229 MDVals[0] =
230 EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
231 MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
232 MDVals[2] = Signatures;
233 MDVals[3] = Resources;
234 MDVals[4] = Properties;
235 return MDNode::get(Ctx, MDVals);
238 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
239 MDNode *MDResources,
240 const uint64_t EntryShaderFlags,
241 const Triple::EnvironmentType ShaderProfile) {
242 MDTuple *Properties =
243 getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
244 return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
245 EP.Entry->getContext());
248 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
249 if (MMDI.ValidatorVersion.empty())
250 return;
252 LLVMContext &Ctx = M.getContext();
253 IRBuilder<> IRB(Ctx);
254 Metadata *MDVals[2];
255 MDVals[0] =
256 ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
257 MDVals[1] = ConstantAsMetadata::get(
258 IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
259 NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
260 // Set validator version obtained from DXIL Metadata Analysis pass
261 ValVerNode->clearOperands();
262 ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
265 static void emitShaderModelVersionMD(Module &M,
266 const ModuleMetadataInfo &MMDI) {
267 LLVMContext &Ctx = M.getContext();
268 IRBuilder<> IRB(Ctx);
269 Metadata *SMVals[3];
270 VersionTuple SM = MMDI.ShaderModelVersion;
271 SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
272 SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
273 SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
274 NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
275 SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
278 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
279 LLVMContext &Ctx = M.getContext();
280 IRBuilder<> IRB(Ctx);
281 VersionTuple DXILVer = MMDI.DXILVersion;
282 Metadata *DXILVals[2];
283 DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
284 DXILVals[1] =
285 ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
286 NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
287 DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
290 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
291 uint64_t ShaderFlags) {
292 LLVMContext &Ctx = M.getContext();
293 MDTuple *Properties = nullptr;
294 if (ShaderFlags != 0) {
295 SmallVector<Metadata *> MDVals;
296 MDVals.append(
297 getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
298 Properties = MDNode::get(Ctx, MDVals);
300 // Library has an entry metadata with resource table metadata and all other
301 // MDNodes as null.
302 return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
305 // TODO: We might need to refactor this to be more generic,
306 // in case we need more metadata to be replaced.
307 static void translateBranchMetadata(Module &M) {
308 for (Function &F : M) {
309 for (BasicBlock &BB : F) {
310 Instruction *BBTerminatorInst = BB.getTerminator();
312 MDNode *HlslControlFlowMD =
313 BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
315 if (!HlslControlFlowMD)
316 continue;
318 assert(HlslControlFlowMD->getNumOperands() == 2 &&
319 "invalid operands for hlsl.controlflow.hint");
321 MDBuilder MDHelper(M.getContext());
322 ConstantInt *Op1 =
323 mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
325 SmallVector<llvm::Metadata *, 2> Vals(
326 ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
327 MDHelper.createConstant(Op1)});
329 MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals);
331 BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
332 BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
337 static void translateMetadata(Module &M, DXILBindingMap &DBM,
338 DXILResourceTypeMap &DRTM,
339 const Resources &MDResources,
340 const ModuleShaderFlags &ShaderFlags,
341 const ModuleMetadataInfo &MMDI) {
342 LLVMContext &Ctx = M.getContext();
343 IRBuilder<> IRB(Ctx);
344 SmallVector<MDNode *> EntryFnMDNodes;
346 emitValidatorVersionMD(M, MMDI);
347 emitShaderModelVersionMD(M, MMDI);
348 emitDXILVersionTupleMD(M, MMDI);
349 NamedMDNode *NamedResourceMD =
350 emitResourceMetadata(M, DBM, DRTM, MDResources);
351 auto *ResourceMD =
352 (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
353 // FIXME: Add support to construct Signatures
354 // See https://github.com/llvm/llvm-project/issues/57928
355 MDTuple *Signatures = nullptr;
357 if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
358 // Get the combined shader flag mask of all functions in the library to be
359 // used as shader flags mask value associated with top-level library entry
360 // metadata.
361 uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
362 EntryFnMDNodes.emplace_back(
363 emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
364 } else if (MMDI.EntryPropertyVec.size() > 1) {
365 M.getContext().diagnose(DiagnosticInfoTranslateMD(
366 M, "Non-library shader: One and only one entry expected"));
369 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
370 const ComputedShaderFlags &EntrySFMask =
371 ShaderFlags.getFunctionFlags(EntryProp.Entry);
373 // If ShaderProfile is Library, mask is already consolidated in the
374 // top-level library node. Hence it is not emitted.
375 uint64_t EntryShaderFlags = 0;
376 if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
377 EntryShaderFlags = EntrySFMask;
378 if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
379 M.getContext().diagnose(DiagnosticInfoTranslateMD(
381 "Shader stage '" +
382 Twine(getShortShaderStage(EntryProp.ShaderStage) +
383 "' for entry '" + Twine(EntryProp.Entry->getName()) +
384 "' different from specified target profile '" +
385 Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
386 "'"))));
389 EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
390 EntryShaderFlags,
391 MMDI.ShaderProfile));
394 NamedMDNode *EntryPointsNamedMD =
395 M.getOrInsertNamedMetadata("dx.entryPoints");
396 for (auto *Entry : EntryFnMDNodes)
397 EntryPointsNamedMD->addOperand(Entry);
400 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
401 ModuleAnalysisManager &MAM) {
402 DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M);
403 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
404 const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
405 const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
406 const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
408 translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
409 translateBranchMetadata(M);
411 return PreservedAnalyses::all();
414 namespace {
415 class DXILTranslateMetadataLegacy : public ModulePass {
416 public:
417 static char ID; // Pass identification, replacement for typeid
418 explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
420 StringRef getPassName() const override { return "DXIL Translate Metadata"; }
422 void getAnalysisUsage(AnalysisUsage &AU) const override {
423 AU.addRequired<DXILResourceTypeWrapperPass>();
424 AU.addRequired<DXILResourceBindingWrapperPass>();
425 AU.addRequired<DXILResourceMDWrapper>();
426 AU.addRequired<ShaderFlagsAnalysisWrapper>();
427 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
428 AU.addPreserved<DXILResourceBindingWrapperPass>();
429 AU.addPreserved<DXILResourceMDWrapper>();
430 AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
431 AU.addPreserved<ShaderFlagsAnalysisWrapper>();
434 bool runOnModule(Module &M) override {
435 DXILBindingMap &DBM =
436 getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
437 DXILResourceTypeMap &DRTM =
438 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
439 const dxil::Resources &MDResources =
440 getAnalysis<DXILResourceMDWrapper>().getDXILResource();
441 const ModuleShaderFlags &ShaderFlags =
442 getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
443 dxil::ModuleMetadataInfo MMDI =
444 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
446 translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI);
447 translateBranchMetadata(M);
448 return true;
452 } // namespace
454 char DXILTranslateMetadataLegacy::ID = 0;
456 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() {
457 return new DXILTranslateMetadataLegacy();
460 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
461 "DXIL Translate Metadata", false, false)
462 INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass)
463 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper)
464 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
465 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
466 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
467 "DXIL Translate Metadata", false, false)