1 //===- MergeFunctions.cpp - Merge identical functions ---------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass looks for equivalent functions that are mergable and folds them.
12 // A hash is computed from the function, based on its type and number of
15 // Once all hashes are computed, we perform an expensive equality comparison
16 // on each function pair. This takes n^2/2 comparisons per bucket, so it's
17 // important that the hash function be high quality. The equality comparison
18 // iterates through each instruction in each basic block.
20 // When a match is found, the functions are folded. We can only fold two
21 // functions when we know that the definition of one of them is not
24 //===----------------------------------------------------------------------===//
28 // * fold vector<T*>::push_back and vector<S*>::push_back.
30 // These two functions have different types, but in a way that doesn't matter
31 // to us. As long as we never see an S or T itself, using S* and S** is the
32 // same as using a T* and T**.
34 // * virtual functions.
36 // Many functions have their address taken by the virtual function table for
37 // the object they belong to. However, as long as it's only used for a lookup
38 // and call, this is irrelevant, and we'd like to fold such implementations.
40 //===----------------------------------------------------------------------===//
42 #define DEBUG_TYPE "mergefunc"
43 #include "llvm/Transforms/IPO.h"
44 #include "llvm/ADT/DenseMap.h"
45 #include "llvm/ADT/FoldingSet.h"
46 #include "llvm/ADT/Statistic.h"
47 #include "llvm/Constants.h"
48 #include "llvm/InlineAsm.h"
49 #include "llvm/Instructions.h"
50 #include "llvm/LLVMContext.h"
51 #include "llvm/Module.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/CallSite.h"
54 #include "llvm/Support/Compiler.h"
55 #include "llvm/Support/Debug.h"
56 #include "llvm/Support/ErrorHandling.h"
57 #include "llvm/Support/raw_ostream.h"
62 STATISTIC(NumFunctionsMerged
, "Number of functions merged");
65 struct VISIBILITY_HIDDEN MergeFunctions
: public ModulePass
{
66 static char ID
; // Pass identification, replacement for typeid
67 MergeFunctions() : ModulePass(&ID
) {}
69 bool runOnModule(Module
&M
);
73 char MergeFunctions::ID
= 0;
74 static RegisterPass
<MergeFunctions
>
75 X("mergefunc", "Merge Functions");
77 ModulePass
*llvm::createMergeFunctionsPass() {
78 return new MergeFunctions();
81 // ===----------------------------------------------------------------------===
82 // Comparison of functions
83 // ===----------------------------------------------------------------------===
85 static unsigned long hash(const Function
*F
) {
86 const FunctionType
*FTy
= F
->getFunctionType();
89 ID
.AddInteger(F
->size());
90 ID
.AddInteger(F
->getCallingConv());
91 ID
.AddBoolean(F
->hasGC());
92 ID
.AddBoolean(FTy
->isVarArg());
93 ID
.AddInteger(FTy
->getReturnType()->getTypeID());
94 for (unsigned i
= 0, e
= FTy
->getNumParams(); i
!= e
; ++i
)
95 ID
.AddInteger(FTy
->getParamType(i
)->getTypeID());
96 return ID
.ComputeHash();
99 /// IgnoreBitcasts - given a bitcast, returns the first non-bitcast found by
100 /// walking the chain of cast operands. Otherwise, returns the argument.
101 static Value
* IgnoreBitcasts(Value
*V
) {
102 while (BitCastInst
*BC
= dyn_cast
<BitCastInst
>(V
))
103 V
= BC
->getOperand(0);
108 /// isEquivalentType - any two pointers are equivalent. Otherwise, standard
109 /// type equivalence rules apply.
110 static bool isEquivalentType(const Type
*Ty1
, const Type
*Ty2
) {
113 if (Ty1
->getTypeID() != Ty2
->getTypeID())
116 switch(Ty1
->getTypeID()) {
118 case Type::FloatTyID
:
119 case Type::DoubleTyID
:
120 case Type::X86_FP80TyID
:
121 case Type::FP128TyID
:
122 case Type::PPC_FP128TyID
:
123 case Type::LabelTyID
:
124 case Type::MetadataTyID
:
127 case Type::IntegerTyID
:
128 case Type::OpaqueTyID
:
129 // Ty1 == Ty2 would have returned true earlier.
133 llvm_unreachable("Unknown type!");
136 case Type::PointerTyID
: {
137 const PointerType
*PTy1
= cast
<PointerType
>(Ty1
);
138 const PointerType
*PTy2
= cast
<PointerType
>(Ty2
);
139 return PTy1
->getAddressSpace() == PTy2
->getAddressSpace();
142 case Type::StructTyID
: {
143 const StructType
*STy1
= cast
<StructType
>(Ty1
);
144 const StructType
*STy2
= cast
<StructType
>(Ty2
);
145 if (STy1
->getNumElements() != STy2
->getNumElements())
148 if (STy1
->isPacked() != STy2
->isPacked())
151 for (unsigned i
= 0, e
= STy1
->getNumElements(); i
!= e
; ++i
) {
152 if (!isEquivalentType(STy1
->getElementType(i
), STy2
->getElementType(i
)))
158 case Type::FunctionTyID
: {
159 const FunctionType
*FTy1
= cast
<FunctionType
>(Ty1
);
160 const FunctionType
*FTy2
= cast
<FunctionType
>(Ty2
);
161 if (FTy1
->getNumParams() != FTy2
->getNumParams() ||
162 FTy1
->isVarArg() != FTy2
->isVarArg())
165 if (!isEquivalentType(FTy1
->getReturnType(), FTy2
->getReturnType()))
168 for (unsigned i
= 0, e
= FTy1
->getNumParams(); i
!= e
; ++i
) {
169 if (!isEquivalentType(FTy1
->getParamType(i
), FTy2
->getParamType(i
)))
175 case Type::ArrayTyID
:
176 case Type::VectorTyID
: {
177 const SequentialType
*STy1
= cast
<SequentialType
>(Ty1
);
178 const SequentialType
*STy2
= cast
<SequentialType
>(Ty2
);
179 return isEquivalentType(STy1
->getElementType(), STy2
->getElementType());
184 /// isEquivalentOperation - determine whether the two operations are the same
185 /// except that pointer-to-A and pointer-to-B are equivalent. This should be
186 /// kept in sync with Instruction::isSameOperationAs.
188 isEquivalentOperation(const Instruction
*I1
, const Instruction
*I2
) {
189 if (I1
->getOpcode() != I2
->getOpcode() ||
190 I1
->getNumOperands() != I2
->getNumOperands() ||
191 !isEquivalentType(I1
->getType(), I2
->getType()) ||
192 !I1
->hasSameSubclassOptionalData(I2
))
195 // We have two instructions of identical opcode and #operands. Check to see
196 // if all operands are the same type
197 for (unsigned i
= 0, e
= I1
->getNumOperands(); i
!= e
; ++i
)
198 if (!isEquivalentType(I1
->getOperand(i
)->getType(),
199 I2
->getOperand(i
)->getType()))
202 // Check special state that is a part of some instructions.
203 if (const LoadInst
*LI
= dyn_cast
<LoadInst
>(I1
))
204 return LI
->isVolatile() == cast
<LoadInst
>(I2
)->isVolatile() &&
205 LI
->getAlignment() == cast
<LoadInst
>(I2
)->getAlignment();
206 if (const StoreInst
*SI
= dyn_cast
<StoreInst
>(I1
))
207 return SI
->isVolatile() == cast
<StoreInst
>(I2
)->isVolatile() &&
208 SI
->getAlignment() == cast
<StoreInst
>(I2
)->getAlignment();
209 if (const CmpInst
*CI
= dyn_cast
<CmpInst
>(I1
))
210 return CI
->getPredicate() == cast
<CmpInst
>(I2
)->getPredicate();
211 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I1
))
212 return CI
->isTailCall() == cast
<CallInst
>(I2
)->isTailCall() &&
213 CI
->getCallingConv() == cast
<CallInst
>(I2
)->getCallingConv() &&
214 CI
->getAttributes().getRawPointer() ==
215 cast
<CallInst
>(I2
)->getAttributes().getRawPointer();
216 if (const InvokeInst
*CI
= dyn_cast
<InvokeInst
>(I1
))
217 return CI
->getCallingConv() == cast
<InvokeInst
>(I2
)->getCallingConv() &&
218 CI
->getAttributes().getRawPointer() ==
219 cast
<InvokeInst
>(I2
)->getAttributes().getRawPointer();
220 if (const InsertValueInst
*IVI
= dyn_cast
<InsertValueInst
>(I1
)) {
221 if (IVI
->getNumIndices() != cast
<InsertValueInst
>(I2
)->getNumIndices())
223 for (unsigned i
= 0, e
= IVI
->getNumIndices(); i
!= e
; ++i
)
224 if (IVI
->idx_begin()[i
] != cast
<InsertValueInst
>(I2
)->idx_begin()[i
])
228 if (const ExtractValueInst
*EVI
= dyn_cast
<ExtractValueInst
>(I1
)) {
229 if (EVI
->getNumIndices() != cast
<ExtractValueInst
>(I2
)->getNumIndices())
231 for (unsigned i
= 0, e
= EVI
->getNumIndices(); i
!= e
; ++i
)
232 if (EVI
->idx_begin()[i
] != cast
<ExtractValueInst
>(I2
)->idx_begin()[i
])
240 static bool compare(const Value
*V
, const Value
*U
) {
241 assert(!isa
<BasicBlock
>(V
) && !isa
<BasicBlock
>(U
) &&
242 "Must not compare basic blocks.");
244 assert(isEquivalentType(V
->getType(), U
->getType()) &&
245 "Two of the same operation have operands of different type.");
247 // TODO: If the constant is an expression of F, we should accept that it's
248 // equal to the same expression in terms of G.
249 if (isa
<Constant
>(V
))
252 // The caller has ensured that ValueMap[V] != U. Since Arguments are
253 // pre-loaded into the ValueMap, and Instructions are added as we go, we know
254 // that this can only be a mis-match.
255 if (isa
<Instruction
>(V
) || isa
<Argument
>(V
))
258 if (isa
<InlineAsm
>(V
) && isa
<InlineAsm
>(U
)) {
259 const InlineAsm
*IAF
= cast
<InlineAsm
>(V
);
260 const InlineAsm
*IAG
= cast
<InlineAsm
>(U
);
261 return IAF
->getAsmString() == IAG
->getAsmString() &&
262 IAF
->getConstraintString() == IAG
->getConstraintString();
268 static bool equals(const BasicBlock
*BB1
, const BasicBlock
*BB2
,
269 DenseMap
<const Value
*, const Value
*> &ValueMap
,
270 DenseMap
<const Value
*, const Value
*> &SpeculationMap
) {
271 // Speculatively add it anyways. If it's false, we'll notice a difference
272 // later, and this won't matter.
275 BasicBlock::const_iterator FI
= BB1
->begin(), FE
= BB1
->end();
276 BasicBlock::const_iterator GI
= BB2
->begin(), GE
= BB2
->end();
279 if (isa
<BitCastInst
>(FI
)) {
283 if (isa
<BitCastInst
>(GI
)) {
288 if (!isEquivalentOperation(FI
, GI
))
291 if (isa
<GetElementPtrInst
>(FI
)) {
292 const GetElementPtrInst
*GEPF
= cast
<GetElementPtrInst
>(FI
);
293 const GetElementPtrInst
*GEPG
= cast
<GetElementPtrInst
>(GI
);
294 if (GEPF
->hasAllZeroIndices() && GEPG
->hasAllZeroIndices()) {
295 // It's effectively a bitcast.
300 // TODO: we only really care about the elements before the index
301 if (FI
->getOperand(0)->getType() != GI
->getOperand(0)->getType())
305 if (ValueMap
[FI
] == GI
) {
310 if (ValueMap
[FI
] != NULL
)
313 for (unsigned i
= 0, e
= FI
->getNumOperands(); i
!= e
; ++i
) {
314 Value
*OpF
= IgnoreBitcasts(FI
->getOperand(i
));
315 Value
*OpG
= IgnoreBitcasts(GI
->getOperand(i
));
317 if (ValueMap
[OpF
] == OpG
)
320 if (ValueMap
[OpF
] != NULL
)
323 if (OpF
->getValueID() != OpG
->getValueID() ||
324 !isEquivalentType(OpF
->getType(), OpG
->getType()))
327 if (isa
<PHINode
>(FI
)) {
328 if (SpeculationMap
[OpF
] == NULL
)
329 SpeculationMap
[OpF
] = OpG
;
330 else if (SpeculationMap
[OpF
] != OpG
)
333 } else if (isa
<BasicBlock
>(OpF
)) {
334 assert(isa
<TerminatorInst
>(FI
) &&
335 "BasicBlock referenced by non-Terminator non-PHI");
336 // This call changes the ValueMap, hence we can't use
337 // Value *& = ValueMap[...]
338 if (!equals(cast
<BasicBlock
>(OpF
), cast
<BasicBlock
>(OpG
), ValueMap
,
342 if (!compare(OpF
, OpG
))
351 } while (FI
!= FE
&& GI
!= GE
);
353 return FI
== FE
&& GI
== GE
;
356 static bool equals(const Function
*F
, const Function
*G
) {
357 // We need to recheck everything, but check the things that weren't included
358 // in the hash first.
360 if (F
->getAttributes() != G
->getAttributes())
363 if (F
->hasGC() != G
->hasGC())
366 if (F
->hasGC() && F
->getGC() != G
->getGC())
369 if (F
->hasSection() != G
->hasSection())
372 if (F
->hasSection() && F
->getSection() != G
->getSection())
375 if (F
->isVarArg() != G
->isVarArg())
378 // TODO: if it's internal and only used in direct calls, we could handle this
380 if (F
->getCallingConv() != G
->getCallingConv())
383 if (!isEquivalentType(F
->getFunctionType(), G
->getFunctionType()))
386 DenseMap
<const Value
*, const Value
*> ValueMap
;
387 DenseMap
<const Value
*, const Value
*> SpeculationMap
;
390 assert(F
->arg_size() == G
->arg_size() &&
391 "Identical functions have a different number of args.");
393 for (Function::const_arg_iterator fi
= F
->arg_begin(), gi
= G
->arg_begin(),
394 fe
= F
->arg_end(); fi
!= fe
; ++fi
, ++gi
)
397 if (!equals(&F
->getEntryBlock(), &G
->getEntryBlock(), ValueMap
,
401 for (DenseMap
<const Value
*, const Value
*>::iterator
402 I
= SpeculationMap
.begin(), E
= SpeculationMap
.end(); I
!= E
; ++I
) {
403 if (ValueMap
[I
->first
] != I
->second
)
410 // ===----------------------------------------------------------------------===
411 // Folding of functions
412 // ===----------------------------------------------------------------------===
415 // * F is external strong, G is external strong:
416 // turn G into a thunk to F (1)
417 // * F is external strong, G is external weak:
418 // turn G into a thunk to F (1)
419 // * F is external weak, G is external weak:
421 // * F is external strong, G is internal:
422 // address of G taken:
423 // turn G into a thunk to F (1)
424 // address of G not taken:
425 // make G an alias to F (2)
426 // * F is internal, G is external weak
427 // address of F is taken:
428 // turn G into a thunk to F (1)
429 // address of F is not taken:
430 // make G an alias of F (2)
431 // * F is internal, G is internal:
432 // address of F and G are taken:
433 // turn G into a thunk to F (1)
434 // address of G is not taken:
435 // make G an alias to F (2)
437 // alias requires linkage == (external,local,weak) fallback to creating a thunk
438 // external means 'externally visible' linkage != (internal,private)
439 // internal means linkage == (internal,private)
440 // weak means linkage mayBeOverridable
441 // being external implies that the address is taken
443 // 1. turn G into a thunk to F
444 // 2. make G an alias to F
446 enum LinkageCategory
{
452 static LinkageCategory
categorize(const Function
*F
) {
453 switch (F
->getLinkage()) {
454 case GlobalValue::InternalLinkage
:
455 case GlobalValue::PrivateLinkage
:
456 case GlobalValue::LinkerPrivateLinkage
:
459 case GlobalValue::WeakAnyLinkage
:
460 case GlobalValue::WeakODRLinkage
:
461 case GlobalValue::ExternalWeakLinkage
:
464 case GlobalValue::ExternalLinkage
:
465 case GlobalValue::AvailableExternallyLinkage
:
466 case GlobalValue::LinkOnceAnyLinkage
:
467 case GlobalValue::LinkOnceODRLinkage
:
468 case GlobalValue::AppendingLinkage
:
469 case GlobalValue::DLLImportLinkage
:
470 case GlobalValue::DLLExportLinkage
:
471 case GlobalValue::GhostLinkage
:
472 case GlobalValue::CommonLinkage
:
473 return ExternalStrong
;
476 llvm_unreachable("Unknown LinkageType.");
480 static void ThunkGToF(Function
*F
, Function
*G
) {
481 Function
*NewG
= Function::Create(G
->getFunctionType(), G
->getLinkage(), "",
483 BasicBlock
*BB
= BasicBlock::Create(F
->getContext(), "", NewG
);
485 std::vector
<Value
*> Args
;
487 const FunctionType
*FFTy
= F
->getFunctionType();
488 for (Function::arg_iterator AI
= NewG
->arg_begin(), AE
= NewG
->arg_end();
490 if (FFTy
->getParamType(i
) == AI
->getType())
493 Value
*BCI
= new BitCastInst(AI
, FFTy
->getParamType(i
), "", BB
);
499 CallInst
*CI
= CallInst::Create(F
, Args
.begin(), Args
.end(), "", BB
);
501 CI
->setCallingConv(F
->getCallingConv());
502 if (NewG
->getReturnType() == Type::getVoidTy(F
->getContext())) {
503 ReturnInst::Create(F
->getContext(), BB
);
504 } else if (CI
->getType() != NewG
->getReturnType()) {
505 Value
*BCI
= new BitCastInst(CI
, NewG
->getReturnType(), "", BB
);
506 ReturnInst::Create(F
->getContext(), BCI
, BB
);
508 ReturnInst::Create(F
->getContext(), CI
, BB
);
511 NewG
->copyAttributesFrom(G
);
513 G
->replaceAllUsesWith(NewG
);
514 G
->eraseFromParent();
516 // TODO: look at direct callers to G and make them all direct callers to F.
519 static void AliasGToF(Function
*F
, Function
*G
) {
520 if (!G
->hasExternalLinkage() && !G
->hasLocalLinkage() && !G
->hasWeakLinkage())
521 return ThunkGToF(F
, G
);
523 GlobalAlias
*GA
= new GlobalAlias(
524 G
->getType(), G
->getLinkage(), "",
525 ConstantExpr::getBitCast(F
, G
->getType()), G
->getParent());
526 F
->setAlignment(std::max(F
->getAlignment(), G
->getAlignment()));
528 GA
->setVisibility(G
->getVisibility());
529 G
->replaceAllUsesWith(GA
);
530 G
->eraseFromParent();
533 static bool fold(std::vector
<Function
*> &FnVec
, unsigned i
, unsigned j
) {
534 Function
*F
= FnVec
[i
];
535 Function
*G
= FnVec
[j
];
537 LinkageCategory catF
= categorize(F
);
538 LinkageCategory catG
= categorize(G
);
540 if (catF
== ExternalWeak
|| (catF
== Internal
&& catG
== ExternalStrong
)) {
541 std::swap(FnVec
[i
], FnVec
[j
]);
543 std::swap(catF
, catG
);
554 if (G
->hasAddressTaken())
563 assert(catG
== ExternalWeak
);
565 // Make them both thunks to the same internal function.
566 F
->setAlignment(std::max(F
->getAlignment(), G
->getAlignment()));
567 Function
*H
= Function::Create(F
->getFunctionType(), F
->getLinkage(), "",
569 H
->copyAttributesFrom(F
);
571 F
->replaceAllUsesWith(H
);
576 F
->setLinkage(GlobalValue::InternalLinkage
);
585 if (F
->hasAddressTaken())
591 bool addrTakenF
= F
->hasAddressTaken();
592 bool addrTakenG
= G
->hasAddressTaken();
593 if (!addrTakenF
&& addrTakenG
) {
594 std::swap(FnVec
[i
], FnVec
[j
]);
596 std::swap(addrTakenF
, addrTakenG
);
599 if (addrTakenF
&& addrTakenG
) {
610 ++NumFunctionsMerged
;
614 // ===----------------------------------------------------------------------===
616 // ===----------------------------------------------------------------------===
618 bool MergeFunctions::runOnModule(Module
&M
) {
619 bool Changed
= false;
621 std::map
<unsigned long, std::vector
<Function
*> > FnMap
;
623 for (Module::iterator F
= M
.begin(), E
= M
.end(); F
!= E
; ++F
) {
624 if (F
->isDeclaration() || F
->isIntrinsic())
627 FnMap
[hash(F
)].push_back(F
);
630 // TODO: instead of running in a loop, we could also fold functions in
631 // callgraph order. Constructing the CFG probably isn't cheaper than just
632 // running in a loop, unless it happened to already be available.
636 LocalChanged
= false;
637 DEBUG(errs() << "size: " << FnMap
.size() << "\n");
638 for (std::map
<unsigned long, std::vector
<Function
*> >::iterator
639 I
= FnMap
.begin(), E
= FnMap
.end(); I
!= E
; ++I
) {
640 std::vector
<Function
*> &FnVec
= I
->second
;
641 DEBUG(errs() << "hash (" << I
->first
<< "): " << FnVec
.size() << "\n");
643 for (int i
= 0, e
= FnVec
.size(); i
!= e
; ++i
) {
644 for (int j
= i
+ 1; j
!= e
; ++j
) {
645 bool isEqual
= equals(FnVec
[i
], FnVec
[j
]);
647 DEBUG(errs() << " " << FnVec
[i
]->getName()
648 << (isEqual
? " == " : " != ")
649 << FnVec
[j
]->getName() << "\n");
652 if (fold(FnVec
, i
, j
)) {
654 FnVec
.erase(FnVec
.begin() + j
);
662 Changed
|= LocalChanged
;
663 } while (LocalChanged
);