1 //===-- StructRetPromotion.cpp - Promote sret arguments ------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass finds functions that return a struct (using a pointer to the struct
11 // as the first argument of the function, marked with the 'sret' attribute) and
12 // replaces them with a new function that simply returns each of the elements of
13 // that struct (using multiple return values).
15 // This pass works under a number of conditions:
16 // 1. The returned struct must not contain other structs
17 // 2. The returned struct must only be used to load values from
18 // 3. The placeholder struct passed in is the result of an alloca
20 //===----------------------------------------------------------------------===//
22 #define DEBUG_TYPE "sretpromotion"
23 #include "llvm/Transforms/IPO.h"
24 #include "llvm/Constants.h"
25 #include "llvm/DerivedTypes.h"
26 #include "llvm/LLVMContext.h"
27 #include "llvm/Module.h"
28 #include "llvm/CallGraphSCCPass.h"
29 #include "llvm/Instructions.h"
30 #include "llvm/Analysis/CallGraph.h"
31 #include "llvm/Support/CallSite.h"
32 #include "llvm/Support/CFG.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/ADT/Statistic.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/Statistic.h"
37 #include "llvm/Support/Compiler.h"
38 #include "llvm/Support/raw_ostream.h"
41 STATISTIC(NumRejectedSRETUses
, "Number of sret rejected due to unexpected uses");
42 STATISTIC(NumSRET
, "Number of sret promoted");
44 /// SRETPromotion - This pass removes sret parameter and updates
45 /// function to use multiple return value.
47 struct VISIBILITY_HIDDEN SRETPromotion
: public CallGraphSCCPass
{
48 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
49 CallGraphSCCPass::getAnalysisUsage(AU
);
52 virtual bool runOnSCC(std::vector
<CallGraphNode
*> &SCC
);
53 static char ID
; // Pass identification, replacement for typeid
54 SRETPromotion() : CallGraphSCCPass(&ID
) {}
57 CallGraphNode
*PromoteReturn(CallGraphNode
*CGN
);
58 bool isSafeToUpdateAllCallers(Function
*F
);
59 Function
*cloneFunctionBody(Function
*F
, const StructType
*STy
);
60 CallGraphNode
*updateCallSites(Function
*F
, Function
*NF
);
61 bool nestedStructType(const StructType
*STy
);
65 char SRETPromotion::ID
= 0;
66 static RegisterPass
<SRETPromotion
>
67 X("sretpromotion", "Promote sret arguments to multiple ret values");
69 Pass
*llvm::createStructRetPromotionPass() {
70 return new SRETPromotion();
73 bool SRETPromotion::runOnSCC(std::vector
<CallGraphNode
*> &SCC
) {
76 for (unsigned i
= 0, e
= SCC
.size(); i
!= e
; ++i
)
77 if (CallGraphNode
*NewNode
= PromoteReturn(SCC
[i
])) {
85 /// PromoteReturn - This method promotes function that uses StructRet paramater
86 /// into a function that uses multiple return values.
87 CallGraphNode
*SRETPromotion::PromoteReturn(CallGraphNode
*CGN
) {
88 Function
*F
= CGN
->getFunction();
90 if (!F
|| F
->isDeclaration() || !F
->hasLocalLinkage())
93 // Make sure that function returns struct.
94 if (F
->arg_size() == 0 || !F
->hasStructRetAttr() || F
->doesNotReturn())
97 DEBUG(errs() << "SretPromotion: Looking at sret function "
98 << F
->getName() << "\n");
100 assert(F
->getReturnType() == Type::getVoidTy(F
->getContext()) &&
101 "Invalid function return type");
102 Function::arg_iterator AI
= F
->arg_begin();
103 const llvm::PointerType
*FArgType
= dyn_cast
<PointerType
>(AI
->getType());
104 assert(FArgType
&& "Invalid sret parameter type");
105 const llvm::StructType
*STy
=
106 dyn_cast
<StructType
>(FArgType
->getElementType());
107 assert(STy
&& "Invalid sret parameter element type");
109 // Check if it is ok to perform this promotion.
110 if (isSafeToUpdateAllCallers(F
) == false) {
111 DEBUG(errs() << "SretPromotion: Not all callers can be updated\n");
112 NumRejectedSRETUses
++;
116 DEBUG(errs() << "SretPromotion: sret argument will be promoted\n");
118 // [1] Replace use of sret parameter
119 AllocaInst
*TheAlloca
= new AllocaInst(STy
, NULL
, "mrv",
120 F
->getEntryBlock().begin());
121 Value
*NFirstArg
= F
->arg_begin();
122 NFirstArg
->replaceAllUsesWith(TheAlloca
);
124 // [2] Find and replace ret instructions
125 for (Function::iterator FI
= F
->begin(), FE
= F
->end(); FI
!= FE
; ++FI
)
126 for(BasicBlock::iterator BI
= FI
->begin(), BE
= FI
->end(); BI
!= BE
; ) {
129 if (isa
<ReturnInst
>(I
)) {
130 Value
*NV
= new LoadInst(TheAlloca
, "mrv.ld", I
);
131 ReturnInst
*NR
= ReturnInst::Create(F
->getContext(), NV
, I
);
132 I
->replaceAllUsesWith(NR
);
133 I
->eraseFromParent();
137 // [3] Create the new function body and insert it into the module.
138 Function
*NF
= cloneFunctionBody(F
, STy
);
140 // [4] Update all call sites to use new function
141 CallGraphNode
*NF_CFN
= updateCallSites(F
, NF
);
143 CallGraph
&CG
= getAnalysis
<CallGraph
>();
144 NF_CFN
->stealCalledFunctionsFrom(CG
[F
]);
146 delete CG
.removeFunctionFromModule(F
);
150 // Check if it is ok to perform this promotion.
151 bool SRETPromotion::isSafeToUpdateAllCallers(Function
*F
) {
154 // No users. OK to modify signature.
157 for (Value::use_iterator FnUseI
= F
->use_begin(), FnUseE
= F
->use_end();
158 FnUseI
!= FnUseE
; ++FnUseI
) {
159 // The function is passed in as an argument to (possibly) another function,
160 // we can't change it!
161 CallSite CS
= CallSite::get(*FnUseI
);
162 Instruction
*Call
= CS
.getInstruction();
163 // The function is used by something else than a call or invoke instruction,
164 // we can't change it!
165 if (!Call
|| !CS
.isCallee(FnUseI
))
167 CallSite::arg_iterator AI
= CS
.arg_begin();
168 Value
*FirstArg
= *AI
;
170 if (!isa
<AllocaInst
>(FirstArg
))
173 // Check FirstArg's users.
174 for (Value::use_iterator ArgI
= FirstArg
->use_begin(),
175 ArgE
= FirstArg
->use_end(); ArgI
!= ArgE
; ++ArgI
) {
177 // If FirstArg user is a CallInst that does not correspond to current
178 // call site then this function F is not suitable for sret promotion.
179 if (CallInst
*CI
= dyn_cast
<CallInst
>(ArgI
)) {
183 // If FirstArg user is a GEP whose all users are not LoadInst then
184 // this function F is not suitable for sret promotion.
185 else if (GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(ArgI
)) {
186 // TODO : Use dom info and insert PHINodes to collect get results
187 // from multiple call sites for this GEP.
188 if (GEP
->getParent() != Call
->getParent())
190 for (Value::use_iterator GEPI
= GEP
->use_begin(), GEPE
= GEP
->use_end();
191 GEPI
!= GEPE
; ++GEPI
)
192 if (!isa
<LoadInst
>(GEPI
))
195 // Any other FirstArg users make this function unsuitable for sret
205 /// cloneFunctionBody - Create a new function based on F and
206 /// insert it into module. Remove first argument. Use STy as
207 /// the return type for new function.
208 Function
*SRETPromotion::cloneFunctionBody(Function
*F
,
209 const StructType
*STy
) {
211 const FunctionType
*FTy
= F
->getFunctionType();
212 std::vector
<const Type
*> Params
;
214 // Attributes - Keep track of the parameter attributes for the arguments.
215 SmallVector
<AttributeWithIndex
, 8> AttributesVec
;
216 const AttrListPtr
&PAL
= F
->getAttributes();
218 // Add any return attributes.
219 if (Attributes attrs
= PAL
.getRetAttributes())
220 AttributesVec
.push_back(AttributeWithIndex::get(0, attrs
));
222 // Skip first argument.
223 Function::arg_iterator I
= F
->arg_begin(), E
= F
->arg_end();
225 // 0th parameter attribute is reserved for return type.
226 // 1th parameter attribute is for first 1st sret argument.
227 unsigned ParamIndex
= 2;
229 Params
.push_back(I
->getType());
230 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
231 AttributesVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
236 // Add any fn attributes.
237 if (Attributes attrs
= PAL
.getFnAttributes())
238 AttributesVec
.push_back(AttributeWithIndex::get(~0, attrs
));
241 FunctionType
*NFTy
= FunctionType::get(STy
, Params
, FTy
->isVarArg());
242 Function
*NF
= Function::Create(NFTy
, F
->getLinkage());
244 NF
->copyAttributesFrom(F
);
245 NF
->setAttributes(AttrListPtr::get(AttributesVec
.begin(), AttributesVec
.end()));
246 F
->getParent()->getFunctionList().insert(F
, NF
);
247 NF
->getBasicBlockList().splice(NF
->begin(), F
->getBasicBlockList());
252 Function::arg_iterator NI
= NF
->arg_begin();
255 I
->replaceAllUsesWith(NI
);
264 /// updateCallSites - Update all sites that call F to use NF.
265 CallGraphNode
*SRETPromotion::updateCallSites(Function
*F
, Function
*NF
) {
266 CallGraph
&CG
= getAnalysis
<CallGraph
>();
267 SmallVector
<Value
*, 16> Args
;
269 // Attributes - Keep track of the parameter attributes for the arguments.
270 SmallVector
<AttributeWithIndex
, 8> ArgAttrsVec
;
272 // Get a new callgraph node for NF.
273 CallGraphNode
*NF_CGN
= CG
.getOrInsertFunction(NF
);
275 while (!F
->use_empty()) {
276 CallSite CS
= CallSite::get(*F
->use_begin());
277 Instruction
*Call
= CS
.getInstruction();
279 const AttrListPtr
&PAL
= F
->getAttributes();
280 // Add any return attributes.
281 if (Attributes attrs
= PAL
.getRetAttributes())
282 ArgAttrsVec
.push_back(AttributeWithIndex::get(0, attrs
));
284 // Copy arguments, however skip first one.
285 CallSite::arg_iterator AI
= CS
.arg_begin(), AE
= CS
.arg_end();
286 Value
*FirstCArg
= *AI
;
288 // 0th parameter attribute is reserved for return type.
289 // 1th parameter attribute is for first 1st sret argument.
290 unsigned ParamIndex
= 2;
293 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
294 ArgAttrsVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
299 // Add any function attributes.
300 if (Attributes attrs
= PAL
.getFnAttributes())
301 ArgAttrsVec
.push_back(AttributeWithIndex::get(~0, attrs
));
303 AttrListPtr NewPAL
= AttrListPtr::get(ArgAttrsVec
.begin(), ArgAttrsVec
.end());
305 // Build new call instruction.
307 if (InvokeInst
*II
= dyn_cast
<InvokeInst
>(Call
)) {
308 New
= InvokeInst::Create(NF
, II
->getNormalDest(), II
->getUnwindDest(),
309 Args
.begin(), Args
.end(), "", Call
);
310 cast
<InvokeInst
>(New
)->setCallingConv(CS
.getCallingConv());
311 cast
<InvokeInst
>(New
)->setAttributes(NewPAL
);
313 New
= CallInst::Create(NF
, Args
.begin(), Args
.end(), "", Call
);
314 cast
<CallInst
>(New
)->setCallingConv(CS
.getCallingConv());
315 cast
<CallInst
>(New
)->setAttributes(NewPAL
);
316 if (cast
<CallInst
>(Call
)->isTailCall())
317 cast
<CallInst
>(New
)->setTailCall();
323 // Update the callgraph to know that the callsite has been transformed.
324 CallGraphNode
*CalleeNode
= CG
[Call
->getParent()->getParent()];
325 CalleeNode
->removeCallEdgeFor(Call
);
326 CalleeNode
->addCalledFunction(New
, NF_CGN
);
328 // Update all users of sret parameter to extract value using extractvalue.
329 for (Value::use_iterator UI
= FirstCArg
->use_begin(),
330 UE
= FirstCArg
->use_end(); UI
!= UE
; ) {
332 CallInst
*C2
= dyn_cast
<CallInst
>(U2
);
333 if (C2
&& (C2
== Call
))
336 GetElementPtrInst
*UGEP
= cast
<GetElementPtrInst
>(U2
);
337 ConstantInt
*Idx
= cast
<ConstantInt
>(UGEP
->getOperand(2));
338 Value
*GR
= ExtractValueInst::Create(New
, Idx
->getZExtValue(),
340 while(!UGEP
->use_empty()) {
341 // isSafeToUpdateAllCallers has checked that all GEP uses are
343 LoadInst
*L
= cast
<LoadInst
>(*UGEP
->use_begin());
344 L
->replaceAllUsesWith(GR
);
345 L
->eraseFromParent();
347 UGEP
->eraseFromParent();
350 Call
->eraseFromParent();
356 /// nestedStructType - Return true if STy includes any
357 /// other aggregate types
358 bool SRETPromotion::nestedStructType(const StructType
*STy
) {
359 unsigned Num
= STy
->getNumElements();
360 for (unsigned i
= 0; i
< Num
; i
++) {
361 const Type
*Ty
= STy
->getElementType(i
);
362 if (!Ty
->isSingleValueType() && Ty
!= Type::getVoidTy(STy
->getContext()))