1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/MachineJumpTableInfo.h"
15 #include "llvm/CodeGen/SwitchLoweringUtils.h"
18 using namespace SwitchCG
;
20 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector
&Clusters
,
21 unsigned First
, unsigned Last
) {
22 assert(Last
>= First
);
23 const APInt
&LowCase
= Clusters
[First
].Low
->getValue();
24 const APInt
&HighCase
= Clusters
[Last
].High
->getValue();
25 assert(LowCase
.getBitWidth() == HighCase
.getBitWidth());
27 // FIXME: A range of consecutive cases has 100% density, but only requires one
28 // comparison to lower. We should discriminate against such consecutive ranges
30 return (HighCase
- LowCase
).getLimitedValue((UINT64_MAX
- 1) / 100) + 1;
34 SwitchCG::getJumpTableNumCases(const SmallVectorImpl
<unsigned> &TotalCases
,
35 unsigned First
, unsigned Last
) {
36 assert(Last
>= First
);
37 assert(TotalCases
[Last
] >= TotalCases
[First
]);
39 TotalCases
[Last
] - (First
== 0 ? 0 : TotalCases
[First
- 1]);
43 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector
&Clusters
,
45 MachineBasicBlock
*DefaultMBB
) {
47 // Clusters must be non-empty, sorted, and only contain Range clusters.
48 assert(!Clusters
.empty());
49 for (CaseCluster
&C
: Clusters
)
50 assert(C
.Kind
== CC_Range
);
51 for (unsigned i
= 1, e
= Clusters
.size(); i
< e
; ++i
)
52 assert(Clusters
[i
- 1].High
->getValue().slt(Clusters
[i
].Low
->getValue()));
55 assert(TLI
&& "TLI not set!");
56 if (!TLI
->areJTsAllowed(SI
->getParent()->getParent()))
59 const unsigned MinJumpTableEntries
= TLI
->getMinimumJumpTableEntries();
60 const unsigned SmallNumberOfEntries
= MinJumpTableEntries
/ 2;
62 // Bail if not enough cases.
63 const int64_t N
= Clusters
.size();
64 if (N
< 2 || N
< MinJumpTableEntries
)
67 // Accumulated number of cases in each cluster and those prior to it.
68 SmallVector
<unsigned, 8> TotalCases(N
);
69 for (unsigned i
= 0; i
< N
; ++i
) {
70 const APInt
&Hi
= Clusters
[i
].High
->getValue();
71 const APInt
&Lo
= Clusters
[i
].Low
->getValue();
72 TotalCases
[i
] = (Hi
- Lo
).getLimitedValue() + 1;
74 TotalCases
[i
] += TotalCases
[i
- 1];
77 uint64_t Range
= getJumpTableRange(Clusters
,0, N
- 1);
78 uint64_t NumCases
= getJumpTableNumCases(TotalCases
, 0, N
- 1);
79 assert(NumCases
< UINT64_MAX
/ 100);
80 assert(Range
>= NumCases
);
82 // Cheap case: the whole range may be suitable for jump table.
83 if (TLI
->isSuitableForJumpTable(SI
, NumCases
, Range
)) {
84 CaseCluster JTCluster
;
85 if (buildJumpTable(Clusters
, 0, N
- 1, SI
, DefaultMBB
, JTCluster
)) {
86 Clusters
[0] = JTCluster
;
92 // The algorithm below is not suitable for -O0.
93 if (TM
->getOptLevel() == CodeGenOpt::None
)
96 // Split Clusters into minimum number of dense partitions. The algorithm uses
97 // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
98 // for the Case Statement'" (1994), but builds the MinPartitions array in
99 // reverse order to make it easier to reconstruct the partitions in ascending
100 // order. In the choice between two optimal partitionings, it picks the one
101 // which yields more jump tables.
103 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
104 SmallVector
<unsigned, 8> MinPartitions(N
);
105 // LastElement[i] is the last element of the partition starting at i.
106 SmallVector
<unsigned, 8> LastElement(N
);
107 // PartitionsScore[i] is used to break ties when choosing between two
108 // partitionings resulting in the same number of partitions.
109 SmallVector
<unsigned, 8> PartitionsScore(N
);
110 // For PartitionsScore, a small number of comparisons is considered as good as
111 // a jump table and a single comparison is considered better than a jump
113 enum PartitionScores
: unsigned {
120 // Base case: There is only one way to partition Clusters[N-1].
121 MinPartitions
[N
- 1] = 1;
122 LastElement
[N
- 1] = N
- 1;
123 PartitionsScore
[N
- 1] = PartitionScores::SingleCase
;
125 // Note: loop indexes are signed to avoid underflow.
126 for (int64_t i
= N
- 2; i
>= 0; i
--) {
127 // Find optimal partitioning of Clusters[i..N-1].
128 // Baseline: Put Clusters[i] into a partition on its own.
129 MinPartitions
[i
] = MinPartitions
[i
+ 1] + 1;
131 PartitionsScore
[i
] = PartitionsScore
[i
+ 1] + PartitionScores::SingleCase
;
133 // Search for a solution that results in fewer partitions.
134 for (int64_t j
= N
- 1; j
> i
; j
--) {
135 // Try building a partition from Clusters[i..j].
136 Range
= getJumpTableRange(Clusters
, i
, j
);
137 NumCases
= getJumpTableNumCases(TotalCases
, i
, j
);
138 assert(NumCases
< UINT64_MAX
/ 100);
139 assert(Range
>= NumCases
);
141 if (TLI
->isSuitableForJumpTable(SI
, NumCases
, Range
)) {
142 unsigned NumPartitions
= 1 + (j
== N
- 1 ? 0 : MinPartitions
[j
+ 1]);
143 unsigned Score
= j
== N
- 1 ? 0 : PartitionsScore
[j
+ 1];
144 int64_t NumEntries
= j
- i
+ 1;
147 Score
+= PartitionScores::SingleCase
;
148 else if (NumEntries
<= SmallNumberOfEntries
)
149 Score
+= PartitionScores::FewCases
;
150 else if (NumEntries
>= MinJumpTableEntries
)
151 Score
+= PartitionScores::Table
;
153 // If this leads to fewer partitions, or to the same number of
154 // partitions with better score, it is a better partitioning.
155 if (NumPartitions
< MinPartitions
[i
] ||
156 (NumPartitions
== MinPartitions
[i
] && Score
> PartitionsScore
[i
])) {
157 MinPartitions
[i
] = NumPartitions
;
159 PartitionsScore
[i
] = Score
;
165 // Iterate over the partitions, replacing some with jump tables in-place.
166 unsigned DstIndex
= 0;
167 for (unsigned First
= 0, Last
; First
< N
; First
= Last
+ 1) {
168 Last
= LastElement
[First
];
169 assert(Last
>= First
);
170 assert(DstIndex
<= First
);
171 unsigned NumClusters
= Last
- First
+ 1;
173 CaseCluster JTCluster
;
174 if (NumClusters
>= MinJumpTableEntries
&&
175 buildJumpTable(Clusters
, First
, Last
, SI
, DefaultMBB
, JTCluster
)) {
176 Clusters
[DstIndex
++] = JTCluster
;
178 for (unsigned I
= First
; I
<= Last
; ++I
)
179 std::memmove(&Clusters
[DstIndex
++], &Clusters
[I
], sizeof(Clusters
[I
]));
182 Clusters
.resize(DstIndex
);
185 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector
&Clusters
,
186 unsigned First
, unsigned Last
,
187 const SwitchInst
*SI
,
188 MachineBasicBlock
*DefaultMBB
,
189 CaseCluster
&JTCluster
) {
190 assert(First
<= Last
);
192 auto Prob
= BranchProbability::getZero();
193 unsigned NumCmps
= 0;
194 std::vector
<MachineBasicBlock
*> Table
;
195 DenseMap
<MachineBasicBlock
*, BranchProbability
> JTProbs
;
197 // Initialize probabilities in JTProbs.
198 for (unsigned I
= First
; I
<= Last
; ++I
)
199 JTProbs
[Clusters
[I
].MBB
] = BranchProbability::getZero();
201 for (unsigned I
= First
; I
<= Last
; ++I
) {
202 assert(Clusters
[I
].Kind
== CC_Range
);
203 Prob
+= Clusters
[I
].Prob
;
204 const APInt
&Low
= Clusters
[I
].Low
->getValue();
205 const APInt
&High
= Clusters
[I
].High
->getValue();
206 NumCmps
+= (Low
== High
) ? 1 : 2;
208 // Fill the gap between this and the previous cluster.
209 const APInt
&PreviousHigh
= Clusters
[I
- 1].High
->getValue();
210 assert(PreviousHigh
.slt(Low
));
211 uint64_t Gap
= (Low
- PreviousHigh
).getLimitedValue() - 1;
212 for (uint64_t J
= 0; J
< Gap
; J
++)
213 Table
.push_back(DefaultMBB
);
215 uint64_t ClusterSize
= (High
- Low
).getLimitedValue() + 1;
216 for (uint64_t J
= 0; J
< ClusterSize
; ++J
)
217 Table
.push_back(Clusters
[I
].MBB
);
218 JTProbs
[Clusters
[I
].MBB
] += Clusters
[I
].Prob
;
221 unsigned NumDests
= JTProbs
.size();
222 if (TLI
->isSuitableForBitTests(NumDests
, NumCmps
,
223 Clusters
[First
].Low
->getValue(),
224 Clusters
[Last
].High
->getValue(), *DL
)) {
225 // Clusters[First..Last] should be lowered as bit tests instead.
229 // Create the MBB that will load from and jump through the table.
230 // Note: We create it here, but it's not inserted into the function yet.
231 MachineFunction
*CurMF
= FuncInfo
.MF
;
232 MachineBasicBlock
*JumpTableMBB
=
233 CurMF
->CreateMachineBasicBlock(SI
->getParent());
235 // Add successors. Note: use table order for determinism.
236 SmallPtrSet
<MachineBasicBlock
*, 8> Done
;
237 for (MachineBasicBlock
*Succ
: Table
) {
238 if (Done
.count(Succ
))
240 addSuccessorWithProb(JumpTableMBB
, Succ
, JTProbs
[Succ
]);
243 JumpTableMBB
->normalizeSuccProbs();
245 unsigned JTI
= CurMF
->getOrCreateJumpTableInfo(TLI
->getJumpTableEncoding())
246 ->createJumpTableIndex(Table
);
248 // Set up the jump table info.
249 JumpTable
JT(-1U, JTI
, JumpTableMBB
, nullptr);
250 JumpTableHeader
JTH(Clusters
[First
].Low
->getValue(),
251 Clusters
[Last
].High
->getValue(), SI
->getCondition(),
253 JTCases
.emplace_back(std::move(JTH
), std::move(JT
));
255 JTCluster
= CaseCluster::jumpTable(Clusters
[First
].Low
, Clusters
[Last
].High
,
256 JTCases
.size() - 1, Prob
);
260 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector
&Clusters
,
261 const SwitchInst
*SI
) {
262 // Partition Clusters into as few subsets as possible, where each subset has a
263 // range that fits in a machine word and has <= 3 unique destinations.
266 // Clusters must be sorted and contain Range or JumpTable clusters.
267 assert(!Clusters
.empty());
268 assert(Clusters
[0].Kind
== CC_Range
|| Clusters
[0].Kind
== CC_JumpTable
);
269 for (const CaseCluster
&C
: Clusters
)
270 assert(C
.Kind
== CC_Range
|| C
.Kind
== CC_JumpTable
);
271 for (unsigned i
= 1; i
< Clusters
.size(); ++i
)
272 assert(Clusters
[i
-1].High
->getValue().slt(Clusters
[i
].Low
->getValue()));
275 // The algorithm below is not suitable for -O0.
276 if (TM
->getOptLevel() == CodeGenOpt::None
)
279 // If target does not have legal shift left, do not emit bit tests at all.
280 EVT PTy
= TLI
->getPointerTy(*DL
);
281 if (!TLI
->isOperationLegal(ISD::SHL
, PTy
))
284 int BitWidth
= PTy
.getSizeInBits();
285 const int64_t N
= Clusters
.size();
287 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
288 SmallVector
<unsigned, 8> MinPartitions(N
);
289 // LastElement[i] is the last element of the partition starting at i.
290 SmallVector
<unsigned, 8> LastElement(N
);
292 // FIXME: This might not be the best algorithm for finding bit test clusters.
294 // Base case: There is only one way to partition Clusters[N-1].
295 MinPartitions
[N
- 1] = 1;
296 LastElement
[N
- 1] = N
- 1;
298 // Note: loop indexes are signed to avoid underflow.
299 for (int64_t i
= N
- 2; i
>= 0; --i
) {
300 // Find optimal partitioning of Clusters[i..N-1].
301 // Baseline: Put Clusters[i] into a partition on its own.
302 MinPartitions
[i
] = MinPartitions
[i
+ 1] + 1;
305 // Search for a solution that results in fewer partitions.
306 // Note: the search is limited by BitWidth, reducing time complexity.
307 for (int64_t j
= std::min(N
- 1, i
+ BitWidth
- 1); j
> i
; --j
) {
308 // Try building a partition from Clusters[i..j].
311 if (!TLI
->rangeFitsInWord(Clusters
[i
].Low
->getValue(),
312 Clusters
[j
].High
->getValue(), *DL
))
315 // Check nbr of destinations and cluster types.
316 // FIXME: This works, but doesn't seem very efficient.
317 bool RangesOnly
= true;
318 BitVector
Dests(FuncInfo
.MF
->getNumBlockIDs());
319 for (int64_t k
= i
; k
<= j
; k
++) {
320 if (Clusters
[k
].Kind
!= CC_Range
) {
324 Dests
.set(Clusters
[k
].MBB
->getNumber());
326 if (!RangesOnly
|| Dests
.count() > 3)
329 // Check if it's a better partition.
330 unsigned NumPartitions
= 1 + (j
== N
- 1 ? 0 : MinPartitions
[j
+ 1]);
331 if (NumPartitions
< MinPartitions
[i
]) {
332 // Found a better partition.
333 MinPartitions
[i
] = NumPartitions
;
339 // Iterate over the partitions, replacing with bit-test clusters in-place.
340 unsigned DstIndex
= 0;
341 for (unsigned First
= 0, Last
; First
< N
; First
= Last
+ 1) {
342 Last
= LastElement
[First
];
343 assert(First
<= Last
);
344 assert(DstIndex
<= First
);
346 CaseCluster BitTestCluster
;
347 if (buildBitTests(Clusters
, First
, Last
, SI
, BitTestCluster
)) {
348 Clusters
[DstIndex
++] = BitTestCluster
;
350 size_t NumClusters
= Last
- First
+ 1;
351 std::memmove(&Clusters
[DstIndex
], &Clusters
[First
],
352 sizeof(Clusters
[0]) * NumClusters
);
353 DstIndex
+= NumClusters
;
356 Clusters
.resize(DstIndex
);
359 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector
&Clusters
,
360 unsigned First
, unsigned Last
,
361 const SwitchInst
*SI
,
362 CaseCluster
&BTCluster
) {
363 assert(First
<= Last
);
367 BitVector
Dests(FuncInfo
.MF
->getNumBlockIDs());
368 unsigned NumCmps
= 0;
369 for (int64_t I
= First
; I
<= Last
; ++I
) {
370 assert(Clusters
[I
].Kind
== CC_Range
);
371 Dests
.set(Clusters
[I
].MBB
->getNumber());
372 NumCmps
+= (Clusters
[I
].Low
== Clusters
[I
].High
) ? 1 : 2;
374 unsigned NumDests
= Dests
.count();
376 APInt Low
= Clusters
[First
].Low
->getValue();
377 APInt High
= Clusters
[Last
].High
->getValue();
378 assert(Low
.slt(High
));
380 if (!TLI
->isSuitableForBitTests(NumDests
, NumCmps
, Low
, High
, *DL
))
386 const int BitWidth
= TLI
->getPointerTy(*DL
).getSizeInBits();
387 assert(TLI
->rangeFitsInWord(Low
, High
, *DL
) &&
388 "Case range must fit in bit mask!");
390 // Check if the clusters cover a contiguous range such that no value in the
391 // range will jump to the default statement.
392 bool ContiguousRange
= true;
393 for (int64_t I
= First
+ 1; I
<= Last
; ++I
) {
394 if (Clusters
[I
].Low
->getValue() != Clusters
[I
- 1].High
->getValue() + 1) {
395 ContiguousRange
= false;
400 if (Low
.isStrictlyPositive() && High
.slt(BitWidth
)) {
401 // Optimize the case where all the case values fit in a word without having
402 // to subtract minValue. In this case, we can optimize away the subtraction.
403 LowBound
= APInt::getNullValue(Low
.getBitWidth());
405 ContiguousRange
= false;
408 CmpRange
= High
- Low
;
412 auto TotalProb
= BranchProbability::getZero();
413 for (unsigned i
= First
; i
<= Last
; ++i
) {
414 // Find the CaseBits for this destination.
416 for (j
= 0; j
< CBV
.size(); ++j
)
417 if (CBV
[j
].BB
== Clusters
[i
].MBB
)
421 CaseBits(0, Clusters
[i
].MBB
, 0, BranchProbability::getZero()));
422 CaseBits
*CB
= &CBV
[j
];
424 // Update Mask, Bits and ExtraProb.
425 uint64_t Lo
= (Clusters
[i
].Low
->getValue() - LowBound
).getZExtValue();
426 uint64_t Hi
= (Clusters
[i
].High
->getValue() - LowBound
).getZExtValue();
427 assert(Hi
>= Lo
&& Hi
< 64 && "Invalid bit case!");
428 CB
->Mask
|= (-1ULL >> (63 - (Hi
- Lo
))) << Lo
;
429 CB
->Bits
+= Hi
- Lo
+ 1;
430 CB
->ExtraProb
+= Clusters
[i
].Prob
;
431 TotalProb
+= Clusters
[i
].Prob
;
435 llvm::sort(CBV
, [](const CaseBits
&a
, const CaseBits
&b
) {
436 // Sort by probability first, number of bits second, bit mask third.
437 if (a
.ExtraProb
!= b
.ExtraProb
)
438 return a
.ExtraProb
> b
.ExtraProb
;
439 if (a
.Bits
!= b
.Bits
)
440 return a
.Bits
> b
.Bits
;
441 return a
.Mask
< b
.Mask
;
444 for (auto &CB
: CBV
) {
445 MachineBasicBlock
*BitTestBB
=
446 FuncInfo
.MF
->CreateMachineBasicBlock(SI
->getParent());
447 BTI
.push_back(BitTestCase(CB
.Mask
, BitTestBB
, CB
.BB
, CB
.ExtraProb
));
449 BitTestCases
.emplace_back(std::move(LowBound
), std::move(CmpRange
),
450 SI
->getCondition(), -1U, MVT::Other
, false,
451 ContiguousRange
, nullptr, nullptr, std::move(BTI
),
454 BTCluster
= CaseCluster::bitTests(Clusters
[First
].Low
, Clusters
[Last
].High
,
455 BitTestCases
.size() - 1, TotalProb
);
459 void SwitchCG::sortAndRangeify(CaseClusterVector
&Clusters
) {
461 for (const CaseCluster
&CC
: Clusters
)
462 assert(CC
.Low
== CC
.High
&& "Input clusters must be single-case");
465 llvm::sort(Clusters
, [](const CaseCluster
&a
, const CaseCluster
&b
) {
466 return a
.Low
->getValue().slt(b
.Low
->getValue());
469 // Merge adjacent clusters with the same destination.
470 const unsigned N
= Clusters
.size();
471 unsigned DstIndex
= 0;
472 for (unsigned SrcIndex
= 0; SrcIndex
< N
; ++SrcIndex
) {
473 CaseCluster
&CC
= Clusters
[SrcIndex
];
474 const ConstantInt
*CaseVal
= CC
.Low
;
475 MachineBasicBlock
*Succ
= CC
.MBB
;
477 if (DstIndex
!= 0 && Clusters
[DstIndex
- 1].MBB
== Succ
&&
478 (CaseVal
->getValue() - Clusters
[DstIndex
- 1].High
->getValue()) == 1) {
479 // If this case has the same successor and is a neighbour, merge it into
480 // the previous cluster.
481 Clusters
[DstIndex
- 1].High
= CaseVal
;
482 Clusters
[DstIndex
- 1].Prob
+= CC
.Prob
;
484 std::memmove(&Clusters
[DstIndex
++], &Clusters
[SrcIndex
],
485 sizeof(Clusters
[SrcIndex
]));
488 Clusters
.resize(DstIndex
);