1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
13 //===----------------------------------------------------------------------===//
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Support/FormatVariadic.h"
25 using namespace clang
;
26 using namespace CodeGen
;
27 using namespace clang::hlsl
;
32 void addDxilValVersion(StringRef ValVersionStr
, llvm::Module
&M
) {
33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34 // Assume ValVersionStr is legal here.
36 if (Version
.tryParse(ValVersionStr
) || Version
.getBuild() ||
37 Version
.getSubminor() || !Version
.getMinor()) {
41 uint64_t Major
= Version
.getMajor();
42 uint64_t Minor
= *Version
.getMinor();
44 auto &Ctx
= M
.getContext();
45 IRBuilder
<> B(M
.getContext());
46 MDNode
*Val
= MDNode::get(Ctx
, {ConstantAsMetadata::get(B
.getInt32(Major
)),
47 ConstantAsMetadata::get(B
.getInt32(Minor
))});
48 StringRef DXILValKey
= "dx.valver";
49 auto *DXILValMD
= M
.getOrInsertNamedMetadata(DXILValKey
);
50 DXILValMD
->addOperand(Val
);
52 void addDisableOptimizations(llvm::Module
&M
) {
53 StringRef Key
= "dx.disable_optimizations";
54 M
.addModuleFlag(llvm::Module::ModFlagBehavior::Override
, Key
, 1);
56 // cbuffer will be translated into global variable in special address space.
57 // If translate into C,
62 // float foo() { return a + b; }
64 // will be translated into
69 // } cbuffer_A __attribute__((address_space(4)));
70 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
72 // layoutBuffer will create the struct A type.
73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
76 void layoutBuffer(CGHLSLRuntime::Buffer
&Buf
, const DataLayout
&DL
) {
77 if (Buf
.Constants
.empty())
80 std::vector
<llvm::Type
*> EltTys
;
81 for (auto &Const
: Buf
.Constants
) {
82 GlobalVariable
*GV
= Const
.first
;
83 Const
.second
= EltTys
.size();
84 llvm::Type
*Ty
= GV
->getValueType();
85 EltTys
.emplace_back(Ty
);
87 Buf
.LayoutStruct
= llvm::StructType::get(EltTys
[0]->getContext(), EltTys
);
90 GlobalVariable
*replaceBuffer(CGHLSLRuntime::Buffer
&Buf
) {
91 // Create global variable for CB.
92 GlobalVariable
*CBGV
= new GlobalVariable(
93 Buf
.LayoutStruct
, /*isConstant*/ true,
94 GlobalValue::LinkageTypes::ExternalLinkage
, nullptr,
95 llvm::formatv("{0}{1}", Buf
.Name
, Buf
.IsCBuffer
? ".cb." : ".tb."),
96 GlobalValue::NotThreadLocal
);
98 IRBuilder
<> B(CBGV
->getContext());
99 Value
*ZeroIdx
= B
.getInt32(0);
100 // Replace Const use with CB use.
101 for (auto &[GV
, Offset
] : Buf
.Constants
) {
103 B
.CreateGEP(Buf
.LayoutStruct
, CBGV
, {ZeroIdx
, B
.getInt32(Offset
)});
105 assert(Buf
.LayoutStruct
->getElementType(Offset
) == GV
->getValueType() &&
106 "constant type mismatch");
109 GV
->replaceAllUsesWith(GEP
);
111 GV
->removeDeadConstantUsers();
112 GV
->eraseFromParent();
119 void CGHLSLRuntime::addConstant(VarDecl
*D
, Buffer
&CB
) {
120 if (D
->getStorageClass() == SC_Static
) {
121 // For static inside cbuffer, take as global static.
122 // Don't add to cbuffer.
127 auto *GV
= cast
<GlobalVariable
>(CGM
.GetAddrOfGlobalVar(D
));
128 // Add debug info for constVal.
129 if (CGDebugInfo
*DI
= CGM
.getModuleDebugInfo())
130 if (CGM
.getCodeGenOpts().getDebugInfo() >=
131 codegenoptions::DebugInfoKind::LimitedDebugInfo
)
132 DI
->EmitGlobalVariable(cast
<GlobalVariable
>(GV
), D
);
134 // FIXME: support packoffset.
135 // See https://github.com/llvm/llvm-project/issues/57914.
137 bool HasUserOffset
= false;
139 unsigned LowerBound
= HasUserOffset
? Offset
: UINT_MAX
;
140 CB
.Constants
.emplace_back(std::make_pair(GV
, LowerBound
));
143 void CGHLSLRuntime::addBufferDecls(const DeclContext
*DC
, Buffer
&CB
) {
144 for (Decl
*it
: DC
->decls()) {
145 if (auto *ConstDecl
= dyn_cast
<VarDecl
>(it
)) {
146 addConstant(ConstDecl
, CB
);
147 } else if (isa
<CXXRecordDecl
, EmptyDecl
>(it
)) {
148 // Nothing to do for this declaration.
149 } else if (isa
<FunctionDecl
>(it
)) {
150 // A function within an cbuffer is effectively a top-level function,
151 // as it only refers to globally scoped declarations.
152 CGM
.EmitTopLevelDecl(it
);
157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl
*D
) {
158 Buffers
.emplace_back(Buffer(D
));
159 addBufferDecls(D
, Buffers
.back());
162 void CGHLSLRuntime::finishCodeGen() {
163 auto &TargetOpts
= CGM
.getTarget().getTargetOpts();
164 llvm::Module
&M
= CGM
.getModule();
165 Triple
T(M
.getTargetTriple());
166 if (T
.getArch() == Triple::ArchType::dxil
)
167 addDxilValVersion(TargetOpts
.DxilValidatorVersion
, M
);
169 generateGlobalCtorDtorCalls();
170 if (CGM
.getCodeGenOpts().OptimizationLevel
== 0)
171 addDisableOptimizations(M
);
173 const DataLayout
&DL
= M
.getDataLayout();
175 for (auto &Buf
: Buffers
) {
176 layoutBuffer(Buf
, DL
);
177 GlobalVariable
*GV
= replaceBuffer(Buf
);
178 M
.insertGlobalVariable(GV
);
179 llvm::hlsl::ResourceClass RC
= Buf
.IsCBuffer
180 ? llvm::hlsl::ResourceClass::CBuffer
181 : llvm::hlsl::ResourceClass::SRV
;
182 llvm::hlsl::ResourceKind RK
= Buf
.IsCBuffer
183 ? llvm::hlsl::ResourceKind::CBuffer
184 : llvm::hlsl::ResourceKind::TBuffer
;
186 Buf
.Name
.str() + (Buf
.IsCBuffer
? ".cb." : ".tb.") + "ty";
187 addBufferResourceAnnotation(GV
, TyName
, RC
, RK
, Buf
.Binding
);
191 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl
*D
)
192 : Name(D
->getName()), IsCBuffer(D
->isCBuffer()),
193 Binding(D
->getAttr
<HLSLResourceBindingAttr
>()) {}
195 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable
*GV
,
196 llvm::StringRef TyName
,
197 llvm::hlsl::ResourceClass RC
,
198 llvm::hlsl::ResourceKind RK
,
199 BufferResBinding
&Binding
) {
200 llvm::Module
&M
= CGM
.getModule();
202 NamedMDNode
*ResourceMD
= nullptr;
204 case llvm::hlsl::ResourceClass::UAV
:
205 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.uavs");
207 case llvm::hlsl::ResourceClass::SRV
:
208 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.srvs");
210 case llvm::hlsl::ResourceClass::CBuffer
:
211 ResourceMD
= M
.getOrInsertNamedMetadata("hlsl.cbufs");
214 assert(false && "Unsupported buffer type!");
218 assert(ResourceMD
!= nullptr &&
219 "ResourceMD must have been set by the switch above.");
221 llvm::hlsl::FrontendResource
Res(
222 GV
, TyName
, RK
, Binding
.Reg
.value_or(UINT_MAX
), Binding
.Space
);
223 ResourceMD
->addOperand(Res
.getMetadata());
226 static llvm::hlsl::ResourceKind
227 castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK
) {
229 case HLSLResourceAttr::ResourceKind::Texture1D
:
230 return llvm::hlsl::ResourceKind::Texture1D
;
231 case HLSLResourceAttr::ResourceKind::Texture2D
:
232 return llvm::hlsl::ResourceKind::Texture2D
;
233 case HLSLResourceAttr::ResourceKind::Texture2DMS
:
234 return llvm::hlsl::ResourceKind::Texture2DMS
;
235 case HLSLResourceAttr::ResourceKind::Texture3D
:
236 return llvm::hlsl::ResourceKind::Texture3D
;
237 case HLSLResourceAttr::ResourceKind::TextureCube
:
238 return llvm::hlsl::ResourceKind::TextureCube
;
239 case HLSLResourceAttr::ResourceKind::Texture1DArray
:
240 return llvm::hlsl::ResourceKind::Texture1DArray
;
241 case HLSLResourceAttr::ResourceKind::Texture2DArray
:
242 return llvm::hlsl::ResourceKind::Texture2DArray
;
243 case HLSLResourceAttr::ResourceKind::Texture2DMSArray
:
244 return llvm::hlsl::ResourceKind::Texture2DMSArray
;
245 case HLSLResourceAttr::ResourceKind::TextureCubeArray
:
246 return llvm::hlsl::ResourceKind::TextureCubeArray
;
247 case HLSLResourceAttr::ResourceKind::TypedBuffer
:
248 return llvm::hlsl::ResourceKind::TypedBuffer
;
249 case HLSLResourceAttr::ResourceKind::RawBuffer
:
250 return llvm::hlsl::ResourceKind::RawBuffer
;
251 case HLSLResourceAttr::ResourceKind::StructuredBuffer
:
252 return llvm::hlsl::ResourceKind::StructuredBuffer
;
253 case HLSLResourceAttr::ResourceKind::CBufferKind
:
254 return llvm::hlsl::ResourceKind::CBuffer
;
255 case HLSLResourceAttr::ResourceKind::SamplerKind
:
256 return llvm::hlsl::ResourceKind::Sampler
;
257 case HLSLResourceAttr::ResourceKind::TBuffer
:
258 return llvm::hlsl::ResourceKind::TBuffer
;
259 case HLSLResourceAttr::ResourceKind::RTAccelerationStructure
:
260 return llvm::hlsl::ResourceKind::RTAccelerationStructure
;
261 case HLSLResourceAttr::ResourceKind::FeedbackTexture2D
:
262 return llvm::hlsl::ResourceKind::FeedbackTexture2D
;
263 case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray
:
264 return llvm::hlsl::ResourceKind::FeedbackTexture2DArray
;
266 // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to
267 // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for
268 // HLSLResourceAttr::ResourceKind.
270 static_cast<uint32_t>(
271 HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray
) ==
272 (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries
) - 2));
273 llvm_unreachable("all switch cases should be covered");
276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl
*D
, GlobalVariable
*GV
) {
277 const Type
*Ty
= D
->getType()->getPointeeOrArrayElementType();
280 const auto *RD
= Ty
->getAsCXXRecordDecl();
283 const auto *Attr
= RD
->getAttr
<HLSLResourceAttr
>();
287 HLSLResourceAttr::ResourceClass RC
= Attr
->getResourceType();
288 llvm::hlsl::ResourceKind RK
=
289 castResourceShapeToResourceKind(Attr
->getResourceShape());
292 BufferResBinding
Binding(D
->getAttr
<HLSLResourceBindingAttr
>());
293 addBufferResourceAnnotation(GV
, QT
.getAsString(),
294 static_cast<llvm::hlsl::ResourceClass
>(RC
), RK
,
298 CGHLSLRuntime::BufferResBinding::BufferResBinding(
299 HLSLResourceBindingAttr
*Binding
) {
301 llvm::APInt
RegInt(64, 0);
302 Binding
->getSlot().substr(1).getAsInteger(10, RegInt
);
303 Reg
= RegInt
.getLimitedValue();
304 llvm::APInt
SpaceInt(64, 0);
305 Binding
->getSpace().substr(5).getAsInteger(10, SpaceInt
);
306 Space
= SpaceInt
.getLimitedValue();
312 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
313 const FunctionDecl
*FD
, llvm::Function
*Fn
) {
314 const auto *ShaderAttr
= FD
->getAttr
<HLSLShaderAttr
>();
315 assert(ShaderAttr
&& "All entry functions must have a HLSLShaderAttr");
316 const StringRef ShaderAttrKindStr
= "hlsl.shader";
317 Fn
->addFnAttr(ShaderAttrKindStr
,
318 ShaderAttr
->ConvertShaderTypeToStr(ShaderAttr
->getType()));
319 if (HLSLNumThreadsAttr
*NumThreadsAttr
= FD
->getAttr
<HLSLNumThreadsAttr
>()) {
320 const StringRef NumThreadsKindStr
= "hlsl.numthreads";
321 std::string NumThreadsStr
=
322 formatv("{0},{1},{2}", NumThreadsAttr
->getX(), NumThreadsAttr
->getY(),
323 NumThreadsAttr
->getZ());
324 Fn
->addFnAttr(NumThreadsKindStr
, NumThreadsStr
);
328 static Value
*buildVectorInput(IRBuilder
<> &B
, Function
*F
, llvm::Type
*Ty
) {
329 if (const auto *VT
= dyn_cast
<FixedVectorType
>(Ty
)) {
330 Value
*Result
= PoisonValue::get(Ty
);
331 for (unsigned I
= 0; I
< VT
->getNumElements(); ++I
) {
332 Value
*Elt
= B
.CreateCall(F
, {B
.getInt32(I
)});
333 Result
= B
.CreateInsertElement(Result
, Elt
, I
);
337 return B
.CreateCall(F
, {B
.getInt32(0)});
340 llvm::Value
*CGHLSLRuntime::emitInputSemantic(IRBuilder
<> &B
,
341 const ParmVarDecl
&D
,
343 assert(D
.hasAttrs() && "Entry parameter missing annotation attribute!");
344 if (D
.hasAttr
<HLSLSV_GroupIndexAttr
>()) {
345 llvm::Function
*DxGroupIndex
=
346 CGM
.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group
);
347 return B
.CreateCall(FunctionCallee(DxGroupIndex
));
349 if (D
.hasAttr
<HLSLSV_DispatchThreadIDAttr
>()) {
350 llvm::Function
*DxThreadID
= CGM
.getIntrinsic(Intrinsic::dx_thread_id
);
351 return buildVectorInput(B
, DxThreadID
, Ty
);
353 assert(false && "Unhandled parameter attribute");
357 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl
*FD
,
358 llvm::Function
*Fn
) {
359 llvm::Module
&M
= CGM
.getModule();
360 llvm::LLVMContext
&Ctx
= M
.getContext();
361 auto *EntryTy
= llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx
), false);
363 Function::Create(EntryTy
, Function::ExternalLinkage
, FD
->getName(), &M
);
365 // Copy function attributes over, we have no argument or return attributes
366 // that can be valid on the real entry.
367 AttributeList NewAttrs
= AttributeList::get(Ctx
, AttributeList::FunctionIndex
,
368 Fn
->getAttributes().getFnAttrs());
369 EntryFn
->setAttributes(NewAttrs
);
370 setHLSLEntryAttributes(FD
, EntryFn
);
372 // Set the called function as internal linkage.
373 Fn
->setLinkage(GlobalValue::InternalLinkage
);
375 BasicBlock
*BB
= BasicBlock::Create(Ctx
, "entry", EntryFn
);
377 llvm::SmallVector
<Value
*> Args
;
378 // FIXME: support struct parameters where semantics are on members.
379 // See: https://github.com/llvm/llvm-project/issues/57874
380 unsigned SRetOffset
= 0;
381 for (const auto &Param
: Fn
->args()) {
382 if (Param
.hasStructRetAttr()) {
383 // FIXME: support output.
384 // See: https://github.com/llvm/llvm-project/issues/57874
386 Args
.emplace_back(PoisonValue::get(Param
.getType()));
389 const ParmVarDecl
*PD
= FD
->getParamDecl(Param
.getArgNo() - SRetOffset
);
390 Args
.push_back(emitInputSemantic(B
, *PD
, Param
.getType()));
393 CallInst
*CI
= B
.CreateCall(FunctionCallee(Fn
), Args
);
395 // FIXME: Handle codegen for return type semantics.
396 // See: https://github.com/llvm/llvm-project/issues/57875
400 static void gatherFunctions(SmallVectorImpl
<Function
*> &Fns
, llvm::Module
&M
,
403 M
.getNamedGlobal(CtorOrDtor
? "llvm.global_ctors" : "llvm.global_dtors");
406 const auto *CA
= dyn_cast
<ConstantArray
>(GV
->getInitializer());
409 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410 // HLSL neither supports priorities or COMDat values, so we will check those
411 // in an assert but not handle them.
413 llvm::SmallVector
<Function
*> CtorFns
;
414 for (const auto &Ctor
: CA
->operands()) {
415 if (isa
<ConstantAggregateZero
>(Ctor
))
417 ConstantStruct
*CS
= cast
<ConstantStruct
>(Ctor
);
419 assert(cast
<ConstantInt
>(CS
->getOperand(0))->getValue() == 65535 &&
420 "HLSL doesn't support setting priority for global ctors.");
421 assert(isa
<ConstantPointerNull
>(CS
->getOperand(2)) &&
422 "HLSL doesn't support COMDat for global ctors.");
423 Fns
.push_back(cast
<Function
>(CS
->getOperand(1)));
427 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428 llvm::Module
&M
= CGM
.getModule();
429 SmallVector
<Function
*> CtorFns
;
430 SmallVector
<Function
*> DtorFns
;
431 gatherFunctions(CtorFns
, M
, true);
432 gatherFunctions(DtorFns
, M
, false);
434 // Insert a call to the global constructor at the beginning of the entry block
435 // to externally exported functions. This is a bit of a hack, but HLSL allows
436 // global constructors, but doesn't support driver initialization of globals.
437 for (auto &F
: M
.functions()) {
438 if (!F
.hasFnAttribute("hlsl.shader"))
440 IRBuilder
<> B(&F
.getEntryBlock(), F
.getEntryBlock().begin());
441 for (auto *Fn
: CtorFns
)
442 B
.CreateCall(FunctionCallee(Fn
));
444 // Insert global dtors before the terminator of the last instruction
445 B
.SetInsertPoint(F
.back().getTerminator());
446 for (auto *Fn
: DtorFns
)
447 B
.CreateCall(FunctionCallee(Fn
));
450 // No need to keep global ctors/dtors for non-lib profile after call to
451 // ctors/dtors added for entry.
452 Triple
T(M
.getTargetTriple());
453 if (T
.getEnvironment() != Triple::EnvironmentType::Library
) {
454 if (auto *GV
= M
.getNamedGlobal("llvm.global_ctors"))
455 GV
->eraseFromParent();
456 if (auto *GV
= M
.getNamedGlobal("llvm.global_dtors"))
457 GV
->eraseFromParent();