[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Conversion / IndexToSPIRV / IndexToSPIRV.cpp
blob7c441830e1e3be4d6a99bec59d92ba6c329e3d53
1 //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
10 #include "../SPIRVCommon/Pattern.h"
11 #include "mlir/Dialect/Index/IR/IndexDialect.h"
12 #include "mlir/Dialect/Index/IR/IndexOps.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "mlir/Pass/Pass.h"
18 using namespace mlir;
19 using namespace index;
21 namespace {
23 //===----------------------------------------------------------------------===//
24 // Trivial Conversions
25 //===----------------------------------------------------------------------===//
27 using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
28 using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
29 using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
30 using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
31 using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
32 using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
33 using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
34 using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
35 using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
36 using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
37 using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
39 using ConvertIndexShl =
40 spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
41 using ConvertIndexShrS =
42 spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
43 using ConvertIndexShrU =
44 spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
46 /// It is the case that when we convert bitwise operations to SPIR-V operations
47 /// we must take into account the special pattern in SPIR-V that if the
48 /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49 /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50 /// index.add is never a boolean operation so we can directly convert it to the
51 /// Bitwise[And|Or]Op.
52 using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
53 using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
54 using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
56 //===----------------------------------------------------------------------===//
57 // ConvertConstantBool
58 //===----------------------------------------------------------------------===//
60 // Converts index.bool.constant operation to spirv.Constant.
61 struct ConvertIndexConstantBoolOpPattern final
62 : OpConversionPattern<BoolConstantOp> {
63 using OpConversionPattern::OpConversionPattern;
65 LogicalResult
66 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
67 ConversionPatternRewriter &rewriter) const override {
68 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
69 op.getValueAttr());
70 return success();
74 //===----------------------------------------------------------------------===//
75 // ConvertConstant
76 //===----------------------------------------------------------------------===//
78 // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
79 // when required.
80 struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
81 using OpConversionPattern::OpConversionPattern;
83 LogicalResult
84 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
87 Type indexType = typeConverter->getIndexType();
89 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
90 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
91 op, indexType, IntegerAttr::get(indexType, value));
92 return success();
96 //===----------------------------------------------------------------------===//
97 // ConvertIndexCeilDivS
98 //===----------------------------------------------------------------------===//
100 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102 /// conversion in IndexToLLVM.
103 struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
104 using OpConversionPattern::OpConversionPattern;
106 LogicalResult
107 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109 Location loc = op.getLoc();
110 Value n = adaptor.getLhs();
111 Type n_type = n.getType();
112 Value m = adaptor.getRhs();
114 // Define the constants
115 Value zero = rewriter.create<spirv::ConstantOp>(
116 loc, n_type, IntegerAttr::get(n_type, 0));
117 Value posOne = rewriter.create<spirv::ConstantOp>(
118 loc, n_type, IntegerAttr::get(n_type, 1));
119 Value negOne = rewriter.create<spirv::ConstantOp>(
120 loc, n_type, IntegerAttr::get(n_type, -1));
122 // Compute `x`.
123 Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
124 Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
126 // Compute the positive result.
127 Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
128 Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
129 Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
131 // Compute the negative result.
132 Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
133 Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
134 Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
136 // Pick the positive result if `n` and `m` have the same sign and `n` is
137 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138 Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
139 Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
141 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
142 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
143 return success();
147 //===----------------------------------------------------------------------===//
148 // ConvertIndexCeilDivU
149 //===----------------------------------------------------------------------===//
151 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152 /// from the equivalent conversion in IndexToLLVM.
153 struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
154 using OpConversionPattern::OpConversionPattern;
156 LogicalResult
157 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 Location loc = op.getLoc();
160 Value n = adaptor.getLhs();
161 Type n_type = n.getType();
162 Value m = adaptor.getRhs();
164 // Define the constants
165 Value zero = rewriter.create<spirv::ConstantOp>(
166 loc, n_type, IntegerAttr::get(n_type, 0));
167 Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
168 IntegerAttr::get(n_type, 1));
170 // Compute the non-zero result.
171 Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
172 Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
173 Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
175 // Pick the result
176 Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
177 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
178 return success();
182 //===----------------------------------------------------------------------===//
183 // ConvertIndexFloorDivS
184 //===----------------------------------------------------------------------===//
186 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187 /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
188 /// in IndexToLLVM.
189 struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
190 using OpConversionPattern::OpConversionPattern;
192 LogicalResult
193 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const override {
195 Location loc = op.getLoc();
196 Value n = adaptor.getLhs();
197 Type n_type = n.getType();
198 Value m = adaptor.getRhs();
200 // Define the constants
201 Value zero = rewriter.create<spirv::ConstantOp>(
202 loc, n_type, IntegerAttr::get(n_type, 0));
203 Value posOne = rewriter.create<spirv::ConstantOp>(
204 loc, n_type, IntegerAttr::get(n_type, 1));
205 Value negOne = rewriter.create<spirv::ConstantOp>(
206 loc, n_type, IntegerAttr::get(n_type, -1));
208 // Compute `x`.
209 Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
210 Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
212 // Compute the negative result
213 Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
214 Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
215 Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
217 // Compute the positive result.
218 Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
220 // Pick the negative result if `n` and `m` have different signs and `n` is
221 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222 Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
223 Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
226 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228 return success();
232 //===----------------------------------------------------------------------===//
233 // ConvertIndexCast
234 //===----------------------------------------------------------------------===//
236 /// Convert a cast op. If the materialized index type is the same as the other
237 /// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239 /// zero extend when the result bitwidth is larger.
240 template <typename CastOp, typename ConvertOp>
241 struct ConvertIndexCast final : OpConversionPattern<CastOp> {
242 using OpConversionPattern<CastOp>::OpConversionPattern;
244 LogicalResult
245 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(srcType)) {
253 srcType = indexType;
255 if (isa<IndexType>(dstType)) {
256 dstType = indexType;
259 if (srcType == dstType) {
260 rewriter.replaceOp(op, adaptor.getInput());
261 } else {
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
265 return success();
269 using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270 using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
272 //===----------------------------------------------------------------------===//
273 // ConvertIndexCmp
274 //===----------------------------------------------------------------------===//
276 // Helper template to replace the operation
277 template <typename ICmpOp>
278 static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) {
280 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281 return success();
284 struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
285 using OpConversionPattern::OpConversionPattern;
287 LogicalResult
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 // We must convert the predicates to the corresponding int comparions.
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
313 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
317 //===----------------------------------------------------------------------===//
318 // ConvertIndexSizeOf
319 //===----------------------------------------------------------------------===//
321 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
322 struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
323 using OpConversionPattern::OpConversionPattern;
325 LogicalResult
326 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter) const override {
328 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
329 Type indexType = typeConverter->getIndexType();
330 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332 op, indexType, IntegerAttr::get(indexType, bitwidth));
333 return success();
336 } // namespace
338 //===----------------------------------------------------------------------===//
339 // Pattern Population
340 //===----------------------------------------------------------------------===//
342 void index::populateIndexToSPIRVPatterns(
343 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
344 patterns.add<
345 // clang-format off
346 ConvertIndexAdd,
347 ConvertIndexSub,
348 ConvertIndexMul,
349 ConvertIndexDivS,
350 ConvertIndexDivU,
351 ConvertIndexRemS,
352 ConvertIndexRemU,
353 ConvertIndexMaxS,
354 ConvertIndexMaxU,
355 ConvertIndexMinS,
356 ConvertIndexMinU,
357 ConvertIndexShl,
358 ConvertIndexShrS,
359 ConvertIndexShrU,
360 ConvertIndexAnd,
361 ConvertIndexOr,
362 ConvertIndexXor,
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
368 ConvertIndexCastS,
369 ConvertIndexCastU,
370 ConvertIndexCmpPattern,
371 ConvertIndexSizeOf
372 >(typeConverter, patterns.getContext());
375 //===----------------------------------------------------------------------===//
376 // ODS-Generated Definitions
377 //===----------------------------------------------------------------------===//
379 namespace mlir {
380 #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381 #include "mlir/Conversion/Passes.h.inc"
382 } // namespace mlir
384 //===----------------------------------------------------------------------===//
385 // Pass Definition
386 //===----------------------------------------------------------------------===//
388 namespace {
389 struct ConvertIndexToSPIRVPass
390 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
391 using Base::Base;
393 void runOnOperation() override {
394 Operation *op = getOperation();
395 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
396 std::unique_ptr<SPIRVConversionTarget> target =
397 SPIRVConversionTarget::get(targetAttr);
399 SPIRVConversionOptions options;
400 options.use64bitIndex = this->use64bitIndex;
401 SPIRVTypeConverter typeConverter(targetAttr, options);
403 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
404 // in patterns for other dialects.
405 target->addLegalOp<UnrealizedConversionCastOp>();
407 // Allow the spirv operations we are converting to
408 target->addLegalDialect<spirv::SPIRVDialect>();
409 // Fail hard when there are any remaining 'index' ops.
410 target->addIllegalDialect<index::IndexDialect>();
412 RewritePatternSet patterns(&getContext());
413 index::populateIndexToSPIRVPatterns(typeConverter, patterns);
415 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
416 signalPassFailure();
419 } // namespace