1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // The LowerSwitch transformation rewrites switch instructions with a sequence
11 // of branches, which allows targets to get away with not implementing the
12 // switch instruction until it is convenient.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
18 #include "llvm/Constants.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/LLVMContext.h"
22 #include "llvm/Pass.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/Compiler.h"
26 #include "llvm/Support/raw_ostream.h"
31 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
32 /// instructions. Note that this cannot be a BasicBlock pass because it
34 class VISIBILITY_HIDDEN LowerSwitch
: public FunctionPass
{
36 static char ID
; // Pass identification, replacement for typeid
37 LowerSwitch() : FunctionPass(&ID
) {}
39 virtual bool runOnFunction(Function
&F
);
41 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
42 // This is a cluster of orthogonal Transforms
43 AU
.addPreserved
<UnifyFunctionExitNodes
>();
44 AU
.addPreservedID(PromoteMemoryToRegisterID
);
45 AU
.addPreservedID(LowerInvokePassID
);
46 AU
.addPreservedID(LowerAllocationsID
);
54 CaseRange() : Low(0), High(0), BB(0) { }
55 CaseRange(Constant
* low
, Constant
* high
, BasicBlock
* bb
) :
56 Low(low
), High(high
), BB(bb
) { }
59 typedef std::vector
<CaseRange
> CaseVector
;
60 typedef std::vector
<CaseRange
>::iterator CaseItr
;
62 void processSwitchInst(SwitchInst
*SI
);
64 BasicBlock
* switchConvert(CaseItr Begin
, CaseItr End
, Value
* Val
,
65 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
66 BasicBlock
* newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
67 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
68 unsigned Clusterify(CaseVector
& Cases
, SwitchInst
*SI
);
71 /// The comparison function for sorting the switch case values in the vector.
72 /// WARNING: Case ranges should be disjoint!
74 bool operator () (const LowerSwitch::CaseRange
& C1
,
75 const LowerSwitch::CaseRange
& C2
) {
77 const ConstantInt
* CI1
= cast
<const ConstantInt
>(C1
.Low
);
78 const ConstantInt
* CI2
= cast
<const ConstantInt
>(C2
.High
);
79 return CI1
->getValue().slt(CI2
->getValue());
84 char LowerSwitch::ID
= 0;
85 static RegisterPass
<LowerSwitch
>
86 X("lowerswitch", "Lower SwitchInst's to branches");
88 // Publically exposed interface to pass...
89 const PassInfo
*const llvm::LowerSwitchID
= &X
;
90 // createLowerSwitchPass - Interface to this file...
91 FunctionPass
*llvm::createLowerSwitchPass() {
92 return new LowerSwitch();
95 bool LowerSwitch::runOnFunction(Function
&F
) {
98 for (Function::iterator I
= F
.begin(), E
= F
.end(); I
!= E
; ) {
99 BasicBlock
*Cur
= I
++; // Advance over block so we don't traverse new blocks
101 if (SwitchInst
*SI
= dyn_cast
<SwitchInst
>(Cur
->getTerminator())) {
103 processSwitchInst(SI
);
110 // operator<< - Used for debugging purposes.
112 static raw_ostream
& operator<<(raw_ostream
&O
,
113 const LowerSwitch::CaseVector
&C
) ATTRIBUTE_USED
;
114 static raw_ostream
& operator<<(raw_ostream
&O
,
115 const LowerSwitch::CaseVector
&C
) {
118 for (LowerSwitch::CaseVector::const_iterator B
= C
.begin(),
119 E
= C
.end(); B
!= E
; ) {
120 O
<< *B
->Low
<< " -" << *B
->High
;
121 if (++B
!= E
) O
<< ", ";
127 // switchConvert - Convert the switch statement into a binary lookup of
128 // the case values. The function recursively builds this tree.
130 BasicBlock
* LowerSwitch::switchConvert(CaseItr Begin
, CaseItr End
,
131 Value
* Val
, BasicBlock
* OrigBlock
,
134 unsigned Size
= End
- Begin
;
137 return newLeafBlock(*Begin
, Val
, OrigBlock
, Default
);
139 unsigned Mid
= Size
/ 2;
140 std::vector
<CaseRange
> LHS(Begin
, Begin
+ Mid
);
141 DEBUG(errs() << "LHS: " << LHS
<< "\n");
142 std::vector
<CaseRange
> RHS(Begin
+ Mid
, End
);
143 DEBUG(errs() << "RHS: " << RHS
<< "\n");
145 CaseRange
& Pivot
= *(Begin
+ Mid
);
146 DEBUG(errs() << "Pivot ==> "
147 << cast
<ConstantInt
>(Pivot
.Low
)->getValue() << " -"
148 << cast
<ConstantInt
>(Pivot
.High
)->getValue() << "\n");
150 BasicBlock
* LBranch
= switchConvert(LHS
.begin(), LHS
.end(), Val
,
152 BasicBlock
* RBranch
= switchConvert(RHS
.begin(), RHS
.end(), Val
,
155 // Create a new node that checks if the value is < pivot. Go to the
156 // left branch if it is and right branch if not.
157 Function
* F
= OrigBlock
->getParent();
158 BasicBlock
* NewNode
= BasicBlock::Create("NodeBlock");
159 Function::iterator FI
= OrigBlock
;
160 F
->getBasicBlockList().insert(++FI
, NewNode
);
162 ICmpInst
* Comp
= new ICmpInst(Default
->getContext(), ICmpInst::ICMP_SLT
,
163 Val
, Pivot
.Low
, "Pivot");
164 NewNode
->getInstList().push_back(Comp
);
165 BranchInst::Create(LBranch
, RBranch
, Comp
, NewNode
);
169 // newLeafBlock - Create a new leaf block for the binary lookup tree. It
170 // checks if the switch's value == the case's value. If not, then it
171 // jumps to the default branch. At this point in the tree, the value
172 // can't be another valid case value, so the jump to the "default" branch
175 BasicBlock
* LowerSwitch::newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
176 BasicBlock
* OrigBlock
,
179 Function
* F
= OrigBlock
->getParent();
180 BasicBlock
* NewLeaf
= BasicBlock::Create("LeafBlock");
181 Function::iterator FI
= OrigBlock
;
182 F
->getBasicBlockList().insert(++FI
, NewLeaf
);
185 ICmpInst
* Comp
= NULL
;
186 if (Leaf
.Low
== Leaf
.High
) {
187 // Make the seteq instruction...
188 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_EQ
, Val
,
189 Leaf
.Low
, "SwitchLeaf");
191 // Make range comparison
192 if (cast
<ConstantInt
>(Leaf
.Low
)->isMinValue(true /*isSigned*/)) {
193 // Val >= Min && Val <= Hi --> Val <= Hi
194 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_SLE
, Val
, Leaf
.High
,
196 } else if (cast
<ConstantInt
>(Leaf
.Low
)->isZero()) {
197 // Val >= 0 && Val <= Hi --> Val <=u Hi
198 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_ULE
, Val
, Leaf
.High
,
201 // Emit V-Lo <=u Hi-Lo
202 Constant
* NegLo
= ConstantExpr::getNeg(Leaf
.Low
);
203 Instruction
* Add
= BinaryOperator::CreateAdd(Val
, NegLo
,
204 Val
->getName()+".off",
206 Constant
*UpperBound
= ConstantExpr::getAdd(NegLo
, Leaf
.High
);
207 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_ULE
, Add
, UpperBound
,
212 // Make the conditional branch...
213 BasicBlock
* Succ
= Leaf
.BB
;
214 BranchInst::Create(Succ
, Default
, Comp
, NewLeaf
);
216 // If there were any PHI nodes in this successor, rewrite one entry
217 // from OrigBlock to come from NewLeaf.
218 for (BasicBlock::iterator I
= Succ
->begin(); isa
<PHINode
>(I
); ++I
) {
219 PHINode
* PN
= cast
<PHINode
>(I
);
220 // Remove all but one incoming entries from the cluster
221 uint64_t Range
= cast
<ConstantInt
>(Leaf
.High
)->getSExtValue() -
222 cast
<ConstantInt
>(Leaf
.Low
)->getSExtValue();
223 for (uint64_t j
= 0; j
< Range
; ++j
) {
224 PN
->removeIncomingValue(OrigBlock
);
227 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
228 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
229 PN
->setIncomingBlock((unsigned)BlockIdx
, NewLeaf
);
235 // Clusterify - Transform simple list of Cases into list of CaseRange's
236 unsigned LowerSwitch::Clusterify(CaseVector
& Cases
, SwitchInst
*SI
) {
237 unsigned numCmps
= 0;
239 // Start with "simple" cases
240 for (unsigned i
= 1; i
< SI
->getNumSuccessors(); ++i
)
241 Cases
.push_back(CaseRange(SI
->getSuccessorValue(i
),
242 SI
->getSuccessorValue(i
),
243 SI
->getSuccessor(i
)));
244 std::sort(Cases
.begin(), Cases
.end(), CaseCmp());
246 // Merge case into clusters
248 for (CaseItr I
=Cases
.begin(), J
=next(Cases
.begin()); J
!=Cases
.end(); ) {
249 int64_t nextValue
= cast
<ConstantInt
>(J
->Low
)->getSExtValue();
250 int64_t currentValue
= cast
<ConstantInt
>(I
->High
)->getSExtValue();
251 BasicBlock
* nextBB
= J
->BB
;
252 BasicBlock
* currentBB
= I
->BB
;
254 // If the two neighboring cases go to the same destination, merge them
255 // into a single case.
256 if ((nextValue
-currentValue
==1) && (currentBB
== nextBB
)) {
264 for (CaseItr I
=Cases
.begin(), E
=Cases
.end(); I
!=E
; ++I
, ++numCmps
) {
265 if (I
->Low
!= I
->High
)
266 // A range counts double, since it requires two compares.
273 // processSwitchInst - Replace the specified switch instruction with a sequence
274 // of chained if-then insts in a balanced binary search.
276 void LowerSwitch::processSwitchInst(SwitchInst
*SI
) {
277 BasicBlock
*CurBlock
= SI
->getParent();
278 BasicBlock
*OrigBlock
= CurBlock
;
279 Function
*F
= CurBlock
->getParent();
280 Value
*Val
= SI
->getOperand(0); // The value we are switching on...
281 BasicBlock
* Default
= SI
->getDefaultDest();
283 // If there is only the default destination, don't bother with the code below.
284 if (SI
->getNumOperands() == 2) {
285 BranchInst::Create(SI
->getDefaultDest(), CurBlock
);
286 CurBlock
->getInstList().erase(SI
);
290 // Create a new, empty default block so that the new hierarchy of
291 // if-then statements go to this and the PHI nodes are happy.
292 BasicBlock
* NewDefault
= BasicBlock::Create("NewDefault");
293 F
->getBasicBlockList().insert(Default
, NewDefault
);
295 BranchInst::Create(Default
, NewDefault
);
297 // If there is an entry in any PHI nodes for the default edge, make sure
298 // to update them as well.
299 for (BasicBlock::iterator I
= Default
->begin(); isa
<PHINode
>(I
); ++I
) {
300 PHINode
*PN
= cast
<PHINode
>(I
);
301 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
302 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
303 PN
->setIncomingBlock((unsigned)BlockIdx
, NewDefault
);
306 // Prepare cases vector.
308 unsigned numCmps
= Clusterify(Cases
, SI
);
310 DEBUG(errs() << "Clusterify finished. Total clusters: " << Cases
.size()
311 << ". Total compares: " << numCmps
<< "\n");
312 DEBUG(errs() << "Cases: " << Cases
<< "\n");
315 BasicBlock
* SwitchBlock
= switchConvert(Cases
.begin(), Cases
.end(), Val
,
316 OrigBlock
, NewDefault
);
318 // Branch to our shiny new if-then stuff...
319 BranchInst::Create(SwitchBlock
, OrigBlock
);
321 // We are now done with the switch instruction, delete it.
322 CurBlock
->getInstList().erase(SI
);