[mlir][Vector] Fix `vector.shuffle` folder for poison indices (#124863)
[llvm-project.git] / mlir / lib / Dialect / Vector / IR / VectorOps.cpp
blob93f89eda2da5a6b0f00b921f020e725def6d1ec6
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/UB/IR/UBOps.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/DialectImplementation.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/OpImplementation.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Interfaces/SubsetOpInterface.h"
37 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
38 #include "mlir/Support/LLVM.h"
39 #include "mlir/Transforms/InliningUtils.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/ADT/bit.h"
47 #include <cassert>
48 #include <cstdint>
49 #include <numeric>
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55 using namespace mlir;
56 using namespace mlir::vector;
58 /// Helper enum to classify mask value.
59 enum class MaskFormat {
60 AllTrue = 0,
61 AllFalse = 1,
62 Unknown = 2,
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
69 static MaskFormat getMaskFormat(Value mask) {
70 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71 // Inspect constant dense values. We count up for bits that
72 // are set, count down for bits that are cleared, and bail
73 // when a mix is detected.
74 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 int64_t val = 0;
76 for (bool b : denseElts.getValues<bool>())
77 if (b && val >= 0)
78 val++;
79 else if (!b && val <= 0)
80 val--;
81 else
82 return MaskFormat::Unknown;
83 if (val > 0)
84 return MaskFormat::AllTrue;
85 if (val < 0)
86 return MaskFormat::AllFalse;
88 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89 // Inspect constant mask index. If the index exceeds the
90 // dimension size, all bits are set. If the index is zero
91 // or less, no bits are set.
92 ArrayRef<int64_t> masks = m.getMaskDimSizes();
93 auto shape = m.getType().getShape();
94 bool allTrue = true;
95 bool allFalse = true;
96 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97 if (maskIdx < dimSize)
98 allTrue = false;
99 if (maskIdx > 0)
100 allFalse = false;
102 if (allTrue)
103 return MaskFormat::AllTrue;
104 if (allFalse)
105 return MaskFormat::AllFalse;
106 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107 // Finds all-false create_masks. An all-true create_mask requires all
108 // dims to be constants, so that'll be folded to a constant_mask, then
109 // detected in the constant_mask case.
110 auto maskOperands = m.getOperands();
111 for (Value operand : maskOperands) {
112 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 int64_t dimSize =
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115 if (dimSize <= 0)
116 return MaskFormat::AllFalse;
119 return MaskFormat::Unknown;
121 return MaskFormat::Unknown;
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
126 void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
127 builder.create<vector::YieldOp>(loc);
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132 Type elementType) {
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
136 return elementType.isIntOrIndexOrFloat();
137 case CombiningKind::MINUI:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
144 return elementType.isIntOrIndex();
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
151 return false;
154 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
155 VectorType vectorType) {
156 int64_t elementVectorRank = 0;
157 VectorType elementVectorType =
158 llvm::dyn_cast<VectorType>(shapedType.getElementType());
159 if (elementVectorType)
160 elementVectorRank += elementVectorType.getRank();
161 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162 // TODO: replace once we have 0-d vectors.
163 if (shapedType.getRank() == 0 &&
164 vectorType.getShape() == ArrayRef<int64_t>{1})
165 return AffineMap::get(
166 /*numDims=*/0, /*numSymbols=*/0,
167 getAffineConstantExpr(0, shapedType.getContext()));
168 return AffineMap::getMinorIdentityMap(
169 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170 shapedType.getContext());
173 /// Check if `write` is of a constant splat and the masked `read` is padded with
174 /// the same splat value -- meaning it could be the same value as the initial
175 /// constant splat.
176 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
177 vector::TransferReadOp read) {
178 auto readMask = read.getMask();
179 auto writeMask = write.getMask();
180 // Check if the masks are consistent. The splat value could be the same if the
181 // read is masked (and padded with the splat value), and the write is unmasked
182 // or has the same mask. Note this does not allow the case where the write is
183 // masked and the read is unmasked, as then the read could be of more elements
184 // than the write (which may not be the same value).
185 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
186 if (!couldBeSameSplat)
187 return false;
188 // Check for constant splat (as the source of the write).
189 DenseElementsAttr splatAttr;
190 if (!matchPattern(write.getVector(),
191 m_Constant<DenseElementsAttr>(&splatAttr)) ||
192 !splatAttr.isSplat()) {
193 return false;
195 // The padding of the read and the constant splat value must be the same.
196 Attribute padAttr;
197 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
198 return false;
199 return padAttr == splatAttr.getSplatValue<Attribute>();
202 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
203 vector::TransferReadOp read) {
204 return !defWrite.hasOutOfBoundsDim() &&
205 defWrite.getIndices() == read.getIndices() &&
206 defWrite.getVectorType() == read.getVectorType() &&
207 defWrite.getPermutationMap() == read.getPermutationMap() &&
208 ((!defWrite.getMask() && !read.getMask()) ||
209 isSplatWriteConsistentWithMaskedRead(defWrite, read));
212 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
213 vector::TransferWriteOp priorWrite) {
214 return priorWrite.getIndices() == write.getIndices() &&
215 priorWrite.getMask() == write.getMask() &&
216 priorWrite.getVectorType() == write.getVectorType() &&
217 priorWrite.getPermutationMap() == write.getPermutationMap();
220 bool mlir::vector::isDisjointTransferIndices(
221 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
222 bool testDynamicValueUsingBounds) {
223 // For simplicity only look at transfer of same type.
224 if (transferA.getVectorType() != transferB.getVectorType())
225 return false;
226 unsigned rankOffset = transferA.getLeadingShapedRank();
227 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
228 Value indexA = transferA.getIndices()[i];
229 Value indexB = transferB.getIndices()[i];
230 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
231 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
233 if (i < rankOffset) {
234 // For leading dimensions, if we can prove that index are different we
235 // know we are accessing disjoint slices.
236 if (cstIndexA.has_value() && cstIndexB.has_value()) {
237 if (*cstIndexA != *cstIndexB)
238 return true;
239 continue;
241 if (testDynamicValueUsingBounds) {
242 // First try to see if we can fully compose and simplify the affine
243 // expression as a fast track.
244 FailureOr<uint64_t> delta =
245 affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
246 if (succeeded(delta) && *delta != 0)
247 return true;
249 FailureOr<bool> testEqual =
250 ValueBoundsConstraintSet::areEqual(indexA, indexB);
251 if (succeeded(testEqual) && !testEqual.value())
252 return true;
254 } else {
255 // For this dimension, we slice a part of the memref we need to make sure
256 // the intervals accessed don't overlap.
257 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
258 if (cstIndexA.has_value() && cstIndexB.has_value()) {
259 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
260 if (distance >= vectorDim)
261 return true;
262 continue;
264 if (testDynamicValueUsingBounds) {
265 // First try to see if we can fully compose and simplify the affine
266 // expression as a fast track.
267 FailureOr<int64_t> delta =
268 affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
269 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
270 return true;
272 FailureOr<int64_t> computeDelta =
273 ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
274 if (succeeded(computeDelta)) {
275 if (std::abs(computeDelta.value()) >= vectorDim)
276 return true;
281 return false;
284 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
285 VectorTransferOpInterface transferB,
286 bool testDynamicValueUsingBounds) {
287 if (transferA.getSource() != transferB.getSource())
288 return false;
289 return isDisjointTransferIndices(transferA, transferB,
290 testDynamicValueUsingBounds);
293 // Helper to iterate over n-D vector slice elements. Calculate the next
294 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
295 // Modifies the `position` in place. Returns a failure when `position` becomes
296 // the end position.
297 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
298 ArrayRef<int64_t> shape,
299 ArrayRef<int64_t> offsets) {
300 for (auto [posInDim, dimSize, offsetInDim] :
301 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 ++posInDim;
303 if (posInDim < dimSize + offsetInDim)
304 return success();
306 // Carry the overflow to the next loop iteration.
307 posInDim = offsetInDim;
310 return failure();
313 /// Returns the integer numbers in `values`. `values` are expected to be
314 /// constant operations.
315 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
316 SmallVector<int64_t> ints;
317 llvm::transform(values, std::back_inserter(ints), [](Value value) {
318 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
319 assert(constOp && "Unexpected non-constant index");
320 return constOp.value();
322 return ints;
325 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
326 /// be constant operations.
327 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
328 SmallVector<int64_t> ints;
329 llvm::transform(
330 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
331 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
332 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
334 return ints;
337 /// Convert `foldResults` into Values. Integer attributes are converted to
338 /// constant op.
339 SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
340 ArrayRef<OpFoldResult> foldResults) {
341 SmallVector<Value> values;
342 llvm::transform(foldResults, std::back_inserter(values),
343 [&](OpFoldResult foldResult) {
344 if (auto attr = foldResult.dyn_cast<Attribute>())
345 return builder
346 .create<arith::ConstantIndexOp>(
347 loc, cast<IntegerAttr>(attr).getInt())
348 .getResult();
350 return cast<Value>(foldResult);
352 return values;
355 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
356 if (value.getDefiningOp<vector::VectorScaleOp>())
357 return 1;
358 auto mul = value.getDefiningOp<arith::MulIOp>();
359 if (!mul)
360 return {};
361 auto lhs = mul.getLhs();
362 auto rhs = mul.getRhs();
363 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 return getConstantIntValue(rhs);
365 if (rhs.getDefiningOp<vector::VectorScaleOp>())
366 return getConstantIntValue(lhs);
367 return {};
370 //===----------------------------------------------------------------------===//
371 // CombiningKindAttr
372 //===----------------------------------------------------------------------===//
374 namespace mlir {
375 namespace vector {
376 namespace detail {
377 struct BitmaskEnumStorage : public AttributeStorage {
378 using KeyTy = uint64_t;
380 BitmaskEnumStorage(KeyTy val) : value(val) {}
382 bool operator==(const KeyTy &key) const { return value == key; }
384 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
385 const KeyTy &key) {
386 return new (allocator.allocate<BitmaskEnumStorage>())
387 BitmaskEnumStorage(key);
390 KeyTy value = 0;
392 } // namespace detail
393 } // namespace vector
394 } // namespace mlir
396 //===----------------------------------------------------------------------===//
397 // VectorDialect
398 //===----------------------------------------------------------------------===//
400 namespace {
401 /// This class defines the interface for handling inlining with vector dialect
402 /// operations.
403 struct VectorInlinerInterface : public DialectInlinerInterface {
404 using DialectInlinerInterface::DialectInlinerInterface;
406 /// All vector dialect ops can be inlined.
407 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
408 return true;
411 } // namespace
413 void VectorDialect::initialize() {
414 addAttributes<
415 #define GET_ATTRDEF_LIST
416 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
417 >();
419 addOperations<
420 #define GET_OP_LIST
421 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
422 >();
424 addInterfaces<VectorInlinerInterface>();
426 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
427 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 YieldOp>();
429 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 TransferWriteOp>();
431 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
432 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
437 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
438 Attribute value, Type type,
439 Location loc) {
440 return arith::ConstantOp::materialize(builder, value, type, loc);
443 IntegerType vector::getVectorSubscriptType(Builder &builder) {
444 return builder.getIntegerType(64);
447 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
448 ArrayRef<int64_t> values) {
449 return builder.getI64ArrayAttr(values);
452 //===----------------------------------------------------------------------===//
453 // MultiDimReductionOp
454 //===----------------------------------------------------------------------===//
456 void vector::MultiDimReductionOp::build(OpBuilder &builder,
457 OperationState &result, Value source,
458 Value acc, ArrayRef<bool> reductionMask,
459 CombiningKind kind) {
460 SmallVector<int64_t> reductionDims;
461 for (const auto &en : llvm::enumerate(reductionMask))
462 if (en.value())
463 reductionDims.push_back(en.index());
464 build(builder, result, kind, source, acc, reductionDims);
467 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
468 // Single parallel dim, this is a noop.
469 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
470 return getSource();
471 return {};
474 std::optional<SmallVector<int64_t, 4>>
475 MultiDimReductionOp::getShapeForUnroll() {
476 return llvm::to_vector<4>(getSourceVectorType().getShape());
479 LogicalResult MultiDimReductionOp::verify() {
480 SmallVector<int64_t> targetShape;
481 SmallVector<bool> scalableDims;
482 Type inferredReturnType;
483 auto sourceScalableDims = getSourceVectorType().getScalableDims();
484 for (auto [dimIdx, dimSize] :
485 llvm::enumerate(getSourceVectorType().getShape()))
486 if (!llvm::any_of(getReductionDims(),
487 [dimIdx = dimIdx](int64_t reductionDimIdx) {
488 return reductionDimIdx == static_cast<int64_t>(dimIdx);
489 })) {
490 targetShape.push_back(dimSize);
491 scalableDims.push_back(sourceScalableDims[dimIdx]);
493 // TODO: update to also allow 0-d vectors when available.
494 if (targetShape.empty())
495 inferredReturnType = getSourceVectorType().getElementType();
496 else
497 inferredReturnType = VectorType::get(
498 targetShape, getSourceVectorType().getElementType(), scalableDims);
499 if (getType() != inferredReturnType)
500 return emitOpError() << "destination type " << getType()
501 << " is incompatible with source type "
502 << getSourceVectorType();
504 return success();
507 /// Returns the mask type expected by this operation.
508 Type MultiDimReductionOp::getExpectedMaskType() {
509 auto vecType = getSourceVectorType();
510 return VectorType::get(vecType.getShape(),
511 IntegerType::get(vecType.getContext(), /*width=*/1),
512 vecType.getScalableDims());
515 namespace {
516 // Only unit dimensions that are being reduced are folded. If the dimension is
517 // unit, but not reduced, it is not folded, thereby keeping the output type the
518 // same. If not all dimensions which are reduced are of unit dimension, this
519 // transformation does nothing. This is just a generalization of
520 // ElideSingleElementReduction for ReduceOp.
521 struct ElideUnitDimsInMultiDimReduction
522 : public OpRewritePattern<MultiDimReductionOp> {
523 using OpRewritePattern::OpRewritePattern;
525 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
526 PatternRewriter &rewriter) const override {
527 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
528 for (const auto &dim : enumerate(shape)) {
529 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
530 return failure();
533 // Vector mask setup.
534 OpBuilder::InsertionGuard guard(rewriter);
535 Operation *rootOp;
536 Value mask;
537 if (reductionOp.isMasked()) {
538 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
539 rootOp = reductionOp.getMaskingOp();
540 mask = reductionOp.getMaskingOp().getMask();
541 } else {
542 rootOp = reductionOp;
545 Location loc = reductionOp.getLoc();
546 Value acc = reductionOp.getAcc();
547 Value cast;
548 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
549 if (mask) {
550 VectorType newMaskType =
551 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
552 dstVecType.getScalableDims());
553 mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
555 cast = rewriter.create<vector::ShapeCastOp>(
556 loc, reductionOp.getDestType(), reductionOp.getSource());
557 } else {
558 // This means we are reducing all the dimensions, and all reduction
559 // dimensions are of size 1. So a simple extraction would do.
560 SmallVector<int64_t> zeroIdx(shape.size(), 0);
561 if (mask)
562 mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
563 cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
564 zeroIdx);
567 Value result =
568 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
569 cast, /*fastmath=*/nullptr, mask);
570 rewriter.replaceOp(rootOp, result);
571 return success();
574 } // namespace
576 void MultiDimReductionOp::getCanonicalizationPatterns(
577 RewritePatternSet &results, MLIRContext *context) {
578 results.add<ElideUnitDimsInMultiDimReduction>(context);
581 //===----------------------------------------------------------------------===//
582 // ReductionOp
583 //===----------------------------------------------------------------------===//
585 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
586 CombiningKind kind, Value vector,
587 arith::FastMathFlags fastMathFlags) {
588 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
591 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
592 CombiningKind kind, Value vector, Value acc,
593 arith::FastMathFlags fastMathFlags) {
594 build(builder, result,
595 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
596 acc, fastMathFlags);
599 LogicalResult ReductionOp::verify() {
600 // Verify for 0-D and 1-D vector.
601 int64_t rank = getSourceVectorType().getRank();
602 if (rank > 1)
603 return emitOpError("unsupported reduction rank: ") << rank;
605 // Verify supported reduction kind.
606 Type eltType = getDest().getType();
607 if (!isSupportedCombiningKind(getKind(), eltType))
608 return emitOpError("unsupported reduction type '")
609 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
610 << "'";
612 return success();
615 // MaskableOpInterface methods.
617 /// Returns the mask type expected by this operation.
618 Type ReductionOp::getExpectedMaskType() {
619 auto vecType = getSourceVectorType();
620 return VectorType::get(vecType.getShape(),
621 IntegerType::get(vecType.getContext(), /*width=*/1),
622 vecType.getScalableDims());
625 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
626 OpBuilder &builder, Location loc,
627 Value vector) {
628 switch (op) {
629 case arith::AtomicRMWKind::addf:
630 case arith::AtomicRMWKind::addi:
631 return builder.create<vector::ReductionOp>(vector.getLoc(),
632 CombiningKind::ADD, vector);
633 case arith::AtomicRMWKind::mulf:
634 case arith::AtomicRMWKind::muli:
635 return builder.create<vector::ReductionOp>(vector.getLoc(),
636 CombiningKind::MUL, vector);
637 case arith::AtomicRMWKind::minimumf:
638 return builder.create<vector::ReductionOp>(vector.getLoc(),
639 CombiningKind::MINIMUMF, vector);
640 case arith::AtomicRMWKind::mins:
641 return builder.create<vector::ReductionOp>(vector.getLoc(),
642 CombiningKind::MINSI, vector);
643 case arith::AtomicRMWKind::minu:
644 return builder.create<vector::ReductionOp>(vector.getLoc(),
645 CombiningKind::MINUI, vector);
646 case arith::AtomicRMWKind::maximumf:
647 return builder.create<vector::ReductionOp>(vector.getLoc(),
648 CombiningKind::MAXIMUMF, vector);
649 case arith::AtomicRMWKind::maxs:
650 return builder.create<vector::ReductionOp>(vector.getLoc(),
651 CombiningKind::MAXSI, vector);
652 case arith::AtomicRMWKind::maxu:
653 return builder.create<vector::ReductionOp>(vector.getLoc(),
654 CombiningKind::MAXUI, vector);
655 case arith::AtomicRMWKind::andi:
656 return builder.create<vector::ReductionOp>(vector.getLoc(),
657 CombiningKind::AND, vector);
658 case arith::AtomicRMWKind::ori:
659 return builder.create<vector::ReductionOp>(vector.getLoc(),
660 CombiningKind::OR, vector);
661 // TODO: Add remaining reduction operations.
662 default:
663 (void)emitOptionalError(loc, "Reduction operation type not supported");
664 break;
666 return nullptr;
669 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
670 return llvm::to_vector<4>(getSourceVectorType().getShape());
673 namespace {
674 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
675 using OpRewritePattern::OpRewritePattern;
677 LogicalResult matchAndRewrite(ReductionOp reductionOp,
678 PatternRewriter &rewriter) const override {
679 // Vector mask setup.
680 OpBuilder::InsertionGuard guard(rewriter);
681 auto maskableOp =
682 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
683 Operation *rootOp;
684 Value mask;
685 if (maskableOp.isMasked()) {
686 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
687 rootOp = maskableOp.getMaskingOp();
688 mask = maskableOp.getMaskingOp().getMask();
689 } else {
690 rootOp = reductionOp;
693 auto vectorType = reductionOp.getSourceVectorType();
694 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
695 return failure();
697 Location loc = reductionOp.getLoc();
698 Value result;
699 if (vectorType.getRank() == 0) {
700 if (mask)
701 mask = rewriter.create<ExtractElementOp>(loc, mask);
702 result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
703 } else {
704 if (mask)
705 mask = rewriter.create<ExtractOp>(loc, mask, 0);
706 result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
709 if (Value acc = reductionOp.getAcc())
710 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
711 result, acc,
712 reductionOp.getFastmathAttr(), mask);
714 rewriter.replaceOp(rootOp, result);
715 return success();
718 } // namespace
720 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
721 MLIRContext *context) {
722 results.add<ElideSingleElementReduction>(context);
725 //===----------------------------------------------------------------------===//
726 // ContractionOp
727 //===----------------------------------------------------------------------===//
729 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
730 Value lhs, Value rhs, Value acc,
731 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
732 ArrayRef<IteratorType> iteratorTypes) {
733 result.addOperands({lhs, rhs, acc});
734 result.addTypes(acc.getType());
735 result.addAttribute(
736 getIndexingMapsAttrName(result.name),
737 builder.getAffineMapArrayAttr(
738 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
739 result.addAttribute(
740 getIteratorTypesAttrName(result.name),
741 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
742 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
743 return IteratorTypeAttr::get(builder.getContext(), t);
744 }))));
747 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
748 Value lhs, Value rhs, Value acc,
749 ArrayAttr indexingMaps,
750 ArrayAttr iteratorTypes) {
751 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
752 ContractionOp::getDefaultKind());
755 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
756 Value lhs, Value rhs, Value acc,
757 ArrayAttr indexingMaps,
758 ArrayAttr iteratorTypes, CombiningKind kind) {
759 result.addOperands({lhs, rhs, acc});
760 result.addTypes(acc.getType());
761 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
762 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
763 result.addAttribute(getKindAttrName(result.name),
764 CombiningKindAttr::get(builder.getContext(), kind));
767 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
768 OpAsmParser::UnresolvedOperand lhsInfo;
769 OpAsmParser::UnresolvedOperand rhsInfo;
770 OpAsmParser::UnresolvedOperand accInfo;
771 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
772 SmallVector<Type, 2> types;
773 Type resultType;
774 auto loc = parser.getCurrentLocation();
775 DictionaryAttr dictAttr;
776 // TODO: Unify linalg op attribute parsing.
777 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
778 parser.parseComma() || parser.parseOperand(rhsInfo) ||
779 parser.parseComma() || parser.parseOperand(accInfo) ||
780 parser.parseTrailingOperandList(masksInfo) ||
781 parser.parseOptionalAttrDict(result.attributes) ||
782 parser.parseColonTypeList(types) ||
783 parser.parseKeywordType("into", resultType) ||
784 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
785 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
786 parser.resolveOperand(accInfo, resultType, result.operands) ||
787 parser.addTypeToList(resultType, result.types))
788 return failure();
789 result.attributes.append(dictAttr.getValue().begin(),
790 dictAttr.getValue().end());
792 // Convert array of string into an array of IteratyType enums. This is needed,
793 // because tests still use the old format when 'iterator_types' attribute is
794 // represented as an array of strings.
795 // TODO: Remove this conversion once tests are fixed.
796 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
797 result.attributes.get(getIteratorTypesAttrName(result.name)));
799 SmallVector<Attribute> iteratorTypeAttrs;
801 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
802 auto maybeIteratorType = symbolizeIteratorType(s);
803 if (!maybeIteratorType.has_value())
804 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
806 iteratorTypeAttrs.push_back(
807 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
809 result.attributes.set(getIteratorTypesAttrName(result.name),
810 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
812 if (!result.attributes.get(getKindAttrName(result.name))) {
813 result.addAttribute(
814 getKindAttrName(result.name),
815 CombiningKindAttr::get(result.getContext(),
816 ContractionOp::getDefaultKind()));
818 if (masksInfo.empty())
819 return success();
820 if (masksInfo.size() != 2)
821 return parser.emitError(parser.getNameLoc(),
822 "expected zero or exactly 2 vector mask operands");
823 auto lhsType = llvm::cast<VectorType>(types[0]);
824 auto rhsType = llvm::cast<VectorType>(types[1]);
825 auto maskElementType = parser.getBuilder().getI1Type();
826 std::array<VectorType, 2> maskTypes = {
827 VectorType::Builder(lhsType).setElementType(maskElementType),
828 VectorType::Builder(rhsType).setElementType(maskElementType)};
829 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
830 return failure();
831 return success();
834 void ContractionOp::print(OpAsmPrinter &p) {
835 // TODO: Unify printing code with linalg ops.
836 auto attrNames = getTraitAttrNames();
837 llvm::StringSet<> traitAttrsSet;
838 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839 SmallVector<NamedAttribute, 8> attrs;
840 for (auto attr : (*this)->getAttrs()) {
841 if (attr.getName() == getIteratorTypesAttrName()) {
842 auto iteratorTypes =
843 llvm::cast<ArrayAttr>(attr.getValue())
844 .getAsValueRange<IteratorTypeAttr, IteratorType>();
845 // Convert IteratorType enums into the string representation. This is
846 // needed, because tests still use the old format when 'iterator_types'
847 // attribute is represented as an array of strings.
848 // TODO: Remove this conversion once tests are fixed.
849 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
850 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
851 return StringAttr::get(getContext(), stringifyIteratorType(t));
852 }));
854 attrs.emplace_back(getIteratorTypesAttrName(),
855 ArrayAttr::get(getContext(), iteratorTypeNames));
856 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
857 attrs.push_back(attr);
860 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
861 p << " " << dictAttr << " " << getLhs() << ", ";
862 p << getRhs() << ", " << getAcc();
864 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
865 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
866 << getResultType();
869 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
870 const std::vector<std::pair<int64_t, int64_t>> &map) {
871 for (auto &dimPair : map) {
872 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
873 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
874 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
875 return false;
877 return true;
880 static LogicalResult verifyOutputShape(
881 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
882 Type resType,
883 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
884 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
885 DenseSet<int64_t> lhsContractingDimSet;
886 DenseSet<int64_t> rhsContractingDimSet;
887 for (auto &dimPair : contractingDimMap) {
888 lhsContractingDimSet.insert(dimPair.first);
889 rhsContractingDimSet.insert(dimPair.second);
891 DenseSet<int64_t> rhsBatchDimSet;
892 for (auto &dimPair : batchDimMap)
893 rhsBatchDimSet.insert(dimPair.second);
895 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
896 SmallVector<int64_t, 4> expectedResultDims;
897 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
898 if (lhsContractingDimSet.count(i) > 0)
899 continue;
900 expectedResultDims.push_back(lhsType.getDimSize(i));
903 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
904 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
905 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
906 continue;
907 expectedResultDims.push_back(rhsType.getDimSize(i));
910 // Verify 'expectedResultDims'.
911 if (expectedResultDims.empty()) {
912 // No batch or free dimension implies a scalar result.
913 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
914 return op.emitOpError("invalid accumulator/result vector shape");
915 } else {
916 // At least one batch or free dimension implies a vector result.
917 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
918 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
919 if (!resVectorType || !accVectorType)
920 return op.emitOpError("invalid accumulator/result vector shape");
922 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
923 // types fully define the result vector type. This assumes the affine maps
924 // are well-formed, which must have been verified already.
925 MLIRContext *ctx = op.getContext();
926 AffineMap lhsMap = op.getIndexingMapsArray()[0];
927 AffineMap rhsMap = op.getIndexingMapsArray()[1];
928 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
929 return op.emitOpError(
930 "expected all dimensions to be either a LHS or a RHS dimension");
931 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
932 for (auto pair :
933 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
934 VectorType v = pair.first;
935 auto map = pair.second;
936 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
937 unsigned pos = map.getDimPosition(idx);
938 if (!extents[pos])
939 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
942 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
943 return op.emitOpError("expected all dimensions to get an extent as "
944 "either a LHS or a RHS dimension");
946 AffineMap resMap = op.getIndexingMapsArray()[2];
947 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
948 /*symbolCount=*/0, extents, ctx);
949 // Compose the resMap with the extentsMap, which is a constant map.
950 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
951 assert(llvm::all_of(expectedMap.getResults(),
952 llvm::IsaPred<AffineConstantExpr>) &&
953 "expected constant extent along all dimensions.");
954 // Extract the expected shape and build the type.
955 auto expectedShape = llvm::to_vector<4>(
956 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
957 return cast<AffineConstantExpr>(e).getValue();
958 }));
959 auto expected =
960 VectorType::get(expectedShape, resVectorType.getElementType(),
961 resVectorType.getScalableDims());
962 if (resVectorType != expected || accVectorType != expected)
963 return op.emitOpError(
964 "invalid accumulator/result vector shape, expected: ")
965 << expected;
967 return success();
970 LogicalResult ContractionOp::verify() {
971 VectorType lhsType = getLhsType();
972 VectorType rhsType = getRhsType();
973 Type accType = getAccType();
974 Type resType = getResultType();
976 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
977 if (!lhsType.getElementType().isSignlessInteger())
978 return emitOpError("only supports signless integer types");
981 // Verify that an indexing map was specified for each vector operand.
982 if (getIndexingMapsArray().size() != 3)
983 return emitOpError("expected an indexing map for each vector operand");
985 // Verify that each index map has 'numIterators' inputs, no symbols, and
986 // that the number of map outputs equals the rank of its associated
987 // vector operand.
988 unsigned numIterators = getIteratorTypes().getValue().size();
989 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
990 auto index = it.index();
991 auto map = it.value();
992 if (map.getNumSymbols() != 0)
993 return emitOpError("expected indexing map ")
994 << index << " to have no symbols";
995 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
996 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
997 // Verify that the map has the right number of inputs, outputs, and indices.
998 // This also correctly accounts for (..) -> () for rank-0 results.
999 if (map.getNumDims() != numIterators)
1000 return emitOpError("expected indexing map ")
1001 << index << " to have " << numIterators << " number of inputs";
1002 if (map.getNumResults() != rank)
1003 return emitOpError("expected indexing map ")
1004 << index << " to have " << rank << " number of outputs";
1005 if (!map.isProjectedPermutation())
1006 return emitOpError("expected indexing map ")
1007 << index << " to be a projected permutation of its inputs";
1010 auto contractingDimMap = getContractingDimMap();
1011 auto batchDimMap = getBatchDimMap();
1013 // Verify at least one contracting dimension pair was specified.
1014 if (contractingDimMap.empty())
1015 return emitOpError("expected at least one contracting dimension pair");
1017 // Verify contracting dimension map was properly constructed.
1018 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1019 return emitOpError("invalid contracting dimension map");
1021 // Verify batch dimension map was properly constructed.
1022 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1023 return emitOpError("invalid batch dimension map");
1025 // Verify 'accType' and 'resType' shape.
1026 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1027 contractingDimMap, batchDimMap)))
1028 return failure();
1030 // Verify supported combining kind.
1031 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1032 auto elementType = vectorType ? vectorType.getElementType() : resType;
1033 if (!isSupportedCombiningKind(getKind(), elementType))
1034 return emitOpError("unsupported contraction type");
1036 return success();
1039 // MaskableOpInterface methods.
1041 /// Returns the mask type expected by this operation. Mostly used for
1042 /// verification purposes. It requires the operation to be vectorized."
1043 Type ContractionOp::getExpectedMaskType() {
1044 auto indexingMaps = this->getIndexingMapsArray();
1045 AffineMap lhsIdxMap = indexingMaps[0];
1046 AffineMap rhsIdxMap = indexingMaps[1];
1047 VectorType lhsType = this->getLhsType();
1048 VectorType rhsType = this->getRhsType();
1050 unsigned numVecDims = lhsIdxMap.getNumDims();
1051 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1052 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1054 // Using the information in the indexing maps, extract the size of each
1055 // dimension in the vector.contract operation from the two input operands.
1056 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1057 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1058 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1059 lhsType.getScalableDims()[dimIdx];
1061 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1062 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1063 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1064 rhsType.getScalableDims()[dimIdx];
1067 assert(!ShapedType::isDynamicShape(maskShape) &&
1068 "Mask shape couldn't be computed");
1070 return VectorType::get(maskShape,
1071 IntegerType::get(lhsType.getContext(), /*width=*/1),
1072 maskShapeScalableDims);
1075 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1076 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1077 getIteratorTypesAttrName(), getKindAttrName()};
1080 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1081 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1082 if (targetExpr == map.getResult(i))
1083 return i;
1084 return -1;
1087 static std::vector<std::pair<int64_t, int64_t>>
1088 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1089 IteratorType targetIteratorType, MLIRContext *context) {
1090 std::vector<std::pair<int64_t, int64_t>> dimMap;
1091 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1092 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1093 if (iteratorType != targetIteratorType)
1094 continue;
1095 // Search lhs/rhs map results for 'targetExpr'.
1096 auto targetExpr = getAffineDimExpr(it.index(), context);
1097 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1098 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1099 if (lhsDim >= 0 && rhsDim >= 0)
1100 dimMap.emplace_back(lhsDim, rhsDim);
1102 return dimMap;
1105 void ContractionOp::getIterationBounds(
1106 SmallVectorImpl<int64_t> &iterationBounds) {
1107 auto lhsShape = getLhsType().getShape();
1108 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1109 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1110 SmallVector<int64_t, 2> iterationShape;
1111 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1112 // Search lhs/rhs map results for 'targetExpr'.
1113 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1114 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1115 if (iteratorType == IteratorType::reduction) {
1116 // Get reduction dim size from lhs shape (same size in rhsShape).
1117 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1118 assert(lhsDimIndex >= 0);
1119 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1120 continue;
1122 // Get parallel dimension size from result shape.
1123 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1124 assert(resDimIndex >= 0);
1125 assert(resVectorType != nullptr);
1126 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1130 void ContractionOp::getIterationIndexMap(
1131 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1132 unsigned numMaps = getIndexingMapsArray().size();
1133 iterationIndexMap.resize(numMaps);
1134 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1135 auto index = it.index();
1136 auto map = it.value();
1137 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1138 auto dim = cast<AffineDimExpr>(map.getResult(i));
1139 iterationIndexMap[index][dim.getPosition()] = i;
1144 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1145 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1146 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1147 getContext());
1150 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1151 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1152 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1153 getContext());
1156 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157 SmallVector<int64_t, 4> shape;
1158 getIterationBounds(shape);
1159 return shape;
1162 /// Return a fused vector::ContractionOp which represents a patterns such as:
1164 /// ```mlir
1165 /// %c0 = vector.constant 0: ...
1166 /// %c = vector.contract %a, %b, %c0: ...
1167 /// %e = add %c, %d: ...
1168 /// ```
1170 /// by:
1172 /// ```mlir
1173 /// %e = vector.contract %a, %b, %d: ...
1174 /// ```
1176 /// Return null if the canonicalization does not apply.
1177 // TODO: This should be a folding of Add into Contract in core but while they
1178 // live in different dialects, it is not possible without unnatural
1179 // dependencies.
1180 template <typename AddOpType>
1181 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1182 using OpRewritePattern<AddOpType>::OpRewritePattern;
1184 LogicalResult matchAndRewrite(AddOpType addOp,
1185 PatternRewriter &rewriter) const override {
1186 auto canonicalize = [&](Value maybeContraction,
1187 Value otherOperand) -> vector::ContractionOp {
1188 vector::ContractionOp contractionOp =
1189 dyn_cast_or_null<vector::ContractionOp>(
1190 maybeContraction.getDefiningOp());
1191 if (!contractionOp)
1192 return vector::ContractionOp();
1193 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1194 contractionOp.getAcc().getDefiningOp())) {
1195 if (maybeZero.getValue() ==
1196 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1197 IRMapping bvm;
1198 bvm.map(contractionOp.getAcc(), otherOperand);
1199 auto newContraction =
1200 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1201 rewriter.replaceOp(addOp, newContraction.getResult());
1202 return newContraction;
1205 return vector::ContractionOp();
1208 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1209 vector::ContractionOp contract = canonicalize(a, b);
1210 contract = contract ? contract : canonicalize(b, a);
1211 return contract ? success() : failure();
1215 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1216 MLIRContext *context) {
1217 results.add<CanonicalizeContractAdd<arith::AddIOp>,
1218 CanonicalizeContractAdd<arith::AddFOp>>(context);
1221 //===----------------------------------------------------------------------===//
1222 // ExtractElementOp
1223 //===----------------------------------------------------------------------===//
1225 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1226 SetIntRangeFn setResultRanges) {
1227 setResultRanges(getResult(), argRanges.front());
1230 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1231 Value source) {
1232 result.addOperands({source});
1233 result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1236 LogicalResult vector::ExtractElementOp::verify() {
1237 VectorType vectorType = getSourceVectorType();
1238 if (vectorType.getRank() == 0) {
1239 if (getPosition())
1240 return emitOpError("expected position to be empty with 0-D vector");
1241 return success();
1243 if (vectorType.getRank() != 1)
1244 return emitOpError("unexpected >1 vector rank");
1245 if (!getPosition())
1246 return emitOpError("expected position for 1-D vector");
1247 return success();
1250 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1251 // Skip the 0-D vector here now.
1252 if (!adaptor.getPosition())
1253 return {};
1255 // Fold extractelement (splat X) -> X.
1256 if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1257 return splat.getInput();
1259 // Fold extractelement(broadcast(X)) -> X.
1260 if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1261 if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1262 return broadcast.getSource();
1264 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1265 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1266 if (!pos || !src)
1267 return {};
1269 auto srcElements = src.getValues<Attribute>();
1271 uint64_t posIdx = pos.getInt();
1272 if (posIdx >= srcElements.size())
1273 return {};
1275 return srcElements[posIdx];
1278 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1279 // `poisonValue`.
1280 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1281 int64_t maxIndex) {
1282 return index == poisonValue || (index >= 0 && index < maxIndex);
1285 //===----------------------------------------------------------------------===//
1286 // ExtractOp
1287 //===----------------------------------------------------------------------===//
1289 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1290 SetIntRangeFn setResultRanges) {
1291 setResultRanges(getResult(), argRanges.front());
1294 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1295 Value source, int64_t position) {
1296 build(builder, result, source, ArrayRef<int64_t>{position});
1299 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1300 Value source, OpFoldResult position) {
1301 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1304 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1305 Value source, ArrayRef<int64_t> position) {
1306 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1307 builder.getDenseI64ArrayAttr(position));
1310 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1311 Value source, ArrayRef<OpFoldResult> position) {
1312 SmallVector<int64_t> staticPos;
1313 SmallVector<Value> dynamicPos;
1314 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1315 build(builder, result, source, dynamicPos,
1316 builder.getDenseI64ArrayAttr(staticPos));
1319 LogicalResult
1320 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1321 ExtractOp::Adaptor adaptor,
1322 SmallVectorImpl<Type> &inferredReturnTypes) {
1323 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1324 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1325 vectorType.getRank()) {
1326 inferredReturnTypes.push_back(vectorType.getElementType());
1327 } else {
1328 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1329 vectorType.getRank());
1330 inferredReturnTypes.push_back(VectorType::get(
1331 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1332 vectorType.getScalableDims().drop_front(n)));
1334 return success();
1337 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1338 // Allow extracting 1-element vectors instead of scalars.
1339 auto isCompatible = [](TypeRange l, TypeRange r) {
1340 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1341 return vectorType && vectorType.getShape().equals({1}) &&
1342 vectorType.getElementType() == r.front();
1344 if (l.size() == 1 && r.size() == 1 &&
1345 (isCompatible(l, r) || isCompatible(r, l)))
1346 return true;
1347 return l == r;
1350 LogicalResult vector::ExtractOp::verify() {
1351 // Note: This check must come before getMixedPosition() to prevent a crash.
1352 auto dynamicMarkersCount =
1353 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1354 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1355 return emitOpError(
1356 "mismatch between dynamic and static positions (kDynamic marker but no "
1357 "corresponding dynamic position) -- this can only happen due to an "
1358 "incorrect fold/rewrite");
1359 auto position = getMixedPosition();
1360 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1361 return emitOpError(
1362 "expected position attribute of rank no greater than vector rank");
1363 for (auto [idx, pos] : llvm::enumerate(position)) {
1364 if (auto attr = dyn_cast<Attribute>(pos)) {
1365 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1366 if (!isValidPositiveIndexOrPoison(
1367 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1368 return emitOpError("expected position attribute #")
1369 << (idx + 1)
1370 << " to be a non-negative integer smaller than the "
1371 "corresponding vector dimension or poison (-1)";
1375 return success();
1378 template <typename IntType>
1379 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1380 return llvm::to_vector<4>(llvm::map_range(
1381 arrayAttr.getAsRange<IntegerAttr>(),
1382 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1385 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1386 /// positions.
1387 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1388 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1389 return failure();
1391 // TODO: Canonicalization for dynamic position not implemented yet.
1392 if (extractOp.hasDynamicPosition())
1393 return failure();
1395 SmallVector<int64_t> globalPosition;
1396 ExtractOp currentOp = extractOp;
1397 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1398 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1399 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1400 currentOp = nextOp;
1401 // TODO: Canonicalization for dynamic position not implemented yet.
1402 if (currentOp.hasDynamicPosition())
1403 return failure();
1404 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1405 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1407 extractOp.setOperand(0, currentOp.getVector());
1408 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1409 OpBuilder b(extractOp.getContext());
1410 std::reverse(globalPosition.begin(), globalPosition.end());
1411 extractOp.setStaticPosition(globalPosition);
1412 return success();
1415 namespace {
1416 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1417 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1418 /// Compose TransposeOp permutations as we walk back.
1419 /// This helper class keeps an updated extraction position `extractPosition`
1420 /// with extra trailing sentinels.
1421 /// The sentinels encode the internal transposition status of the result vector.
1422 /// As we iterate, extractPosition is permuted and updated.
1423 class ExtractFromInsertTransposeChainState {
1424 public:
1425 ExtractFromInsertTransposeChainState(ExtractOp e);
1427 /// Iterate over producing insert and transpose ops until we find a fold.
1428 Value fold();
1430 private:
1431 /// Return true if the vector at position `a` is contained within the vector
1432 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1433 /// is a prefix of `b`.
1434 template <typename ContainerA, typename ContainerB>
1435 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1436 return a.size() <= b.size() &&
1437 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1440 /// Return true if the vector at position `a` intersects the vector at
1441 /// position `b`. Under insert/extract semantics, this is the same as equality
1442 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1443 /// Comparison is on the common prefix (i.e. zip).
1444 template <typename ContainerA, typename ContainerB>
1445 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1446 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1447 if (elemA < 0 || elemB < 0)
1448 continue;
1449 if (elemA != elemB)
1450 return false;
1452 return true;
1455 /// Folding is only possible in the absence of an internal permutation in the
1456 /// result vector.
1457 bool canFold() {
1458 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1461 // Helper to get the next defining op of interest.
1462 void updateStateForNextIteration(Value v) {
1463 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1464 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1467 // Case 1. If we hit a transpose, just compose the map and iterate.
1468 // Invariant: insert + transpose do not change rank, we can always compose.
1469 LogicalResult handleTransposeOp();
1471 // Case 2: the insert position matches extractPosition exactly, early return.
1472 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1474 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1475 /// portion of the source of the insert.
1476 /// Example:
1477 /// ```
1478 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1479 /// // extractPosition == [1, 2, 3]
1480 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1481 /// // can fold to vector.extract %source[0, 3]
1482 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1483 /// ```
1484 /// To traverse through %source, we need to set the leading dims to 0 and
1485 /// drop the extra leading dims.
1486 /// This method updates the internal state.
1487 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1489 /// Try to fold in place to extract(source, extractPosition) and return the
1490 /// folded result. Return null if folding is not possible (e.g. due to an
1491 /// internal transposition in the result).
1492 Value tryToFoldExtractOpInPlace(Value source);
1494 ExtractOp extractOp;
1495 int64_t vectorRank;
1496 int64_t extractedRank;
1498 InsertOp nextInsertOp;
1499 TransposeOp nextTransposeOp;
1501 /// Sentinel values that encode the internal permutation status of the result.
1502 /// They are set to (-1, ... , -k) at the beginning and appended to
1503 /// `extractPosition`.
1504 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1505 /// ensure that there is no internal transposition.
1506 /// Internal transposition cannot be accounted for with a folding pattern.
1507 // TODO: We could relax the internal transposition with an extra transposition
1508 // operation in a future canonicalizer.
1509 SmallVector<int64_t> sentinels;
1510 SmallVector<int64_t> extractPosition;
1512 } // namespace
1514 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1515 ExtractOp e)
1516 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1517 extractedRank(extractOp.getNumIndices()) {
1518 assert(vectorRank >= extractedRank && "Extracted position overflow");
1519 sentinels.reserve(vectorRank - extractedRank);
1520 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1521 sentinels.push_back(-(i + 1));
1522 extractPosition.assign(extractOp.getStaticPosition().begin(),
1523 extractOp.getStaticPosition().end());
1524 llvm::append_range(extractPosition, sentinels);
1527 // Case 1. If we hit a transpose, just compose the map and iterate.
1528 // Invariant: insert + transpose do not change rank, we can always compose.
1529 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1530 // TODO: Canonicalization for dynamic position not implemented yet.
1531 if (extractOp.hasDynamicPosition())
1532 return failure();
1534 if (!nextTransposeOp)
1535 return failure();
1536 AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1537 nextTransposeOp.getPermutation(), extractOp.getContext()));
1538 extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
1539 return success();
1542 // Case 2: the insert position matches extractPosition exactly, early return.
1543 LogicalResult
1544 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1545 Value &res) {
1546 // TODO: Canonicalization for dynamic position not implemented yet.
1547 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1548 return failure();
1550 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1551 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1552 return failure();
1553 // Case 2.a. early-exit fold.
1554 res = nextInsertOp.getSource();
1555 // Case 2.b. if internal transposition is present, canFold will be false.
1556 return success(canFold());
1559 /// Case 3: if inserted position is a prefix of extractPosition,
1560 /// extract a portion of the source of the insertion.
1561 /// This method updates the internal state.
1562 LogicalResult
1563 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1564 // TODO: Canonicalization for dynamic position not implemented yet.
1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1566 return failure();
1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569 if (!isContainedWithin(insertedPos, extractPosition))
1570 return failure();
1571 // Set leading dims to zero.
1572 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1573 // Drop extra leading dims.
1574 extractPosition.erase(extractPosition.begin(),
1575 extractPosition.begin() + insertedPos.size());
1576 extractedRank = extractPosition.size() - sentinels.size();
1577 // Case 3.a. early-exit fold (break and delegate to post-while path).
1578 res = nextInsertOp.getSource();
1579 // Case 3.b. if internal transposition is present, canFold will be false.
1580 return success();
1583 /// Try to fold in place to extract(source, extractPosition) and return the
1584 /// folded result. Return null if folding is not possible (e.g. due to an
1585 /// internal transposition in the result).
1586 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1587 Value source) {
1588 // TODO: Canonicalization for dynamic position not implemented yet.
1589 if (extractOp.hasDynamicPosition())
1590 return Value();
1592 // If we can't fold (either internal transposition, or nothing to fold), bail.
1593 bool nothingToFold = (source == extractOp.getVector());
1594 if (nothingToFold || !canFold())
1595 return Value();
1597 // Otherwise, fold by updating the op inplace and return its result.
1598 OpBuilder b(extractOp.getContext());
1599 extractOp.setStaticPosition(
1600 ArrayRef(extractPosition).take_front(extractedRank));
1601 extractOp.getVectorMutable().assign(source);
1602 return extractOp.getResult();
1605 /// Iterate over producing insert and transpose ops until we find a fold.
1606 Value ExtractFromInsertTransposeChainState::fold() {
1607 // TODO: Canonicalization for dynamic position not implemented yet.
1608 if (extractOp.hasDynamicPosition())
1609 return Value();
1611 Value valueToExtractFrom = extractOp.getVector();
1612 updateStateForNextIteration(valueToExtractFrom);
1613 while (nextInsertOp || nextTransposeOp) {
1614 // Case 1. If we hit a transpose, just compose the map and iterate.
1615 // Invariant: insert + transpose do not change rank, we can always compose.
1616 if (succeeded(handleTransposeOp())) {
1617 valueToExtractFrom = nextTransposeOp.getVector();
1618 updateStateForNextIteration(valueToExtractFrom);
1619 continue;
1622 Value result;
1623 // Case 2: the position match exactly.
1624 if (succeeded(handleInsertOpWithMatchingPos(result)))
1625 return result;
1627 // Case 3: if the inserted position is a prefix of extractPosition, we can
1628 // just extract a portion of the source of the insert.
1629 if (succeeded(handleInsertOpWithPrefixPos(result)))
1630 return tryToFoldExtractOpInPlace(result);
1632 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1633 // values. This is a more difficult case and we bail.
1634 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1635 if (isContainedWithin(extractPosition, insertedPos) ||
1636 intersectsWhereNonNegative(extractPosition, insertedPos))
1637 return Value();
1639 // Case 5: No intersection, we forward the extract to insertOp.dest().
1640 valueToExtractFrom = nextInsertOp.getDest();
1641 updateStateForNextIteration(valueToExtractFrom);
1643 // If after all this we can fold, go for it.
1644 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1647 /// Returns true if the operation has a 0-D vector type operand or result.
1648 static bool hasZeroDimVectors(Operation *op) {
1649 auto hasZeroDimVectorType = [](Type type) -> bool {
1650 auto vecType = dyn_cast<VectorType>(type);
1651 return vecType && vecType.getRank() == 0;
1654 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1655 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1658 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1659 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1660 // TODO: Canonicalization for dynamic position not implemented yet.
1661 if (extractOp.hasDynamicPosition())
1662 return Value();
1664 Operation *defOp = extractOp.getVector().getDefiningOp();
1665 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1666 return Value();
1668 Value source = defOp->getOperand(0);
1669 if (extractOp.getType() == source.getType())
1670 return source;
1671 auto getRank = [](Type type) {
1672 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1673 : 0;
1676 // If splat or broadcast from a scalar, just return the source scalar.
1677 unsigned broadcastSrcRank = getRank(source.getType());
1678 if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1679 return source;
1681 unsigned extractResultRank = getRank(extractOp.getType());
1682 if (extractResultRank >= broadcastSrcRank)
1683 return Value();
1684 // Check that the dimension of the result haven't been broadcasted.
1685 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1686 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1687 if (extractVecType && broadcastVecType &&
1688 extractVecType.getShape() !=
1689 broadcastVecType.getShape().take_back(extractResultRank))
1690 return Value();
1692 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1693 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1695 // Detect all the positions that come from "dim-1" broadcasting.
1696 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1697 // extract position to `0` when extracting from the source operand.
1698 llvm::SetVector<int64_t> broadcastedUnitDims =
1699 broadcastOp.computeBroadcastedUnitDims();
1700 SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1701 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1702 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1703 if (broadcastedUnitDims.contains(i))
1704 extractPos[i] = 0;
1705 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1706 // matching extract position when extracting from the source operand.
1707 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1708 extractPos.erase(extractPos.begin(),
1709 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1710 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1711 OpBuilder b(extractOp.getContext());
1712 extractOp.setOperand(0, source);
1713 extractOp.setStaticPosition(extractPos);
1714 return extractOp.getResult();
1717 /// Fold extractOp coming from ShuffleOp.
1719 /// Example:
1721 /// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1722 /// : vector<8xf32>, vector<8xf32>
1723 /// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1724 /// ->
1725 /// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1727 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1728 // Dynamic positions are not folded as the resulting code would be more
1729 // complex than the input code.
1730 if (extractOp.hasDynamicPosition())
1731 return Value();
1733 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1734 if (!shuffleOp)
1735 return Value();
1737 // TODO: 0-D or multi-dimensional vectors not supported yet.
1738 if (shuffleOp.getResultVectorType().getRank() != 1)
1739 return Value();
1741 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1742 auto shuffleMask = shuffleOp.getMask();
1743 int64_t extractIdx = extractOp.getStaticPosition()[0];
1744 int64_t shuffleIdx = shuffleMask[extractIdx];
1746 // Find the shuffled vector to extract from based on the shuffle index.
1747 if (shuffleIdx < inputVecSize) {
1748 extractOp.setOperand(0, shuffleOp.getV1());
1749 extractOp.setStaticPosition({shuffleIdx});
1750 } else {
1751 extractOp.setOperand(0, shuffleOp.getV2());
1752 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1755 return extractOp.getResult();
1758 // Fold extractOp with source coming from ShapeCast op.
1759 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1760 // TODO: Canonicalization for dynamic position not implemented yet.
1761 if (extractOp.hasDynamicPosition())
1762 return Value();
1764 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1765 if (!shapeCastOp)
1766 return Value();
1768 // Get the nth dimension size starting from lowest dimension.
1769 auto getDimReverse = [](VectorType type, int64_t n) {
1770 return type.getShape().take_back(n + 1).front();
1772 int64_t destinationRank =
1773 llvm::isa<VectorType>(extractOp.getType())
1774 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1775 : 0;
1776 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1777 return Value();
1778 if (destinationRank > 0) {
1779 auto destinationType =
1780 llvm::cast<VectorType>(extractOp.getResult().getType());
1781 for (int64_t i = 0; i < destinationRank; i++) {
1782 // The lowest dimension of the destination must match the lowest
1783 // dimension of the shapecast op source.
1784 // TODO: This case could be support in a canonicalization pattern.
1785 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1786 getDimReverse(destinationType, i))
1787 return Value();
1790 // Extract the strides associated with the extract op vector source. Then use
1791 // this to calculate a linearized position for the extract.
1792 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1793 std::reverse(extractedPos.begin(), extractedPos.end());
1794 SmallVector<int64_t, 4> strides;
1795 int64_t stride = 1;
1796 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1797 strides.push_back(stride);
1798 stride *=
1799 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1802 int64_t position = linearize(extractedPos, strides);
1803 // Then extract the strides associated to the shapeCast op vector source and
1804 // delinearize the position using those strides.
1805 SmallVector<int64_t, 4> newStrides;
1806 int64_t numDimension =
1807 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1808 stride = 1;
1809 for (int64_t i = 0; i < numDimension; i++) {
1810 newStrides.push_back(stride);
1811 stride *=
1812 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1814 std::reverse(newStrides.begin(), newStrides.end());
1815 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1816 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1817 OpBuilder b(extractOp.getContext());
1818 extractOp.setStaticPosition(newPosition);
1819 extractOp.setOperand(0, shapeCastOp.getSource());
1820 return extractOp.getResult();
1823 /// Fold an ExtractOp from ExtractStridedSliceOp.
1824 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1825 // TODO: Canonicalization for dynamic position not implemented yet.
1826 if (extractOp.hasDynamicPosition())
1827 return Value();
1829 auto extractStridedSliceOp =
1830 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1831 if (!extractStridedSliceOp)
1832 return Value();
1834 // 0-D vectors not supported.
1835 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1836 if (hasZeroDimVectors(extractStridedSliceOp))
1837 return Value();
1839 // Return if 'extractStridedSliceOp' has non-unit strides.
1840 if (extractStridedSliceOp.hasNonUnitStrides())
1841 return Value();
1843 // Trim offsets for dimensions fully extracted.
1844 auto sliceOffsets =
1845 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1846 while (!sliceOffsets.empty()) {
1847 size_t lastOffset = sliceOffsets.size() - 1;
1848 if (sliceOffsets.back() != 0 ||
1849 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1850 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1851 break;
1852 sliceOffsets.pop_back();
1854 unsigned destinationRank = 0;
1855 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1856 destinationRank = vecType.getRank();
1857 // The dimensions of the result need to be untouched by the
1858 // extractStridedSlice op.
1859 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1860 sliceOffsets.size())
1861 return Value();
1863 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1864 assert(extractedPos.size() >= sliceOffsets.size());
1865 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1866 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1867 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1869 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1870 OpBuilder b(extractOp.getContext());
1871 extractOp.setStaticPosition(extractedPos);
1872 return extractOp.getResult();
1875 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1876 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1877 // TODO: Canonicalization for dynamic position not implemented yet.
1878 if (extractOp.hasDynamicPosition())
1879 return Value();
1881 int64_t destinationRank =
1882 llvm::isa<VectorType>(extractOp.getType())
1883 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1884 : 0;
1885 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1886 if (!insertOp)
1887 return Value();
1889 // 0-D vectors not supported.
1890 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1891 if (hasZeroDimVectors(insertOp))
1892 return Value();
1894 while (insertOp) {
1895 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1896 insertOp.getSourceVectorType().getRank();
1897 if (destinationRank > insertOp.getSourceVectorType().getRank())
1898 return Value();
1899 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1900 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1902 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1903 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1905 return Value();
1906 bool disjoint = false;
1907 SmallVector<int64_t, 4> offsetDiffs;
1908 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1909 int64_t start = insertOffsets[dim];
1910 int64_t size =
1911 (dim < insertRankDiff)
1913 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1914 int64_t end = start + size;
1915 int64_t offset = extractOffsets[dim];
1916 // Check if the start of the extract offset is in the interval inserted.
1917 if (start <= offset && offset < end) {
1918 if (dim >= insertRankDiff)
1919 offsetDiffs.push_back(offset - start);
1920 continue;
1922 disjoint = true;
1923 break;
1925 // The extract element chunk overlap with the vector inserted.
1926 if (!disjoint) {
1927 // If any of the inner dimensions are only partially inserted we have a
1928 // partial overlap.
1929 int64_t srcRankDiff =
1930 insertOp.getSourceVectorType().getRank() - destinationRank;
1931 for (int64_t i = 0; i < destinationRank; i++) {
1932 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1933 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1934 insertRankDiff))
1935 return Value();
1937 extractOp.getVectorMutable().assign(insertOp.getSource());
1938 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1939 OpBuilder b(extractOp.getContext());
1940 extractOp.setStaticPosition(offsetDiffs);
1941 return extractOp.getResult();
1943 // If the chunk extracted is disjoint from the chunk inserted, keep
1944 // looking in the insert chain.
1945 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1947 return Value();
1950 /// Try to fold the extraction of a scalar from a vector defined by
1951 /// vector.from_elements. E.g.:
1953 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1954 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1955 /// ==> fold to %a
1956 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1957 // Dynamic extractions cannot be folded.
1958 if (extractOp.hasDynamicPosition())
1959 return {};
1961 // Look for extract(from_elements).
1962 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1963 if (!fromElementsOp)
1964 return {};
1966 // Scalable vectors are not supported.
1967 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1968 if (vecType.isScalable())
1969 return {};
1971 // Only extractions of scalars are supported.
1972 int64_t rank = vecType.getRank();
1973 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1974 if (extractOp.getType() != vecType.getElementType())
1975 return {};
1976 assert(static_cast<int64_t>(indices.size()) == rank &&
1977 "unexpected number of indices");
1979 // Compute flattened/linearized index and fold to operand.
1980 int flatIndex = 0;
1981 int stride = 1;
1982 for (int i = rank - 1; i >= 0; --i) {
1983 flatIndex += indices[i] * stride;
1984 stride *= vecType.getDimSize(i);
1986 return fromElementsOp.getElements()[flatIndex];
1989 /// Fold an insert or extract operation into an poison value when a poison index
1990 /// is found at any dimension of the static position.
1991 static ub::PoisonAttr
1992 foldPoisonIndexInsertExtractOp(MLIRContext *context,
1993 ArrayRef<int64_t> staticPos, int64_t poisonVal) {
1994 if (!llvm::is_contained(staticPos, poisonVal))
1995 return ub::PoisonAttr();
1997 return ub::PoisonAttr::get(context);
2000 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2001 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2002 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2003 // mismatch).
2004 if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2005 return getVector();
2006 if (auto res = foldPoisonIndexInsertExtractOp(
2007 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2008 return res;
2009 if (succeeded(foldExtractOpFromExtractChain(*this)))
2010 return getResult();
2011 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2012 return res;
2013 if (auto res = foldExtractFromBroadcast(*this))
2014 return res;
2015 if (auto res = foldExtractFromShuffle(*this))
2016 return res;
2017 if (auto res = foldExtractFromShapeCast(*this))
2018 return res;
2019 if (auto val = foldExtractFromExtractStrided(*this))
2020 return val;
2021 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2022 return val;
2023 if (auto val = foldScalarExtractFromFromElements(*this))
2024 return val;
2025 return OpFoldResult();
2028 namespace {
2030 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2031 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2032 public:
2033 using OpRewritePattern::OpRewritePattern;
2035 LogicalResult matchAndRewrite(ExtractOp extractOp,
2036 PatternRewriter &rewriter) const override {
2037 Operation *defOp = extractOp.getVector().getDefiningOp();
2038 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2039 return failure();
2041 Value source = defOp->getOperand(0);
2042 if (extractOp.getType() == source.getType())
2043 return failure();
2044 auto getRank = [](Type type) {
2045 return llvm::isa<VectorType>(type)
2046 ? llvm::cast<VectorType>(type).getRank()
2047 : 0;
2049 unsigned broadcastSrcRank = getRank(source.getType());
2050 unsigned extractResultRank = getRank(extractOp.getType());
2051 // We only consider the case where the rank of the source is less than or
2052 // equal to the rank of the extract dst. The other cases are handled in the
2053 // folding patterns.
2054 if (extractResultRank < broadcastSrcRank)
2055 return failure();
2057 // Special case if broadcast src is a 0D vector.
2058 if (extractResultRank == 0) {
2059 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2060 rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2061 return success();
2063 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2064 extractOp, extractOp.getType(), source);
2065 return success();
2069 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2070 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2071 public:
2072 using OpRewritePattern::OpRewritePattern;
2074 LogicalResult matchAndRewrite(ExtractOp extractOp,
2075 PatternRewriter &rewriter) const override {
2076 // Return if 'ExtractOp' operand is not defined by a splat vector
2077 // ConstantOp.
2078 Value sourceVector = extractOp.getVector();
2079 Attribute vectorCst;
2080 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2081 return failure();
2082 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2083 if (!splat)
2084 return failure();
2085 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2086 if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2087 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2088 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2089 return success();
2093 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2094 class ExtractOpNonSplatConstantFolder final
2095 : public OpRewritePattern<ExtractOp> {
2096 public:
2097 using OpRewritePattern::OpRewritePattern;
2099 LogicalResult matchAndRewrite(ExtractOp extractOp,
2100 PatternRewriter &rewriter) const override {
2101 // TODO: Canonicalization for dynamic position not implemented yet.
2102 if (extractOp.hasDynamicPosition())
2103 return failure();
2105 // Return if 'ExtractOp' operand is not defined by a compatible vector
2106 // ConstantOp.
2107 Value sourceVector = extractOp.getVector();
2108 Attribute vectorCst;
2109 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2110 return failure();
2112 auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2113 if (vecTy.isScalable())
2114 return failure();
2116 // The splat case is handled by `ExtractOpSplatConstantFolder`.
2117 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2118 if (!dense || dense.isSplat())
2119 return failure();
2121 // Calculate the linearized position of the continuous chunk of elements to
2122 // extract.
2123 llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2124 copy(extractOp.getStaticPosition(), completePositions.begin());
2125 int64_t elemBeginPosition =
2126 linearize(completePositions, computeStrides(vecTy.getShape()));
2127 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2129 TypedAttr newAttr;
2130 if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2131 SmallVector<Attribute> elementValues(
2132 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2133 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2134 } else {
2135 newAttr = *denseValuesBegin;
2138 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2139 return success();
2143 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2144 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2145 public:
2146 using OpRewritePattern::OpRewritePattern;
2148 LogicalResult matchAndRewrite(ExtractOp extractOp,
2149 PatternRewriter &rewriter) const override {
2150 auto createMaskOp =
2151 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2152 if (!createMaskOp)
2153 return failure();
2155 VectorType extractedMaskType =
2156 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2158 if (!extractedMaskType)
2159 return failure();
2161 auto maskOperands = createMaskOp.getOperands();
2162 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2163 VectorType maskType = createMaskOp.getVectorType();
2165 bool containsUnknownDims = false;
2166 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2168 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2169 dimIdx++) {
2170 int64_t pos = extractOpPos[dimIdx];
2171 Value operand = maskOperands[dimIdx];
2172 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2173 if (!constantOp) {
2174 // Bounds of this dim unknown.
2175 containsUnknownDims = true;
2176 continue;
2179 int64_t createMaskBound =
2180 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2182 if (pos != ShapedType::kDynamic) {
2183 // If any position is outside the range from the `create_mask`, then the
2184 // extracted mask will be all-false.
2185 allFalse |= pos >= createMaskBound;
2186 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2187 // This dim is not all-true and since this is a dynamic index we don't
2188 // know if the extraction is within the true or false region.
2189 // Note: Zero dims have already handled via getMaskFormat().
2190 containsUnknownDims = true;
2194 if (allFalse) {
2195 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2196 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2197 } else if (!containsUnknownDims) {
2198 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2199 extractOp, extractedMaskType,
2200 maskOperands.drop_front(extractOpPos.size()));
2201 } else {
2202 return failure();
2204 return success();
2208 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2209 // does not change.
2210 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2211 PatternRewriter &rewriter) {
2212 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2213 if (!castOp)
2214 return failure();
2216 VectorType sourceType = castOp.getSourceVectorType();
2217 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2218 if (!targetType)
2219 return failure();
2221 if (sourceType.getNumElements() != targetType.getNumElements())
2222 return failure();
2224 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2225 castOp.getSource());
2226 return success();
2229 /// Try to canonicalize the extraction of a subvector from a vector defined by
2230 /// vector.from_elements. E.g.:
2232 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2233 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2234 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2235 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2236 PatternRewriter &rewriter) {
2237 // Dynamic positions are not supported.
2238 if (extractOp.hasDynamicPosition())
2239 return failure();
2241 // Scalar extracts are handled by the folder.
2242 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2243 if (!resultType)
2244 return failure();
2246 // Look for extracts from a from_elements op.
2247 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2248 if (!fromElementsOp)
2249 return failure();
2250 VectorType inputType = fromElementsOp.getType();
2252 // Scalable vectors are not supported.
2253 if (resultType.isScalable() || inputType.isScalable())
2254 return failure();
2256 // Compute the position of first extracted element and flatten/linearize the
2257 // position.
2258 SmallVector<int64_t> firstElementPos =
2259 llvm::to_vector(extractOp.getStaticPosition());
2260 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2261 int flatIndex = 0;
2262 int stride = 1;
2263 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2264 flatIndex += firstElementPos[i] * stride;
2265 stride *= inputType.getDimSize(i);
2268 // Replace the op with a smaller from_elements op.
2269 rewriter.replaceOpWithNewOp<FromElementsOp>(
2270 extractOp, resultType,
2271 fromElementsOp.getElements().slice(flatIndex,
2272 resultType.getNumElements()));
2273 return success();
2276 /// Fold an insert or extract operation into an poison value when a poison index
2277 /// is found at any dimension of the static position.
2278 template <typename OpTy>
2279 LogicalResult
2280 canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
2281 if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
2282 op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
2283 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
2284 return success();
2287 return failure();
2290 } // namespace
2292 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2293 MLIRContext *context) {
2294 results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2295 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2296 results.add(foldExtractFromShapeCastToShapeCast);
2297 results.add(foldExtractFromFromElements);
2298 results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
2301 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2302 SmallVectorImpl<int64_t> &results) {
2303 for (auto attr : arrayAttr)
2304 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2307 //===----------------------------------------------------------------------===//
2308 // FmaOp
2309 //===----------------------------------------------------------------------===//
2311 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2312 return llvm::to_vector<4>(getVectorType().getShape());
2315 //===----------------------------------------------------------------------===//
2316 // FromElementsOp
2317 //===----------------------------------------------------------------------===//
2319 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2320 /// same SSA value. E.g.:
2322 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2323 /// ==> rewrite to vector.splat %a : vector<3xf32>
2324 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2325 PatternRewriter &rewriter) {
2326 if (!llvm::all_equal(fromElementsOp.getElements()))
2327 return failure();
2328 rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2329 fromElementsOp.getElements().front());
2330 return success();
2333 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2334 MLIRContext *context) {
2335 results.add(rewriteFromElementsAsSplat);
2338 //===----------------------------------------------------------------------===//
2339 // BroadcastOp
2340 //===----------------------------------------------------------------------===//
2342 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2343 SetIntRangeFn setResultRanges) {
2344 setResultRanges(getResult(), argRanges.front());
2347 /// Return the dimensions of the result vector that were formerly ones in the
2348 /// source tensor and thus correspond to "dim-1" broadcasting.
2349 static llvm::SetVector<int64_t>
2350 computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
2351 ArrayRef<int64_t> dstShape) {
2352 int64_t rankDiff = dstShape.size() - srcShape.size();
2353 int64_t dstDim = rankDiff;
2354 llvm::SetVector<int64_t> res;
2355 for (auto [s1, s2] :
2356 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2357 if (s1 != s2) {
2358 assert(s1 == 1 && "expected dim-1 broadcasting");
2359 res.insert(dstDim);
2361 ++dstDim;
2363 return res;
2366 llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2367 // Scalar broadcast is without any unit dim broadcast.
2368 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2369 if (!srcVectorType)
2370 return {};
2371 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2372 getResultVectorType().getShape());
2375 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2376 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2377 /// This requires (and asserts) that the broadcast is free of dim-1
2378 /// broadcasting.
2379 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2380 /// vector.transpose may be inserted to make the broadcast possible.
2381 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2382 /// the helper will assert. This means:
2383 /// 1. `dstShape` must not be empty.
2384 /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2385 /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2386 // must match the `value` shape.
2387 Value BroadcastOp::createOrFoldBroadcastOp(
2388 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2389 const llvm::SetVector<int64_t> &broadcastedDims) {
2390 assert(!dstShape.empty() && "unexpected empty dst shape");
2392 // Well-formedness check.
2393 SmallVector<int64_t> checkShape;
2394 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2395 if (broadcastedDims.contains(i))
2396 continue;
2397 checkShape.push_back(dstShape[i]);
2399 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2400 "ill-formed broadcastedDims contains values not confined to "
2401 "destVectorShape");
2403 Location loc = value.getLoc();
2404 Type elementType = getElementTypeOrSelf(value.getType());
2405 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2406 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2408 // Step 2. If scalar -> dstShape broadcast, just do it.
2409 if (!srcVectorType) {
2410 assert(checkShape.empty() &&
2411 "ill-formed createOrFoldBroadcastOp arguments");
2412 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2415 assert(srcVectorType.getShape().equals(checkShape) &&
2416 "ill-formed createOrFoldBroadcastOp arguments");
2418 // Step 3. Since vector.broadcast only allows creating leading dims,
2419 // vector -> dstShape broadcast may require a transpose.
2420 // Traverse the dims in order and construct:
2421 // 1. The leading entries of the broadcastShape that is guaranteed to be
2422 // achievable by a simple broadcast.
2423 // 2. The induced permutation for the subsequent vector.transpose that will
2424 // bring us from `broadcastShape` back to he desired `dstShape`.
2425 // If the induced permutation is not the identity, create a vector.transpose.
2426 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2427 broadcastShape.reserve(dstShape.size());
2428 // Consider the example:
2429 // srcShape = 2x4
2430 // dstShape = 1x2x3x4x5
2431 // broadcastedDims = [0, 2, 4]
2433 // We want to build:
2434 // broadcastShape = 1x3x5x2x4
2435 // permutation = [0, 2, 4, 1, 3]
2436 // ---V--- -----V-----
2437 // leading broadcast part src shape part
2439 // Note that the trailing dims of broadcastShape are exactly the srcShape
2440 // by construction.
2441 // nextSrcShapeDim is used to keep track of where in the permutation the
2442 // "src shape part" occurs.
2443 int64_t nextSrcShapeDim = broadcastedDims.size();
2444 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2445 if (broadcastedDims.contains(i)) {
2446 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2447 // bring it to the head of the broadcastShape.
2448 // It will need to be permuted back from `broadcastShape.size() - 1` into
2449 // position `i`.
2450 broadcastShape.push_back(dstShape[i]);
2451 permutation[i] = broadcastShape.size() - 1;
2452 } else {
2453 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2454 // shape and needs to be permuted into position `i`.
2455 // Don't touch `broadcastShape` here, the whole srcShape will be
2456 // appended after.
2457 permutation[i] = nextSrcShapeDim++;
2460 // 3.c. Append the srcShape.
2461 llvm::append_range(broadcastShape, srcVectorType.getShape());
2463 // Ensure there are no dim-1 broadcasts.
2464 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2465 .empty() &&
2466 "unexpected dim-1 broadcast");
2468 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2469 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2470 vector::BroadcastableToResult::Success &&
2471 "must be broadcastable");
2472 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2473 // Step 4. If we find any dimension that indeed needs to be permuted,
2474 // immediately return a new vector.transpose.
2475 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2476 if (permutation[i] != i)
2477 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2478 // Otherwise return res.
2479 return res;
2482 BroadcastableToResult mlir::vector::isBroadcastableTo(
2483 Type srcType, VectorType dstVectorType,
2484 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2485 // Broadcast scalar to vector of the same element type.
2486 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2487 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2488 return BroadcastableToResult::Success;
2489 // From now on, only vectors broadcast.
2490 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2491 if (!srcVectorType)
2492 return BroadcastableToResult::SourceTypeNotAVector;
2494 int64_t srcRank = srcVectorType.getRank();
2495 int64_t dstRank = dstVectorType.getRank();
2496 if (srcRank > dstRank)
2497 return BroadcastableToResult::SourceRankHigher;
2498 // Source has an exact match or singleton value for all trailing dimensions
2499 // (all leading dimensions are simply duplicated).
2500 int64_t lead = dstRank - srcRank;
2501 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2502 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2503 // encountered?
2504 bool foundMismatchingDims = false;
2506 // Check fixed-width dims.
2507 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2508 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2509 if (srcDim != 1 && srcDim != dstDim)
2510 foundMismatchingDims = true;
2512 // Check scalable flags.
2513 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2514 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2515 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2516 // 1 -> [N] is fine, everything else should be rejected when mixing
2517 // fixed-width and scalable dims
2518 (srcDimScalableFlag != dstDimScalableFlag &&
2519 (srcDim != 1 || srcDimScalableFlag)))
2520 foundMismatchingDims = true;
2522 if (foundMismatchingDims) {
2523 if (mismatchingDims != nullptr) {
2524 mismatchingDims->first.dim = srcDim;
2525 mismatchingDims->first.isScalable = srcDimScalableFlag;
2527 mismatchingDims->second.dim = dstDim;
2528 mismatchingDims->second.isScalable = dstDimScalableFlag;
2530 return BroadcastableToResult::DimensionMismatch;
2534 return BroadcastableToResult::Success;
2537 LogicalResult BroadcastOp::verify() {
2538 std::pair<VectorDim, VectorDim> mismatchingDims;
2539 BroadcastableToResult res = isBroadcastableTo(
2540 getSourceType(), getResultVectorType(), &mismatchingDims);
2541 if (res == BroadcastableToResult::Success)
2542 return success();
2543 if (res == BroadcastableToResult::SourceRankHigher)
2544 return emitOpError("source rank higher than destination rank");
2545 if (res == BroadcastableToResult::DimensionMismatch) {
2546 return emitOpError("dimension mismatch (")
2547 << (mismatchingDims.first.isScalable ? "[" : "")
2548 << mismatchingDims.first.dim
2549 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2550 << (mismatchingDims.second.isScalable ? "[" : "")
2551 << mismatchingDims.second.dim
2552 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2554 if (res == BroadcastableToResult::SourceTypeNotAVector)
2555 return emitOpError("source type is not a vector");
2556 llvm_unreachable("unexpected vector.broadcast op error");
2559 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2560 if (getSourceType() == getResultVectorType())
2561 return getSource();
2562 if (!adaptor.getSource())
2563 return {};
2564 auto vectorType = getResultVectorType();
2565 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2566 if (vectorType.getElementType() != attr.getType())
2567 return {};
2568 return DenseElementsAttr::get(vectorType, attr);
2570 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2571 if (vectorType.getElementType() != attr.getType())
2572 return {};
2573 return DenseElementsAttr::get(vectorType, attr);
2575 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2576 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2577 return {};
2580 namespace {
2582 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2583 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2584 using OpRewritePattern::OpRewritePattern;
2586 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2587 PatternRewriter &rewriter) const override {
2588 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2589 if (!srcBroadcast)
2590 return failure();
2591 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2592 broadcastOp.getResultVectorType(),
2593 srcBroadcast.getSource());
2594 return success();
2597 } // namespace
2599 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2600 MLIRContext *context) {
2601 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2602 // calling `populateCastAwayVectorLeadingOneDimPatterns`
2603 results.add<BroadcastFolder>(context);
2606 //===----------------------------------------------------------------------===//
2607 // ShuffleOp
2608 //===----------------------------------------------------------------------===//
2610 LogicalResult ShuffleOp::verify() {
2611 VectorType resultType = getResultVectorType();
2612 VectorType v1Type = getV1VectorType();
2613 VectorType v2Type = getV2VectorType();
2614 // Verify ranks.
2615 int64_t resRank = resultType.getRank();
2616 int64_t v1Rank = v1Type.getRank();
2617 int64_t v2Rank = v2Type.getRank();
2618 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2619 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2620 if (!wellFormed0DCase && !wellFormedNDCase)
2621 return emitOpError("rank mismatch");
2623 // Verify all but leading dimension sizes.
2624 for (int64_t r = 1; r < v1Rank; ++r) {
2625 int64_t resDim = resultType.getDimSize(r);
2626 int64_t v1Dim = v1Type.getDimSize(r);
2627 int64_t v2Dim = v2Type.getDimSize(r);
2628 if (resDim != v1Dim || v1Dim != v2Dim)
2629 return emitOpError("dimension mismatch");
2631 // Verify mask length.
2632 ArrayRef<int64_t> mask = getMask();
2633 int64_t maskLength = mask.size();
2634 if (maskLength <= 0)
2635 return emitOpError("invalid mask length");
2636 if (maskLength != resultType.getDimSize(0))
2637 return emitOpError("mask length mismatch");
2638 // Verify all indices.
2639 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2640 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2641 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2642 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2643 return emitOpError("mask index #") << (idx + 1) << " out of range";
2645 return success();
2648 LogicalResult
2649 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2650 ShuffleOp::Adaptor adaptor,
2651 SmallVectorImpl<Type> &inferredReturnTypes) {
2652 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2653 auto v1Rank = v1Type.getRank();
2654 // Construct resulting type: leading dimension matches mask
2655 // length, all trailing dimensions match the operands.
2656 SmallVector<int64_t, 4> shape;
2657 shape.reserve(v1Rank);
2658 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2659 // In the 0-D case there is no trailing shape to append.
2660 if (v1Rank > 0)
2661 llvm::append_range(shape, v1Type.getShape().drop_front());
2662 inferredReturnTypes.push_back(
2663 VectorType::get(shape, v1Type.getElementType()));
2664 return success();
2667 template <typename T>
2668 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2669 T expected = begin;
2670 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2671 return value == expected++;
2675 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2676 auto v1Type = getV1VectorType();
2677 auto v2Type = getV2VectorType();
2679 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2680 "Vector shuffle does not support scalable vectors");
2682 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2683 // but must be a canonicalization into a vector.broadcast.
2684 if (v1Type.getRank() == 0)
2685 return {};
2687 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2688 auto mask = getMask();
2689 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
2690 return getV1();
2691 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2692 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
2693 return getV2();
2695 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2696 if (!v1Attr || !v2Attr)
2697 return {};
2699 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2700 // manipulation.
2701 if (v1Type.getRank() != 1)
2702 return {};
2704 int64_t v1Size = v1Type.getDimSize(0);
2706 SmallVector<Attribute> results;
2707 auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
2708 auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
2709 for (int64_t maskIdx : mask) {
2710 Attribute indexedElm;
2711 // Select v1[0] for poison indices.
2712 // TODO: Return a partial poison vector when supported by the UB dialect.
2713 if (maskIdx == ShuffleOp::kPoisonIndex) {
2714 indexedElm = v1Elements[0];
2715 } else {
2716 indexedElm =
2717 maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2720 results.push_back(indexedElm);
2723 return DenseElementsAttr::get(getResultVectorType(), results);
2726 namespace {
2728 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2729 // to a broadcast.
2730 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2731 using OpRewritePattern::OpRewritePattern;
2733 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2734 PatternRewriter &rewriter) const override {
2735 VectorType v1VectorType = shuffleOp.getV1VectorType();
2736 ArrayRef<int64_t> mask = shuffleOp.getMask();
2737 if (v1VectorType.getRank() > 0)
2738 return failure();
2739 if (mask.size() != 1)
2740 return failure();
2741 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2742 if (mask[0] == 0)
2743 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2744 shuffleOp.getV1());
2745 else
2746 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2747 shuffleOp.getV2());
2748 return success();
2752 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2753 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2754 public:
2755 using OpRewritePattern::OpRewritePattern;
2757 LogicalResult matchAndRewrite(ShuffleOp op,
2758 PatternRewriter &rewriter) const override {
2759 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2760 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2762 if (!v1Splat || !v2Splat)
2763 return failure();
2765 if (v1Splat.getInput() != v2Splat.getInput())
2766 return failure();
2768 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2769 return success();
2773 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2774 /// vector.interleave.
2775 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2776 public:
2777 using OpRewritePattern::OpRewritePattern;
2779 LogicalResult matchAndRewrite(ShuffleOp op,
2780 PatternRewriter &rewriter) const override {
2781 VectorType resultType = op.getResultVectorType();
2782 if (resultType.isScalable())
2783 return rewriter.notifyMatchFailure(
2784 op, "ShuffleOp can't represent a scalable interleave");
2786 if (resultType.getRank() != 1)
2787 return rewriter.notifyMatchFailure(
2788 op, "ShuffleOp can't represent an n-D interleave");
2790 VectorType sourceType = op.getV1VectorType();
2791 if (sourceType != op.getV2VectorType() ||
2792 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2793 return rewriter.notifyMatchFailure(
2794 op, "ShuffleOp types don't match an interleave");
2797 ArrayRef<int64_t> shuffleMask = op.getMask();
2798 int64_t resultVectorSize = resultType.getNumElements();
2799 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2800 int64_t maskValueA = shuffleMask[i * 2];
2801 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2802 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2803 return rewriter.notifyMatchFailure(op,
2804 "ShuffleOp mask not interleaving");
2807 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2808 return success();
2812 } // namespace
2814 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2815 MLIRContext *context) {
2816 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2817 context);
2820 //===----------------------------------------------------------------------===//
2821 // InsertElementOp
2822 //===----------------------------------------------------------------------===//
2824 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2825 SetIntRangeFn setResultRanges) {
2826 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2829 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2830 Value source, Value dest) {
2831 build(builder, result, source, dest, {});
2834 LogicalResult InsertElementOp::verify() {
2835 auto dstVectorType = getDestVectorType();
2836 if (dstVectorType.getRank() == 0) {
2837 if (getPosition())
2838 return emitOpError("expected position to be empty with 0-D vector");
2839 return success();
2841 if (dstVectorType.getRank() != 1)
2842 return emitOpError("unexpected >1 vector rank");
2843 if (!getPosition())
2844 return emitOpError("expected position for 1-D vector");
2845 return success();
2848 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2849 // Skip the 0-D vector here.
2850 if (!adaptor.getPosition())
2851 return {};
2853 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2854 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2855 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2856 if (!src || !dst || !pos)
2857 return {};
2859 if (src.getType() != getDestVectorType().getElementType())
2860 return {};
2862 auto dstElements = dst.getValues<Attribute>();
2864 SmallVector<Attribute> results(dstElements);
2866 uint64_t posIdx = pos.getInt();
2867 if (posIdx >= results.size())
2868 return {};
2869 results[posIdx] = src;
2871 return DenseElementsAttr::get(getDestVectorType(), results);
2874 //===----------------------------------------------------------------------===//
2875 // InsertOp
2876 //===----------------------------------------------------------------------===//
2878 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2879 SetIntRangeFn setResultRanges) {
2880 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2883 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2884 Value source, Value dest, int64_t position) {
2885 build(builder, result, source, dest, ArrayRef<int64_t>{position});
2888 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2889 Value source, Value dest, OpFoldResult position) {
2890 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2893 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2894 Value source, Value dest,
2895 ArrayRef<int64_t> position) {
2896 SmallVector<OpFoldResult> posVals;
2897 posVals.reserve(position.size());
2898 llvm::transform(position, std::back_inserter(posVals),
2899 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2900 build(builder, result, source, dest, posVals);
2903 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2904 Value source, Value dest,
2905 ArrayRef<OpFoldResult> position) {
2906 SmallVector<int64_t> staticPos;
2907 SmallVector<Value> dynamicPos;
2908 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2909 build(builder, result, source, dest, dynamicPos,
2910 builder.getDenseI64ArrayAttr(staticPos));
2913 LogicalResult InsertOp::verify() {
2914 SmallVector<OpFoldResult> position = getMixedPosition();
2915 auto destVectorType = getDestVectorType();
2916 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2917 return emitOpError(
2918 "expected position attribute of rank no greater than dest vector rank");
2919 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2920 if (srcVectorType &&
2921 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2922 static_cast<unsigned>(destVectorType.getRank())))
2923 return emitOpError("expected position attribute rank + source rank to "
2924 "match dest vector rank");
2925 if (!srcVectorType &&
2926 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2927 return emitOpError(
2928 "expected position attribute rank to match the dest vector rank");
2929 for (auto [idx, pos] : llvm::enumerate(position)) {
2930 if (auto attr = pos.dyn_cast<Attribute>()) {
2931 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2932 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2933 destVectorType.getDimSize(idx))) {
2934 return emitOpError("expected position attribute #")
2935 << (idx + 1)
2936 << " to be a non-negative integer smaller than the "
2937 "corresponding "
2938 "dest vector dimension";
2942 return success();
2945 namespace {
2947 // If insertOp is only inserting unit dimensions it can be transformed to a
2948 // broadcast.
2949 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2950 public:
2951 using OpRewritePattern::OpRewritePattern;
2953 LogicalResult matchAndRewrite(InsertOp insertOp,
2954 PatternRewriter &rewriter) const override {
2955 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2956 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2957 srcVecType.getNumElements())
2958 return failure();
2959 rewriter.replaceOpWithNewOp<BroadcastOp>(
2960 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2961 return success();
2965 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2966 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2967 public:
2968 using OpRewritePattern::OpRewritePattern;
2970 LogicalResult matchAndRewrite(InsertOp op,
2971 PatternRewriter &rewriter) const override {
2972 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2973 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2975 if (!srcSplat || !dstSplat)
2976 return failure();
2978 if (srcSplat.getInput() != dstSplat.getInput())
2979 return failure();
2981 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2982 return success();
2986 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2987 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2988 public:
2989 using OpRewritePattern::OpRewritePattern;
2991 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2992 // unless the source vector constant has a single use.
2993 static constexpr int64_t vectorSizeFoldThreshold = 256;
2995 LogicalResult matchAndRewrite(InsertOp op,
2996 PatternRewriter &rewriter) const override {
2997 // TODO: Canonicalization for dynamic position not implemented yet.
2998 if (op.hasDynamicPosition())
2999 return failure();
3001 // Return if 'InsertOp' operand is not defined by a compatible vector
3002 // ConstantOp.
3003 TypedValue<VectorType> destVector = op.getDest();
3004 Attribute vectorDestCst;
3005 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3006 return failure();
3007 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3008 if (!denseDest)
3009 return failure();
3011 VectorType destTy = destVector.getType();
3012 if (destTy.isScalable())
3013 return failure();
3015 // Make sure we do not create too many large constants.
3016 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3017 !destVector.hasOneUse())
3018 return failure();
3020 Value sourceValue = op.getSource();
3021 Attribute sourceCst;
3022 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3023 return failure();
3025 // Calculate the linearized position of the continuous chunk of elements to
3026 // insert.
3027 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3028 copy(op.getStaticPosition(), completePositions.begin());
3029 int64_t insertBeginPosition =
3030 linearize(completePositions, computeStrides(destTy.getShape()));
3032 SmallVector<Attribute> insertedValues;
3033 Type destEltType = destTy.getElementType();
3035 // The `convertIntegerAttr` method specifically handles the case
3036 // for `llvm.mlir.constant` which can hold an attribute with a
3037 // different type than the return type.
3038 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3039 for (auto value : denseSource.getValues<Attribute>())
3040 insertedValues.push_back(convertIntegerAttr(value, destEltType));
3041 } else {
3042 insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
3045 auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
3046 copy(insertedValues, allValues.begin() + insertBeginPosition);
3047 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3049 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3050 return success();
3053 private:
3054 /// Converts the expected type to an IntegerAttr if there's
3055 /// a mismatch.
3056 Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3057 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3058 if (intAttr.getType() != expectedType)
3059 return IntegerAttr::get(expectedType, intAttr.getInt());
3061 return attr;
3065 } // namespace
3067 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3068 MLIRContext *context) {
3069 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3070 InsertOpConstantFolder>(context);
3071 results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
3074 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3075 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3076 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3077 // (type mismatch).
3078 if (getNumIndices() == 0 && getSourceType() == getType())
3079 return getSource();
3080 if (auto res = foldPoisonIndexInsertExtractOp(
3081 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3082 return res;
3084 return {};
3087 //===----------------------------------------------------------------------===//
3088 // InsertStridedSliceOp
3089 //===----------------------------------------------------------------------===//
3091 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3092 Value source, Value dest,
3093 ArrayRef<int64_t> offsets,
3094 ArrayRef<int64_t> strides) {
3095 result.addOperands({source, dest});
3096 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3097 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3098 result.addTypes(dest.getType());
3099 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3100 offsetsAttr);
3101 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3102 stridesAttr);
3105 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3106 template <typename OpType>
3107 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3108 ArrayAttr arrayAttr,
3109 ArrayRef<int64_t> shape,
3110 StringRef attrName) {
3111 if (arrayAttr.size() > shape.size())
3112 return op.emitOpError("expected ")
3113 << attrName << " attribute of rank no greater than vector rank";
3114 return success();
3117 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3118 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3119 // Otherwise, the admissible interval is [min, max].
3120 template <typename OpType>
3121 static LogicalResult
3122 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3123 int64_t max, StringRef attrName,
3124 bool halfOpen = true) {
3125 for (auto attr : arrayAttr) {
3126 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3127 auto upper = max;
3128 if (!halfOpen)
3129 upper += 1;
3130 if (val < min || val >= upper)
3131 return op.emitOpError("expected ") << attrName << " to be confined to ["
3132 << min << ", " << upper << ")";
3134 return success();
3137 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3138 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3139 // Otherwise, the admissible interval is [min, max].
3140 template <typename OpType>
3141 static LogicalResult
3142 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3143 ArrayRef<int64_t> shape, StringRef attrName,
3144 bool halfOpen = true, int64_t min = 0) {
3145 for (auto [index, attrDimPair] :
3146 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3147 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3148 int64_t max = std::get<1>(attrDimPair);
3149 if (!halfOpen)
3150 max += 1;
3151 if (val < min || val >= max)
3152 return op.emitOpError("expected ")
3153 << attrName << " dimension " << index << " to be confined to ["
3154 << min << ", " << max << ")";
3156 return success();
3159 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3160 // [min, max} interval:
3161 // val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3162 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3163 // the admissible interval is [min, max].
3164 template <typename OpType>
3165 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
3166 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3167 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3168 bool halfOpen = true, int64_t min = 1) {
3169 assert(arrayAttr1.size() <= shape.size());
3170 assert(arrayAttr2.size() <= shape.size());
3171 for (auto [index, it] :
3172 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3173 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3174 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3175 int64_t max = std::get<2>(it);
3176 if (!halfOpen)
3177 max += 1;
3178 if (val1 + val2 < 0 || val1 + val2 >= max)
3179 return op.emitOpError("expected sum(")
3180 << attrName1 << ", " << attrName2 << ") dimension " << index
3181 << " to be confined to [" << min << ", " << max << ")";
3183 return success();
3186 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3187 MLIRContext *context) {
3188 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3189 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3191 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3194 LogicalResult InsertStridedSliceOp::verify() {
3195 auto sourceVectorType = getSourceVectorType();
3196 auto destVectorType = getDestVectorType();
3197 auto offsets = getOffsetsAttr();
3198 auto strides = getStridesAttr();
3199 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3200 return emitOpError(
3201 "expected offsets of same size as destination vector rank");
3202 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3203 return emitOpError("expected strides of same size as source vector rank");
3204 if (sourceVectorType.getRank() > destVectorType.getRank())
3205 return emitOpError(
3206 "expected source rank to be no greater than destination rank");
3208 auto sourceShape = sourceVectorType.getShape();
3209 auto destShape = destVectorType.getShape();
3210 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3211 destShape.size() - sourceShape.size(), 0);
3212 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3213 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3214 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3215 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3216 offName)) ||
3217 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3218 /*max=*/1, stridesName,
3219 /*halfOpen=*/false)) ||
3220 failed(isSumOfIntegerArrayAttrConfinedToShape(
3221 *this, offsets,
3222 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3223 offName, "source vector shape",
3224 /*halfOpen=*/false, /*min=*/1)))
3225 return failure();
3227 unsigned rankDiff = destShape.size() - sourceShape.size();
3228 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3229 if (sourceVectorType.getScalableDims()[idx] !=
3230 destVectorType.getScalableDims()[idx + rankDiff]) {
3231 return emitOpError("mismatching scalable flags (at source vector idx=")
3232 << idx << ")";
3234 if (sourceVectorType.getScalableDims()[idx]) {
3235 auto sourceSize = sourceShape[idx];
3236 auto destSize = destShape[idx + rankDiff];
3237 if (sourceSize != destSize) {
3238 return emitOpError("expected size at idx=")
3239 << idx
3240 << (" to match the corresponding base size from the input "
3241 "vector (")
3242 << sourceSize << (" vs ") << destSize << (")");
3247 return success();
3250 namespace {
3251 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3252 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3253 class FoldInsertStridedSliceSplat final
3254 : public OpRewritePattern<InsertStridedSliceOp> {
3255 public:
3256 using OpRewritePattern::OpRewritePattern;
3258 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3259 PatternRewriter &rewriter) const override {
3260 auto srcSplatOp =
3261 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3262 auto destSplatOp =
3263 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3265 if (!srcSplatOp || !destSplatOp)
3266 return failure();
3268 if (srcSplatOp.getInput() != destSplatOp.getInput())
3269 return failure();
3271 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3272 return success();
3276 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3277 /// to dst.
3278 class FoldInsertStridedSliceOfExtract final
3279 : public OpRewritePattern<InsertStridedSliceOp> {
3280 public:
3281 using OpRewritePattern::OpRewritePattern;
3283 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3284 PatternRewriter &rewriter) const override {
3285 auto extractStridedSliceOp =
3286 insertStridedSliceOp.getSource()
3287 .getDefiningOp<vector::ExtractStridedSliceOp>();
3289 if (!extractStridedSliceOp)
3290 return failure();
3292 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3293 return failure();
3295 // Check if have the same strides and offsets.
3296 if (extractStridedSliceOp.getStrides() !=
3297 insertStridedSliceOp.getStrides() ||
3298 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3299 return failure();
3301 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3302 return success();
3306 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3307 // ConstantOp.
3308 class InsertStridedSliceConstantFolder final
3309 : public OpRewritePattern<InsertStridedSliceOp> {
3310 public:
3311 using OpRewritePattern::OpRewritePattern;
3313 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3314 // unless the source vector constant has a single use.
3315 static constexpr int64_t vectorSizeFoldThreshold = 256;
3317 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3318 PatternRewriter &rewriter) const override {
3319 // Return if 'InsertOp' operand is not defined by a compatible vector
3320 // ConstantOp.
3321 TypedValue<VectorType> destVector = op.getDest();
3322 Attribute vectorDestCst;
3323 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3324 return failure();
3326 VectorType destTy = destVector.getType();
3327 if (destTy.isScalable())
3328 return failure();
3330 // Make sure we do not create too many large constants.
3331 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3332 !destVector.hasOneUse())
3333 return failure();
3335 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3337 TypedValue<VectorType> sourceValue = op.getSource();
3338 Attribute sourceCst;
3339 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3340 return failure();
3342 // TODO: Handle non-unit strides when they become available.
3343 if (op.hasNonUnitStrides())
3344 return failure();
3346 VectorType sliceVecTy = sourceValue.getType();
3347 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3348 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3349 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3350 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3352 // Calcualte the destination element indices by enumerating all slice
3353 // positions within the destination and linearizing them. The enumeration
3354 // order is lexicographic which yields a sequence of monotonically
3355 // increasing linearized position indices.
3356 // Because the destination may have higher dimensionality then the slice,
3357 // we keep track of two overlapping sets of positions and offsets.
3358 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3359 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3360 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3361 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3362 MutableArrayRef<int64_t> currSlicePosition(
3363 currDestPosition.begin() + rankDifference, currDestPosition.end());
3364 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3365 offsets.end());
3366 do {
3367 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3368 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3369 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3370 "Invalid slice element");
3371 newValues[linearizedPosition] = *sliceValuesIt;
3372 ++sliceValuesIt;
3373 } while (succeeded(
3374 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3376 auto newAttr = DenseElementsAttr::get(destTy, newValues);
3377 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3378 return success();
3382 } // namespace
3384 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3385 RewritePatternSet &results, MLIRContext *context) {
3386 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3387 InsertStridedSliceConstantFolder>(context);
3390 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3391 if (getSourceVectorType() == getDestVectorType())
3392 return getSource();
3393 return {};
3396 //===----------------------------------------------------------------------===//
3397 // OuterProductOp
3398 //===----------------------------------------------------------------------===//
3400 /// Build an op without mask, use the type of `acc` as the return type.
3401 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3402 Value lhs, Value rhs, Value acc) {
3403 result.addOperands({lhs, rhs, acc});
3404 result.addTypes(acc.getType());
3407 void OuterProductOp::print(OpAsmPrinter &p) {
3408 p << " " << getLhs() << ", " << getRhs();
3409 if (getAcc()) {
3410 p << ", " << getAcc();
3411 p.printOptionalAttrDict((*this)->getAttrs());
3413 p << " : " << getLhs().getType() << ", " << getRhs().getType();
3416 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3417 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
3418 Type tLHS, tRHS;
3419 if (parser.parseOperandList(operandsInfo) ||
3420 parser.parseOptionalAttrDict(result.attributes) ||
3421 parser.parseColonType(tLHS) || parser.parseComma() ||
3422 parser.parseType(tRHS))
3423 return failure();
3424 if (operandsInfo.size() < 2)
3425 return parser.emitError(parser.getNameLoc(),
3426 "expected at least 2 operands");
3427 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3428 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3429 if (!vLHS)
3430 return parser.emitError(parser.getNameLoc(),
3431 "expected vector type for operand #1");
3433 VectorType resType;
3434 if (vRHS) {
3435 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3436 vRHS.getScalableDims()[0]};
3437 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3438 vLHS.getElementType(), scalableDimsRes);
3439 } else {
3440 // Scalar RHS operand
3441 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3442 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3443 scalableDimsRes);
3446 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3447 result.attributes.append(
3448 OuterProductOp::getKindAttrName(result.name),
3449 CombiningKindAttr::get(result.getContext(),
3450 OuterProductOp::getDefaultKind()));
3453 return failure(
3454 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3455 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3456 (operandsInfo.size() > 2 &&
3457 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3458 parser.addTypeToList(resType, result.types));
3461 LogicalResult OuterProductOp::verify() {
3462 Type tRHS = getOperandTypeRHS();
3463 VectorType vLHS = getOperandVectorTypeLHS(),
3464 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3465 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3467 if (vLHS.getRank() != 1)
3468 return emitOpError("expected 1-d vector for operand #1");
3470 if (vRHS) {
3471 // Proper OUTER operation.
3472 if (vRHS.getRank() != 1)
3473 return emitOpError("expected 1-d vector for operand #2");
3474 if (vRES.getRank() != 2)
3475 return emitOpError("expected 2-d vector result");
3476 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3477 return emitOpError("expected #1 operand dim to match result dim #1");
3478 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3479 return emitOpError("expected #2 operand dim to match result dim #2");
3480 if (vLHS.isScalable() && !vRHS.isScalable()) {
3481 // This restriction reflects what's currently supported in terms of
3482 // scalable vectors. However, we could relax this if there's a use case.
3483 return emitOpError(
3484 "expected either both or only #2 operand dim to be scalable");
3486 } else {
3487 // An AXPY operation.
3488 if (vRES.getRank() != 1)
3489 return emitOpError("expected 1-d vector result");
3490 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3491 return emitOpError("expected #1 operand dim to match result dim #1");
3494 if (vACC && vACC != vRES)
3495 return emitOpError("expected operand #3 of same type as result type");
3497 // Verify supported combining kind.
3498 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3499 return emitOpError("unsupported outerproduct type");
3501 return success();
3504 // MaskableOpInterface methods.
3506 /// Returns the mask type expected by this operation. Mostly used for
3507 /// verification purposes. It requires the operation to be vectorized."
3508 Type OuterProductOp::getExpectedMaskType() {
3509 auto vecType = this->getResultVectorType();
3510 return VectorType::get(vecType.getShape(),
3511 IntegerType::get(vecType.getContext(), /*width=*/1),
3512 vecType.getScalableDims());
3515 //===----------------------------------------------------------------------===//
3516 // ExtractStridedSliceOp
3517 //===----------------------------------------------------------------------===//
3519 // Inference works as follows:
3520 // 1. Add 'sizes' from prefix of dims in 'offsets'.
3521 // 2. Add sizes from 'vectorType' for remaining dims.
3522 // Scalable flags are inherited from 'vectorType'.
3523 static Type inferStridedSliceOpResultType(VectorType vectorType,
3524 ArrayAttr offsets, ArrayAttr sizes,
3525 ArrayAttr strides) {
3526 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3527 SmallVector<int64_t, 4> shape;
3528 shape.reserve(vectorType.getRank());
3529 unsigned idx = 0;
3530 for (unsigned e = offsets.size(); idx < e; ++idx)
3531 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3532 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3533 shape.push_back(vectorType.getShape()[idx]);
3535 return VectorType::get(shape, vectorType.getElementType(),
3536 vectorType.getScalableDims());
3539 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3540 Value source, ArrayRef<int64_t> offsets,
3541 ArrayRef<int64_t> sizes,
3542 ArrayRef<int64_t> strides) {
3543 result.addOperands(source);
3544 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3545 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3546 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3547 result.addTypes(
3548 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3549 offsetsAttr, sizesAttr, stridesAttr));
3550 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3551 offsetsAttr);
3552 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3553 sizesAttr);
3554 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3555 stridesAttr);
3558 LogicalResult ExtractStridedSliceOp::verify() {
3559 auto type = getSourceVectorType();
3560 auto offsets = getOffsetsAttr();
3561 auto sizes = getSizesAttr();
3562 auto strides = getStridesAttr();
3563 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3564 return emitOpError(
3565 "expected offsets, sizes and strides attributes of same size");
3567 auto shape = type.getShape();
3568 auto offName = getOffsetsAttrName();
3569 auto sizesName = getSizesAttrName();
3570 auto stridesName = getStridesAttrName();
3571 if (failed(
3572 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3573 failed(
3574 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3575 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3576 stridesName)) ||
3577 failed(
3578 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3579 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3580 /*halfOpen=*/false,
3581 /*min=*/1)) ||
3582 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3583 /*max=*/1, stridesName,
3584 /*halfOpen=*/false)) ||
3585 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3586 shape, offName, sizesName,
3587 /*halfOpen=*/false)))
3588 return failure();
3590 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3591 offsets, sizes, strides);
3592 if (getResult().getType() != resultType)
3593 return emitOpError("expected result type to be ") << resultType;
3595 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3596 if (type.getScalableDims()[idx]) {
3597 auto inputDim = type.getShape()[idx];
3598 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3599 if (inputDim != inputSize)
3600 return emitOpError("expected size at idx=")
3601 << idx
3602 << (" to match the corresponding base size from the input "
3603 "vector (")
3604 << inputSize << (" vs ") << inputDim << (")");
3608 return success();
3611 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3612 // to use the source of the InsertStrided ops if we can detect that the
3613 // extracted vector is a subset of one of the vector inserted.
3614 static LogicalResult
3615 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3616 // Helper to extract integer out of ArrayAttr.
3617 auto getElement = [](ArrayAttr array, int idx) {
3618 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3620 ArrayAttr extractOffsets = op.getOffsets();
3621 ArrayAttr extractStrides = op.getStrides();
3622 ArrayAttr extractSizes = op.getSizes();
3623 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3624 while (insertOp) {
3625 if (op.getSourceVectorType().getRank() !=
3626 insertOp.getSourceVectorType().getRank())
3627 return failure();
3628 ArrayAttr insertOffsets = insertOp.getOffsets();
3629 ArrayAttr insertStrides = insertOp.getStrides();
3630 // If the rank of extract is greater than the rank of insert, we are likely
3631 // extracting a partial chunk of the vector inserted.
3632 if (extractOffsets.size() > insertOffsets.size())
3633 return failure();
3634 bool patialoverlap = false;
3635 bool disjoint = false;
3636 SmallVector<int64_t, 4> offsetDiffs;
3637 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3638 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3639 return failure();
3640 int64_t start = getElement(insertOffsets, dim);
3641 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3642 int64_t offset = getElement(extractOffsets, dim);
3643 int64_t size = getElement(extractSizes, dim);
3644 // Check if the start of the extract offset is in the interval inserted.
3645 if (start <= offset && offset < end) {
3646 // If the extract interval overlaps but is not fully included we may
3647 // have a partial overlap that will prevent any folding.
3648 if (offset + size > end)
3649 patialoverlap = true;
3650 offsetDiffs.push_back(offset - start);
3651 continue;
3653 disjoint = true;
3654 break;
3656 // The extract element chunk is a subset of the insert element.
3657 if (!disjoint && !patialoverlap) {
3658 op.setOperand(insertOp.getSource());
3659 // OpBuilder is only used as a helper to build an I64ArrayAttr.
3660 OpBuilder b(op.getContext());
3661 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3662 return success();
3664 // If the chunk extracted is disjoint from the chunk inserted, keep looking
3665 // in the insert chain.
3666 if (disjoint)
3667 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3668 else {
3669 // The extracted vector partially overlap the inserted vector, we cannot
3670 // fold.
3671 return failure();
3674 return failure();
3677 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3678 if (getSourceVectorType() == getResult().getType())
3679 return getVector();
3680 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3681 return getResult();
3682 return {};
3685 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3686 populateFromInt64AttrArray(getOffsets(), results);
3689 namespace {
3691 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3692 // ConstantMaskOp.
3693 class StridedSliceConstantMaskFolder final
3694 : public OpRewritePattern<ExtractStridedSliceOp> {
3695 public:
3696 using OpRewritePattern::OpRewritePattern;
3698 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3699 PatternRewriter &rewriter) const override {
3700 // Return if 'extractStridedSliceOp' operand is not defined by a
3701 // ConstantMaskOp.
3702 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3703 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3704 if (!constantMaskOp)
3705 return failure();
3706 // Return if 'extractStridedSliceOp' has non-unit strides.
3707 if (extractStridedSliceOp.hasNonUnitStrides())
3708 return failure();
3709 // Gather constant mask dimension sizes.
3710 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3711 // Gather strided slice offsets and sizes.
3712 SmallVector<int64_t, 4> sliceOffsets;
3713 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3714 sliceOffsets);
3715 SmallVector<int64_t, 4> sliceSizes;
3716 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3718 // Compute slice of vector mask region.
3719 SmallVector<int64_t, 4> sliceMaskDimSizes;
3720 sliceMaskDimSizes.reserve(maskDimSizes.size());
3721 for (auto [maskDimSize, sliceOffset, sliceSize] :
3722 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3723 int64_t sliceMaskDimSize = std::max(
3724 static_cast<int64_t>(0),
3725 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3726 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3728 // Add unchanged dimensions.
3729 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3730 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3731 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3732 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3733 // region is a conjunction of mask dim intervals).
3734 if (llvm::is_contained(sliceMaskDimSizes, 0))
3735 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3737 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3738 // region.
3739 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3740 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3741 sliceMaskDimSizes);
3742 return success();
3746 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3747 class StridedSliceSplatConstantFolder final
3748 : public OpRewritePattern<ExtractStridedSliceOp> {
3749 public:
3750 using OpRewritePattern::OpRewritePattern;
3752 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3753 PatternRewriter &rewriter) const override {
3754 // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3755 // ConstantOp.
3756 Value sourceVector = extractStridedSliceOp.getVector();
3757 Attribute vectorCst;
3758 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3759 return failure();
3761 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3762 if (!splat)
3763 return failure();
3765 auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3766 splat.getSplatValue<Attribute>());
3767 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3768 newAttr);
3769 return success();
3773 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3774 // ConstantOp.
3775 class StridedSliceNonSplatConstantFolder final
3776 : public OpRewritePattern<ExtractStridedSliceOp> {
3777 public:
3778 using OpRewritePattern::OpRewritePattern;
3780 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3781 PatternRewriter &rewriter) const override {
3782 // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3783 // ConstantOp.
3784 Value sourceVector = extractStridedSliceOp.getVector();
3785 Attribute vectorCst;
3786 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3787 return failure();
3789 // The splat case is handled by `StridedSliceSplatConstantFolder`.
3790 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3791 if (!dense || dense.isSplat())
3792 return failure();
3794 // TODO: Handle non-unit strides when they become available.
3795 if (extractStridedSliceOp.hasNonUnitStrides())
3796 return failure();
3798 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3799 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3800 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3802 VectorType sliceVecTy = extractStridedSliceOp.getType();
3803 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3804 int64_t sliceRank = sliceVecTy.getRank();
3806 // Expand offsets and sizes to match the vector rank.
3807 SmallVector<int64_t, 4> offsets(sliceRank, 0);
3808 copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3810 SmallVector<int64_t, 4> sizes(sourceShape);
3811 copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3813 // Calculate the slice elements by enumerating all slice positions and
3814 // linearizing them. The enumeration order is lexicographic which yields a
3815 // sequence of monotonically increasing linearized position indices.
3816 auto denseValuesBegin = dense.value_begin<Attribute>();
3817 SmallVector<Attribute> sliceValues;
3818 sliceValues.reserve(sliceVecTy.getNumElements());
3819 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3820 do {
3821 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3822 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3823 "Invalid index");
3824 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3825 } while (
3826 succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3828 assert(static_cast<int64_t>(sliceValues.size()) ==
3829 sliceVecTy.getNumElements() &&
3830 "Invalid number of slice elements");
3831 auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3832 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3833 newAttr);
3834 return success();
3838 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3839 // BroadcastOp(ExtractStrideSliceOp).
3840 class StridedSliceBroadcast final
3841 : public OpRewritePattern<ExtractStridedSliceOp> {
3842 public:
3843 using OpRewritePattern::OpRewritePattern;
3845 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3846 PatternRewriter &rewriter) const override {
3847 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3848 if (!broadcast)
3849 return failure();
3850 auto srcVecType =
3851 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3852 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3853 auto dstVecType = llvm::cast<VectorType>(op.getType());
3854 unsigned dstRank = dstVecType.getRank();
3855 unsigned rankDiff = dstRank - srcRank;
3856 // Check if the most inner dimensions of the source of the broadcast are the
3857 // same as the destination of the extract. If this is the case we can just
3858 // use a broadcast as the original dimensions are untouched.
3859 bool lowerDimMatch = true;
3860 for (unsigned i = 0; i < srcRank; i++) {
3861 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3862 lowerDimMatch = false;
3863 break;
3866 Value source = broadcast.getSource();
3867 // If the inner dimensions don't match, it means we need to extract from the
3868 // source of the orignal broadcast and then broadcast the extracted value.
3869 // We also need to handle degenerated cases where the source is effectively
3870 // just a single scalar.
3871 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3872 if (!lowerDimMatch && !isScalarSrc) {
3873 source = rewriter.create<ExtractStridedSliceOp>(
3874 op->getLoc(), source,
3875 getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3876 getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3877 getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3879 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3880 return success();
3884 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3885 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3886 public:
3887 using OpRewritePattern::OpRewritePattern;
3889 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3890 PatternRewriter &rewriter) const override {
3891 auto splat = op.getVector().getDefiningOp<SplatOp>();
3892 if (!splat)
3893 return failure();
3894 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3895 return success();
3899 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3900 /// slice is contiguous, into extract and shape_cast.
3902 /// Example:
3903 /// Before:
3904 /// %1 = vector.extract_strided_slice %arg0 {
3905 /// offsets = [0, 0, 0, 0, 0],
3906 /// sizes = [1, 1, 1, 1, 8],
3907 /// strides = [1, 1, 1, 1, 1]
3908 /// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3909 /// After:
3910 /// %0 = vector.extract %arg0[0, 0, 0, 0]
3911 /// : vector<8xi8> from vector<8x1x1x2x8xi8>
3912 /// %1 = vector.shape_cast %0
3913 /// : vector<8xi8> to vector<1x1x1x1x8xi8>
3915 class ContiguousExtractStridedSliceToExtract final
3916 : public OpRewritePattern<ExtractStridedSliceOp> {
3917 public:
3918 using OpRewritePattern::OpRewritePattern;
3920 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3921 PatternRewriter &rewriter) const override {
3922 if (op.hasNonUnitStrides())
3923 return failure();
3924 Value source = op.getOperand();
3925 auto sourceType = cast<VectorType>(source.getType());
3926 if (sourceType.isScalable() || sourceType.getRank() == 0)
3927 return failure();
3929 // Compute the number of offsets to pass to ExtractOp::build. That is the
3930 // difference between the source rank and the desired slice rank. We walk
3931 // the dimensions from innermost out, and stop when the next slice dimension
3932 // is not full-size.
3933 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3934 int numOffsets;
3935 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3936 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3937 break;
3940 // If the created extract op would have no offsets, then this whole
3941 // extract_strided_slice is the identity and should have been handled by
3942 // other canonicalizations.
3943 if (numOffsets == 0)
3944 return failure();
3946 // If not even the inner-most dimension is full-size, this op can't be
3947 // rewritten as an ExtractOp.
3948 if (numOffsets == sourceType.getRank() &&
3949 static_cast<int>(sizes.size()) == sourceType.getRank())
3950 return failure();
3952 // The outer dimensions must have unit size.
3953 for (int i = 0; i < numOffsets; ++i) {
3954 if (sizes[i] != 1)
3955 return failure();
3958 // Avoid generating slices that have leading unit dimensions. The shape_cast
3959 // op that we create below would take bad generic fallback patterns
3960 // (ShapeCastOpRewritePattern).
3961 while (sizes[numOffsets] == 1 &&
3962 numOffsets < static_cast<int>(sizes.size()) - 1) {
3963 ++numOffsets;
3966 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3967 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3968 Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3969 extractOffsets);
3970 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3971 return success();
3975 } // namespace
3977 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3978 RewritePatternSet &results, MLIRContext *context) {
3979 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3980 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3981 results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3982 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3983 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3984 context);
3987 //===----------------------------------------------------------------------===//
3988 // TransferReadOp
3989 //===----------------------------------------------------------------------===//
3991 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3992 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3993 VectorType vectorType, Value source,
3994 ValueRange indices, AffineMapAttr permutationMapAttr,
3995 /*optional*/ ArrayAttr inBoundsAttr) {
3996 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3997 Value padding = builder.create<arith::ConstantOp>(
3998 result.location, elemType, builder.getZeroAttr(elemType));
3999 build(builder, result, vectorType, source, indices, permutationMapAttr,
4000 padding, /*mask=*/Value(), inBoundsAttr);
4003 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4004 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4005 VectorType vectorType, Value source,
4006 ValueRange indices, AffineMap permutationMap,
4007 std::optional<ArrayRef<bool>> inBounds) {
4008 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4009 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4010 ? builder.getBoolArrayAttr(inBounds.value())
4011 : builder.getBoolArrayAttr(
4012 SmallVector<bool>(vectorType.getRank(), false));
4013 build(builder, result, vectorType, source, indices, permutationMapAttr,
4014 inBoundsAttr);
4017 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4018 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4019 VectorType vectorType, Value source,
4020 ValueRange indices, Value padding,
4021 std::optional<ArrayRef<bool>> inBounds) {
4022 AffineMap permutationMap = getTransferMinorIdentityMap(
4023 llvm::cast<ShapedType>(source.getType()), vectorType);
4024 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4025 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4026 ? builder.getBoolArrayAttr(inBounds.value())
4027 : builder.getBoolArrayAttr(
4028 SmallVector<bool>(vectorType.getRank(), false));
4029 build(builder, result, vectorType, source, indices, permutationMapAttr,
4030 padding,
4031 /*mask=*/Value(), inBoundsAttr);
4034 /// 4. Builder that sets padding to zero and permutation map to
4035 /// 'getMinorIdentityMap'.
4036 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4037 VectorType vectorType, Value source,
4038 ValueRange indices,
4039 std::optional<ArrayRef<bool>> inBounds) {
4040 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4041 Value padding = builder.create<arith::ConstantOp>(
4042 result.location, elemType, builder.getZeroAttr(elemType));
4043 build(builder, result, vectorType, source, indices, padding, inBounds);
4046 template <typename EmitFun>
4047 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4048 EmitFun emitOpError) {
4049 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4050 for (auto expr : permutationMap.getResults()) {
4051 auto dim = dyn_cast<AffineDimExpr>(expr);
4052 auto zero = dyn_cast<AffineConstantExpr>(expr);
4053 if (zero) {
4054 if (zero.getValue() != 0) {
4055 return emitOpError(
4056 "requires a projected permutation_map (at most one dim or the zero "
4057 "constant can appear in each result)");
4059 continue;
4061 if (!dim) {
4062 return emitOpError("requires a projected permutation_map (at most one "
4063 "dim or the zero constant can appear in each result)");
4065 if (seen[dim.getPosition()]) {
4066 return emitOpError(
4067 "requires a permutation_map that is a permutation (found one dim "
4068 "used more than once)");
4070 seen[dim.getPosition()] = true;
4072 return success();
4075 static LogicalResult
4076 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4077 VectorType vectorType, VectorType maskType,
4078 VectorType inferredMaskType, AffineMap permutationMap,
4079 ArrayAttr inBounds) {
4080 if (op->hasAttr("masked")) {
4081 return op->emitOpError("masked attribute has been removed. "
4082 "Use in_bounds instead.");
4085 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4086 return op->emitOpError(
4087 "requires source to be a memref or ranked tensor type");
4089 auto elementType = shapedType.getElementType();
4090 DataLayout dataLayout = DataLayout::closest(op);
4091 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4092 // Memref or tensor has vector element type.
4093 unsigned sourceVecSize =
4094 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4095 vectorElementType.getShape().back();
4096 unsigned resultVecSize =
4097 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4098 vectorType.getShape().back();
4099 if (resultVecSize % sourceVecSize != 0)
4100 return op->emitOpError(
4101 "requires the bitwidth of the minor 1-D vector to be an integral "
4102 "multiple of the bitwidth of the minor 1-D vector of the source");
4104 unsigned sourceVecEltRank = vectorElementType.getRank();
4105 unsigned resultVecRank = vectorType.getRank();
4106 if (sourceVecEltRank > resultVecRank)
4107 return op->emitOpError(
4108 "requires source vector element and vector result ranks to match.");
4109 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4110 // Check that permutation map results match 'rankOffset' of vector type.
4111 if (permutationMap.getNumResults() != rankOffset)
4112 return op->emitOpError("requires a permutation_map with result dims of "
4113 "the same rank as the vector type");
4115 if (maskType)
4116 return op->emitOpError("does not support masks with vector element type");
4117 } else {
4118 // Memref or tensor has scalar element type.
4119 unsigned minorSize =
4120 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4121 unsigned resultVecSize =
4122 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4123 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4124 return op->emitOpError(
4125 "requires the bitwidth of the minor 1-D vector to be an integral "
4126 "multiple of the bitwidth of the source element type");
4128 // Check that permutation map results match rank of vector type.
4129 if (permutationMap.getNumResults() != vectorType.getRank())
4130 return op->emitOpError("requires a permutation_map with result dims of "
4131 "the same rank as the vector type");
4134 if (permutationMap.getNumSymbols() != 0)
4135 return op->emitOpError("requires permutation_map without symbols");
4137 if (permutationMap.getNumInputs() != shapedType.getRank())
4138 return op->emitOpError("requires a permutation_map with input dims of the "
4139 "same rank as the source type");
4141 if (maskType && maskType != inferredMaskType)
4142 return op->emitOpError("inferred mask type (")
4143 << inferredMaskType << ") and mask operand type (" << maskType
4144 << ") don't match";
4146 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4147 return op->emitOpError("expects the in_bounds attr of same rank "
4148 "as permutation_map results: ")
4149 << AffineMapAttr::get(permutationMap)
4150 << " vs inBounds of size: " << inBounds.size();
4152 return success();
4155 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4156 SmallVector<StringRef, 3> elidedAttrs;
4157 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4158 if (op.getPermutationMap().isMinorIdentity())
4159 elidedAttrs.push_back(op.getPermutationMapAttrName());
4160 // Elide in_bounds attribute if all dims are out-of-bounds.
4161 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4162 elidedAttrs.push_back(op.getInBoundsAttrName());
4163 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4166 void TransferReadOp::print(OpAsmPrinter &p) {
4167 p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4168 if (getMask())
4169 p << ", " << getMask();
4170 printTransferAttrs(p, *this);
4171 p << " : " << getShapedType() << ", " << getVectorType();
4174 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4175 AffineMap permMap) {
4176 auto i1Type = IntegerType::get(permMap.getContext(), 1);
4177 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4178 assert(invPermMap && "Inversed permutation map couldn't be computed");
4179 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4181 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4182 // 0-D mask into a single-element 1-D mask.
4183 if (maskShape.empty())
4184 maskShape.push_back(1);
4186 SmallVector<bool> scalableDims =
4187 applyPermutationMap(invPermMap, vecType.getScalableDims());
4189 return VectorType::get(maskShape, i1Type, scalableDims);
4192 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4193 auto &builder = parser.getBuilder();
4194 SMLoc typesLoc;
4195 OpAsmParser::UnresolvedOperand sourceInfo;
4196 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4197 OpAsmParser::UnresolvedOperand paddingInfo;
4198 SmallVector<Type, 2> types;
4199 OpAsmParser::UnresolvedOperand maskInfo;
4200 // Parsing with support for paddingValue.
4201 if (parser.parseOperand(sourceInfo) ||
4202 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4203 parser.parseComma() || parser.parseOperand(paddingInfo))
4204 return failure();
4205 ParseResult hasMask = parser.parseOptionalComma();
4206 if (hasMask.succeeded()) {
4207 if (parser.parseOperand(maskInfo))
4208 return failure();
4210 if (parser.parseOptionalAttrDict(result.attributes) ||
4211 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4212 return failure();
4213 if (types.size() != 2)
4214 return parser.emitError(typesLoc, "requires two types");
4215 auto indexType = builder.getIndexType();
4216 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4217 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4218 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4219 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4220 if (!vectorType)
4221 return parser.emitError(typesLoc, "requires vector type");
4222 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4223 Attribute permMapAttr = result.attributes.get(permMapAttrName);
4224 AffineMap permMap;
4225 if (!permMapAttr) {
4226 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4227 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4228 } else {
4229 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4231 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4232 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4233 if (!inBoundsAttr) {
4234 result.addAttribute(inBoundsAttrName,
4235 builder.getBoolArrayAttr(
4236 SmallVector<bool>(permMap.getNumResults(), false)));
4238 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4239 parser.resolveOperands(indexInfo, indexType, result.operands) ||
4240 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4241 result.operands))
4242 return failure();
4243 if (hasMask.succeeded()) {
4244 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4245 return parser.emitError(
4246 maskInfo.location, "does not support masks with vector element type");
4247 if (vectorType.getRank() != permMap.getNumResults()) {
4248 return parser.emitError(typesLoc,
4249 "expected the same rank for the vector and the "
4250 "results of the permutation map");
4252 // Instead of adding the mask type as an op type, compute it based on the
4253 // vector type and the permutation map (to keep the type signature small).
4254 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4255 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4256 return failure();
4258 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4259 builder.getDenseI32ArrayAttr(
4260 {1, static_cast<int32_t>(indexInfo.size()), 1,
4261 static_cast<int32_t>(hasMask.succeeded())}));
4262 return parser.addTypeToList(vectorType, result.types);
4265 LogicalResult TransferReadOp::verify() {
4266 // Consistency of elemental types in source and vector.
4267 ShapedType shapedType = getShapedType();
4268 VectorType vectorType = getVectorType();
4269 VectorType maskType = getMaskType();
4270 auto paddingType = getPadding().getType();
4271 auto permutationMap = getPermutationMap();
4272 VectorType inferredMaskType =
4273 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4274 : VectorType();
4275 auto sourceElementType = shapedType.getElementType();
4277 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4278 return emitOpError("requires ") << shapedType.getRank() << " indices";
4280 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4281 shapedType, vectorType, maskType,
4282 inferredMaskType, permutationMap, getInBounds())))
4283 return failure();
4285 if (auto sourceVectorElementType =
4286 llvm::dyn_cast<VectorType>(sourceElementType)) {
4287 // Source has vector element type.
4288 // Check that 'sourceVectorElementType' and 'paddingType' types match.
4289 if (sourceVectorElementType != paddingType)
4290 return emitOpError(
4291 "requires source element type and padding type to match.");
4293 } else {
4294 // Check that 'paddingType' is valid to store in a vector type.
4295 if (!VectorType::isValidElementType(paddingType))
4296 return emitOpError("requires valid padding vector elemental type");
4298 // Check that padding type and vector element types match.
4299 if (paddingType != sourceElementType)
4300 return emitOpError(
4301 "requires formal padding and source of the same elemental type");
4304 return verifyPermutationMap(permutationMap,
4305 [&](Twine t) { return emitOpError(t); });
4308 // MaskableOpInterface methods.
4310 /// Returns the mask type expected by this operation. Mostly used for
4311 /// verification purposes. It requires the operation to be vectorized."
4312 Type TransferReadOp::getExpectedMaskType() {
4313 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4316 template <typename TransferOp>
4317 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4318 // TODO: support more aggressive createOrFold on:
4319 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4320 if (op.getShapedType().isDynamicDim(indicesIdx))
4321 return false;
4322 Value index = op.getIndices()[indicesIdx];
4323 std::optional<int64_t> cstOp = getConstantIntValue(index);
4324 if (!cstOp.has_value())
4325 return false;
4327 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4328 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4330 return cstOp.value() + vectorSize <= sourceSize;
4333 template <typename TransferOp>
4334 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4335 // TODO: support 0-d corner case.
4336 // TODO: Be less conservative.
4337 if (op.getTransferRank() == 0)
4338 return failure();
4339 AffineMap permutationMap = op.getPermutationMap();
4340 bool changed = false;
4341 SmallVector<bool, 4> newInBounds;
4342 newInBounds.reserve(op.getTransferRank());
4343 // Idxs of non-bcast dims - used when analysing bcast dims.
4344 SmallVector<unsigned> nonBcastDims;
4346 // 1. Process non-broadcast dims
4347 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4348 // 1.1. Already marked as in-bounds, nothing to see here.
4349 if (op.isDimInBounds(i)) {
4350 newInBounds.push_back(true);
4351 continue;
4353 // 1.2. Currently out-of-bounds, check whether we can statically determine
4354 // it is inBounds.
4355 bool inBounds = false;
4356 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4357 if (dimExpr) {
4358 inBounds = isInBounds(op, /*resultIdx=*/i,
4359 /*indicesIdx=*/dimExpr.getPosition());
4360 nonBcastDims.push_back(i);
4363 newInBounds.push_back(inBounds);
4364 // We commit the pattern if it is "more inbounds".
4365 changed |= inBounds;
4368 // 2. Handle broadcast dims
4369 // If all non-broadcast dims are "in bounds", then all bcast dims should be
4370 // "in bounds" as well.
4371 bool allNonBcastDimsInBounds = llvm::all_of(
4372 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4373 if (allNonBcastDimsInBounds) {
4374 for (size_t idx : permutationMap.getBroadcastDims()) {
4375 changed |= !newInBounds[idx];
4376 newInBounds[idx] = true;
4380 if (!changed)
4381 return failure();
4382 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4383 OpBuilder b(op.getContext());
4384 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4385 return success();
4388 template <typename TransferOp>
4389 static LogicalResult foldTransferFullMask(TransferOp op) {
4390 auto mask = op.getMask();
4391 if (!mask)
4392 return failure();
4394 if (getMaskFormat(mask) != MaskFormat::AllTrue)
4395 return failure();
4397 op.getMaskMutable().clear();
4398 return success();
4401 /// ```
4402 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4403 /// : vector<1x4xf32>, tensor<4x4xf32>
4404 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4405 /// : tensor<4x4xf32>, vector<1x4xf32>
4406 /// ```
4407 /// -> Folds into
4408 /// ```
4409 /// %v0
4410 /// ```
4411 static Value foldRAW(TransferReadOp readOp) {
4412 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4413 return {};
4414 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4415 while (defWrite) {
4416 if (checkSameValueRAW(defWrite, readOp))
4417 return defWrite.getVector();
4418 if (!isDisjointTransferIndices(
4419 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4420 cast<VectorTransferOpInterface>(readOp.getOperation())))
4421 break;
4422 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4424 return {};
4427 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4428 if (Value vec = foldRAW(*this))
4429 return vec;
4430 /// transfer_read(memrefcast) -> transfer_read
4431 if (succeeded(foldTransferInBoundsAttribute(*this)))
4432 return getResult();
4433 if (succeeded(foldTransferFullMask(*this)))
4434 return getResult();
4435 if (succeeded(memref::foldMemRefCast(*this)))
4436 return getResult();
4437 if (succeeded(tensor::foldTensorCast(*this)))
4438 return getResult();
4439 return OpFoldResult();
4442 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4443 return llvm::to_vector<4>(getVectorType().getShape());
4446 void TransferReadOp::getEffects(
4447 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4448 &effects) {
4449 if (llvm::isa<MemRefType>(getShapedType()))
4450 effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4451 SideEffects::DefaultResource::get());
4454 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4455 if (hasPureTensorSemantics())
4456 return Speculation::Speculatable;
4457 return Speculation::NotSpeculatable;
4460 namespace {
4461 /// Store to load forwarding for transfer operations with permuation maps.
4462 /// Even if the permutation maps are different we can still propagate the store
4463 /// into the load if the size of the dimensions read and written match. Then we
4464 /// can replace the transfer_read + transfer_write by vector.broadcast and
4465 /// vector.transpose.
4466 /// Example:
4467 /// ```
4468 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4469 /// {in_bounds = [true, true],
4470 /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4471 /// vector<4x1xf32>, tensor<4x4x4xf32>
4472 /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4473 /// {in_bounds = [true, true, true, true],
4474 /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4475 /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4476 /// ```
4477 /// To:
4478 /// ```
4479 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4480 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4481 /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4482 /// ```
4483 struct TransferReadAfterWriteToBroadcast
4484 : public OpRewritePattern<TransferReadOp> {
4485 using OpRewritePattern::OpRewritePattern;
4487 LogicalResult matchAndRewrite(TransferReadOp readOp,
4488 PatternRewriter &rewriter) const override {
4489 if (readOp.hasOutOfBoundsDim() ||
4490 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4491 return failure();
4492 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4493 if (!defWrite)
4494 return failure();
4495 // TODO: If the written transfer chunk is a superset of the read transfer
4496 // chunk we could do an extract_strided_slice.
4497 if (readOp.getTransferChunkAccessed() !=
4498 defWrite.getTransferChunkAccessed())
4499 return failure();
4500 // TODO: Support cases where a dim is explicitly written but implicitly
4501 // read (i.e., a unit dim that is rank reduced).
4502 if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4503 getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4504 return failure();
4505 if (readOp.getIndices() != defWrite.getIndices() ||
4506 readOp.getMask() != defWrite.getMask())
4507 return failure();
4508 Value vec = defWrite.getVector();
4509 // TODO: loop through the chain of transfer_write if we can prove that they
4510 // don't overlap with the transfer_read. This requires improving
4511 // `isDisjointTransferIndices` helper.
4512 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4513 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4514 AffineMap map = readMap.compose(writeMap);
4515 if (map.getNumResults() == 0)
4516 return failure();
4517 // Calculate the permutation to apply to go from the vector stored to the
4518 // vector read.
4519 SmallVector<unsigned> permutation;
4520 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4521 return failure();
4523 Location loc = readOp.getLoc();
4524 // Calculate the broadcast shape by applying the reverse permutation to the
4525 // final shape we want.
4526 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4527 SmallVector<int64_t> broadcastShape(destShape.size());
4528 SmallVector<bool> broadcastScalableFlags(destShape.size());
4529 for (const auto &pos : llvm::enumerate(permutation)) {
4530 broadcastShape[pos.value()] = destShape[pos.index()];
4531 broadcastScalableFlags[pos.value()] =
4532 readOp.getVectorType().getScalableDims()[pos.index()];
4534 VectorType broadcastedType = VectorType::get(
4535 broadcastShape, defWrite.getVectorType().getElementType(),
4536 broadcastScalableFlags);
4537 vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4538 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4539 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4540 transposePerm);
4541 return success();
4544 } // namespace
4546 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4547 MLIRContext *context) {
4548 results.add<TransferReadAfterWriteToBroadcast>(context);
4551 //===----------------------------------------------------------------------===//
4552 // TransferWriteOp
4553 //===----------------------------------------------------------------------===//
4555 /// 1. Builder with type inference.
4556 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4557 Value vector, Value dest, ValueRange indices,
4558 AffineMapAttr permutationMapAttr,
4559 /*optional*/ Value mask,
4560 /*optional*/ ArrayAttr inBoundsAttr) {
4561 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4562 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4563 mask, inBoundsAttr);
4566 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4567 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4568 Value vector, Value dest, ValueRange indices,
4569 AffineMapAttr permutationMapAttr,
4570 /*optional*/ ArrayAttr inBoundsAttr) {
4571 build(builder, result, vector, dest, indices, permutationMapAttr,
4572 /*mask=*/Value(), inBoundsAttr);
4575 /// 3. Builder with type inference that sets an empty mask (variant without
4576 /// attrs)
4577 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4578 Value vector, Value dest, ValueRange indices,
4579 AffineMap permutationMap,
4580 std::optional<ArrayRef<bool>> inBounds) {
4581 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4582 auto inBoundsAttr =
4583 (inBounds && !inBounds.value().empty())
4584 ? builder.getBoolArrayAttr(inBounds.value())
4585 : builder.getBoolArrayAttr(SmallVector<bool>(
4586 llvm::cast<VectorType>(vector.getType()).getRank(), false));
4587 build(builder, result, vector, dest, indices, permutationMapAttr,
4588 /*mask=*/Value(), inBoundsAttr);
4591 /// 4. Builder with type inference that sets an empty mask and sets permutation
4592 /// map to 'getMinorIdentityMap'.
4593 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4594 Value vector, Value dest, ValueRange indices,
4595 std::optional<ArrayRef<bool>> inBounds) {
4596 auto vectorType = llvm::cast<VectorType>(vector.getType());
4597 AffineMap permutationMap = getTransferMinorIdentityMap(
4598 llvm::cast<ShapedType>(dest.getType()), vectorType);
4599 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4602 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4603 OperationState &result) {
4604 auto &builder = parser.getBuilder();
4605 SMLoc typesLoc;
4606 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4607 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4608 SmallVector<Type, 2> types;
4609 OpAsmParser::UnresolvedOperand maskInfo;
4610 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4611 parser.parseOperand(sourceInfo) ||
4612 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4613 return failure();
4614 ParseResult hasMask = parser.parseOptionalComma();
4615 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4616 return failure();
4617 if (parser.parseOptionalAttrDict(result.attributes) ||
4618 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4619 return failure();
4620 if (types.size() != 2)
4621 return parser.emitError(typesLoc, "requires two types");
4622 auto indexType = builder.getIndexType();
4623 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4624 if (!vectorType)
4625 return parser.emitError(typesLoc, "requires vector type");
4626 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4627 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4628 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4629 auto permMapAttrName =
4630 TransferWriteOp::getPermutationMapAttrName(result.name);
4631 auto permMapAttr = result.attributes.get(permMapAttrName);
4632 AffineMap permMap;
4633 if (!permMapAttr) {
4634 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4635 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4636 } else {
4637 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4639 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4640 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4641 if (!inBoundsAttr) {
4642 result.addAttribute(inBoundsAttrName,
4643 builder.getBoolArrayAttr(
4644 SmallVector<bool>(permMap.getNumResults(), false)));
4646 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4647 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4648 parser.resolveOperands(indexInfo, indexType, result.operands))
4649 return failure();
4650 if (hasMask.succeeded()) {
4651 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4652 return parser.emitError(
4653 maskInfo.location, "does not support masks with vector element type");
4654 if (vectorType.getRank() != permMap.getNumResults()) {
4655 return parser.emitError(typesLoc,
4656 "expected the same rank for the vector and the "
4657 "results of the permutation map");
4659 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4660 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4661 return failure();
4663 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4664 builder.getDenseI32ArrayAttr(
4665 {1, 1, static_cast<int32_t>(indexInfo.size()),
4666 static_cast<int32_t>(hasMask.succeeded())}));
4667 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4668 parser.addTypeToList(shapedType, result.types));
4671 void TransferWriteOp::print(OpAsmPrinter &p) {
4672 p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4673 if (getMask())
4674 p << ", " << getMask();
4675 printTransferAttrs(p, *this);
4676 p << " : " << getVectorType() << ", " << getShapedType();
4679 LogicalResult TransferWriteOp::verify() {
4680 // Consistency of elemental types in shape and vector.
4681 ShapedType shapedType = getShapedType();
4682 VectorType vectorType = getVectorType();
4683 VectorType maskType = getMaskType();
4684 auto permutationMap = getPermutationMap();
4685 VectorType inferredMaskType =
4686 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4687 : VectorType();
4689 if (llvm::size(getIndices()) != shapedType.getRank())
4690 return emitOpError("requires ") << shapedType.getRank() << " indices";
4692 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4693 // as the semantics is unclear. This can be revisited later if necessary.
4694 if (hasBroadcastDim())
4695 return emitOpError("should not have broadcast dimensions");
4697 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4698 shapedType, vectorType, maskType,
4699 inferredMaskType, permutationMap, getInBounds())))
4700 return failure();
4702 return verifyPermutationMap(permutationMap,
4703 [&](Twine t) { return emitOpError(t); });
4706 // MaskableOpInterface methods.
4708 /// Returns the mask type expected by this operation. Mostly used for
4709 /// verification purposes.
4710 Type TransferWriteOp::getExpectedMaskType() {
4711 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4714 /// Fold:
4715 /// ```
4716 /// %t1 = ...
4717 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4718 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4719 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4720 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4721 /// ```
4723 /// into:
4725 /// ```
4726 /// %t0
4727 /// ```
4729 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4730 /// block argument or has side effects.
4731 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4732 ArrayRef<Attribute>,
4733 SmallVectorImpl<OpFoldResult> &results) {
4734 // TODO: support 0-d corner case.
4735 if (write.getTransferRank() == 0)
4736 return failure();
4737 auto rankedTensorType =
4738 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4739 // If not operating on tensors, bail.
4740 if (!rankedTensorType)
4741 return failure();
4742 // If no read, bail.
4743 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4744 if (!read)
4745 return failure();
4746 // TODO: support 0-d corner case.
4747 if (read.getTransferRank() == 0)
4748 return failure();
4749 // For now, only accept minor identity. Future: composition is minor identity.
4750 if (!read.getPermutationMap().isMinorIdentity() ||
4751 !write.getPermutationMap().isMinorIdentity())
4752 return failure();
4753 // Bail on mismatching ranks.
4754 if (read.getTransferRank() != write.getTransferRank())
4755 return failure();
4756 // Bail on potential out-of-bounds accesses.
4757 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4758 return failure();
4759 // Tensor types must be the same.
4760 if (read.getSource().getType() != rankedTensorType)
4761 return failure();
4762 // Vector types must be the same.
4763 if (read.getVectorType() != write.getVectorType())
4764 return failure();
4765 // Vector and Tensor shapes must match.
4766 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4767 return failure();
4768 // If any index is nonzero.
4769 auto isNotConstantZero = [](Value v) {
4770 auto cstOp = getConstantIntValue(v);
4771 return !cstOp.has_value() || cstOp.value() != 0;
4773 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4774 llvm::any_of(write.getIndices(), isNotConstantZero))
4775 return failure();
4776 // Success.
4777 results.push_back(read.getSource());
4778 return success();
4781 static bool checkSameValueWAR(vector::TransferReadOp read,
4782 vector::TransferWriteOp write) {
4783 return read.getSource() == write.getSource() &&
4784 read.getIndices() == write.getIndices() &&
4785 read.getPermutationMap() == write.getPermutationMap() &&
4786 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4787 !write.getMask();
4789 /// Fold transfer_write write after read:
4790 /// ```
4791 /// %t0 = ...
4792 /// %v = vector.transfer_read %t0[%c0...] :
4793 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
4794 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
4795 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
4796 /// ```
4798 /// into:
4800 /// ```
4801 /// %t0
4802 /// ```
4803 static LogicalResult foldWAR(TransferWriteOp write,
4804 SmallVectorImpl<OpFoldResult> &results) {
4805 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4806 return failure();
4807 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4808 if (!read)
4809 return failure();
4811 if (!checkSameValueWAR(read, write))
4812 return failure();
4813 results.push_back(read.getSource());
4814 return success();
4817 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4818 SmallVectorImpl<OpFoldResult> &results) {
4819 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4820 return success();
4821 if (succeeded(foldWAR(*this, results)))
4822 return success();
4823 if (succeeded(foldTransferInBoundsAttribute(*this)))
4824 return success();
4825 if (succeeded(foldTransferFullMask(*this)))
4826 return success();
4827 return memref::foldMemRefCast(*this);
4830 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4831 return llvm::to_vector<4>(getVectorType().getShape());
4834 void TransferWriteOp::getEffects(
4835 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4836 &effects) {
4837 if (llvm::isa<MemRefType>(getShapedType()))
4838 effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4839 SideEffects::DefaultResource::get());
4842 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4843 if (hasPureTensorSemantics())
4844 return Speculation::Speculatable;
4845 return Speculation::NotSpeculatable;
4848 namespace {
4849 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4850 /// DCE
4851 /// ```
4852 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4853 /// : vector<1x4xf32>, tensor<4x4xf32>
4854 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4855 /// : vector<1x4xf32>, tensor<4x4xf32>
4856 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4857 /// : vector<1x4xf32>, tensor<4x4xf32>
4858 /// ```
4860 /// into:
4862 /// ```
4863 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4864 /// : vector<1x4xf32>, tensor<4x4xf32>
4865 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4866 /// : vector<1x4xf32>, tensor<4x4xf32>
4867 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4868 /// : vector<1x4xf32>, tensor<4x4xf32>
4869 /// ```
4871 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4872 /// any other uses.
4873 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4874 public:
4875 using OpRewritePattern::OpRewritePattern;
4876 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4877 PatternRewriter &rewriter) const override {
4878 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4879 return failure();
4880 vector::TransferWriteOp writeToModify = writeOp;
4882 auto defWrite =
4883 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4884 while (defWrite) {
4885 if (checkSameValueWAW(writeOp, defWrite)) {
4886 rewriter.modifyOpInPlace(writeToModify, [&]() {
4887 writeToModify.getSourceMutable().assign(defWrite.getSource());
4889 return success();
4891 if (!isDisjointTransferIndices(
4892 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4893 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4894 break;
4895 // If the previous write op doesn't have any other use we an safely look
4896 // at the previous store to see if it can be removed.
4897 if (!defWrite->hasOneUse())
4898 break;
4899 writeToModify = defWrite;
4900 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4902 return failure();
4906 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4907 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4908 /// overwritten and inserted into another tensor. After this rewrite, the
4909 /// operations bufferize in-place since all of them work on the same slice.
4911 /// For example:
4912 /// ```mlir
4913 /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4914 /// : vector<8x16xf32>, tensor<8x16xf32>
4915 /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4916 /// : tensor<8x16xf32> to tensor<?x?xf32>
4917 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4918 /// : tensor<?x?xf32> into tensor<27x37xf32>
4919 /// ```
4920 /// folds to
4921 /// ```mlir
4922 /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4923 /// : tensor<27x37xf32> to tensor<?x?xf32>
4924 /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4925 /// : vector<8x16xf32>, tensor<?x?xf32>
4926 /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4927 /// : tensor<?x?xf32> into tensor<27x37xf32>
4928 /// ```
4929 struct SwapExtractSliceOfTransferWrite
4930 : public OpRewritePattern<tensor::InsertSliceOp> {
4931 public:
4932 using OpRewritePattern::OpRewritePattern;
4934 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4935 PatternRewriter &rewriter) const override {
4936 if (!insertOp.hasUnitStride())
4937 return failure();
4938 auto extractOp =
4939 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4940 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4941 return failure();
4942 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4943 if (!transferOp || !transferOp->hasOneUse())
4944 return failure();
4946 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4947 // rank-reducing.
4948 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4949 return rewriter.notifyMatchFailure(insertOp,
4950 "use-def chain is rank-reducing");
4953 // Fail if tensor::ExtractSliceOp has non-zero offset.
4954 if (!extractOp.hasZeroOffset()) {
4955 return rewriter.notifyMatchFailure(insertOp,
4956 "ExtractSliceOp has non-zero offset");
4959 // Fail if tensor::TransferWriteOp has non-zero offset.
4960 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4961 return getConstantIntValue(value) == static_cast<int64_t>(0);
4962 })) {
4963 return rewriter.notifyMatchFailure(insertOp,
4964 "TranferWriteOp has non-zero offset");
4967 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4968 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4969 return rewriter.notifyMatchFailure(
4970 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4973 for (auto [insertSize, extractSize] :
4974 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4975 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4976 return rewriter.notifyMatchFailure(
4977 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4981 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4982 assert(transferOp.getVectorType().hasStaticShape() &&
4983 "expected vector to have a static shape");
4984 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4985 SmallVector<int64_t> resultShape = applyPermutationMap(
4986 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4987 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4988 return rewriter.notifyMatchFailure(
4989 insertOp, "TransferWriteOp may not write the full tensor.");
4992 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4993 // Set all in_bounds to false and let the folder infer them.
4994 SmallVector<bool> newInBounds(vectorShape.size(), false);
4995 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4996 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4997 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4998 insertOp.getMixedStrides());
4999 auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5000 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5001 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5002 rewriter.getBoolArrayAttr(newInBounds));
5003 rewriter.modifyOpInPlace(insertOp, [&]() {
5004 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5006 return success();
5010 } // namespace
5012 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5013 MLIRContext *context) {
5014 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5017 //===----------------------------------------------------------------------===//
5018 // LoadOp
5019 //===----------------------------------------------------------------------===//
5021 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5022 VectorType vecTy,
5023 MemRefType memRefTy) {
5024 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5025 // need any strides limitations.
5026 if (!vecTy.isScalable() &&
5027 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5028 return success();
5030 if (!memRefTy.isLastDimUnitStride())
5031 return op->emitOpError("most minor memref dim must have unit stride");
5032 return success();
5035 LogicalResult vector::LoadOp::verify() {
5036 VectorType resVecTy = getVectorType();
5037 MemRefType memRefTy = getMemRefType();
5039 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5040 return failure();
5042 // Checks for vector memrefs.
5043 Type memElemTy = memRefTy.getElementType();
5044 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5045 if (memVecTy != resVecTy)
5046 return emitOpError("base memref and result vector types should match");
5047 memElemTy = memVecTy.getElementType();
5050 if (resVecTy.getElementType() != memElemTy)
5051 return emitOpError("base and result element types should match");
5052 if (llvm::size(getIndices()) != memRefTy.getRank())
5053 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5054 return success();
5057 OpFoldResult LoadOp::fold(FoldAdaptor) {
5058 if (succeeded(memref::foldMemRefCast(*this)))
5059 return getResult();
5060 return OpFoldResult();
5063 //===----------------------------------------------------------------------===//
5064 // StoreOp
5065 //===----------------------------------------------------------------------===//
5067 LogicalResult vector::StoreOp::verify() {
5068 VectorType valueVecTy = getVectorType();
5069 MemRefType memRefTy = getMemRefType();
5071 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5072 return failure();
5074 // Checks for vector memrefs.
5075 Type memElemTy = memRefTy.getElementType();
5076 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5077 if (memVecTy != valueVecTy)
5078 return emitOpError(
5079 "base memref and valueToStore vector types should match");
5080 memElemTy = memVecTy.getElementType();
5083 if (valueVecTy.getElementType() != memElemTy)
5084 return emitOpError("base and valueToStore element type should match");
5085 if (llvm::size(getIndices()) != memRefTy.getRank())
5086 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5087 return success();
5090 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5091 SmallVectorImpl<OpFoldResult> &results) {
5092 return memref::foldMemRefCast(*this);
5095 //===----------------------------------------------------------------------===//
5096 // MaskedLoadOp
5097 //===----------------------------------------------------------------------===//
5099 LogicalResult MaskedLoadOp::verify() {
5100 VectorType maskVType = getMaskVectorType();
5101 VectorType passVType = getPassThruVectorType();
5102 VectorType resVType = getVectorType();
5103 MemRefType memType = getMemRefType();
5105 if (resVType.getElementType() != memType.getElementType())
5106 return emitOpError("base and result element type should match");
5107 if (llvm::size(getIndices()) != memType.getRank())
5108 return emitOpError("requires ") << memType.getRank() << " indices";
5109 if (resVType.getShape() != maskVType.getShape())
5110 return emitOpError("expected result shape to match mask shape");
5111 if (resVType != passVType)
5112 return emitOpError("expected pass_thru of same type as result type");
5113 return success();
5116 namespace {
5117 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5118 public:
5119 using OpRewritePattern::OpRewritePattern;
5120 LogicalResult matchAndRewrite(MaskedLoadOp load,
5121 PatternRewriter &rewriter) const override {
5122 switch (getMaskFormat(load.getMask())) {
5123 case MaskFormat::AllTrue:
5124 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5125 load, load.getType(), load.getBase(), load.getIndices());
5126 return success();
5127 case MaskFormat::AllFalse:
5128 rewriter.replaceOp(load, load.getPassThru());
5129 return success();
5130 case MaskFormat::Unknown:
5131 return failure();
5133 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5136 } // namespace
5138 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5139 MLIRContext *context) {
5140 results.add<MaskedLoadFolder>(context);
5143 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5144 if (succeeded(memref::foldMemRefCast(*this)))
5145 return getResult();
5146 return OpFoldResult();
5149 //===----------------------------------------------------------------------===//
5150 // MaskedStoreOp
5151 //===----------------------------------------------------------------------===//
5153 LogicalResult MaskedStoreOp::verify() {
5154 VectorType maskVType = getMaskVectorType();
5155 VectorType valueVType = getVectorType();
5156 MemRefType memType = getMemRefType();
5158 if (valueVType.getElementType() != memType.getElementType())
5159 return emitOpError("base and valueToStore element type should match");
5160 if (llvm::size(getIndices()) != memType.getRank())
5161 return emitOpError("requires ") << memType.getRank() << " indices";
5162 if (valueVType.getShape() != maskVType.getShape())
5163 return emitOpError("expected valueToStore shape to match mask shape");
5164 return success();
5167 namespace {
5168 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5169 public:
5170 using OpRewritePattern::OpRewritePattern;
5171 LogicalResult matchAndRewrite(MaskedStoreOp store,
5172 PatternRewriter &rewriter) const override {
5173 switch (getMaskFormat(store.getMask())) {
5174 case MaskFormat::AllTrue:
5175 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5176 store, store.getValueToStore(), store.getBase(), store.getIndices());
5177 return success();
5178 case MaskFormat::AllFalse:
5179 rewriter.eraseOp(store);
5180 return success();
5181 case MaskFormat::Unknown:
5182 return failure();
5184 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5187 } // namespace
5189 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5190 MLIRContext *context) {
5191 results.add<MaskedStoreFolder>(context);
5194 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5195 SmallVectorImpl<OpFoldResult> &results) {
5196 return memref::foldMemRefCast(*this);
5199 //===----------------------------------------------------------------------===//
5200 // GatherOp
5201 //===----------------------------------------------------------------------===//
5203 LogicalResult GatherOp::verify() {
5204 VectorType indVType = getIndexVectorType();
5205 VectorType maskVType = getMaskVectorType();
5206 VectorType resVType = getVectorType();
5207 ShapedType baseType = getBaseType();
5209 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5210 return emitOpError("requires base to be a memref or ranked tensor type");
5212 if (resVType.getElementType() != baseType.getElementType())
5213 return emitOpError("base and result element type should match");
5214 if (llvm::size(getIndices()) != baseType.getRank())
5215 return emitOpError("requires ") << baseType.getRank() << " indices";
5216 if (resVType.getShape() != indVType.getShape())
5217 return emitOpError("expected result dim to match indices dim");
5218 if (resVType.getShape() != maskVType.getShape())
5219 return emitOpError("expected result dim to match mask dim");
5220 if (resVType != getPassThruVectorType())
5221 return emitOpError("expected pass_thru of same type as result type");
5222 return success();
5225 // MaskableOpInterface methods.
5227 /// Returns the mask type expected by this operation. Mostly used for
5228 /// verification purposes. It requires the operation to be vectorized."
5229 Type GatherOp::getExpectedMaskType() {
5230 auto vecType = this->getIndexVectorType();
5231 return VectorType::get(vecType.getShape(),
5232 IntegerType::get(vecType.getContext(), /*width=*/1),
5233 vecType.getScalableDims());
5236 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5237 return llvm::to_vector<4>(getVectorType().getShape());
5240 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5241 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5242 auto vecType = dyn_cast<VectorType>(indexVec.getType());
5243 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5244 return failure();
5246 if (indexVec.getDefiningOp<StepOp>())
5247 return success();
5249 DenseIntElementsAttr elements;
5250 if (!matchPattern(indexVec, m_Constant(&elements)))
5251 return failure();
5253 return success(
5254 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5257 namespace {
5258 class GatherFolder final : public OpRewritePattern<GatherOp> {
5259 public:
5260 using OpRewritePattern::OpRewritePattern;
5261 LogicalResult matchAndRewrite(GatherOp gather,
5262 PatternRewriter &rewriter) const override {
5263 switch (getMaskFormat(gather.getMask())) {
5264 case MaskFormat::AllTrue:
5265 return failure(); // no unmasked equivalent
5266 case MaskFormat::AllFalse:
5267 rewriter.replaceOp(gather, gather.getPassThru());
5268 return success();
5269 case MaskFormat::Unknown:
5270 return failure();
5272 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5276 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5277 /// maskedload. Only 1D fixed vectors are supported for now.
5278 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5279 public:
5280 using OpRewritePattern::OpRewritePattern;
5281 LogicalResult matchAndRewrite(GatherOp op,
5282 PatternRewriter &rewriter) const override {
5283 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5284 return failure();
5286 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5287 op.getIndices(), op.getMask(),
5288 op.getPassThru());
5289 return success();
5292 } // namespace
5294 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5295 MLIRContext *context) {
5296 results.add<GatherFolder, FoldContiguousGather>(context);
5299 //===----------------------------------------------------------------------===//
5300 // ScatterOp
5301 //===----------------------------------------------------------------------===//
5303 LogicalResult ScatterOp::verify() {
5304 VectorType indVType = getIndexVectorType();
5305 VectorType maskVType = getMaskVectorType();
5306 VectorType valueVType = getVectorType();
5307 MemRefType memType = getMemRefType();
5309 if (valueVType.getElementType() != memType.getElementType())
5310 return emitOpError("base and valueToStore element type should match");
5311 if (llvm::size(getIndices()) != memType.getRank())
5312 return emitOpError("requires ") << memType.getRank() << " indices";
5313 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5314 return emitOpError("expected valueToStore dim to match indices dim");
5315 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5316 return emitOpError("expected valueToStore dim to match mask dim");
5317 return success();
5320 namespace {
5321 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5322 public:
5323 using OpRewritePattern::OpRewritePattern;
5324 LogicalResult matchAndRewrite(ScatterOp scatter,
5325 PatternRewriter &rewriter) const override {
5326 switch (getMaskFormat(scatter.getMask())) {
5327 case MaskFormat::AllTrue:
5328 return failure(); // no unmasked equivalent
5329 case MaskFormat::AllFalse:
5330 rewriter.eraseOp(scatter);
5331 return success();
5332 case MaskFormat::Unknown:
5333 return failure();
5335 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5339 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5340 /// maskedstore. Only 1D fixed vectors are supported for now.
5341 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5342 public:
5343 using OpRewritePattern::OpRewritePattern;
5344 LogicalResult matchAndRewrite(ScatterOp op,
5345 PatternRewriter &rewriter) const override {
5346 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5347 return failure();
5349 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5350 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5351 return success();
5354 } // namespace
5356 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5357 MLIRContext *context) {
5358 results.add<ScatterFolder, FoldContiguousScatter>(context);
5361 //===----------------------------------------------------------------------===//
5362 // ExpandLoadOp
5363 //===----------------------------------------------------------------------===//
5365 LogicalResult ExpandLoadOp::verify() {
5366 VectorType maskVType = getMaskVectorType();
5367 VectorType passVType = getPassThruVectorType();
5368 VectorType resVType = getVectorType();
5369 MemRefType memType = getMemRefType();
5371 if (resVType.getElementType() != memType.getElementType())
5372 return emitOpError("base and result element type should match");
5373 if (llvm::size(getIndices()) != memType.getRank())
5374 return emitOpError("requires ") << memType.getRank() << " indices";
5375 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5376 return emitOpError("expected result dim to match mask dim");
5377 if (resVType != passVType)
5378 return emitOpError("expected pass_thru of same type as result type");
5379 return success();
5382 namespace {
5383 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5384 public:
5385 using OpRewritePattern::OpRewritePattern;
5386 LogicalResult matchAndRewrite(ExpandLoadOp expand,
5387 PatternRewriter &rewriter) const override {
5388 switch (getMaskFormat(expand.getMask())) {
5389 case MaskFormat::AllTrue:
5390 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5391 expand, expand.getType(), expand.getBase(), expand.getIndices());
5392 return success();
5393 case MaskFormat::AllFalse:
5394 rewriter.replaceOp(expand, expand.getPassThru());
5395 return success();
5396 case MaskFormat::Unknown:
5397 return failure();
5399 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5402 } // namespace
5404 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5405 MLIRContext *context) {
5406 results.add<ExpandLoadFolder>(context);
5409 //===----------------------------------------------------------------------===//
5410 // CompressStoreOp
5411 //===----------------------------------------------------------------------===//
5413 LogicalResult CompressStoreOp::verify() {
5414 VectorType maskVType = getMaskVectorType();
5415 VectorType valueVType = getVectorType();
5416 MemRefType memType = getMemRefType();
5418 if (valueVType.getElementType() != memType.getElementType())
5419 return emitOpError("base and valueToStore element type should match");
5420 if (llvm::size(getIndices()) != memType.getRank())
5421 return emitOpError("requires ") << memType.getRank() << " indices";
5422 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5423 return emitOpError("expected valueToStore dim to match mask dim");
5424 return success();
5427 namespace {
5428 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5429 public:
5430 using OpRewritePattern::OpRewritePattern;
5431 LogicalResult matchAndRewrite(CompressStoreOp compress,
5432 PatternRewriter &rewriter) const override {
5433 switch (getMaskFormat(compress.getMask())) {
5434 case MaskFormat::AllTrue:
5435 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5436 compress, compress.getValueToStore(), compress.getBase(),
5437 compress.getIndices());
5438 return success();
5439 case MaskFormat::AllFalse:
5440 rewriter.eraseOp(compress);
5441 return success();
5442 case MaskFormat::Unknown:
5443 return failure();
5445 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5448 } // namespace
5450 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5451 MLIRContext *context) {
5452 results.add<CompressStoreFolder>(context);
5455 //===----------------------------------------------------------------------===//
5456 // ShapeCastOp
5457 //===----------------------------------------------------------------------===//
5459 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5460 SetIntRangeFn setResultRanges) {
5461 setResultRanges(getResult(), argRanges.front());
5464 /// Returns true if each element of 'a' is equal to the product of a contiguous
5465 /// sequence of the elements of 'b'. Returns false otherwise.
5466 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5467 unsigned rankA = a.size();
5468 unsigned rankB = b.size();
5469 assert(rankA < rankB);
5471 auto isOne = [](int64_t v) { return v == 1; };
5473 // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5474 // casted to a 0-d vector.
5475 if (rankA == 0 && llvm::all_of(b, isOne))
5476 return true;
5478 unsigned i = 0;
5479 unsigned j = 0;
5480 while (i < rankA && j < rankB) {
5481 int64_t dimA = a[i];
5482 int64_t dimB = 1;
5483 while (dimB < dimA && j < rankB)
5484 dimB *= b[j++];
5485 if (dimA != dimB)
5486 break;
5487 ++i;
5489 // Handle the case when trailing dimensions are of size 1.
5490 // Include them into the contiguous sequence.
5491 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5492 i = rankA;
5493 if (j < rankB && llvm::all_of(b.slice(j), isOne))
5494 j = rankB;
5497 return i == rankA && j == rankB;
5500 static LogicalResult verifyVectorShapeCast(Operation *op,
5501 VectorType sourceVectorType,
5502 VectorType resultVectorType) {
5503 // Check that element type is the same.
5504 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5505 return op->emitOpError("source/result vectors must have same element type");
5506 auto sourceShape = sourceVectorType.getShape();
5507 auto resultShape = resultVectorType.getShape();
5509 // Check that product of source dim sizes matches product of result dim sizes.
5510 int64_t sourceDimProduct = std::accumulate(
5511 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5512 int64_t resultDimProduct = std::accumulate(
5513 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5514 if (sourceDimProduct != resultDimProduct)
5515 return op->emitOpError("source/result number of elements must match");
5517 // Check that expanding/contracting rank cases.
5518 unsigned sourceRank = sourceVectorType.getRank();
5519 unsigned resultRank = resultVectorType.getRank();
5520 if (sourceRank < resultRank) {
5521 if (!isValidShapeCast(sourceShape, resultShape))
5522 return op->emitOpError("invalid shape cast");
5523 } else if (sourceRank > resultRank) {
5524 if (!isValidShapeCast(resultShape, sourceShape))
5525 return op->emitOpError("invalid shape cast");
5528 // Check that (non-)scalability is preserved
5529 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5530 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5531 if (sourceNScalableDims != resultNScalableDims)
5532 return op->emitOpError("different number of scalable dims at source (")
5533 << sourceNScalableDims << ") and result (" << resultNScalableDims
5534 << ")";
5535 sourceVectorType.getNumDynamicDims();
5537 return success();
5540 LogicalResult ShapeCastOp::verify() {
5541 auto sourceVectorType =
5542 llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5543 auto resultVectorType =
5544 llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5546 // Check if source/result are of vector type.
5547 if (sourceVectorType && resultVectorType)
5548 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5550 return success();
5553 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5554 // No-op shape cast.
5555 if (getSource().getType() == getResult().getType())
5556 return getSource();
5558 // Canceling shape casts.
5559 if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5560 if (getResult().getType() == otherOp.getSource().getType())
5561 return otherOp.getSource();
5563 // Only allows valid transitive folding.
5564 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5565 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5566 if (srcType.getRank() < resultType.getRank()) {
5567 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5568 return {};
5569 } else if (srcType.getRank() > resultType.getRank()) {
5570 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5571 return {};
5572 } else {
5573 return {};
5576 setOperand(otherOp.getSource());
5577 return getResult();
5580 // Cancelling broadcast and shape cast ops.
5581 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5582 if (bcastOp.getSourceType() == getType())
5583 return bcastOp.getSource();
5586 return {};
5589 namespace {
5590 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5591 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5592 public:
5593 using OpRewritePattern::OpRewritePattern;
5595 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5596 PatternRewriter &rewriter) const override {
5597 auto constantOp =
5598 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5599 if (!constantOp)
5600 return failure();
5601 // Only handle splat for now.
5602 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5603 if (!dense)
5604 return failure();
5605 auto newAttr =
5606 DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5607 dense.getSplatValue<Attribute>());
5608 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5609 return success();
5613 /// Helper function that computes a new vector type based on the input vector
5614 /// type by removing the trailing one dims:
5616 /// vector<4x1x1xi1> --> vector<4x1>
5618 static VectorType trimTrailingOneDims(VectorType oldType) {
5619 ArrayRef<int64_t> oldShape = oldType.getShape();
5620 ArrayRef<int64_t> newShape = oldShape;
5622 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5623 ArrayRef<bool> newScalableDims = oldScalableDims;
5625 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5626 newShape = newShape.drop_back(1);
5627 newScalableDims = newScalableDims.drop_back(1);
5630 // Make sure we have at least 1 dimension.
5631 // TODO: Add support for 0-D vectors.
5632 if (newShape.empty()) {
5633 newShape = oldShape.take_back();
5634 newScalableDims = oldScalableDims.take_back();
5637 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5640 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5642 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5643 /// dimension. If the input vector comes from `vector.create_mask` for which
5644 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5645 /// to fold shape_cast into create_mask.
5647 /// BEFORE:
5648 /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5649 /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5650 /// AFTER:
5651 /// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5652 class ShapeCastCreateMaskFolderTrailingOneDim final
5653 : public OpRewritePattern<ShapeCastOp> {
5654 public:
5655 using OpRewritePattern::OpRewritePattern;
5657 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5658 PatternRewriter &rewriter) const override {
5659 Value shapeOpSrc = shapeOp->getOperand(0);
5660 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5661 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5662 if (!createMaskOp && !constantMaskOp)
5663 return failure();
5665 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5666 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5668 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5669 if (newVecType != shapeOpResTy)
5670 return failure();
5672 auto numDimsToDrop =
5673 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5675 // No unit dims to drop
5676 if (!numDimsToDrop)
5677 return failure();
5679 if (createMaskOp) {
5680 auto maskOperands = createMaskOp.getOperands();
5681 auto numMaskOperands = maskOperands.size();
5683 // Check every mask dim size to see whether it can be dropped
5684 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5685 --i) {
5686 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5687 if (!constant || (constant.value() != 1))
5688 return failure();
5690 SmallVector<Value> newMaskOperands =
5691 maskOperands.drop_back(numDimsToDrop);
5693 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5694 newMaskOperands);
5695 return success();
5698 if (constantMaskOp) {
5699 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5700 auto numMaskOperands = maskDimSizes.size();
5702 // Check every mask dim size to see whether it can be dropped
5703 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5704 --i) {
5705 if (maskDimSizes[i] != 1)
5706 return failure();
5709 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5710 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5711 newMaskOperands);
5712 return success();
5715 return failure();
5719 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5720 /// This only applies when the shape of the broadcast source
5721 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5722 /// reshape is expressive enough to capture the result in a single op), or
5723 /// 2. has the same element count as the shape cast result.
5724 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5725 public:
5726 using OpRewritePattern::OpRewritePattern;
5728 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5729 PatternRewriter &rewriter) const override {
5730 auto broadcastOp =
5731 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5732 if (!broadcastOp)
5733 return failure();
5735 ArrayRef<int64_t> broadcastSourceShape;
5736 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5737 broadcastSourceShape = srcType.getShape();
5738 ArrayRef<int64_t> shapeCastTargetShape =
5739 shapeCastOp.getResultVectorType().getShape();
5741 // If `broadcastSourceShape` is a suffix of the result, we can just replace
5742 // with a broadcast to the final shape.
5743 if (broadcastSourceShape ==
5744 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5745 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5746 shapeCastOp, shapeCastOp.getResultVectorType(),
5747 broadcastOp.getSource());
5748 return success();
5751 // Otherwise, if the final result has the same element count, we can replace
5752 // with a shape cast.
5753 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5754 if (srcType.getNumElements() ==
5755 shapeCastOp.getResultVectorType().getNumElements()) {
5756 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5757 shapeCastOp, shapeCastOp.getResultVectorType(),
5758 broadcastOp.getSource());
5759 return success();
5763 return failure();
5767 } // namespace
5769 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5770 MLIRContext *context) {
5771 results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5772 ShapeCastBroadcastFolder>(context);
5775 //===----------------------------------------------------------------------===//
5776 // VectorBitCastOp
5777 //===----------------------------------------------------------------------===//
5779 LogicalResult BitCastOp::verify() {
5780 auto sourceVectorType = getSourceVectorType();
5781 auto resultVectorType = getResultVectorType();
5783 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5784 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5785 return emitOpError("dimension size mismatch at: ") << i;
5788 DataLayout dataLayout = DataLayout::closest(*this);
5789 auto sourceElementBits =
5790 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5791 auto resultElementBits =
5792 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5794 if (sourceVectorType.getRank() == 0) {
5795 if (sourceElementBits != resultElementBits)
5796 return emitOpError("source/result bitwidth of the 0-D vector element "
5797 "types must be equal");
5798 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5799 resultElementBits * resultVectorType.getShape().back()) {
5800 return emitOpError(
5801 "source/result bitwidth of the minor 1-D vectors must be equal");
5804 return success();
5807 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5808 // Nop cast.
5809 if (getSource().getType() == getResult().getType())
5810 return getSource();
5812 // Canceling bitcasts.
5813 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5814 if (getResult().getType() == otherOp.getSource().getType())
5815 return otherOp.getSource();
5817 setOperand(otherOp.getSource());
5818 return getResult();
5821 Attribute sourceConstant = adaptor.getSource();
5822 if (!sourceConstant)
5823 return {};
5825 Type srcElemType = getSourceVectorType().getElementType();
5826 Type dstElemType = getResultVectorType().getElementType();
5828 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5829 if (floatPack.isSplat()) {
5830 auto splat = floatPack.getSplatValue<FloatAttr>();
5832 // Casting fp16 into fp32.
5833 if (srcElemType.isF16() && dstElemType.isF32()) {
5834 uint32_t bits = static_cast<uint32_t>(
5835 splat.getValue().bitcastToAPInt().getZExtValue());
5836 // Duplicate the 16-bit pattern.
5837 bits = (bits << 16) | (bits & 0xffff);
5838 APInt intBits(32, bits);
5839 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5840 return DenseElementsAttr::get(getResultVectorType(), floatBits);
5845 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5846 if (intPack.isSplat()) {
5847 auto splat = intPack.getSplatValue<IntegerAttr>();
5849 if (llvm::isa<IntegerType>(dstElemType)) {
5850 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5851 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5853 // Casting to a larger integer bit width.
5854 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5855 APInt intBits = splat.getValue().zext(dstBitWidth);
5857 // Duplicate the lower width element.
5858 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5859 intBits = (intBits << srcBitWidth) | intBits;
5860 return DenseElementsAttr::get(getResultVectorType(), intBits);
5866 return {};
5869 //===----------------------------------------------------------------------===//
5870 // TypeCastOp
5871 //===----------------------------------------------------------------------===//
5873 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5874 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5875 SmallVector<int64_t, 8> res(memRefType.getShape());
5876 if (vectorType)
5877 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5878 return res;
5881 /// Build the canonical memRefType with a single vector.
5882 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5883 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5884 Value source) {
5885 result.addOperands(source);
5886 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5887 VectorType vectorType =
5888 VectorType::get(extractShape(memRefType),
5889 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
5890 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5891 memRefType.getMemorySpace()));
5894 LogicalResult TypeCastOp::verify() {
5895 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5896 if (!canonicalType.getLayout().isIdentity())
5897 return emitOpError("expects operand to be a memref with identity layout");
5898 if (!getResultMemRefType().getLayout().isIdentity())
5899 return emitOpError("expects result to be a memref with identity layout");
5900 if (getResultMemRefType().getMemorySpace() !=
5901 getMemRefType().getMemorySpace())
5902 return emitOpError("expects result in same memory space");
5904 auto sourceType = getMemRefType();
5905 auto resultType = getResultMemRefType();
5906 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5907 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
5908 return emitOpError(
5909 "expects result and operand with same underlying scalar type: ")
5910 << resultType;
5911 if (extractShape(sourceType) != extractShape(resultType))
5912 return emitOpError(
5913 "expects concatenated result and operand shapes to be equal: ")
5914 << resultType;
5915 return success();
5918 //===----------------------------------------------------------------------===//
5919 // TransposeOp
5920 //===----------------------------------------------------------------------===//
5922 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5923 Value vector, ArrayRef<int64_t> permutation) {
5924 VectorType vt = llvm::cast<VectorType>(vector.getType());
5925 SmallVector<int64_t, 4> transposedShape(vt.getRank());
5926 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5927 for (unsigned i = 0; i < permutation.size(); ++i) {
5928 transposedShape[i] = vt.getShape()[permutation[i]];
5929 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5932 result.addOperands(vector);
5933 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5934 transposedScalableDims));
5935 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5936 builder.getDenseI64ArrayAttr(permutation));
5939 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5940 // Eliminate splat constant transpose ops.
5941 if (auto attr =
5942 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5943 if (attr.isSplat())
5944 return attr.reshape(getResultVectorType());
5946 // Eliminate identity transpose ops. This happens when the dimensions of the
5947 // input vector remain in their original order after the transpose operation.
5948 ArrayRef<int64_t> perm = getPermutation();
5950 // Check if the permutation of the dimensions contains sequential values:
5951 // {0, 1, 2, ...}.
5952 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5953 if (perm[i] != i)
5954 return {};
5957 return getVector();
5960 LogicalResult vector::TransposeOp::verify() {
5961 VectorType vectorType = getSourceVectorType();
5962 VectorType resultType = getResultVectorType();
5963 int64_t rank = resultType.getRank();
5964 if (vectorType.getRank() != rank)
5965 return emitOpError("vector result rank mismatch: ") << rank;
5966 // Verify transposition array.
5967 ArrayRef<int64_t> perm = getPermutation();
5968 int64_t size = perm.size();
5969 if (rank != size)
5970 return emitOpError("transposition length mismatch: ") << size;
5971 SmallVector<bool, 8> seen(rank, false);
5972 for (const auto &ta : llvm::enumerate(perm)) {
5973 if (ta.value() < 0 || ta.value() >= rank)
5974 return emitOpError("transposition index out of range: ") << ta.value();
5975 if (seen[ta.value()])
5976 return emitOpError("duplicate position index: ") << ta.value();
5977 seen[ta.value()] = true;
5978 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5979 return emitOpError("dimension size mismatch at: ") << ta.value();
5981 return success();
5984 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5985 return llvm::to_vector<4>(getResultVectorType().getShape());
5988 namespace {
5990 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5991 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5992 public:
5993 using OpRewritePattern::OpRewritePattern;
5995 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5996 PatternRewriter &rewriter) const override {
5997 // Composes two permutations: result[i] = permutation1[permutation2[i]].
5998 auto composePermutations = [](ArrayRef<int64_t> permutation1,
5999 ArrayRef<int64_t> permutation2) {
6000 SmallVector<int64_t, 4> result;
6001 for (auto index : permutation2)
6002 result.push_back(permutation1[index]);
6003 return result;
6006 // Return if the input of 'transposeOp' is not defined by another transpose.
6007 vector::TransposeOp parentTransposeOp =
6008 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6009 if (!parentTransposeOp)
6010 return failure();
6012 SmallVector<int64_t, 4> permutation = composePermutations(
6013 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6014 // Replace 'transposeOp' with a new transpose operation.
6015 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6016 transposeOp, transposeOp.getResult().getType(),
6017 parentTransposeOp.getVector(), permutation);
6018 return success();
6022 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6023 struct FoldTransposedScalarBroadcast final
6024 : public OpRewritePattern<vector::TransposeOp> {
6025 using OpRewritePattern::OpRewritePattern;
6027 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6028 PatternRewriter &rewriter) const override {
6029 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6030 if (!bcastOp)
6031 return failure();
6033 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6034 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6035 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6036 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6037 return success();
6040 return failure();
6044 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6045 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6046 public:
6047 using OpRewritePattern::OpRewritePattern;
6049 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6050 PatternRewriter &rewriter) const override {
6051 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6052 if (!splatOp)
6053 return failure();
6055 rewriter.replaceOpWithNewOp<vector::SplatOp>(
6056 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6057 return success();
6061 /// Folds transpose(create_mask) into a new transposed create_mask.
6062 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6063 public:
6064 using OpRewritePattern::OpRewritePattern;
6066 LogicalResult matchAndRewrite(TransposeOp transpOp,
6067 PatternRewriter &rewriter) const override {
6068 Value transposeSrc = transpOp.getVector();
6069 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6070 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6071 if (!createMaskOp && !constantMaskOp)
6072 return failure();
6074 // Get the transpose permutation and apply it to the vector.create_mask or
6075 // vector.constant_mask operands.
6076 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6078 if (createMaskOp) {
6079 auto maskOperands = createMaskOp.getOperands();
6080 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6081 applyPermutationToVector(newOperands, permutation);
6083 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6084 transpOp, transpOp.getResultVectorType(), newOperands);
6085 return success();
6088 // ConstantMaskOp case.
6089 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6090 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6092 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6093 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6094 return success();
6098 } // namespace
6100 void vector::TransposeOp::getCanonicalizationPatterns(
6101 RewritePatternSet &results, MLIRContext *context) {
6102 results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6103 TransposeFolder, FoldTransposeSplat>(context);
6106 //===----------------------------------------------------------------------===//
6107 // ConstantMaskOp
6108 //===----------------------------------------------------------------------===//
6110 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6111 VectorType type, ConstantMaskKind kind) {
6112 assert(kind == ConstantMaskKind::AllTrue ||
6113 kind == ConstantMaskKind::AllFalse);
6114 build(builder, result, type,
6115 kind == ConstantMaskKind::AllTrue
6116 ? type.getShape()
6117 : SmallVector<int64_t>(type.getRank(), 0));
6120 LogicalResult ConstantMaskOp::verify() {
6121 auto resultType = llvm::cast<VectorType>(getResult().getType());
6122 // Check the corner case of 0-D vectors first.
6123 if (resultType.getRank() == 0) {
6124 if (getMaskDimSizes().size() != 1)
6125 return emitError("array attr must have length 1 for 0-D vectors");
6126 auto dim = getMaskDimSizes()[0];
6127 if (dim != 0 && dim != 1)
6128 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6129 return success();
6132 // Verify that array attr size matches the rank of the vector result.
6133 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6134 return emitOpError(
6135 "must specify array attr of size equal vector result rank");
6136 // Verify that each array attr element is in bounds of corresponding vector
6137 // result dimension size.
6138 auto resultShape = resultType.getShape();
6139 auto resultScalableDims = resultType.getScalableDims();
6140 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6141 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6142 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6143 return emitOpError(
6144 "array attr of size out of bounds of vector result dimension size");
6145 if (resultScalableDims[index] && maskDimSize != 0 &&
6146 maskDimSize != resultShape[index])
6147 return emitOpError(
6148 "only supports 'none set' or 'all set' scalable dimensions");
6150 // Verify that if one mask dim size is zero, they all should be zero (because
6151 // the mask region is a conjunction of each mask dimension interval).
6152 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6153 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6154 if (anyZeros && !allZeros)
6155 return emitOpError("expected all mask dim sizes to be zeros, "
6156 "as a result of conjunction with zero mask dim");
6157 return success();
6160 bool ConstantMaskOp::isAllOnesMask() {
6161 auto resultType = getVectorType();
6162 // Check the corner case of 0-D vectors first.
6163 if (resultType.getRank() == 0) {
6164 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6165 return getMaskDimSizes()[0] == 1;
6167 for (const auto [resultSize, maskDimSize] :
6168 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6169 if (maskDimSize < resultSize)
6170 return false;
6172 return true;
6175 //===----------------------------------------------------------------------===//
6176 // CreateMaskOp
6177 //===----------------------------------------------------------------------===//
6179 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6180 VectorType type,
6181 ArrayRef<OpFoldResult> mixedOperands) {
6182 SmallVector<Value> operands =
6183 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6184 build(builder, result, type, operands);
6187 LogicalResult CreateMaskOp::verify() {
6188 auto vectorType = llvm::cast<VectorType>(getResult().getType());
6189 // Verify that an operand was specified for each result vector each dimension.
6190 if (vectorType.getRank() == 0) {
6191 if (getNumOperands() != 1)
6192 return emitOpError(
6193 "must specify exactly one operand for 0-D create_mask");
6194 } else if (getNumOperands() !=
6195 llvm::cast<VectorType>(getResult().getType()).getRank()) {
6196 return emitOpError(
6197 "must specify an operand for each result vector dimension");
6199 return success();
6202 namespace {
6204 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6206 /// Ex 1:
6207 /// %c2 = arith.constant 2 : index
6208 /// %c3 = arith.constant 3 : index
6209 /// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6210 /// Becomes:
6211 /// vector.constant_mask [3, 2] : vector<4x3xi1>
6213 /// Ex 2:
6214 /// %c_neg_1 = arith.constant -1 : index
6215 /// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6216 /// becomes:
6217 /// vector.constant_mask [0] : vector<[8]xi1>
6219 /// Ex 3:
6220 /// %c8 = arith.constant 8 : index
6221 /// %c16 = arith.constant 16 : index
6222 /// %0 = vector.vscale
6223 /// %1 = arith.muli %0, %c16 : index
6224 /// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6225 /// becomes:
6226 /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6227 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6228 public:
6229 using OpRewritePattern::OpRewritePattern;
6231 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6232 PatternRewriter &rewriter) const override {
6233 VectorType maskType = createMaskOp.getVectorType();
6234 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6235 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6237 // Special case: Rank zero shape.
6238 constexpr std::array<int64_t, 1> rankZeroShape{1};
6239 constexpr std::array<bool, 1> rankZeroScalableDims{false};
6240 if (maskType.getRank() == 0) {
6241 maskTypeDimSizes = rankZeroShape;
6242 maskTypeDimScalableFlags = rankZeroScalableDims;
6245 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6246 // collect the `constantDims` (for the ConstantMaskOp).
6247 SmallVector<int64_t, 4> constantDims;
6248 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6249 if (auto intSize = getConstantIntValue(dimSize)) {
6250 // Constant value.
6251 // If the mask dim is non-scalable this can be any value.
6252 // If the mask dim is scalable only zero (all-false) is supported.
6253 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6254 return failure();
6255 constantDims.push_back(*intSize);
6256 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6257 // Constant vscale multiple (e.g. 4 x vscale).
6258 // Must be all-true to fold to a ConstantMask.
6259 if (vscaleMultiplier < maskTypeDimSizes[i])
6260 return failure();
6261 constantDims.push_back(*vscaleMultiplier);
6262 } else {
6263 return failure();
6267 // Clamp values to constant_mask bounds.
6268 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6269 value = std::clamp<int64_t>(value, 0, maskDimSize);
6271 // If one of dim sizes is zero, set all dims to zero.
6272 if (llvm::is_contained(constantDims, 0))
6273 constantDims.assign(constantDims.size(), 0);
6275 // Replace 'createMaskOp' with ConstantMaskOp.
6276 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6277 constantDims);
6278 return success();
6282 } // namespace
6284 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6285 MLIRContext *context) {
6286 results.add<CreateMaskFolder>(context);
6289 //===----------------------------------------------------------------------===//
6290 // MaskOp
6291 //===----------------------------------------------------------------------===//
6293 void MaskOp::build(
6294 OpBuilder &builder, OperationState &result, Value mask,
6295 Operation *maskableOp,
6296 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6297 assert(maskRegionBuilder &&
6298 "builder callback for 'maskRegion' must be present");
6300 result.addOperands(mask);
6301 OpBuilder::InsertionGuard guard(builder);
6302 Region *maskRegion = result.addRegion();
6303 builder.createBlock(maskRegion);
6304 maskRegionBuilder(builder, maskableOp);
6307 void MaskOp::build(
6308 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6309 Value mask, Operation *maskableOp,
6310 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6311 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6312 maskRegionBuilder);
6315 void MaskOp::build(
6316 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6317 Value mask, Value passthru, Operation *maskableOp,
6318 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6319 build(builder, result, mask, maskableOp, maskRegionBuilder);
6320 if (passthru)
6321 result.addOperands(passthru);
6322 result.addTypes(resultTypes);
6325 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6326 // Create the op region.
6327 result.regions.reserve(1);
6328 Region &maskRegion = *result.addRegion();
6330 auto &builder = parser.getBuilder();
6332 // Parse all the operands.
6333 OpAsmParser::UnresolvedOperand mask;
6334 if (parser.parseOperand(mask))
6335 return failure();
6337 // Optional passthru operand.
6338 OpAsmParser::UnresolvedOperand passthru;
6339 ParseResult parsePassthru = parser.parseOptionalComma();
6340 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6341 return failure();
6343 // Parse op region.
6344 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6345 return failure();
6347 MaskOp::ensureTerminator(maskRegion, builder, result.location);
6349 // Parse the optional attribute list.
6350 if (parser.parseOptionalAttrDict(result.attributes))
6351 return failure();
6353 // Parse all the types.
6354 Type maskType;
6355 if (parser.parseColonType(maskType))
6356 return failure();
6358 SmallVector<Type> resultTypes;
6359 if (parser.parseOptionalArrowTypeList(resultTypes))
6360 return failure();
6361 result.types.append(resultTypes);
6363 // Resolve operands.
6364 if (parser.resolveOperand(mask, maskType, result.operands))
6365 return failure();
6367 if (parsePassthru.succeeded())
6368 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6369 return failure();
6371 return success();
6374 void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
6375 p << " " << getMask();
6376 if (getPassthru())
6377 p << ", " << getPassthru();
6379 // Print single masked operation and skip terminator.
6380 p << " { ";
6381 Block *singleBlock = &getMaskRegion().getBlocks().front();
6382 if (singleBlock && !singleBlock->getOperations().empty())
6383 p.printCustomOrGenericOp(&singleBlock->front());
6384 p << " }";
6386 p.printOptionalAttrDict(getOperation()->getAttrs());
6388 p << " : " << getMask().getType();
6389 if (getNumResults() > 0)
6390 p << " -> " << getResultTypes();
6393 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6394 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6395 MaskOp>::ensureTerminator(region, builder, loc);
6396 // Keep the default yield terminator if the number of masked operations is not
6397 // the expected. This case will trigger a verification failure.
6398 Block &block = region.front();
6399 if (block.getOperations().size() != 2)
6400 return;
6402 // Replace default yield terminator with a new one that returns the results
6403 // from the masked operation.
6404 OpBuilder opBuilder(builder.getContext());
6405 Operation *maskedOp = &block.front();
6406 Operation *oldYieldOp = &block.back();
6407 assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6409 // Empty vector.mask op.
6410 if (maskedOp == oldYieldOp)
6411 return;
6413 opBuilder.setInsertionPoint(oldYieldOp);
6414 opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6415 oldYieldOp->dropAllReferences();
6416 oldYieldOp->erase();
6419 LogicalResult MaskOp::verify() {
6420 // Structural checks.
6421 Block &block = getMaskRegion().getBlocks().front();
6422 if (block.getOperations().empty())
6423 return emitOpError("expects a terminator within the mask region");
6425 unsigned numMaskRegionOps = block.getOperations().size();
6426 if (numMaskRegionOps > 2)
6427 return emitOpError("expects only one operation to mask");
6429 // Terminator checks.
6430 auto terminator = dyn_cast<vector::YieldOp>(block.back());
6431 if (!terminator)
6432 return emitOpError("expects a terminator within the mask region");
6434 if (terminator->getNumOperands() != getNumResults())
6435 return emitOpError(
6436 "expects number of results to match mask region yielded values");
6438 // Empty vector.mask. Nothing else to check.
6439 if (numMaskRegionOps == 1)
6440 return success();
6442 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6443 if (!maskableOp)
6444 return emitOpError("expects a MaskableOpInterface within the mask region");
6446 // Result checks.
6447 if (maskableOp->getNumResults() != getNumResults())
6448 return emitOpError("expects number of results to match maskable operation "
6449 "number of results");
6451 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6452 return emitOpError(
6453 "expects result type to match maskable operation result type");
6455 if (llvm::count_if(maskableOp->getResultTypes(),
6456 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6457 return emitOpError("multiple vector results not supported");
6459 // Mask checks.
6460 Type expectedMaskType = maskableOp.getExpectedMaskType();
6461 if (getMask().getType() != expectedMaskType)
6462 return emitOpError("expects a ")
6463 << expectedMaskType << " mask for the maskable operation";
6465 // Passthru checks.
6466 Value passthru = getPassthru();
6467 if (passthru) {
6468 if (!maskableOp.supportsPassthru())
6469 return emitOpError(
6470 "doesn't expect a passthru argument for this maskable operation");
6472 if (maskableOp->getNumResults() != 1)
6473 return emitOpError("expects result when passthru argument is provided");
6475 if (passthru.getType() != maskableOp->getResultTypes()[0])
6476 return emitOpError("expects passthru type to match result type");
6479 return success();
6482 /// Folds vector.mask ops with an all-true mask.
6483 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6484 SmallVectorImpl<OpFoldResult> &results) {
6485 MaskFormat maskFormat = getMaskFormat(getMask());
6486 if (isEmpty())
6487 return failure();
6489 if (maskFormat != MaskFormat::AllTrue)
6490 return failure();
6492 // Move maskable operation outside of the `vector.mask` region.
6493 Operation *maskableOp = getMaskableOp();
6494 maskableOp->dropAllUses();
6495 maskableOp->moveBefore(getOperation());
6497 llvm::append_range(results, maskableOp->getResults());
6498 return success();
6501 // Elides empty vector.mask operations with or without return values. Propagates
6502 // the yielded values by the vector.yield terminator, if any, or erases the op,
6503 // otherwise.
6504 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6505 using OpRewritePattern::OpRewritePattern;
6507 LogicalResult matchAndRewrite(MaskOp maskOp,
6508 PatternRewriter &rewriter) const override {
6509 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6510 if (maskingOp.getMaskableOp())
6511 return failure();
6513 if (!maskOp.isEmpty())
6514 return failure();
6516 Block *block = maskOp.getMaskBlock();
6517 auto terminator = cast<vector::YieldOp>(block->front());
6518 if (terminator.getNumOperands() == 0)
6519 rewriter.eraseOp(maskOp);
6520 else
6521 rewriter.replaceOp(maskOp, terminator.getOperands());
6523 return success();
6527 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6528 MLIRContext *context) {
6529 results.add<ElideEmptyMaskOp>(context);
6532 // MaskingOpInterface definitions.
6534 /// Returns the operation masked by this 'vector.mask'.
6535 Operation *MaskOp::getMaskableOp() {
6536 Block *block = getMaskBlock();
6537 if (block->getOperations().size() < 2)
6538 return nullptr;
6540 return &block->front();
6543 /// Returns true if 'vector.mask' has a passthru value.
6544 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6546 //===----------------------------------------------------------------------===//
6547 // ScanOp
6548 //===----------------------------------------------------------------------===//
6550 LogicalResult ScanOp::verify() {
6551 VectorType srcType = getSourceType();
6552 VectorType initialType = getInitialValueType();
6553 // Check reduction dimension < rank.
6554 int64_t srcRank = srcType.getRank();
6555 int64_t reductionDim = getReductionDim();
6556 if (reductionDim >= srcRank)
6557 return emitOpError("reduction dimension ")
6558 << reductionDim << " has to be less than " << srcRank;
6560 // Check that rank(initial_value) = rank(src) - 1.
6561 int64_t initialValueRank = initialType.getRank();
6562 if (initialValueRank != srcRank - 1)
6563 return emitOpError("initial value rank ")
6564 << initialValueRank << " has to be equal to " << srcRank - 1;
6566 // Check shapes of initial value and src.
6567 ArrayRef<int64_t> srcShape = srcType.getShape();
6568 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6569 SmallVector<int64_t> expectedShape;
6570 for (int i = 0; i < srcRank; i++) {
6571 if (i != reductionDim)
6572 expectedShape.push_back(srcShape[i]);
6574 if (!llvm::equal(initialValueShapes, expectedShape)) {
6575 return emitOpError("incompatible input/initial value shapes");
6578 // Verify supported reduction kind.
6579 Type eltType = getDestType().getElementType();
6580 if (!isSupportedCombiningKind(getKind(), eltType))
6581 return emitOpError("unsupported reduction type ")
6582 << eltType << " for kind '" << stringifyCombiningKind(getKind())
6583 << "'";
6585 return success();
6588 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
6589 RewritePatternSet &patterns, PatternBenefit benefit) {
6590 patterns
6591 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6592 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6593 StridedSliceConstantMaskFolder, TransposeFolder>(
6594 patterns.getContext(), benefit);
6597 //===----------------------------------------------------------------------===//
6598 // SplatOp
6599 //===----------------------------------------------------------------------===//
6601 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6602 auto constOperand = adaptor.getInput();
6603 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6604 return {};
6606 // SplatElementsAttr::get treats single value for second arg as being a splat.
6607 return SplatElementsAttr::get(getType(), {constOperand});
6610 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6611 SetIntRangeFn setResultRanges) {
6612 setResultRanges(getResult(), argRanges.front());
6615 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
6616 CombiningKind kind, Value v1, Value acc,
6617 arith::FastMathFlagsAttr fastmath,
6618 Value mask) {
6619 Type t1 = getElementTypeOrSelf(v1.getType());
6620 Type tAcc = getElementTypeOrSelf(acc.getType());
6621 Value result;
6623 switch (kind) {
6624 case CombiningKind::ADD:
6625 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6626 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6627 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6628 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6629 else
6630 llvm_unreachable("invalid value types for ADD reduction");
6631 break;
6632 case CombiningKind::AND:
6633 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6634 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6635 break;
6636 case CombiningKind::MAXNUMF:
6637 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6638 "expected float values");
6639 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6640 break;
6641 case CombiningKind::MAXIMUMF:
6642 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6643 "expected float values");
6644 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6645 break;
6646 case CombiningKind::MINNUMF:
6647 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6648 "expected float values");
6649 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6650 break;
6651 case CombiningKind::MINIMUMF:
6652 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6653 "expected float values");
6654 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6655 break;
6656 case CombiningKind::MAXSI:
6657 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6658 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6659 break;
6660 case CombiningKind::MINSI:
6661 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6662 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6663 break;
6664 case CombiningKind::MAXUI:
6665 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6666 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6667 break;
6668 case CombiningKind::MINUI:
6669 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6670 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6671 break;
6672 case CombiningKind::MUL:
6673 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6674 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6675 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6676 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6677 else
6678 llvm_unreachable("invalid value types for MUL reduction");
6679 break;
6680 case CombiningKind::OR:
6681 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6682 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6683 break;
6684 case CombiningKind::XOR:
6685 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6686 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6687 break;
6690 assert(result && "unknown CombiningKind");
6691 return selectPassthru(b, mask, result, acc);
6694 //===----------------------------------------------------------------------===//
6695 // Vector Masking Utilities
6696 //===----------------------------------------------------------------------===//
6698 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6699 /// as masked operation.
6700 void mlir::vector::createMaskOpRegion(OpBuilder &builder,
6701 Operation *maskableOp) {
6702 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6703 Block *insBlock = builder.getInsertionBlock();
6704 // Create a block and move the op to that block.
6705 insBlock->getOperations().splice(
6706 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6707 builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6710 /// Creates a vector.mask operation around a maskable operation. Returns the
6711 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6712 /// the maskable operation itself.
6713 Operation *mlir::vector::maskOperation(OpBuilder &builder,
6714 Operation *maskableOp, Value mask,
6715 Value passthru) {
6716 if (!mask)
6717 return maskableOp;
6718 if (passthru)
6719 return builder.create<MaskOp>(maskableOp->getLoc(),
6720 maskableOp->getResultTypes(), mask, passthru,
6721 maskableOp, createMaskOpRegion);
6722 return builder.create<MaskOp>(maskableOp->getLoc(),
6723 maskableOp->getResultTypes(), mask, maskableOp,
6724 createMaskOpRegion);
6727 /// Creates a vector select operation that picks values from `newValue` or
6728 /// `passthru` for each result vector lane based on `mask`. This utility is used
6729 /// to propagate the pass-thru value of vector.mask or for cases where only the
6730 /// pass-thru value propagation is needed. VP intrinsics do not support
6731 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6732 /// usually able to match op + select patterns and fold them into a native
6733 /// target instructions.
6734 Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
6735 Value newValue, Value passthru) {
6736 if (!mask)
6737 return newValue;
6739 return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6740 mask, newValue, passthru);
6743 //===----------------------------------------------------------------------===//
6744 // TableGen'd op method definitions
6745 //===----------------------------------------------------------------------===//
6747 #define GET_ATTRDEF_CLASSES
6748 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6750 #define GET_OP_CLASSES
6751 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"