[AArch64][SME2] Add multi-vector saturating doubling multiply high intrinsics
[llvm-project.git] / mlir / lib / Conversion / VectorToSPIRV / VectorToSPIRV.cpp
blob0dbf67e0b69f84cfeb2d15f10d44886fcb187863
1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <numeric>
29 using namespace mlir;
31 /// Gets the first integer value from `attr`, assuming it is an integer array
32 /// attribute.
33 static uint64_t getFirstIntValue(ArrayAttr attr) {
34 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
37 /// Returns the number of bits for the given scalar/vector type.
38 static int getNumBits(Type type) {
39 if (auto vectorType = type.dyn_cast<VectorType>())
40 return vectorType.cast<ShapedType>().getSizeInBits();
41 return type.getIntOrFloatBitWidth();
44 namespace {
46 struct VectorBitcastConvert final
47 : public OpConversionPattern<vector::BitCastOp> {
48 using OpConversionPattern::OpConversionPattern;
50 LogicalResult
51 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
54 if (!dstType)
55 return failure();
57 if (dstType == adaptor.getSource().getType()) {
58 rewriter.replaceOp(bitcastOp, adaptor.getSource());
59 return success();
62 // Check that the source and destination type have the same bitwidth.
63 // Depending on the target environment, we may need to emulate certain
64 // types, which can cause issue with bitcast.
65 Type srcType = adaptor.getSource().getType();
66 if (getNumBits(dstType) != getNumBits(srcType)) {
67 return rewriter.notifyMatchFailure(
68 bitcastOp,
69 llvm::formatv("different source ({0}) and target ({1}) bitwidth",
70 srcType, dstType));
73 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
74 adaptor.getSource());
75 return success();
79 struct VectorBroadcastConvert final
80 : public OpConversionPattern<vector::BroadcastOp> {
81 using OpConversionPattern::OpConversionPattern;
83 LogicalResult
84 matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
87 if (!resultType)
88 return failure();
90 if (resultType.isa<spirv::ScalarType>()) {
91 rewriter.replaceOp(castOp, adaptor.getSource());
92 return success();
95 SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
96 adaptor.getSource());
97 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
98 castOp, castOp.getVectorType(), source);
99 return success();
103 struct VectorExtractOpConvert final
104 : public OpConversionPattern<vector::ExtractOp> {
105 using OpConversionPattern::OpConversionPattern;
107 LogicalResult
108 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter) const override {
110 // Only support extracting a scalar value now.
111 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
112 if (resultVectorType && resultVectorType.getNumElements() > 1)
113 return failure();
115 Type dstType = getTypeConverter()->convertType(extractOp.getType());
116 if (!dstType)
117 return failure();
119 if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
120 rewriter.replaceOp(extractOp, adaptor.getVector());
121 return success();
124 int32_t id = getFirstIntValue(extractOp.getPosition());
125 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
126 extractOp, adaptor.getVector(), id);
127 return success();
131 struct VectorExtractStridedSliceOpConvert final
132 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
133 using OpConversionPattern::OpConversionPattern;
135 LogicalResult
136 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter) const override {
138 Type dstType = getTypeConverter()->convertType(extractOp.getType());
139 if (!dstType)
140 return failure();
142 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
143 uint64_t size = getFirstIntValue(extractOp.getSizes());
144 uint64_t stride = getFirstIntValue(extractOp.getStrides());
145 if (stride != 1)
146 return failure();
148 Value srcVector = adaptor.getOperands().front();
150 // Extract vector<1xT> case.
151 if (dstType.isa<spirv::ScalarType>()) {
152 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
153 srcVector, offset);
154 return success();
157 SmallVector<int32_t, 2> indices(size);
158 std::iota(indices.begin(), indices.end(), offset);
160 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
161 extractOp, dstType, srcVector, srcVector,
162 rewriter.getI32ArrayAttr(indices));
164 return success();
168 template <class SPIRVFMAOp>
169 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
170 using OpConversionPattern::OpConversionPattern;
172 LogicalResult
173 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
174 ConversionPatternRewriter &rewriter) const override {
175 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
176 if (!dstType)
177 return failure();
178 rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
179 adaptor.getRhs(), adaptor.getAcc());
180 return success();
184 struct VectorInsertOpConvert final
185 : public OpConversionPattern<vector::InsertOp> {
186 using OpConversionPattern::OpConversionPattern;
188 LogicalResult
189 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter) const override {
191 // Special case for inserting scalar values into size-1 vectors.
192 if (insertOp.getSourceType().isIntOrFloat() &&
193 insertOp.getDestVectorType().getNumElements() == 1) {
194 rewriter.replaceOp(insertOp, adaptor.getSource());
195 return success();
198 if (insertOp.getSourceType().isa<VectorType>() ||
199 !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
200 return failure();
201 int32_t id = getFirstIntValue(insertOp.getPosition());
202 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
203 insertOp, adaptor.getSource(), adaptor.getDest(), id);
204 return success();
208 struct VectorExtractElementOpConvert final
209 : public OpConversionPattern<vector::ExtractElementOp> {
210 using OpConversionPattern::OpConversionPattern;
212 LogicalResult
213 matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override {
215 Type resultType = getTypeConverter()->convertType(extractOp.getType());
216 if (!resultType)
217 return failure();
219 if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
220 rewriter.replaceOp(extractOp, adaptor.getVector());
221 return success();
224 APInt cstPos;
225 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
226 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
227 extractOp, resultType, adaptor.getVector(),
228 rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
229 else
230 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
231 extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
232 return success();
236 struct VectorInsertElementOpConvert final
237 : public OpConversionPattern<vector::InsertElementOp> {
238 using OpConversionPattern::OpConversionPattern;
240 LogicalResult
241 matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override {
243 Type vectorType = getTypeConverter()->convertType(insertOp.getType());
244 if (!vectorType)
245 return failure();
247 if (vectorType.isa<spirv::ScalarType>()) {
248 rewriter.replaceOp(insertOp, adaptor.getSource());
249 return success();
252 APInt cstPos;
253 if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
254 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
255 insertOp, adaptor.getSource(), adaptor.getDest(),
256 cstPos.getSExtValue());
257 else
258 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
259 insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
260 adaptor.getPosition());
261 return success();
265 struct VectorInsertStridedSliceOpConvert final
266 : public OpConversionPattern<vector::InsertStridedSliceOp> {
267 using OpConversionPattern::OpConversionPattern;
269 LogicalResult
270 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 Value srcVector = adaptor.getOperands().front();
273 Value dstVector = adaptor.getOperands().back();
275 uint64_t stride = getFirstIntValue(insertOp.getStrides());
276 if (stride != 1)
277 return failure();
278 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
280 if (srcVector.getType().isa<spirv::ScalarType>()) {
281 assert(!dstVector.getType().isa<spirv::ScalarType>());
282 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
283 insertOp, dstVector.getType(), srcVector, dstVector,
284 rewriter.getI32ArrayAttr(offset));
285 return success();
288 uint64_t totalSize =
289 dstVector.getType().cast<VectorType>().getNumElements();
290 uint64_t insertSize =
291 srcVector.getType().cast<VectorType>().getNumElements();
293 SmallVector<int32_t, 2> indices(totalSize);
294 std::iota(indices.begin(), indices.end(), 0);
295 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
296 totalSize);
298 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
299 insertOp, dstVector.getType(), dstVector, srcVector,
300 rewriter.getI32ArrayAttr(indices));
302 return success();
306 template <class SPIRVFMaxOp, class SPIRVFMinOp, class SPIRVUMaxOp,
307 class SPIRVUMinOp, class SPIRVSMaxOp, class SPIRVSMinOp>
308 struct VectorReductionPattern final
309 : public OpConversionPattern<vector::ReductionOp> {
310 using OpConversionPattern::OpConversionPattern;
312 LogicalResult
313 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 Type resultType = typeConverter->convertType(reduceOp.getType());
316 if (!resultType)
317 return failure();
319 auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
320 if (!srcVectorType || srcVectorType.getRank() != 1)
321 return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
323 // Extract all elements.
324 int numElements = srcVectorType.getDimSize(0);
325 SmallVector<Value, 4> values;
326 values.reserve(numElements + (adaptor.getAcc() != nullptr));
327 Location loc = reduceOp.getLoc();
328 for (int i = 0; i < numElements; ++i) {
329 values.push_back(rewriter.create<spirv::CompositeExtractOp>(
330 loc, srcVectorType.getElementType(), adaptor.getVector(),
331 rewriter.getI32ArrayAttr({i})));
333 if (Value acc = adaptor.getAcc())
334 values.push_back(acc);
336 // Reduce them.
337 Value result = values.front();
338 for (Value next : llvm::ArrayRef(values).drop_front()) {
339 switch (reduceOp.getKind()) {
341 #define INT_AND_FLOAT_CASE(kind, iop, fop) \
342 case vector::CombiningKind::kind: \
343 if (resultType.isa<IntegerType>()) { \
344 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
345 } else { \
346 assert(resultType.isa<FloatType>()); \
347 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
349 break
351 #define INT_OR_FLOAT_CASE(kind, fop) \
352 case vector::CombiningKind::kind: \
353 result = rewriter.create<fop>(loc, resultType, result, next); \
354 break
356 INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
357 INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
359 INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp);
360 INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp);
361 INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
362 INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
363 INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
364 INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
366 case vector::CombiningKind::AND:
367 case vector::CombiningKind::OR:
368 case vector::CombiningKind::XOR:
369 return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
373 rewriter.replaceOp(reduceOp, result);
374 return success();
378 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
379 public:
380 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
382 LogicalResult
383 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
384 ConversionPatternRewriter &rewriter) const override {
385 Type dstType = getTypeConverter()->convertType(op.getType());
386 if (!dstType)
387 return failure();
388 if (dstType.isa<spirv::ScalarType>()) {
389 rewriter.replaceOp(op, adaptor.getInput());
390 } else {
391 auto dstVecType = dstType.cast<VectorType>();
392 SmallVector<Value, 4> source(dstVecType.getNumElements(),
393 adaptor.getInput());
394 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
395 source);
397 return success();
401 struct VectorShuffleOpConvert final
402 : public OpConversionPattern<vector::ShuffleOp> {
403 using OpConversionPattern::OpConversionPattern;
405 LogicalResult
406 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 auto oldResultType = shuffleOp.getVectorType();
409 if (!spirv::CompositeType::isValid(oldResultType))
410 return failure();
411 Type newResultType = getTypeConverter()->convertType(oldResultType);
413 auto oldSourceType = shuffleOp.getV1VectorType();
414 if (oldSourceType.getNumElements() > 1) {
415 SmallVector<int32_t, 4> components = llvm::to_vector<4>(
416 llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
417 return attr.cast<IntegerAttr>().getValue().getZExtValue();
418 }));
419 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
420 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
421 rewriter.getI32ArrayAttr(components));
422 return success();
425 SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
426 SmallVector<Value, 4> newOperands;
427 newOperands.reserve(oldResultType.getNumElements());
428 for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
429 newOperands.push_back(oldOperands[i.getZExtValue()]);
431 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
432 shuffleOp, newResultType, newOperands);
434 return success();
438 } // namespace
439 #define CL_MAX_MIN_OPS \
440 spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
441 spirv::CLSMaxOp, spirv::CLSMinOp
443 #define GL_MAX_MIN_OPS \
444 spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
445 spirv::GLSMaxOp, spirv::GLSMinOp
447 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
448 RewritePatternSet &patterns) {
449 patterns.add<
450 VectorBitcastConvert, VectorBroadcastConvert,
451 VectorExtractElementOpConvert, VectorExtractOpConvert,
452 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
453 VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
454 VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
455 VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
456 VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
457 patterns.getContext());