1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass extracts the specified basic blocks from the module into their
12 //===----------------------------------------------------------------------===//
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Pass.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Transforms/IPO.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 #include "llvm/Transforms/Utils/CodeExtractor.h"
28 #define DEBUG_TYPE "block-extractor"
30 STATISTIC(NumExtracted
, "Number of basic blocks extracted");
32 static cl::opt
<std::string
> BlockExtractorFile(
33 "extract-blocks-file", cl::value_desc("filename"),
34 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden
);
36 cl::opt
<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
37 cl::desc("Erase the existing functions"),
40 class BlockExtractor
: public ModulePass
{
41 SmallVector
<SmallVector
<BasicBlock
*, 16>, 4> GroupsOfBlocks
;
43 /// Map a function name to groups of blocks.
44 SmallVector
<std::pair
<std::string
, SmallVector
<std::string
, 4>>, 4>
47 void init(const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
48 &GroupsOfBlocksToExtract
) {
49 for (const SmallVectorImpl
<BasicBlock
*> &GroupOfBlocks
:
50 GroupsOfBlocksToExtract
) {
51 SmallVector
<BasicBlock
*, 16> NewGroup
;
52 NewGroup
.append(GroupOfBlocks
.begin(), GroupOfBlocks
.end());
53 GroupsOfBlocks
.emplace_back(NewGroup
);
55 if (!BlockExtractorFile
.empty())
61 BlockExtractor(const SmallVectorImpl
<BasicBlock
*> &BlocksToExtract
,
63 : ModulePass(ID
), EraseFunctions(EraseFunctions
) {
64 // We want one group per element of the input list.
65 SmallVector
<SmallVector
<BasicBlock
*, 16>, 4> MassagedGroupsOfBlocks
;
66 for (BasicBlock
*BB
: BlocksToExtract
) {
67 SmallVector
<BasicBlock
*, 16> NewGroup
;
68 NewGroup
.push_back(BB
);
69 MassagedGroupsOfBlocks
.push_back(NewGroup
);
71 init(MassagedGroupsOfBlocks
);
74 BlockExtractor(const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
75 &GroupsOfBlocksToExtract
,
77 : ModulePass(ID
), EraseFunctions(EraseFunctions
) {
78 init(GroupsOfBlocksToExtract
);
81 BlockExtractor() : BlockExtractor(SmallVector
<BasicBlock
*, 0>(), false) {}
82 bool runOnModule(Module
&M
) override
;
86 void splitLandingPadPreds(Function
&F
);
88 } // end anonymous namespace
90 char BlockExtractor::ID
= 0;
91 INITIALIZE_PASS(BlockExtractor
, "extract-blocks",
92 "Extract basic blocks from module", false, false)
94 ModulePass
*llvm::createBlockExtractorPass() { return new BlockExtractor(); }
95 ModulePass
*llvm::createBlockExtractorPass(
96 const SmallVectorImpl
<BasicBlock
*> &BlocksToExtract
, bool EraseFunctions
) {
97 return new BlockExtractor(BlocksToExtract
, EraseFunctions
);
99 ModulePass
*llvm::createBlockExtractorPass(
100 const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
101 &GroupsOfBlocksToExtract
,
102 bool EraseFunctions
) {
103 return new BlockExtractor(GroupsOfBlocksToExtract
, EraseFunctions
);
106 /// Gets all of the blocks specified in the input file.
107 void BlockExtractor::loadFile() {
108 auto ErrOrBuf
= MemoryBuffer::getFile(BlockExtractorFile
);
109 if (ErrOrBuf
.getError())
110 report_fatal_error("BlockExtractor couldn't load the file.");
112 auto &Buf
= *ErrOrBuf
;
113 SmallVector
<StringRef
, 16> Lines
;
114 Buf
->getBuffer().split(Lines
, '\n', /*MaxSplit=*/-1,
115 /*KeepEmpty=*/false);
116 for (const auto &Line
: Lines
) {
117 SmallVector
<StringRef
, 4> LineSplit
;
118 Line
.split(LineSplit
, ' ', /*MaxSplit=*/-1,
119 /*KeepEmpty=*/false);
120 if (LineSplit
.empty())
122 SmallVector
<StringRef
, 4> BBNames
;
123 LineSplit
[1].split(BBNames
, ';', /*MaxSplit=*/-1,
124 /*KeepEmpty=*/false);
126 report_fatal_error("Missing bbs name");
127 BlocksByName
.push_back({LineSplit
[0], {BBNames
.begin(), BBNames
.end()}});
131 /// Extracts the landing pads to make sure all of them have only one
133 void BlockExtractor::splitLandingPadPreds(Function
&F
) {
134 for (BasicBlock
&BB
: F
) {
135 for (Instruction
&I
: BB
) {
136 if (!isa
<InvokeInst
>(&I
))
138 InvokeInst
*II
= cast
<InvokeInst
>(&I
);
139 BasicBlock
*Parent
= II
->getParent();
140 BasicBlock
*LPad
= II
->getUnwindDest();
142 // Look through the landing pad's predecessors. If one of them ends in an
143 // 'invoke', then we want to split the landing pad.
145 for (auto PredBB
: predecessors(LPad
)) {
146 if (PredBB
->isLandingPad() && PredBB
!= Parent
&&
147 isa
<InvokeInst
>(Parent
->getTerminator())) {
156 SmallVector
<BasicBlock
*, 2> NewBBs
;
157 SplitLandingPadPredecessors(LPad
, Parent
, ".1", ".2", NewBBs
);
162 bool BlockExtractor::runOnModule(Module
&M
) {
164 bool Changed
= false;
166 // Get all the functions.
167 SmallVector
<Function
*, 4> Functions
;
168 for (Function
&F
: M
) {
169 splitLandingPadPreds(F
);
170 Functions
.push_back(&F
);
173 // Get all the blocks specified in the input file.
174 unsigned NextGroupIdx
= GroupsOfBlocks
.size();
175 GroupsOfBlocks
.resize(NextGroupIdx
+ BlocksByName
.size());
176 for (const auto &BInfo
: BlocksByName
) {
177 Function
*F
= M
.getFunction(BInfo
.first
);
179 report_fatal_error("Invalid function name specified in the input file");
180 for (const auto &BBInfo
: BInfo
.second
) {
181 auto Res
= llvm::find_if(*F
, [&](const BasicBlock
&BB
) {
182 return BB
.getName().equals(BBInfo
);
185 report_fatal_error("Invalid block name specified in the input file");
186 GroupsOfBlocks
[NextGroupIdx
].push_back(&*Res
);
191 // Extract each group of basic blocks.
192 for (auto &BBs
: GroupsOfBlocks
) {
193 SmallVector
<BasicBlock
*, 32> BlocksToExtractVec
;
194 for (BasicBlock
*BB
: BBs
) {
195 // Check if the module contains BB.
196 if (BB
->getParent()->getParent() != &M
)
197 report_fatal_error("Invalid basic block");
198 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
199 << BB
->getParent()->getName() << ":" << BB
->getName()
201 BlocksToExtractVec
.push_back(BB
);
202 if (const InvokeInst
*II
= dyn_cast
<InvokeInst
>(BB
->getTerminator()))
203 BlocksToExtractVec
.push_back(II
->getUnwindDest());
207 Function
*F
= CodeExtractor(BlocksToExtractVec
).extractCodeRegion();
209 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs
.begin())->getName()
210 << "' in: " << F
->getName() << '\n');
212 LLVM_DEBUG(dbgs() << "Failed to extract for group '"
213 << (*BBs
.begin())->getName() << "'\n");
216 // Erase the functions.
217 if (EraseFunctions
|| BlockExtractorEraseFuncs
) {
218 for (Function
*F
: Functions
) {
219 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F
->getName()
223 // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
224 for (Function
&F
: M
)
225 F
.setLinkage(GlobalValue::ExternalLinkage
);