1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/BuiltinTypeInterfaces.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
35 using namespace mlir::vector
;
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType
reducedVectorTypeBack(VectorType tp
) {
39 assert((tp
.getRank() > 1) && "unlowerable vector type");
40 return VectorType::get(tp
.getShape().take_back(), tp
.getElementType(),
41 tp
.getScalableDims().take_back());
44 // Helper that picks the proper sequence for inserting.
45 static Value
insertOne(ConversionPatternRewriter
&rewriter
,
46 const LLVMTypeConverter
&typeConverter
, Location loc
,
47 Value val1
, Value val2
, Type llvmType
, int64_t rank
,
49 assert(rank
> 0 && "0-D vector corner case should have been handled already");
51 auto idxType
= rewriter
.getIndexType();
52 auto constant
= rewriter
.create
<LLVM::ConstantOp
>(
53 loc
, typeConverter
.convertType(idxType
),
54 rewriter
.getIntegerAttr(idxType
, pos
));
55 return rewriter
.create
<LLVM::InsertElementOp
>(loc
, llvmType
, val1
, val2
,
58 return rewriter
.create
<LLVM::InsertValueOp
>(loc
, val1
, val2
, pos
);
61 // Helper that picks the proper sequence for extracting.
62 static Value
extractOne(ConversionPatternRewriter
&rewriter
,
63 const LLVMTypeConverter
&typeConverter
, Location loc
,
64 Value val
, Type llvmType
, int64_t rank
, int64_t pos
) {
66 auto idxType
= rewriter
.getIndexType();
67 auto constant
= rewriter
.create
<LLVM::ConstantOp
>(
68 loc
, typeConverter
.convertType(idxType
),
69 rewriter
.getIntegerAttr(idxType
, pos
));
70 return rewriter
.create
<LLVM::ExtractElementOp
>(loc
, llvmType
, val
,
73 return rewriter
.create
<LLVM::ExtractValueOp
>(loc
, val
, pos
);
76 // Helper that returns data layout alignment of a memref.
77 LogicalResult
getMemRefAlignment(const LLVMTypeConverter
&typeConverter
,
78 MemRefType memrefType
, unsigned &align
) {
79 Type elementTy
= typeConverter
.convertType(memrefType
.getElementType());
83 // TODO: this should use the MLIR data layout when it becomes available and
84 // stop depending on translation.
85 llvm::LLVMContext llvmContext
;
86 align
= LLVM::TypeToLLVMIRTranslator(llvmContext
)
87 .getPreferredAlignment(elementTy
, typeConverter
.getDataLayout());
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult
isMemRefTypeSupported(MemRefType memRefType
,
93 const LLVMTypeConverter
&converter
) {
94 if (!isLastMemrefDimUnitStride(memRefType
))
96 if (failed(converter
.getMemRefAddressSpace(memRefType
)))
101 // Add an index vector component to a base pointer.
102 static Value
getIndexedPtrs(ConversionPatternRewriter
&rewriter
, Location loc
,
103 const LLVMTypeConverter
&typeConverter
,
104 MemRefType memRefType
, Value llvmMemref
, Value base
,
105 Value index
, VectorType vectorType
) {
106 assert(succeeded(isMemRefTypeSupported(memRefType
, typeConverter
)) &&
107 "unsupported memref type");
108 assert(vectorType
.getRank() == 1 && "expected a 1-d vector type");
109 auto pType
= MemRefDescriptor(llvmMemref
).getElementPtrType();
111 LLVM::getVectorType(pType
, vectorType
.getDimSize(0),
112 /*isScalable=*/vectorType
.getScalableDims()[0]);
113 return rewriter
.create
<LLVM::GEPOp
>(
114 loc
, ptrsType
, typeConverter
.convertType(memRefType
.getElementType()),
118 /// Convert `foldResult` into a Value. Integer attribute is converted to
119 /// an LLVM constant op.
120 static Value
getAsLLVMValue(OpBuilder
&builder
, Location loc
,
121 OpFoldResult foldResult
) {
122 if (auto attr
= foldResult
.dyn_cast
<Attribute
>()) {
123 auto intAttr
= cast
<IntegerAttr
>(attr
);
124 return builder
.create
<LLVM::ConstantOp
>(loc
, intAttr
).getResult();
127 return foldResult
.get
<Value
>();
132 /// Trivial Vector to LLVM conversions
133 using VectorScaleOpConversion
=
134 OneToOneConvertToLLVMPattern
<vector::VectorScaleOp
, LLVM::vscale
>;
136 /// Conversion pattern for a vector.bitcast.
137 class VectorBitCastOpConversion
138 : public ConvertOpToLLVMPattern
<vector::BitCastOp
> {
140 using ConvertOpToLLVMPattern
<vector::BitCastOp
>::ConvertOpToLLVMPattern
;
143 matchAndRewrite(vector::BitCastOp bitCastOp
, OpAdaptor adaptor
,
144 ConversionPatternRewriter
&rewriter
) const override
{
145 // Only 0-D and 1-D vectors can be lowered to LLVM.
146 VectorType resultTy
= bitCastOp
.getResultVectorType();
147 if (resultTy
.getRank() > 1)
149 Type newResultTy
= typeConverter
->convertType(resultTy
);
150 rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(bitCastOp
, newResultTy
,
151 adaptor
.getOperands()[0]);
156 /// Conversion pattern for a vector.matrix_multiply.
157 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
158 class VectorMatmulOpConversion
159 : public ConvertOpToLLVMPattern
<vector::MatmulOp
> {
161 using ConvertOpToLLVMPattern
<vector::MatmulOp
>::ConvertOpToLLVMPattern
;
164 matchAndRewrite(vector::MatmulOp matmulOp
, OpAdaptor adaptor
,
165 ConversionPatternRewriter
&rewriter
) const override
{
166 rewriter
.replaceOpWithNewOp
<LLVM::MatrixMultiplyOp
>(
167 matmulOp
, typeConverter
->convertType(matmulOp
.getRes().getType()),
168 adaptor
.getLhs(), adaptor
.getRhs(), matmulOp
.getLhsRows(),
169 matmulOp
.getLhsColumns(), matmulOp
.getRhsColumns());
174 /// Conversion pattern for a vector.flat_transpose.
175 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
176 class VectorFlatTransposeOpConversion
177 : public ConvertOpToLLVMPattern
<vector::FlatTransposeOp
> {
179 using ConvertOpToLLVMPattern
<vector::FlatTransposeOp
>::ConvertOpToLLVMPattern
;
182 matchAndRewrite(vector::FlatTransposeOp transOp
, OpAdaptor adaptor
,
183 ConversionPatternRewriter
&rewriter
) const override
{
184 rewriter
.replaceOpWithNewOp
<LLVM::MatrixTransposeOp
>(
185 transOp
, typeConverter
->convertType(transOp
.getRes().getType()),
186 adaptor
.getMatrix(), transOp
.getRows(), transOp
.getColumns());
191 /// Overloaded utility that replaces a vector.load, vector.store,
192 /// vector.maskedload and vector.maskedstore with their respective LLVM
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp
,
195 vector::LoadOpAdaptor adaptor
,
196 VectorType vectorTy
, Value ptr
, unsigned align
,
197 ConversionPatternRewriter
&rewriter
) {
198 rewriter
.replaceOpWithNewOp
<LLVM::LoadOp
>(loadOp
, vectorTy
, ptr
, align
,
200 loadOp
.getNontemporal());
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp
,
204 vector::MaskedLoadOpAdaptor adaptor
,
205 VectorType vectorTy
, Value ptr
, unsigned align
,
206 ConversionPatternRewriter
&rewriter
) {
207 rewriter
.replaceOpWithNewOp
<LLVM::MaskedLoadOp
>(
208 loadOp
, vectorTy
, ptr
, adaptor
.getMask(), adaptor
.getPassThru(), align
);
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp
,
212 vector::StoreOpAdaptor adaptor
,
213 VectorType vectorTy
, Value ptr
, unsigned align
,
214 ConversionPatternRewriter
&rewriter
) {
215 rewriter
.replaceOpWithNewOp
<LLVM::StoreOp
>(storeOp
, adaptor
.getValueToStore(),
216 ptr
, align
, /*volatile_=*/false,
217 storeOp
.getNontemporal());
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp
,
221 vector::MaskedStoreOpAdaptor adaptor
,
222 VectorType vectorTy
, Value ptr
, unsigned align
,
223 ConversionPatternRewriter
&rewriter
) {
224 rewriter
.replaceOpWithNewOp
<LLVM::MaskedStoreOp
>(
225 storeOp
, adaptor
.getValueToStore(), ptr
, adaptor
.getMask(), align
);
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp
>
231 class VectorLoadStoreConversion
: public ConvertOpToLLVMPattern
<LoadOrStoreOp
> {
233 using ConvertOpToLLVMPattern
<LoadOrStoreOp
>::ConvertOpToLLVMPattern
;
236 matchAndRewrite(LoadOrStoreOp loadOrStoreOp
,
237 typename
LoadOrStoreOp::Adaptor adaptor
,
238 ConversionPatternRewriter
&rewriter
) const override
{
239 // Only 1-D vectors can be lowered to LLVM.
240 VectorType vectorTy
= loadOrStoreOp
.getVectorType();
241 if (vectorTy
.getRank() > 1)
244 auto loc
= loadOrStoreOp
->getLoc();
245 MemRefType memRefTy
= loadOrStoreOp
.getMemRefType();
247 // Resolve alignment.
249 if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy
, align
)))
253 auto vtype
= cast
<VectorType
>(
254 this->typeConverter
->convertType(loadOrStoreOp
.getVectorType()));
255 Value dataPtr
= this->getStridedElementPtr(loc
, memRefTy
, adaptor
.getBase(),
256 adaptor
.getIndices(), rewriter
);
257 replaceLoadOrStoreOp(loadOrStoreOp
, adaptor
, vtype
, dataPtr
, align
,
263 /// Conversion pattern for a vector.gather.
264 class VectorGatherOpConversion
265 : public ConvertOpToLLVMPattern
<vector::GatherOp
> {
267 using ConvertOpToLLVMPattern
<vector::GatherOp
>::ConvertOpToLLVMPattern
;
270 matchAndRewrite(vector::GatherOp gather
, OpAdaptor adaptor
,
271 ConversionPatternRewriter
&rewriter
) const override
{
272 MemRefType memRefType
= dyn_cast
<MemRefType
>(gather
.getBaseType());
273 assert(memRefType
&& "The base should be bufferized");
275 if (failed(isMemRefTypeSupported(memRefType
, *this->getTypeConverter())))
278 auto loc
= gather
->getLoc();
280 // Resolve alignment.
282 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType
, align
)))
285 Value ptr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getBase(),
286 adaptor
.getIndices(), rewriter
);
287 Value base
= adaptor
.getBase();
289 auto llvmNDVectorTy
= adaptor
.getIndexVec().getType();
290 // Handle the simple case of 1-D vector.
291 if (!isa
<LLVM::LLVMArrayType
>(llvmNDVectorTy
)) {
292 auto vType
= gather
.getVectorType();
295 getIndexedPtrs(rewriter
, loc
, *this->getTypeConverter(), memRefType
,
296 base
, ptr
, adaptor
.getIndexVec(), vType
);
297 // Replace with the gather intrinsic.
298 rewriter
.replaceOpWithNewOp
<LLVM::masked_gather
>(
299 gather
, typeConverter
->convertType(vType
), ptrs
, adaptor
.getMask(),
300 adaptor
.getPassThru(), rewriter
.getI32IntegerAttr(align
));
304 const LLVMTypeConverter
&typeConverter
= *this->getTypeConverter();
305 auto callback
= [align
, memRefType
, base
, ptr
, loc
, &rewriter
,
306 &typeConverter
](Type llvm1DVectorTy
,
307 ValueRange vectorOperands
) {
309 Value ptrs
= getIndexedPtrs(
310 rewriter
, loc
, typeConverter
, memRefType
, base
, ptr
,
311 /*index=*/vectorOperands
[0], cast
<VectorType
>(llvm1DVectorTy
));
312 // Create the gather intrinsic.
313 return rewriter
.create
<LLVM::masked_gather
>(
314 loc
, llvm1DVectorTy
, ptrs
, /*mask=*/vectorOperands
[1],
315 /*passThru=*/vectorOperands
[2], rewriter
.getI32IntegerAttr(align
));
317 SmallVector
<Value
> vectorOperands
= {
318 adaptor
.getIndexVec(), adaptor
.getMask(), adaptor
.getPassThru()};
319 return LLVM::detail::handleMultidimensionalVectors(
320 gather
, vectorOperands
, *getTypeConverter(), callback
, rewriter
);
324 /// Conversion pattern for a vector.scatter.
325 class VectorScatterOpConversion
326 : public ConvertOpToLLVMPattern
<vector::ScatterOp
> {
328 using ConvertOpToLLVMPattern
<vector::ScatterOp
>::ConvertOpToLLVMPattern
;
331 matchAndRewrite(vector::ScatterOp scatter
, OpAdaptor adaptor
,
332 ConversionPatternRewriter
&rewriter
) const override
{
333 auto loc
= scatter
->getLoc();
334 MemRefType memRefType
= scatter
.getMemRefType();
336 if (failed(isMemRefTypeSupported(memRefType
, *this->getTypeConverter())))
339 // Resolve alignment.
341 if (failed(getMemRefAlignment(*getTypeConverter(), memRefType
, align
)))
345 VectorType vType
= scatter
.getVectorType();
346 Value ptr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getBase(),
347 adaptor
.getIndices(), rewriter
);
349 getIndexedPtrs(rewriter
, loc
, *this->getTypeConverter(), memRefType
,
350 adaptor
.getBase(), ptr
, adaptor
.getIndexVec(), vType
);
352 // Replace with the scatter intrinsic.
353 rewriter
.replaceOpWithNewOp
<LLVM::masked_scatter
>(
354 scatter
, adaptor
.getValueToStore(), ptrs
, adaptor
.getMask(),
355 rewriter
.getI32IntegerAttr(align
));
360 /// Conversion pattern for a vector.expandload.
361 class VectorExpandLoadOpConversion
362 : public ConvertOpToLLVMPattern
<vector::ExpandLoadOp
> {
364 using ConvertOpToLLVMPattern
<vector::ExpandLoadOp
>::ConvertOpToLLVMPattern
;
367 matchAndRewrite(vector::ExpandLoadOp expand
, OpAdaptor adaptor
,
368 ConversionPatternRewriter
&rewriter
) const override
{
369 auto loc
= expand
->getLoc();
370 MemRefType memRefType
= expand
.getMemRefType();
373 auto vtype
= typeConverter
->convertType(expand
.getVectorType());
374 Value ptr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getBase(),
375 adaptor
.getIndices(), rewriter
);
377 rewriter
.replaceOpWithNewOp
<LLVM::masked_expandload
>(
378 expand
, vtype
, ptr
, adaptor
.getMask(), adaptor
.getPassThru());
383 /// Conversion pattern for a vector.compressstore.
384 class VectorCompressStoreOpConversion
385 : public ConvertOpToLLVMPattern
<vector::CompressStoreOp
> {
387 using ConvertOpToLLVMPattern
<vector::CompressStoreOp
>::ConvertOpToLLVMPattern
;
390 matchAndRewrite(vector::CompressStoreOp compress
, OpAdaptor adaptor
,
391 ConversionPatternRewriter
&rewriter
) const override
{
392 auto loc
= compress
->getLoc();
393 MemRefType memRefType
= compress
.getMemRefType();
396 Value ptr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getBase(),
397 adaptor
.getIndices(), rewriter
);
399 rewriter
.replaceOpWithNewOp
<LLVM::masked_compressstore
>(
400 compress
, adaptor
.getValueToStore(), ptr
, adaptor
.getMask());
405 /// Reduction neutral classes for overloading.
406 class ReductionNeutralZero
{};
407 class ReductionNeutralIntOne
{};
408 class ReductionNeutralFPOne
{};
409 class ReductionNeutralAllOnes
{};
410 class ReductionNeutralSIntMin
{};
411 class ReductionNeutralUIntMin
{};
412 class ReductionNeutralSIntMax
{};
413 class ReductionNeutralUIntMax
{};
414 class ReductionNeutralFPMin
{};
415 class ReductionNeutralFPMax
{};
417 /// Create the reduction neutral zero value.
418 static Value
createReductionNeutralValue(ReductionNeutralZero neutral
,
419 ConversionPatternRewriter
&rewriter
,
420 Location loc
, Type llvmType
) {
421 return rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmType
,
422 rewriter
.getZeroAttr(llvmType
));
425 /// Create the reduction neutral integer one value.
426 static Value
createReductionNeutralValue(ReductionNeutralIntOne neutral
,
427 ConversionPatternRewriter
&rewriter
,
428 Location loc
, Type llvmType
) {
429 return rewriter
.create
<LLVM::ConstantOp
>(
430 loc
, llvmType
, rewriter
.getIntegerAttr(llvmType
, 1));
433 /// Create the reduction neutral fp one value.
434 static Value
createReductionNeutralValue(ReductionNeutralFPOne neutral
,
435 ConversionPatternRewriter
&rewriter
,
436 Location loc
, Type llvmType
) {
437 return rewriter
.create
<LLVM::ConstantOp
>(
438 loc
, llvmType
, rewriter
.getFloatAttr(llvmType
, 1.0));
441 /// Create the reduction neutral all-ones value.
442 static Value
createReductionNeutralValue(ReductionNeutralAllOnes neutral
,
443 ConversionPatternRewriter
&rewriter
,
444 Location loc
, Type llvmType
) {
445 return rewriter
.create
<LLVM::ConstantOp
>(
447 rewriter
.getIntegerAttr(
448 llvmType
, llvm::APInt::getAllOnes(llvmType
.getIntOrFloatBitWidth())));
451 /// Create the reduction neutral signed int minimum value.
452 static Value
createReductionNeutralValue(ReductionNeutralSIntMin neutral
,
453 ConversionPatternRewriter
&rewriter
,
454 Location loc
, Type llvmType
) {
455 return rewriter
.create
<LLVM::ConstantOp
>(
457 rewriter
.getIntegerAttr(llvmType
, llvm::APInt::getSignedMinValue(
458 llvmType
.getIntOrFloatBitWidth())));
461 /// Create the reduction neutral unsigned int minimum value.
462 static Value
createReductionNeutralValue(ReductionNeutralUIntMin neutral
,
463 ConversionPatternRewriter
&rewriter
,
464 Location loc
, Type llvmType
) {
465 return rewriter
.create
<LLVM::ConstantOp
>(
467 rewriter
.getIntegerAttr(llvmType
, llvm::APInt::getMinValue(
468 llvmType
.getIntOrFloatBitWidth())));
471 /// Create the reduction neutral signed int maximum value.
472 static Value
createReductionNeutralValue(ReductionNeutralSIntMax neutral
,
473 ConversionPatternRewriter
&rewriter
,
474 Location loc
, Type llvmType
) {
475 return rewriter
.create
<LLVM::ConstantOp
>(
477 rewriter
.getIntegerAttr(llvmType
, llvm::APInt::getSignedMaxValue(
478 llvmType
.getIntOrFloatBitWidth())));
481 /// Create the reduction neutral unsigned int maximum value.
482 static Value
createReductionNeutralValue(ReductionNeutralUIntMax neutral
,
483 ConversionPatternRewriter
&rewriter
,
484 Location loc
, Type llvmType
) {
485 return rewriter
.create
<LLVM::ConstantOp
>(
487 rewriter
.getIntegerAttr(llvmType
, llvm::APInt::getMaxValue(
488 llvmType
.getIntOrFloatBitWidth())));
491 /// Create the reduction neutral fp minimum value.
492 static Value
createReductionNeutralValue(ReductionNeutralFPMin neutral
,
493 ConversionPatternRewriter
&rewriter
,
494 Location loc
, Type llvmType
) {
495 auto floatType
= cast
<FloatType
>(llvmType
);
496 return rewriter
.create
<LLVM::ConstantOp
>(
498 rewriter
.getFloatAttr(
499 llvmType
, llvm::APFloat::getQNaN(floatType
.getFloatSemantics(),
500 /*Negative=*/false)));
503 /// Create the reduction neutral fp maximum value.
504 static Value
createReductionNeutralValue(ReductionNeutralFPMax neutral
,
505 ConversionPatternRewriter
&rewriter
,
506 Location loc
, Type llvmType
) {
507 auto floatType
= cast
<FloatType
>(llvmType
);
508 return rewriter
.create
<LLVM::ConstantOp
>(
510 rewriter
.getFloatAttr(
511 llvmType
, llvm::APFloat::getQNaN(floatType
.getFloatSemantics(),
512 /*Negative=*/true)));
515 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
516 /// returns a new accumulator value using `ReductionNeutral`.
517 template <class ReductionNeutral
>
518 static Value
getOrCreateAccumulator(ConversionPatternRewriter
&rewriter
,
519 Location loc
, Type llvmType
,
524 return createReductionNeutralValue(ReductionNeutral(), rewriter
, loc
,
528 /// Creates a value with the 1-D vector shape provided in `llvmType`.
529 /// This is used as effective vector length by some intrinsics supporting
530 /// dynamic vector lengths at runtime.
531 static Value
createVectorLengthValue(ConversionPatternRewriter
&rewriter
,
532 Location loc
, Type llvmType
) {
533 VectorType vType
= cast
<VectorType
>(llvmType
);
534 auto vShape
= vType
.getShape();
535 assert(vShape
.size() == 1 && "Unexpected multi-dim vector type");
537 Value baseVecLength
= rewriter
.create
<LLVM::ConstantOp
>(
538 loc
, rewriter
.getI32Type(),
539 rewriter
.getIntegerAttr(rewriter
.getI32Type(), vShape
[0]));
541 if (!vType
.getScalableDims()[0])
542 return baseVecLength
;
544 // For a scalable vector type, create and return `vScale * baseVecLength`.
545 Value vScale
= rewriter
.create
<vector::VectorScaleOp
>(loc
);
547 rewriter
.create
<arith::IndexCastOp
>(loc
, rewriter
.getI32Type(), vScale
);
548 Value scalableVecLength
=
549 rewriter
.create
<arith::MulIOp
>(loc
, baseVecLength
, vScale
);
550 return scalableVecLength
;
553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
557 template <class LLVMRedIntrinOp
, class ScalarOp
>
558 static Value
createIntegerReductionArithmeticOpLowering(
559 ConversionPatternRewriter
&rewriter
, Location loc
, Type llvmType
,
560 Value vectorOperand
, Value accumulator
) {
562 Value result
= rewriter
.create
<LLVMRedIntrinOp
>(loc
, llvmType
, vectorOperand
);
565 result
= rewriter
.create
<ScalarOp
>(loc
, accumulator
, result
);
569 /// Helper method to lower a `vector.reduction` operation that performs
570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
572 /// the accumulator value if non-null.
573 template <class LLVMRedIntrinOp
>
574 static Value
createIntegerReductionComparisonOpLowering(
575 ConversionPatternRewriter
&rewriter
, Location loc
, Type llvmType
,
576 Value vectorOperand
, Value accumulator
, LLVM::ICmpPredicate predicate
) {
577 Value result
= rewriter
.create
<LLVMRedIntrinOp
>(loc
, llvmType
, vectorOperand
);
580 rewriter
.create
<LLVM::ICmpOp
>(loc
, predicate
, accumulator
, result
);
581 result
= rewriter
.create
<LLVM::SelectOp
>(loc
, cmp
, accumulator
, result
);
587 template <typename Source
>
588 struct VectorToScalarMapper
;
590 struct VectorToScalarMapper
<LLVM::vector_reduce_fmaximum
> {
591 using Type
= LLVM::MaximumOp
;
594 struct VectorToScalarMapper
<LLVM::vector_reduce_fminimum
> {
595 using Type
= LLVM::MinimumOp
;
598 struct VectorToScalarMapper
<LLVM::vector_reduce_fmax
> {
599 using Type
= LLVM::MaxNumOp
;
602 struct VectorToScalarMapper
<LLVM::vector_reduce_fmin
> {
603 using Type
= LLVM::MinNumOp
;
607 template <class LLVMRedIntrinOp
>
608 static Value
createFPReductionComparisonOpLowering(
609 ConversionPatternRewriter
&rewriter
, Location loc
, Type llvmType
,
610 Value vectorOperand
, Value accumulator
, LLVM::FastmathFlagsAttr fmf
) {
612 rewriter
.create
<LLVMRedIntrinOp
>(loc
, llvmType
, vectorOperand
, fmf
);
616 rewriter
.create
<typename VectorToScalarMapper
<LLVMRedIntrinOp
>::Type
>(
617 loc
, result
, accumulator
);
623 /// Reduction neutral classes for overloading
624 class MaskNeutralFMaximum
{};
625 class MaskNeutralFMinimum
{};
627 /// Get the mask neutral floating point maximum value
629 getMaskNeutralValue(MaskNeutralFMaximum
,
630 const llvm::fltSemantics
&floatSemantics
) {
631 return llvm::APFloat::getSmallest(floatSemantics
, /*Negative=*/true);
633 /// Get the mask neutral floating point minimum value
635 getMaskNeutralValue(MaskNeutralFMinimum
,
636 const llvm::fltSemantics
&floatSemantics
) {
637 return llvm::APFloat::getLargest(floatSemantics
, /*Negative=*/false);
640 /// Create the mask neutral floating point MLIR vector constant
641 template <typename MaskNeutral
>
642 static Value
createMaskNeutralValue(ConversionPatternRewriter
&rewriter
,
643 Location loc
, Type llvmType
,
645 const auto &floatSemantics
= cast
<FloatType
>(llvmType
).getFloatSemantics();
646 auto value
= getMaskNeutralValue(MaskNeutral
{}, floatSemantics
);
647 auto denseValue
= DenseElementsAttr::get(cast
<ShapedType
>(vectorType
), value
);
648 return rewriter
.create
<LLVM::ConstantOp
>(loc
, vectorType
, denseValue
);
651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
653 /// `fmaximum`/`fminimum`.
654 /// More information: https://github.com/llvm/llvm-project/issues/64940
655 template <class LLVMRedIntrinOp
, class MaskNeutral
>
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter
&rewriter
,
658 Location loc
, Type llvmType
,
659 Value vectorOperand
, Value accumulator
,
660 Value mask
, LLVM::FastmathFlagsAttr fmf
) {
661 const Value vectorMaskNeutral
= createMaskNeutralValue
<MaskNeutral
>(
662 rewriter
, loc
, llvmType
, vectorOperand
.getType());
663 const Value selectedVectorByMask
= rewriter
.create
<LLVM::SelectOp
>(
664 loc
, mask
, vectorOperand
, vectorMaskNeutral
);
665 return createFPReductionComparisonOpLowering
<LLVMRedIntrinOp
>(
666 rewriter
, loc
, llvmType
, selectedVectorByMask
, accumulator
, fmf
);
669 template <class LLVMRedIntrinOp
, class ReductionNeutral
>
671 lowerReductionWithStartValue(ConversionPatternRewriter
&rewriter
, Location loc
,
672 Type llvmType
, Value vectorOperand
,
673 Value accumulator
, LLVM::FastmathFlagsAttr fmf
) {
674 accumulator
= getOrCreateAccumulator
<ReductionNeutral
>(rewriter
, loc
,
675 llvmType
, accumulator
);
676 return rewriter
.create
<LLVMRedIntrinOp
>(loc
, llvmType
,
677 /*startValue=*/accumulator
,
681 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
682 /// that requires a start value. This start value format spans across fp
683 /// reductions without mask and all the masked reduction intrinsics.
684 template <class LLVMVPRedIntrinOp
, class ReductionNeutral
>
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter
&rewriter
,
687 Location loc
, Type llvmType
,
688 Value vectorOperand
, Value accumulator
) {
689 accumulator
= getOrCreateAccumulator
<ReductionNeutral
>(rewriter
, loc
,
690 llvmType
, accumulator
);
691 return rewriter
.create
<LLVMVPRedIntrinOp
>(loc
, llvmType
,
692 /*startValue=*/accumulator
,
696 template <class LLVMVPRedIntrinOp
, class ReductionNeutral
>
697 static Value
lowerPredicatedReductionWithStartValue(
698 ConversionPatternRewriter
&rewriter
, Location loc
, Type llvmType
,
699 Value vectorOperand
, Value accumulator
, Value mask
) {
700 accumulator
= getOrCreateAccumulator
<ReductionNeutral
>(rewriter
, loc
,
701 llvmType
, accumulator
);
703 createVectorLengthValue(rewriter
, loc
, vectorOperand
.getType());
704 return rewriter
.create
<LLVMVPRedIntrinOp
>(loc
, llvmType
,
705 /*startValue=*/accumulator
,
706 vectorOperand
, mask
, vectorLength
);
709 template <class LLVMIntVPRedIntrinOp
, class IntReductionNeutral
,
710 class LLVMFPVPRedIntrinOp
, class FPReductionNeutral
>
711 static Value
lowerPredicatedReductionWithStartValue(
712 ConversionPatternRewriter
&rewriter
, Location loc
, Type llvmType
,
713 Value vectorOperand
, Value accumulator
, Value mask
) {
714 if (llvmType
.isIntOrIndex())
715 return lowerPredicatedReductionWithStartValue
<LLVMIntVPRedIntrinOp
,
716 IntReductionNeutral
>(
717 rewriter
, loc
, llvmType
, vectorOperand
, accumulator
, mask
);
720 return lowerPredicatedReductionWithStartValue
<LLVMFPVPRedIntrinOp
,
722 rewriter
, loc
, llvmType
, vectorOperand
, accumulator
, mask
);
725 /// Conversion pattern for all vector reductions.
726 class VectorReductionOpConversion
727 : public ConvertOpToLLVMPattern
<vector::ReductionOp
> {
729 explicit VectorReductionOpConversion(const LLVMTypeConverter
&typeConv
,
730 bool reassociateFPRed
)
731 : ConvertOpToLLVMPattern
<vector::ReductionOp
>(typeConv
),
732 reassociateFPReductions(reassociateFPRed
) {}
735 matchAndRewrite(vector::ReductionOp reductionOp
, OpAdaptor adaptor
,
736 ConversionPatternRewriter
&rewriter
) const override
{
737 auto kind
= reductionOp
.getKind();
738 Type eltType
= reductionOp
.getDest().getType();
739 Type llvmType
= typeConverter
->convertType(eltType
);
740 Value operand
= adaptor
.getVector();
741 Value acc
= adaptor
.getAcc();
742 Location loc
= reductionOp
.getLoc();
744 if (eltType
.isIntOrIndex()) {
745 // Integer reductions: add/mul/min/max/and/or/xor.
748 case vector::CombiningKind::ADD
:
750 createIntegerReductionArithmeticOpLowering
<LLVM::vector_reduce_add
,
752 rewriter
, loc
, llvmType
, operand
, acc
);
754 case vector::CombiningKind::MUL
:
756 createIntegerReductionArithmeticOpLowering
<LLVM::vector_reduce_mul
,
758 rewriter
, loc
, llvmType
, operand
, acc
);
760 case vector::CombiningKind::MINUI
:
761 result
= createIntegerReductionComparisonOpLowering
<
762 LLVM::vector_reduce_umin
>(rewriter
, loc
, llvmType
, operand
, acc
,
763 LLVM::ICmpPredicate::ule
);
765 case vector::CombiningKind::MINSI
:
766 result
= createIntegerReductionComparisonOpLowering
<
767 LLVM::vector_reduce_smin
>(rewriter
, loc
, llvmType
, operand
, acc
,
768 LLVM::ICmpPredicate::sle
);
770 case vector::CombiningKind::MAXUI
:
771 result
= createIntegerReductionComparisonOpLowering
<
772 LLVM::vector_reduce_umax
>(rewriter
, loc
, llvmType
, operand
, acc
,
773 LLVM::ICmpPredicate::uge
);
775 case vector::CombiningKind::MAXSI
:
776 result
= createIntegerReductionComparisonOpLowering
<
777 LLVM::vector_reduce_smax
>(rewriter
, loc
, llvmType
, operand
, acc
,
778 LLVM::ICmpPredicate::sge
);
780 case vector::CombiningKind::AND
:
782 createIntegerReductionArithmeticOpLowering
<LLVM::vector_reduce_and
,
784 rewriter
, loc
, llvmType
, operand
, acc
);
786 case vector::CombiningKind::OR
:
788 createIntegerReductionArithmeticOpLowering
<LLVM::vector_reduce_or
,
790 rewriter
, loc
, llvmType
, operand
, acc
);
792 case vector::CombiningKind::XOR
:
794 createIntegerReductionArithmeticOpLowering
<LLVM::vector_reduce_xor
,
796 rewriter
, loc
, llvmType
, operand
, acc
);
801 rewriter
.replaceOp(reductionOp
, result
);
806 if (!isa
<FloatType
>(eltType
))
809 arith::FastMathFlagsAttr fMFAttr
= reductionOp
.getFastMathFlagsAttr();
810 LLVM::FastmathFlagsAttr fmf
= LLVM::FastmathFlagsAttr::get(
811 reductionOp
.getContext(),
812 convertArithFastMathFlagsToLLVM(fMFAttr
.getValue()));
813 fmf
= LLVM::FastmathFlagsAttr::get(
814 reductionOp
.getContext(),
815 fmf
.getValue() | (reassociateFPReductions
? LLVM::FastmathFlags::reassoc
816 : LLVM::FastmathFlags::none
));
818 // Floating-point reductions: add/mul/min/max
820 if (kind
== vector::CombiningKind::ADD
) {
821 result
= lowerReductionWithStartValue
<LLVM::vector_reduce_fadd
,
822 ReductionNeutralZero
>(
823 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
824 } else if (kind
== vector::CombiningKind::MUL
) {
825 result
= lowerReductionWithStartValue
<LLVM::vector_reduce_fmul
,
826 ReductionNeutralFPOne
>(
827 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
828 } else if (kind
== vector::CombiningKind::MINIMUMF
) {
830 createFPReductionComparisonOpLowering
<LLVM::vector_reduce_fminimum
>(
831 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
832 } else if (kind
== vector::CombiningKind::MAXIMUMF
) {
834 createFPReductionComparisonOpLowering
<LLVM::vector_reduce_fmaximum
>(
835 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
836 } else if (kind
== vector::CombiningKind::MINNUMF
) {
837 result
= createFPReductionComparisonOpLowering
<LLVM::vector_reduce_fmin
>(
838 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
839 } else if (kind
== vector::CombiningKind::MAXNUMF
) {
840 result
= createFPReductionComparisonOpLowering
<LLVM::vector_reduce_fmax
>(
841 rewriter
, loc
, llvmType
, operand
, acc
, fmf
);
845 rewriter
.replaceOp(reductionOp
, result
);
850 const bool reassociateFPReductions
;
853 /// Base class to convert a `vector.mask` operation while matching traits
854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
856 /// method performs a second match against the maskable operation `MaskedOp`.
857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
858 /// implemented by the concrete conversion classes. This method can match
859 /// against specific traits of the `vector.mask` and the maskable operation. It
860 /// must replace the `vector.mask` operation.
861 template <class MaskedOp
>
862 class VectorMaskOpConversionBase
863 : public ConvertOpToLLVMPattern
<vector::MaskOp
> {
865 using ConvertOpToLLVMPattern
<vector::MaskOp
>::ConvertOpToLLVMPattern
;
868 matchAndRewrite(vector::MaskOp maskOp
, OpAdaptor adaptor
,
869 ConversionPatternRewriter
&rewriter
) const final
{
870 // Match against the maskable operation kind.
871 auto maskedOp
= llvm::dyn_cast_or_null
<MaskedOp
>(maskOp
.getMaskableOp());
874 return matchAndRewriteMaskableOp(maskOp
, maskedOp
, rewriter
);
878 virtual LogicalResult
879 matchAndRewriteMaskableOp(vector::MaskOp maskOp
,
880 vector::MaskableOpInterface maskableOp
,
881 ConversionPatternRewriter
&rewriter
) const = 0;
884 class MaskedReductionOpConversion
885 : public VectorMaskOpConversionBase
<vector::ReductionOp
> {
888 using VectorMaskOpConversionBase
<
889 vector::ReductionOp
>::VectorMaskOpConversionBase
;
891 LogicalResult
matchAndRewriteMaskableOp(
892 vector::MaskOp maskOp
, MaskableOpInterface maskableOp
,
893 ConversionPatternRewriter
&rewriter
) const override
{
894 auto reductionOp
= cast
<ReductionOp
>(maskableOp
.getOperation());
895 auto kind
= reductionOp
.getKind();
896 Type eltType
= reductionOp
.getDest().getType();
897 Type llvmType
= typeConverter
->convertType(eltType
);
898 Value operand
= reductionOp
.getVector();
899 Value acc
= reductionOp
.getAcc();
900 Location loc
= reductionOp
.getLoc();
902 arith::FastMathFlagsAttr fMFAttr
= reductionOp
.getFastMathFlagsAttr();
903 LLVM::FastmathFlagsAttr fmf
= LLVM::FastmathFlagsAttr::get(
904 reductionOp
.getContext(),
905 convertArithFastMathFlagsToLLVM(fMFAttr
.getValue()));
909 case vector::CombiningKind::ADD
:
910 result
= lowerPredicatedReductionWithStartValue
<
911 LLVM::VPReduceAddOp
, ReductionNeutralZero
, LLVM::VPReduceFAddOp
,
912 ReductionNeutralZero
>(rewriter
, loc
, llvmType
, operand
, acc
,
915 case vector::CombiningKind::MUL
:
916 result
= lowerPredicatedReductionWithStartValue
<
917 LLVM::VPReduceMulOp
, ReductionNeutralIntOne
, LLVM::VPReduceFMulOp
,
918 ReductionNeutralFPOne
>(rewriter
, loc
, llvmType
, operand
, acc
,
921 case vector::CombiningKind::MINUI
:
922 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceUMinOp
,
923 ReductionNeutralUIntMax
>(
924 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
926 case vector::CombiningKind::MINSI
:
927 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceSMinOp
,
928 ReductionNeutralSIntMax
>(
929 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
931 case vector::CombiningKind::MAXUI
:
932 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceUMaxOp
,
933 ReductionNeutralUIntMin
>(
934 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
936 case vector::CombiningKind::MAXSI
:
937 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceSMaxOp
,
938 ReductionNeutralSIntMin
>(
939 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
941 case vector::CombiningKind::AND
:
942 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceAndOp
,
943 ReductionNeutralAllOnes
>(
944 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
946 case vector::CombiningKind::OR
:
947 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceOrOp
,
948 ReductionNeutralZero
>(
949 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
951 case vector::CombiningKind::XOR
:
952 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceXorOp
,
953 ReductionNeutralZero
>(
954 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
956 case vector::CombiningKind::MINNUMF
:
957 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceFMinOp
,
958 ReductionNeutralFPMax
>(
959 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
961 case vector::CombiningKind::MAXNUMF
:
962 result
= lowerPredicatedReductionWithStartValue
<LLVM::VPReduceFMaxOp
,
963 ReductionNeutralFPMin
>(
964 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask());
966 case CombiningKind::MAXIMUMF
:
967 result
= lowerMaskedReductionWithRegular
<LLVM::vector_reduce_fmaximum
,
968 MaskNeutralFMaximum
>(
969 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask(), fmf
);
971 case CombiningKind::MINIMUMF
:
972 result
= lowerMaskedReductionWithRegular
<LLVM::vector_reduce_fminimum
,
973 MaskNeutralFMinimum
>(
974 rewriter
, loc
, llvmType
, operand
, acc
, maskOp
.getMask(), fmf
);
978 // Replace `vector.mask` operation altogether.
979 rewriter
.replaceOp(maskOp
, result
);
984 class VectorShuffleOpConversion
985 : public ConvertOpToLLVMPattern
<vector::ShuffleOp
> {
987 using ConvertOpToLLVMPattern
<vector::ShuffleOp
>::ConvertOpToLLVMPattern
;
990 matchAndRewrite(vector::ShuffleOp shuffleOp
, OpAdaptor adaptor
,
991 ConversionPatternRewriter
&rewriter
) const override
{
992 auto loc
= shuffleOp
->getLoc();
993 auto v1Type
= shuffleOp
.getV1VectorType();
994 auto v2Type
= shuffleOp
.getV2VectorType();
995 auto vectorType
= shuffleOp
.getResultVectorType();
996 Type llvmType
= typeConverter
->convertType(vectorType
);
997 ArrayRef
<int64_t> mask
= shuffleOp
.getMask();
999 // Bail if result type cannot be lowered.
1003 // Get rank and dimension sizes.
1004 int64_t rank
= vectorType
.getRank();
1006 bool wellFormed0DCase
=
1007 v1Type
.getRank() == 0 && v2Type
.getRank() == 0 && rank
== 1;
1008 bool wellFormedNDCase
=
1009 v1Type
.getRank() == rank
&& v2Type
.getRank() == rank
;
1010 assert((wellFormed0DCase
|| wellFormedNDCase
) && "op is not well-formed");
1013 // For rank 0 and 1, where both operands have *exactly* the same vector
1014 // type, there is direct shuffle support in LLVM. Use it!
1015 if (rank
<= 1 && v1Type
== v2Type
) {
1016 Value llvmShuffleOp
= rewriter
.create
<LLVM::ShuffleVectorOp
>(
1017 loc
, adaptor
.getV1(), adaptor
.getV2(),
1018 llvm::to_vector_of
<int32_t>(mask
));
1019 rewriter
.replaceOp(shuffleOp
, llvmShuffleOp
);
1023 // For all other cases, insert the individual values individually.
1024 int64_t v1Dim
= v1Type
.getDimSize(0);
1026 if (auto arrayType
= dyn_cast
<LLVM::LLVMArrayType
>(llvmType
))
1027 eltType
= arrayType
.getElementType();
1029 eltType
= cast
<VectorType
>(llvmType
).getElementType();
1030 Value insert
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmType
);
1032 for (int64_t extPos
: mask
) {
1033 Value value
= adaptor
.getV1();
1034 if (extPos
>= v1Dim
) {
1036 value
= adaptor
.getV2();
1038 Value extract
= extractOne(rewriter
, *getTypeConverter(), loc
, value
,
1039 eltType
, rank
, extPos
);
1040 insert
= insertOne(rewriter
, *getTypeConverter(), loc
, insert
, extract
,
1041 llvmType
, rank
, insPos
++);
1043 rewriter
.replaceOp(shuffleOp
, insert
);
1048 class VectorExtractElementOpConversion
1049 : public ConvertOpToLLVMPattern
<vector::ExtractElementOp
> {
1051 using ConvertOpToLLVMPattern
<
1052 vector::ExtractElementOp
>::ConvertOpToLLVMPattern
;
1055 matchAndRewrite(vector::ExtractElementOp extractEltOp
, OpAdaptor adaptor
,
1056 ConversionPatternRewriter
&rewriter
) const override
{
1057 auto vectorType
= extractEltOp
.getSourceVectorType();
1058 auto llvmType
= typeConverter
->convertType(vectorType
.getElementType());
1060 // Bail if result type cannot be lowered.
1064 if (vectorType
.getRank() == 0) {
1065 Location loc
= extractEltOp
.getLoc();
1066 auto idxType
= rewriter
.getIndexType();
1067 auto zero
= rewriter
.create
<LLVM::ConstantOp
>(
1068 loc
, typeConverter
->convertType(idxType
),
1069 rewriter
.getIntegerAttr(idxType
, 0));
1070 rewriter
.replaceOpWithNewOp
<LLVM::ExtractElementOp
>(
1071 extractEltOp
, llvmType
, adaptor
.getVector(), zero
);
1075 rewriter
.replaceOpWithNewOp
<LLVM::ExtractElementOp
>(
1076 extractEltOp
, llvmType
, adaptor
.getVector(), adaptor
.getPosition());
1081 class VectorExtractOpConversion
1082 : public ConvertOpToLLVMPattern
<vector::ExtractOp
> {
1084 using ConvertOpToLLVMPattern
<vector::ExtractOp
>::ConvertOpToLLVMPattern
;
1087 matchAndRewrite(vector::ExtractOp extractOp
, OpAdaptor adaptor
,
1088 ConversionPatternRewriter
&rewriter
) const override
{
1089 auto loc
= extractOp
->getLoc();
1090 auto resultType
= extractOp
.getResult().getType();
1091 auto llvmResultType
= typeConverter
->convertType(resultType
);
1092 // Bail if result type cannot be lowered.
1093 if (!llvmResultType
)
1096 SmallVector
<OpFoldResult
> positionVec
= getMixedValues(
1097 adaptor
.getStaticPosition(), adaptor
.getDynamicPosition(), rewriter
);
1099 // Extract entire vector. Should be handled by folder, but just to be safe.
1100 ArrayRef
<OpFoldResult
> position(positionVec
);
1101 if (position
.empty()) {
1102 rewriter
.replaceOp(extractOp
, adaptor
.getVector());
1106 // One-shot extraction of vector from array (only requires extractvalue).
1107 // Except for extracting 1-element vectors.
1108 if (isa
<VectorType
>(resultType
) &&
1110 static_cast<size_t>(extractOp
.getSourceVectorType().getRank())) {
1111 if (extractOp
.hasDynamicPosition())
1114 Value extracted
= rewriter
.create
<LLVM::ExtractValueOp
>(
1115 loc
, adaptor
.getVector(), getAsIntegers(position
));
1116 rewriter
.replaceOp(extractOp
, extracted
);
1120 // Potential extraction of 1-D vector from array.
1121 Value extracted
= adaptor
.getVector();
1122 if (position
.size() > 1) {
1123 if (extractOp
.hasDynamicPosition())
1126 SmallVector
<int64_t> nMinusOnePosition
=
1127 getAsIntegers(position
.drop_back());
1128 extracted
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, extracted
,
1132 Value lastPosition
= getAsLLVMValue(rewriter
, loc
, position
.back());
1133 // Remaining extraction of element from 1-D LLVM vector.
1134 rewriter
.replaceOpWithNewOp
<LLVM::ExtractElementOp
>(extractOp
, extracted
,
1140 /// Conversion pattern that turns a vector.fma on a 1-D vector
1141 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1142 /// This does not match vectors of n >= 2 rank.
1146 /// vector.fma %a, %a, %a : vector<8xf32>
1148 /// is converted to:
1150 /// llvm.intr.fmuladd %va, %va, %va:
1151 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1152 /// -> !llvm."<8 x f32>">
1154 class VectorFMAOp1DConversion
: public ConvertOpToLLVMPattern
<vector::FMAOp
> {
1156 using ConvertOpToLLVMPattern
<vector::FMAOp
>::ConvertOpToLLVMPattern
;
1159 matchAndRewrite(vector::FMAOp fmaOp
, OpAdaptor adaptor
,
1160 ConversionPatternRewriter
&rewriter
) const override
{
1161 VectorType vType
= fmaOp
.getVectorType();
1162 if (vType
.getRank() > 1)
1165 rewriter
.replaceOpWithNewOp
<LLVM::FMulAddOp
>(
1166 fmaOp
, adaptor
.getLhs(), adaptor
.getRhs(), adaptor
.getAcc());
1171 class VectorInsertElementOpConversion
1172 : public ConvertOpToLLVMPattern
<vector::InsertElementOp
> {
1174 using ConvertOpToLLVMPattern
<vector::InsertElementOp
>::ConvertOpToLLVMPattern
;
1177 matchAndRewrite(vector::InsertElementOp insertEltOp
, OpAdaptor adaptor
,
1178 ConversionPatternRewriter
&rewriter
) const override
{
1179 auto vectorType
= insertEltOp
.getDestVectorType();
1180 auto llvmType
= typeConverter
->convertType(vectorType
);
1182 // Bail if result type cannot be lowered.
1186 if (vectorType
.getRank() == 0) {
1187 Location loc
= insertEltOp
.getLoc();
1188 auto idxType
= rewriter
.getIndexType();
1189 auto zero
= rewriter
.create
<LLVM::ConstantOp
>(
1190 loc
, typeConverter
->convertType(idxType
),
1191 rewriter
.getIntegerAttr(idxType
, 0));
1192 rewriter
.replaceOpWithNewOp
<LLVM::InsertElementOp
>(
1193 insertEltOp
, llvmType
, adaptor
.getDest(), adaptor
.getSource(), zero
);
1197 rewriter
.replaceOpWithNewOp
<LLVM::InsertElementOp
>(
1198 insertEltOp
, llvmType
, adaptor
.getDest(), adaptor
.getSource(),
1199 adaptor
.getPosition());
1204 class VectorInsertOpConversion
1205 : public ConvertOpToLLVMPattern
<vector::InsertOp
> {
1207 using ConvertOpToLLVMPattern
<vector::InsertOp
>::ConvertOpToLLVMPattern
;
1210 matchAndRewrite(vector::InsertOp insertOp
, OpAdaptor adaptor
,
1211 ConversionPatternRewriter
&rewriter
) const override
{
1212 auto loc
= insertOp
->getLoc();
1213 auto sourceType
= insertOp
.getSourceType();
1214 auto destVectorType
= insertOp
.getDestVectorType();
1215 auto llvmResultType
= typeConverter
->convertType(destVectorType
);
1216 // Bail if result type cannot be lowered.
1217 if (!llvmResultType
)
1220 SmallVector
<OpFoldResult
> positionVec
= getMixedValues(
1221 adaptor
.getStaticPosition(), adaptor
.getDynamicPosition(), rewriter
);
1223 // Overwrite entire vector with value. Should be handled by folder, but
1225 ArrayRef
<OpFoldResult
> position(positionVec
);
1226 if (position
.empty()) {
1227 rewriter
.replaceOp(insertOp
, adaptor
.getSource());
1231 // One-shot insertion of a vector into an array (only requires insertvalue).
1232 if (isa
<VectorType
>(sourceType
)) {
1233 if (insertOp
.hasDynamicPosition())
1236 Value inserted
= rewriter
.create
<LLVM::InsertValueOp
>(
1237 loc
, adaptor
.getDest(), adaptor
.getSource(), getAsIntegers(position
));
1238 rewriter
.replaceOp(insertOp
, inserted
);
1242 // Potential extraction of 1-D vector from array.
1243 Value extracted
= adaptor
.getDest();
1244 auto oneDVectorType
= destVectorType
;
1245 if (position
.size() > 1) {
1246 if (insertOp
.hasDynamicPosition())
1249 oneDVectorType
= reducedVectorTypeBack(destVectorType
);
1250 extracted
= rewriter
.create
<LLVM::ExtractValueOp
>(
1251 loc
, extracted
, getAsIntegers(position
.drop_back()));
1254 // Insertion of an element into a 1-D LLVM vector.
1255 Value inserted
= rewriter
.create
<LLVM::InsertElementOp
>(
1256 loc
, typeConverter
->convertType(oneDVectorType
), extracted
,
1257 adaptor
.getSource(), getAsLLVMValue(rewriter
, loc
, position
.back()));
1259 // Potential insertion of resulting 1-D vector into array.
1260 if (position
.size() > 1) {
1261 if (insertOp
.hasDynamicPosition())
1264 inserted
= rewriter
.create
<LLVM::InsertValueOp
>(
1265 loc
, adaptor
.getDest(), inserted
,
1266 getAsIntegers(position
.drop_back()));
1269 rewriter
.replaceOp(insertOp
, inserted
);
1274 /// Lower vector.scalable.insert ops to LLVM vector.insert
1275 struct VectorScalableInsertOpLowering
1276 : public ConvertOpToLLVMPattern
<vector::ScalableInsertOp
> {
1277 using ConvertOpToLLVMPattern
<
1278 vector::ScalableInsertOp
>::ConvertOpToLLVMPattern
;
1281 matchAndRewrite(vector::ScalableInsertOp insOp
, OpAdaptor adaptor
,
1282 ConversionPatternRewriter
&rewriter
) const override
{
1283 rewriter
.replaceOpWithNewOp
<LLVM::vector_insert
>(
1284 insOp
, adaptor
.getDest(), adaptor
.getSource(), adaptor
.getPos());
1289 /// Lower vector.scalable.extract ops to LLVM vector.extract
1290 struct VectorScalableExtractOpLowering
1291 : public ConvertOpToLLVMPattern
<vector::ScalableExtractOp
> {
1292 using ConvertOpToLLVMPattern
<
1293 vector::ScalableExtractOp
>::ConvertOpToLLVMPattern
;
1296 matchAndRewrite(vector::ScalableExtractOp extOp
, OpAdaptor adaptor
,
1297 ConversionPatternRewriter
&rewriter
) const override
{
1298 rewriter
.replaceOpWithNewOp
<LLVM::vector_extract
>(
1299 extOp
, typeConverter
->convertType(extOp
.getResultVectorType()),
1300 adaptor
.getSource(), adaptor
.getPos());
1305 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1309 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1311 /// is rewritten into:
1313 /// %r = splat %f0: vector<2x4xf32>
1314 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1315 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1316 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1317 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1318 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1319 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1320 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1321 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1322 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1323 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1324 /// // %r3 holds the final value.
1326 class VectorFMAOpNDRewritePattern
: public OpRewritePattern
<FMAOp
> {
1328 using OpRewritePattern
<FMAOp
>::OpRewritePattern
;
1331 // This pattern recursively unpacks one dimension at a time. The recursion
1332 // bounded as the rank is strictly decreasing.
1333 setHasBoundedRewriteRecursion();
1336 LogicalResult
matchAndRewrite(FMAOp op
,
1337 PatternRewriter
&rewriter
) const override
{
1338 auto vType
= op
.getVectorType();
1339 if (vType
.getRank() < 2)
1342 auto loc
= op
.getLoc();
1343 auto elemType
= vType
.getElementType();
1344 Value zero
= rewriter
.create
<arith::ConstantOp
>(
1345 loc
, elemType
, rewriter
.getZeroAttr(elemType
));
1346 Value desc
= rewriter
.create
<vector::SplatOp
>(loc
, vType
, zero
);
1347 for (int64_t i
= 0, e
= vType
.getShape().front(); i
!= e
; ++i
) {
1348 Value extrLHS
= rewriter
.create
<ExtractOp
>(loc
, op
.getLhs(), i
);
1349 Value extrRHS
= rewriter
.create
<ExtractOp
>(loc
, op
.getRhs(), i
);
1350 Value extrACC
= rewriter
.create
<ExtractOp
>(loc
, op
.getAcc(), i
);
1351 Value fma
= rewriter
.create
<FMAOp
>(loc
, extrLHS
, extrRHS
, extrACC
);
1352 desc
= rewriter
.create
<InsertOp
>(loc
, fma
, desc
, i
);
1354 rewriter
.replaceOp(op
, desc
);
1359 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1361 static std::optional
<SmallVector
<int64_t, 4>>
1362 computeContiguousStrides(MemRefType memRefType
) {
1364 SmallVector
<int64_t, 4> strides
;
1365 if (failed(getStridesAndOffset(memRefType
, strides
, offset
)))
1366 return std::nullopt
;
1367 if (!strides
.empty() && strides
.back() != 1)
1368 return std::nullopt
;
1369 // If no layout or identity layout, this is contiguous by definition.
1370 if (memRefType
.getLayout().isIdentity())
1373 // Otherwise, we must determine contiguity form shapes. This can only ever
1374 // work in static cases because MemRefType is underspecified to represent
1375 // contiguous dynamic shapes in other ways than with just empty/identity
1377 auto sizes
= memRefType
.getShape();
1378 for (int index
= 0, e
= strides
.size() - 1; index
< e
; ++index
) {
1379 if (ShapedType::isDynamic(sizes
[index
+ 1]) ||
1380 ShapedType::isDynamic(strides
[index
]) ||
1381 ShapedType::isDynamic(strides
[index
+ 1]))
1382 return std::nullopt
;
1383 if (strides
[index
] != strides
[index
+ 1] * sizes
[index
+ 1])
1384 return std::nullopt
;
1389 class VectorTypeCastOpConversion
1390 : public ConvertOpToLLVMPattern
<vector::TypeCastOp
> {
1392 using ConvertOpToLLVMPattern
<vector::TypeCastOp
>::ConvertOpToLLVMPattern
;
1395 matchAndRewrite(vector::TypeCastOp castOp
, OpAdaptor adaptor
,
1396 ConversionPatternRewriter
&rewriter
) const override
{
1397 auto loc
= castOp
->getLoc();
1398 MemRefType sourceMemRefType
=
1399 cast
<MemRefType
>(castOp
.getOperand().getType());
1400 MemRefType targetMemRefType
= castOp
.getType();
1402 // Only static shape casts supported atm.
1403 if (!sourceMemRefType
.hasStaticShape() ||
1404 !targetMemRefType
.hasStaticShape())
1407 auto llvmSourceDescriptorTy
=
1408 dyn_cast
<LLVM::LLVMStructType
>(adaptor
.getOperands()[0].getType());
1409 if (!llvmSourceDescriptorTy
)
1411 MemRefDescriptor
sourceMemRef(adaptor
.getOperands()[0]);
1413 auto llvmTargetDescriptorTy
= dyn_cast_or_null
<LLVM::LLVMStructType
>(
1414 typeConverter
->convertType(targetMemRefType
));
1415 if (!llvmTargetDescriptorTy
)
1418 // Only contiguous source buffers supported atm.
1419 auto sourceStrides
= computeContiguousStrides(sourceMemRefType
);
1422 auto targetStrides
= computeContiguousStrides(targetMemRefType
);
1425 // Only support static strides for now, regardless of contiguity.
1426 if (llvm::any_of(*targetStrides
, ShapedType::isDynamic
))
1429 auto int64Ty
= IntegerType::get(rewriter
.getContext(), 64);
1431 // Create descriptor.
1432 auto desc
= MemRefDescriptor::undef(rewriter
, loc
, llvmTargetDescriptorTy
);
1433 // Set allocated ptr.
1434 Value allocated
= sourceMemRef
.allocatedPtr(rewriter
, loc
);
1435 desc
.setAllocatedPtr(rewriter
, loc
, allocated
);
1438 Value ptr
= sourceMemRef
.alignedPtr(rewriter
, loc
);
1439 desc
.setAlignedPtr(rewriter
, loc
, ptr
);
1441 auto attr
= rewriter
.getIntegerAttr(rewriter
.getIndexType(), 0);
1442 auto zero
= rewriter
.create
<LLVM::ConstantOp
>(loc
, int64Ty
, attr
);
1443 desc
.setOffset(rewriter
, loc
, zero
);
1445 // Fill size and stride descriptors in memref.
1446 for (const auto &indexedSize
:
1447 llvm::enumerate(targetMemRefType
.getShape())) {
1448 int64_t index
= indexedSize
.index();
1450 rewriter
.getIntegerAttr(rewriter
.getIndexType(), indexedSize
.value());
1451 auto size
= rewriter
.create
<LLVM::ConstantOp
>(loc
, int64Ty
, sizeAttr
);
1452 desc
.setSize(rewriter
, loc
, index
, size
);
1453 auto strideAttr
= rewriter
.getIntegerAttr(rewriter
.getIndexType(),
1454 (*targetStrides
)[index
]);
1455 auto stride
= rewriter
.create
<LLVM::ConstantOp
>(loc
, int64Ty
, strideAttr
);
1456 desc
.setStride(rewriter
, loc
, index
, stride
);
1459 rewriter
.replaceOp(castOp
, {desc
});
1464 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1465 /// Non-scalable versions of this operation are handled in Vector Transforms.
1466 class VectorCreateMaskOpRewritePattern
1467 : public OpRewritePattern
<vector::CreateMaskOp
> {
1469 explicit VectorCreateMaskOpRewritePattern(MLIRContext
*context
,
1470 bool enableIndexOpt
)
1471 : OpRewritePattern
<vector::CreateMaskOp
>(context
),
1472 force32BitVectorIndices(enableIndexOpt
) {}
1474 LogicalResult
matchAndRewrite(vector::CreateMaskOp op
,
1475 PatternRewriter
&rewriter
) const override
{
1476 auto dstType
= op
.getType();
1477 if (dstType
.getRank() != 1 || !cast
<VectorType
>(dstType
).isScalable())
1479 IntegerType idxType
=
1480 force32BitVectorIndices
? rewriter
.getI32Type() : rewriter
.getI64Type();
1481 auto loc
= op
->getLoc();
1482 Value indices
= rewriter
.create
<LLVM::StepVectorOp
>(
1483 loc
, LLVM::getVectorType(idxType
, dstType
.getShape()[0],
1484 /*isScalable=*/true));
1485 auto bound
= getValueOrCreateCastToIndexLike(rewriter
, loc
, idxType
,
1487 Value bounds
= rewriter
.create
<SplatOp
>(loc
, indices
.getType(), bound
);
1488 Value comp
= rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::slt
,
1490 rewriter
.replaceOp(op
, comp
);
1495 const bool force32BitVectorIndices
;
1498 class VectorPrintOpConversion
: public ConvertOpToLLVMPattern
<vector::PrintOp
> {
1500 using ConvertOpToLLVMPattern
<vector::PrintOp
>::ConvertOpToLLVMPattern
;
1502 // Lowering implementation that relies on a small runtime support library,
1503 // which only needs to provide a few printing methods (single value for all
1504 // data types, opening/closing bracket, comma, newline). The lowering splits
1505 // the vector into elementary printing operations. The advantage of this
1506 // approach is that the library can remain unaware of all low-level
1507 // implementation details of vectors while still supporting output of any
1508 // shaped and dimensioned vector.
1510 // Note: This lowering only handles scalars, n-D vectors are broken into
1511 // printing scalars in loops in VectorToSCF.
1513 // TODO: rely solely on libc in future? something else?
1516 matchAndRewrite(vector::PrintOp printOp
, OpAdaptor adaptor
,
1517 ConversionPatternRewriter
&rewriter
) const override
{
1518 auto parent
= printOp
->getParentOfType
<ModuleOp
>();
1522 auto loc
= printOp
->getLoc();
1524 if (auto value
= adaptor
.getSource()) {
1525 Type printType
= printOp
.getPrintType();
1526 if (isa
<VectorType
>(printType
)) {
1527 // Vectors should be broken into elementary print ops in VectorToSCF.
1530 if (failed(emitScalarPrint(rewriter
, parent
, loc
, printType
, value
)))
1534 auto punct
= printOp
.getPunctuation();
1535 if (auto stringLiteral
= printOp
.getStringLiteral()) {
1536 LLVM::createPrintStrCall(rewriter
, loc
, parent
, "vector_print_str",
1537 *stringLiteral
, *getTypeConverter(),
1538 /*addNewline=*/false);
1539 } else if (punct
!= PrintPunctuation::NoPunctuation
) {
1540 emitCall(rewriter
, printOp
->getLoc(), [&] {
1542 case PrintPunctuation::Close
:
1543 return LLVM::lookupOrCreatePrintCloseFn(parent
);
1544 case PrintPunctuation::Open
:
1545 return LLVM::lookupOrCreatePrintOpenFn(parent
);
1546 case PrintPunctuation::Comma
:
1547 return LLVM::lookupOrCreatePrintCommaFn(parent
);
1548 case PrintPunctuation::NewLine
:
1549 return LLVM::lookupOrCreatePrintNewlineFn(parent
);
1551 llvm_unreachable("unexpected punctuation");
1556 rewriter
.eraseOp(printOp
);
1561 enum class PrintConversion
{
1570 LogicalResult
emitScalarPrint(ConversionPatternRewriter
&rewriter
,
1571 ModuleOp parent
, Location loc
, Type printType
,
1572 Value value
) const {
1573 if (typeConverter
->convertType(printType
) == nullptr)
1576 // Make sure element type has runtime support.
1577 PrintConversion conversion
= PrintConversion::None
;
1579 if (printType
.isF32()) {
1580 printer
= LLVM::lookupOrCreatePrintF32Fn(parent
);
1581 } else if (printType
.isF64()) {
1582 printer
= LLVM::lookupOrCreatePrintF64Fn(parent
);
1583 } else if (printType
.isF16()) {
1584 conversion
= PrintConversion::Bitcast16
; // bits!
1585 printer
= LLVM::lookupOrCreatePrintF16Fn(parent
);
1586 } else if (printType
.isBF16()) {
1587 conversion
= PrintConversion::Bitcast16
; // bits!
1588 printer
= LLVM::lookupOrCreatePrintBF16Fn(parent
);
1589 } else if (printType
.isIndex()) {
1590 printer
= LLVM::lookupOrCreatePrintU64Fn(parent
);
1591 } else if (auto intTy
= dyn_cast
<IntegerType
>(printType
)) {
1592 // Integers need a zero or sign extension on the operand
1593 // (depending on the source type) as well as a signed or
1594 // unsigned print method. Up to 64-bit is supported.
1595 unsigned width
= intTy
.getWidth();
1596 if (intTy
.isUnsigned()) {
1599 conversion
= PrintConversion::ZeroExt64
;
1600 printer
= LLVM::lookupOrCreatePrintU64Fn(parent
);
1605 assert(intTy
.isSignless() || intTy
.isSigned());
1607 // Note that we *always* zero extend booleans (1-bit integers),
1608 // so that true/false is printed as 1/0 rather than -1/0.
1610 conversion
= PrintConversion::ZeroExt64
;
1611 else if (width
< 64)
1612 conversion
= PrintConversion::SignExt64
;
1613 printer
= LLVM::lookupOrCreatePrintI64Fn(parent
);
1622 switch (conversion
) {
1623 case PrintConversion::ZeroExt64
:
1624 value
= rewriter
.create
<arith::ExtUIOp
>(
1625 loc
, IntegerType::get(rewriter
.getContext(), 64), value
);
1627 case PrintConversion::SignExt64
:
1628 value
= rewriter
.create
<arith::ExtSIOp
>(
1629 loc
, IntegerType::get(rewriter
.getContext(), 64), value
);
1631 case PrintConversion::Bitcast16
:
1632 value
= rewriter
.create
<LLVM::BitcastOp
>(
1633 loc
, IntegerType::get(rewriter
.getContext(), 16), value
);
1635 case PrintConversion::None
:
1638 emitCall(rewriter
, loc
, printer
, value
);
1642 // Helper to emit a call.
1643 static void emitCall(ConversionPatternRewriter
&rewriter
, Location loc
,
1644 Operation
*ref
, ValueRange params
= ValueRange()) {
1645 rewriter
.create
<LLVM::CallOp
>(loc
, TypeRange(), SymbolRefAttr::get(ref
),
1650 /// The Splat operation is lowered to an insertelement + a shufflevector
1651 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1652 struct VectorSplatOpLowering
: public ConvertOpToLLVMPattern
<vector::SplatOp
> {
1653 using ConvertOpToLLVMPattern
<vector::SplatOp
>::ConvertOpToLLVMPattern
;
1656 matchAndRewrite(vector::SplatOp splatOp
, OpAdaptor adaptor
,
1657 ConversionPatternRewriter
&rewriter
) const override
{
1658 VectorType resultType
= cast
<VectorType
>(splatOp
.getType());
1659 if (resultType
.getRank() > 1)
1662 // First insert it into an undef vector so we can shuffle it.
1663 auto vectorType
= typeConverter
->convertType(splatOp
.getType());
1664 Value undef
= rewriter
.create
<LLVM::UndefOp
>(splatOp
.getLoc(), vectorType
);
1665 auto zero
= rewriter
.create
<LLVM::ConstantOp
>(
1667 typeConverter
->convertType(rewriter
.getIntegerType(32)),
1668 rewriter
.getZeroAttr(rewriter
.getIntegerType(32)));
1670 // For 0-d vector, we simply do `insertelement`.
1671 if (resultType
.getRank() == 0) {
1672 rewriter
.replaceOpWithNewOp
<LLVM::InsertElementOp
>(
1673 splatOp
, vectorType
, undef
, adaptor
.getInput(), zero
);
1677 // For 1-d vector, we additionally do a `vectorshuffle`.
1678 auto v
= rewriter
.create
<LLVM::InsertElementOp
>(
1679 splatOp
.getLoc(), vectorType
, undef
, adaptor
.getInput(), zero
);
1681 int64_t width
= cast
<VectorType
>(splatOp
.getType()).getDimSize(0);
1682 SmallVector
<int32_t> zeroValues(width
, 0);
1684 // Shuffle the value across the desired number of elements.
1685 rewriter
.replaceOpWithNewOp
<LLVM::ShuffleVectorOp
>(splatOp
, v
, undef
,
1691 /// The Splat operation is lowered to an insertelement + a shufflevector
1692 /// operation. Splat to only 2+-d vector result types are lowered by the
1693 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1694 struct VectorSplatNdOpLowering
: public ConvertOpToLLVMPattern
<SplatOp
> {
1695 using ConvertOpToLLVMPattern
<SplatOp
>::ConvertOpToLLVMPattern
;
1698 matchAndRewrite(SplatOp splatOp
, OpAdaptor adaptor
,
1699 ConversionPatternRewriter
&rewriter
) const override
{
1700 VectorType resultType
= splatOp
.getType();
1701 if (resultType
.getRank() <= 1)
1704 // First insert it into an undef vector so we can shuffle it.
1705 auto loc
= splatOp
.getLoc();
1706 auto vectorTypeInfo
=
1707 LLVM::detail::extractNDVectorTypeInfo(resultType
, *getTypeConverter());
1708 auto llvmNDVectorTy
= vectorTypeInfo
.llvmNDVectorTy
;
1709 auto llvm1DVectorTy
= vectorTypeInfo
.llvm1DVectorTy
;
1710 if (!llvmNDVectorTy
|| !llvm1DVectorTy
)
1713 // Construct returned value.
1714 Value desc
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmNDVectorTy
);
1716 // Construct a 1-D vector with the splatted value that we insert in all the
1717 // places within the returned descriptor.
1718 Value vdesc
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvm1DVectorTy
);
1719 auto zero
= rewriter
.create
<LLVM::ConstantOp
>(
1720 loc
, typeConverter
->convertType(rewriter
.getIntegerType(32)),
1721 rewriter
.getZeroAttr(rewriter
.getIntegerType(32)));
1722 Value v
= rewriter
.create
<LLVM::InsertElementOp
>(loc
, llvm1DVectorTy
, vdesc
,
1723 adaptor
.getInput(), zero
);
1725 // Shuffle the value across the desired number of elements.
1726 int64_t width
= resultType
.getDimSize(resultType
.getRank() - 1);
1727 SmallVector
<int32_t> zeroValues(width
, 0);
1728 v
= rewriter
.create
<LLVM::ShuffleVectorOp
>(loc
, v
, v
, zeroValues
);
1730 // Iterate of linear index, convert to coords space and insert splatted 1-D
1731 // vector in each position.
1732 nDVectorIterate(vectorTypeInfo
, rewriter
, [&](ArrayRef
<int64_t> position
) {
1733 desc
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, desc
, v
, position
);
1735 rewriter
.replaceOp(splatOp
, desc
);
1740 /// Conversion pattern for a `vector.interleave`.
1741 /// This supports fixed-sized vectors and scalable vectors.
1742 struct VectorInterleaveOpLowering
1743 : public ConvertOpToLLVMPattern
<vector::InterleaveOp
> {
1744 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
1747 matchAndRewrite(vector::InterleaveOp interleaveOp
, OpAdaptor adaptor
,
1748 ConversionPatternRewriter
&rewriter
) const override
{
1749 VectorType resultType
= interleaveOp
.getResultVectorType();
1750 // n-D interleaves should have been lowered already.
1751 if (resultType
.getRank() != 1)
1752 return rewriter
.notifyMatchFailure(interleaveOp
,
1753 "InterleaveOp not rank 1");
1754 // If the result is rank 1, then this directly maps to LLVM.
1755 if (resultType
.isScalable()) {
1756 rewriter
.replaceOpWithNewOp
<LLVM::vector_interleave2
>(
1757 interleaveOp
, typeConverter
->convertType(resultType
),
1758 adaptor
.getLhs(), adaptor
.getRhs());
1761 // Lower fixed-size interleaves to a shufflevector. While the
1762 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1763 // langref still recommends fixed-vectors use shufflevector, see:
1764 // https://llvm.org/docs/LangRef.html#id876.
1765 int64_t resultVectorSize
= resultType
.getNumElements();
1766 SmallVector
<int32_t> interleaveShuffleMask
;
1767 interleaveShuffleMask
.reserve(resultVectorSize
);
1768 for (int i
= 0, end
= resultVectorSize
/ 2; i
< end
; ++i
) {
1769 interleaveShuffleMask
.push_back(i
);
1770 interleaveShuffleMask
.push_back((resultVectorSize
/ 2) + i
);
1772 rewriter
.replaceOpWithNewOp
<LLVM::ShuffleVectorOp
>(
1773 interleaveOp
, adaptor
.getLhs(), adaptor
.getRhs(),
1774 interleaveShuffleMask
);
1779 /// Conversion pattern for a `vector.deinterleave`.
1780 /// This supports fixed-sized vectors and scalable vectors.
1781 struct VectorDeinterleaveOpLowering
1782 : public ConvertOpToLLVMPattern
<vector::DeinterleaveOp
> {
1783 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
1786 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp
, OpAdaptor adaptor
,
1787 ConversionPatternRewriter
&rewriter
) const override
{
1788 VectorType resultType
= deinterleaveOp
.getResultVectorType();
1789 VectorType sourceType
= deinterleaveOp
.getSourceVectorType();
1790 auto loc
= deinterleaveOp
.getLoc();
1792 // Note: n-D deinterleave operations should be lowered to the 1-D before
1793 // converting to LLVM.
1794 if (resultType
.getRank() != 1)
1795 return rewriter
.notifyMatchFailure(deinterleaveOp
,
1796 "DeinterleaveOp not rank 1");
1798 if (resultType
.isScalable()) {
1799 auto llvmTypeConverter
= this->getTypeConverter();
1800 auto deinterleaveResults
= deinterleaveOp
.getResultTypes();
1801 auto packedOpResults
=
1802 llvmTypeConverter
->packOperationResults(deinterleaveResults
);
1803 auto intrinsic
= rewriter
.create
<LLVM::vector_deinterleave2
>(
1804 loc
, packedOpResults
, adaptor
.getSource());
1806 auto evenResult
= rewriter
.create
<LLVM::ExtractValueOp
>(
1807 loc
, intrinsic
->getResult(0), 0);
1808 auto oddResult
= rewriter
.create
<LLVM::ExtractValueOp
>(
1809 loc
, intrinsic
->getResult(0), 1);
1811 rewriter
.replaceOp(deinterleaveOp
, ValueRange
{evenResult
, oddResult
});
1814 // Lower fixed-size deinterleave to two shufflevectors. While the
1815 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1816 // langref still recommends fixed-vectors use shufflevector, see:
1817 // https://llvm.org/docs/LangRef.html#id889.
1818 int64_t resultVectorSize
= resultType
.getNumElements();
1819 SmallVector
<int32_t> evenShuffleMask
;
1820 SmallVector
<int32_t> oddShuffleMask
;
1822 evenShuffleMask
.reserve(resultVectorSize
);
1823 oddShuffleMask
.reserve(resultVectorSize
);
1825 for (int i
= 0; i
< sourceType
.getNumElements(); ++i
) {
1827 evenShuffleMask
.push_back(i
);
1829 oddShuffleMask
.push_back(i
);
1832 auto poison
= rewriter
.create
<LLVM::PoisonOp
>(loc
, sourceType
);
1833 auto evenShuffle
= rewriter
.create
<LLVM::ShuffleVectorOp
>(
1834 loc
, adaptor
.getSource(), poison
, evenShuffleMask
);
1835 auto oddShuffle
= rewriter
.create
<LLVM::ShuffleVectorOp
>(
1836 loc
, adaptor
.getSource(), poison
, oddShuffleMask
);
1838 rewriter
.replaceOp(deinterleaveOp
, ValueRange
{evenShuffle
, oddShuffle
});
1843 /// Conversion pattern for a `vector.from_elements`.
1844 struct VectorFromElementsLowering
1845 : public ConvertOpToLLVMPattern
<vector::FromElementsOp
> {
1846 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
1849 matchAndRewrite(vector::FromElementsOp fromElementsOp
, OpAdaptor adaptor
,
1850 ConversionPatternRewriter
&rewriter
) const override
{
1851 Location loc
= fromElementsOp
.getLoc();
1852 VectorType vectorType
= fromElementsOp
.getType();
1853 // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1854 // Such ops should be handled in the same way as vector.insert.
1855 if (vectorType
.getRank() > 1)
1856 return rewriter
.notifyMatchFailure(fromElementsOp
,
1857 "rank > 1 vectors are not supported");
1858 Type llvmType
= typeConverter
->convertType(vectorType
);
1859 Value result
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmType
);
1860 for (auto [idx
, val
] : llvm::enumerate(adaptor
.getElements()))
1861 result
= rewriter
.create
<vector::InsertOp
>(loc
, val
, result
, idx
);
1862 rewriter
.replaceOp(fromElementsOp
, result
);
1867 /// Conversion pattern for vector.step.
1868 struct VectorScalableStepOpLowering
1869 : public ConvertOpToLLVMPattern
<vector::StepOp
> {
1870 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
1873 matchAndRewrite(vector::StepOp stepOp
, OpAdaptor adaptor
,
1874 ConversionPatternRewriter
&rewriter
) const override
{
1875 auto resultType
= cast
<VectorType
>(stepOp
.getType());
1876 if (!resultType
.isScalable()) {
1879 Type llvmType
= typeConverter
->convertType(stepOp
.getType());
1880 rewriter
.replaceOpWithNewOp
<LLVM::StepVectorOp
>(stepOp
, llvmType
);
1887 /// Populate the given list with patterns that convert from Vector to LLVM.
1888 void mlir::populateVectorToLLVMConversionPatterns(
1889 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
1890 bool reassociateFPReductions
, bool force32BitVectorIndices
) {
1891 MLIRContext
*ctx
= converter
.getDialect()->getContext();
1892 patterns
.add
<VectorFMAOpNDRewritePattern
>(ctx
);
1893 populateVectorInsertExtractStridedSliceTransforms(patterns
);
1894 populateVectorStepLoweringPatterns(patterns
);
1895 patterns
.add
<VectorReductionOpConversion
>(converter
, reassociateFPReductions
);
1896 patterns
.add
<VectorCreateMaskOpRewritePattern
>(ctx
, force32BitVectorIndices
);
1897 patterns
.add
<VectorBitCastOpConversion
, VectorShuffleOpConversion
,
1898 VectorExtractElementOpConversion
, VectorExtractOpConversion
,
1899 VectorFMAOp1DConversion
, VectorInsertElementOpConversion
,
1900 VectorInsertOpConversion
, VectorPrintOpConversion
,
1901 VectorTypeCastOpConversion
, VectorScaleOpConversion
,
1902 VectorLoadStoreConversion
<vector::LoadOp
>,
1903 VectorLoadStoreConversion
<vector::MaskedLoadOp
>,
1904 VectorLoadStoreConversion
<vector::StoreOp
>,
1905 VectorLoadStoreConversion
<vector::MaskedStoreOp
>,
1906 VectorGatherOpConversion
, VectorScatterOpConversion
,
1907 VectorExpandLoadOpConversion
, VectorCompressStoreOpConversion
,
1908 VectorSplatOpLowering
, VectorSplatNdOpLowering
,
1909 VectorScalableInsertOpLowering
, VectorScalableExtractOpLowering
,
1910 MaskedReductionOpConversion
, VectorInterleaveOpLowering
,
1911 VectorDeinterleaveOpLowering
, VectorFromElementsLowering
,
1912 VectorScalableStepOpLowering
>(converter
);
1913 // Transfer ops with rank > 1 are handled by VectorToSCF.
1914 populateVectorTransferLoweringPatterns(patterns
, /*maxTransferRank=*/1);
1917 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1918 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
1919 patterns
.add
<VectorMatmulOpConversion
>(converter
);
1920 patterns
.add
<VectorFlatTransposeOpConversion
>(converter
);