1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
10 #include "../SPIRVCommon/Pattern.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/Dialect/Index/IR/IndexOps.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "mlir/Pass/Pass.h"
19 using namespace index
;
23 //===----------------------------------------------------------------------===//
24 // Trivial Conversions
25 //===----------------------------------------------------------------------===//
27 using ConvertIndexAdd
= spirv::ElementwiseOpPattern
<AddOp
, spirv::IAddOp
>;
28 using ConvertIndexSub
= spirv::ElementwiseOpPattern
<SubOp
, spirv::ISubOp
>;
29 using ConvertIndexMul
= spirv::ElementwiseOpPattern
<MulOp
, spirv::IMulOp
>;
30 using ConvertIndexDivS
= spirv::ElementwiseOpPattern
<DivSOp
, spirv::SDivOp
>;
31 using ConvertIndexDivU
= spirv::ElementwiseOpPattern
<DivUOp
, spirv::UDivOp
>;
32 using ConvertIndexRemS
= spirv::ElementwiseOpPattern
<RemSOp
, spirv::SRemOp
>;
33 using ConvertIndexRemU
= spirv::ElementwiseOpPattern
<RemUOp
, spirv::UModOp
>;
34 using ConvertIndexMaxS
= spirv::ElementwiseOpPattern
<MaxSOp
, spirv::GLSMaxOp
>;
35 using ConvertIndexMaxU
= spirv::ElementwiseOpPattern
<MaxUOp
, spirv::GLUMaxOp
>;
36 using ConvertIndexMinS
= spirv::ElementwiseOpPattern
<MinSOp
, spirv::GLSMinOp
>;
37 using ConvertIndexMinU
= spirv::ElementwiseOpPattern
<MinUOp
, spirv::GLUMinOp
>;
39 using ConvertIndexShl
=
40 spirv::ElementwiseOpPattern
<ShlOp
, spirv::ShiftLeftLogicalOp
>;
41 using ConvertIndexShrS
=
42 spirv::ElementwiseOpPattern
<ShrSOp
, spirv::ShiftRightArithmeticOp
>;
43 using ConvertIndexShrU
=
44 spirv::ElementwiseOpPattern
<ShrUOp
, spirv::ShiftRightLogicalOp
>;
46 /// It is the case that when we convert bitwise operations to SPIR-V operations
47 /// we must take into account the special pattern in SPIR-V that if the
48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50 /// index.add is never a boolean operation so we can directly convert it to the
51 /// Bitwise[And|Or]Op.
52 using ConvertIndexAnd
= spirv::ElementwiseOpPattern
<AndOp
, spirv::BitwiseAndOp
>;
53 using ConvertIndexOr
= spirv::ElementwiseOpPattern
<OrOp
, spirv::BitwiseOrOp
>;
54 using ConvertIndexXor
= spirv::ElementwiseOpPattern
<XOrOp
, spirv::BitwiseXorOp
>;
56 //===----------------------------------------------------------------------===//
57 // ConvertConstantBool
58 //===----------------------------------------------------------------------===//
60 // Converts index.bool.constant operation to spirv.Constant.
61 struct ConvertIndexConstantBoolOpPattern final
62 : OpConversionPattern
<BoolConstantOp
> {
63 using OpConversionPattern::OpConversionPattern
;
66 matchAndRewrite(BoolConstantOp op
, BoolConstantOpAdaptor adaptor
,
67 ConversionPatternRewriter
&rewriter
) const override
{
68 rewriter
.replaceOpWithNewOp
<spirv::ConstantOp
>(op
, op
.getType(),
74 //===----------------------------------------------------------------------===//
76 //===----------------------------------------------------------------------===//
78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
80 struct ConvertIndexConstantOpPattern final
: OpConversionPattern
<ConstantOp
> {
81 using OpConversionPattern::OpConversionPattern
;
84 matchAndRewrite(ConstantOp op
, ConstantOpAdaptor adaptor
,
85 ConversionPatternRewriter
&rewriter
) const override
{
86 auto *typeConverter
= this->template getTypeConverter
<SPIRVTypeConverter
>();
87 Type indexType
= typeConverter
->getIndexType();
89 APInt value
= op
.getValue().trunc(typeConverter
->getIndexTypeBitwidth());
90 rewriter
.replaceOpWithNewOp
<spirv::ConstantOp
>(
91 op
, indexType
, IntegerAttr::get(indexType
, value
));
96 //===----------------------------------------------------------------------===//
97 // ConvertIndexCeilDivS
98 //===----------------------------------------------------------------------===//
100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102 /// conversion in IndexToLLVM.
103 struct ConvertIndexCeilDivSPattern final
: OpConversionPattern
<CeilDivSOp
> {
104 using OpConversionPattern::OpConversionPattern
;
107 matchAndRewrite(CeilDivSOp op
, CeilDivSOpAdaptor adaptor
,
108 ConversionPatternRewriter
&rewriter
) const override
{
109 Location loc
= op
.getLoc();
110 Value n
= adaptor
.getLhs();
111 Type n_type
= n
.getType();
112 Value m
= adaptor
.getRhs();
114 // Define the constants
115 Value zero
= rewriter
.create
<spirv::ConstantOp
>(
116 loc
, n_type
, IntegerAttr::get(n_type
, 0));
117 Value posOne
= rewriter
.create
<spirv::ConstantOp
>(
118 loc
, n_type
, IntegerAttr::get(n_type
, 1));
119 Value negOne
= rewriter
.create
<spirv::ConstantOp
>(
120 loc
, n_type
, IntegerAttr::get(n_type
, -1));
123 Value mPos
= rewriter
.create
<spirv::SGreaterThanOp
>(loc
, m
, zero
);
124 Value x
= rewriter
.create
<spirv::SelectOp
>(loc
, mPos
, negOne
, posOne
);
126 // Compute the positive result.
127 Value nPlusX
= rewriter
.create
<spirv::IAddOp
>(loc
, n
, x
);
128 Value nPlusXDivM
= rewriter
.create
<spirv::SDivOp
>(loc
, nPlusX
, m
);
129 Value posRes
= rewriter
.create
<spirv::IAddOp
>(loc
, nPlusXDivM
, posOne
);
131 // Compute the negative result.
132 Value negN
= rewriter
.create
<spirv::ISubOp
>(loc
, zero
, n
);
133 Value negNDivM
= rewriter
.create
<spirv::SDivOp
>(loc
, negN
, m
);
134 Value negRes
= rewriter
.create
<spirv::ISubOp
>(loc
, zero
, negNDivM
);
136 // Pick the positive result if `n` and `m` have the same sign and `n` is
137 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138 Value nPos
= rewriter
.create
<spirv::SGreaterThanOp
>(loc
, n
, zero
);
139 Value sameSign
= rewriter
.create
<spirv::LogicalEqualOp
>(loc
, nPos
, mPos
);
140 Value nNonZero
= rewriter
.create
<spirv::INotEqualOp
>(loc
, n
, zero
);
141 Value cmp
= rewriter
.create
<spirv::LogicalAndOp
>(loc
, sameSign
, nNonZero
);
142 rewriter
.replaceOpWithNewOp
<spirv::SelectOp
>(op
, cmp
, posRes
, negRes
);
147 //===----------------------------------------------------------------------===//
148 // ConvertIndexCeilDivU
149 //===----------------------------------------------------------------------===//
151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152 /// from the equivalent conversion in IndexToLLVM.
153 struct ConvertIndexCeilDivUPattern final
: OpConversionPattern
<CeilDivUOp
> {
154 using OpConversionPattern::OpConversionPattern
;
157 matchAndRewrite(CeilDivUOp op
, CeilDivUOpAdaptor adaptor
,
158 ConversionPatternRewriter
&rewriter
) const override
{
159 Location loc
= op
.getLoc();
160 Value n
= adaptor
.getLhs();
161 Type n_type
= n
.getType();
162 Value m
= adaptor
.getRhs();
164 // Define the constants
165 Value zero
= rewriter
.create
<spirv::ConstantOp
>(
166 loc
, n_type
, IntegerAttr::get(n_type
, 0));
167 Value one
= rewriter
.create
<spirv::ConstantOp
>(loc
, n_type
,
168 IntegerAttr::get(n_type
, 1));
170 // Compute the non-zero result.
171 Value minusOne
= rewriter
.create
<spirv::ISubOp
>(loc
, n
, one
);
172 Value quotient
= rewriter
.create
<spirv::UDivOp
>(loc
, minusOne
, m
);
173 Value plusOne
= rewriter
.create
<spirv::IAddOp
>(loc
, quotient
, one
);
176 Value cmp
= rewriter
.create
<spirv::IEqualOp
>(loc
, n
, zero
);
177 rewriter
.replaceOpWithNewOp
<spirv::SelectOp
>(op
, cmp
, zero
, plusOne
);
182 //===----------------------------------------------------------------------===//
183 // ConvertIndexFloorDivS
184 //===----------------------------------------------------------------------===//
186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
189 struct ConvertIndexFloorDivSPattern final
: OpConversionPattern
<FloorDivSOp
> {
190 using OpConversionPattern::OpConversionPattern
;
193 matchAndRewrite(FloorDivSOp op
, FloorDivSOpAdaptor adaptor
,
194 ConversionPatternRewriter
&rewriter
) const override
{
195 Location loc
= op
.getLoc();
196 Value n
= adaptor
.getLhs();
197 Type n_type
= n
.getType();
198 Value m
= adaptor
.getRhs();
200 // Define the constants
201 Value zero
= rewriter
.create
<spirv::ConstantOp
>(
202 loc
, n_type
, IntegerAttr::get(n_type
, 0));
203 Value posOne
= rewriter
.create
<spirv::ConstantOp
>(
204 loc
, n_type
, IntegerAttr::get(n_type
, 1));
205 Value negOne
= rewriter
.create
<spirv::ConstantOp
>(
206 loc
, n_type
, IntegerAttr::get(n_type
, -1));
209 Value mNeg
= rewriter
.create
<spirv::SLessThanOp
>(loc
, m
, zero
);
210 Value x
= rewriter
.create
<spirv::SelectOp
>(loc
, mNeg
, posOne
, negOne
);
212 // Compute the negative result
213 Value xMinusN
= rewriter
.create
<spirv::ISubOp
>(loc
, x
, n
);
214 Value xMinusNDivM
= rewriter
.create
<spirv::SDivOp
>(loc
, xMinusN
, m
);
215 Value negRes
= rewriter
.create
<spirv::ISubOp
>(loc
, negOne
, xMinusNDivM
);
217 // Compute the positive result.
218 Value posRes
= rewriter
.create
<spirv::SDivOp
>(loc
, n
, m
);
220 // Pick the negative result if `n` and `m` have different signs and `n` is
221 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222 Value nNeg
= rewriter
.create
<spirv::SLessThanOp
>(loc
, n
, zero
);
223 Value diffSign
= rewriter
.create
<spirv::LogicalNotEqualOp
>(loc
, nNeg
, mNeg
);
224 Value nNonZero
= rewriter
.create
<spirv::INotEqualOp
>(loc
, n
, zero
);
226 Value cmp
= rewriter
.create
<spirv::LogicalAndOp
>(loc
, diffSign
, nNonZero
);
227 rewriter
.replaceOpWithNewOp
<spirv::SelectOp
>(op
, cmp
, posRes
, negRes
);
232 //===----------------------------------------------------------------------===//
234 //===----------------------------------------------------------------------===//
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp
, typename ConvertOp
>
241 struct ConvertIndexCast final
: OpConversionPattern
<CastOp
> {
242 using OpConversionPattern
<CastOp
>::OpConversionPattern
;
245 matchAndRewrite(CastOp op
, typename
CastOp::Adaptor adaptor
,
246 ConversionPatternRewriter
&rewriter
) const override
{
247 auto *typeConverter
= this->template getTypeConverter
<SPIRVTypeConverter
>();
248 Type indexType
= typeConverter
->getIndexType();
250 Type srcType
= adaptor
.getInput().getType();
251 Type dstType
= op
.getType();
252 if (isa
<IndexType
>(srcType
)) {
255 if (isa
<IndexType
>(dstType
)) {
259 if (srcType
== dstType
) {
260 rewriter
.replaceOp(op
, adaptor
.getInput());
262 rewriter
.template replaceOpWithNewOp
<ConvertOp
>(op
, dstType
,
263 adaptor
.getOperands());
269 using ConvertIndexCastS
= ConvertIndexCast
<CastSOp
, spirv::SConvertOp
>;
270 using ConvertIndexCastU
= ConvertIndexCast
<CastUOp
, spirv::UConvertOp
>;
272 //===----------------------------------------------------------------------===//
274 //===----------------------------------------------------------------------===//
276 // Helper template to replace the operation
277 template <typename ICmpOp
>
278 static LogicalResult
rewriteCmpOp(CmpOp op
, CmpOpAdaptor adaptor
,
279 ConversionPatternRewriter
&rewriter
) {
280 rewriter
.replaceOpWithNewOp
<ICmpOp
>(op
, adaptor
.getLhs(), adaptor
.getRhs());
284 struct ConvertIndexCmpPattern final
: OpConversionPattern
<CmpOp
> {
285 using OpConversionPattern::OpConversionPattern
;
288 matchAndRewrite(CmpOp op
, CmpOpAdaptor adaptor
,
289 ConversionPatternRewriter
&rewriter
) const override
{
290 // We must convert the predicates to the corresponding int comparions.
291 switch (op
.getPred()) {
292 case IndexCmpPredicate::EQ
:
293 return rewriteCmpOp
<spirv::IEqualOp
>(op
, adaptor
, rewriter
);
294 case IndexCmpPredicate::NE
:
295 return rewriteCmpOp
<spirv::INotEqualOp
>(op
, adaptor
, rewriter
);
296 case IndexCmpPredicate::SGE
:
297 return rewriteCmpOp
<spirv::SGreaterThanEqualOp
>(op
, adaptor
, rewriter
);
298 case IndexCmpPredicate::SGT
:
299 return rewriteCmpOp
<spirv::SGreaterThanOp
>(op
, adaptor
, rewriter
);
300 case IndexCmpPredicate::SLE
:
301 return rewriteCmpOp
<spirv::SLessThanEqualOp
>(op
, adaptor
, rewriter
);
302 case IndexCmpPredicate::SLT
:
303 return rewriteCmpOp
<spirv::SLessThanOp
>(op
, adaptor
, rewriter
);
304 case IndexCmpPredicate::UGE
:
305 return rewriteCmpOp
<spirv::UGreaterThanEqualOp
>(op
, adaptor
, rewriter
);
306 case IndexCmpPredicate::UGT
:
307 return rewriteCmpOp
<spirv::UGreaterThanOp
>(op
, adaptor
, rewriter
);
308 case IndexCmpPredicate::ULE
:
309 return rewriteCmpOp
<spirv::ULessThanEqualOp
>(op
, adaptor
, rewriter
);
310 case IndexCmpPredicate::ULT
:
311 return rewriteCmpOp
<spirv::ULessThanOp
>(op
, adaptor
, rewriter
);
313 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
317 //===----------------------------------------------------------------------===//
318 // ConvertIndexSizeOf
319 //===----------------------------------------------------------------------===//
321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322 struct ConvertIndexSizeOf final
: OpConversionPattern
<SizeOfOp
> {
323 using OpConversionPattern::OpConversionPattern
;
326 matchAndRewrite(SizeOfOp op
, SizeOfOpAdaptor adaptor
,
327 ConversionPatternRewriter
&rewriter
) const override
{
328 auto *typeConverter
= this->template getTypeConverter
<SPIRVTypeConverter
>();
329 Type indexType
= typeConverter
->getIndexType();
330 unsigned bitwidth
= typeConverter
->getIndexTypeBitwidth();
331 rewriter
.replaceOpWithNewOp
<spirv::ConstantOp
>(
332 op
, indexType
, IntegerAttr::get(indexType
, bitwidth
));
338 //===----------------------------------------------------------------------===//
339 // Pattern Population
340 //===----------------------------------------------------------------------===//
342 void index::populateIndexToSPIRVPatterns(
343 const SPIRVTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
363 ConvertIndexConstantBoolOpPattern
,
364 ConvertIndexConstantOpPattern
,
365 ConvertIndexCeilDivSPattern
,
366 ConvertIndexCeilDivUPattern
,
367 ConvertIndexFloorDivSPattern
,
370 ConvertIndexCmpPattern
,
372 >(typeConverter
, patterns
.getContext());
375 //===----------------------------------------------------------------------===//
376 // ODS-Generated Definitions
377 //===----------------------------------------------------------------------===//
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
384 //===----------------------------------------------------------------------===//
386 //===----------------------------------------------------------------------===//
389 struct ConvertIndexToSPIRVPass
390 : public impl::ConvertIndexToSPIRVPassBase
<ConvertIndexToSPIRVPass
> {
393 void runOnOperation() override
{
394 Operation
*op
= getOperation();
395 spirv::TargetEnvAttr targetAttr
= spirv::lookupTargetEnvOrDefault(op
);
396 std::unique_ptr
<SPIRVConversionTarget
> target
=
397 SPIRVConversionTarget::get(targetAttr
);
399 SPIRVConversionOptions options
;
400 options
.use64bitIndex
= this->use64bitIndex
;
401 SPIRVTypeConverter
typeConverter(targetAttr
, options
);
403 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404 // in patterns for other dialects.
405 target
->addLegalOp
<UnrealizedConversionCastOp
>();
407 // Allow the spirv operations we are converting to
408 target
->addLegalDialect
<spirv::SPIRVDialect
>();
409 // Fail hard when there are any remaining 'index' ops.
410 target
->addIllegalDialect
<index::IndexDialect
>();
412 RewritePatternSet
patterns(&getContext());
413 index::populateIndexToSPIRVPatterns(typeConverter
, patterns
);
415 if (failed(applyPartialConversion(op
, *target
, std::move(patterns
))))