Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / VectorToSPIRV / VectorToSPIRV.cpp
blob656b1cb3e99a1d54cbe34478bbe62c9e66725dcd
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include <cassert>
36 #include <cstdint>
37 #include <numeric>
39 using namespace mlir;
41 /// Returns the integer value from the first valid input element, assuming Value
42 /// inputs are defined by a constant index ops and Attribute inputs are integer
43 /// attributes.
44 static uint64_t getFirstIntValue(ArrayAttr attr) {
45 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
48 /// Returns the number of bits for the given scalar/vector type.
49 static int getNumBits(Type type) {
50 // TODO: This does not take into account any memory layout or widening
51 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
52 // though in practice it will likely be stored as in a 4xi64 vector register.
53 if (auto vectorType = dyn_cast<VectorType>(type))
54 return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
55 return type.getIntOrFloatBitWidth();
58 namespace {
60 struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
61 using OpConversionPattern::OpConversionPattern;
63 LogicalResult
64 matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
65 ConversionPatternRewriter &rewriter) const override {
66 Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
67 if (!dstType)
68 return failure();
70 // If dstType is same as the source type or the vector size is 1, it can be
71 // directly replaced by the source.
72 if (dstType == adaptor.getSource().getType() ||
73 shapeCastOp.getResultVectorType().getNumElements() == 1) {
74 rewriter.replaceOp(shapeCastOp, adaptor.getSource());
75 return success();
78 // Lowering for size-n vectors when n > 1 hasn't been implemented.
79 return failure();
83 struct VectorBitcastConvert final
84 : public OpConversionPattern<vector::BitCastOp> {
85 using OpConversionPattern::OpConversionPattern;
87 LogicalResult
88 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
91 if (!dstType)
92 return failure();
94 if (dstType == adaptor.getSource().getType()) {
95 rewriter.replaceOp(bitcastOp, adaptor.getSource());
96 return success();
99 // Check that the source and destination type have the same bitwidth.
100 // Depending on the target environment, we may need to emulate certain
101 // types, which can cause issue with bitcast.
102 Type srcType = adaptor.getSource().getType();
103 if (getNumBits(dstType) != getNumBits(srcType)) {
104 return rewriter.notifyMatchFailure(
105 bitcastOp,
106 llvm::formatv("different source ({0}) and target ({1}) bitwidth",
107 srcType, dstType));
110 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
111 adaptor.getSource());
112 return success();
116 struct VectorBroadcastConvert final
117 : public OpConversionPattern<vector::BroadcastOp> {
118 using OpConversionPattern::OpConversionPattern;
120 LogicalResult
121 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 Type resultType =
124 getTypeConverter()->convertType(castOp.getResultVectorType());
125 if (!resultType)
126 return failure();
128 if (isa<spirv::ScalarType>(resultType)) {
129 rewriter.replaceOp(castOp, adaptor.getSource());
130 return success();
133 SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
134 adaptor.getSource());
135 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
136 source);
137 return success();
141 struct VectorExtractOpConvert final
142 : public OpConversionPattern<vector::ExtractOp> {
143 using OpConversionPattern::OpConversionPattern;
145 LogicalResult
146 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 Type dstType = getTypeConverter()->convertType(extractOp.getType());
149 if (!dstType)
150 return failure();
152 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
153 rewriter.replaceOp(extractOp, adaptor.getVector());
154 return success();
157 if (std::optional<int64_t> id =
158 getConstantIntValue(extractOp.getMixedPosition()[0]))
159 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
160 extractOp, dstType, adaptor.getVector(),
161 rewriter.getI32ArrayAttr(id.value()));
162 else
163 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
164 extractOp, dstType, adaptor.getVector(),
165 adaptor.getDynamicPosition()[0]);
166 return success();
170 struct VectorExtractStridedSliceOpConvert final
171 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
172 using OpConversionPattern::OpConversionPattern;
174 LogicalResult
175 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
176 ConversionPatternRewriter &rewriter) const override {
177 Type dstType = getTypeConverter()->convertType(extractOp.getType());
178 if (!dstType)
179 return failure();
181 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
182 uint64_t size = getFirstIntValue(extractOp.getSizes());
183 uint64_t stride = getFirstIntValue(extractOp.getStrides());
184 if (stride != 1)
185 return failure();
187 Value srcVector = adaptor.getOperands().front();
189 // Extract vector<1xT> case.
190 if (isa<spirv::ScalarType>(dstType)) {
191 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
192 srcVector, offset);
193 return success();
196 SmallVector<int32_t, 2> indices(size);
197 std::iota(indices.begin(), indices.end(), offset);
199 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
200 extractOp, dstType, srcVector, srcVector,
201 rewriter.getI32ArrayAttr(indices));
203 return success();
207 template <class SPIRVFMAOp>
208 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
209 using OpConversionPattern::OpConversionPattern;
211 LogicalResult
212 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override {
214 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
215 if (!dstType)
216 return failure();
217 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
218 adaptor.getRhs(), adaptor.getAcc());
219 return success();
223 struct VectorInsertOpConvert final
224 : public OpConversionPattern<vector::InsertOp> {
225 using OpConversionPattern::OpConversionPattern;
227 LogicalResult
228 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter) const override {
230 if (isa<VectorType>(insertOp.getSourceType()))
231 return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
232 if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
233 return rewriter.notifyMatchFailure(insertOp,
234 "unsupported dest vector type");
236 // Special case for inserting scalar values into size-1 vectors.
237 if (insertOp.getSourceType().isIntOrFloat() &&
238 insertOp.getDestVectorType().getNumElements() == 1) {
239 rewriter.replaceOp(insertOp, adaptor.getSource());
240 return success();
243 if (std::optional<int64_t> id =
244 getConstantIntValue(insertOp.getMixedPosition()[0]))
245 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
246 insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
247 else
248 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
249 insertOp, insertOp.getDest(), adaptor.getSource(),
250 adaptor.getDynamicPosition()[0]);
251 return success();
255 struct VectorExtractElementOpConvert final
256 : public OpConversionPattern<vector::ExtractElementOp> {
257 using OpConversionPattern::OpConversionPattern;
259 LogicalResult
260 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 Type resultType = getTypeConverter()->convertType(extractOp.getType());
263 if (!resultType)
264 return failure();
266 if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
267 rewriter.replaceOp(extractOp, adaptor.getVector());
268 return success();
271 APInt cstPos;
272 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
273 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
274 extractOp, resultType, adaptor.getVector(),
275 rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
276 else
277 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
278 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
279 return success();
283 struct VectorInsertElementOpConvert final
284 : public OpConversionPattern<vector::InsertElementOp> {
285 using OpConversionPattern::OpConversionPattern;
287 LogicalResult
288 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
291 if (!vectorType)
292 return failure();
294 if (isa<spirv::ScalarType>(vectorType)) {
295 rewriter.replaceOp(insertOp, adaptor.getSource());
296 return success();
299 APInt cstPos;
300 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
301 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
302 insertOp, adaptor.getSource(), adaptor.getDest(),
303 cstPos.getSExtValue());
304 else
305 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
306 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
307 adaptor.getPosition());
308 return success();
312 struct VectorInsertStridedSliceOpConvert final
313 : public OpConversionPattern<vector::InsertStridedSliceOp> {
314 using OpConversionPattern::OpConversionPattern;
316 LogicalResult
317 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override {
319 Value srcVector = adaptor.getOperands().front();
320 Value dstVector = adaptor.getOperands().back();
322 uint64_t stride = getFirstIntValue(insertOp.getStrides());
323 if (stride != 1)
324 return failure();
325 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
327 if (isa<spirv::ScalarType>(srcVector.getType())) {
328 assert(!isa<spirv::ScalarType>(dstVector.getType()));
329 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
330 insertOp, dstVector.getType(), srcVector, dstVector,
331 rewriter.getI32ArrayAttr(offset));
332 return success();
335 uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
336 uint64_t insertSize =
337 cast<VectorType>(srcVector.getType()).getNumElements();
339 SmallVector<int32_t, 2> indices(totalSize);
340 std::iota(indices.begin(), indices.end(), 0);
341 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
342 totalSize);
344 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
345 insertOp, dstVector.getType(), dstVector, srcVector,
346 rewriter.getI32ArrayAttr(indices));
348 return success();
352 static SmallVector<Value> extractAllElements(
353 vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
354 VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
355 int numElements = static_cast<int>(srcVectorType.getDimSize(0));
356 SmallVector<Value> values;
357 values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
358 Location loc = reduceOp.getLoc();
360 for (int i = 0; i < numElements; ++i) {
361 values.push_back(rewriter.create<spirv::CompositeExtractOp>(
362 loc, srcVectorType.getElementType(), adaptor.getVector(),
363 rewriter.getI32ArrayAttr({i})));
365 if (Value acc = adaptor.getAcc())
366 values.push_back(acc);
368 return values;
371 struct ReductionRewriteInfo {
372 Type resultType;
373 SmallVector<Value> extractedElements;
376 FailureOr<ReductionRewriteInfo> static getReductionInfo(
377 vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
378 ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
379 Type resultType = typeConverter.convertType(op.getType());
380 if (!resultType)
381 return failure();
383 auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
384 if (!srcVectorType || srcVectorType.getRank() != 1)
385 return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
387 SmallVector<Value> extractedElements =
388 extractAllElements(op, adaptor, srcVectorType, rewriter);
390 return ReductionRewriteInfo{resultType, std::move(extractedElements)};
393 template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
394 typename SPIRVSMinOp>
395 struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
396 using OpConversionPattern::OpConversionPattern;
398 LogicalResult
399 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
400 ConversionPatternRewriter &rewriter) const override {
401 auto reductionInfo =
402 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
403 if (failed(reductionInfo))
404 return failure();
406 auto [resultType, extractedElements] = *reductionInfo;
407 Location loc = reduceOp->getLoc();
408 Value result = extractedElements.front();
409 for (Value next : llvm::drop_begin(extractedElements)) {
410 switch (reduceOp.getKind()) {
412 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
413 case vector::CombiningKind::kind: \
414 if (llvm::isa<IntegerType>(resultType)) { \
415 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
416 } else { \
417 assert(llvm::isa<FloatType>(resultType)); \
418 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
420 break
422 #define INT_OR_FLOAT_CASE(kind, fop) \
423 case vector::CombiningKind::kind: \
424 result = rewriter.create<fop>(loc, resultType, result, next); \
425 break
427 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
428 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
429 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
430 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
431 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
432 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
434 case vector::CombiningKind::AND:
435 case vector::CombiningKind::OR:
436 case vector::CombiningKind::XOR:
437 return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
438 default:
439 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
441 #undef INT_AND_FLOAT_CASE
442 #undef INT_OR_FLOAT_CASE
445 rewriter.replaceOp(reduceOp, result);
446 return success();
450 template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
451 struct VectorReductionFloatMinMax final
452 : OpConversionPattern<vector::ReductionOp> {
453 using OpConversionPattern::OpConversionPattern;
455 LogicalResult
456 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 auto reductionInfo =
459 getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
460 if (failed(reductionInfo))
461 return failure();
463 auto [resultType, extractedElements] = *reductionInfo;
464 Location loc = reduceOp->getLoc();
465 Value result = extractedElements.front();
466 for (Value next : llvm::drop_begin(extractedElements)) {
467 switch (reduceOp.getKind()) {
469 #define INT_OR_FLOAT_CASE(kind, fop) \
470 case vector::CombiningKind::kind: \
471 result = rewriter.create<fop>(loc, resultType, result, next); \
472 break
474 INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
475 INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
476 INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
477 INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
479 default:
480 return rewriter.notifyMatchFailure(reduceOp, "not handled here");
482 #undef INT_OR_FLOAT_CASE
485 rewriter.replaceOp(reduceOp, result);
486 return success();
490 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
491 public:
492 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
494 LogicalResult
495 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const override {
497 Type dstType = getTypeConverter()->convertType(op.getType());
498 if (!dstType)
499 return failure();
500 if (isa<spirv::ScalarType>(dstType)) {
501 rewriter.replaceOp(op, adaptor.getInput());
502 } else {
503 auto dstVecType = cast<VectorType>(dstType);
504 SmallVector<Value, 4> source(dstVecType.getNumElements(),
505 adaptor.getInput());
506 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
507 source);
509 return success();
513 struct VectorShuffleOpConvert final
514 : public OpConversionPattern<vector::ShuffleOp> {
515 using OpConversionPattern::OpConversionPattern;
517 LogicalResult
518 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
519 ConversionPatternRewriter &rewriter) const override {
520 VectorType oldResultType = shuffleOp.getResultVectorType();
521 Type newResultType = getTypeConverter()->convertType(oldResultType);
522 if (!newResultType)
523 return rewriter.notifyMatchFailure(shuffleOp,
524 "unsupported result vector type");
526 auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
528 VectorType oldV1Type = shuffleOp.getV1VectorType();
529 VectorType oldV2Type = shuffleOp.getV2VectorType();
531 // When both operands and the result are SPIR-V vectors, emit a SPIR-V
532 // shuffle.
533 if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
534 oldResultType.getNumElements() > 1) {
535 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
536 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
537 rewriter.getI32ArrayAttr(mask));
538 return success();
541 // When at least one of the operands or the result becomes a scalar after
542 // type conversion for SPIR-V, extract all the required elements and
543 // construct the result vector.
544 auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
545 Value scalarOrVec, int32_t idx) -> Value {
546 if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
547 return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
548 idx);
550 assert(idx == 0 && "Invalid scalar element index");
551 return scalarOrVec;
554 int32_t numV1Elems = oldV1Type.getNumElements();
555 SmallVector<Value> newOperands(mask.size());
556 for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
557 Value vec = adaptor.getV1();
558 int32_t elementIdx = shuffleIdx;
559 if (elementIdx >= numV1Elems) {
560 vec = adaptor.getV2();
561 elementIdx -= numV1Elems;
564 newOperand = getElementAtIdx(vec, elementIdx);
567 // Handle the scalar result corner case.
568 if (newOperands.size() == 1) {
569 rewriter.replaceOp(shuffleOp, newOperands.front());
570 return success();
573 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
574 shuffleOp, newResultType, newOperands);
575 return success();
579 struct VectorInterleaveOpConvert final
580 : public OpConversionPattern<vector::InterleaveOp> {
581 using OpConversionPattern::OpConversionPattern;
583 LogicalResult
584 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const override {
586 // Check the result vector type.
587 VectorType oldResultType = interleaveOp.getResultVectorType();
588 Type newResultType = getTypeConverter()->convertType(oldResultType);
589 if (!newResultType)
590 return rewriter.notifyMatchFailure(interleaveOp,
591 "unsupported result vector type");
593 // Interleave the indices.
594 VectorType sourceType = interleaveOp.getSourceVectorType();
595 int n = sourceType.getNumElements();
597 // Input vectors of size 1 are converted to scalars by the type converter.
598 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
599 // use `spirv::CompositeConstructOp`.
600 if (n == 1) {
601 Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
602 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
603 interleaveOp, newResultType, newOperands);
604 return success();
607 auto seq = llvm::seq<int64_t>(2 * n);
608 auto indices = llvm::map_to_vector(
609 seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
611 // Emit a SPIR-V shuffle.
612 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
613 interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
614 rewriter.getI32ArrayAttr(indices));
616 return success();
620 struct VectorDeinterleaveOpConvert final
621 : public OpConversionPattern<vector::DeinterleaveOp> {
622 using OpConversionPattern::OpConversionPattern;
624 LogicalResult
625 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
626 ConversionPatternRewriter &rewriter) const override {
628 // Check the result vector type.
629 VectorType oldResultType = deinterleaveOp.getResultVectorType();
630 Type newResultType = getTypeConverter()->convertType(oldResultType);
631 if (!newResultType)
632 return rewriter.notifyMatchFailure(deinterleaveOp,
633 "unsupported result vector type");
635 Location loc = deinterleaveOp->getLoc();
637 // Deinterleave the indices.
638 Value sourceVector = adaptor.getSource();
639 VectorType sourceType = deinterleaveOp.getSourceVectorType();
640 int n = sourceType.getNumElements();
642 // Output vectors of size 1 are converted to scalars by the type converter.
643 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
644 // use `spirv::CompositeExtractOp`.
645 if (n == 2) {
646 auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
647 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
649 auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
650 loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
652 rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
653 return success();
656 // Indices for `shuffleEven` (result 0).
657 auto seqEven = llvm::seq<int64_t>(n / 2);
658 auto indicesEven =
659 llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
661 // Indices for `shuffleOdd` (result 1).
662 auto seqOdd = llvm::seq<int64_t>(n / 2);
663 auto indicesOdd =
664 llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
666 // Create two SPIR-V shuffles.
667 auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
668 loc, newResultType, sourceVector, sourceVector,
669 rewriter.getI32ArrayAttr(indicesEven));
671 auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
672 loc, newResultType, sourceVector, sourceVector,
673 rewriter.getI32ArrayAttr(indicesOdd));
675 rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
676 return success();
680 struct VectorLoadOpConverter final
681 : public OpConversionPattern<vector::LoadOp> {
682 using OpConversionPattern::OpConversionPattern;
684 LogicalResult
685 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) const override {
687 auto memrefType = loadOp.getMemRefType();
688 auto attr =
689 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
690 if (!attr)
691 return rewriter.notifyMatchFailure(
692 loadOp, "expected spirv.storage_class memory space");
694 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
695 auto loc = loadOp.getLoc();
696 Value accessChain =
697 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
698 adaptor.getIndices(), loc, rewriter);
699 if (!accessChain)
700 return rewriter.notifyMatchFailure(
701 loadOp, "failed to get memref element pointer");
703 spirv::StorageClass storageClass = attr.getValue();
704 auto vectorType = loadOp.getVectorType();
705 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
706 Value castedAccessChain =
707 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
708 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
709 castedAccessChain);
711 return success();
715 struct VectorStoreOpConverter final
716 : public OpConversionPattern<vector::StoreOp> {
717 using OpConversionPattern::OpConversionPattern;
719 LogicalResult
720 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter) const override {
722 auto memrefType = storeOp.getMemRefType();
723 auto attr =
724 dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
725 if (!attr)
726 return rewriter.notifyMatchFailure(
727 storeOp, "expected spirv.storage_class memory space");
729 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
730 auto loc = storeOp.getLoc();
731 Value accessChain =
732 spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
733 adaptor.getIndices(), loc, rewriter);
734 if (!accessChain)
735 return rewriter.notifyMatchFailure(
736 storeOp, "failed to get memref element pointer");
738 spirv::StorageClass storageClass = attr.getValue();
739 auto vectorType = storeOp.getVectorType();
740 auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
741 Value castedAccessChain =
742 rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
743 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
744 adaptor.getValueToStore());
746 return success();
750 struct VectorReductionToIntDotProd final
751 : OpRewritePattern<vector::ReductionOp> {
752 using OpRewritePattern::OpRewritePattern;
754 LogicalResult matchAndRewrite(vector::ReductionOp op,
755 PatternRewriter &rewriter) const override {
756 if (op.getKind() != vector::CombiningKind::ADD)
757 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
759 auto resultType = dyn_cast<IntegerType>(op.getType());
760 if (!resultType)
761 return rewriter.notifyMatchFailure(op, "result is not an integer");
763 int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
764 if (!llvm::is_contained({32, 64}, resultBitwidth))
765 return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
767 VectorType inVecTy = op.getSourceVectorType();
768 if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
769 inVecTy.getShape().size() != 1 || inVecTy.isScalable())
770 return rewriter.notifyMatchFailure(op, "unsupported vector shape");
772 auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
773 if (!mul)
774 return rewriter.notifyMatchFailure(
775 op, "reduction operand is not 'arith.muli'");
777 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
778 spirv::SDotAccSatOp, false>(op, mul, rewriter)))
779 return success();
781 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
782 spirv::UDotAccSatOp, false>(op, mul, rewriter)))
783 return success();
785 if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
786 spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
787 return success();
789 if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
790 spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
791 return success();
793 return failure();
796 private:
797 template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
798 typename DotAccOp, bool SwapOperands>
799 static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
800 PatternRewriter &rewriter) {
801 auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
802 if (!lhs)
803 return failure();
804 Value lhsIn = lhs.getIn();
805 auto lhsInType = cast<VectorType>(lhsIn.getType());
806 if (!lhsInType.getElementType().isInteger(8))
807 return failure();
809 auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
810 if (!rhs)
811 return failure();
812 Value rhsIn = rhs.getIn();
813 auto rhsInType = cast<VectorType>(rhsIn.getType());
814 if (!rhsInType.getElementType().isInteger(8))
815 return failure();
817 if (op.getSourceVectorType().getNumElements() == 3) {
818 IntegerType i8Type = rewriter.getI8Type();
819 auto v4i8Type = VectorType::get({4}, i8Type);
820 Location loc = op.getLoc();
821 Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
822 lhsIn = rewriter.create<spirv::CompositeConstructOp>(
823 loc, v4i8Type, ValueRange{lhsIn, zero});
824 rhsIn = rewriter.create<spirv::CompositeConstructOp>(
825 loc, v4i8Type, ValueRange{rhsIn, zero});
828 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
829 // we have to swap operands instead in that case.
830 if (SwapOperands)
831 std::swap(lhsIn, rhsIn);
833 if (Value acc = op.getAcc()) {
834 rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
835 nullptr);
836 } else {
837 rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
838 nullptr);
841 return success();
845 struct VectorReductionToFPDotProd final
846 : OpConversionPattern<vector::ReductionOp> {
847 using OpConversionPattern::OpConversionPattern;
849 LogicalResult
850 matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
851 ConversionPatternRewriter &rewriter) const override {
852 if (op.getKind() != vector::CombiningKind::ADD)
853 return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
855 auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
856 if (!resultType)
857 return rewriter.notifyMatchFailure(op, "result is not a float");
859 Value vec = adaptor.getVector();
860 Value acc = adaptor.getAcc();
862 auto vectorType = dyn_cast<VectorType>(vec.getType());
863 if (!vectorType) {
864 assert(isa<FloatType>(vec.getType()) &&
865 "Expected the vector to be scalarized");
866 if (acc) {
867 rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
868 return success();
871 rewriter.replaceOp(op, vec);
872 return success();
875 Location loc = op.getLoc();
876 Value lhs;
877 Value rhs;
878 if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
879 lhs = mul.getLhs();
880 rhs = mul.getRhs();
881 } else {
882 // If the operand is not a mul, use a vector of ones for the dot operand
883 // to just sum up all values.
884 lhs = vec;
885 Attribute oneAttr =
886 rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
887 oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
888 rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
890 assert(lhs);
891 assert(rhs);
893 Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
894 if (acc)
895 res = rewriter.create<spirv::FAddOp>(loc, acc, res);
897 rewriter.replaceOp(op, res);
898 return success();
902 struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
903 using OpConversionPattern::OpConversionPattern;
905 LogicalResult
906 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
907 ConversionPatternRewriter &rewriter) const override {
908 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
909 Type dstType = typeConverter.convertType(stepOp.getType());
910 if (!dstType)
911 return failure();
913 Location loc = stepOp.getLoc();
914 int64_t numElements = stepOp.getType().getNumElements();
915 auto intType =
916 rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
918 // Input vectors of size 1 are converted to scalars by the type converter.
919 // We just create a constant in this case.
920 if (numElements == 1) {
921 Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
922 rewriter.replaceOp(stepOp, zero);
923 return success();
926 SmallVector<Value> source;
927 source.reserve(numElements);
928 for (int64_t i = 0; i < numElements; ++i) {
929 Attribute intAttr = rewriter.getIntegerAttr(intType, i);
930 Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
931 source.push_back(constOp);
933 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
934 source);
935 return success();
939 } // namespace
940 #define CL_INT_MAX_MIN_OPS \
941 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
943 #define GL_INT_MAX_MIN_OPS \
944 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
946 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
947 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
949 void mlir::populateVectorToSPIRVPatterns(
950 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
951 patterns.add<
952 VectorBitcastConvert, VectorBroadcastConvert,
953 VectorExtractElementOpConvert, VectorExtractOpConvert,
954 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
955 VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
956 VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
957 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
958 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
959 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
960 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
961 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
962 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
963 VectorStepOpConvert>(typeConverter, patterns.getContext(),
964 PatternBenefit(1));
966 // Make sure that the more specialized dot product pattern has higher benefit
967 // than the generic one that extracts all elements.
968 patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
969 PatternBenefit(2));
972 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
973 RewritePatternSet &patterns) {
974 patterns.add<VectorReductionToIntDotProd>(patterns.getContext());