1 //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 // Arguments to kernel and device functions are passed via param space,
11 // which imposes certain restrictions:
12 // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
14 // Kernel parameters are read-only and accessible only via ld.param
15 // instruction, directly or via a pointer. Pointers to kernel
16 // arguments can't be converted to generic address space.
18 // Device function parameters are directly accessible via
19 // ld.param/st.param, but taking the address of one returns a pointer
20 // to a copy created in local space which *can't* be used with
23 // Copying a byval struct into local memory in IR allows us to enforce
24 // the param space restrictions, gives the rest of IR a pointer w/o
25 // param space restrictions, and gives us an opportunity to eliminate
28 // Pointer arguments to kernel functions need more work to be lowered:
30 // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
31 // global address space. This allows later optimizations to emit
32 // ld.global.*/st.global.* for accessing these pointer arguments. For
35 // define void @foo(float* %input) {
36 // %v = load float, float* %input, align 4
42 // define void @foo(float* %input) {
43 // %input2 = addrspacecast float* %input to float addrspace(1)*
44 // %input3 = addrspacecast float addrspace(1)* %input2 to float*
45 // %v = load float, float* %input3, align 4
49 // Later, NVPTXInferAddressSpaces will optimize it to
51 // define void @foo(float* %input) {
52 // %input2 = addrspacecast float* %input to float addrspace(1)*
53 // %v = load float, float addrspace(1)* %input2, align 4
57 // 2. Convert pointers in a byval kernel parameter to pointers in the global
58 // address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
64 // __global__ void foo(S s) {
69 // "b" points to the global address space. In the IR level,
71 // define void @foo({i32*, i32*}* byval %input) {
72 // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
73 // %b = load i32*, i32** %b_ptr
79 // define void @foo({i32*, i32*}* byval %input) {
80 // %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81 // %b = load i32*, i32** %b_ptr
82 // %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83 // %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
87 // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88 // cancel the addrspacecast pair this pass emits.
89 //===----------------------------------------------------------------------===//
92 #include "NVPTXTargetMachine.h"
93 #include "NVPTXUtilities.h"
94 #include "MCTargetDesc/NVPTXBaseInfo.h"
95 #include "llvm/Analysis/ValueTracking.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/Instructions.h"
98 #include "llvm/IR/Module.h"
99 #include "llvm/IR/Type.h"
100 #include "llvm/Pass.h"
102 using namespace llvm
;
105 void initializeNVPTXLowerArgsPass(PassRegistry
&);
109 class NVPTXLowerArgs
: public FunctionPass
{
110 bool runOnFunction(Function
&F
) override
;
112 bool runOnKernelFunction(Function
&F
);
113 bool runOnDeviceFunction(Function
&F
);
115 // handle byval parameters
116 void handleByValParam(Argument
*Arg
);
117 // Knowing Ptr must point to the global address space, this function
118 // addrspacecasts Ptr to global and then back to generic. This allows
119 // NVPTXInferAddressSpaces to fold the global-to-generic cast into
120 // loads/stores that appear later.
121 void markPointerAsGlobal(Value
*Ptr
);
124 static char ID
; // Pass identification, replacement for typeid
125 NVPTXLowerArgs(const NVPTXTargetMachine
*TM
= nullptr)
126 : FunctionPass(ID
), TM(TM
) {}
127 StringRef
getPassName() const override
{
128 return "Lower pointer arguments of CUDA kernels";
132 const NVPTXTargetMachine
*TM
;
136 char NVPTXLowerArgs::ID
= 1;
138 INITIALIZE_PASS(NVPTXLowerArgs
, "nvptx-lower-args",
139 "Lower arguments (NVPTX)", false, false)
141 // =============================================================================
142 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
143 // then add the following instructions to the first basic block:
145 // %temp = alloca %struct.x, align 8
146 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
147 // %tv = load %struct.x addrspace(101)* %tempd
148 // store %struct.x %tv, %struct.x* %temp, align 8
150 // The above code allocates some space in the stack and copies the incoming
151 // struct from param space to local space.
152 // Then replace all occurrences of %d by %temp.
153 // =============================================================================
154 void NVPTXLowerArgs::handleByValParam(Argument
*Arg
) {
155 Function
*Func
= Arg
->getParent();
156 Instruction
*FirstInst
= &(Func
->getEntryBlock().front());
157 PointerType
*PType
= dyn_cast
<PointerType
>(Arg
->getType());
159 assert(PType
&& "Expecting pointer type in handleByValParam");
161 Type
*StructType
= PType
->getElementType();
162 unsigned AS
= Func
->getParent()->getDataLayout().getAllocaAddrSpace();
163 AllocaInst
*AllocA
= new AllocaInst(StructType
, AS
, Arg
->getName(), FirstInst
);
164 // Set the alignment to alignment of the byval parameter. This is because,
165 // later load/stores assume that alignment, and we are going to replace
166 // the use of the byval parameter with this alloca instruction.
167 AllocA
->setAlignment(MaybeAlign(Func
->getParamAlignment(Arg
->getArgNo())));
168 Arg
->replaceAllUsesWith(AllocA
);
170 Value
*ArgInParam
= new AddrSpaceCastInst(
171 Arg
, PointerType::get(StructType
, ADDRESS_SPACE_PARAM
), Arg
->getName(),
174 new LoadInst(StructType
, ArgInParam
, Arg
->getName(), FirstInst
);
175 new StoreInst(LI
, AllocA
, FirstInst
);
178 void NVPTXLowerArgs::markPointerAsGlobal(Value
*Ptr
) {
179 if (Ptr
->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL
)
182 // Deciding where to emit the addrspacecast pair.
183 BasicBlock::iterator InsertPt
;
184 if (Argument
*Arg
= dyn_cast
<Argument
>(Ptr
)) {
185 // Insert at the functon entry if Ptr is an argument.
186 InsertPt
= Arg
->getParent()->getEntryBlock().begin();
188 // Insert right after Ptr if Ptr is an instruction.
189 InsertPt
= ++cast
<Instruction
>(Ptr
)->getIterator();
190 assert(InsertPt
!= InsertPt
->getParent()->end() &&
191 "We don't call this function with Ptr being a terminator.");
194 Instruction
*PtrInGlobal
= new AddrSpaceCastInst(
195 Ptr
, PointerType::get(Ptr
->getType()->getPointerElementType(),
196 ADDRESS_SPACE_GLOBAL
),
197 Ptr
->getName(), &*InsertPt
);
198 Value
*PtrInGeneric
= new AddrSpaceCastInst(PtrInGlobal
, Ptr
->getType(),
199 Ptr
->getName(), &*InsertPt
);
200 // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
201 Ptr
->replaceAllUsesWith(PtrInGeneric
);
202 PtrInGlobal
->setOperand(0, Ptr
);
205 // =============================================================================
206 // Main function for this pass.
207 // =============================================================================
208 bool NVPTXLowerArgs::runOnKernelFunction(Function
&F
) {
209 if (TM
&& TM
->getDrvInterface() == NVPTX::CUDA
) {
210 // Mark pointers in byval structs as global.
213 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(&I
)) {
214 if (LI
->getType()->isPointerTy()) {
215 Value
*UO
= GetUnderlyingObject(LI
->getPointerOperand(),
216 F
.getParent()->getDataLayout());
217 if (Argument
*Arg
= dyn_cast
<Argument
>(UO
)) {
218 if (Arg
->hasByValAttr()) {
219 // LI is a load from a pointer within a byval kernel parameter.
220 markPointerAsGlobal(LI
);
229 for (Argument
&Arg
: F
.args()) {
230 if (Arg
.getType()->isPointerTy()) {
231 if (Arg
.hasByValAttr())
232 handleByValParam(&Arg
);
233 else if (TM
&& TM
->getDrvInterface() == NVPTX::CUDA
)
234 markPointerAsGlobal(&Arg
);
240 // Device functions only need to copy byval args into local memory.
241 bool NVPTXLowerArgs::runOnDeviceFunction(Function
&F
) {
242 for (Argument
&Arg
: F
.args())
243 if (Arg
.getType()->isPointerTy() && Arg
.hasByValAttr())
244 handleByValParam(&Arg
);
248 bool NVPTXLowerArgs::runOnFunction(Function
&F
) {
249 return isKernelFunction(F
) ? runOnKernelFunction(F
) : runOnDeviceFunction(F
);
253 llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine
*TM
) {
254 return new NVPTXLowerArgs(TM
);