Rename CODE_OWNERS -> Maintainers (#114544)
[llvm-project.git] / mlir / lib / Conversion / VectorToLLVM / ConvertVectorToLLVM.cpp
blob58ca84c8d7bca62f0d75cbb53bdf5734facce9d6
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/BuiltinTypeInterfaces.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
34 using namespace mlir;
35 using namespace mlir::vector;
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39 assert((tp.getRank() > 1) && "unlowerable vector type");
40 return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41 tp.getScalableDims().take_back());
44 // Helper that picks the proper sequence for inserting.
45 static Value insertOne(ConversionPatternRewriter &rewriter,
46 const LLVMTypeConverter &typeConverter, Location loc,
47 Value val1, Value val2, Type llvmType, int64_t rank,
48 int64_t pos) {
49 assert(rank > 0 && "0-D vector corner case should have been handled already");
50 if (rank == 1) {
51 auto idxType = rewriter.getIndexType();
52 auto constant = rewriter.create<LLVM::ConstantOp>(
53 loc, typeConverter.convertType(idxType),
54 rewriter.getIntegerAttr(idxType, pos));
55 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56 constant);
58 return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
61 // Helper that picks the proper sequence for extracting.
62 static Value extractOne(ConversionPatternRewriter &rewriter,
63 const LLVMTypeConverter &typeConverter, Location loc,
64 Value val, Type llvmType, int64_t rank, int64_t pos) {
65 if (rank <= 1) {
66 auto idxType = rewriter.getIndexType();
67 auto constant = rewriter.create<LLVM::ConstantOp>(
68 loc, typeConverter.convertType(idxType),
69 rewriter.getIntegerAttr(idxType, pos));
70 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71 constant);
73 return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
76 // Helper that returns data layout alignment of a memref.
77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
78 MemRefType memrefType, unsigned &align) {
79 Type elementTy = typeConverter.convertType(memrefType.getElementType());
80 if (!elementTy)
81 return failure();
83 // TODO: this should use the MLIR data layout when it becomes available and
84 // stop depending on translation.
85 llvm::LLVMContext llvmContext;
86 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87 .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88 return success();
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93 const LLVMTypeConverter &converter) {
94 if (!isLastMemrefDimUnitStride(memRefType))
95 return failure();
96 if (failed(converter.getMemRefAddressSpace(memRefType)))
97 return failure();
98 return success();
101 // Add an index vector component to a base pointer.
102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
103 const LLVMTypeConverter &typeConverter,
104 MemRefType memRefType, Value llvmMemref, Value base,
105 Value index, VectorType vectorType) {
106 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107 "unsupported memref type");
108 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
109 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
110 auto ptrsType =
111 LLVM::getVectorType(pType, vectorType.getDimSize(0),
112 /*isScalable=*/vectorType.getScalableDims()[0]);
113 return rewriter.create<LLVM::GEPOp>(
114 loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
115 base, index);
118 /// Convert `foldResult` into a Value. Integer attribute is converted to
119 /// an LLVM constant op.
120 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
121 OpFoldResult foldResult) {
122 if (auto attr = foldResult.dyn_cast<Attribute>()) {
123 auto intAttr = cast<IntegerAttr>(attr);
124 return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
127 return foldResult.get<Value>();
130 namespace {
132 /// Trivial Vector to LLVM conversions
133 using VectorScaleOpConversion =
134 OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
136 /// Conversion pattern for a vector.bitcast.
137 class VectorBitCastOpConversion
138 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
139 public:
140 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
142 LogicalResult
143 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 // Only 0-D and 1-D vectors can be lowered to LLVM.
146 VectorType resultTy = bitCastOp.getResultVectorType();
147 if (resultTy.getRank() > 1)
148 return failure();
149 Type newResultTy = typeConverter->convertType(resultTy);
150 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
151 adaptor.getOperands()[0]);
152 return success();
156 /// Conversion pattern for a vector.matrix_multiply.
157 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
158 class VectorMatmulOpConversion
159 : public ConvertOpToLLVMPattern<vector::MatmulOp> {
160 public:
161 using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
163 LogicalResult
164 matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter) const override {
166 rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
167 matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168 adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169 matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
170 return success();
174 /// Conversion pattern for a vector.flat_transpose.
175 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
176 class VectorFlatTransposeOpConversion
177 : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
178 public:
179 using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
181 LogicalResult
182 matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
185 transOp, typeConverter->convertType(transOp.getRes().getType()),
186 adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
187 return success();
191 /// Overloaded utility that replaces a vector.load, vector.store,
192 /// vector.maskedload and vector.maskedstore with their respective LLVM
193 /// couterparts.
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195 vector::LoadOpAdaptor adaptor,
196 VectorType vectorTy, Value ptr, unsigned align,
197 ConversionPatternRewriter &rewriter) {
198 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
199 /*volatile_=*/false,
200 loadOp.getNontemporal());
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204 vector::MaskedLoadOpAdaptor adaptor,
205 VectorType vectorTy, Value ptr, unsigned align,
206 ConversionPatternRewriter &rewriter) {
207 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
208 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212 vector::StoreOpAdaptor adaptor,
213 VectorType vectorTy, Value ptr, unsigned align,
214 ConversionPatternRewriter &rewriter) {
215 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
216 ptr, align, /*volatile_=*/false,
217 storeOp.getNontemporal());
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221 vector::MaskedStoreOpAdaptor adaptor,
222 VectorType vectorTy, Value ptr, unsigned align,
223 ConversionPatternRewriter &rewriter) {
224 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
225 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp>
231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
232 public:
233 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
235 LogicalResult
236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237 typename LoadOrStoreOp::Adaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 // Only 1-D vectors can be lowered to LLVM.
240 VectorType vectorTy = loadOrStoreOp.getVectorType();
241 if (vectorTy.getRank() > 1)
242 return failure();
244 auto loc = loadOrStoreOp->getLoc();
245 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
247 // Resolve alignment.
248 unsigned align;
249 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
250 return failure();
252 // Resolve address.
253 auto vtype = cast<VectorType>(
254 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255 Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256 adaptor.getIndices(), rewriter);
257 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
258 rewriter);
259 return success();
263 /// Conversion pattern for a vector.gather.
264 class VectorGatherOpConversion
265 : public ConvertOpToLLVMPattern<vector::GatherOp> {
266 public:
267 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
269 LogicalResult
270 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273 assert(memRefType && "The base should be bufferized");
275 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
276 return failure();
278 auto loc = gather->getLoc();
280 // Resolve alignment.
281 unsigned align;
282 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
283 return failure();
285 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286 adaptor.getIndices(), rewriter);
287 Value base = adaptor.getBase();
289 auto llvmNDVectorTy = adaptor.getIndexVec().getType();
290 // Handle the simple case of 1-D vector.
291 if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292 auto vType = gather.getVectorType();
293 // Resolve address.
294 Value ptrs =
295 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296 base, ptr, adaptor.getIndexVec(), vType);
297 // Replace with the gather intrinsic.
298 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
299 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
300 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
301 return success();
304 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
305 auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306 &typeConverter](Type llvm1DVectorTy,
307 ValueRange vectorOperands) {
308 // Resolve address.
309 Value ptrs = getIndexedPtrs(
310 rewriter, loc, typeConverter, memRefType, base, ptr,
311 /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
312 // Create the gather intrinsic.
313 return rewriter.create<LLVM::masked_gather>(
314 loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
315 /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
317 SmallVector<Value> vectorOperands = {
318 adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
319 return LLVM::detail::handleMultidimensionalVectors(
320 gather, vectorOperands, *getTypeConverter(), callback, rewriter);
324 /// Conversion pattern for a vector.scatter.
325 class VectorScatterOpConversion
326 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
327 public:
328 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
330 LogicalResult
331 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override {
333 auto loc = scatter->getLoc();
334 MemRefType memRefType = scatter.getMemRefType();
336 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
337 return failure();
339 // Resolve alignment.
340 unsigned align;
341 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
342 return failure();
344 // Resolve address.
345 VectorType vType = scatter.getVectorType();
346 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347 adaptor.getIndices(), rewriter);
348 Value ptrs =
349 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350 adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
352 // Replace with the scatter intrinsic.
353 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
354 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
355 rewriter.getI32IntegerAttr(align));
356 return success();
360 /// Conversion pattern for a vector.expandload.
361 class VectorExpandLoadOpConversion
362 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
363 public:
364 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
366 LogicalResult
367 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const override {
369 auto loc = expand->getLoc();
370 MemRefType memRefType = expand.getMemRefType();
372 // Resolve address.
373 auto vtype = typeConverter->convertType(expand.getVectorType());
374 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375 adaptor.getIndices(), rewriter);
377 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
378 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
379 return success();
383 /// Conversion pattern for a vector.compressstore.
384 class VectorCompressStoreOpConversion
385 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
386 public:
387 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
389 LogicalResult
390 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter) const override {
392 auto loc = compress->getLoc();
393 MemRefType memRefType = compress.getMemRefType();
395 // Resolve address.
396 Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397 adaptor.getIndices(), rewriter);
399 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
400 compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
401 return success();
405 /// Reduction neutral classes for overloading.
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
417 /// Create the reduction neutral zero value.
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
419 ConversionPatternRewriter &rewriter,
420 Location loc, Type llvmType) {
421 return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
422 rewriter.getZeroAttr(llvmType));
425 /// Create the reduction neutral integer one value.
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
427 ConversionPatternRewriter &rewriter,
428 Location loc, Type llvmType) {
429 return rewriter.create<LLVM::ConstantOp>(
430 loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
433 /// Create the reduction neutral fp one value.
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
435 ConversionPatternRewriter &rewriter,
436 Location loc, Type llvmType) {
437 return rewriter.create<LLVM::ConstantOp>(
438 loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
441 /// Create the reduction neutral all-ones value.
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
443 ConversionPatternRewriter &rewriter,
444 Location loc, Type llvmType) {
445 return rewriter.create<LLVM::ConstantOp>(
446 loc, llvmType,
447 rewriter.getIntegerAttr(
448 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
451 /// Create the reduction neutral signed int minimum value.
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
453 ConversionPatternRewriter &rewriter,
454 Location loc, Type llvmType) {
455 return rewriter.create<LLVM::ConstantOp>(
456 loc, llvmType,
457 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
458 llvmType.getIntOrFloatBitWidth())));
461 /// Create the reduction neutral unsigned int minimum value.
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
463 ConversionPatternRewriter &rewriter,
464 Location loc, Type llvmType) {
465 return rewriter.create<LLVM::ConstantOp>(
466 loc, llvmType,
467 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
468 llvmType.getIntOrFloatBitWidth())));
471 /// Create the reduction neutral signed int maximum value.
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
473 ConversionPatternRewriter &rewriter,
474 Location loc, Type llvmType) {
475 return rewriter.create<LLVM::ConstantOp>(
476 loc, llvmType,
477 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
478 llvmType.getIntOrFloatBitWidth())));
481 /// Create the reduction neutral unsigned int maximum value.
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
483 ConversionPatternRewriter &rewriter,
484 Location loc, Type llvmType) {
485 return rewriter.create<LLVM::ConstantOp>(
486 loc, llvmType,
487 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
488 llvmType.getIntOrFloatBitWidth())));
491 /// Create the reduction neutral fp minimum value.
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
493 ConversionPatternRewriter &rewriter,
494 Location loc, Type llvmType) {
495 auto floatType = cast<FloatType>(llvmType);
496 return rewriter.create<LLVM::ConstantOp>(
497 loc, llvmType,
498 rewriter.getFloatAttr(
499 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
500 /*Negative=*/false)));
503 /// Create the reduction neutral fp maximum value.
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
505 ConversionPatternRewriter &rewriter,
506 Location loc, Type llvmType) {
507 auto floatType = cast<FloatType>(llvmType);
508 return rewriter.create<LLVM::ConstantOp>(
509 loc, llvmType,
510 rewriter.getFloatAttr(
511 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
512 /*Negative=*/true)));
515 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
516 /// returns a new accumulator value using `ReductionNeutral`.
517 template <class ReductionNeutral>
518 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
519 Location loc, Type llvmType,
520 Value accumulator) {
521 if (accumulator)
522 return accumulator;
524 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
525 llvmType);
528 /// Creates a value with the 1-D vector shape provided in `llvmType`.
529 /// This is used as effective vector length by some intrinsics supporting
530 /// dynamic vector lengths at runtime.
531 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
532 Location loc, Type llvmType) {
533 VectorType vType = cast<VectorType>(llvmType);
534 auto vShape = vType.getShape();
535 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
537 Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
538 loc, rewriter.getI32Type(),
539 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
541 if (!vType.getScalableDims()[0])
542 return baseVecLength;
544 // For a scalable vector type, create and return `vScale * baseVecLength`.
545 Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
546 vScale =
547 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
548 Value scalableVecLength =
549 rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
550 return scalableVecLength;
553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
556 /// non-null.
557 template <class LLVMRedIntrinOp, class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
559 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
560 Value vectorOperand, Value accumulator) {
562 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
564 if (accumulator)
565 result = rewriter.create<ScalarOp>(loc, accumulator, result);
566 return result;
569 /// Helper method to lower a `vector.reduction` operation that performs
570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
572 /// the accumulator value if non-null.
573 template <class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
575 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
576 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
577 Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
578 if (accumulator) {
579 Value cmp =
580 rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581 result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
583 return result;
586 namespace {
587 template <typename Source>
588 struct VectorToScalarMapper;
589 template <>
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591 using Type = LLVM::MaximumOp;
593 template <>
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595 using Type = LLVM::MinimumOp;
597 template <>
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599 using Type = LLVM::MaxNumOp;
601 template <>
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603 using Type = LLVM::MinNumOp;
605 } // namespace
607 template <class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
609 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
610 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
611 Value result =
612 rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
614 if (accumulator) {
615 result =
616 rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617 loc, result, accumulator);
620 return result;
623 /// Reduction neutral classes for overloading
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
627 /// Get the mask neutral floating point maximum value
628 static llvm::APFloat
629 getMaskNeutralValue(MaskNeutralFMaximum,
630 const llvm::fltSemantics &floatSemantics) {
631 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
633 /// Get the mask neutral floating point minimum value
634 static llvm::APFloat
635 getMaskNeutralValue(MaskNeutralFMinimum,
636 const llvm::fltSemantics &floatSemantics) {
637 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
640 /// Create the mask neutral floating point MLIR vector constant
641 template <typename MaskNeutral>
642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
643 Location loc, Type llvmType,
644 Type vectorType) {
645 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
647 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
648 return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
653 /// `fmaximum`/`fminimum`.
654 /// More information: https://github.com/llvm/llvm-project/issues/64940
655 template <class LLVMRedIntrinOp, class MaskNeutral>
656 static Value
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
658 Location loc, Type llvmType,
659 Value vectorOperand, Value accumulator,
660 Value mask, LLVM::FastmathFlagsAttr fmf) {
661 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662 rewriter, loc, llvmType, vectorOperand.getType());
663 const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
664 loc, mask, vectorOperand, vectorMaskNeutral);
665 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
669 template <class LLVMRedIntrinOp, class ReductionNeutral>
670 static Value
671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
672 Type llvmType, Value vectorOperand,
673 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675 llvmType, accumulator);
676 return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
677 /*startValue=*/accumulator,
678 vectorOperand, fmf);
681 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
682 /// that requires a start value. This start value format spans across fp
683 /// reductions without mask and all the masked reduction intrinsics.
684 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
685 static Value
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
687 Location loc, Type llvmType,
688 Value vectorOperand, Value accumulator) {
689 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690 llvmType, accumulator);
691 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
692 /*startValue=*/accumulator,
693 vectorOperand);
696 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
698 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
699 Value vectorOperand, Value accumulator, Value mask) {
700 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701 llvmType, accumulator);
702 Value vectorLength =
703 createVectorLengthValue(rewriter, loc, vectorOperand.getType());
704 return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
705 /*startValue=*/accumulator,
706 vectorOperand, mask, vectorLength);
709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
710 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
712 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
713 Value vectorOperand, Value accumulator, Value mask) {
714 if (llvmType.isIntOrIndex())
715 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716 IntReductionNeutral>(
717 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
719 // FP dispatch.
720 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
721 FPReductionNeutral>(
722 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
725 /// Conversion pattern for all vector reductions.
726 class VectorReductionOpConversion
727 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
728 public:
729 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
730 bool reassociateFPRed)
731 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
732 reassociateFPReductions(reassociateFPRed) {}
734 LogicalResult
735 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 auto kind = reductionOp.getKind();
738 Type eltType = reductionOp.getDest().getType();
739 Type llvmType = typeConverter->convertType(eltType);
740 Value operand = adaptor.getVector();
741 Value acc = adaptor.getAcc();
742 Location loc = reductionOp.getLoc();
744 if (eltType.isIntOrIndex()) {
745 // Integer reductions: add/mul/min/max/and/or/xor.
746 Value result;
747 switch (kind) {
748 case vector::CombiningKind::ADD:
749 result =
750 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
751 LLVM::AddOp>(
752 rewriter, loc, llvmType, operand, acc);
753 break;
754 case vector::CombiningKind::MUL:
755 result =
756 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
757 LLVM::MulOp>(
758 rewriter, loc, llvmType, operand, acc);
759 break;
760 case vector::CombiningKind::MINUI:
761 result = createIntegerReductionComparisonOpLowering<
762 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763 LLVM::ICmpPredicate::ule);
764 break;
765 case vector::CombiningKind::MINSI:
766 result = createIntegerReductionComparisonOpLowering<
767 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768 LLVM::ICmpPredicate::sle);
769 break;
770 case vector::CombiningKind::MAXUI:
771 result = createIntegerReductionComparisonOpLowering<
772 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773 LLVM::ICmpPredicate::uge);
774 break;
775 case vector::CombiningKind::MAXSI:
776 result = createIntegerReductionComparisonOpLowering<
777 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778 LLVM::ICmpPredicate::sge);
779 break;
780 case vector::CombiningKind::AND:
781 result =
782 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
783 LLVM::AndOp>(
784 rewriter, loc, llvmType, operand, acc);
785 break;
786 case vector::CombiningKind::OR:
787 result =
788 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
789 LLVM::OrOp>(
790 rewriter, loc, llvmType, operand, acc);
791 break;
792 case vector::CombiningKind::XOR:
793 result =
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
795 LLVM::XOrOp>(
796 rewriter, loc, llvmType, operand, acc);
797 break;
798 default:
799 return failure();
801 rewriter.replaceOp(reductionOp, result);
803 return success();
806 if (!isa<FloatType>(eltType))
807 return failure();
809 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
810 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
811 reductionOp.getContext(),
812 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
813 fmf = LLVM::FastmathFlagsAttr::get(
814 reductionOp.getContext(),
815 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816 : LLVM::FastmathFlags::none));
818 // Floating-point reductions: add/mul/min/max
819 Value result;
820 if (kind == vector::CombiningKind::ADD) {
821 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822 ReductionNeutralZero>(
823 rewriter, loc, llvmType, operand, acc, fmf);
824 } else if (kind == vector::CombiningKind::MUL) {
825 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826 ReductionNeutralFPOne>(
827 rewriter, loc, llvmType, operand, acc, fmf);
828 } else if (kind == vector::CombiningKind::MINIMUMF) {
829 result =
830 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831 rewriter, loc, llvmType, operand, acc, fmf);
832 } else if (kind == vector::CombiningKind::MAXIMUMF) {
833 result =
834 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835 rewriter, loc, llvmType, operand, acc, fmf);
836 } else if (kind == vector::CombiningKind::MINNUMF) {
837 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838 rewriter, loc, llvmType, operand, acc, fmf);
839 } else if (kind == vector::CombiningKind::MAXNUMF) {
840 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841 rewriter, loc, llvmType, operand, acc, fmf);
842 } else
843 return failure();
845 rewriter.replaceOp(reductionOp, result);
846 return success();
849 private:
850 const bool reassociateFPReductions;
853 /// Base class to convert a `vector.mask` operation while matching traits
854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
856 /// method performs a second match against the maskable operation `MaskedOp`.
857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
858 /// implemented by the concrete conversion classes. This method can match
859 /// against specific traits of the `vector.mask` and the maskable operation. It
860 /// must replace the `vector.mask` operation.
861 template <class MaskedOp>
862 class VectorMaskOpConversionBase
863 : public ConvertOpToLLVMPattern<vector::MaskOp> {
864 public:
865 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
867 LogicalResult
868 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
869 ConversionPatternRewriter &rewriter) const final {
870 // Match against the maskable operation kind.
871 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
872 if (!maskedOp)
873 return failure();
874 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
877 protected:
878 virtual LogicalResult
879 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880 vector::MaskableOpInterface maskableOp,
881 ConversionPatternRewriter &rewriter) const = 0;
884 class MaskedReductionOpConversion
885 : public VectorMaskOpConversionBase<vector::ReductionOp> {
887 public:
888 using VectorMaskOpConversionBase<
889 vector::ReductionOp>::VectorMaskOpConversionBase;
891 LogicalResult matchAndRewriteMaskableOp(
892 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
893 ConversionPatternRewriter &rewriter) const override {
894 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895 auto kind = reductionOp.getKind();
896 Type eltType = reductionOp.getDest().getType();
897 Type llvmType = typeConverter->convertType(eltType);
898 Value operand = reductionOp.getVector();
899 Value acc = reductionOp.getAcc();
900 Location loc = reductionOp.getLoc();
902 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
903 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
904 reductionOp.getContext(),
905 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
907 Value result;
908 switch (kind) {
909 case vector::CombiningKind::ADD:
910 result = lowerPredicatedReductionWithStartValue<
911 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
913 maskOp.getMask());
914 break;
915 case vector::CombiningKind::MUL:
916 result = lowerPredicatedReductionWithStartValue<
917 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
919 maskOp.getMask());
920 break;
921 case vector::CombiningKind::MINUI:
922 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923 ReductionNeutralUIntMax>(
924 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
925 break;
926 case vector::CombiningKind::MINSI:
927 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928 ReductionNeutralSIntMax>(
929 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
930 break;
931 case vector::CombiningKind::MAXUI:
932 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933 ReductionNeutralUIntMin>(
934 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
935 break;
936 case vector::CombiningKind::MAXSI:
937 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938 ReductionNeutralSIntMin>(
939 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
940 break;
941 case vector::CombiningKind::AND:
942 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943 ReductionNeutralAllOnes>(
944 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
945 break;
946 case vector::CombiningKind::OR:
947 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948 ReductionNeutralZero>(
949 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
950 break;
951 case vector::CombiningKind::XOR:
952 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953 ReductionNeutralZero>(
954 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
955 break;
956 case vector::CombiningKind::MINNUMF:
957 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958 ReductionNeutralFPMax>(
959 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
960 break;
961 case vector::CombiningKind::MAXNUMF:
962 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963 ReductionNeutralFPMin>(
964 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
965 break;
966 case CombiningKind::MAXIMUMF:
967 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968 MaskNeutralFMaximum>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
970 break;
971 case CombiningKind::MINIMUMF:
972 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973 MaskNeutralFMinimum>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
975 break;
978 // Replace `vector.mask` operation altogether.
979 rewriter.replaceOp(maskOp, result);
980 return success();
984 class VectorShuffleOpConversion
985 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
986 public:
987 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
989 LogicalResult
990 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
991 ConversionPatternRewriter &rewriter) const override {
992 auto loc = shuffleOp->getLoc();
993 auto v1Type = shuffleOp.getV1VectorType();
994 auto v2Type = shuffleOp.getV2VectorType();
995 auto vectorType = shuffleOp.getResultVectorType();
996 Type llvmType = typeConverter->convertType(vectorType);
997 ArrayRef<int64_t> mask = shuffleOp.getMask();
999 // Bail if result type cannot be lowered.
1000 if (!llvmType)
1001 return failure();
1003 // Get rank and dimension sizes.
1004 int64_t rank = vectorType.getRank();
1005 #ifndef NDEBUG
1006 bool wellFormed0DCase =
1007 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008 bool wellFormedNDCase =
1009 v1Type.getRank() == rank && v2Type.getRank() == rank;
1010 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1011 #endif
1013 // For rank 0 and 1, where both operands have *exactly* the same vector
1014 // type, there is direct shuffle support in LLVM. Use it!
1015 if (rank <= 1 && v1Type == v2Type) {
1016 Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1017 loc, adaptor.getV1(), adaptor.getV2(),
1018 llvm::to_vector_of<int32_t>(mask));
1019 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1020 return success();
1023 // For all other cases, insert the individual values individually.
1024 int64_t v1Dim = v1Type.getDimSize(0);
1025 Type eltType;
1026 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027 eltType = arrayType.getElementType();
1028 else
1029 eltType = cast<VectorType>(llvmType).getElementType();
1030 Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1031 int64_t insPos = 0;
1032 for (int64_t extPos : mask) {
1033 Value value = adaptor.getV1();
1034 if (extPos >= v1Dim) {
1035 extPos -= v1Dim;
1036 value = adaptor.getV2();
1038 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1039 eltType, rank, extPos);
1040 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041 llvmType, rank, insPos++);
1043 rewriter.replaceOp(shuffleOp, insert);
1044 return success();
1048 class VectorExtractElementOpConversion
1049 : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1050 public:
1051 using ConvertOpToLLVMPattern<
1052 vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1054 LogicalResult
1055 matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1056 ConversionPatternRewriter &rewriter) const override {
1057 auto vectorType = extractEltOp.getSourceVectorType();
1058 auto llvmType = typeConverter->convertType(vectorType.getElementType());
1060 // Bail if result type cannot be lowered.
1061 if (!llvmType)
1062 return failure();
1064 if (vectorType.getRank() == 0) {
1065 Location loc = extractEltOp.getLoc();
1066 auto idxType = rewriter.getIndexType();
1067 auto zero = rewriter.create<LLVM::ConstantOp>(
1068 loc, typeConverter->convertType(idxType),
1069 rewriter.getIntegerAttr(idxType, 0));
1070 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1071 extractEltOp, llvmType, adaptor.getVector(), zero);
1072 return success();
1075 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1076 extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1077 return success();
1081 class VectorExtractOpConversion
1082 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1083 public:
1084 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1086 LogicalResult
1087 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1088 ConversionPatternRewriter &rewriter) const override {
1089 auto loc = extractOp->getLoc();
1090 auto resultType = extractOp.getResult().getType();
1091 auto llvmResultType = typeConverter->convertType(resultType);
1092 // Bail if result type cannot be lowered.
1093 if (!llvmResultType)
1094 return failure();
1096 SmallVector<OpFoldResult> positionVec = getMixedValues(
1097 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1099 // Extract entire vector. Should be handled by folder, but just to be safe.
1100 ArrayRef<OpFoldResult> position(positionVec);
1101 if (position.empty()) {
1102 rewriter.replaceOp(extractOp, adaptor.getVector());
1103 return success();
1106 // One-shot extraction of vector from array (only requires extractvalue).
1107 // Except for extracting 1-element vectors.
1108 if (isa<VectorType>(resultType) &&
1109 position.size() !=
1110 static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
1111 if (extractOp.hasDynamicPosition())
1112 return failure();
1114 Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1115 loc, adaptor.getVector(), getAsIntegers(position));
1116 rewriter.replaceOp(extractOp, extracted);
1117 return success();
1120 // Potential extraction of 1-D vector from array.
1121 Value extracted = adaptor.getVector();
1122 if (position.size() > 1) {
1123 if (extractOp.hasDynamicPosition())
1124 return failure();
1126 SmallVector<int64_t> nMinusOnePosition =
1127 getAsIntegers(position.drop_back());
1128 extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1129 nMinusOnePosition);
1132 Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1133 // Remaining extraction of element from 1-D LLVM vector.
1134 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1135 lastPosition);
1136 return success();
1140 /// Conversion pattern that turns a vector.fma on a 1-D vector
1141 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1142 /// This does not match vectors of n >= 2 rank.
1144 /// Example:
1145 /// ```
1146 /// vector.fma %a, %a, %a : vector<8xf32>
1147 /// ```
1148 /// is converted to:
1149 /// ```
1150 /// llvm.intr.fmuladd %va, %va, %va:
1151 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1152 /// -> !llvm."<8 x f32>">
1153 /// ```
1154 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1155 public:
1156 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1158 LogicalResult
1159 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1160 ConversionPatternRewriter &rewriter) const override {
1161 VectorType vType = fmaOp.getVectorType();
1162 if (vType.getRank() > 1)
1163 return failure();
1165 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1166 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1167 return success();
1171 class VectorInsertElementOpConversion
1172 : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1173 public:
1174 using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1176 LogicalResult
1177 matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter) const override {
1179 auto vectorType = insertEltOp.getDestVectorType();
1180 auto llvmType = typeConverter->convertType(vectorType);
1182 // Bail if result type cannot be lowered.
1183 if (!llvmType)
1184 return failure();
1186 if (vectorType.getRank() == 0) {
1187 Location loc = insertEltOp.getLoc();
1188 auto idxType = rewriter.getIndexType();
1189 auto zero = rewriter.create<LLVM::ConstantOp>(
1190 loc, typeConverter->convertType(idxType),
1191 rewriter.getIntegerAttr(idxType, 0));
1192 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1193 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1194 return success();
1197 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1198 insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1199 adaptor.getPosition());
1200 return success();
1204 class VectorInsertOpConversion
1205 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1206 public:
1207 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1209 LogicalResult
1210 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1211 ConversionPatternRewriter &rewriter) const override {
1212 auto loc = insertOp->getLoc();
1213 auto sourceType = insertOp.getSourceType();
1214 auto destVectorType = insertOp.getDestVectorType();
1215 auto llvmResultType = typeConverter->convertType(destVectorType);
1216 // Bail if result type cannot be lowered.
1217 if (!llvmResultType)
1218 return failure();
1220 SmallVector<OpFoldResult> positionVec = getMixedValues(
1221 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1223 // Overwrite entire vector with value. Should be handled by folder, but
1224 // just to be safe.
1225 ArrayRef<OpFoldResult> position(positionVec);
1226 if (position.empty()) {
1227 rewriter.replaceOp(insertOp, adaptor.getSource());
1228 return success();
1231 // One-shot insertion of a vector into an array (only requires insertvalue).
1232 if (isa<VectorType>(sourceType)) {
1233 if (insertOp.hasDynamicPosition())
1234 return failure();
1236 Value inserted = rewriter.create<LLVM::InsertValueOp>(
1237 loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1238 rewriter.replaceOp(insertOp, inserted);
1239 return success();
1242 // Potential extraction of 1-D vector from array.
1243 Value extracted = adaptor.getDest();
1244 auto oneDVectorType = destVectorType;
1245 if (position.size() > 1) {
1246 if (insertOp.hasDynamicPosition())
1247 return failure();
1249 oneDVectorType = reducedVectorTypeBack(destVectorType);
1250 extracted = rewriter.create<LLVM::ExtractValueOp>(
1251 loc, extracted, getAsIntegers(position.drop_back()));
1254 // Insertion of an element into a 1-D LLVM vector.
1255 Value inserted = rewriter.create<LLVM::InsertElementOp>(
1256 loc, typeConverter->convertType(oneDVectorType), extracted,
1257 adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1259 // Potential insertion of resulting 1-D vector into array.
1260 if (position.size() > 1) {
1261 if (insertOp.hasDynamicPosition())
1262 return failure();
1264 inserted = rewriter.create<LLVM::InsertValueOp>(
1265 loc, adaptor.getDest(), inserted,
1266 getAsIntegers(position.drop_back()));
1269 rewriter.replaceOp(insertOp, inserted);
1270 return success();
1274 /// Lower vector.scalable.insert ops to LLVM vector.insert
1275 struct VectorScalableInsertOpLowering
1276 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1277 using ConvertOpToLLVMPattern<
1278 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1280 LogicalResult
1281 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1282 ConversionPatternRewriter &rewriter) const override {
1283 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1284 insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1285 return success();
1289 /// Lower vector.scalable.extract ops to LLVM vector.extract
1290 struct VectorScalableExtractOpLowering
1291 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1292 using ConvertOpToLLVMPattern<
1293 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1295 LogicalResult
1296 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1297 ConversionPatternRewriter &rewriter) const override {
1298 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1299 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1300 adaptor.getSource(), adaptor.getPos());
1301 return success();
1305 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1307 /// Example:
1308 /// ```
1309 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1310 /// ```
1311 /// is rewritten into:
1312 /// ```
1313 /// %r = splat %f0: vector<2x4xf32>
1314 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1315 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1316 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1317 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1318 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1319 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1320 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1321 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1322 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1323 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1324 /// // %r3 holds the final value.
1325 /// ```
1326 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1327 public:
1328 using OpRewritePattern<FMAOp>::OpRewritePattern;
1330 void initialize() {
1331 // This pattern recursively unpacks one dimension at a time. The recursion
1332 // bounded as the rank is strictly decreasing.
1333 setHasBoundedRewriteRecursion();
1336 LogicalResult matchAndRewrite(FMAOp op,
1337 PatternRewriter &rewriter) const override {
1338 auto vType = op.getVectorType();
1339 if (vType.getRank() < 2)
1340 return failure();
1342 auto loc = op.getLoc();
1343 auto elemType = vType.getElementType();
1344 Value zero = rewriter.create<arith::ConstantOp>(
1345 loc, elemType, rewriter.getZeroAttr(elemType));
1346 Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1347 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1348 Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1349 Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1350 Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1351 Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1352 desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1354 rewriter.replaceOp(op, desc);
1355 return success();
1359 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1360 /// static layout.
1361 static std::optional<SmallVector<int64_t, 4>>
1362 computeContiguousStrides(MemRefType memRefType) {
1363 int64_t offset;
1364 SmallVector<int64_t, 4> strides;
1365 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1366 return std::nullopt;
1367 if (!strides.empty() && strides.back() != 1)
1368 return std::nullopt;
1369 // If no layout or identity layout, this is contiguous by definition.
1370 if (memRefType.getLayout().isIdentity())
1371 return strides;
1373 // Otherwise, we must determine contiguity form shapes. This can only ever
1374 // work in static cases because MemRefType is underspecified to represent
1375 // contiguous dynamic shapes in other ways than with just empty/identity
1376 // layout.
1377 auto sizes = memRefType.getShape();
1378 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1379 if (ShapedType::isDynamic(sizes[index + 1]) ||
1380 ShapedType::isDynamic(strides[index]) ||
1381 ShapedType::isDynamic(strides[index + 1]))
1382 return std::nullopt;
1383 if (strides[index] != strides[index + 1] * sizes[index + 1])
1384 return std::nullopt;
1386 return strides;
1389 class VectorTypeCastOpConversion
1390 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1391 public:
1392 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1394 LogicalResult
1395 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter) const override {
1397 auto loc = castOp->getLoc();
1398 MemRefType sourceMemRefType =
1399 cast<MemRefType>(castOp.getOperand().getType());
1400 MemRefType targetMemRefType = castOp.getType();
1402 // Only static shape casts supported atm.
1403 if (!sourceMemRefType.hasStaticShape() ||
1404 !targetMemRefType.hasStaticShape())
1405 return failure();
1407 auto llvmSourceDescriptorTy =
1408 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1409 if (!llvmSourceDescriptorTy)
1410 return failure();
1411 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1413 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1414 typeConverter->convertType(targetMemRefType));
1415 if (!llvmTargetDescriptorTy)
1416 return failure();
1418 // Only contiguous source buffers supported atm.
1419 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1420 if (!sourceStrides)
1421 return failure();
1422 auto targetStrides = computeContiguousStrides(targetMemRefType);
1423 if (!targetStrides)
1424 return failure();
1425 // Only support static strides for now, regardless of contiguity.
1426 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1427 return failure();
1429 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1431 // Create descriptor.
1432 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1433 // Set allocated ptr.
1434 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1435 desc.setAllocatedPtr(rewriter, loc, allocated);
1437 // Set aligned ptr.
1438 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1439 desc.setAlignedPtr(rewriter, loc, ptr);
1440 // Fill offset 0.
1441 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1442 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1443 desc.setOffset(rewriter, loc, zero);
1445 // Fill size and stride descriptors in memref.
1446 for (const auto &indexedSize :
1447 llvm::enumerate(targetMemRefType.getShape())) {
1448 int64_t index = indexedSize.index();
1449 auto sizeAttr =
1450 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1451 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1452 desc.setSize(rewriter, loc, index, size);
1453 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1454 (*targetStrides)[index]);
1455 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1456 desc.setStride(rewriter, loc, index, stride);
1459 rewriter.replaceOp(castOp, {desc});
1460 return success();
1464 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1465 /// Non-scalable versions of this operation are handled in Vector Transforms.
1466 class VectorCreateMaskOpRewritePattern
1467 : public OpRewritePattern<vector::CreateMaskOp> {
1468 public:
1469 explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1470 bool enableIndexOpt)
1471 : OpRewritePattern<vector::CreateMaskOp>(context),
1472 force32BitVectorIndices(enableIndexOpt) {}
1474 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1475 PatternRewriter &rewriter) const override {
1476 auto dstType = op.getType();
1477 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1478 return failure();
1479 IntegerType idxType =
1480 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1481 auto loc = op->getLoc();
1482 Value indices = rewriter.create<LLVM::StepVectorOp>(
1483 loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1484 /*isScalable=*/true));
1485 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1486 op.getOperand(0));
1487 Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1488 Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1489 indices, bounds);
1490 rewriter.replaceOp(op, comp);
1491 return success();
1494 private:
1495 const bool force32BitVectorIndices;
1498 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1499 public:
1500 using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1502 // Lowering implementation that relies on a small runtime support library,
1503 // which only needs to provide a few printing methods (single value for all
1504 // data types, opening/closing bracket, comma, newline). The lowering splits
1505 // the vector into elementary printing operations. The advantage of this
1506 // approach is that the library can remain unaware of all low-level
1507 // implementation details of vectors while still supporting output of any
1508 // shaped and dimensioned vector.
1510 // Note: This lowering only handles scalars, n-D vectors are broken into
1511 // printing scalars in loops in VectorToSCF.
1513 // TODO: rely solely on libc in future? something else?
1515 LogicalResult
1516 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1517 ConversionPatternRewriter &rewriter) const override {
1518 auto parent = printOp->getParentOfType<ModuleOp>();
1519 if (!parent)
1520 return failure();
1522 auto loc = printOp->getLoc();
1524 if (auto value = adaptor.getSource()) {
1525 Type printType = printOp.getPrintType();
1526 if (isa<VectorType>(printType)) {
1527 // Vectors should be broken into elementary print ops in VectorToSCF.
1528 return failure();
1530 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1531 return failure();
1534 auto punct = printOp.getPunctuation();
1535 if (auto stringLiteral = printOp.getStringLiteral()) {
1536 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1537 *stringLiteral, *getTypeConverter(),
1538 /*addNewline=*/false);
1539 } else if (punct != PrintPunctuation::NoPunctuation) {
1540 emitCall(rewriter, printOp->getLoc(), [&] {
1541 switch (punct) {
1542 case PrintPunctuation::Close:
1543 return LLVM::lookupOrCreatePrintCloseFn(parent);
1544 case PrintPunctuation::Open:
1545 return LLVM::lookupOrCreatePrintOpenFn(parent);
1546 case PrintPunctuation::Comma:
1547 return LLVM::lookupOrCreatePrintCommaFn(parent);
1548 case PrintPunctuation::NewLine:
1549 return LLVM::lookupOrCreatePrintNewlineFn(parent);
1550 default:
1551 llvm_unreachable("unexpected punctuation");
1553 }());
1556 rewriter.eraseOp(printOp);
1557 return success();
1560 private:
1561 enum class PrintConversion {
1562 // clang-format off
1563 None,
1564 ZeroExt64,
1565 SignExt64,
1566 Bitcast16
1567 // clang-format on
1570 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1571 ModuleOp parent, Location loc, Type printType,
1572 Value value) const {
1573 if (typeConverter->convertType(printType) == nullptr)
1574 return failure();
1576 // Make sure element type has runtime support.
1577 PrintConversion conversion = PrintConversion::None;
1578 Operation *printer;
1579 if (printType.isF32()) {
1580 printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1581 } else if (printType.isF64()) {
1582 printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1583 } else if (printType.isF16()) {
1584 conversion = PrintConversion::Bitcast16; // bits!
1585 printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1586 } else if (printType.isBF16()) {
1587 conversion = PrintConversion::Bitcast16; // bits!
1588 printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1589 } else if (printType.isIndex()) {
1590 printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1591 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1592 // Integers need a zero or sign extension on the operand
1593 // (depending on the source type) as well as a signed or
1594 // unsigned print method. Up to 64-bit is supported.
1595 unsigned width = intTy.getWidth();
1596 if (intTy.isUnsigned()) {
1597 if (width <= 64) {
1598 if (width < 64)
1599 conversion = PrintConversion::ZeroExt64;
1600 printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1601 } else {
1602 return failure();
1604 } else {
1605 assert(intTy.isSignless() || intTy.isSigned());
1606 if (width <= 64) {
1607 // Note that we *always* zero extend booleans (1-bit integers),
1608 // so that true/false is printed as 1/0 rather than -1/0.
1609 if (width == 1)
1610 conversion = PrintConversion::ZeroExt64;
1611 else if (width < 64)
1612 conversion = PrintConversion::SignExt64;
1613 printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1614 } else {
1615 return failure();
1618 } else {
1619 return failure();
1622 switch (conversion) {
1623 case PrintConversion::ZeroExt64:
1624 value = rewriter.create<arith::ExtUIOp>(
1625 loc, IntegerType::get(rewriter.getContext(), 64), value);
1626 break;
1627 case PrintConversion::SignExt64:
1628 value = rewriter.create<arith::ExtSIOp>(
1629 loc, IntegerType::get(rewriter.getContext(), 64), value);
1630 break;
1631 case PrintConversion::Bitcast16:
1632 value = rewriter.create<LLVM::BitcastOp>(
1633 loc, IntegerType::get(rewriter.getContext(), 16), value);
1634 break;
1635 case PrintConversion::None:
1636 break;
1638 emitCall(rewriter, loc, printer, value);
1639 return success();
1642 // Helper to emit a call.
1643 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1644 Operation *ref, ValueRange params = ValueRange()) {
1645 rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1646 params);
1650 /// The Splat operation is lowered to an insertelement + a shufflevector
1651 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1652 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1653 using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1655 LogicalResult
1656 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1657 ConversionPatternRewriter &rewriter) const override {
1658 VectorType resultType = cast<VectorType>(splatOp.getType());
1659 if (resultType.getRank() > 1)
1660 return failure();
1662 // First insert it into an undef vector so we can shuffle it.
1663 auto vectorType = typeConverter->convertType(splatOp.getType());
1664 Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1665 auto zero = rewriter.create<LLVM::ConstantOp>(
1666 splatOp.getLoc(),
1667 typeConverter->convertType(rewriter.getIntegerType(32)),
1668 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1670 // For 0-d vector, we simply do `insertelement`.
1671 if (resultType.getRank() == 0) {
1672 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1673 splatOp, vectorType, undef, adaptor.getInput(), zero);
1674 return success();
1677 // For 1-d vector, we additionally do a `vectorshuffle`.
1678 auto v = rewriter.create<LLVM::InsertElementOp>(
1679 splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1681 int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1682 SmallVector<int32_t> zeroValues(width, 0);
1684 // Shuffle the value across the desired number of elements.
1685 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1686 zeroValues);
1687 return success();
1691 /// The Splat operation is lowered to an insertelement + a shufflevector
1692 /// operation. Splat to only 2+-d vector result types are lowered by the
1693 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1694 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1695 using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1697 LogicalResult
1698 matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1699 ConversionPatternRewriter &rewriter) const override {
1700 VectorType resultType = splatOp.getType();
1701 if (resultType.getRank() <= 1)
1702 return failure();
1704 // First insert it into an undef vector so we can shuffle it.
1705 auto loc = splatOp.getLoc();
1706 auto vectorTypeInfo =
1707 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1708 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1709 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1710 if (!llvmNDVectorTy || !llvm1DVectorTy)
1711 return failure();
1713 // Construct returned value.
1714 Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1716 // Construct a 1-D vector with the splatted value that we insert in all the
1717 // places within the returned descriptor.
1718 Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1719 auto zero = rewriter.create<LLVM::ConstantOp>(
1720 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1721 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1722 Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1723 adaptor.getInput(), zero);
1725 // Shuffle the value across the desired number of elements.
1726 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1727 SmallVector<int32_t> zeroValues(width, 0);
1728 v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1730 // Iterate of linear index, convert to coords space and insert splatted 1-D
1731 // vector in each position.
1732 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1733 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1735 rewriter.replaceOp(splatOp, desc);
1736 return success();
1740 /// Conversion pattern for a `vector.interleave`.
1741 /// This supports fixed-sized vectors and scalable vectors.
1742 struct VectorInterleaveOpLowering
1743 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1744 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1746 LogicalResult
1747 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1748 ConversionPatternRewriter &rewriter) const override {
1749 VectorType resultType = interleaveOp.getResultVectorType();
1750 // n-D interleaves should have been lowered already.
1751 if (resultType.getRank() != 1)
1752 return rewriter.notifyMatchFailure(interleaveOp,
1753 "InterleaveOp not rank 1");
1754 // If the result is rank 1, then this directly maps to LLVM.
1755 if (resultType.isScalable()) {
1756 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1757 interleaveOp, typeConverter->convertType(resultType),
1758 adaptor.getLhs(), adaptor.getRhs());
1759 return success();
1761 // Lower fixed-size interleaves to a shufflevector. While the
1762 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1763 // langref still recommends fixed-vectors use shufflevector, see:
1764 // https://llvm.org/docs/LangRef.html#id876.
1765 int64_t resultVectorSize = resultType.getNumElements();
1766 SmallVector<int32_t> interleaveShuffleMask;
1767 interleaveShuffleMask.reserve(resultVectorSize);
1768 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1769 interleaveShuffleMask.push_back(i);
1770 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1772 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1773 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1774 interleaveShuffleMask);
1775 return success();
1779 /// Conversion pattern for a `vector.deinterleave`.
1780 /// This supports fixed-sized vectors and scalable vectors.
1781 struct VectorDeinterleaveOpLowering
1782 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1783 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1785 LogicalResult
1786 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1787 ConversionPatternRewriter &rewriter) const override {
1788 VectorType resultType = deinterleaveOp.getResultVectorType();
1789 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1790 auto loc = deinterleaveOp.getLoc();
1792 // Note: n-D deinterleave operations should be lowered to the 1-D before
1793 // converting to LLVM.
1794 if (resultType.getRank() != 1)
1795 return rewriter.notifyMatchFailure(deinterleaveOp,
1796 "DeinterleaveOp not rank 1");
1798 if (resultType.isScalable()) {
1799 auto llvmTypeConverter = this->getTypeConverter();
1800 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1801 auto packedOpResults =
1802 llvmTypeConverter->packOperationResults(deinterleaveResults);
1803 auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1804 loc, packedOpResults, adaptor.getSource());
1806 auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1807 loc, intrinsic->getResult(0), 0);
1808 auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1809 loc, intrinsic->getResult(0), 1);
1811 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1812 return success();
1814 // Lower fixed-size deinterleave to two shufflevectors. While the
1815 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1816 // langref still recommends fixed-vectors use shufflevector, see:
1817 // https://llvm.org/docs/LangRef.html#id889.
1818 int64_t resultVectorSize = resultType.getNumElements();
1819 SmallVector<int32_t> evenShuffleMask;
1820 SmallVector<int32_t> oddShuffleMask;
1822 evenShuffleMask.reserve(resultVectorSize);
1823 oddShuffleMask.reserve(resultVectorSize);
1825 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1826 if (i % 2 == 0)
1827 evenShuffleMask.push_back(i);
1828 else
1829 oddShuffleMask.push_back(i);
1832 auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1833 auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1834 loc, adaptor.getSource(), poison, evenShuffleMask);
1835 auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1836 loc, adaptor.getSource(), poison, oddShuffleMask);
1838 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1839 return success();
1843 /// Conversion pattern for a `vector.from_elements`.
1844 struct VectorFromElementsLowering
1845 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1846 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1848 LogicalResult
1849 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1850 ConversionPatternRewriter &rewriter) const override {
1851 Location loc = fromElementsOp.getLoc();
1852 VectorType vectorType = fromElementsOp.getType();
1853 // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1854 // Such ops should be handled in the same way as vector.insert.
1855 if (vectorType.getRank() > 1)
1856 return rewriter.notifyMatchFailure(fromElementsOp,
1857 "rank > 1 vectors are not supported");
1858 Type llvmType = typeConverter->convertType(vectorType);
1859 Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1860 for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1861 result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1862 rewriter.replaceOp(fromElementsOp, result);
1863 return success();
1867 /// Conversion pattern for vector.step.
1868 struct VectorScalableStepOpLowering
1869 : public ConvertOpToLLVMPattern<vector::StepOp> {
1870 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1872 LogicalResult
1873 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1874 ConversionPatternRewriter &rewriter) const override {
1875 auto resultType = cast<VectorType>(stepOp.getType());
1876 if (!resultType.isScalable()) {
1877 return failure();
1879 Type llvmType = typeConverter->convertType(stepOp.getType());
1880 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1881 return success();
1885 } // namespace
1887 /// Populate the given list with patterns that convert from Vector to LLVM.
1888 void mlir::populateVectorToLLVMConversionPatterns(
1889 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1890 bool reassociateFPReductions, bool force32BitVectorIndices) {
1891 MLIRContext *ctx = converter.getDialect()->getContext();
1892 patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1893 populateVectorInsertExtractStridedSliceTransforms(patterns);
1894 populateVectorStepLoweringPatterns(patterns);
1895 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1896 patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1897 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1898 VectorExtractElementOpConversion, VectorExtractOpConversion,
1899 VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1900 VectorInsertOpConversion, VectorPrintOpConversion,
1901 VectorTypeCastOpConversion, VectorScaleOpConversion,
1902 VectorLoadStoreConversion<vector::LoadOp>,
1903 VectorLoadStoreConversion<vector::MaskedLoadOp>,
1904 VectorLoadStoreConversion<vector::StoreOp>,
1905 VectorLoadStoreConversion<vector::MaskedStoreOp>,
1906 VectorGatherOpConversion, VectorScatterOpConversion,
1907 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1908 VectorSplatOpLowering, VectorSplatNdOpLowering,
1909 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1910 MaskedReductionOpConversion, VectorInterleaveOpLowering,
1911 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1912 VectorScalableStepOpLowering>(converter);
1913 // Transfer ops with rank > 1 are handled by VectorToSCF.
1914 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1917 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1918 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1919 patterns.add<VectorMatmulOpConversion>(converter);
1920 patterns.add<VectorFlatTransposeOpConversion>(converter);