1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the interface to tear out a code region, such as an
10 // individual loop or a parallel section, into a new function, replacing it with
11 // a call to the new function.
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Transforms/Utils/CodeExtractor.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Analysis/AssumptionCache.h"
24 #include "llvm/Analysis/BlockFrequencyInfo.h"
25 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26 #include "llvm/Analysis/BranchProbabilityInfo.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/IR/Argument.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DataLayout.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/GlobalValue.h"
39 #include "llvm/IR/InstrTypes.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/Intrinsics.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/Module.h"
47 #include "llvm/IR/PatternMatch.h"
48 #include "llvm/IR/Type.h"
49 #include "llvm/IR/User.h"
50 #include "llvm/IR/Value.h"
51 #include "llvm/IR/Verifier.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/BlockFrequency.h"
54 #include "llvm/Support/BranchProbability.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
61 #include "llvm/Transforms/Utils/Local.h"
71 using namespace llvm::PatternMatch
;
72 using ProfileCount
= Function::ProfileCount
;
74 #define DEBUG_TYPE "code-extractor"
76 // Provide a command-line option to aggregate function arguments into a struct
77 // for functions produced by the code extractor. This is useful when converting
78 // extracted functions to pthread-based code, as only one argument (void*) can
79 // be passed in to pthread_create().
81 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden
,
82 cl::desc("Aggregate arguments to code-extracted functions"));
84 /// Test whether a block is valid for extraction.
85 static bool isBlockValidForExtraction(const BasicBlock
&BB
,
86 const SetVector
<BasicBlock
*> &Result
,
87 bool AllowVarArgs
, bool AllowAlloca
) {
88 // taking the address of a basic block moved to another function is illegal
89 if (BB
.hasAddressTaken())
92 // don't hoist code that uses another basicblock address, as it's likely to
93 // lead to unexpected behavior, like cross-function jumps
94 SmallPtrSet
<User
const *, 16> Visited
;
95 SmallVector
<User
const *, 16> ToVisit
;
97 for (Instruction
const &Inst
: BB
)
98 ToVisit
.push_back(&Inst
);
100 while (!ToVisit
.empty()) {
101 User
const *Curr
= ToVisit
.pop_back_val();
102 if (!Visited
.insert(Curr
).second
)
104 if (isa
<BlockAddress
const>(Curr
))
105 return false; // even a reference to self is likely to be not compatible
107 if (isa
<Instruction
>(Curr
) && cast
<Instruction
>(Curr
)->getParent() != &BB
)
110 for (auto const &U
: Curr
->operands()) {
111 if (auto *UU
= dyn_cast
<User
>(U
))
112 ToVisit
.push_back(UU
);
116 // If explicitly requested, allow vastart and alloca. For invoke instructions
117 // verify that extraction is valid.
118 for (BasicBlock::const_iterator I
= BB
.begin(), E
= BB
.end(); I
!= E
; ++I
) {
119 if (isa
<AllocaInst
>(I
)) {
125 if (const auto *II
= dyn_cast
<InvokeInst
>(I
)) {
126 // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
127 // must be a part of the subgraph which is being extracted.
128 if (auto *UBB
= II
->getUnwindDest())
129 if (!Result
.count(UBB
))
134 // All catch handlers of a catchswitch instruction as well as the unwind
135 // destination must be in the subgraph.
136 if (const auto *CSI
= dyn_cast
<CatchSwitchInst
>(I
)) {
137 if (auto *UBB
= CSI
->getUnwindDest())
138 if (!Result
.count(UBB
))
140 for (auto *HBB
: CSI
->handlers())
141 if (!Result
.count(const_cast<BasicBlock
*>(HBB
)))
146 // Make sure that entire catch handler is within subgraph. It is sufficient
147 // to check that catch return's block is in the list.
148 if (const auto *CPI
= dyn_cast
<CatchPadInst
>(I
)) {
149 for (const auto *U
: CPI
->users())
150 if (const auto *CRI
= dyn_cast
<CatchReturnInst
>(U
))
151 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
156 // And do similar checks for cleanup handler - the entire handler must be
157 // in subgraph which is going to be extracted. For cleanup return should
158 // additionally check that the unwind destination is also in the subgraph.
159 if (const auto *CPI
= dyn_cast
<CleanupPadInst
>(I
)) {
160 for (const auto *U
: CPI
->users())
161 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(U
))
162 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
166 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(I
)) {
167 if (auto *UBB
= CRI
->getUnwindDest())
168 if (!Result
.count(UBB
))
173 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I
)) {
174 if (const Function
*F
= CI
->getCalledFunction()) {
175 auto IID
= F
->getIntrinsicID();
176 if (IID
== Intrinsic::vastart
) {
183 // Currently, we miscompile outlined copies of eh_typid_for. There are
184 // proposals for fixing this in llvm.org/PR39545.
185 if (IID
== Intrinsic::eh_typeid_for
)
194 /// Build a set of blocks to extract if the input blocks are viable.
195 static SetVector
<BasicBlock
*>
196 buildExtractionBlockSet(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
197 bool AllowVarArgs
, bool AllowAlloca
) {
198 assert(!BBs
.empty() && "The set of blocks to extract must be non-empty");
199 SetVector
<BasicBlock
*> Result
;
201 // Loop over the blocks, adding them to our set-vector, and aborting with an
202 // empty set if we encounter invalid blocks.
203 for (BasicBlock
*BB
: BBs
) {
204 // If this block is dead, don't process it.
205 if (DT
&& !DT
->isReachableFromEntry(BB
))
208 if (!Result
.insert(BB
))
209 llvm_unreachable("Repeated basic blocks in extraction input");
212 LLVM_DEBUG(dbgs() << "Region front block: " << Result
.front()->getName()
215 for (auto *BB
: Result
) {
216 if (!isBlockValidForExtraction(*BB
, Result
, AllowVarArgs
, AllowAlloca
))
219 // Make sure that the first block is not a landing pad.
220 if (BB
== Result
.front()) {
222 LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
228 // All blocks other than the first must not have predecessors outside of
229 // the subgraph which is being extracted.
230 for (auto *PBB
: predecessors(BB
))
231 if (!Result
.count(PBB
)) {
232 LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
233 "outside the region except for the first block!\n"
234 << "Problematic source BB: " << BB
->getName() << "\n"
235 << "Problematic destination BB: " << PBB
->getName()
244 CodeExtractor::CodeExtractor(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
245 bool AggregateArgs
, BlockFrequencyInfo
*BFI
,
246 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
247 bool AllowVarArgs
, bool AllowAlloca
,
249 : DT(DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
250 BPI(BPI
), AC(AC
), AllowVarArgs(AllowVarArgs
),
251 Blocks(buildExtractionBlockSet(BBs
, DT
, AllowVarArgs
, AllowAlloca
)),
254 CodeExtractor::CodeExtractor(DominatorTree
&DT
, Loop
&L
, bool AggregateArgs
,
255 BlockFrequencyInfo
*BFI
,
256 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
258 : DT(&DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
259 BPI(BPI
), AC(AC
), AllowVarArgs(false),
260 Blocks(buildExtractionBlockSet(L
.getBlocks(), &DT
,
261 /* AllowVarArgs */ false,
262 /* AllowAlloca */ false)),
265 /// definedInRegion - Return true if the specified value is defined in the
266 /// extracted region.
267 static bool definedInRegion(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
268 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
269 if (Blocks
.count(I
->getParent()))
274 /// definedInCaller - Return true if the specified value is defined in the
275 /// function being code extracted, but not in the region being extracted.
276 /// These values must be passed in as live-ins to the function.
277 static bool definedInCaller(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
278 if (isa
<Argument
>(V
)) return true;
279 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
280 if (!Blocks
.count(I
->getParent()))
285 static BasicBlock
*getCommonExitBlock(const SetVector
<BasicBlock
*> &Blocks
) {
286 BasicBlock
*CommonExitBlock
= nullptr;
287 auto hasNonCommonExitSucc
= [&](BasicBlock
*Block
) {
288 for (auto *Succ
: successors(Block
)) {
289 // Internal edges, ok.
290 if (Blocks
.count(Succ
))
292 if (!CommonExitBlock
) {
293 CommonExitBlock
= Succ
;
296 if (CommonExitBlock
== Succ
)
304 if (any_of(Blocks
, hasNonCommonExitSucc
))
307 return CommonExitBlock
;
310 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
311 Instruction
*Addr
) const {
312 AllocaInst
*AI
= cast
<AllocaInst
>(Addr
->stripInBoundsConstantOffsets());
313 Function
*Func
= (*Blocks
.begin())->getParent();
314 for (BasicBlock
&BB
: *Func
) {
315 if (Blocks
.count(&BB
))
317 for (Instruction
&II
: BB
) {
318 if (isa
<DbgInfoIntrinsic
>(II
))
321 unsigned Opcode
= II
.getOpcode();
322 Value
*MemAddr
= nullptr;
324 case Instruction::Store
:
325 case Instruction::Load
: {
326 if (Opcode
== Instruction::Store
) {
327 StoreInst
*SI
= cast
<StoreInst
>(&II
);
328 MemAddr
= SI
->getPointerOperand();
330 LoadInst
*LI
= cast
<LoadInst
>(&II
);
331 MemAddr
= LI
->getPointerOperand();
333 // Global variable can not be aliased with locals.
334 if (dyn_cast
<Constant
>(MemAddr
))
336 Value
*Base
= MemAddr
->stripInBoundsConstantOffsets();
337 if (!isa
<AllocaInst
>(Base
) || Base
== AI
)
342 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(&II
);
344 if (IntrInst
->isLifetimeStartOrEnd())
348 // Treat all the other cases conservatively if it has side effects.
349 if (II
.mayHaveSideEffects())
360 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock
*CommonExitBlock
) {
361 BasicBlock
*SinglePredFromOutlineRegion
= nullptr;
362 assert(!Blocks
.count(CommonExitBlock
) &&
363 "Expect a block outside the region!");
364 for (auto *Pred
: predecessors(CommonExitBlock
)) {
365 if (!Blocks
.count(Pred
))
367 if (!SinglePredFromOutlineRegion
) {
368 SinglePredFromOutlineRegion
= Pred
;
369 } else if (SinglePredFromOutlineRegion
!= Pred
) {
370 SinglePredFromOutlineRegion
= nullptr;
375 if (SinglePredFromOutlineRegion
)
376 return SinglePredFromOutlineRegion
;
379 auto getFirstPHI
= [](BasicBlock
*BB
) {
380 BasicBlock::iterator I
= BB
->begin();
381 PHINode
*FirstPhi
= nullptr;
382 while (I
!= BB
->end()) {
383 PHINode
*Phi
= dyn_cast
<PHINode
>(I
);
393 // If there are any phi nodes, the single pred either exists or has already
394 // be created before code extraction.
395 assert(!getFirstPHI(CommonExitBlock
) && "Phi not expected");
398 BasicBlock
*NewExitBlock
= CommonExitBlock
->splitBasicBlock(
399 CommonExitBlock
->getFirstNonPHI()->getIterator());
401 for (auto PI
= pred_begin(CommonExitBlock
), PE
= pred_end(CommonExitBlock
);
403 BasicBlock
*Pred
= *PI
++;
404 if (Blocks
.count(Pred
))
406 Pred
->getTerminator()->replaceUsesOfWith(CommonExitBlock
, NewExitBlock
);
408 // Now add the old exit block to the outline region.
409 Blocks
.insert(CommonExitBlock
);
410 return CommonExitBlock
;
413 // Find the pair of life time markers for address 'Addr' that are either
414 // defined inside the outline region or can legally be shrinkwrapped into the
415 // outline region. If there are not other untracked uses of the address, return
416 // the pair of markers if found; otherwise return a pair of nullptr.
417 CodeExtractor::LifetimeMarkerInfo
418 CodeExtractor::getLifetimeMarkers(Instruction
*Addr
,
419 BasicBlock
*ExitBlock
) const {
420 LifetimeMarkerInfo Info
;
422 for (User
*U
: Addr
->users()) {
423 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(U
);
425 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_start
) {
426 // Do not handle the case where Addr has multiple start markers.
429 Info
.LifeStart
= IntrInst
;
431 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_end
) {
434 Info
.LifeEnd
= IntrInst
;
438 // Find untracked uses of the address, bail.
439 if (!definedInRegion(Blocks
, U
))
443 if (!Info
.LifeStart
|| !Info
.LifeEnd
)
446 Info
.SinkLifeStart
= !definedInRegion(Blocks
, Info
.LifeStart
);
447 Info
.HoistLifeEnd
= !definedInRegion(Blocks
, Info
.LifeEnd
);
448 // Do legality check.
449 if ((Info
.SinkLifeStart
|| Info
.HoistLifeEnd
) &&
450 !isLegalToShrinkwrapLifetimeMarkers(Addr
))
453 // Check to see if we have a place to do hoisting, if not, bail.
454 if (Info
.HoistLifeEnd
&& !ExitBlock
)
460 void CodeExtractor::findAllocas(ValueSet
&SinkCands
, ValueSet
&HoistCands
,
461 BasicBlock
*&ExitBlock
) const {
462 Function
*Func
= (*Blocks
.begin())->getParent();
463 ExitBlock
= getCommonExitBlock(Blocks
);
465 auto moveOrIgnoreLifetimeMarkers
=
466 [&](const LifetimeMarkerInfo
&LMI
) -> bool {
469 if (LMI
.SinkLifeStart
) {
470 LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI
.LifeStart
472 SinkCands
.insert(LMI
.LifeStart
);
474 if (LMI
.HoistLifeEnd
) {
475 LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI
.LifeEnd
<< "\n");
476 HoistCands
.insert(LMI
.LifeEnd
);
481 for (BasicBlock
&BB
: *Func
) {
482 if (Blocks
.count(&BB
))
484 for (Instruction
&II
: BB
) {
485 auto *AI
= dyn_cast
<AllocaInst
>(&II
);
489 LifetimeMarkerInfo MarkerInfo
= getLifetimeMarkers(AI
, ExitBlock
);
490 bool Moved
= moveOrIgnoreLifetimeMarkers(MarkerInfo
);
492 LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI
<< "\n");
493 SinkCands
.insert(AI
);
497 // Follow any bitcasts.
498 SmallVector
<Instruction
*, 2> Bitcasts
;
499 SmallVector
<LifetimeMarkerInfo
, 2> BitcastLifetimeInfo
;
500 for (User
*U
: AI
->users()) {
501 if (U
->stripInBoundsConstantOffsets() == AI
) {
502 Instruction
*Bitcast
= cast
<Instruction
>(U
);
503 LifetimeMarkerInfo LMI
= getLifetimeMarkers(Bitcast
, ExitBlock
);
505 Bitcasts
.push_back(Bitcast
);
506 BitcastLifetimeInfo
.push_back(LMI
);
511 // Found unknown use of AI.
512 if (!definedInRegion(Blocks
, U
)) {
518 // Either no bitcasts reference the alloca or there are unknown uses.
519 if (Bitcasts
.empty())
522 LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI
<< "\n");
523 SinkCands
.insert(AI
);
524 for (unsigned I
= 0, E
= Bitcasts
.size(); I
!= E
; ++I
) {
525 Instruction
*BitcastAddr
= Bitcasts
[I
];
526 const LifetimeMarkerInfo
&LMI
= BitcastLifetimeInfo
[I
];
527 assert(LMI
.LifeStart
&&
528 "Unsafe to sink bitcast without lifetime markers");
529 moveOrIgnoreLifetimeMarkers(LMI
);
530 if (!definedInRegion(Blocks
, BitcastAddr
)) {
531 LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
533 SinkCands
.insert(BitcastAddr
);
540 void CodeExtractor::findInputsOutputs(ValueSet
&Inputs
, ValueSet
&Outputs
,
541 const ValueSet
&SinkCands
) const {
542 for (BasicBlock
*BB
: Blocks
) {
543 // If a used value is defined outside the region, it's an input. If an
544 // instruction is used outside the region, it's an output.
545 for (Instruction
&II
: *BB
) {
546 for (User::op_iterator OI
= II
.op_begin(), OE
= II
.op_end(); OI
!= OE
;
549 if (!SinkCands
.count(V
) && definedInCaller(Blocks
, V
))
553 for (User
*U
: II
.users())
554 if (!definedInRegion(Blocks
, U
)) {
562 /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
563 /// of the region, we need to split the entry block of the region so that the
564 /// PHI node is easier to deal with.
565 void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock
*&Header
) {
566 unsigned NumPredsFromRegion
= 0;
567 unsigned NumPredsOutsideRegion
= 0;
569 if (Header
!= &Header
->getParent()->getEntryBlock()) {
570 PHINode
*PN
= dyn_cast
<PHINode
>(Header
->begin());
571 if (!PN
) return; // No PHI nodes.
573 // If the header node contains any PHI nodes, check to see if there is more
574 // than one entry from outside the region. If so, we need to sever the
575 // header block into two.
576 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
577 if (Blocks
.count(PN
->getIncomingBlock(i
)))
578 ++NumPredsFromRegion
;
580 ++NumPredsOutsideRegion
;
582 // If there is one (or fewer) predecessor from outside the region, we don't
583 // need to do anything special.
584 if (NumPredsOutsideRegion
<= 1) return;
587 // Otherwise, we need to split the header block into two pieces: one
588 // containing PHI nodes merging values from outside of the region, and a
589 // second that contains all of the code for the block and merges back any
590 // incoming values from inside of the region.
591 BasicBlock
*NewBB
= SplitBlock(Header
, Header
->getFirstNonPHI(), DT
);
593 // We only want to code extract the second block now, and it becomes the new
594 // header of the region.
595 BasicBlock
*OldPred
= Header
;
596 Blocks
.remove(OldPred
);
597 Blocks
.insert(NewBB
);
600 // Okay, now we need to adjust the PHI nodes and any branches from within the
601 // region to go to the new header block instead of the old header block.
602 if (NumPredsFromRegion
) {
603 PHINode
*PN
= cast
<PHINode
>(OldPred
->begin());
604 // Loop over all of the predecessors of OldPred that are in the region,
605 // changing them to branch to NewBB instead.
606 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
607 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
608 Instruction
*TI
= PN
->getIncomingBlock(i
)->getTerminator();
609 TI
->replaceUsesOfWith(OldPred
, NewBB
);
612 // Okay, everything within the region is now branching to the right block, we
613 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
614 BasicBlock::iterator AfterPHIs
;
615 for (AfterPHIs
= OldPred
->begin(); isa
<PHINode
>(AfterPHIs
); ++AfterPHIs
) {
616 PHINode
*PN
= cast
<PHINode
>(AfterPHIs
);
617 // Create a new PHI node in the new region, which has an incoming value
618 // from OldPred of PN.
619 PHINode
*NewPN
= PHINode::Create(PN
->getType(), 1 + NumPredsFromRegion
,
620 PN
->getName() + ".ce", &NewBB
->front());
621 PN
->replaceAllUsesWith(NewPN
);
622 NewPN
->addIncoming(PN
, OldPred
);
624 // Loop over all of the incoming value in PN, moving them to NewPN if they
625 // are from the extracted region.
626 for (unsigned i
= 0; i
!= PN
->getNumIncomingValues(); ++i
) {
627 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
628 NewPN
->addIncoming(PN
->getIncomingValue(i
), PN
->getIncomingBlock(i
));
629 PN
->removeIncomingValue(i
);
637 /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
638 /// outlined region, we split these PHIs on two: one with inputs from region
639 /// and other with remaining incoming blocks; then first PHIs are placed in
641 void CodeExtractor::severSplitPHINodesOfExits(
642 const SmallPtrSetImpl
<BasicBlock
*> &Exits
) {
643 for (BasicBlock
*ExitBB
: Exits
) {
644 BasicBlock
*NewBB
= nullptr;
646 for (PHINode
&PN
: ExitBB
->phis()) {
647 // Find all incoming values from the outlining region.
648 SmallVector
<unsigned, 2> IncomingVals
;
649 for (unsigned i
= 0; i
< PN
.getNumIncomingValues(); ++i
)
650 if (Blocks
.count(PN
.getIncomingBlock(i
)))
651 IncomingVals
.push_back(i
);
653 // Do not process PHI if there is one (or fewer) predecessor from region.
654 // If PHI has exactly one predecessor from region, only this one incoming
655 // will be replaced on codeRepl block, so it should be safe to skip PHI.
656 if (IncomingVals
.size() <= 1)
659 // Create block for new PHIs and add it to the list of outlined if it
660 // wasn't done before.
662 NewBB
= BasicBlock::Create(ExitBB
->getContext(),
663 ExitBB
->getName() + ".split",
664 ExitBB
->getParent(), ExitBB
);
665 SmallVector
<BasicBlock
*, 4> Preds(pred_begin(ExitBB
),
667 for (BasicBlock
*PredBB
: Preds
)
668 if (Blocks
.count(PredBB
))
669 PredBB
->getTerminator()->replaceUsesOfWith(ExitBB
, NewBB
);
670 BranchInst::Create(ExitBB
, NewBB
);
671 Blocks
.insert(NewBB
);
676 PHINode::Create(PN
.getType(), IncomingVals
.size(),
677 PN
.getName() + ".ce", NewBB
->getFirstNonPHI());
678 for (unsigned i
: IncomingVals
)
679 NewPN
->addIncoming(PN
.getIncomingValue(i
), PN
.getIncomingBlock(i
));
680 for (unsigned i
: reverse(IncomingVals
))
681 PN
.removeIncomingValue(i
, false);
682 PN
.addIncoming(NewPN
, NewBB
);
687 void CodeExtractor::splitReturnBlocks() {
688 for (BasicBlock
*Block
: Blocks
)
689 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(Block
->getTerminator())) {
691 Block
->splitBasicBlock(RI
->getIterator(), Block
->getName() + ".ret");
693 // Old dominates New. New node dominates all other nodes dominated
695 DomTreeNode
*OldNode
= DT
->getNode(Block
);
696 SmallVector
<DomTreeNode
*, 8> Children(OldNode
->begin(),
699 DomTreeNode
*NewNode
= DT
->addNewBlock(New
, Block
);
701 for (DomTreeNode
*I
: Children
)
702 DT
->changeImmediateDominator(I
, NewNode
);
707 /// constructFunction - make a function based on inputs and outputs, as follows:
708 /// f(in0, ..., inN, out0, ..., outN)
709 Function
*CodeExtractor::constructFunction(const ValueSet
&inputs
,
710 const ValueSet
&outputs
,
712 BasicBlock
*newRootNode
,
713 BasicBlock
*newHeader
,
714 Function
*oldFunction
,
716 LLVM_DEBUG(dbgs() << "inputs: " << inputs
.size() << "\n");
717 LLVM_DEBUG(dbgs() << "outputs: " << outputs
.size() << "\n");
719 // This function returns unsigned, outputs will go back by reference.
720 switch (NumExitBlocks
) {
722 case 1: RetTy
= Type::getVoidTy(header
->getContext()); break;
723 case 2: RetTy
= Type::getInt1Ty(header
->getContext()); break;
724 default: RetTy
= Type::getInt16Ty(header
->getContext()); break;
727 std::vector
<Type
*> paramTy
;
729 // Add the types of the input values to the function's argument list
730 for (Value
*value
: inputs
) {
731 LLVM_DEBUG(dbgs() << "value used in func: " << *value
<< "\n");
732 paramTy
.push_back(value
->getType());
735 // Add the types of the output values to the function's argument list.
736 for (Value
*output
: outputs
) {
737 LLVM_DEBUG(dbgs() << "instr used in func: " << *output
<< "\n");
739 paramTy
.push_back(output
->getType());
741 paramTy
.push_back(PointerType::getUnqual(output
->getType()));
745 dbgs() << "Function type: " << *RetTy
<< " f(";
746 for (Type
*i
: paramTy
)
747 dbgs() << *i
<< ", ";
751 StructType
*StructTy
;
752 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
753 StructTy
= StructType::get(M
->getContext(), paramTy
);
755 paramTy
.push_back(PointerType::getUnqual(StructTy
));
757 FunctionType
*funcType
=
758 FunctionType::get(RetTy
, paramTy
,
759 AllowVarArgs
&& oldFunction
->isVarArg());
761 std::string SuffixToUse
=
763 ? (header
->getName().empty() ? "extracted" : header
->getName().str())
765 // Create the new function
766 Function
*newFunction
= Function::Create(
767 funcType
, GlobalValue::InternalLinkage
, oldFunction
->getAddressSpace(),
768 oldFunction
->getName() + "." + SuffixToUse
, M
);
769 // If the old function is no-throw, so is the new one.
770 if (oldFunction
->doesNotThrow())
771 newFunction
->setDoesNotThrow();
773 // Inherit the uwtable attribute if we need to.
774 if (oldFunction
->hasUWTable())
775 newFunction
->setHasUWTable();
777 // Inherit all of the target dependent attributes and white-listed
778 // target independent attributes.
779 // (e.g. If the extracted region contains a call to an x86.sse
780 // instruction we need to make sure that the extracted region has the
781 // "target-features" attribute allowing it to be lowered.
782 // FIXME: This should be changed to check to see if a specific
783 // attribute can not be inherited.
784 for (const auto &Attr
: oldFunction
->getAttributes().getFnAttributes()) {
785 if (Attr
.isStringAttribute()) {
786 if (Attr
.getKindAsString() == "thunk")
789 switch (Attr
.getKindAsEnum()) {
790 // Those attributes cannot be propagated safely. Explicitly list them
791 // here so we get a warning if new attributes are added. This list also
792 // includes non-function attributes.
793 case Attribute::Alignment
:
794 case Attribute::AllocSize
:
795 case Attribute::ArgMemOnly
:
796 case Attribute::Builtin
:
797 case Attribute::ByVal
:
798 case Attribute::Convergent
:
799 case Attribute::Dereferenceable
:
800 case Attribute::DereferenceableOrNull
:
801 case Attribute::InAlloca
:
802 case Attribute::InReg
:
803 case Attribute::InaccessibleMemOnly
:
804 case Attribute::InaccessibleMemOrArgMemOnly
:
805 case Attribute::JumpTable
:
806 case Attribute::Naked
:
807 case Attribute::Nest
:
808 case Attribute::NoAlias
:
809 case Attribute::NoBuiltin
:
810 case Attribute::NoCapture
:
811 case Attribute::NoReturn
:
812 case Attribute::NoSync
:
813 case Attribute::None
:
814 case Attribute::NonNull
:
815 case Attribute::ReadNone
:
816 case Attribute::ReadOnly
:
817 case Attribute::Returned
:
818 case Attribute::ReturnsTwice
:
819 case Attribute::SExt
:
820 case Attribute::Speculatable
:
821 case Attribute::StackAlignment
:
822 case Attribute::StructRet
:
823 case Attribute::SwiftError
:
824 case Attribute::SwiftSelf
:
825 case Attribute::WillReturn
:
826 case Attribute::WriteOnly
:
827 case Attribute::ZExt
:
828 case Attribute::ImmArg
:
829 case Attribute::EndAttrKinds
:
831 // Those attributes should be safe to propagate to the extracted function.
832 case Attribute::AlwaysInline
:
833 case Attribute::Cold
:
834 case Attribute::NoRecurse
:
835 case Attribute::InlineHint
:
836 case Attribute::MinSize
:
837 case Attribute::NoDuplicate
:
838 case Attribute::NoFree
:
839 case Attribute::NoImplicitFloat
:
840 case Attribute::NoInline
:
841 case Attribute::NonLazyBind
:
842 case Attribute::NoRedZone
:
843 case Attribute::NoUnwind
:
844 case Attribute::OptForFuzzing
:
845 case Attribute::OptimizeNone
:
846 case Attribute::OptimizeForSize
:
847 case Attribute::SafeStack
:
848 case Attribute::ShadowCallStack
:
849 case Attribute::SanitizeAddress
:
850 case Attribute::SanitizeMemory
:
851 case Attribute::SanitizeThread
:
852 case Attribute::SanitizeHWAddress
:
853 case Attribute::SanitizeMemTag
:
854 case Attribute::SpeculativeLoadHardening
:
855 case Attribute::StackProtect
:
856 case Attribute::StackProtectReq
:
857 case Attribute::StackProtectStrong
:
858 case Attribute::StrictFP
:
859 case Attribute::UWTable
:
860 case Attribute::NoCfCheck
:
864 newFunction
->addFnAttr(Attr
);
866 newFunction
->getBasicBlockList().push_back(newRootNode
);
868 // Create an iterator to name all of the arguments we inserted.
869 Function::arg_iterator AI
= newFunction
->arg_begin();
871 // Rewrite all users of the inputs in the extracted region to use the
872 // arguments (or appropriate addressing into struct) instead.
873 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
877 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(header
->getContext()));
878 Idx
[1] = ConstantInt::get(Type::getInt32Ty(header
->getContext()), i
);
879 Instruction
*TI
= newFunction
->begin()->getTerminator();
880 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
881 StructTy
, &*AI
, Idx
, "gep_" + inputs
[i
]->getName(), TI
);
882 RewriteVal
= new LoadInst(StructTy
->getElementType(i
), GEP
,
883 "loadgep_" + inputs
[i
]->getName(), TI
);
887 std::vector
<User
*> Users(inputs
[i
]->user_begin(), inputs
[i
]->user_end());
888 for (User
*use
: Users
)
889 if (Instruction
*inst
= dyn_cast
<Instruction
>(use
))
890 if (Blocks
.count(inst
->getParent()))
891 inst
->replaceUsesOfWith(inputs
[i
], RewriteVal
);
894 // Set names for input and output arguments.
895 if (!AggregateArgs
) {
896 AI
= newFunction
->arg_begin();
897 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
, ++AI
)
898 AI
->setName(inputs
[i
]->getName());
899 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
, ++AI
)
900 AI
->setName(outputs
[i
]->getName()+".out");
903 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
904 // within the new function. This must be done before we lose track of which
905 // blocks were originally in the code region.
906 std::vector
<User
*> Users(header
->user_begin(), header
->user_end());
907 for (unsigned i
= 0, e
= Users
.size(); i
!= e
; ++i
)
908 // The BasicBlock which contains the branch is not in the region
909 // modify the branch target to a new block
910 if (Instruction
*I
= dyn_cast
<Instruction
>(Users
[i
]))
911 if (I
->isTerminator() && !Blocks
.count(I
->getParent()) &&
912 I
->getParent()->getParent() == oldFunction
)
913 I
->replaceUsesOfWith(header
, newHeader
);
918 /// Erase lifetime.start markers which reference inputs to the extraction
919 /// region, and insert the referenced memory into \p LifetimesStart.
921 /// The extraction region is defined by a set of blocks (\p Blocks), and a set
922 /// of allocas which will be moved from the caller function into the extracted
923 /// function (\p SunkAllocas).
924 static void eraseLifetimeMarkersOnInputs(const SetVector
<BasicBlock
*> &Blocks
,
925 const SetVector
<Value
*> &SunkAllocas
,
926 SetVector
<Value
*> &LifetimesStart
) {
927 for (BasicBlock
*BB
: Blocks
) {
928 for (auto It
= BB
->begin(), End
= BB
->end(); It
!= End
;) {
929 auto *II
= dyn_cast
<IntrinsicInst
>(&*It
);
931 if (!II
|| !II
->isLifetimeStartOrEnd())
934 // Get the memory operand of the lifetime marker. If the underlying
935 // object is a sunk alloca, or is otherwise defined in the extraction
936 // region, the lifetime marker must not be erased.
937 Value
*Mem
= II
->getOperand(1)->stripInBoundsOffsets();
938 if (SunkAllocas
.count(Mem
) || definedInRegion(Blocks
, Mem
))
941 if (II
->getIntrinsicID() == Intrinsic::lifetime_start
)
942 LifetimesStart
.insert(Mem
);
943 II
->eraseFromParent();
948 /// Insert lifetime start/end markers surrounding the call to the new function
949 /// for objects defined in the caller.
950 static void insertLifetimeMarkersSurroundingCall(
951 Module
*M
, ArrayRef
<Value
*> LifetimesStart
, ArrayRef
<Value
*> LifetimesEnd
,
953 LLVMContext
&Ctx
= M
->getContext();
954 auto Int8PtrTy
= Type::getInt8PtrTy(Ctx
);
955 auto NegativeOne
= ConstantInt::getSigned(Type::getInt64Ty(Ctx
), -1);
956 Instruction
*Term
= TheCall
->getParent()->getTerminator();
958 // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
959 // needed to satisfy this requirement so they may be reused.
960 DenseMap
<Value
*, Value
*> Bitcasts
;
962 // Emit lifetime markers for the pointers given in \p Objects. Insert the
963 // markers before the call if \p InsertBefore, and after the call otherwise.
964 auto insertMarkers
= [&](Function
*MarkerFunc
, ArrayRef
<Value
*> Objects
,
966 for (Value
*Mem
: Objects
) {
967 assert((!isa
<Instruction
>(Mem
) || cast
<Instruction
>(Mem
)->getFunction() ==
968 TheCall
->getFunction()) &&
969 "Input memory not defined in original function");
970 Value
*&MemAsI8Ptr
= Bitcasts
[Mem
];
972 if (Mem
->getType() == Int8PtrTy
)
976 CastInst::CreatePointerCast(Mem
, Int8PtrTy
, "lt.cast", TheCall
);
979 auto Marker
= CallInst::Create(MarkerFunc
, {NegativeOne
, MemAsI8Ptr
});
981 Marker
->insertBefore(TheCall
);
983 Marker
->insertBefore(Term
);
987 if (!LifetimesStart
.empty()) {
988 auto StartFn
= llvm::Intrinsic::getDeclaration(
989 M
, llvm::Intrinsic::lifetime_start
, Int8PtrTy
);
990 insertMarkers(StartFn
, LifetimesStart
, /*InsertBefore=*/true);
993 if (!LifetimesEnd
.empty()) {
994 auto EndFn
= llvm::Intrinsic::getDeclaration(
995 M
, llvm::Intrinsic::lifetime_end
, Int8PtrTy
);
996 insertMarkers(EndFn
, LifetimesEnd
, /*InsertBefore=*/false);
1000 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
1001 /// the call instruction, splitting any PHI nodes in the header block as
1003 CallInst
*CodeExtractor::emitCallAndSwitchStatement(Function
*newFunction
,
1004 BasicBlock
*codeReplacer
,
1006 ValueSet
&outputs
) {
1007 // Emit a call to the new function, passing in: *pointer to struct (if
1008 // aggregating parameters), or plan inputs and allocated memory for outputs
1009 std::vector
<Value
*> params
, StructValues
, ReloadOutputs
, Reloads
;
1011 Module
*M
= newFunction
->getParent();
1012 LLVMContext
&Context
= M
->getContext();
1013 const DataLayout
&DL
= M
->getDataLayout();
1014 CallInst
*call
= nullptr;
1016 // Add inputs as params, or to be filled into the struct
1018 SmallVector
<unsigned, 1> SwiftErrorArgs
;
1019 for (Value
*input
: inputs
) {
1021 StructValues
.push_back(input
);
1023 params
.push_back(input
);
1024 if (input
->isSwiftError())
1025 SwiftErrorArgs
.push_back(ArgNo
);
1030 // Create allocas for the outputs
1031 for (Value
*output
: outputs
) {
1032 if (AggregateArgs
) {
1033 StructValues
.push_back(output
);
1035 AllocaInst
*alloca
=
1036 new AllocaInst(output
->getType(), DL
.getAllocaAddrSpace(),
1037 nullptr, output
->getName() + ".loc",
1038 &codeReplacer
->getParent()->front().front());
1039 ReloadOutputs
.push_back(alloca
);
1040 params
.push_back(alloca
);
1044 StructType
*StructArgTy
= nullptr;
1045 AllocaInst
*Struct
= nullptr;
1046 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
1047 std::vector
<Type
*> ArgTypes
;
1048 for (ValueSet::iterator v
= StructValues
.begin(),
1049 ve
= StructValues
.end(); v
!= ve
; ++v
)
1050 ArgTypes
.push_back((*v
)->getType());
1052 // Allocate a struct at the beginning of this function
1053 StructArgTy
= StructType::get(newFunction
->getContext(), ArgTypes
);
1054 Struct
= new AllocaInst(StructArgTy
, DL
.getAllocaAddrSpace(), nullptr,
1056 &codeReplacer
->getParent()->front().front());
1057 params
.push_back(Struct
);
1059 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
1061 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1062 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), i
);
1063 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1064 StructArgTy
, Struct
, Idx
, "gep_" + StructValues
[i
]->getName());
1065 codeReplacer
->getInstList().push_back(GEP
);
1066 StoreInst
*SI
= new StoreInst(StructValues
[i
], GEP
);
1067 codeReplacer
->getInstList().push_back(SI
);
1071 // Emit the call to the function
1072 call
= CallInst::Create(newFunction
, params
,
1073 NumExitBlocks
> 1 ? "targetBlock" : "");
1074 // Add debug location to the new call, if the original function has debug
1075 // info. In that case, the terminator of the entry block of the extracted
1076 // function contains the first debug location of the extracted function,
1077 // set in extractCodeRegion.
1078 if (codeReplacer
->getParent()->getSubprogram()) {
1079 if (auto DL
= newFunction
->getEntryBlock().getTerminator()->getDebugLoc())
1080 call
->setDebugLoc(DL
);
1082 codeReplacer
->getInstList().push_back(call
);
1084 // Set swifterror parameter attributes.
1085 for (unsigned SwiftErrArgNo
: SwiftErrorArgs
) {
1086 call
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1087 newFunction
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1090 Function::arg_iterator OutputArgBegin
= newFunction
->arg_begin();
1091 unsigned FirstOut
= inputs
.size();
1093 std::advance(OutputArgBegin
, inputs
.size());
1095 // Reload the outputs passed in by reference.
1096 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1097 Value
*Output
= nullptr;
1098 if (AggregateArgs
) {
1100 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1101 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1102 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1103 StructArgTy
, Struct
, Idx
, "gep_reload_" + outputs
[i
]->getName());
1104 codeReplacer
->getInstList().push_back(GEP
);
1107 Output
= ReloadOutputs
[i
];
1109 LoadInst
*load
= new LoadInst(outputs
[i
]->getType(), Output
,
1110 outputs
[i
]->getName() + ".reload");
1111 Reloads
.push_back(load
);
1112 codeReplacer
->getInstList().push_back(load
);
1113 std::vector
<User
*> Users(outputs
[i
]->user_begin(), outputs
[i
]->user_end());
1114 for (unsigned u
= 0, e
= Users
.size(); u
!= e
; ++u
) {
1115 Instruction
*inst
= cast
<Instruction
>(Users
[u
]);
1116 if (!Blocks
.count(inst
->getParent()))
1117 inst
->replaceUsesOfWith(outputs
[i
], load
);
1121 // Now we can emit a switch statement using the call as a value.
1122 SwitchInst
*TheSwitch
=
1123 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context
)),
1124 codeReplacer
, 0, codeReplacer
);
1126 // Since there may be multiple exits from the original region, make the new
1127 // function return an unsigned, switch on that number. This loop iterates
1128 // over all of the blocks in the extracted region, updating any terminator
1129 // instructions in the to-be-extracted region that branch to blocks that are
1130 // not in the region to be extracted.
1131 std::map
<BasicBlock
*, BasicBlock
*> ExitBlockMap
;
1133 unsigned switchVal
= 0;
1134 for (BasicBlock
*Block
: Blocks
) {
1135 Instruction
*TI
= Block
->getTerminator();
1136 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
!= e
; ++i
)
1137 if (!Blocks
.count(TI
->getSuccessor(i
))) {
1138 BasicBlock
*OldTarget
= TI
->getSuccessor(i
);
1139 // add a new basic block which returns the appropriate value
1140 BasicBlock
*&NewTarget
= ExitBlockMap
[OldTarget
];
1142 // If we don't already have an exit stub for this non-extracted
1143 // destination, create one now!
1144 NewTarget
= BasicBlock::Create(Context
,
1145 OldTarget
->getName() + ".exitStub",
1147 unsigned SuccNum
= switchVal
++;
1149 Value
*brVal
= nullptr;
1150 switch (NumExitBlocks
) {
1152 case 1: break; // No value needed.
1153 case 2: // Conditional branch, return a bool
1154 brVal
= ConstantInt::get(Type::getInt1Ty(Context
), !SuccNum
);
1157 brVal
= ConstantInt::get(Type::getInt16Ty(Context
), SuccNum
);
1161 ReturnInst::Create(Context
, brVal
, NewTarget
);
1163 // Update the switch instruction.
1164 TheSwitch
->addCase(ConstantInt::get(Type::getInt16Ty(Context
),
1169 // rewrite the original branch instruction with this new target
1170 TI
->setSuccessor(i
, NewTarget
);
1174 // Store the arguments right after the definition of output value.
1175 // This should be proceeded after creating exit stubs to be ensure that invoke
1176 // result restore will be placed in the outlined function.
1177 Function::arg_iterator OAI
= OutputArgBegin
;
1178 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1179 auto *OutI
= dyn_cast
<Instruction
>(outputs
[i
]);
1183 // Find proper insertion point.
1184 BasicBlock::iterator InsertPt
;
1185 // In case OutI is an invoke, we insert the store at the beginning in the
1186 // 'normal destination' BB. Otherwise we insert the store right after OutI.
1187 if (auto *InvokeI
= dyn_cast
<InvokeInst
>(OutI
))
1188 InsertPt
= InvokeI
->getNormalDest()->getFirstInsertionPt();
1189 else if (auto *Phi
= dyn_cast
<PHINode
>(OutI
))
1190 InsertPt
= Phi
->getParent()->getFirstInsertionPt();
1192 InsertPt
= std::next(OutI
->getIterator());
1194 Instruction
*InsertBefore
= &*InsertPt
;
1195 assert((InsertBefore
->getFunction() == newFunction
||
1196 Blocks
.count(InsertBefore
->getParent())) &&
1197 "InsertPt should be in new function");
1198 assert(OAI
!= newFunction
->arg_end() &&
1199 "Number of output arguments should match "
1200 "the amount of defined values");
1201 if (AggregateArgs
) {
1203 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1204 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1205 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1206 StructArgTy
, &*OAI
, Idx
, "gep_" + outputs
[i
]->getName(),
1208 new StoreInst(outputs
[i
], GEP
, InsertBefore
);
1209 // Since there should be only one struct argument aggregating
1210 // all the output values, we shouldn't increment OAI, which always
1211 // points to the struct argument, in this case.
1213 new StoreInst(outputs
[i
], &*OAI
, InsertBefore
);
1218 // Now that we've done the deed, simplify the switch instruction.
1219 Type
*OldFnRetTy
= TheSwitch
->getParent()->getParent()->getReturnType();
1220 switch (NumExitBlocks
) {
1222 // There are no successors (the block containing the switch itself), which
1223 // means that previously this was the last part of the function, and hence
1224 // this should be rewritten as a `ret'
1226 // Check if the function should return a value
1227 if (OldFnRetTy
->isVoidTy()) {
1228 ReturnInst::Create(Context
, nullptr, TheSwitch
); // Return void
1229 } else if (OldFnRetTy
== TheSwitch
->getCondition()->getType()) {
1230 // return what we have
1231 ReturnInst::Create(Context
, TheSwitch
->getCondition(), TheSwitch
);
1233 // Otherwise we must have code extracted an unwind or something, just
1234 // return whatever we want.
1235 ReturnInst::Create(Context
,
1236 Constant::getNullValue(OldFnRetTy
), TheSwitch
);
1239 TheSwitch
->eraseFromParent();
1242 // Only a single destination, change the switch into an unconditional
1244 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
);
1245 TheSwitch
->eraseFromParent();
1248 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
->getSuccessor(2),
1250 TheSwitch
->eraseFromParent();
1253 // Otherwise, make the default destination of the switch instruction be one
1254 // of the other successors.
1255 TheSwitch
->setCondition(call
);
1256 TheSwitch
->setDefaultDest(TheSwitch
->getSuccessor(NumExitBlocks
));
1257 // Remove redundant case
1258 TheSwitch
->removeCase(SwitchInst::CaseIt(TheSwitch
, NumExitBlocks
-1));
1262 // Insert lifetime markers around the reloads of any output values. The
1263 // allocas output values are stored in are only in-use in the codeRepl block.
1264 insertLifetimeMarkersSurroundingCall(M
, ReloadOutputs
, ReloadOutputs
, call
);
1269 void CodeExtractor::moveCodeToFunction(Function
*newFunction
) {
1270 Function
*oldFunc
= (*Blocks
.begin())->getParent();
1271 Function::BasicBlockListType
&oldBlocks
= oldFunc
->getBasicBlockList();
1272 Function::BasicBlockListType
&newBlocks
= newFunction
->getBasicBlockList();
1274 for (BasicBlock
*Block
: Blocks
) {
1275 // Delete the basic block from the old function, and the list of blocks
1276 oldBlocks
.remove(Block
);
1278 // Insert this basic block into the new function
1279 newBlocks
.push_back(Block
);
1281 // Remove @llvm.assume calls that were moved to the new function from the
1282 // old function's assumption cache.
1284 for (auto &I
: *Block
)
1285 if (match(&I
, m_Intrinsic
<Intrinsic::assume
>()))
1286 AC
->unregisterAssumption(cast
<CallInst
>(&I
));
1290 void CodeExtractor::calculateNewCallTerminatorWeights(
1291 BasicBlock
*CodeReplacer
,
1292 DenseMap
<BasicBlock
*, BlockFrequency
> &ExitWeights
,
1293 BranchProbabilityInfo
*BPI
) {
1294 using Distribution
= BlockFrequencyInfoImplBase::Distribution
;
1295 using BlockNode
= BlockFrequencyInfoImplBase::BlockNode
;
1297 // Update the branch weights for the exit block.
1298 Instruction
*TI
= CodeReplacer
->getTerminator();
1299 SmallVector
<unsigned, 8> BranchWeights(TI
->getNumSuccessors(), 0);
1301 // Block Frequency distribution with dummy node.
1302 Distribution BranchDist
;
1304 // Add each of the frequencies of the successors.
1305 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
< e
; ++i
) {
1306 BlockNode
ExitNode(i
);
1307 uint64_t ExitFreq
= ExitWeights
[TI
->getSuccessor(i
)].getFrequency();
1309 BranchDist
.addExit(ExitNode
, ExitFreq
);
1311 BPI
->setEdgeProbability(CodeReplacer
, i
, BranchProbability::getZero());
1314 // Check for no total weight.
1315 if (BranchDist
.Total
== 0)
1318 // Normalize the distribution so that they can fit in unsigned.
1319 BranchDist
.normalize();
1321 // Create normalized branch weights and set the metadata.
1322 for (unsigned I
= 0, E
= BranchDist
.Weights
.size(); I
< E
; ++I
) {
1323 const auto &Weight
= BranchDist
.Weights
[I
];
1325 // Get the weight and update the current BFI.
1326 BranchWeights
[Weight
.TargetNode
.Index
] = Weight
.Amount
;
1327 BranchProbability
BP(Weight
.Amount
, BranchDist
.Total
);
1328 BPI
->setEdgeProbability(CodeReplacer
, Weight
.TargetNode
.Index
, BP
);
1331 LLVMContext::MD_prof
,
1332 MDBuilder(TI
->getContext()).createBranchWeights(BranchWeights
));
1335 Function
*CodeExtractor::extractCodeRegion() {
1339 // Assumption: this is a single-entry code region, and the header is the first
1340 // block in the region.
1341 BasicBlock
*header
= *Blocks
.begin();
1342 Function
*oldFunction
= header
->getParent();
1344 // For functions with varargs, check that varargs handling is only done in the
1345 // outlined function, i.e vastart and vaend are only used in outlined blocks.
1346 if (AllowVarArgs
&& oldFunction
->getFunctionType()->isVarArg()) {
1347 auto containsVarArgIntrinsic
= [](Instruction
&I
) {
1348 if (const CallInst
*CI
= dyn_cast
<CallInst
>(&I
))
1349 if (const Function
*F
= CI
->getCalledFunction())
1350 return F
->getIntrinsicID() == Intrinsic::vastart
||
1351 F
->getIntrinsicID() == Intrinsic::vaend
;
1355 for (auto &BB
: *oldFunction
) {
1356 if (Blocks
.count(&BB
))
1358 if (llvm::any_of(BB
, containsVarArgIntrinsic
))
1362 ValueSet inputs
, outputs
, SinkingCands
, HoistingCands
;
1363 BasicBlock
*CommonExit
= nullptr;
1365 // Calculate the entry frequency of the new function before we change the root
1367 BlockFrequency EntryFreq
;
1369 assert(BPI
&& "Both BPI and BFI are required to preserve profile info");
1370 for (BasicBlock
*Pred
: predecessors(header
)) {
1371 if (Blocks
.count(Pred
))
1374 BFI
->getBlockFreq(Pred
) * BPI
->getEdgeProbability(Pred
, header
);
1378 // If we have any return instructions in the region, split those blocks so
1379 // that the return is not in the region.
1380 splitReturnBlocks();
1382 // Calculate the exit blocks for the extracted region and the total exit
1383 // weights for each of those blocks.
1384 DenseMap
<BasicBlock
*, BlockFrequency
> ExitWeights
;
1385 SmallPtrSet
<BasicBlock
*, 1> ExitBlocks
;
1386 for (BasicBlock
*Block
: Blocks
) {
1387 for (succ_iterator SI
= succ_begin(Block
), SE
= succ_end(Block
); SI
!= SE
;
1389 if (!Blocks
.count(*SI
)) {
1390 // Update the branch weight for this successor.
1392 BlockFrequency
&BF
= ExitWeights
[*SI
];
1393 BF
+= BFI
->getBlockFreq(Block
) * BPI
->getEdgeProbability(Block
, *SI
);
1395 ExitBlocks
.insert(*SI
);
1399 NumExitBlocks
= ExitBlocks
.size();
1401 // If we have to split PHI nodes of the entry or exit blocks, do so now.
1402 severSplitPHINodesOfEntry(header
);
1403 severSplitPHINodesOfExits(ExitBlocks
);
1405 // This takes place of the original loop
1406 BasicBlock
*codeReplacer
= BasicBlock::Create(header
->getContext(),
1407 "codeRepl", oldFunction
,
1410 // The new function needs a root node because other nodes can branch to the
1411 // head of the region, but the entry node of a function cannot have preds.
1412 BasicBlock
*newFuncRoot
= BasicBlock::Create(header
->getContext(),
1414 auto *BranchI
= BranchInst::Create(header
);
1415 // If the original function has debug info, we have to add a debug location
1416 // to the new branch instruction from the artificial entry block.
1417 // We use the debug location of the first instruction in the extracted
1418 // blocks, as there is no other equivalent line in the source code.
1419 if (oldFunction
->getSubprogram()) {
1420 any_of(Blocks
, [&BranchI
](const BasicBlock
*BB
) {
1421 return any_of(*BB
, [&BranchI
](const Instruction
&I
) {
1422 if (!I
.getDebugLoc())
1424 BranchI
->setDebugLoc(I
.getDebugLoc());
1429 newFuncRoot
->getInstList().push_back(BranchI
);
1431 findAllocas(SinkingCands
, HoistingCands
, CommonExit
);
1432 assert(HoistingCands
.empty() || CommonExit
);
1434 // Find inputs to, outputs from the code region.
1435 findInputsOutputs(inputs
, outputs
, SinkingCands
);
1437 // Now sink all instructions which only have non-phi uses inside the region.
1438 // Group the allocas at the start of the block, so that any bitcast uses of
1439 // the allocas are well-defined.
1440 AllocaInst
*FirstSunkAlloca
= nullptr;
1441 for (auto *II
: SinkingCands
) {
1442 if (auto *AI
= dyn_cast
<AllocaInst
>(II
)) {
1443 AI
->moveBefore(*newFuncRoot
, newFuncRoot
->getFirstInsertionPt());
1444 if (!FirstSunkAlloca
)
1445 FirstSunkAlloca
= AI
;
1448 assert((SinkingCands
.empty() || FirstSunkAlloca
) &&
1449 "Did not expect a sink candidate without any allocas");
1450 for (auto *II
: SinkingCands
) {
1451 if (!isa
<AllocaInst
>(II
)) {
1452 cast
<Instruction
>(II
)->moveAfter(FirstSunkAlloca
);
1456 if (!HoistingCands
.empty()) {
1457 auto *HoistToBlock
= findOrCreateBlockForHoisting(CommonExit
);
1458 Instruction
*TI
= HoistToBlock
->getTerminator();
1459 for (auto *II
: HoistingCands
)
1460 cast
<Instruction
>(II
)->moveBefore(TI
);
1463 // Collect objects which are inputs to the extraction region and also
1464 // referenced by lifetime start markers within it. The effects of these
1465 // markers must be replicated in the calling function to prevent the stack
1466 // coloring pass from merging slots which store input objects.
1467 ValueSet LifetimesStart
;
1468 eraseLifetimeMarkersOnInputs(Blocks
, SinkingCands
, LifetimesStart
);
1470 // Construct new function based on inputs/outputs & add allocas for all defs.
1471 Function
*newFunction
=
1472 constructFunction(inputs
, outputs
, header
, newFuncRoot
, codeReplacer
,
1473 oldFunction
, oldFunction
->getParent());
1475 // Update the entry count of the function.
1477 auto Count
= BFI
->getProfileCountFromFreq(EntryFreq
.getFrequency());
1478 if (Count
.hasValue())
1479 newFunction
->setEntryCount(
1480 ProfileCount(Count
.getValue(), Function::PCT_Real
)); // FIXME
1481 BFI
->setBlockFreq(codeReplacer
, EntryFreq
.getFrequency());
1485 emitCallAndSwitchStatement(newFunction
, codeReplacer
, inputs
, outputs
);
1487 moveCodeToFunction(newFunction
);
1489 // Replicate the effects of any lifetime start/end markers which referenced
1490 // input objects in the extraction region by placing markers around the call.
1491 insertLifetimeMarkersSurroundingCall(
1492 oldFunction
->getParent(), LifetimesStart
.getArrayRef(), {}, TheCall
);
1494 // Propagate personality info to the new function if there is one.
1495 if (oldFunction
->hasPersonalityFn())
1496 newFunction
->setPersonalityFn(oldFunction
->getPersonalityFn());
1498 // Update the branch weights for the exit block.
1499 if (BFI
&& NumExitBlocks
> 1)
1500 calculateNewCallTerminatorWeights(codeReplacer
, ExitWeights
, BPI
);
1502 // Loop over all of the PHI nodes in the header and exit blocks, and change
1503 // any references to the old incoming edge to be the new incoming edge.
1504 for (BasicBlock::iterator I
= header
->begin(); isa
<PHINode
>(I
); ++I
) {
1505 PHINode
*PN
= cast
<PHINode
>(I
);
1506 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
1507 if (!Blocks
.count(PN
->getIncomingBlock(i
)))
1508 PN
->setIncomingBlock(i
, newFuncRoot
);
1511 for (BasicBlock
*ExitBB
: ExitBlocks
)
1512 for (PHINode
&PN
: ExitBB
->phis()) {
1513 Value
*IncomingCodeReplacerVal
= nullptr;
1514 for (unsigned i
= 0, e
= PN
.getNumIncomingValues(); i
!= e
; ++i
) {
1515 // Ignore incoming values from outside of the extracted region.
1516 if (!Blocks
.count(PN
.getIncomingBlock(i
)))
1519 // Ensure that there is only one incoming value from codeReplacer.
1520 if (!IncomingCodeReplacerVal
) {
1521 PN
.setIncomingBlock(i
, codeReplacer
);
1522 IncomingCodeReplacerVal
= PN
.getIncomingValue(i
);
1524 assert(IncomingCodeReplacerVal
== PN
.getIncomingValue(i
) &&
1525 "PHI has two incompatbile incoming values from codeRepl");
1529 // Erase debug info intrinsics. Variable updates within the new function are
1530 // invisible to debuggers. This could be improved by defining a DISubprogram
1531 // for the new function.
1532 for (BasicBlock
&BB
: *newFunction
) {
1533 auto BlockIt
= BB
.begin();
1534 // Remove debug info intrinsics from the new function.
1535 while (BlockIt
!= BB
.end()) {
1536 Instruction
*Inst
= &*BlockIt
;
1538 if (isa
<DbgInfoIntrinsic
>(Inst
))
1539 Inst
->eraseFromParent();
1541 // Remove debug info intrinsics which refer to values in the new function
1542 // from the old function.
1543 SmallVector
<DbgVariableIntrinsic
*, 4> DbgUsers
;
1544 for (Instruction
&I
: BB
)
1545 findDbgUsers(DbgUsers
, &I
);
1546 for (DbgVariableIntrinsic
*DVI
: DbgUsers
)
1547 DVI
->eraseFromParent();
1550 // Mark the new function `noreturn` if applicable. Terminators which resume
1551 // exception propagation are treated as returning instructions. This is to
1552 // avoid inserting traps after calls to outlined functions which unwind.
1553 bool doesNotReturn
= none_of(*newFunction
, [](const BasicBlock
&BB
) {
1554 const Instruction
*Term
= BB
.getTerminator();
1555 return isa
<ReturnInst
>(Term
) || isa
<ResumeInst
>(Term
);
1558 newFunction
->setDoesNotReturn();
1560 LLVM_DEBUG(if (verifyFunction(*newFunction
, &errs())) {
1561 newFunction
->dump();
1562 report_fatal_error("verification of newFunction failed!");
1564 LLVM_DEBUG(if (verifyFunction(*oldFunction
))
1565 report_fatal_error("verification of oldFunction failed!"));