1 //===-- StructRetPromotion.cpp - Promote sret arguments ------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass finds functions that return a struct (using a pointer to the struct
11 // as the first argument of the function, marked with the 'sret' attribute) and
12 // replaces them with a new function that simply returns each of the elements of
13 // that struct (using multiple return values).
15 // This pass works under a number of conditions:
16 // 1. The returned struct must not contain other structs
17 // 2. The returned struct must only be used to load values from
18 // 3. The placeholder struct passed in is the result of an alloca
20 //===----------------------------------------------------------------------===//
22 #define DEBUG_TYPE "sretpromotion"
23 #include "llvm/Transforms/IPO.h"
24 #include "llvm/Constants.h"
25 #include "llvm/DerivedTypes.h"
26 #include "llvm/Module.h"
27 #include "llvm/CallGraphSCCPass.h"
28 #include "llvm/Instructions.h"
29 #include "llvm/Analysis/CallGraph.h"
30 #include "llvm/Support/CallSite.h"
31 #include "llvm/Support/CFG.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/ADT/Statistic.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/Statistic.h"
36 #include "llvm/Support/Compiler.h"
39 STATISTIC(NumRejectedSRETUses
, "Number of sret rejected due to unexpected uses");
40 STATISTIC(NumSRET
, "Number of sret promoted");
42 /// SRETPromotion - This pass removes sret parameter and updates
43 /// function to use multiple return value.
45 struct VISIBILITY_HIDDEN SRETPromotion
: public CallGraphSCCPass
{
46 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
47 CallGraphSCCPass::getAnalysisUsage(AU
);
50 virtual bool runOnSCC(const std::vector
<CallGraphNode
*> &SCC
);
51 static char ID
; // Pass identification, replacement for typeid
52 SRETPromotion() : CallGraphSCCPass(&ID
) {}
55 bool PromoteReturn(CallGraphNode
*CGN
);
56 bool isSafeToUpdateAllCallers(Function
*F
);
57 Function
*cloneFunctionBody(Function
*F
, const StructType
*STy
);
58 void updateCallSites(Function
*F
, Function
*NF
);
59 bool nestedStructType(const StructType
*STy
);
63 char SRETPromotion::ID
= 0;
64 static RegisterPass
<SRETPromotion
>
65 X("sretpromotion", "Promote sret arguments to multiple ret values");
67 Pass
*llvm::createStructRetPromotionPass() {
68 return new SRETPromotion();
71 bool SRETPromotion::runOnSCC(const std::vector
<CallGraphNode
*> &SCC
) {
74 for (unsigned i
= 0, e
= SCC
.size(); i
!= e
; ++i
)
75 Changed
|= PromoteReturn(SCC
[i
]);
80 /// PromoteReturn - This method promotes function that uses StructRet paramater
81 /// into a function that uses mulitple return value.
82 bool SRETPromotion::PromoteReturn(CallGraphNode
*CGN
) {
83 Function
*F
= CGN
->getFunction();
85 if (!F
|| F
->isDeclaration() || !F
->hasLocalLinkage())
88 // Make sure that function returns struct.
89 if (F
->arg_size() == 0 || !F
->hasStructRetAttr() || F
->doesNotReturn())
92 DOUT
<< "SretPromotion: Looking at sret function " << F
->getNameStart() << "\n";
94 assert (F
->getReturnType() == Type::VoidTy
&& "Invalid function return type");
95 Function::arg_iterator AI
= F
->arg_begin();
96 const llvm::PointerType
*FArgType
= dyn_cast
<PointerType
>(AI
->getType());
97 assert (FArgType
&& "Invalid sret parameter type");
98 const llvm::StructType
*STy
=
99 dyn_cast
<StructType
>(FArgType
->getElementType());
100 assert (STy
&& "Invalid sret parameter element type");
102 // Check if it is ok to perform this promotion.
103 if (isSafeToUpdateAllCallers(F
) == false) {
104 DOUT
<< "SretPromotion: Not all callers can be updated\n";
105 NumRejectedSRETUses
++;
109 DOUT
<< "SretPromotion: sret argument will be promoted\n";
111 // [1] Replace use of sret parameter
112 AllocaInst
*TheAlloca
= new AllocaInst (STy
, NULL
, "mrv",
113 F
->getEntryBlock().begin());
114 Value
*NFirstArg
= F
->arg_begin();
115 NFirstArg
->replaceAllUsesWith(TheAlloca
);
117 // [2] Find and replace ret instructions
118 for (Function::iterator FI
= F
->begin(), FE
= F
->end(); FI
!= FE
; ++FI
)
119 for(BasicBlock::iterator BI
= FI
->begin(), BE
= FI
->end(); BI
!= BE
; ) {
122 if (isa
<ReturnInst
>(I
)) {
123 Value
*NV
= new LoadInst(TheAlloca
, "mrv.ld", I
);
124 ReturnInst
*NR
= ReturnInst::Create(NV
, I
);
125 I
->replaceAllUsesWith(NR
);
126 I
->eraseFromParent();
130 // [3] Create the new function body and insert it into the module.
131 Function
*NF
= cloneFunctionBody(F
, STy
);
133 // [4] Update all call sites to use new function
134 updateCallSites(F
, NF
);
136 F
->eraseFromParent();
137 getAnalysis
<CallGraph
>().changeFunction(F
, NF
);
141 // Check if it is ok to perform this promotion.
142 bool SRETPromotion::isSafeToUpdateAllCallers(Function
*F
) {
145 // No users. OK to modify signature.
148 for (Value::use_iterator FnUseI
= F
->use_begin(), FnUseE
= F
->use_end();
149 FnUseI
!= FnUseE
; ++FnUseI
) {
150 // The function is passed in as an argument to (possibly) another function,
151 // we can't change it!
152 CallSite CS
= CallSite::get(*FnUseI
);
153 Instruction
*Call
= CS
.getInstruction();
154 // The function is used by something else than a call or invoke instruction,
155 // we can't change it!
156 if (!Call
|| !CS
.isCallee(FnUseI
))
158 CallSite::arg_iterator AI
= CS
.arg_begin();
159 Value
*FirstArg
= *AI
;
161 if (!isa
<AllocaInst
>(FirstArg
))
164 // Check FirstArg's users.
165 for (Value::use_iterator ArgI
= FirstArg
->use_begin(),
166 ArgE
= FirstArg
->use_end(); ArgI
!= ArgE
; ++ArgI
) {
168 // If FirstArg user is a CallInst that does not correspond to current
169 // call site then this function F is not suitable for sret promotion.
170 if (CallInst
*CI
= dyn_cast
<CallInst
>(ArgI
)) {
174 // If FirstArg user is a GEP whose all users are not LoadInst then
175 // this function F is not suitable for sret promotion.
176 else if (GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(ArgI
)) {
177 // TODO : Use dom info and insert PHINodes to collect get results
178 // from multiple call sites for this GEP.
179 if (GEP
->getParent() != Call
->getParent())
181 for (Value::use_iterator GEPI
= GEP
->use_begin(), GEPE
= GEP
->use_end();
182 GEPI
!= GEPE
; ++GEPI
)
183 if (!isa
<LoadInst
>(GEPI
))
186 // Any other FirstArg users make this function unsuitable for sret
196 /// cloneFunctionBody - Create a new function based on F and
197 /// insert it into module. Remove first argument. Use STy as
198 /// the return type for new function.
199 Function
*SRETPromotion::cloneFunctionBody(Function
*F
,
200 const StructType
*STy
) {
202 const FunctionType
*FTy
= F
->getFunctionType();
203 std::vector
<const Type
*> Params
;
205 // Attributes - Keep track of the parameter attributes for the arguments.
206 SmallVector
<AttributeWithIndex
, 8> AttributesVec
;
207 const AttrListPtr
&PAL
= F
->getAttributes();
209 // Add any return attributes.
210 if (Attributes attrs
= PAL
.getRetAttributes())
211 AttributesVec
.push_back(AttributeWithIndex::get(0, attrs
));
213 // Skip first argument.
214 Function::arg_iterator I
= F
->arg_begin(), E
= F
->arg_end();
216 // 0th parameter attribute is reserved for return type.
217 // 1th parameter attribute is for first 1st sret argument.
218 unsigned ParamIndex
= 2;
220 Params
.push_back(I
->getType());
221 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
222 AttributesVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
227 // Add any fn attributes.
228 if (Attributes attrs
= PAL
.getFnAttributes())
229 AttributesVec
.push_back(AttributeWithIndex::get(~0, attrs
));
232 FunctionType
*NFTy
= FunctionType::get(STy
, Params
, FTy
->isVarArg());
233 Function
*NF
= Function::Create(NFTy
, F
->getLinkage());
235 NF
->copyAttributesFrom(F
);
236 NF
->setAttributes(AttrListPtr::get(AttributesVec
.begin(), AttributesVec
.end()));
237 F
->getParent()->getFunctionList().insert(F
, NF
);
238 NF
->getBasicBlockList().splice(NF
->begin(), F
->getBasicBlockList());
243 Function::arg_iterator NI
= NF
->arg_begin();
246 I
->replaceAllUsesWith(NI
);
255 /// updateCallSites - Update all sites that call F to use NF.
256 void SRETPromotion::updateCallSites(Function
*F
, Function
*NF
) {
257 CallGraph
&CG
= getAnalysis
<CallGraph
>();
258 SmallVector
<Value
*, 16> Args
;
260 // Attributes - Keep track of the parameter attributes for the arguments.
261 SmallVector
<AttributeWithIndex
, 8> ArgAttrsVec
;
263 while (!F
->use_empty()) {
264 CallSite CS
= CallSite::get(*F
->use_begin());
265 Instruction
*Call
= CS
.getInstruction();
267 const AttrListPtr
&PAL
= F
->getAttributes();
268 // Add any return attributes.
269 if (Attributes attrs
= PAL
.getRetAttributes())
270 ArgAttrsVec
.push_back(AttributeWithIndex::get(0, attrs
));
272 // Copy arguments, however skip first one.
273 CallSite::arg_iterator AI
= CS
.arg_begin(), AE
= CS
.arg_end();
274 Value
*FirstCArg
= *AI
;
276 // 0th parameter attribute is reserved for return type.
277 // 1th parameter attribute is for first 1st sret argument.
278 unsigned ParamIndex
= 2;
281 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
282 ArgAttrsVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
287 // Add any function attributes.
288 if (Attributes attrs
= PAL
.getFnAttributes())
289 ArgAttrsVec
.push_back(AttributeWithIndex::get(~0, attrs
));
291 AttrListPtr NewPAL
= AttrListPtr::get(ArgAttrsVec
.begin(), ArgAttrsVec
.end());
293 // Build new call instruction.
295 if (InvokeInst
*II
= dyn_cast
<InvokeInst
>(Call
)) {
296 New
= InvokeInst::Create(NF
, II
->getNormalDest(), II
->getUnwindDest(),
297 Args
.begin(), Args
.end(), "", Call
);
298 cast
<InvokeInst
>(New
)->setCallingConv(CS
.getCallingConv());
299 cast
<InvokeInst
>(New
)->setAttributes(NewPAL
);
301 New
= CallInst::Create(NF
, Args
.begin(), Args
.end(), "", Call
);
302 cast
<CallInst
>(New
)->setCallingConv(CS
.getCallingConv());
303 cast
<CallInst
>(New
)->setAttributes(NewPAL
);
304 if (cast
<CallInst
>(Call
)->isTailCall())
305 cast
<CallInst
>(New
)->setTailCall();
311 // Update the callgraph to know that the callsite has been transformed.
312 CG
[Call
->getParent()->getParent()]->replaceCallSite(Call
, New
);
314 // Update all users of sret parameter to extract value using extractvalue.
315 for (Value::use_iterator UI
= FirstCArg
->use_begin(),
316 UE
= FirstCArg
->use_end(); UI
!= UE
; ) {
318 CallInst
*C2
= dyn_cast
<CallInst
>(U2
);
319 if (C2
&& (C2
== Call
))
321 else if (GetElementPtrInst
*UGEP
= dyn_cast
<GetElementPtrInst
>(U2
)) {
322 ConstantInt
*Idx
= dyn_cast
<ConstantInt
>(UGEP
->getOperand(2));
323 assert (Idx
&& "Unexpected getelementptr index!");
324 Value
*GR
= ExtractValueInst::Create(New
, Idx
->getZExtValue(),
326 while(!UGEP
->use_empty()) {
327 // isSafeToUpdateAllCallers has checked that all GEP uses are
329 LoadInst
*L
= cast
<LoadInst
>(*UGEP
->use_begin());
330 L
->replaceAllUsesWith(GR
);
331 L
->eraseFromParent();
333 UGEP
->eraseFromParent();
335 else assert( 0 && "Unexpected sret parameter use");
337 Call
->eraseFromParent();
341 /// nestedStructType - Return true if STy includes any
342 /// other aggregate types
343 bool SRETPromotion::nestedStructType(const StructType
*STy
) {
344 unsigned Num
= STy
->getNumElements();
345 for (unsigned i
= 0; i
< Num
; i
++) {
346 const Type
*Ty
= STy
->getElementType(i
);
347 if (!Ty
->isSingleValueType() && Ty
!= Type::VoidTy
)