1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements the interface to tear out a code region, such as an
11 // individual loop or a parallel section, into a new function, replacing it with
12 // a call to the new function.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Utils/CodeExtractor.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/BlockFrequencyInfo.h"
21 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
22 #include "llvm/Analysis/BranchProbabilityInfo.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Analysis/RegionInfo.h"
25 #include "llvm/Analysis/RegionIterator.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Verifier.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BlockFrequency.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/ErrorHandling.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
47 #define DEBUG_TYPE "code-extractor"
49 // Provide a command-line option to aggregate function arguments into a struct
50 // for functions produced by the code extractor. This is useful when converting
51 // extracted functions to pthread-based code, as only one argument (void*) can
52 // be passed in to pthread_create().
54 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden
,
55 cl::desc("Aggregate arguments to code-extracted functions"));
57 /// \brief Test whether a block is valid for extraction.
58 bool CodeExtractor::isBlockValidForExtraction(const BasicBlock
&BB
) {
59 // Landing pads must be in the function where they were inserted for cleanup.
62 // taking the address of a basic block moved to another function is illegal
63 if (BB
.hasAddressTaken())
66 // don't hoist code that uses another basicblock address, as it's likely to
67 // lead to unexpected behavior, like cross-function jumps
68 SmallPtrSet
<User
const *, 16> Visited
;
69 SmallVector
<User
const *, 16> ToVisit
;
71 for (Instruction
const &Inst
: BB
)
72 ToVisit
.push_back(&Inst
);
74 while (!ToVisit
.empty()) {
75 User
const *Curr
= ToVisit
.pop_back_val();
76 if (!Visited
.insert(Curr
).second
)
78 if (isa
<BlockAddress
const>(Curr
))
79 return false; // even a reference to self is likely to be not compatible
81 if (isa
<Instruction
>(Curr
) && cast
<Instruction
>(Curr
)->getParent() != &BB
)
84 for (auto const &U
: Curr
->operands()) {
85 if (auto *UU
= dyn_cast
<User
>(U
))
86 ToVisit
.push_back(UU
);
90 // Don't hoist code containing allocas, invokes, or vastarts.
91 for (BasicBlock::const_iterator I
= BB
.begin(), E
= BB
.end(); I
!= E
; ++I
) {
92 if (isa
<AllocaInst
>(I
) || isa
<InvokeInst
>(I
))
94 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I
))
95 if (const Function
*F
= CI
->getCalledFunction())
96 if (F
->getIntrinsicID() == Intrinsic::vastart
)
103 /// \brief Build a set of blocks to extract if the input blocks are viable.
104 static SetVector
<BasicBlock
*>
105 buildExtractionBlockSet(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
) {
106 assert(!BBs
.empty() && "The set of blocks to extract must be non-empty");
107 SetVector
<BasicBlock
*> Result
;
109 // Loop over the blocks, adding them to our set-vector, and aborting with an
110 // empty set if we encounter invalid blocks.
111 for (BasicBlock
*BB
: BBs
) {
113 // If this block is dead, don't process it.
114 if (DT
&& !DT
->isReachableFromEntry(BB
))
117 if (!Result
.insert(BB
))
118 llvm_unreachable("Repeated basic blocks in extraction input");
119 if (!CodeExtractor::isBlockValidForExtraction(*BB
)) {
126 for (SetVector
<BasicBlock
*>::iterator I
= std::next(Result
.begin()),
129 for (pred_iterator PI
= pred_begin(*I
), PE
= pred_end(*I
);
131 assert(Result
.count(*PI
) &&
132 "No blocks in this region may have entries from outside the region"
133 " except for the first block!");
139 CodeExtractor::CodeExtractor(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
140 bool AggregateArgs
, BlockFrequencyInfo
*BFI
,
141 BranchProbabilityInfo
*BPI
)
142 : DT(DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
143 BPI(BPI
), Blocks(buildExtractionBlockSet(BBs
, DT
)), NumExitBlocks(~0U) {}
145 CodeExtractor::CodeExtractor(DominatorTree
&DT
, Loop
&L
, bool AggregateArgs
,
146 BlockFrequencyInfo
*BFI
,
147 BranchProbabilityInfo
*BPI
)
148 : DT(&DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
149 BPI(BPI
), Blocks(buildExtractionBlockSet(L
.getBlocks(), &DT
)),
150 NumExitBlocks(~0U) {}
152 /// definedInRegion - Return true if the specified value is defined in the
153 /// extracted region.
154 static bool definedInRegion(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
155 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
156 if (Blocks
.count(I
->getParent()))
161 /// definedInCaller - Return true if the specified value is defined in the
162 /// function being code extracted, but not in the region being extracted.
163 /// These values must be passed in as live-ins to the function.
164 static bool definedInCaller(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
165 if (isa
<Argument
>(V
)) return true;
166 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
167 if (!Blocks
.count(I
->getParent()))
172 static BasicBlock
*getCommonExitBlock(const SetVector
<BasicBlock
*> &Blocks
) {
173 BasicBlock
*CommonExitBlock
= nullptr;
174 auto hasNonCommonExitSucc
= [&](BasicBlock
*Block
) {
175 for (auto *Succ
: successors(Block
)) {
176 // Internal edges, ok.
177 if (Blocks
.count(Succ
))
179 if (!CommonExitBlock
) {
180 CommonExitBlock
= Succ
;
183 if (CommonExitBlock
== Succ
)
191 if (any_of(Blocks
, hasNonCommonExitSucc
))
194 return CommonExitBlock
;
197 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
198 Instruction
*Addr
) const {
199 AllocaInst
*AI
= cast
<AllocaInst
>(Addr
->stripInBoundsConstantOffsets());
200 Function
*Func
= (*Blocks
.begin())->getParent();
201 for (BasicBlock
&BB
: *Func
) {
202 if (Blocks
.count(&BB
))
204 for (Instruction
&II
: BB
) {
206 if (isa
<DbgInfoIntrinsic
>(II
))
209 unsigned Opcode
= II
.getOpcode();
210 Value
*MemAddr
= nullptr;
212 case Instruction::Store
:
213 case Instruction::Load
: {
214 if (Opcode
== Instruction::Store
) {
215 StoreInst
*SI
= cast
<StoreInst
>(&II
);
216 MemAddr
= SI
->getPointerOperand();
218 LoadInst
*LI
= cast
<LoadInst
>(&II
);
219 MemAddr
= LI
->getPointerOperand();
221 // Global variable can not be aliased with locals.
222 if (dyn_cast
<Constant
>(MemAddr
))
224 Value
*Base
= MemAddr
->stripInBoundsConstantOffsets();
225 if (!dyn_cast
<AllocaInst
>(Base
) || Base
== AI
)
230 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(&II
);
232 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_start
||
233 IntrInst
->getIntrinsicID() == Intrinsic::lifetime_end
)
237 // Treat all the other cases conservatively if it has side effects.
238 if (II
.mayHaveSideEffects())
249 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock
*CommonExitBlock
) {
250 BasicBlock
*SinglePredFromOutlineRegion
= nullptr;
251 assert(!Blocks
.count(CommonExitBlock
) &&
252 "Expect a block outside the region!");
253 for (auto *Pred
: predecessors(CommonExitBlock
)) {
254 if (!Blocks
.count(Pred
))
256 if (!SinglePredFromOutlineRegion
) {
257 SinglePredFromOutlineRegion
= Pred
;
258 } else if (SinglePredFromOutlineRegion
!= Pred
) {
259 SinglePredFromOutlineRegion
= nullptr;
264 if (SinglePredFromOutlineRegion
)
265 return SinglePredFromOutlineRegion
;
268 auto getFirstPHI
= [](BasicBlock
*BB
) {
269 BasicBlock::iterator I
= BB
->begin();
270 PHINode
*FirstPhi
= nullptr;
271 while (I
!= BB
->end()) {
272 PHINode
*Phi
= dyn_cast
<PHINode
>(I
);
282 // If there are any phi nodes, the single pred either exists or has already
283 // be created before code extraction.
284 assert(!getFirstPHI(CommonExitBlock
) && "Phi not expected");
287 BasicBlock
*NewExitBlock
= CommonExitBlock
->splitBasicBlock(
288 CommonExitBlock
->getFirstNonPHI()->getIterator());
290 for (auto *Pred
: predecessors(CommonExitBlock
)) {
291 if (Blocks
.count(Pred
))
293 Pred
->getTerminator()->replaceUsesOfWith(CommonExitBlock
, NewExitBlock
);
295 // Now add the old exit block to the outline region.
296 Blocks
.insert(CommonExitBlock
);
297 return CommonExitBlock
;
300 void CodeExtractor::findAllocas(ValueSet
&SinkCands
, ValueSet
&HoistCands
,
301 BasicBlock
*&ExitBlock
) const {
302 Function
*Func
= (*Blocks
.begin())->getParent();
303 ExitBlock
= getCommonExitBlock(Blocks
);
305 for (BasicBlock
&BB
: *Func
) {
306 if (Blocks
.count(&BB
))
308 for (Instruction
&II
: BB
) {
309 auto *AI
= dyn_cast
<AllocaInst
>(&II
);
313 // Find the pair of life time markers for address 'Addr' that are either
314 // defined inside the outline region or can legally be shrinkwrapped into
315 // the outline region. If there are not other untracked uses of the
316 // address, return the pair of markers if found; otherwise return a pair
318 auto GetLifeTimeMarkers
=
319 [&](Instruction
*Addr
, bool &SinkLifeStart
,
320 bool &HoistLifeEnd
) -> std::pair
<Instruction
*, Instruction
*> {
321 Instruction
*LifeStart
= nullptr, *LifeEnd
= nullptr;
323 for (User
*U
: Addr
->users()) {
324 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(U
);
326 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_start
) {
327 // Do not handle the case where AI has multiple start markers.
329 return std::make_pair
<Instruction
*>(nullptr, nullptr);
330 LifeStart
= IntrInst
;
332 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_end
) {
334 return std::make_pair
<Instruction
*>(nullptr, nullptr);
339 // Find untracked uses of the address, bail.
340 if (!definedInRegion(Blocks
, U
))
341 return std::make_pair
<Instruction
*>(nullptr, nullptr);
344 if (!LifeStart
|| !LifeEnd
)
345 return std::make_pair
<Instruction
*>(nullptr, nullptr);
347 SinkLifeStart
= !definedInRegion(Blocks
, LifeStart
);
348 HoistLifeEnd
= !definedInRegion(Blocks
, LifeEnd
);
349 // Do legality Check.
350 if ((SinkLifeStart
|| HoistLifeEnd
) &&
351 !isLegalToShrinkwrapLifetimeMarkers(Addr
))
352 return std::make_pair
<Instruction
*>(nullptr, nullptr);
354 // Check to see if we have a place to do hoisting, if not, bail.
355 if (HoistLifeEnd
&& !ExitBlock
)
356 return std::make_pair
<Instruction
*>(nullptr, nullptr);
358 return std::make_pair(LifeStart
, LifeEnd
);
361 bool SinkLifeStart
= false, HoistLifeEnd
= false;
362 auto Markers
= GetLifeTimeMarkers(AI
, SinkLifeStart
, HoistLifeEnd
);
366 SinkCands
.insert(Markers
.first
);
367 SinkCands
.insert(AI
);
369 HoistCands
.insert(Markers
.second
);
373 // Follow the bitcast.
374 Instruction
*MarkerAddr
= nullptr;
375 for (User
*U
: AI
->users()) {
377 if (U
->stripInBoundsConstantOffsets() == AI
) {
378 SinkLifeStart
= false;
379 HoistLifeEnd
= false;
380 Instruction
*Bitcast
= cast
<Instruction
>(U
);
381 Markers
= GetLifeTimeMarkers(Bitcast
, SinkLifeStart
, HoistLifeEnd
);
383 MarkerAddr
= Bitcast
;
388 // Found unknown use of AI.
389 if (!definedInRegion(Blocks
, U
)) {
390 MarkerAddr
= nullptr;
397 SinkCands
.insert(Markers
.first
);
398 if (!definedInRegion(Blocks
, MarkerAddr
))
399 SinkCands
.insert(MarkerAddr
);
400 SinkCands
.insert(AI
);
402 HoistCands
.insert(Markers
.second
);
408 void CodeExtractor::findInputsOutputs(ValueSet
&Inputs
, ValueSet
&Outputs
,
409 const ValueSet
&SinkCands
) const {
411 for (BasicBlock
*BB
: Blocks
) {
412 // If a used value is defined outside the region, it's an input. If an
413 // instruction is used outside the region, it's an output.
414 for (Instruction
&II
: *BB
) {
415 for (User::op_iterator OI
= II
.op_begin(), OE
= II
.op_end(); OI
!= OE
;
418 if (!SinkCands
.count(V
) && definedInCaller(Blocks
, V
))
422 for (User
*U
: II
.users())
423 if (!definedInRegion(Blocks
, U
)) {
431 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
432 /// region, we need to split the entry block of the region so that the PHI node
433 /// is easier to deal with.
434 void CodeExtractor::severSplitPHINodes(BasicBlock
*&Header
) {
435 unsigned NumPredsFromRegion
= 0;
436 unsigned NumPredsOutsideRegion
= 0;
438 if (Header
!= &Header
->getParent()->getEntryBlock()) {
439 PHINode
*PN
= dyn_cast
<PHINode
>(Header
->begin());
440 if (!PN
) return; // No PHI nodes.
442 // If the header node contains any PHI nodes, check to see if there is more
443 // than one entry from outside the region. If so, we need to sever the
444 // header block into two.
445 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
446 if (Blocks
.count(PN
->getIncomingBlock(i
)))
447 ++NumPredsFromRegion
;
449 ++NumPredsOutsideRegion
;
451 // If there is one (or fewer) predecessor from outside the region, we don't
452 // need to do anything special.
453 if (NumPredsOutsideRegion
<= 1) return;
456 // Otherwise, we need to split the header block into two pieces: one
457 // containing PHI nodes merging values from outside of the region, and a
458 // second that contains all of the code for the block and merges back any
459 // incoming values from inside of the region.
460 BasicBlock
*NewBB
= llvm::SplitBlock(Header
, Header
->getFirstNonPHI(), DT
);
462 // We only want to code extract the second block now, and it becomes the new
463 // header of the region.
464 BasicBlock
*OldPred
= Header
;
465 Blocks
.remove(OldPred
);
466 Blocks
.insert(NewBB
);
469 // Okay, now we need to adjust the PHI nodes and any branches from within the
470 // region to go to the new header block instead of the old header block.
471 if (NumPredsFromRegion
) {
472 PHINode
*PN
= cast
<PHINode
>(OldPred
->begin());
473 // Loop over all of the predecessors of OldPred that are in the region,
474 // changing them to branch to NewBB instead.
475 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
476 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
477 TerminatorInst
*TI
= PN
->getIncomingBlock(i
)->getTerminator();
478 TI
->replaceUsesOfWith(OldPred
, NewBB
);
481 // Okay, everything within the region is now branching to the right block, we
482 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
483 BasicBlock::iterator AfterPHIs
;
484 for (AfterPHIs
= OldPred
->begin(); isa
<PHINode
>(AfterPHIs
); ++AfterPHIs
) {
485 PHINode
*PN
= cast
<PHINode
>(AfterPHIs
);
486 // Create a new PHI node in the new region, which has an incoming value
487 // from OldPred of PN.
488 PHINode
*NewPN
= PHINode::Create(PN
->getType(), 1 + NumPredsFromRegion
,
489 PN
->getName() + ".ce", &NewBB
->front());
490 PN
->replaceAllUsesWith(NewPN
);
491 NewPN
->addIncoming(PN
, OldPred
);
493 // Loop over all of the incoming value in PN, moving them to NewPN if they
494 // are from the extracted region.
495 for (unsigned i
= 0; i
!= PN
->getNumIncomingValues(); ++i
) {
496 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
497 NewPN
->addIncoming(PN
->getIncomingValue(i
), PN
->getIncomingBlock(i
));
498 PN
->removeIncomingValue(i
);
506 void CodeExtractor::splitReturnBlocks() {
507 for (BasicBlock
*Block
: Blocks
)
508 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(Block
->getTerminator())) {
510 Block
->splitBasicBlock(RI
->getIterator(), Block
->getName() + ".ret");
512 // Old dominates New. New node dominates all other nodes dominated
514 DomTreeNode
*OldNode
= DT
->getNode(Block
);
515 SmallVector
<DomTreeNode
*, 8> Children(OldNode
->begin(),
518 DomTreeNode
*NewNode
= DT
->addNewBlock(New
, Block
);
520 for (DomTreeNode
*I
: Children
)
521 DT
->changeImmediateDominator(I
, NewNode
);
526 /// constructFunction - make a function based on inputs and outputs, as follows:
527 /// f(in0, ..., inN, out0, ..., outN)
529 Function
*CodeExtractor::constructFunction(const ValueSet
&inputs
,
530 const ValueSet
&outputs
,
532 BasicBlock
*newRootNode
,
533 BasicBlock
*newHeader
,
534 Function
*oldFunction
,
536 DEBUG(dbgs() << "inputs: " << inputs
.size() << "\n");
537 DEBUG(dbgs() << "outputs: " << outputs
.size() << "\n");
539 // This function returns unsigned, outputs will go back by reference.
540 switch (NumExitBlocks
) {
542 case 1: RetTy
= Type::getVoidTy(header
->getContext()); break;
543 case 2: RetTy
= Type::getInt1Ty(header
->getContext()); break;
544 default: RetTy
= Type::getInt16Ty(header
->getContext()); break;
547 std::vector
<Type
*> paramTy
;
549 // Add the types of the input values to the function's argument list
550 for (Value
*value
: inputs
) {
551 DEBUG(dbgs() << "value used in func: " << *value
<< "\n");
552 paramTy
.push_back(value
->getType());
555 // Add the types of the output values to the function's argument list.
556 for (Value
*output
: outputs
) {
557 DEBUG(dbgs() << "instr used in func: " << *output
<< "\n");
559 paramTy
.push_back(output
->getType());
561 paramTy
.push_back(PointerType::getUnqual(output
->getType()));
565 dbgs() << "Function type: " << *RetTy
<< " f(";
566 for (Type
*i
: paramTy
)
567 dbgs() << *i
<< ", ";
571 StructType
*StructTy
;
572 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
573 StructTy
= StructType::get(M
->getContext(), paramTy
);
575 paramTy
.push_back(PointerType::getUnqual(StructTy
));
577 FunctionType
*funcType
=
578 FunctionType::get(RetTy
, paramTy
, false);
580 // Create the new function
581 Function
*newFunction
= Function::Create(funcType
,
582 GlobalValue::InternalLinkage
,
583 oldFunction
->getName() + "_" +
584 header
->getName(), M
);
585 // If the old function is no-throw, so is the new one.
586 if (oldFunction
->doesNotThrow())
587 newFunction
->setDoesNotThrow();
589 // Inherit the uwtable attribute if we need to.
590 if (oldFunction
->hasUWTable())
591 newFunction
->setHasUWTable();
593 // Inherit all of the target dependent attributes.
594 // (e.g. If the extracted region contains a call to an x86.sse
595 // instruction we need to make sure that the extracted region has the
596 // "target-features" attribute allowing it to be lowered.
597 // FIXME: This should be changed to check to see if a specific
598 // attribute can not be inherited.
599 AttrBuilder
AB(oldFunction
->getAttributes().getFnAttributes());
600 for (const auto &Attr
: AB
.td_attrs())
601 newFunction
->addFnAttr(Attr
.first
, Attr
.second
);
603 newFunction
->getBasicBlockList().push_back(newRootNode
);
605 // Create an iterator to name all of the arguments we inserted.
606 Function::arg_iterator AI
= newFunction
->arg_begin();
608 // Rewrite all users of the inputs in the extracted region to use the
609 // arguments (or appropriate addressing into struct) instead.
610 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
614 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(header
->getContext()));
615 Idx
[1] = ConstantInt::get(Type::getInt32Ty(header
->getContext()), i
);
616 TerminatorInst
*TI
= newFunction
->begin()->getTerminator();
617 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
618 StructTy
, &*AI
, Idx
, "gep_" + inputs
[i
]->getName(), TI
);
619 RewriteVal
= new LoadInst(GEP
, "loadgep_" + inputs
[i
]->getName(), TI
);
623 std::vector
<User
*> Users(inputs
[i
]->user_begin(), inputs
[i
]->user_end());
624 for (User
*use
: Users
)
625 if (Instruction
*inst
= dyn_cast
<Instruction
>(use
))
626 if (Blocks
.count(inst
->getParent()))
627 inst
->replaceUsesOfWith(inputs
[i
], RewriteVal
);
630 // Set names for input and output arguments.
631 if (!AggregateArgs
) {
632 AI
= newFunction
->arg_begin();
633 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
, ++AI
)
634 AI
->setName(inputs
[i
]->getName());
635 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
, ++AI
)
636 AI
->setName(outputs
[i
]->getName()+".out");
639 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
640 // within the new function. This must be done before we lose track of which
641 // blocks were originally in the code region.
642 std::vector
<User
*> Users(header
->user_begin(), header
->user_end());
643 for (unsigned i
= 0, e
= Users
.size(); i
!= e
; ++i
)
644 // The BasicBlock which contains the branch is not in the region
645 // modify the branch target to a new block
646 if (TerminatorInst
*TI
= dyn_cast
<TerminatorInst
>(Users
[i
]))
647 if (!Blocks
.count(TI
->getParent()) &&
648 TI
->getParent()->getParent() == oldFunction
)
649 TI
->replaceUsesOfWith(header
, newHeader
);
654 /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI
655 /// that uses the value within the basic block, and return the predecessor
656 /// block associated with that use, or return 0 if none is found.
657 static BasicBlock
* FindPhiPredForUseInBlock(Value
* Used
, BasicBlock
* BB
) {
658 for (Use
&U
: Used
->uses()) {
659 PHINode
*P
= dyn_cast
<PHINode
>(U
.getUser());
660 if (P
&& P
->getParent() == BB
)
661 return P
->getIncomingBlock(U
);
667 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
668 /// the call instruction, splitting any PHI nodes in the header block as
671 emitCallAndSwitchStatement(Function
*newFunction
, BasicBlock
*codeReplacer
,
672 ValueSet
&inputs
, ValueSet
&outputs
) {
673 // Emit a call to the new function, passing in: *pointer to struct (if
674 // aggregating parameters), or plan inputs and allocated memory for outputs
675 std::vector
<Value
*> params
, StructValues
, ReloadOutputs
, Reloads
;
677 Module
*M
= newFunction
->getParent();
678 LLVMContext
&Context
= M
->getContext();
679 const DataLayout
&DL
= M
->getDataLayout();
681 // Add inputs as params, or to be filled into the struct
682 for (Value
*input
: inputs
)
684 StructValues
.push_back(input
);
686 params
.push_back(input
);
688 // Create allocas for the outputs
689 for (Value
*output
: outputs
) {
691 StructValues
.push_back(output
);
694 new AllocaInst(output
->getType(), DL
.getAllocaAddrSpace(),
695 nullptr, output
->getName() + ".loc",
696 &codeReplacer
->getParent()->front().front());
697 ReloadOutputs
.push_back(alloca
);
698 params
.push_back(alloca
);
702 StructType
*StructArgTy
= nullptr;
703 AllocaInst
*Struct
= nullptr;
704 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
705 std::vector
<Type
*> ArgTypes
;
706 for (ValueSet::iterator v
= StructValues
.begin(),
707 ve
= StructValues
.end(); v
!= ve
; ++v
)
708 ArgTypes
.push_back((*v
)->getType());
710 // Allocate a struct at the beginning of this function
711 StructArgTy
= StructType::get(newFunction
->getContext(), ArgTypes
);
712 Struct
= new AllocaInst(StructArgTy
, DL
.getAllocaAddrSpace(), nullptr,
714 &codeReplacer
->getParent()->front().front());
715 params
.push_back(Struct
);
717 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
719 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
720 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), i
);
721 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
722 StructArgTy
, Struct
, Idx
, "gep_" + StructValues
[i
]->getName());
723 codeReplacer
->getInstList().push_back(GEP
);
724 StoreInst
*SI
= new StoreInst(StructValues
[i
], GEP
);
725 codeReplacer
->getInstList().push_back(SI
);
729 // Emit the call to the function
730 CallInst
*call
= CallInst::Create(newFunction
, params
,
731 NumExitBlocks
> 1 ? "targetBlock" : "");
732 codeReplacer
->getInstList().push_back(call
);
734 Function::arg_iterator OutputArgBegin
= newFunction
->arg_begin();
735 unsigned FirstOut
= inputs
.size();
737 std::advance(OutputArgBegin
, inputs
.size());
739 // Reload the outputs passed in by reference
740 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
741 Value
*Output
= nullptr;
744 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
745 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
746 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
747 StructArgTy
, Struct
, Idx
, "gep_reload_" + outputs
[i
]->getName());
748 codeReplacer
->getInstList().push_back(GEP
);
751 Output
= ReloadOutputs
[i
];
753 LoadInst
*load
= new LoadInst(Output
, outputs
[i
]->getName()+".reload");
754 Reloads
.push_back(load
);
755 codeReplacer
->getInstList().push_back(load
);
756 std::vector
<User
*> Users(outputs
[i
]->user_begin(), outputs
[i
]->user_end());
757 for (unsigned u
= 0, e
= Users
.size(); u
!= e
; ++u
) {
758 Instruction
*inst
= cast
<Instruction
>(Users
[u
]);
759 if (!Blocks
.count(inst
->getParent()))
760 inst
->replaceUsesOfWith(outputs
[i
], load
);
764 // Now we can emit a switch statement using the call as a value.
765 SwitchInst
*TheSwitch
=
766 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context
)),
767 codeReplacer
, 0, codeReplacer
);
769 // Since there may be multiple exits from the original region, make the new
770 // function return an unsigned, switch on that number. This loop iterates
771 // over all of the blocks in the extracted region, updating any terminator
772 // instructions in the to-be-extracted region that branch to blocks that are
773 // not in the region to be extracted.
774 std::map
<BasicBlock
*, BasicBlock
*> ExitBlockMap
;
776 unsigned switchVal
= 0;
777 for (BasicBlock
*Block
: Blocks
) {
778 TerminatorInst
*TI
= Block
->getTerminator();
779 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
!= e
; ++i
)
780 if (!Blocks
.count(TI
->getSuccessor(i
))) {
781 BasicBlock
*OldTarget
= TI
->getSuccessor(i
);
782 // add a new basic block which returns the appropriate value
783 BasicBlock
*&NewTarget
= ExitBlockMap
[OldTarget
];
785 // If we don't already have an exit stub for this non-extracted
786 // destination, create one now!
787 NewTarget
= BasicBlock::Create(Context
,
788 OldTarget
->getName() + ".exitStub",
790 unsigned SuccNum
= switchVal
++;
792 Value
*brVal
= nullptr;
793 switch (NumExitBlocks
) {
795 case 1: break; // No value needed.
796 case 2: // Conditional branch, return a bool
797 brVal
= ConstantInt::get(Type::getInt1Ty(Context
), !SuccNum
);
800 brVal
= ConstantInt::get(Type::getInt16Ty(Context
), SuccNum
);
804 ReturnInst
*NTRet
= ReturnInst::Create(Context
, brVal
, NewTarget
);
806 // Update the switch instruction.
807 TheSwitch
->addCase(ConstantInt::get(Type::getInt16Ty(Context
),
811 // Restore values just before we exit
812 Function::arg_iterator OAI
= OutputArgBegin
;
813 for (unsigned out
= 0, e
= outputs
.size(); out
!= e
; ++out
) {
814 // For an invoke, the normal destination is the only one that is
815 // dominated by the result of the invocation
816 BasicBlock
*DefBlock
= cast
<Instruction
>(outputs
[out
])->getParent();
818 bool DominatesDef
= true;
820 BasicBlock
*NormalDest
= nullptr;
821 if (auto *Invoke
= dyn_cast
<InvokeInst
>(outputs
[out
]))
822 NormalDest
= Invoke
->getNormalDest();
825 DefBlock
= NormalDest
;
827 // Make sure we are looking at the original successor block, not
828 // at a newly inserted exit block, which won't be in the dominator
830 for (const auto &I
: ExitBlockMap
)
831 if (DefBlock
== I
.second
) {
836 // In the extract block case, if the block we are extracting ends
837 // with an invoke instruction, make sure that we don't emit a
838 // store of the invoke value for the unwind block.
839 if (!DT
&& DefBlock
!= OldTarget
)
840 DominatesDef
= false;
844 DominatesDef
= DT
->dominates(DefBlock
, OldTarget
);
846 // If the output value is used by a phi in the target block,
847 // then we need to test for dominance of the phi's predecessor
848 // instead. Unfortunately, this a little complicated since we
849 // have already rewritten uses of the value to uses of the reload.
850 BasicBlock
* pred
= FindPhiPredForUseInBlock(Reloads
[out
],
852 if (pred
&& DT
&& DT
->dominates(DefBlock
, pred
))
859 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
860 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
),
862 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
863 StructArgTy
, &*OAI
, Idx
, "gep_" + outputs
[out
]->getName(),
865 new StoreInst(outputs
[out
], GEP
, NTRet
);
867 new StoreInst(outputs
[out
], &*OAI
, NTRet
);
870 // Advance output iterator even if we don't emit a store
871 if (!AggregateArgs
) ++OAI
;
875 // rewrite the original branch instruction with this new target
876 TI
->setSuccessor(i
, NewTarget
);
880 // Now that we've done the deed, simplify the switch instruction.
881 Type
*OldFnRetTy
= TheSwitch
->getParent()->getParent()->getReturnType();
882 switch (NumExitBlocks
) {
884 // There are no successors (the block containing the switch itself), which
885 // means that previously this was the last part of the function, and hence
886 // this should be rewritten as a `ret'
888 // Check if the function should return a value
889 if (OldFnRetTy
->isVoidTy()) {
890 ReturnInst::Create(Context
, nullptr, TheSwitch
); // Return void
891 } else if (OldFnRetTy
== TheSwitch
->getCondition()->getType()) {
892 // return what we have
893 ReturnInst::Create(Context
, TheSwitch
->getCondition(), TheSwitch
);
895 // Otherwise we must have code extracted an unwind or something, just
896 // return whatever we want.
897 ReturnInst::Create(Context
,
898 Constant::getNullValue(OldFnRetTy
), TheSwitch
);
901 TheSwitch
->eraseFromParent();
904 // Only a single destination, change the switch into an unconditional
906 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
);
907 TheSwitch
->eraseFromParent();
910 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
->getSuccessor(2),
912 TheSwitch
->eraseFromParent();
915 // Otherwise, make the default destination of the switch instruction be one
916 // of the other successors.
917 TheSwitch
->setCondition(call
);
918 TheSwitch
->setDefaultDest(TheSwitch
->getSuccessor(NumExitBlocks
));
919 // Remove redundant case
920 TheSwitch
->removeCase(SwitchInst::CaseIt(TheSwitch
, NumExitBlocks
-1));
925 void CodeExtractor::moveCodeToFunction(Function
*newFunction
) {
926 Function
*oldFunc
= (*Blocks
.begin())->getParent();
927 Function::BasicBlockListType
&oldBlocks
= oldFunc
->getBasicBlockList();
928 Function::BasicBlockListType
&newBlocks
= newFunction
->getBasicBlockList();
930 for (BasicBlock
*Block
: Blocks
) {
931 // Delete the basic block from the old function, and the list of blocks
932 oldBlocks
.remove(Block
);
934 // Insert this basic block into the new function
935 newBlocks
.push_back(Block
);
939 void CodeExtractor::calculateNewCallTerminatorWeights(
940 BasicBlock
*CodeReplacer
,
941 DenseMap
<BasicBlock
*, BlockFrequency
> &ExitWeights
,
942 BranchProbabilityInfo
*BPI
) {
943 typedef BlockFrequencyInfoImplBase::Distribution Distribution
;
944 typedef BlockFrequencyInfoImplBase::BlockNode BlockNode
;
946 // Update the branch weights for the exit block.
947 TerminatorInst
*TI
= CodeReplacer
->getTerminator();
948 SmallVector
<unsigned, 8> BranchWeights(TI
->getNumSuccessors(), 0);
950 // Block Frequency distribution with dummy node.
951 Distribution BranchDist
;
953 // Add each of the frequencies of the successors.
954 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
< e
; ++i
) {
955 BlockNode
ExitNode(i
);
956 uint64_t ExitFreq
= ExitWeights
[TI
->getSuccessor(i
)].getFrequency();
958 BranchDist
.addExit(ExitNode
, ExitFreq
);
960 BPI
->setEdgeProbability(CodeReplacer
, i
, BranchProbability::getZero());
963 // Check for no total weight.
964 if (BranchDist
.Total
== 0)
967 // Normalize the distribution so that they can fit in unsigned.
968 BranchDist
.normalize();
970 // Create normalized branch weights and set the metadata.
971 for (unsigned I
= 0, E
= BranchDist
.Weights
.size(); I
< E
; ++I
) {
972 const auto &Weight
= BranchDist
.Weights
[I
];
974 // Get the weight and update the current BFI.
975 BranchWeights
[Weight
.TargetNode
.Index
] = Weight
.Amount
;
976 BranchProbability
BP(Weight
.Amount
, BranchDist
.Total
);
977 BPI
->setEdgeProbability(CodeReplacer
, Weight
.TargetNode
.Index
, BP
);
980 LLVMContext::MD_prof
,
981 MDBuilder(TI
->getContext()).createBranchWeights(BranchWeights
));
984 Function
*CodeExtractor::extractCodeRegion() {
988 ValueSet inputs
, outputs
, SinkingCands
, HoistingCands
;
989 BasicBlock
*CommonExit
= nullptr;
991 // Assumption: this is a single-entry code region, and the header is the first
992 // block in the region.
993 BasicBlock
*header
= *Blocks
.begin();
995 // Calculate the entry frequency of the new function before we change the root
997 BlockFrequency EntryFreq
;
999 assert(BPI
&& "Both BPI and BFI are required to preserve profile info");
1000 for (BasicBlock
*Pred
: predecessors(header
)) {
1001 if (Blocks
.count(Pred
))
1004 BFI
->getBlockFreq(Pred
) * BPI
->getEdgeProbability(Pred
, header
);
1008 // If we have to split PHI nodes or the entry block, do so now.
1009 severSplitPHINodes(header
);
1011 // If we have any return instructions in the region, split those blocks so
1012 // that the return is not in the region.
1013 splitReturnBlocks();
1015 Function
*oldFunction
= header
->getParent();
1017 // This takes place of the original loop
1018 BasicBlock
*codeReplacer
= BasicBlock::Create(header
->getContext(),
1019 "codeRepl", oldFunction
,
1022 // The new function needs a root node because other nodes can branch to the
1023 // head of the region, but the entry node of a function cannot have preds.
1024 BasicBlock
*newFuncRoot
= BasicBlock::Create(header
->getContext(),
1026 newFuncRoot
->getInstList().push_back(BranchInst::Create(header
));
1028 findAllocas(SinkingCands
, HoistingCands
, CommonExit
);
1029 assert(HoistingCands
.empty() || CommonExit
);
1031 // Find inputs to, outputs from the code region.
1032 findInputsOutputs(inputs
, outputs
, SinkingCands
);
1034 // Now sink all instructions which only have non-phi uses inside the region
1035 for (auto *II
: SinkingCands
)
1036 cast
<Instruction
>(II
)->moveBefore(*newFuncRoot
,
1037 newFuncRoot
->getFirstInsertionPt());
1039 if (!HoistingCands
.empty()) {
1040 auto *HoistToBlock
= findOrCreateBlockForHoisting(CommonExit
);
1041 Instruction
*TI
= HoistToBlock
->getTerminator();
1042 for (auto *II
: HoistingCands
)
1043 cast
<Instruction
>(II
)->moveBefore(TI
);
1046 // Calculate the exit blocks for the extracted region and the total exit
1047 // weights for each of those blocks.
1048 DenseMap
<BasicBlock
*, BlockFrequency
> ExitWeights
;
1049 SmallPtrSet
<BasicBlock
*, 1> ExitBlocks
;
1050 for (BasicBlock
*Block
: Blocks
) {
1051 for (succ_iterator SI
= succ_begin(Block
), SE
= succ_end(Block
); SI
!= SE
;
1053 if (!Blocks
.count(*SI
)) {
1054 // Update the branch weight for this successor.
1056 BlockFrequency
&BF
= ExitWeights
[*SI
];
1057 BF
+= BFI
->getBlockFreq(Block
) * BPI
->getEdgeProbability(Block
, *SI
);
1059 ExitBlocks
.insert(*SI
);
1063 NumExitBlocks
= ExitBlocks
.size();
1065 // Construct new function based on inputs/outputs & add allocas for all defs.
1066 Function
*newFunction
= constructFunction(inputs
, outputs
, header
,
1068 codeReplacer
, oldFunction
,
1069 oldFunction
->getParent());
1071 // Update the entry count of the function.
1073 Optional
<uint64_t> EntryCount
=
1074 BFI
->getProfileCountFromFreq(EntryFreq
.getFrequency());
1075 if (EntryCount
.hasValue())
1076 newFunction
->setEntryCount(EntryCount
.getValue());
1077 BFI
->setBlockFreq(codeReplacer
, EntryFreq
.getFrequency());
1080 emitCallAndSwitchStatement(newFunction
, codeReplacer
, inputs
, outputs
);
1082 moveCodeToFunction(newFunction
);
1084 // Update the branch weights for the exit block.
1085 if (BFI
&& NumExitBlocks
> 1)
1086 calculateNewCallTerminatorWeights(codeReplacer
, ExitWeights
, BPI
);
1088 // Loop over all of the PHI nodes in the header block, and change any
1089 // references to the old incoming edge to be the new incoming edge.
1090 for (BasicBlock::iterator I
= header
->begin(); isa
<PHINode
>(I
); ++I
) {
1091 PHINode
*PN
= cast
<PHINode
>(I
);
1092 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
1093 if (!Blocks
.count(PN
->getIncomingBlock(i
)))
1094 PN
->setIncomingBlock(i
, newFuncRoot
);
1097 // Look at all successors of the codeReplacer block. If any of these blocks
1098 // had PHI nodes in them, we need to update the "from" block to be the code
1099 // replacer, not the original block in the extracted region.
1100 std::vector
<BasicBlock
*> Succs(succ_begin(codeReplacer
),
1101 succ_end(codeReplacer
));
1102 for (unsigned i
= 0, e
= Succs
.size(); i
!= e
; ++i
)
1103 for (BasicBlock::iterator I
= Succs
[i
]->begin(); isa
<PHINode
>(I
); ++I
) {
1104 PHINode
*PN
= cast
<PHINode
>(I
);
1105 std::set
<BasicBlock
*> ProcessedPreds
;
1106 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
1107 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
1108 if (ProcessedPreds
.insert(PN
->getIncomingBlock(i
)).second
)
1109 PN
->setIncomingBlock(i
, codeReplacer
);
1111 // There were multiple entries in the PHI for this block, now there
1112 // is only one, so remove the duplicated entries.
1113 PN
->removeIncomingValue(i
, false);
1119 DEBUG(if (verifyFunction(*newFunction
))
1120 report_fatal_error("verifyFunction failed!"));