1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the SPIRVTargetLowering class.
11 //===----------------------------------------------------------------------===//
13 #include "SPIRVISelLowering.h"
15 #include "llvm/IR/IntrinsicsSPIRV.h"
17 #define DEBUG_TYPE "spirv-lower"
21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
22 LLVMContext
&Context
, CallingConv::ID CC
, EVT VT
) const {
23 // This code avoids CallLowering fail inside getVectorTypeBreakdown
24 // on v3i1 arguments. Maybe we need to return 1 for all types.
25 // TODO: remove it once this case is supported by the default implementation.
26 if (VT
.isVector() && VT
.getVectorNumElements() == 3 &&
27 (VT
.getVectorElementType() == MVT::i1
||
28 VT
.getVectorElementType() == MVT::i8
))
30 if (!VT
.isVector() && VT
.isInteger() && VT
.getSizeInBits() <= 64)
32 return getNumRegisters(Context
, VT
);
35 MVT
SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext
&Context
,
38 // This code avoids CallLowering fail inside getVectorTypeBreakdown
39 // on v3i1 arguments. Maybe we need to return i32 for all types.
40 // TODO: remove it once this case is supported by the default implementation.
41 if (VT
.isVector() && VT
.getVectorNumElements() == 3) {
42 if (VT
.getVectorElementType() == MVT::i1
)
44 else if (VT
.getVectorElementType() == MVT::i8
)
47 return getRegisterType(Context
, VT
);
50 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo
&Info
,
53 unsigned Intrinsic
) const {
54 unsigned AlignIdx
= 3;
56 case Intrinsic::spv_load
:
59 case Intrinsic::spv_store
: {
60 if (I
.getNumOperands() >= AlignIdx
+ 1) {
61 auto *AlignOp
= cast
<ConstantInt
>(I
.getOperand(AlignIdx
));
62 Info
.align
= Align(AlignOp
->getZExtValue());
64 Info
.flags
= static_cast<MachineMemOperand::Flags
>(
65 cast
<ConstantInt
>(I
.getOperand(AlignIdx
- 1))->getZExtValue());
66 Info
.memVT
= MVT::i64
;
67 // TODO: take into account opaque pointers (don't use getElementType).
68 // MVT::getVT(PtrTy->getElementType());