1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the interface to tear out a code region, such as an
10 // individual loop or a parallel section, into a new function, replacing it with
11 // a call to the new function.
13 //===----------------------------------------------------------------------===//
15 #include "llvm/Transforms/Utils/CodeExtractor.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Analysis/AssumptionCache.h"
24 #include "llvm/Analysis/BlockFrequencyInfo.h"
25 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26 #include "llvm/Analysis/BranchProbabilityInfo.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/IR/Argument.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DataLayout.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/GlobalValue.h"
39 #include "llvm/IR/InstrTypes.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/Intrinsics.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/Module.h"
47 #include "llvm/IR/PatternMatch.h"
48 #include "llvm/IR/Type.h"
49 #include "llvm/IR/User.h"
50 #include "llvm/IR/Value.h"
51 #include "llvm/IR/Verifier.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/BlockFrequency.h"
54 #include "llvm/Support/BranchProbability.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CommandLine.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
61 #include "llvm/Transforms/Utils/Local.h"
71 using namespace llvm::PatternMatch
;
72 using ProfileCount
= Function::ProfileCount
;
74 #define DEBUG_TYPE "code-extractor"
76 // Provide a command-line option to aggregate function arguments into a struct
77 // for functions produced by the code extractor. This is useful when converting
78 // extracted functions to pthread-based code, as only one argument (void*) can
79 // be passed in to pthread_create().
81 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden
,
82 cl::desc("Aggregate arguments to code-extracted functions"));
84 /// Test whether a block is valid for extraction.
85 static bool isBlockValidForExtraction(const BasicBlock
&BB
,
86 const SetVector
<BasicBlock
*> &Result
,
87 bool AllowVarArgs
, bool AllowAlloca
) {
88 // taking the address of a basic block moved to another function is illegal
89 if (BB
.hasAddressTaken())
92 // don't hoist code that uses another basicblock address, as it's likely to
93 // lead to unexpected behavior, like cross-function jumps
94 SmallPtrSet
<User
const *, 16> Visited
;
95 SmallVector
<User
const *, 16> ToVisit
;
97 for (Instruction
const &Inst
: BB
)
98 ToVisit
.push_back(&Inst
);
100 while (!ToVisit
.empty()) {
101 User
const *Curr
= ToVisit
.pop_back_val();
102 if (!Visited
.insert(Curr
).second
)
104 if (isa
<BlockAddress
const>(Curr
))
105 return false; // even a reference to self is likely to be not compatible
107 if (isa
<Instruction
>(Curr
) && cast
<Instruction
>(Curr
)->getParent() != &BB
)
110 for (auto const &U
: Curr
->operands()) {
111 if (auto *UU
= dyn_cast
<User
>(U
))
112 ToVisit
.push_back(UU
);
116 // If explicitly requested, allow vastart and alloca. For invoke instructions
117 // verify that extraction is valid.
118 for (BasicBlock::const_iterator I
= BB
.begin(), E
= BB
.end(); I
!= E
; ++I
) {
119 if (isa
<AllocaInst
>(I
)) {
125 if (const auto *II
= dyn_cast
<InvokeInst
>(I
)) {
126 // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
127 // must be a part of the subgraph which is being extracted.
128 if (auto *UBB
= II
->getUnwindDest())
129 if (!Result
.count(UBB
))
134 // All catch handlers of a catchswitch instruction as well as the unwind
135 // destination must be in the subgraph.
136 if (const auto *CSI
= dyn_cast
<CatchSwitchInst
>(I
)) {
137 if (auto *UBB
= CSI
->getUnwindDest())
138 if (!Result
.count(UBB
))
140 for (auto *HBB
: CSI
->handlers())
141 if (!Result
.count(const_cast<BasicBlock
*>(HBB
)))
146 // Make sure that entire catch handler is within subgraph. It is sufficient
147 // to check that catch return's block is in the list.
148 if (const auto *CPI
= dyn_cast
<CatchPadInst
>(I
)) {
149 for (const auto *U
: CPI
->users())
150 if (const auto *CRI
= dyn_cast
<CatchReturnInst
>(U
))
151 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
156 // And do similar checks for cleanup handler - the entire handler must be
157 // in subgraph which is going to be extracted. For cleanup return should
158 // additionally check that the unwind destination is also in the subgraph.
159 if (const auto *CPI
= dyn_cast
<CleanupPadInst
>(I
)) {
160 for (const auto *U
: CPI
->users())
161 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(U
))
162 if (!Result
.count(const_cast<BasicBlock
*>(CRI
->getParent())))
166 if (const auto *CRI
= dyn_cast
<CleanupReturnInst
>(I
)) {
167 if (auto *UBB
= CRI
->getUnwindDest())
168 if (!Result
.count(UBB
))
173 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I
)) {
174 if (const Function
*F
= CI
->getCalledFunction()) {
175 auto IID
= F
->getIntrinsicID();
176 if (IID
== Intrinsic::vastart
) {
183 // Currently, we miscompile outlined copies of eh_typid_for. There are
184 // proposals for fixing this in llvm.org/PR39545.
185 if (IID
== Intrinsic::eh_typeid_for
)
194 /// Build a set of blocks to extract if the input blocks are viable.
195 static SetVector
<BasicBlock
*>
196 buildExtractionBlockSet(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
197 bool AllowVarArgs
, bool AllowAlloca
) {
198 assert(!BBs
.empty() && "The set of blocks to extract must be non-empty");
199 SetVector
<BasicBlock
*> Result
;
201 // Loop over the blocks, adding them to our set-vector, and aborting with an
202 // empty set if we encounter invalid blocks.
203 for (BasicBlock
*BB
: BBs
) {
204 // If this block is dead, don't process it.
205 if (DT
&& !DT
->isReachableFromEntry(BB
))
208 if (!Result
.insert(BB
))
209 llvm_unreachable("Repeated basic blocks in extraction input");
212 LLVM_DEBUG(dbgs() << "Region front block: " << Result
.front()->getName()
215 for (auto *BB
: Result
) {
216 if (!isBlockValidForExtraction(*BB
, Result
, AllowVarArgs
, AllowAlloca
))
219 // Make sure that the first block is not a landing pad.
220 if (BB
== Result
.front()) {
222 LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
228 // All blocks other than the first must not have predecessors outside of
229 // the subgraph which is being extracted.
230 for (auto *PBB
: predecessors(BB
))
231 if (!Result
.count(PBB
)) {
232 LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
233 "outside the region except for the first block!\n"
234 << "Problematic source BB: " << BB
->getName() << "\n"
235 << "Problematic destination BB: " << PBB
->getName()
244 CodeExtractor::CodeExtractor(ArrayRef
<BasicBlock
*> BBs
, DominatorTree
*DT
,
245 bool AggregateArgs
, BlockFrequencyInfo
*BFI
,
246 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
247 bool AllowVarArgs
, bool AllowAlloca
,
249 : DT(DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
250 BPI(BPI
), AC(AC
), AllowVarArgs(AllowVarArgs
),
251 Blocks(buildExtractionBlockSet(BBs
, DT
, AllowVarArgs
, AllowAlloca
)),
254 CodeExtractor::CodeExtractor(DominatorTree
&DT
, Loop
&L
, bool AggregateArgs
,
255 BlockFrequencyInfo
*BFI
,
256 BranchProbabilityInfo
*BPI
, AssumptionCache
*AC
,
258 : DT(&DT
), AggregateArgs(AggregateArgs
|| AggregateArgsOpt
), BFI(BFI
),
259 BPI(BPI
), AC(AC
), AllowVarArgs(false),
260 Blocks(buildExtractionBlockSet(L
.getBlocks(), &DT
,
261 /* AllowVarArgs */ false,
262 /* AllowAlloca */ false)),
265 /// definedInRegion - Return true if the specified value is defined in the
266 /// extracted region.
267 static bool definedInRegion(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
268 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
269 if (Blocks
.count(I
->getParent()))
274 /// definedInCaller - Return true if the specified value is defined in the
275 /// function being code extracted, but not in the region being extracted.
276 /// These values must be passed in as live-ins to the function.
277 static bool definedInCaller(const SetVector
<BasicBlock
*> &Blocks
, Value
*V
) {
278 if (isa
<Argument
>(V
)) return true;
279 if (Instruction
*I
= dyn_cast
<Instruction
>(V
))
280 if (!Blocks
.count(I
->getParent()))
285 static BasicBlock
*getCommonExitBlock(const SetVector
<BasicBlock
*> &Blocks
) {
286 BasicBlock
*CommonExitBlock
= nullptr;
287 auto hasNonCommonExitSucc
= [&](BasicBlock
*Block
) {
288 for (auto *Succ
: successors(Block
)) {
289 // Internal edges, ok.
290 if (Blocks
.count(Succ
))
292 if (!CommonExitBlock
) {
293 CommonExitBlock
= Succ
;
296 if (CommonExitBlock
!= Succ
)
302 if (any_of(Blocks
, hasNonCommonExitSucc
))
305 return CommonExitBlock
;
308 CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function
&F
) {
309 for (BasicBlock
&BB
: F
) {
310 for (Instruction
&II
: BB
.instructionsWithoutDebug())
311 if (auto *AI
= dyn_cast
<AllocaInst
>(&II
))
312 Allocas
.push_back(AI
);
314 findSideEffectInfoForBlock(BB
);
318 void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock
&BB
) {
319 for (Instruction
&II
: BB
.instructionsWithoutDebug()) {
320 unsigned Opcode
= II
.getOpcode();
321 Value
*MemAddr
= nullptr;
323 case Instruction::Store
:
324 case Instruction::Load
: {
325 if (Opcode
== Instruction::Store
) {
326 StoreInst
*SI
= cast
<StoreInst
>(&II
);
327 MemAddr
= SI
->getPointerOperand();
329 LoadInst
*LI
= cast
<LoadInst
>(&II
);
330 MemAddr
= LI
->getPointerOperand();
332 // Global variable can not be aliased with locals.
333 if (dyn_cast
<Constant
>(MemAddr
))
335 Value
*Base
= MemAddr
->stripInBoundsConstantOffsets();
336 if (!isa
<AllocaInst
>(Base
)) {
337 SideEffectingBlocks
.insert(&BB
);
340 BaseMemAddrs
[&BB
].insert(Base
);
344 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(&II
);
346 if (IntrInst
->isLifetimeStartOrEnd())
348 SideEffectingBlocks
.insert(&BB
);
351 // Treat all the other cases conservatively if it has side effects.
352 if (II
.mayHaveSideEffects()) {
353 SideEffectingBlocks
.insert(&BB
);
361 bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
362 BasicBlock
&BB
, AllocaInst
*Addr
) const {
363 if (SideEffectingBlocks
.count(&BB
))
365 auto It
= BaseMemAddrs
.find(&BB
);
366 if (It
!= BaseMemAddrs
.end())
367 return It
->second
.count(Addr
);
371 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
372 const CodeExtractorAnalysisCache
&CEAC
, Instruction
*Addr
) const {
373 AllocaInst
*AI
= cast
<AllocaInst
>(Addr
->stripInBoundsConstantOffsets());
374 Function
*Func
= (*Blocks
.begin())->getParent();
375 for (BasicBlock
&BB
: *Func
) {
376 if (Blocks
.count(&BB
))
378 if (CEAC
.doesBlockContainClobberOfAddr(BB
, AI
))
385 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock
*CommonExitBlock
) {
386 BasicBlock
*SinglePredFromOutlineRegion
= nullptr;
387 assert(!Blocks
.count(CommonExitBlock
) &&
388 "Expect a block outside the region!");
389 for (auto *Pred
: predecessors(CommonExitBlock
)) {
390 if (!Blocks
.count(Pred
))
392 if (!SinglePredFromOutlineRegion
) {
393 SinglePredFromOutlineRegion
= Pred
;
394 } else if (SinglePredFromOutlineRegion
!= Pred
) {
395 SinglePredFromOutlineRegion
= nullptr;
400 if (SinglePredFromOutlineRegion
)
401 return SinglePredFromOutlineRegion
;
404 auto getFirstPHI
= [](BasicBlock
*BB
) {
405 BasicBlock::iterator I
= BB
->begin();
406 PHINode
*FirstPhi
= nullptr;
407 while (I
!= BB
->end()) {
408 PHINode
*Phi
= dyn_cast
<PHINode
>(I
);
418 // If there are any phi nodes, the single pred either exists or has already
419 // be created before code extraction.
420 assert(!getFirstPHI(CommonExitBlock
) && "Phi not expected");
423 BasicBlock
*NewExitBlock
= CommonExitBlock
->splitBasicBlock(
424 CommonExitBlock
->getFirstNonPHI()->getIterator());
426 for (auto PI
= pred_begin(CommonExitBlock
), PE
= pred_end(CommonExitBlock
);
428 BasicBlock
*Pred
= *PI
++;
429 if (Blocks
.count(Pred
))
431 Pred
->getTerminator()->replaceUsesOfWith(CommonExitBlock
, NewExitBlock
);
433 // Now add the old exit block to the outline region.
434 Blocks
.insert(CommonExitBlock
);
435 return CommonExitBlock
;
438 // Find the pair of life time markers for address 'Addr' that are either
439 // defined inside the outline region or can legally be shrinkwrapped into the
440 // outline region. If there are not other untracked uses of the address, return
441 // the pair of markers if found; otherwise return a pair of nullptr.
442 CodeExtractor::LifetimeMarkerInfo
443 CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache
&CEAC
,
445 BasicBlock
*ExitBlock
) const {
446 LifetimeMarkerInfo Info
;
448 for (User
*U
: Addr
->users()) {
449 IntrinsicInst
*IntrInst
= dyn_cast
<IntrinsicInst
>(U
);
451 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_start
) {
452 // Do not handle the case where Addr has multiple start markers.
455 Info
.LifeStart
= IntrInst
;
457 if (IntrInst
->getIntrinsicID() == Intrinsic::lifetime_end
) {
460 Info
.LifeEnd
= IntrInst
;
464 // Find untracked uses of the address, bail.
465 if (!definedInRegion(Blocks
, U
))
469 if (!Info
.LifeStart
|| !Info
.LifeEnd
)
472 Info
.SinkLifeStart
= !definedInRegion(Blocks
, Info
.LifeStart
);
473 Info
.HoistLifeEnd
= !definedInRegion(Blocks
, Info
.LifeEnd
);
474 // Do legality check.
475 if ((Info
.SinkLifeStart
|| Info
.HoistLifeEnd
) &&
476 !isLegalToShrinkwrapLifetimeMarkers(CEAC
, Addr
))
479 // Check to see if we have a place to do hoisting, if not, bail.
480 if (Info
.HoistLifeEnd
&& !ExitBlock
)
486 void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache
&CEAC
,
487 ValueSet
&SinkCands
, ValueSet
&HoistCands
,
488 BasicBlock
*&ExitBlock
) const {
489 Function
*Func
= (*Blocks
.begin())->getParent();
490 ExitBlock
= getCommonExitBlock(Blocks
);
492 auto moveOrIgnoreLifetimeMarkers
=
493 [&](const LifetimeMarkerInfo
&LMI
) -> bool {
496 if (LMI
.SinkLifeStart
) {
497 LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI
.LifeStart
499 SinkCands
.insert(LMI
.LifeStart
);
501 if (LMI
.HoistLifeEnd
) {
502 LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI
.LifeEnd
<< "\n");
503 HoistCands
.insert(LMI
.LifeEnd
);
508 // Look up allocas in the original function in CodeExtractorAnalysisCache, as
509 // this is much faster than walking all the instructions.
510 for (AllocaInst
*AI
: CEAC
.getAllocas()) {
511 BasicBlock
*BB
= AI
->getParent();
512 if (Blocks
.count(BB
))
515 // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
516 // check whether it is actually still in the original function.
517 Function
*AIFunc
= BB
->getParent();
521 LifetimeMarkerInfo MarkerInfo
= getLifetimeMarkers(CEAC
, AI
, ExitBlock
);
522 bool Moved
= moveOrIgnoreLifetimeMarkers(MarkerInfo
);
524 LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI
<< "\n");
525 SinkCands
.insert(AI
);
529 // Follow any bitcasts.
530 SmallVector
<Instruction
*, 2> Bitcasts
;
531 SmallVector
<LifetimeMarkerInfo
, 2> BitcastLifetimeInfo
;
532 for (User
*U
: AI
->users()) {
533 if (U
->stripInBoundsConstantOffsets() == AI
) {
534 Instruction
*Bitcast
= cast
<Instruction
>(U
);
535 LifetimeMarkerInfo LMI
= getLifetimeMarkers(CEAC
, Bitcast
, ExitBlock
);
537 Bitcasts
.push_back(Bitcast
);
538 BitcastLifetimeInfo
.push_back(LMI
);
543 // Found unknown use of AI.
544 if (!definedInRegion(Blocks
, U
)) {
550 // Either no bitcasts reference the alloca or there are unknown uses.
551 if (Bitcasts
.empty())
554 LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI
<< "\n");
555 SinkCands
.insert(AI
);
556 for (unsigned I
= 0, E
= Bitcasts
.size(); I
!= E
; ++I
) {
557 Instruction
*BitcastAddr
= Bitcasts
[I
];
558 const LifetimeMarkerInfo
&LMI
= BitcastLifetimeInfo
[I
];
559 assert(LMI
.LifeStart
&&
560 "Unsafe to sink bitcast without lifetime markers");
561 moveOrIgnoreLifetimeMarkers(LMI
);
562 if (!definedInRegion(Blocks
, BitcastAddr
)) {
563 LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
565 SinkCands
.insert(BitcastAddr
);
571 bool CodeExtractor::isEligible() const {
574 BasicBlock
*Header
= *Blocks
.begin();
575 Function
*F
= Header
->getParent();
577 // For functions with varargs, check that varargs handling is only done in the
578 // outlined function, i.e vastart and vaend are only used in outlined blocks.
579 if (AllowVarArgs
&& F
->getFunctionType()->isVarArg()) {
580 auto containsVarArgIntrinsic
= [](const Instruction
&I
) {
581 if (const CallInst
*CI
= dyn_cast
<CallInst
>(&I
))
582 if (const Function
*Callee
= CI
->getCalledFunction())
583 return Callee
->getIntrinsicID() == Intrinsic::vastart
||
584 Callee
->getIntrinsicID() == Intrinsic::vaend
;
588 for (auto &BB
: *F
) {
589 if (Blocks
.count(&BB
))
591 if (llvm::any_of(BB
, containsVarArgIntrinsic
))
598 void CodeExtractor::findInputsOutputs(ValueSet
&Inputs
, ValueSet
&Outputs
,
599 const ValueSet
&SinkCands
) const {
600 for (BasicBlock
*BB
: Blocks
) {
601 // If a used value is defined outside the region, it's an input. If an
602 // instruction is used outside the region, it's an output.
603 for (Instruction
&II
: *BB
) {
604 for (auto &OI
: II
.operands()) {
606 if (!SinkCands
.count(V
) && definedInCaller(Blocks
, V
))
610 for (User
*U
: II
.users())
611 if (!definedInRegion(Blocks
, U
)) {
619 /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
620 /// of the region, we need to split the entry block of the region so that the
621 /// PHI node is easier to deal with.
622 void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock
*&Header
) {
623 unsigned NumPredsFromRegion
= 0;
624 unsigned NumPredsOutsideRegion
= 0;
626 if (Header
!= &Header
->getParent()->getEntryBlock()) {
627 PHINode
*PN
= dyn_cast
<PHINode
>(Header
->begin());
628 if (!PN
) return; // No PHI nodes.
630 // If the header node contains any PHI nodes, check to see if there is more
631 // than one entry from outside the region. If so, we need to sever the
632 // header block into two.
633 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
634 if (Blocks
.count(PN
->getIncomingBlock(i
)))
635 ++NumPredsFromRegion
;
637 ++NumPredsOutsideRegion
;
639 // If there is one (or fewer) predecessor from outside the region, we don't
640 // need to do anything special.
641 if (NumPredsOutsideRegion
<= 1) return;
644 // Otherwise, we need to split the header block into two pieces: one
645 // containing PHI nodes merging values from outside of the region, and a
646 // second that contains all of the code for the block and merges back any
647 // incoming values from inside of the region.
648 BasicBlock
*NewBB
= SplitBlock(Header
, Header
->getFirstNonPHI(), DT
);
650 // We only want to code extract the second block now, and it becomes the new
651 // header of the region.
652 BasicBlock
*OldPred
= Header
;
653 Blocks
.remove(OldPred
);
654 Blocks
.insert(NewBB
);
657 // Okay, now we need to adjust the PHI nodes and any branches from within the
658 // region to go to the new header block instead of the old header block.
659 if (NumPredsFromRegion
) {
660 PHINode
*PN
= cast
<PHINode
>(OldPred
->begin());
661 // Loop over all of the predecessors of OldPred that are in the region,
662 // changing them to branch to NewBB instead.
663 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
664 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
665 Instruction
*TI
= PN
->getIncomingBlock(i
)->getTerminator();
666 TI
->replaceUsesOfWith(OldPred
, NewBB
);
669 // Okay, everything within the region is now branching to the right block, we
670 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
671 BasicBlock::iterator AfterPHIs
;
672 for (AfterPHIs
= OldPred
->begin(); isa
<PHINode
>(AfterPHIs
); ++AfterPHIs
) {
673 PHINode
*PN
= cast
<PHINode
>(AfterPHIs
);
674 // Create a new PHI node in the new region, which has an incoming value
675 // from OldPred of PN.
676 PHINode
*NewPN
= PHINode::Create(PN
->getType(), 1 + NumPredsFromRegion
,
677 PN
->getName() + ".ce", &NewBB
->front());
678 PN
->replaceAllUsesWith(NewPN
);
679 NewPN
->addIncoming(PN
, OldPred
);
681 // Loop over all of the incoming value in PN, moving them to NewPN if they
682 // are from the extracted region.
683 for (unsigned i
= 0; i
!= PN
->getNumIncomingValues(); ++i
) {
684 if (Blocks
.count(PN
->getIncomingBlock(i
))) {
685 NewPN
->addIncoming(PN
->getIncomingValue(i
), PN
->getIncomingBlock(i
));
686 PN
->removeIncomingValue(i
);
694 /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
695 /// outlined region, we split these PHIs on two: one with inputs from region
696 /// and other with remaining incoming blocks; then first PHIs are placed in
698 void CodeExtractor::severSplitPHINodesOfExits(
699 const SmallPtrSetImpl
<BasicBlock
*> &Exits
) {
700 for (BasicBlock
*ExitBB
: Exits
) {
701 BasicBlock
*NewBB
= nullptr;
703 for (PHINode
&PN
: ExitBB
->phis()) {
704 // Find all incoming values from the outlining region.
705 SmallVector
<unsigned, 2> IncomingVals
;
706 for (unsigned i
= 0; i
< PN
.getNumIncomingValues(); ++i
)
707 if (Blocks
.count(PN
.getIncomingBlock(i
)))
708 IncomingVals
.push_back(i
);
710 // Do not process PHI if there is one (or fewer) predecessor from region.
711 // If PHI has exactly one predecessor from region, only this one incoming
712 // will be replaced on codeRepl block, so it should be safe to skip PHI.
713 if (IncomingVals
.size() <= 1)
716 // Create block for new PHIs and add it to the list of outlined if it
717 // wasn't done before.
719 NewBB
= BasicBlock::Create(ExitBB
->getContext(),
720 ExitBB
->getName() + ".split",
721 ExitBB
->getParent(), ExitBB
);
722 SmallVector
<BasicBlock
*, 4> Preds(pred_begin(ExitBB
),
724 for (BasicBlock
*PredBB
: Preds
)
725 if (Blocks
.count(PredBB
))
726 PredBB
->getTerminator()->replaceUsesOfWith(ExitBB
, NewBB
);
727 BranchInst::Create(ExitBB
, NewBB
);
728 Blocks
.insert(NewBB
);
733 PHINode::Create(PN
.getType(), IncomingVals
.size(),
734 PN
.getName() + ".ce", NewBB
->getFirstNonPHI());
735 for (unsigned i
: IncomingVals
)
736 NewPN
->addIncoming(PN
.getIncomingValue(i
), PN
.getIncomingBlock(i
));
737 for (unsigned i
: reverse(IncomingVals
))
738 PN
.removeIncomingValue(i
, false);
739 PN
.addIncoming(NewPN
, NewBB
);
744 void CodeExtractor::splitReturnBlocks() {
745 for (BasicBlock
*Block
: Blocks
)
746 if (ReturnInst
*RI
= dyn_cast
<ReturnInst
>(Block
->getTerminator())) {
748 Block
->splitBasicBlock(RI
->getIterator(), Block
->getName() + ".ret");
750 // Old dominates New. New node dominates all other nodes dominated
752 DomTreeNode
*OldNode
= DT
->getNode(Block
);
753 SmallVector
<DomTreeNode
*, 8> Children(OldNode
->begin(),
756 DomTreeNode
*NewNode
= DT
->addNewBlock(New
, Block
);
758 for (DomTreeNode
*I
: Children
)
759 DT
->changeImmediateDominator(I
, NewNode
);
764 /// constructFunction - make a function based on inputs and outputs, as follows:
765 /// f(in0, ..., inN, out0, ..., outN)
766 Function
*CodeExtractor::constructFunction(const ValueSet
&inputs
,
767 const ValueSet
&outputs
,
769 BasicBlock
*newRootNode
,
770 BasicBlock
*newHeader
,
771 Function
*oldFunction
,
773 LLVM_DEBUG(dbgs() << "inputs: " << inputs
.size() << "\n");
774 LLVM_DEBUG(dbgs() << "outputs: " << outputs
.size() << "\n");
776 // This function returns unsigned, outputs will go back by reference.
777 switch (NumExitBlocks
) {
779 case 1: RetTy
= Type::getVoidTy(header
->getContext()); break;
780 case 2: RetTy
= Type::getInt1Ty(header
->getContext()); break;
781 default: RetTy
= Type::getInt16Ty(header
->getContext()); break;
784 std::vector
<Type
*> paramTy
;
786 // Add the types of the input values to the function's argument list
787 for (Value
*value
: inputs
) {
788 LLVM_DEBUG(dbgs() << "value used in func: " << *value
<< "\n");
789 paramTy
.push_back(value
->getType());
792 // Add the types of the output values to the function's argument list.
793 for (Value
*output
: outputs
) {
794 LLVM_DEBUG(dbgs() << "instr used in func: " << *output
<< "\n");
796 paramTy
.push_back(output
->getType());
798 paramTy
.push_back(PointerType::getUnqual(output
->getType()));
802 dbgs() << "Function type: " << *RetTy
<< " f(";
803 for (Type
*i
: paramTy
)
804 dbgs() << *i
<< ", ";
808 StructType
*StructTy
;
809 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
810 StructTy
= StructType::get(M
->getContext(), paramTy
);
812 paramTy
.push_back(PointerType::getUnqual(StructTy
));
814 FunctionType
*funcType
=
815 FunctionType::get(RetTy
, paramTy
,
816 AllowVarArgs
&& oldFunction
->isVarArg());
818 std::string SuffixToUse
=
820 ? (header
->getName().empty() ? "extracted" : header
->getName().str())
822 // Create the new function
823 Function
*newFunction
= Function::Create(
824 funcType
, GlobalValue::InternalLinkage
, oldFunction
->getAddressSpace(),
825 oldFunction
->getName() + "." + SuffixToUse
, M
);
826 // If the old function is no-throw, so is the new one.
827 if (oldFunction
->doesNotThrow())
828 newFunction
->setDoesNotThrow();
830 // Inherit the uwtable attribute if we need to.
831 if (oldFunction
->hasUWTable())
832 newFunction
->setHasUWTable();
834 // Inherit all of the target dependent attributes and white-listed
835 // target independent attributes.
836 // (e.g. If the extracted region contains a call to an x86.sse
837 // instruction we need to make sure that the extracted region has the
838 // "target-features" attribute allowing it to be lowered.
839 // FIXME: This should be changed to check to see if a specific
840 // attribute can not be inherited.
841 for (const auto &Attr
: oldFunction
->getAttributes().getFnAttributes()) {
842 if (Attr
.isStringAttribute()) {
843 if (Attr
.getKindAsString() == "thunk")
846 switch (Attr
.getKindAsEnum()) {
847 // Those attributes cannot be propagated safely. Explicitly list them
848 // here so we get a warning if new attributes are added. This list also
849 // includes non-function attributes.
850 case Attribute::Alignment
:
851 case Attribute::AllocSize
:
852 case Attribute::ArgMemOnly
:
853 case Attribute::Builtin
:
854 case Attribute::ByVal
:
855 case Attribute::Convergent
:
856 case Attribute::Dereferenceable
:
857 case Attribute::DereferenceableOrNull
:
858 case Attribute::InAlloca
:
859 case Attribute::InReg
:
860 case Attribute::InaccessibleMemOnly
:
861 case Attribute::InaccessibleMemOrArgMemOnly
:
862 case Attribute::JumpTable
:
863 case Attribute::Naked
:
864 case Attribute::Nest
:
865 case Attribute::NoAlias
:
866 case Attribute::NoBuiltin
:
867 case Attribute::NoCapture
:
868 case Attribute::NoReturn
:
869 case Attribute::NoSync
:
870 case Attribute::None
:
871 case Attribute::NonNull
:
872 case Attribute::ReadNone
:
873 case Attribute::ReadOnly
:
874 case Attribute::Returned
:
875 case Attribute::ReturnsTwice
:
876 case Attribute::SExt
:
877 case Attribute::Speculatable
:
878 case Attribute::StackAlignment
:
879 case Attribute::StructRet
:
880 case Attribute::SwiftError
:
881 case Attribute::SwiftSelf
:
882 case Attribute::WillReturn
:
883 case Attribute::WriteOnly
:
884 case Attribute::ZExt
:
885 case Attribute::ImmArg
:
886 case Attribute::EndAttrKinds
:
888 // Those attributes should be safe to propagate to the extracted function.
889 case Attribute::AlwaysInline
:
890 case Attribute::Cold
:
891 case Attribute::NoRecurse
:
892 case Attribute::InlineHint
:
893 case Attribute::MinSize
:
894 case Attribute::NoDuplicate
:
895 case Attribute::NoFree
:
896 case Attribute::NoImplicitFloat
:
897 case Attribute::NoInline
:
898 case Attribute::NonLazyBind
:
899 case Attribute::NoRedZone
:
900 case Attribute::NoUnwind
:
901 case Attribute::OptForFuzzing
:
902 case Attribute::OptimizeNone
:
903 case Attribute::OptimizeForSize
:
904 case Attribute::SafeStack
:
905 case Attribute::ShadowCallStack
:
906 case Attribute::SanitizeAddress
:
907 case Attribute::SanitizeMemory
:
908 case Attribute::SanitizeThread
:
909 case Attribute::SanitizeHWAddress
:
910 case Attribute::SanitizeMemTag
:
911 case Attribute::SpeculativeLoadHardening
:
912 case Attribute::StackProtect
:
913 case Attribute::StackProtectReq
:
914 case Attribute::StackProtectStrong
:
915 case Attribute::StrictFP
:
916 case Attribute::UWTable
:
917 case Attribute::NoCfCheck
:
921 newFunction
->addFnAttr(Attr
);
923 newFunction
->getBasicBlockList().push_back(newRootNode
);
925 // Create an iterator to name all of the arguments we inserted.
926 Function::arg_iterator AI
= newFunction
->arg_begin();
928 // Rewrite all users of the inputs in the extracted region to use the
929 // arguments (or appropriate addressing into struct) instead.
930 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
934 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(header
->getContext()));
935 Idx
[1] = ConstantInt::get(Type::getInt32Ty(header
->getContext()), i
);
936 Instruction
*TI
= newFunction
->begin()->getTerminator();
937 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
938 StructTy
, &*AI
, Idx
, "gep_" + inputs
[i
]->getName(), TI
);
939 RewriteVal
= new LoadInst(StructTy
->getElementType(i
), GEP
,
940 "loadgep_" + inputs
[i
]->getName(), TI
);
944 std::vector
<User
*> Users(inputs
[i
]->user_begin(), inputs
[i
]->user_end());
945 for (User
*use
: Users
)
946 if (Instruction
*inst
= dyn_cast
<Instruction
>(use
))
947 if (Blocks
.count(inst
->getParent()))
948 inst
->replaceUsesOfWith(inputs
[i
], RewriteVal
);
951 // Set names for input and output arguments.
952 if (!AggregateArgs
) {
953 AI
= newFunction
->arg_begin();
954 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
, ++AI
)
955 AI
->setName(inputs
[i
]->getName());
956 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
, ++AI
)
957 AI
->setName(outputs
[i
]->getName()+".out");
960 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
961 // within the new function. This must be done before we lose track of which
962 // blocks were originally in the code region.
963 std::vector
<User
*> Users(header
->user_begin(), header
->user_end());
964 for (auto &U
: Users
)
965 // The BasicBlock which contains the branch is not in the region
966 // modify the branch target to a new block
967 if (Instruction
*I
= dyn_cast
<Instruction
>(U
))
968 if (I
->isTerminator() && I
->getFunction() == oldFunction
&&
969 !Blocks
.count(I
->getParent()))
970 I
->replaceUsesOfWith(header
, newHeader
);
975 /// Erase lifetime.start markers which reference inputs to the extraction
976 /// region, and insert the referenced memory into \p LifetimesStart.
978 /// The extraction region is defined by a set of blocks (\p Blocks), and a set
979 /// of allocas which will be moved from the caller function into the extracted
980 /// function (\p SunkAllocas).
981 static void eraseLifetimeMarkersOnInputs(const SetVector
<BasicBlock
*> &Blocks
,
982 const SetVector
<Value
*> &SunkAllocas
,
983 SetVector
<Value
*> &LifetimesStart
) {
984 for (BasicBlock
*BB
: Blocks
) {
985 for (auto It
= BB
->begin(), End
= BB
->end(); It
!= End
;) {
986 auto *II
= dyn_cast
<IntrinsicInst
>(&*It
);
988 if (!II
|| !II
->isLifetimeStartOrEnd())
991 // Get the memory operand of the lifetime marker. If the underlying
992 // object is a sunk alloca, or is otherwise defined in the extraction
993 // region, the lifetime marker must not be erased.
994 Value
*Mem
= II
->getOperand(1)->stripInBoundsOffsets();
995 if (SunkAllocas
.count(Mem
) || definedInRegion(Blocks
, Mem
))
998 if (II
->getIntrinsicID() == Intrinsic::lifetime_start
)
999 LifetimesStart
.insert(Mem
);
1000 II
->eraseFromParent();
1005 /// Insert lifetime start/end markers surrounding the call to the new function
1006 /// for objects defined in the caller.
1007 static void insertLifetimeMarkersSurroundingCall(
1008 Module
*M
, ArrayRef
<Value
*> LifetimesStart
, ArrayRef
<Value
*> LifetimesEnd
,
1009 CallInst
*TheCall
) {
1010 LLVMContext
&Ctx
= M
->getContext();
1011 auto Int8PtrTy
= Type::getInt8PtrTy(Ctx
);
1012 auto NegativeOne
= ConstantInt::getSigned(Type::getInt64Ty(Ctx
), -1);
1013 Instruction
*Term
= TheCall
->getParent()->getTerminator();
1015 // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
1016 // needed to satisfy this requirement so they may be reused.
1017 DenseMap
<Value
*, Value
*> Bitcasts
;
1019 // Emit lifetime markers for the pointers given in \p Objects. Insert the
1020 // markers before the call if \p InsertBefore, and after the call otherwise.
1021 auto insertMarkers
= [&](Function
*MarkerFunc
, ArrayRef
<Value
*> Objects
,
1022 bool InsertBefore
) {
1023 for (Value
*Mem
: Objects
) {
1024 assert((!isa
<Instruction
>(Mem
) || cast
<Instruction
>(Mem
)->getFunction() ==
1025 TheCall
->getFunction()) &&
1026 "Input memory not defined in original function");
1027 Value
*&MemAsI8Ptr
= Bitcasts
[Mem
];
1029 if (Mem
->getType() == Int8PtrTy
)
1033 CastInst::CreatePointerCast(Mem
, Int8PtrTy
, "lt.cast", TheCall
);
1036 auto Marker
= CallInst::Create(MarkerFunc
, {NegativeOne
, MemAsI8Ptr
});
1038 Marker
->insertBefore(TheCall
);
1040 Marker
->insertBefore(Term
);
1044 if (!LifetimesStart
.empty()) {
1045 auto StartFn
= llvm::Intrinsic::getDeclaration(
1046 M
, llvm::Intrinsic::lifetime_start
, Int8PtrTy
);
1047 insertMarkers(StartFn
, LifetimesStart
, /*InsertBefore=*/true);
1050 if (!LifetimesEnd
.empty()) {
1051 auto EndFn
= llvm::Intrinsic::getDeclaration(
1052 M
, llvm::Intrinsic::lifetime_end
, Int8PtrTy
);
1053 insertMarkers(EndFn
, LifetimesEnd
, /*InsertBefore=*/false);
1057 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
1058 /// the call instruction, splitting any PHI nodes in the header block as
1060 CallInst
*CodeExtractor::emitCallAndSwitchStatement(Function
*newFunction
,
1061 BasicBlock
*codeReplacer
,
1063 ValueSet
&outputs
) {
1064 // Emit a call to the new function, passing in: *pointer to struct (if
1065 // aggregating parameters), or plan inputs and allocated memory for outputs
1066 std::vector
<Value
*> params
, StructValues
, ReloadOutputs
, Reloads
;
1068 Module
*M
= newFunction
->getParent();
1069 LLVMContext
&Context
= M
->getContext();
1070 const DataLayout
&DL
= M
->getDataLayout();
1071 CallInst
*call
= nullptr;
1073 // Add inputs as params, or to be filled into the struct
1075 SmallVector
<unsigned, 1> SwiftErrorArgs
;
1076 for (Value
*input
: inputs
) {
1078 StructValues
.push_back(input
);
1080 params
.push_back(input
);
1081 if (input
->isSwiftError())
1082 SwiftErrorArgs
.push_back(ArgNo
);
1087 // Create allocas for the outputs
1088 for (Value
*output
: outputs
) {
1089 if (AggregateArgs
) {
1090 StructValues
.push_back(output
);
1092 AllocaInst
*alloca
=
1093 new AllocaInst(output
->getType(), DL
.getAllocaAddrSpace(),
1094 nullptr, output
->getName() + ".loc",
1095 &codeReplacer
->getParent()->front().front());
1096 ReloadOutputs
.push_back(alloca
);
1097 params
.push_back(alloca
);
1101 StructType
*StructArgTy
= nullptr;
1102 AllocaInst
*Struct
= nullptr;
1103 if (AggregateArgs
&& (inputs
.size() + outputs
.size() > 0)) {
1104 std::vector
<Type
*> ArgTypes
;
1105 for (ValueSet::iterator v
= StructValues
.begin(),
1106 ve
= StructValues
.end(); v
!= ve
; ++v
)
1107 ArgTypes
.push_back((*v
)->getType());
1109 // Allocate a struct at the beginning of this function
1110 StructArgTy
= StructType::get(newFunction
->getContext(), ArgTypes
);
1111 Struct
= new AllocaInst(StructArgTy
, DL
.getAllocaAddrSpace(), nullptr,
1113 &codeReplacer
->getParent()->front().front());
1114 params
.push_back(Struct
);
1116 for (unsigned i
= 0, e
= inputs
.size(); i
!= e
; ++i
) {
1118 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1119 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), i
);
1120 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1121 StructArgTy
, Struct
, Idx
, "gep_" + StructValues
[i
]->getName());
1122 codeReplacer
->getInstList().push_back(GEP
);
1123 StoreInst
*SI
= new StoreInst(StructValues
[i
], GEP
);
1124 codeReplacer
->getInstList().push_back(SI
);
1128 // Emit the call to the function
1129 call
= CallInst::Create(newFunction
, params
,
1130 NumExitBlocks
> 1 ? "targetBlock" : "");
1131 // Add debug location to the new call, if the original function has debug
1132 // info. In that case, the terminator of the entry block of the extracted
1133 // function contains the first debug location of the extracted function,
1134 // set in extractCodeRegion.
1135 if (codeReplacer
->getParent()->getSubprogram()) {
1136 if (auto DL
= newFunction
->getEntryBlock().getTerminator()->getDebugLoc())
1137 call
->setDebugLoc(DL
);
1139 codeReplacer
->getInstList().push_back(call
);
1141 // Set swifterror parameter attributes.
1142 for (unsigned SwiftErrArgNo
: SwiftErrorArgs
) {
1143 call
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1144 newFunction
->addParamAttr(SwiftErrArgNo
, Attribute::SwiftError
);
1147 Function::arg_iterator OutputArgBegin
= newFunction
->arg_begin();
1148 unsigned FirstOut
= inputs
.size();
1150 std::advance(OutputArgBegin
, inputs
.size());
1152 // Reload the outputs passed in by reference.
1153 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1154 Value
*Output
= nullptr;
1155 if (AggregateArgs
) {
1157 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1158 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1159 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1160 StructArgTy
, Struct
, Idx
, "gep_reload_" + outputs
[i
]->getName());
1161 codeReplacer
->getInstList().push_back(GEP
);
1164 Output
= ReloadOutputs
[i
];
1166 LoadInst
*load
= new LoadInst(outputs
[i
]->getType(), Output
,
1167 outputs
[i
]->getName() + ".reload");
1168 Reloads
.push_back(load
);
1169 codeReplacer
->getInstList().push_back(load
);
1170 std::vector
<User
*> Users(outputs
[i
]->user_begin(), outputs
[i
]->user_end());
1171 for (unsigned u
= 0, e
= Users
.size(); u
!= e
; ++u
) {
1172 Instruction
*inst
= cast
<Instruction
>(Users
[u
]);
1173 if (!Blocks
.count(inst
->getParent()))
1174 inst
->replaceUsesOfWith(outputs
[i
], load
);
1178 // Now we can emit a switch statement using the call as a value.
1179 SwitchInst
*TheSwitch
=
1180 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context
)),
1181 codeReplacer
, 0, codeReplacer
);
1183 // Since there may be multiple exits from the original region, make the new
1184 // function return an unsigned, switch on that number. This loop iterates
1185 // over all of the blocks in the extracted region, updating any terminator
1186 // instructions in the to-be-extracted region that branch to blocks that are
1187 // not in the region to be extracted.
1188 std::map
<BasicBlock
*, BasicBlock
*> ExitBlockMap
;
1190 unsigned switchVal
= 0;
1191 for (BasicBlock
*Block
: Blocks
) {
1192 Instruction
*TI
= Block
->getTerminator();
1193 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
!= e
; ++i
)
1194 if (!Blocks
.count(TI
->getSuccessor(i
))) {
1195 BasicBlock
*OldTarget
= TI
->getSuccessor(i
);
1196 // add a new basic block which returns the appropriate value
1197 BasicBlock
*&NewTarget
= ExitBlockMap
[OldTarget
];
1199 // If we don't already have an exit stub for this non-extracted
1200 // destination, create one now!
1201 NewTarget
= BasicBlock::Create(Context
,
1202 OldTarget
->getName() + ".exitStub",
1204 unsigned SuccNum
= switchVal
++;
1206 Value
*brVal
= nullptr;
1207 switch (NumExitBlocks
) {
1209 case 1: break; // No value needed.
1210 case 2: // Conditional branch, return a bool
1211 brVal
= ConstantInt::get(Type::getInt1Ty(Context
), !SuccNum
);
1214 brVal
= ConstantInt::get(Type::getInt16Ty(Context
), SuccNum
);
1218 ReturnInst::Create(Context
, brVal
, NewTarget
);
1220 // Update the switch instruction.
1221 TheSwitch
->addCase(ConstantInt::get(Type::getInt16Ty(Context
),
1226 // rewrite the original branch instruction with this new target
1227 TI
->setSuccessor(i
, NewTarget
);
1231 // Store the arguments right after the definition of output value.
1232 // This should be proceeded after creating exit stubs to be ensure that invoke
1233 // result restore will be placed in the outlined function.
1234 Function::arg_iterator OAI
= OutputArgBegin
;
1235 for (unsigned i
= 0, e
= outputs
.size(); i
!= e
; ++i
) {
1236 auto *OutI
= dyn_cast
<Instruction
>(outputs
[i
]);
1240 // Find proper insertion point.
1241 BasicBlock::iterator InsertPt
;
1242 // In case OutI is an invoke, we insert the store at the beginning in the
1243 // 'normal destination' BB. Otherwise we insert the store right after OutI.
1244 if (auto *InvokeI
= dyn_cast
<InvokeInst
>(OutI
))
1245 InsertPt
= InvokeI
->getNormalDest()->getFirstInsertionPt();
1246 else if (auto *Phi
= dyn_cast
<PHINode
>(OutI
))
1247 InsertPt
= Phi
->getParent()->getFirstInsertionPt();
1249 InsertPt
= std::next(OutI
->getIterator());
1251 Instruction
*InsertBefore
= &*InsertPt
;
1252 assert((InsertBefore
->getFunction() == newFunction
||
1253 Blocks
.count(InsertBefore
->getParent())) &&
1254 "InsertPt should be in new function");
1255 assert(OAI
!= newFunction
->arg_end() &&
1256 "Number of output arguments should match "
1257 "the amount of defined values");
1258 if (AggregateArgs
) {
1260 Idx
[0] = Constant::getNullValue(Type::getInt32Ty(Context
));
1261 Idx
[1] = ConstantInt::get(Type::getInt32Ty(Context
), FirstOut
+ i
);
1262 GetElementPtrInst
*GEP
= GetElementPtrInst::Create(
1263 StructArgTy
, &*OAI
, Idx
, "gep_" + outputs
[i
]->getName(),
1265 new StoreInst(outputs
[i
], GEP
, InsertBefore
);
1266 // Since there should be only one struct argument aggregating
1267 // all the output values, we shouldn't increment OAI, which always
1268 // points to the struct argument, in this case.
1270 new StoreInst(outputs
[i
], &*OAI
, InsertBefore
);
1275 // Now that we've done the deed, simplify the switch instruction.
1276 Type
*OldFnRetTy
= TheSwitch
->getParent()->getParent()->getReturnType();
1277 switch (NumExitBlocks
) {
1279 // There are no successors (the block containing the switch itself), which
1280 // means that previously this was the last part of the function, and hence
1281 // this should be rewritten as a `ret'
1283 // Check if the function should return a value
1284 if (OldFnRetTy
->isVoidTy()) {
1285 ReturnInst::Create(Context
, nullptr, TheSwitch
); // Return void
1286 } else if (OldFnRetTy
== TheSwitch
->getCondition()->getType()) {
1287 // return what we have
1288 ReturnInst::Create(Context
, TheSwitch
->getCondition(), TheSwitch
);
1290 // Otherwise we must have code extracted an unwind or something, just
1291 // return whatever we want.
1292 ReturnInst::Create(Context
,
1293 Constant::getNullValue(OldFnRetTy
), TheSwitch
);
1296 TheSwitch
->eraseFromParent();
1299 // Only a single destination, change the switch into an unconditional
1301 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
);
1302 TheSwitch
->eraseFromParent();
1305 BranchInst::Create(TheSwitch
->getSuccessor(1), TheSwitch
->getSuccessor(2),
1307 TheSwitch
->eraseFromParent();
1310 // Otherwise, make the default destination of the switch instruction be one
1311 // of the other successors.
1312 TheSwitch
->setCondition(call
);
1313 TheSwitch
->setDefaultDest(TheSwitch
->getSuccessor(NumExitBlocks
));
1314 // Remove redundant case
1315 TheSwitch
->removeCase(SwitchInst::CaseIt(TheSwitch
, NumExitBlocks
-1));
1319 // Insert lifetime markers around the reloads of any output values. The
1320 // allocas output values are stored in are only in-use in the codeRepl block.
1321 insertLifetimeMarkersSurroundingCall(M
, ReloadOutputs
, ReloadOutputs
, call
);
1326 void CodeExtractor::moveCodeToFunction(Function
*newFunction
) {
1327 Function
*oldFunc
= (*Blocks
.begin())->getParent();
1328 Function::BasicBlockListType
&oldBlocks
= oldFunc
->getBasicBlockList();
1329 Function::BasicBlockListType
&newBlocks
= newFunction
->getBasicBlockList();
1331 for (BasicBlock
*Block
: Blocks
) {
1332 // Delete the basic block from the old function, and the list of blocks
1333 oldBlocks
.remove(Block
);
1335 // Insert this basic block into the new function
1336 newBlocks
.push_back(Block
);
1340 void CodeExtractor::calculateNewCallTerminatorWeights(
1341 BasicBlock
*CodeReplacer
,
1342 DenseMap
<BasicBlock
*, BlockFrequency
> &ExitWeights
,
1343 BranchProbabilityInfo
*BPI
) {
1344 using Distribution
= BlockFrequencyInfoImplBase::Distribution
;
1345 using BlockNode
= BlockFrequencyInfoImplBase::BlockNode
;
1347 // Update the branch weights for the exit block.
1348 Instruction
*TI
= CodeReplacer
->getTerminator();
1349 SmallVector
<unsigned, 8> BranchWeights(TI
->getNumSuccessors(), 0);
1351 // Block Frequency distribution with dummy node.
1352 Distribution BranchDist
;
1354 // Add each of the frequencies of the successors.
1355 for (unsigned i
= 0, e
= TI
->getNumSuccessors(); i
< e
; ++i
) {
1356 BlockNode
ExitNode(i
);
1357 uint64_t ExitFreq
= ExitWeights
[TI
->getSuccessor(i
)].getFrequency();
1359 BranchDist
.addExit(ExitNode
, ExitFreq
);
1361 BPI
->setEdgeProbability(CodeReplacer
, i
, BranchProbability::getZero());
1364 // Check for no total weight.
1365 if (BranchDist
.Total
== 0)
1368 // Normalize the distribution so that they can fit in unsigned.
1369 BranchDist
.normalize();
1371 // Create normalized branch weights and set the metadata.
1372 for (unsigned I
= 0, E
= BranchDist
.Weights
.size(); I
< E
; ++I
) {
1373 const auto &Weight
= BranchDist
.Weights
[I
];
1375 // Get the weight and update the current BFI.
1376 BranchWeights
[Weight
.TargetNode
.Index
] = Weight
.Amount
;
1377 BranchProbability
BP(Weight
.Amount
, BranchDist
.Total
);
1378 BPI
->setEdgeProbability(CodeReplacer
, Weight
.TargetNode
.Index
, BP
);
1381 LLVMContext::MD_prof
,
1382 MDBuilder(TI
->getContext()).createBranchWeights(BranchWeights
));
1386 CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache
&CEAC
) {
1390 // Assumption: this is a single-entry code region, and the header is the first
1391 // block in the region.
1392 BasicBlock
*header
= *Blocks
.begin();
1393 Function
*oldFunction
= header
->getParent();
1395 // Calculate the entry frequency of the new function before we change the root
1397 BlockFrequency EntryFreq
;
1399 assert(BPI
&& "Both BPI and BFI are required to preserve profile info");
1400 for (BasicBlock
*Pred
: predecessors(header
)) {
1401 if (Blocks
.count(Pred
))
1404 BFI
->getBlockFreq(Pred
) * BPI
->getEdgeProbability(Pred
, header
);
1409 // Remove @llvm.assume calls that were moved to the new function from the
1410 // old function's assumption cache.
1411 for (BasicBlock
*Block
: Blocks
)
1412 for (auto &I
: *Block
)
1413 if (match(&I
, m_Intrinsic
<Intrinsic::assume
>()))
1414 AC
->unregisterAssumption(cast
<CallInst
>(&I
));
1417 // If we have any return instructions in the region, split those blocks so
1418 // that the return is not in the region.
1419 splitReturnBlocks();
1421 // Calculate the exit blocks for the extracted region and the total exit
1422 // weights for each of those blocks.
1423 DenseMap
<BasicBlock
*, BlockFrequency
> ExitWeights
;
1424 SmallPtrSet
<BasicBlock
*, 1> ExitBlocks
;
1425 for (BasicBlock
*Block
: Blocks
) {
1426 for (succ_iterator SI
= succ_begin(Block
), SE
= succ_end(Block
); SI
!= SE
;
1428 if (!Blocks
.count(*SI
)) {
1429 // Update the branch weight for this successor.
1431 BlockFrequency
&BF
= ExitWeights
[*SI
];
1432 BF
+= BFI
->getBlockFreq(Block
) * BPI
->getEdgeProbability(Block
, *SI
);
1434 ExitBlocks
.insert(*SI
);
1438 NumExitBlocks
= ExitBlocks
.size();
1440 // If we have to split PHI nodes of the entry or exit blocks, do so now.
1441 severSplitPHINodesOfEntry(header
);
1442 severSplitPHINodesOfExits(ExitBlocks
);
1444 // This takes place of the original loop
1445 BasicBlock
*codeReplacer
= BasicBlock::Create(header
->getContext(),
1446 "codeRepl", oldFunction
,
1449 // The new function needs a root node because other nodes can branch to the
1450 // head of the region, but the entry node of a function cannot have preds.
1451 BasicBlock
*newFuncRoot
= BasicBlock::Create(header
->getContext(),
1453 auto *BranchI
= BranchInst::Create(header
);
1454 // If the original function has debug info, we have to add a debug location
1455 // to the new branch instruction from the artificial entry block.
1456 // We use the debug location of the first instruction in the extracted
1457 // blocks, as there is no other equivalent line in the source code.
1458 if (oldFunction
->getSubprogram()) {
1459 any_of(Blocks
, [&BranchI
](const BasicBlock
*BB
) {
1460 return any_of(*BB
, [&BranchI
](const Instruction
&I
) {
1461 if (!I
.getDebugLoc())
1463 BranchI
->setDebugLoc(I
.getDebugLoc());
1468 newFuncRoot
->getInstList().push_back(BranchI
);
1470 ValueSet inputs
, outputs
, SinkingCands
, HoistingCands
;
1471 BasicBlock
*CommonExit
= nullptr;
1472 findAllocas(CEAC
, SinkingCands
, HoistingCands
, CommonExit
);
1473 assert(HoistingCands
.empty() || CommonExit
);
1475 // Find inputs to, outputs from the code region.
1476 findInputsOutputs(inputs
, outputs
, SinkingCands
);
1478 // Now sink all instructions which only have non-phi uses inside the region.
1479 // Group the allocas at the start of the block, so that any bitcast uses of
1480 // the allocas are well-defined.
1481 AllocaInst
*FirstSunkAlloca
= nullptr;
1482 for (auto *II
: SinkingCands
) {
1483 if (auto *AI
= dyn_cast
<AllocaInst
>(II
)) {
1484 AI
->moveBefore(*newFuncRoot
, newFuncRoot
->getFirstInsertionPt());
1485 if (!FirstSunkAlloca
)
1486 FirstSunkAlloca
= AI
;
1489 assert((SinkingCands
.empty() || FirstSunkAlloca
) &&
1490 "Did not expect a sink candidate without any allocas");
1491 for (auto *II
: SinkingCands
) {
1492 if (!isa
<AllocaInst
>(II
)) {
1493 cast
<Instruction
>(II
)->moveAfter(FirstSunkAlloca
);
1497 if (!HoistingCands
.empty()) {
1498 auto *HoistToBlock
= findOrCreateBlockForHoisting(CommonExit
);
1499 Instruction
*TI
= HoistToBlock
->getTerminator();
1500 for (auto *II
: HoistingCands
)
1501 cast
<Instruction
>(II
)->moveBefore(TI
);
1504 // Collect objects which are inputs to the extraction region and also
1505 // referenced by lifetime start markers within it. The effects of these
1506 // markers must be replicated in the calling function to prevent the stack
1507 // coloring pass from merging slots which store input objects.
1508 ValueSet LifetimesStart
;
1509 eraseLifetimeMarkersOnInputs(Blocks
, SinkingCands
, LifetimesStart
);
1511 // Construct new function based on inputs/outputs & add allocas for all defs.
1512 Function
*newFunction
=
1513 constructFunction(inputs
, outputs
, header
, newFuncRoot
, codeReplacer
,
1514 oldFunction
, oldFunction
->getParent());
1516 // Update the entry count of the function.
1518 auto Count
= BFI
->getProfileCountFromFreq(EntryFreq
.getFrequency());
1519 if (Count
.hasValue())
1520 newFunction
->setEntryCount(
1521 ProfileCount(Count
.getValue(), Function::PCT_Real
)); // FIXME
1522 BFI
->setBlockFreq(codeReplacer
, EntryFreq
.getFrequency());
1526 emitCallAndSwitchStatement(newFunction
, codeReplacer
, inputs
, outputs
);
1528 moveCodeToFunction(newFunction
);
1530 // Replicate the effects of any lifetime start/end markers which referenced
1531 // input objects in the extraction region by placing markers around the call.
1532 insertLifetimeMarkersSurroundingCall(
1533 oldFunction
->getParent(), LifetimesStart
.getArrayRef(), {}, TheCall
);
1535 // Propagate personality info to the new function if there is one.
1536 if (oldFunction
->hasPersonalityFn())
1537 newFunction
->setPersonalityFn(oldFunction
->getPersonalityFn());
1539 // Update the branch weights for the exit block.
1540 if (BFI
&& NumExitBlocks
> 1)
1541 calculateNewCallTerminatorWeights(codeReplacer
, ExitWeights
, BPI
);
1543 // Loop over all of the PHI nodes in the header and exit blocks, and change
1544 // any references to the old incoming edge to be the new incoming edge.
1545 for (BasicBlock::iterator I
= header
->begin(); isa
<PHINode
>(I
); ++I
) {
1546 PHINode
*PN
= cast
<PHINode
>(I
);
1547 for (unsigned i
= 0, e
= PN
->getNumIncomingValues(); i
!= e
; ++i
)
1548 if (!Blocks
.count(PN
->getIncomingBlock(i
)))
1549 PN
->setIncomingBlock(i
, newFuncRoot
);
1552 for (BasicBlock
*ExitBB
: ExitBlocks
)
1553 for (PHINode
&PN
: ExitBB
->phis()) {
1554 Value
*IncomingCodeReplacerVal
= nullptr;
1555 for (unsigned i
= 0, e
= PN
.getNumIncomingValues(); i
!= e
; ++i
) {
1556 // Ignore incoming values from outside of the extracted region.
1557 if (!Blocks
.count(PN
.getIncomingBlock(i
)))
1560 // Ensure that there is only one incoming value from codeReplacer.
1561 if (!IncomingCodeReplacerVal
) {
1562 PN
.setIncomingBlock(i
, codeReplacer
);
1563 IncomingCodeReplacerVal
= PN
.getIncomingValue(i
);
1565 assert(IncomingCodeReplacerVal
== PN
.getIncomingValue(i
) &&
1566 "PHI has two incompatbile incoming values from codeRepl");
1570 // Erase debug info intrinsics. Variable updates within the new function are
1571 // invisible to debuggers. This could be improved by defining a DISubprogram
1572 // for the new function.
1573 for (BasicBlock
&BB
: *newFunction
) {
1574 auto BlockIt
= BB
.begin();
1575 // Remove debug info intrinsics from the new function.
1576 while (BlockIt
!= BB
.end()) {
1577 Instruction
*Inst
= &*BlockIt
;
1579 if (isa
<DbgInfoIntrinsic
>(Inst
))
1580 Inst
->eraseFromParent();
1582 // Remove debug info intrinsics which refer to values in the new function
1583 // from the old function.
1584 SmallVector
<DbgVariableIntrinsic
*, 4> DbgUsers
;
1585 for (Instruction
&I
: BB
)
1586 findDbgUsers(DbgUsers
, &I
);
1587 for (DbgVariableIntrinsic
*DVI
: DbgUsers
)
1588 DVI
->eraseFromParent();
1591 // Mark the new function `noreturn` if applicable. Terminators which resume
1592 // exception propagation are treated as returning instructions. This is to
1593 // avoid inserting traps after calls to outlined functions which unwind.
1594 bool doesNotReturn
= none_of(*newFunction
, [](const BasicBlock
&BB
) {
1595 const Instruction
*Term
= BB
.getTerminator();
1596 return isa
<ReturnInst
>(Term
) || isa
<ResumeInst
>(Term
);
1599 newFunction
->setDoesNotReturn();
1601 LLVM_DEBUG(if (verifyFunction(*newFunction
, &errs())) {
1602 newFunction
->dump();
1603 report_fatal_error("verification of newFunction failed!");
1605 LLVM_DEBUG(if (verifyFunction(*oldFunction
))
1606 report_fatal_error("verification of oldFunction failed!"));
1607 LLVM_DEBUG(if (AC
&& verifyAssumptionCache(*oldFunction
, AC
))
1608 report_fatal_error("Stale Asumption cache for old Function!"));
1612 bool CodeExtractor::verifyAssumptionCache(const Function
& F
,
1613 AssumptionCache
*AC
) {
1614 for (auto AssumeVH
: AC
->assumptions()) {
1615 CallInst
*I
= cast
<CallInst
>(AssumeVH
);
1616 if (I
->getFunction() != &F
)