1 //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains switch inst lowering optimizations and utilities for
10 // codegen, so that it can be used for both SelectionDAG and GlobalISel.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/CodeGen/SwitchLoweringUtils.h"
15 #include "llvm/CodeGen/FunctionLoweringInfo.h"
16 #include "llvm/CodeGen/MachineJumpTableInfo.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 #include "llvm/Target/TargetMachine.h"
21 using namespace SwitchCG
;
23 uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector
&Clusters
,
24 unsigned First
, unsigned Last
) {
25 assert(Last
>= First
);
26 const APInt
&LowCase
= Clusters
[First
].Low
->getValue();
27 const APInt
&HighCase
= Clusters
[Last
].High
->getValue();
28 assert(LowCase
.getBitWidth() == HighCase
.getBitWidth());
30 // FIXME: A range of consecutive cases has 100% density, but only requires one
31 // comparison to lower. We should discriminate against such consecutive ranges
33 return (HighCase
- LowCase
).getLimitedValue((UINT64_MAX
- 1) / 100) + 1;
37 SwitchCG::getJumpTableNumCases(const SmallVectorImpl
<unsigned> &TotalCases
,
38 unsigned First
, unsigned Last
) {
39 assert(Last
>= First
);
40 assert(TotalCases
[Last
] >= TotalCases
[First
]);
42 TotalCases
[Last
] - (First
== 0 ? 0 : TotalCases
[First
- 1]);
46 void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector
&Clusters
,
48 MachineBasicBlock
*DefaultMBB
,
49 ProfileSummaryInfo
*PSI
,
50 BlockFrequencyInfo
*BFI
) {
52 // Clusters must be non-empty, sorted, and only contain Range clusters.
53 assert(!Clusters
.empty());
54 for (CaseCluster
&C
: Clusters
)
55 assert(C
.Kind
== CC_Range
);
56 for (unsigned i
= 1, e
= Clusters
.size(); i
< e
; ++i
)
57 assert(Clusters
[i
- 1].High
->getValue().slt(Clusters
[i
].Low
->getValue()));
60 assert(TLI
&& "TLI not set!");
61 if (!TLI
->areJTsAllowed(SI
->getParent()->getParent()))
64 const unsigned MinJumpTableEntries
= TLI
->getMinimumJumpTableEntries();
65 const unsigned SmallNumberOfEntries
= MinJumpTableEntries
/ 2;
67 // Bail if not enough cases.
68 const int64_t N
= Clusters
.size();
69 if (N
< 2 || N
< MinJumpTableEntries
)
72 // Accumulated number of cases in each cluster and those prior to it.
73 SmallVector
<unsigned, 8> TotalCases(N
);
74 for (unsigned i
= 0; i
< N
; ++i
) {
75 const APInt
&Hi
= Clusters
[i
].High
->getValue();
76 const APInt
&Lo
= Clusters
[i
].Low
->getValue();
77 TotalCases
[i
] = (Hi
- Lo
).getLimitedValue() + 1;
79 TotalCases
[i
] += TotalCases
[i
- 1];
82 uint64_t Range
= getJumpTableRange(Clusters
,0, N
- 1);
83 uint64_t NumCases
= getJumpTableNumCases(TotalCases
, 0, N
- 1);
84 assert(NumCases
< UINT64_MAX
/ 100);
85 assert(Range
>= NumCases
);
87 // Cheap case: the whole range may be suitable for jump table.
88 if (TLI
->isSuitableForJumpTable(SI
, NumCases
, Range
, PSI
, BFI
)) {
89 CaseCluster JTCluster
;
90 if (buildJumpTable(Clusters
, 0, N
- 1, SI
, DefaultMBB
, JTCluster
)) {
91 Clusters
[0] = JTCluster
;
97 // The algorithm below is not suitable for -O0.
98 if (TM
->getOptLevel() == CodeGenOptLevel::None
)
101 // Split Clusters into minimum number of dense partitions. The algorithm uses
102 // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
103 // for the Case Statement'" (1994), but builds the MinPartitions array in
104 // reverse order to make it easier to reconstruct the partitions in ascending
105 // order. In the choice between two optimal partitionings, it picks the one
106 // which yields more jump tables.
108 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
109 SmallVector
<unsigned, 8> MinPartitions(N
);
110 // LastElement[i] is the last element of the partition starting at i.
111 SmallVector
<unsigned, 8> LastElement(N
);
112 // PartitionsScore[i] is used to break ties when choosing between two
113 // partitionings resulting in the same number of partitions.
114 SmallVector
<unsigned, 8> PartitionsScore(N
);
115 // For PartitionsScore, a small number of comparisons is considered as good as
116 // a jump table and a single comparison is considered better than a jump
118 enum PartitionScores
: unsigned {
125 // Base case: There is only one way to partition Clusters[N-1].
126 MinPartitions
[N
- 1] = 1;
127 LastElement
[N
- 1] = N
- 1;
128 PartitionsScore
[N
- 1] = PartitionScores::SingleCase
;
130 // Note: loop indexes are signed to avoid underflow.
131 for (int64_t i
= N
- 2; i
>= 0; i
--) {
132 // Find optimal partitioning of Clusters[i..N-1].
133 // Baseline: Put Clusters[i] into a partition on its own.
134 MinPartitions
[i
] = MinPartitions
[i
+ 1] + 1;
136 PartitionsScore
[i
] = PartitionsScore
[i
+ 1] + PartitionScores::SingleCase
;
138 // Search for a solution that results in fewer partitions.
139 for (int64_t j
= N
- 1; j
> i
; j
--) {
140 // Try building a partition from Clusters[i..j].
141 Range
= getJumpTableRange(Clusters
, i
, j
);
142 NumCases
= getJumpTableNumCases(TotalCases
, i
, j
);
143 assert(NumCases
< UINT64_MAX
/ 100);
144 assert(Range
>= NumCases
);
146 if (TLI
->isSuitableForJumpTable(SI
, NumCases
, Range
, PSI
, BFI
)) {
147 unsigned NumPartitions
= 1 + (j
== N
- 1 ? 0 : MinPartitions
[j
+ 1]);
148 unsigned Score
= j
== N
- 1 ? 0 : PartitionsScore
[j
+ 1];
149 int64_t NumEntries
= j
- i
+ 1;
152 Score
+= PartitionScores::SingleCase
;
153 else if (NumEntries
<= SmallNumberOfEntries
)
154 Score
+= PartitionScores::FewCases
;
155 else if (NumEntries
>= MinJumpTableEntries
)
156 Score
+= PartitionScores::Table
;
158 // If this leads to fewer partitions, or to the same number of
159 // partitions with better score, it is a better partitioning.
160 if (NumPartitions
< MinPartitions
[i
] ||
161 (NumPartitions
== MinPartitions
[i
] && Score
> PartitionsScore
[i
])) {
162 MinPartitions
[i
] = NumPartitions
;
164 PartitionsScore
[i
] = Score
;
170 // Iterate over the partitions, replacing some with jump tables in-place.
171 unsigned DstIndex
= 0;
172 for (unsigned First
= 0, Last
; First
< N
; First
= Last
+ 1) {
173 Last
= LastElement
[First
];
174 assert(Last
>= First
);
175 assert(DstIndex
<= First
);
176 unsigned NumClusters
= Last
- First
+ 1;
178 CaseCluster JTCluster
;
179 if (NumClusters
>= MinJumpTableEntries
&&
180 buildJumpTable(Clusters
, First
, Last
, SI
, DefaultMBB
, JTCluster
)) {
181 Clusters
[DstIndex
++] = JTCluster
;
183 for (unsigned I
= First
; I
<= Last
; ++I
)
184 std::memmove(&Clusters
[DstIndex
++], &Clusters
[I
], sizeof(Clusters
[I
]));
187 Clusters
.resize(DstIndex
);
190 bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector
&Clusters
,
191 unsigned First
, unsigned Last
,
192 const SwitchInst
*SI
,
193 MachineBasicBlock
*DefaultMBB
,
194 CaseCluster
&JTCluster
) {
195 assert(First
<= Last
);
197 auto Prob
= BranchProbability::getZero();
198 unsigned NumCmps
= 0;
199 std::vector
<MachineBasicBlock
*> Table
;
200 DenseMap
<MachineBasicBlock
*, BranchProbability
> JTProbs
;
202 // Initialize probabilities in JTProbs.
203 for (unsigned I
= First
; I
<= Last
; ++I
)
204 JTProbs
[Clusters
[I
].MBB
] = BranchProbability::getZero();
206 for (unsigned I
= First
; I
<= Last
; ++I
) {
207 assert(Clusters
[I
].Kind
== CC_Range
);
208 Prob
+= Clusters
[I
].Prob
;
209 const APInt
&Low
= Clusters
[I
].Low
->getValue();
210 const APInt
&High
= Clusters
[I
].High
->getValue();
211 NumCmps
+= (Low
== High
) ? 1 : 2;
213 // Fill the gap between this and the previous cluster.
214 const APInt
&PreviousHigh
= Clusters
[I
- 1].High
->getValue();
215 assert(PreviousHigh
.slt(Low
));
216 uint64_t Gap
= (Low
- PreviousHigh
).getLimitedValue() - 1;
217 for (uint64_t J
= 0; J
< Gap
; J
++)
218 Table
.push_back(DefaultMBB
);
220 uint64_t ClusterSize
= (High
- Low
).getLimitedValue() + 1;
221 for (uint64_t J
= 0; J
< ClusterSize
; ++J
)
222 Table
.push_back(Clusters
[I
].MBB
);
223 JTProbs
[Clusters
[I
].MBB
] += Clusters
[I
].Prob
;
226 unsigned NumDests
= JTProbs
.size();
227 if (TLI
->isSuitableForBitTests(NumDests
, NumCmps
,
228 Clusters
[First
].Low
->getValue(),
229 Clusters
[Last
].High
->getValue(), *DL
)) {
230 // Clusters[First..Last] should be lowered as bit tests instead.
234 // Create the MBB that will load from and jump through the table.
235 // Note: We create it here, but it's not inserted into the function yet.
236 MachineFunction
*CurMF
= FuncInfo
.MF
;
237 MachineBasicBlock
*JumpTableMBB
=
238 CurMF
->CreateMachineBasicBlock(SI
->getParent());
240 // Add successors. Note: use table order for determinism.
241 SmallPtrSet
<MachineBasicBlock
*, 8> Done
;
242 for (MachineBasicBlock
*Succ
: Table
) {
243 if (Done
.count(Succ
))
245 addSuccessorWithProb(JumpTableMBB
, Succ
, JTProbs
[Succ
]);
248 JumpTableMBB
->normalizeSuccProbs();
250 unsigned JTI
= CurMF
->getOrCreateJumpTableInfo(TLI
->getJumpTableEncoding())
251 ->createJumpTableIndex(Table
);
253 // Set up the jump table info.
254 JumpTable
JT(-1U, JTI
, JumpTableMBB
, nullptr);
255 JumpTableHeader
JTH(Clusters
[First
].Low
->getValue(),
256 Clusters
[Last
].High
->getValue(), SI
->getCondition(),
258 JTCases
.emplace_back(std::move(JTH
), std::move(JT
));
260 JTCluster
= CaseCluster::jumpTable(Clusters
[First
].Low
, Clusters
[Last
].High
,
261 JTCases
.size() - 1, Prob
);
265 void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector
&Clusters
,
266 const SwitchInst
*SI
) {
267 // Partition Clusters into as few subsets as possible, where each subset has a
268 // range that fits in a machine word and has <= 3 unique destinations.
271 // Clusters must be sorted and contain Range or JumpTable clusters.
272 assert(!Clusters
.empty());
273 assert(Clusters
[0].Kind
== CC_Range
|| Clusters
[0].Kind
== CC_JumpTable
);
274 for (const CaseCluster
&C
: Clusters
)
275 assert(C
.Kind
== CC_Range
|| C
.Kind
== CC_JumpTable
);
276 for (unsigned i
= 1; i
< Clusters
.size(); ++i
)
277 assert(Clusters
[i
-1].High
->getValue().slt(Clusters
[i
].Low
->getValue()));
280 // The algorithm below is not suitable for -O0.
281 if (TM
->getOptLevel() == CodeGenOptLevel::None
)
284 // If target does not have legal shift left, do not emit bit tests at all.
285 EVT PTy
= TLI
->getPointerTy(*DL
);
286 if (!TLI
->isOperationLegal(ISD::SHL
, PTy
))
289 int BitWidth
= PTy
.getSizeInBits();
290 const int64_t N
= Clusters
.size();
292 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
293 SmallVector
<unsigned, 8> MinPartitions(N
);
294 // LastElement[i] is the last element of the partition starting at i.
295 SmallVector
<unsigned, 8> LastElement(N
);
297 // FIXME: This might not be the best algorithm for finding bit test clusters.
299 // Base case: There is only one way to partition Clusters[N-1].
300 MinPartitions
[N
- 1] = 1;
301 LastElement
[N
- 1] = N
- 1;
303 // Note: loop indexes are signed to avoid underflow.
304 for (int64_t i
= N
- 2; i
>= 0; --i
) {
305 // Find optimal partitioning of Clusters[i..N-1].
306 // Baseline: Put Clusters[i] into a partition on its own.
307 MinPartitions
[i
] = MinPartitions
[i
+ 1] + 1;
310 // Search for a solution that results in fewer partitions.
311 // Note: the search is limited by BitWidth, reducing time complexity.
312 for (int64_t j
= std::min(N
- 1, i
+ BitWidth
- 1); j
> i
; --j
) {
313 // Try building a partition from Clusters[i..j].
316 if (!TLI
->rangeFitsInWord(Clusters
[i
].Low
->getValue(),
317 Clusters
[j
].High
->getValue(), *DL
))
320 // Check nbr of destinations and cluster types.
321 // FIXME: This works, but doesn't seem very efficient.
322 bool RangesOnly
= true;
323 BitVector
Dests(FuncInfo
.MF
->getNumBlockIDs());
324 for (int64_t k
= i
; k
<= j
; k
++) {
325 if (Clusters
[k
].Kind
!= CC_Range
) {
329 Dests
.set(Clusters
[k
].MBB
->getNumber());
331 if (!RangesOnly
|| Dests
.count() > 3)
334 // Check if it's a better partition.
335 unsigned NumPartitions
= 1 + (j
== N
- 1 ? 0 : MinPartitions
[j
+ 1]);
336 if (NumPartitions
< MinPartitions
[i
]) {
337 // Found a better partition.
338 MinPartitions
[i
] = NumPartitions
;
344 // Iterate over the partitions, replacing with bit-test clusters in-place.
345 unsigned DstIndex
= 0;
346 for (unsigned First
= 0, Last
; First
< N
; First
= Last
+ 1) {
347 Last
= LastElement
[First
];
348 assert(First
<= Last
);
349 assert(DstIndex
<= First
);
351 CaseCluster BitTestCluster
;
352 if (buildBitTests(Clusters
, First
, Last
, SI
, BitTestCluster
)) {
353 Clusters
[DstIndex
++] = BitTestCluster
;
355 size_t NumClusters
= Last
- First
+ 1;
356 std::memmove(&Clusters
[DstIndex
], &Clusters
[First
],
357 sizeof(Clusters
[0]) * NumClusters
);
358 DstIndex
+= NumClusters
;
361 Clusters
.resize(DstIndex
);
364 bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector
&Clusters
,
365 unsigned First
, unsigned Last
,
366 const SwitchInst
*SI
,
367 CaseCluster
&BTCluster
) {
368 assert(First
<= Last
);
372 BitVector
Dests(FuncInfo
.MF
->getNumBlockIDs());
373 unsigned NumCmps
= 0;
374 for (int64_t I
= First
; I
<= Last
; ++I
) {
375 assert(Clusters
[I
].Kind
== CC_Range
);
376 Dests
.set(Clusters
[I
].MBB
->getNumber());
377 NumCmps
+= (Clusters
[I
].Low
== Clusters
[I
].High
) ? 1 : 2;
379 unsigned NumDests
= Dests
.count();
381 APInt Low
= Clusters
[First
].Low
->getValue();
382 APInt High
= Clusters
[Last
].High
->getValue();
383 assert(Low
.slt(High
));
385 if (!TLI
->isSuitableForBitTests(NumDests
, NumCmps
, Low
, High
, *DL
))
391 const int BitWidth
= TLI
->getPointerTy(*DL
).getSizeInBits();
392 assert(TLI
->rangeFitsInWord(Low
, High
, *DL
) &&
393 "Case range must fit in bit mask!");
395 // Check if the clusters cover a contiguous range such that no value in the
396 // range will jump to the default statement.
397 bool ContiguousRange
= true;
398 for (int64_t I
= First
+ 1; I
<= Last
; ++I
) {
399 if (Clusters
[I
].Low
->getValue() != Clusters
[I
- 1].High
->getValue() + 1) {
400 ContiguousRange
= false;
405 if (Low
.isStrictlyPositive() && High
.slt(BitWidth
)) {
406 // Optimize the case where all the case values fit in a word without having
407 // to subtract minValue. In this case, we can optimize away the subtraction.
408 LowBound
= APInt::getZero(Low
.getBitWidth());
410 ContiguousRange
= false;
413 CmpRange
= High
- Low
;
417 auto TotalProb
= BranchProbability::getZero();
418 for (unsigned i
= First
; i
<= Last
; ++i
) {
419 // Find the CaseBits for this destination.
421 for (j
= 0; j
< CBV
.size(); ++j
)
422 if (CBV
[j
].BB
== Clusters
[i
].MBB
)
426 CaseBits(0, Clusters
[i
].MBB
, 0, BranchProbability::getZero()));
427 CaseBits
*CB
= &CBV
[j
];
429 // Update Mask, Bits and ExtraProb.
430 uint64_t Lo
= (Clusters
[i
].Low
->getValue() - LowBound
).getZExtValue();
431 uint64_t Hi
= (Clusters
[i
].High
->getValue() - LowBound
).getZExtValue();
432 assert(Hi
>= Lo
&& Hi
< 64 && "Invalid bit case!");
433 CB
->Mask
|= (-1ULL >> (63 - (Hi
- Lo
))) << Lo
;
434 CB
->Bits
+= Hi
- Lo
+ 1;
435 CB
->ExtraProb
+= Clusters
[i
].Prob
;
436 TotalProb
+= Clusters
[i
].Prob
;
440 llvm::sort(CBV
, [](const CaseBits
&a
, const CaseBits
&b
) {
441 // Sort by probability first, number of bits second, bit mask third.
442 if (a
.ExtraProb
!= b
.ExtraProb
)
443 return a
.ExtraProb
> b
.ExtraProb
;
444 if (a
.Bits
!= b
.Bits
)
445 return a
.Bits
> b
.Bits
;
446 return a
.Mask
< b
.Mask
;
449 for (auto &CB
: CBV
) {
450 MachineBasicBlock
*BitTestBB
=
451 FuncInfo
.MF
->CreateMachineBasicBlock(SI
->getParent());
452 BTI
.push_back(BitTestCase(CB
.Mask
, BitTestBB
, CB
.BB
, CB
.ExtraProb
));
454 BitTestCases
.emplace_back(std::move(LowBound
), std::move(CmpRange
),
455 SI
->getCondition(), -1U, MVT::Other
, false,
456 ContiguousRange
, nullptr, nullptr, std::move(BTI
),
459 BTCluster
= CaseCluster::bitTests(Clusters
[First
].Low
, Clusters
[Last
].High
,
460 BitTestCases
.size() - 1, TotalProb
);
464 void SwitchCG::sortAndRangeify(CaseClusterVector
&Clusters
) {
466 for (const CaseCluster
&CC
: Clusters
)
467 assert(CC
.Low
== CC
.High
&& "Input clusters must be single-case");
470 llvm::sort(Clusters
, [](const CaseCluster
&a
, const CaseCluster
&b
) {
471 return a
.Low
->getValue().slt(b
.Low
->getValue());
474 // Merge adjacent clusters with the same destination.
475 const unsigned N
= Clusters
.size();
476 unsigned DstIndex
= 0;
477 for (unsigned SrcIndex
= 0; SrcIndex
< N
; ++SrcIndex
) {
478 CaseCluster
&CC
= Clusters
[SrcIndex
];
479 const ConstantInt
*CaseVal
= CC
.Low
;
480 MachineBasicBlock
*Succ
= CC
.MBB
;
482 if (DstIndex
!= 0 && Clusters
[DstIndex
- 1].MBB
== Succ
&&
483 (CaseVal
->getValue() - Clusters
[DstIndex
- 1].High
->getValue()) == 1) {
484 // If this case has the same successor and is a neighbour, merge it into
485 // the previous cluster.
486 Clusters
[DstIndex
- 1].High
= CaseVal
;
487 Clusters
[DstIndex
- 1].Prob
+= CC
.Prob
;
489 std::memmove(&Clusters
[DstIndex
++], &Clusters
[SrcIndex
],
490 sizeof(Clusters
[SrcIndex
]));
493 Clusters
.resize(DstIndex
);