1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the interface to tear out a code region, such as an
10 // individual loop or a parallel section, into a new function, replacing it with
11 // a call to the new function.
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Transforms/Utils/CodeExtractor.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Analysis/AssumptionCache.h"
24 #include "llvm/Analysis/BlockFrequencyInfo.h"
25 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26 #include "llvm/Analysis/BranchProbabilityInfo.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/IR/Argument.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DataLayout.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/GlobalValue.h"
39 #include "llvm/IR/InstrTypes.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/Intrinsics.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/Module.h"
47 #include "llvm/IR/PatternMatch.h"
48 #include "llvm/IR/Type.h"
49 #include "llvm/IR/User.h"
50 #include "llvm/IR/Value.h"
51 #include "llvm/IR/Verifier.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/BlockFrequency.h"
54 #include "llvm/Support/BranchProbability.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
61 #include "llvm/Transforms/Utils/Local.h"
71 using namespace llvm::PatternMatch
;
72 using ProfileCount
= Function::ProfileCount
;
74 #define DEBUG_TYPE "code-extractor"
76 // Provide a command-line option to aggregate function arguments into a struct
77 // for functions produced by the code extractor. This is useful when converting
78 // extracted functions to pthread-based code, as only one argument (void*) can
79 // be passed in to pthread_create().
81 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden
,
82 cl::desc("Aggregate arguments to code-extracted functions"));
84 /// Test whether a block is valid for extraction.
85 static bool isBlockValidForExtraction(const BasicBlock
&BB
,
86 const SetVector
<BasicBlock
*> &Result
,
87 bool AllowVarArgs
, bool AllowAlloca
) {
88 // taking the address of a basic block moved to another function is illegal
89 if (BB
.hasAddressTaken())
92 // don't hoist code that uses another basicblock address, as it's likely to
93 // lead to unexpected behavior, like cross-function jumps
94 SmallPtrSet
<User
const *, 16> Visited
;
95 SmallVector
<User
const *, 16> ToVisit
;
97 for (Instruction
const &Inst
: BB
)
98 ToVisit
.push_back(&Inst
);
100 while (!ToVisit
.empty()) {
101 User
const *Curr
= ToVisit
.pop_back_val();
102 if (!Visited
.insert(Curr
).second
)
104 if (isa
<BlockAddress
const>(Curr
))
105 return false; // even a reference to self is likely to be not compatible
107 if (isa
<Instruction
>(Curr
) && cast
<Instruction
>(Curr
)->getParent() != &BB
)
110 for (auto const &U
: Curr
->operands()) {
111 if (auto *UU
= dyn_cast
<User
>(U
))
112 ToVisit
.push_back(UU
);
116 // If explicitly requested, allow vastart and alloca. For invoke instructions
117 // verify that extraction is valid.
118 for (BasicBlock::const_iterator I
= BB
.begin(), E
= BB
.end(); I
!= E
; ++I
) {
119 if (isa
<AllocaInst
>(I
)) {
125 if (const auto *II
= dyn_cast
<InvokeInst
>(I
)) {
126 // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
127 // must be a part of the subgraph which is being extracted.
128 if (auto *UBB
= II
->getUnwindDest())
129 if (!Result
.count(UBB
))
134 // All catch handlers of a catchswitch instruction as well as the unwind
135 // destination must be in the subgraph.
136 if (const auto *CSI
= dyn_cast
<CatchSwitchInst
>(I
)) {
137 if (auto *UBB
= CSI
->getUnwindDest())
138 if (!Result
.count(UBB
))
140 for (auto *HBB
: CSI
->handlers())
141 if (!Result
.count(const_cast<BasicBlock
*>(HBB
)))
146 // Make sure that entire catch handler is within subgraph. It is sufficient
147 // to check that catch return's block is in the list.
148 if (const auto *CPI
= dyn_cast
<CatchPadInst
>(I
)) {
149 for (const auto *U
: CPI
->users())
150 if (const auto *CRI
= dyn_cast
<CatchReturnInst
>(U
))
151 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
156 // And do similar checks for cleanup handler - the entire handler must be
157 // in subgraph which is going to be extracted. For cleanup return should
158 // additionally check that the unwind destination is also in the subgraph.
159 if (const auto *CPI
= dyn_cast
<CleanupPadInst
>(I
)) {
160 for (const auto *U
: CPI
->users())
161 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(U
))
162 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
166 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(I
)) {
167 if (auto *UBB
= CRI
->getUnwindDest())
168 if (!Result
.count(UBB
))
173 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I
)) {
174 if (const Function
*F
= CI
->getCalledFunction()) {
175 auto IID
= F
->getIntrinsicID();
176 if (IID
== Intrinsic::vastart
) {
183 // Currently, we miscompile outlined copies of eh_typid_for. There are
184 // proposals for fixing this in llvm.org/PR39545.
185 if (IID
== Intrinsic::eh_typeid_for
)
194 /// Build a set of blocks to extract if the input blocks are viable.
195 static SetVector
<BasicBlock
*>
196 buildExtractionBlockSet(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
197 bool AllowVarArgs
, bool AllowAlloca
) {
198 assert(!BBs
.empty() && "The set of blocks to extract must be non-empty");
199 SetVector
<BasicBlock
*> Result
;
201 // Loop over the blocks, adding them to our set-vector, and aborting with an
202 // empty set if we encounter invalid blocks.
203 for (BasicBlock
*BB
: BBs
) {
204 // If this block is dead, don't process it.
205 if (DT
&& !DT
->isReachableFromEntry(BB
))
208 if (!Result
.insert(BB
))
209 llvm_unreachable("Repeated basic blocks in extraction input");
212 for (auto *BB
: Result
) {
213 if (!isBlockValidForExtraction(*BB
, Result
, AllowVarArgs
, AllowAlloca
))
216 // Make sure that the first block is not a landing pad.
217 if (BB
== Result
.front()) {
219 LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
225 // All blocks other than the first must not have predecessors outside of
226 // the subgraph which is being extracted.
227 for (auto *PBB
: predecessors(BB
))
228 if (!Result
.count(PBB
)) {
230 dbgs() << "No blocks in this region may have entries from "
231 "outside the region except for the first block!\n");
239 CodeExtractor::CodeExtractor(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
240 bool AggregateArgs
, BlockFrequencyInfo
*BFI
,
241 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
242 bool AllowVarArgs
, bool AllowAlloca
,
244 : DT(DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
245 BPI(BPI
), AC(AC
), AllowVarArgs(AllowVarArgs
),
246 Blocks(buildExtractionBlockSet(BBs
, DT
, AllowVarArgs
, AllowAlloca
)),
249 CodeExtractor::CodeExtractor(DominatorTree
&DT
, Loop
&L
, bool AggregateArgs
,
250 BlockFrequencyInfo
*BFI
,
251 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
253 : DT(&DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
254 BPI(BPI
), AC(AC
), AllowVarArgs(false),
255 Blocks(buildExtractionBlockSet(L
.getBlocks(), &DT
,
256 /* AllowVarArgs */ false,
257 /* AllowAlloca */ false)),
260 /// definedInRegion - Return true if the specified value is defined in the
261 /// extracted region.
262 static bool definedInRegion(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
263 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
264 if (Blocks
.count(I
->getParent()))
269 /// definedInCaller - Return true if the specified value is defined in the
270 /// function being code extracted, but not in the region being extracted.
271 /// These values must be passed in as live-ins to the function.
272 static bool definedInCaller(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
273 if (isa
<Argument
>(V
)) return true;
274 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
275 if (!Blocks
.count(I
->getParent()))
280 static BasicBlock
*getCommonExitBlock(const SetVector
<BasicBlock
*> &Blocks
) {
281 BasicBlock
*CommonExitBlock
= nullptr;
282 auto hasNonCommonExitSucc
= [&](BasicBlock
*Block
) {
283 for (auto *Succ
: successors(Block
)) {
284 // Internal edges, ok.
285 if (Blocks
.count(Succ
))
287 if (!CommonExitBlock
) {
288 CommonExitBlock
= Succ
;
291 if (CommonExitBlock
== Succ
)
299 if (any_of(Blocks
, hasNonCommonExitSucc
))
302 return CommonExitBlock
;
305 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
306 Instruction
*Addr
) const {
307 AllocaInst
*AI
= cast
<AllocaInst
>(Addr
->stripInBoundsConstantOffsets());
308 Function
*Func
= (*Blocks
.begin())->getParent();
309 for (BasicBlock
&BB
: *Func
) {
310 if (Blocks
.count(&BB
))
312 for (Instruction
&II
: BB
) {
313 if (isa
<DbgInfoIntrinsic
>(II
))
316 unsigned Opcode
= II
.getOpcode();
317 Value
*MemAddr
= nullptr;
319 case Instruction::Store
:
320 case Instruction::Load
: {
321 if (Opcode
== Instruction::Store
) {
322 StoreInst
*SI
= cast
<StoreInst
>(&II
);
323 MemAddr
= SI
->getPointerOperand();
325 LoadInst
*LI
= cast
<LoadInst
>(&II
);
326 MemAddr
= LI
->getPointerOperand();
328 // Global variable can not be aliased with locals.
329 if (dyn_cast
<Constant
>(MemAddr
))
331 Value
*Base
= MemAddr
->stripInBoundsConstantOffsets();
332 if (!dyn_cast
<AllocaInst
>(Base
) || Base
== AI
)
337 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(&II
);
339 if (IntrInst
->isLifetimeStartOrEnd())
343 // Treat all the other cases conservatively if it has side effects.
344 if (II
.mayHaveSideEffects())
355 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock
*CommonExitBlock
) {
356 BasicBlock
*SinglePredFromOutlineRegion
= nullptr;
357 assert(!Blocks
.count(CommonExitBlock
) &&
358 "Expect a block outside the region!");
359 for (auto *Pred
: predecessors(CommonExitBlock
)) {
360 if (!Blocks
.count(Pred
))
362 if (!SinglePredFromOutlineRegion
) {
363 SinglePredFromOutlineRegion
= Pred
;
364 } else if (SinglePredFromOutlineRegion
!= Pred
) {
365 SinglePredFromOutlineRegion
= nullptr;
370 if (SinglePredFromOutlineRegion
)
371 return SinglePredFromOutlineRegion
;
374 auto getFirstPHI
= [](BasicBlock
*BB
) {
375 BasicBlock::iterator I
= BB
->begin();
376 PHINode
*FirstPhi
= nullptr;
377 while (I
!= BB
->end()) {
378 PHINode
*Phi
= dyn_cast
<PHINode
>(I
);
388 // If there are any phi nodes, the single pred either exists or has already
389 // be created before code extraction.
390 assert(!getFirstPHI(CommonExitBlock
) && "Phi not expected");
393 BasicBlock
*NewExitBlock
= CommonExitBlock
->splitBasicBlock(
394 CommonExitBlock
->getFirstNonPHI()->getIterator());
396 for (auto PI
= pred_begin(CommonExitBlock
), PE
= pred_end(CommonExitBlock
);
398 BasicBlock
*Pred
= *PI
++;
399 if (Blocks
.count(Pred
))
401 Pred
->getTerminator()->replaceUsesOfWith(CommonExitBlock
, NewExitBlock
);
403 // Now add the old exit block to the outline region.
404 Blocks
.insert(CommonExitBlock
);
405 return CommonExitBlock
;
408 void CodeExtractor::findAllocas(ValueSet
&SinkCands
, ValueSet
&HoistCands
,
409 BasicBlock
*&ExitBlock
) const {
410 Function
*Func
= (*Blocks
.begin())->getParent();
411 ExitBlock
= getCommonExitBlock(Blocks
);
413 for (BasicBlock
&BB
: *Func
) {
414 if (Blocks
.count(&BB
))
416 for (Instruction
&II
: BB
) {
417 auto *AI
= dyn_cast
<AllocaInst
>(&II
);
421 // Find the pair of life time markers for address 'Addr' that are either
422 // defined inside the outline region or can legally be shrinkwrapped into
423 // the outline region. If there are not other untracked uses of the
424 // address, return the pair of markers if found; otherwise return a pair
426 auto GetLifeTimeMarkers
=
427 [&](Instruction
*Addr
, bool &SinkLifeStart
,
428 bool &HoistLifeEnd
) -> std::pair
<Instruction
*, Instruction
*> {
429 Instruction
*LifeStart
= nullptr, *LifeEnd
= nullptr;
431 for (User
*U
: Addr
->users()) {
432 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(U
);
434 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_start
) {
435 // Do not handle the case where AI has multiple start markers.
437 return std::make_pair
<Instruction
*>(nullptr, nullptr);
438 LifeStart
= IntrInst
;
440 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_end
) {
442 return std::make_pair
<Instruction
*>(nullptr, nullptr);
447 // Find untracked uses of the address, bail.
448 if (!definedInRegion(Blocks
, U
))
449 return std::make_pair
<Instruction
*>(nullptr, nullptr);
452 if (!LifeStart
|| !LifeEnd
)
453 return std::make_pair
<Instruction
*>(nullptr, nullptr);
455 SinkLifeStart
= !definedInRegion(Blocks
, LifeStart
);
456 HoistLifeEnd
= !definedInRegion(Blocks
, LifeEnd
);
457 // Do legality Check.
458 if ((SinkLifeStart
|| HoistLifeEnd
) &&
459 !isLegalToShrinkwrapLifetimeMarkers(Addr
))
460 return std::make_pair
<Instruction
*>(nullptr, nullptr);
462 // Check to see if we have a place to do hoisting, if not, bail.
463 if (HoistLifeEnd
&& !ExitBlock
)
464 return std::make_pair
<Instruction
*>(nullptr, nullptr);
466 return std::make_pair(LifeStart
, LifeEnd
);
469 bool SinkLifeStart
= false, HoistLifeEnd
= false;
470 auto Markers
= GetLifeTimeMarkers(AI
, SinkLifeStart
, HoistLifeEnd
);
474 SinkCands
.insert(Markers
.first
);
475 SinkCands
.insert(AI
);
477 HoistCands
.insert(Markers
.second
);
481 // Follow the bitcast.
482 Instruction
*MarkerAddr
= nullptr;
483 for (User
*U
: AI
->users()) {
484 if (U
->stripInBoundsConstantOffsets() == AI
) {
485 SinkLifeStart
= false;
486 HoistLifeEnd
= false;
487 Instruction
*Bitcast
= cast
<Instruction
>(U
);
488 Markers
= GetLifeTimeMarkers(Bitcast
, SinkLifeStart
, HoistLifeEnd
);
490 MarkerAddr
= Bitcast
;
495 // Found unknown use of AI.
496 if (!definedInRegion(Blocks
, U
)) {
497 MarkerAddr
= nullptr;
504 SinkCands
.insert(Markers
.first
);
505 if (!definedInRegion(Blocks
, MarkerAddr
))
506 SinkCands
.insert(MarkerAddr
);
507 SinkCands
.insert(AI
);
509 HoistCands
.insert(Markers
.second
);
515 void CodeExtractor::findInputsOutputs(ValueSet
&Inputs
, ValueSet
&Outputs
,
516 const ValueSet
&SinkCands
) const {
517 for (BasicBlock
*BB
: Blocks
) {
518 // If a used value is defined outside the region, it's an input. If an
519 // instruction is used outside the region, it's an output.
520 for (Instruction
&II
: *BB
) {
521 for (User::op_iterator OI
= II
.op_begin(), OE
= II
.op_end(); OI
!= OE
;
524 if (!SinkCands
.count(V
) && definedInCaller(Blocks
, V
))
528 for (User
*U
: II
.users())
529 if (!definedInRegion(Blocks
, U
)) {
537 /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
538 /// of the region, we need to split the entry block of the region so that the
539 /// PHI node is easier to deal with.
540 void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock
*&Header
) {
541 unsigned NumPredsFromRegion
= 0;
542 unsigned NumPredsOutsideRegion
= 0;
544 if (Header
!= &Header
->getParent()->getEntryBlock()) {
545 PHINode
*PN
= dyn_cast
<PHINode
>(Header
->begin());
546 if (!PN
) return; // No PHI nodes.
548 // If the header node contains any PHI nodes, check to see if there is more
549 // than one entry from outside the region. If so, we need to sever the
550 // header block into two.
551 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
552 if (Blocks
.count(PN
->getIncomingBlock(i
)))
553 ++NumPredsFromRegion
;
555 ++NumPredsOutsideRegion
;
557 // If there is one (or fewer) predecessor from outside the region, we don't
558 // need to do anything special.
559 if (NumPredsOutsideRegion
<= 1) return;
562 // Otherwise, we need to split the header block into two pieces: one
563 // containing PHI nodes merging values from outside of the region, and a
564 // second that contains all of the code for the block and merges back any
565 // incoming values from inside of the region.
566 BasicBlock
*NewBB
= SplitBlock(Header
, Header
->getFirstNonPHI(), DT
);
568 // We only want to code extract the second block now, and it becomes the new
569 // header of the region.
570 BasicBlock
*OldPred
= Header
;
571 Blocks
.remove(OldPred
);
572 Blocks
.insert(NewBB
);
575 // Okay, now we need to adjust the PHI nodes and any branches from within the
576 // region to go to the new header block instead of the old header block.
577 if (NumPredsFromRegion
) {
578 PHINode
*PN
= cast
<PHINode
>(OldPred
->begin());
579 // Loop over all of the predecessors of OldPred that are in the region,
580 // changing them to branch to NewBB instead.
581 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
582 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
583 Instruction
*TI
= PN
->getIncomingBlock(i
)->getTerminator();
584 TI
->replaceUsesOfWith(OldPred
, NewBB
);
587 // Okay, everything within the region is now branching to the right block, we
588 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
589 BasicBlock::iterator AfterPHIs
;
590 for (AfterPHIs
= OldPred
->begin(); isa
<PHINode
>(AfterPHIs
); ++AfterPHIs
) {
591 PHINode
*PN
= cast
<PHINode
>(AfterPHIs
);
592 // Create a new PHI node in the new region, which has an incoming value
593 // from OldPred of PN.
594 PHINode
*NewPN
= PHINode::Create(PN
->getType(), 1 + NumPredsFromRegion
,
595 PN
->getName() + ".ce", &NewBB
->front());
596 PN
->replaceAllUsesWith(NewPN
);
597 NewPN
->addIncoming(PN
, OldPred
);
599 // Loop over all of the incoming value in PN, moving them to NewPN if they
600 // are from the extracted region.
601 for (unsigned i
= 0; i
!= PN
->getNumIncomingValues(); ++i
) {
602 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
603 NewPN
->addIncoming(PN
->getIncomingValue(i
), PN
->getIncomingBlock(i
));
604 PN
->removeIncomingValue(i
);
612 /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
613 /// outlined region, we split these PHIs on two: one with inputs from region
614 /// and other with remaining incoming blocks; then first PHIs are placed in
616 void CodeExtractor::severSplitPHINodesOfExits(
617 const SmallPtrSetImpl
<BasicBlock
*> &Exits
) {
618 for (BasicBlock
*ExitBB
: Exits
) {
619 BasicBlock
*NewBB
= nullptr;
621 for (PHINode
&PN
: ExitBB
->phis()) {
622 // Find all incoming values from the outlining region.
623 SmallVector
<unsigned, 2> IncomingVals
;
624 for (unsigned i
= 0; i
< PN
.getNumIncomingValues(); ++i
)
625 if (Blocks
.count(PN
.getIncomingBlock(i
)))
626 IncomingVals
.push_back(i
);
628 // Do not process PHI if there is one (or fewer) predecessor from region.
629 // If PHI has exactly one predecessor from region, only this one incoming
630 // will be replaced on codeRepl block, so it should be safe to skip PHI.
631 if (IncomingVals
.size() <= 1)
634 // Create block for new PHIs and add it to the list of outlined if it
635 // wasn't done before.
637 NewBB
= BasicBlock::Create(ExitBB
->getContext(),
638 ExitBB
->getName() + ".split",
639 ExitBB
->getParent(), ExitBB
);
640 SmallVector
<BasicBlock
*, 4> Preds(pred_begin(ExitBB
),
642 for (BasicBlock
*PredBB
: Preds
)
643 if (Blocks
.count(PredBB
))
644 PredBB
->getTerminator()->replaceUsesOfWith(ExitBB
, NewBB
);
645 BranchInst::Create(ExitBB
, NewBB
);
646 Blocks
.insert(NewBB
);
651 PHINode::Create(PN
.getType(), IncomingVals
.size(),
652 PN
.getName() + ".ce", NewBB
->getFirstNonPHI());
653 for (unsigned i
: IncomingVals
)
654 NewPN
->addIncoming(PN
.getIncomingValue(i
), PN
.getIncomingBlock(i
));
655 for (unsigned i
: reverse(IncomingVals
))
656 PN
.removeIncomingValue(i
, false);
657 PN
.addIncoming(NewPN
, NewBB
);
662 void CodeExtractor::splitReturnBlocks() {
663 for (BasicBlock
*Block
: Blocks
)
664 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(Block
->getTerminator())) {
666 Block
->splitBasicBlock(RI
->getIterator(), Block
->getName() + ".ret");
668 // Old dominates New. New node dominates all other nodes dominated
670 DomTreeNode
*OldNode
= DT
->getNode(Block
);
671 SmallVector
<DomTreeNode
*, 8> Children(OldNode
->begin(),
674 DomTreeNode
*NewNode
= DT
->addNewBlock(New
, Block
);
676 for (DomTreeNode
*I
: Children
)
677 DT
->changeImmediateDominator(I
, NewNode
);
682 /// constructFunction - make a function based on inputs and outputs, as follows:
683 /// f(in0, ..., inN, out0, ..., outN)
684 Function
*CodeExtractor::constructFunction(const ValueSet
&inputs
,
685 const ValueSet
&outputs
,
687 BasicBlock
*newRootNode
,
688 BasicBlock
*newHeader
,
689 Function
*oldFunction
,
691 LLVM_DEBUG(dbgs() << "inputs: " << inputs
.size() << "\n");
692 LLVM_DEBUG(dbgs() << "outputs: " << outputs
.size() << "\n");
694 // This function returns unsigned, outputs will go back by reference.
695 switch (NumExitBlocks
) {
697 case 1: RetTy
= Type::getVoidTy(header
->getContext()); break;
698 case 2: RetTy
= Type::getInt1Ty(header
->getContext()); break;
699 default: RetTy
= Type::getInt16Ty(header
->getContext()); break;
702 std::vector
<Type
*> paramTy
;
704 // Add the types of the input values to the function's argument list
705 for (Value
*value
: inputs
) {
706 LLVM_DEBUG(dbgs() << "value used in func: " << *value
<< "\n");
707 paramTy
.push_back(value
->getType());
710 // Add the types of the output values to the function's argument list.
711 for (Value
*output
: outputs
) {
712 LLVM_DEBUG(dbgs() << "instr used in func: " << *output
<< "\n");
714 paramTy
.push_back(output
->getType());
716 paramTy
.push_back(PointerType::getUnqual(output
->getType()));
720 dbgs() << "Function type: " << *RetTy
<< " f(";
721 for (Type
*i
: paramTy
)
722 dbgs() << *i
<< ", ";
726 StructType
*StructTy
;
727 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
728 StructTy
= StructType::get(M
->getContext(), paramTy
);
730 paramTy
.push_back(PointerType::getUnqual(StructTy
));
732 FunctionType
*funcType
=
733 FunctionType::get(RetTy
, paramTy
,
734 AllowVarArgs
&& oldFunction
->isVarArg());
736 std::string SuffixToUse
=
738 ? (header
->getName().empty() ? "extracted" : header
->getName().str())
740 // Create the new function
741 Function
*newFunction
= Function::Create(
742 funcType
, GlobalValue::InternalLinkage
, oldFunction
->getAddressSpace(),
743 oldFunction
->getName() + "." + SuffixToUse
, M
);
744 // If the old function is no-throw, so is the new one.
745 if (oldFunction
->doesNotThrow())
746 newFunction
->setDoesNotThrow();
748 // Inherit the uwtable attribute if we need to.
749 if (oldFunction
->hasUWTable())
750 newFunction
->setHasUWTable();
752 // Inherit all of the target dependent attributes and white-listed
753 // target independent attributes.
754 // (e.g. If the extracted region contains a call to an x86.sse
755 // instruction we need to make sure that the extracted region has the
756 // "target-features" attribute allowing it to be lowered.
757 // FIXME: This should be changed to check to see if a specific
758 // attribute can not be inherited.
759 for (const auto &Attr
: oldFunction
->getAttributes().getFnAttributes()) {
760 if (Attr
.isStringAttribute()) {
761 if (Attr
.getKindAsString() == "thunk")
764 switch (Attr
.getKindAsEnum()) {
765 // Those attributes cannot be propagated safely. Explicitly list them
766 // here so we get a warning if new attributes are added. This list also
767 // includes non-function attributes.
768 case Attribute::Alignment
:
769 case Attribute::AllocSize
:
770 case Attribute::ArgMemOnly
:
771 case Attribute::Builtin
:
772 case Attribute::ByVal
:
773 case Attribute::Convergent
:
774 case Attribute::Dereferenceable
:
775 case Attribute::DereferenceableOrNull
:
776 case Attribute::InAlloca
:
777 case Attribute::InReg
:
778 case Attribute::InaccessibleMemOnly
:
779 case Attribute::InaccessibleMemOrArgMemOnly
:
780 case Attribute::JumpTable
:
781 case Attribute::Naked
:
782 case Attribute::Nest
:
783 case Attribute::NoAlias
:
784 case Attribute::NoBuiltin
:
785 case Attribute::NoCapture
:
786 case Attribute::NoReturn
:
787 case Attribute::None
:
788 case Attribute::NonNull
:
789 case Attribute::ReadNone
:
790 case Attribute::ReadOnly
:
791 case Attribute::Returned
:
792 case Attribute::ReturnsTwice
:
793 case Attribute::SExt
:
794 case Attribute::Speculatable
:
795 case Attribute::StackAlignment
:
796 case Attribute::StructRet
:
797 case Attribute::SwiftError
:
798 case Attribute::SwiftSelf
:
799 case Attribute::WriteOnly
:
800 case Attribute::ZExt
:
801 case Attribute::EndAttrKinds
:
803 // Those attributes should be safe to propagate to the extracted function.
804 case Attribute::AlwaysInline
:
805 case Attribute::Cold
:
806 case Attribute::NoRecurse
:
807 case Attribute::InlineHint
:
808 case Attribute::MinSize
:
809 case Attribute::NoDuplicate
:
810 case Attribute::NoImplicitFloat
:
811 case Attribute::NoInline
:
812 case Attribute::NonLazyBind
:
813 case Attribute::NoRedZone
:
814 case Attribute::NoUnwind
:
815 case Attribute::OptForFuzzing
:
816 case Attribute::OptimizeNone
:
817 case Attribute::OptimizeForSize
:
818 case Attribute::SafeStack
:
819 case Attribute::ShadowCallStack
:
820 case Attribute::SanitizeAddress
:
821 case Attribute::SanitizeMemory
:
822 case Attribute::SanitizeThread
:
823 case Attribute::SanitizeHWAddress
:
824 case Attribute::SpeculativeLoadHardening
:
825 case Attribute::StackProtect
:
826 case Attribute::StackProtectReq
:
827 case Attribute::StackProtectStrong
:
828 case Attribute::StrictFP
:
829 case Attribute::UWTable
:
830 case Attribute::NoCfCheck
:
834 newFunction
->addFnAttr(Attr
);
836 newFunction
->getBasicBlockList().push_back(newRootNode
);
838 // Create an iterator to name all of the arguments we inserted.
839 Function::arg_iterator AI
= newFunction
->arg_begin();
841 // Rewrite all users of the inputs in the extracted region to use the
842 // arguments (or appropriate addressing into struct) instead.
843 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
847 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(header
->getContext()));
848 Idx
[1] = ConstantInt::get(Type::getInt32Ty(header
->getContext()), i
);
849 Instruction
*TI
= newFunction
->begin()->getTerminator();
850 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
851 StructTy
, &*AI
, Idx
, "gep_" + inputs
[i
]->getName(), TI
);
852 RewriteVal
= new LoadInst(StructTy
->getElementType(i
), GEP
,
853 "loadgep_" + inputs
[i
]->getName(), TI
);
857 std::vector
<User
*> Users(inputs
[i
]->user_begin(), inputs
[i
]->user_end());
858 for (User
*use
: Users
)
859 if (Instruction
*inst
= dyn_cast
<Instruction
>(use
))
860 if (Blocks
.count(inst
->getParent()))
861 inst
->replaceUsesOfWith(inputs
[i
], RewriteVal
);
864 // Set names for input and output arguments.
865 if (!AggregateArgs
) {
866 AI
= newFunction
->arg_begin();
867 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
, ++AI
)
868 AI
->setName(inputs
[i
]->getName());
869 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
, ++AI
)
870 AI
->setName(outputs
[i
]->getName()+".out");
873 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
874 // within the new function. This must be done before we lose track of which
875 // blocks were originally in the code region.
876 std::vector
<User
*> Users(header
->user_begin(), header
->user_end());
877 for (unsigned i
= 0, e
= Users
.size(); i
!= e
; ++i
)
878 // The BasicBlock which contains the branch is not in the region
879 // modify the branch target to a new block
880 if (Instruction
*I
= dyn_cast
<Instruction
>(Users
[i
]))
881 if (I
->isTerminator() && !Blocks
.count(I
->getParent()) &&
882 I
->getParent()->getParent() == oldFunction
)
883 I
->replaceUsesOfWith(header
, newHeader
);
888 /// Erase lifetime.start markers which reference inputs to the extraction
889 /// region, and insert the referenced memory into \p LifetimesStart.
891 /// The extraction region is defined by a set of blocks (\p Blocks), and a set
892 /// of allocas which will be moved from the caller function into the extracted
893 /// function (\p SunkAllocas).
894 static void eraseLifetimeMarkersOnInputs(const SetVector
<BasicBlock
*> &Blocks
,
895 const SetVector
<Value
*> &SunkAllocas
,
896 SetVector
<Value
*> &LifetimesStart
) {
897 for (BasicBlock
*BB
: Blocks
) {
898 for (auto It
= BB
->begin(), End
= BB
->end(); It
!= End
;) {
899 auto *II
= dyn_cast
<IntrinsicInst
>(&*It
);
901 if (!II
|| !II
->isLifetimeStartOrEnd())
904 // Get the memory operand of the lifetime marker. If the underlying
905 // object is a sunk alloca, or is otherwise defined in the extraction
906 // region, the lifetime marker must not be erased.
907 Value
*Mem
= II
->getOperand(1)->stripInBoundsOffsets();
908 if (SunkAllocas
.count(Mem
) || definedInRegion(Blocks
, Mem
))
911 if (II
->getIntrinsicID() == Intrinsic::lifetime_start
)
912 LifetimesStart
.insert(Mem
);
913 II
->eraseFromParent();
918 /// Insert lifetime start/end markers surrounding the call to the new function
919 /// for objects defined in the caller.
920 static void insertLifetimeMarkersSurroundingCall(
921 Module
*M
, ArrayRef
<Value
*> LifetimesStart
, ArrayRef
<Value
*> LifetimesEnd
,
923 LLVMContext
&Ctx
= M
->getContext();
924 auto Int8PtrTy
= Type::getInt8PtrTy(Ctx
);
925 auto NegativeOne
= ConstantInt::getSigned(Type::getInt64Ty(Ctx
), -1);
926 Instruction
*Term
= TheCall
->getParent()->getTerminator();
928 // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
929 // needed to satisfy this requirement so they may be reused.
930 DenseMap
<Value
*, Value
*> Bitcasts
;
932 // Emit lifetime markers for the pointers given in \p Objects. Insert the
933 // markers before the call if \p InsertBefore, and after the call otherwise.
934 auto insertMarkers
= [&](Function
*MarkerFunc
, ArrayRef
<Value
*> Objects
,
936 for (Value
*Mem
: Objects
) {
937 assert((!isa
<Instruction
>(Mem
) || cast
<Instruction
>(Mem
)->getFunction() ==
938 TheCall
->getFunction()) &&
939 "Input memory not defined in original function");
940 Value
*&MemAsI8Ptr
= Bitcasts
[Mem
];
942 if (Mem
->getType() == Int8PtrTy
)
946 CastInst::CreatePointerCast(Mem
, Int8PtrTy
, "lt.cast", TheCall
);
949 auto Marker
= CallInst::Create(MarkerFunc
, {NegativeOne
, MemAsI8Ptr
});
951 Marker
->insertBefore(TheCall
);
953 Marker
->insertBefore(Term
);
957 if (!LifetimesStart
.empty()) {
958 auto StartFn
= llvm::Intrinsic::getDeclaration(
959 M
, llvm::Intrinsic::lifetime_start
, Int8PtrTy
);
960 insertMarkers(StartFn
, LifetimesStart
, /*InsertBefore=*/true);
963 if (!LifetimesEnd
.empty()) {
964 auto EndFn
= llvm::Intrinsic::getDeclaration(
965 M
, llvm::Intrinsic::lifetime_end
, Int8PtrTy
);
966 insertMarkers(EndFn
, LifetimesEnd
, /*InsertBefore=*/false);
970 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
971 /// the call instruction, splitting any PHI nodes in the header block as
973 CallInst
*CodeExtractor::emitCallAndSwitchStatement(Function
*newFunction
,
974 BasicBlock
*codeReplacer
,
977 // Emit a call to the new function, passing in: *pointer to struct (if
978 // aggregating parameters), or plan inputs and allocated memory for outputs
979 std::vector
<Value
*> params
, StructValues
, ReloadOutputs
, Reloads
;
981 Module
*M
= newFunction
->getParent();
982 LLVMContext
&Context
= M
->getContext();
983 const DataLayout
&DL
= M
->getDataLayout();
984 CallInst
*call
= nullptr;
986 // Add inputs as params, or to be filled into the struct
988 SmallVector
<unsigned, 1> SwiftErrorArgs
;
989 for (Value
*input
: inputs
) {
991 StructValues
.push_back(input
);
993 params
.push_back(input
);
994 if (input
->isSwiftError())
995 SwiftErrorArgs
.push_back(ArgNo
);
1000 // Create allocas for the outputs
1001 for (Value
*output
: outputs
) {
1002 if (AggregateArgs
) {
1003 StructValues
.push_back(output
);
1005 AllocaInst
*alloca
=
1006 new AllocaInst(output
->getType(), DL
.getAllocaAddrSpace(),
1007 nullptr, output
->getName() + ".loc",
1008 &codeReplacer
->getParent()->front().front());
1009 ReloadOutputs
.push_back(alloca
);
1010 params
.push_back(alloca
);
1014 StructType
*StructArgTy
= nullptr;
1015 AllocaInst
*Struct
= nullptr;
1016 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
1017 std::vector
<Type
*> ArgTypes
;
1018 for (ValueSet::iterator v
= StructValues
.begin(),
1019 ve
= StructValues
.end(); v
!= ve
; ++v
)
1020 ArgTypes
.push_back((*v
)->getType());
1022 // Allocate a struct at the beginning of this function
1023 StructArgTy
= StructType::get(newFunction
->getContext(), ArgTypes
);
1024 Struct
= new AllocaInst(StructArgTy
, DL
.getAllocaAddrSpace(), nullptr,
1026 &codeReplacer
->getParent()->front().front());
1027 params
.push_back(Struct
);
1029 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
1031 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1032 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), i
);
1033 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1034 StructArgTy
, Struct
, Idx
, "gep_" + StructValues
[i
]->getName());
1035 codeReplacer
->getInstList().push_back(GEP
);
1036 StoreInst
*SI
= new StoreInst(StructValues
[i
], GEP
);
1037 codeReplacer
->getInstList().push_back(SI
);
1041 // Emit the call to the function
1042 call
= CallInst::Create(newFunction
, params
,
1043 NumExitBlocks
> 1 ? "targetBlock" : "");
1044 // Add debug location to the new call, if the original function has debug
1045 // info. In that case, the terminator of the entry block of the extracted
1046 // function contains the first debug location of the extracted function,
1047 // set in extractCodeRegion.
1048 if (codeReplacer
->getParent()->getSubprogram()) {
1049 if (auto DL
= newFunction
->getEntryBlock().getTerminator()->getDebugLoc())
1050 call
->setDebugLoc(DL
);
1052 codeReplacer
->getInstList().push_back(call
);
1054 // Set swifterror parameter attributes.
1055 for (unsigned SwiftErrArgNo
: SwiftErrorArgs
) {
1056 call
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1057 newFunction
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1060 Function::arg_iterator OutputArgBegin
= newFunction
->arg_begin();
1061 unsigned FirstOut
= inputs
.size();
1063 std::advance(OutputArgBegin
, inputs
.size());
1065 // Reload the outputs passed in by reference.
1066 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1067 Value
*Output
= nullptr;
1068 if (AggregateArgs
) {
1070 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1071 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1072 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1073 StructArgTy
, Struct
, Idx
, "gep_reload_" + outputs
[i
]->getName());
1074 codeReplacer
->getInstList().push_back(GEP
);
1077 Output
= ReloadOutputs
[i
];
1079 LoadInst
*load
= new LoadInst(outputs
[i
]->getType(), Output
,
1080 outputs
[i
]->getName() + ".reload");
1081 Reloads
.push_back(load
);
1082 codeReplacer
->getInstList().push_back(load
);
1083 std::vector
<User
*> Users(outputs
[i
]->user_begin(), outputs
[i
]->user_end());
1084 for (unsigned u
= 0, e
= Users
.size(); u
!= e
; ++u
) {
1085 Instruction
*inst
= cast
<Instruction
>(Users
[u
]);
1086 if (!Blocks
.count(inst
->getParent()))
1087 inst
->replaceUsesOfWith(outputs
[i
], load
);
1091 // Now we can emit a switch statement using the call as a value.
1092 SwitchInst
*TheSwitch
=
1093 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context
)),
1094 codeReplacer
, 0, codeReplacer
);
1096 // Since there may be multiple exits from the original region, make the new
1097 // function return an unsigned, switch on that number. This loop iterates
1098 // over all of the blocks in the extracted region, updating any terminator
1099 // instructions in the to-be-extracted region that branch to blocks that are
1100 // not in the region to be extracted.
1101 std::map
<BasicBlock
*, BasicBlock
*> ExitBlockMap
;
1103 unsigned switchVal
= 0;
1104 for (BasicBlock
*Block
: Blocks
) {
1105 Instruction
*TI
= Block
->getTerminator();
1106 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
!= e
; ++i
)
1107 if (!Blocks
.count(TI
->getSuccessor(i
))) {
1108 BasicBlock
*OldTarget
= TI
->getSuccessor(i
);
1109 // add a new basic block which returns the appropriate value
1110 BasicBlock
*&NewTarget
= ExitBlockMap
[OldTarget
];
1112 // If we don't already have an exit stub for this non-extracted
1113 // destination, create one now!
1114 NewTarget
= BasicBlock::Create(Context
,
1115 OldTarget
->getName() + ".exitStub",
1117 unsigned SuccNum
= switchVal
++;
1119 Value
*brVal
= nullptr;
1120 switch (NumExitBlocks
) {
1122 case 1: break; // No value needed.
1123 case 2: // Conditional branch, return a bool
1124 brVal
= ConstantInt::get(Type::getInt1Ty(Context
), !SuccNum
);
1127 brVal
= ConstantInt::get(Type::getInt16Ty(Context
), SuccNum
);
1131 ReturnInst::Create(Context
, brVal
, NewTarget
);
1133 // Update the switch instruction.
1134 TheSwitch
->addCase(ConstantInt::get(Type::getInt16Ty(Context
),
1139 // rewrite the original branch instruction with this new target
1140 TI
->setSuccessor(i
, NewTarget
);
1144 // Store the arguments right after the definition of output value.
1145 // This should be proceeded after creating exit stubs to be ensure that invoke
1146 // result restore will be placed in the outlined function.
1147 Function::arg_iterator OAI
= OutputArgBegin
;
1148 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1149 auto *OutI
= dyn_cast
<Instruction
>(outputs
[i
]);
1153 // Find proper insertion point.
1154 BasicBlock::iterator InsertPt
;
1155 // In case OutI is an invoke, we insert the store at the beginning in the
1156 // 'normal destination' BB. Otherwise we insert the store right after OutI.
1157 if (auto *InvokeI
= dyn_cast
<InvokeInst
>(OutI
))
1158 InsertPt
= InvokeI
->getNormalDest()->getFirstInsertionPt();
1159 else if (auto *Phi
= dyn_cast
<PHINode
>(OutI
))
1160 InsertPt
= Phi
->getParent()->getFirstInsertionPt();
1162 InsertPt
= std::next(OutI
->getIterator());
1164 Instruction
*InsertBefore
= &*InsertPt
;
1165 assert((InsertBefore
->getFunction() == newFunction
||
1166 Blocks
.count(InsertBefore
->getParent())) &&
1167 "InsertPt should be in new function");
1168 assert(OAI
!= newFunction
->arg_end() &&
1169 "Number of output arguments should match "
1170 "the amount of defined values");
1171 if (AggregateArgs
) {
1173 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1174 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1175 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1176 StructArgTy
, &*OAI
, Idx
, "gep_" + outputs
[i
]->getName(),
1178 new StoreInst(outputs
[i
], GEP
, InsertBefore
);
1179 // Since there should be only one struct argument aggregating
1180 // all the output values, we shouldn't increment OAI, which always
1181 // points to the struct argument, in this case.
1183 new StoreInst(outputs
[i
], &*OAI
, InsertBefore
);
1188 // Now that we've done the deed, simplify the switch instruction.
1189 Type
*OldFnRetTy
= TheSwitch
->getParent()->getParent()->getReturnType();
1190 switch (NumExitBlocks
) {
1192 // There are no successors (the block containing the switch itself), which
1193 // means that previously this was the last part of the function, and hence
1194 // this should be rewritten as a `ret'
1196 // Check if the function should return a value
1197 if (OldFnRetTy
->isVoidTy()) {
1198 ReturnInst::Create(Context
, nullptr, TheSwitch
); // Return void
1199 } else if (OldFnRetTy
== TheSwitch
->getCondition()->getType()) {
1200 // return what we have
1201 ReturnInst::Create(Context
, TheSwitch
->getCondition(), TheSwitch
);
1203 // Otherwise we must have code extracted an unwind or something, just
1204 // return whatever we want.
1205 ReturnInst::Create(Context
,
1206 Constant::getNullValue(OldFnRetTy
), TheSwitch
);
1209 TheSwitch
->eraseFromParent();
1212 // Only a single destination, change the switch into an unconditional
1214 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
);
1215 TheSwitch
->eraseFromParent();
1218 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
->getSuccessor(2),
1220 TheSwitch
->eraseFromParent();
1223 // Otherwise, make the default destination of the switch instruction be one
1224 // of the other successors.
1225 TheSwitch
->setCondition(call
);
1226 TheSwitch
->setDefaultDest(TheSwitch
->getSuccessor(NumExitBlocks
));
1227 // Remove redundant case
1228 TheSwitch
->removeCase(SwitchInst::CaseIt(TheSwitch
, NumExitBlocks
-1));
1232 // Insert lifetime markers around the reloads of any output values. The
1233 // allocas output values are stored in are only in-use in the codeRepl block.
1234 insertLifetimeMarkersSurroundingCall(M
, ReloadOutputs
, ReloadOutputs
, call
);
1239 void CodeExtractor::moveCodeToFunction(Function
*newFunction
) {
1240 Function
*oldFunc
= (*Blocks
.begin())->getParent();
1241 Function::BasicBlockListType
&oldBlocks
= oldFunc
->getBasicBlockList();
1242 Function::BasicBlockListType
&newBlocks
= newFunction
->getBasicBlockList();
1244 for (BasicBlock
*Block
: Blocks
) {
1245 // Delete the basic block from the old function, and the list of blocks
1246 oldBlocks
.remove(Block
);
1248 // Insert this basic block into the new function
1249 newBlocks
.push_back(Block
);
1251 // Remove @llvm.assume calls that were moved to the new function from the
1252 // old function's assumption cache.
1254 for (auto &I
: *Block
)
1255 if (match(&I
, m_Intrinsic
<Intrinsic::assume
>()))
1256 AC
->unregisterAssumption(cast
<CallInst
>(&I
));
1260 void CodeExtractor::calculateNewCallTerminatorWeights(
1261 BasicBlock
*CodeReplacer
,
1262 DenseMap
<BasicBlock
*, BlockFrequency
> &ExitWeights
,
1263 BranchProbabilityInfo
*BPI
) {
1264 using Distribution
= BlockFrequencyInfoImplBase::Distribution
;
1265 using BlockNode
= BlockFrequencyInfoImplBase::BlockNode
;
1267 // Update the branch weights for the exit block.
1268 Instruction
*TI
= CodeReplacer
->getTerminator();
1269 SmallVector
<unsigned, 8> BranchWeights(TI
->getNumSuccessors(), 0);
1271 // Block Frequency distribution with dummy node.
1272 Distribution BranchDist
;
1274 // Add each of the frequencies of the successors.
1275 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
< e
; ++i
) {
1276 BlockNode
ExitNode(i
);
1277 uint64_t ExitFreq
= ExitWeights
[TI
->getSuccessor(i
)].getFrequency();
1279 BranchDist
.addExit(ExitNode
, ExitFreq
);
1281 BPI
->setEdgeProbability(CodeReplacer
, i
, BranchProbability::getZero());
1284 // Check for no total weight.
1285 if (BranchDist
.Total
== 0)
1288 // Normalize the distribution so that they can fit in unsigned.
1289 BranchDist
.normalize();
1291 // Create normalized branch weights and set the metadata.
1292 for (unsigned I
= 0, E
= BranchDist
.Weights
.size(); I
< E
; ++I
) {
1293 const auto &Weight
= BranchDist
.Weights
[I
];
1295 // Get the weight and update the current BFI.
1296 BranchWeights
[Weight
.TargetNode
.Index
] = Weight
.Amount
;
1297 BranchProbability
BP(Weight
.Amount
, BranchDist
.Total
);
1298 BPI
->setEdgeProbability(CodeReplacer
, Weight
.TargetNode
.Index
, BP
);
1301 LLVMContext::MD_prof
,
1302 MDBuilder(TI
->getContext()).createBranchWeights(BranchWeights
));
1305 Function
*CodeExtractor::extractCodeRegion() {
1309 // Assumption: this is a single-entry code region, and the header is the first
1310 // block in the region.
1311 BasicBlock
*header
= *Blocks
.begin();
1312 Function
*oldFunction
= header
->getParent();
1314 // For functions with varargs, check that varargs handling is only done in the
1315 // outlined function, i.e vastart and vaend are only used in outlined blocks.
1316 if (AllowVarArgs
&& oldFunction
->getFunctionType()->isVarArg()) {
1317 auto containsVarArgIntrinsic
= [](Instruction
&I
) {
1318 if (const CallInst
*CI
= dyn_cast
<CallInst
>(&I
))
1319 if (const Function
*F
= CI
->getCalledFunction())
1320 return F
->getIntrinsicID() == Intrinsic::vastart
||
1321 F
->getIntrinsicID() == Intrinsic::vaend
;
1325 for (auto &BB
: *oldFunction
) {
1326 if (Blocks
.count(&BB
))
1328 if (llvm::any_of(BB
, containsVarArgIntrinsic
))
1332 ValueSet inputs
, outputs
, SinkingCands
, HoistingCands
;
1333 BasicBlock
*CommonExit
= nullptr;
1335 // Calculate the entry frequency of the new function before we change the root
1337 BlockFrequency EntryFreq
;
1339 assert(BPI
&& "Both BPI and BFI are required to preserve profile info");
1340 for (BasicBlock
*Pred
: predecessors(header
)) {
1341 if (Blocks
.count(Pred
))
1344 BFI
->getBlockFreq(Pred
) * BPI
->getEdgeProbability(Pred
, header
);
1348 // If we have any return instructions in the region, split those blocks so
1349 // that the return is not in the region.
1350 splitReturnBlocks();
1352 // Calculate the exit blocks for the extracted region and the total exit
1353 // weights for each of those blocks.
1354 DenseMap
<BasicBlock
*, BlockFrequency
> ExitWeights
;
1355 SmallPtrSet
<BasicBlock
*, 1> ExitBlocks
;
1356 for (BasicBlock
*Block
: Blocks
) {
1357 for (succ_iterator SI
= succ_begin(Block
), SE
= succ_end(Block
); SI
!= SE
;
1359 if (!Blocks
.count(*SI
)) {
1360 // Update the branch weight for this successor.
1362 BlockFrequency
&BF
= ExitWeights
[*SI
];
1363 BF
+= BFI
->getBlockFreq(Block
) * BPI
->getEdgeProbability(Block
, *SI
);
1365 ExitBlocks
.insert(*SI
);
1369 NumExitBlocks
= ExitBlocks
.size();
1371 // If we have to split PHI nodes of the entry or exit blocks, do so now.
1372 severSplitPHINodesOfEntry(header
);
1373 severSplitPHINodesOfExits(ExitBlocks
);
1375 // This takes place of the original loop
1376 BasicBlock
*codeReplacer
= BasicBlock::Create(header
->getContext(),
1377 "codeRepl", oldFunction
,
1380 // The new function needs a root node because other nodes can branch to the
1381 // head of the region, but the entry node of a function cannot have preds.
1382 BasicBlock
*newFuncRoot
= BasicBlock::Create(header
->getContext(),
1384 auto *BranchI
= BranchInst::Create(header
);
1385 // If the original function has debug info, we have to add a debug location
1386 // to the new branch instruction from the artificial entry block.
1387 // We use the debug location of the first instruction in the extracted
1388 // blocks, as there is no other equivalent line in the source code.
1389 if (oldFunction
->getSubprogram()) {
1390 any_of(Blocks
, [&BranchI
](const BasicBlock
*BB
) {
1391 return any_of(*BB
, [&BranchI
](const Instruction
&I
) {
1392 if (!I
.getDebugLoc())
1394 BranchI
->setDebugLoc(I
.getDebugLoc());
1399 newFuncRoot
->getInstList().push_back(BranchI
);
1401 findAllocas(SinkingCands
, HoistingCands
, CommonExit
);
1402 assert(HoistingCands
.empty() || CommonExit
);
1404 // Find inputs to, outputs from the code region.
1405 findInputsOutputs(inputs
, outputs
, SinkingCands
);
1407 // Now sink all instructions which only have non-phi uses inside the region
1408 for (auto *II
: SinkingCands
)
1409 cast
<Instruction
>(II
)->moveBefore(*newFuncRoot
,
1410 newFuncRoot
->getFirstInsertionPt());
1412 if (!HoistingCands
.empty()) {
1413 auto *HoistToBlock
= findOrCreateBlockForHoisting(CommonExit
);
1414 Instruction
*TI
= HoistToBlock
->getTerminator();
1415 for (auto *II
: HoistingCands
)
1416 cast
<Instruction
>(II
)->moveBefore(TI
);
1419 // Collect objects which are inputs to the extraction region and also
1420 // referenced by lifetime start markers within it. The effects of these
1421 // markers must be replicated in the calling function to prevent the stack
1422 // coloring pass from merging slots which store input objects.
1423 ValueSet LifetimesStart
;
1424 eraseLifetimeMarkersOnInputs(Blocks
, SinkingCands
, LifetimesStart
);
1426 // Construct new function based on inputs/outputs & add allocas for all defs.
1427 Function
*newFunction
=
1428 constructFunction(inputs
, outputs
, header
, newFuncRoot
, codeReplacer
,
1429 oldFunction
, oldFunction
->getParent());
1431 // Update the entry count of the function.
1433 auto Count
= BFI
->getProfileCountFromFreq(EntryFreq
.getFrequency());
1434 if (Count
.hasValue())
1435 newFunction
->setEntryCount(
1436 ProfileCount(Count
.getValue(), Function::PCT_Real
)); // FIXME
1437 BFI
->setBlockFreq(codeReplacer
, EntryFreq
.getFrequency());
1441 emitCallAndSwitchStatement(newFunction
, codeReplacer
, inputs
, outputs
);
1443 moveCodeToFunction(newFunction
);
1445 // Replicate the effects of any lifetime start/end markers which referenced
1446 // input objects in the extraction region by placing markers around the call.
1447 insertLifetimeMarkersSurroundingCall(
1448 oldFunction
->getParent(), LifetimesStart
.getArrayRef(), {}, TheCall
);
1450 // Propagate personality info to the new function if there is one.
1451 if (oldFunction
->hasPersonalityFn())
1452 newFunction
->setPersonalityFn(oldFunction
->getPersonalityFn());
1454 // Update the branch weights for the exit block.
1455 if (BFI
&& NumExitBlocks
> 1)
1456 calculateNewCallTerminatorWeights(codeReplacer
, ExitWeights
, BPI
);
1458 // Loop over all of the PHI nodes in the header and exit blocks, and change
1459 // any references to the old incoming edge to be the new incoming edge.
1460 for (BasicBlock::iterator I
= header
->begin(); isa
<PHINode
>(I
); ++I
) {
1461 PHINode
*PN
= cast
<PHINode
>(I
);
1462 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
1463 if (!Blocks
.count(PN
->getIncomingBlock(i
)))
1464 PN
->setIncomingBlock(i
, newFuncRoot
);
1467 for (BasicBlock
*ExitBB
: ExitBlocks
)
1468 for (PHINode
&PN
: ExitBB
->phis()) {
1469 Value
*IncomingCodeReplacerVal
= nullptr;
1470 for (unsigned i
= 0, e
= PN
.getNumIncomingValues(); i
!= e
; ++i
) {
1471 // Ignore incoming values from outside of the extracted region.
1472 if (!Blocks
.count(PN
.getIncomingBlock(i
)))
1475 // Ensure that there is only one incoming value from codeReplacer.
1476 if (!IncomingCodeReplacerVal
) {
1477 PN
.setIncomingBlock(i
, codeReplacer
);
1478 IncomingCodeReplacerVal
= PN
.getIncomingValue(i
);
1480 assert(IncomingCodeReplacerVal
== PN
.getIncomingValue(i
) &&
1481 "PHI has two incompatbile incoming values from codeRepl");
1485 // Erase debug info intrinsics. Variable updates within the new function are
1486 // invisible to debuggers. This could be improved by defining a DISubprogram
1487 // for the new function.
1488 for (BasicBlock
&BB
: *newFunction
) {
1489 auto BlockIt
= BB
.begin();
1490 // Remove debug info intrinsics from the new function.
1491 while (BlockIt
!= BB
.end()) {
1492 Instruction
*Inst
= &*BlockIt
;
1494 if (isa
<DbgInfoIntrinsic
>(Inst
))
1495 Inst
->eraseFromParent();
1497 // Remove debug info intrinsics which refer to values in the new function
1498 // from the old function.
1499 SmallVector
<DbgVariableIntrinsic
*, 4> DbgUsers
;
1500 for (Instruction
&I
: BB
)
1501 findDbgUsers(DbgUsers
, &I
);
1502 for (DbgVariableIntrinsic
*DVI
: DbgUsers
)
1503 DVI
->eraseFromParent();
1506 // Mark the new function `noreturn` if applicable. Terminators which resume
1507 // exception propagation are treated as returning instructions. This is to
1508 // avoid inserting traps after calls to outlined functions which unwind.
1509 bool doesNotReturn
= none_of(*newFunction
, [](const BasicBlock
&BB
) {
1510 const Instruction
*Term
= BB
.getTerminator();
1511 return isa
<ReturnInst
>(Term
) || isa
<ResumeInst
>(Term
);
1514 newFunction
->setDoesNotReturn();
1516 LLVM_DEBUG(if (verifyFunction(*newFunction
, &errs())) {
1517 newFunction
->dump();
1518 report_fatal_error("verification of newFunction failed!");
1520 LLVM_DEBUG(if (verifyFunction(*oldFunction
))
1521 report_fatal_error("verification of oldFunction failed!"));