1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
11 // top-level loop into its own new function. If the loop is the ONLY loop in a
12 // given function, it is not touched. This is a pass most useful for debugging
15 //===----------------------------------------------------------------------===//
17 #define DEBUG_TYPE "loop-extract"
18 #include "llvm/Transforms/IPO.h"
19 #include "llvm/Instructions.h"
20 #include "llvm/Module.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Analysis/Dominators.h"
23 #include "llvm/Analysis/LoopPass.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/FunctionUtils.h"
27 #include "llvm/ADT/Statistic.h"
32 STATISTIC(NumExtracted
, "Number of loops extracted");
35 struct LoopExtractor
: public LoopPass
{
36 static char ID
; // Pass identification, replacement for typeid
39 explicit LoopExtractor(unsigned numLoops
= ~0)
40 : LoopPass(ID
), NumLoops(numLoops
) {
41 initializeLoopExtractorPass(*PassRegistry::getPassRegistry());
44 virtual bool runOnLoop(Loop
*L
, LPPassManager
&LPM
);
46 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
47 AU
.addRequiredID(BreakCriticalEdgesID
);
48 AU
.addRequiredID(LoopSimplifyID
);
49 AU
.addRequired
<DominatorTree
>();
54 char LoopExtractor::ID
= 0;
55 INITIALIZE_PASS_BEGIN(LoopExtractor
, "loop-extract",
56 "Extract loops into new functions", false, false)
57 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges
)
58 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
59 INITIALIZE_PASS_DEPENDENCY(DominatorTree
)
60 INITIALIZE_PASS_END(LoopExtractor
, "loop-extract",
61 "Extract loops into new functions", false, false)
64 /// SingleLoopExtractor - For bugpoint.
65 struct SingleLoopExtractor
: public LoopExtractor
{
66 static char ID
; // Pass identification, replacement for typeid
67 SingleLoopExtractor() : LoopExtractor(1) {}
69 } // End anonymous namespace
71 char SingleLoopExtractor::ID
= 0;
72 INITIALIZE_PASS(SingleLoopExtractor
, "loop-extract-single",
73 "Extract at most one loop into a new function", false, false)
75 // createLoopExtractorPass - This pass extracts all natural loops from the
76 // program into a function if it can.
78 Pass
*llvm::createLoopExtractorPass() { return new LoopExtractor(); }
80 bool LoopExtractor::runOnLoop(Loop
*L
, LPPassManager
&LPM
) {
81 // Only visit top-level loops.
82 if (L
->getParentLoop())
85 // If LoopSimplify form is not available, stay out of trouble.
86 if (!L
->isLoopSimplifyForm())
89 DominatorTree
&DT
= getAnalysis
<DominatorTree
>();
92 // If there is more than one top-level loop in this function, extract all of
93 // the loops. Otherwise there is exactly one top-level loop; in this case if
94 // this function is more than a minimal wrapper around the loop, extract
96 bool ShouldExtractLoop
= false;
98 // Extract the loop if the entry block doesn't branch to the loop header.
99 TerminatorInst
*EntryTI
=
100 L
->getHeader()->getParent()->getEntryBlock().getTerminator();
101 if (!isa
<BranchInst
>(EntryTI
) ||
102 !cast
<BranchInst
>(EntryTI
)->isUnconditional() ||
103 EntryTI
->getSuccessor(0) != L
->getHeader())
104 ShouldExtractLoop
= true;
106 // Check to see if any exits from the loop are more than just return
108 SmallVector
<BasicBlock
*, 8> ExitBlocks
;
109 L
->getExitBlocks(ExitBlocks
);
110 for (unsigned i
= 0, e
= ExitBlocks
.size(); i
!= e
; ++i
)
111 if (!isa
<ReturnInst
>(ExitBlocks
[i
]->getTerminator())) {
112 ShouldExtractLoop
= true;
116 if (ShouldExtractLoop
) {
117 if (NumLoops
== 0) return Changed
;
119 if (ExtractLoop(DT
, L
) != 0) {
121 // After extraction, the loop is replaced by a function call, so
122 // we shouldn't try to run any more loop passes on it.
123 LPM
.deleteLoopFromQueue(L
);
131 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
132 // program into a function if it can. This is used by bugpoint.
134 Pass
*llvm::createSingleLoopExtractorPass() {
135 return new SingleLoopExtractor();
139 // BlockFile - A file which contains a list of blocks that should not be
141 static cl::opt
<std::string
>
142 BlockFile("extract-blocks-file", cl::value_desc("filename"),
143 cl::desc("A file containing list of basic blocks to not extract"),
147 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
148 /// from the module into their own functions except for those specified by the
149 /// BlocksToNotExtract list.
150 class BlockExtractorPass
: public ModulePass
{
151 void LoadFile(const char *Filename
);
153 std::vector
<BasicBlock
*> BlocksToNotExtract
;
154 std::vector
<std::pair
<std::string
, std::string
> > BlocksToNotExtractByName
;
156 static char ID
; // Pass identification, replacement for typeid
157 BlockExtractorPass() : ModulePass(ID
) {
158 if (!BlockFile
.empty())
159 LoadFile(BlockFile
.c_str());
162 bool runOnModule(Module
&M
);
166 char BlockExtractorPass::ID
= 0;
167 INITIALIZE_PASS(BlockExtractorPass
, "extract-blocks",
168 "Extract Basic Blocks From Module (for bugpoint use)",
171 // createBlockExtractorPass - This pass extracts all blocks (except those
172 // specified in the argument list) from the functions in the module.
174 ModulePass
*llvm::createBlockExtractorPass()
176 return new BlockExtractorPass();
179 void BlockExtractorPass::LoadFile(const char *Filename
) {
180 // Load the BlockFile...
181 std::ifstream
In(Filename
);
183 errs() << "WARNING: BlockExtractor couldn't load file '" << Filename
188 std::string FunctionName
, BlockName
;
191 if (!BlockName
.empty())
192 BlocksToNotExtractByName
.push_back(
193 std::make_pair(FunctionName
, BlockName
));
197 bool BlockExtractorPass::runOnModule(Module
&M
) {
198 std::set
<BasicBlock
*> TranslatedBlocksToNotExtract
;
199 for (unsigned i
= 0, e
= BlocksToNotExtract
.size(); i
!= e
; ++i
) {
200 BasicBlock
*BB
= BlocksToNotExtract
[i
];
201 Function
*F
= BB
->getParent();
203 // Map the corresponding function in this module.
204 Function
*MF
= M
.getFunction(F
->getName());
205 assert(MF
->getFunctionType() == F
->getFunctionType() && "Wrong function?");
207 // Figure out which index the basic block is in its function.
208 Function::iterator BBI
= MF
->begin();
209 std::advance(BBI
, std::distance(F
->begin(), Function::iterator(BB
)));
210 TranslatedBlocksToNotExtract
.insert(BBI
);
213 while (!BlocksToNotExtractByName
.empty()) {
214 // There's no way to find BBs by name without looking at every BB inside
215 // every Function. Fortunately, this is always empty except when used by
216 // bugpoint in which case correctness is more important than performance.
218 std::string
&FuncName
= BlocksToNotExtractByName
.back().first
;
219 std::string
&BlockName
= BlocksToNotExtractByName
.back().second
;
221 for (Module::iterator FI
= M
.begin(), FE
= M
.end(); FI
!= FE
; ++FI
) {
223 if (F
.getName() != FuncName
) continue;
225 for (Function::iterator BI
= F
.begin(), BE
= F
.end(); BI
!= BE
; ++BI
) {
226 BasicBlock
&BB
= *BI
;
227 if (BB
.getName() != BlockName
) continue;
229 TranslatedBlocksToNotExtract
.insert(BI
);
233 BlocksToNotExtractByName
.pop_back();
236 // Now that we know which blocks to not extract, figure out which ones we WANT
238 std::vector
<BasicBlock
*> BlocksToExtract
;
239 for (Module::iterator F
= M
.begin(), E
= M
.end(); F
!= E
; ++F
)
240 for (Function::iterator BB
= F
->begin(), E
= F
->end(); BB
!= E
; ++BB
)
241 if (!TranslatedBlocksToNotExtract
.count(BB
))
242 BlocksToExtract
.push_back(BB
);
244 for (unsigned i
= 0, e
= BlocksToExtract
.size(); i
!= e
; ++i
)
245 ExtractBasicBlock(BlocksToExtract
[i
]);
247 return !BlocksToExtract
.empty();