[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / SPIRV / SPIRVLegalizerInfo.cpp
blobe7b35555293a3e676dcfaea58a1b3570c5d4191b
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
11 //===----------------------------------------------------------------------===//
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
53 bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58 using namespace TargetOpcode;
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 =
106 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
109 // TODO: remove copy-pasting here by using concatenation in some way.
110 auto allPtrsScalarsAndVectors = {
111 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
116 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119 v16s8, v16s16, v16s32, v16s64};
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
123 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
124 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
126 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
127 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
128 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
129 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
131 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
133 auto allIntScalars = {s8, s16, s32, s64};
135 auto allFloatScalars = {s16, s32, s64};
137 auto allFloatScalarsAndVectors = {
138 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
139 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
141 auto allFloatAndIntScalars = allIntScalars;
143 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
144 auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
146 for (auto Opc : TypeFoldingSupportingOpcs)
147 getActionDefinitionsBuilder(Opc).custom();
149 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
151 // TODO: add proper rules for vectors legalization.
152 getActionDefinitionsBuilder(
153 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
154 .alwaysLegal();
156 // Vector Reduction Operations
157 getActionDefinitionsBuilder(
158 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
159 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
160 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
161 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
162 .legalFor(allVectors)
163 .scalarize(1)
164 .lower();
166 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
167 .scalarize(2)
168 .lower();
170 // Merge/Unmerge
171 // TODO: add proper legalization rules.
172 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
174 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
175 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
177 getActionDefinitionsBuilder(G_MEMSET).legalIf(
178 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
180 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
181 .legalForCartesianProduct(allPtrs, allPtrs);
183 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
185 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
187 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
189 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
190 .legalForCartesianProduct(allIntScalarsAndVectors,
191 allFloatScalarsAndVectors);
193 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
194 .legalForCartesianProduct(allFloatScalarsAndVectors,
195 allScalarsAndVectors);
197 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
198 .legalFor(allIntScalarsAndVectors);
200 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
201 allIntScalarsAndVectors, allIntScalarsAndVectors);
203 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
205 getActionDefinitionsBuilder(G_BITCAST).legalIf(
206 all(typeInSet(0, allPtrsScalarsAndVectors),
207 typeInSet(1, allPtrsScalarsAndVectors)));
209 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
213 getActionDefinitionsBuilder(G_INTTOPTR)
214 .legalForCartesianProduct(allPtrs, allIntScalars);
215 getActionDefinitionsBuilder(G_PTRTOINT)
216 .legalForCartesianProduct(allIntScalars, allPtrs);
217 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
218 allPtrs, allIntScalars);
220 // ST.canDirectlyComparePointers() for pointer args is supported in
221 // legalizeCustom().
222 getActionDefinitionsBuilder(G_ICMP).customIf(
223 all(typeInSet(0, allBoolScalarsAndVectors),
224 typeInSet(1, allPtrsScalarsAndVectors)));
226 getActionDefinitionsBuilder(G_FCMP).legalIf(
227 all(typeInSet(0, allBoolScalarsAndVectors),
228 typeInSet(1, allFloatScalarsAndVectors)));
230 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
231 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
232 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
233 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
234 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
236 getActionDefinitionsBuilder(
237 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
238 .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
240 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
241 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
243 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
244 // TODO: add proper legalization rules.
245 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
247 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
248 .alwaysLegal();
250 // Extensions.
251 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
252 .legalForCartesianProduct(allScalarsAndVectors);
254 // FP conversions.
255 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
256 .legalForCartesianProduct(allFloatScalarsAndVectors);
258 // Pointer-handling.
259 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
261 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
262 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
264 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
265 // tighten these requirements. Many of these math functions are only legal on
266 // specific bitwidths, so they are not selectable for
267 // allFloatScalarsAndVectors.
268 getActionDefinitionsBuilder({G_FPOW,
269 G_FEXP,
270 G_FEXP2,
271 G_FLOG,
272 G_FLOG2,
273 G_FLOG10,
274 G_FABS,
275 G_FMINNUM,
276 G_FMAXNUM,
277 G_FCEIL,
278 G_FCOS,
279 G_FSIN,
280 G_FTAN,
281 G_FSQRT,
282 G_FFLOOR,
283 G_FRINT,
284 G_FNEARBYINT,
285 G_INTRINSIC_ROUND,
286 G_INTRINSIC_TRUNC,
287 G_FMINIMUM,
288 G_FMAXIMUM,
289 G_INTRINSIC_ROUNDEVEN})
290 .legalFor(allFloatScalarsAndVectors);
292 getActionDefinitionsBuilder(G_FCOPYSIGN)
293 .legalForCartesianProduct(allFloatScalarsAndVectors,
294 allFloatScalarsAndVectors);
296 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
297 allFloatScalarsAndVectors, allIntScalarsAndVectors);
299 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
300 getActionDefinitionsBuilder(
301 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
302 .legalForCartesianProduct(allIntScalarsAndVectors,
303 allIntScalarsAndVectors);
305 // Struct return types become a single scalar, so cannot easily legalize.
306 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
309 getLegacyLegalizerInfo().computeTables();
310 verify(*ST.getInstrInfo());
313 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
314 LegalizerHelper &Helper,
315 MachineRegisterInfo &MRI,
316 SPIRVGlobalRegistry *GR) {
317 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
318 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
319 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
320 .addDef(ConvReg)
321 .addUse(Reg);
322 return ConvReg;
325 bool SPIRVLegalizerInfo::legalizeCustom(
326 LegalizerHelper &Helper, MachineInstr &MI,
327 LostDebugLocObserver &LocObserver) const {
328 auto Opc = MI.getOpcode();
329 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
330 if (!isTypeFoldingSupported(Opc)) {
331 assert(Opc == TargetOpcode::G_ICMP);
332 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
333 auto &Op0 = MI.getOperand(2);
334 auto &Op1 = MI.getOperand(3);
335 Register Reg0 = Op0.getReg();
336 Register Reg1 = Op1.getReg();
337 CmpInst::Predicate Cond =
338 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
339 if ((!ST->canDirectlyComparePointers() ||
340 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
341 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
342 LLT ConvT = LLT::scalar(ST->getPointerSize());
343 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
344 ST->getPointerSize());
345 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
346 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
347 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
349 return true;
351 // TODO: implement legalization for other opcodes.
352 return true;