[Alignment][NFC] Support compile time constants
[llvm-core.git] / include / llvm / CodeGen / SwitchLoweringUtils.h
blobb8adcf759b197cdd71eafa6879e621ee46e4f78d
1 //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10 #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/SelectionDAGNodes.h"
14 #include "llvm/CodeGen/TargetLowering.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/Support/BranchProbability.h"
18 namespace llvm {
20 class FunctionLoweringInfo;
21 class MachineBasicBlock;
23 namespace SwitchCG {
25 enum CaseClusterKind {
26 /// A cluster of adjacent case labels with the same destination, or just one
27 /// case.
28 CC_Range,
29 /// A cluster of cases suitable for jump table lowering.
30 CC_JumpTable,
31 /// A cluster of cases suitable for bit test lowering.
32 CC_BitTests
35 /// A cluster of case labels.
36 struct CaseCluster {
37 CaseClusterKind Kind;
38 const ConstantInt *Low, *High;
39 union {
40 MachineBasicBlock *MBB;
41 unsigned JTCasesIndex;
42 unsigned BTCasesIndex;
44 BranchProbability Prob;
46 static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
47 MachineBasicBlock *MBB, BranchProbability Prob) {
48 CaseCluster C;
49 C.Kind = CC_Range;
50 C.Low = Low;
51 C.High = High;
52 C.MBB = MBB;
53 C.Prob = Prob;
54 return C;
57 static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
58 unsigned JTCasesIndex, BranchProbability Prob) {
59 CaseCluster C;
60 C.Kind = CC_JumpTable;
61 C.Low = Low;
62 C.High = High;
63 C.JTCasesIndex = JTCasesIndex;
64 C.Prob = Prob;
65 return C;
68 static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
69 unsigned BTCasesIndex, BranchProbability Prob) {
70 CaseCluster C;
71 C.Kind = CC_BitTests;
72 C.Low = Low;
73 C.High = High;
74 C.BTCasesIndex = BTCasesIndex;
75 C.Prob = Prob;
76 return C;
80 using CaseClusterVector = std::vector<CaseCluster>;
81 using CaseClusterIt = CaseClusterVector::iterator;
83 /// Sort Clusters and merge adjacent cases.
84 void sortAndRangeify(CaseClusterVector &Clusters);
86 struct CaseBits {
87 uint64_t Mask = 0;
88 MachineBasicBlock *BB = nullptr;
89 unsigned Bits = 0;
90 BranchProbability ExtraProb;
92 CaseBits() = default;
93 CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
94 BranchProbability Prob)
95 : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
98 using CaseBitsVector = std::vector<CaseBits>;
100 /// This structure is used to communicate between SelectionDAGBuilder and
101 /// SDISel for the code generation of additional basic blocks needed by
102 /// multi-case switch statements.
103 struct CaseBlock {
104 // For the GISel interface.
105 struct PredInfoPair {
106 CmpInst::Predicate Pred;
107 // Set when no comparison should be emitted.
108 bool NoCmp;
110 union {
111 // The condition code to use for the case block's setcc node.
112 // Besides the integer condition codes, this can also be SETTRUE, in which
113 // case no comparison gets emitted.
114 ISD::CondCode CC;
115 struct PredInfoPair PredInfo;
118 // The LHS/MHS/RHS of the comparison to emit.
119 // Emit by default LHS op RHS. MHS is used for range comparisons:
120 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
121 const Value *CmpLHS, *CmpMHS, *CmpRHS;
123 // The block to branch to if the setcc is true/false.
124 MachineBasicBlock *TrueBB, *FalseBB;
126 // The block into which to emit the code for the setcc and branches.
127 MachineBasicBlock *ThisBB;
129 /// The debug location of the instruction this CaseBlock was
130 /// produced from.
131 SDLoc DL;
132 DebugLoc DbgLoc;
134 // Branch weights.
135 BranchProbability TrueProb, FalseProb;
137 // Constructor for SelectionDAG.
138 CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
139 const Value *cmpmiddle, MachineBasicBlock *truebb,
140 MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
141 BranchProbability trueprob = BranchProbability::getUnknown(),
142 BranchProbability falseprob = BranchProbability::getUnknown())
143 : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
144 TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
145 TrueProb(trueprob), FalseProb(falseprob) {}
147 // Constructor for GISel.
148 CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
149 const Value *cmprhs, const Value *cmpmiddle,
150 MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
151 MachineBasicBlock *me, DebugLoc dl,
152 BranchProbability trueprob = BranchProbability::getUnknown(),
153 BranchProbability falseprob = BranchProbability::getUnknown())
154 : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
155 CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
156 DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
159 struct JumpTable {
160 /// The virtual register containing the index of the jump table entry
161 /// to jump to.
162 unsigned Reg;
163 /// The JumpTableIndex for this jump table in the function.
164 unsigned JTI;
165 /// The MBB into which to emit the code for the indirect jump.
166 MachineBasicBlock *MBB;
167 /// The MBB of the default bb, which is a successor of the range
168 /// check MBB. This is when updating PHI nodes in successors.
169 MachineBasicBlock *Default;
171 JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
172 : Reg(R), JTI(J), MBB(M), Default(D) {}
174 struct JumpTableHeader {
175 APInt First;
176 APInt Last;
177 const Value *SValue;
178 MachineBasicBlock *HeaderBB;
179 bool Emitted;
180 bool OmitRangeCheck;
182 JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
183 bool E = false)
184 : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
185 Emitted(E), OmitRangeCheck(false) {}
187 using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
189 struct BitTestCase {
190 uint64_t Mask;
191 MachineBasicBlock *ThisBB;
192 MachineBasicBlock *TargetBB;
193 BranchProbability ExtraProb;
195 BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
196 BranchProbability Prob)
197 : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
200 using BitTestInfo = SmallVector<BitTestCase, 3>;
202 struct BitTestBlock {
203 APInt First;
204 APInt Range;
205 const Value *SValue;
206 unsigned Reg;
207 MVT RegVT;
208 bool Emitted;
209 bool ContiguousRange;
210 MachineBasicBlock *Parent;
211 MachineBasicBlock *Default;
212 BitTestInfo Cases;
213 BranchProbability Prob;
214 BranchProbability DefaultProb;
215 bool OmitRangeCheck;
217 BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
218 bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
219 BitTestInfo C, BranchProbability Pr)
220 : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
221 RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
222 Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {}
225 /// Return the range of values within a range.
226 uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
227 unsigned Last);
229 /// Return the number of cases within a range.
230 uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
231 unsigned First, unsigned Last);
233 struct SwitchWorkListItem {
234 MachineBasicBlock *MBB;
235 CaseClusterIt FirstCluster;
236 CaseClusterIt LastCluster;
237 const ConstantInt *GE;
238 const ConstantInt *LT;
239 BranchProbability DefaultProb;
241 using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
243 class SwitchLowering {
244 public:
245 SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
247 void init(const TargetLowering &tli, const TargetMachine &tm,
248 const DataLayout &dl) {
249 TLI = &tli;
250 TM = &tm;
251 DL = &dl;
254 /// Vector of CaseBlock structures used to communicate SwitchInst code
255 /// generation information.
256 std::vector<CaseBlock> SwitchCases;
258 /// Vector of JumpTable structures used to communicate SwitchInst code
259 /// generation information.
260 std::vector<JumpTableBlock> JTCases;
262 /// Vector of BitTestBlock structures used to communicate SwitchInst code
263 /// generation information.
264 std::vector<BitTestBlock> BitTestCases;
266 void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
267 MachineBasicBlock *DefaultMBB);
269 bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
270 unsigned Last, const SwitchInst *SI,
271 MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
274 void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
276 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
277 /// decides it's not a good idea.
278 bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
279 const SwitchInst *SI, CaseCluster &BTCluster);
281 virtual void addSuccessorWithProb(
282 MachineBasicBlock *Src, MachineBasicBlock *Dst,
283 BranchProbability Prob = BranchProbability::getUnknown()) = 0;
285 virtual ~SwitchLowering() = default;
287 private:
288 const TargetLowering *TLI;
289 const TargetMachine *TM;
290 const DataLayout *DL;
291 FunctionLoweringInfo &FuncInfo;
294 } // namespace SwitchCG
295 } // namespace llvm
297 #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H