1 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // The LowerSwitch transformation rewrites switch instructions with a sequence
11 // of branches, which allows targets to get away with not implementing the
12 // switch instruction until it is convenient.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
18 #include "llvm/Constants.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/LLVMContext.h"
22 #include "llvm/Pass.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/Compiler.h"
26 #include "llvm/Support/raw_ostream.h"
31 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
32 /// instructions. Note that this cannot be a BasicBlock pass because it
34 class VISIBILITY_HIDDEN LowerSwitch
: public FunctionPass
{
36 static char ID
; // Pass identification, replacement for typeid
37 LowerSwitch() : FunctionPass(&ID
) {}
39 virtual bool runOnFunction(Function
&F
);
41 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
42 // This is a cluster of orthogonal Transforms
43 AU
.addPreserved
<UnifyFunctionExitNodes
>();
44 AU
.addPreservedID(PromoteMemoryToRegisterID
);
45 AU
.addPreservedID(LowerInvokePassID
);
46 AU
.addPreservedID(LowerAllocationsID
);
54 CaseRange() : Low(0), High(0), BB(0) { }
55 CaseRange(Constant
* low
, Constant
* high
, BasicBlock
* bb
) :
56 Low(low
), High(high
), BB(bb
) { }
59 typedef std::vector
<CaseRange
> CaseVector
;
60 typedef std::vector
<CaseRange
>::iterator CaseItr
;
62 void processSwitchInst(SwitchInst
*SI
);
64 BasicBlock
* switchConvert(CaseItr Begin
, CaseItr End
, Value
* Val
,
65 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
66 BasicBlock
* newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
67 BasicBlock
* OrigBlock
, BasicBlock
* Default
);
68 unsigned Clusterify(CaseVector
& Cases
, SwitchInst
*SI
);
71 /// The comparison function for sorting the switch case values in the vector.
72 /// WARNING: Case ranges should be disjoint!
74 bool operator () (const LowerSwitch::CaseRange
& C1
,
75 const LowerSwitch::CaseRange
& C2
) {
77 const ConstantInt
* CI1
= cast
<const ConstantInt
>(C1
.Low
);
78 const ConstantInt
* CI2
= cast
<const ConstantInt
>(C2
.High
);
79 return CI1
->getValue().slt(CI2
->getValue());
84 char LowerSwitch::ID
= 0;
85 static RegisterPass
<LowerSwitch
>
86 X("lowerswitch", "Lower SwitchInst's to branches");
88 // Publically exposed interface to pass...
89 const PassInfo
*const llvm::LowerSwitchID
= &X
;
90 // createLowerSwitchPass - Interface to this file...
91 FunctionPass
*llvm::createLowerSwitchPass() {
92 return new LowerSwitch();
95 bool LowerSwitch::runOnFunction(Function
&F
) {
98 for (Function::iterator I
= F
.begin(), E
= F
.end(); I
!= E
; ) {
99 BasicBlock
*Cur
= I
++; // Advance over block so we don't traverse new blocks
101 if (SwitchInst
*SI
= dyn_cast
<SwitchInst
>(Cur
->getTerminator())) {
103 processSwitchInst(SI
);
110 // operator<< - Used for debugging purposes.
112 static raw_ostream
& operator<<(raw_ostream
&O
,
113 const LowerSwitch::CaseVector
&C
) ATTRIBUTE_USED
;
114 static raw_ostream
& operator<<(raw_ostream
&O
,
115 const LowerSwitch::CaseVector
&C
) {
118 for (LowerSwitch::CaseVector::const_iterator B
= C
.begin(),
119 E
= C
.end(); B
!= E
; ) {
120 O
<< *B
->Low
<< " -" << *B
->High
;
121 if (++B
!= E
) O
<< ", ";
127 // switchConvert - Convert the switch statement into a binary lookup of
128 // the case values. The function recursively builds this tree.
130 BasicBlock
* LowerSwitch::switchConvert(CaseItr Begin
, CaseItr End
,
131 Value
* Val
, BasicBlock
* OrigBlock
,
134 unsigned Size
= End
- Begin
;
137 return newLeafBlock(*Begin
, Val
, OrigBlock
, Default
);
139 unsigned Mid
= Size
/ 2;
140 std::vector
<CaseRange
> LHS(Begin
, Begin
+ Mid
);
141 DEBUG(errs() << "LHS: " << LHS
<< "\n");
142 std::vector
<CaseRange
> RHS(Begin
+ Mid
, End
);
143 DEBUG(errs() << "RHS: " << RHS
<< "\n");
145 CaseRange
& Pivot
= *(Begin
+ Mid
);
146 DEBUG(errs() << "Pivot ==> "
147 << cast
<ConstantInt
>(Pivot
.Low
)->getValue() << " -"
148 << cast
<ConstantInt
>(Pivot
.High
)->getValue() << "\n");
150 BasicBlock
* LBranch
= switchConvert(LHS
.begin(), LHS
.end(), Val
,
152 BasicBlock
* RBranch
= switchConvert(RHS
.begin(), RHS
.end(), Val
,
155 // Create a new node that checks if the value is < pivot. Go to the
156 // left branch if it is and right branch if not.
157 Function
* F
= OrigBlock
->getParent();
158 BasicBlock
* NewNode
= BasicBlock::Create("NodeBlock");
159 Function::iterator FI
= OrigBlock
;
160 F
->getBasicBlockList().insert(++FI
, NewNode
);
162 ICmpInst
* Comp
= new ICmpInst(Default
->getContext(), ICmpInst::ICMP_SLT
,
163 Val
, Pivot
.Low
, "Pivot");
164 NewNode
->getInstList().push_back(Comp
);
165 BranchInst::Create(LBranch
, RBranch
, Comp
, NewNode
);
169 // newLeafBlock - Create a new leaf block for the binary lookup tree. It
170 // checks if the switch's value == the case's value. If not, then it
171 // jumps to the default branch. At this point in the tree, the value
172 // can't be another valid case value, so the jump to the "default" branch
175 BasicBlock
* LowerSwitch::newLeafBlock(CaseRange
& Leaf
, Value
* Val
,
176 BasicBlock
* OrigBlock
,
179 Function
* F
= OrigBlock
->getParent();
180 LLVMContext
&Context
= F
->getContext();
181 BasicBlock
* NewLeaf
= BasicBlock::Create("LeafBlock");
182 Function::iterator FI
= OrigBlock
;
183 F
->getBasicBlockList().insert(++FI
, NewLeaf
);
186 ICmpInst
* Comp
= NULL
;
187 if (Leaf
.Low
== Leaf
.High
) {
188 // Make the seteq instruction...
189 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_EQ
, Val
,
190 Leaf
.Low
, "SwitchLeaf");
192 // Make range comparison
193 if (cast
<ConstantInt
>(Leaf
.Low
)->isMinValue(true /*isSigned*/)) {
194 // Val >= Min && Val <= Hi --> Val <= Hi
195 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_SLE
, Val
, Leaf
.High
,
197 } else if (cast
<ConstantInt
>(Leaf
.Low
)->isZero()) {
198 // Val >= 0 && Val <= Hi --> Val <=u Hi
199 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_ULE
, Val
, Leaf
.High
,
202 // Emit V-Lo <=u Hi-Lo
203 Constant
* NegLo
= ConstantExpr::getNeg(Leaf
.Low
);
204 Instruction
* Add
= BinaryOperator::CreateAdd(Val
, NegLo
,
205 Val
->getName()+".off",
207 Constant
*UpperBound
= ConstantExpr::getAdd(NegLo
, Leaf
.High
);
208 Comp
= new ICmpInst(*NewLeaf
, ICmpInst::ICMP_ULE
, Add
, UpperBound
,
213 // Make the conditional branch...
214 BasicBlock
* Succ
= Leaf
.BB
;
215 BranchInst::Create(Succ
, Default
, Comp
, NewLeaf
);
217 // If there were any PHI nodes in this successor, rewrite one entry
218 // from OrigBlock to come from NewLeaf.
219 for (BasicBlock::iterator I
= Succ
->begin(); isa
<PHINode
>(I
); ++I
) {
220 PHINode
* PN
= cast
<PHINode
>(I
);
221 // Remove all but one incoming entries from the cluster
222 uint64_t Range
= cast
<ConstantInt
>(Leaf
.High
)->getSExtValue() -
223 cast
<ConstantInt
>(Leaf
.Low
)->getSExtValue();
224 for (uint64_t j
= 0; j
< Range
; ++j
) {
225 PN
->removeIncomingValue(OrigBlock
);
228 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
229 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
230 PN
->setIncomingBlock((unsigned)BlockIdx
, NewLeaf
);
236 // Clusterify - Transform simple list of Cases into list of CaseRange's
237 unsigned LowerSwitch::Clusterify(CaseVector
& Cases
, SwitchInst
*SI
) {
238 unsigned numCmps
= 0;
240 // Start with "simple" cases
241 for (unsigned i
= 1; i
< SI
->getNumSuccessors(); ++i
)
242 Cases
.push_back(CaseRange(SI
->getSuccessorValue(i
),
243 SI
->getSuccessorValue(i
),
244 SI
->getSuccessor(i
)));
245 std::sort(Cases
.begin(), Cases
.end(), CaseCmp());
247 // Merge case into clusters
249 for (CaseItr I
=Cases
.begin(), J
=next(Cases
.begin()); J
!=Cases
.end(); ) {
250 int64_t nextValue
= cast
<ConstantInt
>(J
->Low
)->getSExtValue();
251 int64_t currentValue
= cast
<ConstantInt
>(I
->High
)->getSExtValue();
252 BasicBlock
* nextBB
= J
->BB
;
253 BasicBlock
* currentBB
= I
->BB
;
255 // If the two neighboring cases go to the same destination, merge them
256 // into a single case.
257 if ((nextValue
-currentValue
==1) && (currentBB
== nextBB
)) {
265 for (CaseItr I
=Cases
.begin(), E
=Cases
.end(); I
!=E
; ++I
, ++numCmps
) {
266 if (I
->Low
!= I
->High
)
267 // A range counts double, since it requires two compares.
274 // processSwitchInst - Replace the specified switch instruction with a sequence
275 // of chained if-then insts in a balanced binary search.
277 void LowerSwitch::processSwitchInst(SwitchInst
*SI
) {
278 BasicBlock
*CurBlock
= SI
->getParent();
279 BasicBlock
*OrigBlock
= CurBlock
;
280 Function
*F
= CurBlock
->getParent();
281 Value
*Val
= SI
->getOperand(0); // The value we are switching on...
282 BasicBlock
* Default
= SI
->getDefaultDest();
284 // If there is only the default destination, don't bother with the code below.
285 if (SI
->getNumOperands() == 2) {
286 BranchInst::Create(SI
->getDefaultDest(), CurBlock
);
287 CurBlock
->getInstList().erase(SI
);
291 // Create a new, empty default block so that the new hierarchy of
292 // if-then statements go to this and the PHI nodes are happy.
293 BasicBlock
* NewDefault
= BasicBlock::Create("NewDefault");
294 F
->getBasicBlockList().insert(Default
, NewDefault
);
296 BranchInst::Create(Default
, NewDefault
);
298 // If there is an entry in any PHI nodes for the default edge, make sure
299 // to update them as well.
300 for (BasicBlock::iterator I
= Default
->begin(); isa
<PHINode
>(I
); ++I
) {
301 PHINode
*PN
= cast
<PHINode
>(I
);
302 int BlockIdx
= PN
->getBasicBlockIndex(OrigBlock
);
303 assert(BlockIdx
!= -1 && "Switch didn't go to this successor??");
304 PN
->setIncomingBlock((unsigned)BlockIdx
, NewDefault
);
307 // Prepare cases vector.
309 unsigned numCmps
= Clusterify(Cases
, SI
);
311 DEBUG(errs() << "Clusterify finished. Total clusters: " << Cases
.size()
312 << ". Total compares: " << numCmps
<< "\n");
313 DEBUG(errs() << "Cases: " << Cases
<< "\n");
316 BasicBlock
* SwitchBlock
= switchConvert(Cases
.begin(), Cases
.end(), Val
,
317 OrigBlock
, NewDefault
);
319 // Branch to our shiny new if-then stuff...
320 BranchInst::Create(SwitchBlock
, OrigBlock
);
322 // We are now done with the switch instruction, delete it.
323 CurBlock
->getInstList().erase(SI
);