1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/UB/IR/UBOps.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/DialectImplementation.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/OpImplementation.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Interfaces/SubsetOpInterface.h"
37 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
38 #include "mlir/Support/LLVM.h"
39 #include "mlir/Transforms/InliningUtils.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/ADT/bit.h"
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
56 using namespace mlir::vector
;
58 /// Helper enum to classify mask value.
59 enum class MaskFormat
{
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
69 static MaskFormat
getMaskFormat(Value mask
) {
70 if (auto c
= mask
.getDefiningOp
<arith::ConstantOp
>()) {
71 // Inspect constant dense values. We count up for bits that
72 // are set, count down for bits that are cleared, and bail
73 // when a mix is detected.
74 if (auto denseElts
= llvm::dyn_cast
<DenseIntElementsAttr
>(c
.getValue())) {
76 for (bool b
: denseElts
.getValues
<bool>())
79 else if (!b
&& val
<= 0)
82 return MaskFormat::Unknown
;
84 return MaskFormat::AllTrue
;
86 return MaskFormat::AllFalse
;
88 } else if (auto m
= mask
.getDefiningOp
<ConstantMaskOp
>()) {
89 // Inspect constant mask index. If the index exceeds the
90 // dimension size, all bits are set. If the index is zero
91 // or less, no bits are set.
92 ArrayRef
<int64_t> masks
= m
.getMaskDimSizes();
93 auto shape
= m
.getType().getShape();
96 for (auto [maskIdx
, dimSize
] : llvm::zip_equal(masks
, shape
)) {
97 if (maskIdx
< dimSize
)
103 return MaskFormat::AllTrue
;
105 return MaskFormat::AllFalse
;
106 } else if (auto m
= mask
.getDefiningOp
<CreateMaskOp
>()) {
107 // Finds all-false create_masks. An all-true create_mask requires all
108 // dims to be constants, so that'll be folded to a constant_mask, then
109 // detected in the constant_mask case.
110 auto maskOperands
= m
.getOperands();
111 for (Value operand
: maskOperands
) {
112 if (auto constantOp
= operand
.getDefiningOp
<arith::ConstantOp
>()) {
114 llvm::cast
<IntegerAttr
>(constantOp
.getValue()).getInt();
116 return MaskFormat::AllFalse
;
119 return MaskFormat::Unknown
;
121 return MaskFormat::Unknown
;
124 /// Default callback to build a region with a 'vector.yield' terminator with no
126 void mlir::vector::buildTerminatedBody(OpBuilder
&builder
, Location loc
) {
127 builder
.create
<vector::YieldOp
>(loc
);
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind
,
133 switch (combiningKind
) {
134 case CombiningKind::ADD
:
135 case CombiningKind::MUL
:
136 return elementType
.isIntOrIndexOrFloat();
137 case CombiningKind::MINUI
:
138 case CombiningKind::MINSI
:
139 case CombiningKind::MAXUI
:
140 case CombiningKind::MAXSI
:
141 case CombiningKind::AND
:
142 case CombiningKind::OR
:
143 case CombiningKind::XOR
:
144 return elementType
.isIntOrIndex();
145 case CombiningKind::MINNUMF
:
146 case CombiningKind::MAXNUMF
:
147 case CombiningKind::MINIMUMF
:
148 case CombiningKind::MAXIMUMF
:
149 return llvm::isa
<FloatType
>(elementType
);
154 AffineMap
mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType
,
155 VectorType vectorType
) {
156 int64_t elementVectorRank
= 0;
157 VectorType elementVectorType
=
158 llvm::dyn_cast
<VectorType
>(shapedType
.getElementType());
159 if (elementVectorType
)
160 elementVectorRank
+= elementVectorType
.getRank();
161 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162 // TODO: replace once we have 0-d vectors.
163 if (shapedType
.getRank() == 0 &&
164 vectorType
.getShape() == ArrayRef
<int64_t>{1})
165 return AffineMap::get(
166 /*numDims=*/0, /*numSymbols=*/0,
167 getAffineConstantExpr(0, shapedType
.getContext()));
168 return AffineMap::getMinorIdentityMap(
169 shapedType
.getRank(), vectorType
.getRank() - elementVectorRank
,
170 shapedType
.getContext());
173 /// Check if `write` is of a constant splat and the masked `read` is padded with
174 /// the same splat value -- meaning it could be the same value as the initial
176 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write
,
177 vector::TransferReadOp read
) {
178 auto readMask
= read
.getMask();
179 auto writeMask
= write
.getMask();
180 // Check if the masks are consistent. The splat value could be the same if the
181 // read is masked (and padded with the splat value), and the write is unmasked
182 // or has the same mask. Note this does not allow the case where the write is
183 // masked and the read is unmasked, as then the read could be of more elements
184 // than the write (which may not be the same value).
185 bool couldBeSameSplat
= readMask
&& (!writeMask
|| writeMask
== readMask
);
186 if (!couldBeSameSplat
)
188 // Check for constant splat (as the source of the write).
189 DenseElementsAttr splatAttr
;
190 if (!matchPattern(write
.getVector(),
191 m_Constant
<DenseElementsAttr
>(&splatAttr
)) ||
192 !splatAttr
.isSplat()) {
195 // The padding of the read and the constant splat value must be the same.
197 if (!matchPattern(read
.getPadding(), m_Constant(&padAttr
)))
199 return padAttr
== splatAttr
.getSplatValue
<Attribute
>();
202 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite
,
203 vector::TransferReadOp read
) {
204 return !defWrite
.hasOutOfBoundsDim() &&
205 defWrite
.getIndices() == read
.getIndices() &&
206 defWrite
.getVectorType() == read
.getVectorType() &&
207 defWrite
.getPermutationMap() == read
.getPermutationMap() &&
208 ((!defWrite
.getMask() && !read
.getMask()) ||
209 isSplatWriteConsistentWithMaskedRead(defWrite
, read
));
212 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write
,
213 vector::TransferWriteOp priorWrite
) {
214 return priorWrite
.getIndices() == write
.getIndices() &&
215 priorWrite
.getMask() == write
.getMask() &&
216 priorWrite
.getVectorType() == write
.getVectorType() &&
217 priorWrite
.getPermutationMap() == write
.getPermutationMap();
220 bool mlir::vector::isDisjointTransferIndices(
221 VectorTransferOpInterface transferA
, VectorTransferOpInterface transferB
,
222 bool testDynamicValueUsingBounds
) {
223 // For simplicity only look at transfer of same type.
224 if (transferA
.getVectorType() != transferB
.getVectorType())
226 unsigned rankOffset
= transferA
.getLeadingShapedRank();
227 for (unsigned i
= 0, e
= transferA
.getIndices().size(); i
< e
; i
++) {
228 Value indexA
= transferA
.getIndices()[i
];
229 Value indexB
= transferB
.getIndices()[i
];
230 std::optional
<int64_t> cstIndexA
= getConstantIntValue(indexA
);
231 std::optional
<int64_t> cstIndexB
= getConstantIntValue(indexB
);
233 if (i
< rankOffset
) {
234 // For leading dimensions, if we can prove that index are different we
235 // know we are accessing disjoint slices.
236 if (cstIndexA
.has_value() && cstIndexB
.has_value()) {
237 if (*cstIndexA
!= *cstIndexB
)
241 if (testDynamicValueUsingBounds
) {
242 // First try to see if we can fully compose and simplify the affine
243 // expression as a fast track.
244 FailureOr
<uint64_t> delta
=
245 affine::fullyComposeAndComputeConstantDelta(indexA
, indexB
);
246 if (succeeded(delta
) && *delta
!= 0)
249 FailureOr
<bool> testEqual
=
250 ValueBoundsConstraintSet::areEqual(indexA
, indexB
);
251 if (succeeded(testEqual
) && !testEqual
.value())
255 // For this dimension, we slice a part of the memref we need to make sure
256 // the intervals accessed don't overlap.
257 int64_t vectorDim
= transferA
.getVectorType().getDimSize(i
- rankOffset
);
258 if (cstIndexA
.has_value() && cstIndexB
.has_value()) {
259 int64_t distance
= std::abs(*cstIndexA
- *cstIndexB
);
260 if (distance
>= vectorDim
)
264 if (testDynamicValueUsingBounds
) {
265 // First try to see if we can fully compose and simplify the affine
266 // expression as a fast track.
267 FailureOr
<int64_t> delta
=
268 affine::fullyComposeAndComputeConstantDelta(indexA
, indexB
);
269 if (succeeded(delta
) && std::abs(*delta
) >= vectorDim
)
272 FailureOr
<int64_t> computeDelta
=
273 ValueBoundsConstraintSet::computeConstantDelta(indexA
, indexB
);
274 if (succeeded(computeDelta
)) {
275 if (std::abs(computeDelta
.value()) >= vectorDim
)
284 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA
,
285 VectorTransferOpInterface transferB
,
286 bool testDynamicValueUsingBounds
) {
287 if (transferA
.getSource() != transferB
.getSource())
289 return isDisjointTransferIndices(transferA
, transferB
,
290 testDynamicValueUsingBounds
);
293 // Helper to iterate over n-D vector slice elements. Calculate the next
294 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
295 // Modifies the `position` in place. Returns a failure when `position` becomes
297 static LogicalResult
incSlicePosition(MutableArrayRef
<int64_t> position
,
298 ArrayRef
<int64_t> shape
,
299 ArrayRef
<int64_t> offsets
) {
300 for (auto [posInDim
, dimSize
, offsetInDim
] :
301 llvm::reverse(llvm::zip_equal(position
, shape
, offsets
))) {
303 if (posInDim
< dimSize
+ offsetInDim
)
306 // Carry the overflow to the next loop iteration.
307 posInDim
= offsetInDim
;
313 /// Returns the integer numbers in `values`. `values` are expected to be
314 /// constant operations.
315 SmallVector
<int64_t> vector::getAsIntegers(ArrayRef
<Value
> values
) {
316 SmallVector
<int64_t> ints
;
317 llvm::transform(values
, std::back_inserter(ints
), [](Value value
) {
318 auto constOp
= value
.getDefiningOp
<arith::ConstantIndexOp
>();
319 assert(constOp
&& "Unexpected non-constant index");
320 return constOp
.value();
325 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
326 /// be constant operations.
327 SmallVector
<int64_t> vector::getAsIntegers(ArrayRef
<OpFoldResult
> foldResults
) {
328 SmallVector
<int64_t> ints
;
330 foldResults
, std::back_inserter(ints
), [](OpFoldResult foldResult
) {
331 assert(isa
<Attribute
>(foldResult
) && "Unexpected non-constant index");
332 return cast
<IntegerAttr
>(cast
<Attribute
>(foldResult
)).getInt();
337 /// Convert `foldResults` into Values. Integer attributes are converted to
339 SmallVector
<Value
> vector::getAsValues(OpBuilder
&builder
, Location loc
,
340 ArrayRef
<OpFoldResult
> foldResults
) {
341 SmallVector
<Value
> values
;
342 llvm::transform(foldResults
, std::back_inserter(values
),
343 [&](OpFoldResult foldResult
) {
344 if (auto attr
= foldResult
.dyn_cast
<Attribute
>())
346 .create
<arith::ConstantIndexOp
>(
347 loc
, cast
<IntegerAttr
>(attr
).getInt())
350 return cast
<Value
>(foldResult
);
355 std::optional
<int64_t> vector::getConstantVscaleMultiplier(Value value
) {
356 if (value
.getDefiningOp
<vector::VectorScaleOp
>())
358 auto mul
= value
.getDefiningOp
<arith::MulIOp
>();
361 auto lhs
= mul
.getLhs();
362 auto rhs
= mul
.getRhs();
363 if (lhs
.getDefiningOp
<vector::VectorScaleOp
>())
364 return getConstantIntValue(rhs
);
365 if (rhs
.getDefiningOp
<vector::VectorScaleOp
>())
366 return getConstantIntValue(lhs
);
370 //===----------------------------------------------------------------------===//
372 //===----------------------------------------------------------------------===//
377 struct BitmaskEnumStorage
: public AttributeStorage
{
378 using KeyTy
= uint64_t;
380 BitmaskEnumStorage(KeyTy val
) : value(val
) {}
382 bool operator==(const KeyTy
&key
) const { return value
== key
; }
384 static BitmaskEnumStorage
*construct(AttributeStorageAllocator
&allocator
,
386 return new (allocator
.allocate
<BitmaskEnumStorage
>())
387 BitmaskEnumStorage(key
);
392 } // namespace detail
393 } // namespace vector
396 //===----------------------------------------------------------------------===//
398 //===----------------------------------------------------------------------===//
401 /// This class defines the interface for handling inlining with vector dialect
403 struct VectorInlinerInterface
: public DialectInlinerInterface
{
404 using DialectInlinerInterface::DialectInlinerInterface
;
406 /// All vector dialect ops can be inlined.
407 bool isLegalToInline(Operation
*, Region
*, bool, IRMapping
&) const final
{
413 void VectorDialect::initialize() {
415 #define GET_ATTRDEF_LIST
416 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
421 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
424 addInterfaces
<VectorInlinerInterface
>();
426 declarePromisedInterfaces
<bufferization::BufferizableOpInterface
,
427 TransferReadOp
, TransferWriteOp
, GatherOp
, MaskOp
,
429 declarePromisedInterfaces
<SubsetOpInterface
, TransferReadOp
,
431 declarePromisedInterface
<SubsetExtractionOpInterface
, TransferReadOp
>();
432 declarePromisedInterface
<SubsetInsertionOpInterface
, TransferWriteOp
>();
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
437 Operation
*VectorDialect::materializeConstant(OpBuilder
&builder
,
438 Attribute value
, Type type
,
440 return arith::ConstantOp::materialize(builder
, value
, type
, loc
);
443 IntegerType
vector::getVectorSubscriptType(Builder
&builder
) {
444 return builder
.getIntegerType(64);
447 ArrayAttr
vector::getVectorSubscriptAttr(Builder
&builder
,
448 ArrayRef
<int64_t> values
) {
449 return builder
.getI64ArrayAttr(values
);
452 //===----------------------------------------------------------------------===//
453 // MultiDimReductionOp
454 //===----------------------------------------------------------------------===//
456 void vector::MultiDimReductionOp::build(OpBuilder
&builder
,
457 OperationState
&result
, Value source
,
458 Value acc
, ArrayRef
<bool> reductionMask
,
459 CombiningKind kind
) {
460 SmallVector
<int64_t> reductionDims
;
461 for (const auto &en
: llvm::enumerate(reductionMask
))
463 reductionDims
.push_back(en
.index());
464 build(builder
, result
, kind
, source
, acc
, reductionDims
);
467 OpFoldResult
MultiDimReductionOp::fold(FoldAdaptor adaptor
) {
468 // Single parallel dim, this is a noop.
469 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
474 std::optional
<SmallVector
<int64_t, 4>>
475 MultiDimReductionOp::getShapeForUnroll() {
476 return llvm::to_vector
<4>(getSourceVectorType().getShape());
479 LogicalResult
MultiDimReductionOp::verify() {
480 SmallVector
<int64_t> targetShape
;
481 SmallVector
<bool> scalableDims
;
482 Type inferredReturnType
;
483 auto sourceScalableDims
= getSourceVectorType().getScalableDims();
484 for (auto [dimIdx
, dimSize
] :
485 llvm::enumerate(getSourceVectorType().getShape()))
486 if (!llvm::any_of(getReductionDims(),
487 [dimIdx
= dimIdx
](int64_t reductionDimIdx
) {
488 return reductionDimIdx
== static_cast<int64_t>(dimIdx
);
490 targetShape
.push_back(dimSize
);
491 scalableDims
.push_back(sourceScalableDims
[dimIdx
]);
493 // TODO: update to also allow 0-d vectors when available.
494 if (targetShape
.empty())
495 inferredReturnType
= getSourceVectorType().getElementType();
497 inferredReturnType
= VectorType::get(
498 targetShape
, getSourceVectorType().getElementType(), scalableDims
);
499 if (getType() != inferredReturnType
)
500 return emitOpError() << "destination type " << getType()
501 << " is incompatible with source type "
502 << getSourceVectorType();
507 /// Returns the mask type expected by this operation.
508 Type
MultiDimReductionOp::getExpectedMaskType() {
509 auto vecType
= getSourceVectorType();
510 return VectorType::get(vecType
.getShape(),
511 IntegerType::get(vecType
.getContext(), /*width=*/1),
512 vecType
.getScalableDims());
516 // Only unit dimensions that are being reduced are folded. If the dimension is
517 // unit, but not reduced, it is not folded, thereby keeping the output type the
518 // same. If not all dimensions which are reduced are of unit dimension, this
519 // transformation does nothing. This is just a generalization of
520 // ElideSingleElementReduction for ReduceOp.
521 struct ElideUnitDimsInMultiDimReduction
522 : public OpRewritePattern
<MultiDimReductionOp
> {
523 using OpRewritePattern::OpRewritePattern
;
525 LogicalResult
matchAndRewrite(MultiDimReductionOp reductionOp
,
526 PatternRewriter
&rewriter
) const override
{
527 ArrayRef
<int64_t> shape
= reductionOp
.getSourceVectorType().getShape();
528 for (const auto &dim
: enumerate(shape
)) {
529 if (reductionOp
.isReducedDim(dim
.index()) && dim
.value() != 1)
533 // Vector mask setup.
534 OpBuilder::InsertionGuard
guard(rewriter
);
537 if (reductionOp
.isMasked()) {
538 rewriter
.setInsertionPoint(reductionOp
.getMaskingOp());
539 rootOp
= reductionOp
.getMaskingOp();
540 mask
= reductionOp
.getMaskingOp().getMask();
542 rootOp
= reductionOp
;
545 Location loc
= reductionOp
.getLoc();
546 Value acc
= reductionOp
.getAcc();
548 if (auto dstVecType
= dyn_cast
<VectorType
>(reductionOp
.getDestType())) {
550 VectorType newMaskType
=
551 VectorType::get(dstVecType
.getShape(), rewriter
.getI1Type(),
552 dstVecType
.getScalableDims());
553 mask
= rewriter
.create
<vector::ShapeCastOp
>(loc
, newMaskType
, mask
);
555 cast
= rewriter
.create
<vector::ShapeCastOp
>(
556 loc
, reductionOp
.getDestType(), reductionOp
.getSource());
558 // This means we are reducing all the dimensions, and all reduction
559 // dimensions are of size 1. So a simple extraction would do.
560 SmallVector
<int64_t> zeroIdx(shape
.size(), 0);
562 mask
= rewriter
.create
<vector::ExtractOp
>(loc
, mask
, zeroIdx
);
563 cast
= rewriter
.create
<vector::ExtractOp
>(loc
, reductionOp
.getSource(),
568 vector::makeArithReduction(rewriter
, loc
, reductionOp
.getKind(), acc
,
569 cast
, /*fastmath=*/nullptr, mask
);
570 rewriter
.replaceOp(rootOp
, result
);
576 void MultiDimReductionOp::getCanonicalizationPatterns(
577 RewritePatternSet
&results
, MLIRContext
*context
) {
578 results
.add
<ElideUnitDimsInMultiDimReduction
>(context
);
581 //===----------------------------------------------------------------------===//
583 //===----------------------------------------------------------------------===//
585 void vector::ReductionOp::build(OpBuilder
&builder
, OperationState
&result
,
586 CombiningKind kind
, Value vector
,
587 arith::FastMathFlags fastMathFlags
) {
588 build(builder
, result
, kind
, vector
, /*acc=*/Value(), fastMathFlags
);
591 void vector::ReductionOp::build(OpBuilder
&builder
, OperationState
&result
,
592 CombiningKind kind
, Value vector
, Value acc
,
593 arith::FastMathFlags fastMathFlags
) {
594 build(builder
, result
,
595 llvm::cast
<VectorType
>(vector
.getType()).getElementType(), kind
, vector
,
599 LogicalResult
ReductionOp::verify() {
600 // Verify for 0-D and 1-D vector.
601 int64_t rank
= getSourceVectorType().getRank();
603 return emitOpError("unsupported reduction rank: ") << rank
;
605 // Verify supported reduction kind.
606 Type eltType
= getDest().getType();
607 if (!isSupportedCombiningKind(getKind(), eltType
))
608 return emitOpError("unsupported reduction type '")
609 << eltType
<< "' for kind '" << stringifyCombiningKind(getKind())
615 // MaskableOpInterface methods.
617 /// Returns the mask type expected by this operation.
618 Type
ReductionOp::getExpectedMaskType() {
619 auto vecType
= getSourceVectorType();
620 return VectorType::get(vecType
.getShape(),
621 IntegerType::get(vecType
.getContext(), /*width=*/1),
622 vecType
.getScalableDims());
625 Value
mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op
,
626 OpBuilder
&builder
, Location loc
,
629 case arith::AtomicRMWKind::addf
:
630 case arith::AtomicRMWKind::addi
:
631 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
632 CombiningKind::ADD
, vector
);
633 case arith::AtomicRMWKind::mulf
:
634 case arith::AtomicRMWKind::muli
:
635 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
636 CombiningKind::MUL
, vector
);
637 case arith::AtomicRMWKind::minimumf
:
638 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
639 CombiningKind::MINIMUMF
, vector
);
640 case arith::AtomicRMWKind::mins
:
641 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
642 CombiningKind::MINSI
, vector
);
643 case arith::AtomicRMWKind::minu
:
644 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
645 CombiningKind::MINUI
, vector
);
646 case arith::AtomicRMWKind::maximumf
:
647 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
648 CombiningKind::MAXIMUMF
, vector
);
649 case arith::AtomicRMWKind::maxs
:
650 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
651 CombiningKind::MAXSI
, vector
);
652 case arith::AtomicRMWKind::maxu
:
653 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
654 CombiningKind::MAXUI
, vector
);
655 case arith::AtomicRMWKind::andi
:
656 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
657 CombiningKind::AND
, vector
);
658 case arith::AtomicRMWKind::ori
:
659 return builder
.create
<vector::ReductionOp
>(vector
.getLoc(),
660 CombiningKind::OR
, vector
);
661 // TODO: Add remaining reduction operations.
663 (void)emitOptionalError(loc
, "Reduction operation type not supported");
669 std::optional
<SmallVector
<int64_t, 4>> ReductionOp::getShapeForUnroll() {
670 return llvm::to_vector
<4>(getSourceVectorType().getShape());
674 struct ElideSingleElementReduction
: public OpRewritePattern
<ReductionOp
> {
675 using OpRewritePattern::OpRewritePattern
;
677 LogicalResult
matchAndRewrite(ReductionOp reductionOp
,
678 PatternRewriter
&rewriter
) const override
{
679 // Vector mask setup.
680 OpBuilder::InsertionGuard
guard(rewriter
);
682 cast
<vector::MaskableOpInterface
>(reductionOp
.getOperation());
685 if (maskableOp
.isMasked()) {
686 rewriter
.setInsertionPoint(maskableOp
.getMaskingOp());
687 rootOp
= maskableOp
.getMaskingOp();
688 mask
= maskableOp
.getMaskingOp().getMask();
690 rootOp
= reductionOp
;
693 auto vectorType
= reductionOp
.getSourceVectorType();
694 if (vectorType
.getRank() != 0 && vectorType
.getDimSize(0) != 1)
697 Location loc
= reductionOp
.getLoc();
699 if (vectorType
.getRank() == 0) {
701 mask
= rewriter
.create
<ExtractElementOp
>(loc
, mask
);
702 result
= rewriter
.create
<ExtractElementOp
>(loc
, reductionOp
.getVector());
705 mask
= rewriter
.create
<ExtractOp
>(loc
, mask
, 0);
706 result
= rewriter
.create
<ExtractOp
>(loc
, reductionOp
.getVector(), 0);
709 if (Value acc
= reductionOp
.getAcc())
710 result
= vector::makeArithReduction(rewriter
, loc
, reductionOp
.getKind(),
712 reductionOp
.getFastmathAttr(), mask
);
714 rewriter
.replaceOp(rootOp
, result
);
720 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
721 MLIRContext
*context
) {
722 results
.add
<ElideSingleElementReduction
>(context
);
725 //===----------------------------------------------------------------------===//
727 //===----------------------------------------------------------------------===//
729 void vector::ContractionOp::build(OpBuilder
&builder
, OperationState
&result
,
730 Value lhs
, Value rhs
, Value acc
,
731 ArrayRef
<ArrayRef
<AffineExpr
>> indexingExprs
,
732 ArrayRef
<IteratorType
> iteratorTypes
) {
733 result
.addOperands({lhs
, rhs
, acc
});
734 result
.addTypes(acc
.getType());
736 getIndexingMapsAttrName(result
.name
),
737 builder
.getAffineMapArrayAttr(
738 AffineMap::inferFromExprList(indexingExprs
, builder
.getContext())));
740 getIteratorTypesAttrName(result
.name
),
741 builder
.getArrayAttr(llvm::to_vector(llvm::map_range(
742 iteratorTypes
, [&](IteratorType t
) -> mlir::Attribute
{
743 return IteratorTypeAttr::get(builder
.getContext(), t
);
747 void vector::ContractionOp::build(OpBuilder
&builder
, OperationState
&result
,
748 Value lhs
, Value rhs
, Value acc
,
749 ArrayAttr indexingMaps
,
750 ArrayAttr iteratorTypes
) {
751 build(builder
, result
, lhs
, rhs
, acc
, indexingMaps
, iteratorTypes
,
752 ContractionOp::getDefaultKind());
755 void vector::ContractionOp::build(OpBuilder
&builder
, OperationState
&result
,
756 Value lhs
, Value rhs
, Value acc
,
757 ArrayAttr indexingMaps
,
758 ArrayAttr iteratorTypes
, CombiningKind kind
) {
759 result
.addOperands({lhs
, rhs
, acc
});
760 result
.addTypes(acc
.getType());
761 result
.addAttribute(getIndexingMapsAttrName(result
.name
), indexingMaps
);
762 result
.addAttribute(getIteratorTypesAttrName(result
.name
), iteratorTypes
);
763 result
.addAttribute(getKindAttrName(result
.name
),
764 CombiningKindAttr::get(builder
.getContext(), kind
));
767 ParseResult
ContractionOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
768 OpAsmParser::UnresolvedOperand lhsInfo
;
769 OpAsmParser::UnresolvedOperand rhsInfo
;
770 OpAsmParser::UnresolvedOperand accInfo
;
771 SmallVector
<OpAsmParser::UnresolvedOperand
, 2> masksInfo
;
772 SmallVector
<Type
, 2> types
;
774 auto loc
= parser
.getCurrentLocation();
775 DictionaryAttr dictAttr
;
776 // TODO: Unify linalg op attribute parsing.
777 if (parser
.parseAttribute(dictAttr
) || parser
.parseOperand(lhsInfo
) ||
778 parser
.parseComma() || parser
.parseOperand(rhsInfo
) ||
779 parser
.parseComma() || parser
.parseOperand(accInfo
) ||
780 parser
.parseTrailingOperandList(masksInfo
) ||
781 parser
.parseOptionalAttrDict(result
.attributes
) ||
782 parser
.parseColonTypeList(types
) ||
783 parser
.parseKeywordType("into", resultType
) ||
784 parser
.resolveOperand(lhsInfo
, types
[0], result
.operands
) ||
785 parser
.resolveOperand(rhsInfo
, types
[1], result
.operands
) ||
786 parser
.resolveOperand(accInfo
, resultType
, result
.operands
) ||
787 parser
.addTypeToList(resultType
, result
.types
))
789 result
.attributes
.append(dictAttr
.getValue().begin(),
790 dictAttr
.getValue().end());
792 // Convert array of string into an array of IteratyType enums. This is needed,
793 // because tests still use the old format when 'iterator_types' attribute is
794 // represented as an array of strings.
795 // TODO: Remove this conversion once tests are fixed.
796 ArrayAttr iteratorTypes
= llvm::cast
<ArrayAttr
>(
797 result
.attributes
.get(getIteratorTypesAttrName(result
.name
)));
799 SmallVector
<Attribute
> iteratorTypeAttrs
;
801 for (StringRef s
: iteratorTypes
.getAsValueRange
<StringAttr
>()) {
802 auto maybeIteratorType
= symbolizeIteratorType(s
);
803 if (!maybeIteratorType
.has_value())
804 return parser
.emitError(loc
) << "unexpected iterator_type (" << s
<< ")";
806 iteratorTypeAttrs
.push_back(
807 IteratorTypeAttr::get(parser
.getContext(), maybeIteratorType
.value()));
809 result
.attributes
.set(getIteratorTypesAttrName(result
.name
),
810 parser
.getBuilder().getArrayAttr(iteratorTypeAttrs
));
812 if (!result
.attributes
.get(getKindAttrName(result
.name
))) {
814 getKindAttrName(result
.name
),
815 CombiningKindAttr::get(result
.getContext(),
816 ContractionOp::getDefaultKind()));
818 if (masksInfo
.empty())
820 if (masksInfo
.size() != 2)
821 return parser
.emitError(parser
.getNameLoc(),
822 "expected zero or exactly 2 vector mask operands");
823 auto lhsType
= llvm::cast
<VectorType
>(types
[0]);
824 auto rhsType
= llvm::cast
<VectorType
>(types
[1]);
825 auto maskElementType
= parser
.getBuilder().getI1Type();
826 std::array
<VectorType
, 2> maskTypes
= {
827 VectorType::Builder(lhsType
).setElementType(maskElementType
),
828 VectorType::Builder(rhsType
).setElementType(maskElementType
)};
829 if (parser
.resolveOperands(masksInfo
, maskTypes
, loc
, result
.operands
))
834 void ContractionOp::print(OpAsmPrinter
&p
) {
835 // TODO: Unify printing code with linalg ops.
836 auto attrNames
= getTraitAttrNames();
837 llvm::StringSet
<> traitAttrsSet
;
838 traitAttrsSet
.insert(attrNames
.begin(), attrNames
.end());
839 SmallVector
<NamedAttribute
, 8> attrs
;
840 for (auto attr
: (*this)->getAttrs()) {
841 if (attr
.getName() == getIteratorTypesAttrName()) {
843 llvm::cast
<ArrayAttr
>(attr
.getValue())
844 .getAsValueRange
<IteratorTypeAttr
, IteratorType
>();
845 // Convert IteratorType enums into the string representation. This is
846 // needed, because tests still use the old format when 'iterator_types'
847 // attribute is represented as an array of strings.
848 // TODO: Remove this conversion once tests are fixed.
849 SmallVector
<Attribute
> iteratorTypeNames
= llvm::to_vector(
850 llvm::map_range(iteratorTypes
, [&](IteratorType t
) -> Attribute
{
851 return StringAttr::get(getContext(), stringifyIteratorType(t
));
854 attrs
.emplace_back(getIteratorTypesAttrName(),
855 ArrayAttr::get(getContext(), iteratorTypeNames
));
856 } else if (traitAttrsSet
.count(attr
.getName().strref()) > 0)
857 attrs
.push_back(attr
);
860 auto dictAttr
= DictionaryAttr::get(getContext(), attrs
);
861 p
<< " " << dictAttr
<< " " << getLhs() << ", ";
862 p
<< getRhs() << ", " << getAcc();
864 p
.printOptionalAttrDict((*this)->getAttrs(), attrNames
);
865 p
<< " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
869 static bool verifyDimMap(VectorType lhsType
, VectorType rhsType
,
870 const std::vector
<std::pair
<int64_t, int64_t>> &map
) {
871 for (auto &dimPair
: map
) {
872 if (dimPair
.first
< 0 || dimPair
.first
>= lhsType
.getRank() ||
873 dimPair
.second
< 0 || dimPair
.second
>= rhsType
.getRank() ||
874 lhsType
.getDimSize(dimPair
.first
) != rhsType
.getDimSize(dimPair
.second
))
880 static LogicalResult
verifyOutputShape(
881 ContractionOp op
, VectorType lhsType
, VectorType rhsType
, Type accType
,
883 const std::vector
<std::pair
<int64_t, int64_t>> &contractingDimMap
,
884 const std::vector
<std::pair
<int64_t, int64_t>> &batchDimMap
) {
885 DenseSet
<int64_t> lhsContractingDimSet
;
886 DenseSet
<int64_t> rhsContractingDimSet
;
887 for (auto &dimPair
: contractingDimMap
) {
888 lhsContractingDimSet
.insert(dimPair
.first
);
889 rhsContractingDimSet
.insert(dimPair
.second
);
891 DenseSet
<int64_t> rhsBatchDimSet
;
892 for (auto &dimPair
: batchDimMap
)
893 rhsBatchDimSet
.insert(dimPair
.second
);
895 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
896 SmallVector
<int64_t, 4> expectedResultDims
;
897 for (int64_t i
= 0, e
= lhsType
.getRank(); i
< e
; ++i
) {
898 if (lhsContractingDimSet
.count(i
) > 0)
900 expectedResultDims
.push_back(lhsType
.getDimSize(i
));
903 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
904 for (int64_t i
= 0, e
= rhsType
.getRank(); i
< e
; ++i
) {
905 if (rhsContractingDimSet
.count(i
) > 0 || rhsBatchDimSet
.count(i
) > 0)
907 expectedResultDims
.push_back(rhsType
.getDimSize(i
));
910 // Verify 'expectedResultDims'.
911 if (expectedResultDims
.empty()) {
912 // No batch or free dimension implies a scalar result.
913 if (llvm::isa
<VectorType
>(resType
) || llvm::isa
<VectorType
>(accType
))
914 return op
.emitOpError("invalid accumulator/result vector shape");
916 // At least one batch or free dimension implies a vector result.
917 auto resVectorType
= llvm::dyn_cast
<VectorType
>(resType
);
918 auto accVectorType
= llvm::dyn_cast
<VectorType
>(accType
);
919 if (!resVectorType
|| !accVectorType
)
920 return op
.emitOpError("invalid accumulator/result vector shape");
922 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
923 // types fully define the result vector type. This assumes the affine maps
924 // are well-formed, which must have been verified already.
925 MLIRContext
*ctx
= op
.getContext();
926 AffineMap lhsMap
= op
.getIndexingMapsArray()[0];
927 AffineMap rhsMap
= op
.getIndexingMapsArray()[1];
928 if (getUnusedDimsBitVector({lhsMap
, rhsMap
}).any())
929 return op
.emitOpError(
930 "expected all dimensions to be either a LHS or a RHS dimension");
931 SmallVector
<AffineExpr
, 4> extents(lhsMap
.getNumInputs());
933 {std::make_pair(lhsType
, lhsMap
), std::make_pair(rhsType
, rhsMap
)}) {
934 VectorType v
= pair
.first
;
935 auto map
= pair
.second
;
936 for (unsigned idx
= 0, e
= v
.getRank(); idx
< e
; ++idx
) {
937 unsigned pos
= map
.getDimPosition(idx
);
939 extents
[pos
] = getAffineConstantExpr(v
.getShape()[idx
], ctx
);
942 if (!llvm::all_of(extents
, [](AffineExpr e
) { return e
; }))
943 return op
.emitOpError("expected all dimensions to get an extent as "
944 "either a LHS or a RHS dimension");
946 AffineMap resMap
= op
.getIndexingMapsArray()[2];
947 auto extentsMap
= AffineMap::get(/*dimCount=*/extents
.size(),
948 /*symbolCount=*/0, extents
, ctx
);
949 // Compose the resMap with the extentsMap, which is a constant map.
950 AffineMap expectedMap
= simplifyAffineMap(resMap
.compose(extentsMap
));
951 assert(llvm::all_of(expectedMap
.getResults(),
952 llvm::IsaPred
<AffineConstantExpr
>) &&
953 "expected constant extent along all dimensions.");
954 // Extract the expected shape and build the type.
955 auto expectedShape
= llvm::to_vector
<4>(
956 llvm::map_range(expectedMap
.getResults(), [](AffineExpr e
) {
957 return cast
<AffineConstantExpr
>(e
).getValue();
960 VectorType::get(expectedShape
, resVectorType
.getElementType(),
961 resVectorType
.getScalableDims());
962 if (resVectorType
!= expected
|| accVectorType
!= expected
)
963 return op
.emitOpError(
964 "invalid accumulator/result vector shape, expected: ")
970 LogicalResult
ContractionOp::verify() {
971 VectorType lhsType
= getLhsType();
972 VectorType rhsType
= getRhsType();
973 Type accType
= getAccType();
974 Type resType
= getResultType();
976 if (llvm::isa
<IntegerType
>(lhsType
.getElementType())) {
977 if (!lhsType
.getElementType().isSignlessInteger())
978 return emitOpError("only supports signless integer types");
981 // Verify that an indexing map was specified for each vector operand.
982 if (getIndexingMapsArray().size() != 3)
983 return emitOpError("expected an indexing map for each vector operand");
985 // Verify that each index map has 'numIterators' inputs, no symbols, and
986 // that the number of map outputs equals the rank of its associated
988 unsigned numIterators
= getIteratorTypes().getValue().size();
989 for (const auto &it
: llvm::enumerate(getIndexingMapsArray())) {
990 auto index
= it
.index();
991 auto map
= it
.value();
992 if (map
.getNumSymbols() != 0)
993 return emitOpError("expected indexing map ")
994 << index
<< " to have no symbols";
995 auto vectorType
= llvm::dyn_cast
<VectorType
>(getOperand(index
).getType());
996 unsigned rank
= vectorType
? vectorType
.getShape().size() : 0;
997 // Verify that the map has the right number of inputs, outputs, and indices.
998 // This also correctly accounts for (..) -> () for rank-0 results.
999 if (map
.getNumDims() != numIterators
)
1000 return emitOpError("expected indexing map ")
1001 << index
<< " to have " << numIterators
<< " number of inputs";
1002 if (map
.getNumResults() != rank
)
1003 return emitOpError("expected indexing map ")
1004 << index
<< " to have " << rank
<< " number of outputs";
1005 if (!map
.isProjectedPermutation())
1006 return emitOpError("expected indexing map ")
1007 << index
<< " to be a projected permutation of its inputs";
1010 auto contractingDimMap
= getContractingDimMap();
1011 auto batchDimMap
= getBatchDimMap();
1013 // Verify at least one contracting dimension pair was specified.
1014 if (contractingDimMap
.empty())
1015 return emitOpError("expected at least one contracting dimension pair");
1017 // Verify contracting dimension map was properly constructed.
1018 if (!verifyDimMap(lhsType
, rhsType
, contractingDimMap
))
1019 return emitOpError("invalid contracting dimension map");
1021 // Verify batch dimension map was properly constructed.
1022 if (!verifyDimMap(lhsType
, rhsType
, batchDimMap
))
1023 return emitOpError("invalid batch dimension map");
1025 // Verify 'accType' and 'resType' shape.
1026 if (failed(verifyOutputShape(*this, lhsType
, rhsType
, accType
, resType
,
1027 contractingDimMap
, batchDimMap
)))
1030 // Verify supported combining kind.
1031 auto vectorType
= llvm::dyn_cast
<VectorType
>(resType
);
1032 auto elementType
= vectorType
? vectorType
.getElementType() : resType
;
1033 if (!isSupportedCombiningKind(getKind(), elementType
))
1034 return emitOpError("unsupported contraction type");
1039 // MaskableOpInterface methods.
1041 /// Returns the mask type expected by this operation. Mostly used for
1042 /// verification purposes. It requires the operation to be vectorized."
1043 Type
ContractionOp::getExpectedMaskType() {
1044 auto indexingMaps
= this->getIndexingMapsArray();
1045 AffineMap lhsIdxMap
= indexingMaps
[0];
1046 AffineMap rhsIdxMap
= indexingMaps
[1];
1047 VectorType lhsType
= this->getLhsType();
1048 VectorType rhsType
= this->getRhsType();
1050 unsigned numVecDims
= lhsIdxMap
.getNumDims();
1051 SmallVector
<int64_t> maskShape(numVecDims
, ShapedType::kDynamic
);
1052 SmallVector
<bool> maskShapeScalableDims(numVecDims
, false);
1054 // Using the information in the indexing maps, extract the size of each
1055 // dimension in the vector.contract operation from the two input operands.
1056 for (auto [dimIdx
, dimSize
] : llvm::enumerate(lhsType
.getShape())) {
1057 maskShape
[lhsIdxMap
.getDimPosition(dimIdx
)] = dimSize
;
1058 maskShapeScalableDims
[lhsIdxMap
.getDimPosition(dimIdx
)] =
1059 lhsType
.getScalableDims()[dimIdx
];
1061 for (auto [dimIdx
, dimSize
] : llvm::enumerate(rhsType
.getShape())) {
1062 maskShape
[rhsIdxMap
.getDimPosition(dimIdx
)] = dimSize
;
1063 maskShapeScalableDims
[rhsIdxMap
.getDimPosition(dimIdx
)] =
1064 rhsType
.getScalableDims()[dimIdx
];
1067 assert(!ShapedType::isDynamicShape(maskShape
) &&
1068 "Mask shape couldn't be computed");
1070 return VectorType::get(maskShape
,
1071 IntegerType::get(lhsType
.getContext(), /*width=*/1),
1072 maskShapeScalableDims
);
1075 SmallVector
<StringRef
> ContractionOp::getTraitAttrNames() {
1076 return SmallVector
<StringRef
>{getIndexingMapsAttrName(),
1077 getIteratorTypesAttrName(), getKindAttrName()};
1080 static int64_t getResultIndex(AffineMap map
, AffineExpr targetExpr
) {
1081 for (int64_t i
= 0, e
= map
.getNumResults(); i
< e
; ++i
)
1082 if (targetExpr
== map
.getResult(i
))
1087 static std::vector
<std::pair
<int64_t, int64_t>>
1088 getDimMap(ArrayRef
<AffineMap
> indexingMaps
, ArrayAttr iteratorTypes
,
1089 IteratorType targetIteratorType
, MLIRContext
*context
) {
1090 std::vector
<std::pair
<int64_t, int64_t>> dimMap
;
1091 for (const auto &it
: llvm::enumerate(iteratorTypes
)) {
1092 auto iteratorType
= llvm::cast
<IteratorTypeAttr
>(it
.value()).getValue();
1093 if (iteratorType
!= targetIteratorType
)
1095 // Search lhs/rhs map results for 'targetExpr'.
1096 auto targetExpr
= getAffineDimExpr(it
.index(), context
);
1097 int64_t lhsDim
= getResultIndex(indexingMaps
[0], targetExpr
);
1098 int64_t rhsDim
= getResultIndex(indexingMaps
[1], targetExpr
);
1099 if (lhsDim
>= 0 && rhsDim
>= 0)
1100 dimMap
.emplace_back(lhsDim
, rhsDim
);
1105 void ContractionOp::getIterationBounds(
1106 SmallVectorImpl
<int64_t> &iterationBounds
) {
1107 auto lhsShape
= getLhsType().getShape();
1108 auto resVectorType
= llvm::dyn_cast
<VectorType
>(getResultType());
1109 SmallVector
<AffineMap
, 4> indexingMaps(getIndexingMapsArray());
1110 SmallVector
<int64_t, 2> iterationShape
;
1111 for (const auto &it
: llvm::enumerate(getIteratorTypes())) {
1112 // Search lhs/rhs map results for 'targetExpr'.
1113 auto targetExpr
= getAffineDimExpr(it
.index(), getContext());
1114 auto iteratorType
= llvm::cast
<IteratorTypeAttr
>(it
.value()).getValue();
1115 if (iteratorType
== IteratorType::reduction
) {
1116 // Get reduction dim size from lhs shape (same size in rhsShape).
1117 int64_t lhsDimIndex
= getResultIndex(indexingMaps
[0], targetExpr
);
1118 assert(lhsDimIndex
>= 0);
1119 iterationBounds
.push_back(lhsShape
[lhsDimIndex
]);
1122 // Get parallel dimension size from result shape.
1123 int64_t resDimIndex
= getResultIndex(indexingMaps
[2], targetExpr
);
1124 assert(resDimIndex
>= 0);
1125 assert(resVectorType
!= nullptr);
1126 iterationBounds
.push_back(resVectorType
.getShape()[resDimIndex
]);
1130 void ContractionOp::getIterationIndexMap(
1131 std::vector
<DenseMap
<int64_t, int64_t>> &iterationIndexMap
) {
1132 unsigned numMaps
= getIndexingMapsArray().size();
1133 iterationIndexMap
.resize(numMaps
);
1134 for (const auto &it
: llvm::enumerate(getIndexingMapsArray())) {
1135 auto index
= it
.index();
1136 auto map
= it
.value();
1137 for (unsigned i
= 0, e
= map
.getNumResults(); i
< e
; ++i
) {
1138 auto dim
= cast
<AffineDimExpr
>(map
.getResult(i
));
1139 iterationIndexMap
[index
][dim
.getPosition()] = i
;
1144 std::vector
<std::pair
<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1145 SmallVector
<AffineMap
, 4> indexingMaps(getIndexingMapsArray());
1146 return getDimMap(indexingMaps
, getIteratorTypes(), IteratorType::reduction
,
1150 std::vector
<std::pair
<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1151 SmallVector
<AffineMap
, 4> indexingMaps(getIndexingMapsArray());
1152 return getDimMap(indexingMaps
, getIteratorTypes(), IteratorType::parallel
,
1156 std::optional
<SmallVector
<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157 SmallVector
<int64_t, 4> shape
;
1158 getIterationBounds(shape
);
1162 /// Return a fused vector::ContractionOp which represents a patterns such as:
1165 /// %c0 = vector.constant 0: ...
1166 /// %c = vector.contract %a, %b, %c0: ...
1167 /// %e = add %c, %d: ...
1173 /// %e = vector.contract %a, %b, %d: ...
1176 /// Return null if the canonicalization does not apply.
1177 // TODO: This should be a folding of Add into Contract in core but while they
1178 // live in different dialects, it is not possible without unnatural
1180 template <typename AddOpType
>
1181 struct CanonicalizeContractAdd
: public OpRewritePattern
<AddOpType
> {
1182 using OpRewritePattern
<AddOpType
>::OpRewritePattern
;
1184 LogicalResult
matchAndRewrite(AddOpType addOp
,
1185 PatternRewriter
&rewriter
) const override
{
1186 auto canonicalize
= [&](Value maybeContraction
,
1187 Value otherOperand
) -> vector::ContractionOp
{
1188 vector::ContractionOp contractionOp
=
1189 dyn_cast_or_null
<vector::ContractionOp
>(
1190 maybeContraction
.getDefiningOp());
1192 return vector::ContractionOp();
1193 if (auto maybeZero
= dyn_cast_or_null
<arith::ConstantOp
>(
1194 contractionOp
.getAcc().getDefiningOp())) {
1195 if (maybeZero
.getValue() ==
1196 rewriter
.getZeroAttr(contractionOp
.getAcc().getType())) {
1198 bvm
.map(contractionOp
.getAcc(), otherOperand
);
1199 auto newContraction
=
1200 cast
<vector::ContractionOp
>(rewriter
.clone(*contractionOp
, bvm
));
1201 rewriter
.replaceOp(addOp
, newContraction
.getResult());
1202 return newContraction
;
1205 return vector::ContractionOp();
1208 Value a
= addOp
->getOperand(0), b
= addOp
->getOperand(1);
1209 vector::ContractionOp contract
= canonicalize(a
, b
);
1210 contract
= contract
? contract
: canonicalize(b
, a
);
1211 return contract
? success() : failure();
1215 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1216 MLIRContext
*context
) {
1217 results
.add
<CanonicalizeContractAdd
<arith::AddIOp
>,
1218 CanonicalizeContractAdd
<arith::AddFOp
>>(context
);
1221 //===----------------------------------------------------------------------===//
1223 //===----------------------------------------------------------------------===//
1225 void ExtractElementOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
1226 SetIntRangeFn setResultRanges
) {
1227 setResultRanges(getResult(), argRanges
.front());
1230 void vector::ExtractElementOp::build(OpBuilder
&builder
, OperationState
&result
,
1232 result
.addOperands({source
});
1233 result
.addTypes(llvm::cast
<VectorType
>(source
.getType()).getElementType());
1236 LogicalResult
vector::ExtractElementOp::verify() {
1237 VectorType vectorType
= getSourceVectorType();
1238 if (vectorType
.getRank() == 0) {
1240 return emitOpError("expected position to be empty with 0-D vector");
1243 if (vectorType
.getRank() != 1)
1244 return emitOpError("unexpected >1 vector rank");
1246 return emitOpError("expected position for 1-D vector");
1250 OpFoldResult
vector::ExtractElementOp::fold(FoldAdaptor adaptor
) {
1251 // Skip the 0-D vector here now.
1252 if (!adaptor
.getPosition())
1255 // Fold extractelement (splat X) -> X.
1256 if (auto splat
= getVector().getDefiningOp
<vector::SplatOp
>())
1257 return splat
.getInput();
1259 // Fold extractelement(broadcast(X)) -> X.
1260 if (auto broadcast
= getVector().getDefiningOp
<vector::BroadcastOp
>())
1261 if (!llvm::isa
<VectorType
>(broadcast
.getSource().getType()))
1262 return broadcast
.getSource();
1264 auto src
= dyn_cast_or_null
<DenseElementsAttr
>(adaptor
.getVector());
1265 auto pos
= dyn_cast_or_null
<IntegerAttr
>(adaptor
.getPosition());
1269 auto srcElements
= src
.getValues
<Attribute
>();
1271 uint64_t posIdx
= pos
.getInt();
1272 if (posIdx
>= srcElements
.size())
1275 return srcElements
[posIdx
];
1278 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1280 static bool isValidPositiveIndexOrPoison(int64_t index
, int64_t poisonValue
,
1282 return index
== poisonValue
|| (index
>= 0 && index
< maxIndex
);
1285 //===----------------------------------------------------------------------===//
1287 //===----------------------------------------------------------------------===//
1289 void ExtractOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
1290 SetIntRangeFn setResultRanges
) {
1291 setResultRanges(getResult(), argRanges
.front());
1294 void vector::ExtractOp::build(OpBuilder
&builder
, OperationState
&result
,
1295 Value source
, int64_t position
) {
1296 build(builder
, result
, source
, ArrayRef
<int64_t>{position
});
1299 void vector::ExtractOp::build(OpBuilder
&builder
, OperationState
&result
,
1300 Value source
, OpFoldResult position
) {
1301 build(builder
, result
, source
, ArrayRef
<OpFoldResult
>{position
});
1304 void vector::ExtractOp::build(OpBuilder
&builder
, OperationState
&result
,
1305 Value source
, ArrayRef
<int64_t> position
) {
1306 build(builder
, result
, source
, /*dynamic_position=*/ArrayRef
<Value
>(),
1307 builder
.getDenseI64ArrayAttr(position
));
1310 void vector::ExtractOp::build(OpBuilder
&builder
, OperationState
&result
,
1311 Value source
, ArrayRef
<OpFoldResult
> position
) {
1312 SmallVector
<int64_t> staticPos
;
1313 SmallVector
<Value
> dynamicPos
;
1314 dispatchIndexOpFoldResults(position
, dynamicPos
, staticPos
);
1315 build(builder
, result
, source
, dynamicPos
,
1316 builder
.getDenseI64ArrayAttr(staticPos
));
1320 ExtractOp::inferReturnTypes(MLIRContext
*, std::optional
<Location
>,
1321 ExtractOp::Adaptor adaptor
,
1322 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
1323 auto vectorType
= llvm::cast
<VectorType
>(adaptor
.getVector().getType());
1324 if (static_cast<int64_t>(adaptor
.getStaticPosition().size()) ==
1325 vectorType
.getRank()) {
1326 inferredReturnTypes
.push_back(vectorType
.getElementType());
1328 auto n
= std::min
<size_t>(adaptor
.getStaticPosition().size(),
1329 vectorType
.getRank());
1330 inferredReturnTypes
.push_back(VectorType::get(
1331 vectorType
.getShape().drop_front(n
), vectorType
.getElementType(),
1332 vectorType
.getScalableDims().drop_front(n
)));
1337 bool ExtractOp::isCompatibleReturnTypes(TypeRange l
, TypeRange r
) {
1338 // Allow extracting 1-element vectors instead of scalars.
1339 auto isCompatible
= [](TypeRange l
, TypeRange r
) {
1340 auto vectorType
= llvm::dyn_cast
<VectorType
>(l
.front());
1341 return vectorType
&& vectorType
.getShape().equals({1}) &&
1342 vectorType
.getElementType() == r
.front();
1344 if (l
.size() == 1 && r
.size() == 1 &&
1345 (isCompatible(l
, r
) || isCompatible(r
, l
)))
1350 LogicalResult
vector::ExtractOp::verify() {
1351 // Note: This check must come before getMixedPosition() to prevent a crash.
1352 auto dynamicMarkersCount
=
1353 llvm::count_if(getStaticPosition(), ShapedType::isDynamic
);
1354 if (static_cast<size_t>(dynamicMarkersCount
) != getDynamicPosition().size())
1356 "mismatch between dynamic and static positions (kDynamic marker but no "
1357 "corresponding dynamic position) -- this can only happen due to an "
1358 "incorrect fold/rewrite");
1359 auto position
= getMixedPosition();
1360 if (position
.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1362 "expected position attribute of rank no greater than vector rank");
1363 for (auto [idx
, pos
] : llvm::enumerate(position
)) {
1364 if (auto attr
= dyn_cast
<Attribute
>(pos
)) {
1365 int64_t constIdx
= cast
<IntegerAttr
>(attr
).getInt();
1366 if (!isValidPositiveIndexOrPoison(
1367 constIdx
, kPoisonIndex
, getSourceVectorType().getDimSize(idx
))) {
1368 return emitOpError("expected position attribute #")
1370 << " to be a non-negative integer smaller than the "
1371 "corresponding vector dimension or poison (-1)";
1378 template <typename IntType
>
1379 static SmallVector
<IntType
> extractVector(ArrayAttr arrayAttr
) {
1380 return llvm::to_vector
<4>(llvm::map_range(
1381 arrayAttr
.getAsRange
<IntegerAttr
>(),
1382 [](IntegerAttr attr
) { return static_cast<IntType
>(attr
.getInt()); }));
1385 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1387 static LogicalResult
foldExtractOpFromExtractChain(ExtractOp extractOp
) {
1388 if (!extractOp
.getVector().getDefiningOp
<ExtractOp
>())
1391 // TODO: Canonicalization for dynamic position not implemented yet.
1392 if (extractOp
.hasDynamicPosition())
1395 SmallVector
<int64_t> globalPosition
;
1396 ExtractOp currentOp
= extractOp
;
1397 ArrayRef
<int64_t> extrPos
= currentOp
.getStaticPosition();
1398 globalPosition
.append(extrPos
.rbegin(), extrPos
.rend());
1399 while (ExtractOp nextOp
= currentOp
.getVector().getDefiningOp
<ExtractOp
>()) {
1401 // TODO: Canonicalization for dynamic position not implemented yet.
1402 if (currentOp
.hasDynamicPosition())
1404 ArrayRef
<int64_t> extrPos
= currentOp
.getStaticPosition();
1405 globalPosition
.append(extrPos
.rbegin(), extrPos
.rend());
1407 extractOp
.setOperand(0, currentOp
.getVector());
1408 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1409 OpBuilder
b(extractOp
.getContext());
1410 std::reverse(globalPosition
.begin(), globalPosition
.end());
1411 extractOp
.setStaticPosition(globalPosition
);
1416 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1417 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1418 /// Compose TransposeOp permutations as we walk back.
1419 /// This helper class keeps an updated extraction position `extractPosition`
1420 /// with extra trailing sentinels.
1421 /// The sentinels encode the internal transposition status of the result vector.
1422 /// As we iterate, extractPosition is permuted and updated.
1423 class ExtractFromInsertTransposeChainState
{
1425 ExtractFromInsertTransposeChainState(ExtractOp e
);
1427 /// Iterate over producing insert and transpose ops until we find a fold.
1431 /// Return true if the vector at position `a` is contained within the vector
1432 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1433 /// is a prefix of `b`.
1434 template <typename ContainerA
, typename ContainerB
>
1435 bool isContainedWithin(const ContainerA
&a
, const ContainerB
&b
) {
1436 return a
.size() <= b
.size() &&
1437 std::equal(a
.begin(), a
.begin() + a
.size(), b
.begin());
1440 /// Return true if the vector at position `a` intersects the vector at
1441 /// position `b`. Under insert/extract semantics, this is the same as equality
1442 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1443 /// Comparison is on the common prefix (i.e. zip).
1444 template <typename ContainerA
, typename ContainerB
>
1445 bool intersectsWhereNonNegative(const ContainerA
&a
, const ContainerB
&b
) {
1446 for (auto [elemA
, elemB
] : llvm::zip(a
, b
)) {
1447 if (elemA
< 0 || elemB
< 0)
1455 /// Folding is only possible in the absence of an internal permutation in the
1458 return (sentinels
== ArrayRef(extractPosition
).drop_front(extractedRank
));
1461 // Helper to get the next defining op of interest.
1462 void updateStateForNextIteration(Value v
) {
1463 nextInsertOp
= v
.getDefiningOp
<vector::InsertOp
>();
1464 nextTransposeOp
= v
.getDefiningOp
<vector::TransposeOp
>();
1467 // Case 1. If we hit a transpose, just compose the map and iterate.
1468 // Invariant: insert + transpose do not change rank, we can always compose.
1469 LogicalResult
handleTransposeOp();
1471 // Case 2: the insert position matches extractPosition exactly, early return.
1472 LogicalResult
handleInsertOpWithMatchingPos(Value
&res
);
1474 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1475 /// portion of the source of the insert.
1478 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1479 /// // extractPosition == [1, 2, 3]
1480 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1481 /// // can fold to vector.extract %source[0, 3]
1482 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1484 /// To traverse through %source, we need to set the leading dims to 0 and
1485 /// drop the extra leading dims.
1486 /// This method updates the internal state.
1487 LogicalResult
handleInsertOpWithPrefixPos(Value
&res
);
1489 /// Try to fold in place to extract(source, extractPosition) and return the
1490 /// folded result. Return null if folding is not possible (e.g. due to an
1491 /// internal transposition in the result).
1492 Value
tryToFoldExtractOpInPlace(Value source
);
1494 ExtractOp extractOp
;
1496 int64_t extractedRank
;
1498 InsertOp nextInsertOp
;
1499 TransposeOp nextTransposeOp
;
1501 /// Sentinel values that encode the internal permutation status of the result.
1502 /// They are set to (-1, ... , -k) at the beginning and appended to
1503 /// `extractPosition`.
1504 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1505 /// ensure that there is no internal transposition.
1506 /// Internal transposition cannot be accounted for with a folding pattern.
1507 // TODO: We could relax the internal transposition with an extra transposition
1508 // operation in a future canonicalizer.
1509 SmallVector
<int64_t> sentinels
;
1510 SmallVector
<int64_t> extractPosition
;
1514 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1516 : extractOp(e
), vectorRank(extractOp
.getSourceVectorType().getRank()),
1517 extractedRank(extractOp
.getNumIndices()) {
1518 assert(vectorRank
>= extractedRank
&& "Extracted position overflow");
1519 sentinels
.reserve(vectorRank
- extractedRank
);
1520 for (int64_t i
= 0, e
= vectorRank
- extractedRank
; i
< e
; ++i
)
1521 sentinels
.push_back(-(i
+ 1));
1522 extractPosition
.assign(extractOp
.getStaticPosition().begin(),
1523 extractOp
.getStaticPosition().end());
1524 llvm::append_range(extractPosition
, sentinels
);
1527 // Case 1. If we hit a transpose, just compose the map and iterate.
1528 // Invariant: insert + transpose do not change rank, we can always compose.
1529 LogicalResult
ExtractFromInsertTransposeChainState::handleTransposeOp() {
1530 // TODO: Canonicalization for dynamic position not implemented yet.
1531 if (extractOp
.hasDynamicPosition())
1534 if (!nextTransposeOp
)
1536 AffineMap m
= inversePermutation(AffineMap::getPermutationMap(
1537 nextTransposeOp
.getPermutation(), extractOp
.getContext()));
1538 extractPosition
= applyPermutationMap(m
, ArrayRef(extractPosition
));
1542 // Case 2: the insert position matches extractPosition exactly, early return.
1544 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1546 // TODO: Canonicalization for dynamic position not implemented yet.
1547 if (extractOp
.hasDynamicPosition() || nextInsertOp
.hasDynamicPosition())
1550 ArrayRef
<int64_t> insertedPos
= nextInsertOp
.getStaticPosition();
1551 if (insertedPos
!= llvm::ArrayRef(extractPosition
).take_front(extractedRank
))
1553 // Case 2.a. early-exit fold.
1554 res
= nextInsertOp
.getSource();
1555 // Case 2.b. if internal transposition is present, canFold will be false.
1556 return success(canFold());
1559 /// Case 3: if inserted position is a prefix of extractPosition,
1560 /// extract a portion of the source of the insertion.
1561 /// This method updates the internal state.
1563 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value
&res
) {
1564 // TODO: Canonicalization for dynamic position not implemented yet.
1565 if (extractOp
.hasDynamicPosition() || nextInsertOp
.hasDynamicPosition())
1568 ArrayRef
<int64_t> insertedPos
= nextInsertOp
.getStaticPosition();
1569 if (!isContainedWithin(insertedPos
, extractPosition
))
1571 // Set leading dims to zero.
1572 std::fill_n(extractPosition
.begin(), insertedPos
.size(), 0);
1573 // Drop extra leading dims.
1574 extractPosition
.erase(extractPosition
.begin(),
1575 extractPosition
.begin() + insertedPos
.size());
1576 extractedRank
= extractPosition
.size() - sentinels
.size();
1577 // Case 3.a. early-exit fold (break and delegate to post-while path).
1578 res
= nextInsertOp
.getSource();
1579 // Case 3.b. if internal transposition is present, canFold will be false.
1583 /// Try to fold in place to extract(source, extractPosition) and return the
1584 /// folded result. Return null if folding is not possible (e.g. due to an
1585 /// internal transposition in the result).
1586 Value
ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1588 // TODO: Canonicalization for dynamic position not implemented yet.
1589 if (extractOp
.hasDynamicPosition())
1592 // If we can't fold (either internal transposition, or nothing to fold), bail.
1593 bool nothingToFold
= (source
== extractOp
.getVector());
1594 if (nothingToFold
|| !canFold())
1597 // Otherwise, fold by updating the op inplace and return its result.
1598 OpBuilder
b(extractOp
.getContext());
1599 extractOp
.setStaticPosition(
1600 ArrayRef(extractPosition
).take_front(extractedRank
));
1601 extractOp
.getVectorMutable().assign(source
);
1602 return extractOp
.getResult();
1605 /// Iterate over producing insert and transpose ops until we find a fold.
1606 Value
ExtractFromInsertTransposeChainState::fold() {
1607 // TODO: Canonicalization for dynamic position not implemented yet.
1608 if (extractOp
.hasDynamicPosition())
1611 Value valueToExtractFrom
= extractOp
.getVector();
1612 updateStateForNextIteration(valueToExtractFrom
);
1613 while (nextInsertOp
|| nextTransposeOp
) {
1614 // Case 1. If we hit a transpose, just compose the map and iterate.
1615 // Invariant: insert + transpose do not change rank, we can always compose.
1616 if (succeeded(handleTransposeOp())) {
1617 valueToExtractFrom
= nextTransposeOp
.getVector();
1618 updateStateForNextIteration(valueToExtractFrom
);
1623 // Case 2: the position match exactly.
1624 if (succeeded(handleInsertOpWithMatchingPos(result
)))
1627 // Case 3: if the inserted position is a prefix of extractPosition, we can
1628 // just extract a portion of the source of the insert.
1629 if (succeeded(handleInsertOpWithPrefixPos(result
)))
1630 return tryToFoldExtractOpInPlace(result
);
1632 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1633 // values. This is a more difficult case and we bail.
1634 ArrayRef
<int64_t> insertedPos
= nextInsertOp
.getStaticPosition();
1635 if (isContainedWithin(extractPosition
, insertedPos
) ||
1636 intersectsWhereNonNegative(extractPosition
, insertedPos
))
1639 // Case 5: No intersection, we forward the extract to insertOp.dest().
1640 valueToExtractFrom
= nextInsertOp
.getDest();
1641 updateStateForNextIteration(valueToExtractFrom
);
1643 // If after all this we can fold, go for it.
1644 return tryToFoldExtractOpInPlace(valueToExtractFrom
);
1647 /// Returns true if the operation has a 0-D vector type operand or result.
1648 static bool hasZeroDimVectors(Operation
*op
) {
1649 auto hasZeroDimVectorType
= [](Type type
) -> bool {
1650 auto vecType
= dyn_cast
<VectorType
>(type
);
1651 return vecType
&& vecType
.getRank() == 0;
1654 return llvm::any_of(op
->getOperandTypes(), hasZeroDimVectorType
) ||
1655 llvm::any_of(op
->getResultTypes(), hasZeroDimVectorType
);
1658 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1659 static Value
foldExtractFromBroadcast(ExtractOp extractOp
) {
1660 // TODO: Canonicalization for dynamic position not implemented yet.
1661 if (extractOp
.hasDynamicPosition())
1664 Operation
*defOp
= extractOp
.getVector().getDefiningOp();
1665 if (!defOp
|| !isa
<vector::BroadcastOp
, SplatOp
>(defOp
))
1668 Value source
= defOp
->getOperand(0);
1669 if (extractOp
.getType() == source
.getType())
1671 auto getRank
= [](Type type
) {
1672 return llvm::isa
<VectorType
>(type
) ? llvm::cast
<VectorType
>(type
).getRank()
1676 // If splat or broadcast from a scalar, just return the source scalar.
1677 unsigned broadcastSrcRank
= getRank(source
.getType());
1678 if (broadcastSrcRank
== 0 && source
.getType() == extractOp
.getType())
1681 unsigned extractResultRank
= getRank(extractOp
.getType());
1682 if (extractResultRank
>= broadcastSrcRank
)
1684 // Check that the dimension of the result haven't been broadcasted.
1685 auto extractVecType
= llvm::dyn_cast
<VectorType
>(extractOp
.getType());
1686 auto broadcastVecType
= llvm::dyn_cast
<VectorType
>(source
.getType());
1687 if (extractVecType
&& broadcastVecType
&&
1688 extractVecType
.getShape() !=
1689 broadcastVecType
.getShape().take_back(extractResultRank
))
1692 auto broadcastOp
= cast
<vector::BroadcastOp
>(defOp
);
1693 int64_t broadcastDstRank
= broadcastOp
.getResultVectorType().getRank();
1695 // Detect all the positions that come from "dim-1" broadcasting.
1696 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1697 // extract position to `0` when extracting from the source operand.
1698 llvm::SetVector
<int64_t> broadcastedUnitDims
=
1699 broadcastOp
.computeBroadcastedUnitDims();
1700 SmallVector
<int64_t> extractPos(extractOp
.getStaticPosition());
1701 int64_t broadcastRankDiff
= broadcastDstRank
- broadcastSrcRank
;
1702 for (int64_t i
= broadcastRankDiff
, e
= extractPos
.size(); i
< e
; ++i
)
1703 if (broadcastedUnitDims
.contains(i
))
1705 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1706 // matching extract position when extracting from the source operand.
1707 int64_t rankDiff
= broadcastSrcRank
- extractResultRank
;
1708 extractPos
.erase(extractPos
.begin(),
1709 std::next(extractPos
.begin(), extractPos
.size() - rankDiff
));
1710 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1711 OpBuilder
b(extractOp
.getContext());
1712 extractOp
.setOperand(0, source
);
1713 extractOp
.setStaticPosition(extractPos
);
1714 return extractOp
.getResult();
1717 /// Fold extractOp coming from ShuffleOp.
1721 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1722 /// : vector<8xf32>, vector<8xf32>
1723 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1725 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1727 static Value
foldExtractFromShuffle(ExtractOp extractOp
) {
1728 // Dynamic positions are not folded as the resulting code would be more
1729 // complex than the input code.
1730 if (extractOp
.hasDynamicPosition())
1733 auto shuffleOp
= extractOp
.getVector().getDefiningOp
<ShuffleOp
>();
1737 // TODO: 0-D or multi-dimensional vectors not supported yet.
1738 if (shuffleOp
.getResultVectorType().getRank() != 1)
1741 int64_t inputVecSize
= shuffleOp
.getV1().getType().getShape()[0];
1742 auto shuffleMask
= shuffleOp
.getMask();
1743 int64_t extractIdx
= extractOp
.getStaticPosition()[0];
1744 int64_t shuffleIdx
= shuffleMask
[extractIdx
];
1746 // Find the shuffled vector to extract from based on the shuffle index.
1747 if (shuffleIdx
< inputVecSize
) {
1748 extractOp
.setOperand(0, shuffleOp
.getV1());
1749 extractOp
.setStaticPosition({shuffleIdx
});
1751 extractOp
.setOperand(0, shuffleOp
.getV2());
1752 extractOp
.setStaticPosition({shuffleIdx
- inputVecSize
});
1755 return extractOp
.getResult();
1758 // Fold extractOp with source coming from ShapeCast op.
1759 static Value
foldExtractFromShapeCast(ExtractOp extractOp
) {
1760 // TODO: Canonicalization for dynamic position not implemented yet.
1761 if (extractOp
.hasDynamicPosition())
1764 auto shapeCastOp
= extractOp
.getVector().getDefiningOp
<vector::ShapeCastOp
>();
1768 // Get the nth dimension size starting from lowest dimension.
1769 auto getDimReverse
= [](VectorType type
, int64_t n
) {
1770 return type
.getShape().take_back(n
+ 1).front();
1772 int64_t destinationRank
=
1773 llvm::isa
<VectorType
>(extractOp
.getType())
1774 ? llvm::cast
<VectorType
>(extractOp
.getType()).getRank()
1776 if (destinationRank
> shapeCastOp
.getSourceVectorType().getRank())
1778 if (destinationRank
> 0) {
1779 auto destinationType
=
1780 llvm::cast
<VectorType
>(extractOp
.getResult().getType());
1781 for (int64_t i
= 0; i
< destinationRank
; i
++) {
1782 // The lowest dimension of the destination must match the lowest
1783 // dimension of the shapecast op source.
1784 // TODO: This case could be support in a canonicalization pattern.
1785 if (getDimReverse(shapeCastOp
.getSourceVectorType(), i
) !=
1786 getDimReverse(destinationType
, i
))
1790 // Extract the strides associated with the extract op vector source. Then use
1791 // this to calculate a linearized position for the extract.
1792 SmallVector
<int64_t> extractedPos(extractOp
.getStaticPosition());
1793 std::reverse(extractedPos
.begin(), extractedPos
.end());
1794 SmallVector
<int64_t, 4> strides
;
1796 for (int64_t i
= 0, e
= extractedPos
.size(); i
< e
; i
++) {
1797 strides
.push_back(stride
);
1799 getDimReverse(extractOp
.getSourceVectorType(), i
+ destinationRank
);
1802 int64_t position
= linearize(extractedPos
, strides
);
1803 // Then extract the strides associated to the shapeCast op vector source and
1804 // delinearize the position using those strides.
1805 SmallVector
<int64_t, 4> newStrides
;
1806 int64_t numDimension
=
1807 shapeCastOp
.getSourceVectorType().getRank() - destinationRank
;
1809 for (int64_t i
= 0; i
< numDimension
; i
++) {
1810 newStrides
.push_back(stride
);
1812 getDimReverse(shapeCastOp
.getSourceVectorType(), i
+ destinationRank
);
1814 std::reverse(newStrides
.begin(), newStrides
.end());
1815 SmallVector
<int64_t, 4> newPosition
= delinearize(position
, newStrides
);
1816 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1817 OpBuilder
b(extractOp
.getContext());
1818 extractOp
.setStaticPosition(newPosition
);
1819 extractOp
.setOperand(0, shapeCastOp
.getSource());
1820 return extractOp
.getResult();
1823 /// Fold an ExtractOp from ExtractStridedSliceOp.
1824 static Value
foldExtractFromExtractStrided(ExtractOp extractOp
) {
1825 // TODO: Canonicalization for dynamic position not implemented yet.
1826 if (extractOp
.hasDynamicPosition())
1829 auto extractStridedSliceOp
=
1830 extractOp
.getVector().getDefiningOp
<vector::ExtractStridedSliceOp
>();
1831 if (!extractStridedSliceOp
)
1834 // 0-D vectors not supported.
1835 assert(!hasZeroDimVectors(extractOp
) && "0-D vectors not supported");
1836 if (hasZeroDimVectors(extractStridedSliceOp
))
1839 // Return if 'extractStridedSliceOp' has non-unit strides.
1840 if (extractStridedSliceOp
.hasNonUnitStrides())
1843 // Trim offsets for dimensions fully extracted.
1845 extractVector
<int64_t>(extractStridedSliceOp
.getOffsets());
1846 while (!sliceOffsets
.empty()) {
1847 size_t lastOffset
= sliceOffsets
.size() - 1;
1848 if (sliceOffsets
.back() != 0 ||
1849 extractStridedSliceOp
.getType().getDimSize(lastOffset
) !=
1850 extractStridedSliceOp
.getSourceVectorType().getDimSize(lastOffset
))
1852 sliceOffsets
.pop_back();
1854 unsigned destinationRank
= 0;
1855 if (auto vecType
= llvm::dyn_cast
<VectorType
>(extractOp
.getType()))
1856 destinationRank
= vecType
.getRank();
1857 // The dimensions of the result need to be untouched by the
1858 // extractStridedSlice op.
1859 if (destinationRank
> extractStridedSliceOp
.getSourceVectorType().getRank() -
1860 sliceOffsets
.size())
1863 SmallVector
<int64_t> extractedPos(extractOp
.getStaticPosition());
1864 assert(extractedPos
.size() >= sliceOffsets
.size());
1865 for (size_t i
= 0, e
= sliceOffsets
.size(); i
< e
; i
++)
1866 extractedPos
[i
] = extractedPos
[i
] + sliceOffsets
[i
];
1867 extractOp
.getVectorMutable().assign(extractStridedSliceOp
.getVector());
1869 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1870 OpBuilder
b(extractOp
.getContext());
1871 extractOp
.setStaticPosition(extractedPos
);
1872 return extractOp
.getResult();
1875 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1876 static Value
foldExtractStridedOpFromInsertChain(ExtractOp extractOp
) {
1877 // TODO: Canonicalization for dynamic position not implemented yet.
1878 if (extractOp
.hasDynamicPosition())
1881 int64_t destinationRank
=
1882 llvm::isa
<VectorType
>(extractOp
.getType())
1883 ? llvm::cast
<VectorType
>(extractOp
.getType()).getRank()
1885 auto insertOp
= extractOp
.getVector().getDefiningOp
<InsertStridedSliceOp
>();
1889 // 0-D vectors not supported.
1890 assert(!hasZeroDimVectors(extractOp
) && "0-D vectors not supported");
1891 if (hasZeroDimVectors(insertOp
))
1895 int64_t insertRankDiff
= insertOp
.getDestVectorType().getRank() -
1896 insertOp
.getSourceVectorType().getRank();
1897 if (destinationRank
> insertOp
.getSourceVectorType().getRank())
1899 auto insertOffsets
= extractVector
<int64_t>(insertOp
.getOffsets());
1900 ArrayRef
<int64_t> extractOffsets
= extractOp
.getStaticPosition();
1902 if (llvm::any_of(insertOp
.getStrides(), [](Attribute attr
) {
1903 return llvm::cast
<IntegerAttr
>(attr
).getInt() != 1;
1906 bool disjoint
= false;
1907 SmallVector
<int64_t, 4> offsetDiffs
;
1908 for (unsigned dim
= 0, e
= extractOffsets
.size(); dim
< e
; ++dim
) {
1909 int64_t start
= insertOffsets
[dim
];
1911 (dim
< insertRankDiff
)
1913 : insertOp
.getSourceVectorType().getDimSize(dim
- insertRankDiff
);
1914 int64_t end
= start
+ size
;
1915 int64_t offset
= extractOffsets
[dim
];
1916 // Check if the start of the extract offset is in the interval inserted.
1917 if (start
<= offset
&& offset
< end
) {
1918 if (dim
>= insertRankDiff
)
1919 offsetDiffs
.push_back(offset
- start
);
1925 // The extract element chunk overlap with the vector inserted.
1927 // If any of the inner dimensions are only partially inserted we have a
1929 int64_t srcRankDiff
=
1930 insertOp
.getSourceVectorType().getRank() - destinationRank
;
1931 for (int64_t i
= 0; i
< destinationRank
; i
++) {
1932 if (insertOp
.getSourceVectorType().getDimSize(i
+ srcRankDiff
) !=
1933 insertOp
.getDestVectorType().getDimSize(i
+ srcRankDiff
+
1937 extractOp
.getVectorMutable().assign(insertOp
.getSource());
1938 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1939 OpBuilder
b(extractOp
.getContext());
1940 extractOp
.setStaticPosition(offsetDiffs
);
1941 return extractOp
.getResult();
1943 // If the chunk extracted is disjoint from the chunk inserted, keep
1944 // looking in the insert chain.
1945 insertOp
= insertOp
.getDest().getDefiningOp
<InsertStridedSliceOp
>();
1950 /// Try to fold the extraction of a scalar from a vector defined by
1951 /// vector.from_elements. E.g.:
1953 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1954 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1956 static Value
foldScalarExtractFromFromElements(ExtractOp extractOp
) {
1957 // Dynamic extractions cannot be folded.
1958 if (extractOp
.hasDynamicPosition())
1961 // Look for extract(from_elements).
1962 auto fromElementsOp
= extractOp
.getVector().getDefiningOp
<FromElementsOp
>();
1963 if (!fromElementsOp
)
1966 // Scalable vectors are not supported.
1967 auto vecType
= llvm::cast
<VectorType
>(fromElementsOp
.getType());
1968 if (vecType
.isScalable())
1971 // Only extractions of scalars are supported.
1972 int64_t rank
= vecType
.getRank();
1973 ArrayRef
<int64_t> indices
= extractOp
.getStaticPosition();
1974 if (extractOp
.getType() != vecType
.getElementType())
1976 assert(static_cast<int64_t>(indices
.size()) == rank
&&
1977 "unexpected number of indices");
1979 // Compute flattened/linearized index and fold to operand.
1982 for (int i
= rank
- 1; i
>= 0; --i
) {
1983 flatIndex
+= indices
[i
] * stride
;
1984 stride
*= vecType
.getDimSize(i
);
1986 return fromElementsOp
.getElements()[flatIndex
];
1989 /// Fold an insert or extract operation into an poison value when a poison index
1990 /// is found at any dimension of the static position.
1991 static ub::PoisonAttr
1992 foldPoisonIndexInsertExtractOp(MLIRContext
*context
,
1993 ArrayRef
<int64_t> staticPos
, int64_t poisonVal
) {
1994 if (!llvm::is_contained(staticPos
, poisonVal
))
1995 return ub::PoisonAttr();
1997 return ub::PoisonAttr::get(context
);
2000 OpFoldResult
ExtractOp::fold(FoldAdaptor adaptor
) {
2001 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2002 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2004 if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2006 if (auto res
= foldPoisonIndexInsertExtractOp(
2007 getContext(), adaptor
.getStaticPosition(), kPoisonIndex
))
2009 if (succeeded(foldExtractOpFromExtractChain(*this)))
2011 if (auto res
= ExtractFromInsertTransposeChainState(*this).fold())
2013 if (auto res
= foldExtractFromBroadcast(*this))
2015 if (auto res
= foldExtractFromShuffle(*this))
2017 if (auto res
= foldExtractFromShapeCast(*this))
2019 if (auto val
= foldExtractFromExtractStrided(*this))
2021 if (auto val
= foldExtractStridedOpFromInsertChain(*this))
2023 if (auto val
= foldScalarExtractFromFromElements(*this))
2025 return OpFoldResult();
2030 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2031 class ExtractOpFromBroadcast final
: public OpRewritePattern
<ExtractOp
> {
2033 using OpRewritePattern::OpRewritePattern
;
2035 LogicalResult
matchAndRewrite(ExtractOp extractOp
,
2036 PatternRewriter
&rewriter
) const override
{
2037 Operation
*defOp
= extractOp
.getVector().getDefiningOp();
2038 if (!defOp
|| !isa
<vector::BroadcastOp
, SplatOp
>(defOp
))
2041 Value source
= defOp
->getOperand(0);
2042 if (extractOp
.getType() == source
.getType())
2044 auto getRank
= [](Type type
) {
2045 return llvm::isa
<VectorType
>(type
)
2046 ? llvm::cast
<VectorType
>(type
).getRank()
2049 unsigned broadcastSrcRank
= getRank(source
.getType());
2050 unsigned extractResultRank
= getRank(extractOp
.getType());
2051 // We only consider the case where the rank of the source is less than or
2052 // equal to the rank of the extract dst. The other cases are handled in the
2053 // folding patterns.
2054 if (extractResultRank
< broadcastSrcRank
)
2057 // Special case if broadcast src is a 0D vector.
2058 if (extractResultRank
== 0) {
2059 assert(broadcastSrcRank
== 0 && llvm::isa
<VectorType
>(source
.getType()));
2060 rewriter
.replaceOpWithNewOp
<vector::ExtractElementOp
>(extractOp
, source
);
2063 rewriter
.replaceOpWithNewOp
<vector::BroadcastOp
>(
2064 extractOp
, extractOp
.getType(), source
);
2069 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2070 class ExtractOpSplatConstantFolder final
: public OpRewritePattern
<ExtractOp
> {
2072 using OpRewritePattern::OpRewritePattern
;
2074 LogicalResult
matchAndRewrite(ExtractOp extractOp
,
2075 PatternRewriter
&rewriter
) const override
{
2076 // Return if 'ExtractOp' operand is not defined by a splat vector
2078 Value sourceVector
= extractOp
.getVector();
2079 Attribute vectorCst
;
2080 if (!matchPattern(sourceVector
, m_Constant(&vectorCst
)))
2082 auto splat
= llvm::dyn_cast
<SplatElementsAttr
>(vectorCst
);
2085 TypedAttr newAttr
= splat
.getSplatValue
<TypedAttr
>();
2086 if (auto vecDstType
= llvm::dyn_cast
<VectorType
>(extractOp
.getType()))
2087 newAttr
= DenseElementsAttr::get(vecDstType
, newAttr
);
2088 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(extractOp
, newAttr
);
2093 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2094 class ExtractOpNonSplatConstantFolder final
2095 : public OpRewritePattern
<ExtractOp
> {
2097 using OpRewritePattern::OpRewritePattern
;
2099 LogicalResult
matchAndRewrite(ExtractOp extractOp
,
2100 PatternRewriter
&rewriter
) const override
{
2101 // TODO: Canonicalization for dynamic position not implemented yet.
2102 if (extractOp
.hasDynamicPosition())
2105 // Return if 'ExtractOp' operand is not defined by a compatible vector
2107 Value sourceVector
= extractOp
.getVector();
2108 Attribute vectorCst
;
2109 if (!matchPattern(sourceVector
, m_Constant(&vectorCst
)))
2112 auto vecTy
= llvm::cast
<VectorType
>(sourceVector
.getType());
2113 if (vecTy
.isScalable())
2116 // The splat case is handled by `ExtractOpSplatConstantFolder`.
2117 auto dense
= llvm::dyn_cast
<DenseElementsAttr
>(vectorCst
);
2118 if (!dense
|| dense
.isSplat())
2121 // Calculate the linearized position of the continuous chunk of elements to
2123 llvm::SmallVector
<int64_t> completePositions(vecTy
.getRank(), 0);
2124 copy(extractOp
.getStaticPosition(), completePositions
.begin());
2125 int64_t elemBeginPosition
=
2126 linearize(completePositions
, computeStrides(vecTy
.getShape()));
2127 auto denseValuesBegin
= dense
.value_begin
<TypedAttr
>() + elemBeginPosition
;
2130 if (auto resVecTy
= llvm::dyn_cast
<VectorType
>(extractOp
.getType())) {
2131 SmallVector
<Attribute
> elementValues(
2132 denseValuesBegin
, denseValuesBegin
+ resVecTy
.getNumElements());
2133 newAttr
= DenseElementsAttr::get(resVecTy
, elementValues
);
2135 newAttr
= *denseValuesBegin
;
2138 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(extractOp
, newAttr
);
2143 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2144 class ExtractOpFromCreateMask final
: public OpRewritePattern
<ExtractOp
> {
2146 using OpRewritePattern::OpRewritePattern
;
2148 LogicalResult
matchAndRewrite(ExtractOp extractOp
,
2149 PatternRewriter
&rewriter
) const override
{
2151 extractOp
.getVector().getDefiningOp
<vector::CreateMaskOp
>();
2155 VectorType extractedMaskType
=
2156 llvm::dyn_cast
<VectorType
>(extractOp
.getResult().getType());
2158 if (!extractedMaskType
)
2161 auto maskOperands
= createMaskOp
.getOperands();
2162 ArrayRef
<int64_t> extractOpPos
= extractOp
.getStaticPosition();
2163 VectorType maskType
= createMaskOp
.getVectorType();
2165 bool containsUnknownDims
= false;
2166 bool allFalse
= getMaskFormat(createMaskOp
) == MaskFormat::AllFalse
;
2168 for (size_t dimIdx
= 0; !allFalse
&& dimIdx
< extractOpPos
.size();
2170 int64_t pos
= extractOpPos
[dimIdx
];
2171 Value operand
= maskOperands
[dimIdx
];
2172 auto constantOp
= operand
.getDefiningOp
<arith::ConstantOp
>();
2174 // Bounds of this dim unknown.
2175 containsUnknownDims
= true;
2179 int64_t createMaskBound
=
2180 llvm::cast
<IntegerAttr
>(constantOp
.getValue()).getInt();
2182 if (pos
!= ShapedType::kDynamic
) {
2183 // If any position is outside the range from the `create_mask`, then the
2184 // extracted mask will be all-false.
2185 allFalse
|= pos
>= createMaskBound
;
2186 } else if (createMaskBound
< maskType
.getDimSize(dimIdx
)) {
2187 // This dim is not all-true and since this is a dynamic index we don't
2188 // know if the extraction is within the true or false region.
2189 // Note: Zero dims have already handled via getMaskFormat().
2190 containsUnknownDims
= true;
2195 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(
2196 extractOp
, DenseElementsAttr::get(extractedMaskType
, false));
2197 } else if (!containsUnknownDims
) {
2198 rewriter
.replaceOpWithNewOp
<vector::CreateMaskOp
>(
2199 extractOp
, extractedMaskType
,
2200 maskOperands
.drop_front(extractOpPos
.size()));
2208 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2210 LogicalResult
foldExtractFromShapeCastToShapeCast(ExtractOp extractOp
,
2211 PatternRewriter
&rewriter
) {
2212 auto castOp
= extractOp
.getVector().getDefiningOp
<ShapeCastOp
>();
2216 VectorType sourceType
= castOp
.getSourceVectorType();
2217 auto targetType
= dyn_cast
<VectorType
>(extractOp
.getResult().getType());
2221 if (sourceType
.getNumElements() != targetType
.getNumElements())
2224 rewriter
.replaceOpWithNewOp
<vector::ShapeCastOp
>(extractOp
, targetType
,
2225 castOp
.getSource());
2229 /// Try to canonicalize the extraction of a subvector from a vector defined by
2230 /// vector.from_elements. E.g.:
2232 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2233 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2234 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2235 LogicalResult
foldExtractFromFromElements(ExtractOp extractOp
,
2236 PatternRewriter
&rewriter
) {
2237 // Dynamic positions are not supported.
2238 if (extractOp
.hasDynamicPosition())
2241 // Scalar extracts are handled by the folder.
2242 auto resultType
= dyn_cast
<VectorType
>(extractOp
.getType());
2246 // Look for extracts from a from_elements op.
2247 auto fromElementsOp
= extractOp
.getVector().getDefiningOp
<FromElementsOp
>();
2248 if (!fromElementsOp
)
2250 VectorType inputType
= fromElementsOp
.getType();
2252 // Scalable vectors are not supported.
2253 if (resultType
.isScalable() || inputType
.isScalable())
2256 // Compute the position of first extracted element and flatten/linearize the
2258 SmallVector
<int64_t> firstElementPos
=
2259 llvm::to_vector(extractOp
.getStaticPosition());
2260 firstElementPos
.append(/*NumInputs=*/resultType
.getRank(), /*Elt=*/0);
2263 for (int64_t i
= inputType
.getRank() - 1; i
>= 0; --i
) {
2264 flatIndex
+= firstElementPos
[i
] * stride
;
2265 stride
*= inputType
.getDimSize(i
);
2268 // Replace the op with a smaller from_elements op.
2269 rewriter
.replaceOpWithNewOp
<FromElementsOp
>(
2270 extractOp
, resultType
,
2271 fromElementsOp
.getElements().slice(flatIndex
,
2272 resultType
.getNumElements()));
2276 /// Fold an insert or extract operation into an poison value when a poison index
2277 /// is found at any dimension of the static position.
2278 template <typename OpTy
>
2280 canonicalizePoisonIndexInsertExtractOp(OpTy op
, PatternRewriter
&rewriter
) {
2281 if (auto poisonAttr
= foldPoisonIndexInsertExtractOp(
2282 op
.getContext(), op
.getStaticPosition(), OpTy::kPoisonIndex
)) {
2283 rewriter
.replaceOpWithNewOp
<ub::PoisonOp
>(op
, op
.getType(), poisonAttr
);
2292 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2293 MLIRContext
*context
) {
2294 results
.add
<ExtractOpSplatConstantFolder
, ExtractOpNonSplatConstantFolder
,
2295 ExtractOpFromBroadcast
, ExtractOpFromCreateMask
>(context
);
2296 results
.add(foldExtractFromShapeCastToShapeCast
);
2297 results
.add(foldExtractFromFromElements
);
2298 results
.add(canonicalizePoisonIndexInsertExtractOp
<ExtractOp
>);
2301 static void populateFromInt64AttrArray(ArrayAttr arrayAttr
,
2302 SmallVectorImpl
<int64_t> &results
) {
2303 for (auto attr
: arrayAttr
)
2304 results
.push_back(llvm::cast
<IntegerAttr
>(attr
).getInt());
2307 //===----------------------------------------------------------------------===//
2309 //===----------------------------------------------------------------------===//
2311 std::optional
<SmallVector
<int64_t, 4>> FMAOp::getShapeForUnroll() {
2312 return llvm::to_vector
<4>(getVectorType().getShape());
2315 //===----------------------------------------------------------------------===//
2317 //===----------------------------------------------------------------------===//
2319 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2320 /// same SSA value. E.g.:
2322 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2323 /// ==> rewrite to vector.splat %a : vector<3xf32>
2324 static LogicalResult
rewriteFromElementsAsSplat(FromElementsOp fromElementsOp
,
2325 PatternRewriter
&rewriter
) {
2326 if (!llvm::all_equal(fromElementsOp
.getElements()))
2328 rewriter
.replaceOpWithNewOp
<SplatOp
>(fromElementsOp
, fromElementsOp
.getType(),
2329 fromElementsOp
.getElements().front());
2333 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2334 MLIRContext
*context
) {
2335 results
.add(rewriteFromElementsAsSplat
);
2338 //===----------------------------------------------------------------------===//
2340 //===----------------------------------------------------------------------===//
2342 void BroadcastOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
2343 SetIntRangeFn setResultRanges
) {
2344 setResultRanges(getResult(), argRanges
.front());
2347 /// Return the dimensions of the result vector that were formerly ones in the
2348 /// source tensor and thus correspond to "dim-1" broadcasting.
2349 static llvm::SetVector
<int64_t>
2350 computeBroadcastedUnitDims(ArrayRef
<int64_t> srcShape
,
2351 ArrayRef
<int64_t> dstShape
) {
2352 int64_t rankDiff
= dstShape
.size() - srcShape
.size();
2353 int64_t dstDim
= rankDiff
;
2354 llvm::SetVector
<int64_t> res
;
2355 for (auto [s1
, s2
] :
2356 llvm::zip_equal(srcShape
, dstShape
.drop_front(rankDiff
))) {
2358 assert(s1
== 1 && "expected dim-1 broadcasting");
2366 llvm::SetVector
<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2367 // Scalar broadcast is without any unit dim broadcast.
2368 auto srcVectorType
= llvm::dyn_cast
<VectorType
>(getSourceType());
2371 return ::computeBroadcastedUnitDims(srcVectorType
.getShape(),
2372 getResultVectorType().getShape());
2375 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2376 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2377 /// This requires (and asserts) that the broadcast is free of dim-1
2379 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2380 /// vector.transpose may be inserted to make the broadcast possible.
2381 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2382 /// the helper will assert. This means:
2383 /// 1. `dstShape` must not be empty.
2384 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2385 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2386 // must match the `value` shape.
2387 Value
BroadcastOp::createOrFoldBroadcastOp(
2388 OpBuilder
&b
, Value value
, ArrayRef
<int64_t> dstShape
,
2389 const llvm::SetVector
<int64_t> &broadcastedDims
) {
2390 assert(!dstShape
.empty() && "unexpected empty dst shape");
2392 // Well-formedness check.
2393 SmallVector
<int64_t> checkShape
;
2394 for (int i
= 0, e
= dstShape
.size(); i
< e
; ++i
) {
2395 if (broadcastedDims
.contains(i
))
2397 checkShape
.push_back(dstShape
[i
]);
2399 assert(broadcastedDims
.size() == dstShape
.size() - checkShape
.size() &&
2400 "ill-formed broadcastedDims contains values not confined to "
2403 Location loc
= value
.getLoc();
2404 Type elementType
= getElementTypeOrSelf(value
.getType());
2405 VectorType srcVectorType
= llvm::dyn_cast
<VectorType
>(value
.getType());
2406 VectorType dstVectorType
= VectorType::get(dstShape
, elementType
);
2408 // Step 2. If scalar -> dstShape broadcast, just do it.
2409 if (!srcVectorType
) {
2410 assert(checkShape
.empty() &&
2411 "ill-formed createOrFoldBroadcastOp arguments");
2412 return b
.createOrFold
<vector::BroadcastOp
>(loc
, dstVectorType
, value
);
2415 assert(srcVectorType
.getShape().equals(checkShape
) &&
2416 "ill-formed createOrFoldBroadcastOp arguments");
2418 // Step 3. Since vector.broadcast only allows creating leading dims,
2419 // vector -> dstShape broadcast may require a transpose.
2420 // Traverse the dims in order and construct:
2421 // 1. The leading entries of the broadcastShape that is guaranteed to be
2422 // achievable by a simple broadcast.
2423 // 2. The induced permutation for the subsequent vector.transpose that will
2424 // bring us from `broadcastShape` back to he desired `dstShape`.
2425 // If the induced permutation is not the identity, create a vector.transpose.
2426 SmallVector
<int64_t> broadcastShape
, permutation(dstShape
.size(), -1);
2427 broadcastShape
.reserve(dstShape
.size());
2428 // Consider the example:
2430 // dstShape = 1x2x3x4x5
2431 // broadcastedDims = [0, 2, 4]
2433 // We want to build:
2434 // broadcastShape = 1x3x5x2x4
2435 // permutation = [0, 2, 4, 1, 3]
2436 // ---V--- -----V-----
2437 // leading broadcast part src shape part
2439 // Note that the trailing dims of broadcastShape are exactly the srcShape
2441 // nextSrcShapeDim is used to keep track of where in the permutation the
2442 // "src shape part" occurs.
2443 int64_t nextSrcShapeDim
= broadcastedDims
.size();
2444 for (int64_t i
= 0, e
= dstShape
.size(); i
< e
; ++i
) {
2445 if (broadcastedDims
.contains(i
)) {
2446 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2447 // bring it to the head of the broadcastShape.
2448 // It will need to be permuted back from `broadcastShape.size() - 1` into
2450 broadcastShape
.push_back(dstShape
[i
]);
2451 permutation
[i
] = broadcastShape
.size() - 1;
2453 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2454 // shape and needs to be permuted into position `i`.
2455 // Don't touch `broadcastShape` here, the whole srcShape will be
2457 permutation
[i
] = nextSrcShapeDim
++;
2460 // 3.c. Append the srcShape.
2461 llvm::append_range(broadcastShape
, srcVectorType
.getShape());
2463 // Ensure there are no dim-1 broadcasts.
2464 assert(::computeBroadcastedUnitDims(srcVectorType
.getShape(), broadcastShape
)
2466 "unexpected dim-1 broadcast");
2468 VectorType broadcastType
= VectorType::get(broadcastShape
, elementType
);
2469 assert(vector::isBroadcastableTo(value
.getType(), broadcastType
) ==
2470 vector::BroadcastableToResult::Success
&&
2471 "must be broadcastable");
2472 Value res
= b
.createOrFold
<vector::BroadcastOp
>(loc
, broadcastType
, value
);
2473 // Step 4. If we find any dimension that indeed needs to be permuted,
2474 // immediately return a new vector.transpose.
2475 for (int64_t i
= 0, e
= permutation
.size(); i
< e
; ++i
)
2476 if (permutation
[i
] != i
)
2477 return b
.createOrFold
<vector::TransposeOp
>(loc
, res
, permutation
);
2478 // Otherwise return res.
2482 BroadcastableToResult
mlir::vector::isBroadcastableTo(
2483 Type srcType
, VectorType dstVectorType
,
2484 std::pair
<VectorDim
, VectorDim
> *mismatchingDims
) {
2485 // Broadcast scalar to vector of the same element type.
2486 if (srcType
.isIntOrIndexOrFloat() && dstVectorType
&&
2487 getElementTypeOrSelf(srcType
) == getElementTypeOrSelf(dstVectorType
))
2488 return BroadcastableToResult::Success
;
2489 // From now on, only vectors broadcast.
2490 VectorType srcVectorType
= llvm::dyn_cast
<VectorType
>(srcType
);
2492 return BroadcastableToResult::SourceTypeNotAVector
;
2494 int64_t srcRank
= srcVectorType
.getRank();
2495 int64_t dstRank
= dstVectorType
.getRank();
2496 if (srcRank
> dstRank
)
2497 return BroadcastableToResult::SourceRankHigher
;
2498 // Source has an exact match or singleton value for all trailing dimensions
2499 // (all leading dimensions are simply duplicated).
2500 int64_t lead
= dstRank
- srcRank
;
2501 for (int64_t dimIdx
= 0; dimIdx
< srcRank
; ++dimIdx
) {
2502 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2504 bool foundMismatchingDims
= false;
2506 // Check fixed-width dims.
2507 int64_t srcDim
= srcVectorType
.getDimSize(dimIdx
);
2508 int64_t dstDim
= dstVectorType
.getDimSize(lead
+ dimIdx
);
2509 if (srcDim
!= 1 && srcDim
!= dstDim
)
2510 foundMismatchingDims
= true;
2512 // Check scalable flags.
2513 bool srcDimScalableFlag
= srcVectorType
.getScalableDims()[dimIdx
];
2514 bool dstDimScalableFlag
= dstVectorType
.getScalableDims()[lead
+ dimIdx
];
2515 if ((srcDim
== 1 && srcDimScalableFlag
&& dstDim
!= 1) ||
2516 // 1 -> [N] is fine, everything else should be rejected when mixing
2517 // fixed-width and scalable dims
2518 (srcDimScalableFlag
!= dstDimScalableFlag
&&
2519 (srcDim
!= 1 || srcDimScalableFlag
)))
2520 foundMismatchingDims
= true;
2522 if (foundMismatchingDims
) {
2523 if (mismatchingDims
!= nullptr) {
2524 mismatchingDims
->first
.dim
= srcDim
;
2525 mismatchingDims
->first
.isScalable
= srcDimScalableFlag
;
2527 mismatchingDims
->second
.dim
= dstDim
;
2528 mismatchingDims
->second
.isScalable
= dstDimScalableFlag
;
2530 return BroadcastableToResult::DimensionMismatch
;
2534 return BroadcastableToResult::Success
;
2537 LogicalResult
BroadcastOp::verify() {
2538 std::pair
<VectorDim
, VectorDim
> mismatchingDims
;
2539 BroadcastableToResult res
= isBroadcastableTo(
2540 getSourceType(), getResultVectorType(), &mismatchingDims
);
2541 if (res
== BroadcastableToResult::Success
)
2543 if (res
== BroadcastableToResult::SourceRankHigher
)
2544 return emitOpError("source rank higher than destination rank");
2545 if (res
== BroadcastableToResult::DimensionMismatch
) {
2546 return emitOpError("dimension mismatch (")
2547 << (mismatchingDims
.first
.isScalable
? "[" : "")
2548 << mismatchingDims
.first
.dim
2549 << (mismatchingDims
.first
.isScalable
? "]" : "") << " vs. "
2550 << (mismatchingDims
.second
.isScalable
? "[" : "")
2551 << mismatchingDims
.second
.dim
2552 << (mismatchingDims
.second
.isScalable
? "]" : "") << ")";
2554 if (res
== BroadcastableToResult::SourceTypeNotAVector
)
2555 return emitOpError("source type is not a vector");
2556 llvm_unreachable("unexpected vector.broadcast op error");
2559 OpFoldResult
BroadcastOp::fold(FoldAdaptor adaptor
) {
2560 if (getSourceType() == getResultVectorType())
2562 if (!adaptor
.getSource())
2564 auto vectorType
= getResultVectorType();
2565 if (auto attr
= llvm::dyn_cast
<IntegerAttr
>(adaptor
.getSource())) {
2566 if (vectorType
.getElementType() != attr
.getType())
2568 return DenseElementsAttr::get(vectorType
, attr
);
2570 if (auto attr
= llvm::dyn_cast
<FloatAttr
>(adaptor
.getSource())) {
2571 if (vectorType
.getElementType() != attr
.getType())
2573 return DenseElementsAttr::get(vectorType
, attr
);
2575 if (auto attr
= llvm::dyn_cast
<SplatElementsAttr
>(adaptor
.getSource()))
2576 return DenseElementsAttr::get(vectorType
, attr
.getSplatValue
<Attribute
>());
2582 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2583 struct BroadcastFolder
: public OpRewritePattern
<BroadcastOp
> {
2584 using OpRewritePattern::OpRewritePattern
;
2586 LogicalResult
matchAndRewrite(BroadcastOp broadcastOp
,
2587 PatternRewriter
&rewriter
) const override
{
2588 auto srcBroadcast
= broadcastOp
.getSource().getDefiningOp
<BroadcastOp
>();
2591 rewriter
.replaceOpWithNewOp
<BroadcastOp
>(broadcastOp
,
2592 broadcastOp
.getResultVectorType(),
2593 srcBroadcast
.getSource());
2599 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2600 MLIRContext
*context
) {
2601 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2602 // calling `populateCastAwayVectorLeadingOneDimPatterns`
2603 results
.add
<BroadcastFolder
>(context
);
2606 //===----------------------------------------------------------------------===//
2608 //===----------------------------------------------------------------------===//
2610 LogicalResult
ShuffleOp::verify() {
2611 VectorType resultType
= getResultVectorType();
2612 VectorType v1Type
= getV1VectorType();
2613 VectorType v2Type
= getV2VectorType();
2615 int64_t resRank
= resultType
.getRank();
2616 int64_t v1Rank
= v1Type
.getRank();
2617 int64_t v2Rank
= v2Type
.getRank();
2618 bool wellFormed0DCase
= v1Rank
== 0 && v2Rank
== 0 && resRank
== 1;
2619 bool wellFormedNDCase
= v1Rank
== resRank
&& v2Rank
== resRank
;
2620 if (!wellFormed0DCase
&& !wellFormedNDCase
)
2621 return emitOpError("rank mismatch");
2623 // Verify all but leading dimension sizes.
2624 for (int64_t r
= 1; r
< v1Rank
; ++r
) {
2625 int64_t resDim
= resultType
.getDimSize(r
);
2626 int64_t v1Dim
= v1Type
.getDimSize(r
);
2627 int64_t v2Dim
= v2Type
.getDimSize(r
);
2628 if (resDim
!= v1Dim
|| v1Dim
!= v2Dim
)
2629 return emitOpError("dimension mismatch");
2631 // Verify mask length.
2632 ArrayRef
<int64_t> mask
= getMask();
2633 int64_t maskLength
= mask
.size();
2634 if (maskLength
<= 0)
2635 return emitOpError("invalid mask length");
2636 if (maskLength
!= resultType
.getDimSize(0))
2637 return emitOpError("mask length mismatch");
2638 // Verify all indices.
2639 int64_t indexSize
= (v1Type
.getRank() == 0 ? 1 : v1Type
.getDimSize(0)) +
2640 (v2Type
.getRank() == 0 ? 1 : v2Type
.getDimSize(0));
2641 for (auto [idx
, maskPos
] : llvm::enumerate(mask
)) {
2642 if (!isValidPositiveIndexOrPoison(maskPos
, kPoisonIndex
, indexSize
))
2643 return emitOpError("mask index #") << (idx
+ 1) << " out of range";
2649 ShuffleOp::inferReturnTypes(MLIRContext
*, std::optional
<Location
>,
2650 ShuffleOp::Adaptor adaptor
,
2651 SmallVectorImpl
<Type
> &inferredReturnTypes
) {
2652 auto v1Type
= llvm::cast
<VectorType
>(adaptor
.getV1().getType());
2653 auto v1Rank
= v1Type
.getRank();
2654 // Construct resulting type: leading dimension matches mask
2655 // length, all trailing dimensions match the operands.
2656 SmallVector
<int64_t, 4> shape
;
2657 shape
.reserve(v1Rank
);
2658 shape
.push_back(std::max
<size_t>(1, adaptor
.getMask().size()));
2659 // In the 0-D case there is no trailing shape to append.
2661 llvm::append_range(shape
, v1Type
.getShape().drop_front());
2662 inferredReturnTypes
.push_back(
2663 VectorType::get(shape
, v1Type
.getElementType()));
2667 template <typename T
>
2668 static bool isStepIndexArray(ArrayRef
<T
> idxArr
, uint64_t begin
, size_t width
) {
2670 return idxArr
.size() == width
&& llvm::all_of(idxArr
, [&expected
](T value
) {
2671 return value
== expected
++;
2675 OpFoldResult
vector::ShuffleOp::fold(FoldAdaptor adaptor
) {
2676 auto v1Type
= getV1VectorType();
2677 auto v2Type
= getV2VectorType();
2679 assert(!v1Type
.isScalable() && !v2Type
.isScalable() &&
2680 "Vector shuffle does not support scalable vectors");
2682 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2683 // but must be a canonicalization into a vector.broadcast.
2684 if (v1Type
.getRank() == 0)
2687 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2688 auto mask
= getMask();
2689 if (isStepIndexArray(mask
, 0, v1Type
.getDimSize(0)))
2691 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2692 if (isStepIndexArray(mask
, v1Type
.getDimSize(0), v2Type
.getDimSize(0)))
2695 Attribute v1Attr
= adaptor
.getV1(), v2Attr
= adaptor
.getV2();
2696 if (!v1Attr
|| !v2Attr
)
2699 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2701 if (v1Type
.getRank() != 1)
2704 int64_t v1Size
= v1Type
.getDimSize(0);
2706 SmallVector
<Attribute
> results
;
2707 auto v1Elements
= cast
<DenseElementsAttr
>(v1Attr
).getValues
<Attribute
>();
2708 auto v2Elements
= cast
<DenseElementsAttr
>(v2Attr
).getValues
<Attribute
>();
2709 for (int64_t maskIdx
: mask
) {
2710 Attribute indexedElm
;
2711 // Select v1[0] for poison indices.
2712 // TODO: Return a partial poison vector when supported by the UB dialect.
2713 if (maskIdx
== ShuffleOp::kPoisonIndex
) {
2714 indexedElm
= v1Elements
[0];
2717 maskIdx
< v1Size
? v1Elements
[maskIdx
] : v2Elements
[maskIdx
- v1Size
];
2720 results
.push_back(indexedElm
);
2723 return DenseElementsAttr::get(getResultVectorType(), results
);
2728 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2730 struct Canonicalize0DShuffleOp
: public OpRewritePattern
<ShuffleOp
> {
2731 using OpRewritePattern::OpRewritePattern
;
2733 LogicalResult
matchAndRewrite(ShuffleOp shuffleOp
,
2734 PatternRewriter
&rewriter
) const override
{
2735 VectorType v1VectorType
= shuffleOp
.getV1VectorType();
2736 ArrayRef
<int64_t> mask
= shuffleOp
.getMask();
2737 if (v1VectorType
.getRank() > 0)
2739 if (mask
.size() != 1)
2741 VectorType resType
= VectorType::Builder(v1VectorType
).setShape({1});
2743 rewriter
.replaceOpWithNewOp
<vector::BroadcastOp
>(shuffleOp
, resType
,
2746 rewriter
.replaceOpWithNewOp
<vector::BroadcastOp
>(shuffleOp
, resType
,
2752 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2753 class ShuffleSplat final
: public OpRewritePattern
<ShuffleOp
> {
2755 using OpRewritePattern::OpRewritePattern
;
2757 LogicalResult
matchAndRewrite(ShuffleOp op
,
2758 PatternRewriter
&rewriter
) const override
{
2759 auto v1Splat
= op
.getV1().getDefiningOp
<SplatOp
>();
2760 auto v2Splat
= op
.getV2().getDefiningOp
<SplatOp
>();
2762 if (!v1Splat
|| !v2Splat
)
2765 if (v1Splat
.getInput() != v2Splat
.getInput())
2768 rewriter
.replaceOpWithNewOp
<SplatOp
>(op
, op
.getType(), v1Splat
.getInput());
2773 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2774 /// vector.interleave.
2775 class ShuffleInterleave
: public OpRewritePattern
<ShuffleOp
> {
2777 using OpRewritePattern::OpRewritePattern
;
2779 LogicalResult
matchAndRewrite(ShuffleOp op
,
2780 PatternRewriter
&rewriter
) const override
{
2781 VectorType resultType
= op
.getResultVectorType();
2782 if (resultType
.isScalable())
2783 return rewriter
.notifyMatchFailure(
2784 op
, "ShuffleOp can't represent a scalable interleave");
2786 if (resultType
.getRank() != 1)
2787 return rewriter
.notifyMatchFailure(
2788 op
, "ShuffleOp can't represent an n-D interleave");
2790 VectorType sourceType
= op
.getV1VectorType();
2791 if (sourceType
!= op
.getV2VectorType() ||
2792 sourceType
.getNumElements() * 2 != resultType
.getNumElements()) {
2793 return rewriter
.notifyMatchFailure(
2794 op
, "ShuffleOp types don't match an interleave");
2797 ArrayRef
<int64_t> shuffleMask
= op
.getMask();
2798 int64_t resultVectorSize
= resultType
.getNumElements();
2799 for (int i
= 0, e
= resultVectorSize
/ 2; i
< e
; ++i
) {
2800 int64_t maskValueA
= shuffleMask
[i
* 2];
2801 int64_t maskValueB
= shuffleMask
[(i
* 2) + 1];
2802 if (maskValueA
!= i
|| maskValueB
!= (resultVectorSize
/ 2) + i
)
2803 return rewriter
.notifyMatchFailure(op
,
2804 "ShuffleOp mask not interleaving");
2807 rewriter
.replaceOpWithNewOp
<InterleaveOp
>(op
, op
.getV1(), op
.getV2());
2814 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2815 MLIRContext
*context
) {
2816 results
.add
<ShuffleSplat
, ShuffleInterleave
, Canonicalize0DShuffleOp
>(
2820 //===----------------------------------------------------------------------===//
2822 //===----------------------------------------------------------------------===//
2824 void InsertElementOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
2825 SetIntRangeFn setResultRanges
) {
2826 setResultRanges(getResult(), argRanges
[0].rangeUnion(argRanges
[1]));
2829 void InsertElementOp::build(OpBuilder
&builder
, OperationState
&result
,
2830 Value source
, Value dest
) {
2831 build(builder
, result
, source
, dest
, {});
2834 LogicalResult
InsertElementOp::verify() {
2835 auto dstVectorType
= getDestVectorType();
2836 if (dstVectorType
.getRank() == 0) {
2838 return emitOpError("expected position to be empty with 0-D vector");
2841 if (dstVectorType
.getRank() != 1)
2842 return emitOpError("unexpected >1 vector rank");
2844 return emitOpError("expected position for 1-D vector");
2848 OpFoldResult
vector::InsertElementOp::fold(FoldAdaptor adaptor
) {
2849 // Skip the 0-D vector here.
2850 if (!adaptor
.getPosition())
2853 auto src
= dyn_cast_or_null
<TypedAttr
>(adaptor
.getSource());
2854 auto dst
= dyn_cast_or_null
<DenseElementsAttr
>(adaptor
.getDest());
2855 auto pos
= dyn_cast_or_null
<IntegerAttr
>(adaptor
.getPosition());
2856 if (!src
|| !dst
|| !pos
)
2859 if (src
.getType() != getDestVectorType().getElementType())
2862 auto dstElements
= dst
.getValues
<Attribute
>();
2864 SmallVector
<Attribute
> results(dstElements
);
2866 uint64_t posIdx
= pos
.getInt();
2867 if (posIdx
>= results
.size())
2869 results
[posIdx
] = src
;
2871 return DenseElementsAttr::get(getDestVectorType(), results
);
2874 //===----------------------------------------------------------------------===//
2876 //===----------------------------------------------------------------------===//
2878 void vector::InsertOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
2879 SetIntRangeFn setResultRanges
) {
2880 setResultRanges(getResult(), argRanges
[0].rangeUnion(argRanges
[1]));
2883 void vector::InsertOp::build(OpBuilder
&builder
, OperationState
&result
,
2884 Value source
, Value dest
, int64_t position
) {
2885 build(builder
, result
, source
, dest
, ArrayRef
<int64_t>{position
});
2888 void vector::InsertOp::build(OpBuilder
&builder
, OperationState
&result
,
2889 Value source
, Value dest
, OpFoldResult position
) {
2890 build(builder
, result
, source
, dest
, ArrayRef
<OpFoldResult
>{position
});
2893 void vector::InsertOp::build(OpBuilder
&builder
, OperationState
&result
,
2894 Value source
, Value dest
,
2895 ArrayRef
<int64_t> position
) {
2896 SmallVector
<OpFoldResult
> posVals
;
2897 posVals
.reserve(position
.size());
2898 llvm::transform(position
, std::back_inserter(posVals
),
2899 [&](int64_t pos
) { return builder
.getI64IntegerAttr(pos
); });
2900 build(builder
, result
, source
, dest
, posVals
);
2903 void vector::InsertOp::build(OpBuilder
&builder
, OperationState
&result
,
2904 Value source
, Value dest
,
2905 ArrayRef
<OpFoldResult
> position
) {
2906 SmallVector
<int64_t> staticPos
;
2907 SmallVector
<Value
> dynamicPos
;
2908 dispatchIndexOpFoldResults(position
, dynamicPos
, staticPos
);
2909 build(builder
, result
, source
, dest
, dynamicPos
,
2910 builder
.getDenseI64ArrayAttr(staticPos
));
2913 LogicalResult
InsertOp::verify() {
2914 SmallVector
<OpFoldResult
> position
= getMixedPosition();
2915 auto destVectorType
= getDestVectorType();
2916 if (position
.size() > static_cast<unsigned>(destVectorType
.getRank()))
2918 "expected position attribute of rank no greater than dest vector rank");
2919 auto srcVectorType
= llvm::dyn_cast
<VectorType
>(getSourceType());
2920 if (srcVectorType
&&
2921 (static_cast<unsigned>(srcVectorType
.getRank()) + position
.size() !=
2922 static_cast<unsigned>(destVectorType
.getRank())))
2923 return emitOpError("expected position attribute rank + source rank to "
2924 "match dest vector rank");
2925 if (!srcVectorType
&&
2926 (position
.size() != static_cast<unsigned>(destVectorType
.getRank())))
2928 "expected position attribute rank to match the dest vector rank");
2929 for (auto [idx
, pos
] : llvm::enumerate(position
)) {
2930 if (auto attr
= pos
.dyn_cast
<Attribute
>()) {
2931 int64_t constIdx
= cast
<IntegerAttr
>(attr
).getInt();
2932 if (!isValidPositiveIndexOrPoison(constIdx
, kPoisonIndex
,
2933 destVectorType
.getDimSize(idx
))) {
2934 return emitOpError("expected position attribute #")
2936 << " to be a non-negative integer smaller than the "
2938 "dest vector dimension";
2947 // If insertOp is only inserting unit dimensions it can be transformed to a
2949 class InsertToBroadcast final
: public OpRewritePattern
<InsertOp
> {
2951 using OpRewritePattern::OpRewritePattern
;
2953 LogicalResult
matchAndRewrite(InsertOp insertOp
,
2954 PatternRewriter
&rewriter
) const override
{
2955 auto srcVecType
= llvm::dyn_cast
<VectorType
>(insertOp
.getSourceType());
2956 if (!srcVecType
|| insertOp
.getDestVectorType().getNumElements() !=
2957 srcVecType
.getNumElements())
2959 rewriter
.replaceOpWithNewOp
<BroadcastOp
>(
2960 insertOp
, insertOp
.getDestVectorType(), insertOp
.getSource());
2965 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2966 class InsertSplatToSplat final
: public OpRewritePattern
<InsertOp
> {
2968 using OpRewritePattern::OpRewritePattern
;
2970 LogicalResult
matchAndRewrite(InsertOp op
,
2971 PatternRewriter
&rewriter
) const override
{
2972 auto srcSplat
= op
.getSource().getDefiningOp
<SplatOp
>();
2973 auto dstSplat
= op
.getDest().getDefiningOp
<SplatOp
>();
2975 if (!srcSplat
|| !dstSplat
)
2978 if (srcSplat
.getInput() != dstSplat
.getInput())
2981 rewriter
.replaceOpWithNewOp
<SplatOp
>(op
, op
.getType(), srcSplat
.getInput());
2986 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2987 class InsertOpConstantFolder final
: public OpRewritePattern
<InsertOp
> {
2989 using OpRewritePattern::OpRewritePattern
;
2991 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2992 // unless the source vector constant has a single use.
2993 static constexpr int64_t vectorSizeFoldThreshold
= 256;
2995 LogicalResult
matchAndRewrite(InsertOp op
,
2996 PatternRewriter
&rewriter
) const override
{
2997 // TODO: Canonicalization for dynamic position not implemented yet.
2998 if (op
.hasDynamicPosition())
3001 // Return if 'InsertOp' operand is not defined by a compatible vector
3003 TypedValue
<VectorType
> destVector
= op
.getDest();
3004 Attribute vectorDestCst
;
3005 if (!matchPattern(destVector
, m_Constant(&vectorDestCst
)))
3007 auto denseDest
= llvm::dyn_cast
<DenseElementsAttr
>(vectorDestCst
);
3011 VectorType destTy
= destVector
.getType();
3012 if (destTy
.isScalable())
3015 // Make sure we do not create too many large constants.
3016 if (destTy
.getNumElements() > vectorSizeFoldThreshold
&&
3017 !destVector
.hasOneUse())
3020 Value sourceValue
= op
.getSource();
3021 Attribute sourceCst
;
3022 if (!matchPattern(sourceValue
, m_Constant(&sourceCst
)))
3025 // Calculate the linearized position of the continuous chunk of elements to
3027 llvm::SmallVector
<int64_t> completePositions(destTy
.getRank(), 0);
3028 copy(op
.getStaticPosition(), completePositions
.begin());
3029 int64_t insertBeginPosition
=
3030 linearize(completePositions
, computeStrides(destTy
.getShape()));
3032 SmallVector
<Attribute
> insertedValues
;
3033 Type destEltType
= destTy
.getElementType();
3035 // The `convertIntegerAttr` method specifically handles the case
3036 // for `llvm.mlir.constant` which can hold an attribute with a
3037 // different type than the return type.
3038 if (auto denseSource
= llvm::dyn_cast
<DenseElementsAttr
>(sourceCst
)) {
3039 for (auto value
: denseSource
.getValues
<Attribute
>())
3040 insertedValues
.push_back(convertIntegerAttr(value
, destEltType
));
3042 insertedValues
.push_back(convertIntegerAttr(sourceCst
, destEltType
));
3045 auto allValues
= llvm::to_vector(denseDest
.getValues
<Attribute
>());
3046 copy(insertedValues
, allValues
.begin() + insertBeginPosition
);
3047 auto newAttr
= DenseElementsAttr::get(destTy
, allValues
);
3049 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(op
, newAttr
);
3054 /// Converts the expected type to an IntegerAttr if there's
3056 Attribute
convertIntegerAttr(Attribute attr
, Type expectedType
) const {
3057 if (auto intAttr
= mlir::dyn_cast
<IntegerAttr
>(attr
)) {
3058 if (intAttr
.getType() != expectedType
)
3059 return IntegerAttr::get(expectedType
, intAttr
.getInt());
3067 void InsertOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
3068 MLIRContext
*context
) {
3069 results
.add
<InsertToBroadcast
, BroadcastFolder
, InsertSplatToSplat
,
3070 InsertOpConstantFolder
>(context
);
3071 results
.add(canonicalizePoisonIndexInsertExtractOp
<InsertOp
>);
3074 OpFoldResult
vector::InsertOp::fold(FoldAdaptor adaptor
) {
3075 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3076 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3078 if (getNumIndices() == 0 && getSourceType() == getType())
3080 if (auto res
= foldPoisonIndexInsertExtractOp(
3081 getContext(), adaptor
.getStaticPosition(), kPoisonIndex
))
3087 //===----------------------------------------------------------------------===//
3088 // InsertStridedSliceOp
3089 //===----------------------------------------------------------------------===//
3091 void InsertStridedSliceOp::build(OpBuilder
&builder
, OperationState
&result
,
3092 Value source
, Value dest
,
3093 ArrayRef
<int64_t> offsets
,
3094 ArrayRef
<int64_t> strides
) {
3095 result
.addOperands({source
, dest
});
3096 auto offsetsAttr
= getVectorSubscriptAttr(builder
, offsets
);
3097 auto stridesAttr
= getVectorSubscriptAttr(builder
, strides
);
3098 result
.addTypes(dest
.getType());
3099 result
.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result
.name
),
3101 result
.addAttribute(InsertStridedSliceOp::getStridesAttrName(result
.name
),
3105 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3106 template <typename OpType
>
3107 static LogicalResult
isIntegerArrayAttrSmallerThanShape(OpType op
,
3108 ArrayAttr arrayAttr
,
3109 ArrayRef
<int64_t> shape
,
3110 StringRef attrName
) {
3111 if (arrayAttr
.size() > shape
.size())
3112 return op
.emitOpError("expected ")
3113 << attrName
<< " attribute of rank no greater than vector rank";
3117 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3118 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3119 // Otherwise, the admissible interval is [min, max].
3120 template <typename OpType
>
3121 static LogicalResult
3122 isIntegerArrayAttrConfinedToRange(OpType op
, ArrayAttr arrayAttr
, int64_t min
,
3123 int64_t max
, StringRef attrName
,
3124 bool halfOpen
= true) {
3125 for (auto attr
: arrayAttr
) {
3126 auto val
= llvm::cast
<IntegerAttr
>(attr
).getInt();
3130 if (val
< min
|| val
>= upper
)
3131 return op
.emitOpError("expected ") << attrName
<< " to be confined to ["
3132 << min
<< ", " << upper
<< ")";
3137 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3138 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3139 // Otherwise, the admissible interval is [min, max].
3140 template <typename OpType
>
3141 static LogicalResult
3142 isIntegerArrayAttrConfinedToShape(OpType op
, ArrayAttr arrayAttr
,
3143 ArrayRef
<int64_t> shape
, StringRef attrName
,
3144 bool halfOpen
= true, int64_t min
= 0) {
3145 for (auto [index
, attrDimPair
] :
3146 llvm::enumerate(llvm::zip_first(arrayAttr
, shape
))) {
3147 int64_t val
= llvm::cast
<IntegerAttr
>(std::get
<0>(attrDimPair
)).getInt();
3148 int64_t max
= std::get
<1>(attrDimPair
);
3151 if (val
< min
|| val
>= max
)
3152 return op
.emitOpError("expected ")
3153 << attrName
<< " dimension " << index
<< " to be confined to ["
3154 << min
<< ", " << max
<< ")";
3159 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3160 // [min, max} interval:
3161 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3162 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3163 // the admissible interval is [min, max].
3164 template <typename OpType
>
3165 static LogicalResult
isSumOfIntegerArrayAttrConfinedToShape(
3166 OpType op
, ArrayAttr arrayAttr1
, ArrayAttr arrayAttr2
,
3167 ArrayRef
<int64_t> shape
, StringRef attrName1
, StringRef attrName2
,
3168 bool halfOpen
= true, int64_t min
= 1) {
3169 assert(arrayAttr1
.size() <= shape
.size());
3170 assert(arrayAttr2
.size() <= shape
.size());
3171 for (auto [index
, it
] :
3172 llvm::enumerate(llvm::zip(arrayAttr1
, arrayAttr2
, shape
))) {
3173 auto val1
= llvm::cast
<IntegerAttr
>(std::get
<0>(it
)).getInt();
3174 auto val2
= llvm::cast
<IntegerAttr
>(std::get
<1>(it
)).getInt();
3175 int64_t max
= std::get
<2>(it
);
3178 if (val1
+ val2
< 0 || val1
+ val2
>= max
)
3179 return op
.emitOpError("expected sum(")
3180 << attrName1
<< ", " << attrName2
<< ") dimension " << index
3181 << " to be confined to [" << min
<< ", " << max
<< ")";
3186 static ArrayAttr
makeI64ArrayAttr(ArrayRef
<int64_t> values
,
3187 MLIRContext
*context
) {
3188 auto attrs
= llvm::map_range(values
, [context
](int64_t v
) -> Attribute
{
3189 return IntegerAttr::get(IntegerType::get(context
, 64), APInt(64, v
));
3191 return ArrayAttr::get(context
, llvm::to_vector
<8>(attrs
));
3194 LogicalResult
InsertStridedSliceOp::verify() {
3195 auto sourceVectorType
= getSourceVectorType();
3196 auto destVectorType
= getDestVectorType();
3197 auto offsets
= getOffsetsAttr();
3198 auto strides
= getStridesAttr();
3199 if (offsets
.size() != static_cast<unsigned>(destVectorType
.getRank()))
3201 "expected offsets of same size as destination vector rank");
3202 if (strides
.size() != static_cast<unsigned>(sourceVectorType
.getRank()))
3203 return emitOpError("expected strides of same size as source vector rank");
3204 if (sourceVectorType
.getRank() > destVectorType
.getRank())
3206 "expected source rank to be no greater than destination rank");
3208 auto sourceShape
= sourceVectorType
.getShape();
3209 auto destShape
= destVectorType
.getShape();
3210 SmallVector
<int64_t, 4> sourceShapeAsDestShape(
3211 destShape
.size() - sourceShape
.size(), 0);
3212 sourceShapeAsDestShape
.append(sourceShape
.begin(), sourceShape
.end());
3213 auto offName
= InsertStridedSliceOp::getOffsetsAttrName();
3214 auto stridesName
= InsertStridedSliceOp::getStridesAttrName();
3215 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets
, destShape
,
3217 failed(isIntegerArrayAttrConfinedToRange(*this, strides
, /*min=*/1,
3218 /*max=*/1, stridesName
,
3219 /*halfOpen=*/false)) ||
3220 failed(isSumOfIntegerArrayAttrConfinedToShape(
3222 makeI64ArrayAttr(sourceShapeAsDestShape
, getContext()), destShape
,
3223 offName
, "source vector shape",
3224 /*halfOpen=*/false, /*min=*/1)))
3227 unsigned rankDiff
= destShape
.size() - sourceShape
.size();
3228 for (unsigned idx
= 0; idx
< sourceShape
.size(); ++idx
) {
3229 if (sourceVectorType
.getScalableDims()[idx
] !=
3230 destVectorType
.getScalableDims()[idx
+ rankDiff
]) {
3231 return emitOpError("mismatching scalable flags (at source vector idx=")
3234 if (sourceVectorType
.getScalableDims()[idx
]) {
3235 auto sourceSize
= sourceShape
[idx
];
3236 auto destSize
= destShape
[idx
+ rankDiff
];
3237 if (sourceSize
!= destSize
) {
3238 return emitOpError("expected size at idx=")
3240 << (" to match the corresponding base size from the input "
3242 << sourceSize
<< (" vs ") << destSize
<< (")");
3251 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3252 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3253 class FoldInsertStridedSliceSplat final
3254 : public OpRewritePattern
<InsertStridedSliceOp
> {
3256 using OpRewritePattern::OpRewritePattern
;
3258 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp
,
3259 PatternRewriter
&rewriter
) const override
{
3261 insertStridedSliceOp
.getSource().getDefiningOp
<vector::SplatOp
>();
3263 insertStridedSliceOp
.getDest().getDefiningOp
<vector::SplatOp
>();
3265 if (!srcSplatOp
|| !destSplatOp
)
3268 if (srcSplatOp
.getInput() != destSplatOp
.getInput())
3271 rewriter
.replaceOp(insertStridedSliceOp
, insertStridedSliceOp
.getDest());
3276 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3278 class FoldInsertStridedSliceOfExtract final
3279 : public OpRewritePattern
<InsertStridedSliceOp
> {
3281 using OpRewritePattern::OpRewritePattern
;
3283 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp
,
3284 PatternRewriter
&rewriter
) const override
{
3285 auto extractStridedSliceOp
=
3286 insertStridedSliceOp
.getSource()
3287 .getDefiningOp
<vector::ExtractStridedSliceOp
>();
3289 if (!extractStridedSliceOp
)
3292 if (extractStridedSliceOp
.getOperand() != insertStridedSliceOp
.getDest())
3295 // Check if have the same strides and offsets.
3296 if (extractStridedSliceOp
.getStrides() !=
3297 insertStridedSliceOp
.getStrides() ||
3298 extractStridedSliceOp
.getOffsets() != insertStridedSliceOp
.getOffsets())
3301 rewriter
.replaceOp(insertStridedSliceOp
, insertStridedSliceOp
.getDest());
3306 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3308 class InsertStridedSliceConstantFolder final
3309 : public OpRewritePattern
<InsertStridedSliceOp
> {
3311 using OpRewritePattern::OpRewritePattern
;
3313 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3314 // unless the source vector constant has a single use.
3315 static constexpr int64_t vectorSizeFoldThreshold
= 256;
3317 LogicalResult
matchAndRewrite(InsertStridedSliceOp op
,
3318 PatternRewriter
&rewriter
) const override
{
3319 // Return if 'InsertOp' operand is not defined by a compatible vector
3321 TypedValue
<VectorType
> destVector
= op
.getDest();
3322 Attribute vectorDestCst
;
3323 if (!matchPattern(destVector
, m_Constant(&vectorDestCst
)))
3326 VectorType destTy
= destVector
.getType();
3327 if (destTy
.isScalable())
3330 // Make sure we do not create too many large constants.
3331 if (destTy
.getNumElements() > vectorSizeFoldThreshold
&&
3332 !destVector
.hasOneUse())
3335 auto denseDest
= llvm::cast
<DenseElementsAttr
>(vectorDestCst
);
3337 TypedValue
<VectorType
> sourceValue
= op
.getSource();
3338 Attribute sourceCst
;
3339 if (!matchPattern(sourceValue
, m_Constant(&sourceCst
)))
3342 // TODO: Handle non-unit strides when they become available.
3343 if (op
.hasNonUnitStrides())
3346 VectorType sliceVecTy
= sourceValue
.getType();
3347 ArrayRef
<int64_t> sliceShape
= sliceVecTy
.getShape();
3348 int64_t rankDifference
= destTy
.getRank() - sliceVecTy
.getRank();
3349 SmallVector
<int64_t, 4> offsets
= getI64SubArray(op
.getOffsets());
3350 SmallVector
<int64_t, 4> destStrides
= computeStrides(destTy
.getShape());
3352 // Calcualte the destination element indices by enumerating all slice
3353 // positions within the destination and linearizing them. The enumeration
3354 // order is lexicographic which yields a sequence of monotonically
3355 // increasing linearized position indices.
3356 // Because the destination may have higher dimensionality then the slice,
3357 // we keep track of two overlapping sets of positions and offsets.
3358 auto denseSlice
= llvm::cast
<DenseElementsAttr
>(sourceCst
);
3359 auto sliceValuesIt
= denseSlice
.value_begin
<Attribute
>();
3360 auto newValues
= llvm::to_vector(denseDest
.getValues
<Attribute
>());
3361 SmallVector
<int64_t> currDestPosition(offsets
.begin(), offsets
.end());
3362 MutableArrayRef
<int64_t> currSlicePosition(
3363 currDestPosition
.begin() + rankDifference
, currDestPosition
.end());
3364 ArrayRef
<int64_t> sliceOffsets(offsets
.begin() + rankDifference
,
3367 int64_t linearizedPosition
= linearize(currDestPosition
, destStrides
);
3368 assert(linearizedPosition
< destTy
.getNumElements() && "Invalid index");
3369 assert(sliceValuesIt
!= denseSlice
.value_end
<Attribute
>() &&
3370 "Invalid slice element");
3371 newValues
[linearizedPosition
] = *sliceValuesIt
;
3374 incSlicePosition(currSlicePosition
, sliceShape
, sliceOffsets
)));
3376 auto newAttr
= DenseElementsAttr::get(destTy
, newValues
);
3377 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(op
, newAttr
);
3384 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3385 RewritePatternSet
&results
, MLIRContext
*context
) {
3386 results
.add
<FoldInsertStridedSliceSplat
, FoldInsertStridedSliceOfExtract
,
3387 InsertStridedSliceConstantFolder
>(context
);
3390 OpFoldResult
InsertStridedSliceOp::fold(FoldAdaptor adaptor
) {
3391 if (getSourceVectorType() == getDestVectorType())
3396 //===----------------------------------------------------------------------===//
3398 //===----------------------------------------------------------------------===//
3400 /// Build an op without mask, use the type of `acc` as the return type.
3401 void OuterProductOp::build(OpBuilder
&builder
, OperationState
&result
,
3402 Value lhs
, Value rhs
, Value acc
) {
3403 result
.addOperands({lhs
, rhs
, acc
});
3404 result
.addTypes(acc
.getType());
3407 void OuterProductOp::print(OpAsmPrinter
&p
) {
3408 p
<< " " << getLhs() << ", " << getRhs();
3410 p
<< ", " << getAcc();
3411 p
.printOptionalAttrDict((*this)->getAttrs());
3413 p
<< " : " << getLhs().getType() << ", " << getRhs().getType();
3416 ParseResult
OuterProductOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
3417 SmallVector
<OpAsmParser::UnresolvedOperand
, 3> operandsInfo
;
3419 if (parser
.parseOperandList(operandsInfo
) ||
3420 parser
.parseOptionalAttrDict(result
.attributes
) ||
3421 parser
.parseColonType(tLHS
) || parser
.parseComma() ||
3422 parser
.parseType(tRHS
))
3424 if (operandsInfo
.size() < 2)
3425 return parser
.emitError(parser
.getNameLoc(),
3426 "expected at least 2 operands");
3427 VectorType vLHS
= llvm::dyn_cast
<VectorType
>(tLHS
);
3428 VectorType vRHS
= llvm::dyn_cast
<VectorType
>(tRHS
);
3430 return parser
.emitError(parser
.getNameLoc(),
3431 "expected vector type for operand #1");
3435 SmallVector
<bool> scalableDimsRes
{vLHS
.getScalableDims()[0],
3436 vRHS
.getScalableDims()[0]};
3437 resType
= VectorType::get({vLHS
.getDimSize(0), vRHS
.getDimSize(0)},
3438 vLHS
.getElementType(), scalableDimsRes
);
3440 // Scalar RHS operand
3441 SmallVector
<bool> scalableDimsRes
{vLHS
.getScalableDims()[0]};
3442 resType
= VectorType::get({vLHS
.getDimSize(0)}, vLHS
.getElementType(),
3446 if (!result
.attributes
.get(OuterProductOp::getKindAttrName(result
.name
))) {
3447 result
.attributes
.append(
3448 OuterProductOp::getKindAttrName(result
.name
),
3449 CombiningKindAttr::get(result
.getContext(),
3450 OuterProductOp::getDefaultKind()));
3454 parser
.resolveOperand(operandsInfo
[0], tLHS
, result
.operands
) ||
3455 parser
.resolveOperand(operandsInfo
[1], tRHS
, result
.operands
) ||
3456 (operandsInfo
.size() > 2 &&
3457 parser
.resolveOperand(operandsInfo
[2], resType
, result
.operands
)) ||
3458 parser
.addTypeToList(resType
, result
.types
));
3461 LogicalResult
OuterProductOp::verify() {
3462 Type tRHS
= getOperandTypeRHS();
3463 VectorType vLHS
= getOperandVectorTypeLHS(),
3464 vRHS
= llvm::dyn_cast
<VectorType
>(tRHS
),
3465 vACC
= getOperandVectorTypeACC(), vRES
= getResultVectorType();
3467 if (vLHS
.getRank() != 1)
3468 return emitOpError("expected 1-d vector for operand #1");
3471 // Proper OUTER operation.
3472 if (vRHS
.getRank() != 1)
3473 return emitOpError("expected 1-d vector for operand #2");
3474 if (vRES
.getRank() != 2)
3475 return emitOpError("expected 2-d vector result");
3476 if (vLHS
.getDimSize(0) != vRES
.getDimSize(0))
3477 return emitOpError("expected #1 operand dim to match result dim #1");
3478 if (vRHS
.getDimSize(0) != vRES
.getDimSize(1))
3479 return emitOpError("expected #2 operand dim to match result dim #2");
3480 if (vLHS
.isScalable() && !vRHS
.isScalable()) {
3481 // This restriction reflects what's currently supported in terms of
3482 // scalable vectors. However, we could relax this if there's a use case.
3484 "expected either both or only #2 operand dim to be scalable");
3487 // An AXPY operation.
3488 if (vRES
.getRank() != 1)
3489 return emitOpError("expected 1-d vector result");
3490 if (vLHS
.getDimSize(0) != vRES
.getDimSize(0))
3491 return emitOpError("expected #1 operand dim to match result dim #1");
3494 if (vACC
&& vACC
!= vRES
)
3495 return emitOpError("expected operand #3 of same type as result type");
3497 // Verify supported combining kind.
3498 if (!isSupportedCombiningKind(getKind(), vRES
.getElementType()))
3499 return emitOpError("unsupported outerproduct type");
3504 // MaskableOpInterface methods.
3506 /// Returns the mask type expected by this operation. Mostly used for
3507 /// verification purposes. It requires the operation to be vectorized."
3508 Type
OuterProductOp::getExpectedMaskType() {
3509 auto vecType
= this->getResultVectorType();
3510 return VectorType::get(vecType
.getShape(),
3511 IntegerType::get(vecType
.getContext(), /*width=*/1),
3512 vecType
.getScalableDims());
3515 //===----------------------------------------------------------------------===//
3516 // ExtractStridedSliceOp
3517 //===----------------------------------------------------------------------===//
3519 // Inference works as follows:
3520 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3521 // 2. Add sizes from 'vectorType' for remaining dims.
3522 // Scalable flags are inherited from 'vectorType'.
3523 static Type
inferStridedSliceOpResultType(VectorType vectorType
,
3524 ArrayAttr offsets
, ArrayAttr sizes
,
3525 ArrayAttr strides
) {
3526 assert(offsets
.size() == sizes
.size() && offsets
.size() == strides
.size());
3527 SmallVector
<int64_t, 4> shape
;
3528 shape
.reserve(vectorType
.getRank());
3530 for (unsigned e
= offsets
.size(); idx
< e
; ++idx
)
3531 shape
.push_back(llvm::cast
<IntegerAttr
>(sizes
[idx
]).getInt());
3532 for (unsigned e
= vectorType
.getShape().size(); idx
< e
; ++idx
)
3533 shape
.push_back(vectorType
.getShape()[idx
]);
3535 return VectorType::get(shape
, vectorType
.getElementType(),
3536 vectorType
.getScalableDims());
3539 void ExtractStridedSliceOp::build(OpBuilder
&builder
, OperationState
&result
,
3540 Value source
, ArrayRef
<int64_t> offsets
,
3541 ArrayRef
<int64_t> sizes
,
3542 ArrayRef
<int64_t> strides
) {
3543 result
.addOperands(source
);
3544 auto offsetsAttr
= getVectorSubscriptAttr(builder
, offsets
);
3545 auto sizesAttr
= getVectorSubscriptAttr(builder
, sizes
);
3546 auto stridesAttr
= getVectorSubscriptAttr(builder
, strides
);
3548 inferStridedSliceOpResultType(llvm::cast
<VectorType
>(source
.getType()),
3549 offsetsAttr
, sizesAttr
, stridesAttr
));
3550 result
.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result
.name
),
3552 result
.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result
.name
),
3554 result
.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result
.name
),
3558 LogicalResult
ExtractStridedSliceOp::verify() {
3559 auto type
= getSourceVectorType();
3560 auto offsets
= getOffsetsAttr();
3561 auto sizes
= getSizesAttr();
3562 auto strides
= getStridesAttr();
3563 if (offsets
.size() != sizes
.size() || offsets
.size() != strides
.size())
3565 "expected offsets, sizes and strides attributes of same size");
3567 auto shape
= type
.getShape();
3568 auto offName
= getOffsetsAttrName();
3569 auto sizesName
= getSizesAttrName();
3570 auto stridesName
= getStridesAttrName();
3572 isIntegerArrayAttrSmallerThanShape(*this, offsets
, shape
, offName
)) ||
3574 isIntegerArrayAttrSmallerThanShape(*this, sizes
, shape
, sizesName
)) ||
3575 failed(isIntegerArrayAttrSmallerThanShape(*this, strides
, shape
,
3578 isIntegerArrayAttrConfinedToShape(*this, offsets
, shape
, offName
)) ||
3579 failed(isIntegerArrayAttrConfinedToShape(*this, sizes
, shape
, sizesName
,
3582 failed(isIntegerArrayAttrConfinedToRange(*this, strides
, /*min=*/1,
3583 /*max=*/1, stridesName
,
3584 /*halfOpen=*/false)) ||
3585 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets
, sizes
,
3586 shape
, offName
, sizesName
,
3587 /*halfOpen=*/false)))
3590 auto resultType
= inferStridedSliceOpResultType(getSourceVectorType(),
3591 offsets
, sizes
, strides
);
3592 if (getResult().getType() != resultType
)
3593 return emitOpError("expected result type to be ") << resultType
;
3595 for (unsigned idx
= 0; idx
< sizes
.size(); ++idx
) {
3596 if (type
.getScalableDims()[idx
]) {
3597 auto inputDim
= type
.getShape()[idx
];
3598 auto inputSize
= llvm::cast
<IntegerAttr
>(sizes
[idx
]).getInt();
3599 if (inputDim
!= inputSize
)
3600 return emitOpError("expected size at idx=")
3602 << (" to match the corresponding base size from the input "
3604 << inputSize
<< (" vs ") << inputDim
<< (")");
3611 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3612 // to use the source of the InsertStrided ops if we can detect that the
3613 // extracted vector is a subset of one of the vector inserted.
3614 static LogicalResult
3615 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op
) {
3616 // Helper to extract integer out of ArrayAttr.
3617 auto getElement
= [](ArrayAttr array
, int idx
) {
3618 return llvm::cast
<IntegerAttr
>(array
[idx
]).getInt();
3620 ArrayAttr extractOffsets
= op
.getOffsets();
3621 ArrayAttr extractStrides
= op
.getStrides();
3622 ArrayAttr extractSizes
= op
.getSizes();
3623 auto insertOp
= op
.getVector().getDefiningOp
<InsertStridedSliceOp
>();
3625 if (op
.getSourceVectorType().getRank() !=
3626 insertOp
.getSourceVectorType().getRank())
3628 ArrayAttr insertOffsets
= insertOp
.getOffsets();
3629 ArrayAttr insertStrides
= insertOp
.getStrides();
3630 // If the rank of extract is greater than the rank of insert, we are likely
3631 // extracting a partial chunk of the vector inserted.
3632 if (extractOffsets
.size() > insertOffsets
.size())
3634 bool patialoverlap
= false;
3635 bool disjoint
= false;
3636 SmallVector
<int64_t, 4> offsetDiffs
;
3637 for (unsigned dim
= 0, e
= extractOffsets
.size(); dim
< e
; ++dim
) {
3638 if (getElement(extractStrides
, dim
) != getElement(insertStrides
, dim
))
3640 int64_t start
= getElement(insertOffsets
, dim
);
3641 int64_t end
= start
+ insertOp
.getSourceVectorType().getDimSize(dim
);
3642 int64_t offset
= getElement(extractOffsets
, dim
);
3643 int64_t size
= getElement(extractSizes
, dim
);
3644 // Check if the start of the extract offset is in the interval inserted.
3645 if (start
<= offset
&& offset
< end
) {
3646 // If the extract interval overlaps but is not fully included we may
3647 // have a partial overlap that will prevent any folding.
3648 if (offset
+ size
> end
)
3649 patialoverlap
= true;
3650 offsetDiffs
.push_back(offset
- start
);
3656 // The extract element chunk is a subset of the insert element.
3657 if (!disjoint
&& !patialoverlap
) {
3658 op
.setOperand(insertOp
.getSource());
3659 // OpBuilder is only used as a helper to build an I64ArrayAttr.
3660 OpBuilder
b(op
.getContext());
3661 op
.setOffsetsAttr(b
.getI64ArrayAttr(offsetDiffs
));
3664 // If the chunk extracted is disjoint from the chunk inserted, keep looking
3665 // in the insert chain.
3667 insertOp
= insertOp
.getDest().getDefiningOp
<InsertStridedSliceOp
>();
3669 // The extracted vector partially overlap the inserted vector, we cannot
3677 OpFoldResult
ExtractStridedSliceOp::fold(FoldAdaptor adaptor
) {
3678 if (getSourceVectorType() == getResult().getType())
3680 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3685 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl
<int64_t> &results
) {
3686 populateFromInt64AttrArray(getOffsets(), results
);
3691 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3693 class StridedSliceConstantMaskFolder final
3694 : public OpRewritePattern
<ExtractStridedSliceOp
> {
3696 using OpRewritePattern::OpRewritePattern
;
3698 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp
,
3699 PatternRewriter
&rewriter
) const override
{
3700 // Return if 'extractStridedSliceOp' operand is not defined by a
3702 auto *defOp
= extractStridedSliceOp
.getVector().getDefiningOp();
3703 auto constantMaskOp
= dyn_cast_or_null
<ConstantMaskOp
>(defOp
);
3704 if (!constantMaskOp
)
3706 // Return if 'extractStridedSliceOp' has non-unit strides.
3707 if (extractStridedSliceOp
.hasNonUnitStrides())
3709 // Gather constant mask dimension sizes.
3710 ArrayRef
<int64_t> maskDimSizes
= constantMaskOp
.getMaskDimSizes();
3711 // Gather strided slice offsets and sizes.
3712 SmallVector
<int64_t, 4> sliceOffsets
;
3713 populateFromInt64AttrArray(extractStridedSliceOp
.getOffsets(),
3715 SmallVector
<int64_t, 4> sliceSizes
;
3716 populateFromInt64AttrArray(extractStridedSliceOp
.getSizes(), sliceSizes
);
3718 // Compute slice of vector mask region.
3719 SmallVector
<int64_t, 4> sliceMaskDimSizes
;
3720 sliceMaskDimSizes
.reserve(maskDimSizes
.size());
3721 for (auto [maskDimSize
, sliceOffset
, sliceSize
] :
3722 llvm::zip(maskDimSizes
, sliceOffsets
, sliceSizes
)) {
3723 int64_t sliceMaskDimSize
= std::max(
3724 static_cast<int64_t>(0),
3725 std::min(sliceOffset
+ sliceSize
, maskDimSize
) - sliceOffset
);
3726 sliceMaskDimSizes
.push_back(sliceMaskDimSize
);
3728 // Add unchanged dimensions.
3729 if (sliceMaskDimSizes
.size() < maskDimSizes
.size())
3730 for (size_t i
= sliceMaskDimSizes
.size(); i
< maskDimSizes
.size(); ++i
)
3731 sliceMaskDimSizes
.push_back(maskDimSizes
[i
]);
3732 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3733 // region is a conjunction of mask dim intervals).
3734 if (llvm::is_contained(sliceMaskDimSizes
, 0))
3735 sliceMaskDimSizes
.assign(maskDimSizes
.size(), 0);
3737 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3739 rewriter
.replaceOpWithNewOp
<ConstantMaskOp
>(
3740 extractStridedSliceOp
, extractStridedSliceOp
.getResult().getType(),
3746 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3747 class StridedSliceSplatConstantFolder final
3748 : public OpRewritePattern
<ExtractStridedSliceOp
> {
3750 using OpRewritePattern::OpRewritePattern
;
3752 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp
,
3753 PatternRewriter
&rewriter
) const override
{
3754 // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3756 Value sourceVector
= extractStridedSliceOp
.getVector();
3757 Attribute vectorCst
;
3758 if (!matchPattern(sourceVector
, m_Constant(&vectorCst
)))
3761 auto splat
= llvm::dyn_cast
<SplatElementsAttr
>(vectorCst
);
3765 auto newAttr
= SplatElementsAttr::get(extractStridedSliceOp
.getType(),
3766 splat
.getSplatValue
<Attribute
>());
3767 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(extractStridedSliceOp
,
3773 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3775 class StridedSliceNonSplatConstantFolder final
3776 : public OpRewritePattern
<ExtractStridedSliceOp
> {
3778 using OpRewritePattern::OpRewritePattern
;
3780 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp
,
3781 PatternRewriter
&rewriter
) const override
{
3782 // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3784 Value sourceVector
= extractStridedSliceOp
.getVector();
3785 Attribute vectorCst
;
3786 if (!matchPattern(sourceVector
, m_Constant(&vectorCst
)))
3789 // The splat case is handled by `StridedSliceSplatConstantFolder`.
3790 auto dense
= llvm::dyn_cast
<DenseElementsAttr
>(vectorCst
);
3791 if (!dense
|| dense
.isSplat())
3794 // TODO: Handle non-unit strides when they become available.
3795 if (extractStridedSliceOp
.hasNonUnitStrides())
3798 auto sourceVecTy
= llvm::cast
<VectorType
>(sourceVector
.getType());
3799 ArrayRef
<int64_t> sourceShape
= sourceVecTy
.getShape();
3800 SmallVector
<int64_t, 4> sourceStrides
= computeStrides(sourceShape
);
3802 VectorType sliceVecTy
= extractStridedSliceOp
.getType();
3803 ArrayRef
<int64_t> sliceShape
= sliceVecTy
.getShape();
3804 int64_t sliceRank
= sliceVecTy
.getRank();
3806 // Expand offsets and sizes to match the vector rank.
3807 SmallVector
<int64_t, 4> offsets(sliceRank
, 0);
3808 copy(getI64SubArray(extractStridedSliceOp
.getOffsets()), offsets
.begin());
3810 SmallVector
<int64_t, 4> sizes(sourceShape
);
3811 copy(getI64SubArray(extractStridedSliceOp
.getSizes()), sizes
.begin());
3813 // Calculate the slice elements by enumerating all slice positions and
3814 // linearizing them. The enumeration order is lexicographic which yields a
3815 // sequence of monotonically increasing linearized position indices.
3816 auto denseValuesBegin
= dense
.value_begin
<Attribute
>();
3817 SmallVector
<Attribute
> sliceValues
;
3818 sliceValues
.reserve(sliceVecTy
.getNumElements());
3819 SmallVector
<int64_t> currSlicePosition(offsets
.begin(), offsets
.end());
3821 int64_t linearizedPosition
= linearize(currSlicePosition
, sourceStrides
);
3822 assert(linearizedPosition
< sourceVecTy
.getNumElements() &&
3824 sliceValues
.push_back(*(denseValuesBegin
+ linearizedPosition
));
3826 succeeded(incSlicePosition(currSlicePosition
, sliceShape
, offsets
)));
3828 assert(static_cast<int64_t>(sliceValues
.size()) ==
3829 sliceVecTy
.getNumElements() &&
3830 "Invalid number of slice elements");
3831 auto newAttr
= DenseElementsAttr::get(sliceVecTy
, sliceValues
);
3832 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(extractStridedSliceOp
,
3838 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3839 // BroadcastOp(ExtractStrideSliceOp).
3840 class StridedSliceBroadcast final
3841 : public OpRewritePattern
<ExtractStridedSliceOp
> {
3843 using OpRewritePattern::OpRewritePattern
;
3845 LogicalResult
matchAndRewrite(ExtractStridedSliceOp op
,
3846 PatternRewriter
&rewriter
) const override
{
3847 auto broadcast
= op
.getVector().getDefiningOp
<BroadcastOp
>();
3851 llvm::dyn_cast
<VectorType
>(broadcast
.getSource().getType());
3852 unsigned srcRank
= srcVecType
? srcVecType
.getRank() : 0;
3853 auto dstVecType
= llvm::cast
<VectorType
>(op
.getType());
3854 unsigned dstRank
= dstVecType
.getRank();
3855 unsigned rankDiff
= dstRank
- srcRank
;
3856 // Check if the most inner dimensions of the source of the broadcast are the
3857 // same as the destination of the extract. If this is the case we can just
3858 // use a broadcast as the original dimensions are untouched.
3859 bool lowerDimMatch
= true;
3860 for (unsigned i
= 0; i
< srcRank
; i
++) {
3861 if (srcVecType
.getDimSize(i
) != dstVecType
.getDimSize(i
+ rankDiff
)) {
3862 lowerDimMatch
= false;
3866 Value source
= broadcast
.getSource();
3867 // If the inner dimensions don't match, it means we need to extract from the
3868 // source of the orignal broadcast and then broadcast the extracted value.
3869 // We also need to handle degenerated cases where the source is effectively
3870 // just a single scalar.
3871 bool isScalarSrc
= (srcRank
== 0 || srcVecType
.getNumElements() == 1);
3872 if (!lowerDimMatch
&& !isScalarSrc
) {
3873 source
= rewriter
.create
<ExtractStridedSliceOp
>(
3874 op
->getLoc(), source
,
3875 getI64SubArray(op
.getOffsets(), /* dropFront=*/rankDiff
),
3876 getI64SubArray(op
.getSizes(), /* dropFront=*/rankDiff
),
3877 getI64SubArray(op
.getStrides(), /* dropFront=*/rankDiff
));
3879 rewriter
.replaceOpWithNewOp
<BroadcastOp
>(op
, op
.getType(), source
);
3884 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3885 class StridedSliceSplat final
: public OpRewritePattern
<ExtractStridedSliceOp
> {
3887 using OpRewritePattern::OpRewritePattern
;
3889 LogicalResult
matchAndRewrite(ExtractStridedSliceOp op
,
3890 PatternRewriter
&rewriter
) const override
{
3891 auto splat
= op
.getVector().getDefiningOp
<SplatOp
>();
3894 rewriter
.replaceOpWithNewOp
<SplatOp
>(op
, op
.getType(), splat
.getInput());
3899 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3900 /// slice is contiguous, into extract and shape_cast.
3904 /// %1 = vector.extract_strided_slice %arg0 {
3905 /// offsets = [0, 0, 0, 0, 0],
3906 /// sizes = [1, 1, 1, 1, 8],
3907 /// strides = [1, 1, 1, 1, 1]
3908 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3910 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3911 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3912 /// %1 = vector.shape_cast %0
3913 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3915 class ContiguousExtractStridedSliceToExtract final
3916 : public OpRewritePattern
<ExtractStridedSliceOp
> {
3918 using OpRewritePattern::OpRewritePattern
;
3920 LogicalResult
matchAndRewrite(ExtractStridedSliceOp op
,
3921 PatternRewriter
&rewriter
) const override
{
3922 if (op
.hasNonUnitStrides())
3924 Value source
= op
.getOperand();
3925 auto sourceType
= cast
<VectorType
>(source
.getType());
3926 if (sourceType
.isScalable() || sourceType
.getRank() == 0)
3929 // Compute the number of offsets to pass to ExtractOp::build. That is the
3930 // difference between the source rank and the desired slice rank. We walk
3931 // the dimensions from innermost out, and stop when the next slice dimension
3932 // is not full-size.
3933 SmallVector
<int64_t> sizes
= getI64SubArray(op
.getSizes());
3935 for (numOffsets
= sizes
.size(); numOffsets
> 0; --numOffsets
) {
3936 if (sizes
[numOffsets
- 1] != sourceType
.getDimSize(numOffsets
- 1))
3940 // If the created extract op would have no offsets, then this whole
3941 // extract_strided_slice is the identity and should have been handled by
3942 // other canonicalizations.
3943 if (numOffsets
== 0)
3946 // If not even the inner-most dimension is full-size, this op can't be
3947 // rewritten as an ExtractOp.
3948 if (numOffsets
== sourceType
.getRank() &&
3949 static_cast<int>(sizes
.size()) == sourceType
.getRank())
3952 // The outer dimensions must have unit size.
3953 for (int i
= 0; i
< numOffsets
; ++i
) {
3958 // Avoid generating slices that have leading unit dimensions. The shape_cast
3959 // op that we create below would take bad generic fallback patterns
3960 // (ShapeCastOpRewritePattern).
3961 while (sizes
[numOffsets
] == 1 &&
3962 numOffsets
< static_cast<int>(sizes
.size()) - 1) {
3966 SmallVector
<int64_t> offsets
= getI64SubArray(op
.getOffsets());
3967 auto extractOffsets
= ArrayRef(offsets
).take_front(numOffsets
);
3968 Value extract
= rewriter
.create
<vector::ExtractOp
>(op
->getLoc(), source
,
3970 rewriter
.replaceOpWithNewOp
<vector::ShapeCastOp
>(op
, op
.getType(), extract
);
3977 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3978 RewritePatternSet
&results
, MLIRContext
*context
) {
3979 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3980 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3981 results
.add
<StridedSliceConstantMaskFolder
, StridedSliceSplatConstantFolder
,
3982 StridedSliceNonSplatConstantFolder
, StridedSliceBroadcast
,
3983 StridedSliceSplat
, ContiguousExtractStridedSliceToExtract
>(
3987 //===----------------------------------------------------------------------===//
3989 //===----------------------------------------------------------------------===//
3991 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3992 void TransferReadOp::build(OpBuilder
&builder
, OperationState
&result
,
3993 VectorType vectorType
, Value source
,
3994 ValueRange indices
, AffineMapAttr permutationMapAttr
,
3995 /*optional*/ ArrayAttr inBoundsAttr
) {
3996 Type elemType
= llvm::cast
<ShapedType
>(source
.getType()).getElementType();
3997 Value padding
= builder
.create
<arith::ConstantOp
>(
3998 result
.location
, elemType
, builder
.getZeroAttr(elemType
));
3999 build(builder
, result
, vectorType
, source
, indices
, permutationMapAttr
,
4000 padding
, /*mask=*/Value(), inBoundsAttr
);
4003 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4004 void TransferReadOp::build(OpBuilder
&builder
, OperationState
&result
,
4005 VectorType vectorType
, Value source
,
4006 ValueRange indices
, AffineMap permutationMap
,
4007 std::optional
<ArrayRef
<bool>> inBounds
) {
4008 auto permutationMapAttr
= AffineMapAttr::get(permutationMap
);
4009 auto inBoundsAttr
= (inBounds
&& !inBounds
.value().empty())
4010 ? builder
.getBoolArrayAttr(inBounds
.value())
4011 : builder
.getBoolArrayAttr(
4012 SmallVector
<bool>(vectorType
.getRank(), false));
4013 build(builder
, result
, vectorType
, source
, indices
, permutationMapAttr
,
4017 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4018 void TransferReadOp::build(OpBuilder
&builder
, OperationState
&result
,
4019 VectorType vectorType
, Value source
,
4020 ValueRange indices
, Value padding
,
4021 std::optional
<ArrayRef
<bool>> inBounds
) {
4022 AffineMap permutationMap
= getTransferMinorIdentityMap(
4023 llvm::cast
<ShapedType
>(source
.getType()), vectorType
);
4024 auto permutationMapAttr
= AffineMapAttr::get(permutationMap
);
4025 auto inBoundsAttr
= (inBounds
&& !inBounds
.value().empty())
4026 ? builder
.getBoolArrayAttr(inBounds
.value())
4027 : builder
.getBoolArrayAttr(
4028 SmallVector
<bool>(vectorType
.getRank(), false));
4029 build(builder
, result
, vectorType
, source
, indices
, permutationMapAttr
,
4031 /*mask=*/Value(), inBoundsAttr
);
4034 /// 4. Builder that sets padding to zero and permutation map to
4035 /// 'getMinorIdentityMap'.
4036 void TransferReadOp::build(OpBuilder
&builder
, OperationState
&result
,
4037 VectorType vectorType
, Value source
,
4039 std::optional
<ArrayRef
<bool>> inBounds
) {
4040 Type elemType
= llvm::cast
<ShapedType
>(source
.getType()).getElementType();
4041 Value padding
= builder
.create
<arith::ConstantOp
>(
4042 result
.location
, elemType
, builder
.getZeroAttr(elemType
));
4043 build(builder
, result
, vectorType
, source
, indices
, padding
, inBounds
);
4046 template <typename EmitFun
>
4047 static LogicalResult
verifyPermutationMap(AffineMap permutationMap
,
4048 EmitFun emitOpError
) {
4049 SmallVector
<bool, 8> seen(permutationMap
.getNumInputs(), false);
4050 for (auto expr
: permutationMap
.getResults()) {
4051 auto dim
= dyn_cast
<AffineDimExpr
>(expr
);
4052 auto zero
= dyn_cast
<AffineConstantExpr
>(expr
);
4054 if (zero
.getValue() != 0) {
4056 "requires a projected permutation_map (at most one dim or the zero "
4057 "constant can appear in each result)");
4062 return emitOpError("requires a projected permutation_map (at most one "
4063 "dim or the zero constant can appear in each result)");
4065 if (seen
[dim
.getPosition()]) {
4067 "requires a permutation_map that is a permutation (found one dim "
4068 "used more than once)");
4070 seen
[dim
.getPosition()] = true;
4075 static LogicalResult
4076 verifyTransferOp(VectorTransferOpInterface op
, ShapedType shapedType
,
4077 VectorType vectorType
, VectorType maskType
,
4078 VectorType inferredMaskType
, AffineMap permutationMap
,
4079 ArrayAttr inBounds
) {
4080 if (op
->hasAttr("masked")) {
4081 return op
->emitOpError("masked attribute has been removed. "
4082 "Use in_bounds instead.");
4085 if (!llvm::isa
<MemRefType
, RankedTensorType
>(shapedType
))
4086 return op
->emitOpError(
4087 "requires source to be a memref or ranked tensor type");
4089 auto elementType
= shapedType
.getElementType();
4090 DataLayout dataLayout
= DataLayout::closest(op
);
4091 if (auto vectorElementType
= llvm::dyn_cast
<VectorType
>(elementType
)) {
4092 // Memref or tensor has vector element type.
4093 unsigned sourceVecSize
=
4094 dataLayout
.getTypeSizeInBits(vectorElementType
.getElementType()) *
4095 vectorElementType
.getShape().back();
4096 unsigned resultVecSize
=
4097 dataLayout
.getTypeSizeInBits(vectorType
.getElementType()) *
4098 vectorType
.getShape().back();
4099 if (resultVecSize
% sourceVecSize
!= 0)
4100 return op
->emitOpError(
4101 "requires the bitwidth of the minor 1-D vector to be an integral "
4102 "multiple of the bitwidth of the minor 1-D vector of the source");
4104 unsigned sourceVecEltRank
= vectorElementType
.getRank();
4105 unsigned resultVecRank
= vectorType
.getRank();
4106 if (sourceVecEltRank
> resultVecRank
)
4107 return op
->emitOpError(
4108 "requires source vector element and vector result ranks to match.");
4109 unsigned rankOffset
= resultVecRank
- sourceVecEltRank
;
4110 // Check that permutation map results match 'rankOffset' of vector type.
4111 if (permutationMap
.getNumResults() != rankOffset
)
4112 return op
->emitOpError("requires a permutation_map with result dims of "
4113 "the same rank as the vector type");
4116 return op
->emitOpError("does not support masks with vector element type");
4118 // Memref or tensor has scalar element type.
4119 unsigned minorSize
=
4120 vectorType
.getRank() == 0 ? 1 : vectorType
.getShape().back();
4121 unsigned resultVecSize
=
4122 dataLayout
.getTypeSizeInBits(vectorType
.getElementType()) * minorSize
;
4123 if (resultVecSize
% dataLayout
.getTypeSizeInBits(elementType
) != 0)
4124 return op
->emitOpError(
4125 "requires the bitwidth of the minor 1-D vector to be an integral "
4126 "multiple of the bitwidth of the source element type");
4128 // Check that permutation map results match rank of vector type.
4129 if (permutationMap
.getNumResults() != vectorType
.getRank())
4130 return op
->emitOpError("requires a permutation_map with result dims of "
4131 "the same rank as the vector type");
4134 if (permutationMap
.getNumSymbols() != 0)
4135 return op
->emitOpError("requires permutation_map without symbols");
4137 if (permutationMap
.getNumInputs() != shapedType
.getRank())
4138 return op
->emitOpError("requires a permutation_map with input dims of the "
4139 "same rank as the source type");
4141 if (maskType
&& maskType
!= inferredMaskType
)
4142 return op
->emitOpError("inferred mask type (")
4143 << inferredMaskType
<< ") and mask operand type (" << maskType
4146 if (permutationMap
.getNumResults() != static_cast<int64_t>(inBounds
.size()))
4147 return op
->emitOpError("expects the in_bounds attr of same rank "
4148 "as permutation_map results: ")
4149 << AffineMapAttr::get(permutationMap
)
4150 << " vs inBounds of size: " << inBounds
.size();
4155 static void printTransferAttrs(OpAsmPrinter
&p
, VectorTransferOpInterface op
) {
4156 SmallVector
<StringRef
, 3> elidedAttrs
;
4157 elidedAttrs
.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4158 if (op
.getPermutationMap().isMinorIdentity())
4159 elidedAttrs
.push_back(op
.getPermutationMapAttrName());
4160 // Elide in_bounds attribute if all dims are out-of-bounds.
4161 if (llvm::none_of(op
.getInBoundsValues(), [](bool b
) { return b
; }))
4162 elidedAttrs
.push_back(op
.getInBoundsAttrName());
4163 p
.printOptionalAttrDict(op
->getAttrs(), elidedAttrs
);
4166 void TransferReadOp::print(OpAsmPrinter
&p
) {
4167 p
<< " " << getSource() << "[" << getIndices() << "], " << getPadding();
4169 p
<< ", " << getMask();
4170 printTransferAttrs(p
, *this);
4171 p
<< " : " << getShapedType() << ", " << getVectorType();
4174 VectorType
mlir::vector::inferTransferOpMaskType(VectorType vecType
,
4175 AffineMap permMap
) {
4176 auto i1Type
= IntegerType::get(permMap
.getContext(), 1);
4177 AffineMap invPermMap
= inversePermutation(compressUnusedDims(permMap
));
4178 assert(invPermMap
&& "Inversed permutation map couldn't be computed");
4179 SmallVector
<int64_t, 8> maskShape
= invPermMap
.compose(vecType
.getShape());
4181 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4182 // 0-D mask into a single-element 1-D mask.
4183 if (maskShape
.empty())
4184 maskShape
.push_back(1);
4186 SmallVector
<bool> scalableDims
=
4187 applyPermutationMap(invPermMap
, vecType
.getScalableDims());
4189 return VectorType::get(maskShape
, i1Type
, scalableDims
);
4192 ParseResult
TransferReadOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
4193 auto &builder
= parser
.getBuilder();
4195 OpAsmParser::UnresolvedOperand sourceInfo
;
4196 SmallVector
<OpAsmParser::UnresolvedOperand
, 8> indexInfo
;
4197 OpAsmParser::UnresolvedOperand paddingInfo
;
4198 SmallVector
<Type
, 2> types
;
4199 OpAsmParser::UnresolvedOperand maskInfo
;
4200 // Parsing with support for paddingValue.
4201 if (parser
.parseOperand(sourceInfo
) ||
4202 parser
.parseOperandList(indexInfo
, OpAsmParser::Delimiter::Square
) ||
4203 parser
.parseComma() || parser
.parseOperand(paddingInfo
))
4205 ParseResult hasMask
= parser
.parseOptionalComma();
4206 if (hasMask
.succeeded()) {
4207 if (parser
.parseOperand(maskInfo
))
4210 if (parser
.parseOptionalAttrDict(result
.attributes
) ||
4211 parser
.getCurrentLocation(&typesLoc
) || parser
.parseColonTypeList(types
))
4213 if (types
.size() != 2)
4214 return parser
.emitError(typesLoc
, "requires two types");
4215 auto indexType
= builder
.getIndexType();
4216 auto shapedType
= llvm::dyn_cast
<ShapedType
>(types
[0]);
4217 if (!shapedType
|| !llvm::isa
<MemRefType
, RankedTensorType
>(shapedType
))
4218 return parser
.emitError(typesLoc
, "requires memref or ranked tensor type");
4219 VectorType vectorType
= llvm::dyn_cast
<VectorType
>(types
[1]);
4221 return parser
.emitError(typesLoc
, "requires vector type");
4222 auto permMapAttrName
= TransferReadOp::getPermutationMapAttrName(result
.name
);
4223 Attribute permMapAttr
= result
.attributes
.get(permMapAttrName
);
4226 permMap
= getTransferMinorIdentityMap(shapedType
, vectorType
);
4227 result
.attributes
.set(permMapAttrName
, AffineMapAttr::get(permMap
));
4229 permMap
= llvm::cast
<AffineMapAttr
>(permMapAttr
).getValue();
4231 auto inBoundsAttrName
= TransferReadOp::getInBoundsAttrName(result
.name
);
4232 Attribute inBoundsAttr
= result
.attributes
.get(inBoundsAttrName
);
4233 if (!inBoundsAttr
) {
4234 result
.addAttribute(inBoundsAttrName
,
4235 builder
.getBoolArrayAttr(
4236 SmallVector
<bool>(permMap
.getNumResults(), false)));
4238 if (parser
.resolveOperand(sourceInfo
, shapedType
, result
.operands
) ||
4239 parser
.resolveOperands(indexInfo
, indexType
, result
.operands
) ||
4240 parser
.resolveOperand(paddingInfo
, shapedType
.getElementType(),
4243 if (hasMask
.succeeded()) {
4244 if (llvm::dyn_cast
<VectorType
>(shapedType
.getElementType()))
4245 return parser
.emitError(
4246 maskInfo
.location
, "does not support masks with vector element type");
4247 if (vectorType
.getRank() != permMap
.getNumResults()) {
4248 return parser
.emitError(typesLoc
,
4249 "expected the same rank for the vector and the "
4250 "results of the permutation map");
4252 // Instead of adding the mask type as an op type, compute it based on the
4253 // vector type and the permutation map (to keep the type signature small).
4254 auto maskType
= inferTransferOpMaskType(vectorType
, permMap
);
4255 if (parser
.resolveOperand(maskInfo
, maskType
, result
.operands
))
4258 result
.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4259 builder
.getDenseI32ArrayAttr(
4260 {1, static_cast<int32_t>(indexInfo
.size()), 1,
4261 static_cast<int32_t>(hasMask
.succeeded())}));
4262 return parser
.addTypeToList(vectorType
, result
.types
);
4265 LogicalResult
TransferReadOp::verify() {
4266 // Consistency of elemental types in source and vector.
4267 ShapedType shapedType
= getShapedType();
4268 VectorType vectorType
= getVectorType();
4269 VectorType maskType
= getMaskType();
4270 auto paddingType
= getPadding().getType();
4271 auto permutationMap
= getPermutationMap();
4272 VectorType inferredMaskType
=
4273 maskType
? inferTransferOpMaskType(vectorType
, permutationMap
)
4275 auto sourceElementType
= shapedType
.getElementType();
4277 if (static_cast<int64_t>(getIndices().size()) != shapedType
.getRank())
4278 return emitOpError("requires ") << shapedType
.getRank() << " indices";
4280 if (failed(verifyTransferOp(cast
<VectorTransferOpInterface
>(getOperation()),
4281 shapedType
, vectorType
, maskType
,
4282 inferredMaskType
, permutationMap
, getInBounds())))
4285 if (auto sourceVectorElementType
=
4286 llvm::dyn_cast
<VectorType
>(sourceElementType
)) {
4287 // Source has vector element type.
4288 // Check that 'sourceVectorElementType' and 'paddingType' types match.
4289 if (sourceVectorElementType
!= paddingType
)
4291 "requires source element type and padding type to match.");
4294 // Check that 'paddingType' is valid to store in a vector type.
4295 if (!VectorType::isValidElementType(paddingType
))
4296 return emitOpError("requires valid padding vector elemental type");
4298 // Check that padding type and vector element types match.
4299 if (paddingType
!= sourceElementType
)
4301 "requires formal padding and source of the same elemental type");
4304 return verifyPermutationMap(permutationMap
,
4305 [&](Twine t
) { return emitOpError(t
); });
4308 // MaskableOpInterface methods.
4310 /// Returns the mask type expected by this operation. Mostly used for
4311 /// verification purposes. It requires the operation to be vectorized."
4312 Type
TransferReadOp::getExpectedMaskType() {
4313 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4316 template <typename TransferOp
>
4317 static bool isInBounds(TransferOp op
, int64_t resultIdx
, int64_t indicesIdx
) {
4318 // TODO: support more aggressive createOrFold on:
4319 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4320 if (op
.getShapedType().isDynamicDim(indicesIdx
))
4322 Value index
= op
.getIndices()[indicesIdx
];
4323 std::optional
<int64_t> cstOp
= getConstantIntValue(index
);
4324 if (!cstOp
.has_value())
4327 int64_t sourceSize
= op
.getShapedType().getDimSize(indicesIdx
);
4328 int64_t vectorSize
= op
.getVectorType().getDimSize(resultIdx
);
4330 return cstOp
.value() + vectorSize
<= sourceSize
;
4333 template <typename TransferOp
>
4334 static LogicalResult
foldTransferInBoundsAttribute(TransferOp op
) {
4335 // TODO: support 0-d corner case.
4336 // TODO: Be less conservative.
4337 if (op
.getTransferRank() == 0)
4339 AffineMap permutationMap
= op
.getPermutationMap();
4340 bool changed
= false;
4341 SmallVector
<bool, 4> newInBounds
;
4342 newInBounds
.reserve(op
.getTransferRank());
4343 // Idxs of non-bcast dims - used when analysing bcast dims.
4344 SmallVector
<unsigned> nonBcastDims
;
4346 // 1. Process non-broadcast dims
4347 for (unsigned i
= 0; i
< op
.getTransferRank(); ++i
) {
4348 // 1.1. Already marked as in-bounds, nothing to see here.
4349 if (op
.isDimInBounds(i
)) {
4350 newInBounds
.push_back(true);
4353 // 1.2. Currently out-of-bounds, check whether we can statically determine
4355 bool inBounds
= false;
4356 auto dimExpr
= dyn_cast
<AffineDimExpr
>(permutationMap
.getResult(i
));
4358 inBounds
= isInBounds(op
, /*resultIdx=*/i
,
4359 /*indicesIdx=*/dimExpr
.getPosition());
4360 nonBcastDims
.push_back(i
);
4363 newInBounds
.push_back(inBounds
);
4364 // We commit the pattern if it is "more inbounds".
4365 changed
|= inBounds
;
4368 // 2. Handle broadcast dims
4369 // If all non-broadcast dims are "in bounds", then all bcast dims should be
4370 // "in bounds" as well.
4371 bool allNonBcastDimsInBounds
= llvm::all_of(
4372 nonBcastDims
, [&newInBounds
](unsigned idx
) { return newInBounds
[idx
]; });
4373 if (allNonBcastDimsInBounds
) {
4374 for (size_t idx
: permutationMap
.getBroadcastDims()) {
4375 changed
|= !newInBounds
[idx
];
4376 newInBounds
[idx
] = true;
4382 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4383 OpBuilder
b(op
.getContext());
4384 op
.setInBoundsAttr(b
.getBoolArrayAttr(newInBounds
));
4388 template <typename TransferOp
>
4389 static LogicalResult
foldTransferFullMask(TransferOp op
) {
4390 auto mask
= op
.getMask();
4394 if (getMaskFormat(mask
) != MaskFormat::AllTrue
)
4397 op
.getMaskMutable().clear();
4402 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4403 /// : vector<1x4xf32>, tensor<4x4xf32>
4404 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4405 /// : tensor<4x4xf32>, vector<1x4xf32>
4411 static Value
foldRAW(TransferReadOp readOp
) {
4412 if (!llvm::isa
<RankedTensorType
>(readOp
.getShapedType()))
4414 auto defWrite
= readOp
.getSource().getDefiningOp
<vector::TransferWriteOp
>();
4416 if (checkSameValueRAW(defWrite
, readOp
))
4417 return defWrite
.getVector();
4418 if (!isDisjointTransferIndices(
4419 cast
<VectorTransferOpInterface
>(defWrite
.getOperation()),
4420 cast
<VectorTransferOpInterface
>(readOp
.getOperation())))
4422 defWrite
= defWrite
.getSource().getDefiningOp
<vector::TransferWriteOp
>();
4427 OpFoldResult
TransferReadOp::fold(FoldAdaptor
) {
4428 if (Value vec
= foldRAW(*this))
4430 /// transfer_read(memrefcast) -> transfer_read
4431 if (succeeded(foldTransferInBoundsAttribute(*this)))
4433 if (succeeded(foldTransferFullMask(*this)))
4435 if (succeeded(memref::foldMemRefCast(*this)))
4437 if (succeeded(tensor::foldTensorCast(*this)))
4439 return OpFoldResult();
4442 std::optional
<SmallVector
<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4443 return llvm::to_vector
<4>(getVectorType().getShape());
4446 void TransferReadOp::getEffects(
4447 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
4449 if (llvm::isa
<MemRefType
>(getShapedType()))
4450 effects
.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4451 SideEffects::DefaultResource::get());
4454 Speculation::Speculatability
TransferReadOp::getSpeculatability() {
4455 if (hasPureTensorSemantics())
4456 return Speculation::Speculatable
;
4457 return Speculation::NotSpeculatable
;
4461 /// Store to load forwarding for transfer operations with permuation maps.
4462 /// Even if the permutation maps are different we can still propagate the store
4463 /// into the load if the size of the dimensions read and written match. Then we
4464 /// can replace the transfer_read + transfer_write by vector.broadcast and
4465 /// vector.transpose.
4468 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4469 /// {in_bounds = [true, true],
4470 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4471 /// vector<4x1xf32>, tensor<4x4x4xf32>
4472 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4473 /// {in_bounds = [true, true, true, true],
4474 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4475 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4479 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4480 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4481 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4483 struct TransferReadAfterWriteToBroadcast
4484 : public OpRewritePattern
<TransferReadOp
> {
4485 using OpRewritePattern::OpRewritePattern
;
4487 LogicalResult
matchAndRewrite(TransferReadOp readOp
,
4488 PatternRewriter
&rewriter
) const override
{
4489 if (readOp
.hasOutOfBoundsDim() ||
4490 !llvm::isa
<RankedTensorType
>(readOp
.getShapedType()))
4492 auto defWrite
= readOp
.getSource().getDefiningOp
<vector::TransferWriteOp
>();
4495 // TODO: If the written transfer chunk is a superset of the read transfer
4496 // chunk we could do an extract_strided_slice.
4497 if (readOp
.getTransferChunkAccessed() !=
4498 defWrite
.getTransferChunkAccessed())
4500 // TODO: Support cases where a dim is explicitly written but implicitly
4501 // read (i.e., a unit dim that is rank reduced).
4502 if (getUnusedDimsBitVector({readOp
.getPermutationMap()}) !=
4503 getUnusedDimsBitVector({defWrite
.getPermutationMap()}))
4505 if (readOp
.getIndices() != defWrite
.getIndices() ||
4506 readOp
.getMask() != defWrite
.getMask())
4508 Value vec
= defWrite
.getVector();
4509 // TODO: loop through the chain of transfer_write if we can prove that they
4510 // don't overlap with the transfer_read. This requires improving
4511 // `isDisjointTransferIndices` helper.
4512 AffineMap readMap
= compressUnusedDims(readOp
.getPermutationMap());
4513 AffineMap writeMap
= compressUnusedDims(defWrite
.getPermutationMap());
4514 AffineMap map
= readMap
.compose(writeMap
);
4515 if (map
.getNumResults() == 0)
4517 // Calculate the permutation to apply to go from the vector stored to the
4519 SmallVector
<unsigned> permutation
;
4520 if (!map
.isPermutationOfMinorIdentityWithBroadcasting(permutation
))
4523 Location loc
= readOp
.getLoc();
4524 // Calculate the broadcast shape by applying the reverse permutation to the
4525 // final shape we want.
4526 ArrayRef
<int64_t> destShape
= readOp
.getVectorType().getShape();
4527 SmallVector
<int64_t> broadcastShape(destShape
.size());
4528 SmallVector
<bool> broadcastScalableFlags(destShape
.size());
4529 for (const auto &pos
: llvm::enumerate(permutation
)) {
4530 broadcastShape
[pos
.value()] = destShape
[pos
.index()];
4531 broadcastScalableFlags
[pos
.value()] =
4532 readOp
.getVectorType().getScalableDims()[pos
.index()];
4534 VectorType broadcastedType
= VectorType::get(
4535 broadcastShape
, defWrite
.getVectorType().getElementType(),
4536 broadcastScalableFlags
);
4537 vec
= rewriter
.create
<vector::BroadcastOp
>(loc
, broadcastedType
, vec
);
4538 SmallVector
<int64_t> transposePerm(permutation
.begin(), permutation
.end());
4539 rewriter
.replaceOpWithNewOp
<vector::TransposeOp
>(readOp
, vec
,
4546 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
4547 MLIRContext
*context
) {
4548 results
.add
<TransferReadAfterWriteToBroadcast
>(context
);
4551 //===----------------------------------------------------------------------===//
4553 //===----------------------------------------------------------------------===//
4555 /// 1. Builder with type inference.
4556 void TransferWriteOp::build(OpBuilder
&builder
, OperationState
&result
,
4557 Value vector
, Value dest
, ValueRange indices
,
4558 AffineMapAttr permutationMapAttr
,
4559 /*optional*/ Value mask
,
4560 /*optional*/ ArrayAttr inBoundsAttr
) {
4561 Type resultType
= llvm::dyn_cast
<RankedTensorType
>(dest
.getType());
4562 build(builder
, result
, resultType
, vector
, dest
, indices
, permutationMapAttr
,
4563 mask
, inBoundsAttr
);
4566 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4567 void TransferWriteOp::build(OpBuilder
&builder
, OperationState
&result
,
4568 Value vector
, Value dest
, ValueRange indices
,
4569 AffineMapAttr permutationMapAttr
,
4570 /*optional*/ ArrayAttr inBoundsAttr
) {
4571 build(builder
, result
, vector
, dest
, indices
, permutationMapAttr
,
4572 /*mask=*/Value(), inBoundsAttr
);
4575 /// 3. Builder with type inference that sets an empty mask (variant without
4577 void TransferWriteOp::build(OpBuilder
&builder
, OperationState
&result
,
4578 Value vector
, Value dest
, ValueRange indices
,
4579 AffineMap permutationMap
,
4580 std::optional
<ArrayRef
<bool>> inBounds
) {
4581 auto permutationMapAttr
= AffineMapAttr::get(permutationMap
);
4583 (inBounds
&& !inBounds
.value().empty())
4584 ? builder
.getBoolArrayAttr(inBounds
.value())
4585 : builder
.getBoolArrayAttr(SmallVector
<bool>(
4586 llvm::cast
<VectorType
>(vector
.getType()).getRank(), false));
4587 build(builder
, result
, vector
, dest
, indices
, permutationMapAttr
,
4588 /*mask=*/Value(), inBoundsAttr
);
4591 /// 4. Builder with type inference that sets an empty mask and sets permutation
4592 /// map to 'getMinorIdentityMap'.
4593 void TransferWriteOp::build(OpBuilder
&builder
, OperationState
&result
,
4594 Value vector
, Value dest
, ValueRange indices
,
4595 std::optional
<ArrayRef
<bool>> inBounds
) {
4596 auto vectorType
= llvm::cast
<VectorType
>(vector
.getType());
4597 AffineMap permutationMap
= getTransferMinorIdentityMap(
4598 llvm::cast
<ShapedType
>(dest
.getType()), vectorType
);
4599 build(builder
, result
, vector
, dest
, indices
, permutationMap
, inBounds
);
4602 ParseResult
TransferWriteOp::parse(OpAsmParser
&parser
,
4603 OperationState
&result
) {
4604 auto &builder
= parser
.getBuilder();
4606 OpAsmParser::UnresolvedOperand vectorInfo
, sourceInfo
;
4607 SmallVector
<OpAsmParser::UnresolvedOperand
, 8> indexInfo
;
4608 SmallVector
<Type
, 2> types
;
4609 OpAsmParser::UnresolvedOperand maskInfo
;
4610 if (parser
.parseOperand(vectorInfo
) || parser
.parseComma() ||
4611 parser
.parseOperand(sourceInfo
) ||
4612 parser
.parseOperandList(indexInfo
, OpAsmParser::Delimiter::Square
))
4614 ParseResult hasMask
= parser
.parseOptionalComma();
4615 if (hasMask
.succeeded() && parser
.parseOperand(maskInfo
))
4617 if (parser
.parseOptionalAttrDict(result
.attributes
) ||
4618 parser
.getCurrentLocation(&typesLoc
) || parser
.parseColonTypeList(types
))
4620 if (types
.size() != 2)
4621 return parser
.emitError(typesLoc
, "requires two types");
4622 auto indexType
= builder
.getIndexType();
4623 VectorType vectorType
= llvm::dyn_cast
<VectorType
>(types
[0]);
4625 return parser
.emitError(typesLoc
, "requires vector type");
4626 ShapedType shapedType
= llvm::dyn_cast
<ShapedType
>(types
[1]);
4627 if (!shapedType
|| !llvm::isa
<MemRefType
, RankedTensorType
>(shapedType
))
4628 return parser
.emitError(typesLoc
, "requires memref or ranked tensor type");
4629 auto permMapAttrName
=
4630 TransferWriteOp::getPermutationMapAttrName(result
.name
);
4631 auto permMapAttr
= result
.attributes
.get(permMapAttrName
);
4634 permMap
= getTransferMinorIdentityMap(shapedType
, vectorType
);
4635 result
.attributes
.set(permMapAttrName
, AffineMapAttr::get(permMap
));
4637 permMap
= llvm::cast
<AffineMapAttr
>(permMapAttr
).getValue();
4639 auto inBoundsAttrName
= TransferWriteOp::getInBoundsAttrName(result
.name
);
4640 Attribute inBoundsAttr
= result
.attributes
.get(inBoundsAttrName
);
4641 if (!inBoundsAttr
) {
4642 result
.addAttribute(inBoundsAttrName
,
4643 builder
.getBoolArrayAttr(
4644 SmallVector
<bool>(permMap
.getNumResults(), false)));
4646 if (parser
.resolveOperand(vectorInfo
, vectorType
, result
.operands
) ||
4647 parser
.resolveOperand(sourceInfo
, shapedType
, result
.operands
) ||
4648 parser
.resolveOperands(indexInfo
, indexType
, result
.operands
))
4650 if (hasMask
.succeeded()) {
4651 if (llvm::dyn_cast
<VectorType
>(shapedType
.getElementType()))
4652 return parser
.emitError(
4653 maskInfo
.location
, "does not support masks with vector element type");
4654 if (vectorType
.getRank() != permMap
.getNumResults()) {
4655 return parser
.emitError(typesLoc
,
4656 "expected the same rank for the vector and the "
4657 "results of the permutation map");
4659 auto maskType
= inferTransferOpMaskType(vectorType
, permMap
);
4660 if (parser
.resolveOperand(maskInfo
, maskType
, result
.operands
))
4663 result
.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4664 builder
.getDenseI32ArrayAttr(
4665 {1, 1, static_cast<int32_t>(indexInfo
.size()),
4666 static_cast<int32_t>(hasMask
.succeeded())}));
4667 return failure(llvm::isa
<RankedTensorType
>(shapedType
) &&
4668 parser
.addTypeToList(shapedType
, result
.types
));
4671 void TransferWriteOp::print(OpAsmPrinter
&p
) {
4672 p
<< " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4674 p
<< ", " << getMask();
4675 printTransferAttrs(p
, *this);
4676 p
<< " : " << getVectorType() << ", " << getShapedType();
4679 LogicalResult
TransferWriteOp::verify() {
4680 // Consistency of elemental types in shape and vector.
4681 ShapedType shapedType
= getShapedType();
4682 VectorType vectorType
= getVectorType();
4683 VectorType maskType
= getMaskType();
4684 auto permutationMap
= getPermutationMap();
4685 VectorType inferredMaskType
=
4686 maskType
? inferTransferOpMaskType(vectorType
, permutationMap
)
4689 if (llvm::size(getIndices()) != shapedType
.getRank())
4690 return emitOpError("requires ") << shapedType
.getRank() << " indices";
4692 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4693 // as the semantics is unclear. This can be revisited later if necessary.
4694 if (hasBroadcastDim())
4695 return emitOpError("should not have broadcast dimensions");
4697 if (failed(verifyTransferOp(cast
<VectorTransferOpInterface
>(getOperation()),
4698 shapedType
, vectorType
, maskType
,
4699 inferredMaskType
, permutationMap
, getInBounds())))
4702 return verifyPermutationMap(permutationMap
,
4703 [&](Twine t
) { return emitOpError(t
); });
4706 // MaskableOpInterface methods.
4708 /// Returns the mask type expected by this operation. Mostly used for
4709 /// verification purposes.
4710 Type
TransferWriteOp::getExpectedMaskType() {
4711 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4717 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4718 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4719 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4720 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4729 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4730 /// block argument or has side effects.
4731 static LogicalResult
foldReadInitWrite(TransferWriteOp write
,
4732 ArrayRef
<Attribute
>,
4733 SmallVectorImpl
<OpFoldResult
> &results
) {
4734 // TODO: support 0-d corner case.
4735 if (write
.getTransferRank() == 0)
4737 auto rankedTensorType
=
4738 llvm::dyn_cast
<RankedTensorType
>(write
.getSource().getType());
4739 // If not operating on tensors, bail.
4740 if (!rankedTensorType
)
4742 // If no read, bail.
4743 auto read
= write
.getVector().getDefiningOp
<vector::TransferReadOp
>();
4746 // TODO: support 0-d corner case.
4747 if (read
.getTransferRank() == 0)
4749 // For now, only accept minor identity. Future: composition is minor identity.
4750 if (!read
.getPermutationMap().isMinorIdentity() ||
4751 !write
.getPermutationMap().isMinorIdentity())
4753 // Bail on mismatching ranks.
4754 if (read
.getTransferRank() != write
.getTransferRank())
4756 // Bail on potential out-of-bounds accesses.
4757 if (read
.hasOutOfBoundsDim() || write
.hasOutOfBoundsDim())
4759 // Tensor types must be the same.
4760 if (read
.getSource().getType() != rankedTensorType
)
4762 // Vector types must be the same.
4763 if (read
.getVectorType() != write
.getVectorType())
4765 // Vector and Tensor shapes must match.
4766 if (read
.getVectorType().getShape() != rankedTensorType
.getShape())
4768 // If any index is nonzero.
4769 auto isNotConstantZero
= [](Value v
) {
4770 auto cstOp
= getConstantIntValue(v
);
4771 return !cstOp
.has_value() || cstOp
.value() != 0;
4773 if (llvm::any_of(read
.getIndices(), isNotConstantZero
) ||
4774 llvm::any_of(write
.getIndices(), isNotConstantZero
))
4777 results
.push_back(read
.getSource());
4781 static bool checkSameValueWAR(vector::TransferReadOp read
,
4782 vector::TransferWriteOp write
) {
4783 return read
.getSource() == write
.getSource() &&
4784 read
.getIndices() == write
.getIndices() &&
4785 read
.getPermutationMap() == write
.getPermutationMap() &&
4786 read
.getVectorType() == write
.getVectorType() && !read
.getMask() &&
4789 /// Fold transfer_write write after read:
4792 /// %v = vector.transfer_read %t0[%c0...] :
4793 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4794 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4795 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4803 static LogicalResult
foldWAR(TransferWriteOp write
,
4804 SmallVectorImpl
<OpFoldResult
> &results
) {
4805 if (!llvm::isa
<RankedTensorType
>(write
.getSource().getType()))
4807 auto read
= write
.getVector().getDefiningOp
<vector::TransferReadOp
>();
4811 if (!checkSameValueWAR(read
, write
))
4813 results
.push_back(read
.getSource());
4817 LogicalResult
TransferWriteOp::fold(FoldAdaptor adaptor
,
4818 SmallVectorImpl
<OpFoldResult
> &results
) {
4819 if (succeeded(foldReadInitWrite(*this, adaptor
.getOperands(), results
)))
4821 if (succeeded(foldWAR(*this, results
)))
4823 if (succeeded(foldTransferInBoundsAttribute(*this)))
4825 if (succeeded(foldTransferFullMask(*this)))
4827 return memref::foldMemRefCast(*this);
4830 std::optional
<SmallVector
<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4831 return llvm::to_vector
<4>(getVectorType().getShape());
4834 void TransferWriteOp::getEffects(
4835 SmallVectorImpl
<SideEffects::EffectInstance
<MemoryEffects::Effect
>>
4837 if (llvm::isa
<MemRefType
>(getShapedType()))
4838 effects
.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4839 SideEffects::DefaultResource::get());
4842 Speculation::Speculatability
TransferWriteOp::getSpeculatability() {
4843 if (hasPureTensorSemantics())
4844 return Speculation::Speculatable
;
4845 return Speculation::NotSpeculatable
;
4849 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4852 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4853 /// : vector<1x4xf32>, tensor<4x4xf32>
4854 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4855 /// : vector<1x4xf32>, tensor<4x4xf32>
4856 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4857 /// : vector<1x4xf32>, tensor<4x4xf32>
4863 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4864 /// : vector<1x4xf32>, tensor<4x4xf32>
4865 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4866 /// : vector<1x4xf32>, tensor<4x4xf32>
4867 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4868 /// : vector<1x4xf32>, tensor<4x4xf32>
4871 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4873 class FoldWaw final
: public OpRewritePattern
<TransferWriteOp
> {
4875 using OpRewritePattern::OpRewritePattern
;
4876 LogicalResult
matchAndRewrite(TransferWriteOp writeOp
,
4877 PatternRewriter
&rewriter
) const override
{
4878 if (!llvm::isa
<RankedTensorType
>(writeOp
.getShapedType()))
4880 vector::TransferWriteOp writeToModify
= writeOp
;
4883 writeOp
.getSource().getDefiningOp
<vector::TransferWriteOp
>();
4885 if (checkSameValueWAW(writeOp
, defWrite
)) {
4886 rewriter
.modifyOpInPlace(writeToModify
, [&]() {
4887 writeToModify
.getSourceMutable().assign(defWrite
.getSource());
4891 if (!isDisjointTransferIndices(
4892 cast
<VectorTransferOpInterface
>(defWrite
.getOperation()),
4893 cast
<VectorTransferOpInterface
>(writeOp
.getOperation())))
4895 // If the previous write op doesn't have any other use we an safely look
4896 // at the previous store to see if it can be removed.
4897 if (!defWrite
->hasOneUse())
4899 writeToModify
= defWrite
;
4900 defWrite
= defWrite
.getSource().getDefiningOp
<vector::TransferWriteOp
>();
4906 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4907 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4908 /// overwritten and inserted into another tensor. After this rewrite, the
4909 /// operations bufferize in-place since all of them work on the same slice.
4913 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4914 /// : vector<8x16xf32>, tensor<8x16xf32>
4915 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4916 /// : tensor<8x16xf32> to tensor<?x?xf32>
4917 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4918 /// : tensor<?x?xf32> into tensor<27x37xf32>
4922 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4923 /// : tensor<27x37xf32> to tensor<?x?xf32>
4924 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4925 /// : vector<8x16xf32>, tensor<?x?xf32>
4926 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4927 /// : tensor<?x?xf32> into tensor<27x37xf32>
4929 struct SwapExtractSliceOfTransferWrite
4930 : public OpRewritePattern
<tensor::InsertSliceOp
> {
4932 using OpRewritePattern::OpRewritePattern
;
4934 LogicalResult
matchAndRewrite(tensor::InsertSliceOp insertOp
,
4935 PatternRewriter
&rewriter
) const override
{
4936 if (!insertOp
.hasUnitStride())
4939 insertOp
.getSource().getDefiningOp
<tensor::ExtractSliceOp
>();
4940 if (!extractOp
|| !extractOp
.hasUnitStride() || !extractOp
->hasOneUse())
4942 auto transferOp
= extractOp
.getSource().getDefiningOp
<TransferWriteOp
>();
4943 if (!transferOp
|| !transferOp
->hasOneUse())
4946 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4948 if (insertOp
.getSourceType().getRank() != transferOp
.getTransferRank()) {
4949 return rewriter
.notifyMatchFailure(insertOp
,
4950 "use-def chain is rank-reducing");
4953 // Fail if tensor::ExtractSliceOp has non-zero offset.
4954 if (!extractOp
.hasZeroOffset()) {
4955 return rewriter
.notifyMatchFailure(insertOp
,
4956 "ExtractSliceOp has non-zero offset");
4959 // Fail if tensor::TransferWriteOp has non-zero offset.
4960 if (!llvm::all_of(transferOp
.getIndices(), [](Value value
) {
4961 return getConstantIntValue(value
) == static_cast<int64_t>(0);
4963 return rewriter
.notifyMatchFailure(insertOp
,
4964 "TranferWriteOp has non-zero offset");
4967 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4968 if (insertOp
.getMixedSizes().size() != extractOp
.getMixedSizes().size()) {
4969 return rewriter
.notifyMatchFailure(
4970 insertOp
, "InsertSliceOp and ExtractSliceOp ranks differ");
4973 for (auto [insertSize
, extractSize
] :
4974 llvm::zip_equal(insertOp
.getMixedSizes(), extractOp
.getMixedSizes())) {
4975 if (!isEqualConstantIntOrValue(insertSize
, extractSize
)) {
4976 return rewriter
.notifyMatchFailure(
4977 insertOp
, "InsertSliceOp and ExtractSliceOp sizes differ");
4981 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4982 assert(transferOp
.getVectorType().hasStaticShape() &&
4983 "expected vector to have a static shape");
4984 ArrayRef
<int64_t> vectorShape
= transferOp
.getVectorType().getShape();
4985 SmallVector
<int64_t> resultShape
= applyPermutationMap(
4986 transferOp
.getPermutationMap(), transferOp
.getShapedType().getShape());
4987 if (transferOp
.getMask() || !vectorShape
.equals(resultShape
)) {
4988 return rewriter
.notifyMatchFailure(
4989 insertOp
, "TransferWriteOp may not write the full tensor.");
4992 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4993 // Set all in_bounds to false and let the folder infer them.
4994 SmallVector
<bool> newInBounds(vectorShape
.size(), false);
4995 auto newExtractOp
= rewriter
.create
<tensor::ExtractSliceOp
>(
4996 extractOp
.getLoc(), insertOp
.getSourceType(), insertOp
.getDest(),
4997 insertOp
.getMixedOffsets(), insertOp
.getMixedSizes(),
4998 insertOp
.getMixedStrides());
4999 auto newTransferWriteOp
= rewriter
.create
<TransferWriteOp
>(
5000 transferOp
.getLoc(), transferOp
.getVector(), newExtractOp
.getResult(),
5001 transferOp
.getIndices(), transferOp
.getPermutationMapAttr(),
5002 rewriter
.getBoolArrayAttr(newInBounds
));
5003 rewriter
.modifyOpInPlace(insertOp
, [&]() {
5004 insertOp
.getSourceMutable().assign(newTransferWriteOp
.getResult());
5012 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5013 MLIRContext
*context
) {
5014 results
.add
<FoldWaw
, SwapExtractSliceOfTransferWrite
>(context
);
5017 //===----------------------------------------------------------------------===//
5019 //===----------------------------------------------------------------------===//
5021 static LogicalResult
verifyLoadStoreMemRefLayout(Operation
*op
,
5023 MemRefType memRefTy
) {
5024 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5025 // need any strides limitations.
5026 if (!vecTy
.isScalable() &&
5027 (vecTy
.getRank() == 0 || vecTy
.getNumElements() == 1))
5030 if (!memRefTy
.isLastDimUnitStride())
5031 return op
->emitOpError("most minor memref dim must have unit stride");
5035 LogicalResult
vector::LoadOp::verify() {
5036 VectorType resVecTy
= getVectorType();
5037 MemRefType memRefTy
= getMemRefType();
5039 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy
, memRefTy
)))
5042 // Checks for vector memrefs.
5043 Type memElemTy
= memRefTy
.getElementType();
5044 if (auto memVecTy
= llvm::dyn_cast
<VectorType
>(memElemTy
)) {
5045 if (memVecTy
!= resVecTy
)
5046 return emitOpError("base memref and result vector types should match");
5047 memElemTy
= memVecTy
.getElementType();
5050 if (resVecTy
.getElementType() != memElemTy
)
5051 return emitOpError("base and result element types should match");
5052 if (llvm::size(getIndices()) != memRefTy
.getRank())
5053 return emitOpError("requires ") << memRefTy
.getRank() << " indices";
5057 OpFoldResult
LoadOp::fold(FoldAdaptor
) {
5058 if (succeeded(memref::foldMemRefCast(*this)))
5060 return OpFoldResult();
5063 //===----------------------------------------------------------------------===//
5065 //===----------------------------------------------------------------------===//
5067 LogicalResult
vector::StoreOp::verify() {
5068 VectorType valueVecTy
= getVectorType();
5069 MemRefType memRefTy
= getMemRefType();
5071 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy
, memRefTy
)))
5074 // Checks for vector memrefs.
5075 Type memElemTy
= memRefTy
.getElementType();
5076 if (auto memVecTy
= llvm::dyn_cast
<VectorType
>(memElemTy
)) {
5077 if (memVecTy
!= valueVecTy
)
5079 "base memref and valueToStore vector types should match");
5080 memElemTy
= memVecTy
.getElementType();
5083 if (valueVecTy
.getElementType() != memElemTy
)
5084 return emitOpError("base and valueToStore element type should match");
5085 if (llvm::size(getIndices()) != memRefTy
.getRank())
5086 return emitOpError("requires ") << memRefTy
.getRank() << " indices";
5090 LogicalResult
StoreOp::fold(FoldAdaptor adaptor
,
5091 SmallVectorImpl
<OpFoldResult
> &results
) {
5092 return memref::foldMemRefCast(*this);
5095 //===----------------------------------------------------------------------===//
5097 //===----------------------------------------------------------------------===//
5099 LogicalResult
MaskedLoadOp::verify() {
5100 VectorType maskVType
= getMaskVectorType();
5101 VectorType passVType
= getPassThruVectorType();
5102 VectorType resVType
= getVectorType();
5103 MemRefType memType
= getMemRefType();
5105 if (resVType
.getElementType() != memType
.getElementType())
5106 return emitOpError("base and result element type should match");
5107 if (llvm::size(getIndices()) != memType
.getRank())
5108 return emitOpError("requires ") << memType
.getRank() << " indices";
5109 if (resVType
.getShape() != maskVType
.getShape())
5110 return emitOpError("expected result shape to match mask shape");
5111 if (resVType
!= passVType
)
5112 return emitOpError("expected pass_thru of same type as result type");
5117 class MaskedLoadFolder final
: public OpRewritePattern
<MaskedLoadOp
> {
5119 using OpRewritePattern::OpRewritePattern
;
5120 LogicalResult
matchAndRewrite(MaskedLoadOp load
,
5121 PatternRewriter
&rewriter
) const override
{
5122 switch (getMaskFormat(load
.getMask())) {
5123 case MaskFormat::AllTrue
:
5124 rewriter
.replaceOpWithNewOp
<vector::LoadOp
>(
5125 load
, load
.getType(), load
.getBase(), load
.getIndices());
5127 case MaskFormat::AllFalse
:
5128 rewriter
.replaceOp(load
, load
.getPassThru());
5130 case MaskFormat::Unknown
:
5133 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5138 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5139 MLIRContext
*context
) {
5140 results
.add
<MaskedLoadFolder
>(context
);
5143 OpFoldResult
MaskedLoadOp::fold(FoldAdaptor
) {
5144 if (succeeded(memref::foldMemRefCast(*this)))
5146 return OpFoldResult();
5149 //===----------------------------------------------------------------------===//
5151 //===----------------------------------------------------------------------===//
5153 LogicalResult
MaskedStoreOp::verify() {
5154 VectorType maskVType
= getMaskVectorType();
5155 VectorType valueVType
= getVectorType();
5156 MemRefType memType
= getMemRefType();
5158 if (valueVType
.getElementType() != memType
.getElementType())
5159 return emitOpError("base and valueToStore element type should match");
5160 if (llvm::size(getIndices()) != memType
.getRank())
5161 return emitOpError("requires ") << memType
.getRank() << " indices";
5162 if (valueVType
.getShape() != maskVType
.getShape())
5163 return emitOpError("expected valueToStore shape to match mask shape");
5168 class MaskedStoreFolder final
: public OpRewritePattern
<MaskedStoreOp
> {
5170 using OpRewritePattern::OpRewritePattern
;
5171 LogicalResult
matchAndRewrite(MaskedStoreOp store
,
5172 PatternRewriter
&rewriter
) const override
{
5173 switch (getMaskFormat(store
.getMask())) {
5174 case MaskFormat::AllTrue
:
5175 rewriter
.replaceOpWithNewOp
<vector::StoreOp
>(
5176 store
, store
.getValueToStore(), store
.getBase(), store
.getIndices());
5178 case MaskFormat::AllFalse
:
5179 rewriter
.eraseOp(store
);
5181 case MaskFormat::Unknown
:
5184 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5189 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5190 MLIRContext
*context
) {
5191 results
.add
<MaskedStoreFolder
>(context
);
5194 LogicalResult
MaskedStoreOp::fold(FoldAdaptor adaptor
,
5195 SmallVectorImpl
<OpFoldResult
> &results
) {
5196 return memref::foldMemRefCast(*this);
5199 //===----------------------------------------------------------------------===//
5201 //===----------------------------------------------------------------------===//
5203 LogicalResult
GatherOp::verify() {
5204 VectorType indVType
= getIndexVectorType();
5205 VectorType maskVType
= getMaskVectorType();
5206 VectorType resVType
= getVectorType();
5207 ShapedType baseType
= getBaseType();
5209 if (!llvm::isa
<MemRefType
, RankedTensorType
>(baseType
))
5210 return emitOpError("requires base to be a memref or ranked tensor type");
5212 if (resVType
.getElementType() != baseType
.getElementType())
5213 return emitOpError("base and result element type should match");
5214 if (llvm::size(getIndices()) != baseType
.getRank())
5215 return emitOpError("requires ") << baseType
.getRank() << " indices";
5216 if (resVType
.getShape() != indVType
.getShape())
5217 return emitOpError("expected result dim to match indices dim");
5218 if (resVType
.getShape() != maskVType
.getShape())
5219 return emitOpError("expected result dim to match mask dim");
5220 if (resVType
!= getPassThruVectorType())
5221 return emitOpError("expected pass_thru of same type as result type");
5225 // MaskableOpInterface methods.
5227 /// Returns the mask type expected by this operation. Mostly used for
5228 /// verification purposes. It requires the operation to be vectorized."
5229 Type
GatherOp::getExpectedMaskType() {
5230 auto vecType
= this->getIndexVectorType();
5231 return VectorType::get(vecType
.getShape(),
5232 IntegerType::get(vecType
.getContext(), /*width=*/1),
5233 vecType
.getScalableDims());
5236 std::optional
<SmallVector
<int64_t, 4>> GatherOp::getShapeForUnroll() {
5237 return llvm::to_vector
<4>(getVectorType().getShape());
5240 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5241 static LogicalResult
isZeroBasedContiguousSeq(Value indexVec
) {
5242 auto vecType
= dyn_cast
<VectorType
>(indexVec
.getType());
5243 if (!vecType
|| vecType
.getRank() != 1 || vecType
.isScalable())
5246 if (indexVec
.getDefiningOp
<StepOp
>())
5249 DenseIntElementsAttr elements
;
5250 if (!matchPattern(indexVec
, m_Constant(&elements
)))
5254 llvm::equal(elements
, llvm::seq
<int64_t>(0, vecType
.getNumElements())));
5258 class GatherFolder final
: public OpRewritePattern
<GatherOp
> {
5260 using OpRewritePattern::OpRewritePattern
;
5261 LogicalResult
matchAndRewrite(GatherOp gather
,
5262 PatternRewriter
&rewriter
) const override
{
5263 switch (getMaskFormat(gather
.getMask())) {
5264 case MaskFormat::AllTrue
:
5265 return failure(); // no unmasked equivalent
5266 case MaskFormat::AllFalse
:
5267 rewriter
.replaceOp(gather
, gather
.getPassThru());
5269 case MaskFormat::Unknown
:
5272 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5276 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5277 /// maskedload. Only 1D fixed vectors are supported for now.
5278 class FoldContiguousGather final
: public OpRewritePattern
<GatherOp
> {
5280 using OpRewritePattern::OpRewritePattern
;
5281 LogicalResult
matchAndRewrite(GatherOp op
,
5282 PatternRewriter
&rewriter
) const override
{
5283 if (failed(isZeroBasedContiguousSeq(op
.getIndexVec())))
5286 rewriter
.replaceOpWithNewOp
<MaskedLoadOp
>(op
, op
.getType(), op
.getBase(),
5287 op
.getIndices(), op
.getMask(),
5294 void GatherOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5295 MLIRContext
*context
) {
5296 results
.add
<GatherFolder
, FoldContiguousGather
>(context
);
5299 //===----------------------------------------------------------------------===//
5301 //===----------------------------------------------------------------------===//
5303 LogicalResult
ScatterOp::verify() {
5304 VectorType indVType
= getIndexVectorType();
5305 VectorType maskVType
= getMaskVectorType();
5306 VectorType valueVType
= getVectorType();
5307 MemRefType memType
= getMemRefType();
5309 if (valueVType
.getElementType() != memType
.getElementType())
5310 return emitOpError("base and valueToStore element type should match");
5311 if (llvm::size(getIndices()) != memType
.getRank())
5312 return emitOpError("requires ") << memType
.getRank() << " indices";
5313 if (valueVType
.getDimSize(0) != indVType
.getDimSize(0))
5314 return emitOpError("expected valueToStore dim to match indices dim");
5315 if (valueVType
.getDimSize(0) != maskVType
.getDimSize(0))
5316 return emitOpError("expected valueToStore dim to match mask dim");
5321 class ScatterFolder final
: public OpRewritePattern
<ScatterOp
> {
5323 using OpRewritePattern::OpRewritePattern
;
5324 LogicalResult
matchAndRewrite(ScatterOp scatter
,
5325 PatternRewriter
&rewriter
) const override
{
5326 switch (getMaskFormat(scatter
.getMask())) {
5327 case MaskFormat::AllTrue
:
5328 return failure(); // no unmasked equivalent
5329 case MaskFormat::AllFalse
:
5330 rewriter
.eraseOp(scatter
);
5332 case MaskFormat::Unknown
:
5335 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5339 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5340 /// maskedstore. Only 1D fixed vectors are supported for now.
5341 class FoldContiguousScatter final
: public OpRewritePattern
<ScatterOp
> {
5343 using OpRewritePattern::OpRewritePattern
;
5344 LogicalResult
matchAndRewrite(ScatterOp op
,
5345 PatternRewriter
&rewriter
) const override
{
5346 if (failed(isZeroBasedContiguousSeq(op
.getIndexVec())))
5349 rewriter
.replaceOpWithNewOp
<MaskedStoreOp
>(
5350 op
, op
.getBase(), op
.getIndices(), op
.getMask(), op
.getValueToStore());
5356 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5357 MLIRContext
*context
) {
5358 results
.add
<ScatterFolder
, FoldContiguousScatter
>(context
);
5361 //===----------------------------------------------------------------------===//
5363 //===----------------------------------------------------------------------===//
5365 LogicalResult
ExpandLoadOp::verify() {
5366 VectorType maskVType
= getMaskVectorType();
5367 VectorType passVType
= getPassThruVectorType();
5368 VectorType resVType
= getVectorType();
5369 MemRefType memType
= getMemRefType();
5371 if (resVType
.getElementType() != memType
.getElementType())
5372 return emitOpError("base and result element type should match");
5373 if (llvm::size(getIndices()) != memType
.getRank())
5374 return emitOpError("requires ") << memType
.getRank() << " indices";
5375 if (resVType
.getDimSize(0) != maskVType
.getDimSize(0))
5376 return emitOpError("expected result dim to match mask dim");
5377 if (resVType
!= passVType
)
5378 return emitOpError("expected pass_thru of same type as result type");
5383 class ExpandLoadFolder final
: public OpRewritePattern
<ExpandLoadOp
> {
5385 using OpRewritePattern::OpRewritePattern
;
5386 LogicalResult
matchAndRewrite(ExpandLoadOp expand
,
5387 PatternRewriter
&rewriter
) const override
{
5388 switch (getMaskFormat(expand
.getMask())) {
5389 case MaskFormat::AllTrue
:
5390 rewriter
.replaceOpWithNewOp
<vector::LoadOp
>(
5391 expand
, expand
.getType(), expand
.getBase(), expand
.getIndices());
5393 case MaskFormat::AllFalse
:
5394 rewriter
.replaceOp(expand
, expand
.getPassThru());
5396 case MaskFormat::Unknown
:
5399 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5404 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5405 MLIRContext
*context
) {
5406 results
.add
<ExpandLoadFolder
>(context
);
5409 //===----------------------------------------------------------------------===//
5411 //===----------------------------------------------------------------------===//
5413 LogicalResult
CompressStoreOp::verify() {
5414 VectorType maskVType
= getMaskVectorType();
5415 VectorType valueVType
= getVectorType();
5416 MemRefType memType
= getMemRefType();
5418 if (valueVType
.getElementType() != memType
.getElementType())
5419 return emitOpError("base and valueToStore element type should match");
5420 if (llvm::size(getIndices()) != memType
.getRank())
5421 return emitOpError("requires ") << memType
.getRank() << " indices";
5422 if (valueVType
.getDimSize(0) != maskVType
.getDimSize(0))
5423 return emitOpError("expected valueToStore dim to match mask dim");
5428 class CompressStoreFolder final
: public OpRewritePattern
<CompressStoreOp
> {
5430 using OpRewritePattern::OpRewritePattern
;
5431 LogicalResult
matchAndRewrite(CompressStoreOp compress
,
5432 PatternRewriter
&rewriter
) const override
{
5433 switch (getMaskFormat(compress
.getMask())) {
5434 case MaskFormat::AllTrue
:
5435 rewriter
.replaceOpWithNewOp
<vector::StoreOp
>(
5436 compress
, compress
.getValueToStore(), compress
.getBase(),
5437 compress
.getIndices());
5439 case MaskFormat::AllFalse
:
5440 rewriter
.eraseOp(compress
);
5442 case MaskFormat::Unknown
:
5445 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5450 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5451 MLIRContext
*context
) {
5452 results
.add
<CompressStoreFolder
>(context
);
5455 //===----------------------------------------------------------------------===//
5457 //===----------------------------------------------------------------------===//
5459 void ShapeCastOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
5460 SetIntRangeFn setResultRanges
) {
5461 setResultRanges(getResult(), argRanges
.front());
5464 /// Returns true if each element of 'a' is equal to the product of a contiguous
5465 /// sequence of the elements of 'b'. Returns false otherwise.
5466 static bool isValidShapeCast(ArrayRef
<int64_t> a
, ArrayRef
<int64_t> b
) {
5467 unsigned rankA
= a
.size();
5468 unsigned rankB
= b
.size();
5469 assert(rankA
< rankB
);
5471 auto isOne
= [](int64_t v
) { return v
== 1; };
5473 // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5474 // casted to a 0-d vector.
5475 if (rankA
== 0 && llvm::all_of(b
, isOne
))
5480 while (i
< rankA
&& j
< rankB
) {
5481 int64_t dimA
= a
[i
];
5483 while (dimB
< dimA
&& j
< rankB
)
5489 // Handle the case when trailing dimensions are of size 1.
5490 // Include them into the contiguous sequence.
5491 if (i
< rankA
&& llvm::all_of(a
.slice(i
), isOne
))
5493 if (j
< rankB
&& llvm::all_of(b
.slice(j
), isOne
))
5497 return i
== rankA
&& j
== rankB
;
5500 static LogicalResult
verifyVectorShapeCast(Operation
*op
,
5501 VectorType sourceVectorType
,
5502 VectorType resultVectorType
) {
5503 // Check that element type is the same.
5504 if (sourceVectorType
.getElementType() != resultVectorType
.getElementType())
5505 return op
->emitOpError("source/result vectors must have same element type");
5506 auto sourceShape
= sourceVectorType
.getShape();
5507 auto resultShape
= resultVectorType
.getShape();
5509 // Check that product of source dim sizes matches product of result dim sizes.
5510 int64_t sourceDimProduct
= std::accumulate(
5511 sourceShape
.begin(), sourceShape
.end(), 1LL, std::multiplies
<int64_t>{});
5512 int64_t resultDimProduct
= std::accumulate(
5513 resultShape
.begin(), resultShape
.end(), 1LL, std::multiplies
<int64_t>{});
5514 if (sourceDimProduct
!= resultDimProduct
)
5515 return op
->emitOpError("source/result number of elements must match");
5517 // Check that expanding/contracting rank cases.
5518 unsigned sourceRank
= sourceVectorType
.getRank();
5519 unsigned resultRank
= resultVectorType
.getRank();
5520 if (sourceRank
< resultRank
) {
5521 if (!isValidShapeCast(sourceShape
, resultShape
))
5522 return op
->emitOpError("invalid shape cast");
5523 } else if (sourceRank
> resultRank
) {
5524 if (!isValidShapeCast(resultShape
, sourceShape
))
5525 return op
->emitOpError("invalid shape cast");
5528 // Check that (non-)scalability is preserved
5529 int64_t sourceNScalableDims
= sourceVectorType
.getNumScalableDims();
5530 int64_t resultNScalableDims
= resultVectorType
.getNumScalableDims();
5531 if (sourceNScalableDims
!= resultNScalableDims
)
5532 return op
->emitOpError("different number of scalable dims at source (")
5533 << sourceNScalableDims
<< ") and result (" << resultNScalableDims
5535 sourceVectorType
.getNumDynamicDims();
5540 LogicalResult
ShapeCastOp::verify() {
5541 auto sourceVectorType
=
5542 llvm::dyn_cast_or_null
<VectorType
>(getSource().getType());
5543 auto resultVectorType
=
5544 llvm::dyn_cast_or_null
<VectorType
>(getResult().getType());
5546 // Check if source/result are of vector type.
5547 if (sourceVectorType
&& resultVectorType
)
5548 return verifyVectorShapeCast(*this, sourceVectorType
, resultVectorType
);
5553 OpFoldResult
ShapeCastOp::fold(FoldAdaptor adaptor
) {
5554 // No-op shape cast.
5555 if (getSource().getType() == getResult().getType())
5558 // Canceling shape casts.
5559 if (auto otherOp
= getSource().getDefiningOp
<ShapeCastOp
>()) {
5560 if (getResult().getType() == otherOp
.getSource().getType())
5561 return otherOp
.getSource();
5563 // Only allows valid transitive folding.
5564 VectorType srcType
= llvm::cast
<VectorType
>(otherOp
.getSource().getType());
5565 VectorType resultType
= llvm::cast
<VectorType
>(getResult().getType());
5566 if (srcType
.getRank() < resultType
.getRank()) {
5567 if (!isValidShapeCast(srcType
.getShape(), resultType
.getShape()))
5569 } else if (srcType
.getRank() > resultType
.getRank()) {
5570 if (!isValidShapeCast(resultType
.getShape(), srcType
.getShape()))
5576 setOperand(otherOp
.getSource());
5580 // Cancelling broadcast and shape cast ops.
5581 if (auto bcastOp
= getSource().getDefiningOp
<BroadcastOp
>()) {
5582 if (bcastOp
.getSourceType() == getType())
5583 return bcastOp
.getSource();
5590 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5591 class ShapeCastConstantFolder final
: public OpRewritePattern
<ShapeCastOp
> {
5593 using OpRewritePattern::OpRewritePattern
;
5595 LogicalResult
matchAndRewrite(ShapeCastOp shapeCastOp
,
5596 PatternRewriter
&rewriter
) const override
{
5598 shapeCastOp
.getSource().getDefiningOp
<arith::ConstantOp
>();
5601 // Only handle splat for now.
5602 auto dense
= llvm::dyn_cast
<SplatElementsAttr
>(constantOp
.getValue());
5606 DenseElementsAttr::get(llvm::cast
<VectorType
>(shapeCastOp
.getType()),
5607 dense
.getSplatValue
<Attribute
>());
5608 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(shapeCastOp
, newAttr
);
5613 /// Helper function that computes a new vector type based on the input vector
5614 /// type by removing the trailing one dims:
5616 /// vector<4x1x1xi1> --> vector<4x1>
5618 static VectorType
trimTrailingOneDims(VectorType oldType
) {
5619 ArrayRef
<int64_t> oldShape
= oldType
.getShape();
5620 ArrayRef
<int64_t> newShape
= oldShape
;
5622 ArrayRef
<bool> oldScalableDims
= oldType
.getScalableDims();
5623 ArrayRef
<bool> newScalableDims
= oldScalableDims
;
5625 while (!newShape
.empty() && newShape
.back() == 1 && !newScalableDims
.back()) {
5626 newShape
= newShape
.drop_back(1);
5627 newScalableDims
= newScalableDims
.drop_back(1);
5630 // Make sure we have at least 1 dimension.
5631 // TODO: Add support for 0-D vectors.
5632 if (newShape
.empty()) {
5633 newShape
= oldShape
.take_back();
5634 newScalableDims
= oldScalableDims
.take_back();
5637 return VectorType::get(newShape
, oldType
.getElementType(), newScalableDims
);
5640 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5642 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5643 /// dimension. If the input vector comes from `vector.create_mask` for which
5644 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5645 /// to fold shape_cast into create_mask.
5648 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5649 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5651 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5652 class ShapeCastCreateMaskFolderTrailingOneDim final
5653 : public OpRewritePattern
<ShapeCastOp
> {
5655 using OpRewritePattern::OpRewritePattern
;
5657 LogicalResult
matchAndRewrite(ShapeCastOp shapeOp
,
5658 PatternRewriter
&rewriter
) const override
{
5659 Value shapeOpSrc
= shapeOp
->getOperand(0);
5660 auto createMaskOp
= shapeOpSrc
.getDefiningOp
<vector::CreateMaskOp
>();
5661 auto constantMaskOp
= shapeOpSrc
.getDefiningOp
<vector::ConstantMaskOp
>();
5662 if (!createMaskOp
&& !constantMaskOp
)
5665 VectorType shapeOpResTy
= shapeOp
.getResultVectorType();
5666 VectorType shapeOpSrcTy
= shapeOp
.getSourceVectorType();
5668 VectorType newVecType
= trimTrailingOneDims(shapeOpSrcTy
);
5669 if (newVecType
!= shapeOpResTy
)
5672 auto numDimsToDrop
=
5673 shapeOpSrcTy
.getShape().size() - shapeOpResTy
.getShape().size();
5675 // No unit dims to drop
5680 auto maskOperands
= createMaskOp
.getOperands();
5681 auto numMaskOperands
= maskOperands
.size();
5683 // Check every mask dim size to see whether it can be dropped
5684 for (size_t i
= numMaskOperands
- 1; i
>= numMaskOperands
- numDimsToDrop
;
5686 auto constant
= maskOperands
[i
].getDefiningOp
<arith::ConstantIndexOp
>();
5687 if (!constant
|| (constant
.value() != 1))
5690 SmallVector
<Value
> newMaskOperands
=
5691 maskOperands
.drop_back(numDimsToDrop
);
5693 rewriter
.replaceOpWithNewOp
<vector::CreateMaskOp
>(shapeOp
, shapeOpResTy
,
5698 if (constantMaskOp
) {
5699 auto maskDimSizes
= constantMaskOp
.getMaskDimSizes();
5700 auto numMaskOperands
= maskDimSizes
.size();
5702 // Check every mask dim size to see whether it can be dropped
5703 for (size_t i
= numMaskOperands
- 1; i
>= numMaskOperands
- numDimsToDrop
;
5705 if (maskDimSizes
[i
] != 1)
5709 auto newMaskOperands
= maskDimSizes
.drop_back(numDimsToDrop
);
5710 rewriter
.replaceOpWithNewOp
<vector::ConstantMaskOp
>(shapeOp
, shapeOpResTy
,
5719 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5720 /// This only applies when the shape of the broadcast source
5721 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5722 /// reshape is expressive enough to capture the result in a single op), or
5723 /// 2. has the same element count as the shape cast result.
5724 class ShapeCastBroadcastFolder final
: public OpRewritePattern
<ShapeCastOp
> {
5726 using OpRewritePattern::OpRewritePattern
;
5728 LogicalResult
matchAndRewrite(ShapeCastOp shapeCastOp
,
5729 PatternRewriter
&rewriter
) const override
{
5731 shapeCastOp
.getSource().getDefiningOp
<vector::BroadcastOp
>();
5735 ArrayRef
<int64_t> broadcastSourceShape
;
5736 if (auto srcType
= dyn_cast
<VectorType
>(broadcastOp
.getSourceType()))
5737 broadcastSourceShape
= srcType
.getShape();
5738 ArrayRef
<int64_t> shapeCastTargetShape
=
5739 shapeCastOp
.getResultVectorType().getShape();
5741 // If `broadcastSourceShape` is a suffix of the result, we can just replace
5742 // with a broadcast to the final shape.
5743 if (broadcastSourceShape
==
5744 shapeCastTargetShape
.take_back(broadcastSourceShape
.size())) {
5745 rewriter
.replaceOpWithNewOp
<vector::BroadcastOp
>(
5746 shapeCastOp
, shapeCastOp
.getResultVectorType(),
5747 broadcastOp
.getSource());
5751 // Otherwise, if the final result has the same element count, we can replace
5752 // with a shape cast.
5753 if (auto srcType
= dyn_cast
<VectorType
>(broadcastOp
.getSourceType())) {
5754 if (srcType
.getNumElements() ==
5755 shapeCastOp
.getResultVectorType().getNumElements()) {
5756 rewriter
.replaceOpWithNewOp
<vector::ShapeCastOp
>(
5757 shapeCastOp
, shapeCastOp
.getResultVectorType(),
5758 broadcastOp
.getSource());
5769 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
5770 MLIRContext
*context
) {
5771 results
.add
<ShapeCastConstantFolder
, ShapeCastCreateMaskFolderTrailingOneDim
,
5772 ShapeCastBroadcastFolder
>(context
);
5775 //===----------------------------------------------------------------------===//
5777 //===----------------------------------------------------------------------===//
5779 LogicalResult
BitCastOp::verify() {
5780 auto sourceVectorType
= getSourceVectorType();
5781 auto resultVectorType
= getResultVectorType();
5783 for (int64_t i
= 0, e
= sourceVectorType
.getRank() - 1; i
< e
; i
++) {
5784 if (sourceVectorType
.getDimSize(i
) != resultVectorType
.getDimSize(i
))
5785 return emitOpError("dimension size mismatch at: ") << i
;
5788 DataLayout dataLayout
= DataLayout::closest(*this);
5789 auto sourceElementBits
=
5790 dataLayout
.getTypeSizeInBits(sourceVectorType
.getElementType());
5791 auto resultElementBits
=
5792 dataLayout
.getTypeSizeInBits(resultVectorType
.getElementType());
5794 if (sourceVectorType
.getRank() == 0) {
5795 if (sourceElementBits
!= resultElementBits
)
5796 return emitOpError("source/result bitwidth of the 0-D vector element "
5797 "types must be equal");
5798 } else if (sourceElementBits
* sourceVectorType
.getShape().back() !=
5799 resultElementBits
* resultVectorType
.getShape().back()) {
5801 "source/result bitwidth of the minor 1-D vectors must be equal");
5807 OpFoldResult
BitCastOp::fold(FoldAdaptor adaptor
) {
5809 if (getSource().getType() == getResult().getType())
5812 // Canceling bitcasts.
5813 if (auto otherOp
= getSource().getDefiningOp
<BitCastOp
>()) {
5814 if (getResult().getType() == otherOp
.getSource().getType())
5815 return otherOp
.getSource();
5817 setOperand(otherOp
.getSource());
5821 Attribute sourceConstant
= adaptor
.getSource();
5822 if (!sourceConstant
)
5825 Type srcElemType
= getSourceVectorType().getElementType();
5826 Type dstElemType
= getResultVectorType().getElementType();
5828 if (auto floatPack
= llvm::dyn_cast
<DenseFPElementsAttr
>(sourceConstant
)) {
5829 if (floatPack
.isSplat()) {
5830 auto splat
= floatPack
.getSplatValue
<FloatAttr
>();
5832 // Casting fp16 into fp32.
5833 if (srcElemType
.isF16() && dstElemType
.isF32()) {
5834 uint32_t bits
= static_cast<uint32_t>(
5835 splat
.getValue().bitcastToAPInt().getZExtValue());
5836 // Duplicate the 16-bit pattern.
5837 bits
= (bits
<< 16) | (bits
& 0xffff);
5838 APInt
intBits(32, bits
);
5839 APFloat
floatBits(llvm::APFloat::IEEEsingle(), intBits
);
5840 return DenseElementsAttr::get(getResultVectorType(), floatBits
);
5845 if (auto intPack
= llvm::dyn_cast
<DenseIntElementsAttr
>(sourceConstant
)) {
5846 if (intPack
.isSplat()) {
5847 auto splat
= intPack
.getSplatValue
<IntegerAttr
>();
5849 if (llvm::isa
<IntegerType
>(dstElemType
)) {
5850 uint64_t srcBitWidth
= srcElemType
.getIntOrFloatBitWidth();
5851 uint64_t dstBitWidth
= dstElemType
.getIntOrFloatBitWidth();
5853 // Casting to a larger integer bit width.
5854 if (dstBitWidth
> srcBitWidth
&& dstBitWidth
% srcBitWidth
== 0) {
5855 APInt intBits
= splat
.getValue().zext(dstBitWidth
);
5857 // Duplicate the lower width element.
5858 for (uint64_t i
= 0; i
< dstBitWidth
/ srcBitWidth
- 1; i
++)
5859 intBits
= (intBits
<< srcBitWidth
) | intBits
;
5860 return DenseElementsAttr::get(getResultVectorType(), intBits
);
5869 //===----------------------------------------------------------------------===//
5871 //===----------------------------------------------------------------------===//
5873 static SmallVector
<int64_t, 8> extractShape(MemRefType memRefType
) {
5874 auto vectorType
= llvm::dyn_cast
<VectorType
>(memRefType
.getElementType());
5875 SmallVector
<int64_t, 8> res(memRefType
.getShape());
5877 res
.append(vectorType
.getShape().begin(), vectorType
.getShape().end());
5881 /// Build the canonical memRefType with a single vector.
5882 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5883 void TypeCastOp::build(OpBuilder
&builder
, OperationState
&result
,
5885 result
.addOperands(source
);
5886 MemRefType memRefType
= llvm::cast
<MemRefType
>(source
.getType());
5887 VectorType vectorType
=
5888 VectorType::get(extractShape(memRefType
),
5889 getElementTypeOrSelf(getElementTypeOrSelf(memRefType
)));
5890 result
.addTypes(MemRefType::get({}, vectorType
, MemRefLayoutAttrInterface(),
5891 memRefType
.getMemorySpace()));
5894 LogicalResult
TypeCastOp::verify() {
5895 MemRefType canonicalType
= getMemRefType().canonicalizeStridedLayout();
5896 if (!canonicalType
.getLayout().isIdentity())
5897 return emitOpError("expects operand to be a memref with identity layout");
5898 if (!getResultMemRefType().getLayout().isIdentity())
5899 return emitOpError("expects result to be a memref with identity layout");
5900 if (getResultMemRefType().getMemorySpace() !=
5901 getMemRefType().getMemorySpace())
5902 return emitOpError("expects result in same memory space");
5904 auto sourceType
= getMemRefType();
5905 auto resultType
= getResultMemRefType();
5906 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType
)) !=
5907 getElementTypeOrSelf(getElementTypeOrSelf(resultType
)))
5909 "expects result and operand with same underlying scalar type: ")
5911 if (extractShape(sourceType
) != extractShape(resultType
))
5913 "expects concatenated result and operand shapes to be equal: ")
5918 //===----------------------------------------------------------------------===//
5920 //===----------------------------------------------------------------------===//
5922 void vector::TransposeOp::build(OpBuilder
&builder
, OperationState
&result
,
5923 Value vector
, ArrayRef
<int64_t> permutation
) {
5924 VectorType vt
= llvm::cast
<VectorType
>(vector
.getType());
5925 SmallVector
<int64_t, 4> transposedShape(vt
.getRank());
5926 SmallVector
<bool, 4> transposedScalableDims(vt
.getRank());
5927 for (unsigned i
= 0; i
< permutation
.size(); ++i
) {
5928 transposedShape
[i
] = vt
.getShape()[permutation
[i
]];
5929 transposedScalableDims
[i
] = vt
.getScalableDims()[permutation
[i
]];
5932 result
.addOperands(vector
);
5933 result
.addTypes(VectorType::get(transposedShape
, vt
.getElementType(),
5934 transposedScalableDims
));
5935 result
.addAttribute(TransposeOp::getPermutationAttrName(result
.name
),
5936 builder
.getDenseI64ArrayAttr(permutation
));
5939 OpFoldResult
vector::TransposeOp::fold(FoldAdaptor adaptor
) {
5940 // Eliminate splat constant transpose ops.
5942 llvm::dyn_cast_if_present
<DenseElementsAttr
>(adaptor
.getVector()))
5944 return attr
.reshape(getResultVectorType());
5946 // Eliminate identity transpose ops. This happens when the dimensions of the
5947 // input vector remain in their original order after the transpose operation.
5948 ArrayRef
<int64_t> perm
= getPermutation();
5950 // Check if the permutation of the dimensions contains sequential values:
5952 for (int64_t i
= 0, e
= perm
.size(); i
< e
; i
++) {
5960 LogicalResult
vector::TransposeOp::verify() {
5961 VectorType vectorType
= getSourceVectorType();
5962 VectorType resultType
= getResultVectorType();
5963 int64_t rank
= resultType
.getRank();
5964 if (vectorType
.getRank() != rank
)
5965 return emitOpError("vector result rank mismatch: ") << rank
;
5966 // Verify transposition array.
5967 ArrayRef
<int64_t> perm
= getPermutation();
5968 int64_t size
= perm
.size();
5970 return emitOpError("transposition length mismatch: ") << size
;
5971 SmallVector
<bool, 8> seen(rank
, false);
5972 for (const auto &ta
: llvm::enumerate(perm
)) {
5973 if (ta
.value() < 0 || ta
.value() >= rank
)
5974 return emitOpError("transposition index out of range: ") << ta
.value();
5975 if (seen
[ta
.value()])
5976 return emitOpError("duplicate position index: ") << ta
.value();
5977 seen
[ta
.value()] = true;
5978 if (resultType
.getDimSize(ta
.index()) != vectorType
.getDimSize(ta
.value()))
5979 return emitOpError("dimension size mismatch at: ") << ta
.value();
5984 std::optional
<SmallVector
<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5985 return llvm::to_vector
<4>(getResultVectorType().getShape());
5990 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5991 class TransposeFolder final
: public OpRewritePattern
<vector::TransposeOp
> {
5993 using OpRewritePattern::OpRewritePattern
;
5995 LogicalResult
matchAndRewrite(vector::TransposeOp transposeOp
,
5996 PatternRewriter
&rewriter
) const override
{
5997 // Composes two permutations: result[i] = permutation1[permutation2[i]].
5998 auto composePermutations
= [](ArrayRef
<int64_t> permutation1
,
5999 ArrayRef
<int64_t> permutation2
) {
6000 SmallVector
<int64_t, 4> result
;
6001 for (auto index
: permutation2
)
6002 result
.push_back(permutation1
[index
]);
6006 // Return if the input of 'transposeOp' is not defined by another transpose.
6007 vector::TransposeOp parentTransposeOp
=
6008 transposeOp
.getVector().getDefiningOp
<vector::TransposeOp
>();
6009 if (!parentTransposeOp
)
6012 SmallVector
<int64_t, 4> permutation
= composePermutations(
6013 parentTransposeOp
.getPermutation(), transposeOp
.getPermutation());
6014 // Replace 'transposeOp' with a new transpose operation.
6015 rewriter
.replaceOpWithNewOp
<vector::TransposeOp
>(
6016 transposeOp
, transposeOp
.getResult().getType(),
6017 parentTransposeOp
.getVector(), permutation
);
6022 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6023 struct FoldTransposedScalarBroadcast final
6024 : public OpRewritePattern
<vector::TransposeOp
> {
6025 using OpRewritePattern::OpRewritePattern
;
6027 LogicalResult
matchAndRewrite(vector::TransposeOp transposeOp
,
6028 PatternRewriter
&rewriter
) const override
{
6029 auto bcastOp
= transposeOp
.getVector().getDefiningOp
<vector::BroadcastOp
>();
6033 auto srcVectorType
= llvm::dyn_cast
<VectorType
>(bcastOp
.getSourceType());
6034 if (!srcVectorType
|| srcVectorType
.getNumElements() == 1) {
6035 rewriter
.replaceOpWithNewOp
<vector::BroadcastOp
>(
6036 transposeOp
, transposeOp
.getResultVectorType(), bcastOp
.getSource());
6044 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6045 class FoldTransposeSplat final
: public OpRewritePattern
<TransposeOp
> {
6047 using OpRewritePattern::OpRewritePattern
;
6049 LogicalResult
matchAndRewrite(TransposeOp transposeOp
,
6050 PatternRewriter
&rewriter
) const override
{
6051 auto splatOp
= transposeOp
.getVector().getDefiningOp
<vector::SplatOp
>();
6055 rewriter
.replaceOpWithNewOp
<vector::SplatOp
>(
6056 transposeOp
, transposeOp
.getResultVectorType(), splatOp
.getInput());
6061 /// Folds transpose(create_mask) into a new transposed create_mask.
6062 class FoldTransposeCreateMask final
: public OpRewritePattern
<TransposeOp
> {
6064 using OpRewritePattern::OpRewritePattern
;
6066 LogicalResult
matchAndRewrite(TransposeOp transpOp
,
6067 PatternRewriter
&rewriter
) const override
{
6068 Value transposeSrc
= transpOp
.getVector();
6069 auto createMaskOp
= transposeSrc
.getDefiningOp
<vector::CreateMaskOp
>();
6070 auto constantMaskOp
= transposeSrc
.getDefiningOp
<vector::ConstantMaskOp
>();
6071 if (!createMaskOp
&& !constantMaskOp
)
6074 // Get the transpose permutation and apply it to the vector.create_mask or
6075 // vector.constant_mask operands.
6076 ArrayRef
<int64_t> permutation
= transpOp
.getPermutation();
6079 auto maskOperands
= createMaskOp
.getOperands();
6080 SmallVector
<Value
> newOperands(maskOperands
.begin(), maskOperands
.end());
6081 applyPermutationToVector(newOperands
, permutation
);
6083 rewriter
.replaceOpWithNewOp
<vector::CreateMaskOp
>(
6084 transpOp
, transpOp
.getResultVectorType(), newOperands
);
6088 // ConstantMaskOp case.
6089 auto maskDimSizes
= constantMaskOp
.getMaskDimSizes();
6090 auto newMaskDimSizes
= applyPermutation(maskDimSizes
, permutation
);
6092 rewriter
.replaceOpWithNewOp
<vector::ConstantMaskOp
>(
6093 transpOp
, transpOp
.getResultVectorType(), newMaskDimSizes
);
6100 void vector::TransposeOp::getCanonicalizationPatterns(
6101 RewritePatternSet
&results
, MLIRContext
*context
) {
6102 results
.add
<FoldTransposeCreateMask
, FoldTransposedScalarBroadcast
,
6103 TransposeFolder
, FoldTransposeSplat
>(context
);
6106 //===----------------------------------------------------------------------===//
6108 //===----------------------------------------------------------------------===//
6110 void ConstantMaskOp::build(OpBuilder
&builder
, OperationState
&result
,
6111 VectorType type
, ConstantMaskKind kind
) {
6112 assert(kind
== ConstantMaskKind::AllTrue
||
6113 kind
== ConstantMaskKind::AllFalse
);
6114 build(builder
, result
, type
,
6115 kind
== ConstantMaskKind::AllTrue
6117 : SmallVector
<int64_t>(type
.getRank(), 0));
6120 LogicalResult
ConstantMaskOp::verify() {
6121 auto resultType
= llvm::cast
<VectorType
>(getResult().getType());
6122 // Check the corner case of 0-D vectors first.
6123 if (resultType
.getRank() == 0) {
6124 if (getMaskDimSizes().size() != 1)
6125 return emitError("array attr must have length 1 for 0-D vectors");
6126 auto dim
= getMaskDimSizes()[0];
6127 if (dim
!= 0 && dim
!= 1)
6128 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6132 // Verify that array attr size matches the rank of the vector result.
6133 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType
.getRank())
6135 "must specify array attr of size equal vector result rank");
6136 // Verify that each array attr element is in bounds of corresponding vector
6137 // result dimension size.
6138 auto resultShape
= resultType
.getShape();
6139 auto resultScalableDims
= resultType
.getScalableDims();
6140 ArrayRef
<int64_t> maskDimSizes
= getMaskDimSizes();
6141 for (const auto [index
, maskDimSize
] : llvm::enumerate(maskDimSizes
)) {
6142 if (maskDimSize
< 0 || maskDimSize
> resultShape
[index
])
6144 "array attr of size out of bounds of vector result dimension size");
6145 if (resultScalableDims
[index
] && maskDimSize
!= 0 &&
6146 maskDimSize
!= resultShape
[index
])
6148 "only supports 'none set' or 'all set' scalable dimensions");
6150 // Verify that if one mask dim size is zero, they all should be zero (because
6151 // the mask region is a conjunction of each mask dimension interval).
6152 bool anyZeros
= llvm::is_contained(maskDimSizes
, 0);
6153 bool allZeros
= llvm::all_of(maskDimSizes
, [](int64_t s
) { return s
== 0; });
6154 if (anyZeros
&& !allZeros
)
6155 return emitOpError("expected all mask dim sizes to be zeros, "
6156 "as a result of conjunction with zero mask dim");
6160 bool ConstantMaskOp::isAllOnesMask() {
6161 auto resultType
= getVectorType();
6162 // Check the corner case of 0-D vectors first.
6163 if (resultType
.getRank() == 0) {
6164 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6165 return getMaskDimSizes()[0] == 1;
6167 for (const auto [resultSize
, maskDimSize
] :
6168 llvm::zip_equal(resultType
.getShape(), getMaskDimSizes())) {
6169 if (maskDimSize
< resultSize
)
6175 //===----------------------------------------------------------------------===//
6177 //===----------------------------------------------------------------------===//
6179 void CreateMaskOp::build(OpBuilder
&builder
, OperationState
&result
,
6181 ArrayRef
<OpFoldResult
> mixedOperands
) {
6182 SmallVector
<Value
> operands
=
6183 getValueOrCreateConstantIndexOp(builder
, result
.location
, mixedOperands
);
6184 build(builder
, result
, type
, operands
);
6187 LogicalResult
CreateMaskOp::verify() {
6188 auto vectorType
= llvm::cast
<VectorType
>(getResult().getType());
6189 // Verify that an operand was specified for each result vector each dimension.
6190 if (vectorType
.getRank() == 0) {
6191 if (getNumOperands() != 1)
6193 "must specify exactly one operand for 0-D create_mask");
6194 } else if (getNumOperands() !=
6195 llvm::cast
<VectorType
>(getResult().getType()).getRank()) {
6197 "must specify an operand for each result vector dimension");
6204 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6207 /// %c2 = arith.constant 2 : index
6208 /// %c3 = arith.constant 3 : index
6209 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6211 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6214 /// %c_neg_1 = arith.constant -1 : index
6215 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6217 /// vector.constant_mask [0] : vector<[8]xi1>
6220 /// %c8 = arith.constant 8 : index
6221 /// %c16 = arith.constant 16 : index
6222 /// %0 = vector.vscale
6223 /// %1 = arith.muli %0, %c16 : index
6224 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6226 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6227 class CreateMaskFolder final
: public OpRewritePattern
<CreateMaskOp
> {
6229 using OpRewritePattern::OpRewritePattern
;
6231 LogicalResult
matchAndRewrite(CreateMaskOp createMaskOp
,
6232 PatternRewriter
&rewriter
) const override
{
6233 VectorType maskType
= createMaskOp
.getVectorType();
6234 ArrayRef
<int64_t> maskTypeDimSizes
= maskType
.getShape();
6235 ArrayRef
<bool> maskTypeDimScalableFlags
= maskType
.getScalableDims();
6237 // Special case: Rank zero shape.
6238 constexpr std::array
<int64_t, 1> rankZeroShape
{1};
6239 constexpr std::array
<bool, 1> rankZeroScalableDims
{false};
6240 if (maskType
.getRank() == 0) {
6241 maskTypeDimSizes
= rankZeroShape
;
6242 maskTypeDimScalableFlags
= rankZeroScalableDims
;
6245 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6246 // collect the `constantDims` (for the ConstantMaskOp).
6247 SmallVector
<int64_t, 4> constantDims
;
6248 for (auto [i
, dimSize
] : llvm::enumerate(createMaskOp
.getOperands())) {
6249 if (auto intSize
= getConstantIntValue(dimSize
)) {
6251 // If the mask dim is non-scalable this can be any value.
6252 // If the mask dim is scalable only zero (all-false) is supported.
6253 if (maskTypeDimScalableFlags
[i
] && intSize
>= 0)
6255 constantDims
.push_back(*intSize
);
6256 } else if (auto vscaleMultiplier
= getConstantVscaleMultiplier(dimSize
)) {
6257 // Constant vscale multiple (e.g. 4 x vscale).
6258 // Must be all-true to fold to a ConstantMask.
6259 if (vscaleMultiplier
< maskTypeDimSizes
[i
])
6261 constantDims
.push_back(*vscaleMultiplier
);
6267 // Clamp values to constant_mask bounds.
6268 for (auto [value
, maskDimSize
] : llvm::zip(constantDims
, maskTypeDimSizes
))
6269 value
= std::clamp
<int64_t>(value
, 0, maskDimSize
);
6271 // If one of dim sizes is zero, set all dims to zero.
6272 if (llvm::is_contained(constantDims
, 0))
6273 constantDims
.assign(constantDims
.size(), 0);
6275 // Replace 'createMaskOp' with ConstantMaskOp.
6276 rewriter
.replaceOpWithNewOp
<ConstantMaskOp
>(createMaskOp
, maskType
,
6284 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
6285 MLIRContext
*context
) {
6286 results
.add
<CreateMaskFolder
>(context
);
6289 //===----------------------------------------------------------------------===//
6291 //===----------------------------------------------------------------------===//
6294 OpBuilder
&builder
, OperationState
&result
, Value mask
,
6295 Operation
*maskableOp
,
6296 function_ref
<void(OpBuilder
&, Operation
*)> maskRegionBuilder
) {
6297 assert(maskRegionBuilder
&&
6298 "builder callback for 'maskRegion' must be present");
6300 result
.addOperands(mask
);
6301 OpBuilder::InsertionGuard
guard(builder
);
6302 Region
*maskRegion
= result
.addRegion();
6303 builder
.createBlock(maskRegion
);
6304 maskRegionBuilder(builder
, maskableOp
);
6308 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTypes
,
6309 Value mask
, Operation
*maskableOp
,
6310 function_ref
<void(OpBuilder
&, Operation
*)> maskRegionBuilder
) {
6311 build(builder
, result
, resultTypes
, mask
, /*passthru=*/Value(), maskableOp
,
6316 OpBuilder
&builder
, OperationState
&result
, TypeRange resultTypes
,
6317 Value mask
, Value passthru
, Operation
*maskableOp
,
6318 function_ref
<void(OpBuilder
&, Operation
*)> maskRegionBuilder
) {
6319 build(builder
, result
, mask
, maskableOp
, maskRegionBuilder
);
6321 result
.addOperands(passthru
);
6322 result
.addTypes(resultTypes
);
6325 ParseResult
MaskOp::parse(OpAsmParser
&parser
, OperationState
&result
) {
6326 // Create the op region.
6327 result
.regions
.reserve(1);
6328 Region
&maskRegion
= *result
.addRegion();
6330 auto &builder
= parser
.getBuilder();
6332 // Parse all the operands.
6333 OpAsmParser::UnresolvedOperand mask
;
6334 if (parser
.parseOperand(mask
))
6337 // Optional passthru operand.
6338 OpAsmParser::UnresolvedOperand passthru
;
6339 ParseResult parsePassthru
= parser
.parseOptionalComma();
6340 if (parsePassthru
.succeeded() && parser
.parseOperand(passthru
))
6344 if (parser
.parseRegion(maskRegion
, /*arguments=*/{}, /*argTypes=*/{}))
6347 MaskOp::ensureTerminator(maskRegion
, builder
, result
.location
);
6349 // Parse the optional attribute list.
6350 if (parser
.parseOptionalAttrDict(result
.attributes
))
6353 // Parse all the types.
6355 if (parser
.parseColonType(maskType
))
6358 SmallVector
<Type
> resultTypes
;
6359 if (parser
.parseOptionalArrowTypeList(resultTypes
))
6361 result
.types
.append(resultTypes
);
6363 // Resolve operands.
6364 if (parser
.resolveOperand(mask
, maskType
, result
.operands
))
6367 if (parsePassthru
.succeeded())
6368 if (parser
.resolveOperand(passthru
, resultTypes
[0], result
.operands
))
6374 void mlir::vector::MaskOp::print(OpAsmPrinter
&p
) {
6375 p
<< " " << getMask();
6377 p
<< ", " << getPassthru();
6379 // Print single masked operation and skip terminator.
6381 Block
*singleBlock
= &getMaskRegion().getBlocks().front();
6382 if (singleBlock
&& !singleBlock
->getOperations().empty())
6383 p
.printCustomOrGenericOp(&singleBlock
->front());
6386 p
.printOptionalAttrDict(getOperation()->getAttrs());
6388 p
<< " : " << getMask().getType();
6389 if (getNumResults() > 0)
6390 p
<< " -> " << getResultTypes();
6393 void MaskOp::ensureTerminator(Region
®ion
, Builder
&builder
, Location loc
) {
6394 OpTrait::SingleBlockImplicitTerminator
<vector::YieldOp
>::Impl
<
6395 MaskOp
>::ensureTerminator(region
, builder
, loc
);
6396 // Keep the default yield terminator if the number of masked operations is not
6397 // the expected. This case will trigger a verification failure.
6398 Block
&block
= region
.front();
6399 if (block
.getOperations().size() != 2)
6402 // Replace default yield terminator with a new one that returns the results
6403 // from the masked operation.
6404 OpBuilder
opBuilder(builder
.getContext());
6405 Operation
*maskedOp
= &block
.front();
6406 Operation
*oldYieldOp
= &block
.back();
6407 assert(isa
<vector::YieldOp
>(oldYieldOp
) && "Expected vector::YieldOp");
6409 // Empty vector.mask op.
6410 if (maskedOp
== oldYieldOp
)
6413 opBuilder
.setInsertionPoint(oldYieldOp
);
6414 opBuilder
.create
<vector::YieldOp
>(loc
, maskedOp
->getResults());
6415 oldYieldOp
->dropAllReferences();
6416 oldYieldOp
->erase();
6419 LogicalResult
MaskOp::verify() {
6420 // Structural checks.
6421 Block
&block
= getMaskRegion().getBlocks().front();
6422 if (block
.getOperations().empty())
6423 return emitOpError("expects a terminator within the mask region");
6425 unsigned numMaskRegionOps
= block
.getOperations().size();
6426 if (numMaskRegionOps
> 2)
6427 return emitOpError("expects only one operation to mask");
6429 // Terminator checks.
6430 auto terminator
= dyn_cast
<vector::YieldOp
>(block
.back());
6432 return emitOpError("expects a terminator within the mask region");
6434 if (terminator
->getNumOperands() != getNumResults())
6436 "expects number of results to match mask region yielded values");
6438 // Empty vector.mask. Nothing else to check.
6439 if (numMaskRegionOps
== 1)
6442 auto maskableOp
= dyn_cast
<MaskableOpInterface
>(block
.front());
6444 return emitOpError("expects a MaskableOpInterface within the mask region");
6447 if (maskableOp
->getNumResults() != getNumResults())
6448 return emitOpError("expects number of results to match maskable operation "
6449 "number of results");
6451 if (!llvm::equal(maskableOp
->getResultTypes(), getResultTypes()))
6453 "expects result type to match maskable operation result type");
6455 if (llvm::count_if(maskableOp
->getResultTypes(),
6456 [](Type t
) { return llvm::isa
<VectorType
>(t
); }) > 1)
6457 return emitOpError("multiple vector results not supported");
6460 Type expectedMaskType
= maskableOp
.getExpectedMaskType();
6461 if (getMask().getType() != expectedMaskType
)
6462 return emitOpError("expects a ")
6463 << expectedMaskType
<< " mask for the maskable operation";
6466 Value passthru
= getPassthru();
6468 if (!maskableOp
.supportsPassthru())
6470 "doesn't expect a passthru argument for this maskable operation");
6472 if (maskableOp
->getNumResults() != 1)
6473 return emitOpError("expects result when passthru argument is provided");
6475 if (passthru
.getType() != maskableOp
->getResultTypes()[0])
6476 return emitOpError("expects passthru type to match result type");
6482 /// Folds vector.mask ops with an all-true mask.
6483 LogicalResult
MaskOp::fold(FoldAdaptor adaptor
,
6484 SmallVectorImpl
<OpFoldResult
> &results
) {
6485 MaskFormat maskFormat
= getMaskFormat(getMask());
6489 if (maskFormat
!= MaskFormat::AllTrue
)
6492 // Move maskable operation outside of the `vector.mask` region.
6493 Operation
*maskableOp
= getMaskableOp();
6494 maskableOp
->dropAllUses();
6495 maskableOp
->moveBefore(getOperation());
6497 llvm::append_range(results
, maskableOp
->getResults());
6501 // Elides empty vector.mask operations with or without return values. Propagates
6502 // the yielded values by the vector.yield terminator, if any, or erases the op,
6504 class ElideEmptyMaskOp
: public OpRewritePattern
<MaskOp
> {
6505 using OpRewritePattern::OpRewritePattern
;
6507 LogicalResult
matchAndRewrite(MaskOp maskOp
,
6508 PatternRewriter
&rewriter
) const override
{
6509 auto maskingOp
= cast
<MaskingOpInterface
>(maskOp
.getOperation());
6510 if (maskingOp
.getMaskableOp())
6513 if (!maskOp
.isEmpty())
6516 Block
*block
= maskOp
.getMaskBlock();
6517 auto terminator
= cast
<vector::YieldOp
>(block
->front());
6518 if (terminator
.getNumOperands() == 0)
6519 rewriter
.eraseOp(maskOp
);
6521 rewriter
.replaceOp(maskOp
, terminator
.getOperands());
6527 void MaskOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
6528 MLIRContext
*context
) {
6529 results
.add
<ElideEmptyMaskOp
>(context
);
6532 // MaskingOpInterface definitions.
6534 /// Returns the operation masked by this 'vector.mask'.
6535 Operation
*MaskOp::getMaskableOp() {
6536 Block
*block
= getMaskBlock();
6537 if (block
->getOperations().size() < 2)
6540 return &block
->front();
6543 /// Returns true if 'vector.mask' has a passthru value.
6544 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6546 //===----------------------------------------------------------------------===//
6548 //===----------------------------------------------------------------------===//
6550 LogicalResult
ScanOp::verify() {
6551 VectorType srcType
= getSourceType();
6552 VectorType initialType
= getInitialValueType();
6553 // Check reduction dimension < rank.
6554 int64_t srcRank
= srcType
.getRank();
6555 int64_t reductionDim
= getReductionDim();
6556 if (reductionDim
>= srcRank
)
6557 return emitOpError("reduction dimension ")
6558 << reductionDim
<< " has to be less than " << srcRank
;
6560 // Check that rank(initial_value) = rank(src) - 1.
6561 int64_t initialValueRank
= initialType
.getRank();
6562 if (initialValueRank
!= srcRank
- 1)
6563 return emitOpError("initial value rank ")
6564 << initialValueRank
<< " has to be equal to " << srcRank
- 1;
6566 // Check shapes of initial value and src.
6567 ArrayRef
<int64_t> srcShape
= srcType
.getShape();
6568 ArrayRef
<int64_t> initialValueShapes
= initialType
.getShape();
6569 SmallVector
<int64_t> expectedShape
;
6570 for (int i
= 0; i
< srcRank
; i
++) {
6571 if (i
!= reductionDim
)
6572 expectedShape
.push_back(srcShape
[i
]);
6574 if (!llvm::equal(initialValueShapes
, expectedShape
)) {
6575 return emitOpError("incompatible input/initial value shapes");
6578 // Verify supported reduction kind.
6579 Type eltType
= getDestType().getElementType();
6580 if (!isSupportedCombiningKind(getKind(), eltType
))
6581 return emitOpError("unsupported reduction type ")
6582 << eltType
<< " for kind '" << stringifyCombiningKind(getKind())
6588 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
6589 RewritePatternSet
&patterns
, PatternBenefit benefit
) {
6591 .add
<CreateMaskFolder
, MaskedLoadFolder
, MaskedStoreFolder
, GatherFolder
,
6592 ScatterFolder
, ExpandLoadFolder
, CompressStoreFolder
,
6593 StridedSliceConstantMaskFolder
, TransposeFolder
>(
6594 patterns
.getContext(), benefit
);
6597 //===----------------------------------------------------------------------===//
6599 //===----------------------------------------------------------------------===//
6601 OpFoldResult
SplatOp::fold(FoldAdaptor adaptor
) {
6602 auto constOperand
= adaptor
.getInput();
6603 if (!isa_and_nonnull
<IntegerAttr
, FloatAttr
>(constOperand
))
6606 // SplatElementsAttr::get treats single value for second arg as being a splat.
6607 return SplatElementsAttr::get(getType(), {constOperand
});
6610 void SplatOp::inferResultRanges(ArrayRef
<ConstantIntRanges
> argRanges
,
6611 SetIntRangeFn setResultRanges
) {
6612 setResultRanges(getResult(), argRanges
.front());
6615 Value
mlir::vector::makeArithReduction(OpBuilder
&b
, Location loc
,
6616 CombiningKind kind
, Value v1
, Value acc
,
6617 arith::FastMathFlagsAttr fastmath
,
6619 Type t1
= getElementTypeOrSelf(v1
.getType());
6620 Type tAcc
= getElementTypeOrSelf(acc
.getType());
6624 case CombiningKind::ADD
:
6625 if (t1
.isIntOrIndex() && tAcc
.isIntOrIndex())
6626 result
= b
.createOrFold
<arith::AddIOp
>(loc
, v1
, acc
);
6627 else if (llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
))
6628 result
= b
.createOrFold
<arith::AddFOp
>(loc
, v1
, acc
, fastmath
);
6630 llvm_unreachable("invalid value types for ADD reduction");
6632 case CombiningKind::AND
:
6633 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6634 result
= b
.createOrFold
<arith::AndIOp
>(loc
, v1
, acc
);
6636 case CombiningKind::MAXNUMF
:
6637 assert(llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
) &&
6638 "expected float values");
6639 result
= b
.createOrFold
<arith::MaxNumFOp
>(loc
, v1
, acc
, fastmath
);
6641 case CombiningKind::MAXIMUMF
:
6642 assert(llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
) &&
6643 "expected float values");
6644 result
= b
.createOrFold
<arith::MaximumFOp
>(loc
, v1
, acc
, fastmath
);
6646 case CombiningKind::MINNUMF
:
6647 assert(llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
) &&
6648 "expected float values");
6649 result
= b
.createOrFold
<arith::MinNumFOp
>(loc
, v1
, acc
, fastmath
);
6651 case CombiningKind::MINIMUMF
:
6652 assert(llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
) &&
6653 "expected float values");
6654 result
= b
.createOrFold
<arith::MinimumFOp
>(loc
, v1
, acc
, fastmath
);
6656 case CombiningKind::MAXSI
:
6657 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6658 result
= b
.createOrFold
<arith::MaxSIOp
>(loc
, v1
, acc
);
6660 case CombiningKind::MINSI
:
6661 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6662 result
= b
.createOrFold
<arith::MinSIOp
>(loc
, v1
, acc
);
6664 case CombiningKind::MAXUI
:
6665 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6666 result
= b
.createOrFold
<arith::MaxUIOp
>(loc
, v1
, acc
);
6668 case CombiningKind::MINUI
:
6669 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6670 result
= b
.createOrFold
<arith::MinUIOp
>(loc
, v1
, acc
);
6672 case CombiningKind::MUL
:
6673 if (t1
.isIntOrIndex() && tAcc
.isIntOrIndex())
6674 result
= b
.createOrFold
<arith::MulIOp
>(loc
, v1
, acc
);
6675 else if (llvm::isa
<FloatType
>(t1
) && llvm::isa
<FloatType
>(tAcc
))
6676 result
= b
.createOrFold
<arith::MulFOp
>(loc
, v1
, acc
, fastmath
);
6678 llvm_unreachable("invalid value types for MUL reduction");
6680 case CombiningKind::OR
:
6681 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6682 result
= b
.createOrFold
<arith::OrIOp
>(loc
, v1
, acc
);
6684 case CombiningKind::XOR
:
6685 assert(t1
.isIntOrIndex() && tAcc
.isIntOrIndex() && "expected int values");
6686 result
= b
.createOrFold
<arith::XOrIOp
>(loc
, v1
, acc
);
6690 assert(result
&& "unknown CombiningKind");
6691 return selectPassthru(b
, mask
, result
, acc
);
6694 //===----------------------------------------------------------------------===//
6695 // Vector Masking Utilities
6696 //===----------------------------------------------------------------------===//
6698 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6699 /// as masked operation.
6700 void mlir::vector::createMaskOpRegion(OpBuilder
&builder
,
6701 Operation
*maskableOp
) {
6702 assert(maskableOp
->getBlock() && "MaskableOp must be inserted into a block");
6703 Block
*insBlock
= builder
.getInsertionBlock();
6704 // Create a block and move the op to that block.
6705 insBlock
->getOperations().splice(
6706 insBlock
->begin(), maskableOp
->getBlock()->getOperations(), maskableOp
);
6707 builder
.create
<YieldOp
>(maskableOp
->getLoc(), maskableOp
->getResults());
6710 /// Creates a vector.mask operation around a maskable operation. Returns the
6711 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6712 /// the maskable operation itself.
6713 Operation
*mlir::vector::maskOperation(OpBuilder
&builder
,
6714 Operation
*maskableOp
, Value mask
,
6719 return builder
.create
<MaskOp
>(maskableOp
->getLoc(),
6720 maskableOp
->getResultTypes(), mask
, passthru
,
6721 maskableOp
, createMaskOpRegion
);
6722 return builder
.create
<MaskOp
>(maskableOp
->getLoc(),
6723 maskableOp
->getResultTypes(), mask
, maskableOp
,
6724 createMaskOpRegion
);
6727 /// Creates a vector select operation that picks values from `newValue` or
6728 /// `passthru` for each result vector lane based on `mask`. This utility is used
6729 /// to propagate the pass-thru value of vector.mask or for cases where only the
6730 /// pass-thru value propagation is needed. VP intrinsics do not support
6731 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6732 /// usually able to match op + select patterns and fold them into a native
6733 /// target instructions.
6734 Value
mlir::vector::selectPassthru(OpBuilder
&builder
, Value mask
,
6735 Value newValue
, Value passthru
) {
6739 return builder
.create
<arith::SelectOp
>(newValue
.getLoc(), newValue
.getType(),
6740 mask
, newValue
, passthru
);
6743 //===----------------------------------------------------------------------===//
6744 // TableGen'd op method definitions
6745 //===----------------------------------------------------------------------===//
6747 #define GET_ATTRDEF_CLASSES
6748 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6750 #define GET_OP_CLASSES
6751 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"