1 //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 /// \file This pass attempts to replace out argument usage with a return of a
13 /// We can support returning a lot of values directly in registers, but
14 /// idiomatic C code frequently uses a pointer argument to return a second value
15 /// rather than returning a struct by value. GPU stack access is also quite
16 /// painful, so we want to avoid that if possible. Passing a stack object
17 /// pointer to a function also requires an additional address expansion code
18 /// sequence to convert the pointer to be relative to the kernel's scratch wave
19 /// offset register since the callee doesn't know what stack frame the incoming
20 /// pointer is relative to.
22 /// The goal is to try rewriting code that looks like this:
24 /// int foo(int a, int b, int* out) {
29 /// into something like this:
31 /// std::pair<int, int> foo(int a, int b) {
32 /// return std::make_pair(a + b, bar());
35 /// Typically the incoming pointer is a simple alloca for a temporary variable
36 /// to use the API, which if replaced with a struct return will be easily SROA'd
37 /// out when the stub function we create is inlined
39 /// This pass introduces the struct return, but leaves the unused pointer
40 /// arguments and introduces a new stub function calling the struct returning
41 /// body. DeadArgumentElimination should be run after this to clean these up.
43 //===----------------------------------------------------------------------===//
46 #include "Utils/AMDGPUBaseInfo.h"
47 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
48 #include "llvm/ADT/DenseMap.h"
49 #include "llvm/ADT/STLExtras.h"
50 #include "llvm/ADT/SmallSet.h"
51 #include "llvm/ADT/SmallVector.h"
52 #include "llvm/ADT/Statistic.h"
53 #include "llvm/Analysis/MemoryLocation.h"
54 #include "llvm/IR/Argument.h"
55 #include "llvm/IR/Attributes.h"
56 #include "llvm/IR/BasicBlock.h"
57 #include "llvm/IR/Constants.h"
58 #include "llvm/IR/DataLayout.h"
59 #include "llvm/IR/DerivedTypes.h"
60 #include "llvm/IR/Function.h"
61 #include "llvm/IR/IRBuilder.h"
62 #include "llvm/IR/Instructions.h"
63 #include "llvm/IR/Module.h"
64 #include "llvm/IR/Type.h"
65 #include "llvm/IR/Use.h"
66 #include "llvm/IR/User.h"
67 #include "llvm/IR/Value.h"
68 #include "llvm/Pass.h"
69 #include "llvm/Support/Casting.h"
70 #include "llvm/Support/CommandLine.h"
71 #include "llvm/Support/Debug.h"
72 #include "llvm/Support/raw_ostream.h"
76 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
80 static cl::opt
<bool> AnyAddressSpace(
81 "amdgpu-any-address-space-out-arguments",
82 cl::desc("Replace pointer out arguments with "
83 "struct returns for non-private address space"),
87 static cl::opt
<unsigned> MaxNumRetRegs(
88 "amdgpu-max-return-arg-num-regs",
89 cl::desc("Approximately limit number of return registers for replacing out arguments"),
93 STATISTIC(NumOutArgumentsReplaced
,
94 "Number out arguments moved to struct return values");
95 STATISTIC(NumOutArgumentFunctionsReplaced
,
96 "Number of functions with out arguments moved to struct return values");
100 class AMDGPURewriteOutArguments
: public FunctionPass
{
102 const DataLayout
*DL
= nullptr;
103 MemoryDependenceResults
*MDA
= nullptr;
105 bool checkArgumentUses(Value
&Arg
) const;
106 bool isOutArgumentCandidate(Argument
&Arg
) const;
109 bool isVec3ToVec4Shuffle(Type
*Ty0
, Type
* Ty1
) const;
115 AMDGPURewriteOutArguments() : FunctionPass(ID
) {}
117 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
118 AU
.addRequired
<MemoryDependenceWrapperPass
>();
119 FunctionPass::getAnalysisUsage(AU
);
122 bool doInitialization(Module
&M
) override
;
123 bool runOnFunction(Function
&F
) override
;
126 } // end anonymous namespace
128 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
129 "AMDGPU Rewrite Out Arguments", false, false)
130 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass
)
131 INITIALIZE_PASS_END(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
132 "AMDGPU Rewrite Out Arguments", false, false)
134 char AMDGPURewriteOutArguments::ID
= 0;
136 bool AMDGPURewriteOutArguments::checkArgumentUses(Value
&Arg
) const {
137 const int MaxUses
= 10;
140 for (Use
&U
: Arg
.uses()) {
141 StoreInst
*SI
= dyn_cast
<StoreInst
>(U
.getUser());
142 if (UseCount
> MaxUses
)
146 auto *BCI
= dyn_cast
<BitCastInst
>(U
.getUser());
147 if (!BCI
|| !BCI
->hasOneUse())
150 // We don't handle multiple stores currently, so stores to aggregate
151 // pointers aren't worth the trouble since they are canonically split up.
152 Type
*DestEltTy
= BCI
->getType()->getPointerElementType();
153 if (DestEltTy
->isAggregateType())
156 // We could handle these if we had a convenient way to bitcast between
158 Type
*SrcEltTy
= Arg
.getType()->getPointerElementType();
159 if (SrcEltTy
->isArrayTy())
162 // Special case handle structs with single members. It is useful to handle
163 // some casts between structs and non-structs, but we can't bitcast
164 // directly between them. directly bitcast between them. Blender uses
165 // some casts that look like { <3 x float> }* to <4 x float>*
166 if ((SrcEltTy
->isStructTy() && (SrcEltTy
->getNumContainedTypes() != 1)))
169 // Clang emits OpenCL 3-vector type accesses with a bitcast to the
170 // equivalent 4-element vector and accesses that, and we're looking for
171 // this pointer cast.
172 if (DL
->getTypeAllocSize(SrcEltTy
) != DL
->getTypeAllocSize(DestEltTy
))
175 return checkArgumentUses(*BCI
);
178 if (!SI
->isSimple() ||
179 U
.getOperandNo() != StoreInst::getPointerOperandIndex())
185 // Skip unused arguments.
189 bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument
&Arg
) const {
190 const unsigned MaxOutArgSizeBytes
= 4 * MaxNumRetRegs
;
191 PointerType
*ArgTy
= dyn_cast
<PointerType
>(Arg
.getType());
193 // TODO: It might be useful for any out arguments, not just privates.
194 if (!ArgTy
|| (ArgTy
->getAddressSpace() != DL
->getAllocaAddrSpace() &&
196 Arg
.hasByValAttr() || Arg
.hasStructRetAttr() ||
197 DL
->getTypeStoreSize(ArgTy
->getPointerElementType()) > MaxOutArgSizeBytes
) {
201 return checkArgumentUses(Arg
);
204 bool AMDGPURewriteOutArguments::doInitialization(Module
&M
) {
205 DL
= &M
.getDataLayout();
210 bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type
*Ty0
, Type
* Ty1
) const {
211 VectorType
*VT0
= dyn_cast
<VectorType
>(Ty0
);
212 VectorType
*VT1
= dyn_cast
<VectorType
>(Ty1
);
216 if (VT0
->getNumElements() != 3 ||
217 VT1
->getNumElements() != 4)
220 return DL
->getTypeSizeInBits(VT0
->getElementType()) ==
221 DL
->getTypeSizeInBits(VT1
->getElementType());
225 bool AMDGPURewriteOutArguments::runOnFunction(Function
&F
) {
229 // TODO: Could probably handle variadic functions.
230 if (F
.isVarArg() || F
.hasStructRetAttr() ||
231 AMDGPU::isEntryFunctionCC(F
.getCallingConv()))
234 MDA
= &getAnalysis
<MemoryDependenceWrapperPass
>().getMemDep();
236 unsigned ReturnNumRegs
= 0;
237 SmallSet
<int, 4> OutArgIndexes
;
238 SmallVector
<Type
*, 4> ReturnTypes
;
239 Type
*RetTy
= F
.getReturnType();
240 if (!RetTy
->isVoidTy()) {
241 ReturnNumRegs
= DL
->getTypeStoreSize(RetTy
) / 4;
243 if (ReturnNumRegs
>= MaxNumRetRegs
)
246 ReturnTypes
.push_back(RetTy
);
249 SmallVector
<Argument
*, 4> OutArgs
;
250 for (Argument
&Arg
: F
.args()) {
251 if (isOutArgumentCandidate(Arg
)) {
252 LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
253 << " in function " << F
.getName() << '\n');
254 OutArgs
.push_back(&Arg
);
261 using ReplacementVec
= SmallVector
<std::pair
<Argument
*, Value
*>, 4>;
263 DenseMap
<ReturnInst
*, ReplacementVec
> Replacements
;
265 SmallVector
<ReturnInst
*, 4> Returns
;
266 for (BasicBlock
&BB
: F
) {
267 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(&BB
.back()))
268 Returns
.push_back(RI
);
279 // Keep retrying if we are able to successfully eliminate an argument. This
280 // helps with cases with multiple arguments which may alias, such as in a
281 // sincos implemntation. If we have 2 stores to arguments, on the first
282 // attempt the MDA query will succeed for the second store but not the
283 // first. On the second iteration we've removed that out clobbering argument
284 // (by effectively moving it into another function) and will find the second
285 // argument is OK to move.
286 for (Argument
*OutArg
: OutArgs
) {
287 bool ThisReplaceable
= true;
288 SmallVector
<std::pair
<ReturnInst
*, StoreInst
*>, 4> ReplaceableStores
;
290 Type
*ArgTy
= OutArg
->getType()->getPointerElementType();
292 // Skip this argument if converting it will push us over the register
293 // count to return limit.
295 // TODO: This is an approximation. When legalized this could be more. We
296 // can ask TLI for exactly how many.
297 unsigned ArgNumRegs
= DL
->getTypeStoreSize(ArgTy
) / 4;
298 if (ArgNumRegs
+ ReturnNumRegs
> MaxNumRetRegs
)
301 // An argument is convertible only if all exit blocks are able to replace
303 for (ReturnInst
*RI
: Returns
) {
304 BasicBlock
*BB
= RI
->getParent();
306 MemDepResult Q
= MDA
->getPointerDependencyFrom(MemoryLocation(OutArg
),
307 true, BB
->end(), BB
, RI
);
308 StoreInst
*SI
= nullptr;
310 SI
= dyn_cast
<StoreInst
>(Q
.getInst());
313 LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI
<< '\n');
314 ReplaceableStores
.emplace_back(RI
, SI
);
316 ThisReplaceable
= false;
321 if (!ThisReplaceable
)
322 continue; // Try the next argument candidate.
324 for (std::pair
<ReturnInst
*, StoreInst
*> Store
: ReplaceableStores
) {
325 Value
*ReplVal
= Store
.second
->getValueOperand();
327 auto &ValVec
= Replacements
[Store
.first
];
328 if (llvm::find_if(ValVec
,
329 [OutArg
](const std::pair
<Argument
*, Value
*> &Entry
) {
330 return Entry
.first
== OutArg
;}) != ValVec
.end()) {
332 << "Saw multiple out arg stores" << *OutArg
<< '\n');
333 // It is possible to see stores to the same argument multiple times,
334 // but we expect these would have been optimized out already.
335 ThisReplaceable
= false;
339 ValVec
.emplace_back(OutArg
, ReplVal
);
340 Store
.second
->eraseFromParent();
343 if (ThisReplaceable
) {
344 ReturnTypes
.push_back(ArgTy
);
345 OutArgIndexes
.insert(OutArg
->getArgNo());
346 ++NumOutArgumentsReplaced
;
352 if (Replacements
.empty())
355 LLVMContext
&Ctx
= F
.getParent()->getContext();
356 StructType
*NewRetTy
= StructType::create(Ctx
, ReturnTypes
, F
.getName());
358 FunctionType
*NewFuncTy
= FunctionType::get(NewRetTy
,
359 F
.getFunctionType()->params(),
362 LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy
<< '\n');
364 Function
*NewFunc
= Function::Create(NewFuncTy
, Function::PrivateLinkage
,
365 F
.getName() + ".body");
366 F
.getParent()->getFunctionList().insert(F
.getIterator(), NewFunc
);
367 NewFunc
->copyAttributesFrom(&F
);
368 NewFunc
->setComdat(F
.getComdat());
370 // We want to preserve the function and param attributes, but need to strip
371 // off any return attributes, e.g. zeroext doesn't make sense with a struct.
372 NewFunc
->stealArgumentListFrom(F
);
374 AttrBuilder RetAttrs
;
375 RetAttrs
.addAttribute(Attribute::SExt
);
376 RetAttrs
.addAttribute(Attribute::ZExt
);
377 RetAttrs
.addAttribute(Attribute::NoAlias
);
378 NewFunc
->removeAttributes(AttributeList::ReturnIndex
, RetAttrs
);
379 // TODO: How to preserve metadata?
381 // Move the body of the function into the new rewritten function, and replace
382 // this function with a stub.
383 NewFunc
->getBasicBlockList().splice(NewFunc
->begin(), F
.getBasicBlockList());
385 for (std::pair
<ReturnInst
*, ReplacementVec
> &Replacement
: Replacements
) {
386 ReturnInst
*RI
= Replacement
.first
;
388 B
.SetCurrentDebugLocation(RI
->getDebugLoc());
391 Value
*NewRetVal
= UndefValue::get(NewRetTy
);
393 Value
*RetVal
= RI
->getReturnValue();
395 NewRetVal
= B
.CreateInsertValue(NewRetVal
, RetVal
, RetIdx
++);
397 for (std::pair
<Argument
*, Value
*> ReturnPoint
: Replacement
.second
) {
398 Argument
*Arg
= ReturnPoint
.first
;
399 Value
*Val
= ReturnPoint
.second
;
400 Type
*EltTy
= Arg
->getType()->getPointerElementType();
401 if (Val
->getType() != EltTy
) {
402 Type
*EffectiveEltTy
= EltTy
;
403 if (StructType
*CT
= dyn_cast
<StructType
>(EltTy
)) {
404 assert(CT
->getNumContainedTypes() == 1);
405 EffectiveEltTy
= CT
->getContainedType(0);
408 if (DL
->getTypeSizeInBits(EffectiveEltTy
) !=
409 DL
->getTypeSizeInBits(Val
->getType())) {
410 assert(isVec3ToVec4Shuffle(EffectiveEltTy
, Val
->getType()));
411 Val
= B
.CreateShuffleVector(Val
, UndefValue::get(Val
->getType()),
415 Val
= B
.CreateBitCast(Val
, EffectiveEltTy
);
417 // Re-create single element composite.
418 if (EltTy
!= EffectiveEltTy
)
419 Val
= B
.CreateInsertValue(UndefValue::get(EltTy
), Val
, 0);
422 NewRetVal
= B
.CreateInsertValue(NewRetVal
, Val
, RetIdx
++);
426 RI
->setOperand(0, NewRetVal
);
428 B
.CreateRet(NewRetVal
);
429 RI
->eraseFromParent();
433 SmallVector
<Value
*, 16> StubCallArgs
;
434 for (Argument
&Arg
: F
.args()) {
435 if (OutArgIndexes
.count(Arg
.getArgNo())) {
436 // It's easier to preserve the type of the argument list. We rely on
437 // DeadArgumentElimination to take care of these.
438 StubCallArgs
.push_back(UndefValue::get(Arg
.getType()));
440 StubCallArgs
.push_back(&Arg
);
444 BasicBlock
*StubBB
= BasicBlock::Create(Ctx
, "", &F
);
445 IRBuilder
<> B(StubBB
);
446 CallInst
*StubCall
= B
.CreateCall(NewFunc
, StubCallArgs
);
448 int RetIdx
= RetTy
->isVoidTy() ? 0 : 1;
449 for (Argument
&Arg
: F
.args()) {
450 if (!OutArgIndexes
.count(Arg
.getArgNo()))
453 PointerType
*ArgType
= cast
<PointerType
>(Arg
.getType());
455 auto *EltTy
= ArgType
->getElementType();
456 unsigned Align
= Arg
.getParamAlignment();
458 Align
= DL
->getABITypeAlignment(EltTy
);
460 Value
*Val
= B
.CreateExtractValue(StubCall
, RetIdx
++);
461 Type
*PtrTy
= Val
->getType()->getPointerTo(ArgType
->getAddressSpace());
463 // We can peek through bitcasts, so the type may not match.
464 Value
*PtrVal
= B
.CreateBitCast(&Arg
, PtrTy
);
466 B
.CreateAlignedStore(Val
, PtrVal
, Align
);
469 if (!RetTy
->isVoidTy()) {
470 B
.CreateRet(B
.CreateExtractValue(StubCall
, 0));
475 // The function is now a stub we want to inline.
476 F
.addFnAttr(Attribute::AlwaysInline
);
478 ++NumOutArgumentFunctionsReplaced
;
482 FunctionPass
*llvm::createAMDGPURewriteOutArgumentsPass() {
483 return new AMDGPURewriteOutArguments();