1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This file contains class to help build DXIL op functions.
10 //===----------------------------------------------------------------------===//
12 #include "DXILOpBuilder.h"
13 #include "DXILConstants.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/Support/DXILABI.h"
17 #include "llvm/Support/ErrorHandling.h"
20 using namespace llvm::dxil
;
22 constexpr StringLiteral DXILOpNamePrefix
= "dx.op.";
26 enum OverloadKind
: uint16_t {
36 UserDefineType
= 1 << 9,
42 static const char *getOverloadTypeName(OverloadKind Kind
) {
44 case OverloadKind::HALF
:
46 case OverloadKind::FLOAT
:
48 case OverloadKind::DOUBLE
:
50 case OverloadKind::I1
:
52 case OverloadKind::I8
:
54 case OverloadKind::I16
:
56 case OverloadKind::I32
:
58 case OverloadKind::I64
:
60 case OverloadKind::VOID
:
61 case OverloadKind::ObjectType
:
62 case OverloadKind::UserDefineType
:
65 llvm_unreachable("invalid overload type for name");
69 static OverloadKind
getOverloadKind(Type
*Ty
) {
70 Type::TypeID T
= Ty
->getTypeID();
73 return OverloadKind::VOID
;
75 return OverloadKind::HALF
;
77 return OverloadKind::FLOAT
;
78 case Type::DoubleTyID
:
79 return OverloadKind::DOUBLE
;
80 case Type::IntegerTyID
: {
81 IntegerType
*ITy
= cast
<IntegerType
>(Ty
);
82 unsigned Bits
= ITy
->getBitWidth();
85 return OverloadKind::I1
;
87 return OverloadKind::I8
;
89 return OverloadKind::I16
;
91 return OverloadKind::I32
;
93 return OverloadKind::I64
;
95 llvm_unreachable("invalid overload type");
96 return OverloadKind::VOID
;
99 case Type::PointerTyID
:
100 return OverloadKind::UserDefineType
;
101 case Type::StructTyID
:
102 return OverloadKind::ObjectType
;
104 llvm_unreachable("invalid overload type");
105 return OverloadKind::VOID
;
109 static std::string
getTypeName(OverloadKind Kind
, Type
*Ty
) {
110 if (Kind
< OverloadKind::UserDefineType
) {
111 return getOverloadTypeName(Kind
);
112 } else if (Kind
== OverloadKind::UserDefineType
) {
113 StructType
*ST
= cast
<StructType
>(Ty
);
114 return ST
->getStructName().str();
115 } else if (Kind
== OverloadKind::ObjectType
) {
116 StructType
*ST
= cast
<StructType
>(Ty
);
117 return ST
->getStructName().str();
120 raw_string_ostream
OS(Str
);
126 // Static properties.
127 struct OpCodeProperty
{
129 // Offset in DXILOpCodeNameTable.
130 unsigned OpCodeNameOffset
;
131 dxil::OpCodeClass OpCodeClass
;
132 // Offset in DXILOpCodeClassNameTable.
133 unsigned OpCodeClassNameOffset
;
134 uint16_t OverloadTys
;
135 llvm::Attribute::AttrKind FuncAttr
;
136 int OverloadParamIndex
; // parameter index which control the overload.
137 // When < 0, should be only 1 overload type.
138 unsigned NumOfParameters
; // Number of parameters include return value.
139 unsigned ParameterTableOffset
; // Offset in ParameterTable.
142 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143 // getOpCodeParameterKind which generated by tableGen.
144 #define DXIL_OP_OPERATION_TABLE
145 #include "DXILOperation.inc"
146 #undef DXIL_OP_OPERATION_TABLE
148 static std::string
constructOverloadName(OverloadKind Kind
, Type
*Ty
,
149 const OpCodeProperty
&Prop
) {
150 if (Kind
== OverloadKind::VOID
) {
151 return (Twine(DXILOpNamePrefix
) + getOpCodeClassName(Prop
)).str();
153 return (Twine(DXILOpNamePrefix
) + getOpCodeClassName(Prop
) + "." +
154 getTypeName(Kind
, Ty
))
158 static std::string
constructOverloadTypeName(OverloadKind Kind
,
159 StringRef TypeName
) {
160 if (Kind
== OverloadKind::VOID
)
161 return TypeName
.str();
163 assert(Kind
< OverloadKind::UserDefineType
&& "invalid overload kind");
164 return (Twine(TypeName
) + getOverloadTypeName(Kind
)).str();
167 static StructType
*getOrCreateStructType(StringRef Name
,
168 ArrayRef
<Type
*> EltTys
,
170 StructType
*ST
= StructType::getTypeByName(Ctx
, Name
);
174 return StructType::create(Ctx
, EltTys
, Name
);
177 static StructType
*getResRetType(Type
*OverloadTy
, LLVMContext
&Ctx
) {
178 OverloadKind Kind
= getOverloadKind(OverloadTy
);
179 std::string TypeName
= constructOverloadTypeName(Kind
, "dx.types.ResRet.");
180 Type
*FieldTypes
[5] = {OverloadTy
, OverloadTy
, OverloadTy
, OverloadTy
,
181 Type::getInt32Ty(Ctx
)};
182 return getOrCreateStructType(TypeName
, FieldTypes
, Ctx
);
185 static StructType
*getHandleType(LLVMContext
&Ctx
) {
186 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx
),
190 static Type
*getTypeFromParameterKind(ParameterKind Kind
, Type
*OverloadTy
) {
191 auto &Ctx
= OverloadTy
->getContext();
193 case ParameterKind::Void
:
194 return Type::getVoidTy(Ctx
);
195 case ParameterKind::Half
:
196 return Type::getHalfTy(Ctx
);
197 case ParameterKind::Float
:
198 return Type::getFloatTy(Ctx
);
199 case ParameterKind::Double
:
200 return Type::getDoubleTy(Ctx
);
201 case ParameterKind::I1
:
202 return Type::getInt1Ty(Ctx
);
203 case ParameterKind::I8
:
204 return Type::getInt8Ty(Ctx
);
205 case ParameterKind::I16
:
206 return Type::getInt16Ty(Ctx
);
207 case ParameterKind::I32
:
208 return Type::getInt32Ty(Ctx
);
209 case ParameterKind::I64
:
210 return Type::getInt64Ty(Ctx
);
211 case ParameterKind::Overload
:
213 case ParameterKind::ResourceRet
:
214 return getResRetType(OverloadTy
, Ctx
);
215 case ParameterKind::DXILHandle
:
216 return getHandleType(Ctx
);
220 llvm_unreachable("Invalid parameter kind");
224 /// Construct DXIL function type. This is the type of a function with
225 /// the following prototype
226 /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
227 /// <param-types> are constructed from types in Prop.
228 /// \param Prop Structure containing DXIL Operation properties based on
229 /// its specification in DXIL.td.
230 /// \param OverloadTy Return type to be used to construct DXIL function type.
231 static FunctionType
*getDXILOpFunctionType(const OpCodeProperty
*Prop
,
232 Type
*ReturnTy
, Type
*OverloadTy
) {
233 SmallVector
<Type
*> ArgTys
;
235 auto ParamKinds
= getOpCodeParameterKind(*Prop
);
237 // Add ReturnTy as return type of the function
238 ArgTys
.emplace_back(ReturnTy
);
240 // Add DXIL Opcode value type viz., Int32 as first argument
241 ArgTys
.emplace_back(Type::getInt32Ty(OverloadTy
->getContext()));
243 // Add DXIL Operation parameter types as specified in DXIL properties
244 for (unsigned I
= 0; I
< Prop
->NumOfParameters
; ++I
) {
245 ParameterKind Kind
= ParamKinds
[I
];
246 ArgTys
.emplace_back(getTypeFromParameterKind(Kind
, OverloadTy
));
248 return FunctionType::get(
249 ArgTys
[0], ArrayRef
<Type
*>(&ArgTys
[1], ArgTys
.size() - 1), false);
255 CallInst
*DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode
, Type
*ReturnTy
,
257 SmallVector
<Value
*> Args
) {
258 const OpCodeProperty
*Prop
= getOpCodeProperty(OpCode
);
260 OverloadKind Kind
= getOverloadKind(OverloadTy
);
261 if ((Prop
->OverloadTys
& (uint16_t)Kind
) == 0) {
262 report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
265 std::string DXILFnName
= constructOverloadName(Kind
, OverloadTy
, *Prop
);
266 FunctionCallee DXILFn
;
267 // Get the function with name DXILFnName, if one exists
268 if (auto *Func
= M
.getFunction(DXILFnName
)) {
269 DXILFn
= FunctionCallee(Func
);
271 // Construct and add a function with name DXILFnName
272 FunctionType
*DXILOpFT
= getDXILOpFunctionType(Prop
, ReturnTy
, OverloadTy
);
273 DXILFn
= M
.getOrInsertFunction(DXILFnName
, DXILOpFT
);
276 return B
.CreateCall(DXILFn
, Args
);
279 Type
*DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode
, FunctionType
*FT
) {
281 const OpCodeProperty
*Prop
= getOpCodeProperty(OpCode
);
282 // If DXIL Op has no overload parameter, just return the
283 // precise return type specified.
284 if (Prop
->OverloadParamIndex
< 0) {
285 auto &Ctx
= FT
->getContext();
286 switch (Prop
->OverloadTys
) {
287 case OverloadKind::VOID
:
288 return Type::getVoidTy(Ctx
);
289 case OverloadKind::HALF
:
290 return Type::getHalfTy(Ctx
);
291 case OverloadKind::FLOAT
:
292 return Type::getFloatTy(Ctx
);
293 case OverloadKind::DOUBLE
:
294 return Type::getDoubleTy(Ctx
);
295 case OverloadKind::I1
:
296 return Type::getInt1Ty(Ctx
);
297 case OverloadKind::I8
:
298 return Type::getInt8Ty(Ctx
);
299 case OverloadKind::I16
:
300 return Type::getInt16Ty(Ctx
);
301 case OverloadKind::I32
:
302 return Type::getInt32Ty(Ctx
);
303 case OverloadKind::I64
:
304 return Type::getInt64Ty(Ctx
);
306 llvm_unreachable("invalid overload type");
311 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
312 Type
*OverloadType
= FT
->getReturnType();
313 if (Prop
->OverloadParamIndex
!= 0) {
315 OverloadType
= FT
->getParamType(Prop
->OverloadParamIndex
- 1);
318 auto ParamKinds
= getOpCodeParameterKind(*Prop
);
319 auto Kind
= ParamKinds
[Prop
->OverloadParamIndex
];
320 // For ResRet and CBufferRet, OverloadTy is in field of StructType.
321 if (Kind
== ParameterKind::CBufferRet
||
322 Kind
== ParameterKind::ResourceRet
) {
323 auto *ST
= cast
<StructType
>(OverloadType
);
324 OverloadType
= ST
->getElementType(0);
329 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp
) {
330 return ::getOpCodeName(DXILOp
);