[clang] NFC, add a "continue" bailout in the for-loop of
[llvm-project.git] / llvm / lib / Target / DirectX / DXILShaderFlags.cpp
blob6a15bac153d857829776badf6ba701a5c1da64df
1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 /// Shader Flags.
11 ///
12 //===----------------------------------------------------------------------===//
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Intrinsics.h"
24 #include "llvm/IR/IntrinsicsDirectX.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
30 using namespace llvm;
31 using namespace llvm::dxil;
33 /// Update the shader flags mask based on the given instruction.
34 /// \param CSF Shader flags mask to update.
35 /// \param I Instruction to check.
36 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
37 const Instruction &I,
38 DXILResourceTypeMap &DRTM) {
39 if (!CSF.Doubles)
40 CSF.Doubles = I.getType()->isDoubleTy();
42 if (!CSF.Doubles) {
43 for (const Value *Op : I.operands()) {
44 if (Op->getType()->isDoubleTy()) {
45 CSF.Doubles = true;
46 break;
51 if (CSF.Doubles) {
52 switch (I.getOpcode()) {
53 case Instruction::FDiv:
54 case Instruction::UIToFP:
55 case Instruction::SIToFP:
56 case Instruction::FPToUI:
57 case Instruction::FPToSI:
58 CSF.DX11_1_DoubleExtensions = true;
59 break;
63 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
64 switch (II->getIntrinsicID()) {
65 default:
66 break;
67 case Intrinsic::dx_resource_handlefrombinding:
68 switch (DRTM[cast<TargetExtType>(II->getType())].getResourceKind()) {
69 case dxil::ResourceKind::StructuredBuffer:
70 case dxil::ResourceKind::RawBuffer:
71 CSF.EnableRawAndStructuredBuffers = true;
72 break;
73 default:
74 break;
76 break;
77 case Intrinsic::dx_resource_load_typedbuffer: {
78 dxil::ResourceTypeInfo &RTI =
79 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
80 if (RTI.isTyped())
81 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
82 break;
86 // Handle call instructions
87 if (auto *CI = dyn_cast<CallInst>(&I)) {
88 const Function *CF = CI->getCalledFunction();
89 // Merge-in shader flags mask of the called function in the current module
90 if (FunctionFlags.contains(CF))
91 CSF.merge(FunctionFlags[CF]);
93 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
94 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
98 /// Construct ModuleShaderFlags for module Module M
99 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
100 CallGraph CG(M);
102 // Compute Shader Flags Mask for all functions using post-order visit of SCC
103 // of the call graph.
104 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
105 ++SCCI) {
106 const std::vector<CallGraphNode *> &CurSCC = *SCCI;
108 // Union of shader masks of all functions in CurSCC
109 ComputedShaderFlags SCCSF;
110 // List of functions in CurSCC that are neither external nor declarations
111 // and hence whose flags are collected
112 SmallVector<Function *> CurSCCFuncs;
113 for (CallGraphNode *CGN : CurSCC) {
114 Function *F = CGN->getFunction();
115 if (!F)
116 continue;
118 if (F->isDeclaration()) {
119 assert(!F->getName().starts_with("dx.op.") &&
120 "DXIL Shader Flag analysis should not be run post-lowering.");
121 continue;
124 ComputedShaderFlags CSF;
125 for (const auto &BB : *F)
126 for (const auto &I : BB)
127 updateFunctionFlags(CSF, I, DRTM);
128 // Update combined shader flags mask for all functions in this SCC
129 SCCSF.merge(CSF);
131 CurSCCFuncs.push_back(F);
134 // Update combined shader flags mask for all functions of the module
135 CombinedSFMask.merge(SCCSF);
137 // Shader flags mask of each of the functions in an SCC of the call graph is
138 // the union of all functions in the SCC. Update shader flags masks of
139 // functions in CurSCC accordingly. This is trivially true if SCC contains
140 // one function.
141 for (Function *F : CurSCCFuncs)
142 // Merge SCCSF with that of F
143 FunctionFlags[F].merge(SCCSF);
147 void ComputedShaderFlags::print(raw_ostream &OS) const {
148 uint64_t FlagVal = (uint64_t) * this;
149 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
150 if (FlagVal == 0)
151 return;
152 OS << "; Note: shader requires additional functionality:\n";
153 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
154 if (FlagName) \
155 (OS << ";").indent(7) << Str << "\n";
156 #include "llvm/BinaryFormat/DXContainerConstants.def"
157 OS << "; Note: extra DXIL module flags:\n";
158 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
159 if (FlagName) \
160 (OS << ";").indent(7) << Str << "\n";
161 #include "llvm/BinaryFormat/DXContainerConstants.def"
162 OS << ";\n";
165 /// Return the shader flags mask of the specified function Func.
166 const ComputedShaderFlags &
167 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
168 auto Iter = FunctionFlags.find(Func);
169 assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
170 "Get Shader Flags : No Shader Flags Mask exists for function");
171 return Iter->second;
174 //===----------------------------------------------------------------------===//
175 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
177 // Provide an explicit template instantiation for the static ID.
178 AnalysisKey ShaderFlagsAnalysis::Key;
180 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
181 ModuleAnalysisManager &AM) {
182 DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
184 ModuleShaderFlags MSFI;
185 MSFI.initialize(M, DRTM);
187 return MSFI;
190 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
191 ModuleAnalysisManager &AM) {
192 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
193 // Print description of combined shader flags for all module functions
194 OS << "; Combined Shader Flags for Module\n";
195 FlagsInfo.getCombinedFlags().print(OS);
196 // Print shader flags mask for each of the module functions
197 OS << "; Shader Flags for Module Functions\n";
198 for (const auto &F : M.getFunctionList()) {
199 if (F.isDeclaration())
200 continue;
201 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
202 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
203 (uint64_t)(SFMask));
206 return PreservedAnalyses::all();
209 //===----------------------------------------------------------------------===//
210 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
212 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
213 DXILResourceTypeMap &DRTM =
214 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
216 MSFI.initialize(M, DRTM);
217 return false;
220 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
221 AU.setPreservesAll();
222 AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
225 char ShaderFlagsAnalysisWrapper::ID = 0;
227 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
228 "DXIL Shader Flag Analysis", true, true)
229 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
230 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
231 "DXIL Shader Flag Analysis", true, true)