1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // The LowerSwitch transformation rewrites switch instructions with a sequence
11 // of branches, which allows targets to get away with not implementing the
12 // switch instruction until it is convenient.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
18 #include "llvm/Constants.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Compiler.h"
25 #include "llvm/Support/raw_ostream.h"
30 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
31 /// instructions. Note that this cannot be a BasicBlock pass because it
33 class VISIBILITY_HIDDEN LowerSwitch
: public FunctionPass
{
35 static char ID
; // Pass identification, replacement for typeid
36 LowerSwitch() : FunctionPass(&ID
) {}
38 virtual bool runOnFunction(Function
&F
);
40 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
41 // This is a cluster of orthogonal Transforms
42 AU
.addPreserved
<UnifyFunctionExitNodes
>();
43 AU
.addPreservedID(PromoteMemoryToRegisterID
);
44 AU
.addPreservedID(LowerInvokePassID
);
45 AU
.addPreservedID(LowerAllocationsID
);
53 CaseRange() : Low(0), High(0), BB(0) { }
54 CaseRange(Constant
* low
, Constant
* high
, BasicBlock
* bb
) :
55 Low(low
), High(high
), BB(bb
) { }
58 typedef std::vector
<CaseRange
> CaseVector
;
59 typedef std::vector
<CaseRange
>::iterator CaseItr
;
61 void processSwitchInst(SwitchInst
*SI
);
63 BasicBlock
* switchConvert(CaseItr Begin
, CaseItr End
, Value
* Val
,
64 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
65 BasicBlock
* newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
66 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
67 unsigned Clusterify(CaseVector
& Cases
, SwitchInst
*SI
);
70 /// The comparison function for sorting the switch case values in the vector.
71 /// WARNING: Case ranges should be disjoint!
73 bool operator () (const LowerSwitch::CaseRange
& C1
,
74 const LowerSwitch::CaseRange
& C2
) {
76 const ConstantInt
* CI1
= cast
<const ConstantInt
>(C1
.Low
);
77 const ConstantInt
* CI2
= cast
<const ConstantInt
>(C2
.High
);
78 return CI1
->getValue().slt(CI2
->getValue());
83 char LowerSwitch::ID
= 0;
84 static RegisterPass
<LowerSwitch
>
85 X("lowerswitch", "Lower SwitchInst's to branches");
87 // Publically exposed interface to pass...
88 const PassInfo
*const llvm::LowerSwitchID
= &X
;
89 // createLowerSwitchPass - Interface to this file...
90 FunctionPass
*llvm::createLowerSwitchPass() {
91 return new LowerSwitch();
94 bool LowerSwitch::runOnFunction(Function
&F
) {
97 for (Function::iterator I
= F
.begin(), E
= F
.end(); I
!= E
; ) {
98 BasicBlock
*Cur
= I
++; // Advance over block so we don't traverse new blocks
100 if (SwitchInst
*SI
= dyn_cast
<SwitchInst
>(Cur
->getTerminator())) {
102 processSwitchInst(SI
);
109 // operator<< - Used for debugging purposes.
111 static std::ostream
& operator<<(std::ostream
&O
,
112 const LowerSwitch::CaseVector
&C
) {
115 for (LowerSwitch::CaseVector::const_iterator B
= C
.begin(),
116 E
= C
.end(); B
!= E
; ) {
117 O
<< *B
->Low
<< " -" << *B
->High
;
118 if (++B
!= E
) O
<< ", ";
124 static OStream
& operator<<(OStream
&O
, const LowerSwitch::CaseVector
&C
) {
125 if (O
.stream()) *O
.stream() << C
;
129 // switchConvert - Convert the switch statement into a binary lookup of
130 // the case values. The function recursively builds this tree.
132 BasicBlock
* LowerSwitch::switchConvert(CaseItr Begin
, CaseItr End
,
133 Value
* Val
, BasicBlock
* OrigBlock
,
136 unsigned Size
= End
- Begin
;
139 return newLeafBlock(*Begin
, Val
, OrigBlock
, Default
);
141 unsigned Mid
= Size
/ 2;
142 std::vector
<CaseRange
> LHS(Begin
, Begin
+ Mid
);
143 DOUT
<< "LHS: " << LHS
<< "\n";
144 std::vector
<CaseRange
> RHS(Begin
+ Mid
, End
);
145 DOUT
<< "RHS: " << RHS
<< "\n";
147 CaseRange
& Pivot
= *(Begin
+ Mid
);
148 DEBUG(errs() << "Pivot ==> "
149 << cast
<ConstantInt
>(Pivot
.Low
)->getValue() << " -"
150 << cast
<ConstantInt
>(Pivot
.High
)->getValue() << "\n");
152 BasicBlock
* LBranch
= switchConvert(LHS
.begin(), LHS
.end(), Val
,
154 BasicBlock
* RBranch
= switchConvert(RHS
.begin(), RHS
.end(), Val
,
157 // Create a new node that checks if the value is < pivot. Go to the
158 // left branch if it is and right branch if not.
159 Function
* F
= OrigBlock
->getParent();
160 BasicBlock
* NewNode
= BasicBlock::Create("NodeBlock");
161 Function::iterator FI
= OrigBlock
;
162 F
->getBasicBlockList().insert(++FI
, NewNode
);
164 ICmpInst
* Comp
= new ICmpInst(ICmpInst::ICMP_SLT
, Val
, Pivot
.Low
, "Pivot");
165 NewNode
->getInstList().push_back(Comp
);
166 BranchInst::Create(LBranch
, RBranch
, Comp
, NewNode
);
170 // newLeafBlock - Create a new leaf block for the binary lookup tree. It
171 // checks if the switch's value == the case's value. If not, then it
172 // jumps to the default branch. At this point in the tree, the value
173 // can't be another valid case value, so the jump to the "default" branch
176 BasicBlock
* LowerSwitch::newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
177 BasicBlock
* OrigBlock
,
180 Function
* F
= OrigBlock
->getParent();
181 BasicBlock
* NewLeaf
= BasicBlock::Create("LeafBlock");
182 Function::iterator FI
= OrigBlock
;
183 F
->getBasicBlockList().insert(++FI
, NewLeaf
);
186 ICmpInst
* Comp
= NULL
;
187 if (Leaf
.Low
== Leaf
.High
) {
188 // Make the seteq instruction...
189 Comp
= new ICmpInst(ICmpInst::ICMP_EQ
, Val
, Leaf
.Low
,
190 "SwitchLeaf", NewLeaf
);
192 // Make range comparison
193 if (cast
<ConstantInt
>(Leaf
.Low
)->isMinValue(true /*isSigned*/)) {
194 // Val >= Min && Val <= Hi --> Val <= Hi
195 Comp
= new ICmpInst(ICmpInst::ICMP_SLE
, Val
, Leaf
.High
,
196 "SwitchLeaf", NewLeaf
);
197 } else if (cast
<ConstantInt
>(Leaf
.Low
)->isZero()) {
198 // Val >= 0 && Val <= Hi --> Val <=u Hi
199 Comp
= new ICmpInst(ICmpInst::ICMP_ULE
, Val
, Leaf
.High
,
200 "SwitchLeaf", NewLeaf
);
202 // Emit V-Lo <=u Hi-Lo
203 Constant
* NegLo
= ConstantExpr::getNeg(Leaf
.Low
);
204 Instruction
* Add
= BinaryOperator::CreateAdd(Val
, NegLo
,
205 Val
->getName()+".off",
207 Constant
*UpperBound
= ConstantExpr::getAdd(NegLo
, Leaf
.High
);
208 Comp
= new ICmpInst(ICmpInst::ICMP_ULE
, Add
, UpperBound
,
209 "SwitchLeaf", NewLeaf
);
213 // Make the conditional branch...
214 BasicBlock
* Succ
= Leaf
.BB
;
215 BranchInst::Create(Succ
, Default
, Comp
, NewLeaf
);
217 // If there were any PHI nodes in this successor, rewrite one entry
218 // from OrigBlock to come from NewLeaf.
219 for (BasicBlock::iterator I
= Succ
->begin(); isa
<PHINode
>(I
); ++I
) {
220 PHINode
* PN
= cast
<PHINode
>(I
);
221 // Remove all but one incoming entries from the cluster
222 uint64_t Range
= cast
<ConstantInt
>(Leaf
.High
)->getSExtValue() -
223 cast
<ConstantInt
>(Leaf
.Low
)->getSExtValue();
224 for (uint64_t j
= 0; j
< Range
; ++j
) {
225 PN
->removeIncomingValue(OrigBlock
);
228 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
229 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
230 PN
->setIncomingBlock((unsigned)BlockIdx
, NewLeaf
);
236 // Clusterify - Transform simple list of Cases into list of CaseRange's
237 unsigned LowerSwitch::Clusterify(CaseVector
& Cases
, SwitchInst
*SI
) {
238 unsigned numCmps
= 0;
240 // Start with "simple" cases
241 for (unsigned i
= 1; i
< SI
->getNumSuccessors(); ++i
)
242 Cases
.push_back(CaseRange(SI
->getSuccessorValue(i
),
243 SI
->getSuccessorValue(i
),
244 SI
->getSuccessor(i
)));
245 std::sort(Cases
.begin(), Cases
.end(), CaseCmp());
247 // Merge case into clusters
249 for (CaseItr I
=Cases
.begin(), J
=next(Cases
.begin()); J
!=Cases
.end(); ) {
250 int64_t nextValue
= cast
<ConstantInt
>(J
->Low
)->getSExtValue();
251 int64_t currentValue
= cast
<ConstantInt
>(I
->High
)->getSExtValue();
252 BasicBlock
* nextBB
= J
->BB
;
253 BasicBlock
* currentBB
= I
->BB
;
255 // If the two neighboring cases go to the same destination, merge them
256 // into a single case.
257 if ((nextValue
-currentValue
==1) && (currentBB
== nextBB
)) {
265 for (CaseItr I
=Cases
.begin(), E
=Cases
.end(); I
!=E
; ++I
, ++numCmps
) {
266 if (I
->Low
!= I
->High
)
267 // A range counts double, since it requires two compares.
274 // processSwitchInst - Replace the specified switch instruction with a sequence
275 // of chained if-then insts in a balanced binary search.
277 void LowerSwitch::processSwitchInst(SwitchInst
*SI
) {
278 BasicBlock
*CurBlock
= SI
->getParent();
279 BasicBlock
*OrigBlock
= CurBlock
;
280 Function
*F
= CurBlock
->getParent();
281 Value
*Val
= SI
->getOperand(0); // The value we are switching on...
282 BasicBlock
* Default
= SI
->getDefaultDest();
284 // If there is only the default destination, don't bother with the code below.
285 if (SI
->getNumOperands() == 2) {
286 BranchInst::Create(SI
->getDefaultDest(), CurBlock
);
287 CurBlock
->getInstList().erase(SI
);
291 // Create a new, empty default block so that the new hierarchy of
292 // if-then statements go to this and the PHI nodes are happy.
293 BasicBlock
* NewDefault
= BasicBlock::Create("NewDefault");
294 F
->getBasicBlockList().insert(Default
, NewDefault
);
296 BranchInst::Create(Default
, NewDefault
);
298 // If there is an entry in any PHI nodes for the default edge, make sure
299 // to update them as well.
300 for (BasicBlock::iterator I
= Default
->begin(); isa
<PHINode
>(I
); ++I
) {
301 PHINode
*PN
= cast
<PHINode
>(I
);
302 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
303 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
304 PN
->setIncomingBlock((unsigned)BlockIdx
, NewDefault
);
307 // Prepare cases vector.
309 unsigned numCmps
= Clusterify(Cases
, SI
);
311 DOUT
<< "Clusterify finished. Total clusters: " << Cases
.size()
312 << ". Total compares: " << numCmps
<< "\n";
313 DOUT
<< "Cases: " << Cases
<< "\n";
315 BasicBlock
* SwitchBlock
= switchConvert(Cases
.begin(), Cases
.end(), Val
,
316 OrigBlock
, NewDefault
);
318 // Branch to our shiny new if-then stuff...
319 BranchInst::Create(SwitchBlock
, OrigBlock
);
321 // We are now done with the switch instruction, delete it.
322 CurBlock
->getInstList().erase(SI
);