Remove the default clause from a fully-covering switch
[llvm-core.git] / lib / Transforms / Utils / CodeExtractor.cpp
blob1189714dfab10be5c18f89309cf6971aeccf8a04
1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the interface to tear out a code region, such as an
11 // individual loop or a parallel section, into a new function, replacing it with
12 // a call to the new function.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Utils/CodeExtractor.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/BlockFrequencyInfo.h"
21 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
22 #include "llvm/Analysis/BranchProbabilityInfo.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Analysis/RegionInfo.h"
25 #include "llvm/Analysis/RegionIterator.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Verifier.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BlockFrequency.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/ErrorHandling.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include <algorithm>
44 #include <set>
45 using namespace llvm;
47 #define DEBUG_TYPE "code-extractor"
49 // Provide a command-line option to aggregate function arguments into a struct
50 // for functions produced by the code extractor. This is useful when converting
51 // extracted functions to pthread-based code, as only one argument (void*) can
52 // be passed in to pthread_create().
53 static cl::opt<bool>
54 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
55 cl::desc("Aggregate arguments to code-extracted functions"));
57 /// \brief Test whether a block is valid for extraction.
58 bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) {
59 // Landing pads must be in the function where they were inserted for cleanup.
60 if (BB.isEHPad())
61 return false;
62 // taking the address of a basic block moved to another function is illegal
63 if (BB.hasAddressTaken())
64 return false;
66 // don't hoist code that uses another basicblock address, as it's likely to
67 // lead to unexpected behavior, like cross-function jumps
68 SmallPtrSet<User const *, 16> Visited;
69 SmallVector<User const *, 16> ToVisit;
71 for (Instruction const &Inst : BB)
72 ToVisit.push_back(&Inst);
74 while (!ToVisit.empty()) {
75 User const *Curr = ToVisit.pop_back_val();
76 if (!Visited.insert(Curr).second)
77 continue;
78 if (isa<BlockAddress const>(Curr))
79 return false; // even a reference to self is likely to be not compatible
81 if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
82 continue;
84 for (auto const &U : Curr->operands()) {
85 if (auto *UU = dyn_cast<User>(U))
86 ToVisit.push_back(UU);
90 // Don't hoist code containing allocas, invokes, or vastarts.
91 for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
92 if (isa<AllocaInst>(I) || isa<InvokeInst>(I))
93 return false;
94 if (const CallInst *CI = dyn_cast<CallInst>(I))
95 if (const Function *F = CI->getCalledFunction())
96 if (F->getIntrinsicID() == Intrinsic::vastart)
97 return false;
100 return true;
103 /// \brief Build a set of blocks to extract if the input blocks are viable.
104 static SetVector<BasicBlock *>
105 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) {
106 assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
107 SetVector<BasicBlock *> Result;
109 // Loop over the blocks, adding them to our set-vector, and aborting with an
110 // empty set if we encounter invalid blocks.
111 for (BasicBlock *BB : BBs) {
113 // If this block is dead, don't process it.
114 if (DT && !DT->isReachableFromEntry(BB))
115 continue;
117 if (!Result.insert(BB))
118 llvm_unreachable("Repeated basic blocks in extraction input");
119 if (!CodeExtractor::isBlockValidForExtraction(*BB)) {
120 Result.clear();
121 return Result;
125 #ifndef NDEBUG
126 for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()),
127 E = Result.end();
128 I != E; ++I)
129 for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I);
130 PI != PE; ++PI)
131 assert(Result.count(*PI) &&
132 "No blocks in this region may have entries from outside the region"
133 " except for the first block!");
134 #endif
136 return Result;
139 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
140 bool AggregateArgs, BlockFrequencyInfo *BFI,
141 BranchProbabilityInfo *BPI)
142 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
143 BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)), NumExitBlocks(~0U) {}
145 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
146 BlockFrequencyInfo *BFI,
147 BranchProbabilityInfo *BPI)
148 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
149 BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)),
150 NumExitBlocks(~0U) {}
152 /// definedInRegion - Return true if the specified value is defined in the
153 /// extracted region.
154 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
155 if (Instruction *I = dyn_cast<Instruction>(V))
156 if (Blocks.count(I->getParent()))
157 return true;
158 return false;
161 /// definedInCaller - Return true if the specified value is defined in the
162 /// function being code extracted, but not in the region being extracted.
163 /// These values must be passed in as live-ins to the function.
164 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
165 if (isa<Argument>(V)) return true;
166 if (Instruction *I = dyn_cast<Instruction>(V))
167 if (!Blocks.count(I->getParent()))
168 return true;
169 return false;
172 static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
173 BasicBlock *CommonExitBlock = nullptr;
174 auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
175 for (auto *Succ : successors(Block)) {
176 // Internal edges, ok.
177 if (Blocks.count(Succ))
178 continue;
179 if (!CommonExitBlock) {
180 CommonExitBlock = Succ;
181 continue;
183 if (CommonExitBlock == Succ)
184 continue;
186 return true;
188 return false;
191 if (any_of(Blocks, hasNonCommonExitSucc))
192 return nullptr;
194 return CommonExitBlock;
197 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
198 Instruction *Addr) const {
199 AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
200 Function *Func = (*Blocks.begin())->getParent();
201 for (BasicBlock &BB : *Func) {
202 if (Blocks.count(&BB))
203 continue;
204 for (Instruction &II : BB) {
206 if (isa<DbgInfoIntrinsic>(II))
207 continue;
209 unsigned Opcode = II.getOpcode();
210 Value *MemAddr = nullptr;
211 switch (Opcode) {
212 case Instruction::Store:
213 case Instruction::Load: {
214 if (Opcode == Instruction::Store) {
215 StoreInst *SI = cast<StoreInst>(&II);
216 MemAddr = SI->getPointerOperand();
217 } else {
218 LoadInst *LI = cast<LoadInst>(&II);
219 MemAddr = LI->getPointerOperand();
221 // Global variable can not be aliased with locals.
222 if (dyn_cast<Constant>(MemAddr))
223 break;
224 Value *Base = MemAddr->stripInBoundsConstantOffsets();
225 if (!dyn_cast<AllocaInst>(Base) || Base == AI)
226 return false;
227 break;
229 default: {
230 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
231 if (IntrInst) {
232 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start ||
233 IntrInst->getIntrinsicID() == Intrinsic::lifetime_end)
234 break;
235 return false;
237 // Treat all the other cases conservatively if it has side effects.
238 if (II.mayHaveSideEffects())
239 return false;
245 return true;
248 BasicBlock *
249 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
250 BasicBlock *SinglePredFromOutlineRegion = nullptr;
251 assert(!Blocks.count(CommonExitBlock) &&
252 "Expect a block outside the region!");
253 for (auto *Pred : predecessors(CommonExitBlock)) {
254 if (!Blocks.count(Pred))
255 continue;
256 if (!SinglePredFromOutlineRegion) {
257 SinglePredFromOutlineRegion = Pred;
258 } else if (SinglePredFromOutlineRegion != Pred) {
259 SinglePredFromOutlineRegion = nullptr;
260 break;
264 if (SinglePredFromOutlineRegion)
265 return SinglePredFromOutlineRegion;
267 #ifndef NDEBUG
268 auto getFirstPHI = [](BasicBlock *BB) {
269 BasicBlock::iterator I = BB->begin();
270 PHINode *FirstPhi = nullptr;
271 while (I != BB->end()) {
272 PHINode *Phi = dyn_cast<PHINode>(I);
273 if (!Phi)
274 break;
275 if (!FirstPhi) {
276 FirstPhi = Phi;
277 break;
280 return FirstPhi;
282 // If there are any phi nodes, the single pred either exists or has already
283 // be created before code extraction.
284 assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
285 #endif
287 BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
288 CommonExitBlock->getFirstNonPHI()->getIterator());
290 for (auto *Pred : predecessors(CommonExitBlock)) {
291 if (Blocks.count(Pred))
292 continue;
293 Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
295 // Now add the old exit block to the outline region.
296 Blocks.insert(CommonExitBlock);
297 return CommonExitBlock;
300 void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
301 BasicBlock *&ExitBlock) const {
302 Function *Func = (*Blocks.begin())->getParent();
303 ExitBlock = getCommonExitBlock(Blocks);
305 for (BasicBlock &BB : *Func) {
306 if (Blocks.count(&BB))
307 continue;
308 for (Instruction &II : BB) {
309 auto *AI = dyn_cast<AllocaInst>(&II);
310 if (!AI)
311 continue;
313 // Find the pair of life time markers for address 'Addr' that are either
314 // defined inside the outline region or can legally be shrinkwrapped into
315 // the outline region. If there are not other untracked uses of the
316 // address, return the pair of markers if found; otherwise return a pair
317 // of nullptr.
318 auto GetLifeTimeMarkers =
319 [&](Instruction *Addr, bool &SinkLifeStart,
320 bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
321 Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
323 for (User *U : Addr->users()) {
324 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
325 if (IntrInst) {
326 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
327 // Do not handle the case where AI has multiple start markers.
328 if (LifeStart)
329 return std::make_pair<Instruction *>(nullptr, nullptr);
330 LifeStart = IntrInst;
332 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
333 if (LifeEnd)
334 return std::make_pair<Instruction *>(nullptr, nullptr);
335 LifeEnd = IntrInst;
337 continue;
339 // Find untracked uses of the address, bail.
340 if (!definedInRegion(Blocks, U))
341 return std::make_pair<Instruction *>(nullptr, nullptr);
344 if (!LifeStart || !LifeEnd)
345 return std::make_pair<Instruction *>(nullptr, nullptr);
347 SinkLifeStart = !definedInRegion(Blocks, LifeStart);
348 HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
349 // Do legality Check.
350 if ((SinkLifeStart || HoistLifeEnd) &&
351 !isLegalToShrinkwrapLifetimeMarkers(Addr))
352 return std::make_pair<Instruction *>(nullptr, nullptr);
354 // Check to see if we have a place to do hoisting, if not, bail.
355 if (HoistLifeEnd && !ExitBlock)
356 return std::make_pair<Instruction *>(nullptr, nullptr);
358 return std::make_pair(LifeStart, LifeEnd);
361 bool SinkLifeStart = false, HoistLifeEnd = false;
362 auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
364 if (Markers.first) {
365 if (SinkLifeStart)
366 SinkCands.insert(Markers.first);
367 SinkCands.insert(AI);
368 if (HoistLifeEnd)
369 HoistCands.insert(Markers.second);
370 continue;
373 // Follow the bitcast.
374 Instruction *MarkerAddr = nullptr;
375 for (User *U : AI->users()) {
377 if (U->stripInBoundsConstantOffsets() == AI) {
378 SinkLifeStart = false;
379 HoistLifeEnd = false;
380 Instruction *Bitcast = cast<Instruction>(U);
381 Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
382 if (Markers.first) {
383 MarkerAddr = Bitcast;
384 continue;
388 // Found unknown use of AI.
389 if (!definedInRegion(Blocks, U)) {
390 MarkerAddr = nullptr;
391 break;
395 if (MarkerAddr) {
396 if (SinkLifeStart)
397 SinkCands.insert(Markers.first);
398 if (!definedInRegion(Blocks, MarkerAddr))
399 SinkCands.insert(MarkerAddr);
400 SinkCands.insert(AI);
401 if (HoistLifeEnd)
402 HoistCands.insert(Markers.second);
408 void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
409 const ValueSet &SinkCands) const {
411 for (BasicBlock *BB : Blocks) {
412 // If a used value is defined outside the region, it's an input. If an
413 // instruction is used outside the region, it's an output.
414 for (Instruction &II : *BB) {
415 for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
416 ++OI) {
417 Value *V = *OI;
418 if (!SinkCands.count(V) && definedInCaller(Blocks, V))
419 Inputs.insert(V);
422 for (User *U : II.users())
423 if (!definedInRegion(Blocks, U)) {
424 Outputs.insert(&II);
425 break;
431 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
432 /// region, we need to split the entry block of the region so that the PHI node
433 /// is easier to deal with.
434 void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) {
435 unsigned NumPredsFromRegion = 0;
436 unsigned NumPredsOutsideRegion = 0;
438 if (Header != &Header->getParent()->getEntryBlock()) {
439 PHINode *PN = dyn_cast<PHINode>(Header->begin());
440 if (!PN) return; // No PHI nodes.
442 // If the header node contains any PHI nodes, check to see if there is more
443 // than one entry from outside the region. If so, we need to sever the
444 // header block into two.
445 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
446 if (Blocks.count(PN->getIncomingBlock(i)))
447 ++NumPredsFromRegion;
448 else
449 ++NumPredsOutsideRegion;
451 // If there is one (or fewer) predecessor from outside the region, we don't
452 // need to do anything special.
453 if (NumPredsOutsideRegion <= 1) return;
456 // Otherwise, we need to split the header block into two pieces: one
457 // containing PHI nodes merging values from outside of the region, and a
458 // second that contains all of the code for the block and merges back any
459 // incoming values from inside of the region.
460 BasicBlock *NewBB = llvm::SplitBlock(Header, Header->getFirstNonPHI(), DT);
462 // We only want to code extract the second block now, and it becomes the new
463 // header of the region.
464 BasicBlock *OldPred = Header;
465 Blocks.remove(OldPred);
466 Blocks.insert(NewBB);
467 Header = NewBB;
469 // Okay, now we need to adjust the PHI nodes and any branches from within the
470 // region to go to the new header block instead of the old header block.
471 if (NumPredsFromRegion) {
472 PHINode *PN = cast<PHINode>(OldPred->begin());
473 // Loop over all of the predecessors of OldPred that are in the region,
474 // changing them to branch to NewBB instead.
475 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
476 if (Blocks.count(PN->getIncomingBlock(i))) {
477 TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator();
478 TI->replaceUsesOfWith(OldPred, NewBB);
481 // Okay, everything within the region is now branching to the right block, we
482 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
483 BasicBlock::iterator AfterPHIs;
484 for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
485 PHINode *PN = cast<PHINode>(AfterPHIs);
486 // Create a new PHI node in the new region, which has an incoming value
487 // from OldPred of PN.
488 PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
489 PN->getName() + ".ce", &NewBB->front());
490 PN->replaceAllUsesWith(NewPN);
491 NewPN->addIncoming(PN, OldPred);
493 // Loop over all of the incoming value in PN, moving them to NewPN if they
494 // are from the extracted region.
495 for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
496 if (Blocks.count(PN->getIncomingBlock(i))) {
497 NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
498 PN->removeIncomingValue(i);
499 --i;
506 void CodeExtractor::splitReturnBlocks() {
507 for (BasicBlock *Block : Blocks)
508 if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
509 BasicBlock *New =
510 Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
511 if (DT) {
512 // Old dominates New. New node dominates all other nodes dominated
513 // by Old.
514 DomTreeNode *OldNode = DT->getNode(Block);
515 SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
516 OldNode->end());
518 DomTreeNode *NewNode = DT->addNewBlock(New, Block);
520 for (DomTreeNode *I : Children)
521 DT->changeImmediateDominator(I, NewNode);
526 /// constructFunction - make a function based on inputs and outputs, as follows:
527 /// f(in0, ..., inN, out0, ..., outN)
529 Function *CodeExtractor::constructFunction(const ValueSet &inputs,
530 const ValueSet &outputs,
531 BasicBlock *header,
532 BasicBlock *newRootNode,
533 BasicBlock *newHeader,
534 Function *oldFunction,
535 Module *M) {
536 DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
537 DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
539 // This function returns unsigned, outputs will go back by reference.
540 switch (NumExitBlocks) {
541 case 0:
542 case 1: RetTy = Type::getVoidTy(header->getContext()); break;
543 case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
544 default: RetTy = Type::getInt16Ty(header->getContext()); break;
547 std::vector<Type*> paramTy;
549 // Add the types of the input values to the function's argument list
550 for (Value *value : inputs) {
551 DEBUG(dbgs() << "value used in func: " << *value << "\n");
552 paramTy.push_back(value->getType());
555 // Add the types of the output values to the function's argument list.
556 for (Value *output : outputs) {
557 DEBUG(dbgs() << "instr used in func: " << *output << "\n");
558 if (AggregateArgs)
559 paramTy.push_back(output->getType());
560 else
561 paramTy.push_back(PointerType::getUnqual(output->getType()));
564 DEBUG({
565 dbgs() << "Function type: " << *RetTy << " f(";
566 for (Type *i : paramTy)
567 dbgs() << *i << ", ";
568 dbgs() << ")\n";
571 StructType *StructTy;
572 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
573 StructTy = StructType::get(M->getContext(), paramTy);
574 paramTy.clear();
575 paramTy.push_back(PointerType::getUnqual(StructTy));
577 FunctionType *funcType =
578 FunctionType::get(RetTy, paramTy, false);
580 // Create the new function
581 Function *newFunction = Function::Create(funcType,
582 GlobalValue::InternalLinkage,
583 oldFunction->getName() + "_" +
584 header->getName(), M);
585 // If the old function is no-throw, so is the new one.
586 if (oldFunction->doesNotThrow())
587 newFunction->setDoesNotThrow();
589 // Inherit the uwtable attribute if we need to.
590 if (oldFunction->hasUWTable())
591 newFunction->setHasUWTable();
593 // Inherit all of the target dependent attributes.
594 // (e.g. If the extracted region contains a call to an x86.sse
595 // instruction we need to make sure that the extracted region has the
596 // "target-features" attribute allowing it to be lowered.
597 // FIXME: This should be changed to check to see if a specific
598 // attribute can not be inherited.
599 AttrBuilder AB(oldFunction->getAttributes().getFnAttributes());
600 for (const auto &Attr : AB.td_attrs())
601 newFunction->addFnAttr(Attr.first, Attr.second);
603 newFunction->getBasicBlockList().push_back(newRootNode);
605 // Create an iterator to name all of the arguments we inserted.
606 Function::arg_iterator AI = newFunction->arg_begin();
608 // Rewrite all users of the inputs in the extracted region to use the
609 // arguments (or appropriate addressing into struct) instead.
610 for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
611 Value *RewriteVal;
612 if (AggregateArgs) {
613 Value *Idx[2];
614 Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
615 Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
616 TerminatorInst *TI = newFunction->begin()->getTerminator();
617 GetElementPtrInst *GEP = GetElementPtrInst::Create(
618 StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
619 RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
620 } else
621 RewriteVal = &*AI++;
623 std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end());
624 for (User *use : Users)
625 if (Instruction *inst = dyn_cast<Instruction>(use))
626 if (Blocks.count(inst->getParent()))
627 inst->replaceUsesOfWith(inputs[i], RewriteVal);
630 // Set names for input and output arguments.
631 if (!AggregateArgs) {
632 AI = newFunction->arg_begin();
633 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
634 AI->setName(inputs[i]->getName());
635 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
636 AI->setName(outputs[i]->getName()+".out");
639 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
640 // within the new function. This must be done before we lose track of which
641 // blocks were originally in the code region.
642 std::vector<User*> Users(header->user_begin(), header->user_end());
643 for (unsigned i = 0, e = Users.size(); i != e; ++i)
644 // The BasicBlock which contains the branch is not in the region
645 // modify the branch target to a new block
646 if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i]))
647 if (!Blocks.count(TI->getParent()) &&
648 TI->getParent()->getParent() == oldFunction)
649 TI->replaceUsesOfWith(header, newHeader);
651 return newFunction;
654 /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI
655 /// that uses the value within the basic block, and return the predecessor
656 /// block associated with that use, or return 0 if none is found.
657 static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) {
658 for (Use &U : Used->uses()) {
659 PHINode *P = dyn_cast<PHINode>(U.getUser());
660 if (P && P->getParent() == BB)
661 return P->getIncomingBlock(U);
664 return nullptr;
667 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
668 /// the call instruction, splitting any PHI nodes in the header block as
669 /// necessary.
670 void CodeExtractor::
671 emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
672 ValueSet &inputs, ValueSet &outputs) {
673 // Emit a call to the new function, passing in: *pointer to struct (if
674 // aggregating parameters), or plan inputs and allocated memory for outputs
675 std::vector<Value*> params, StructValues, ReloadOutputs, Reloads;
677 Module *M = newFunction->getParent();
678 LLVMContext &Context = M->getContext();
679 const DataLayout &DL = M->getDataLayout();
681 // Add inputs as params, or to be filled into the struct
682 for (Value *input : inputs)
683 if (AggregateArgs)
684 StructValues.push_back(input);
685 else
686 params.push_back(input);
688 // Create allocas for the outputs
689 for (Value *output : outputs) {
690 if (AggregateArgs) {
691 StructValues.push_back(output);
692 } else {
693 AllocaInst *alloca =
694 new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
695 nullptr, output->getName() + ".loc",
696 &codeReplacer->getParent()->front().front());
697 ReloadOutputs.push_back(alloca);
698 params.push_back(alloca);
702 StructType *StructArgTy = nullptr;
703 AllocaInst *Struct = nullptr;
704 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
705 std::vector<Type*> ArgTypes;
706 for (ValueSet::iterator v = StructValues.begin(),
707 ve = StructValues.end(); v != ve; ++v)
708 ArgTypes.push_back((*v)->getType());
710 // Allocate a struct at the beginning of this function
711 StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
712 Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
713 "structArg",
714 &codeReplacer->getParent()->front().front());
715 params.push_back(Struct);
717 for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
718 Value *Idx[2];
719 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
720 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
721 GetElementPtrInst *GEP = GetElementPtrInst::Create(
722 StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
723 codeReplacer->getInstList().push_back(GEP);
724 StoreInst *SI = new StoreInst(StructValues[i], GEP);
725 codeReplacer->getInstList().push_back(SI);
729 // Emit the call to the function
730 CallInst *call = CallInst::Create(newFunction, params,
731 NumExitBlocks > 1 ? "targetBlock" : "");
732 codeReplacer->getInstList().push_back(call);
734 Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
735 unsigned FirstOut = inputs.size();
736 if (!AggregateArgs)
737 std::advance(OutputArgBegin, inputs.size());
739 // Reload the outputs passed in by reference
740 for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
741 Value *Output = nullptr;
742 if (AggregateArgs) {
743 Value *Idx[2];
744 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
745 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
746 GetElementPtrInst *GEP = GetElementPtrInst::Create(
747 StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
748 codeReplacer->getInstList().push_back(GEP);
749 Output = GEP;
750 } else {
751 Output = ReloadOutputs[i];
753 LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
754 Reloads.push_back(load);
755 codeReplacer->getInstList().push_back(load);
756 std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end());
757 for (unsigned u = 0, e = Users.size(); u != e; ++u) {
758 Instruction *inst = cast<Instruction>(Users[u]);
759 if (!Blocks.count(inst->getParent()))
760 inst->replaceUsesOfWith(outputs[i], load);
764 // Now we can emit a switch statement using the call as a value.
765 SwitchInst *TheSwitch =
766 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
767 codeReplacer, 0, codeReplacer);
769 // Since there may be multiple exits from the original region, make the new
770 // function return an unsigned, switch on that number. This loop iterates
771 // over all of the blocks in the extracted region, updating any terminator
772 // instructions in the to-be-extracted region that branch to blocks that are
773 // not in the region to be extracted.
774 std::map<BasicBlock*, BasicBlock*> ExitBlockMap;
776 unsigned switchVal = 0;
777 for (BasicBlock *Block : Blocks) {
778 TerminatorInst *TI = Block->getTerminator();
779 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
780 if (!Blocks.count(TI->getSuccessor(i))) {
781 BasicBlock *OldTarget = TI->getSuccessor(i);
782 // add a new basic block which returns the appropriate value
783 BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
784 if (!NewTarget) {
785 // If we don't already have an exit stub for this non-extracted
786 // destination, create one now!
787 NewTarget = BasicBlock::Create(Context,
788 OldTarget->getName() + ".exitStub",
789 newFunction);
790 unsigned SuccNum = switchVal++;
792 Value *brVal = nullptr;
793 switch (NumExitBlocks) {
794 case 0:
795 case 1: break; // No value needed.
796 case 2: // Conditional branch, return a bool
797 brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
798 break;
799 default:
800 brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
801 break;
804 ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget);
806 // Update the switch instruction.
807 TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
808 SuccNum),
809 OldTarget);
811 // Restore values just before we exit
812 Function::arg_iterator OAI = OutputArgBegin;
813 for (unsigned out = 0, e = outputs.size(); out != e; ++out) {
814 // For an invoke, the normal destination is the only one that is
815 // dominated by the result of the invocation
816 BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent();
818 bool DominatesDef = true;
820 BasicBlock *NormalDest = nullptr;
821 if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out]))
822 NormalDest = Invoke->getNormalDest();
824 if (NormalDest) {
825 DefBlock = NormalDest;
827 // Make sure we are looking at the original successor block, not
828 // at a newly inserted exit block, which won't be in the dominator
829 // info.
830 for (const auto &I : ExitBlockMap)
831 if (DefBlock == I.second) {
832 DefBlock = I.first;
833 break;
836 // In the extract block case, if the block we are extracting ends
837 // with an invoke instruction, make sure that we don't emit a
838 // store of the invoke value for the unwind block.
839 if (!DT && DefBlock != OldTarget)
840 DominatesDef = false;
843 if (DT) {
844 DominatesDef = DT->dominates(DefBlock, OldTarget);
846 // If the output value is used by a phi in the target block,
847 // then we need to test for dominance of the phi's predecessor
848 // instead. Unfortunately, this a little complicated since we
849 // have already rewritten uses of the value to uses of the reload.
850 BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out],
851 OldTarget);
852 if (pred && DT && DT->dominates(DefBlock, pred))
853 DominatesDef = true;
856 if (DominatesDef) {
857 if (AggregateArgs) {
858 Value *Idx[2];
859 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
860 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context),
861 FirstOut+out);
862 GetElementPtrInst *GEP = GetElementPtrInst::Create(
863 StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(),
864 NTRet);
865 new StoreInst(outputs[out], GEP, NTRet);
866 } else {
867 new StoreInst(outputs[out], &*OAI, NTRet);
870 // Advance output iterator even if we don't emit a store
871 if (!AggregateArgs) ++OAI;
875 // rewrite the original branch instruction with this new target
876 TI->setSuccessor(i, NewTarget);
880 // Now that we've done the deed, simplify the switch instruction.
881 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
882 switch (NumExitBlocks) {
883 case 0:
884 // There are no successors (the block containing the switch itself), which
885 // means that previously this was the last part of the function, and hence
886 // this should be rewritten as a `ret'
888 // Check if the function should return a value
889 if (OldFnRetTy->isVoidTy()) {
890 ReturnInst::Create(Context, nullptr, TheSwitch); // Return void
891 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
892 // return what we have
893 ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
894 } else {
895 // Otherwise we must have code extracted an unwind or something, just
896 // return whatever we want.
897 ReturnInst::Create(Context,
898 Constant::getNullValue(OldFnRetTy), TheSwitch);
901 TheSwitch->eraseFromParent();
902 break;
903 case 1:
904 // Only a single destination, change the switch into an unconditional
905 // branch.
906 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
907 TheSwitch->eraseFromParent();
908 break;
909 case 2:
910 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
911 call, TheSwitch);
912 TheSwitch->eraseFromParent();
913 break;
914 default:
915 // Otherwise, make the default destination of the switch instruction be one
916 // of the other successors.
917 TheSwitch->setCondition(call);
918 TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
919 // Remove redundant case
920 TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
921 break;
925 void CodeExtractor::moveCodeToFunction(Function *newFunction) {
926 Function *oldFunc = (*Blocks.begin())->getParent();
927 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
928 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
930 for (BasicBlock *Block : Blocks) {
931 // Delete the basic block from the old function, and the list of blocks
932 oldBlocks.remove(Block);
934 // Insert this basic block into the new function
935 newBlocks.push_back(Block);
939 void CodeExtractor::calculateNewCallTerminatorWeights(
940 BasicBlock *CodeReplacer,
941 DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
942 BranchProbabilityInfo *BPI) {
943 typedef BlockFrequencyInfoImplBase::Distribution Distribution;
944 typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
946 // Update the branch weights for the exit block.
947 TerminatorInst *TI = CodeReplacer->getTerminator();
948 SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
950 // Block Frequency distribution with dummy node.
951 Distribution BranchDist;
953 // Add each of the frequencies of the successors.
954 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
955 BlockNode ExitNode(i);
956 uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
957 if (ExitFreq != 0)
958 BranchDist.addExit(ExitNode, ExitFreq);
959 else
960 BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
963 // Check for no total weight.
964 if (BranchDist.Total == 0)
965 return;
967 // Normalize the distribution so that they can fit in unsigned.
968 BranchDist.normalize();
970 // Create normalized branch weights and set the metadata.
971 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
972 const auto &Weight = BranchDist.Weights[I];
974 // Get the weight and update the current BFI.
975 BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
976 BranchProbability BP(Weight.Amount, BranchDist.Total);
977 BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
979 TI->setMetadata(
980 LLVMContext::MD_prof,
981 MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
984 Function *CodeExtractor::extractCodeRegion() {
985 if (!isEligible())
986 return nullptr;
988 ValueSet inputs, outputs, SinkingCands, HoistingCands;
989 BasicBlock *CommonExit = nullptr;
991 // Assumption: this is a single-entry code region, and the header is the first
992 // block in the region.
993 BasicBlock *header = *Blocks.begin();
995 // Calculate the entry frequency of the new function before we change the root
996 // block.
997 BlockFrequency EntryFreq;
998 if (BFI) {
999 assert(BPI && "Both BPI and BFI are required to preserve profile info");
1000 for (BasicBlock *Pred : predecessors(header)) {
1001 if (Blocks.count(Pred))
1002 continue;
1003 EntryFreq +=
1004 BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1008 // If we have to split PHI nodes or the entry block, do so now.
1009 severSplitPHINodes(header);
1011 // If we have any return instructions in the region, split those blocks so
1012 // that the return is not in the region.
1013 splitReturnBlocks();
1015 Function *oldFunction = header->getParent();
1017 // This takes place of the original loop
1018 BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
1019 "codeRepl", oldFunction,
1020 header);
1022 // The new function needs a root node because other nodes can branch to the
1023 // head of the region, but the entry node of a function cannot have preds.
1024 BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
1025 "newFuncRoot");
1026 newFuncRoot->getInstList().push_back(BranchInst::Create(header));
1028 findAllocas(SinkingCands, HoistingCands, CommonExit);
1029 assert(HoistingCands.empty() || CommonExit);
1031 // Find inputs to, outputs from the code region.
1032 findInputsOutputs(inputs, outputs, SinkingCands);
1034 // Now sink all instructions which only have non-phi uses inside the region
1035 for (auto *II : SinkingCands)
1036 cast<Instruction>(II)->moveBefore(*newFuncRoot,
1037 newFuncRoot->getFirstInsertionPt());
1039 if (!HoistingCands.empty()) {
1040 auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1041 Instruction *TI = HoistToBlock->getTerminator();
1042 for (auto *II : HoistingCands)
1043 cast<Instruction>(II)->moveBefore(TI);
1046 // Calculate the exit blocks for the extracted region and the total exit
1047 // weights for each of those blocks.
1048 DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1049 SmallPtrSet<BasicBlock *, 1> ExitBlocks;
1050 for (BasicBlock *Block : Blocks) {
1051 for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
1052 ++SI) {
1053 if (!Blocks.count(*SI)) {
1054 // Update the branch weight for this successor.
1055 if (BFI) {
1056 BlockFrequency &BF = ExitWeights[*SI];
1057 BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
1059 ExitBlocks.insert(*SI);
1063 NumExitBlocks = ExitBlocks.size();
1065 // Construct new function based on inputs/outputs & add allocas for all defs.
1066 Function *newFunction = constructFunction(inputs, outputs, header,
1067 newFuncRoot,
1068 codeReplacer, oldFunction,
1069 oldFunction->getParent());
1071 // Update the entry count of the function.
1072 if (BFI) {
1073 Optional<uint64_t> EntryCount =
1074 BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
1075 if (EntryCount.hasValue())
1076 newFunction->setEntryCount(EntryCount.getValue());
1077 BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
1080 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
1082 moveCodeToFunction(newFunction);
1084 // Update the branch weights for the exit block.
1085 if (BFI && NumExitBlocks > 1)
1086 calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
1088 // Loop over all of the PHI nodes in the header block, and change any
1089 // references to the old incoming edge to be the new incoming edge.
1090 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1091 PHINode *PN = cast<PHINode>(I);
1092 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1093 if (!Blocks.count(PN->getIncomingBlock(i)))
1094 PN->setIncomingBlock(i, newFuncRoot);
1097 // Look at all successors of the codeReplacer block. If any of these blocks
1098 // had PHI nodes in them, we need to update the "from" block to be the code
1099 // replacer, not the original block in the extracted region.
1100 std::vector<BasicBlock*> Succs(succ_begin(codeReplacer),
1101 succ_end(codeReplacer));
1102 for (unsigned i = 0, e = Succs.size(); i != e; ++i)
1103 for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) {
1104 PHINode *PN = cast<PHINode>(I);
1105 std::set<BasicBlock*> ProcessedPreds;
1106 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1107 if (Blocks.count(PN->getIncomingBlock(i))) {
1108 if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second)
1109 PN->setIncomingBlock(i, codeReplacer);
1110 else {
1111 // There were multiple entries in the PHI for this block, now there
1112 // is only one, so remove the duplicated entries.
1113 PN->removeIncomingValue(i, false);
1114 --i; --e;
1119 DEBUG(if (verifyFunction(*newFunction))
1120 report_fatal_error("verifyFunction failed!"));
1121 return newFunction;