1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
11 //===----------------------------------------------------------------------===//
13 #include "SPIRVLegalizerInfo.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
24 using namespace llvm::LegalizeActions
;
25 using namespace llvm::LegalityPredicates
;
27 static const std::set
<unsigned> TypeFoldingSupportingOpcs
= {
41 TargetOpcode::G_CONSTANT
,
42 TargetOpcode::G_FCONSTANT
,
49 TargetOpcode::G_SELECT
,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT
,
53 bool isTypeFoldingSupported(unsigned Opcode
) {
54 return TypeFoldingSupportingOpcs
.count(Opcode
) > 0;
57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget
&ST
) {
58 using namespace TargetOpcode
;
61 GR
= ST
.getSPIRVGlobalRegistry();
63 const LLT s1
= LLT::scalar(1);
64 const LLT s8
= LLT::scalar(8);
65 const LLT s16
= LLT::scalar(16);
66 const LLT s32
= LLT::scalar(32);
67 const LLT s64
= LLT::scalar(64);
69 const LLT v16s64
= LLT::fixed_vector(16, 64);
70 const LLT v16s32
= LLT::fixed_vector(16, 32);
71 const LLT v16s16
= LLT::fixed_vector(16, 16);
72 const LLT v16s8
= LLT::fixed_vector(16, 8);
73 const LLT v16s1
= LLT::fixed_vector(16, 1);
75 const LLT v8s64
= LLT::fixed_vector(8, 64);
76 const LLT v8s32
= LLT::fixed_vector(8, 32);
77 const LLT v8s16
= LLT::fixed_vector(8, 16);
78 const LLT v8s8
= LLT::fixed_vector(8, 8);
79 const LLT v8s1
= LLT::fixed_vector(8, 1);
81 const LLT v4s64
= LLT::fixed_vector(4, 64);
82 const LLT v4s32
= LLT::fixed_vector(4, 32);
83 const LLT v4s16
= LLT::fixed_vector(4, 16);
84 const LLT v4s8
= LLT::fixed_vector(4, 8);
85 const LLT v4s1
= LLT::fixed_vector(4, 1);
87 const LLT v3s64
= LLT::fixed_vector(3, 64);
88 const LLT v3s32
= LLT::fixed_vector(3, 32);
89 const LLT v3s16
= LLT::fixed_vector(3, 16);
90 const LLT v3s8
= LLT::fixed_vector(3, 8);
91 const LLT v3s1
= LLT::fixed_vector(3, 1);
93 const LLT v2s64
= LLT::fixed_vector(2, 64);
94 const LLT v2s32
= LLT::fixed_vector(2, 32);
95 const LLT v2s16
= LLT::fixed_vector(2, 16);
96 const LLT v2s8
= LLT::fixed_vector(2, 8);
97 const LLT v2s1
= LLT::fixed_vector(2, 1);
99 const unsigned PSize
= ST
.getPointerSize();
100 const LLT p0
= LLT::pointer(0, PSize
); // Function
101 const LLT p1
= LLT::pointer(1, PSize
); // CrossWorkgroup
102 const LLT p2
= LLT::pointer(2, PSize
); // UniformConstant
103 const LLT p3
= LLT::pointer(3, PSize
); // Workgroup
104 const LLT p4
= LLT::pointer(4, PSize
); // Generic
106 LLT::pointer(5, PSize
); // Input, SPV_INTEL_usm_storage_classes (Device)
107 const LLT p6
= LLT::pointer(6, PSize
); // SPV_INTEL_usm_storage_classes (Host)
109 // TODO: remove copy-pasting here by using concatenation in some way.
110 auto allPtrsScalarsAndVectors
= {
111 p0
, p1
, p2
, p3
, p4
, p5
, p6
, s1
, s8
, s16
,
112 s32
, s64
, v2s1
, v2s8
, v2s16
, v2s32
, v2s64
, v3s1
, v3s8
, v3s16
,
113 v3s32
, v3s64
, v4s1
, v4s8
, v4s16
, v4s32
, v4s64
, v8s1
, v8s8
, v8s16
,
114 v8s32
, v8s64
, v16s1
, v16s8
, v16s16
, v16s32
, v16s64
};
116 auto allVectors
= {v2s1
, v2s8
, v2s16
, v2s32
, v2s64
, v3s1
, v3s8
,
117 v3s16
, v3s32
, v3s64
, v4s1
, v4s8
, v4s16
, v4s32
,
118 v4s64
, v8s1
, v8s8
, v8s16
, v8s32
, v8s64
, v16s1
,
119 v16s8
, v16s16
, v16s32
, v16s64
};
121 auto allScalarsAndVectors
= {
122 s1
, s8
, s16
, s32
, s64
, v2s1
, v2s8
, v2s16
, v2s32
, v2s64
,
123 v3s1
, v3s8
, v3s16
, v3s32
, v3s64
, v4s1
, v4s8
, v4s16
, v4s32
, v4s64
,
124 v8s1
, v8s8
, v8s16
, v8s32
, v8s64
, v16s1
, v16s8
, v16s16
, v16s32
, v16s64
};
126 auto allIntScalarsAndVectors
= {s8
, s16
, s32
, s64
, v2s8
, v2s16
,
127 v2s32
, v2s64
, v3s8
, v3s16
, v3s32
, v3s64
,
128 v4s8
, v4s16
, v4s32
, v4s64
, v8s8
, v8s16
,
129 v8s32
, v8s64
, v16s8
, v16s16
, v16s32
, v16s64
};
131 auto allBoolScalarsAndVectors
= {s1
, v2s1
, v3s1
, v4s1
, v8s1
, v16s1
};
133 auto allIntScalars
= {s8
, s16
, s32
, s64
};
135 auto allFloatScalars
= {s16
, s32
, s64
};
137 auto allFloatScalarsAndVectors
= {
138 s16
, s32
, s64
, v2s16
, v2s32
, v2s64
, v3s16
, v3s32
, v3s64
,
139 v4s16
, v4s32
, v4s64
, v8s16
, v8s32
, v8s64
, v16s16
, v16s32
, v16s64
};
141 auto allFloatAndIntScalars
= allIntScalars
;
143 auto allPtrs
= {p0
, p1
, p2
, p3
, p4
, p5
, p6
};
144 auto allWritablePtrs
= {p0
, p1
, p3
, p4
, p5
, p6
};
146 for (auto Opc
: TypeFoldingSupportingOpcs
)
147 getActionDefinitionsBuilder(Opc
).custom();
149 getActionDefinitionsBuilder(G_GLOBAL_VALUE
).alwaysLegal();
151 // TODO: add proper rules for vectors legalization.
152 getActionDefinitionsBuilder(
153 {G_BUILD_VECTOR
, G_SHUFFLE_VECTOR
, G_SPLAT_VECTOR
})
156 // Vector Reduction Operations
157 getActionDefinitionsBuilder(
158 {G_VECREDUCE_SMIN
, G_VECREDUCE_SMAX
, G_VECREDUCE_UMIN
, G_VECREDUCE_UMAX
,
159 G_VECREDUCE_ADD
, G_VECREDUCE_MUL
, G_VECREDUCE_FMUL
, G_VECREDUCE_FMIN
,
160 G_VECREDUCE_FMAX
, G_VECREDUCE_FMINIMUM
, G_VECREDUCE_FMAXIMUM
,
161 G_VECREDUCE_OR
, G_VECREDUCE_AND
, G_VECREDUCE_XOR
})
162 .legalFor(allVectors
)
166 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD
, G_VECREDUCE_SEQ_FMUL
})
171 // TODO: add proper legalization rules.
172 getActionDefinitionsBuilder(G_UNMERGE_VALUES
).alwaysLegal();
174 getActionDefinitionsBuilder({G_MEMCPY
, G_MEMMOVE
})
175 .legalIf(all(typeInSet(0, allWritablePtrs
), typeInSet(1, allPtrs
)));
177 getActionDefinitionsBuilder(G_MEMSET
).legalIf(
178 all(typeInSet(0, allWritablePtrs
), typeInSet(1, allIntScalars
)));
180 getActionDefinitionsBuilder(G_ADDRSPACE_CAST
)
181 .legalForCartesianProduct(allPtrs
, allPtrs
);
183 getActionDefinitionsBuilder({G_LOAD
, G_STORE
}).legalIf(typeInSet(1, allPtrs
));
185 getActionDefinitionsBuilder(G_BITREVERSE
).legalFor(allFloatScalarsAndVectors
);
187 getActionDefinitionsBuilder(G_FMA
).legalFor(allFloatScalarsAndVectors
);
189 getActionDefinitionsBuilder({G_FPTOSI
, G_FPTOUI
})
190 .legalForCartesianProduct(allIntScalarsAndVectors
,
191 allFloatScalarsAndVectors
);
193 getActionDefinitionsBuilder({G_SITOFP
, G_UITOFP
})
194 .legalForCartesianProduct(allFloatScalarsAndVectors
,
195 allScalarsAndVectors
);
197 getActionDefinitionsBuilder({G_SMIN
, G_SMAX
, G_UMIN
, G_UMAX
, G_ABS
})
198 .legalFor(allIntScalarsAndVectors
);
200 getActionDefinitionsBuilder(G_CTPOP
).legalForCartesianProduct(
201 allIntScalarsAndVectors
, allIntScalarsAndVectors
);
203 getActionDefinitionsBuilder(G_PHI
).legalFor(allPtrsScalarsAndVectors
);
205 getActionDefinitionsBuilder(G_BITCAST
).legalIf(
206 all(typeInSet(0, allPtrsScalarsAndVectors
),
207 typeInSet(1, allPtrsScalarsAndVectors
)));
209 getActionDefinitionsBuilder({G_IMPLICIT_DEF
, G_FREEZE
}).alwaysLegal();
211 getActionDefinitionsBuilder({G_STACKSAVE
, G_STACKRESTORE
}).alwaysLegal();
213 getActionDefinitionsBuilder(G_INTTOPTR
)
214 .legalForCartesianProduct(allPtrs
, allIntScalars
);
215 getActionDefinitionsBuilder(G_PTRTOINT
)
216 .legalForCartesianProduct(allIntScalars
, allPtrs
);
217 getActionDefinitionsBuilder(G_PTR_ADD
).legalForCartesianProduct(
218 allPtrs
, allIntScalars
);
220 // ST.canDirectlyComparePointers() for pointer args is supported in
222 getActionDefinitionsBuilder(G_ICMP
).customIf(
223 all(typeInSet(0, allBoolScalarsAndVectors
),
224 typeInSet(1, allPtrsScalarsAndVectors
)));
226 getActionDefinitionsBuilder(G_FCMP
).legalIf(
227 all(typeInSet(0, allBoolScalarsAndVectors
),
228 typeInSet(1, allFloatScalarsAndVectors
)));
230 getActionDefinitionsBuilder({G_ATOMICRMW_OR
, G_ATOMICRMW_ADD
, G_ATOMICRMW_AND
,
231 G_ATOMICRMW_MAX
, G_ATOMICRMW_MIN
,
232 G_ATOMICRMW_SUB
, G_ATOMICRMW_XOR
,
233 G_ATOMICRMW_UMAX
, G_ATOMICRMW_UMIN
})
234 .legalForCartesianProduct(allIntScalars
, allWritablePtrs
);
236 getActionDefinitionsBuilder(
237 {G_ATOMICRMW_FADD
, G_ATOMICRMW_FSUB
, G_ATOMICRMW_FMIN
, G_ATOMICRMW_FMAX
})
238 .legalForCartesianProduct(allFloatScalars
, allWritablePtrs
);
240 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG
)
241 .legalForCartesianProduct(allFloatAndIntScalars
, allWritablePtrs
);
243 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS
).lower();
244 // TODO: add proper legalization rules.
245 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG
).alwaysLegal();
247 getActionDefinitionsBuilder({G_UADDO
, G_USUBO
, G_SMULO
, G_UMULO
})
251 getActionDefinitionsBuilder({G_TRUNC
, G_ZEXT
, G_SEXT
, G_ANYEXT
})
252 .legalForCartesianProduct(allScalarsAndVectors
);
255 getActionDefinitionsBuilder({G_FPTRUNC
, G_FPEXT
})
256 .legalForCartesianProduct(allFloatScalarsAndVectors
);
259 getActionDefinitionsBuilder(G_FRAME_INDEX
).legalFor({p0
});
261 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
262 getActionDefinitionsBuilder(G_BRCOND
).legalFor({s1
, s32
});
264 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
265 // tighten these requirements. Many of these math functions are only legal on
266 // specific bitwidths, so they are not selectable for
267 // allFloatScalarsAndVectors.
268 getActionDefinitionsBuilder({G_FPOW
,
289 G_INTRINSIC_ROUNDEVEN
})
290 .legalFor(allFloatScalarsAndVectors
);
292 getActionDefinitionsBuilder(G_FCOPYSIGN
)
293 .legalForCartesianProduct(allFloatScalarsAndVectors
,
294 allFloatScalarsAndVectors
);
296 getActionDefinitionsBuilder(G_FPOWI
).legalForCartesianProduct(
297 allFloatScalarsAndVectors
, allIntScalarsAndVectors
);
299 if (ST
.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std
)) {
300 getActionDefinitionsBuilder(
301 {G_CTTZ
, G_CTTZ_ZERO_UNDEF
, G_CTLZ
, G_CTLZ_ZERO_UNDEF
})
302 .legalForCartesianProduct(allIntScalarsAndVectors
,
303 allIntScalarsAndVectors
);
305 // Struct return types become a single scalar, so cannot easily legalize.
306 getActionDefinitionsBuilder({G_SMULH
, G_UMULH
}).alwaysLegal();
309 getLegacyLegalizerInfo().computeTables();
310 verify(*ST
.getInstrInfo());
313 static Register
convertPtrToInt(Register Reg
, LLT ConvTy
, SPIRVType
*SpirvType
,
314 LegalizerHelper
&Helper
,
315 MachineRegisterInfo
&MRI
,
316 SPIRVGlobalRegistry
*GR
) {
317 Register ConvReg
= MRI
.createGenericVirtualRegister(ConvTy
);
318 GR
->assignSPIRVTypeToVReg(SpirvType
, ConvReg
, Helper
.MIRBuilder
.getMF());
319 Helper
.MIRBuilder
.buildInstr(TargetOpcode::G_PTRTOINT
)
325 bool SPIRVLegalizerInfo::legalizeCustom(
326 LegalizerHelper
&Helper
, MachineInstr
&MI
,
327 LostDebugLocObserver
&LocObserver
) const {
328 auto Opc
= MI
.getOpcode();
329 MachineRegisterInfo
&MRI
= MI
.getMF()->getRegInfo();
330 if (!isTypeFoldingSupported(Opc
)) {
331 assert(Opc
== TargetOpcode::G_ICMP
);
332 assert(GR
->getSPIRVTypeForVReg(MI
.getOperand(0).getReg()));
333 auto &Op0
= MI
.getOperand(2);
334 auto &Op1
= MI
.getOperand(3);
335 Register Reg0
= Op0
.getReg();
336 Register Reg1
= Op1
.getReg();
337 CmpInst::Predicate Cond
=
338 static_cast<CmpInst::Predicate
>(MI
.getOperand(1).getPredicate());
339 if ((!ST
->canDirectlyComparePointers() ||
340 (Cond
!= CmpInst::ICMP_EQ
&& Cond
!= CmpInst::ICMP_NE
)) &&
341 MRI
.getType(Reg0
).isPointer() && MRI
.getType(Reg1
).isPointer()) {
342 LLT ConvT
= LLT::scalar(ST
->getPointerSize());
343 Type
*LLVMTy
= IntegerType::get(MI
.getMF()->getFunction().getContext(),
344 ST
->getPointerSize());
345 SPIRVType
*SpirvTy
= GR
->getOrCreateSPIRVType(LLVMTy
, Helper
.MIRBuilder
);
346 Op0
.setReg(convertPtrToInt(Reg0
, ConvT
, SpirvTy
, Helper
, MRI
, GR
));
347 Op1
.setReg(convertPtrToInt(Reg1
, ConvT
, SpirvTy
, Helper
, MRI
, GR
));
351 // TODO: implement legalization for other opcodes.