1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This file contains helper objects and APIs for working with DXIL
12 //===----------------------------------------------------------------------===//
14 #include "DXILShaderFlags.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/IntrinsicsDirectX.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
31 using namespace llvm::dxil
;
33 /// Update the shader flags mask based on the given instruction.
34 /// \param CSF Shader flags mask to update.
35 /// \param I Instruction to check.
36 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags
&CSF
,
38 DXILResourceTypeMap
&DRTM
) {
40 CSF
.Doubles
= I
.getType()->isDoubleTy();
43 for (const Value
*Op
: I
.operands()) {
44 if (Op
->getType()->isDoubleTy()) {
52 switch (I
.getOpcode()) {
53 case Instruction::FDiv
:
54 case Instruction::UIToFP
:
55 case Instruction::SIToFP
:
56 case Instruction::FPToUI
:
57 case Instruction::FPToSI
:
58 CSF
.DX11_1_DoubleExtensions
= true;
63 if (auto *II
= dyn_cast
<IntrinsicInst
>(&I
)) {
64 switch (II
->getIntrinsicID()) {
67 case Intrinsic::dx_resource_handlefrombinding
:
68 switch (DRTM
[cast
<TargetExtType
>(II
->getType())].getResourceKind()) {
69 case dxil::ResourceKind::StructuredBuffer
:
70 case dxil::ResourceKind::RawBuffer
:
71 CSF
.EnableRawAndStructuredBuffers
= true;
77 case Intrinsic::dx_resource_load_typedbuffer
: {
78 dxil::ResourceTypeInfo
&RTI
=
79 DRTM
[cast
<TargetExtType
>(II
->getArgOperand(0)->getType())];
81 CSF
.TypedUAVLoadAdditionalFormats
|= RTI
.getTyped().ElementCount
> 1;
86 // Handle call instructions
87 if (auto *CI
= dyn_cast
<CallInst
>(&I
)) {
88 const Function
*CF
= CI
->getCalledFunction();
89 // Merge-in shader flags mask of the called function in the current module
90 if (FunctionFlags
.contains(CF
))
91 CSF
.merge(FunctionFlags
[CF
]);
93 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
94 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
98 /// Construct ModuleShaderFlags for module Module M
99 void ModuleShaderFlags::initialize(Module
&M
, DXILResourceTypeMap
&DRTM
) {
102 // Compute Shader Flags Mask for all functions using post-order visit of SCC
103 // of the call graph.
104 for (scc_iterator
<CallGraph
*> SCCI
= scc_begin(&CG
); !SCCI
.isAtEnd();
106 const std::vector
<CallGraphNode
*> &CurSCC
= *SCCI
;
108 // Union of shader masks of all functions in CurSCC
109 ComputedShaderFlags SCCSF
;
110 // List of functions in CurSCC that are neither external nor declarations
111 // and hence whose flags are collected
112 SmallVector
<Function
*> CurSCCFuncs
;
113 for (CallGraphNode
*CGN
: CurSCC
) {
114 Function
*F
= CGN
->getFunction();
118 if (F
->isDeclaration()) {
119 assert(!F
->getName().starts_with("dx.op.") &&
120 "DXIL Shader Flag analysis should not be run post-lowering.");
124 ComputedShaderFlags CSF
;
125 for (const auto &BB
: *F
)
126 for (const auto &I
: BB
)
127 updateFunctionFlags(CSF
, I
, DRTM
);
128 // Update combined shader flags mask for all functions in this SCC
131 CurSCCFuncs
.push_back(F
);
134 // Update combined shader flags mask for all functions of the module
135 CombinedSFMask
.merge(SCCSF
);
137 // Shader flags mask of each of the functions in an SCC of the call graph is
138 // the union of all functions in the SCC. Update shader flags masks of
139 // functions in CurSCC accordingly. This is trivially true if SCC contains
141 for (Function
*F
: CurSCCFuncs
)
142 // Merge SCCSF with that of F
143 FunctionFlags
[F
].merge(SCCSF
);
147 void ComputedShaderFlags::print(raw_ostream
&OS
) const {
148 uint64_t FlagVal
= (uint64_t) * this;
149 OS
<< formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal
);
152 OS
<< "; Note: shader requires additional functionality:\n";
153 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
155 (OS << ";").indent(7) << Str << "\n";
156 #include "llvm/BinaryFormat/DXContainerConstants.def"
157 OS
<< "; Note: extra DXIL module flags:\n";
158 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
160 (OS << ";").indent(7) << Str << "\n";
161 #include "llvm/BinaryFormat/DXContainerConstants.def"
165 /// Return the shader flags mask of the specified function Func.
166 const ComputedShaderFlags
&
167 ModuleShaderFlags::getFunctionFlags(const Function
*Func
) const {
168 auto Iter
= FunctionFlags
.find(Func
);
169 assert((Iter
!= FunctionFlags
.end() && Iter
->first
== Func
) &&
170 "Get Shader Flags : No Shader Flags Mask exists for function");
174 //===----------------------------------------------------------------------===//
175 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
177 // Provide an explicit template instantiation for the static ID.
178 AnalysisKey
ShaderFlagsAnalysis::Key
;
180 ModuleShaderFlags
ShaderFlagsAnalysis::run(Module
&M
,
181 ModuleAnalysisManager
&AM
) {
182 DXILResourceTypeMap
&DRTM
= AM
.getResult
<DXILResourceTypeAnalysis
>(M
);
184 ModuleShaderFlags MSFI
;
185 MSFI
.initialize(M
, DRTM
);
190 PreservedAnalyses
ShaderFlagsAnalysisPrinter::run(Module
&M
,
191 ModuleAnalysisManager
&AM
) {
192 const ModuleShaderFlags
&FlagsInfo
= AM
.getResult
<ShaderFlagsAnalysis
>(M
);
193 // Print description of combined shader flags for all module functions
194 OS
<< "; Combined Shader Flags for Module\n";
195 FlagsInfo
.getCombinedFlags().print(OS
);
196 // Print shader flags mask for each of the module functions
197 OS
<< "; Shader Flags for Module Functions\n";
198 for (const auto &F
: M
.getFunctionList()) {
199 if (F
.isDeclaration())
201 const ComputedShaderFlags
&SFMask
= FlagsInfo
.getFunctionFlags(&F
);
202 OS
<< formatv("; Function {0} : {1:x8}\n;\n", F
.getName(),
206 return PreservedAnalyses::all();
209 //===----------------------------------------------------------------------===//
210 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
212 bool ShaderFlagsAnalysisWrapper::runOnModule(Module
&M
) {
213 DXILResourceTypeMap
&DRTM
=
214 getAnalysis
<DXILResourceTypeWrapperPass
>().getResourceTypeMap();
216 MSFI
.initialize(M
, DRTM
);
220 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage
&AU
) const {
221 AU
.setPreservesAll();
222 AU
.addRequiredTransitive
<DXILResourceTypeWrapperPass
>();
225 char ShaderFlagsAnalysisWrapper::ID
= 0;
227 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper
, "dx-shader-flag-analysis",
228 "DXIL Shader Flag Analysis", true, true)
229 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass
)
230 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper
, "dx-shader-flag-analysis",
231 "DXIL Shader Flag Analysis", true, true)