1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
13 //===----------------------------------------------------------------------===//
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Support/FormatVariadic.h"
25 using namespace clang
;
26 using namespace CodeGen
;
27 using namespace clang::hlsl
;
32 void addDxilValVersion(StringRef ValVersionStr
, llvm::Module
&M
) {
33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34 // Assume ValVersionStr is legal here.
36 if (Version
.tryParse(ValVersionStr
) || Version
.getBuild() ||
37 Version
.getSubminor() || !Version
.getMinor()) {
41 uint64_t Major
= Version
.getMajor();
42 uint64_t Minor
= *Version
.getMinor();
44 auto &Ctx
= M
.getContext();
45 IRBuilder
<> B(M
.getContext());
46 MDNode
*Val
= MDNode::get(Ctx
, {ConstantAsMetadata::get(B
.getInt32(Major
)),
47 ConstantAsMetadata::get(B
.getInt32(Minor
))});
48 StringRef DXILValKey
= "dx.valver";
49 auto *DXILValMD
= M
.getOrInsertNamedMetadata(DXILValKey
);
50 DXILValMD
->addOperand(Val
);
52 void addDisableOptimizations(llvm::Module
&M
) {
53 StringRef Key
= "dx.disable_optimizations";
54 M
.addModuleFlag(llvm::Module::ModFlagBehavior::Override
, Key
, 1);
56 // cbuffer will be translated into global variable in special address space.
57 // If translate into C,
62 // float foo() { return a + b; }
64 // will be translated into
69 // } cbuffer_A __attribute__((address_space(4)));
70 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
72 // layoutBuffer will create the struct A type.
73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
76 void layoutBuffer(CGHLSLRuntime::Buffer
&Buf
, const DataLayout
&DL
) {
77 if (Buf
.Constants
.empty())
80 std::vector
<llvm::Type
*> EltTys
;
81 for (auto &Const
: Buf
.Constants
) {
82 GlobalVariable
*GV
= Const
.first
;
83 Const
.second
= EltTys
.size();
84 llvm::Type
*Ty
= GV
->getValueType();
85 EltTys
.emplace_back(Ty
);
87 Buf
.LayoutStruct
= llvm::StructType::get(EltTys
[0]->getContext(), EltTys
);
90 GlobalVariable
*replaceBuffer(CGHLSLRuntime::Buffer
&Buf
) {
91 // Create global variable for CB.
92 GlobalVariable
*CBGV
= new GlobalVariable(
93 Buf
.LayoutStruct
, /*isConstant*/ true,
94 GlobalValue::LinkageTypes::ExternalLinkage
, nullptr,
95 llvm::formatv("{0}{1}", Buf
.Name
, Buf
.IsCBuffer
? ".cb." : ".tb."),
96 GlobalValue::NotThreadLocal
);
98 IRBuilder
<> B(CBGV
->getContext());
99 Value
*ZeroIdx
= B
.getInt32(0);
100 // Replace Const use with CB use.
101 for (auto &[GV
, Offset
] : Buf
.Constants
) {
103 B
.CreateGEP(Buf
.LayoutStruct
, CBGV
, {ZeroIdx
, B
.getInt32(Offset
)});
105 assert(Buf
.LayoutStruct
->getElementType(Offset
) == GV
->getValueType() &&
106 "constant type mismatch");
109 GV
->replaceAllUsesWith(GEP
);
111 GV
->removeDeadConstantUsers();
112 GV
->eraseFromParent();
119 void CGHLSLRuntime::addConstant(VarDecl
*D
, Buffer
&CB
) {
120 if (D
->getStorageClass() == SC_Static
) {
121 // For static inside cbuffer, take as global static.
122 // Don't add to cbuffer.
127 auto *GV
= cast
<GlobalVariable
>(CGM
.GetAddrOfGlobalVar(D
));
128 // Add debug info for constVal.
129 if (CGDebugInfo
*DI
= CGM
.getModuleDebugInfo())
130 if (CGM
.getCodeGenOpts().getDebugInfo() >=
131 codegenoptions::DebugInfoKind::LimitedDebugInfo
)
132 DI
->EmitGlobalVariable(cast
<GlobalVariable
>(GV
), D
);
134 // FIXME: support packoffset.
135 // See https://github.com/llvm/llvm-project/issues/57914.
137 bool HasUserOffset
= false;
139 unsigned LowerBound
= HasUserOffset
? Offset
: UINT_MAX
;
140 CB
.Constants
.emplace_back(std::make_pair(GV
, LowerBound
));
143 void CGHLSLRuntime::addBufferDecls(const DeclContext
*DC
, Buffer
&CB
) {
144 for (Decl
*it
: DC
->decls()) {
145 if (auto *ConstDecl
= dyn_cast
<VarDecl
>(it
)) {
146 addConstant(ConstDecl
, CB
);
147 } else if (isa
<CXXRecordDecl
, EmptyDecl
>(it
)) {
148 // Nothing to do for this declaration.
149 } else if (isa
<FunctionDecl
>(it
)) {
150 // A function within an cbuffer is effectively a top-level function,
151 // as it only refers to globally scoped declarations.
152 CGM
.EmitTopLevelDecl(it
);
157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl
*D
) {
158 Buffers
.emplace_back(Buffer(D
));
159 addBufferDecls(D
, Buffers
.back());
162 void CGHLSLRuntime::finishCodeGen() {
163 auto &TargetOpts
= CGM
.getTarget().getTargetOpts();
164 llvm::Module
&M
= CGM
.getModule();
165 Triple
T(M
.getTargetTriple());
166 if (T
.getArch() == Triple::ArchType::dxil
)
167 addDxilValVersion(TargetOpts
.DxilValidatorVersion
, M
);
169 generateGlobalCtorDtorCalls();
170 if (CGM
.getCodeGenOpts().OptimizationLevel
== 0)
171 addDisableOptimizations(M
);
173 const DataLayout
&DL
= M
.getDataLayout();
175 for (auto &Buf
: Buffers
) {
176 layoutBuffer(Buf
, DL
);
177 GlobalVariable
*GV
= replaceBuffer(Buf
);
178 M
.insertGlobalVariable(GV
);
179 llvm::hlsl::ResourceClass RC
= Buf
.IsCBuffer
180 ? llvm::hlsl::ResourceClass::CBuffer
181 : llvm::hlsl::ResourceClass::SRV
;
182 llvm::hlsl::ResourceKind RK
= Buf
.IsCBuffer
183 ? llvm::hlsl::ResourceKind::CBuffer
184 : llvm::hlsl::ResourceKind::TBuffer
;
186 Buf
.Name
.str() + (Buf
.IsCBuffer
? ".cb." : ".tb.") + "ty";
187 addBufferResourceAnnotation(GV
, TyName
, RC
, RK
, Buf
.Binding
);
191 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl
*D
)
192 : Name(D
->getName()), IsCBuffer(D
->isCBuffer()),
193 Binding(D
->getAttr
<HLSLResourceBindingAttr
>()) {}
195 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable
*GV
,
196 llvm::StringRef TyName
,
197 llvm::hlsl::ResourceClass RC
,
198 llvm::hlsl::ResourceKind RK
,
199 BufferResBinding
&Binding
) {
200 llvm::Module
&M
= CGM
.getModule();
202 NamedMDNode
*ResourceMD
= nullptr;
204 case llvm::hlsl::ResourceClass::UAV
:
205 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.uavs");
207 case llvm::hlsl::ResourceClass::SRV
:
208 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.srvs");
210 case llvm::hlsl::ResourceClass::CBuffer
:
211 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.cbufs");
214 assert(false && "Unsupported buffer type!");
218 assert(ResourceMD
!= nullptr &&
219 "ResourceMD must have been set by the switch above.");
221 llvm::hlsl::FrontendResource
Res(
222 GV
, TyName
, RK
, Binding
.Reg
.value_or(UINT_MAX
), Binding
.Space
);
223 ResourceMD
->addOperand(Res
.getMetadata());
226 void CGHLSLRuntime::annotateHLSLResource(const VarDecl
*D
, GlobalVariable
*GV
) {
227 const Type
*Ty
= D
->getType()->getPointeeOrArrayElementType();
230 const auto *RD
= Ty
->getAsCXXRecordDecl();
233 const auto *Attr
= RD
->getAttr
<HLSLResourceAttr
>();
237 llvm::hlsl::ResourceClass RC
= Attr
->getResourceClass();
238 llvm::hlsl::ResourceKind RK
= Attr
->getResourceKind();
241 BufferResBinding
Binding(D
->getAttr
<HLSLResourceBindingAttr
>());
242 addBufferResourceAnnotation(GV
, QT
.getAsString(), RC
, RK
, Binding
);
245 CGHLSLRuntime::BufferResBinding::BufferResBinding(
246 HLSLResourceBindingAttr
*Binding
) {
248 llvm::APInt
RegInt(64, 0);
249 Binding
->getSlot().substr(1).getAsInteger(10, RegInt
);
250 Reg
= RegInt
.getLimitedValue();
251 llvm::APInt
SpaceInt(64, 0);
252 Binding
->getSpace().substr(5).getAsInteger(10, SpaceInt
);
253 Space
= SpaceInt
.getLimitedValue();
259 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
260 const FunctionDecl
*FD
, llvm::Function
*Fn
) {
261 const auto *ShaderAttr
= FD
->getAttr
<HLSLShaderAttr
>();
262 assert(ShaderAttr
&& "All entry functions must have a HLSLShaderAttr");
263 const StringRef ShaderAttrKindStr
= "hlsl.shader";
264 Fn
->addFnAttr(ShaderAttrKindStr
,
265 ShaderAttr
->ConvertShaderTypeToStr(ShaderAttr
->getType()));
266 if (HLSLNumThreadsAttr
*NumThreadsAttr
= FD
->getAttr
<HLSLNumThreadsAttr
>()) {
267 const StringRef NumThreadsKindStr
= "hlsl.numthreads";
268 std::string NumThreadsStr
=
269 formatv("{0},{1},{2}", NumThreadsAttr
->getX(), NumThreadsAttr
->getY(),
270 NumThreadsAttr
->getZ());
271 Fn
->addFnAttr(NumThreadsKindStr
, NumThreadsStr
);
275 static Value
*buildVectorInput(IRBuilder
<> &B
, Function
*F
, llvm::Type
*Ty
) {
276 if (const auto *VT
= dyn_cast
<FixedVectorType
>(Ty
)) {
277 Value
*Result
= PoisonValue::get(Ty
);
278 for (unsigned I
= 0; I
< VT
->getNumElements(); ++I
) {
279 Value
*Elt
= B
.CreateCall(F
, {B
.getInt32(I
)});
280 Result
= B
.CreateInsertElement(Result
, Elt
, I
);
284 return B
.CreateCall(F
, {B
.getInt32(0)});
287 llvm::Value
*CGHLSLRuntime::emitInputSemantic(IRBuilder
<> &B
,
288 const ParmVarDecl
&D
,
290 assert(D
.hasAttrs() && "Entry parameter missing annotation attribute!");
291 if (D
.hasAttr
<HLSLSV_GroupIndexAttr
>()) {
292 llvm::Function
*DxGroupIndex
=
293 CGM
.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group
);
294 return B
.CreateCall(FunctionCallee(DxGroupIndex
));
296 if (D
.hasAttr
<HLSLSV_DispatchThreadIDAttr
>()) {
297 llvm::Function
*DxThreadID
= CGM
.getIntrinsic(Intrinsic::dx_thread_id
);
298 return buildVectorInput(B
, DxThreadID
, Ty
);
300 assert(false && "Unhandled parameter attribute");
304 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl
*FD
,
305 llvm::Function
*Fn
) {
306 llvm::Module
&M
= CGM
.getModule();
307 llvm::LLVMContext
&Ctx
= M
.getContext();
308 auto *EntryTy
= llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx
), false);
310 Function::Create(EntryTy
, Function::ExternalLinkage
, FD
->getName(), &M
);
312 // Copy function attributes over, we have no argument or return attributes
313 // that can be valid on the real entry.
314 AttributeList NewAttrs
= AttributeList::get(Ctx
, AttributeList::FunctionIndex
,
315 Fn
->getAttributes().getFnAttrs());
316 EntryFn
->setAttributes(NewAttrs
);
317 setHLSLEntryAttributes(FD
, EntryFn
);
319 // Set the called function as internal linkage.
320 Fn
->setLinkage(GlobalValue::InternalLinkage
);
322 BasicBlock
*BB
= BasicBlock::Create(Ctx
, "entry", EntryFn
);
324 llvm::SmallVector
<Value
*> Args
;
325 // FIXME: support struct parameters where semantics are on members.
326 // See: https://github.com/llvm/llvm-project/issues/57874
327 unsigned SRetOffset
= 0;
328 for (const auto &Param
: Fn
->args()) {
329 if (Param
.hasStructRetAttr()) {
330 // FIXME: support output.
331 // See: https://github.com/llvm/llvm-project/issues/57874
333 Args
.emplace_back(PoisonValue::get(Param
.getType()));
336 const ParmVarDecl
*PD
= FD
->getParamDecl(Param
.getArgNo() - SRetOffset
);
337 Args
.push_back(emitInputSemantic(B
, *PD
, Param
.getType()));
340 CallInst
*CI
= B
.CreateCall(FunctionCallee(Fn
), Args
);
342 // FIXME: Handle codegen for return type semantics.
343 // See: https://github.com/llvm/llvm-project/issues/57875
347 static void gatherFunctions(SmallVectorImpl
<Function
*> &Fns
, llvm::Module
&M
,
350 M
.getNamedGlobal(CtorOrDtor
? "llvm.global_ctors" : "llvm.global_dtors");
353 const auto *CA
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
356 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
357 // HLSL neither supports priorities or COMDat values, so we will check those
358 // in an assert but not handle them.
360 llvm::SmallVector
<Function
*> CtorFns
;
361 for (const auto &Ctor
: CA
->operands()) {
362 if (isa
<ConstantAggregateZero
>(Ctor
))
364 ConstantStruct
*CS
= cast
<ConstantStruct
>(Ctor
);
366 assert(cast
<ConstantInt
>(CS
->getOperand(0))->getValue() == 65535 &&
367 "HLSL doesn't support setting priority for global ctors.");
368 assert(isa
<ConstantPointerNull
>(CS
->getOperand(2)) &&
369 "HLSL doesn't support COMDat for global ctors.");
370 Fns
.push_back(cast
<Function
>(CS
->getOperand(1)));
374 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
375 llvm::Module
&M
= CGM
.getModule();
376 SmallVector
<Function
*> CtorFns
;
377 SmallVector
<Function
*> DtorFns
;
378 gatherFunctions(CtorFns
, M
, true);
379 gatherFunctions(DtorFns
, M
, false);
381 // Insert a call to the global constructor at the beginning of the entry block
382 // to externally exported functions. This is a bit of a hack, but HLSL allows
383 // global constructors, but doesn't support driver initialization of globals.
384 for (auto &F
: M
.functions()) {
385 if (!F
.hasFnAttribute("hlsl.shader"))
387 IRBuilder
<> B(&F
.getEntryBlock(), F
.getEntryBlock().begin());
388 for (auto *Fn
: CtorFns
)
389 B
.CreateCall(FunctionCallee(Fn
));
391 // Insert global dtors before the terminator of the last instruction
392 B
.SetInsertPoint(F
.back().getTerminator());
393 for (auto *Fn
: DtorFns
)
394 B
.CreateCall(FunctionCallee(Fn
));
397 // No need to keep global ctors/dtors for non-lib profile after call to
398 // ctors/dtors added for entry.
399 Triple
T(M
.getTargetTriple());
400 if (T
.getEnvironment() != Triple::EnvironmentType::Library
) {
401 if (auto *GV
= M
.getNamedGlobal("llvm.global_ctors"))
402 GV
->eraseFromParent();
403 if (auto *GV
= M
.getNamedGlobal("llvm.global_dtors"))
404 GV
->eraseFromParent();