[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / Tensor / IR / TensorOps.cpp
blob616d4a7d0a0ab545562f95fe9afbf4ee5b7c34b6
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Arith/Utils/Utils.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributeInterfaces.h"
19 #include "mlir/IR/BuiltinTypeInterfaces.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
33 #include <algorithm>
34 #include <optional>
36 using namespace mlir;
37 using namespace mlir::tensor;
39 using llvm::divideCeilSigned;
40 using llvm::divideFloorSigned;
41 using llvm::mod;
43 /// Materialize a single constant operation from a given attribute value with
44 /// the desired resultant type.
45 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
46 Attribute value, Type type,
47 Location loc) {
48 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
49 return op;
50 if (complex::ConstantOp::isBuildableWith(value, type))
51 return builder.create<complex::ConstantOp>(loc, type,
52 llvm::cast<ArrayAttr>(value));
53 return nullptr;
56 OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
57 int64_t dim) {
58 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
59 SmallVector<OpFoldResult> result;
60 if (tensorType.isDynamicDim(dim))
61 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
63 return builder.getIndexAttr(tensorType.getDimSize(dim));
66 SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
67 Location loc, Value value) {
68 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
69 SmallVector<OpFoldResult> result;
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
71 result.push_back(getMixedSize(builder, loc, value, i));
72 return result;
75 FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
76 OpResult opResult) {
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78 assert(tensorType && "expected tensor type");
80 // If the op has a destination, it implements DestinationStyleOpInterface and
81 // we can query the destination operand from that interface.
82 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83 if (destOp)
84 return destOp.getTiedOpOperand(opResult)->get();
86 // Otherwise, create a new destination tensor with the same shape.
87 OpBuilder::InsertionGuard g(b);
88 b.setInsertionPoint(opResult.getDefiningOp());
90 // Compute sizes.
91 SmallVector<OpFoldResult> mixedSizes;
92 if (!tensorType.hasStaticShape()) {
93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94 ReifiedRankedShapedTypeDims reifiedShapes;
95 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96 return failure();
97 mixedSizes = reifiedShapes[opResult.getResultNumber()];
98 } else {
99 // Static shape: Take static sizes directly.
100 for (int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(b.getIndexAttr(sz));
104 // Create empty tensor.
105 Value emptyTensor =
106 b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
107 return emptyTensor;
110 LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
111 Operation *op,
112 SmallVector<Value> &result) {
113 for (OpResult opResult : op->getResults()) {
114 if (llvm::isa<TensorType>(opResult.getType())) {
115 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116 if (failed(destination))
117 return failure();
118 result.push_back(*destination);
121 return success();
124 bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
125 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
129 return false;
131 return tp1 == tp2; // default implementation
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
136 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137 ArrayRef<OpFoldResult> mixedSizes) {
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
143 // Rank-reduced dims must have a static unit dimension.
144 bool isStaticUnitSize =
145 size.value().is<Attribute>() &&
146 llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
148 if (shapePos < 0) {
149 // There are no more dims in the reduced shape. All remaining sizes must
150 // be rank-reduced dims.
151 assert(isStaticUnitSize && "expected unit dim");
152 droppedDims.set(idx);
153 continue;
156 // Dim is preserved if the size is not a static 1.
157 if (!isStaticUnitSize) {
158 --shapePos;
159 continue;
162 // Dim is preserved if the reduced shape dim is also 1.
163 if (reducedShape[shapePos] == 1) {
164 --shapePos;
165 continue;
168 // Otherwise: Dim is dropped.
169 droppedDims.set(idx);
172 assert(shapePos < 0 && "dimension mismatch");
173 return droppedDims;
176 /// Given a ranked tensor type and a range of values that defines its dynamic
177 /// dimension sizes, turn all dynamic sizes that have a constant value into
178 /// static dimension sizes.
179 static RankedTensorType
180 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181 SmallVector<Value> &foldedDynamicSizes) {
182 SmallVector<int64_t> staticShape(type.getShape());
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
186 // Compute new static and dynamic sizes.
187 unsigned ctr = 0;
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
191 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192 if (cst.has_value()) {
193 // Dynamic size must be non-negative.
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
196 continue;
198 staticShape[i] = *cst;
199 } else {
200 foldedDynamicSizes.push_back(dynamicSize);
205 return RankedTensorType::get(staticShape, type.getElementType(),
206 type.getEncoding());
209 //===----------------------------------------------------------------------===//
210 // BitcastOp
211 //===----------------------------------------------------------------------===//
213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214 if (inputs.size() != 1 || outputs.size() != 1)
215 return false;
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
219 if (!aT || !bT)
220 return false;
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223 return false;
225 return succeeded(verifyCompatibleShape(aT, bT));
228 namespace {
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231 /// operation.
232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter) const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
240 return failure();
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
245 return success();
249 } // namespace
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252 MLIRContext *context) {
253 results.add<ChainedTensorBitcast>(context);
256 //===----------------------------------------------------------------------===//
257 // CastOp
258 //===----------------------------------------------------------------------===//
260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261 setNameFn(getResult(), "cast");
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
266 bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
270 // Requires RankedTensorType.
271 if (!sourceType || !targetType)
272 return false;
274 // Requires same elemental type.
275 if (sourceType.getElementType() != targetType.getElementType())
276 return false;
278 // Requires same rank.
279 if (sourceType.getRank() != targetType.getRank())
280 return false;
282 // Requires same encoding.
283 if (sourceType.getEncoding() != targetType.getEncoding())
284 return false;
286 // If cast is towards more static sizes along any dimension, don't fold.
287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (!ShapedType::isDynamic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
290 return false;
293 return true;
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
307 /// Example:
308 /// ```mlir
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
311 /// ```
313 /// folds into:
315 /// ```mlir
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
317 /// ```
318 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
319 if (!castOp)
320 return false;
322 // Can fold if the source of cast has at least as much static information as
323 // its results.
324 return preservesStaticInformation(castOp.getType(),
325 castOp.getSource().getType());
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
331 /// being from different dialects. Returns true when all conditions are met:
332 /// 1. source and result and ranked tensors with same element type and rank.
333 /// 2. the result type has more static information than the source.
335 /// Example:
336 /// ```mlir
337 /// %1 = producer ... : tensor<?x?xf32>
338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
339 /// ```
341 /// can be canonicalized to :
343 /// ```mlir
344 /// %2 = producer ... : tensor<8x16xf32>
345 /// ```
346 /// Not all ops might be canonicalizable this way, but for those that can be,
347 /// this method provides a check that it is worth doing the canonicalization.
348 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
349 if (!castOp)
350 return false;
351 return preservesStaticInformation(castOp.getSource().getType(),
352 castOp.getType());
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
356 /// that can be folded.
357 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
358 bool folded = false;
359 for (OpOperand &operand : op->getOpOperands()) {
360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
361 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
362 operand.set(castOp.getOperand());
363 folded = true;
366 return success(folded);
369 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
370 if (inputs.size() != 1 || outputs.size() != 1)
371 return false;
372 Type a = inputs.front(), b = outputs.front();
373 auto aT = llvm::dyn_cast<TensorType>(a);
374 auto bT = llvm::dyn_cast<TensorType>(b);
375 if (!aT || !bT)
376 return false;
378 if (aT.getElementType() != bT.getElementType())
379 return false;
381 return succeeded(verifyCompatibleShape(aT, bT));
384 /// Compute a TensorType that has the joined shape knowledge of the two
385 /// given TensorTypes. The element types need to match.
386 static TensorType joinShapes(TensorType one, TensorType two) {
387 assert(one.getElementType() == two.getElementType());
389 if (!one.hasRank())
390 return two;
391 if (!two.hasRank())
392 return one;
394 int64_t rank = one.getRank();
395 if (rank != two.getRank())
396 return {};
398 SmallVector<int64_t, 4> join;
399 join.reserve(rank);
400 for (int64_t i = 0; i < rank; ++i) {
401 if (one.isDynamicDim(i)) {
402 join.push_back(two.getDimSize(i));
403 continue;
405 if (two.isDynamicDim(i)) {
406 join.push_back(one.getDimSize(i));
407 continue;
409 if (one.getDimSize(i) != two.getDimSize(i))
410 return {};
411 join.push_back(one.getDimSize(i));
413 return RankedTensorType::get(join, one.getElementType());
416 namespace {
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
419 /// operation if doing so does not remove runtime constraints.
420 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
421 using OpRewritePattern<CastOp>::OpRewritePattern;
423 LogicalResult matchAndRewrite(CastOp tensorCast,
424 PatternRewriter &rewriter) const final {
425 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 if (!tensorCastOperand)
428 return failure();
430 auto sourceType =
431 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
435 // We can remove the intermediate cast if joining all three produces the
436 // same result as just joining the source and result shapes.
437 auto firstJoin =
438 joinShapes(joinShapes(sourceType, intermediateType), resultType);
440 // The join might not exist if the cast sequence would fail at runtime.
441 if (!firstJoin)
442 return failure();
444 // The newJoin always exists if the above join exists, it might just contain
445 // less information. If so, we cannot drop the intermediate cast, as doing
446 // so would remove runtime checks.
447 auto newJoin = joinShapes(sourceType, resultType);
448 if (firstJoin != newJoin)
449 return failure();
451 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452 tensorCastOperand.getOperand());
453 return success();
457 /// Fold tensor.cast into tesor.extract_slice producer.
458 /// Example:
459 /// ```
460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 /// tensor<128x512xf32> to tensor<?x512xf32>
462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
463 /// ```
464 /// ->
465 /// ```
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 /// tensor<128x512xf32> to tensor<16x512xf32>
468 /// ```
469 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
470 using OpRewritePattern<CastOp>::OpRewritePattern;
472 LogicalResult matchAndRewrite(CastOp tensorCast,
473 PatternRewriter &rewriter) const final {
474 auto extractOperand =
475 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
477 // Cannot fold cast to unranked tensor.
478 auto rankedResultType =
479 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480 if (!rankedResultType)
481 return failure();
483 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
484 rankedResultType.getShape() ==
485 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
486 .getShape())
487 return failure();
489 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
490 auto dimMask = computeRankReductionMask(
491 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
492 size_t dimIndex = 0;
493 for (size_t i = 0, e = sizes.size(); i < e; i++) {
494 if (dimMask && dimMask->count(i))
495 continue;
496 int64_t dim = rankedResultType.getShape()[dimIndex++];
497 if (ShapedType::isDynamic(dim))
498 continue;
499 sizes[i] = rewriter.getIndexAttr(dim);
502 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503 tensorCast, rankedResultType, extractOperand.getSource(),
504 extractOperand.getMixedOffsets(), sizes,
505 extractOperand.getMixedStrides());
506 return success();
510 } // namespace
512 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
513 MLIRContext *context) {
514 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
517 //===----------------------------------------------------------------------===//
518 // ConcatOp
519 //===----------------------------------------------------------------------===//
521 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
522 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
523 auto tensorTypes =
524 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
525 return llvm::cast<RankedTensorType>(type);
526 }));
527 int64_t concatRank = tensorTypes[0].getRank();
529 // The concatenation dim must be in the range [0, rank).
530 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
532 SmallVector<int64_t> sizes(concatRank);
533 for (int64_t i = 0, e = concatRank; i < e; ++i) {
534 if (i == dim)
535 continue;
536 SaturatedInteger size;
537 for (auto tensorType : tensorTypes)
538 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
539 sizes[i] = size.asInteger();
541 auto concatSize = SaturatedInteger::wrap(0);
542 for (auto tensorType : tensorTypes)
543 concatSize =
544 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
545 sizes[dim] = concatSize.asInteger();
546 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
549 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
550 ValueRange inputs) {
551 FailureOr<RankedTensorType> resultType =
552 inferResultType(dim, inputs.getTypes());
553 assert(succeeded(resultType) && "failed to infer concatenation result type");
554 build(builder, result, *resultType, dim, inputs);
557 LogicalResult ConcatOp::verify() {
558 if (getInputs().size() < 1)
559 return emitOpError("requires at least one input");
561 SmallVector<RankedTensorType> inputTypes;
562 for (auto input : getInputs())
563 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 RankedTensorType resultType = getResultType();
566 int64_t resultRank = getRank();
567 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568 return type.getRank() != resultRank;
570 return emitOpError("rank of concatenated inputs must match result rank");
572 Type resultElementType = resultType.getElementType();
573 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574 return type.getElementType() != resultElementType;
576 return emitOpError("inputs and result element type must match");
578 int64_t dim = getDim();
579 if (dim >= resultRank)
580 return emitOpError("concatenation dim must be less than the tensor rank");
582 SmallVector<int64_t> sizes(resultRank);
583 for (int64_t i = 0, e = resultRank; i < e; ++i) {
584 if (i == dim)
585 continue;
586 SaturatedInteger size;
587 for (auto tensorType : inputTypes) {
588 FailureOr<SaturatedInteger> maybeSize =
589 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
590 if (failed(maybeSize))
591 return emitOpError("static concatenation size mismatch along ")
592 << "non-concatenated dimension " << i;
593 size = *maybeSize;
595 sizes[i] = size.asInteger();
597 auto concatSize = SaturatedInteger::wrap(0);
598 for (auto tensorType : inputTypes)
599 concatSize =
600 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
601 sizes[dim] = concatSize.asInteger();
602 auto inferredResultType =
603 RankedTensorType::get(sizes, inputTypes[0].getElementType());
605 for (auto [inferredSize, actualSize] :
606 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608 ShapedType::isDynamic(actualSize);
609 if (!hasDynamic && inferredSize != actualSize)
610 return emitOpError("result type ")
611 << resultType << "does not match inferred shape "
612 << inferredResultType << " static sizes";
615 return success();
618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
619 size_t numInputs = getInputs().size();
620 uint64_t concatDim = getDim();
622 SmallVector<SmallVector<OpFoldResult>> inputShapes;
623 inputShapes.reserve(numInputs);
624 SmallVector<OpFoldResult> concatOffsets;
625 concatOffsets.reserve(numInputs);
626 SmallVector<OpFoldResult> outputShape;
628 AffineExpr addExpr =
629 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
630 OpFoldResult zero = builder.getIndexAttr(0);
631 Location loc = getLoc();
632 for (auto [index, input] : llvm::enumerate(getInputs())) {
633 SmallVector<OpFoldResult> inputShape =
634 tensor::getMixedSizes(builder, input.getLoc(), input);
635 if (index == 0) {
636 outputShape = inputShape;
637 concatOffsets.push_back(zero);
638 } else {
639 concatOffsets.push_back(outputShape[concatDim]);
640 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
641 builder, loc, addExpr,
642 {outputShape[concatDim], inputShape[concatDim]});
644 inputShapes.emplace_back(std::move(inputShape));
647 Value replacement = builder.create<tensor::EmptyOp>(
648 loc, outputShape, getType().getElementType());
650 int64_t rank = getType().getRank();
651 OpFoldResult one = builder.getIndexAttr(1);
652 SmallVector<OpFoldResult> strides(rank, one);
653 SmallVector<OpFoldResult> offsets(rank, zero);
654 for (auto [index, input] : llvm::enumerate(getInputs())) {
655 offsets[concatDim] = concatOffsets[index];
656 auto insertSlice = builder.create<tensor::InsertSliceOp>(
657 loc, input, replacement, offsets, inputShapes[index], strides);
658 replacement = insertSlice.getResult();
660 if (replacement.getType() != getType()) {
661 replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
663 return SmallVector<Value>{replacement};
666 LogicalResult
667 ConcatOp::reifyResultShapes(OpBuilder &builder,
668 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
669 ValueRange inputs = getInputs();
670 int64_t dim = getDim();
671 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
673 Value init = inputs[0];
674 int64_t rank = getType().getRank();
676 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
678 // Pre-populate the result sizes with as much static information as possible
679 // from the given result type, as well as the inferred result type, otherwise
680 // use the dim sizes from the first input.
681 for (int64_t i = 0; i < rank; ++i) {
682 if (i == dim)
683 continue;
684 if (!getType().isDynamicDim(i)) {
685 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
686 } else if (!inferredResultType.isDynamicDim(i)) {
687 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
688 builder, getLoc(),
689 builder.getIndexAttr(inferredResultType.getDimSize(i)));
690 } else {
691 reifiedReturnShapes[0][i] =
692 builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
696 if (getType().isDynamicDim(dim)) {
697 // Take the sum of the input sizes along the concatenated dim.
698 AffineExpr sum = builder.getAffineDimExpr(0);
699 SmallVector<OpFoldResult> sizes = {
700 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
701 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
702 sum = sum + builder.getAffineDimExpr(idx + 1);
703 sizes.push_back(
704 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
706 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
707 builder, getLoc(),
708 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
709 } else {
710 // If the result shape is static along the concatenated dim, use the static
711 // shape.
712 reifiedReturnShapes[0][dim] =
713 builder.getIndexAttr(getType().getDimSize(dim));
715 return success();
718 void ConcatOp::getAsmResultNames(
719 function_ref<void(Value, StringRef)> setNameFn) {
720 setNameFn(getResult(), "concat");
723 OpFoldResult ConcatOp::fold(FoldAdaptor) {
724 ValueRange inputs = getInputs();
725 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
726 return inputs[0];
727 return {};
730 namespace {
731 /// Fold a concat op with a single input to a cast.
732 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
733 using OpRewritePattern<ConcatOp>::OpRewritePattern;
735 LogicalResult matchAndRewrite(ConcatOp concatOp,
736 PatternRewriter &rewriter) const override {
737 if (concatOp.getInputs().size() != 1)
738 return failure();
739 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
740 concatOp.getInputs()[0]);
741 return success();
744 } // namespace
746 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
747 MLIRContext *context) {
748 results.add<SingleInputConcatOp>(context);
751 //===----------------------------------------------------------------------===//
752 // DimOp
753 //===----------------------------------------------------------------------===//
755 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
756 setNameFn(getResult(), "dim");
759 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
760 int64_t index) {
761 auto loc = result.location;
762 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
763 build(builder, result, source, indexValue);
766 std::optional<int64_t> DimOp::getConstantIndex() {
767 return getConstantIntValue(getIndex());
770 Speculation::Speculatability DimOp::getSpeculatability() {
771 auto constantIndex = getConstantIndex();
772 if (!constantIndex)
773 return Speculation::NotSpeculatable;
775 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
776 if (!rankedSourceType)
777 return Speculation::NotSpeculatable;
779 if (rankedSourceType.getRank() <= constantIndex)
780 return Speculation::NotSpeculatable;
782 return Speculation::Speculatable;
785 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
786 // All forms of folding require a known index.
787 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
788 if (!index)
789 return {};
791 // Folding for unranked types (UnrankedTensorType) is not supported.
792 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
793 if (!tensorType)
794 return {};
796 // Out of bound indices produce undefined behavior but are still valid IR.
797 // Don't choke on them.
798 int64_t indexVal = index.getInt();
799 if (indexVal < 0 || indexVal >= tensorType.getRank())
800 return {};
802 // Fold if the shape extent along the given index is known.
803 if (!tensorType.isDynamicDim(index.getInt())) {
804 Builder builder(getContext());
805 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
808 Operation *definingOp = getSource().getDefiningOp();
810 // Fold dim to the operand of tensor.generate.
811 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
812 auto resultType =
813 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
814 // The case where the type encodes the size of the dimension is handled
815 // above.
816 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
818 // Find the operand of the fromElements that corresponds to this index.
819 auto dynExtents = fromElements.getDynamicExtents().begin();
820 for (auto dim : resultType.getShape().take_front(index.getInt()))
821 if (ShapedType::isDynamic(dim))
822 dynExtents++;
824 return Value{*dynExtents};
827 // The size at the given index is now known to be a dynamic size.
828 unsigned unsignedIndex = index.getValue().getZExtValue();
830 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
831 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
832 // `resolve-shaped-type-result-dims` pass.
833 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
834 sliceOp.isDynamicSize(unsignedIndex)) {
835 return {sliceOp.getDynamicSize(unsignedIndex)};
839 // dim(cast) -> dim
840 if (succeeded(foldTensorCast(*this)))
841 return getResult();
843 return {};
846 namespace {
847 /// Fold dim of a cast into the dim of the source of the tensor cast.
848 struct DimOfCastOp : public OpRewritePattern<DimOp> {
849 using OpRewritePattern<DimOp>::OpRewritePattern;
851 LogicalResult matchAndRewrite(DimOp dimOp,
852 PatternRewriter &rewriter) const override {
853 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
854 if (!castOp)
855 return failure();
856 Value newSource = castOp.getOperand();
857 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
858 return success();
862 /// Fold dim of a destination passing style op into the dim of the corresponding
863 /// init.
864 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
865 using OpRewritePattern<DimOp>::OpRewritePattern;
867 LogicalResult matchAndRewrite(DimOp dimOp,
868 PatternRewriter &rewriter) const override {
869 auto source = dimOp.getSource();
870 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
871 if (!destOp)
872 return failure();
874 auto resultIndex = cast<OpResult>(source).getResultNumber();
875 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
877 rewriter.modifyOpInPlace(
878 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
879 return success();
883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
884 /// operand.
885 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
886 using OpRewritePattern<DimOp>::OpRewritePattern;
888 LogicalResult matchAndRewrite(DimOp dim,
889 PatternRewriter &rewriter) const override {
890 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
892 if (!reshape)
893 return failure();
895 // Since tensors are immutable we don't need to worry about where to place
896 // the extract call
897 rewriter.setInsertionPointAfter(dim);
898 Location loc = dim.getLoc();
899 Value extract =
900 rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
901 if (extract.getType() != dim.getType())
902 extract =
903 rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
904 rewriter.replaceOp(dim, extract);
905 return success();
908 } // namespace
910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
911 MLIRContext *context) {
912 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
915 //===----------------------------------------------------------------------===//
916 // EmptyOp
917 //===----------------------------------------------------------------------===//
919 void EmptyOp::build(OpBuilder &builder, OperationState &result,
920 ArrayRef<int64_t> staticShape, Type elementType,
921 Attribute encoding) {
922 assert(all_of(staticShape,
923 [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
924 "expected only static sizes");
925 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
928 void EmptyOp::build(OpBuilder &builder, OperationState &result,
929 ArrayRef<int64_t> staticShape, Type elementType,
930 ValueRange dynamicSizes, Attribute encoding) {
931 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
932 build(builder, result, tensorType, dynamicSizes);
935 void EmptyOp::build(OpBuilder &builder, OperationState &result,
936 ArrayRef<OpFoldResult> sizes, Type elementType,
937 Attribute encoding) {
938 SmallVector<int64_t> staticShape;
939 SmallVector<Value> dynamicSizes;
940 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
941 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
944 LogicalResult EmptyOp::verify() {
945 if (getType().getNumDynamicDims() != getDynamicSizes().size())
946 return emitOpError("incorrect number of dynamic sizes, has ")
947 << getDynamicSizes().size() << ", expected "
948 << getType().getNumDynamicDims();
949 return success();
952 LogicalResult
953 EmptyOp::reifyResultShapes(OpBuilder &builder,
954 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
955 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
956 unsigned ctr = 0;
957 for (int64_t i = 0; i < getType().getRank(); ++i) {
958 if (getType().isDynamicDim(i)) {
959 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
960 } else {
961 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
964 return success();
967 Value EmptyOp::getDynamicSize(unsigned idx) {
968 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
969 unsigned ctr = 0;
970 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
971 if (getType().isDynamicDim(i))
972 ++ctr;
973 return getDynamicSizes()[ctr];
976 SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
977 SmallVector<OpFoldResult> result;
978 unsigned ctr = 0;
979 OpBuilder b(getContext());
980 for (int64_t i = 0; i < getType().getRank(); ++i) {
981 if (getType().isDynamicDim(i)) {
982 result.push_back(getDynamicSizes()[ctr++]);
983 } else {
984 result.push_back(b.getIndexAttr(getType().getShape()[i]));
987 return result;
990 namespace {
991 /// Change the type of the result of a `tensor.empty` by making the result
992 /// type statically sized along dimensions that in the original operation were
993 /// defined as dynamic, but the size was defined using a `constant` op. For
994 /// example
996 /// %c5 = arith.constant 5: index
997 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
999 /// to
1001 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1002 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1003 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1005 LogicalResult matchAndRewrite(EmptyOp op,
1006 PatternRewriter &rewriter) const override {
1007 SmallVector<Value> foldedDynamicSizes;
1008 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1009 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1011 // Stop here if no dynamic size was promoted to static.
1012 if (foldedTensorType == op.getType())
1013 return failure();
1015 auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1016 foldedDynamicSizes);
1017 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1018 return success();
1022 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1023 using OpRewritePattern<DimOp>::OpRewritePattern;
1025 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1026 PatternRewriter &rewriter) const override {
1027 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1028 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1029 if (!emptyTensorOp || !maybeConstantIndex)
1030 return failure();
1031 auto emptyTensorType = emptyTensorOp.getType();
1032 if (*maybeConstantIndex < 0 ||
1033 *maybeConstantIndex >= emptyTensorType.getRank() ||
1034 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1035 return failure();
1036 rewriter.replaceOp(dimOp,
1037 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1038 return success();
1042 /// Canonicalize
1044 /// ```mlir
1045 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1046 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1047 /// ```
1049 /// into
1051 /// ```mlir
1052 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1053 /// ```
1055 /// This assumes the input program is correct in terms of its shape. So it is
1056 /// safe to assume that `%d0` is in fact 4.
1057 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1058 using OpRewritePattern<CastOp>::OpRewritePattern;
1060 LogicalResult matchAndRewrite(CastOp castOp,
1061 PatternRewriter &rewriter) const override {
1062 if (!canFoldIntoProducerOp(castOp))
1063 return failure();
1064 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1065 if (!producer)
1066 return failure();
1068 auto resultType =
1069 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1070 ArrayRef<int64_t> resultShape = resultType.getShape();
1071 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1072 SmallVector<OpFoldResult> newMixedSizes;
1073 newMixedSizes.reserve(currMixedSizes.size());
1074 assert(resultShape.size() == currMixedSizes.size() &&
1075 "mismatch in result shape and sizes of empty op");
1076 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1077 int64_t newDim = std::get<0>(it);
1078 OpFoldResult currDim = std::get<1>(it);
1079 // Case 1: The empty tensor dim is static. Check that the tensor cast
1080 // result dim matches.
1081 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1082 if (ShapedType::isDynamic(newDim) ||
1083 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1084 // Something is off, the cast result shape cannot be more dynamic
1085 // than the empty tensor result shape (enforced by
1086 // `canFoldIntoProducer`). Abort for now.
1087 return rewriter.notifyMatchFailure(
1088 producer, "mismatch in static value of shape of empty tensor "
1089 "result and cast result");
1091 newMixedSizes.push_back(attr);
1092 continue;
1095 // Case 2 : The tensor cast shape is static, but empty tensor result
1096 // shape is dynamic.
1097 if (!ShapedType::isDynamic(newDim)) {
1098 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1099 continue;
1102 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1103 // shape is dynamic. Use the dynamic value from the empty tensor op.
1104 newMixedSizes.push_back(currDim);
1107 // TODO: Do not drop tensor encoding.
1108 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1109 resultType.getElementType());
1110 return success();
1114 } // namespace
1116 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1117 MLIRContext *context) {
1118 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1119 ReplaceEmptyTensorStaticShapeDims>(context);
1122 /// Try to remove a tensor operation if it would only reshape a constant.
1123 /// Removes the op and replaces the constant with a new constant of the result
1124 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1125 /// splat value matches the value in the attribute.
1126 static OpFoldResult
1127 reshapeConstantSource(DenseElementsAttr source, TensorType result,
1128 std::optional<Attribute> cst = std::nullopt) {
1129 if (source && source.isSplat() && result.hasStaticShape() &&
1130 (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1131 return source.resizeSplat(result);
1133 return {};
1136 //===----------------------------------------------------------------------===//
1137 // ExtractOp
1138 //===----------------------------------------------------------------------===//
1140 namespace {
1142 /// Canonicalizes the pattern of the form
1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1147 /// to
1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1150 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1151 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1153 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1154 PatternRewriter &rewriter) const final {
1155 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1156 if (!tensorCast)
1157 return failure();
1158 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1159 return failure();
1160 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1161 extract, tensorCast.getSource(), extract.getIndices());
1162 return success();
1166 } // namespace
1168 void ExtractOp::getAsmResultNames(
1169 function_ref<void(Value, StringRef)> setNameFn) {
1170 setNameFn(getResult(), "extracted");
1173 LogicalResult ExtractOp::verify() {
1174 // Verify the # indices match if we have a ranked type.
1175 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1176 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1177 return emitOpError("incorrect number of indices for extract_element");
1178 return success();
1181 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1182 if (Attribute tensor = adaptor.getTensor()) {
1183 // If this is a splat elements attribute, simply return the value.
1184 // All of the elements of a splat attribute are the same.
1185 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1186 return splatTensor.getSplatValue<Attribute>();
1188 // If this is a dense resource elements attribute, return.
1189 if (isa<DenseResourceElementsAttr>(tensor))
1190 return {};
1193 // Collect the constant indices into the tensor.
1194 SmallVector<uint64_t, 8> indices;
1195 for (Attribute indice : adaptor.getIndices()) {
1196 if (!indice || !llvm::isa<IntegerAttr>(indice))
1197 return {};
1198 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1201 // Fold extract(from_elements(...)).
1202 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1203 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1204 auto rank = tensorType.getRank();
1205 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1206 "rank mismatch");
1207 int flatIndex = 0;
1208 int stride = 1;
1209 for (int i = rank - 1; i >= 0; --i) {
1210 flatIndex += indices[i] * stride;
1211 stride *= tensorType.getDimSize(i);
1213 // Prevent out of bounds accesses. This can happen in invalid code that
1214 // will never execute.
1215 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1216 flatIndex < 0)
1217 return {};
1218 return fromElementsOp.getElements()[flatIndex];
1221 // If this is an elements attribute, query the value at the given indices.
1222 if (Attribute tensor = adaptor.getTensor()) {
1223 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1224 if (elementsAttr && elementsAttr.isValidIndex(indices))
1225 return elementsAttr.getValues<Attribute>()[indices];
1228 return {};
1231 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1232 MLIRContext *context) {
1233 results.add<ExtractFromTensorCast>(context);
1236 //===----------------------------------------------------------------------===//
1237 // FromElementsOp
1238 //===----------------------------------------------------------------------===//
1240 void FromElementsOp::getAsmResultNames(
1241 function_ref<void(Value, StringRef)> setNameFn) {
1242 setNameFn(getResult(), "from_elements");
1245 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1246 ValueRange elements) {
1247 assert(!elements.empty() && "expected at least one element");
1248 Type resultType = RankedTensorType::get(
1249 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1250 build(builder, result, resultType, elements);
1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1254 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1255 return DenseElementsAttr::get(getType(), adaptor.getElements());
1256 return {};
1259 namespace {
1261 // Pushes the index_casts that occur before extractions to after the extract.
1262 // This minimizes type conversion in some cases and enables the extract
1263 // canonicalizer. This changes:
1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1268 // to the following:
1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1271 // %cast = arith.index_cast %extract : i32 to index
1273 // to just %element.
1275 // Consider expanding this to a template and handle all tensor cast
1276 // operations.
1277 struct ExtractElementFromIndexCast
1278 : public OpRewritePattern<tensor::ExtractOp> {
1279 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1281 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1282 PatternRewriter &rewriter) const final {
1283 Location loc = extract.getLoc();
1284 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1285 if (!indexCast)
1286 return failure();
1288 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1290 auto newExtract = rewriter.create<tensor::ExtractOp>(
1291 loc, elementTy, indexCast.getIn(), extract.getIndices());
1293 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1294 newExtract);
1296 return success();
1300 } // namespace
1302 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1303 MLIRContext *context) {
1304 results.add<ExtractElementFromIndexCast>(context);
1307 //===----------------------------------------------------------------------===//
1308 // GatherOp
1309 //===----------------------------------------------------------------------===//
1311 void GatherOp::getAsmResultNames(
1312 function_ref<void(Value, StringRef)> setNameFn) {
1313 setNameFn(getResult(), "gather");
1316 /// Return the inferred result type for a gatherOp where:
1317 /// - sourceType is the type of the source tensor gathered from
1318 /// - indicesType is the type of the indices used to gather
1319 /// - gatherDims are the dims along which the gather occurs.
1320 /// Return a full rank or ranked-reduced variant of the type depending on
1321 /// the value of rankReduced.
1323 /// The leading dimensions of the index tensor give the result tensor its
1324 /// leading dimensions.
1325 /// The trailing dimensions of the result tensor are obtained from the source
1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1327 /// rankedReduced is false), or skipping them (otherwise).
1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1329 RankedTensorType indicesType,
1330 ArrayRef<int64_t> gatherDims,
1331 bool rankReduced) {
1332 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1333 resultShape.reserve(resultShape.size() + sourceType.getRank());
1334 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1335 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1336 if (!rankReduced)
1337 resultShape.push_back(1);
1338 continue;
1340 resultShape.push_back(sourceType.getDimSize(idx));
1342 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1345 static LogicalResult
1346 verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
1347 ArrayRef<int64_t> indices, int64_t rank,
1348 StringRef gatherOrScatter, StringRef sourceOrDest) {
1349 if (dims.empty())
1350 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1352 int64_t numGatherDims = dims.size();
1353 if (numGatherDims > rank)
1354 return op->emitOpError(gatherOrScatter)
1355 << "_dims overflow " << sourceOrDest << " rank";
1356 if (indices.empty() || indices.back() != numGatherDims)
1357 return op->emitOpError(gatherOrScatter)
1358 << "_dims length must match the size of last dimension of indices";
1359 for (int64_t val : dims) {
1360 if (val < 0)
1361 return op->emitOpError(gatherOrScatter)
1362 << "_dims value must be non-negative";
1363 if (val >= rank)
1364 return op->emitOpError(gatherOrScatter)
1365 << "_dims value must be smaller than " << sourceOrDest << " rank";
1367 for (int64_t i = 1; i < numGatherDims; ++i) {
1368 if (dims[i - 1] >= dims[i])
1369 return op->emitOpError(gatherOrScatter)
1370 << "_dims values must be strictly increasing";
1372 return success();
1375 LogicalResult GatherOp::verify() {
1376 int64_t sourceRank = getSourceType().getRank();
1377 ArrayRef<int64_t> gatherDims = getGatherDims();
1378 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1379 getIndicesType().getShape(), sourceRank,
1380 "gather", "source")))
1381 return failure();
1383 RankedTensorType expectedResultType = GatherOp::inferResultType(
1384 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1385 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1386 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1387 if (getResultType() != expectedResultType &&
1388 getResultType() != expectedRankReducedResultType) {
1389 return emitOpError("result type "
1390 "mismatch: "
1391 "expected ")
1392 << expectedResultType << " or its rank-reduced variant "
1393 << expectedRankReducedResultType << " (got: " << getResultType()
1394 << ")";
1397 return success();
1400 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1401 if (OpFoldResult reshapedSource = reshapeConstantSource(
1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1403 getResult().getType()))
1404 return reshapedSource;
1405 return {};
1408 //===----------------------------------------------------------------------===//
1409 // InsertOp
1410 //===----------------------------------------------------------------------===//
1412 void InsertOp::getAsmResultNames(
1413 function_ref<void(Value, StringRef)> setNameFn) {
1414 setNameFn(getResult(), "inserted");
1417 LogicalResult InsertOp::verify() {
1418 // Verify the # indices match if we have a ranked type.
1419 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1420 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1421 return emitOpError("incorrect number of indices");
1422 return success();
1425 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1426 Attribute scalar = adaptor.getScalar();
1427 Attribute dest = adaptor.getDest();
1428 if (scalar && dest)
1429 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1430 if (scalar == splatDest.getSplatValue<Attribute>())
1431 return dest;
1432 return {};
1435 //===----------------------------------------------------------------------===//
1436 // GenerateOp
1437 //===----------------------------------------------------------------------===//
1439 void GenerateOp::getAsmResultNames(
1440 function_ref<void(Value, StringRef)> setNameFn) {
1441 setNameFn(getResult(), "generated");
1444 LogicalResult GenerateOp::reifyResultShapes(
1445 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1446 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1447 int idx = 0;
1448 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1449 if (getType().isDynamicDim(dim)) {
1450 reifiedReturnShapes[0][dim] = getOperand(idx++);
1451 } else {
1452 reifiedReturnShapes[0][dim] =
1453 builder.getIndexAttr(getType().getDimSize(dim));
1456 return success();
1459 LogicalResult GenerateOp::verify() {
1460 // Ensure that the tensor type has as many dynamic dimensions as are
1461 // specified by the operands.
1462 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1463 if (getNumOperands() != resultType.getNumDynamicDims())
1464 return emitError("must have as many index operands as dynamic extents "
1465 "in the result type");
1466 return success();
1469 LogicalResult GenerateOp::verifyRegions() {
1470 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1471 // Ensure that region arguments span the index space.
1472 if (!llvm::all_of(getBody().getArgumentTypes(),
1473 [](Type ty) { return ty.isIndex(); }))
1474 return emitError("all body arguments must be index");
1475 if (getBody().getNumArguments() != resultTy.getRank())
1476 return emitError("must have one body argument per input dimension");
1478 // Ensure that the region yields an element of the right type.
1479 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1481 if (yieldOp.getValue().getType() != resultTy.getElementType())
1482 return emitOpError(
1483 "body must be terminated with a `yield` operation of the tensor "
1484 "element type");
1486 return success();
1489 void GenerateOp::build(
1490 OpBuilder &b, OperationState &result, Type resultTy,
1491 ValueRange dynamicExtents,
1492 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1493 build(b, result, resultTy, dynamicExtents);
1495 // Build and populate body.
1496 OpBuilder::InsertionGuard guard(b);
1497 Region *bodyRegion = result.regions.front().get();
1498 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1499 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1500 SmallVector<Location, 2> argumentLocs(rank, result.location);
1501 Block *bodyBlock =
1502 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1503 bodyBuilder(b, result.location, bodyBlock->getArguments());
1506 namespace {
1508 /// Canonicalizes tensor.generate operations with a constant
1509 /// operand into the equivalent operation with the operand expressed in the
1510 /// result type, instead. We also insert a type cast to make sure that the
1511 /// resulting IR is still well-typed.
1512 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1513 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1515 LogicalResult matchAndRewrite(GenerateOp generateOp,
1516 PatternRewriter &rewriter) const final {
1517 SmallVector<Value> foldedDynamicSizes;
1518 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1519 generateOp.getType(), generateOp.getDynamicExtents(),
1520 foldedDynamicSizes);
1522 // Stop here if no dynamic size was promoted to static.
1523 if (foldedTensorType == generateOp.getType())
1524 return failure();
1526 auto loc = generateOp.getLoc();
1527 auto newOp =
1528 rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1529 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1530 newOp.getBody().begin());
1531 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1532 generateOp.getType(), newOp);
1533 return success();
1537 /// Canonicalizes the pattern of the form
1539 /// %tensor = tensor.generate %x {
1540 /// ^bb0(%arg0: index):
1541 /// <computation>
1542 /// yield %1 : index
1543 /// } : tensor<?xindex>
1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1547 /// tensor.generate operation has no side-effects.
1548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1549 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1551 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1552 PatternRewriter &rewriter) const final {
1553 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1554 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1555 return failure();
1557 IRMapping mapping;
1558 Block *body = &tensorFromElements.getBody().front();
1559 mapping.map(body->getArguments(), extract.getIndices());
1560 for (auto &op : body->without_terminator())
1561 rewriter.clone(op, mapping);
1563 auto yield = cast<YieldOp>(body->getTerminator());
1565 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1566 return success();
1570 } // namespace
1572 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1573 MLIRContext *context) {
1574 // TODO: Move extract pattern to tensor::ExtractOp.
1575 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1578 //===----------------------------------------------------------------------===//
1579 // RankOp
1580 //===----------------------------------------------------------------------===//
1582 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1583 setNameFn(getResult(), "rank");
1586 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1587 // Constant fold rank when the rank of the operand is known.
1588 auto type = getOperand().getType();
1589 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1590 if (shapedType && shapedType.hasRank())
1591 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1592 return IntegerAttr();
1595 //===----------------------------------------------------------------------===//
1596 // ReshapeOp
1597 //===----------------------------------------------------------------------===//
1599 void ReshapeOp::getAsmResultNames(
1600 function_ref<void(Value, StringRef)> setNameFn) {
1601 setNameFn(getResult(), "reshape");
1604 static int64_t getNumElements(ShapedType type) {
1605 int64_t numElements = 1;
1606 for (auto dim : type.getShape())
1607 numElements *= dim;
1608 return numElements;
1611 LogicalResult ReshapeOp::verify() {
1612 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1613 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1615 if (operandType.getElementType() != resultType.getElementType())
1616 return emitOpError("element types of source and destination tensor "
1617 "types should be the same");
1619 int64_t shapeSize =
1620 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1621 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1622 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1624 if (resultRankedType) {
1625 if (operandRankedType && resultRankedType.hasStaticShape() &&
1626 operandRankedType.hasStaticShape()) {
1627 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1628 return emitOpError("source and destination tensor should have the "
1629 "same number of elements");
1631 if (ShapedType::isDynamic(shapeSize))
1632 return emitOpError("cannot use shape operand with dynamic length to "
1633 "reshape to statically-ranked tensor type");
1634 if (shapeSize != resultRankedType.getRank())
1635 return emitOpError(
1636 "length of shape operand differs from the result's tensor rank");
1638 return success();
1641 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1642 if (OpFoldResult reshapedSource = reshapeConstantSource(
1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1644 getResult().getType()))
1645 return reshapedSource;
1647 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1648 // producer's input instead as the original tensor to reshape. This could
1649 // render such producer dead code.
1650 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1651 getSourceMutable().assign(reshapeOpProducer.getSource());
1652 return getResult();
1655 auto source = getSource();
1656 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1657 auto resultTy = dyn_cast<RankedTensorType>(getType());
1658 if (!sourceTy || !resultTy || sourceTy != resultTy)
1659 return {};
1661 // If the source and result are both 1D tensors and have the same type, the
1662 // reshape has no effect, even if the tensor is dynamically shaped.
1663 if (sourceTy.getRank() == 1)
1664 return source;
1666 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1667 auto elements = fromElements.getElements();
1668 bool dynamicNoop =
1669 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1670 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1671 auto element = elements[id];
1673 if (auto cst = getConstantIntValue(element)) {
1674 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1675 continue;
1678 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1679 dynamicNoop &= dimOp.getSource() == source;
1681 APSInt dim;
1682 auto cst = getConstantIntValue(dimOp.getIndex());
1683 dynamicNoop &=
1684 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1685 continue;
1688 dynamicNoop = false;
1689 break;
1692 if (dynamicNoop)
1693 return source;
1696 return {};
1699 //===----------------------------------------------------------------------===//
1700 // Reassociative reshape ops
1701 //===----------------------------------------------------------------------===//
1703 void CollapseShapeOp::getAsmResultNames(
1704 function_ref<void(Value, StringRef)> setNameFn) {
1705 setNameFn(getResult(), "collapsed");
1708 void ExpandShapeOp::getAsmResultNames(
1709 function_ref<void(Value, StringRef)> setNameFn) {
1710 setNameFn(getResult(), "expanded");
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1714 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1715 "invalid resultDim");
1716 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1717 if (llvm::is_contained(it.value(), resultDim))
1718 return it.index();
1719 llvm_unreachable("could not find reassociation group");
1722 FailureOr<SmallVector<OpFoldResult>>
1723 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1724 RankedTensorType expandedType,
1725 ArrayRef<ReassociationIndices> reassociation,
1726 ArrayRef<OpFoldResult> inputShape) {
1727 std::optional<SmallVector<OpFoldResult>> outputShape =
1728 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1729 inputShape);
1730 if (!outputShape)
1731 return failure();
1732 return *outputShape;
1735 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1736 Type resultType, Value src,
1737 ArrayRef<ReassociationIndices> reassociation,
1738 ArrayRef<OpFoldResult> outputShape) {
1739 auto [staticOutputShape, dynamicOutputShape] =
1740 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1741 build(builder, result, cast<RankedTensorType>(resultType), src,
1742 getReassociationIndicesAttribute(builder, reassociation),
1743 dynamicOutputShape, staticOutputShape);
1746 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1747 Type resultType, Value src,
1748 ArrayRef<ReassociationIndices> reassociation) {
1749 SmallVector<OpFoldResult> inputShape =
1750 getMixedSizes(builder, result.location, src);
1751 auto tensorResultTy = cast<RankedTensorType>(resultType);
1752 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1753 builder, result.location, tensorResultTy, reassociation, inputShape);
1754 SmallVector<OpFoldResult> outputShapeOrEmpty;
1755 if (succeeded(outputShape)) {
1756 outputShapeOrEmpty = *outputShape;
1758 build(builder, result, tensorResultTy, src, reassociation,
1759 outputShapeOrEmpty);
1762 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1763 return getSymbolLessAffineMaps(getReassociationExprs());
1765 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1766 return convertReassociationIndicesToExprs(getContext(),
1767 getReassociationIndices());
1770 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1771 return getSymbolLessAffineMaps(getReassociationExprs());
1773 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1774 return convertReassociationIndicesToExprs(getContext(),
1775 getReassociationIndices());
1778 RankedTensorType CollapseShapeOp::inferCollapsedType(
1779 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1780 return inferCollapsedType(
1781 type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1782 type.getContext(), reassociation)));
1785 /// Compute the RankedTensorType obtained by applying `reassociation` to
1786 /// `type`.
1787 RankedTensorType
1788 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1789 ArrayRef<AffineMap> reassociation) {
1790 auto shape = type.getShape();
1791 SmallVector<int64_t, 4> newShape;
1792 newShape.reserve(reassociation.size());
1794 // Use the fact that reassociation is valid to simplify the logic: only use
1795 // each map's rank.
1796 assert(isReassociationValid(reassociation) && "invalid reassociation");
1797 unsigned currentDim = 0;
1798 for (AffineMap m : reassociation) {
1799 unsigned dim = m.getNumResults();
1800 auto band = shape.slice(currentDim, dim);
1801 int64_t size = 1;
1802 if (llvm::is_contained(band, ShapedType::kDynamic))
1803 size = ShapedType::kDynamic;
1804 else
1805 for (unsigned d = 0; d < dim; ++d)
1806 size *= shape[currentDim + d];
1807 newShape.push_back(size);
1808 currentDim += dim;
1811 return RankedTensorType::get(newShape, type.getElementType());
1814 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1815 ArrayRef<ReassociationIndices> reassociation,
1816 ArrayRef<NamedAttribute> attrs) {
1817 auto resultType = inferCollapsedType(
1818 llvm::cast<RankedTensorType>(src.getType()),
1819 getSymbolLessAffineMaps(
1820 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1821 result.addAttribute(getReassociationAttrStrName(),
1822 getReassociationIndicesAttribute(b, reassociation));
1823 build(b, result, resultType, src, attrs);
1826 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1827 TensorReshapeOp, ExpandShapeOp>::value>
1828 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1829 RankedTensorType expandedType,
1830 RankedTensorType collapsedType) {
1831 if (failed(
1832 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1833 return failure();
1835 auto maps = op.getReassociationMaps();
1836 RankedTensorType expectedType =
1837 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1838 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1839 return op.emitOpError("expected collapsed type to be ")
1840 << expectedType << ", but got " << collapsedType;
1841 return success();
1844 LogicalResult ExpandShapeOp::verify() {
1845 auto srcType = getSrcType();
1846 auto resultType = getResultType();
1848 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1849 return emitOpError("expected number of static shape dims to be equal to "
1850 "the output rank (")
1851 << resultType.getRank() << ") but found "
1852 << getStaticOutputShape().size() << " inputs instead";
1854 if ((int64_t)getOutputShape().size() !=
1855 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1856 return emitOpError("mismatch in dynamic dims in output_shape and "
1857 "static_output_shape: static_output_shape has ")
1858 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1859 << " dynamic dims while output_shape has " << getOutputShape().size()
1860 << " values";
1862 return verifyTensorReshapeOp(*this, resultType, srcType);
1865 LogicalResult CollapseShapeOp::verify() {
1866 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1869 namespace {
1870 /// Reshape of a splat constant can be replaced with a constant of the result
1871 /// type.
1872 template <typename TensorReshapeOp>
1873 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1874 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1875 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1876 PatternRewriter &rewriter) const override {
1877 DenseElementsAttr attr;
1878 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1879 return failure();
1880 if (!attr || !attr.isSplat())
1881 return failure();
1882 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
1883 reshapeOp.getResultType(), attr.getRawData());
1884 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1885 return success();
1889 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1890 template <typename TensorReshapeOp>
1891 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1892 public:
1893 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1895 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1896 PatternRewriter &rewriter) const override {
1897 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1898 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1899 return failure();
1901 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1902 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1903 return success();
1907 /// Reshape of a FromElements can be replaced with a FromElements of the
1908 /// result type
1909 template <typename TensorReshapeOp>
1910 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1911 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1912 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1913 PatternRewriter &rewriter) const override {
1914 auto fromElements =
1915 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1916 if (!fromElements)
1917 return failure();
1919 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1921 if (!shapedTy.hasStaticShape())
1922 return failure();
1924 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1925 fromElements.getElements());
1926 return success();
1930 // Fold CastOp into CollapseShapeOp when adding static information.
1931 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1932 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1934 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1935 PatternRewriter &rewriter) const override {
1936 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1937 if (!tensor::canFoldIntoConsumerOp(castOp))
1938 return failure();
1940 RankedTensorType srcType =
1941 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1942 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1943 srcType, collapseShapeOp.getReassociationMaps());
1945 if (newResultType == collapseShapeOp.getResultType()) {
1946 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1947 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1949 } else {
1950 auto newOp = rewriter.create<CollapseShapeOp>(
1951 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1952 collapseShapeOp.getReassociation());
1953 rewriter.replaceOpWithNewOp<tensor::CastOp>(
1954 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1956 return success();
1960 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1961 using OpRewritePattern<DimOp>::OpRewritePattern;
1963 LogicalResult matchAndRewrite(DimOp dimOp,
1964 PatternRewriter &rewriter) const override {
1965 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1966 if (!expandShapeOp)
1967 return failure();
1969 // Only constant dimension values are supported.
1970 std::optional<int64_t> dim = dimOp.getConstantIndex();
1971 if (!dim.has_value())
1972 return failure();
1974 // Skip static dims. These are folded to constant ops.
1975 RankedTensorType resultType = expandShapeOp.getResultType();
1976 if (!resultType.isDynamicDim(*dim))
1977 return failure();
1979 // Find reassociation group that contains this result dimension.
1980 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1982 // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1983 // ExpandShapeOp would be ambiguous.)
1984 int64_t product = 1;
1985 ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1986 for (int64_t d : grp) {
1987 if (d != dim) {
1988 assert(!resultType.isDynamicDim(d) && "expected static dim");
1989 product *= resultType.getDimSize(d);
1993 // result dim size = src dim size / (product(other dims in reassoc group))
1994 Value srcDimSz =
1995 rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1996 AffineExpr expr;
1997 bindSymbols(dimOp.getContext(), expr);
1998 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1999 dimOp, expr.floorDiv(product), srcDimSz);
2000 return success();
2004 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
2005 using OpRewritePattern<DimOp>::OpRewritePattern;
2007 LogicalResult matchAndRewrite(DimOp dimOp,
2008 PatternRewriter &rewriter) const override {
2009 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2010 if (!collapseShapeOp)
2011 return failure();
2013 // Only constant dimension values are supported.
2014 std::optional<int64_t> dim = dimOp.getConstantIndex();
2015 if (!dim.has_value())
2016 return failure();
2018 // Skip static dims. These are folded to constant ops.
2019 RankedTensorType resultType = collapseShapeOp.getResultType();
2020 if (!resultType.isDynamicDim(*dim))
2021 return failure();
2023 // Get reassociation group of the result dimension.
2024 ReassociationIndices group =
2025 collapseShapeOp.getReassociationIndices()[*dim];
2027 // result dim size = product(dims in reassoc group)
2028 SmallVector<Value> srcDimSizes;
2029 SmallVector<AffineExpr> syms;
2030 AffineExpr product;
2031 for (const auto &it : llvm::enumerate(group)) {
2032 srcDimSizes.push_back(rewriter.create<DimOp>(
2033 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2034 syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
2035 product = product ? product * syms.back() : syms.back();
2037 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
2038 srcDimSizes);
2039 return success();
2043 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2044 /// matching constant output_shape operands of the expand. This makes the
2045 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2046 /// propagated further.
2047 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2048 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2050 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2051 PatternRewriter &rewriter) const override {
2052 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2053 if (!canFoldIntoConsumerOp(castOp))
2054 return failure();
2056 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2057 SmallVector<ReassociationIndices, 4> reassoc =
2058 expandOp.getReassociationIndices();
2060 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2061 SmallVector<Value> dynamicOutputShape;
2062 auto outputIt = expandOp.getOutputShape().begin();
2064 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2065 for (uint64_t outDim : innerReassoc) {
2066 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2067 continue;
2069 // If the cast's src type is dynamic, don't infer any of the
2070 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2071 // least one of the expanded dimensions to be dynamic if the input is
2072 // dynamic.
2073 Value val = *outputIt;
2074 ++outputIt;
2075 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2076 dynamicOutputShape.push_back(val);
2077 continue;
2080 APInt cst;
2081 if (matchPattern(val, m_ConstantInt(&cst))) {
2082 newOutputShape[outDim] = cst.getSExtValue();
2083 } else {
2084 dynamicOutputShape.push_back(val);
2089 // Couldn't match any values, nothing to change
2090 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2091 return failure();
2093 // Calculate the input shape from the output
2094 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2095 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2096 for (auto outDim : reassoc[inDim]) {
2097 auto ofr = newOutputShape[outDim];
2098 if (ShapedType::isDynamic(ofr)) {
2099 newInputShape[inDim] = ShapedType::kDynamic;
2100 break;
2102 newInputShape[inDim] *= ofr;
2106 SmallVector<OpFoldResult> outputOfr =
2107 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2108 auto inputType = RankedTensorType::get(
2109 newInputShape, expandOp.getSrcType().getElementType());
2110 auto outputType = RankedTensorType::get(
2111 newOutputShape, expandOp.getSrcType().getElementType());
2112 auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2113 expandOp.getSrc());
2114 auto newExpand = rewriter.create<ExpandShapeOp>(
2115 expandOp.getLoc(), outputType, inputCast.getResult(),
2116 expandOp.getReassociationIndices(), outputOfr);
2117 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2118 newExpand.getResult());
2119 return success();
2122 } // namespace
2124 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2125 MLIRContext *context) {
2126 results.add<
2127 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2128 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2129 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2130 FoldReshapeWithSplat<ExpandShapeOp>,
2131 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2132 FoldDimOfCollapseShape>(context);
2135 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2136 MLIRContext *context) {
2137 results.add<
2138 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2139 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2140 tensor::DimOp, RankedTensorType>,
2141 FoldReshapeWithConstant<CollapseShapeOp>,
2142 FoldReshapeWithSplat<CollapseShapeOp>,
2143 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2144 context);
2147 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2148 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2149 adaptor.getOperands());
2152 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2153 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2154 adaptor.getOperands());
2157 //===----------------------------------------------------------------------===//
2158 // ExtractSliceOp
2159 //===----------------------------------------------------------------------===//
2161 void ExtractSliceOp::getAsmResultNames(
2162 function_ref<void(Value, StringRef)> setNameFn) {
2163 setNameFn(getResult(), "extracted_slice");
2166 /// An extract_slice result type can be inferred, when it is not
2167 /// rank-reduced, from the source type and the static representation of
2168 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2169 RankedTensorType ExtractSliceOp::inferResultType(
2170 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2171 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2172 // An extract_slice op may specify only a leading subset of offset/sizes/
2173 // strides in which case we complete with offset=0, sizes from memref type
2174 // and strides=1.
2175 assert(static_cast<int64_t>(staticSizes.size()) ==
2176 sourceTensorType.getRank() &&
2177 "unexpected staticSizes not equal to rank of source");
2178 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2179 sourceTensorType.getEncoding());
2182 RankedTensorType ExtractSliceOp::inferResultType(
2183 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2184 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2185 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2186 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2187 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2188 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2189 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2190 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2191 staticSizes, staticStrides);
2194 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2195 /// number of sizes), drop as many size 1 as needed to produce an inferred
2196 /// type with the desired rank.
2198 /// Note that there may be multiple ways to compute this rank-reduced type:
2199 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2201 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2202 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2203 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2204 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2205 ArrayRef<int64_t> strides) {
2206 // Type inferred in the absence of rank-reducing behavior.
2207 auto inferredType = llvm::cast<RankedTensorType>(
2208 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2209 int rankDiff = inferredType.getRank() - desiredResultRank;
2210 if (rankDiff > 0) {
2211 auto shape = inferredType.getShape();
2212 llvm::SmallBitVector dimsToProject =
2213 getPositionsOfShapeOne(rankDiff, shape);
2214 SmallVector<int64_t> projectedShape;
2215 // Best effort rank-reducing: drop 1s in order.
2216 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2217 if (!dimsToProject.test(pos))
2218 projectedShape.push_back(shape[pos]);
2219 inferredType =
2220 RankedTensorType::get(projectedShape, inferredType.getElementType());
2222 return inferredType;
2225 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2226 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2227 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2228 ArrayRef<OpFoldResult> strides) {
2229 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2230 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2231 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2232 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2233 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2234 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2235 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2236 staticStrides);
2239 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2240 /// result type. If the type passed is nullptr, it is inferred.
2241 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2242 RankedTensorType resultType, Value source,
2243 ArrayRef<OpFoldResult> offsets,
2244 ArrayRef<OpFoldResult> sizes,
2245 ArrayRef<OpFoldResult> strides,
2246 ArrayRef<NamedAttribute> attrs) {
2247 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2248 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2249 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2250 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2251 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2252 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2253 // Structuring implementation this way avoids duplication between builders.
2254 if (!resultType) {
2255 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2256 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2258 result.addAttributes(attrs);
2259 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2260 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2261 b.getDenseI64ArrayAttr(staticSizes),
2262 b.getDenseI64ArrayAttr(staticStrides));
2265 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2266 /// result type.
2267 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2268 ArrayRef<OpFoldResult> offsets,
2269 ArrayRef<OpFoldResult> sizes,
2270 ArrayRef<OpFoldResult> strides,
2271 ArrayRef<NamedAttribute> attrs) {
2272 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2275 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2276 /// a Range vector.
2277 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2278 ArrayRef<Range> ranges,
2279 ArrayRef<NamedAttribute> attrs) {
2280 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2281 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2284 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2285 /// the type passed is nullptr, it is inferred.
2286 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2287 RankedTensorType resultType, Value source,
2288 ValueRange offsets, ValueRange sizes,
2289 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2290 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2291 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2292 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2293 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2294 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2295 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2296 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2299 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2300 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2301 ValueRange offsets, ValueRange sizes,
2302 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2303 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2306 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2307 Operation *op,
2308 RankedTensorType expectedType) {
2309 switch (result) {
2310 case SliceVerificationResult::Success:
2311 return success();
2312 case SliceVerificationResult::RankTooLarge:
2313 return op->emitError("expected rank to be smaller or equal to ")
2314 << "the other rank. ";
2315 case SliceVerificationResult::SizeMismatch:
2316 return op->emitError("expected type to be ")
2317 << expectedType << " or a rank-reduced version. (size mismatch) ";
2318 case SliceVerificationResult::ElemTypeMismatch:
2319 return op->emitError("expected element type to be ")
2320 << expectedType.getElementType();
2321 default:
2322 llvm_unreachable("unexpected extract_slice op verification result");
2326 /// Verifier for ExtractSliceOp.
2327 LogicalResult ExtractSliceOp::verify() {
2328 // Verify result type against inferred type.
2329 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2330 getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2331 SliceVerificationResult result = isRankReducedType(expectedType, getType());
2332 return produceSliceErrorMsg(result, *this, expectedType);
2335 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2336 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2339 FailureOr<Value>
2340 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2341 ArrayRef<int64_t> desiredShape) {
2342 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2343 assert(sourceTensorType && "not a ranked tensor type");
2344 auto sourceShape = sourceTensorType.getShape();
2345 if (sourceShape.equals(desiredShape))
2346 return value;
2347 auto maybeRankReductionMask =
2348 mlir::computeRankReductionMask(sourceShape, desiredShape);
2349 if (!maybeRankReductionMask)
2350 return failure();
2351 return createCanonicalRankReducingExtractSliceOp(
2352 b, loc, value,
2353 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2356 LogicalResult ExtractSliceOp::reifyResultShapes(
2357 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2358 reifiedReturnShapes.resize(1);
2359 reifiedReturnShapes[0].reserve(getType().getRank());
2360 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2361 llvm::SmallBitVector droppedDims = getDroppedDims();
2362 for (const auto &size : enumerate(mixedSizes)) {
2363 if (droppedDims.test(size.index()))
2364 continue;
2365 reifiedReturnShapes[0].push_back(size.value());
2367 return success();
2370 namespace {
2371 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2372 /// This essentially pushes memref_cast past its consuming slice when
2373 /// `canFoldIntoConsumerOp` is true.
2375 /// Example:
2376 /// ```
2377 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2378 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2379 /// tensor<3x4xf32>
2380 /// ```
2381 /// is rewritten into:
2382 /// ```
2383 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2384 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2385 /// ```
2386 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2387 public:
2388 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2390 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2391 PatternRewriter &rewriter) const override {
2392 // Any constant operand, just return to let the constant folder kick in.
2393 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2394 return matchPattern(operand, matchConstantIndex());
2396 return failure();
2398 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2399 if (!castOp)
2400 return failure();
2402 if (!canFoldIntoConsumerOp(castOp))
2403 return failure();
2405 // Create folded extract.
2406 Location loc = sliceOp.getLoc();
2407 Value newResult = rewriter.create<ExtractSliceOp>(
2408 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2409 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2410 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2411 if (newResult.getType() != sliceOp.getType())
2412 newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2413 rewriter.replaceOp(sliceOp, newResult);
2414 return success();
2418 /// Slice elements from `values` into `outValues`. `counts` represents the
2419 /// numbers of elements to stride in the original values for each dimension.
2420 /// The output values can be used to construct a DenseElementsAttr.
2421 template <typename IterTy, typename ElemTy>
2422 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2423 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2424 ArrayRef<int64_t> strides,
2425 llvm::SmallVectorImpl<ElemTy> *outValues) {
2426 assert(offsets.size() == sizes.size());
2427 assert(offsets.size() == strides.size());
2428 if (offsets.empty())
2429 return;
2431 int64_t offset = offsets.front();
2432 int64_t size = sizes.front();
2433 int64_t stride = strides.front();
2434 if (offsets.size() == 1) {
2435 for (int64_t i = 0; i < size; ++i, offset += stride)
2436 outValues->push_back(*(values + offset));
2438 return;
2441 for (int64_t i = 0; i < size; ++i, offset += stride) {
2442 auto begin = values + offset * counts.front();
2443 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2444 offsets.drop_front(), sizes.drop_front(),
2445 strides.drop_front(), outValues);
2449 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2450 /// folded operation might introduce more constant data; Users can control
2451 /// their heuristics by the control function.
2452 class ConstantOpExtractSliceFolder final
2453 : public OpRewritePattern<ExtractSliceOp> {
2454 public:
2455 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2457 ConstantOpExtractSliceFolder(MLIRContext *context,
2458 ControlConstantExtractSliceFusionFn controlFn)
2459 : OpRewritePattern<ExtractSliceOp>(context),
2460 controlFn(std::move(controlFn)) {}
2462 LogicalResult matchAndRewrite(ExtractSliceOp op,
2463 PatternRewriter &rewriter) const override {
2464 DenseElementsAttr attr;
2465 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2466 return failure();
2468 // A constant splat is handled by fold().
2469 if (attr.isSplat())
2470 return failure();
2472 // Dynamic result shape is not supported.
2473 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2474 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2475 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2476 return failure();
2478 // Customized control over the folding.
2479 if (!controlFn(op))
2480 return failure();
2482 int64_t count = sourceType.getNumElements();
2483 if (count == 0)
2484 return failure();
2486 // Check if there are any dynamic parts, which are not supported.
2487 auto offsets = op.getStaticOffsets();
2488 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2489 return failure();
2490 auto sizes = op.getStaticSizes();
2491 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2492 return failure();
2493 auto strides = op.getStaticStrides();
2494 if (llvm::is_contained(strides, ShapedType::kDynamic))
2495 return failure();
2497 // Compute the stride for each dimension.
2498 SmallVector<int64_t> counts;
2499 ArrayRef<int64_t> shape = sourceType.getShape();
2500 counts.reserve(shape.size());
2501 for (int64_t v : shape) {
2502 count = count / v;
2503 counts.push_back(count);
2506 // New attribute constructed by the sliced values.
2507 DenseElementsAttr newAttr;
2509 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2510 SmallVector<APInt> outValues;
2511 outValues.reserve(sourceType.getNumElements());
2512 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2513 elems.begin(), counts, offsets, sizes, strides, &outValues);
2514 newAttr = DenseElementsAttr::get(resultType, outValues);
2515 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2516 SmallVector<APFloat> outValues;
2517 outValues.reserve(sourceType.getNumElements());
2518 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2519 elems.begin(), counts, offsets, sizes, strides, &outValues);
2520 newAttr = DenseElementsAttr::get(resultType, outValues);
2523 if (newAttr) {
2524 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2525 return success();
2528 return failure();
2531 private:
2532 /// This additionally controls whether the fold happens or not. Users can
2533 /// impose their heuristics in the function.
2534 ControlConstantExtractSliceFusionFn controlFn;
2537 } // namespace
2539 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2540 RewritePatternSet &patterns,
2541 const ControlConstantExtractSliceFusionFn &controlFn) {
2542 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2545 /// Return the canonical type of the result of an extract_slice op.
2546 struct SliceReturnTypeCanonicalizer {
2547 RankedTensorType operator()(ExtractSliceOp op,
2548 ArrayRef<OpFoldResult> mixedOffsets,
2549 ArrayRef<OpFoldResult> mixedSizes,
2550 ArrayRef<OpFoldResult> mixedStrides) {
2551 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2552 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2553 mixedStrides);
2557 /// A canonicalizer wrapper to replace ExtractSliceOps.
2558 struct SliceCanonicalizer {
2559 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2560 ExtractSliceOp newOp) {
2561 Value replacement = newOp.getResult();
2562 if (replacement.getType() != op.getType())
2563 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2564 replacement);
2565 rewriter.replaceOp(op, replacement);
2569 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2570 MLIRContext *context) {
2571 results.add<
2572 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2573 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2574 ExtractSliceOpCastFolder>(context);
2578 static LogicalResult
2579 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2580 ShapedType shapedType) {
2581 OpBuilder b(op.getContext());
2582 for (OpFoldResult ofr : op.getMixedOffsets())
2583 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2584 return failure();
2585 // Rank-reducing noops only need to inspect the leading dimensions:
2586 // llvm::zip is appropriate.
2587 auto shape = shapedType.getShape();
2588 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2589 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2590 return failure();
2591 for (OpFoldResult ofr : op.getMixedStrides())
2592 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2593 return failure();
2594 return success();
2597 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2598 /// slice, we can return the InsertSliceOp's source directly.
2599 // TODO: This only checks the immediate producer; extend to go up the
2600 // insert/extract chain if the slices are disjoint.
2601 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2602 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2604 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2605 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2606 insertOp.isSameAs(extractOp, isSame))
2607 return insertOp.getSource();
2609 return {};
2612 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2613 if (OpFoldResult reshapedSource = reshapeConstantSource(
2614 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2615 getResult().getType()))
2616 return reshapedSource;
2617 if (getSourceType() == getType() &&
2618 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2619 return this->getSource();
2620 if (Value slice = foldExtractAfterInsertSlice(*this))
2621 return slice;
2623 return OpFoldResult();
2626 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2627 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2628 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2629 unsigned rank = rankedTensorType.getRank();
2630 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2631 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2632 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2633 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2634 offsets, sizes, strides);
2637 //===----------------------------------------------------------------------===//
2638 // InsertSliceOp
2639 //===----------------------------------------------------------------------===//
2641 void InsertSliceOp::getAsmResultNames(
2642 function_ref<void(Value, StringRef)> setNameFn) {
2643 setNameFn(getResult(), "inserted_slice");
2646 // Build a InsertSliceOp with mixed static and dynamic entries.
2647 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2648 Value dest, ArrayRef<OpFoldResult> offsets,
2649 ArrayRef<OpFoldResult> sizes,
2650 ArrayRef<OpFoldResult> strides,
2651 ArrayRef<NamedAttribute> attrs) {
2652 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2653 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2654 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2655 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2656 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2657 result.addAttributes(attrs);
2658 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2659 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2660 b.getDenseI64ArrayAttr(staticSizes),
2661 b.getDenseI64ArrayAttr(staticStrides));
2664 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2665 /// Range vector.
2666 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2667 Value dest, ArrayRef<Range> ranges,
2668 ArrayRef<NamedAttribute> attrs) {
2669 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2670 build(b, result, source, dest, offsets, sizes, strides, attrs);
2673 // Build a InsertSliceOp with dynamic entries.
2674 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2675 Value dest, ValueRange offsets, ValueRange sizes,
2676 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2677 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2678 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2679 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2680 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2681 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2682 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2683 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2686 /// Rank-reducing type verification for both InsertSliceOp and
2687 /// ParallelInsertSliceOp.
2688 static SliceVerificationResult verifyInsertSliceOp(
2689 RankedTensorType srcType, RankedTensorType dstType,
2690 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2691 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2692 // insert_slice is the inverse of extract_slice, use the same type
2693 // inference.
2694 RankedTensorType expected = ExtractSliceOp::inferResultType(
2695 dstType, staticOffsets, staticSizes, staticStrides);
2696 if (expectedType)
2697 *expectedType = expected;
2698 return isRankReducedType(expected, srcType);
2701 /// Verifier for InsertSliceOp.
2702 LogicalResult InsertSliceOp::verify() {
2703 RankedTensorType expectedType;
2704 SliceVerificationResult result =
2705 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2706 getStaticSizes(), getStaticStrides(), &expectedType);
2707 return produceSliceErrorMsg(result, *this, expectedType);
2710 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2711 /// can mutate the second InsertSliceOp's destination to the first one's.
2713 /// Example:
2715 /// ```mlir
2716 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2717 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2718 /// ```
2720 /// folds into:
2722 /// ```mlir
2723 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2724 /// ```
2726 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2727 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2728 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2730 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2731 if (!prevInsertOp ||
2732 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2733 !prevInsertOp.isSameAs(insertOp, isSame))
2734 return failure();
2736 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2737 return success();
2740 /// Folds round-trip extract/insert slice op pairs.
2741 /// Example:
2742 /// ```mlir
2743 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2744 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2745 /// ```
2746 /// can be folded into %val.
2747 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2748 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2750 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2751 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2752 !extractOp.isSameAs(insertOp, isSame))
2753 return nullptr;
2755 return extractOp.getSource();
2758 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2759 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2760 getSourceType() == getType() &&
2761 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2762 return this->getSource();
2763 if (succeeded(foldInsertAfterInsertSlice(*this)))
2764 return getResult();
2765 if (auto result = foldInsertAfterExtractSlice(*this))
2766 return result;
2767 if (llvm::any_of(getMixedSizes(),
2768 [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2769 return getDest();
2770 return OpFoldResult();
2773 LogicalResult InsertSliceOp::reifyResultShapes(
2774 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2775 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2776 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2777 return success();
2780 namespace {
2781 /// Pattern to rewrite a insert_slice op with constant arguments.
2783 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2784 template <typename InsertOpTy>
2785 class InsertSliceOpConstantArgumentFolder final
2786 : public OpRewritePattern<InsertOpTy> {
2787 public:
2788 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2790 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2791 PatternRewriter &rewriter) const override {
2792 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2793 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2794 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2796 // No constant operands were folded, just return;
2797 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2798 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2799 failed(foldDynamicStrideList(mixedStrides)))
2800 return failure();
2802 // Create the new op in canonical form.
2803 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2804 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2805 mixedOffsets, mixedSizes, mixedStrides);
2806 Value toInsert = insertSliceOp.getSource();
2807 if (sourceType != insertSliceOp.getSourceType()) {
2808 OpBuilder::InsertionGuard g(rewriter);
2809 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2810 // is that the insertion point is just before the ParallelCombiningOp in
2811 // the parallel case.
2812 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2813 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2814 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2815 sourceType, toInsert);
2817 rewriter.replaceOpWithNewOp<InsertOpTy>(
2818 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2819 mixedSizes, mixedStrides);
2820 return success();
2824 /// Fold tensor_casts with insert_slice operations. If the source or
2825 /// destination tensor is a tensor_cast that removes static type information,
2826 /// the cast is folded into the insert_slice operation. E.g.:
2828 /// ```mlir
2829 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2830 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2831 /// ```
2833 /// folds into:
2835 /// ```mlir
2836 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2837 /// ```
2839 /// Note: When folding a cast on the destination tensor, the result of the
2840 /// insert_slice operation is casted to ensure that the type of the result did
2841 /// not change.
2843 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2844 template <typename InsertOpTy>
2845 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2846 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2848 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2849 PatternRewriter &rewriter) const override {
2850 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2851 return matchPattern(operand, matchConstantIndex());
2853 return failure();
2855 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2856 auto castOp = v.getDefiningOp<tensor::CastOp>();
2857 if (!castOp || !canFoldIntoConsumerOp(castOp))
2858 return std::nullopt;
2859 return castOp.getSource();
2861 std::optional<Value> sourceCastSource =
2862 getSourceOfCastOp(insertSliceOp.getSource());
2863 std::optional<Value> destCastSource =
2864 getSourceOfCastOp(insertSliceOp.getDest());
2865 if (!sourceCastSource && !destCastSource)
2866 return failure();
2868 auto src =
2869 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2870 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2871 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2872 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2873 if (!srcType || !dstType)
2874 return failure();
2876 // The tensor.cast source could have additional static information not seen
2877 // in the insert slice op static sizes, so we ignore dynamic dims when
2878 // computing the rank reduction mask.
2879 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2880 auto rankReductionMask = computeRankReductionMask(
2881 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2882 if (!rankReductionMask.has_value())
2883 return failure();
2884 // Replace dimensions in the insert slice op with corresponding static dims
2885 // from the cast source type. If the insert slice sizes have static dims
2886 // that are not static in the tensor.cast source (i.e., when the cast op
2887 // casts a dynamic dim to static), the dim should not be replaced, and the
2888 // pattern will fail later in `verifyInsertSliceOp`.
2889 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2890 int64_t rankReducedIdx = 0;
2891 for (auto [idx, size] : enumerate(staticSizes)) {
2892 if (!rankReductionMask.value().contains(idx) &&
2893 !srcType.isDynamicDim(rankReducedIdx)) {
2894 mixedSizes[idx] = getAsIndexOpFoldResult(
2895 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2896 size = srcType.getDimSize(rankReducedIdx++);
2899 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2900 staticSizes, insertSliceOp.getStaticStrides()) !=
2901 SliceVerificationResult::Success)
2902 return failure();
2904 Operation *replacement = rewriter.create<InsertOpTy>(
2905 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2906 mixedSizes, insertSliceOp.getMixedStrides());
2908 // In the parallel case there is no result and so nothing to cast.
2909 bool isParallelInsert =
2910 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2911 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2912 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2913 insertSliceOp.getDestType(),
2914 replacement->getResult(0));
2916 rewriter.replaceOp(insertSliceOp, replacement->getResults());
2917 return success();
2921 /// If additional static type information can be deduced from a insert_slice's
2922 /// size operands, insert an explicit cast of the op's source operand. This
2923 /// enables other canonicalization patterns that are matching for tensor_cast
2924 /// ops such as `ForOpTensorCastFolder` in SCF.
2926 /// Example:
2928 /// ```mlir
2929 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2930 /// : tensor<?x?xf32> into ...
2931 /// ```
2933 /// folds into:
2935 /// ```mlir
2936 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2937 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2938 /// : tensor<64x64xf32> into ...
2939 /// ```
2941 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2942 template <typename InsertOpTy>
2943 struct InsertSliceOpSourceCastInserter final
2944 : public OpRewritePattern<InsertOpTy> {
2945 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2947 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2948 PatternRewriter &rewriter) const override {
2949 RankedTensorType srcType = insertSliceOp.getSourceType();
2950 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2951 return failure();
2952 SmallVector<int64_t> newSrcShape(srcType.getShape());
2953 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2954 if (std::optional<int64_t> constInt =
2955 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2956 // Bail on invalid IR.
2957 if (*constInt < 0)
2958 return failure();
2959 newSrcShape[i] = *constInt;
2962 if (!hasValidSizesOffsets(newSrcShape))
2963 return failure();
2965 RankedTensorType newSrcType = RankedTensorType::get(
2966 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2967 if (srcType == newSrcType ||
2968 !preservesStaticInformation(srcType, newSrcType) ||
2969 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2970 return failure();
2972 // newSrcType is:
2973 // 1) Different from srcType.
2974 // 2) "More static" than srcType.
2975 // 3) Cast-compatible with srcType.
2976 // Insert the cast.
2977 OpBuilder::InsertionGuard g(rewriter);
2978 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2979 // that the insertion point is just before the ParallelCombiningOp in the
2980 // parallel case.
2981 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2982 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2983 Value cast = rewriter.create<tensor::CastOp>(
2984 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2985 rewriter.replaceOpWithNewOp<InsertOpTy>(
2986 insertSliceOp, cast, insertSliceOp.getDest(),
2987 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2988 insertSliceOp.getMixedStrides());
2989 return success();
2992 } // namespace
2994 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2995 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2998 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2999 MLIRContext *context) {
3000 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3001 InsertSliceOpCastFolder<InsertSliceOp>,
3002 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3005 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
3006 Location loc,
3007 Value tensor,
3008 Value dest) {
3009 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3010 unsigned rank = rankedTensorType.getRank();
3011 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3012 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3013 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3014 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3015 sizes, strides);
3018 //===----------------------------------------------------------------------===//
3019 // PadOp
3020 //===----------------------------------------------------------------------===//
3022 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3023 setNameFn(getResult(), "padded");
3026 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3027 // supports optional types.
3028 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3029 Type typeToInfer, Type typeToInferFrom) {}
3031 ParseResult
3032 parseInferType(OpAsmParser &parser,
3033 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3034 Type &typeToInfer, Type typeToInferFrom) {
3035 if (optOperand)
3036 typeToInfer = typeToInferFrom;
3037 return success();
3040 LogicalResult PadOp::verify() {
3041 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3042 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3043 auto expectedType =
3044 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3045 if (!expectedType) {
3046 return emitError("failed to infer expectedType from sourceType ")
3047 << sourceType << ", specified resultType is " << resultType;
3049 if (resultType.getRank() != expectedType.getRank()) {
3050 return emitError("specified type ")
3051 << resultType << " does not match the inferred type "
3052 << expectedType;
3054 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3055 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3056 continue;
3057 if (expectedType.isDynamicDim(i))
3058 continue;
3059 return emitError("specified type ")
3060 << resultType << " does not match the inferred type "
3061 << expectedType;
3064 return success();
3067 LogicalResult PadOp::verifyRegions() {
3068 auto &region = getRegion();
3069 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3070 Block &block = region.front();
3071 if (block.getNumArguments() != rank)
3072 return emitError("expected the block to have ") << rank << " arguments";
3074 // Note: the number and type of yield values are checked in the YieldOp.
3075 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3076 if (!en.value().isIndex())
3077 return emitOpError("expected block argument ")
3078 << (en.index() + 1) << " to be an index";
3081 // Ensure that the region yields an element of the right type.
3082 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3083 if (yieldOp.getValue().getType() !=
3084 llvm::cast<ShapedType>(getType()).getElementType())
3085 return emitOpError("expected yield type to match shape element type");
3087 return success();
3090 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3091 ArrayRef<int64_t> staticLow,
3092 ArrayRef<int64_t> staticHigh,
3093 ArrayRef<int64_t> resultShape) {
3094 unsigned rank = sourceType.getRank();
3095 if (staticLow.size() != rank)
3096 return RankedTensorType();
3097 if (staticHigh.size() != rank)
3098 return RankedTensorType();
3099 if (!resultShape.empty() && resultShape.size() != rank)
3100 return RankedTensorType();
3102 SmallVector<int64_t, 4> inferredShape;
3103 for (auto i : llvm::seq<unsigned>(0, rank)) {
3104 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3105 staticHigh[i] == ShapedType::kDynamic) {
3106 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3107 : resultShape[i]);
3108 } else {
3109 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3110 assert((resultShape.empty() || size == resultShape[i] ||
3111 resultShape[i] == ShapedType::kDynamic) &&
3112 "mismatch between inferred shape and result shape");
3113 inferredShape.push_back(size);
3117 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3120 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3121 Value source, ArrayRef<int64_t> staticLow,
3122 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3123 bool nofold, ArrayRef<NamedAttribute> attrs) {
3124 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3125 if (!resultType)
3126 resultType = inferResultType(sourceType, staticLow, staticHigh);
3127 result.addAttributes(attrs);
3128 build(b, result, resultType, source, low, high,
3129 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3130 nofold ? b.getUnitAttr() : UnitAttr());
3133 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3134 Value source, ValueRange low, ValueRange high, bool nofold,
3135 ArrayRef<NamedAttribute> attrs) {
3136 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3137 unsigned rank = sourceType.getRank();
3138 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3139 build(b, result, resultType, source, staticVector, staticVector, low, high,
3140 nofold, attrs);
3143 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3144 Value source, ArrayRef<OpFoldResult> low,
3145 ArrayRef<OpFoldResult> high, bool nofold,
3146 ArrayRef<NamedAttribute> attrs) {
3147 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3148 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3149 SmallVector<int64_t, 4> staticLow, staticHigh;
3150 // staticLow and staticHigh have full information of the padding config.
3151 // This will grow staticLow and staticHigh with 1 value. If the config is
3152 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3153 // value as well.
3154 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3155 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3156 if (!resultType) {
3157 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3159 assert(llvm::isa<RankedTensorType>(resultType));
3160 result.addAttributes(attrs);
3161 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3162 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3163 nofold ? b.getUnitAttr() : UnitAttr());
3166 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3167 Value source, ArrayRef<OpFoldResult> low,
3168 ArrayRef<OpFoldResult> high, Value constantPadValue,
3169 bool nofold, ArrayRef<NamedAttribute> attrs) {
3170 build(b, result, resultType, source, low, high, nofold, attrs);
3172 // Add a region and a block to yield the pad value.
3173 Region *region = result.regions[0].get();
3174 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3175 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3176 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3178 // `builder.createBlock` changes the insertion point within the block. Create
3179 // a guard to reset the insertion point of the builder after it is destroyed.
3180 OpBuilder::InsertionGuard guard(b);
3181 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3182 b.create<tensor::YieldOp>(result.location, constantPadValue);
3185 llvm::SmallBitVector PadOp::getPaddedDims() {
3186 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3187 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3188 for (const auto &en : enumerate(paddingWidths))
3189 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3190 paddedDims.set(en.index());
3192 extractPaddedDims(getMixedLowPad());
3193 extractPaddedDims(getMixedHighPad());
3194 return paddedDims;
3197 namespace {
3198 // Folds tensor.pad when padding is static zeros and the attribute
3199 // doesn't request otherwise.
3200 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3201 using OpRewritePattern<PadOp>::OpRewritePattern;
3203 LogicalResult matchAndRewrite(PadOp padTensorOp,
3204 PatternRewriter &rewriter) const override {
3205 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3206 return failure();
3207 if (padTensorOp.getNofold())
3208 return failure();
3209 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3210 padTensorOp, padTensorOp.getResult().getType(),
3211 padTensorOp.getSource());
3212 return success();
3216 // Fold CastOp into PadOp when adding static information.
3217 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3218 using OpRewritePattern<PadOp>::OpRewritePattern;
3220 LogicalResult matchAndRewrite(PadOp padTensorOp,
3221 PatternRewriter &rewriter) const override {
3222 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3223 if (!tensor::canFoldIntoConsumerOp(castOp))
3224 return failure();
3226 auto newResultType = PadOp::inferResultType(
3227 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3228 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3229 padTensorOp.getResultType().getShape());
3231 if (newResultType == padTensorOp.getResultType()) {
3232 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3233 padTensorOp.getSourceMutable().assign(castOp.getSource());
3235 } else {
3236 auto newOp = rewriter.create<PadOp>(
3237 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3238 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3239 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3240 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3241 IRMapping mapper;
3242 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3244 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3245 padTensorOp, padTensorOp.getResultType(), newOp);
3247 return success();
3251 // Fold CastOp using the result of PadOp back into the latter if it adds
3252 // static information.
3253 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3254 using OpRewritePattern<PadOp>::OpRewritePattern;
3256 LogicalResult matchAndRewrite(PadOp padTensorOp,
3257 PatternRewriter &rewriter) const override {
3258 if (!padTensorOp.getResult().hasOneUse())
3259 return failure();
3260 auto tensorCastOp =
3261 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3262 if (!tensorCastOp)
3263 return failure();
3264 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3265 tensorCastOp.getDest().getType()))
3266 return failure();
3268 auto replacementOp = rewriter.create<PadOp>(
3269 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3270 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3271 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3272 padTensorOp.getHigh(), padTensorOp.getNofold(),
3273 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3274 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3276 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3277 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3278 return success();
3282 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3283 /// different dimensions. The pattern applies if the following preconditions
3284 /// hold:
3285 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3286 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3287 /// 3) the tensor::PadOps perform only high-padding,
3288 /// 4) the tensor::PadOps have the same constant padding value,
3289 /// 5) the tensor::PadOps do not have common padding dimensions,
3290 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3291 /// zero-offset for every dimension.
3292 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3293 /// the
3294 /// padded source dimensions.
3296 /// Example:
3298 /// ```mlir
3299 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3300 /// : tensor<64x64xf32> to tensor<?x64xf32>
3301 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3302 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3303 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3304 /// : tensor<8x64xf32> to tensor<8x?xf32>
3305 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3306 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3307 /// ```
3309 /// folds into:
3311 /// ```mlir
3312 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3313 /// : tensor<64x64xf32> to tensor<?x?xf32>
3314 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3315 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3316 /// ```
3317 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3318 using OpRewritePattern<PadOp>::OpRewritePattern;
3320 LogicalResult matchAndRewrite(PadOp padOp,
3321 PatternRewriter &rewriter) const override {
3322 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3323 if (!innerSliceOp)
3324 return failure();
3325 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3326 if (!outerPadOp || outerPadOp.getNofold())
3327 return failure();
3328 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3329 if (!outerSliceOp)
3330 return failure();
3332 // 1) Fail if the chain is rank-reducing.
3333 int64_t rank = padOp.getSourceType().getRank();
3334 if (outerSliceOp.getSourceType().getRank() != rank) {
3335 return rewriter.notifyMatchFailure(padOp,
3336 "cannot fold rank-reducing chain");
3339 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3340 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3341 return rewriter.notifyMatchFailure(
3342 padOp, "cannot fold non-unit stride ExtractSliceOps");
3345 // 3) Fail if the tensor::PadOps have non-zero low padding.
3346 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3347 return rewriter.notifyMatchFailure(padOp,
3348 "cannot fold PadOps with low padding");
3351 // 4) Fail if the tensor::PadOps padding values do not match.
3352 Attribute innerAttr, outerAttr;
3353 Value innerValue = padOp.getConstantPaddingValue();
3354 Value outerValue = outerPadOp.getConstantPaddingValue();
3355 if (!innerValue || !outerValue ||
3356 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3357 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3358 innerAttr != outerAttr) {
3359 return rewriter.notifyMatchFailure(
3360 padOp, "cannot fold PadOps with different padding values");
3363 // 5) Fail if a dimension is padded by both tensor::PadOps.
3364 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3365 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3366 if (innerDims.anyCommon(outerDims)) {
3367 return rewriter.notifyMatchFailure(
3368 padOp, "cannot fold PadOps with common padding dimensions");
3371 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3372 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3373 // for every dimension, and use the offset the other pair. Fail if no
3374 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3375 // exists.
3376 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3377 for (auto en : enumerate(newOffsets)) {
3378 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3379 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3380 if (!innerDims.test(en.index()) &&
3381 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3382 en.value() = outerOffset;
3383 continue;
3385 if (!outerDims.test(en.index()) &&
3386 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3387 en.value() = innerOffset;
3388 continue;
3390 return rewriter.notifyMatchFailure(
3391 padOp, "cannot find zero-offset and zero-padding pair");
3394 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3395 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3396 // outer tensor::PadOp and fail if the size of the inner
3397 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3398 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3399 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3400 for (auto en : enumerate(newSizes)) {
3401 if (!outerDims.test(en.index()))
3402 continue;
3403 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3404 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3405 assert(!ShapedType::isDynamic(sourceSize) &&
3406 "expected padded dimension to have a static size");
3407 if (getConstantIntValue(sliceSize) != sourceSize) {
3408 return rewriter.notifyMatchFailure(
3409 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3410 "match the size of the outer padding");
3412 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3415 // Combine the high paddings of the two tensor::PadOps.
3416 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3417 for (auto en : enumerate(newHighPad)) {
3418 if (innerDims.test(en.index()))
3419 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3420 if (outerDims.test(en.index()))
3421 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3424 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3425 // the two paddings in one step.
3426 auto newSliceOp = rewriter.create<ExtractSliceOp>(
3427 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3428 innerSliceOp.getMixedStrides());
3429 auto newPadOp = rewriter.create<PadOp>(
3430 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3431 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3432 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3433 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3434 newPadOp.getRegion().begin());
3435 rewriter.replaceOp(padOp, newPadOp.getResult());
3436 return success();
3440 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3441 using OpRewritePattern<PadOp>::OpRewritePattern;
3443 LogicalResult matchAndRewrite(PadOp padTensorOp,
3444 PatternRewriter &rewriter) const override {
3445 Value input = padTensorOp.getSource();
3446 if (!llvm::isa<RankedTensorType>(input.getType()))
3447 return failure();
3448 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3449 auto inputRank = inputDims.size();
3451 auto oldResultType =
3452 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3453 if (!oldResultType)
3454 return failure();
3456 auto outputDims = oldResultType.getShape();
3458 // Extract the static info from the high and low operands.
3459 SmallVector<int64_t> constOperandsLow;
3460 SmallVector<Value> newLows;
3461 for (auto operand : padTensorOp.getLow()) {
3462 APSInt intOp;
3463 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3464 constOperandsLow.push_back(ShapedType::kDynamic);
3465 newLows.push_back(operand);
3466 continue;
3468 constOperandsLow.push_back(intOp.getExtValue());
3470 SmallVector<int64_t> constOperandsHigh;
3471 SmallVector<Value> newHighs;
3472 for (auto operand : padTensorOp.getHigh()) {
3473 APSInt intOp;
3474 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3475 constOperandsHigh.push_back(ShapedType::kDynamic);
3476 newHighs.push_back(operand);
3477 continue;
3479 constOperandsHigh.push_back(intOp.getExtValue());
3482 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3483 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3485 // Verify the op is well-formed.
3486 if (inputDims.size() != outputDims.size() ||
3487 inputDims.size() != constLow.size() ||
3488 inputDims.size() != constHigh.size())
3489 return failure();
3491 auto lowCount = 0;
3492 auto highCount = 0;
3493 for (size_t i = 0; i < inputRank; i++) {
3494 if (constLow[i] == ShapedType::kDynamic)
3495 constLow[i] = constOperandsLow[lowCount++];
3496 if (constHigh[i] == ShapedType::kDynamic)
3497 constHigh[i] = constOperandsHigh[highCount++];
3500 auto staticLow = ArrayRef<int64_t>(constLow);
3501 auto staticHigh = ArrayRef<int64_t>(constHigh);
3503 // Calculate the output sizes with the static information.
3504 SmallVector<int64_t> newOutDims;
3505 for (size_t i = 0; i < inputRank; i++) {
3506 if (outputDims[i] == ShapedType::kDynamic) {
3507 newOutDims.push_back(
3508 (staticLow[i] == ShapedType::kDynamic ||
3509 staticHigh[i] == ShapedType::kDynamic ||
3510 inputDims[i] == ShapedType::kDynamic
3511 ? ShapedType::kDynamic
3512 : inputDims[i] + staticLow[i] + staticHigh[i]));
3513 } else {
3514 newOutDims.push_back(outputDims[i]);
3518 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3519 llvm::all_of(newOutDims,
3520 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3521 return failure();
3523 // Rewrite the op using the new static type.
3524 auto newResultType = RankedTensorType::get(
3525 newOutDims, padTensorOp.getType().getElementType());
3526 auto newOp = rewriter.create<PadOp>(
3527 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3528 newLows, newHighs, padTensorOp.getNofold(),
3529 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3531 IRMapping mapper;
3532 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3533 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3534 newOp);
3536 return success();
3540 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3542 /// Example:
3544 /// ```mlir
3545 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3546 /// tensor.yield %val
3547 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3548 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3549 /// tensor.yield %val
3550 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3551 /// ```
3553 /// folds into:
3555 /// ```mlir
3556 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3557 /// tensor.yield %val
3558 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3559 /// ```
3560 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3561 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3563 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3564 PatternRewriter &rewriter) const override {
3565 if (padOp.getNofold()) {
3566 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3569 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3570 if (!producerPad || producerPad.getNofold()) {
3571 return rewriter.notifyMatchFailure(
3572 padOp, "producer is not a foldable tensor.pad op");
3575 // Fail if the tensor::PadOps padding values do not match.
3576 Value consumerPadValue = padOp.getConstantPaddingValue();
3577 Value producerPadValue = producerPad.getConstantPaddingValue();
3578 if (!consumerPadValue || !producerPadValue ||
3579 consumerPadValue != producerPadValue) {
3580 return rewriter.notifyMatchFailure(
3581 padOp,
3582 "cannot fold PadOps with different or non-constant padding values");
3585 Location loc = padOp.getLoc();
3586 AffineExpr d0, d1;
3587 bindDims(rewriter.getContext(), d0, d1);
3589 // Combine the low/high paddings of the two tensor::PadOps.
3590 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3591 ArrayRef<OpFoldResult> producerPaddings) {
3592 SmallVector<OpFoldResult> sumPaddings;
3593 for (auto [consumerIndex, producerIndex] :
3594 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3595 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3596 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3598 return sumPaddings;
3601 SmallVector<OpFoldResult> newHighPad =
3602 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3603 SmallVector<OpFoldResult> newLowPad =
3604 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3606 auto newPadOp = rewriter.create<tensor::PadOp>(
3607 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3608 newLowPad, newHighPad, padOp.getNofold(),
3609 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3610 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3611 newPadOp.getRegion().begin());
3612 rewriter.replaceOp(padOp, newPadOp.getResult());
3613 return success();
3617 } // namespace
3619 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3620 MLIRContext *context) {
3621 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3622 FoldOrthogonalPaddings, FoldStaticPadding,
3623 FoldConsecutiveConstantPadding>(context);
3626 /// Return the padding value of the PadOp if it constant. In this context,
3627 /// "constant" means an actual constant or "defined outside of the block".
3629 /// Values are considered constant in three cases:
3630 /// - A ConstantLike value.
3631 /// - A basic block argument from a different block.
3632 /// - A value defined outside of the block.
3634 /// If the padding value is not constant, an empty Value is returned.
3635 Value PadOp::getConstantPaddingValue() {
3636 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3637 if (!yieldOp)
3638 return {};
3639 Value padValue = yieldOp.getValue();
3640 // Check if yield value is a constant.
3641 if (matchPattern(padValue, m_Constant()))
3642 return padValue;
3643 // Check if yield value is defined inside the PadOp block.
3644 if (padValue.getParentBlock() == &getRegion().front())
3645 return {};
3646 // Else: Yield value defined outside of the PadOp block.
3647 return padValue;
3650 OpFoldResult PadOp::fold(FoldAdaptor) {
3651 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3652 !getNofold())
3653 return getSource();
3654 return {};
3657 //===----------------------------------------------------------------------===//
3658 // ParallelInsertSliceOp
3659 //===----------------------------------------------------------------------===//
3661 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3662 ParallelCombiningOpInterface parallelCombiningParent =
3663 getParallelCombiningParent();
3664 for (const auto &it :
3665 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3666 Operation &nextOp = it.value();
3667 if (&nextOp == getOperation())
3668 return parallelCombiningParent.getParentResult(it.index());
3670 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3673 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3674 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3675 Value source, Value dest,
3676 ArrayRef<OpFoldResult> offsets,
3677 ArrayRef<OpFoldResult> sizes,
3678 ArrayRef<OpFoldResult> strides,
3679 ArrayRef<NamedAttribute> attrs) {
3680 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3681 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3682 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3683 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3684 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3685 result.addAttributes(attrs);
3686 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3687 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3688 b.getDenseI64ArrayAttr(staticSizes),
3689 b.getDenseI64ArrayAttr(staticStrides));
3692 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3693 /// packed into a Range vector.
3694 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3695 Value source, Value dest,
3696 ArrayRef<Range> ranges,
3697 ArrayRef<NamedAttribute> attrs) {
3698 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3699 build(b, result, source, dest, offsets, sizes, strides, attrs);
3702 // Build a ParallelInsertSliceOp with dynamic entries.
3703 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3704 Value source, Value dest, ValueRange offsets,
3705 ValueRange sizes, ValueRange strides,
3706 ArrayRef<NamedAttribute> attrs) {
3707 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3708 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3709 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3710 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3711 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3712 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3713 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3716 LogicalResult ParallelInsertSliceOp::verify() {
3717 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3718 return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3719 << *(getOperation()->getParentOp());
3721 RankedTensorType expectedType;
3722 SliceVerificationResult result =
3723 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3724 getStaticSizes(), getStaticStrides(), &expectedType);
3725 return produceSliceErrorMsg(result, *this, expectedType);
3728 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3729 RewritePatternSet &results, MLIRContext *context) {
3730 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3731 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3732 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3735 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3736 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3739 //===----------------------------------------------------------------------===//
3740 // ScatterOp
3741 //===----------------------------------------------------------------------===//
3743 void ScatterOp::getAsmResultNames(
3744 function_ref<void(Value, StringRef)> setNameFn) {
3745 setNameFn(getResult(), "scatter");
3748 LogicalResult ScatterOp::verify() {
3749 int64_t destRank = getDestType().getRank();
3750 ArrayRef<int64_t> scatterDims = getScatterDims();
3751 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3752 getIndicesType().getShape(), destRank,
3753 "scatter", "dest")))
3754 return failure();
3756 if (!getUnique())
3757 return emitOpError("requires 'unique' attribute to be set");
3758 // TODO: we could also check statically that there are fewer leading index
3759 // tensor dims than the dest dims. If this is not the case, the unique
3760 // attribute cannot be true.
3762 // Use the GatherOp::inferResultType on the `dest` type and verify the
3763 // expected type matches the source type.
3764 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3765 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3766 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3767 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3768 if (getSourceType() != expectedSourceType &&
3769 getSourceType() != expectedRankReducedSourceType) {
3770 return emitOpError("source type "
3771 "mismatch: "
3772 "expected ")
3773 << expectedSourceType << " or its rank-reduced variant "
3774 << expectedRankReducedSourceType << " (got: " << getSourceType()
3775 << ")";
3778 return success();
3781 //===----------------------------------------------------------------------===//
3782 // SplatOp
3783 //===----------------------------------------------------------------------===//
3785 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3786 Type aggregateType, ValueRange dynamicSizes) {
3787 build(builder, result, aggregateType, element, dynamicSizes);
3790 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3791 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3792 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3793 build(builder, result, aggregateType, element, dynamicSizes);
3796 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3797 ArrayRef<OpFoldResult> sizes) {
3798 SmallVector<int64_t> staticShape;
3799 SmallVector<Value> dynamicSizes;
3800 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3801 build(builder, result, element, staticShape, dynamicSizes);
3804 void SplatOp::getAsmResultNames(
3805 function_ref<void(Value, StringRef)> setNameFn) {
3806 setNameFn(getResult(), "splat");
3809 LogicalResult SplatOp::verify() {
3810 if (getType().getNumDynamicDims() != getDynamicSizes().size())
3811 return emitOpError("incorrect number of dynamic sizes, has ")
3812 << getDynamicSizes().size() << ", expected "
3813 << getType().getNumDynamicDims();
3814 return success();
3817 LogicalResult
3818 SplatOp::reifyResultShapes(OpBuilder &builder,
3819 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3820 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3821 unsigned ctr = 0;
3822 for (int64_t i = 0; i < getType().getRank(); ++i) {
3823 if (getType().isDynamicDim(i)) {
3824 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3825 } else {
3826 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3829 return success();
3832 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3833 auto constOperand = adaptor.getInput();
3834 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3835 return {};
3837 // Do not fold if the splat is not statically shaped
3838 if (!getType().hasStaticShape())
3839 return {};
3841 // SplatElementsAttr::get treats single value for second arg as being a
3842 // splat.
3843 return SplatElementsAttr::get(getType(), {constOperand});
3846 //===----------------------------------------------------------------------===//
3847 // PackOp/UnPackOp Common
3848 //===----------------------------------------------------------------------===//
3850 template <typename OpTy>
3851 static LogicalResult
3852 reifyResultShapesImpl(OpTy op, OpBuilder &builder,
3853 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3854 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3855 "applies to only pack or unpack operations");
3856 int64_t destRank = op.getDestRank();
3857 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3858 reifiedReturnShapes[0] =
3859 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3860 return success();
3863 template <typename OpTy>
3864 static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
3865 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3866 "applies to only pack or unpack operations");
3867 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3868 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3869 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3870 assert(tiles.size() == dimsToTile.size() &&
3871 "tiles must match indices of dimension to block");
3872 // bind the dimension `i` with the tile factor.
3873 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3874 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3875 return dimAndTileMapping;
3878 template <typename OpTy>
3879 static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
3880 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3881 "applies to only pack or unpack operations");
3882 Builder builder(op);
3883 SmallVector<OpFoldResult> mixedInnerTiles;
3884 unsigned dynamicValIndex = 0;
3885 for (int64_t staticTile : op.getStaticInnerTiles()) {
3886 if (!ShapedType::isDynamic(staticTile))
3887 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3888 else
3889 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3891 return mixedInnerTiles;
3894 template <typename OpTy>
3895 static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
3896 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3897 "applies to only pack or unpack operations");
3898 SmallVector<Value> dynamicTiles;
3899 SmallVector<int64_t> staticTiles;
3900 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3901 return staticTiles;
3904 /// Returns true if `dimsPos` is invalid. It is invalid when:
3905 /// a) It contains duplicate.
3906 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3907 /// c) The number of elements in `dimsPos` is > than `rank`.
3908 static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
3909 size_t rank) {
3910 size_t dimsPosSize = dimsPos.size();
3911 if (dimsPosSize > rank)
3912 return true;
3913 DenseSet<int64_t> uniqued;
3914 for (int64_t dim : dimsPos)
3915 uniqued.insert(dim);
3916 if (dimsPosSize != uniqued.size())
3917 return true;
3918 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3919 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3923 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3924 /// of the `limitShape`.
3925 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3926 ArrayRef<int64_t> limitShape) {
3927 assert(
3928 sourceShape.size() == limitShape.size() &&
3929 "expected source shape rank, and limit of the shape to have same rank");
3930 return llvm::all_of(
3931 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3932 int64_t sourceExtent = std::get<0>(it);
3933 int64_t limit = std::get<1>(it);
3934 return ShapedType::isDynamic(sourceExtent) ||
3935 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3939 template <typename OpTy>
3940 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
3941 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3942 "applies to only pack or unpack operations");
3943 Operation *op = packOrUnPack.getOperation();
3945 // Return true if we have a zero-value tile.
3946 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3947 return llvm::any_of(
3948 tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3951 // Verify tiles. Do not allow zero tiles.
3952 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3953 if (hasZeros(mixedTiles))
3954 return op->emitError("invalid zero tile factor");
3956 // Verify inner_dims_pos and outer_dims_perm.
3957 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3958 ? packOrUnPack.getSourceType()
3959 : packOrUnPack.getDestType();
3960 size_t unpackedRank = unpackedType.getRank();
3961 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3962 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3963 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3964 return op->emitError("invalid inner_dims_pos vector");
3965 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3966 return op->emitError("invalid outer_dims_perm vector");
3967 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3968 return op->emitError("outer_dims_perm must be a permutation or empty");
3970 // Tiling factors must be less than or equal to the input rank for pack (or
3971 // output rank for unpack), and must match the number of `inner_dims_pos`.
3972 if (mixedTiles.size() > unpackedRank) {
3973 return op->emitError("tiling factors must be less than or equal to the "
3974 "input rank for pack or output rank for unpack");
3976 if (mixedTiles.size() != innerDimsPos.size()) {
3977 return op->emitError(
3978 "tiling factors must equal the number of dimensions to tile");
3981 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3982 ? packOrUnPack.getDestType()
3983 : packOrUnPack.getSourceType();
3984 size_t packedRank = packedType.getRank();
3985 // Require output rank to match input rank + number of blocking factors.
3986 if (unpackedRank + mixedTiles.size() != packedRank) {
3987 return op->emitError(
3988 "packed rank must equal unpacked rank + tiling factors");
3991 // Verify result shape is greater than the minimum expected
3992 // by the pack operation, and that the output shape
3993 // represents full tiles.
3994 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3995 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3996 if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3997 return op->emitError("the shape of output is not large enough to hold the "
3998 "packed data. Expected at least ")
3999 << expectedPackedType << ", got " << packedType;
4001 if (!llvm::all_of(
4002 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4003 mixedTiles),
4004 [](std::tuple<int64_t, OpFoldResult> it) {
4005 int64_t shape = std::get<0>(it);
4006 if (Attribute attr =
4007 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4008 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4009 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4010 return shape == staticTileSize;
4012 return ShapedType::isDynamic(shape);
4013 })) {
4014 return op->emitError("mismatch in inner tile sizes specified and shaped of "
4015 "tiled dimension in the packed type");
4017 return success();
4020 namespace {
4021 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4022 /// various permutations to the op.
4023 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4024 // these. These may or may not become true foldings / canonicalizations
4025 // depending on how aggressive we want to be in automatically folding
4026 // transposes.
4027 struct PackOrUnPackTransposeResult {
4028 SmallVector<int64_t> innerDimsPos;
4029 SmallVector<OpFoldResult> innerTiles;
4030 SmallVector<int64_t> outerDimsPerm;
4032 } // namespace
4034 template <typename OpTy>
4035 static PackOrUnPackTransposeResult
4036 commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
4037 ArrayRef<int64_t> innerPermutation,
4038 ArrayRef<int64_t> outerPermutation) {
4039 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4040 "applies to only pack or unpack operations");
4041 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4042 "some permutation must be non-empty");
4043 PackOrUnPackTransposeResult metadata;
4044 metadata.innerDimsPos =
4045 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4046 metadata.innerTiles =
4047 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4048 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4049 ? packOrUnPackOp.getSourceRank()
4050 : packOrUnPackOp.getDestRank();
4051 metadata.outerDimsPerm =
4052 packOrUnPackOp.getOuterDimsPerm().empty()
4053 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4054 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4055 if (!innerPermutation.empty()) {
4056 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4057 isPermutationVector(innerPermutation) &&
4058 "invalid inner permutation");
4059 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4060 applyPermutationToVector(metadata.innerTiles, innerPermutation);
4062 if (!outerPermutation.empty()) {
4063 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4064 isPermutationVector(outerPermutation) &&
4065 "invalid outer permutation");
4066 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4068 return metadata;
4071 //===----------------------------------------------------------------------===//
4072 // PackOp
4073 //===----------------------------------------------------------------------===//
4075 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4076 setNameFn(getResult(), "pack");
4079 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4080 Value dest, ArrayRef<int64_t> innerDimsPos,
4081 ArrayRef<OpFoldResult> innerTiles,
4082 std::optional<Value> paddingValue,
4083 ArrayRef<int64_t> outerDimsPerm) {
4084 assert(innerDimsPos.size() == innerTiles.size() &&
4085 "number of tile sizes specified must match the specified number of "
4086 "original dimensions to be tiled");
4087 SmallVector<int64_t> staticTileSizes;
4088 SmallVector<Value> dynamicTileSizes;
4089 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4090 build(builder, state, dest.getType(), source, dest,
4091 paddingValue ? *paddingValue : nullptr,
4092 outerDimsPerm.empty() ? nullptr
4093 : builder.getDenseI64ArrayAttr(outerDimsPerm),
4094 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4095 builder.getDenseI64ArrayAttr(staticTileSizes));
4098 LogicalResult
4099 PackOp::reifyResultShapes(OpBuilder &builder,
4100 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4101 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4104 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4105 return getDimAndTileMappingImpl(*this);
4108 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4109 return getMixedTilesImpl(*this);
4112 SmallVector<int64_t> PackOp::getStaticTiles() {
4113 return getStaticTilesImpl(*this);
4116 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4117 ShapedType inputType = getSourceType();
4118 int64_t inputRank = inputType.getRank();
4119 return getDestType().getShape().take_front(inputRank);
4122 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4123 auto innerDimsPos = getInnerDimsPos();
4124 auto packedShape = getDestType().getShape();
4125 SmallVector<int64_t> res;
4127 for (auto index : innerDimsPos)
4128 res.push_back(packedShape[index]);
4130 return res;
4133 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4134 ArrayRef<int64_t> innerDimsPos,
4135 ArrayRef<int64_t> outputShape,
4136 ArrayRef<int64_t> outerDimsPerm,
4137 ArrayRef<OpFoldResult> innerTiles) {
4138 SmallVector<int64_t> outputTileSizes(
4139 outputShape.take_front(inputShape.size()));
4140 if (!outerDimsPerm.empty()) {
4141 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4142 "expected output and outer_dims_perm to have same size");
4143 applyPermutationToVector(outputTileSizes,
4144 invertPermutationVector(outerDimsPerm));
4146 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4147 if (ShapedType::isDynamic(inputShape[pos]))
4148 continue;
4149 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4151 if (!constantTile) {
4152 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4153 (inputShape[pos] % outputTileSizes[pos] != 0))
4154 return true;
4155 } else if (inputShape[pos] % (*constantTile) != 0) {
4156 return true;
4159 return false;
4162 LogicalResult PackOp::verify() {
4163 if (failed(commonVerifierPackAndUnPackOp(*this)))
4164 return failure();
4166 // Verify padding value, and bail out if the tile does not divide the
4167 // dimension fully. In the case of dynamic tile factors or dimensions, having
4168 // a partial tile is undefined behavior.
4169 auto paddingValue = getPaddingValue();
4170 if (paddingValue &&
4171 paddingValue.getType() != getSourceType().getElementType()) {
4172 return emitOpError("expected padding_value has ")
4173 << getSourceType().getElementType()
4174 << " but got: " << paddingValue.getType();
4177 if (!paddingValue &&
4178 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4179 getDestType().getShape(), getOuterDimsPerm(),
4180 getMixedTiles())) {
4181 return emitOpError(
4182 "invalid tile factor or output size provided. Only full tiles are "
4183 "supported when padding_value is not set");
4185 return success();
4188 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4189 /// Value's to kDynamic, even if they are arith.constant values.
4190 static SmallVector<int64_t>
4191 asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
4192 SmallVector<int64_t> result;
4193 for (auto o : ofrs) {
4194 // Have to do this first, as getConstantIntValue special-cases constants.
4195 if (llvm::dyn_cast_if_present<Value>(o))
4196 result.push_back(ShapedType::kDynamic);
4197 else
4198 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4200 return result;
4203 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4204 /// the packed type. Having a shared helper helps implement these two methods in
4205 /// a way that ensures that they agree on which dimensions are dynamic.
4206 static SmallVector<int64_t> getPackOpResultTypeShape(
4207 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4208 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4209 SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4210 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4211 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4212 continue;
4213 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4214 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4215 continue;
4217 resultShape[tiledDim.value()] = divideCeilSigned(
4218 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4221 // Swap tile loops if outer_dims_perm is available.
4222 if (!outerDimsPerm.empty())
4223 applyPermutationToVector(resultShape, outerDimsPerm);
4225 // Append the inner tile dimensions.
4226 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4227 return resultShape;
4230 SmallVector<OpFoldResult> PackOp::getResultShape(
4231 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4232 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4233 ArrayRef<int64_t> outerDimsPerm) {
4234 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4236 AffineExpr s0, s1;
4237 bindSymbols(builder.getContext(), s0, s1);
4238 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4239 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4240 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4241 builder, loc, ceilDivExpr,
4242 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4244 if (!outerDimsPerm.empty())
4245 applyPermutationToVector(resultDims, outerDimsPerm);
4246 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4248 SmallVector<int64_t> resultTypeShape =
4249 getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
4250 asShapeWithAnyValueAsDynamic(innerTileSizes),
4251 innerDimsPos, outerDimsPerm);
4253 // Fix-up `resultDims` to ensure that they are Value's if and only if the
4254 // result type shape says it's a dynamic dim. This is needed as callers may
4255 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4256 // dynamic dims returned by that.
4257 for (unsigned i = 0; i < resultDims.size(); ++i) {
4258 if (!ShapedType::isDynamic(resultTypeShape[i]))
4259 continue;
4260 resultDims[i] =
4261 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4264 return resultDims;
4267 /// Get the expected packed type based on source type, tile factors, position of
4268 /// the inner tiles and permutation of the outer tiled loop.
4269 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4270 ArrayRef<int64_t> innerTileSizes,
4271 ArrayRef<int64_t> innerDimsPos,
4272 ArrayRef<int64_t> outerDimsPerm) {
4273 SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4274 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4275 return RankedTensorType::get(resultShape, sourceType.getElementType());
4278 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4279 ArrayRef<OpFoldResult> innerTileSizes,
4280 ArrayRef<int64_t> innerDimsPos,
4281 ArrayRef<int64_t> outerDimsPerm) {
4282 AffineExpr dim0, dim1;
4283 bindDims(b.getContext(), dim0, dim1);
4284 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4285 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4286 {v1, v2});
4289 SmallVector<OpFoldResult> mixedSizes;
4290 for (auto [index, value] : llvm::enumerate(
4291 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4292 if (ShapedType::isDynamic(value))
4293 mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4294 else
4295 mixedSizes.push_back(b.getIndexAttr(value));
4297 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4298 int64_t dimPos = std::get<0>(it);
4299 OpFoldResult tileSize = std::get<1>(it);
4300 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4302 if (!outerDimsPerm.empty())
4303 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4305 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4306 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4307 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4310 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4311 ArrayRef<int64_t> innerPermutation,
4312 ArrayRef<int64_t> outerPermutation) {
4313 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4314 *this, innerPermutation, outerPermutation);
4315 Value transposedDest =
4316 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4317 metadata.innerDimsPos, metadata.outerDimsPerm);
4318 return b.create<PackOp>(loc, getSource(), transposedDest,
4319 metadata.innerDimsPos, metadata.innerTiles,
4320 getPaddingValue(), metadata.outerDimsPerm);
4323 /// Returns true if the tiles and the tiled dims are constant.
4324 template <typename OpTy>
4325 bool areTilesAndTiledDimsAllConstant(OpTy op) {
4326 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4327 "applies to only pack or unpack operations");
4328 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4329 ? op.getDestType()
4330 : op.getSourceType();
4331 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4332 for (auto [dimDest, tile] : llvm::zip(
4333 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4334 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4335 if (!constTileSize || ShapedType::isDynamic(dimDest))
4336 return false;
4338 return true;
4341 Speculation::Speculatability PackOp::getSpeculatability() {
4342 if (getPaddingValue())
4343 return Speculation::Speculatable;
4345 // The verifier rejects already operations if we can statically prove that the
4346 // sizes of the tiles do not divide perfectly the dimension; thus, check only
4347 // to have constant tiles and tiled inner dimensions.
4348 if (!areTilesAndTiledDimsAllConstant(*this))
4349 return Speculation::NotSpeculatable;
4351 return Speculation::Speculatable;
4354 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4355 // dimensions for pack and unpack.
4356 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4357 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4358 return false;
4359 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4360 return true;
4361 // Outer dims permutation is optional.
4362 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4363 // identity permutation.
4364 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4365 isIdentityPermutation(unPackOp.getOuterDimsPerm());
4368 // Return true if pack and unpack have the same tiles.
4369 // Same SSA values or same integer constants.
4370 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4371 auto packTiles = packOp.getMixedTiles();
4372 auto unPackTiles = unPackOp.getMixedTiles();
4373 if (packTiles.size() != unPackTiles.size())
4374 return false;
4375 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4376 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4377 return false;
4379 return true;
4382 /// Returns true if the pack op does not need a padding value.
4383 static bool paddingIsNotNeeded(PackOp op) {
4384 auto srcType = op.getSourceType();
4385 if (llvm::any_of(op.getInnerDimsPos(),
4386 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4387 return false;
4388 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4389 return false;
4390 return !PackOp::requirePaddingValue(
4391 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4392 op.getOuterDimsPerm(), op.getMixedTiles());
4395 /// Returns true if the `srcShape` or `destShape` is different from the one in
4396 /// `packOp` and populates each with the inferred static shape.
4397 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4398 SmallVectorImpl<int64_t> &destShape) {
4399 bool changeNeeded = false;
4400 srcShape.assign(packOp.getSourceType().getShape().begin(),
4401 packOp.getSourceType().getShape().end());
4402 destShape.assign(packOp.getDestType().getShape().begin(),
4403 packOp.getDestType().getShape().end());
4404 llvm::SmallSetVector<int64_t, 4> innerDims;
4405 innerDims.insert(packOp.getInnerDimsPos().begin(),
4406 packOp.getInnerDimsPos().end());
4407 SmallVector<int64_t> inverseOuterDimsPerm;
4408 if (!packOp.getOuterDimsPerm().empty())
4409 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4410 int srcRank = packOp.getSourceRank();
4411 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4412 if (innerDims.contains(i))
4413 continue;
4414 int64_t srcPos = i;
4415 int64_t destPos = i;
4416 if (!inverseOuterDimsPerm.empty())
4417 destPos = inverseOuterDimsPerm[srcPos];
4418 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4419 ShapedType::isDynamic(destShape[destPos])) {
4420 continue;
4422 int64_t size = srcShape[srcPos];
4423 if (ShapedType::isDynamic(size))
4424 size = destShape[destPos];
4425 srcShape[srcPos] = size;
4426 destShape[destPos] = size;
4427 changeNeeded = true;
4429 return changeNeeded;
4432 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4433 // Fold an pack(unpack(x)) to x.
4434 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4435 if (unPackOp.getSourceType() != packOp.getDestType())
4436 return failure();
4437 if (packOp.getPaddingValue() ||
4438 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4439 !haveSameTiles(packOp, unPackOp))
4440 return failure();
4441 rewriter.replaceOp(packOp, unPackOp.getSource());
4442 return success();
4445 // Fold optional PaddingValue operand away if padding is not needed.
4446 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4447 rewriter.startOpModification(packOp);
4448 packOp.getPaddingValueMutable().clear();
4449 rewriter.finalizeOpModification(packOp);
4450 return success();
4453 // Insert tensor.cast ops if static shape inference is available..
4454 SmallVector<int64_t> srcShape, destShape;
4455 if (inferStaticShape(packOp, srcShape, destShape)) {
4456 Location loc = packOp.getLoc();
4457 Value source = packOp.getSource();
4458 if (srcShape != packOp.getSourceType().getShape()) {
4459 auto newSrcType = packOp.getSourceType().clone(srcShape);
4460 source =
4461 rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4463 Value dest = packOp.getDest();
4464 RankedTensorType originalResultType = packOp.getDestType();
4465 bool needUpdateDestType = (destShape != originalResultType.getShape());
4466 if (needUpdateDestType) {
4467 auto newDestType = packOp.getDestType().clone(destShape);
4468 dest =
4469 rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4471 rewriter.modifyOpInPlace(packOp, [&] {
4472 packOp.getSourceMutable().assign(source);
4473 packOp.getDestMutable().assign(dest);
4474 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4476 // Insert a cast if needed
4477 if (needUpdateDestType) {
4478 rewriter.setInsertionPointAfter(packOp);
4479 auto castOp =
4480 rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4481 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4483 return success();
4486 return failure();
4489 template <typename PackOrUnpackOp>
4490 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4491 RankedTensorType packedTensorType) {
4492 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4493 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4494 "Function meant for pack/unpack");
4495 // This is a pad if packing only adds ones and we don't transpose dimensions.
4497 // Check that we are not transposing any dimensions.
4498 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4499 int64_t numPackedDims = innerDimsPos.size();
4500 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4501 if (orderedDims != innerDimsPos) {
4502 // Dimensions don't happen in order.
4503 return false;
4506 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4507 int64_t packedRank = packedTensorType.getRank();
4508 // At this point we know that we are taking numPackedDims outer
4509 // dimensions and pushing them all the way as the inner most dimensions.
4510 // What's left on the outer most dimensions is, in this order:
4511 // - the factor of the packed dimensions, then
4512 // - the untouched dimensions
4513 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4514 // if all the dimensions that bubble outerward are ones.
4515 // Therefore check that all the dimensions but the numPackedDims inner most
4516 // ones are ones.
4517 return llvm::all_of(
4518 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4519 [&packedShape](int64_t i) { return packedShape[i] == 1; });
4522 bool PackOp::isLikePad() {
4523 auto packedTensorType =
4524 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4525 return isLikePadUnPad(*this, packedTensorType);
4528 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4529 std::optional<Attribute> paddingValue;
4530 if (auto pad = adaptor.getPaddingValue())
4531 paddingValue = pad;
4532 if (OpFoldResult reshapedSource = reshapeConstantSource(
4533 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4534 getDestType(), paddingValue))
4535 return reshapedSource;
4536 return {};
4539 //===----------------------------------------------------------------------===//
4540 // UnPackOp
4541 //===----------------------------------------------------------------------===//
4543 void UnPackOp::getAsmResultNames(
4544 function_ref<void(Value, StringRef)> setNameFn) {
4545 setNameFn(getResult(), "unpack");
4548 LogicalResult
4549 UnPackOp::reifyResultShapes(OpBuilder &builder,
4550 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4551 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4554 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4555 return getDimAndTileMappingImpl(*this);
4558 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4559 return getMixedTilesImpl(*this);
4562 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4563 return getStaticTilesImpl(*this);
4566 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
4567 ShapedType destType = getDestType();
4568 int64_t destRank = destType.getRank();
4569 return getSourceType().getShape().take_front(destRank);
4572 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
4573 auto innerDimsPos = getInnerDimsPos();
4574 auto packedShape = getSourceType().getShape();
4575 SmallVector<int64_t> res;
4577 for (auto index : innerDimsPos)
4578 res.push_back(packedShape[index]);
4580 return res;
4583 LogicalResult UnPackOp::verify() {
4584 return commonVerifierPackAndUnPackOp(*this);
4587 Speculation::Speculatability UnPackOp::getSpeculatability() {
4588 // See PackOp::getSpeculatability.
4589 if (!areTilesAndTiledDimsAllConstant(*this))
4590 return Speculation::NotSpeculatable;
4592 return Speculation::Speculatable;
4595 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4596 Value dest, ArrayRef<int64_t> innerDimsPos,
4597 ArrayRef<OpFoldResult> innerTiles,
4598 ArrayRef<int64_t> outerDimsPerm) {
4599 assert(innerDimsPos.size() == innerTiles.size() &&
4600 "number of tile sizes specified must match the specified number of "
4601 "original dimensions to be tiled");
4602 SmallVector<int64_t> staticTileSizes;
4603 SmallVector<Value> dynamicTileSizes;
4604 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4605 build(builder, state, dest.getType(), source, dest,
4606 outerDimsPerm.empty() ? nullptr
4607 : builder.getDenseI64ArrayAttr(outerDimsPerm),
4608 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4609 builder.getDenseI64ArrayAttr(staticTileSizes));
4612 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4613 Value source,
4614 ArrayRef<OpFoldResult> innerTileSizes,
4615 ArrayRef<int64_t> innerDimsPos,
4616 ArrayRef<int64_t> outerDimsPerm) {
4617 AffineExpr sym0, sym1;
4618 bindSymbols(b.getContext(), sym0, sym1);
4619 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4620 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4623 SmallVector<OpFoldResult> mixedSizes;
4624 auto srcType = llvm::cast<RankedTensorType>(source.getType());
4625 for (auto i :
4626 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4627 if (srcType.isDynamicDim(i))
4628 mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4629 else
4630 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4632 if (!outerDimsPerm.empty()) {
4633 applyPermutationToVector<OpFoldResult>(
4634 mixedSizes, invertPermutationVector(outerDimsPerm));
4637 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4638 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4640 auto elemType = srcType.getElementType();
4641 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4644 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4645 Value transposedSource,
4646 ArrayRef<int64_t> innerPermutation,
4647 ArrayRef<int64_t> outerPermutation) {
4648 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4649 *this, innerPermutation, outerPermutation);
4650 return b.create<UnPackOp>(loc, transposedSource, getDest(),
4651 metadata.innerDimsPos, metadata.innerTiles,
4652 metadata.outerDimsPerm);
4655 /// Returns true if the `srcShape` or `destShape` is different from the one in
4656 /// `op` and populates each with the inferred static shape.
4657 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4658 SmallVectorImpl<int64_t> &destShape) {
4659 bool changeNeeded = false;
4660 srcShape.assign(op.getSourceType().getShape().begin(),
4661 op.getSourceType().getShape().end());
4662 destShape.assign(op.getDestType().getShape().begin(),
4663 op.getDestType().getShape().end());
4664 llvm::SmallSetVector<int64_t, 4> innerDims;
4665 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4666 SmallVector<int64_t> inverseOuterDimsPerm;
4667 if (!op.getOuterDimsPerm().empty())
4668 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4669 int destRank = op.getDestRank();
4670 for (auto i : llvm::seq<int64_t>(0, destRank)) {
4671 if (innerDims.contains(i))
4672 continue;
4673 int64_t srcPos = i;
4674 int64_t destPos = i;
4675 if (!inverseOuterDimsPerm.empty())
4676 srcPos = inverseOuterDimsPerm[destPos];
4677 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4678 ShapedType::isDynamic(destShape[destPos])) {
4679 continue;
4681 int64_t size = srcShape[srcPos];
4682 if (ShapedType::isDynamic(size))
4683 size = destShape[destPos];
4684 srcShape[srcPos] = size;
4685 destShape[destPos] = size;
4686 changeNeeded = true;
4688 return changeNeeded;
4691 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4692 PatternRewriter &rewriter) {
4693 /// unpack(pack(x)) -> x
4694 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4695 if (packOp.getSourceType() != unPackOp.getDestType())
4696 return failure();
4697 if (packOp.getPaddingValue() ||
4698 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4699 !haveSameTiles(packOp, unPackOp))
4700 return failure();
4701 rewriter.replaceOp(unPackOp, packOp.getSource());
4702 return success();
4704 /// unpack(destinationStyleOp(x)) -> unpack(x)
4705 if (auto dstStyleOp =
4706 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4707 auto destValue = cast<OpResult>(unPackOp.getDest());
4708 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4709 rewriter.modifyOpInPlace(unPackOp,
4710 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4711 return success();
4714 // Insert tensor.cast ops if static shape inference is available..
4715 SmallVector<int64_t> srcShape, destShape;
4716 if (inferStaticShape(unPackOp, srcShape, destShape)) {
4717 Location loc = unPackOp.getLoc();
4718 Value source = unPackOp.getSource();
4719 if (srcShape != unPackOp.getSourceType().getShape()) {
4720 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4721 source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4722 unPackOp.getSource());
4724 Value dest = unPackOp.getDest();
4725 if (destShape != unPackOp.getDestType().getShape()) {
4726 auto newDestType = unPackOp.getDestType().clone(destShape);
4727 dest =
4728 rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4730 Value newOp = rewriter.create<tensor::UnPackOp>(
4731 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4732 unPackOp.getOuterDimsPerm());
4733 rewriter.replaceOpWithNewOp<tensor::CastOp>(
4734 unPackOp, unPackOp.getResult().getType(), newOp);
4735 return success();
4738 return failure();
4741 bool UnPackOp::isLikeUnPad() {
4742 RankedTensorType packedTensorType = getSourceType();
4743 return isLikePadUnPad(*this, packedTensorType);
4746 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4747 if (OpFoldResult reshapedSource = reshapeConstantSource(
4748 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4749 getResult().getType()))
4750 return reshapedSource;
4751 return {};
4754 //===----------------------------------------------------------------------===//
4755 // Common Canonicalizers and Folders.
4756 //===----------------------------------------------------------------------===//
4757 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4758 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4759 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4760 // might need special handling of attached regions.
4761 if (isa<InsertSliceOp>(op.getOperation()) ||
4762 isa<LoopLikeOpInterface>(op.getOperation()))
4763 return false;
4765 // If no operand comes from a tensor::CastOp and can be folded then fail.
4766 bool hasTensorCastOperand =
4767 llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4768 if (llvm::isa<BlockArgument>(opOperand.get()))
4769 return false;
4770 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4771 return castOp && canFoldIntoConsumerOp(castOp);
4774 return hasTensorCastOperand;
4777 static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4778 SmallVector<Type> &newResTy) {
4779 SmallVector<Value> newOperands;
4780 newOperands.reserve(op->getNumOperands());
4782 // Assumes that the result has dpsInits followed by nonDpsInits.
4783 int64_t dpsInitIdx = 0;
4784 for (OpOperand &opOperand : op->getOpOperands()) {
4785 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4786 bool fold = canFoldIntoConsumerOp(tensorCastOp);
4787 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4788 if (op.isDpsInit(&opOperand) &&
4789 !llvm::isa<MemRefType>(newOperands.back().getType()))
4790 newResTy[dpsInitIdx++] = newOperands.back().getType();
4792 return newOperands;
4795 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4796 /// `tensor.cast` has source that is more static than the consuming op.
4798 /// Example:
4799 /// ```mlir
4800 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4801 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4802 /// ```
4804 /// folds into:
4806 /// ```mlir
4807 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4808 /// ```
4809 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4810 using OpRewritePattern<PackOp>::OpRewritePattern;
4812 LogicalResult matchAndRewrite(PackOp op,
4813 PatternRewriter &rewriter) const override {
4814 if (!foldTensorCastPrecondition(op))
4815 return failure();
4817 SmallVector<Type> newResultTypes(op->getResultTypes());
4818 SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4820 // Get the updated mixed-tile-sizes attribute.
4821 SmallVector<OpFoldResult> newMixedTileSizes;
4822 for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4823 .getShape()
4824 .take_back(op.getMixedTiles().size()),
4825 op.getMixedTiles())) {
4826 int64_t shape = std::get<0>(it);
4827 if (shape == ShapedType::kDynamic) {
4828 newMixedTileSizes.push_back(std::get<1>(it));
4829 continue;
4832 if (Attribute attr =
4833 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4834 // Already a constant
4835 newMixedTileSizes.push_back(std::get<1>(it));
4836 } else {
4837 int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
4838 assert(tileSize == shape && "tile size and dim size don't match!");
4839 (void)tileSize;
4840 newMixedTileSizes.push_back(
4841 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4845 // Clone op.
4846 PackOp newOp = rewriter.create<PackOp>(
4847 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4848 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4849 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4851 // Replace op.
4852 Value oldResult = op.getResult();
4853 Value newResult = newOp.getResult();
4854 Value replacement = (newResult.getType() != oldResult.getType())
4855 ? rewriter.create<tensor::CastOp>(
4856 op->getLoc(), oldResult.getType(), newResult)
4857 : newResult;
4859 rewriter.replaceOp(op, {replacement});
4861 return success();
4865 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4866 /// the `tensor.cast` has source that is more static than the consuming op.
4868 /// Example:
4869 /// ```mlir
4870 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4871 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4872 /// ```
4874 /// folds into:
4876 /// ```mlir
4877 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4878 /// ```
4879 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4880 /// can add the pattern to their canonicalizers.
4881 struct FoldTensorCastProducerOp
4882 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4883 using OpInterfaceRewritePattern<
4884 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4886 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4887 PatternRewriter &rewriter) const override {
4889 // Reject tensor::PackOp - there's dedicated pattern for that instead.
4890 if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
4891 return failure();
4893 SmallVector<Type> newResultTypes(op->getResultTypes());
4894 SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4896 // Clone op
4897 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4899 SmallVector<Value, 4> replacements;
4900 replacements.reserve(newOp->getNumResults());
4901 for (auto [oldResult, newResult] :
4902 llvm::zip(op->getResults(), newOp->getResults())) {
4903 if (newResult.getType() != oldResult.getType()) {
4904 replacements.push_back(rewriter.create<tensor::CastOp>(
4905 op->getLoc(), oldResult.getType(), newResult));
4906 } else {
4907 replacements.push_back(newResult);
4910 rewriter.replaceOp(op, replacements);
4912 return success();
4916 //===----------------------------------------------------------------------===//
4917 // TensorDialect
4918 //===----------------------------------------------------------------------===//
4920 void TensorDialect::getCanonicalizationPatterns(
4921 RewritePatternSet &results) const {
4922 results.add<FoldTensorCastPackOp>(getContext());
4923 results.add<FoldTensorCastProducerOp>(getContext());
4926 //===----------------------------------------------------------------------===//
4927 // TableGen'd op method definitions
4928 //===----------------------------------------------------------------------===//
4930 #define GET_OP_CLASSES
4931 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"