1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
11 // top-level loop into its own new function. If the loop is the ONLY loop in a
12 // given function, it is not touched. This is a pass most useful for debugging
15 //===----------------------------------------------------------------------===//
17 #define DEBUG_TYPE "loop-extract"
18 #include "llvm/Transforms/IPO.h"
19 #include "llvm/Instructions.h"
20 #include "llvm/Module.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Analysis/Dominators.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Compiler.h"
26 #include "llvm/Transforms/Scalar.h"
27 #include "llvm/Transforms/Utils/FunctionUtils.h"
28 #include "llvm/ADT/Statistic.h"
33 STATISTIC(NumExtracted
, "Number of loops extracted");
36 // FIXME: This is not a function pass, but the PassManager doesn't allow
37 // Module passes to require FunctionPasses, so we can't get loop info if we're
38 // not a function pass.
39 struct VISIBILITY_HIDDEN LoopExtractor
: public FunctionPass
{
40 static char ID
; // Pass identification, replacement for typeid
43 explicit LoopExtractor(unsigned numLoops
= ~0)
44 : FunctionPass(&ID
), NumLoops(numLoops
) {}
46 virtual bool runOnFunction(Function
&F
);
48 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
49 AU
.addRequiredID(BreakCriticalEdgesID
);
50 AU
.addRequiredID(LoopSimplifyID
);
51 AU
.addRequired
<DominatorTree
>();
52 AU
.addRequired
<LoopInfo
>();
57 char LoopExtractor::ID
= 0;
58 static RegisterPass
<LoopExtractor
>
59 X("loop-extract", "Extract loops into new functions");
62 /// SingleLoopExtractor - For bugpoint.
63 struct SingleLoopExtractor
: public LoopExtractor
{
64 static char ID
; // Pass identification, replacement for typeid
65 SingleLoopExtractor() : LoopExtractor(1) {}
67 } // End anonymous namespace
69 char SingleLoopExtractor::ID
= 0;
70 static RegisterPass
<SingleLoopExtractor
>
71 Y("loop-extract-single", "Extract at most one loop into a new function");
73 // createLoopExtractorPass - This pass extracts all natural loops from the
74 // program into a function if it can.
76 FunctionPass
*llvm::createLoopExtractorPass() { return new LoopExtractor(); }
78 bool LoopExtractor::runOnFunction(Function
&F
) {
79 LoopInfo
&LI
= getAnalysis
<LoopInfo
>();
81 // If this function has no loops, there is nothing to do.
85 DominatorTree
&DT
= getAnalysis
<DominatorTree
>();
87 // If there is more than one top-level loop in this function, extract all of
90 if (LI
.end()-LI
.begin() > 1) {
91 for (LoopInfo::iterator i
= LI
.begin(), e
= LI
.end(); i
!= e
; ++i
) {
92 if (NumLoops
== 0) return Changed
;
94 Changed
|= ExtractLoop(DT
, *i
) != 0;
98 // Otherwise there is exactly one top-level loop. If this function is more
99 // than a minimal wrapper around the loop, extract the loop.
100 Loop
*TLL
= *LI
.begin();
101 bool ShouldExtractLoop
= false;
103 // Extract the loop if the entry block doesn't branch to the loop header.
104 TerminatorInst
*EntryTI
= F
.getEntryBlock().getTerminator();
105 if (!isa
<BranchInst
>(EntryTI
) ||
106 !cast
<BranchInst
>(EntryTI
)->isUnconditional() ||
107 EntryTI
->getSuccessor(0) != TLL
->getHeader())
108 ShouldExtractLoop
= true;
110 // Check to see if any exits from the loop are more than just return
112 SmallVector
<BasicBlock
*, 8> ExitBlocks
;
113 TLL
->getExitBlocks(ExitBlocks
);
114 for (unsigned i
= 0, e
= ExitBlocks
.size(); i
!= e
; ++i
)
115 if (!isa
<ReturnInst
>(ExitBlocks
[i
]->getTerminator())) {
116 ShouldExtractLoop
= true;
121 if (ShouldExtractLoop
) {
122 if (NumLoops
== 0) return Changed
;
124 Changed
|= ExtractLoop(DT
, TLL
) != 0;
127 // Okay, this function is a minimal container around the specified loop.
128 // If we extract the loop, we will continue to just keep extracting it
129 // infinitely... so don't extract it. However, if the loop contains any
130 // subloops, extract them.
131 for (Loop::iterator i
= TLL
->begin(), e
= TLL
->end(); i
!= e
; ++i
) {
132 if (NumLoops
== 0) return Changed
;
134 Changed
|= ExtractLoop(DT
, *i
) != 0;
143 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
144 // program into a function if it can. This is used by bugpoint.
146 FunctionPass
*llvm::createSingleLoopExtractorPass() {
147 return new SingleLoopExtractor();
151 // BlockFile - A file which contains a list of blocks that should not be
153 static cl::opt
<std::string
>
154 BlockFile("extract-blocks-file", cl::value_desc("filename"),
155 cl::desc("A file containing list of basic blocks to not extract"),
159 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
160 /// from the module into their own functions except for those specified by the
161 /// BlocksToNotExtract list.
162 class BlockExtractorPass
: public ModulePass
{
163 void LoadFile(const char *Filename
);
165 std::vector
<BasicBlock
*> BlocksToNotExtract
;
166 std::vector
<std::pair
<std::string
, std::string
> > BlocksToNotExtractByName
;
168 static char ID
; // Pass identification, replacement for typeid
169 explicit BlockExtractorPass(const std::vector
<BasicBlock
*> &B
)
170 : ModulePass(&ID
), BlocksToNotExtract(B
) {
171 if (!BlockFile
.empty())
172 LoadFile(BlockFile
.c_str());
174 BlockExtractorPass() : ModulePass(&ID
) {}
176 bool runOnModule(Module
&M
);
180 char BlockExtractorPass::ID
= 0;
181 static RegisterPass
<BlockExtractorPass
>
182 XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)");
184 // createBlockExtractorPass - This pass extracts all blocks (except those
185 // specified in the argument list) from the functions in the module.
187 ModulePass
*llvm::createBlockExtractorPass(const std::vector
<BasicBlock
*> &BTNE
)
189 return new BlockExtractorPass(BTNE
);
192 void BlockExtractorPass::LoadFile(const char *Filename
) {
193 // Load the BlockFile...
194 std::ifstream
In(Filename
);
196 errs() << "WARNING: BlockExtractor couldn't load file '" << Filename
201 std::string FunctionName
, BlockName
;
204 if (!BlockName
.empty())
205 BlocksToNotExtractByName
.push_back(
206 std::make_pair(FunctionName
, BlockName
));
210 bool BlockExtractorPass::runOnModule(Module
&M
) {
211 std::set
<BasicBlock
*> TranslatedBlocksToNotExtract
;
212 for (unsigned i
= 0, e
= BlocksToNotExtract
.size(); i
!= e
; ++i
) {
213 BasicBlock
*BB
= BlocksToNotExtract
[i
];
214 Function
*F
= BB
->getParent();
216 // Map the corresponding function in this module.
217 Function
*MF
= M
.getFunction(F
->getName());
218 assert(MF
->getFunctionType() == F
->getFunctionType() && "Wrong function?");
220 // Figure out which index the basic block is in its function.
221 Function::iterator BBI
= MF
->begin();
222 std::advance(BBI
, std::distance(F
->begin(), Function::iterator(BB
)));
223 TranslatedBlocksToNotExtract
.insert(BBI
);
226 while (!BlocksToNotExtractByName
.empty()) {
227 // There's no way to find BBs by name without looking at every BB inside
228 // every Function. Fortunately, this is always empty except when used by
229 // bugpoint in which case correctness is more important than performance.
231 std::string
&FuncName
= BlocksToNotExtractByName
.back().first
;
232 std::string
&BlockName
= BlocksToNotExtractByName
.back().second
;
234 for (Module::iterator FI
= M
.begin(), FE
= M
.end(); FI
!= FE
; ++FI
) {
236 if (F
.getName() != FuncName
) continue;
238 for (Function::iterator BI
= F
.begin(), BE
= F
.end(); BI
!= BE
; ++BI
) {
239 BasicBlock
&BB
= *BI
;
240 if (BB
.getName() != BlockName
) continue;
242 TranslatedBlocksToNotExtract
.insert(BI
);
246 BlocksToNotExtractByName
.pop_back();
249 // Now that we know which blocks to not extract, figure out which ones we WANT
251 std::vector
<BasicBlock
*> BlocksToExtract
;
252 for (Module::iterator F
= M
.begin(), E
= M
.end(); F
!= E
; ++F
)
253 for (Function::iterator BB
= F
->begin(), E
= F
->end(); BB
!= E
; ++BB
)
254 if (!TranslatedBlocksToNotExtract
.count(BB
))
255 BlocksToExtract
.push_back(BB
);
257 for (unsigned i
= 0, e
= BlocksToExtract
.size(); i
!= e
; ++i
)
258 ExtractBasicBlock(BlocksToExtract
[i
]);
260 return !BlocksToExtract
.empty();