1 //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This pass attempts to replace out argument usage with a return of a
12 /// We can support returning a lot of values directly in registers, but
13 /// idiomatic C code frequently uses a pointer argument to return a second value
14 /// rather than returning a struct by value. GPU stack access is also quite
15 /// painful, so we want to avoid that if possible. Passing a stack object
16 /// pointer to a function also requires an additional address expansion code
17 /// sequence to convert the pointer to be relative to the kernel's scratch wave
18 /// offset register since the callee doesn't know what stack frame the incoming
19 /// pointer is relative to.
21 /// The goal is to try rewriting code that looks like this:
23 /// int foo(int a, int b, int* out) {
28 /// into something like this:
30 /// std::pair<int, int> foo(int a, int b) {
31 /// return std::make_pair(a + b, bar());
34 /// Typically the incoming pointer is a simple alloca for a temporary variable
35 /// to use the API, which if replaced with a struct return will be easily SROA'd
36 /// out when the stub function we create is inlined
38 /// This pass introduces the struct return, but leaves the unused pointer
39 /// arguments and introduces a new stub function calling the struct returning
40 /// body. DeadArgumentElimination should be run after this to clean these up.
42 //===----------------------------------------------------------------------===//
45 #include "Utils/AMDGPUBaseInfo.h"
46 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
47 #include "llvm/ADT/DenseMap.h"
48 #include "llvm/ADT/STLExtras.h"
49 #include "llvm/ADT/SmallSet.h"
50 #include "llvm/ADT/SmallVector.h"
51 #include "llvm/ADT/Statistic.h"
52 #include "llvm/Analysis/MemoryLocation.h"
53 #include "llvm/IR/Argument.h"
54 #include "llvm/IR/Attributes.h"
55 #include "llvm/IR/BasicBlock.h"
56 #include "llvm/IR/Constants.h"
57 #include "llvm/IR/DataLayout.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/IRBuilder.h"
61 #include "llvm/IR/Instructions.h"
62 #include "llvm/IR/Module.h"
63 #include "llvm/IR/Type.h"
64 #include "llvm/IR/Use.h"
65 #include "llvm/IR/User.h"
66 #include "llvm/IR/Value.h"
67 #include "llvm/Pass.h"
68 #include "llvm/Support/Casting.h"
69 #include "llvm/Support/CommandLine.h"
70 #include "llvm/Support/Debug.h"
71 #include "llvm/Support/raw_ostream.h"
75 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
79 static cl::opt
<bool> AnyAddressSpace(
80 "amdgpu-any-address-space-out-arguments",
81 cl::desc("Replace pointer out arguments with "
82 "struct returns for non-private address space"),
86 static cl::opt
<unsigned> MaxNumRetRegs(
87 "amdgpu-max-return-arg-num-regs",
88 cl::desc("Approximately limit number of return registers for replacing out arguments"),
92 STATISTIC(NumOutArgumentsReplaced
,
93 "Number out arguments moved to struct return values");
94 STATISTIC(NumOutArgumentFunctionsReplaced
,
95 "Number of functions with out arguments moved to struct return values");
99 class AMDGPURewriteOutArguments
: public FunctionPass
{
101 const DataLayout
*DL
= nullptr;
102 MemoryDependenceResults
*MDA
= nullptr;
104 bool checkArgumentUses(Value
&Arg
) const;
105 bool isOutArgumentCandidate(Argument
&Arg
) const;
108 bool isVec3ToVec4Shuffle(Type
*Ty0
, Type
* Ty1
) const;
114 AMDGPURewriteOutArguments() : FunctionPass(ID
) {}
116 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
117 AU
.addRequired
<MemoryDependenceWrapperPass
>();
118 FunctionPass::getAnalysisUsage(AU
);
121 bool doInitialization(Module
&M
) override
;
122 bool runOnFunction(Function
&F
) override
;
125 } // end anonymous namespace
127 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
128 "AMDGPU Rewrite Out Arguments", false, false)
129 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass
)
130 INITIALIZE_PASS_END(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
131 "AMDGPU Rewrite Out Arguments", false, false)
133 char AMDGPURewriteOutArguments::ID
= 0;
135 bool AMDGPURewriteOutArguments::checkArgumentUses(Value
&Arg
) const {
136 const int MaxUses
= 10;
139 for (Use
&U
: Arg
.uses()) {
140 StoreInst
*SI
= dyn_cast
<StoreInst
>(U
.getUser());
141 if (UseCount
> MaxUses
)
145 auto *BCI
= dyn_cast
<BitCastInst
>(U
.getUser());
146 if (!BCI
|| !BCI
->hasOneUse())
149 // We don't handle multiple stores currently, so stores to aggregate
150 // pointers aren't worth the trouble since they are canonically split up.
151 Type
*DestEltTy
= BCI
->getType()->getPointerElementType();
152 if (DestEltTy
->isAggregateType())
155 // We could handle these if we had a convenient way to bitcast between
157 Type
*SrcEltTy
= Arg
.getType()->getPointerElementType();
158 if (SrcEltTy
->isArrayTy())
161 // Special case handle structs with single members. It is useful to handle
162 // some casts between structs and non-structs, but we can't bitcast
163 // directly between them. directly bitcast between them. Blender uses
164 // some casts that look like { <3 x float> }* to <4 x float>*
165 if ((SrcEltTy
->isStructTy() && (SrcEltTy
->getStructNumElements() != 1)))
168 // Clang emits OpenCL 3-vector type accesses with a bitcast to the
169 // equivalent 4-element vector and accesses that, and we're looking for
170 // this pointer cast.
171 if (DL
->getTypeAllocSize(SrcEltTy
) != DL
->getTypeAllocSize(DestEltTy
))
174 return checkArgumentUses(*BCI
);
177 if (!SI
->isSimple() ||
178 U
.getOperandNo() != StoreInst::getPointerOperandIndex())
184 // Skip unused arguments.
188 bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument
&Arg
) const {
189 const unsigned MaxOutArgSizeBytes
= 4 * MaxNumRetRegs
;
190 PointerType
*ArgTy
= dyn_cast
<PointerType
>(Arg
.getType());
192 // TODO: It might be useful for any out arguments, not just privates.
193 if (!ArgTy
|| (ArgTy
->getAddressSpace() != DL
->getAllocaAddrSpace() &&
195 Arg
.hasByValAttr() || Arg
.hasStructRetAttr() ||
196 DL
->getTypeStoreSize(ArgTy
->getPointerElementType()) > MaxOutArgSizeBytes
) {
200 return checkArgumentUses(Arg
);
203 bool AMDGPURewriteOutArguments::doInitialization(Module
&M
) {
204 DL
= &M
.getDataLayout();
209 bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type
*Ty0
, Type
* Ty1
) const {
210 VectorType
*VT0
= dyn_cast
<VectorType
>(Ty0
);
211 VectorType
*VT1
= dyn_cast
<VectorType
>(Ty1
);
215 if (VT0
->getNumElements() != 3 ||
216 VT1
->getNumElements() != 4)
219 return DL
->getTypeSizeInBits(VT0
->getElementType()) ==
220 DL
->getTypeSizeInBits(VT1
->getElementType());
224 bool AMDGPURewriteOutArguments::runOnFunction(Function
&F
) {
228 // TODO: Could probably handle variadic functions.
229 if (F
.isVarArg() || F
.hasStructRetAttr() ||
230 AMDGPU::isEntryFunctionCC(F
.getCallingConv()))
233 MDA
= &getAnalysis
<MemoryDependenceWrapperPass
>().getMemDep();
235 unsigned ReturnNumRegs
= 0;
236 SmallSet
<int, 4> OutArgIndexes
;
237 SmallVector
<Type
*, 4> ReturnTypes
;
238 Type
*RetTy
= F
.getReturnType();
239 if (!RetTy
->isVoidTy()) {
240 ReturnNumRegs
= DL
->getTypeStoreSize(RetTy
) / 4;
242 if (ReturnNumRegs
>= MaxNumRetRegs
)
245 ReturnTypes
.push_back(RetTy
);
248 SmallVector
<Argument
*, 4> OutArgs
;
249 for (Argument
&Arg
: F
.args()) {
250 if (isOutArgumentCandidate(Arg
)) {
251 LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
252 << " in function " << F
.getName() << '\n');
253 OutArgs
.push_back(&Arg
);
260 using ReplacementVec
= SmallVector
<std::pair
<Argument
*, Value
*>, 4>;
262 DenseMap
<ReturnInst
*, ReplacementVec
> Replacements
;
264 SmallVector
<ReturnInst
*, 4> Returns
;
265 for (BasicBlock
&BB
: F
) {
266 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(&BB
.back()))
267 Returns
.push_back(RI
);
278 // Keep retrying if we are able to successfully eliminate an argument. This
279 // helps with cases with multiple arguments which may alias, such as in a
280 // sincos implemntation. If we have 2 stores to arguments, on the first
281 // attempt the MDA query will succeed for the second store but not the
282 // first. On the second iteration we've removed that out clobbering argument
283 // (by effectively moving it into another function) and will find the second
284 // argument is OK to move.
285 for (Argument
*OutArg
: OutArgs
) {
286 bool ThisReplaceable
= true;
287 SmallVector
<std::pair
<ReturnInst
*, StoreInst
*>, 4> ReplaceableStores
;
289 Type
*ArgTy
= OutArg
->getType()->getPointerElementType();
291 // Skip this argument if converting it will push us over the register
292 // count to return limit.
294 // TODO: This is an approximation. When legalized this could be more. We
295 // can ask TLI for exactly how many.
296 unsigned ArgNumRegs
= DL
->getTypeStoreSize(ArgTy
) / 4;
297 if (ArgNumRegs
+ ReturnNumRegs
> MaxNumRetRegs
)
300 // An argument is convertible only if all exit blocks are able to replace
302 for (ReturnInst
*RI
: Returns
) {
303 BasicBlock
*BB
= RI
->getParent();
305 MemDepResult Q
= MDA
->getPointerDependencyFrom(MemoryLocation(OutArg
),
306 true, BB
->end(), BB
, RI
);
307 StoreInst
*SI
= nullptr;
309 SI
= dyn_cast
<StoreInst
>(Q
.getInst());
312 LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI
<< '\n');
313 ReplaceableStores
.emplace_back(RI
, SI
);
315 ThisReplaceable
= false;
320 if (!ThisReplaceable
)
321 continue; // Try the next argument candidate.
323 for (std::pair
<ReturnInst
*, StoreInst
*> Store
: ReplaceableStores
) {
324 Value
*ReplVal
= Store
.second
->getValueOperand();
326 auto &ValVec
= Replacements
[Store
.first
];
327 if (llvm::find_if(ValVec
,
328 [OutArg
](const std::pair
<Argument
*, Value
*> &Entry
) {
329 return Entry
.first
== OutArg
;}) != ValVec
.end()) {
331 << "Saw multiple out arg stores" << *OutArg
<< '\n');
332 // It is possible to see stores to the same argument multiple times,
333 // but we expect these would have been optimized out already.
334 ThisReplaceable
= false;
338 ValVec
.emplace_back(OutArg
, ReplVal
);
339 Store
.second
->eraseFromParent();
342 if (ThisReplaceable
) {
343 ReturnTypes
.push_back(ArgTy
);
344 OutArgIndexes
.insert(OutArg
->getArgNo());
345 ++NumOutArgumentsReplaced
;
351 if (Replacements
.empty())
354 LLVMContext
&Ctx
= F
.getParent()->getContext();
355 StructType
*NewRetTy
= StructType::create(Ctx
, ReturnTypes
, F
.getName());
357 FunctionType
*NewFuncTy
= FunctionType::get(NewRetTy
,
358 F
.getFunctionType()->params(),
361 LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy
<< '\n');
363 Function
*NewFunc
= Function::Create(NewFuncTy
, Function::PrivateLinkage
,
364 F
.getName() + ".body");
365 F
.getParent()->getFunctionList().insert(F
.getIterator(), NewFunc
);
366 NewFunc
->copyAttributesFrom(&F
);
367 NewFunc
->setComdat(F
.getComdat());
369 // We want to preserve the function and param attributes, but need to strip
370 // off any return attributes, e.g. zeroext doesn't make sense with a struct.
371 NewFunc
->stealArgumentListFrom(F
);
373 AttrBuilder RetAttrs
;
374 RetAttrs
.addAttribute(Attribute::SExt
);
375 RetAttrs
.addAttribute(Attribute::ZExt
);
376 RetAttrs
.addAttribute(Attribute::NoAlias
);
377 NewFunc
->removeAttributes(AttributeList::ReturnIndex
, RetAttrs
);
378 // TODO: How to preserve metadata?
380 // Move the body of the function into the new rewritten function, and replace
381 // this function with a stub.
382 NewFunc
->getBasicBlockList().splice(NewFunc
->begin(), F
.getBasicBlockList());
384 for (std::pair
<ReturnInst
*, ReplacementVec
> &Replacement
: Replacements
) {
385 ReturnInst
*RI
= Replacement
.first
;
387 B
.SetCurrentDebugLocation(RI
->getDebugLoc());
390 Value
*NewRetVal
= UndefValue::get(NewRetTy
);
392 Value
*RetVal
= RI
->getReturnValue();
394 NewRetVal
= B
.CreateInsertValue(NewRetVal
, RetVal
, RetIdx
++);
396 for (std::pair
<Argument
*, Value
*> ReturnPoint
: Replacement
.second
) {
397 Argument
*Arg
= ReturnPoint
.first
;
398 Value
*Val
= ReturnPoint
.second
;
399 Type
*EltTy
= Arg
->getType()->getPointerElementType();
400 if (Val
->getType() != EltTy
) {
401 Type
*EffectiveEltTy
= EltTy
;
402 if (StructType
*CT
= dyn_cast
<StructType
>(EltTy
)) {
403 assert(CT
->getNumElements() == 1);
404 EffectiveEltTy
= CT
->getElementType(0);
407 if (DL
->getTypeSizeInBits(EffectiveEltTy
) !=
408 DL
->getTypeSizeInBits(Val
->getType())) {
409 assert(isVec3ToVec4Shuffle(EffectiveEltTy
, Val
->getType()));
410 Val
= B
.CreateShuffleVector(Val
, UndefValue::get(Val
->getType()),
414 Val
= B
.CreateBitCast(Val
, EffectiveEltTy
);
416 // Re-create single element composite.
417 if (EltTy
!= EffectiveEltTy
)
418 Val
= B
.CreateInsertValue(UndefValue::get(EltTy
), Val
, 0);
421 NewRetVal
= B
.CreateInsertValue(NewRetVal
, Val
, RetIdx
++);
425 RI
->setOperand(0, NewRetVal
);
427 B
.CreateRet(NewRetVal
);
428 RI
->eraseFromParent();
432 SmallVector
<Value
*, 16> StubCallArgs
;
433 for (Argument
&Arg
: F
.args()) {
434 if (OutArgIndexes
.count(Arg
.getArgNo())) {
435 // It's easier to preserve the type of the argument list. We rely on
436 // DeadArgumentElimination to take care of these.
437 StubCallArgs
.push_back(UndefValue::get(Arg
.getType()));
439 StubCallArgs
.push_back(&Arg
);
443 BasicBlock
*StubBB
= BasicBlock::Create(Ctx
, "", &F
);
444 IRBuilder
<> B(StubBB
);
445 CallInst
*StubCall
= B
.CreateCall(NewFunc
, StubCallArgs
);
447 int RetIdx
= RetTy
->isVoidTy() ? 0 : 1;
448 for (Argument
&Arg
: F
.args()) {
449 if (!OutArgIndexes
.count(Arg
.getArgNo()))
452 PointerType
*ArgType
= cast
<PointerType
>(Arg
.getType());
454 auto *EltTy
= ArgType
->getElementType();
455 unsigned Align
= Arg
.getParamAlignment();
457 Align
= DL
->getABITypeAlignment(EltTy
);
459 Value
*Val
= B
.CreateExtractValue(StubCall
, RetIdx
++);
460 Type
*PtrTy
= Val
->getType()->getPointerTo(ArgType
->getAddressSpace());
462 // We can peek through bitcasts, so the type may not match.
463 Value
*PtrVal
= B
.CreateBitCast(&Arg
, PtrTy
);
465 B
.CreateAlignedStore(Val
, PtrVal
, Align
);
468 if (!RetTy
->isVoidTy()) {
469 B
.CreateRet(B
.CreateExtractValue(StubCall
, 0));
474 // The function is now a stub we want to inline.
475 F
.addFnAttr(Attribute::AlwaysInline
);
477 ++NumOutArgumentFunctionsReplaced
;
481 FunctionPass
*llvm::createAMDGPURewriteOutArgumentsPass() {
482 return new AMDGPURewriteOutArguments();