1 //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This pass attempts to replace out argument usage with a return of a
12 /// We can support returning a lot of values directly in registers, but
13 /// idiomatic C code frequently uses a pointer argument to return a second value
14 /// rather than returning a struct by value. GPU stack access is also quite
15 /// painful, so we want to avoid that if possible. Passing a stack object
16 /// pointer to a function also requires an additional address expansion code
17 /// sequence to convert the pointer to be relative to the kernel's scratch wave
18 /// offset register since the callee doesn't know what stack frame the incoming
19 /// pointer is relative to.
21 /// The goal is to try rewriting code that looks like this:
23 /// int foo(int a, int b, int* out) {
28 /// into something like this:
30 /// std::pair<int, int> foo(int a, int b) {
31 /// return std::pair(a + b, bar());
34 /// Typically the incoming pointer is a simple alloca for a temporary variable
35 /// to use the API, which if replaced with a struct return will be easily SROA'd
36 /// out when the stub function we create is inlined
38 /// This pass introduces the struct return, but leaves the unused pointer
39 /// arguments and introduces a new stub function calling the struct returning
40 /// body. DeadArgumentElimination should be run after this to clean these up.
42 //===----------------------------------------------------------------------===//
45 #include "Utils/AMDGPUBaseInfo.h"
46 #include "llvm/ADT/Statistic.h"
47 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
48 #include "llvm/IR/AttributeMask.h"
49 #include "llvm/IR/IRBuilder.h"
50 #include "llvm/IR/Instructions.h"
51 #include "llvm/InitializePasses.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/CommandLine.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/raw_ostream.h"
57 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
61 static cl::opt
<bool> AnyAddressSpace(
62 "amdgpu-any-address-space-out-arguments",
63 cl::desc("Replace pointer out arguments with "
64 "struct returns for non-private address space"),
68 static cl::opt
<unsigned> MaxNumRetRegs(
69 "amdgpu-max-return-arg-num-regs",
70 cl::desc("Approximately limit number of return registers for replacing out arguments"),
74 STATISTIC(NumOutArgumentsReplaced
,
75 "Number out arguments moved to struct return values");
76 STATISTIC(NumOutArgumentFunctionsReplaced
,
77 "Number of functions with out arguments moved to struct return values");
81 class AMDGPURewriteOutArguments
: public FunctionPass
{
83 const DataLayout
*DL
= nullptr;
84 MemoryDependenceResults
*MDA
= nullptr;
86 Type
*getStoredType(Value
&Arg
) const;
87 Type
*getOutArgumentType(Argument
&Arg
) const;
92 AMDGPURewriteOutArguments() : FunctionPass(ID
) {}
94 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
95 AU
.addRequired
<MemoryDependenceWrapperPass
>();
96 FunctionPass::getAnalysisUsage(AU
);
99 bool doInitialization(Module
&M
) override
;
100 bool runOnFunction(Function
&F
) override
;
103 } // end anonymous namespace
105 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
106 "AMDGPU Rewrite Out Arguments", false, false)
107 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass
)
108 INITIALIZE_PASS_END(AMDGPURewriteOutArguments
, DEBUG_TYPE
,
109 "AMDGPU Rewrite Out Arguments", false, false)
111 char AMDGPURewriteOutArguments::ID
= 0;
113 Type
*AMDGPURewriteOutArguments::getStoredType(Value
&Arg
) const {
114 const int MaxUses
= 10;
117 SmallVector
<Use
*> Worklist
;
118 for (Use
&U
: Arg
.uses())
119 Worklist
.push_back(&U
);
121 Type
*StoredType
= nullptr;
122 while (!Worklist
.empty()) {
123 Use
*U
= Worklist
.pop_back_val();
125 if (auto *BCI
= dyn_cast
<BitCastInst
>(U
->getUser())) {
126 for (Use
&U
: BCI
->uses())
127 Worklist
.push_back(&U
);
131 if (auto *SI
= dyn_cast
<StoreInst
>(U
->getUser())) {
132 if (UseCount
++ > MaxUses
)
135 if (!SI
->isSimple() ||
136 U
->getOperandNo() != StoreInst::getPointerOperandIndex())
139 if (StoredType
&& StoredType
!= SI
->getValueOperand()->getType())
140 return nullptr; // More than one type.
141 StoredType
= SI
->getValueOperand()->getType();
152 Type
*AMDGPURewriteOutArguments::getOutArgumentType(Argument
&Arg
) const {
153 const unsigned MaxOutArgSizeBytes
= 4 * MaxNumRetRegs
;
154 PointerType
*ArgTy
= dyn_cast
<PointerType
>(Arg
.getType());
156 // TODO: It might be useful for any out arguments, not just privates.
157 if (!ArgTy
|| (ArgTy
->getAddressSpace() != DL
->getAllocaAddrSpace() &&
159 Arg
.hasByValAttr() || Arg
.hasStructRetAttr()) {
163 Type
*StoredType
= getStoredType(Arg
);
164 if (!StoredType
|| DL
->getTypeStoreSize(StoredType
) > MaxOutArgSizeBytes
)
170 bool AMDGPURewriteOutArguments::doInitialization(Module
&M
) {
171 DL
= &M
.getDataLayout();
175 bool AMDGPURewriteOutArguments::runOnFunction(Function
&F
) {
179 // TODO: Could probably handle variadic functions.
180 if (F
.isVarArg() || F
.hasStructRetAttr() ||
181 AMDGPU::isEntryFunctionCC(F
.getCallingConv()))
184 MDA
= &getAnalysis
<MemoryDependenceWrapperPass
>().getMemDep();
186 unsigned ReturnNumRegs
= 0;
187 SmallDenseMap
<int, Type
*, 4> OutArgIndexes
;
188 SmallVector
<Type
*, 4> ReturnTypes
;
189 Type
*RetTy
= F
.getReturnType();
190 if (!RetTy
->isVoidTy()) {
191 ReturnNumRegs
= DL
->getTypeStoreSize(RetTy
) / 4;
193 if (ReturnNumRegs
>= MaxNumRetRegs
)
196 ReturnTypes
.push_back(RetTy
);
199 SmallVector
<std::pair
<Argument
*, Type
*>, 4> OutArgs
;
200 for (Argument
&Arg
: F
.args()) {
201 if (Type
*Ty
= getOutArgumentType(Arg
)) {
202 LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
203 << " in function " << F
.getName() << '\n');
204 OutArgs
.push_back({&Arg
, Ty
});
211 using ReplacementVec
= SmallVector
<std::pair
<Argument
*, Value
*>, 4>;
213 DenseMap
<ReturnInst
*, ReplacementVec
> Replacements
;
215 SmallVector
<ReturnInst
*, 4> Returns
;
216 for (BasicBlock
&BB
: F
) {
217 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(&BB
.back()))
218 Returns
.push_back(RI
);
229 // Keep retrying if we are able to successfully eliminate an argument. This
230 // helps with cases with multiple arguments which may alias, such as in a
231 // sincos implementation. If we have 2 stores to arguments, on the first
232 // attempt the MDA query will succeed for the second store but not the
233 // first. On the second iteration we've removed that out clobbering argument
234 // (by effectively moving it into another function) and will find the second
235 // argument is OK to move.
236 for (const auto &Pair
: OutArgs
) {
237 bool ThisReplaceable
= true;
238 SmallVector
<std::pair
<ReturnInst
*, StoreInst
*>, 4> ReplaceableStores
;
240 Argument
*OutArg
= Pair
.first
;
241 Type
*ArgTy
= Pair
.second
;
243 // Skip this argument if converting it will push us over the register
244 // count to return limit.
246 // TODO: This is an approximation. When legalized this could be more. We
247 // can ask TLI for exactly how many.
248 unsigned ArgNumRegs
= DL
->getTypeStoreSize(ArgTy
) / 4;
249 if (ArgNumRegs
+ ReturnNumRegs
> MaxNumRetRegs
)
252 // An argument is convertible only if all exit blocks are able to replace
254 for (ReturnInst
*RI
: Returns
) {
255 BasicBlock
*BB
= RI
->getParent();
257 MemDepResult Q
= MDA
->getPointerDependencyFrom(
258 MemoryLocation::getBeforeOrAfter(OutArg
), true, BB
->end(), BB
, RI
);
259 StoreInst
*SI
= nullptr;
261 SI
= dyn_cast
<StoreInst
>(Q
.getInst());
264 LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI
<< '\n');
265 ReplaceableStores
.emplace_back(RI
, SI
);
267 ThisReplaceable
= false;
272 if (!ThisReplaceable
)
273 continue; // Try the next argument candidate.
275 for (std::pair
<ReturnInst
*, StoreInst
*> Store
: ReplaceableStores
) {
276 Value
*ReplVal
= Store
.second
->getValueOperand();
278 auto &ValVec
= Replacements
[Store
.first
];
279 if (llvm::any_of(ValVec
,
280 [OutArg
](const std::pair
<Argument
*, Value
*> &Entry
) {
281 return Entry
.first
== OutArg
;
284 << "Saw multiple out arg stores" << *OutArg
<< '\n');
285 // It is possible to see stores to the same argument multiple times,
286 // but we expect these would have been optimized out already.
287 ThisReplaceable
= false;
291 ValVec
.emplace_back(OutArg
, ReplVal
);
292 Store
.second
->eraseFromParent();
295 if (ThisReplaceable
) {
296 ReturnTypes
.push_back(ArgTy
);
297 OutArgIndexes
.insert({OutArg
->getArgNo(), ArgTy
});
298 ++NumOutArgumentsReplaced
;
304 if (Replacements
.empty())
307 LLVMContext
&Ctx
= F
.getParent()->getContext();
308 StructType
*NewRetTy
= StructType::create(Ctx
, ReturnTypes
, F
.getName());
310 FunctionType
*NewFuncTy
= FunctionType::get(NewRetTy
,
311 F
.getFunctionType()->params(),
314 LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy
<< '\n');
316 Function
*NewFunc
= Function::Create(NewFuncTy
, Function::PrivateLinkage
,
317 F
.getName() + ".body");
318 F
.getParent()->getFunctionList().insert(F
.getIterator(), NewFunc
);
319 NewFunc
->copyAttributesFrom(&F
);
320 NewFunc
->setComdat(F
.getComdat());
322 // We want to preserve the function and param attributes, but need to strip
323 // off any return attributes, e.g. zeroext doesn't make sense with a struct.
324 NewFunc
->stealArgumentListFrom(F
);
326 AttributeMask RetAttrs
;
327 RetAttrs
.addAttribute(Attribute::SExt
);
328 RetAttrs
.addAttribute(Attribute::ZExt
);
329 RetAttrs
.addAttribute(Attribute::NoAlias
);
330 NewFunc
->removeRetAttrs(RetAttrs
);
331 // TODO: How to preserve metadata?
333 NewFunc
->setIsNewDbgInfoFormat(F
.IsNewDbgInfoFormat
);
335 // Move the body of the function into the new rewritten function, and replace
336 // this function with a stub.
337 NewFunc
->splice(NewFunc
->begin(), &F
);
339 for (std::pair
<ReturnInst
*, ReplacementVec
> &Replacement
: Replacements
) {
340 ReturnInst
*RI
= Replacement
.first
;
342 B
.SetCurrentDebugLocation(RI
->getDebugLoc());
345 Value
*NewRetVal
= PoisonValue::get(NewRetTy
);
347 Value
*RetVal
= RI
->getReturnValue();
349 NewRetVal
= B
.CreateInsertValue(NewRetVal
, RetVal
, RetIdx
++);
351 for (std::pair
<Argument
*, Value
*> ReturnPoint
: Replacement
.second
)
352 NewRetVal
= B
.CreateInsertValue(NewRetVal
, ReturnPoint
.second
, RetIdx
++);
355 RI
->setOperand(0, NewRetVal
);
357 B
.CreateRet(NewRetVal
);
358 RI
->eraseFromParent();
362 SmallVector
<Value
*, 16> StubCallArgs
;
363 for (Argument
&Arg
: F
.args()) {
364 if (OutArgIndexes
.count(Arg
.getArgNo())) {
365 // It's easier to preserve the type of the argument list. We rely on
366 // DeadArgumentElimination to take care of these.
367 StubCallArgs
.push_back(PoisonValue::get(Arg
.getType()));
369 StubCallArgs
.push_back(&Arg
);
373 BasicBlock
*StubBB
= BasicBlock::Create(Ctx
, "", &F
);
374 IRBuilder
<> B(StubBB
);
375 CallInst
*StubCall
= B
.CreateCall(NewFunc
, StubCallArgs
);
377 int RetIdx
= RetTy
->isVoidTy() ? 0 : 1;
378 for (Argument
&Arg
: F
.args()) {
379 if (!OutArgIndexes
.count(Arg
.getArgNo()))
382 Type
*EltTy
= OutArgIndexes
[Arg
.getArgNo()];
384 DL
->getValueOrABITypeAlignment(Arg
.getParamAlign(), EltTy
);
386 Value
*Val
= B
.CreateExtractValue(StubCall
, RetIdx
++);
387 B
.CreateAlignedStore(Val
, &Arg
, Align
);
390 if (!RetTy
->isVoidTy()) {
391 B
.CreateRet(B
.CreateExtractValue(StubCall
, 0));
396 // The function is now a stub we want to inline.
397 F
.addFnAttr(Attribute::AlwaysInline
);
399 ++NumOutArgumentFunctionsReplaced
;
403 FunctionPass
*llvm::createAMDGPURewriteOutArgumentsPass() {
404 return new AMDGPURewriteOutArguments();