[HLSL][SPIRV] Add any intrinsic lowering (#88325)
[llvm-project.git] / llvm / lib / Target / SPIRV / SPIRVISelLowering.cpp
blobb8296c3f6eeaeeaa264eace08e0fe38c767dc7bf
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
11 //===----------------------------------------------------------------------===//
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "SPIRVInstrInfo.h"
16 #include "SPIRVRegisterBankInfo.h"
17 #include "SPIRVRegisterInfo.h"
18 #include "SPIRVSubtarget.h"
19 #include "SPIRVTargetMachine.h"
20 #include "llvm/CodeGen/MachineInstrBuilder.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #define DEBUG_TYPE "spirv-lower"
26 using namespace llvm;
28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30 // This code avoids CallLowering fail inside getVectorTypeBreakdown
31 // on v3i1 arguments. Maybe we need to return 1 for all types.
32 // TODO: remove it once this case is supported by the default implementation.
33 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34 (VT.getVectorElementType() == MVT::i1 ||
35 VT.getVectorElementType() == MVT::i8))
36 return 1;
37 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38 return 1;
39 return getNumRegisters(Context, VT);
42 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
43 CallingConv::ID CC,
44 EVT VT) const {
45 // This code avoids CallLowering fail inside getVectorTypeBreakdown
46 // on v3i1 arguments. Maybe we need to return i32 for all types.
47 // TODO: remove it once this case is supported by the default implementation.
48 if (VT.isVector() && VT.getVectorNumElements() == 3) {
49 if (VT.getVectorElementType() == MVT::i1)
50 return MVT::v4i1;
51 else if (VT.getVectorElementType() == MVT::i8)
52 return MVT::v4i8;
54 return getRegisterType(Context, VT);
57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58 const CallInst &I,
59 MachineFunction &MF,
60 unsigned Intrinsic) const {
61 unsigned AlignIdx = 3;
62 switch (Intrinsic) {
63 case Intrinsic::spv_load:
64 AlignIdx = 2;
65 [[fallthrough]];
66 case Intrinsic::spv_store: {
67 if (I.getNumOperands() >= AlignIdx + 1) {
68 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69 Info.align = Align(AlignOp->getZExtValue());
71 Info.flags = static_cast<MachineMemOperand::Flags>(
72 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73 Info.memVT = MVT::i64;
74 // TODO: take into account opaque pointers (don't use getElementType).
75 // MVT::getVT(PtrTy->getElementType());
76 return true;
77 break;
79 default:
80 break;
82 return false;
85 // Insert a bitcast before the instruction to keep SPIR-V code valid
86 // when there is a type mismatch between results and operand types.
87 static void validatePtrTypes(const SPIRVSubtarget &STI,
88 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
89 MachineInstr &I, unsigned OpIdx,
90 SPIRVType *ResType, const Type *ResTy = nullptr) {
91 // Get operand type
92 MachineFunction *MF = I.getParent()->getParent();
93 Register OpReg = I.getOperand(OpIdx).getReg();
94 SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
95 Register OpTypeReg =
96 TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
97 ? TypeInst->getOperand(1).getReg()
98 : OpReg;
99 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
100 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
101 return;
102 // Get operand's pointee type
103 Register ElemTypeReg = OpType->getOperand(2).getReg();
104 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
105 if (!ElemType)
106 return;
107 // Check if we need a bitcast to make a statement valid
108 bool IsSameMF = MF == ResType->getParent()->getParent();
109 bool IsEqualTypes = IsSameMF ? ElemType == ResType
110 : GR.getTypeForSPIRVType(ElemType) == ResTy;
111 if (IsEqualTypes)
112 return;
113 // There is a type mismatch between results and operand types
114 // and we insert a bitcast before the instruction to keep SPIR-V code valid
115 SPIRV::StorageClass::StorageClass SC =
116 static_cast<SPIRV::StorageClass::StorageClass>(
117 OpType->getOperand(1).getImm());
118 MachineIRBuilder MIB(I);
119 SPIRVType *NewBaseType =
120 IsSameMF ? ResType
121 : GR.getOrCreateSPIRVType(
122 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
123 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
124 if (!GR.isBitcastCompatible(NewPtrType, OpType))
125 report_fatal_error(
126 "insert validation bitcast: incompatible result and operand types");
127 Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
128 bool Res = MIB.buildInstr(SPIRV::OpBitcast)
129 .addDef(NewReg)
130 .addUse(GR.getSPIRVTypeID(NewPtrType))
131 .addUse(OpReg)
132 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
133 *STI.getRegBankInfo());
134 if (!Res)
135 report_fatal_error("insert validation bitcast: cannot constrain all uses");
136 MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
137 GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
138 I.getOperand(OpIdx).setReg(NewReg);
141 // Insert a bitcast before the function call instruction to keep SPIR-V code
142 // valid when there is a type mismatch between actual and expected types of an
143 // argument:
144 // %formal = OpFunctionParameter %formal_type
145 // ...
146 // %res = OpFunctionCall %ty %fun %actual ...
147 // implies that %actual is of %formal_type, and in case of opaque pointers.
148 // We may need to insert a bitcast to ensure this.
149 void validateFunCallMachineDef(const SPIRVSubtarget &STI,
150 MachineRegisterInfo *DefMRI,
151 MachineRegisterInfo *CallMRI,
152 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
153 MachineInstr *FunDef) {
154 if (FunDef->getOpcode() != SPIRV::OpFunction)
155 return;
156 unsigned OpIdx = 3;
157 for (FunDef = FunDef->getNextNode();
158 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
159 OpIdx < FunCall.getNumOperands();
160 FunDef = FunDef->getNextNode(), OpIdx++) {
161 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
162 SPIRVType *DefElemType =
163 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
164 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
165 DefPtrType->getParent()->getParent())
166 : nullptr;
167 if (DefElemType) {
168 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
169 // validatePtrTypes() works in the context if the call site
170 // When we process historical records about forward calls
171 // we need to switch context to the (forward) call site and
172 // then restore it back to the current machine function.
173 MachineFunction *CurMF =
174 GR.setCurrentFunc(*FunCall.getParent()->getParent());
175 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
176 DefElemTy);
177 GR.setCurrentFunc(*CurMF);
182 // Ensure there is no mismatch between actual and expected arg types: calls
183 // with a processed definition. Return Function pointer if it's a forward
184 // call (ahead of definition), and nullptr otherwise.
185 const Function *validateFunCall(const SPIRVSubtarget &STI,
186 MachineRegisterInfo *CallMRI,
187 SPIRVGlobalRegistry &GR,
188 MachineInstr &FunCall) {
189 const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
190 const Function *F = dyn_cast<Function>(GV);
191 MachineInstr *FunDef =
192 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
193 if (!FunDef)
194 return F;
195 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
196 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
197 return nullptr;
200 // Ensure there is no mismatch between actual and expected arg types: calls
201 // ahead of a processed definition.
202 void validateForwardCalls(const SPIRVSubtarget &STI,
203 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
204 MachineInstr &FunDef) {
205 const Function *F = GR.getFunctionByDefinition(&FunDef);
206 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
207 for (MachineInstr *FunCall : *FwdCalls) {
208 MachineRegisterInfo *CallMRI =
209 &FunCall->getParent()->getParent()->getRegInfo();
210 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
214 // Validation of an access chain.
215 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
216 SPIRVGlobalRegistry &GR, MachineInstr &I) {
217 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
218 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
219 SPIRVType *BaseElemType =
220 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
221 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
225 // TODO: the logic of inserting additional bitcast's is to be moved
226 // to pre-IRTranslation passes eventually
227 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
228 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
229 // We'd like to avoid the needless second processing pass.
230 if (ProcessedMF.find(&MF) != ProcessedMF.end())
231 return;
233 MachineRegisterInfo *MRI = &MF.getRegInfo();
234 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
235 GR.setCurrentFunc(MF);
236 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
237 MachineBasicBlock *MBB = &*I;
238 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
239 MBBI != MBBE;) {
240 MachineInstr &MI = *MBBI++;
241 switch (MI.getOpcode()) {
242 case SPIRV::OpAtomicLoad:
243 case SPIRV::OpAtomicExchange:
244 case SPIRV::OpAtomicCompareExchange:
245 case SPIRV::OpAtomicCompareExchangeWeak:
246 case SPIRV::OpAtomicIIncrement:
247 case SPIRV::OpAtomicIDecrement:
248 case SPIRV::OpAtomicIAdd:
249 case SPIRV::OpAtomicISub:
250 case SPIRV::OpAtomicSMin:
251 case SPIRV::OpAtomicUMin:
252 case SPIRV::OpAtomicSMax:
253 case SPIRV::OpAtomicUMax:
254 case SPIRV::OpAtomicAnd:
255 case SPIRV::OpAtomicOr:
256 case SPIRV::OpAtomicXor:
257 // for the above listed instructions
258 // OpAtomicXXX <ResType>, ptr %Op, ...
259 // implies that %Op is a pointer to <ResType>
260 case SPIRV::OpLoad:
261 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
262 validatePtrTypes(STI, MRI, GR, MI, 2,
263 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
264 break;
265 case SPIRV::OpAtomicStore:
266 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
267 // implies that %Op points to the <Obj>'s type
268 validatePtrTypes(STI, MRI, GR, MI, 0,
269 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
270 break;
271 case SPIRV::OpStore:
272 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
273 validatePtrTypes(STI, MRI, GR, MI, 0,
274 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
275 break;
276 case SPIRV::OpPtrCastToGeneric:
277 validateAccessChain(STI, MRI, GR, MI);
278 break;
279 case SPIRV::OpInBoundsPtrAccessChain:
280 if (MI.getNumOperands() == 4)
281 validateAccessChain(STI, MRI, GR, MI);
282 break;
284 case SPIRV::OpFunctionCall:
285 // ensure there is no mismatch between actual and expected arg types:
286 // calls with a processed definition
287 if (MI.getNumOperands() > 3)
288 if (const Function *F = validateFunCall(STI, MRI, GR, MI))
289 GR.addForwardCall(F, &MI);
290 break;
291 case SPIRV::OpFunction:
292 // ensure there is no mismatch between actual and expected arg types:
293 // calls ahead of a processed definition
294 validateForwardCalls(STI, MRI, GR, MI);
295 break;
297 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
298 // instructions when applied to bool type
299 case SPIRV::OpBitwiseOrS:
300 case SPIRV::OpBitwiseOrV:
301 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
302 SPIRV::OpTypeBool))
303 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
304 break;
305 case SPIRV::OpBitwiseAndS:
306 case SPIRV::OpBitwiseAndV:
307 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
308 SPIRV::OpTypeBool))
309 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
310 break;
311 case SPIRV::OpBitwiseXorS:
312 case SPIRV::OpBitwiseXorV:
313 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
314 SPIRV::OpTypeBool))
315 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
316 break;
320 ProcessedMF.insert(&MF);
321 TargetLowering::finalizeLowering(MF);