1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/Support/FormatVariadic.h"
31 /// Gets the first integer value from `attr`, assuming it is an integer array
33 static uint64_t getFirstIntValue(ArrayAttr attr
) {
34 return (*attr
.getAsValueRange
<IntegerAttr
>().begin()).getZExtValue();
37 /// Returns the number of bits for the given scalar/vector type.
38 static int getNumBits(Type type
) {
39 if (auto vectorType
= type
.dyn_cast
<VectorType
>())
40 return vectorType
.cast
<ShapedType
>().getSizeInBits();
41 return type
.getIntOrFloatBitWidth();
46 struct VectorBitcastConvert final
47 : public OpConversionPattern
<vector::BitCastOp
> {
48 using OpConversionPattern::OpConversionPattern
;
51 matchAndRewrite(vector::BitCastOp bitcastOp
, OpAdaptor adaptor
,
52 ConversionPatternRewriter
&rewriter
) const override
{
53 Type dstType
= getTypeConverter()->convertType(bitcastOp
.getType());
57 if (dstType
== adaptor
.getSource().getType()) {
58 rewriter
.replaceOp(bitcastOp
, adaptor
.getSource());
62 // Check that the source and destination type have the same bitwidth.
63 // Depending on the target environment, we may need to emulate certain
64 // types, which can cause issue with bitcast.
65 Type srcType
= adaptor
.getSource().getType();
66 if (getNumBits(dstType
) != getNumBits(srcType
)) {
67 return rewriter
.notifyMatchFailure(
69 llvm::formatv("different source ({0}) and target ({1}) bitwidth",
73 rewriter
.replaceOpWithNewOp
<spirv::BitcastOp
>(bitcastOp
, dstType
,
79 struct VectorBroadcastConvert final
80 : public OpConversionPattern
<vector::BroadcastOp
> {
81 using OpConversionPattern::OpConversionPattern
;
84 matchAndRewrite(vector::BroadcastOp castOp
, OpAdaptor adaptor
,
85 ConversionPatternRewriter
&rewriter
) const override
{
86 Type resultType
= getTypeConverter()->convertType(castOp
.getVectorType());
90 if (resultType
.isa
<spirv::ScalarType
>()) {
91 rewriter
.replaceOp(castOp
, adaptor
.getSource());
95 SmallVector
<Value
, 4> source(castOp
.getVectorType().getNumElements(),
97 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(
98 castOp
, castOp
.getVectorType(), source
);
103 struct VectorExtractOpConvert final
104 : public OpConversionPattern
<vector::ExtractOp
> {
105 using OpConversionPattern::OpConversionPattern
;
108 matchAndRewrite(vector::ExtractOp extractOp
, OpAdaptor adaptor
,
109 ConversionPatternRewriter
&rewriter
) const override
{
110 // Only support extracting a scalar value now.
111 VectorType resultVectorType
= extractOp
.getType().dyn_cast
<VectorType
>();
112 if (resultVectorType
&& resultVectorType
.getNumElements() > 1)
115 Type dstType
= getTypeConverter()->convertType(extractOp
.getType());
119 if (adaptor
.getVector().getType().isa
<spirv::ScalarType
>()) {
120 rewriter
.replaceOp(extractOp
, adaptor
.getVector());
124 int32_t id
= getFirstIntValue(extractOp
.getPosition());
125 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
126 extractOp
, adaptor
.getVector(), id
);
131 struct VectorExtractStridedSliceOpConvert final
132 : public OpConversionPattern
<vector::ExtractStridedSliceOp
> {
133 using OpConversionPattern::OpConversionPattern
;
136 matchAndRewrite(vector::ExtractStridedSliceOp extractOp
, OpAdaptor adaptor
,
137 ConversionPatternRewriter
&rewriter
) const override
{
138 Type dstType
= getTypeConverter()->convertType(extractOp
.getType());
142 uint64_t offset
= getFirstIntValue(extractOp
.getOffsets());
143 uint64_t size
= getFirstIntValue(extractOp
.getSizes());
144 uint64_t stride
= getFirstIntValue(extractOp
.getStrides());
148 Value srcVector
= adaptor
.getOperands().front();
150 // Extract vector<1xT> case.
151 if (dstType
.isa
<spirv::ScalarType
>()) {
152 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(extractOp
,
157 SmallVector
<int32_t, 2> indices(size
);
158 std::iota(indices
.begin(), indices
.end(), offset
);
160 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
161 extractOp
, dstType
, srcVector
, srcVector
,
162 rewriter
.getI32ArrayAttr(indices
));
168 template <class SPIRVFMAOp
>
169 struct VectorFmaOpConvert final
: public OpConversionPattern
<vector::FMAOp
> {
170 using OpConversionPattern::OpConversionPattern
;
173 matchAndRewrite(vector::FMAOp fmaOp
, OpAdaptor adaptor
,
174 ConversionPatternRewriter
&rewriter
) const override
{
175 Type dstType
= getTypeConverter()->convertType(fmaOp
.getType());
178 rewriter
.replaceOpWithNewOp
<SPIRVFMAOp
>(fmaOp
, dstType
, adaptor
.getLhs(),
179 adaptor
.getRhs(), adaptor
.getAcc());
184 struct VectorInsertOpConvert final
185 : public OpConversionPattern
<vector::InsertOp
> {
186 using OpConversionPattern::OpConversionPattern
;
189 matchAndRewrite(vector::InsertOp insertOp
, OpAdaptor adaptor
,
190 ConversionPatternRewriter
&rewriter
) const override
{
191 // Special case for inserting scalar values into size-1 vectors.
192 if (insertOp
.getSourceType().isIntOrFloat() &&
193 insertOp
.getDestVectorType().getNumElements() == 1) {
194 rewriter
.replaceOp(insertOp
, adaptor
.getSource());
198 if (insertOp
.getSourceType().isa
<VectorType
>() ||
199 !spirv::CompositeType::isValid(insertOp
.getDestVectorType()))
201 int32_t id
= getFirstIntValue(insertOp
.getPosition());
202 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
203 insertOp
, adaptor
.getSource(), adaptor
.getDest(), id
);
208 struct VectorExtractElementOpConvert final
209 : public OpConversionPattern
<vector::ExtractElementOp
> {
210 using OpConversionPattern::OpConversionPattern
;
213 matchAndRewrite(vector::ExtractElementOp extractOp
, OpAdaptor adaptor
,
214 ConversionPatternRewriter
&rewriter
) const override
{
215 Type resultType
= getTypeConverter()->convertType(extractOp
.getType());
219 if (adaptor
.getVector().getType().isa
<spirv::ScalarType
>()) {
220 rewriter
.replaceOp(extractOp
, adaptor
.getVector());
225 if (matchPattern(adaptor
.getPosition(), m_ConstantInt(&cstPos
)))
226 rewriter
.replaceOpWithNewOp
<spirv::CompositeExtractOp
>(
227 extractOp
, resultType
, adaptor
.getVector(),
228 rewriter
.getI32ArrayAttr({static_cast<int>(cstPos
.getSExtValue())}));
230 rewriter
.replaceOpWithNewOp
<spirv::VectorExtractDynamicOp
>(
231 extractOp
, resultType
, adaptor
.getVector(), adaptor
.getPosition());
236 struct VectorInsertElementOpConvert final
237 : public OpConversionPattern
<vector::InsertElementOp
> {
238 using OpConversionPattern::OpConversionPattern
;
241 matchAndRewrite(vector::InsertElementOp insertOp
, OpAdaptor adaptor
,
242 ConversionPatternRewriter
&rewriter
) const override
{
243 Type vectorType
= getTypeConverter()->convertType(insertOp
.getType());
247 if (vectorType
.isa
<spirv::ScalarType
>()) {
248 rewriter
.replaceOp(insertOp
, adaptor
.getSource());
253 if (matchPattern(adaptor
.getPosition(), m_ConstantInt(&cstPos
)))
254 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
255 insertOp
, adaptor
.getSource(), adaptor
.getDest(),
256 cstPos
.getSExtValue());
258 rewriter
.replaceOpWithNewOp
<spirv::VectorInsertDynamicOp
>(
259 insertOp
, vectorType
, insertOp
.getDest(), adaptor
.getSource(),
260 adaptor
.getPosition());
265 struct VectorInsertStridedSliceOpConvert final
266 : public OpConversionPattern
<vector::InsertStridedSliceOp
> {
267 using OpConversionPattern::OpConversionPattern
;
270 matchAndRewrite(vector::InsertStridedSliceOp insertOp
, OpAdaptor adaptor
,
271 ConversionPatternRewriter
&rewriter
) const override
{
272 Value srcVector
= adaptor
.getOperands().front();
273 Value dstVector
= adaptor
.getOperands().back();
275 uint64_t stride
= getFirstIntValue(insertOp
.getStrides());
278 uint64_t offset
= getFirstIntValue(insertOp
.getOffsets());
280 if (srcVector
.getType().isa
<spirv::ScalarType
>()) {
281 assert(!dstVector
.getType().isa
<spirv::ScalarType
>());
282 rewriter
.replaceOpWithNewOp
<spirv::CompositeInsertOp
>(
283 insertOp
, dstVector
.getType(), srcVector
, dstVector
,
284 rewriter
.getI32ArrayAttr(offset
));
289 dstVector
.getType().cast
<VectorType
>().getNumElements();
290 uint64_t insertSize
=
291 srcVector
.getType().cast
<VectorType
>().getNumElements();
293 SmallVector
<int32_t, 2> indices(totalSize
);
294 std::iota(indices
.begin(), indices
.end(), 0);
295 std::iota(indices
.begin() + offset
, indices
.begin() + offset
+ insertSize
,
298 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
299 insertOp
, dstVector
.getType(), dstVector
, srcVector
,
300 rewriter
.getI32ArrayAttr(indices
));
306 template <class SPIRVFMaxOp
, class SPIRVFMinOp
, class SPIRVUMaxOp
,
307 class SPIRVUMinOp
, class SPIRVSMaxOp
, class SPIRVSMinOp
>
308 struct VectorReductionPattern final
309 : public OpConversionPattern
<vector::ReductionOp
> {
310 using OpConversionPattern::OpConversionPattern
;
313 matchAndRewrite(vector::ReductionOp reduceOp
, OpAdaptor adaptor
,
314 ConversionPatternRewriter
&rewriter
) const override
{
315 Type resultType
= typeConverter
->convertType(reduceOp
.getType());
319 auto srcVectorType
= adaptor
.getVector().getType().dyn_cast
<VectorType
>();
320 if (!srcVectorType
|| srcVectorType
.getRank() != 1)
321 return rewriter
.notifyMatchFailure(reduceOp
, "not 1-D vector source");
323 // Extract all elements.
324 int numElements
= srcVectorType
.getDimSize(0);
325 SmallVector
<Value
, 4> values
;
326 values
.reserve(numElements
+ (adaptor
.getAcc() != nullptr));
327 Location loc
= reduceOp
.getLoc();
328 for (int i
= 0; i
< numElements
; ++i
) {
329 values
.push_back(rewriter
.create
<spirv::CompositeExtractOp
>(
330 loc
, srcVectorType
.getElementType(), adaptor
.getVector(),
331 rewriter
.getI32ArrayAttr({i
})));
333 if (Value acc
= adaptor
.getAcc())
334 values
.push_back(acc
);
337 Value result
= values
.front();
338 for (Value next
: llvm::ArrayRef(values
).drop_front()) {
339 switch (reduceOp
.getKind()) {
341 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
342 case vector::CombiningKind::kind: \
343 if (resultType.isa<IntegerType>()) { \
344 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
346 assert(resultType.isa<FloatType>()); \
347 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
351 #define INT_OR_FLOAT_CASE(kind, fop) \
352 case vector::CombiningKind::kind: \
353 result = rewriter.create<fop>(loc, resultType, result, next); \
356 INT_AND_FLOAT_CASE(ADD
, IAddOp
, FAddOp
);
357 INT_AND_FLOAT_CASE(MUL
, IMulOp
, FMulOp
);
359 INT_OR_FLOAT_CASE(MAXF
, SPIRVFMaxOp
);
360 INT_OR_FLOAT_CASE(MINF
, SPIRVFMinOp
);
361 INT_OR_FLOAT_CASE(MINUI
, SPIRVUMinOp
);
362 INT_OR_FLOAT_CASE(MINSI
, SPIRVSMinOp
);
363 INT_OR_FLOAT_CASE(MAXUI
, SPIRVUMaxOp
);
364 INT_OR_FLOAT_CASE(MAXSI
, SPIRVSMaxOp
);
366 case vector::CombiningKind::AND
:
367 case vector::CombiningKind::OR
:
368 case vector::CombiningKind::XOR
:
369 return rewriter
.notifyMatchFailure(reduceOp
, "unimplemented");
373 rewriter
.replaceOp(reduceOp
, result
);
378 class VectorSplatPattern final
: public OpConversionPattern
<vector::SplatOp
> {
380 using OpConversionPattern
<vector::SplatOp
>::OpConversionPattern
;
383 matchAndRewrite(vector::SplatOp op
, OpAdaptor adaptor
,
384 ConversionPatternRewriter
&rewriter
) const override
{
385 Type dstType
= getTypeConverter()->convertType(op
.getType());
388 if (dstType
.isa
<spirv::ScalarType
>()) {
389 rewriter
.replaceOp(op
, adaptor
.getInput());
391 auto dstVecType
= dstType
.cast
<VectorType
>();
392 SmallVector
<Value
, 4> source(dstVecType
.getNumElements(),
394 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(op
, dstType
,
401 struct VectorShuffleOpConvert final
402 : public OpConversionPattern
<vector::ShuffleOp
> {
403 using OpConversionPattern::OpConversionPattern
;
406 matchAndRewrite(vector::ShuffleOp shuffleOp
, OpAdaptor adaptor
,
407 ConversionPatternRewriter
&rewriter
) const override
{
408 auto oldResultType
= shuffleOp
.getVectorType();
409 if (!spirv::CompositeType::isValid(oldResultType
))
411 Type newResultType
= getTypeConverter()->convertType(oldResultType
);
413 auto oldSourceType
= shuffleOp
.getV1VectorType();
414 if (oldSourceType
.getNumElements() > 1) {
415 SmallVector
<int32_t, 4> components
= llvm::to_vector
<4>(
416 llvm::map_range(shuffleOp
.getMask(), [](Attribute attr
) -> int32_t {
417 return attr
.cast
<IntegerAttr
>().getValue().getZExtValue();
419 rewriter
.replaceOpWithNewOp
<spirv::VectorShuffleOp
>(
420 shuffleOp
, newResultType
, adaptor
.getV1(), adaptor
.getV2(),
421 rewriter
.getI32ArrayAttr(components
));
425 SmallVector
<Value
, 2> oldOperands
= {adaptor
.getV1(), adaptor
.getV2()};
426 SmallVector
<Value
, 4> newOperands
;
427 newOperands
.reserve(oldResultType
.getNumElements());
428 for (const APInt
&i
: shuffleOp
.getMask().getAsValueRange
<IntegerAttr
>()) {
429 newOperands
.push_back(oldOperands
[i
.getZExtValue()]);
431 rewriter
.replaceOpWithNewOp
<spirv::CompositeConstructOp
>(
432 shuffleOp
, newResultType
, newOperands
);
439 #define CL_MAX_MIN_OPS \
440 spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
441 spirv::CLSMaxOp, spirv::CLSMinOp
443 #define GL_MAX_MIN_OPS \
444 spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
445 spirv::GLSMaxOp, spirv::GLSMinOp
447 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter
&typeConverter
,
448 RewritePatternSet
&patterns
) {
450 VectorBitcastConvert
, VectorBroadcastConvert
,
451 VectorExtractElementOpConvert
, VectorExtractOpConvert
,
452 VectorExtractStridedSliceOpConvert
, VectorFmaOpConvert
<spirv::GLFmaOp
>,
453 VectorFmaOpConvert
<spirv::CLFmaOp
>, VectorInsertElementOpConvert
,
454 VectorInsertOpConvert
, VectorReductionPattern
<GL_MAX_MIN_OPS
>,
455 VectorReductionPattern
<CL_MAX_MIN_OPS
>, VectorInsertStridedSliceOpConvert
,
456 VectorShuffleOpConvert
, VectorSplatPattern
>(typeConverter
,
457 patterns
.getContext());