1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/SelectionDAGNodes.h"
14 #include "llvm/CodeGen/TargetLowering.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/Support/BranchProbability.h"
20 class FunctionLoweringInfo
;
21 class MachineBasicBlock
;
25 enum CaseClusterKind
{
26 /// A cluster of adjacent case labels with the same destination, or just one
29 /// A cluster of cases suitable for jump table lowering.
31 /// A cluster of cases suitable for bit test lowering.
35 /// A cluster of case labels.
38 const ConstantInt
*Low
, *High
;
40 MachineBasicBlock
*MBB
;
41 unsigned JTCasesIndex
;
42 unsigned BTCasesIndex
;
44 BranchProbability Prob
;
46 static CaseCluster
range(const ConstantInt
*Low
, const ConstantInt
*High
,
47 MachineBasicBlock
*MBB
, BranchProbability Prob
) {
57 static CaseCluster
jumpTable(const ConstantInt
*Low
, const ConstantInt
*High
,
58 unsigned JTCasesIndex
, BranchProbability Prob
) {
60 C
.Kind
= CC_JumpTable
;
63 C
.JTCasesIndex
= JTCasesIndex
;
68 static CaseCluster
bitTests(const ConstantInt
*Low
, const ConstantInt
*High
,
69 unsigned BTCasesIndex
, BranchProbability Prob
) {
74 C
.BTCasesIndex
= BTCasesIndex
;
80 using CaseClusterVector
= std::vector
<CaseCluster
>;
81 using CaseClusterIt
= CaseClusterVector::iterator
;
83 /// Sort Clusters and merge adjacent cases.
84 void sortAndRangeify(CaseClusterVector
&Clusters
);
88 MachineBasicBlock
*BB
= nullptr;
90 BranchProbability ExtraProb
;
93 CaseBits(uint64_t mask
, MachineBasicBlock
*bb
, unsigned bits
,
94 BranchProbability Prob
)
95 : Mask(mask
), BB(bb
), Bits(bits
), ExtraProb(Prob
) {}
98 using CaseBitsVector
= std::vector
<CaseBits
>;
100 /// This structure is used to communicate between SelectionDAGBuilder and
101 /// SDISel for the code generation of additional basic blocks needed by
102 /// multi-case switch statements.
104 // For the GISel interface.
105 struct PredInfoPair
{
106 CmpInst::Predicate Pred
;
107 // Set when no comparison should be emitted.
111 // The condition code to use for the case block's setcc node.
112 // Besides the integer condition codes, this can also be SETTRUE, in which
113 // case no comparison gets emitted.
115 struct PredInfoPair PredInfo
;
118 // The LHS/MHS/RHS of the comparison to emit.
119 // Emit by default LHS op RHS. MHS is used for range comparisons:
120 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
121 const Value
*CmpLHS
, *CmpMHS
, *CmpRHS
;
123 // The block to branch to if the setcc is true/false.
124 MachineBasicBlock
*TrueBB
, *FalseBB
;
126 // The block into which to emit the code for the setcc and branches.
127 MachineBasicBlock
*ThisBB
;
129 /// The debug location of the instruction this CaseBlock was
135 BranchProbability TrueProb
, FalseProb
;
137 // Constructor for SelectionDAG.
138 CaseBlock(ISD::CondCode cc
, const Value
*cmplhs
, const Value
*cmprhs
,
139 const Value
*cmpmiddle
, MachineBasicBlock
*truebb
,
140 MachineBasicBlock
*falsebb
, MachineBasicBlock
*me
, SDLoc dl
,
141 BranchProbability trueprob
= BranchProbability::getUnknown(),
142 BranchProbability falseprob
= BranchProbability::getUnknown())
143 : CC(cc
), CmpLHS(cmplhs
), CmpMHS(cmpmiddle
), CmpRHS(cmprhs
),
144 TrueBB(truebb
), FalseBB(falsebb
), ThisBB(me
), DL(dl
),
145 TrueProb(trueprob
), FalseProb(falseprob
) {}
147 // Constructor for GISel.
148 CaseBlock(CmpInst::Predicate pred
, bool nocmp
, const Value
*cmplhs
,
149 const Value
*cmprhs
, const Value
*cmpmiddle
,
150 MachineBasicBlock
*truebb
, MachineBasicBlock
*falsebb
,
151 MachineBasicBlock
*me
, DebugLoc dl
,
152 BranchProbability trueprob
= BranchProbability::getUnknown(),
153 BranchProbability falseprob
= BranchProbability::getUnknown())
154 : PredInfo({pred
, nocmp
}), CmpLHS(cmplhs
), CmpMHS(cmpmiddle
),
155 CmpRHS(cmprhs
), TrueBB(truebb
), FalseBB(falsebb
), ThisBB(me
),
156 DbgLoc(dl
), TrueProb(trueprob
), FalseProb(falseprob
) {}
160 /// The virtual register containing the index of the jump table entry
163 /// The JumpTableIndex for this jump table in the function.
165 /// The MBB into which to emit the code for the indirect jump.
166 MachineBasicBlock
*MBB
;
167 /// The MBB of the default bb, which is a successor of the range
168 /// check MBB. This is when updating PHI nodes in successors.
169 MachineBasicBlock
*Default
;
171 JumpTable(unsigned R
, unsigned J
, MachineBasicBlock
*M
, MachineBasicBlock
*D
)
172 : Reg(R
), JTI(J
), MBB(M
), Default(D
) {}
174 struct JumpTableHeader
{
178 MachineBasicBlock
*HeaderBB
;
182 JumpTableHeader(APInt F
, APInt L
, const Value
*SV
, MachineBasicBlock
*H
,
184 : First(std::move(F
)), Last(std::move(L
)), SValue(SV
), HeaderBB(H
),
185 Emitted(E
), OmitRangeCheck(false) {}
187 using JumpTableBlock
= std::pair
<JumpTableHeader
, JumpTable
>;
191 MachineBasicBlock
*ThisBB
;
192 MachineBasicBlock
*TargetBB
;
193 BranchProbability ExtraProb
;
195 BitTestCase(uint64_t M
, MachineBasicBlock
*T
, MachineBasicBlock
*Tr
,
196 BranchProbability Prob
)
197 : Mask(M
), ThisBB(T
), TargetBB(Tr
), ExtraProb(Prob
) {}
200 using BitTestInfo
= SmallVector
<BitTestCase
, 3>;
202 struct BitTestBlock
{
209 bool ContiguousRange
;
210 MachineBasicBlock
*Parent
;
211 MachineBasicBlock
*Default
;
213 BranchProbability Prob
;
214 BranchProbability DefaultProb
;
216 BitTestBlock(APInt F
, APInt R
, const Value
*SV
, unsigned Rg
, MVT RgVT
, bool E
,
217 bool CR
, MachineBasicBlock
*P
, MachineBasicBlock
*D
,
218 BitTestInfo C
, BranchProbability Pr
)
219 : First(std::move(F
)), Range(std::move(R
)), SValue(SV
), Reg(Rg
),
220 RegVT(RgVT
), Emitted(E
), ContiguousRange(CR
), Parent(P
), Default(D
),
221 Cases(std::move(C
)), Prob(Pr
) {}
224 /// Return the range of value within a range.
225 uint64_t getJumpTableRange(const CaseClusterVector
&Clusters
, unsigned First
,
228 /// Return the number of cases within a range.
229 uint64_t getJumpTableNumCases(const SmallVectorImpl
<unsigned> &TotalCases
,
230 unsigned First
, unsigned Last
);
232 struct SwitchWorkListItem
{
233 MachineBasicBlock
*MBB
;
234 CaseClusterIt FirstCluster
;
235 CaseClusterIt LastCluster
;
236 const ConstantInt
*GE
;
237 const ConstantInt
*LT
;
238 BranchProbability DefaultProb
;
240 using SwitchWorkList
= SmallVector
<SwitchWorkListItem
, 4>;
242 class SwitchLowering
{
244 SwitchLowering(FunctionLoweringInfo
&funcinfo
) : FuncInfo(funcinfo
) {}
246 void init(const TargetLowering
&tli
, const TargetMachine
&tm
,
247 const DataLayout
&dl
) {
253 /// Vector of CaseBlock structures used to communicate SwitchInst code
254 /// generation information.
255 std::vector
<CaseBlock
> SwitchCases
;
257 /// Vector of JumpTable structures used to communicate SwitchInst code
258 /// generation information.
259 std::vector
<JumpTableBlock
> JTCases
;
261 /// Vector of BitTestBlock structures used to communicate SwitchInst code
262 /// generation information.
263 std::vector
<BitTestBlock
> BitTestCases
;
265 void findJumpTables(CaseClusterVector
&Clusters
, const SwitchInst
*SI
,
266 MachineBasicBlock
*DefaultMBB
);
268 bool buildJumpTable(const CaseClusterVector
&Clusters
, unsigned First
,
269 unsigned Last
, const SwitchInst
*SI
,
270 MachineBasicBlock
*DefaultMBB
, CaseCluster
&JTCluster
);
273 void findBitTestClusters(CaseClusterVector
&Clusters
, const SwitchInst
*SI
);
275 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
276 /// decides it's not a good idea.
277 bool buildBitTests(CaseClusterVector
&Clusters
, unsigned First
, unsigned Last
,
278 const SwitchInst
*SI
, CaseCluster
&BTCluster
);
280 virtual void addSuccessorWithProb(
281 MachineBasicBlock
*Src
, MachineBasicBlock
*Dst
,
282 BranchProbability Prob
= BranchProbability::getUnknown()) = 0;
284 virtual ~SwitchLowering() = default;
287 const TargetLowering
*TLI
;
288 const TargetMachine
*TM
;
289 const DataLayout
*DL
;
290 FunctionLoweringInfo
&FuncInfo
;
293 } // namespace SwitchCG
296 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H