1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass extracts the specified basic blocks from the module into their
12 //===----------------------------------------------------------------------===//
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Pass.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Transforms/IPO.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 #include "llvm/Transforms/Utils/CodeExtractor.h"
28 #define DEBUG_TYPE "block-extractor"
30 STATISTIC(NumExtracted
, "Number of basic blocks extracted");
32 static cl::opt
<std::string
> BlockExtractorFile(
33 "extract-blocks-file", cl::value_desc("filename"),
34 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden
);
36 cl::opt
<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs",
37 cl::desc("Erase the existing functions"),
40 class BlockExtractor
: public ModulePass
{
41 SmallVector
<SmallVector
<BasicBlock
*, 16>, 4> GroupsOfBlocks
;
43 /// Map a function name to groups of blocks.
44 SmallVector
<std::pair
<std::string
, SmallVector
<std::string
, 4>>, 4>
47 void init(const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
48 &GroupsOfBlocksToExtract
) {
49 for (const SmallVectorImpl
<BasicBlock
*> &GroupOfBlocks
:
50 GroupsOfBlocksToExtract
) {
51 SmallVector
<BasicBlock
*, 16> NewGroup
;
52 NewGroup
.append(GroupOfBlocks
.begin(), GroupOfBlocks
.end());
53 GroupsOfBlocks
.emplace_back(NewGroup
);
55 if (!BlockExtractorFile
.empty())
61 BlockExtractor(const SmallVectorImpl
<BasicBlock
*> &BlocksToExtract
,
63 : ModulePass(ID
), EraseFunctions(EraseFunctions
) {
64 // We want one group per element of the input list.
65 SmallVector
<SmallVector
<BasicBlock
*, 16>, 4> MassagedGroupsOfBlocks
;
66 for (BasicBlock
*BB
: BlocksToExtract
) {
67 SmallVector
<BasicBlock
*, 16> NewGroup
;
68 NewGroup
.push_back(BB
);
69 MassagedGroupsOfBlocks
.push_back(NewGroup
);
71 init(MassagedGroupsOfBlocks
);
74 BlockExtractor(const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
75 &GroupsOfBlocksToExtract
,
77 : ModulePass(ID
), EraseFunctions(EraseFunctions
) {
78 init(GroupsOfBlocksToExtract
);
81 BlockExtractor() : BlockExtractor(SmallVector
<BasicBlock
*, 0>(), false) {}
82 bool runOnModule(Module
&M
) override
;
86 void splitLandingPadPreds(Function
&F
);
88 } // end anonymous namespace
90 char BlockExtractor::ID
= 0;
91 INITIALIZE_PASS(BlockExtractor
, "extract-blocks",
92 "Extract basic blocks from module", false, false)
94 ModulePass
*llvm::createBlockExtractorPass() { return new BlockExtractor(); }
95 ModulePass
*llvm::createBlockExtractorPass(
96 const SmallVectorImpl
<BasicBlock
*> &BlocksToExtract
, bool EraseFunctions
) {
97 return new BlockExtractor(BlocksToExtract
, EraseFunctions
);
99 ModulePass
*llvm::createBlockExtractorPass(
100 const SmallVectorImpl
<SmallVector
<BasicBlock
*, 16>>
101 &GroupsOfBlocksToExtract
,
102 bool EraseFunctions
) {
103 return new BlockExtractor(GroupsOfBlocksToExtract
, EraseFunctions
);
106 /// Gets all of the blocks specified in the input file.
107 void BlockExtractor::loadFile() {
108 auto ErrOrBuf
= MemoryBuffer::getFile(BlockExtractorFile
);
109 if (ErrOrBuf
.getError())
110 report_fatal_error("BlockExtractor couldn't load the file.");
112 auto &Buf
= *ErrOrBuf
;
113 SmallVector
<StringRef
, 16> Lines
;
114 Buf
->getBuffer().split(Lines
, '\n', /*MaxSplit=*/-1,
115 /*KeepEmpty=*/false);
116 for (const auto &Line
: Lines
) {
117 SmallVector
<StringRef
, 4> LineSplit
;
118 Line
.split(LineSplit
, ' ', /*MaxSplit=*/-1,
119 /*KeepEmpty=*/false);
120 if (LineSplit
.empty())
122 if (LineSplit
.size()!=2)
123 report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'");
124 SmallVector
<StringRef
, 4> BBNames
;
125 LineSplit
[1].split(BBNames
, ';', /*MaxSplit=*/-1,
126 /*KeepEmpty=*/false);
128 report_fatal_error("Missing bbs name");
129 BlocksByName
.push_back({LineSplit
[0], {BBNames
.begin(), BBNames
.end()}});
133 /// Extracts the landing pads to make sure all of them have only one
135 void BlockExtractor::splitLandingPadPreds(Function
&F
) {
136 for (BasicBlock
&BB
: F
) {
137 for (Instruction
&I
: BB
) {
138 if (!isa
<InvokeInst
>(&I
))
140 InvokeInst
*II
= cast
<InvokeInst
>(&I
);
141 BasicBlock
*Parent
= II
->getParent();
142 BasicBlock
*LPad
= II
->getUnwindDest();
144 // Look through the landing pad's predecessors. If one of them ends in an
145 // 'invoke', then we want to split the landing pad.
147 for (auto PredBB
: predecessors(LPad
)) {
148 if (PredBB
->isLandingPad() && PredBB
!= Parent
&&
149 isa
<InvokeInst
>(Parent
->getTerminator())) {
158 SmallVector
<BasicBlock
*, 2> NewBBs
;
159 SplitLandingPadPredecessors(LPad
, Parent
, ".1", ".2", NewBBs
);
164 bool BlockExtractor::runOnModule(Module
&M
) {
166 bool Changed
= false;
168 // Get all the functions.
169 SmallVector
<Function
*, 4> Functions
;
170 for (Function
&F
: M
) {
171 splitLandingPadPreds(F
);
172 Functions
.push_back(&F
);
175 // Get all the blocks specified in the input file.
176 unsigned NextGroupIdx
= GroupsOfBlocks
.size();
177 GroupsOfBlocks
.resize(NextGroupIdx
+ BlocksByName
.size());
178 for (const auto &BInfo
: BlocksByName
) {
179 Function
*F
= M
.getFunction(BInfo
.first
);
181 report_fatal_error("Invalid function name specified in the input file");
182 for (const auto &BBInfo
: BInfo
.second
) {
183 auto Res
= llvm::find_if(*F
, [&](const BasicBlock
&BB
) {
184 return BB
.getName().equals(BBInfo
);
187 report_fatal_error("Invalid block name specified in the input file");
188 GroupsOfBlocks
[NextGroupIdx
].push_back(&*Res
);
193 // Extract each group of basic blocks.
194 for (auto &BBs
: GroupsOfBlocks
) {
195 SmallVector
<BasicBlock
*, 32> BlocksToExtractVec
;
196 for (BasicBlock
*BB
: BBs
) {
197 // Check if the module contains BB.
198 if (BB
->getParent()->getParent() != &M
)
199 report_fatal_error("Invalid basic block");
200 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting "
201 << BB
->getParent()->getName() << ":" << BB
->getName()
203 BlocksToExtractVec
.push_back(BB
);
204 if (const InvokeInst
*II
= dyn_cast
<InvokeInst
>(BB
->getTerminator()))
205 BlocksToExtractVec
.push_back(II
->getUnwindDest());
209 Function
*F
= CodeExtractor(BlocksToExtractVec
).extractCodeRegion();
211 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs
.begin())->getName()
212 << "' in: " << F
->getName() << '\n');
214 LLVM_DEBUG(dbgs() << "Failed to extract for group '"
215 << (*BBs
.begin())->getName() << "'\n");
218 // Erase the functions.
219 if (EraseFunctions
|| BlockExtractorEraseFuncs
) {
220 for (Function
*F
: Functions
) {
221 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F
->getName()
225 // Set linkage as ExternalLinkage to avoid erasing unreachable functions.
226 for (Function
&F
: M
)
227 F
.setLinkage(GlobalValue::ExternalLinkage
);