1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
13 //===----------------------------------------------------------------------===//
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Support/FormatVariadic.h"
24 using namespace clang
;
25 using namespace CodeGen
;
26 using namespace clang::hlsl
;
31 void addDxilValVersion(StringRef ValVersionStr
, llvm::Module
&M
) {
32 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
33 // Assume ValVersionStr is legal here.
35 if (Version
.tryParse(ValVersionStr
) || Version
.getBuild() ||
36 Version
.getSubminor() || !Version
.getMinor()) {
40 uint64_t Major
= Version
.getMajor();
41 uint64_t Minor
= *Version
.getMinor();
43 auto &Ctx
= M
.getContext();
44 IRBuilder
<> B(M
.getContext());
45 MDNode
*Val
= MDNode::get(Ctx
, {ConstantAsMetadata::get(B
.getInt32(Major
)),
46 ConstantAsMetadata::get(B
.getInt32(Minor
))});
47 StringRef DXILValKey
= "dx.valver";
48 auto *DXILValMD
= M
.getOrInsertNamedMetadata(DXILValKey
);
49 DXILValMD
->addOperand(Val
);
51 void addDisableOptimizations(llvm::Module
&M
) {
52 StringRef Key
= "dx.disable_optimizations";
53 M
.addModuleFlag(llvm::Module::ModFlagBehavior::Override
, Key
, 1);
55 // cbuffer will be translated into global variable in special address space.
56 // If translate into C,
61 // float foo() { return a + b; }
63 // will be translated into
68 // } cbuffer_A __attribute__((address_space(4)));
69 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
71 // layoutBuffer will create the struct A type.
72 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
75 void layoutBuffer(CGHLSLRuntime::Buffer
&Buf
, const DataLayout
&DL
) {
76 if (Buf
.Constants
.empty())
79 std::vector
<llvm::Type
*> EltTys
;
80 for (auto &Const
: Buf
.Constants
) {
81 GlobalVariable
*GV
= Const
.first
;
82 Const
.second
= EltTys
.size();
83 llvm::Type
*Ty
= GV
->getValueType();
84 EltTys
.emplace_back(Ty
);
86 Buf
.LayoutStruct
= llvm::StructType::get(EltTys
[0]->getContext(), EltTys
);
89 GlobalVariable
*replaceBuffer(CGHLSLRuntime::Buffer
&Buf
) {
90 // Create global variable for CB.
91 GlobalVariable
*CBGV
= new GlobalVariable(
92 Buf
.LayoutStruct
, /*isConstant*/ true,
93 GlobalValue::LinkageTypes::ExternalLinkage
, nullptr,
94 llvm::formatv("{0}{1}", Buf
.Name
, Buf
.IsCBuffer
? ".cb." : ".tb."),
95 GlobalValue::NotThreadLocal
);
97 IRBuilder
<> B(CBGV
->getContext());
98 Value
*ZeroIdx
= B
.getInt32(0);
99 // Replace Const use with CB use.
100 for (auto &[GV
, Offset
] : Buf
.Constants
) {
102 B
.CreateGEP(Buf
.LayoutStruct
, CBGV
, {ZeroIdx
, B
.getInt32(Offset
)});
104 assert(Buf
.LayoutStruct
->getElementType(Offset
) == GV
->getValueType() &&
105 "constant type mismatch");
108 GV
->replaceAllUsesWith(GEP
);
110 GV
->removeDeadConstantUsers();
111 GV
->eraseFromParent();
118 llvm::Triple::ArchType
CGHLSLRuntime::getArch() {
119 return CGM
.getTarget().getTriple().getArch();
122 void CGHLSLRuntime::addConstant(VarDecl
*D
, Buffer
&CB
) {
123 if (D
->getStorageClass() == SC_Static
) {
124 // For static inside cbuffer, take as global static.
125 // Don't add to cbuffer.
130 auto *GV
= cast
<GlobalVariable
>(CGM
.GetAddrOfGlobalVar(D
));
131 // Add debug info for constVal.
132 if (CGDebugInfo
*DI
= CGM
.getModuleDebugInfo())
133 if (CGM
.getCodeGenOpts().getDebugInfo() >=
134 codegenoptions::DebugInfoKind::LimitedDebugInfo
)
135 DI
->EmitGlobalVariable(cast
<GlobalVariable
>(GV
), D
);
137 // FIXME: support packoffset.
138 // See https://github.com/llvm/llvm-project/issues/57914.
140 bool HasUserOffset
= false;
142 unsigned LowerBound
= HasUserOffset
? Offset
: UINT_MAX
;
143 CB
.Constants
.emplace_back(std::make_pair(GV
, LowerBound
));
146 void CGHLSLRuntime::addBufferDecls(const DeclContext
*DC
, Buffer
&CB
) {
147 for (Decl
*it
: DC
->decls()) {
148 if (auto *ConstDecl
= dyn_cast
<VarDecl
>(it
)) {
149 addConstant(ConstDecl
, CB
);
150 } else if (isa
<CXXRecordDecl
, EmptyDecl
>(it
)) {
151 // Nothing to do for this declaration.
152 } else if (isa
<FunctionDecl
>(it
)) {
153 // A function within an cbuffer is effectively a top-level function,
154 // as it only refers to globally scoped declarations.
155 CGM
.EmitTopLevelDecl(it
);
160 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl
*D
) {
161 Buffers
.emplace_back(Buffer(D
));
162 addBufferDecls(D
, Buffers
.back());
165 void CGHLSLRuntime::finishCodeGen() {
166 auto &TargetOpts
= CGM
.getTarget().getTargetOpts();
167 llvm::Module
&M
= CGM
.getModule();
168 Triple
T(M
.getTargetTriple());
169 if (T
.getArch() == Triple::ArchType::dxil
)
170 addDxilValVersion(TargetOpts
.DxilValidatorVersion
, M
);
172 generateGlobalCtorDtorCalls();
173 if (CGM
.getCodeGenOpts().OptimizationLevel
== 0)
174 addDisableOptimizations(M
);
176 const DataLayout
&DL
= M
.getDataLayout();
178 for (auto &Buf
: Buffers
) {
179 layoutBuffer(Buf
, DL
);
180 GlobalVariable
*GV
= replaceBuffer(Buf
);
181 M
.insertGlobalVariable(GV
);
182 llvm::hlsl::ResourceClass RC
= Buf
.IsCBuffer
183 ? llvm::hlsl::ResourceClass::CBuffer
184 : llvm::hlsl::ResourceClass::SRV
;
185 llvm::hlsl::ResourceKind RK
= Buf
.IsCBuffer
186 ? llvm::hlsl::ResourceKind::CBuffer
187 : llvm::hlsl::ResourceKind::TBuffer
;
188 addBufferResourceAnnotation(GV
, RC
, RK
, /*IsROV=*/false,
189 llvm::hlsl::ElementType::Invalid
, Buf
.Binding
);
193 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl
*D
)
194 : Name(D
->getName()), IsCBuffer(D
->isCBuffer()),
195 Binding(D
->getAttr
<HLSLResourceBindingAttr
>()) {}
197 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable
*GV
,
198 llvm::hlsl::ResourceClass RC
,
199 llvm::hlsl::ResourceKind RK
,
201 llvm::hlsl::ElementType ET
,
202 BufferResBinding
&Binding
) {
203 llvm::Module
&M
= CGM
.getModule();
205 NamedMDNode
*ResourceMD
= nullptr;
207 case llvm::hlsl::ResourceClass::UAV
:
208 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.uavs");
210 case llvm::hlsl::ResourceClass::SRV
:
211 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.srvs");
213 case llvm::hlsl::ResourceClass::CBuffer
:
214 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.cbufs");
217 assert(false && "Unsupported buffer type!");
220 assert(ResourceMD
!= nullptr &&
221 "ResourceMD must have been set by the switch above.");
223 llvm::hlsl::FrontendResource
Res(
224 GV
, RK
, ET
, IsROV
, Binding
.Reg
.value_or(UINT_MAX
), Binding
.Space
);
225 ResourceMD
->addOperand(Res
.getMetadata());
228 static llvm::hlsl::ElementType
229 calculateElementType(const ASTContext
&Context
, const clang::Type
*ResourceTy
) {
230 using llvm::hlsl::ElementType
;
232 // TODO: We may need to update this when we add things like ByteAddressBuffer
233 // that don't have a template parameter (or, indeed, an element type).
234 const auto *TST
= ResourceTy
->getAs
<TemplateSpecializationType
>();
235 assert(TST
&& "Resource types must be template specializations");
236 ArrayRef
<TemplateArgument
> Args
= TST
->template_arguments();
237 assert(!Args
.empty() && "Resource has no element type");
239 // At this point we have a resource with an element type, so we can assume
240 // that it's valid or we would have diagnosed the error earlier.
241 QualType ElTy
= Args
[0].getAsType();
243 // We should either have a basic type or a vector of a basic type.
244 if (const auto *VecTy
= ElTy
->getAs
<clang::VectorType
>())
245 ElTy
= VecTy
->getElementType();
247 if (ElTy
->isSignedIntegerType()) {
248 switch (Context
.getTypeSize(ElTy
)) {
250 return ElementType::I16
;
252 return ElementType::I32
;
254 return ElementType::I64
;
256 } else if (ElTy
->isUnsignedIntegerType()) {
257 switch (Context
.getTypeSize(ElTy
)) {
259 return ElementType::U16
;
261 return ElementType::U32
;
263 return ElementType::U64
;
265 } else if (ElTy
->isSpecificBuiltinType(BuiltinType::Half
))
266 return ElementType::F16
;
267 else if (ElTy
->isSpecificBuiltinType(BuiltinType::Float
))
268 return ElementType::F32
;
269 else if (ElTy
->isSpecificBuiltinType(BuiltinType::Double
))
270 return ElementType::F64
;
272 // TODO: We need to handle unorm/snorm float types here once we support them
273 llvm_unreachable("Invalid element type for resource");
276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl
*D
, GlobalVariable
*GV
) {
277 const Type
*Ty
= D
->getType()->getPointeeOrArrayElementType();
280 const auto *RD
= Ty
->getAsCXXRecordDecl();
283 const auto *HLSLResAttr
= RD
->getAttr
<HLSLResourceAttr
>();
284 const auto *HLSLResClassAttr
= RD
->getAttr
<HLSLResourceClassAttr
>();
285 if (!HLSLResAttr
|| !HLSLResClassAttr
)
288 llvm::hlsl::ResourceClass RC
= HLSLResClassAttr
->getResourceClass();
289 llvm::hlsl::ResourceKind RK
= HLSLResAttr
->getResourceKind();
290 bool IsROV
= HLSLResAttr
->getIsROV();
291 llvm::hlsl::ElementType ET
= calculateElementType(CGM
.getContext(), Ty
);
293 BufferResBinding
Binding(D
->getAttr
<HLSLResourceBindingAttr
>());
294 addBufferResourceAnnotation(GV
, RC
, RK
, IsROV
, ET
, Binding
);
297 CGHLSLRuntime::BufferResBinding::BufferResBinding(
298 HLSLResourceBindingAttr
*Binding
) {
300 llvm::APInt
RegInt(64, 0);
301 Binding
->getSlot().substr(1).getAsInteger(10, RegInt
);
302 Reg
= RegInt
.getLimitedValue();
303 llvm::APInt
SpaceInt(64, 0);
304 Binding
->getSpace().substr(5).getAsInteger(10, SpaceInt
);
305 Space
= SpaceInt
.getLimitedValue();
311 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
312 const FunctionDecl
*FD
, llvm::Function
*Fn
) {
313 const auto *ShaderAttr
= FD
->getAttr
<HLSLShaderAttr
>();
314 assert(ShaderAttr
&& "All entry functions must have a HLSLShaderAttr");
315 const StringRef ShaderAttrKindStr
= "hlsl.shader";
316 Fn
->addFnAttr(ShaderAttrKindStr
,
317 llvm::Triple::getEnvironmentTypeName(ShaderAttr
->getType()));
318 if (HLSLNumThreadsAttr
*NumThreadsAttr
= FD
->getAttr
<HLSLNumThreadsAttr
>()) {
319 const StringRef NumThreadsKindStr
= "hlsl.numthreads";
320 std::string NumThreadsStr
=
321 formatv("{0},{1},{2}", NumThreadsAttr
->getX(), NumThreadsAttr
->getY(),
322 NumThreadsAttr
->getZ());
323 Fn
->addFnAttr(NumThreadsKindStr
, NumThreadsStr
);
327 static Value
*buildVectorInput(IRBuilder
<> &B
, Function
*F
, llvm::Type
*Ty
) {
328 if (const auto *VT
= dyn_cast
<FixedVectorType
>(Ty
)) {
329 Value
*Result
= PoisonValue::get(Ty
);
330 for (unsigned I
= 0; I
< VT
->getNumElements(); ++I
) {
331 Value
*Elt
= B
.CreateCall(F
, {B
.getInt32(I
)});
332 Result
= B
.CreateInsertElement(Result
, Elt
, I
);
336 return B
.CreateCall(F
, {B
.getInt32(0)});
339 llvm::Value
*CGHLSLRuntime::emitInputSemantic(IRBuilder
<> &B
,
340 const ParmVarDecl
&D
,
342 assert(D
.hasAttrs() && "Entry parameter missing annotation attribute!");
343 if (D
.hasAttr
<HLSLSV_GroupIndexAttr
>()) {
344 llvm::Function
*DxGroupIndex
=
345 CGM
.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group
);
346 return B
.CreateCall(FunctionCallee(DxGroupIndex
));
348 if (D
.hasAttr
<HLSLSV_DispatchThreadIDAttr
>()) {
349 llvm::Function
*ThreadIDIntrinsic
=
350 CGM
.getIntrinsic(getThreadIdIntrinsic());
351 return buildVectorInput(B
, ThreadIDIntrinsic
, Ty
);
353 assert(false && "Unhandled parameter attribute");
357 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl
*FD
,
358 llvm::Function
*Fn
) {
359 llvm::Module
&M
= CGM
.getModule();
360 llvm::LLVMContext
&Ctx
= M
.getContext();
361 auto *EntryTy
= llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx
), false);
363 Function::Create(EntryTy
, Function::ExternalLinkage
, FD
->getName(), &M
);
365 // Copy function attributes over, we have no argument or return attributes
366 // that can be valid on the real entry.
367 AttributeList NewAttrs
= AttributeList::get(Ctx
, AttributeList::FunctionIndex
,
368 Fn
->getAttributes().getFnAttrs());
369 EntryFn
->setAttributes(NewAttrs
);
370 setHLSLEntryAttributes(FD
, EntryFn
);
372 // Set the called function as internal linkage.
373 Fn
->setLinkage(GlobalValue::InternalLinkage
);
375 BasicBlock
*BB
= BasicBlock::Create(Ctx
, "entry", EntryFn
);
377 llvm::SmallVector
<Value
*> Args
;
378 // FIXME: support struct parameters where semantics are on members.
379 // See: https://github.com/llvm/llvm-project/issues/57874
380 unsigned SRetOffset
= 0;
381 for (const auto &Param
: Fn
->args()) {
382 if (Param
.hasStructRetAttr()) {
383 // FIXME: support output.
384 // See: https://github.com/llvm/llvm-project/issues/57874
386 Args
.emplace_back(PoisonValue::get(Param
.getType()));
389 const ParmVarDecl
*PD
= FD
->getParamDecl(Param
.getArgNo() - SRetOffset
);
390 Args
.push_back(emitInputSemantic(B
, *PD
, Param
.getType()));
393 CallInst
*CI
= B
.CreateCall(FunctionCallee(Fn
), Args
);
395 // FIXME: Handle codegen for return type semantics.
396 // See: https://github.com/llvm/llvm-project/issues/57875
400 static void gatherFunctions(SmallVectorImpl
<Function
*> &Fns
, llvm::Module
&M
,
403 M
.getNamedGlobal(CtorOrDtor
? "llvm.global_ctors" : "llvm.global_dtors");
406 const auto *CA
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
409 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410 // HLSL neither supports priorities or COMDat values, so we will check those
411 // in an assert but not handle them.
413 llvm::SmallVector
<Function
*> CtorFns
;
414 for (const auto &Ctor
: CA
->operands()) {
415 if (isa
<ConstantAggregateZero
>(Ctor
))
417 ConstantStruct
*CS
= cast
<ConstantStruct
>(Ctor
);
419 assert(cast
<ConstantInt
>(CS
->getOperand(0))->getValue() == 65535 &&
420 "HLSL doesn't support setting priority for global ctors.");
421 assert(isa
<ConstantPointerNull
>(CS
->getOperand(2)) &&
422 "HLSL doesn't support COMDat for global ctors.");
423 Fns
.push_back(cast
<Function
>(CS
->getOperand(1)));
427 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428 llvm::Module
&M
= CGM
.getModule();
429 SmallVector
<Function
*> CtorFns
;
430 SmallVector
<Function
*> DtorFns
;
431 gatherFunctions(CtorFns
, M
, true);
432 gatherFunctions(DtorFns
, M
, false);
434 // Insert a call to the global constructor at the beginning of the entry block
435 // to externally exported functions. This is a bit of a hack, but HLSL allows
436 // global constructors, but doesn't support driver initialization of globals.
437 for (auto &F
: M
.functions()) {
438 if (!F
.hasFnAttribute("hlsl.shader"))
440 IRBuilder
<> B(&F
.getEntryBlock(), F
.getEntryBlock().begin());
441 for (auto *Fn
: CtorFns
)
442 B
.CreateCall(FunctionCallee(Fn
));
444 // Insert global dtors before the terminator of the last instruction
445 B
.SetInsertPoint(F
.back().getTerminator());
446 for (auto *Fn
: DtorFns
)
447 B
.CreateCall(FunctionCallee(Fn
));
450 // No need to keep global ctors/dtors for non-lib profile after call to
451 // ctors/dtors added for entry.
452 Triple
T(M
.getTargetTriple());
453 if (T
.getEnvironment() != Triple::EnvironmentType::Library
) {
454 if (auto *GV
= M
.getNamedGlobal("llvm.global_ctors"))
455 GV
->eraseFromParent();
456 if (auto *GV
= M
.getNamedGlobal("llvm.global_dtors"))
457 GV
->eraseFromParent();