1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
15 //===----------------------------------------------------------------------===//
17 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Support/Casting.h"
30 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize
)
31 : PointerSize(PointerSize
), Bound(0) {}
33 SPIRVType
*SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth
,
36 const SPIRVInstrInfo
&TII
) {
37 SPIRVType
*SpirvType
= getOrCreateSPIRVIntegerType(BitWidth
, I
, TII
);
38 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
43 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth
, Register VReg
,
45 const SPIRVInstrInfo
&TII
) {
46 SPIRVType
*SpirvType
= getOrCreateSPIRVFloatType(BitWidth
, I
, TII
);
47 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
51 SPIRVType
*SPIRVGlobalRegistry::assignVectTypeToVReg(
52 SPIRVType
*BaseType
, unsigned NumElements
, Register VReg
, MachineInstr
&I
,
53 const SPIRVInstrInfo
&TII
) {
54 SPIRVType
*SpirvType
=
55 getOrCreateSPIRVVectorType(BaseType
, NumElements
, I
, TII
);
56 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
60 SPIRVType
*SPIRVGlobalRegistry::assignTypeToVReg(
61 const Type
*Type
, Register VReg
, MachineIRBuilder
&MIRBuilder
,
62 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
63 SPIRVType
*SpirvType
=
64 getOrCreateSPIRVType(Type
, MIRBuilder
, AccessQual
, EmitIR
);
65 assignSPIRVTypeToVReg(SpirvType
, VReg
, MIRBuilder
.getMF());
69 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType
*SpirvType
,
71 MachineFunction
&MF
) {
72 VRegToTypeMap
[&MF
][VReg
] = SpirvType
;
75 static Register
createTypeVReg(MachineIRBuilder
&MIRBuilder
) {
76 auto &MRI
= MIRBuilder
.getMF().getRegInfo();
77 auto Res
= MRI
.createGenericVirtualRegister(LLT::scalar(32));
78 MRI
.setRegClass(Res
, &SPIRV::TYPERegClass
);
82 static Register
createTypeVReg(MachineRegisterInfo
&MRI
) {
83 auto Res
= MRI
.createGenericVirtualRegister(LLT::scalar(32));
84 MRI
.setRegClass(Res
, &SPIRV::TYPERegClass
);
88 SPIRVType
*SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder
&MIRBuilder
) {
89 return MIRBuilder
.buildInstr(SPIRV::OpTypeBool
)
90 .addDef(createTypeVReg(MIRBuilder
));
93 SPIRVType
*SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width
,
94 MachineIRBuilder
&MIRBuilder
,
96 assert(Width
<= 64 && "Unsupported integer width!");
97 const SPIRVSubtarget
&ST
=
98 cast
<SPIRVSubtarget
>(MIRBuilder
.getMF().getSubtarget());
99 if (ST
.canUseExtension(
100 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers
)) {
101 MIRBuilder
.buildInstr(SPIRV::OpExtension
)
102 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers
);
103 MIRBuilder
.buildInstr(SPIRV::OpCapability
)
104 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL
);
105 } else if (Width
<= 8)
107 else if (Width
<= 16)
109 else if (Width
<= 32)
111 else if (Width
<= 64)
114 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeInt
)
115 .addDef(createTypeVReg(MIRBuilder
))
117 .addImm(IsSigned
? 1 : 0);
121 SPIRVType
*SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width
,
122 MachineIRBuilder
&MIRBuilder
) {
123 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeFloat
)
124 .addDef(createTypeVReg(MIRBuilder
))
129 SPIRVType
*SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder
&MIRBuilder
) {
130 return MIRBuilder
.buildInstr(SPIRV::OpTypeVoid
)
131 .addDef(createTypeVReg(MIRBuilder
));
134 SPIRVType
*SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems
,
136 MachineIRBuilder
&MIRBuilder
) {
137 auto EleOpc
= ElemType
->getOpcode();
139 assert((EleOpc
== SPIRV::OpTypeInt
|| EleOpc
== SPIRV::OpTypeFloat
||
140 EleOpc
== SPIRV::OpTypeBool
) &&
141 "Invalid vector element type");
143 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeVector
)
144 .addDef(createTypeVReg(MIRBuilder
))
145 .addUse(getSPIRVTypeID(ElemType
))
150 std::tuple
<Register
, ConstantInt
*, bool>
151 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val
, SPIRVType
*SpvType
,
152 MachineIRBuilder
*MIRBuilder
,
154 const SPIRVInstrInfo
*TII
) {
155 const IntegerType
*LLVMIntTy
;
157 LLVMIntTy
= cast
<IntegerType
>(getTypeForSPIRVType(SpvType
));
159 LLVMIntTy
= IntegerType::getInt32Ty(CurMF
->getFunction().getContext());
160 bool NewInstr
= false;
161 // Find a constant in DT or build a new one.
162 ConstantInt
*CI
= ConstantInt::get(const_cast<IntegerType
*>(LLVMIntTy
), Val
);
163 Register Res
= DT
.find(CI
, CurMF
);
164 if (!Res
.isValid()) {
165 unsigned BitWidth
= SpvType
? getScalarOrVectorBitWidth(SpvType
) : 32;
166 // TODO: handle cases where the type is not 32bit wide
167 // TODO: https://github.com/llvm/llvm-project/issues/88129
168 LLT LLTy
= LLT::scalar(32);
169 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
170 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
172 assignTypeToVReg(LLVMIntTy
, Res
, *MIRBuilder
);
174 assignIntTypeToVReg(BitWidth
, Res
, *I
, *TII
);
175 DT
.add(CI
, CurMF
, Res
);
178 return std::make_tuple(Res
, CI
, NewInstr
);
181 std::tuple
<Register
, ConstantFP
*, bool, unsigned>
182 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val
, SPIRVType
*SpvType
,
183 MachineIRBuilder
*MIRBuilder
,
185 const SPIRVInstrInfo
*TII
) {
186 const Type
*LLVMFloatTy
;
187 LLVMContext
&Ctx
= CurMF
->getFunction().getContext();
188 unsigned BitWidth
= 32;
190 LLVMFloatTy
= getTypeForSPIRVType(SpvType
);
192 LLVMFloatTy
= Type::getFloatTy(Ctx
);
194 SpvType
= getOrCreateSPIRVType(LLVMFloatTy
, *MIRBuilder
);
196 bool NewInstr
= false;
197 // Find a constant in DT or build a new one.
198 auto *const CI
= ConstantFP::get(Ctx
, Val
);
199 Register Res
= DT
.find(CI
, CurMF
);
200 if (!Res
.isValid()) {
202 BitWidth
= getScalarOrVectorBitWidth(SpvType
);
203 // TODO: handle cases where the type is not 32bit wide
204 // TODO: https://github.com/llvm/llvm-project/issues/88129
205 LLT LLTy
= LLT::scalar(32);
206 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
207 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
209 assignTypeToVReg(LLVMFloatTy
, Res
, *MIRBuilder
);
211 assignFloatTypeToVReg(BitWidth
, Res
, *I
, *TII
);
212 DT
.add(CI
, CurMF
, Res
);
215 return std::make_tuple(Res
, CI
, NewInstr
, BitWidth
);
218 Register
SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val
, MachineInstr
&I
,
220 const SPIRVInstrInfo
&TII
,
227 std::tie(Res
, CI
, New
, BitWidth
) =
228 getOrCreateConstFloatReg(Val
, SpvType
, nullptr, &I
, &TII
);
229 // If we have found Res register which is defined by the passed G_CONSTANT
230 // machine instruction, a new constant instruction should be created.
231 if (!New
&& (!I
.getOperand(0).isReg() || Res
!= I
.getOperand(0).getReg()))
233 MachineInstrBuilder MIB
;
234 MachineBasicBlock
&BB
= *I
.getParent();
235 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
236 if (Val
.isPosZero() && ZeroAsNull
) {
237 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
239 .addUse(getSPIRVTypeID(SpvType
));
241 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantF
))
243 .addUse(getSPIRVTypeID(SpvType
));
245 APInt(BitWidth
, CI
->getValueAPF().bitcastToAPInt().getZExtValue()),
248 const auto &ST
= CurMF
->getSubtarget();
249 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
250 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());
254 Register
SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val
, MachineInstr
&I
,
256 const SPIRVInstrInfo
&TII
,
262 std::tie(Res
, CI
, New
) =
263 getOrCreateConstIntReg(Val
, SpvType
, nullptr, &I
, &TII
);
264 // If we have found Res register which is defined by the passed G_CONSTANT
265 // machine instruction, a new constant instruction should be created.
266 if (!New
&& (!I
.getOperand(0).isReg() || Res
!= I
.getOperand(0).getReg()))
268 MachineInstrBuilder MIB
;
269 MachineBasicBlock
&BB
= *I
.getParent();
270 if (Val
|| !ZeroAsNull
) {
271 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantI
))
273 .addUse(getSPIRVTypeID(SpvType
));
274 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType
), Val
), MIB
);
276 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
278 .addUse(getSPIRVTypeID(SpvType
));
280 const auto &ST
= CurMF
->getSubtarget();
281 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
282 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());
286 Register
SPIRVGlobalRegistry::buildConstantInt(uint64_t Val
,
287 MachineIRBuilder
&MIRBuilder
,
290 auto &MF
= MIRBuilder
.getMF();
291 const IntegerType
*LLVMIntTy
;
293 LLVMIntTy
= cast
<IntegerType
>(getTypeForSPIRVType(SpvType
));
295 LLVMIntTy
= IntegerType::getInt32Ty(MF
.getFunction().getContext());
296 // Find a constant in DT or build a new one.
297 const auto ConstInt
=
298 ConstantInt::get(const_cast<IntegerType
*>(LLVMIntTy
), Val
);
299 Register Res
= DT
.find(ConstInt
, &MF
);
300 if (!Res
.isValid()) {
301 unsigned BitWidth
= SpvType
? getScalarOrVectorBitWidth(SpvType
) : 32;
302 LLT LLTy
= LLT::scalar(EmitIR
? BitWidth
: 32);
303 Res
= MF
.getRegInfo().createGenericVirtualRegister(LLTy
);
304 MF
.getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
305 assignTypeToVReg(LLVMIntTy
, Res
, MIRBuilder
,
306 SPIRV::AccessQualifier::ReadWrite
, EmitIR
);
307 DT
.add(ConstInt
, &MIRBuilder
.getMF(), Res
);
309 MIRBuilder
.buildConstant(Res
, *ConstInt
);
311 MachineInstrBuilder MIB
;
314 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantI
)
316 .addUse(getSPIRVTypeID(SpvType
));
317 addNumImm(APInt(BitWidth
, Val
), MIB
);
320 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
322 .addUse(getSPIRVTypeID(SpvType
));
324 const auto &Subtarget
= CurMF
->getSubtarget();
325 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
326 *Subtarget
.getRegisterInfo(),
327 *Subtarget
.getRegBankInfo());
333 Register
SPIRVGlobalRegistry::buildConstantFP(APFloat Val
,
334 MachineIRBuilder
&MIRBuilder
,
335 SPIRVType
*SpvType
) {
336 auto &MF
= MIRBuilder
.getMF();
337 auto &Ctx
= MF
.getFunction().getContext();
339 const Type
*LLVMFPTy
= Type::getFloatTy(Ctx
);
340 SpvType
= getOrCreateSPIRVType(LLVMFPTy
, MIRBuilder
);
342 // Find a constant in DT or build a new one.
343 const auto ConstFP
= ConstantFP::get(Ctx
, Val
);
344 Register Res
= DT
.find(ConstFP
, &MF
);
345 if (!Res
.isValid()) {
346 Res
= MF
.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
347 MF
.getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
348 assignSPIRVTypeToVReg(SpvType
, Res
, MF
);
349 DT
.add(ConstFP
, &MF
, Res
);
351 MachineInstrBuilder MIB
;
352 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantF
)
354 .addUse(getSPIRVTypeID(SpvType
));
355 addNumImm(ConstFP
->getValueAPF().bitcastToAPInt(), MIB
);
361 Register
SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant
*Val
,
364 const SPIRVInstrInfo
&TII
,
366 SPIRVType
*Type
= SpvType
;
367 if (SpvType
->getOpcode() == SPIRV::OpTypeVector
||
368 SpvType
->getOpcode() == SPIRV::OpTypeArray
) {
369 auto EleTypeReg
= SpvType
->getOperand(1).getReg();
370 Type
= getSPIRVTypeForVReg(EleTypeReg
);
372 if (Type
->getOpcode() == SPIRV::OpTypeFloat
) {
373 SPIRVType
*SpvBaseType
= getOrCreateSPIRVFloatType(BitWidth
, I
, TII
);
374 return getOrCreateConstFP(dyn_cast
<ConstantFP
>(Val
)->getValue(), I
,
377 assert(Type
->getOpcode() == SPIRV::OpTypeInt
);
378 SPIRVType
*SpvBaseType
= getOrCreateSPIRVIntegerType(BitWidth
, I
, TII
);
379 return getOrCreateConstInt(Val
->getUniqueInteger().getSExtValue(), I
,
383 Register
SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
384 Constant
*Val
, MachineInstr
&I
, SPIRVType
*SpvType
,
385 const SPIRVInstrInfo
&TII
, Constant
*CA
, unsigned BitWidth
,
386 unsigned ElemCnt
, bool ZeroAsNull
) {
387 // Find a constant vector in DT or build a new one.
388 Register Res
= DT
.find(CA
, CurMF
);
389 // If no values are attached, the composite is null constant.
390 bool IsNull
= Val
->isNullValue() && ZeroAsNull
;
391 if (!Res
.isValid()) {
392 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
393 // error on validation.
394 // TODO: can moved below once sorting of types/consts/defs is implemented.
395 Register SpvScalConst
;
397 SpvScalConst
= getOrCreateBaseRegister(Val
, I
, SpvType
, TII
, BitWidth
);
399 // TODO: handle cases where the type is not 32bit wide
400 // TODO: https://github.com/llvm/llvm-project/issues/88129
401 LLT LLTy
= LLT::scalar(32);
402 Register SpvVecConst
=
403 CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
404 CurMF
->getRegInfo().setRegClass(SpvVecConst
, &SPIRV::IDRegClass
);
405 assignSPIRVTypeToVReg(SpvType
, SpvVecConst
, *CurMF
);
406 DT
.add(CA
, CurMF
, SpvVecConst
);
407 MachineInstrBuilder MIB
;
408 MachineBasicBlock
&BB
= *I
.getParent();
410 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantComposite
))
412 .addUse(getSPIRVTypeID(SpvType
));
413 for (unsigned i
= 0; i
< ElemCnt
; ++i
)
414 MIB
.addUse(SpvScalConst
);
416 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
418 .addUse(getSPIRVTypeID(SpvType
));
420 const auto &Subtarget
= CurMF
->getSubtarget();
421 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
422 *Subtarget
.getRegisterInfo(),
423 *Subtarget
.getRegBankInfo());
429 Register
SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val
,
432 const SPIRVInstrInfo
&TII
,
434 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
435 assert(LLVMTy
->isVectorTy());
436 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
437 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
438 assert(LLVMBaseTy
->isIntegerTy());
439 auto *ConstVal
= ConstantInt::get(LLVMBaseTy
, Val
);
441 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstVal
);
442 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
443 return getOrCreateCompositeOrNull(ConstVal
, I
, SpvType
, TII
, ConstVec
, BW
,
444 SpvType
->getOperand(2).getImm(),
448 Register
SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val
,
451 const SPIRVInstrInfo
&TII
,
453 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
454 assert(LLVMTy
->isVectorTy());
455 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
456 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
457 assert(LLVMBaseTy
->isFloatingPointTy());
458 auto *ConstVal
= ConstantFP::get(LLVMBaseTy
, Val
);
460 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstVal
);
461 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
462 return getOrCreateCompositeOrNull(ConstVal
, I
, SpvType
, TII
, ConstVec
, BW
,
463 SpvType
->getOperand(2).getImm(),
468 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val
, MachineInstr
&I
,
470 const SPIRVInstrInfo
&TII
) {
471 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
472 assert(LLVMTy
->isArrayTy());
473 const ArrayType
*LLVMArrTy
= cast
<ArrayType
>(LLVMTy
);
474 Type
*LLVMBaseTy
= LLVMArrTy
->getElementType();
475 auto *ConstInt
= ConstantInt::get(LLVMBaseTy
, Val
);
477 ConstantArray::get(const_cast<ArrayType
*>(LLVMArrTy
), {ConstInt
});
478 SPIRVType
*SpvBaseTy
= getSPIRVTypeForVReg(SpvType
->getOperand(1).getReg());
479 unsigned BW
= getScalarOrVectorBitWidth(SpvBaseTy
);
480 return getOrCreateCompositeOrNull(ConstInt
, I
, SpvType
, TII
, ConstArr
, BW
,
481 LLVMArrTy
->getNumElements());
484 Register
SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
485 uint64_t Val
, MachineIRBuilder
&MIRBuilder
, SPIRVType
*SpvType
, bool EmitIR
,
486 Constant
*CA
, unsigned BitWidth
, unsigned ElemCnt
) {
487 Register Res
= DT
.find(CA
, CurMF
);
488 if (!Res
.isValid()) {
489 Register SpvScalConst
;
491 SPIRVType
*SpvBaseType
=
492 getOrCreateSPIRVIntegerType(BitWidth
, MIRBuilder
);
493 SpvScalConst
= buildConstantInt(Val
, MIRBuilder
, SpvBaseType
, EmitIR
);
495 LLT LLTy
= EmitIR
? LLT::fixed_vector(ElemCnt
, BitWidth
) : LLT::scalar(32);
496 Register SpvVecConst
=
497 CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
498 CurMF
->getRegInfo().setRegClass(SpvVecConst
, &SPIRV::IDRegClass
);
499 assignSPIRVTypeToVReg(SpvType
, SpvVecConst
, *CurMF
);
500 DT
.add(CA
, CurMF
, SpvVecConst
);
502 MIRBuilder
.buildSplatVector(SpvVecConst
, SpvScalConst
);
505 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantComposite
)
507 .addUse(getSPIRVTypeID(SpvType
));
508 for (unsigned i
= 0; i
< ElemCnt
; ++i
)
509 MIB
.addUse(SpvScalConst
);
511 MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
513 .addUse(getSPIRVTypeID(SpvType
));
522 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val
,
523 MachineIRBuilder
&MIRBuilder
,
524 SPIRVType
*SpvType
, bool EmitIR
) {
525 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
526 assert(LLVMTy
->isVectorTy());
527 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
528 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
529 const auto ConstInt
= ConstantInt::get(LLVMBaseTy
, Val
);
531 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstInt
);
532 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
533 return getOrCreateIntCompositeOrNull(Val
, MIRBuilder
, SpvType
, EmitIR
,
535 SpvType
->getOperand(2).getImm());
539 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val
,
540 MachineIRBuilder
&MIRBuilder
,
541 SPIRVType
*SpvType
, bool EmitIR
) {
542 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
543 assert(LLVMTy
->isArrayTy());
544 const ArrayType
*LLVMArrTy
= cast
<ArrayType
>(LLVMTy
);
545 Type
*LLVMBaseTy
= LLVMArrTy
->getElementType();
546 const auto ConstInt
= ConstantInt::get(LLVMBaseTy
, Val
);
548 ConstantArray::get(const_cast<ArrayType
*>(LLVMArrTy
), {ConstInt
});
549 SPIRVType
*SpvBaseTy
= getSPIRVTypeForVReg(SpvType
->getOperand(1).getReg());
550 unsigned BW
= getScalarOrVectorBitWidth(SpvBaseTy
);
551 return getOrCreateIntCompositeOrNull(Val
, MIRBuilder
, SpvType
, EmitIR
,
553 LLVMArrTy
->getNumElements());
557 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder
&MIRBuilder
,
558 SPIRVType
*SpvType
) {
559 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
560 const TypedPointerType
*LLVMPtrTy
= cast
<TypedPointerType
>(LLVMTy
);
561 // Find a constant in DT or build a new one.
562 Constant
*CP
= ConstantPointerNull::get(PointerType::get(
563 LLVMPtrTy
->getElementType(), LLVMPtrTy
->getAddressSpace()));
564 Register Res
= DT
.find(CP
, CurMF
);
565 if (!Res
.isValid()) {
566 LLT LLTy
= LLT::pointer(LLVMPtrTy
->getAddressSpace(), PointerSize
);
567 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
568 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
569 assignSPIRVTypeToVReg(SpvType
, Res
, *CurMF
);
570 MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
572 .addUse(getSPIRVTypeID(SpvType
));
573 DT
.add(CP
, CurMF
, Res
);
578 Register
SPIRVGlobalRegistry::buildConstantSampler(
579 Register ResReg
, unsigned AddrMode
, unsigned Param
, unsigned FilerMode
,
580 MachineIRBuilder
&MIRBuilder
, SPIRVType
*SpvType
) {
583 SampTy
= getOrCreateSPIRVType(getTypeForSPIRVType(SpvType
), MIRBuilder
);
584 else if ((SampTy
= getOrCreateSPIRVTypeByName("opencl.sampler_t",
585 MIRBuilder
)) == nullptr)
586 report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
591 : MIRBuilder
.getMRI()->createVirtualRegister(&SPIRV::IDRegClass
);
592 auto Res
= MIRBuilder
.buildInstr(SPIRV::OpConstantSampler
)
594 .addUse(getSPIRVTypeID(SampTy
))
598 assert(Res
->getOperand(0).isReg());
599 return Res
->getOperand(0).getReg();
602 Register
SPIRVGlobalRegistry::buildGlobalVariable(
603 Register ResVReg
, SPIRVType
*BaseType
, StringRef Name
,
604 const GlobalValue
*GV
, SPIRV::StorageClass::StorageClass Storage
,
605 const MachineInstr
*Init
, bool IsConst
, bool HasLinkageTy
,
606 SPIRV::LinkageType::LinkageType LinkageType
, MachineIRBuilder
&MIRBuilder
,
607 bool IsInstSelector
) {
608 const GlobalVariable
*GVar
= nullptr;
610 GVar
= cast
<const GlobalVariable
>(GV
);
612 // If GV is not passed explicitly, use the name to find or construct
613 // the global variable.
614 Module
*M
= MIRBuilder
.getMF().getFunction().getParent();
615 GVar
= M
->getGlobalVariable(Name
);
616 if (GVar
== nullptr) {
617 const Type
*Ty
= getTypeForSPIRVType(BaseType
); // TODO: check type.
618 // Module takes ownership of the global var.
619 GVar
= new GlobalVariable(*M
, const_cast<Type
*>(Ty
), false,
620 GlobalValue::ExternalLinkage
, nullptr,
625 Register Reg
= DT
.find(GVar
, &MIRBuilder
.getMF());
628 MIRBuilder
.buildCopy(ResVReg
, Reg
);
632 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpVariable
)
634 .addUse(getSPIRVTypeID(BaseType
))
635 .addImm(static_cast<uint32_t>(Storage
));
638 MIB
.addUse(Init
->getOperand(0).getReg());
641 // ISel may introduce a new register on this step, so we need to add it to
642 // DT and correct its type avoiding fails on the next stage.
643 if (IsInstSelector
) {
644 const auto &Subtarget
= CurMF
->getSubtarget();
645 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
646 *Subtarget
.getRegisterInfo(),
647 *Subtarget
.getRegBankInfo());
649 Reg
= MIB
->getOperand(0).getReg();
650 DT
.add(GVar
, &MIRBuilder
.getMF(), Reg
);
652 // Set to Reg the same type as ResVReg has.
653 auto MRI
= MIRBuilder
.getMRI();
654 assert(MRI
->getType(ResVReg
).isPointer() && "Pointer type is expected");
655 if (Reg
!= ResVReg
) {
657 LLT::pointer(MRI
->getType(ResVReg
).getAddressSpace(), getPointerSize());
658 MRI
->setType(Reg
, RegLLTy
);
659 assignSPIRVTypeToVReg(BaseType
, Reg
, MIRBuilder
.getMF());
661 // Our knowledge about the type may be updated.
662 // If that's the case, we need to update a type
663 // associated with the register.
664 SPIRVType
*DefType
= getSPIRVTypeForVReg(ResVReg
);
665 if (!DefType
|| DefType
!= BaseType
)
666 assignSPIRVTypeToVReg(BaseType
, Reg
, MIRBuilder
.getMF());
669 // If it's a global variable with name, output OpName for it.
670 if (GVar
&& GVar
->hasName())
671 buildOpName(Reg
, GVar
->getName(), MIRBuilder
);
673 // Output decorations for the GV.
674 // TODO: maybe move to GenerateDecorations pass.
675 const SPIRVSubtarget
&ST
=
676 cast
<SPIRVSubtarget
>(MIRBuilder
.getMF().getSubtarget());
677 if (IsConst
&& ST
.isOpenCLEnv())
678 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::Constant
, {});
680 if (GVar
&& GVar
->getAlign().valueOrOne().value() != 1) {
681 unsigned Alignment
= (unsigned)GVar
->getAlign().valueOrOne().value();
682 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::Alignment
, {Alignment
});
686 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::LinkageAttributes
,
687 {static_cast<uint32_t>(LinkageType
)}, Name
);
689 SPIRV::BuiltIn::BuiltIn BuiltInId
;
690 if (getSpirvBuiltInIdByName(Name
, BuiltInId
))
691 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::BuiltIn
,
692 {static_cast<uint32_t>(BuiltInId
)});
697 SPIRVType
*SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems
,
699 MachineIRBuilder
&MIRBuilder
,
701 assert((ElemType
->getOpcode() != SPIRV::OpTypeVoid
) &&
702 "Invalid array element type");
703 Register NumElementsVReg
=
704 buildConstantInt(NumElems
, MIRBuilder
, nullptr, EmitIR
);
705 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeArray
)
706 .addDef(createTypeVReg(MIRBuilder
))
707 .addUse(getSPIRVTypeID(ElemType
))
708 .addUse(NumElementsVReg
);
712 SPIRVType
*SPIRVGlobalRegistry::getOpTypeOpaque(const StructType
*Ty
,
713 MachineIRBuilder
&MIRBuilder
) {
714 assert(Ty
->hasName());
715 const StringRef Name
= Ty
->hasName() ? Ty
->getName() : "";
716 Register ResVReg
= createTypeVReg(MIRBuilder
);
717 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeOpaque
).addDef(ResVReg
);
718 addStringImm(Name
, MIB
);
719 buildOpName(ResVReg
, Name
, MIRBuilder
);
723 SPIRVType
*SPIRVGlobalRegistry::getOpTypeStruct(const StructType
*Ty
,
724 MachineIRBuilder
&MIRBuilder
,
726 SmallVector
<Register
, 4> FieldTypes
;
727 for (const auto &Elem
: Ty
->elements()) {
729 findSPIRVType(toTypedPointer(Elem
, Ty
->getContext()), MIRBuilder
);
730 assert(ElemTy
&& ElemTy
->getOpcode() != SPIRV::OpTypeVoid
&&
731 "Invalid struct element type");
732 FieldTypes
.push_back(getSPIRVTypeID(ElemTy
));
734 Register ResVReg
= createTypeVReg(MIRBuilder
);
735 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeStruct
).addDef(ResVReg
);
736 for (const auto &Ty
: FieldTypes
)
739 buildOpName(ResVReg
, Ty
->getName(), MIRBuilder
);
741 buildOpDecorate(ResVReg
, MIRBuilder
, SPIRV::Decoration::CPacked
, {});
745 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSpecialType(
746 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
747 SPIRV::AccessQualifier::AccessQualifier AccQual
) {
748 assert(isSpecialOpaqueType(Ty
) && "Not a special opaque builtin type");
749 return SPIRV::lowerBuiltinType(Ty
, AccQual
, MIRBuilder
, this);
752 SPIRVType
*SPIRVGlobalRegistry::getOpTypePointer(
753 SPIRV::StorageClass::StorageClass SC
, SPIRVType
*ElemType
,
754 MachineIRBuilder
&MIRBuilder
, Register Reg
) {
756 Reg
= createTypeVReg(MIRBuilder
);
757 return MIRBuilder
.buildInstr(SPIRV::OpTypePointer
)
759 .addImm(static_cast<uint32_t>(SC
))
760 .addUse(getSPIRVTypeID(ElemType
));
763 SPIRVType
*SPIRVGlobalRegistry::getOpTypeForwardPointer(
764 SPIRV::StorageClass::StorageClass SC
, MachineIRBuilder
&MIRBuilder
) {
765 return MIRBuilder
.buildInstr(SPIRV::OpTypeForwardPointer
)
766 .addUse(createTypeVReg(MIRBuilder
))
767 .addImm(static_cast<uint32_t>(SC
));
770 SPIRVType
*SPIRVGlobalRegistry::getOpTypeFunction(
771 SPIRVType
*RetType
, const SmallVectorImpl
<SPIRVType
*> &ArgTypes
,
772 MachineIRBuilder
&MIRBuilder
) {
773 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeFunction
)
774 .addDef(createTypeVReg(MIRBuilder
))
775 .addUse(getSPIRVTypeID(RetType
));
776 for (const SPIRVType
*ArgType
: ArgTypes
)
777 MIB
.addUse(getSPIRVTypeID(ArgType
));
781 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
782 const Type
*Ty
, SPIRVType
*RetType
,
783 const SmallVectorImpl
<SPIRVType
*> &ArgTypes
,
784 MachineIRBuilder
&MIRBuilder
) {
785 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
787 return getSPIRVTypeForVReg(Reg
);
788 SPIRVType
*SpirvType
= getOpTypeFunction(RetType
, ArgTypes
, MIRBuilder
);
789 DT
.add(Ty
, CurMF
, getSPIRVTypeID(SpirvType
));
790 return finishCreatingSPIRVType(Ty
, SpirvType
);
793 SPIRVType
*SPIRVGlobalRegistry::findSPIRVType(
794 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
795 SPIRV::AccessQualifier::AccessQualifier AccQual
, bool EmitIR
) {
796 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
798 return getSPIRVTypeForVReg(Reg
);
799 if (ForwardPointerTypes
.contains(Ty
))
800 return ForwardPointerTypes
[Ty
];
801 return restOfCreateSPIRVType(Ty
, MIRBuilder
, AccQual
, EmitIR
);
804 Register
SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType
*SpirvType
) const {
805 assert(SpirvType
&& "Attempting to get type id for nullptr type.");
806 if (SpirvType
->getOpcode() == SPIRV::OpTypeForwardPointer
)
807 return SpirvType
->uses().begin()->getReg();
808 return SpirvType
->defs().begin()->getReg();
811 SPIRVType
*SPIRVGlobalRegistry::createSPIRVType(
812 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
813 SPIRV::AccessQualifier::AccessQualifier AccQual
, bool EmitIR
) {
814 if (isSpecialOpaqueType(Ty
))
815 return getOrCreateSpecialType(Ty
, MIRBuilder
, AccQual
);
816 auto &TypeToSPIRVTypeMap
= DT
.getTypes()->getAllUses();
817 auto t
= TypeToSPIRVTypeMap
.find(Ty
);
818 if (t
!= TypeToSPIRVTypeMap
.end()) {
819 auto tt
= t
->second
.find(&MIRBuilder
.getMF());
820 if (tt
!= t
->second
.end())
821 return getSPIRVTypeForVReg(tt
->second
);
824 if (auto IType
= dyn_cast
<IntegerType
>(Ty
)) {
825 const unsigned Width
= IType
->getBitWidth();
826 return Width
== 1 ? getOpTypeBool(MIRBuilder
)
827 : getOpTypeInt(Width
, MIRBuilder
, false);
829 if (Ty
->isFloatingPointTy())
830 return getOpTypeFloat(Ty
->getPrimitiveSizeInBits(), MIRBuilder
);
832 return getOpTypeVoid(MIRBuilder
);
833 if (Ty
->isVectorTy()) {
835 findSPIRVType(cast
<FixedVectorType
>(Ty
)->getElementType(), MIRBuilder
);
836 return getOpTypeVector(cast
<FixedVectorType
>(Ty
)->getNumElements(), El
,
839 if (Ty
->isArrayTy()) {
840 SPIRVType
*El
= findSPIRVType(Ty
->getArrayElementType(), MIRBuilder
);
841 return getOpTypeArray(Ty
->getArrayNumElements(), El
, MIRBuilder
, EmitIR
);
843 if (auto SType
= dyn_cast
<StructType
>(Ty
)) {
844 if (SType
->isOpaque())
845 return getOpTypeOpaque(SType
, MIRBuilder
);
846 return getOpTypeStruct(SType
, MIRBuilder
, EmitIR
);
848 if (auto FType
= dyn_cast
<FunctionType
>(Ty
)) {
849 SPIRVType
*RetTy
= findSPIRVType(FType
->getReturnType(), MIRBuilder
);
850 SmallVector
<SPIRVType
*, 4> ParamTypes
;
851 for (const auto &t
: FType
->params()) {
852 ParamTypes
.push_back(findSPIRVType(t
, MIRBuilder
));
854 return getOpTypeFunction(RetTy
, ParamTypes
, MIRBuilder
);
856 unsigned AddrSpace
= 0xFFFF;
857 if (auto PType
= dyn_cast
<TypedPointerType
>(Ty
))
858 AddrSpace
= PType
->getAddressSpace();
859 else if (auto PType
= dyn_cast
<PointerType
>(Ty
))
860 AddrSpace
= PType
->getAddressSpace();
862 report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
864 SPIRVType
*SpvElementType
= nullptr;
865 if (auto PType
= dyn_cast
<TypedPointerType
>(Ty
))
866 SpvElementType
= getOrCreateSPIRVType(PType
->getElementType(), MIRBuilder
,
869 SpvElementType
= getOrCreateSPIRVIntegerType(8, MIRBuilder
);
871 // Get access to information about available extensions
872 const SPIRVSubtarget
*ST
=
873 static_cast<const SPIRVSubtarget
*>(&MIRBuilder
.getMF().getSubtarget());
874 auto SC
= addressSpaceToStorageClass(AddrSpace
, *ST
);
875 // Null pointer means we have a loop in type definitions, make and
876 // return corresponding OpTypeForwardPointer.
877 if (SpvElementType
== nullptr) {
878 if (!ForwardPointerTypes
.contains(Ty
))
879 ForwardPointerTypes
[Ty
] = getOpTypeForwardPointer(SC
, MIRBuilder
);
880 return ForwardPointerTypes
[Ty
];
882 // If we have forward pointer associated with this type, use its register
883 // operand to create OpTypePointer.
884 if (ForwardPointerTypes
.contains(Ty
)) {
885 Register Reg
= getSPIRVTypeID(ForwardPointerTypes
[Ty
]);
886 return getOpTypePointer(SC
, SpvElementType
, MIRBuilder
, Reg
);
889 return getOrCreateSPIRVPointerType(SpvElementType
, MIRBuilder
, SC
);
892 SPIRVType
*SPIRVGlobalRegistry::restOfCreateSPIRVType(
893 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
894 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
895 if (TypesInProcessing
.count(Ty
) && !isPointerTy(Ty
))
897 TypesInProcessing
.insert(Ty
);
898 SPIRVType
*SpirvType
= createSPIRVType(Ty
, MIRBuilder
, AccessQual
, EmitIR
);
899 TypesInProcessing
.erase(Ty
);
900 VRegToTypeMap
[&MIRBuilder
.getMF()][getSPIRVTypeID(SpirvType
)] = SpirvType
;
901 SPIRVToLLVMType
[SpirvType
] = Ty
;
902 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
903 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
904 // will be added later. For special types it is already added to DT.
905 if (SpirvType
->getOpcode() != SPIRV::OpTypeForwardPointer
&& !Reg
.isValid() &&
906 !isSpecialOpaqueType(Ty
)) {
907 if (!isPointerTy(Ty
))
908 DT
.add(Ty
, &MIRBuilder
.getMF(), getSPIRVTypeID(SpirvType
));
909 else if (isTypedPointerTy(Ty
))
910 DT
.add(cast
<TypedPointerType
>(Ty
)->getElementType(),
911 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF(),
912 getSPIRVTypeID(SpirvType
));
914 DT
.add(Type::getInt8Ty(MIRBuilder
.getMF().getFunction().getContext()),
915 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF(),
916 getSPIRVTypeID(SpirvType
));
923 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg
,
924 const MachineFunction
*MF
) const {
925 auto t
= VRegToTypeMap
.find(MF
? MF
: CurMF
);
926 if (t
!= VRegToTypeMap
.end()) {
927 auto tt
= t
->second
.find(VReg
);
928 if (tt
!= t
->second
.end())
934 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVType(
935 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
936 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
938 if (!isPointerTy(Ty
))
939 Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
940 else if (isTypedPointerTy(Ty
))
941 Reg
= DT
.find(cast
<TypedPointerType
>(Ty
)->getElementType(),
942 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF());
945 DT
.find(Type::getInt8Ty(MIRBuilder
.getMF().getFunction().getContext()),
946 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF());
948 if (Reg
.isValid() && !isSpecialOpaqueType(Ty
))
949 return getSPIRVTypeForVReg(Reg
);
950 TypesInProcessing
.clear();
951 SPIRVType
*STy
= restOfCreateSPIRVType(Ty
, MIRBuilder
, AccessQual
, EmitIR
);
952 // Create normal pointer types for the corresponding OpTypeForwardPointers.
953 for (auto &CU
: ForwardPointerTypes
) {
954 const Type
*Ty2
= CU
.first
;
955 SPIRVType
*STy2
= CU
.second
;
956 if ((Reg
= DT
.find(Ty2
, &MIRBuilder
.getMF())).isValid())
957 STy2
= getSPIRVTypeForVReg(Reg
);
959 STy2
= restOfCreateSPIRVType(Ty2
, MIRBuilder
, AccessQual
, EmitIR
);
963 ForwardPointerTypes
.clear();
967 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg
,
968 unsigned TypeOpcode
) const {
969 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
970 assert(Type
&& "isScalarOfType VReg has no type assigned");
971 return Type
->getOpcode() == TypeOpcode
;
974 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg
,
975 unsigned TypeOpcode
) const {
976 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
977 assert(Type
&& "isScalarOrVectorOfType VReg has no type assigned");
978 if (Type
->getOpcode() == TypeOpcode
)
980 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
981 Register ScalarTypeVReg
= Type
->getOperand(1).getReg();
982 SPIRVType
*ScalarType
= getSPIRVTypeForVReg(ScalarTypeVReg
);
983 return ScalarType
->getOpcode() == TypeOpcode
;
989 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg
) const {
990 return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg
));
994 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType
*Type
) const {
997 return Type
->getOpcode() == SPIRV::OpTypeVector
998 ? static_cast<unsigned>(Type
->getOperand(2).getImm())
1003 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType
*Type
) const {
1004 assert(Type
&& "Invalid Type pointer");
1005 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
1006 auto EleTypeReg
= Type
->getOperand(1).getReg();
1007 Type
= getSPIRVTypeForVReg(EleTypeReg
);
1009 if (Type
->getOpcode() == SPIRV::OpTypeInt
||
1010 Type
->getOpcode() == SPIRV::OpTypeFloat
)
1011 return Type
->getOperand(1).getImm();
1012 if (Type
->getOpcode() == SPIRV::OpTypeBool
)
1014 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1017 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1018 const SPIRVType
*Type
) const {
1019 assert(Type
&& "Invalid Type pointer");
1020 unsigned NumElements
= 1;
1021 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
1022 NumElements
= static_cast<unsigned>(Type
->getOperand(2).getImm());
1023 Type
= getSPIRVTypeForVReg(Type
->getOperand(1).getReg());
1025 return Type
->getOpcode() == SPIRV::OpTypeInt
||
1026 Type
->getOpcode() == SPIRV::OpTypeFloat
1027 ? NumElements
* Type
->getOperand(1).getImm()
1031 const SPIRVType
*SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1032 const SPIRVType
*Type
) const {
1033 if (Type
&& Type
->getOpcode() == SPIRV::OpTypeVector
)
1034 Type
= getSPIRVTypeForVReg(Type
->getOperand(1).getReg());
1035 return Type
&& Type
->getOpcode() == SPIRV::OpTypeInt
? Type
: nullptr;
1038 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType
*Type
) const {
1039 const SPIRVType
*IntType
= retrieveScalarOrVectorIntType(Type
);
1040 return IntType
&& IntType
->getOperand(2).getImm() != 0;
1043 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg
) {
1044 SPIRVType
*PtrType
= getSPIRVTypeForVReg(PtrReg
);
1045 SPIRVType
*ElemType
=
1046 PtrType
&& PtrType
->getOpcode() == SPIRV::OpTypePointer
1047 ? getSPIRVTypeForVReg(PtrType
->getOperand(2).getReg())
1049 return ElemType
? ElemType
->getOpcode() : 0;
1052 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType
*Type1
,
1053 const SPIRVType
*Type2
) const {
1054 if (!Type1
|| !Type2
)
1056 auto Op1
= Type1
->getOpcode(), Op2
= Type2
->getOpcode();
1057 // Ignore difference between <1.5 and >=1.5 protocol versions:
1058 // it's valid if either Result Type or Operand is a pointer, and the other
1059 // is a pointer, an integer scalar, or an integer vector.
1060 if (Op1
== SPIRV::OpTypePointer
&&
1061 (Op2
== SPIRV::OpTypePointer
|| retrieveScalarOrVectorIntType(Type2
)))
1063 if (Op2
== SPIRV::OpTypePointer
&&
1064 (Op1
== SPIRV::OpTypePointer
|| retrieveScalarOrVectorIntType(Type1
)))
1066 unsigned Bits1
= getNumScalarOrVectorTotalBitWidth(Type1
),
1067 Bits2
= getNumScalarOrVectorTotalBitWidth(Type2
);
1068 return Bits1
> 0 && Bits1
== Bits2
;
1071 SPIRV::StorageClass::StorageClass
1072 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg
) const {
1073 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
1074 assert(Type
&& Type
->getOpcode() == SPIRV::OpTypePointer
&&
1075 Type
->getOperand(1).isImm() && "Pointer type is expected");
1076 return static_cast<SPIRV::StorageClass::StorageClass
>(
1077 Type
->getOperand(1).getImm());
1080 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1081 MachineIRBuilder
&MIRBuilder
, SPIRVType
*SampledType
, SPIRV::Dim::Dim Dim
,
1082 uint32_t Depth
, uint32_t Arrayed
, uint32_t Multisampled
, uint32_t Sampled
,
1083 SPIRV::ImageFormat::ImageFormat ImageFormat
,
1084 SPIRV::AccessQualifier::AccessQualifier AccessQual
) {
1085 SPIRV::ImageTypeDescriptor
TD(SPIRVToLLVMType
.lookup(SampledType
), Dim
, Depth
,
1086 Arrayed
, Multisampled
, Sampled
, ImageFormat
,
1088 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1090 Register ResVReg
= createTypeVReg(MIRBuilder
);
1091 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1092 return MIRBuilder
.buildInstr(SPIRV::OpTypeImage
)
1094 .addUse(getSPIRVTypeID(SampledType
))
1096 .addImm(Depth
) // Depth (whether or not it is a Depth image).
1097 .addImm(Arrayed
) // Arrayed.
1098 .addImm(Multisampled
) // Multisampled (0 = only single-sample).
1099 .addImm(Sampled
) // Sampled (0 = usage known at runtime).
1100 .addImm(ImageFormat
)
1101 .addImm(AccessQual
);
1105 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder
&MIRBuilder
) {
1106 SPIRV::SamplerTypeDescriptor TD
;
1107 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1109 Register ResVReg
= createTypeVReg(MIRBuilder
);
1110 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1111 return MIRBuilder
.buildInstr(SPIRV::OpTypeSampler
).addDef(ResVReg
);
1114 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1115 MachineIRBuilder
&MIRBuilder
,
1116 SPIRV::AccessQualifier::AccessQualifier AccessQual
) {
1117 SPIRV::PipeTypeDescriptor
TD(AccessQual
);
1118 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1120 Register ResVReg
= createTypeVReg(MIRBuilder
);
1121 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1122 return MIRBuilder
.buildInstr(SPIRV::OpTypePipe
)
1124 .addImm(AccessQual
);
1127 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1128 MachineIRBuilder
&MIRBuilder
) {
1129 SPIRV::DeviceEventTypeDescriptor TD
;
1130 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1132 Register ResVReg
= createTypeVReg(MIRBuilder
);
1133 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1134 return MIRBuilder
.buildInstr(SPIRV::OpTypeDeviceEvent
).addDef(ResVReg
);
1137 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1138 SPIRVType
*ImageType
, MachineIRBuilder
&MIRBuilder
) {
1139 SPIRV::SampledImageTypeDescriptor
TD(
1140 SPIRVToLLVMType
.lookup(MIRBuilder
.getMF().getRegInfo().getVRegDef(
1141 ImageType
->getOperand(1).getReg())),
1143 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1145 Register ResVReg
= createTypeVReg(MIRBuilder
);
1146 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1147 return MIRBuilder
.buildInstr(SPIRV::OpTypeSampledImage
)
1149 .addUse(getSPIRVTypeID(ImageType
));
1152 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1153 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
, unsigned Opcode
) {
1154 Register ResVReg
= DT
.find(Ty
, &MIRBuilder
.getMF());
1155 if (ResVReg
.isValid())
1156 return MIRBuilder
.getMF().getRegInfo().getUniqueVRegDef(ResVReg
);
1157 ResVReg
= createTypeVReg(MIRBuilder
);
1158 SPIRVType
*SpirvTy
= MIRBuilder
.buildInstr(Opcode
).addDef(ResVReg
);
1159 DT
.add(Ty
, &MIRBuilder
.getMF(), ResVReg
);
1163 const MachineInstr
*
1164 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor
&TD
,
1165 MachineIRBuilder
&MIRBuilder
) {
1166 Register Reg
= DT
.find(TD
, &MIRBuilder
.getMF());
1168 return MIRBuilder
.getMF().getRegInfo().getUniqueVRegDef(Reg
);
1172 // Returns nullptr if unable to recognize SPIRV type name
1173 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1174 StringRef TypeStr
, MachineIRBuilder
&MIRBuilder
,
1175 SPIRV::StorageClass::StorageClass SC
,
1176 SPIRV::AccessQualifier::AccessQualifier AQ
) {
1177 unsigned VecElts
= 0;
1178 auto &Ctx
= MIRBuilder
.getMF().getFunction().getContext();
1180 // Parse strings representing either a SPIR-V or OpenCL builtin type.
1181 if (hasBuiltinTypePrefix(TypeStr
))
1182 return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1183 TypeStr
.str(), MIRBuilder
.getContext()),
1186 // Parse type name in either "typeN" or "type vector[N]" format, where
1187 // N is the number of elements of the vector.
1190 Ty
= parseBasicTypeName(TypeStr
, Ctx
);
1192 // Unable to recognize SPIRV type name
1195 auto SpirvTy
= getOrCreateSPIRVType(Ty
, MIRBuilder
, AQ
);
1197 // Handle "type*" or "type* vector[N]".
1198 if (TypeStr
.starts_with("*")) {
1199 SpirvTy
= getOrCreateSPIRVPointerType(SpirvTy
, MIRBuilder
, SC
);
1200 TypeStr
= TypeStr
.substr(strlen("*"));
1203 // Handle "typeN*" or "type vector[N]*".
1204 bool IsPtrToVec
= TypeStr
.consume_back("*");
1206 if (TypeStr
.consume_front(" vector[")) {
1207 TypeStr
= TypeStr
.substr(0, TypeStr
.find(']'));
1209 TypeStr
.getAsInteger(10, VecElts
);
1211 SpirvTy
= getOrCreateSPIRVVectorType(SpirvTy
, VecElts
, MIRBuilder
);
1214 SpirvTy
= getOrCreateSPIRVPointerType(SpirvTy
, MIRBuilder
, SC
);
1220 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth
,
1221 MachineIRBuilder
&MIRBuilder
) {
1222 return getOrCreateSPIRVType(
1223 IntegerType::get(MIRBuilder
.getMF().getFunction().getContext(), BitWidth
),
1227 SPIRVType
*SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type
*LLVMTy
,
1228 SPIRVType
*SpirvType
) {
1229 assert(CurMF
== SpirvType
->getMF());
1230 VRegToTypeMap
[CurMF
][getSPIRVTypeID(SpirvType
)] = SpirvType
;
1231 SPIRVToLLVMType
[SpirvType
] = LLVMTy
;
1235 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth
,
1237 const SPIRVInstrInfo
&TII
,
1238 unsigned SPIRVOPcode
,
1240 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1242 return getSPIRVTypeForVReg(Reg
);
1243 MachineBasicBlock
&BB
= *I
.getParent();
1244 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRVOPcode
))
1245 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1248 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1249 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1252 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1253 unsigned BitWidth
, MachineInstr
&I
, const SPIRVInstrInfo
&TII
) {
1254 Type
*LLVMTy
= IntegerType::get(CurMF
->getFunction().getContext(), BitWidth
);
1255 return getOrCreateSPIRVType(BitWidth
, I
, TII
, SPIRV::OpTypeInt
, LLVMTy
);
1257 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1258 unsigned BitWidth
, MachineInstr
&I
, const SPIRVInstrInfo
&TII
) {
1259 LLVMContext
&Ctx
= CurMF
->getFunction().getContext();
1263 LLVMTy
= Type::getHalfTy(Ctx
);
1266 LLVMTy
= Type::getFloatTy(Ctx
);
1269 LLVMTy
= Type::getDoubleTy(Ctx
);
1272 llvm_unreachable("Bit width is of unexpected size.");
1274 return getOrCreateSPIRVType(BitWidth
, I
, TII
, SPIRV::OpTypeFloat
, LLVMTy
);
1278 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder
&MIRBuilder
) {
1279 return getOrCreateSPIRVType(
1280 IntegerType::get(MIRBuilder
.getMF().getFunction().getContext(), 1),
1285 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr
&I
,
1286 const SPIRVInstrInfo
&TII
) {
1287 Type
*LLVMTy
= IntegerType::get(CurMF
->getFunction().getContext(), 1);
1288 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1290 return getSPIRVTypeForVReg(Reg
);
1291 MachineBasicBlock
&BB
= *I
.getParent();
1292 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeBool
))
1293 .addDef(createTypeVReg(CurMF
->getRegInfo()));
1294 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1295 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1298 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1299 SPIRVType
*BaseType
, unsigned NumElements
, MachineIRBuilder
&MIRBuilder
) {
1300 return getOrCreateSPIRVType(
1301 FixedVectorType::get(const_cast<Type
*>(getTypeForSPIRVType(BaseType
)),
1306 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1307 SPIRVType
*BaseType
, unsigned NumElements
, MachineInstr
&I
,
1308 const SPIRVInstrInfo
&TII
) {
1309 Type
*LLVMTy
= FixedVectorType::get(
1310 const_cast<Type
*>(getTypeForSPIRVType(BaseType
)), NumElements
);
1311 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1313 return getSPIRVTypeForVReg(Reg
);
1314 MachineBasicBlock
&BB
= *I
.getParent();
1315 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeVector
))
1316 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1317 .addUse(getSPIRVTypeID(BaseType
))
1318 .addImm(NumElements
);
1319 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1320 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1323 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1324 SPIRVType
*BaseType
, unsigned NumElements
, MachineInstr
&I
,
1325 const SPIRVInstrInfo
&TII
) {
1326 Type
*LLVMTy
= ArrayType::get(
1327 const_cast<Type
*>(getTypeForSPIRVType(BaseType
)), NumElements
);
1328 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1330 return getSPIRVTypeForVReg(Reg
);
1331 MachineBasicBlock
&BB
= *I
.getParent();
1332 SPIRVType
*SpirvType
= getOrCreateSPIRVIntegerType(32, I
, TII
);
1333 Register Len
= getOrCreateConstInt(NumElements
, I
, SpirvType
, TII
);
1334 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeArray
))
1335 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1336 .addUse(getSPIRVTypeID(BaseType
))
1338 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1339 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1342 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1343 SPIRVType
*BaseType
, MachineIRBuilder
&MIRBuilder
,
1344 SPIRV::StorageClass::StorageClass SC
) {
1345 const Type
*PointerElementType
= getTypeForSPIRVType(BaseType
);
1346 unsigned AddressSpace
= storageClassToAddressSpace(SC
);
1347 Type
*LLVMTy
= TypedPointerType::get(const_cast<Type
*>(PointerElementType
),
1349 // check if this type is already available
1350 Register Reg
= DT
.find(PointerElementType
, AddressSpace
, CurMF
);
1352 return getSPIRVTypeForVReg(Reg
);
1353 // create a new type
1354 auto MIB
= BuildMI(MIRBuilder
.getMBB(), MIRBuilder
.getInsertPt(),
1355 MIRBuilder
.getDebugLoc(),
1356 MIRBuilder
.getTII().get(SPIRV::OpTypePointer
))
1357 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1358 .addImm(static_cast<uint32_t>(SC
))
1359 .addUse(getSPIRVTypeID(BaseType
));
1360 DT
.add(PointerElementType
, AddressSpace
, CurMF
, getSPIRVTypeID(MIB
));
1361 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1364 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1365 SPIRVType
*BaseType
, MachineInstr
&I
, const SPIRVInstrInfo
&,
1366 SPIRV::StorageClass::StorageClass SC
) {
1367 MachineIRBuilder
MIRBuilder(I
);
1368 return getOrCreateSPIRVPointerType(BaseType
, MIRBuilder
, SC
);
1371 Register
SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr
&I
,
1373 const SPIRVInstrInfo
&TII
) {
1375 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
1377 // Find a constant in DT or build a new one.
1378 UndefValue
*UV
= UndefValue::get(const_cast<Type
*>(LLVMTy
));
1379 Register Res
= DT
.find(UV
, CurMF
);
1382 LLT LLTy
= LLT::scalar(32);
1383 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
1384 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
1385 assignSPIRVTypeToVReg(SpvType
, Res
, *CurMF
);
1386 DT
.add(UV
, CurMF
, Res
);
1388 MachineInstrBuilder MIB
;
1389 MIB
= BuildMI(*I
.getParent(), I
, I
.getDebugLoc(), TII
.get(SPIRV::OpUndef
))
1391 .addUse(getSPIRVTypeID(SpvType
));
1392 const auto &ST
= CurMF
->getSubtarget();
1393 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
1394 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());