1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements the PTXTargetLowering class.
12 //===----------------------------------------------------------------------===//
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Support/ErrorHandling.h"
20 #include "llvm/CodeGen/CallingConvLower.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/SelectionDAG.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/Support/raw_ostream.h"
29 //===----------------------------------------------------------------------===//
30 // Calling Convention Implementation
31 //===----------------------------------------------------------------------===//
33 #include "PTXGenCallingConv.inc"
35 //===----------------------------------------------------------------------===//
36 // TargetLowering Implementation
37 //===----------------------------------------------------------------------===//
39 PTXTargetLowering::PTXTargetLowering(TargetMachine
&TM
)
40 : TargetLowering(TM
, new TargetLoweringObjectFileELF()) {
41 // Set up the register classes.
42 addRegisterClass(MVT::i1
, PTX::RegPredRegisterClass
);
43 addRegisterClass(MVT::i16
, PTX::RegI16RegisterClass
);
44 addRegisterClass(MVT::i32
, PTX::RegI32RegisterClass
);
45 addRegisterClass(MVT::i64
, PTX::RegI64RegisterClass
);
46 addRegisterClass(MVT::f32
, PTX::RegF32RegisterClass
);
47 addRegisterClass(MVT::f64
, PTX::RegF64RegisterClass
);
49 setBooleanContents(ZeroOrOneBooleanContent
);
50 setMinFunctionAlignment(2);
52 ////////////////////////////////////
53 /////////// Expansion //////////////
54 ////////////////////////////////////
56 // (any/zero/sign) extload => load + (any/zero/sign) extend
58 setLoadExtAction(ISD::EXTLOAD
, MVT::i16
, Expand
);
59 setLoadExtAction(ISD::ZEXTLOAD
, MVT::i16
, Expand
);
60 setLoadExtAction(ISD::SEXTLOAD
, MVT::i16
, Expand
);
62 // f32 extload => load + fextend
64 setLoadExtAction(ISD::EXTLOAD
, MVT::f32
, Expand
);
66 // f64 truncstore => trunc + store
68 setTruncStoreAction(MVT::f64
, MVT::f32
, Expand
);
70 // sign_extend_inreg => sign_extend
72 setOperationAction(ISD::SIGN_EXTEND_INREG
, MVT::i1
, Expand
);
76 setOperationAction(ISD::BR_CC
, MVT::Other
, Expand
);
80 setOperationAction(ISD::SELECT_CC
, MVT::Other
, Expand
);
81 setOperationAction(ISD::SELECT_CC
, MVT::f32
, Expand
);
82 setOperationAction(ISD::SELECT_CC
, MVT::f64
, Expand
);
84 ////////////////////////////////////
85 //////////// Legal /////////////////
86 ////////////////////////////////////
88 setOperationAction(ISD::ConstantFP
, MVT::f32
, Legal
);
89 setOperationAction(ISD::ConstantFP
, MVT::f64
, Legal
);
91 ////////////////////////////////////
92 //////////// Custom ////////////////
93 ////////////////////////////////////
95 // customise setcc to use bitwise logic if possible
97 setOperationAction(ISD::SETCC
, MVT::i1
, Custom
);
99 // customize translation of memory addresses
101 setOperationAction(ISD::GlobalAddress
, MVT::i32
, Custom
);
102 setOperationAction(ISD::GlobalAddress
, MVT::i64
, Custom
);
104 // Compute derived properties from the register classes
105 computeRegisterProperties();
108 MVT::SimpleValueType
PTXTargetLowering::getSetCCResultType(EVT VT
) const {
112 SDValue
PTXTargetLowering::LowerOperation(SDValue Op
, SelectionDAG
&DAG
) const {
113 switch (Op
.getOpcode()) {
115 llvm_unreachable("Unimplemented operand");
117 return LowerSETCC(Op
, DAG
);
118 case ISD::GlobalAddress
:
119 return LowerGlobalAddress(Op
, DAG
);
123 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode
) const {
126 llvm_unreachable("Unknown opcode");
127 case PTXISD::COPY_ADDRESS
:
128 return "PTXISD::COPY_ADDRESS";
129 case PTXISD::LOAD_PARAM
:
130 return "PTXISD::LOAD_PARAM";
131 case PTXISD::STORE_PARAM
:
132 return "PTXISD::STORE_PARAM";
134 return "PTXISD::EXIT";
136 return "PTXISD::RET";
140 //===----------------------------------------------------------------------===//
141 // Custom Lower Operation
142 //===----------------------------------------------------------------------===//
144 SDValue
PTXTargetLowering::LowerSETCC(SDValue Op
, SelectionDAG
&DAG
) const {
145 assert(Op
.getValueType() == MVT::i1
&& "SetCC type must be 1-bit integer");
146 SDValue Op0
= Op
.getOperand(0);
147 SDValue Op1
= Op
.getOperand(1);
148 SDValue Op2
= Op
.getOperand(2);
149 DebugLoc dl
= Op
.getDebugLoc();
150 ISD::CondCode CC
= cast
<CondCodeSDNode
>(Op
.getOperand(2))->get();
152 // Look for X == 0, X == 1, X != 0, or X != 1
153 // We can simplify these to bitwise logic
155 if (Op1
.getOpcode() == ISD::Constant
&&
156 (cast
<ConstantSDNode
>(Op1
)->getZExtValue() == 1 ||
157 cast
<ConstantSDNode
>(Op1
)->isNullValue()) &&
158 (CC
== ISD::SETEQ
|| CC
== ISD::SETNE
)) {
160 return DAG
.getNode(ISD::AND
, dl
, MVT::i1
, Op0
, Op1
);
163 return DAG
.getNode(ISD::SETCC
, dl
, MVT::i1
, Op0
, Op1
, Op2
);
166 SDValue
PTXTargetLowering::
167 LowerGlobalAddress(SDValue Op
, SelectionDAG
&DAG
) const {
168 EVT PtrVT
= getPointerTy();
169 DebugLoc dl
= Op
.getDebugLoc();
170 const GlobalValue
*GV
= cast
<GlobalAddressSDNode
>(Op
)->getGlobal();
172 assert(PtrVT
.isSimple() && "Pointer must be to primitive type.");
174 SDValue targetGlobal
= DAG
.getTargetGlobalAddress(GV
, dl
, PtrVT
);
175 SDValue movInstr
= DAG
.getNode(PTXISD::COPY_ADDRESS
,
183 //===----------------------------------------------------------------------===//
184 // Calling Convention Implementation
185 //===----------------------------------------------------------------------===//
187 SDValue
PTXTargetLowering::
188 LowerFormalArguments(SDValue Chain
,
189 CallingConv::ID CallConv
,
191 const SmallVectorImpl
<ISD::InputArg
> &Ins
,
194 SmallVectorImpl
<SDValue
> &InVals
) const {
195 if (isVarArg
) llvm_unreachable("PTX does not support varargs");
197 MachineFunction
&MF
= DAG
.getMachineFunction();
198 const PTXSubtarget
& ST
= getTargetMachine().getSubtarget
<PTXSubtarget
>();
199 PTXMachineFunctionInfo
*MFI
= MF
.getInfo
<PTXMachineFunctionInfo
>();
203 llvm_unreachable("Unsupported calling convention");
205 case CallingConv::PTX_Kernel
:
206 MFI
->setKernel(true);
208 case CallingConv::PTX_Device
:
209 MFI
->setKernel(false);
213 // We do one of two things here:
214 // IsKernel || SM >= 2.0 -> Use param space for arguments
215 // SM < 2.0 -> Use registers for arguments
216 if (MFI
->isKernel() || ST
.useParamSpaceForDeviceArgs()) {
217 // We just need to emit the proper LOAD_PARAM ISDs
218 for (unsigned i
= 0, e
= Ins
.size(); i
!= e
; ++i
) {
220 assert((!MFI
->isKernel() || Ins
[i
].VT
!= MVT::i1
) &&
221 "Kernels cannot take pred operands");
223 SDValue ArgValue
= DAG
.getNode(PTXISD::LOAD_PARAM
, dl
, Ins
[i
].VT
, Chain
,
224 DAG
.getTargetConstant(i
, MVT::i32
));
225 InVals
.push_back(ArgValue
);
227 // Instead of storing a physical register in our argument list, we just
228 // store the total size of the parameter, in bits. The ASM printer
229 // knows how to process this.
230 MFI
->addArgReg(Ins
[i
].VT
.getStoreSizeInBits());
234 // For device functions, we use the PTX calling convention to do register
235 // assignments then create CopyFromReg ISDs for the allocated registers
237 SmallVector
<CCValAssign
, 16> ArgLocs
;
238 CCState
CCInfo(CallConv
, isVarArg
, MF
, getTargetMachine(), ArgLocs
,
241 CCInfo
.AnalyzeFormalArguments(Ins
, CC_PTX
);
243 for (unsigned i
= 0, e
= ArgLocs
.size(); i
!= e
; ++i
) {
245 CCValAssign
& VA
= ArgLocs
[i
];
246 EVT RegVT
= VA
.getLocVT();
247 TargetRegisterClass
* TRC
= 0;
249 assert(VA
.isRegLoc() && "CCValAssign must be RegLoc");
251 // Determine which register class we need
252 if (RegVT
== MVT::i1
) {
253 TRC
= PTX::RegPredRegisterClass
;
255 else if (RegVT
== MVT::i16
) {
256 TRC
= PTX::RegI16RegisterClass
;
258 else if (RegVT
== MVT::i32
) {
259 TRC
= PTX::RegI32RegisterClass
;
261 else if (RegVT
== MVT::i64
) {
262 TRC
= PTX::RegI64RegisterClass
;
264 else if (RegVT
== MVT::f32
) {
265 TRC
= PTX::RegF32RegisterClass
;
267 else if (RegVT
== MVT::f64
) {
268 TRC
= PTX::RegF64RegisterClass
;
271 llvm_unreachable("Unknown parameter type");
274 unsigned Reg
= MF
.getRegInfo().createVirtualRegister(TRC
);
275 MF
.getRegInfo().addLiveIn(VA
.getLocReg(), Reg
);
277 SDValue ArgValue
= DAG
.getCopyFromReg(Chain
, dl
, Reg
, RegVT
);
278 InVals
.push_back(ArgValue
);
280 MFI
->addArgReg(VA
.getLocReg());
287 SDValue
PTXTargetLowering::
288 LowerReturn(SDValue Chain
,
289 CallingConv::ID CallConv
,
291 const SmallVectorImpl
<ISD::OutputArg
> &Outs
,
292 const SmallVectorImpl
<SDValue
> &OutVals
,
294 SelectionDAG
&DAG
) const {
295 if (isVarArg
) llvm_unreachable("PTX does not support varargs");
299 llvm_unreachable("Unsupported calling convention.");
300 case CallingConv::PTX_Kernel
:
301 assert(Outs
.size() == 0 && "Kernel must return void.");
302 return DAG
.getNode(PTXISD::EXIT
, dl
, MVT::Other
, Chain
);
303 case CallingConv::PTX_Device
:
304 //assert(Outs.size() <= 1 && "Can at most return one value.");
308 MachineFunction
& MF
= DAG
.getMachineFunction();
309 PTXMachineFunctionInfo
*MFI
= MF
.getInfo
<PTXMachineFunctionInfo
>();
313 // Even though we could use the .param space for return arguments for
314 // device functions if SM >= 2.0 and the number of return arguments is
315 // only 1, we just always use registers since this makes the codegen
317 SmallVector
<CCValAssign
, 16> RVLocs
;
318 CCState
CCInfo(CallConv
, isVarArg
, DAG
.getMachineFunction(),
319 getTargetMachine(), RVLocs
, *DAG
.getContext());
321 CCInfo
.AnalyzeReturn(Outs
, RetCC_PTX
);
323 for (unsigned i
= 0, e
= RVLocs
.size(); i
!= e
; ++i
) {
324 CCValAssign
& VA
= RVLocs
[i
];
326 assert(VA
.isRegLoc() && "CCValAssign must be RegLoc");
328 unsigned Reg
= VA
.getLocReg();
330 DAG
.getMachineFunction().getRegInfo().addLiveOut(Reg
);
332 Chain
= DAG
.getCopyToReg(Chain
, dl
, Reg
, OutVals
[i
], Flag
);
334 // Guarantee that all emitted copies are stuck together,
335 // avoiding something bad
336 Flag
= Chain
.getValue(1);
341 if (Flag
.getNode() == 0) {
342 return DAG
.getNode(PTXISD::RET
, dl
, MVT::Other
, Chain
);
345 return DAG
.getNode(PTXISD::RET
, dl
, MVT::Other
, Chain
, Flag
);