1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the SPIRVTargetLowering class.
11 //===----------------------------------------------------------------------===//
13 #include "SPIRVISelLowering.h"
15 #include "SPIRVInstrInfo.h"
16 #include "SPIRVRegisterBankInfo.h"
17 #include "SPIRVRegisterInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "SPIRVTargetMachine.h"
20 #include "llvm/CodeGen/MachineInstrBuilder.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #define DEBUG_TYPE "spirv-lower"
28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29 LLVMContext
&Context
, CallingConv::ID CC
, EVT VT
) const {
30 // This code avoids CallLowering fail inside getVectorTypeBreakdown
31 // on v3i1 arguments. Maybe we need to return 1 for all types.
32 // TODO: remove it once this case is supported by the default implementation.
33 if (VT
.isVector() && VT
.getVectorNumElements() == 3 &&
34 (VT
.getVectorElementType() == MVT::i1
||
35 VT
.getVectorElementType() == MVT::i8
))
37 if (!VT
.isVector() && VT
.isInteger() && VT
.getSizeInBits() <= 64)
39 return getNumRegisters(Context
, VT
);
42 MVT
SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext
&Context
,
45 // This code avoids CallLowering fail inside getVectorTypeBreakdown
46 // on v3i1 arguments. Maybe we need to return i32 for all types.
47 // TODO: remove it once this case is supported by the default implementation.
48 if (VT
.isVector() && VT
.getVectorNumElements() == 3) {
49 if (VT
.getVectorElementType() == MVT::i1
)
51 else if (VT
.getVectorElementType() == MVT::i8
)
54 return getRegisterType(Context
, VT
);
57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo
&Info
,
60 unsigned Intrinsic
) const {
61 unsigned AlignIdx
= 3;
63 case Intrinsic::spv_load
:
66 case Intrinsic::spv_store
: {
67 if (I
.getNumOperands() >= AlignIdx
+ 1) {
68 auto *AlignOp
= cast
<ConstantInt
>(I
.getOperand(AlignIdx
));
69 Info
.align
= Align(AlignOp
->getZExtValue());
71 Info
.flags
= static_cast<MachineMemOperand::Flags
>(
72 cast
<ConstantInt
>(I
.getOperand(AlignIdx
- 1))->getZExtValue());
73 Info
.memVT
= MVT::i64
;
74 // TODO: take into account opaque pointers (don't use getElementType).
75 // MVT::getVT(PtrTy->getElementType());
85 // Insert a bitcast before the instruction to keep SPIR-V code valid
86 // when there is a type mismatch between results and operand types.
87 static void validatePtrTypes(const SPIRVSubtarget
&STI
,
88 MachineRegisterInfo
*MRI
, SPIRVGlobalRegistry
&GR
,
89 MachineInstr
&I
, unsigned OpIdx
,
90 SPIRVType
*ResType
, const Type
*ResTy
= nullptr) {
92 MachineFunction
*MF
= I
.getParent()->getParent();
93 Register OpReg
= I
.getOperand(OpIdx
).getReg();
94 SPIRVType
*TypeInst
= MRI
->getVRegDef(OpReg
);
96 TypeInst
&& TypeInst
->getOpcode() == SPIRV::OpFunctionParameter
97 ? TypeInst
->getOperand(1).getReg()
99 SPIRVType
*OpType
= GR
.getSPIRVTypeForVReg(OpTypeReg
, MF
);
100 if (!ResType
|| !OpType
|| OpType
->getOpcode() != SPIRV::OpTypePointer
)
102 // Get operand's pointee type
103 Register ElemTypeReg
= OpType
->getOperand(2).getReg();
104 SPIRVType
*ElemType
= GR
.getSPIRVTypeForVReg(ElemTypeReg
, MF
);
107 // Check if we need a bitcast to make a statement valid
108 bool IsSameMF
= MF
== ResType
->getParent()->getParent();
109 bool IsEqualTypes
= IsSameMF
? ElemType
== ResType
110 : GR
.getTypeForSPIRVType(ElemType
) == ResTy
;
113 // There is a type mismatch between results and operand types
114 // and we insert a bitcast before the instruction to keep SPIR-V code valid
115 SPIRV::StorageClass::StorageClass SC
=
116 static_cast<SPIRV::StorageClass::StorageClass
>(
117 OpType
->getOperand(1).getImm());
118 MachineIRBuilder
MIB(I
);
119 SPIRVType
*NewBaseType
=
121 : GR
.getOrCreateSPIRVType(
122 ResTy
, MIB
, SPIRV::AccessQualifier::ReadWrite
, false);
123 SPIRVType
*NewPtrType
= GR
.getOrCreateSPIRVPointerType(NewBaseType
, MIB
, SC
);
124 if (!GR
.isBitcastCompatible(NewPtrType
, OpType
))
126 "insert validation bitcast: incompatible result and operand types");
127 Register NewReg
= MRI
->createGenericVirtualRegister(LLT::scalar(32));
128 bool Res
= MIB
.buildInstr(SPIRV::OpBitcast
)
130 .addUse(GR
.getSPIRVTypeID(NewPtrType
))
132 .constrainAllUses(*STI
.getInstrInfo(), *STI
.getRegisterInfo(),
133 *STI
.getRegBankInfo());
135 report_fatal_error("insert validation bitcast: cannot constrain all uses");
136 MRI
->setRegClass(NewReg
, &SPIRV::IDRegClass
);
137 GR
.assignSPIRVTypeToVReg(NewPtrType
, NewReg
, MIB
.getMF());
138 I
.getOperand(OpIdx
).setReg(NewReg
);
141 // Insert a bitcast before the function call instruction to keep SPIR-V code
142 // valid when there is a type mismatch between actual and expected types of an
144 // %formal = OpFunctionParameter %formal_type
146 // %res = OpFunctionCall %ty %fun %actual ...
147 // implies that %actual is of %formal_type, and in case of opaque pointers.
148 // We may need to insert a bitcast to ensure this.
149 void validateFunCallMachineDef(const SPIRVSubtarget
&STI
,
150 MachineRegisterInfo
*DefMRI
,
151 MachineRegisterInfo
*CallMRI
,
152 SPIRVGlobalRegistry
&GR
, MachineInstr
&FunCall
,
153 MachineInstr
*FunDef
) {
154 if (FunDef
->getOpcode() != SPIRV::OpFunction
)
157 for (FunDef
= FunDef
->getNextNode();
158 FunDef
&& FunDef
->getOpcode() == SPIRV::OpFunctionParameter
&&
159 OpIdx
< FunCall
.getNumOperands();
160 FunDef
= FunDef
->getNextNode(), OpIdx
++) {
161 SPIRVType
*DefPtrType
= DefMRI
->getVRegDef(FunDef
->getOperand(1).getReg());
162 SPIRVType
*DefElemType
=
163 DefPtrType
&& DefPtrType
->getOpcode() == SPIRV::OpTypePointer
164 ? GR
.getSPIRVTypeForVReg(DefPtrType
->getOperand(2).getReg(),
165 DefPtrType
->getParent()->getParent())
168 const Type
*DefElemTy
= GR
.getTypeForSPIRVType(DefElemType
);
169 // validatePtrTypes() works in the context if the call site
170 // When we process historical records about forward calls
171 // we need to switch context to the (forward) call site and
172 // then restore it back to the current machine function.
173 MachineFunction
*CurMF
=
174 GR
.setCurrentFunc(*FunCall
.getParent()->getParent());
175 validatePtrTypes(STI
, CallMRI
, GR
, FunCall
, OpIdx
, DefElemType
,
177 GR
.setCurrentFunc(*CurMF
);
182 // Ensure there is no mismatch between actual and expected arg types: calls
183 // with a processed definition. Return Function pointer if it's a forward
184 // call (ahead of definition), and nullptr otherwise.
185 const Function
*validateFunCall(const SPIRVSubtarget
&STI
,
186 MachineRegisterInfo
*CallMRI
,
187 SPIRVGlobalRegistry
&GR
,
188 MachineInstr
&FunCall
) {
189 const GlobalValue
*GV
= FunCall
.getOperand(2).getGlobal();
190 const Function
*F
= dyn_cast
<Function
>(GV
);
191 MachineInstr
*FunDef
=
192 const_cast<MachineInstr
*>(GR
.getFunctionDefinition(F
));
195 MachineRegisterInfo
*DefMRI
= &FunDef
->getParent()->getParent()->getRegInfo();
196 validateFunCallMachineDef(STI
, DefMRI
, CallMRI
, GR
, FunCall
, FunDef
);
200 // Ensure there is no mismatch between actual and expected arg types: calls
201 // ahead of a processed definition.
202 void validateForwardCalls(const SPIRVSubtarget
&STI
,
203 MachineRegisterInfo
*DefMRI
, SPIRVGlobalRegistry
&GR
,
204 MachineInstr
&FunDef
) {
205 const Function
*F
= GR
.getFunctionByDefinition(&FunDef
);
206 if (SmallPtrSet
<MachineInstr
*, 8> *FwdCalls
= GR
.getForwardCalls(F
))
207 for (MachineInstr
*FunCall
: *FwdCalls
) {
208 MachineRegisterInfo
*CallMRI
=
209 &FunCall
->getParent()->getParent()->getRegInfo();
210 validateFunCallMachineDef(STI
, DefMRI
, CallMRI
, GR
, *FunCall
, &FunDef
);
214 // Validation of an access chain.
215 void validateAccessChain(const SPIRVSubtarget
&STI
, MachineRegisterInfo
*MRI
,
216 SPIRVGlobalRegistry
&GR
, MachineInstr
&I
) {
217 SPIRVType
*BaseTypeInst
= GR
.getSPIRVTypeForVReg(I
.getOperand(0).getReg());
218 if (BaseTypeInst
&& BaseTypeInst
->getOpcode() == SPIRV::OpTypePointer
) {
219 SPIRVType
*BaseElemType
=
220 GR
.getSPIRVTypeForVReg(BaseTypeInst
->getOperand(2).getReg());
221 validatePtrTypes(STI
, MRI
, GR
, I
, 2, BaseElemType
);
225 // TODO: the logic of inserting additional bitcast's is to be moved
226 // to pre-IRTranslation passes eventually
227 void SPIRVTargetLowering::finalizeLowering(MachineFunction
&MF
) const {
228 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
229 // We'd like to avoid the needless second processing pass.
230 if (ProcessedMF
.find(&MF
) != ProcessedMF
.end())
233 MachineRegisterInfo
*MRI
= &MF
.getRegInfo();
234 SPIRVGlobalRegistry
&GR
= *STI
.getSPIRVGlobalRegistry();
235 GR
.setCurrentFunc(MF
);
236 for (MachineFunction::iterator I
= MF
.begin(), E
= MF
.end(); I
!= E
; ++I
) {
237 MachineBasicBlock
*MBB
= &*I
;
238 for (MachineBasicBlock::iterator MBBI
= MBB
->begin(), MBBE
= MBB
->end();
240 MachineInstr
&MI
= *MBBI
++;
241 switch (MI
.getOpcode()) {
242 case SPIRV::OpAtomicLoad
:
243 case SPIRV::OpAtomicExchange
:
244 case SPIRV::OpAtomicCompareExchange
:
245 case SPIRV::OpAtomicCompareExchangeWeak
:
246 case SPIRV::OpAtomicIIncrement
:
247 case SPIRV::OpAtomicIDecrement
:
248 case SPIRV::OpAtomicIAdd
:
249 case SPIRV::OpAtomicISub
:
250 case SPIRV::OpAtomicSMin
:
251 case SPIRV::OpAtomicUMin
:
252 case SPIRV::OpAtomicSMax
:
253 case SPIRV::OpAtomicUMax
:
254 case SPIRV::OpAtomicAnd
:
255 case SPIRV::OpAtomicOr
:
256 case SPIRV::OpAtomicXor
:
257 // for the above listed instructions
258 // OpAtomicXXX <ResType>, ptr %Op, ...
259 // implies that %Op is a pointer to <ResType>
261 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
262 validatePtrTypes(STI
, MRI
, GR
, MI
, 2,
263 GR
.getSPIRVTypeForVReg(MI
.getOperand(0).getReg()));
265 case SPIRV::OpAtomicStore
:
266 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
267 // implies that %Op points to the <Obj>'s type
268 validatePtrTypes(STI
, MRI
, GR
, MI
, 0,
269 GR
.getSPIRVTypeForVReg(MI
.getOperand(3).getReg()));
272 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
273 validatePtrTypes(STI
, MRI
, GR
, MI
, 0,
274 GR
.getSPIRVTypeForVReg(MI
.getOperand(1).getReg()));
276 case SPIRV::OpPtrCastToGeneric
:
277 validateAccessChain(STI
, MRI
, GR
, MI
);
279 case SPIRV::OpInBoundsPtrAccessChain
:
280 if (MI
.getNumOperands() == 4)
281 validateAccessChain(STI
, MRI
, GR
, MI
);
284 case SPIRV::OpFunctionCall
:
285 // ensure there is no mismatch between actual and expected arg types:
286 // calls with a processed definition
287 if (MI
.getNumOperands() > 3)
288 if (const Function
*F
= validateFunCall(STI
, MRI
, GR
, MI
))
289 GR
.addForwardCall(F
, &MI
);
291 case SPIRV::OpFunction
:
292 // ensure there is no mismatch between actual and expected arg types:
293 // calls ahead of a processed definition
294 validateForwardCalls(STI
, MRI
, GR
, MI
);
297 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
298 // instructions when applied to bool type
299 case SPIRV::OpBitwiseOrS
:
300 case SPIRV::OpBitwiseOrV
:
301 if (GR
.isScalarOrVectorOfType(MI
.getOperand(1).getReg(),
303 MI
.setDesc(STI
.getInstrInfo()->get(SPIRV::OpLogicalOr
));
305 case SPIRV::OpBitwiseAndS
:
306 case SPIRV::OpBitwiseAndV
:
307 if (GR
.isScalarOrVectorOfType(MI
.getOperand(1).getReg(),
309 MI
.setDesc(STI
.getInstrInfo()->get(SPIRV::OpLogicalAnd
));
311 case SPIRV::OpBitwiseXorS
:
312 case SPIRV::OpBitwiseXorV
:
313 if (GR
.isScalarOrVectorOfType(MI
.getOperand(1).getReg(),
315 MI
.setDesc(STI
.getInstrInfo()->get(SPIRV::OpLogicalNotEqual
));
320 ProcessedMF
.insert(&MF
);
321 TargetLowering::finalizeLowering(MF
);