1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/TypeUtilities.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/SmallVectorExtras.h"
34 #include "llvm/Support/FormatVariadic.h"
41 /// Returns the integer value from the first valid input element, assuming Value
42 /// inputs are defined by a constant index ops and Attribute inputs are integer
44 static uint64_t getFirstIntValue(ArrayAttr attr
) {
45 return (*attr
.getAsValueRange
<IntegerAttr
>().begin()).getZExtValue();
48 /// Returns the number of bits for the given scalar/vector type.
49 static int getNumBits(Type type
) {
50 // TODO: This does not take into account any memory layout or widening
51 // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
52 // though in practice it will likely be stored as in a 4xi64 vector register.
53 if (auto vectorType
= dyn_cast
<VectorType
>(type
))
54 return vectorType
.getNumElements() * vectorType
.getElementTypeBitWidth();
55 return type
.getIntOrFloatBitWidth();
60 struct VectorShapeCast final
: public OpConversionPattern
<vector::ShapeCastOp
> {
61 using OpConversionPattern::OpConversionPattern
;
64 matchAndRewrite(vector::ShapeCastOp shapeCastOp
, OpAdaptor adaptor
,
65 ConversionPatternRewriter
&rewriter
) const override
{
66 Type dstType
= getTypeConverter()->convertType(shapeCastOp
.getType());
70 // If dstType is same as the source type or the vector size is 1, it can be
71 // directly replaced by the source.
72 if (dstType
== adaptor
.getSource().getType() ||
73 shapeCastOp
.getResultVectorType().getNumElements() == 1) {
74 rewriter
.replaceOp(shapeCastOp
, adaptor
.getSource());
78 // Lowering for size-n vectors when n > 1 hasn't been implemented.
83 struct VectorBitcastConvert final
84 : public OpConversionPattern
<vector::BitCastOp
> {
85 using OpConversionPattern::OpConversionPattern
;
88 matchAndRewrite(vector::BitCastOp bitcastOp
, OpAdaptor adaptor
,
89 ConversionPatternRewriter
&rewriter
) const override
{
90 Type dstType
= getTypeConverter()->convertType(bitcastOp
.getType());
94 if (dstType
== adaptor
.getSource().getType()) {
95 rewriter
.replaceOp(bitcastOp
, adaptor
.getSource());
99 // Check that the source and destination type have the same bitwidth.
100 // Depending on the target environment, we may need to emulate certain
101 // types, which can cause issue with bitcast.
102 Type srcType
= adaptor
.getSource().getType();
103 if (getNumBits(dstType
) != getNumBits(srcType
)) {
104 return rewriter
.notifyMatchFailure(
106 llvm::formatv("different source ({0}) and target ({1}) bitwidth",
110 rewriter
.replaceOpWithNewOp
<spirv::BitcastOp
>(bitcastOp
, dstType
,
111 adaptor
.getSource());
116 struct VectorBroadcastConvert final
117 : public OpConversionPattern
<vector::BroadcastOp
> {
118 using OpConversionPattern::OpConversionPattern
;
121 matchAndRewrite(vector::BroadcastOp castOp
, OpAdaptor adaptor
,
122 ConversionPatternRewriter
&rewriter
) const override
{
124 getTypeConverter()->convertType(castOp
.getResultVectorType());
128 if (isa
<spirv::ScalarType
>(resultType
)) {
129 rewriter
.replaceOp(castOp
, adaptor
.getSource());
133 SmallVector
<Value
, 4> source(castOp
.getResultVectorType().getNumElements(),
134 adaptor
.getSource());
135 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(castOp
, resultType
,
141 struct VectorExtractOpConvert final
142 : public OpConversionPattern
<vector::ExtractOp
> {
143 using OpConversionPattern::OpConversionPattern
;
146 matchAndRewrite(vector::ExtractOp extractOp
, OpAdaptor adaptor
,
147 ConversionPatternRewriter
&rewriter
) const override
{
148 Type dstType
= getTypeConverter()->convertType(extractOp
.getType());
152 if (isa
<spirv::ScalarType
>(adaptor
.getVector().getType())) {
153 rewriter
.replaceOp(extractOp
, adaptor
.getVector());
157 if (std::optional
<int64_t> id
=
158 getConstantIntValue(extractOp
.getMixedPosition()[0]))
159 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
160 extractOp
, dstType
, adaptor
.getVector(),
161 rewriter
.getI32ArrayAttr(id
.value()));
163 rewriter
.replaceOpWithNewOp
<spirv::VectorExtractDynamicOp
>(
164 extractOp
, dstType
, adaptor
.getVector(),
165 adaptor
.getDynamicPosition()[0]);
170 struct VectorExtractStridedSliceOpConvert final
171 : public OpConversionPattern
<vector::ExtractStridedSliceOp
> {
172 using OpConversionPattern::OpConversionPattern
;
175 matchAndRewrite(vector::ExtractStridedSliceOp extractOp
, OpAdaptor adaptor
,
176 ConversionPatternRewriter
&rewriter
) const override
{
177 Type dstType
= getTypeConverter()->convertType(extractOp
.getType());
181 uint64_t offset
= getFirstIntValue(extractOp
.getOffsets());
182 uint64_t size
= getFirstIntValue(extractOp
.getSizes());
183 uint64_t stride
= getFirstIntValue(extractOp
.getStrides());
187 Value srcVector
= adaptor
.getOperands().front();
189 // Extract vector<1xT> case.
190 if (isa
<spirv::ScalarType
>(dstType
)) {
191 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(extractOp
,
196 SmallVector
<int32_t, 2> indices(size
);
197 std::iota(indices
.begin(), indices
.end(), offset
);
199 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
200 extractOp
, dstType
, srcVector
, srcVector
,
201 rewriter
.getI32ArrayAttr(indices
));
207 template <class SPIRVFMAOp
>
208 struct VectorFmaOpConvert final
: public OpConversionPattern
<vector::FMAOp
> {
209 using OpConversionPattern::OpConversionPattern
;
212 matchAndRewrite(vector::FMAOp fmaOp
, OpAdaptor adaptor
,
213 ConversionPatternRewriter
&rewriter
) const override
{
214 Type dstType
= getTypeConverter()->convertType(fmaOp
.getType());
217 rewriter
.replaceOpWithNewOp
<SPIRVFMAOp
>(fmaOp
, dstType
, adaptor
.getLhs(),
218 adaptor
.getRhs(), adaptor
.getAcc());
223 struct VectorInsertOpConvert final
224 : public OpConversionPattern
<vector::InsertOp
> {
225 using OpConversionPattern::OpConversionPattern
;
228 matchAndRewrite(vector::InsertOp insertOp
, OpAdaptor adaptor
,
229 ConversionPatternRewriter
&rewriter
) const override
{
230 if (isa
<VectorType
>(insertOp
.getSourceType()))
231 return rewriter
.notifyMatchFailure(insertOp
, "unsupported vector source");
232 if (!getTypeConverter()->convertType(insertOp
.getDestVectorType()))
233 return rewriter
.notifyMatchFailure(insertOp
,
234 "unsupported dest vector type");
236 // Special case for inserting scalar values into size-1 vectors.
237 if (insertOp
.getSourceType().isIntOrFloat() &&
238 insertOp
.getDestVectorType().getNumElements() == 1) {
239 rewriter
.replaceOp(insertOp
, adaptor
.getSource());
243 if (std::optional
<int64_t> id
=
244 getConstantIntValue(insertOp
.getMixedPosition()[0]))
245 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
246 insertOp
, adaptor
.getSource(), adaptor
.getDest(), id
.value());
248 rewriter
.replaceOpWithNewOp
<spirv::VectorInsertDynamicOp
>(
249 insertOp
, insertOp
.getDest(), adaptor
.getSource(),
250 adaptor
.getDynamicPosition()[0]);
255 struct VectorExtractElementOpConvert final
256 : public OpConversionPattern
<vector::ExtractElementOp
> {
257 using OpConversionPattern::OpConversionPattern
;
260 matchAndRewrite(vector::ExtractElementOp extractOp
, OpAdaptor adaptor
,
261 ConversionPatternRewriter
&rewriter
) const override
{
262 Type resultType
= getTypeConverter()->convertType(extractOp
.getType());
266 if (isa
<spirv::ScalarType
>(adaptor
.getVector().getType())) {
267 rewriter
.replaceOp(extractOp
, adaptor
.getVector());
272 if (matchPattern(adaptor
.getPosition(), m_ConstantInt(&cstPos
)))
273 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
274 extractOp
, resultType
, adaptor
.getVector(),
275 rewriter
.getI32ArrayAttr({static_cast<int>(cstPos
.getSExtValue())}));
277 rewriter
.replaceOpWithNewOp
<spirv::VectorExtractDynamicOp
>(
278 extractOp
, resultType
, adaptor
.getVector(), adaptor
.getPosition());
283 struct VectorInsertElementOpConvert final
284 : public OpConversionPattern
<vector::InsertElementOp
> {
285 using OpConversionPattern::OpConversionPattern
;
288 matchAndRewrite(vector::InsertElementOp insertOp
, OpAdaptor adaptor
,
289 ConversionPatternRewriter
&rewriter
) const override
{
290 Type vectorType
= getTypeConverter()->convertType(insertOp
.getType());
294 if (isa
<spirv::ScalarType
>(vectorType
)) {
295 rewriter
.replaceOp(insertOp
, adaptor
.getSource());
300 if (matchPattern(adaptor
.getPosition(), m_ConstantInt(&cstPos
)))
301 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
302 insertOp
, adaptor
.getSource(), adaptor
.getDest(),
303 cstPos
.getSExtValue());
305 rewriter
.replaceOpWithNewOp
<spirv::VectorInsertDynamicOp
>(
306 insertOp
, vectorType
, insertOp
.getDest(), adaptor
.getSource(),
307 adaptor
.getPosition());
312 struct VectorInsertStridedSliceOpConvert final
313 : public OpConversionPattern
<vector::InsertStridedSliceOp
> {
314 using OpConversionPattern::OpConversionPattern
;
317 matchAndRewrite(vector::InsertStridedSliceOp insertOp
, OpAdaptor adaptor
,
318 ConversionPatternRewriter
&rewriter
) const override
{
319 Value srcVector
= adaptor
.getOperands().front();
320 Value dstVector
= adaptor
.getOperands().back();
322 uint64_t stride
= getFirstIntValue(insertOp
.getStrides());
325 uint64_t offset
= getFirstIntValue(insertOp
.getOffsets());
327 if (isa
<spirv::ScalarType
>(srcVector
.getType())) {
328 assert(!isa
<spirv::ScalarType
>(dstVector
.getType()));
329 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
330 insertOp
, dstVector
.getType(), srcVector
, dstVector
,
331 rewriter
.getI32ArrayAttr(offset
));
335 uint64_t totalSize
= cast
<VectorType
>(dstVector
.getType()).getNumElements();
336 uint64_t insertSize
=
337 cast
<VectorType
>(srcVector
.getType()).getNumElements();
339 SmallVector
<int32_t, 2> indices(totalSize
);
340 std::iota(indices
.begin(), indices
.end(), 0);
341 std::iota(indices
.begin() + offset
, indices
.begin() + offset
+ insertSize
,
344 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
345 insertOp
, dstVector
.getType(), dstVector
, srcVector
,
346 rewriter
.getI32ArrayAttr(indices
));
352 static SmallVector
<Value
> extractAllElements(
353 vector::ReductionOp reduceOp
, vector::ReductionOp::Adaptor adaptor
,
354 VectorType srcVectorType
, ConversionPatternRewriter
&rewriter
) {
355 int numElements
= static_cast<int>(srcVectorType
.getDimSize(0));
356 SmallVector
<Value
> values
;
357 values
.reserve(numElements
+ (adaptor
.getAcc() ? 1 : 0));
358 Location loc
= reduceOp
.getLoc();
360 for (int i
= 0; i
< numElements
; ++i
) {
361 values
.push_back(rewriter
.create
<spirv::CompositeExtractOp
>(
362 loc
, srcVectorType
.getElementType(), adaptor
.getVector(),
363 rewriter
.getI32ArrayAttr({i
})));
365 if (Value acc
= adaptor
.getAcc())
366 values
.push_back(acc
);
371 struct ReductionRewriteInfo
{
373 SmallVector
<Value
> extractedElements
;
376 FailureOr
<ReductionRewriteInfo
> static getReductionInfo(
377 vector::ReductionOp op
, vector::ReductionOp::Adaptor adaptor
,
378 ConversionPatternRewriter
&rewriter
, const TypeConverter
&typeConverter
) {
379 Type resultType
= typeConverter
.convertType(op
.getType());
383 auto srcVectorType
= dyn_cast
<VectorType
>(adaptor
.getVector().getType());
384 if (!srcVectorType
|| srcVectorType
.getRank() != 1)
385 return rewriter
.notifyMatchFailure(op
, "not a 1-D vector source");
387 SmallVector
<Value
> extractedElements
=
388 extractAllElements(op
, adaptor
, srcVectorType
, rewriter
);
390 return ReductionRewriteInfo
{resultType
, std::move(extractedElements
)};
393 template <typename SPIRVUMaxOp
, typename SPIRVUMinOp
, typename SPIRVSMaxOp
,
394 typename SPIRVSMinOp
>
395 struct VectorReductionPattern final
: OpConversionPattern
<vector::ReductionOp
> {
396 using OpConversionPattern::OpConversionPattern
;
399 matchAndRewrite(vector::ReductionOp reduceOp
, OpAdaptor adaptor
,
400 ConversionPatternRewriter
&rewriter
) const override
{
402 getReductionInfo(reduceOp
, adaptor
, rewriter
, *getTypeConverter());
403 if (failed(reductionInfo
))
406 auto [resultType
, extractedElements
] = *reductionInfo
;
407 Location loc
= reduceOp
->getLoc();
408 Value result
= extractedElements
.front();
409 for (Value next
: llvm::drop_begin(extractedElements
)) {
410 switch (reduceOp
.getKind()) {
412 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
413 case vector::CombiningKind::kind: \
414 if (llvm::isa<IntegerType>(resultType)) { \
415 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
417 assert(llvm::isa<FloatType>(resultType)); \
418 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
422 #define INT_OR_FLOAT_CASE(kind, fop) \
423 case vector::CombiningKind::kind: \
424 result = rewriter.create<fop>(loc, resultType, result, next); \
427 INT_AND_FLOAT_CASE(ADD
, IAddOp
, FAddOp
);
428 INT_AND_FLOAT_CASE(MUL
, IMulOp
, FMulOp
);
429 INT_OR_FLOAT_CASE(MINUI
, SPIRVUMinOp
);
430 INT_OR_FLOAT_CASE(MINSI
, SPIRVSMinOp
);
431 INT_OR_FLOAT_CASE(MAXUI
, SPIRVUMaxOp
);
432 INT_OR_FLOAT_CASE(MAXSI
, SPIRVSMaxOp
);
434 case vector::CombiningKind::AND
:
435 case vector::CombiningKind::OR
:
436 case vector::CombiningKind::XOR
:
437 return rewriter
.notifyMatchFailure(reduceOp
, "unimplemented");
439 return rewriter
.notifyMatchFailure(reduceOp
, "not handled here");
441 #undef INT_AND_FLOAT_CASE
442 #undef INT_OR_FLOAT_CASE
445 rewriter
.replaceOp(reduceOp
, result
);
450 template <typename SPIRVFMaxOp
, typename SPIRVFMinOp
>
451 struct VectorReductionFloatMinMax final
452 : OpConversionPattern
<vector::ReductionOp
> {
453 using OpConversionPattern::OpConversionPattern
;
456 matchAndRewrite(vector::ReductionOp reduceOp
, OpAdaptor adaptor
,
457 ConversionPatternRewriter
&rewriter
) const override
{
459 getReductionInfo(reduceOp
, adaptor
, rewriter
, *getTypeConverter());
460 if (failed(reductionInfo
))
463 auto [resultType
, extractedElements
] = *reductionInfo
;
464 Location loc
= reduceOp
->getLoc();
465 Value result
= extractedElements
.front();
466 for (Value next
: llvm::drop_begin(extractedElements
)) {
467 switch (reduceOp
.getKind()) {
469 #define INT_OR_FLOAT_CASE(kind, fop) \
470 case vector::CombiningKind::kind: \
471 result = rewriter.create<fop>(loc, resultType, result, next); \
474 INT_OR_FLOAT_CASE(MAXIMUMF
, SPIRVFMaxOp
);
475 INT_OR_FLOAT_CASE(MINIMUMF
, SPIRVFMinOp
);
476 INT_OR_FLOAT_CASE(MAXNUMF
, SPIRVFMaxOp
);
477 INT_OR_FLOAT_CASE(MINNUMF
, SPIRVFMinOp
);
480 return rewriter
.notifyMatchFailure(reduceOp
, "not handled here");
482 #undef INT_OR_FLOAT_CASE
485 rewriter
.replaceOp(reduceOp
, result
);
490 class VectorSplatPattern final
: public OpConversionPattern
<vector::SplatOp
> {
492 using OpConversionPattern
<vector::SplatOp
>::OpConversionPattern
;
495 matchAndRewrite(vector::SplatOp op
, OpAdaptor adaptor
,
496 ConversionPatternRewriter
&rewriter
) const override
{
497 Type dstType
= getTypeConverter()->convertType(op
.getType());
500 if (isa
<spirv::ScalarType
>(dstType
)) {
501 rewriter
.replaceOp(op
, adaptor
.getInput());
503 auto dstVecType
= cast
<VectorType
>(dstType
);
504 SmallVector
<Value
, 4> source(dstVecType
.getNumElements(),
506 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(op
, dstType
,
513 struct VectorShuffleOpConvert final
514 : public OpConversionPattern
<vector::ShuffleOp
> {
515 using OpConversionPattern::OpConversionPattern
;
518 matchAndRewrite(vector::ShuffleOp shuffleOp
, OpAdaptor adaptor
,
519 ConversionPatternRewriter
&rewriter
) const override
{
520 VectorType oldResultType
= shuffleOp
.getResultVectorType();
521 Type newResultType
= getTypeConverter()->convertType(oldResultType
);
523 return rewriter
.notifyMatchFailure(shuffleOp
,
524 "unsupported result vector type");
526 auto mask
= llvm::to_vector_of
<int32_t>(shuffleOp
.getMask());
528 VectorType oldV1Type
= shuffleOp
.getV1VectorType();
529 VectorType oldV2Type
= shuffleOp
.getV2VectorType();
531 // When both operands and the result are SPIR-V vectors, emit a SPIR-V
533 if (oldV1Type
.getNumElements() > 1 && oldV2Type
.getNumElements() > 1 &&
534 oldResultType
.getNumElements() > 1) {
535 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
536 shuffleOp
, newResultType
, adaptor
.getV1(), adaptor
.getV2(),
537 rewriter
.getI32ArrayAttr(mask
));
541 // When at least one of the operands or the result becomes a scalar after
542 // type conversion for SPIR-V, extract all the required elements and
543 // construct the result vector.
544 auto getElementAtIdx
= [&rewriter
, loc
= shuffleOp
.getLoc()](
545 Value scalarOrVec
, int32_t idx
) -> Value
{
546 if (auto vecTy
= dyn_cast
<VectorType
>(scalarOrVec
.getType()))
547 return rewriter
.create
<spirv::CompositeExtractOp
>(loc
, scalarOrVec
,
550 assert(idx
== 0 && "Invalid scalar element index");
554 int32_t numV1Elems
= oldV1Type
.getNumElements();
555 SmallVector
<Value
> newOperands(mask
.size());
556 for (auto [shuffleIdx
, newOperand
] : llvm::zip_equal(mask
, newOperands
)) {
557 Value vec
= adaptor
.getV1();
558 int32_t elementIdx
= shuffleIdx
;
559 if (elementIdx
>= numV1Elems
) {
560 vec
= adaptor
.getV2();
561 elementIdx
-= numV1Elems
;
564 newOperand
= getElementAtIdx(vec
, elementIdx
);
567 // Handle the scalar result corner case.
568 if (newOperands
.size() == 1) {
569 rewriter
.replaceOp(shuffleOp
, newOperands
.front());
573 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(
574 shuffleOp
, newResultType
, newOperands
);
579 struct VectorInterleaveOpConvert final
580 : public OpConversionPattern
<vector::InterleaveOp
> {
581 using OpConversionPattern::OpConversionPattern
;
584 matchAndRewrite(vector::InterleaveOp interleaveOp
, OpAdaptor adaptor
,
585 ConversionPatternRewriter
&rewriter
) const override
{
586 // Check the result vector type.
587 VectorType oldResultType
= interleaveOp
.getResultVectorType();
588 Type newResultType
= getTypeConverter()->convertType(oldResultType
);
590 return rewriter
.notifyMatchFailure(interleaveOp
,
591 "unsupported result vector type");
593 // Interleave the indices.
594 VectorType sourceType
= interleaveOp
.getSourceVectorType();
595 int n
= sourceType
.getNumElements();
597 // Input vectors of size 1 are converted to scalars by the type converter.
598 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
599 // use `spirv::CompositeConstructOp`.
601 Value newOperands
[] = {adaptor
.getLhs(), adaptor
.getRhs()};
602 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(
603 interleaveOp
, newResultType
, newOperands
);
607 auto seq
= llvm::seq
<int64_t>(2 * n
);
608 auto indices
= llvm::map_to_vector(
609 seq
, [n
](int i
) { return (i
% 2 ? n
: 0) + i
/ 2; });
611 // Emit a SPIR-V shuffle.
612 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
613 interleaveOp
, newResultType
, adaptor
.getLhs(), adaptor
.getRhs(),
614 rewriter
.getI32ArrayAttr(indices
));
620 struct VectorDeinterleaveOpConvert final
621 : public OpConversionPattern
<vector::DeinterleaveOp
> {
622 using OpConversionPattern::OpConversionPattern
;
625 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp
, OpAdaptor adaptor
,
626 ConversionPatternRewriter
&rewriter
) const override
{
628 // Check the result vector type.
629 VectorType oldResultType
= deinterleaveOp
.getResultVectorType();
630 Type newResultType
= getTypeConverter()->convertType(oldResultType
);
632 return rewriter
.notifyMatchFailure(deinterleaveOp
,
633 "unsupported result vector type");
635 Location loc
= deinterleaveOp
->getLoc();
637 // Deinterleave the indices.
638 Value sourceVector
= adaptor
.getSource();
639 VectorType sourceType
= deinterleaveOp
.getSourceVectorType();
640 int n
= sourceType
.getNumElements();
642 // Output vectors of size 1 are converted to scalars by the type converter.
643 // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
644 // use `spirv::CompositeExtractOp`.
646 auto elem0
= rewriter
.create
<spirv::CompositeExtractOp
>(
647 loc
, newResultType
, sourceVector
, rewriter
.getI32ArrayAttr({0}));
649 auto elem1
= rewriter
.create
<spirv::CompositeExtractOp
>(
650 loc
, newResultType
, sourceVector
, rewriter
.getI32ArrayAttr({1}));
652 rewriter
.replaceOp(deinterleaveOp
, {elem0
, elem1
});
656 // Indices for `shuffleEven` (result 0).
657 auto seqEven
= llvm::seq
<int64_t>(n
/ 2);
659 llvm::map_to_vector(seqEven
, [](int i
) { return i
* 2; });
661 // Indices for `shuffleOdd` (result 1).
662 auto seqOdd
= llvm::seq
<int64_t>(n
/ 2);
664 llvm::map_to_vector(seqOdd
, [](int i
) { return i
* 2 + 1; });
666 // Create two SPIR-V shuffles.
667 auto shuffleEven
= rewriter
.create
<spirv::VectorShuffleOp
>(
668 loc
, newResultType
, sourceVector
, sourceVector
,
669 rewriter
.getI32ArrayAttr(indicesEven
));
671 auto shuffleOdd
= rewriter
.create
<spirv::VectorShuffleOp
>(
672 loc
, newResultType
, sourceVector
, sourceVector
,
673 rewriter
.getI32ArrayAttr(indicesOdd
));
675 rewriter
.replaceOp(deinterleaveOp
, {shuffleEven
, shuffleOdd
});
680 struct VectorLoadOpConverter final
681 : public OpConversionPattern
<vector::LoadOp
> {
682 using OpConversionPattern::OpConversionPattern
;
685 matchAndRewrite(vector::LoadOp loadOp
, OpAdaptor adaptor
,
686 ConversionPatternRewriter
&rewriter
) const override
{
687 auto memrefType
= loadOp
.getMemRefType();
689 dyn_cast_or_null
<spirv::StorageClassAttr
>(memrefType
.getMemorySpace());
691 return rewriter
.notifyMatchFailure(
692 loadOp
, "expected spirv.storage_class memory space");
694 const auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
695 auto loc
= loadOp
.getLoc();
697 spirv::getElementPtr(typeConverter
, memrefType
, adaptor
.getBase(),
698 adaptor
.getIndices(), loc
, rewriter
);
700 return rewriter
.notifyMatchFailure(
701 loadOp
, "failed to get memref element pointer");
703 spirv::StorageClass storageClass
= attr
.getValue();
704 auto vectorType
= loadOp
.getVectorType();
705 auto vectorPtrType
= spirv::PointerType::get(vectorType
, storageClass
);
706 Value castedAccessChain
=
707 rewriter
.create
<spirv::BitcastOp
>(loc
, vectorPtrType
, accessChain
);
708 rewriter
.replaceOpWithNewOp
<spirv::LoadOp
>(loadOp
, vectorType
,
715 struct VectorStoreOpConverter final
716 : public OpConversionPattern
<vector::StoreOp
> {
717 using OpConversionPattern::OpConversionPattern
;
720 matchAndRewrite(vector::StoreOp storeOp
, OpAdaptor adaptor
,
721 ConversionPatternRewriter
&rewriter
) const override
{
722 auto memrefType
= storeOp
.getMemRefType();
724 dyn_cast_or_null
<spirv::StorageClassAttr
>(memrefType
.getMemorySpace());
726 return rewriter
.notifyMatchFailure(
727 storeOp
, "expected spirv.storage_class memory space");
729 const auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
730 auto loc
= storeOp
.getLoc();
732 spirv::getElementPtr(typeConverter
, memrefType
, adaptor
.getBase(),
733 adaptor
.getIndices(), loc
, rewriter
);
735 return rewriter
.notifyMatchFailure(
736 storeOp
, "failed to get memref element pointer");
738 spirv::StorageClass storageClass
= attr
.getValue();
739 auto vectorType
= storeOp
.getVectorType();
740 auto vectorPtrType
= spirv::PointerType::get(vectorType
, storageClass
);
741 Value castedAccessChain
=
742 rewriter
.create
<spirv::BitcastOp
>(loc
, vectorPtrType
, accessChain
);
743 rewriter
.replaceOpWithNewOp
<spirv::StoreOp
>(storeOp
, castedAccessChain
,
744 adaptor
.getValueToStore());
750 struct VectorReductionToIntDotProd final
751 : OpRewritePattern
<vector::ReductionOp
> {
752 using OpRewritePattern::OpRewritePattern
;
754 LogicalResult
matchAndRewrite(vector::ReductionOp op
,
755 PatternRewriter
&rewriter
) const override
{
756 if (op
.getKind() != vector::CombiningKind::ADD
)
757 return rewriter
.notifyMatchFailure(op
, "combining kind is not 'add'");
759 auto resultType
= dyn_cast
<IntegerType
>(op
.getType());
761 return rewriter
.notifyMatchFailure(op
, "result is not an integer");
763 int64_t resultBitwidth
= resultType
.getIntOrFloatBitWidth();
764 if (!llvm::is_contained({32, 64}, resultBitwidth
))
765 return rewriter
.notifyMatchFailure(op
, "unsupported integer bitwidth");
767 VectorType inVecTy
= op
.getSourceVectorType();
768 if (!llvm::is_contained({4, 3}, inVecTy
.getNumElements()) ||
769 inVecTy
.getShape().size() != 1 || inVecTy
.isScalable())
770 return rewriter
.notifyMatchFailure(op
, "unsupported vector shape");
772 auto mul
= op
.getVector().getDefiningOp
<arith::MulIOp
>();
774 return rewriter
.notifyMatchFailure(
775 op
, "reduction operand is not 'arith.muli'");
777 if (succeeded(handleCase
<arith::ExtSIOp
, arith::ExtSIOp
, spirv::SDotOp
,
778 spirv::SDotAccSatOp
, false>(op
, mul
, rewriter
)))
781 if (succeeded(handleCase
<arith::ExtUIOp
, arith::ExtUIOp
, spirv::UDotOp
,
782 spirv::UDotAccSatOp
, false>(op
, mul
, rewriter
)))
785 if (succeeded(handleCase
<arith::ExtSIOp
, arith::ExtUIOp
, spirv::SUDotOp
,
786 spirv::SUDotAccSatOp
, false>(op
, mul
, rewriter
)))
789 if (succeeded(handleCase
<arith::ExtUIOp
, arith::ExtSIOp
, spirv::SUDotOp
,
790 spirv::SUDotAccSatOp
, true>(op
, mul
, rewriter
)))
797 template <typename LhsExtensionOp
, typename RhsExtensionOp
, typename DotOp
,
798 typename DotAccOp
, bool SwapOperands
>
799 static LogicalResult
handleCase(vector::ReductionOp op
, arith::MulIOp mul
,
800 PatternRewriter
&rewriter
) {
801 auto lhs
= mul
.getLhs().getDefiningOp
<LhsExtensionOp
>();
804 Value lhsIn
= lhs
.getIn();
805 auto lhsInType
= cast
<VectorType
>(lhsIn
.getType());
806 if (!lhsInType
.getElementType().isInteger(8))
809 auto rhs
= mul
.getRhs().getDefiningOp
<RhsExtensionOp
>();
812 Value rhsIn
= rhs
.getIn();
813 auto rhsInType
= cast
<VectorType
>(rhsIn
.getType());
814 if (!rhsInType
.getElementType().isInteger(8))
817 if (op
.getSourceVectorType().getNumElements() == 3) {
818 IntegerType i8Type
= rewriter
.getI8Type();
819 auto v4i8Type
= VectorType::get({4}, i8Type
);
820 Location loc
= op
.getLoc();
821 Value zero
= spirv::ConstantOp::getZero(i8Type
, loc
, rewriter
);
822 lhsIn
= rewriter
.create
<spirv::CompositeConstructOp
>(
823 loc
, v4i8Type
, ValueRange
{lhsIn
, zero
});
824 rhsIn
= rewriter
.create
<spirv::CompositeConstructOp
>(
825 loc
, v4i8Type
, ValueRange
{rhsIn
, zero
});
828 // There's no variant of dot prod ops for unsigned LHS and signed RHS, so
829 // we have to swap operands instead in that case.
831 std::swap(lhsIn
, rhsIn
);
833 if (Value acc
= op
.getAcc()) {
834 rewriter
.replaceOpWithNewOp
<DotAccOp
>(op
, op
.getType(), lhsIn
, rhsIn
, acc
,
837 rewriter
.replaceOpWithNewOp
<DotOp
>(op
, op
.getType(), lhsIn
, rhsIn
,
845 struct VectorReductionToFPDotProd final
846 : OpConversionPattern
<vector::ReductionOp
> {
847 using OpConversionPattern::OpConversionPattern
;
850 matchAndRewrite(vector::ReductionOp op
, OpAdaptor adaptor
,
851 ConversionPatternRewriter
&rewriter
) const override
{
852 if (op
.getKind() != vector::CombiningKind::ADD
)
853 return rewriter
.notifyMatchFailure(op
, "combining kind is not 'add'");
855 auto resultType
= getTypeConverter()->convertType
<FloatType
>(op
.getType());
857 return rewriter
.notifyMatchFailure(op
, "result is not a float");
859 Value vec
= adaptor
.getVector();
860 Value acc
= adaptor
.getAcc();
862 auto vectorType
= dyn_cast
<VectorType
>(vec
.getType());
864 assert(isa
<FloatType
>(vec
.getType()) &&
865 "Expected the vector to be scalarized");
867 rewriter
.replaceOpWithNewOp
<spirv::FAddOp
>(op
, acc
, vec
);
871 rewriter
.replaceOp(op
, vec
);
875 Location loc
= op
.getLoc();
878 if (auto mul
= vec
.getDefiningOp
<arith::MulFOp
>()) {
882 // If the operand is not a mul, use a vector of ones for the dot operand
883 // to just sum up all values.
886 rewriter
.getFloatAttr(vectorType
.getElementType(), 1.0);
887 oneAttr
= SplatElementsAttr::get(vectorType
, oneAttr
);
888 rhs
= rewriter
.create
<spirv::ConstantOp
>(loc
, vectorType
, oneAttr
);
893 Value res
= rewriter
.create
<spirv::DotOp
>(loc
, resultType
, lhs
, rhs
);
895 res
= rewriter
.create
<spirv::FAddOp
>(loc
, acc
, res
);
897 rewriter
.replaceOp(op
, res
);
902 struct VectorStepOpConvert final
: OpConversionPattern
<vector::StepOp
> {
903 using OpConversionPattern::OpConversionPattern
;
906 matchAndRewrite(vector::StepOp stepOp
, OpAdaptor adaptor
,
907 ConversionPatternRewriter
&rewriter
) const override
{
908 const auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
909 Type dstType
= typeConverter
.convertType(stepOp
.getType());
913 Location loc
= stepOp
.getLoc();
914 int64_t numElements
= stepOp
.getType().getNumElements();
916 rewriter
.getIntegerType(typeConverter
.getIndexTypeBitwidth());
918 // Input vectors of size 1 are converted to scalars by the type converter.
919 // We just create a constant in this case.
920 if (numElements
== 1) {
921 Value zero
= spirv::ConstantOp::getZero(intType
, loc
, rewriter
);
922 rewriter
.replaceOp(stepOp
, zero
);
926 SmallVector
<Value
> source
;
927 source
.reserve(numElements
);
928 for (int64_t i
= 0; i
< numElements
; ++i
) {
929 Attribute intAttr
= rewriter
.getIntegerAttr(intType
, i
);
930 Value constOp
= rewriter
.create
<spirv::ConstantOp
>(loc
, intType
, intAttr
);
931 source
.push_back(constOp
);
933 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(stepOp
, dstType
,
940 #define CL_INT_MAX_MIN_OPS \
941 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
943 #define GL_INT_MAX_MIN_OPS \
944 spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
946 #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
947 #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
949 void mlir::populateVectorToSPIRVPatterns(
950 const SPIRVTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
952 VectorBitcastConvert
, VectorBroadcastConvert
,
953 VectorExtractElementOpConvert
, VectorExtractOpConvert
,
954 VectorExtractStridedSliceOpConvert
, VectorFmaOpConvert
<spirv::GLFmaOp
>,
955 VectorFmaOpConvert
<spirv::CLFmaOp
>, VectorInsertElementOpConvert
,
956 VectorInsertOpConvert
, VectorReductionPattern
<GL_INT_MAX_MIN_OPS
>,
957 VectorReductionPattern
<CL_INT_MAX_MIN_OPS
>,
958 VectorReductionFloatMinMax
<CL_FLOAT_MAX_MIN_OPS
>,
959 VectorReductionFloatMinMax
<GL_FLOAT_MAX_MIN_OPS
>, VectorShapeCast
,
960 VectorInsertStridedSliceOpConvert
, VectorShuffleOpConvert
,
961 VectorInterleaveOpConvert
, VectorDeinterleaveOpConvert
,
962 VectorSplatPattern
, VectorLoadOpConverter
, VectorStoreOpConverter
,
963 VectorStepOpConvert
>(typeConverter
, patterns
.getContext(),
966 // Make sure that the more specialized dot product pattern has higher benefit
967 // than the generic one that extracts all elements.
968 patterns
.add
<VectorReductionToFPDotProd
>(typeConverter
, patterns
.getContext(),
972 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
973 RewritePatternSet
&patterns
) {
974 patterns
.add
<VectorReductionToIntDotProd
>(patterns
.getContext());