1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
15 //===----------------------------------------------------------------------===//
17 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Support/Casting.h"
30 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize
)
31 : PointerSize(PointerSize
), Bound(0) {}
33 SPIRVType
*SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth
,
36 const SPIRVInstrInfo
&TII
) {
37 SPIRVType
*SpirvType
= getOrCreateSPIRVIntegerType(BitWidth
, I
, TII
);
38 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
43 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth
, Register VReg
,
45 const SPIRVInstrInfo
&TII
) {
46 SPIRVType
*SpirvType
= getOrCreateSPIRVFloatType(BitWidth
, I
, TII
);
47 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
51 SPIRVType
*SPIRVGlobalRegistry::assignVectTypeToVReg(
52 SPIRVType
*BaseType
, unsigned NumElements
, Register VReg
, MachineInstr
&I
,
53 const SPIRVInstrInfo
&TII
) {
54 SPIRVType
*SpirvType
=
55 getOrCreateSPIRVVectorType(BaseType
, NumElements
, I
, TII
);
56 assignSPIRVTypeToVReg(SpirvType
, VReg
, *CurMF
);
60 SPIRVType
*SPIRVGlobalRegistry::assignTypeToVReg(
61 const Type
*Type
, Register VReg
, MachineIRBuilder
&MIRBuilder
,
62 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
63 SPIRVType
*SpirvType
=
64 getOrCreateSPIRVType(Type
, MIRBuilder
, AccessQual
, EmitIR
);
65 assignSPIRVTypeToVReg(SpirvType
, VReg
, MIRBuilder
.getMF());
69 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType
*SpirvType
,
71 MachineFunction
&MF
) {
72 VRegToTypeMap
[&MF
][VReg
] = SpirvType
;
75 static Register
createTypeVReg(MachineIRBuilder
&MIRBuilder
) {
76 auto &MRI
= MIRBuilder
.getMF().getRegInfo();
77 auto Res
= MRI
.createGenericVirtualRegister(LLT::scalar(32));
78 MRI
.setRegClass(Res
, &SPIRV::TYPERegClass
);
82 static Register
createTypeVReg(MachineRegisterInfo
&MRI
) {
83 auto Res
= MRI
.createGenericVirtualRegister(LLT::scalar(32));
84 MRI
.setRegClass(Res
, &SPIRV::TYPERegClass
);
88 SPIRVType
*SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder
&MIRBuilder
) {
89 return MIRBuilder
.buildInstr(SPIRV::OpTypeBool
)
90 .addDef(createTypeVReg(MIRBuilder
));
93 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width
) const {
95 report_fatal_error("Unsupported integer width!");
96 const SPIRVSubtarget
&ST
= cast
<SPIRVSubtarget
>(CurMF
->getSubtarget());
97 if (ST
.canUseExtension(
98 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers
))
102 else if (Width
<= 16)
104 else if (Width
<= 32)
111 SPIRVType
*SPIRVGlobalRegistry::getOpTypeInt(unsigned Width
,
112 MachineIRBuilder
&MIRBuilder
,
114 Width
= adjustOpTypeIntWidth(Width
);
115 const SPIRVSubtarget
&ST
=
116 cast
<SPIRVSubtarget
>(MIRBuilder
.getMF().getSubtarget());
117 if (ST
.canUseExtension(
118 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers
)) {
119 MIRBuilder
.buildInstr(SPIRV::OpExtension
)
120 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers
);
121 MIRBuilder
.buildInstr(SPIRV::OpCapability
)
122 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL
);
124 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeInt
)
125 .addDef(createTypeVReg(MIRBuilder
))
127 .addImm(IsSigned
? 1 : 0);
131 SPIRVType
*SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width
,
132 MachineIRBuilder
&MIRBuilder
) {
133 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeFloat
)
134 .addDef(createTypeVReg(MIRBuilder
))
139 SPIRVType
*SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder
&MIRBuilder
) {
140 return MIRBuilder
.buildInstr(SPIRV::OpTypeVoid
)
141 .addDef(createTypeVReg(MIRBuilder
));
144 SPIRVType
*SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems
,
146 MachineIRBuilder
&MIRBuilder
) {
147 auto EleOpc
= ElemType
->getOpcode();
149 assert((EleOpc
== SPIRV::OpTypeInt
|| EleOpc
== SPIRV::OpTypeFloat
||
150 EleOpc
== SPIRV::OpTypeBool
) &&
151 "Invalid vector element type");
153 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeVector
)
154 .addDef(createTypeVReg(MIRBuilder
))
155 .addUse(getSPIRVTypeID(ElemType
))
160 std::tuple
<Register
, ConstantInt
*, bool>
161 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val
, SPIRVType
*SpvType
,
162 MachineIRBuilder
*MIRBuilder
,
164 const SPIRVInstrInfo
*TII
) {
165 const IntegerType
*LLVMIntTy
;
167 LLVMIntTy
= cast
<IntegerType
>(getTypeForSPIRVType(SpvType
));
169 LLVMIntTy
= IntegerType::getInt32Ty(CurMF
->getFunction().getContext());
170 bool NewInstr
= false;
171 // Find a constant in DT or build a new one.
172 ConstantInt
*CI
= ConstantInt::get(const_cast<IntegerType
*>(LLVMIntTy
), Val
);
173 Register Res
= DT
.find(CI
, CurMF
);
174 if (!Res
.isValid()) {
175 unsigned BitWidth
= SpvType
? getScalarOrVectorBitWidth(SpvType
) : 32;
176 // TODO: handle cases where the type is not 32bit wide
177 // TODO: https://github.com/llvm/llvm-project/issues/88129
178 LLT LLTy
= LLT::scalar(32);
179 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
180 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
182 assignTypeToVReg(LLVMIntTy
, Res
, *MIRBuilder
);
184 assignIntTypeToVReg(BitWidth
, Res
, *I
, *TII
);
185 DT
.add(CI
, CurMF
, Res
);
188 return std::make_tuple(Res
, CI
, NewInstr
);
191 std::tuple
<Register
, ConstantFP
*, bool, unsigned>
192 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val
, SPIRVType
*SpvType
,
193 MachineIRBuilder
*MIRBuilder
,
195 const SPIRVInstrInfo
*TII
) {
196 const Type
*LLVMFloatTy
;
197 LLVMContext
&Ctx
= CurMF
->getFunction().getContext();
198 unsigned BitWidth
= 32;
200 LLVMFloatTy
= getTypeForSPIRVType(SpvType
);
202 LLVMFloatTy
= Type::getFloatTy(Ctx
);
204 SpvType
= getOrCreateSPIRVType(LLVMFloatTy
, *MIRBuilder
);
206 bool NewInstr
= false;
207 // Find a constant in DT or build a new one.
208 auto *const CI
= ConstantFP::get(Ctx
, Val
);
209 Register Res
= DT
.find(CI
, CurMF
);
210 if (!Res
.isValid()) {
212 BitWidth
= getScalarOrVectorBitWidth(SpvType
);
213 // TODO: handle cases where the type is not 32bit wide
214 // TODO: https://github.com/llvm/llvm-project/issues/88129
215 LLT LLTy
= LLT::scalar(32);
216 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
217 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
219 assignTypeToVReg(LLVMFloatTy
, Res
, *MIRBuilder
);
221 assignFloatTypeToVReg(BitWidth
, Res
, *I
, *TII
);
222 DT
.add(CI
, CurMF
, Res
);
225 return std::make_tuple(Res
, CI
, NewInstr
, BitWidth
);
228 Register
SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val
, MachineInstr
&I
,
230 const SPIRVInstrInfo
&TII
,
237 std::tie(Res
, CI
, New
, BitWidth
) =
238 getOrCreateConstFloatReg(Val
, SpvType
, nullptr, &I
, &TII
);
239 // If we have found Res register which is defined by the passed G_CONSTANT
240 // machine instruction, a new constant instruction should be created.
241 if (!New
&& (!I
.getOperand(0).isReg() || Res
!= I
.getOperand(0).getReg()))
243 MachineInstrBuilder MIB
;
244 MachineBasicBlock
&BB
= *I
.getParent();
245 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
246 if (Val
.isPosZero() && ZeroAsNull
) {
247 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
249 .addUse(getSPIRVTypeID(SpvType
));
251 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantF
))
253 .addUse(getSPIRVTypeID(SpvType
));
255 APInt(BitWidth
, CI
->getValueAPF().bitcastToAPInt().getZExtValue()),
258 const auto &ST
= CurMF
->getSubtarget();
259 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
260 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());
264 Register
SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val
, MachineInstr
&I
,
266 const SPIRVInstrInfo
&TII
,
272 std::tie(Res
, CI
, New
) =
273 getOrCreateConstIntReg(Val
, SpvType
, nullptr, &I
, &TII
);
274 // If we have found Res register which is defined by the passed G_CONSTANT
275 // machine instruction, a new constant instruction should be created.
276 if (!New
&& (!I
.getOperand(0).isReg() || Res
!= I
.getOperand(0).getReg()))
278 MachineInstrBuilder MIB
;
279 MachineBasicBlock
&BB
= *I
.getParent();
280 if (Val
|| !ZeroAsNull
) {
281 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantI
))
283 .addUse(getSPIRVTypeID(SpvType
));
284 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType
), Val
), MIB
);
286 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
288 .addUse(getSPIRVTypeID(SpvType
));
290 const auto &ST
= CurMF
->getSubtarget();
291 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
292 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());
296 Register
SPIRVGlobalRegistry::buildConstantInt(uint64_t Val
,
297 MachineIRBuilder
&MIRBuilder
,
300 auto &MF
= MIRBuilder
.getMF();
301 const IntegerType
*LLVMIntTy
;
303 LLVMIntTy
= cast
<IntegerType
>(getTypeForSPIRVType(SpvType
));
305 LLVMIntTy
= IntegerType::getInt32Ty(MF
.getFunction().getContext());
306 // Find a constant in DT or build a new one.
307 const auto ConstInt
=
308 ConstantInt::get(const_cast<IntegerType
*>(LLVMIntTy
), Val
);
309 Register Res
= DT
.find(ConstInt
, &MF
);
310 if (!Res
.isValid()) {
311 unsigned BitWidth
= SpvType
? getScalarOrVectorBitWidth(SpvType
) : 32;
312 LLT LLTy
= LLT::scalar(EmitIR
? BitWidth
: 32);
313 Res
= MF
.getRegInfo().createGenericVirtualRegister(LLTy
);
314 MF
.getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
315 assignTypeToVReg(LLVMIntTy
, Res
, MIRBuilder
,
316 SPIRV::AccessQualifier::ReadWrite
, EmitIR
);
317 DT
.add(ConstInt
, &MIRBuilder
.getMF(), Res
);
319 MIRBuilder
.buildConstant(Res
, *ConstInt
);
322 SpvType
= getOrCreateSPIRVIntegerType(BitWidth
, MIRBuilder
);
323 MachineInstrBuilder MIB
;
325 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantI
)
327 .addUse(getSPIRVTypeID(SpvType
));
328 addNumImm(APInt(BitWidth
, Val
), MIB
);
330 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
332 .addUse(getSPIRVTypeID(SpvType
));
334 const auto &Subtarget
= CurMF
->getSubtarget();
335 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
336 *Subtarget
.getRegisterInfo(),
337 *Subtarget
.getRegBankInfo());
343 Register
SPIRVGlobalRegistry::buildConstantFP(APFloat Val
,
344 MachineIRBuilder
&MIRBuilder
,
345 SPIRVType
*SpvType
) {
346 auto &MF
= MIRBuilder
.getMF();
347 auto &Ctx
= MF
.getFunction().getContext();
349 const Type
*LLVMFPTy
= Type::getFloatTy(Ctx
);
350 SpvType
= getOrCreateSPIRVType(LLVMFPTy
, MIRBuilder
);
352 // Find a constant in DT or build a new one.
353 const auto ConstFP
= ConstantFP::get(Ctx
, Val
);
354 Register Res
= DT
.find(ConstFP
, &MF
);
355 if (!Res
.isValid()) {
356 Res
= MF
.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
357 MF
.getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
358 assignSPIRVTypeToVReg(SpvType
, Res
, MF
);
359 DT
.add(ConstFP
, &MF
, Res
);
361 MachineInstrBuilder MIB
;
362 MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantF
)
364 .addUse(getSPIRVTypeID(SpvType
));
365 addNumImm(ConstFP
->getValueAPF().bitcastToAPInt(), MIB
);
371 Register
SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant
*Val
,
374 const SPIRVInstrInfo
&TII
,
376 SPIRVType
*Type
= SpvType
;
377 if (SpvType
->getOpcode() == SPIRV::OpTypeVector
||
378 SpvType
->getOpcode() == SPIRV::OpTypeArray
) {
379 auto EleTypeReg
= SpvType
->getOperand(1).getReg();
380 Type
= getSPIRVTypeForVReg(EleTypeReg
);
382 if (Type
->getOpcode() == SPIRV::OpTypeFloat
) {
383 SPIRVType
*SpvBaseType
= getOrCreateSPIRVFloatType(BitWidth
, I
, TII
);
384 return getOrCreateConstFP(dyn_cast
<ConstantFP
>(Val
)->getValue(), I
,
387 assert(Type
->getOpcode() == SPIRV::OpTypeInt
);
388 SPIRVType
*SpvBaseType
= getOrCreateSPIRVIntegerType(BitWidth
, I
, TII
);
389 return getOrCreateConstInt(Val
->getUniqueInteger().getSExtValue(), I
,
393 Register
SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
394 Constant
*Val
, MachineInstr
&I
, SPIRVType
*SpvType
,
395 const SPIRVInstrInfo
&TII
, Constant
*CA
, unsigned BitWidth
,
396 unsigned ElemCnt
, bool ZeroAsNull
) {
397 // Find a constant vector or array in DT or build a new one.
398 Register Res
= DT
.find(CA
, CurMF
);
399 // If no values are attached, the composite is null constant.
400 bool IsNull
= Val
->isNullValue() && ZeroAsNull
;
401 if (!Res
.isValid()) {
402 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
403 // error on validation.
404 // TODO: can moved below once sorting of types/consts/defs is implemented.
405 Register SpvScalConst
;
407 SpvScalConst
= getOrCreateBaseRegister(Val
, I
, SpvType
, TII
, BitWidth
);
409 // TODO: handle cases where the type is not 32bit wide
410 // TODO: https://github.com/llvm/llvm-project/issues/88129
411 LLT LLTy
= LLT::scalar(32);
412 Register SpvVecConst
=
413 CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
414 CurMF
->getRegInfo().setRegClass(SpvVecConst
, &SPIRV::IDRegClass
);
415 assignSPIRVTypeToVReg(SpvType
, SpvVecConst
, *CurMF
);
416 DT
.add(CA
, CurMF
, SpvVecConst
);
417 MachineInstrBuilder MIB
;
418 MachineBasicBlock
&BB
= *I
.getParent();
420 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantComposite
))
422 .addUse(getSPIRVTypeID(SpvType
));
423 for (unsigned i
= 0; i
< ElemCnt
; ++i
)
424 MIB
.addUse(SpvScalConst
);
426 MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpConstantNull
))
428 .addUse(getSPIRVTypeID(SpvType
));
430 const auto &Subtarget
= CurMF
->getSubtarget();
431 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
432 *Subtarget
.getRegisterInfo(),
433 *Subtarget
.getRegBankInfo());
439 Register
SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val
,
442 const SPIRVInstrInfo
&TII
,
444 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
445 assert(LLVMTy
->isVectorTy());
446 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
447 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
448 assert(LLVMBaseTy
->isIntegerTy());
449 auto *ConstVal
= ConstantInt::get(LLVMBaseTy
, Val
);
451 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstVal
);
452 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
453 return getOrCreateCompositeOrNull(ConstVal
, I
, SpvType
, TII
, ConstVec
, BW
,
454 SpvType
->getOperand(2).getImm(),
458 Register
SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val
,
461 const SPIRVInstrInfo
&TII
,
463 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
464 assert(LLVMTy
->isVectorTy());
465 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
466 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
467 assert(LLVMBaseTy
->isFloatingPointTy());
468 auto *ConstVal
= ConstantFP::get(LLVMBaseTy
, Val
);
470 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstVal
);
471 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
472 return getOrCreateCompositeOrNull(ConstVal
, I
, SpvType
, TII
, ConstVec
, BW
,
473 SpvType
->getOperand(2).getImm(),
477 Register
SPIRVGlobalRegistry::getOrCreateConstIntArray(
478 uint64_t Val
, size_t Num
, MachineInstr
&I
, SPIRVType
*SpvType
,
479 const SPIRVInstrInfo
&TII
) {
480 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
481 assert(LLVMTy
->isArrayTy());
482 const ArrayType
*LLVMArrTy
= cast
<ArrayType
>(LLVMTy
);
483 Type
*LLVMBaseTy
= LLVMArrTy
->getElementType();
484 Constant
*CI
= ConstantInt::get(LLVMBaseTy
, Val
);
485 SPIRVType
*SpvBaseTy
= getSPIRVTypeForVReg(SpvType
->getOperand(1).getReg());
486 unsigned BW
= getScalarOrVectorBitWidth(SpvBaseTy
);
487 // The following is reasonably unique key that is better that [Val]. The naive
488 // alternative would be something along the lines of:
489 // SmallVector<Constant *> NumCI(Num, CI);
490 // Constant *UniqueKey =
491 // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
492 // that would be a truly unique but dangerous key, because it could lead to
493 // the creation of constants of arbitrary length (that is, the parameter of
494 // memset) which were missing in the original module.
495 Constant
*UniqueKey
= ConstantStruct::getAnon(
496 {PoisonValue::get(const_cast<ArrayType
*>(LLVMArrTy
)),
497 ConstantInt::get(LLVMBaseTy
, Val
), ConstantInt::get(LLVMBaseTy
, Num
)});
498 return getOrCreateCompositeOrNull(CI
, I
, SpvType
, TII
, UniqueKey
, BW
,
499 LLVMArrTy
->getNumElements());
502 Register
SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
503 uint64_t Val
, MachineIRBuilder
&MIRBuilder
, SPIRVType
*SpvType
, bool EmitIR
,
504 Constant
*CA
, unsigned BitWidth
, unsigned ElemCnt
) {
505 Register Res
= DT
.find(CA
, CurMF
);
506 if (!Res
.isValid()) {
507 Register SpvScalConst
;
509 SPIRVType
*SpvBaseType
=
510 getOrCreateSPIRVIntegerType(BitWidth
, MIRBuilder
);
511 SpvScalConst
= buildConstantInt(Val
, MIRBuilder
, SpvBaseType
, EmitIR
);
513 LLT LLTy
= EmitIR
? LLT::fixed_vector(ElemCnt
, BitWidth
) : LLT::scalar(32);
514 Register SpvVecConst
=
515 CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
516 CurMF
->getRegInfo().setRegClass(SpvVecConst
, &SPIRV::IDRegClass
);
517 assignSPIRVTypeToVReg(SpvType
, SpvVecConst
, *CurMF
);
518 DT
.add(CA
, CurMF
, SpvVecConst
);
520 MIRBuilder
.buildSplatVector(SpvVecConst
, SpvScalConst
);
523 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpConstantComposite
)
525 .addUse(getSPIRVTypeID(SpvType
));
526 for (unsigned i
= 0; i
< ElemCnt
; ++i
)
527 MIB
.addUse(SpvScalConst
);
529 MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
531 .addUse(getSPIRVTypeID(SpvType
));
540 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val
,
541 MachineIRBuilder
&MIRBuilder
,
542 SPIRVType
*SpvType
, bool EmitIR
) {
543 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
544 assert(LLVMTy
->isVectorTy());
545 const FixedVectorType
*LLVMVecTy
= cast
<FixedVectorType
>(LLVMTy
);
546 Type
*LLVMBaseTy
= LLVMVecTy
->getElementType();
547 const auto ConstInt
= ConstantInt::get(LLVMBaseTy
, Val
);
549 ConstantVector::getSplat(LLVMVecTy
->getElementCount(), ConstInt
);
550 unsigned BW
= getScalarOrVectorBitWidth(SpvType
);
551 return getOrCreateIntCompositeOrNull(Val
, MIRBuilder
, SpvType
, EmitIR
,
553 SpvType
->getOperand(2).getImm());
557 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder
&MIRBuilder
,
558 SPIRVType
*SpvType
) {
559 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
560 const TypedPointerType
*LLVMPtrTy
= cast
<TypedPointerType
>(LLVMTy
);
561 // Find a constant in DT or build a new one.
562 Constant
*CP
= ConstantPointerNull::get(PointerType::get(
563 LLVMPtrTy
->getElementType(), LLVMPtrTy
->getAddressSpace()));
564 Register Res
= DT
.find(CP
, CurMF
);
565 if (!Res
.isValid()) {
566 LLT LLTy
= LLT::pointer(LLVMPtrTy
->getAddressSpace(), PointerSize
);
567 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
568 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
569 assignSPIRVTypeToVReg(SpvType
, Res
, *CurMF
);
570 MIRBuilder
.buildInstr(SPIRV::OpConstantNull
)
572 .addUse(getSPIRVTypeID(SpvType
));
573 DT
.add(CP
, CurMF
, Res
);
578 Register
SPIRVGlobalRegistry::buildConstantSampler(
579 Register ResReg
, unsigned AddrMode
, unsigned Param
, unsigned FilerMode
,
580 MachineIRBuilder
&MIRBuilder
, SPIRVType
*SpvType
) {
583 SampTy
= getOrCreateSPIRVType(getTypeForSPIRVType(SpvType
), MIRBuilder
);
584 else if ((SampTy
= getOrCreateSPIRVTypeByName("opencl.sampler_t",
585 MIRBuilder
)) == nullptr)
586 report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
591 : MIRBuilder
.getMRI()->createVirtualRegister(&SPIRV::IDRegClass
);
592 auto Res
= MIRBuilder
.buildInstr(SPIRV::OpConstantSampler
)
594 .addUse(getSPIRVTypeID(SampTy
))
598 assert(Res
->getOperand(0).isReg());
599 return Res
->getOperand(0).getReg();
602 Register
SPIRVGlobalRegistry::buildGlobalVariable(
603 Register ResVReg
, SPIRVType
*BaseType
, StringRef Name
,
604 const GlobalValue
*GV
, SPIRV::StorageClass::StorageClass Storage
,
605 const MachineInstr
*Init
, bool IsConst
, bool HasLinkageTy
,
606 SPIRV::LinkageType::LinkageType LinkageType
, MachineIRBuilder
&MIRBuilder
,
607 bool IsInstSelector
) {
608 const GlobalVariable
*GVar
= nullptr;
610 GVar
= cast
<const GlobalVariable
>(GV
);
612 // If GV is not passed explicitly, use the name to find or construct
613 // the global variable.
614 Module
*M
= MIRBuilder
.getMF().getFunction().getParent();
615 GVar
= M
->getGlobalVariable(Name
);
616 if (GVar
== nullptr) {
617 const Type
*Ty
= getTypeForSPIRVType(BaseType
); // TODO: check type.
618 // Module takes ownership of the global var.
619 GVar
= new GlobalVariable(*M
, const_cast<Type
*>(Ty
), false,
620 GlobalValue::ExternalLinkage
, nullptr,
625 Register Reg
= DT
.find(GVar
, &MIRBuilder
.getMF());
628 MIRBuilder
.buildCopy(ResVReg
, Reg
);
632 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpVariable
)
634 .addUse(getSPIRVTypeID(BaseType
))
635 .addImm(static_cast<uint32_t>(Storage
));
638 MIB
.addUse(Init
->getOperand(0).getReg());
641 // ISel may introduce a new register on this step, so we need to add it to
642 // DT and correct its type avoiding fails on the next stage.
643 if (IsInstSelector
) {
644 const auto &Subtarget
= CurMF
->getSubtarget();
645 constrainSelectedInstRegOperands(*MIB
, *Subtarget
.getInstrInfo(),
646 *Subtarget
.getRegisterInfo(),
647 *Subtarget
.getRegBankInfo());
649 Reg
= MIB
->getOperand(0).getReg();
650 DT
.add(GVar
, &MIRBuilder
.getMF(), Reg
);
652 // Set to Reg the same type as ResVReg has.
653 auto MRI
= MIRBuilder
.getMRI();
654 assert(MRI
->getType(ResVReg
).isPointer() && "Pointer type is expected");
655 if (Reg
!= ResVReg
) {
657 LLT::pointer(MRI
->getType(ResVReg
).getAddressSpace(), getPointerSize());
658 MRI
->setType(Reg
, RegLLTy
);
659 assignSPIRVTypeToVReg(BaseType
, Reg
, MIRBuilder
.getMF());
661 // Our knowledge about the type may be updated.
662 // If that's the case, we need to update a type
663 // associated with the register.
664 SPIRVType
*DefType
= getSPIRVTypeForVReg(ResVReg
);
665 if (!DefType
|| DefType
!= BaseType
)
666 assignSPIRVTypeToVReg(BaseType
, Reg
, MIRBuilder
.getMF());
669 // If it's a global variable with name, output OpName for it.
670 if (GVar
&& GVar
->hasName())
671 buildOpName(Reg
, GVar
->getName(), MIRBuilder
);
673 // Output decorations for the GV.
674 // TODO: maybe move to GenerateDecorations pass.
675 const SPIRVSubtarget
&ST
=
676 cast
<SPIRVSubtarget
>(MIRBuilder
.getMF().getSubtarget());
677 if (IsConst
&& ST
.isOpenCLEnv())
678 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::Constant
, {});
680 if (GVar
&& GVar
->getAlign().valueOrOne().value() != 1) {
681 unsigned Alignment
= (unsigned)GVar
->getAlign().valueOrOne().value();
682 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::Alignment
, {Alignment
});
686 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::LinkageAttributes
,
687 {static_cast<uint32_t>(LinkageType
)}, Name
);
689 SPIRV::BuiltIn::BuiltIn BuiltInId
;
690 if (getSpirvBuiltInIdByName(Name
, BuiltInId
))
691 buildOpDecorate(Reg
, MIRBuilder
, SPIRV::Decoration::BuiltIn
,
692 {static_cast<uint32_t>(BuiltInId
)});
694 // If it's a global variable with "spirv.Decorations" metadata node
695 // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
697 MDNode
*GVarMD
= nullptr;
698 if (GVar
&& (GVarMD
= GVar
->getMetadata("spirv.Decorations")) != nullptr)
699 buildOpSpirvDecorations(Reg
, MIRBuilder
, GVarMD
);
704 SPIRVType
*SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems
,
706 MachineIRBuilder
&MIRBuilder
,
708 assert((ElemType
->getOpcode() != SPIRV::OpTypeVoid
) &&
709 "Invalid array element type");
710 Register NumElementsVReg
=
711 buildConstantInt(NumElems
, MIRBuilder
, nullptr, EmitIR
);
712 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeArray
)
713 .addDef(createTypeVReg(MIRBuilder
))
714 .addUse(getSPIRVTypeID(ElemType
))
715 .addUse(NumElementsVReg
);
719 SPIRVType
*SPIRVGlobalRegistry::getOpTypeOpaque(const StructType
*Ty
,
720 MachineIRBuilder
&MIRBuilder
) {
721 assert(Ty
->hasName());
722 const StringRef Name
= Ty
->hasName() ? Ty
->getName() : "";
723 Register ResVReg
= createTypeVReg(MIRBuilder
);
724 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeOpaque
).addDef(ResVReg
);
725 addStringImm(Name
, MIB
);
726 buildOpName(ResVReg
, Name
, MIRBuilder
);
730 SPIRVType
*SPIRVGlobalRegistry::getOpTypeStruct(const StructType
*Ty
,
731 MachineIRBuilder
&MIRBuilder
,
733 SmallVector
<Register
, 4> FieldTypes
;
734 for (const auto &Elem
: Ty
->elements()) {
735 SPIRVType
*ElemTy
= findSPIRVType(toTypedPointer(Elem
), MIRBuilder
);
736 assert(ElemTy
&& ElemTy
->getOpcode() != SPIRV::OpTypeVoid
&&
737 "Invalid struct element type");
738 FieldTypes
.push_back(getSPIRVTypeID(ElemTy
));
740 Register ResVReg
= createTypeVReg(MIRBuilder
);
741 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeStruct
).addDef(ResVReg
);
742 for (const auto &Ty
: FieldTypes
)
745 buildOpName(ResVReg
, Ty
->getName(), MIRBuilder
);
747 buildOpDecorate(ResVReg
, MIRBuilder
, SPIRV::Decoration::CPacked
, {});
751 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSpecialType(
752 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
753 SPIRV::AccessQualifier::AccessQualifier AccQual
) {
754 assert(isSpecialOpaqueType(Ty
) && "Not a special opaque builtin type");
755 return SPIRV::lowerBuiltinType(Ty
, AccQual
, MIRBuilder
, this);
758 SPIRVType
*SPIRVGlobalRegistry::getOpTypePointer(
759 SPIRV::StorageClass::StorageClass SC
, SPIRVType
*ElemType
,
760 MachineIRBuilder
&MIRBuilder
, Register Reg
) {
762 Reg
= createTypeVReg(MIRBuilder
);
763 return MIRBuilder
.buildInstr(SPIRV::OpTypePointer
)
765 .addImm(static_cast<uint32_t>(SC
))
766 .addUse(getSPIRVTypeID(ElemType
));
769 SPIRVType
*SPIRVGlobalRegistry::getOpTypeForwardPointer(
770 SPIRV::StorageClass::StorageClass SC
, MachineIRBuilder
&MIRBuilder
) {
771 return MIRBuilder
.buildInstr(SPIRV::OpTypeForwardPointer
)
772 .addUse(createTypeVReg(MIRBuilder
))
773 .addImm(static_cast<uint32_t>(SC
));
776 SPIRVType
*SPIRVGlobalRegistry::getOpTypeFunction(
777 SPIRVType
*RetType
, const SmallVectorImpl
<SPIRVType
*> &ArgTypes
,
778 MachineIRBuilder
&MIRBuilder
) {
779 auto MIB
= MIRBuilder
.buildInstr(SPIRV::OpTypeFunction
)
780 .addDef(createTypeVReg(MIRBuilder
))
781 .addUse(getSPIRVTypeID(RetType
));
782 for (const SPIRVType
*ArgType
: ArgTypes
)
783 MIB
.addUse(getSPIRVTypeID(ArgType
));
787 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
788 const Type
*Ty
, SPIRVType
*RetType
,
789 const SmallVectorImpl
<SPIRVType
*> &ArgTypes
,
790 MachineIRBuilder
&MIRBuilder
) {
791 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
793 return getSPIRVTypeForVReg(Reg
);
794 SPIRVType
*SpirvType
= getOpTypeFunction(RetType
, ArgTypes
, MIRBuilder
);
795 DT
.add(Ty
, CurMF
, getSPIRVTypeID(SpirvType
));
796 return finishCreatingSPIRVType(Ty
, SpirvType
);
799 SPIRVType
*SPIRVGlobalRegistry::findSPIRVType(
800 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
801 SPIRV::AccessQualifier::AccessQualifier AccQual
, bool EmitIR
) {
802 Ty
= adjustIntTypeByWidth(Ty
);
803 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
805 return getSPIRVTypeForVReg(Reg
);
806 if (ForwardPointerTypes
.contains(Ty
))
807 return ForwardPointerTypes
[Ty
];
808 return restOfCreateSPIRVType(Ty
, MIRBuilder
, AccQual
, EmitIR
);
811 Register
SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType
*SpirvType
) const {
812 assert(SpirvType
&& "Attempting to get type id for nullptr type.");
813 if (SpirvType
->getOpcode() == SPIRV::OpTypeForwardPointer
)
814 return SpirvType
->uses().begin()->getReg();
815 return SpirvType
->defs().begin()->getReg();
818 // We need to use a new LLVM integer type if there is a mismatch between
819 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker
820 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
821 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
822 // same "OpTypeInt 8" type for a series of LLVM integer types with number of
823 // bits less than 8. This would lead to duplicate type definitions
824 // eventually due to the method that DuplicateTracker utilizes to reason
825 // about uniqueness of type records.
826 const Type
*SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type
*Ty
) const {
827 if (auto IType
= dyn_cast
<IntegerType
>(Ty
)) {
828 unsigned SrcBitWidth
= IType
->getBitWidth();
829 if (SrcBitWidth
> 1) {
830 unsigned BitWidth
= adjustOpTypeIntWidth(SrcBitWidth
);
831 // Maybe change source LLVM type to keep DuplicateTracker consistent.
832 if (SrcBitWidth
!= BitWidth
)
833 Ty
= IntegerType::get(Ty
->getContext(), BitWidth
);
839 SPIRVType
*SPIRVGlobalRegistry::createSPIRVType(
840 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
841 SPIRV::AccessQualifier::AccessQualifier AccQual
, bool EmitIR
) {
842 if (isSpecialOpaqueType(Ty
))
843 return getOrCreateSpecialType(Ty
, MIRBuilder
, AccQual
);
844 auto &TypeToSPIRVTypeMap
= DT
.getTypes()->getAllUses();
845 auto t
= TypeToSPIRVTypeMap
.find(Ty
);
846 if (t
!= TypeToSPIRVTypeMap
.end()) {
847 auto tt
= t
->second
.find(&MIRBuilder
.getMF());
848 if (tt
!= t
->second
.end())
849 return getSPIRVTypeForVReg(tt
->second
);
852 if (auto IType
= dyn_cast
<IntegerType
>(Ty
)) {
853 const unsigned Width
= IType
->getBitWidth();
854 return Width
== 1 ? getOpTypeBool(MIRBuilder
)
855 : getOpTypeInt(Width
, MIRBuilder
, false);
857 if (Ty
->isFloatingPointTy())
858 return getOpTypeFloat(Ty
->getPrimitiveSizeInBits(), MIRBuilder
);
860 return getOpTypeVoid(MIRBuilder
);
861 if (Ty
->isVectorTy()) {
863 findSPIRVType(cast
<FixedVectorType
>(Ty
)->getElementType(), MIRBuilder
);
864 return getOpTypeVector(cast
<FixedVectorType
>(Ty
)->getNumElements(), El
,
867 if (Ty
->isArrayTy()) {
868 SPIRVType
*El
= findSPIRVType(Ty
->getArrayElementType(), MIRBuilder
);
869 return getOpTypeArray(Ty
->getArrayNumElements(), El
, MIRBuilder
, EmitIR
);
871 if (auto SType
= dyn_cast
<StructType
>(Ty
)) {
872 if (SType
->isOpaque())
873 return getOpTypeOpaque(SType
, MIRBuilder
);
874 return getOpTypeStruct(SType
, MIRBuilder
, EmitIR
);
876 if (auto FType
= dyn_cast
<FunctionType
>(Ty
)) {
877 SPIRVType
*RetTy
= findSPIRVType(FType
->getReturnType(), MIRBuilder
);
878 SmallVector
<SPIRVType
*, 4> ParamTypes
;
879 for (const auto &t
: FType
->params()) {
880 ParamTypes
.push_back(findSPIRVType(t
, MIRBuilder
));
882 return getOpTypeFunction(RetTy
, ParamTypes
, MIRBuilder
);
884 unsigned AddrSpace
= 0xFFFF;
885 if (auto PType
= dyn_cast
<TypedPointerType
>(Ty
))
886 AddrSpace
= PType
->getAddressSpace();
887 else if (auto PType
= dyn_cast
<PointerType
>(Ty
))
888 AddrSpace
= PType
->getAddressSpace();
890 report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
892 SPIRVType
*SpvElementType
= nullptr;
893 if (auto PType
= dyn_cast
<TypedPointerType
>(Ty
))
894 SpvElementType
= getOrCreateSPIRVType(PType
->getElementType(), MIRBuilder
,
897 SpvElementType
= getOrCreateSPIRVIntegerType(8, MIRBuilder
);
899 // Get access to information about available extensions
900 const SPIRVSubtarget
*ST
=
901 static_cast<const SPIRVSubtarget
*>(&MIRBuilder
.getMF().getSubtarget());
902 auto SC
= addressSpaceToStorageClass(AddrSpace
, *ST
);
903 // Null pointer means we have a loop in type definitions, make and
904 // return corresponding OpTypeForwardPointer.
905 if (SpvElementType
== nullptr) {
906 if (!ForwardPointerTypes
.contains(Ty
))
907 ForwardPointerTypes
[Ty
] = getOpTypeForwardPointer(SC
, MIRBuilder
);
908 return ForwardPointerTypes
[Ty
];
910 // If we have forward pointer associated with this type, use its register
911 // operand to create OpTypePointer.
912 if (ForwardPointerTypes
.contains(Ty
)) {
913 Register Reg
= getSPIRVTypeID(ForwardPointerTypes
[Ty
]);
914 return getOpTypePointer(SC
, SpvElementType
, MIRBuilder
, Reg
);
917 return getOrCreateSPIRVPointerType(SpvElementType
, MIRBuilder
, SC
);
920 SPIRVType
*SPIRVGlobalRegistry::restOfCreateSPIRVType(
921 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
922 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
923 if (TypesInProcessing
.count(Ty
) && !isPointerTy(Ty
))
925 TypesInProcessing
.insert(Ty
);
926 SPIRVType
*SpirvType
= createSPIRVType(Ty
, MIRBuilder
, AccessQual
, EmitIR
);
927 TypesInProcessing
.erase(Ty
);
928 VRegToTypeMap
[&MIRBuilder
.getMF()][getSPIRVTypeID(SpirvType
)] = SpirvType
;
929 SPIRVToLLVMType
[SpirvType
] = unifyPtrType(Ty
);
930 Register Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
931 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
932 // will be added later. For special types it is already added to DT.
933 if (SpirvType
->getOpcode() != SPIRV::OpTypeForwardPointer
&& !Reg
.isValid() &&
934 !isSpecialOpaqueType(Ty
)) {
935 if (!isPointerTy(Ty
))
936 DT
.add(Ty
, &MIRBuilder
.getMF(), getSPIRVTypeID(SpirvType
));
937 else if (isTypedPointerTy(Ty
))
938 DT
.add(cast
<TypedPointerType
>(Ty
)->getElementType(),
939 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF(),
940 getSPIRVTypeID(SpirvType
));
942 DT
.add(Type::getInt8Ty(MIRBuilder
.getMF().getFunction().getContext()),
943 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF(),
944 getSPIRVTypeID(SpirvType
));
951 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg
,
952 const MachineFunction
*MF
) const {
953 auto t
= VRegToTypeMap
.find(MF
? MF
: CurMF
);
954 if (t
!= VRegToTypeMap
.end()) {
955 auto tt
= t
->second
.find(VReg
);
956 if (tt
!= t
->second
.end())
962 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVType(
963 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
,
964 SPIRV::AccessQualifier::AccessQualifier AccessQual
, bool EmitIR
) {
966 if (!isPointerTy(Ty
)) {
967 Ty
= adjustIntTypeByWidth(Ty
);
968 Reg
= DT
.find(Ty
, &MIRBuilder
.getMF());
969 } else if (isTypedPointerTy(Ty
)) {
970 Reg
= DT
.find(cast
<TypedPointerType
>(Ty
)->getElementType(),
971 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF());
974 DT
.find(Type::getInt8Ty(MIRBuilder
.getMF().getFunction().getContext()),
975 getPointerAddressSpace(Ty
), &MIRBuilder
.getMF());
978 if (Reg
.isValid() && !isSpecialOpaqueType(Ty
))
979 return getSPIRVTypeForVReg(Reg
);
980 TypesInProcessing
.clear();
981 SPIRVType
*STy
= restOfCreateSPIRVType(Ty
, MIRBuilder
, AccessQual
, EmitIR
);
982 // Create normal pointer types for the corresponding OpTypeForwardPointers.
983 for (auto &CU
: ForwardPointerTypes
) {
984 const Type
*Ty2
= CU
.first
;
985 SPIRVType
*STy2
= CU
.second
;
986 if ((Reg
= DT
.find(Ty2
, &MIRBuilder
.getMF())).isValid())
987 STy2
= getSPIRVTypeForVReg(Reg
);
989 STy2
= restOfCreateSPIRVType(Ty2
, MIRBuilder
, AccessQual
, EmitIR
);
993 ForwardPointerTypes
.clear();
997 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg
,
998 unsigned TypeOpcode
) const {
999 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
1000 assert(Type
&& "isScalarOfType VReg has no type assigned");
1001 return Type
->getOpcode() == TypeOpcode
;
1004 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg
,
1005 unsigned TypeOpcode
) const {
1006 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
1007 assert(Type
&& "isScalarOrVectorOfType VReg has no type assigned");
1008 if (Type
->getOpcode() == TypeOpcode
)
1010 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
1011 Register ScalarTypeVReg
= Type
->getOperand(1).getReg();
1012 SPIRVType
*ScalarType
= getSPIRVTypeForVReg(ScalarTypeVReg
);
1013 return ScalarType
->getOpcode() == TypeOpcode
;
1019 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg
) const {
1020 return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg
));
1024 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType
*Type
) const {
1027 return Type
->getOpcode() == SPIRV::OpTypeVector
1028 ? static_cast<unsigned>(Type
->getOperand(2).getImm())
1033 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType
*Type
) const {
1034 assert(Type
&& "Invalid Type pointer");
1035 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
1036 auto EleTypeReg
= Type
->getOperand(1).getReg();
1037 Type
= getSPIRVTypeForVReg(EleTypeReg
);
1039 if (Type
->getOpcode() == SPIRV::OpTypeInt
||
1040 Type
->getOpcode() == SPIRV::OpTypeFloat
)
1041 return Type
->getOperand(1).getImm();
1042 if (Type
->getOpcode() == SPIRV::OpTypeBool
)
1044 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1047 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1048 const SPIRVType
*Type
) const {
1049 assert(Type
&& "Invalid Type pointer");
1050 unsigned NumElements
= 1;
1051 if (Type
->getOpcode() == SPIRV::OpTypeVector
) {
1052 NumElements
= static_cast<unsigned>(Type
->getOperand(2).getImm());
1053 Type
= getSPIRVTypeForVReg(Type
->getOperand(1).getReg());
1055 return Type
->getOpcode() == SPIRV::OpTypeInt
||
1056 Type
->getOpcode() == SPIRV::OpTypeFloat
1057 ? NumElements
* Type
->getOperand(1).getImm()
1061 const SPIRVType
*SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1062 const SPIRVType
*Type
) const {
1063 if (Type
&& Type
->getOpcode() == SPIRV::OpTypeVector
)
1064 Type
= getSPIRVTypeForVReg(Type
->getOperand(1).getReg());
1065 return Type
&& Type
->getOpcode() == SPIRV::OpTypeInt
? Type
: nullptr;
1068 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType
*Type
) const {
1069 const SPIRVType
*IntType
= retrieveScalarOrVectorIntType(Type
);
1070 return IntType
&& IntType
->getOperand(2).getImm() != 0;
1073 SPIRVType
*SPIRVGlobalRegistry::getPointeeType(SPIRVType
*PtrType
) {
1074 return PtrType
&& PtrType
->getOpcode() == SPIRV::OpTypePointer
1075 ? getSPIRVTypeForVReg(PtrType
->getOperand(2).getReg())
1079 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg
) {
1080 SPIRVType
*ElemType
= getPointeeType(getSPIRVTypeForVReg(PtrReg
));
1081 return ElemType
? ElemType
->getOpcode() : 0;
1084 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType
*Type1
,
1085 const SPIRVType
*Type2
) const {
1086 if (!Type1
|| !Type2
)
1088 auto Op1
= Type1
->getOpcode(), Op2
= Type2
->getOpcode();
1089 // Ignore difference between <1.5 and >=1.5 protocol versions:
1090 // it's valid if either Result Type or Operand is a pointer, and the other
1091 // is a pointer, an integer scalar, or an integer vector.
1092 if (Op1
== SPIRV::OpTypePointer
&&
1093 (Op2
== SPIRV::OpTypePointer
|| retrieveScalarOrVectorIntType(Type2
)))
1095 if (Op2
== SPIRV::OpTypePointer
&&
1096 (Op1
== SPIRV::OpTypePointer
|| retrieveScalarOrVectorIntType(Type1
)))
1098 unsigned Bits1
= getNumScalarOrVectorTotalBitWidth(Type1
),
1099 Bits2
= getNumScalarOrVectorTotalBitWidth(Type2
);
1100 return Bits1
> 0 && Bits1
== Bits2
;
1103 SPIRV::StorageClass::StorageClass
1104 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg
) const {
1105 SPIRVType
*Type
= getSPIRVTypeForVReg(VReg
);
1106 assert(Type
&& Type
->getOpcode() == SPIRV::OpTypePointer
&&
1107 Type
->getOperand(1).isImm() && "Pointer type is expected");
1108 return static_cast<SPIRV::StorageClass::StorageClass
>(
1109 Type
->getOperand(1).getImm());
1112 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1113 MachineIRBuilder
&MIRBuilder
, SPIRVType
*SampledType
, SPIRV::Dim::Dim Dim
,
1114 uint32_t Depth
, uint32_t Arrayed
, uint32_t Multisampled
, uint32_t Sampled
,
1115 SPIRV::ImageFormat::ImageFormat ImageFormat
,
1116 SPIRV::AccessQualifier::AccessQualifier AccessQual
) {
1117 auto TD
= SPIRV::make_descr_image(SPIRVToLLVMType
.lookup(SampledType
), Dim
,
1118 Depth
, Arrayed
, Multisampled
, Sampled
,
1119 ImageFormat
, AccessQual
);
1120 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1122 Register ResVReg
= createTypeVReg(MIRBuilder
);
1123 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1124 return MIRBuilder
.buildInstr(SPIRV::OpTypeImage
)
1126 .addUse(getSPIRVTypeID(SampledType
))
1128 .addImm(Depth
) // Depth (whether or not it is a Depth image).
1129 .addImm(Arrayed
) // Arrayed.
1130 .addImm(Multisampled
) // Multisampled (0 = only single-sample).
1131 .addImm(Sampled
) // Sampled (0 = usage known at runtime).
1132 .addImm(ImageFormat
)
1133 .addImm(AccessQual
);
1137 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder
&MIRBuilder
) {
1138 auto TD
= SPIRV::make_descr_sampler();
1139 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1141 Register ResVReg
= createTypeVReg(MIRBuilder
);
1142 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1143 return MIRBuilder
.buildInstr(SPIRV::OpTypeSampler
).addDef(ResVReg
);
1146 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1147 MachineIRBuilder
&MIRBuilder
,
1148 SPIRV::AccessQualifier::AccessQualifier AccessQual
) {
1149 auto TD
= SPIRV::make_descr_pipe(AccessQual
);
1150 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1152 Register ResVReg
= createTypeVReg(MIRBuilder
);
1153 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1154 return MIRBuilder
.buildInstr(SPIRV::OpTypePipe
)
1156 .addImm(AccessQual
);
1159 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1160 MachineIRBuilder
&MIRBuilder
) {
1161 auto TD
= SPIRV::make_descr_event();
1162 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1164 Register ResVReg
= createTypeVReg(MIRBuilder
);
1165 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1166 return MIRBuilder
.buildInstr(SPIRV::OpTypeDeviceEvent
).addDef(ResVReg
);
1169 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1170 SPIRVType
*ImageType
, MachineIRBuilder
&MIRBuilder
) {
1171 auto TD
= SPIRV::make_descr_sampled_image(
1172 SPIRVToLLVMType
.lookup(MIRBuilder
.getMF().getRegInfo().getVRegDef(
1173 ImageType
->getOperand(1).getReg())),
1175 if (auto *Res
= checkSpecialInstr(TD
, MIRBuilder
))
1177 Register ResVReg
= createTypeVReg(MIRBuilder
);
1178 DT
.add(TD
, &MIRBuilder
.getMF(), ResVReg
);
1179 return MIRBuilder
.buildInstr(SPIRV::OpTypeSampledImage
)
1181 .addUse(getSPIRVTypeID(ImageType
));
1184 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1185 MachineIRBuilder
&MIRBuilder
, const TargetExtType
*ExtensionType
,
1186 const SPIRVType
*ElemType
, uint32_t Scope
, uint32_t Rows
, uint32_t Columns
,
1188 Register ResVReg
= DT
.find(ExtensionType
, &MIRBuilder
.getMF());
1189 if (ResVReg
.isValid())
1190 return MIRBuilder
.getMF().getRegInfo().getUniqueVRegDef(ResVReg
);
1191 ResVReg
= createTypeVReg(MIRBuilder
);
1192 SPIRVType
*SpirvTy
=
1193 MIRBuilder
.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR
)
1195 .addUse(getSPIRVTypeID(ElemType
))
1196 .addUse(buildConstantInt(Scope
, MIRBuilder
, nullptr, true))
1197 .addUse(buildConstantInt(Rows
, MIRBuilder
, nullptr, true))
1198 .addUse(buildConstantInt(Columns
, MIRBuilder
, nullptr, true))
1199 .addUse(buildConstantInt(Use
, MIRBuilder
, nullptr, true));
1200 DT
.add(ExtensionType
, &MIRBuilder
.getMF(), ResVReg
);
1204 SPIRVType
*SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1205 const Type
*Ty
, MachineIRBuilder
&MIRBuilder
, unsigned Opcode
) {
1206 Register ResVReg
= DT
.find(Ty
, &MIRBuilder
.getMF());
1207 if (ResVReg
.isValid())
1208 return MIRBuilder
.getMF().getRegInfo().getUniqueVRegDef(ResVReg
);
1209 ResVReg
= createTypeVReg(MIRBuilder
);
1210 SPIRVType
*SpirvTy
= MIRBuilder
.buildInstr(Opcode
).addDef(ResVReg
);
1211 DT
.add(Ty
, &MIRBuilder
.getMF(), ResVReg
);
1215 const MachineInstr
*
1216 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor
&TD
,
1217 MachineIRBuilder
&MIRBuilder
) {
1218 Register Reg
= DT
.find(TD
, &MIRBuilder
.getMF());
1220 return MIRBuilder
.getMF().getRegInfo().getUniqueVRegDef(Reg
);
1224 // Returns nullptr if unable to recognize SPIRV type name
1225 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1226 StringRef TypeStr
, MachineIRBuilder
&MIRBuilder
,
1227 SPIRV::StorageClass::StorageClass SC
,
1228 SPIRV::AccessQualifier::AccessQualifier AQ
) {
1229 unsigned VecElts
= 0;
1230 auto &Ctx
= MIRBuilder
.getMF().getFunction().getContext();
1232 // Parse strings representing either a SPIR-V or OpenCL builtin type.
1233 if (hasBuiltinTypePrefix(TypeStr
))
1234 return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1235 TypeStr
.str(), MIRBuilder
.getContext()),
1238 // Parse type name in either "typeN" or "type vector[N]" format, where
1239 // N is the number of elements of the vector.
1242 Ty
= parseBasicTypeName(TypeStr
, Ctx
);
1244 // Unable to recognize SPIRV type name
1247 auto SpirvTy
= getOrCreateSPIRVType(Ty
, MIRBuilder
, AQ
);
1249 // Handle "type*" or "type* vector[N]".
1250 if (TypeStr
.starts_with("*")) {
1251 SpirvTy
= getOrCreateSPIRVPointerType(SpirvTy
, MIRBuilder
, SC
);
1252 TypeStr
= TypeStr
.substr(strlen("*"));
1255 // Handle "typeN*" or "type vector[N]*".
1256 bool IsPtrToVec
= TypeStr
.consume_back("*");
1258 if (TypeStr
.consume_front(" vector[")) {
1259 TypeStr
= TypeStr
.substr(0, TypeStr
.find(']'));
1261 TypeStr
.getAsInteger(10, VecElts
);
1263 SpirvTy
= getOrCreateSPIRVVectorType(SpirvTy
, VecElts
, MIRBuilder
);
1266 SpirvTy
= getOrCreateSPIRVPointerType(SpirvTy
, MIRBuilder
, SC
);
1272 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth
,
1273 MachineIRBuilder
&MIRBuilder
) {
1274 return getOrCreateSPIRVType(
1275 IntegerType::get(MIRBuilder
.getMF().getFunction().getContext(), BitWidth
),
1279 SPIRVType
*SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type
*LLVMTy
,
1280 SPIRVType
*SpirvType
) {
1281 assert(CurMF
== SpirvType
->getMF());
1282 VRegToTypeMap
[CurMF
][getSPIRVTypeID(SpirvType
)] = SpirvType
;
1283 SPIRVToLLVMType
[SpirvType
] = unifyPtrType(LLVMTy
);
1287 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth
,
1289 const SPIRVInstrInfo
&TII
,
1290 unsigned SPIRVOPcode
,
1292 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1294 return getSPIRVTypeForVReg(Reg
);
1295 MachineBasicBlock
&BB
= *I
.getParent();
1296 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRVOPcode
))
1297 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1300 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1301 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1304 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1305 unsigned BitWidth
, MachineInstr
&I
, const SPIRVInstrInfo
&TII
) {
1306 // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1307 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1308 // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1309 // with number of bits less than 8, causing duplicate type definitions.
1310 BitWidth
= adjustOpTypeIntWidth(BitWidth
);
1311 Type
*LLVMTy
= IntegerType::get(CurMF
->getFunction().getContext(), BitWidth
);
1312 return getOrCreateSPIRVType(BitWidth
, I
, TII
, SPIRV::OpTypeInt
, LLVMTy
);
1315 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1316 unsigned BitWidth
, MachineInstr
&I
, const SPIRVInstrInfo
&TII
) {
1317 LLVMContext
&Ctx
= CurMF
->getFunction().getContext();
1321 LLVMTy
= Type::getHalfTy(Ctx
);
1324 LLVMTy
= Type::getFloatTy(Ctx
);
1327 LLVMTy
= Type::getDoubleTy(Ctx
);
1330 llvm_unreachable("Bit width is of unexpected size.");
1332 return getOrCreateSPIRVType(BitWidth
, I
, TII
, SPIRV::OpTypeFloat
, LLVMTy
);
1336 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder
&MIRBuilder
) {
1337 return getOrCreateSPIRVType(
1338 IntegerType::get(MIRBuilder
.getMF().getFunction().getContext(), 1),
1343 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr
&I
,
1344 const SPIRVInstrInfo
&TII
) {
1345 Type
*LLVMTy
= IntegerType::get(CurMF
->getFunction().getContext(), 1);
1346 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1348 return getSPIRVTypeForVReg(Reg
);
1349 MachineBasicBlock
&BB
= *I
.getParent();
1350 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeBool
))
1351 .addDef(createTypeVReg(CurMF
->getRegInfo()));
1352 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1353 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1356 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1357 SPIRVType
*BaseType
, unsigned NumElements
, MachineIRBuilder
&MIRBuilder
) {
1358 return getOrCreateSPIRVType(
1359 FixedVectorType::get(const_cast<Type
*>(getTypeForSPIRVType(BaseType
)),
1364 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1365 SPIRVType
*BaseType
, unsigned NumElements
, MachineInstr
&I
,
1366 const SPIRVInstrInfo
&TII
) {
1367 Type
*LLVMTy
= FixedVectorType::get(
1368 const_cast<Type
*>(getTypeForSPIRVType(BaseType
)), NumElements
);
1369 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1371 return getSPIRVTypeForVReg(Reg
);
1372 MachineBasicBlock
&BB
= *I
.getParent();
1373 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeVector
))
1374 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1375 .addUse(getSPIRVTypeID(BaseType
))
1376 .addImm(NumElements
);
1377 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1378 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1381 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1382 SPIRVType
*BaseType
, unsigned NumElements
, MachineInstr
&I
,
1383 const SPIRVInstrInfo
&TII
) {
1384 Type
*LLVMTy
= ArrayType::get(
1385 const_cast<Type
*>(getTypeForSPIRVType(BaseType
)), NumElements
);
1386 Register Reg
= DT
.find(LLVMTy
, CurMF
);
1388 return getSPIRVTypeForVReg(Reg
);
1389 MachineBasicBlock
&BB
= *I
.getParent();
1390 SPIRVType
*SpirvType
= getOrCreateSPIRVIntegerType(32, I
, TII
);
1391 Register Len
= getOrCreateConstInt(NumElements
, I
, SpirvType
, TII
);
1392 auto MIB
= BuildMI(BB
, I
, I
.getDebugLoc(), TII
.get(SPIRV::OpTypeArray
))
1393 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1394 .addUse(getSPIRVTypeID(BaseType
))
1396 DT
.add(LLVMTy
, CurMF
, getSPIRVTypeID(MIB
));
1397 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1400 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1401 SPIRVType
*BaseType
, MachineIRBuilder
&MIRBuilder
,
1402 SPIRV::StorageClass::StorageClass SC
) {
1403 const Type
*PointerElementType
= getTypeForSPIRVType(BaseType
);
1404 unsigned AddressSpace
= storageClassToAddressSpace(SC
);
1405 Type
*LLVMTy
= TypedPointerType::get(const_cast<Type
*>(PointerElementType
),
1407 // check if this type is already available
1408 Register Reg
= DT
.find(PointerElementType
, AddressSpace
, CurMF
);
1410 return getSPIRVTypeForVReg(Reg
);
1411 // create a new type
1412 auto MIB
= BuildMI(MIRBuilder
.getMBB(), MIRBuilder
.getInsertPt(),
1413 MIRBuilder
.getDebugLoc(),
1414 MIRBuilder
.getTII().get(SPIRV::OpTypePointer
))
1415 .addDef(createTypeVReg(CurMF
->getRegInfo()))
1416 .addImm(static_cast<uint32_t>(SC
))
1417 .addUse(getSPIRVTypeID(BaseType
));
1418 DT
.add(PointerElementType
, AddressSpace
, CurMF
, getSPIRVTypeID(MIB
));
1419 return finishCreatingSPIRVType(LLVMTy
, MIB
);
1422 SPIRVType
*SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1423 SPIRVType
*BaseType
, MachineInstr
&I
, const SPIRVInstrInfo
&,
1424 SPIRV::StorageClass::StorageClass SC
) {
1425 MachineIRBuilder
MIRBuilder(I
);
1426 return getOrCreateSPIRVPointerType(BaseType
, MIRBuilder
, SC
);
1429 Register
SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr
&I
,
1431 const SPIRVInstrInfo
&TII
) {
1433 const Type
*LLVMTy
= getTypeForSPIRVType(SpvType
);
1435 // Find a constant in DT or build a new one.
1436 UndefValue
*UV
= UndefValue::get(const_cast<Type
*>(LLVMTy
));
1437 Register Res
= DT
.find(UV
, CurMF
);
1440 LLT LLTy
= LLT::scalar(32);
1441 Res
= CurMF
->getRegInfo().createGenericVirtualRegister(LLTy
);
1442 CurMF
->getRegInfo().setRegClass(Res
, &SPIRV::IDRegClass
);
1443 assignSPIRVTypeToVReg(SpvType
, Res
, *CurMF
);
1444 DT
.add(UV
, CurMF
, Res
);
1446 MachineInstrBuilder MIB
;
1447 MIB
= BuildMI(*I
.getParent(), I
, I
.getDebugLoc(), TII
.get(SPIRV::OpUndef
))
1449 .addUse(getSPIRVTypeID(SpvType
));
1450 const auto &ST
= CurMF
->getSubtarget();
1451 constrainSelectedInstRegOperands(*MIB
, *ST
.getInstrInfo(),
1452 *ST
.getRegisterInfo(), *ST
.getRegBankInfo());