1 //===----------------------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Arith/Utils/Utils.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinAttributeInterfaces.h"
19 #include "mlir/IR/BuiltinTypeInterfaces.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
37 using namespace mlir::tensor
;
39 using llvm::divideCeilSigned
;
40 using llvm::divideFloorSigned
;
43 /// Materialize a single constant operation from a given attribute value with
44 /// the desired resultant type.
45 Operation
*TensorDialect::materializeConstant(OpBuilder
&builder
,
46 Attribute value
, Type type
,
48 if (auto op
= arith::ConstantOp::materialize(builder
, value
, type
, loc
))
50 if (complex::ConstantOp::isBuildableWith(value
, type
))
51 return builder
.create
<complex::ConstantOp
>(loc
, type
,
52 llvm::cast
<ArrayAttr
>(value
));
56 OpFoldResult
tensor::getMixedSize(OpBuilder
&builder
, Location loc
, Value value
,
58 auto tensorType
= llvm::cast
<RankedTensorType
>(value
.getType());
59 SmallVector
<OpFoldResult
> result
;
60 if (tensorType
.isDynamicDim(dim
))
61 return builder
.createOrFold
<tensor::DimOp
>(loc
, value
, dim
);
63 return builder
.getIndexAttr(tensorType
.getDimSize(dim
));
66 SmallVector
<OpFoldResult
> tensor::getMixedSizes(OpBuilder
&builder
,
67 Location loc
, Value value
) {
68 auto tensorType
= llvm::cast
<RankedTensorType
>(value
.getType());
69 SmallVector
<OpFoldResult
> result
;
70 for (int64_t i
= 0; i
< tensorType
.getRank(); ++i
)
71 result
.push_back(getMixedSize(builder
, loc
, value
, i
));
75 FailureOr
<Value
> tensor::getOrCreateDestination(OpBuilder
&b
, Location loc
,
77 auto tensorType
= llvm::dyn_cast
<TensorType
>(opResult
.getType());
78 assert(tensorType
&& "expected tensor type");
80 // If the op has a destination, it implements DestinationStyleOpInterface and
81 // we can query the destination operand from that interface.
82 auto destOp
= opResult
.getDefiningOp
<DestinationStyleOpInterface
>();
84 return destOp
.getTiedOpOperand(opResult
)->get();
86 // Otherwise, create a new destination tensor with the same shape.
87 OpBuilder::InsertionGuard
g(b
);
88 b
.setInsertionPoint(opResult
.getDefiningOp());
91 SmallVector
<OpFoldResult
> mixedSizes
;
92 if (!tensorType
.hasStaticShape()) {
93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94 ReifiedRankedShapedTypeDims reifiedShapes
;
95 if (failed(reifyResultShapes(b
, opResult
.getDefiningOp(), reifiedShapes
)))
97 mixedSizes
= reifiedShapes
[opResult
.getResultNumber()];
99 // Static shape: Take static sizes directly.
100 for (int64_t sz
: tensorType
.getShape())
101 mixedSizes
.push_back(b
.getIndexAttr(sz
));
104 // Create empty tensor.
106 b
.create
<tensor::EmptyOp
>(loc
, mixedSizes
, tensorType
.getElementType());
110 LogicalResult
tensor::getOrCreateDestinations(OpBuilder
&b
, Location loc
,
112 SmallVector
<Value
> &result
) {
113 for (OpResult opResult
: op
->getResults()) {
114 if (llvm::isa
<TensorType
>(opResult
.getType())) {
115 FailureOr
<Value
> destination
= getOrCreateDestination(b
, loc
, opResult
);
116 if (failed(destination
))
118 result
.push_back(*destination
);
124 bool tensor::isSameTypeWithoutEncoding(Type tp1
, Type tp2
) {
125 if (auto rtp1
= llvm::dyn_cast
<RankedTensorType
>(tp1
)) {
126 if (auto rtp2
= llvm::dyn_cast
<RankedTensorType
>(tp2
))
127 return rtp1
.getShape() == rtp2
.getShape() &&
128 rtp1
.getElementType() == rtp2
.getElementType();
131 return tp1
== tp2
; // default implementation
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
136 static llvm::SmallBitVector
getDroppedDims(ArrayRef
<int64_t> reducedShape
,
137 ArrayRef
<OpFoldResult
> mixedSizes
) {
138 llvm::SmallBitVector
droppedDims(mixedSizes
.size());
139 int64_t shapePos
= reducedShape
.size() - 1;
141 for (const auto &size
: enumerate(llvm::reverse(mixedSizes
))) {
142 size_t idx
= mixedSizes
.size() - size
.index() - 1;
143 // Rank-reduced dims must have a static unit dimension.
144 bool isStaticUnitSize
=
145 size
.value().is
<Attribute
>() &&
146 llvm::cast
<IntegerAttr
>(size
.value().get
<Attribute
>()).getInt() == 1;
149 // There are no more dims in the reduced shape. All remaining sizes must
150 // be rank-reduced dims.
151 assert(isStaticUnitSize
&& "expected unit dim");
152 droppedDims
.set(idx
);
156 // Dim is preserved if the size is not a static 1.
157 if (!isStaticUnitSize
) {
162 // Dim is preserved if the reduced shape dim is also 1.
163 if (reducedShape
[shapePos
] == 1) {
168 // Otherwise: Dim is dropped.
169 droppedDims
.set(idx
);
172 assert(shapePos
< 0 && "dimension mismatch");
176 /// Given a ranked tensor type and a range of values that defines its dynamic
177 /// dimension sizes, turn all dynamic sizes that have a constant value into
178 /// static dimension sizes.
179 static RankedTensorType
180 foldDynamicToStaticDimSizes(RankedTensorType type
, ValueRange dynamicSizes
,
181 SmallVector
<Value
> &foldedDynamicSizes
) {
182 SmallVector
<int64_t> staticShape(type
.getShape());
183 assert(type
.getNumDynamicDims() == dynamicSizes
.size() &&
184 "incorrect number of dynamic sizes");
186 // Compute new static and dynamic sizes.
188 for (int64_t i
= 0, e
= type
.getRank(); i
< e
; ++i
) {
189 if (type
.isDynamicDim(i
)) {
190 Value dynamicSize
= dynamicSizes
[ctr
++];
191 std::optional
<int64_t> cst
= getConstantIntValue(dynamicSize
);
192 if (cst
.has_value()) {
193 // Dynamic size must be non-negative.
194 if (cst
.value() < 0) {
195 foldedDynamicSizes
.push_back(dynamicSize
);
198 staticShape
[i
] = *cst
;
200 foldedDynamicSizes
.push_back(dynamicSize
);
205 return RankedTensorType::get(staticShape
, type
.getElementType(),
209 //===----------------------------------------------------------------------===//
211 //===----------------------------------------------------------------------===//
213 bool BitcastOp::areCastCompatible(TypeRange inputs
, TypeRange outputs
) {
214 if (inputs
.size() != 1 || outputs
.size() != 1)
216 Type a
= inputs
.front(), b
= outputs
.front();
217 auto aT
= dyn_cast
<TensorType
>(a
);
218 auto bT
= dyn_cast
<TensorType
>(b
);
222 if (aT
.getElementTypeBitWidth() != bT
.getElementTypeBitWidth())
225 return succeeded(verifyCompatibleShape(aT
, bT
));
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
232 struct ChainedTensorBitcast
: public OpRewritePattern
<BitcastOp
> {
233 using OpRewritePattern
<BitcastOp
>::OpRewritePattern
;
235 LogicalResult
matchAndRewrite(BitcastOp tensorBitcast
,
236 PatternRewriter
&rewriter
) const final
{
237 auto tensorBitcastOperand
=
238 tensorBitcast
.getOperand().getDefiningOp
<BitcastOp
>();
239 if (!tensorBitcastOperand
)
242 auto resultType
= cast
<TensorType
>(tensorBitcast
.getType());
243 rewriter
.replaceOpWithNewOp
<BitcastOp
>(tensorBitcast
, resultType
,
244 tensorBitcastOperand
.getOperand());
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
252 MLIRContext
*context
) {
253 results
.add
<ChainedTensorBitcast
>(context
);
256 //===----------------------------------------------------------------------===//
258 //===----------------------------------------------------------------------===//
260 void CastOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
261 setNameFn(getResult(), "cast");
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
266 bool mlir::tensor::preservesStaticInformation(Type source
, Type target
) {
267 auto sourceType
= llvm::dyn_cast
<RankedTensorType
>(source
);
268 auto targetType
= llvm::dyn_cast
<RankedTensorType
>(target
);
270 // Requires RankedTensorType.
271 if (!sourceType
|| !targetType
)
274 // Requires same elemental type.
275 if (sourceType
.getElementType() != targetType
.getElementType())
278 // Requires same rank.
279 if (sourceType
.getRank() != targetType
.getRank())
282 // Requires same encoding.
283 if (sourceType
.getEncoding() != targetType
.getEncoding())
286 // If cast is towards more static sizes along any dimension, don't fold.
287 for (auto t
: llvm::zip(sourceType
.getShape(), targetType
.getShape())) {
288 if (!ShapedType::isDynamic(std::get
<0>(t
)) &&
289 ShapedType::isDynamic(std::get
<1>(t
)))
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
318 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp
) {
322 // Can fold if the source of cast has at least as much static information as
324 return preservesStaticInformation(castOp
.getType(),
325 castOp
.getSource().getType());
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
331 /// being from different dialects. Returns true when all conditions are met:
332 /// 1. source and result and ranked tensors with same element type and rank.
333 /// 2. the result type has more static information than the source.
337 /// %1 = producer ... : tensor<?x?xf32>
338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
341 /// can be canonicalized to :
344 /// %2 = producer ... : tensor<8x16xf32>
346 /// Not all ops might be canonicalizable this way, but for those that can be,
347 /// this method provides a check that it is worth doing the canonicalization.
348 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp
) {
351 return preservesStaticInformation(castOp
.getSource().getType(),
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
356 /// that can be folded.
357 LogicalResult
mlir::tensor::foldTensorCast(Operation
*op
) {
359 for (OpOperand
&operand
: op
->getOpOperands()) {
360 auto castOp
= operand
.get().getDefiningOp
<tensor::CastOp
>();
361 if (castOp
&& tensor::canFoldIntoConsumerOp(castOp
)) {
362 operand
.set(castOp
.getOperand());
366 return success(folded
);
369 bool CastOp::areCastCompatible(TypeRange inputs
, TypeRange outputs
) {
370 if (inputs
.size() != 1 || outputs
.size() != 1)
372 Type a
= inputs
.front(), b
= outputs
.front();
373 auto aT
= llvm::dyn_cast
<TensorType
>(a
);
374 auto bT
= llvm::dyn_cast
<TensorType
>(b
);
378 if (aT
.getElementType() != bT
.getElementType())
381 return succeeded(verifyCompatibleShape(aT
, bT
));
384 /// Compute a TensorType that has the joined shape knowledge of the two
385 /// given TensorTypes. The element types need to match.
386 static TensorType
joinShapes(TensorType one
, TensorType two
) {
387 assert(one
.getElementType() == two
.getElementType());
394 int64_t rank
= one
.getRank();
395 if (rank
!= two
.getRank())
398 SmallVector
<int64_t, 4> join
;
400 for (int64_t i
= 0; i
< rank
; ++i
) {
401 if (one
.isDynamicDim(i
)) {
402 join
.push_back(two
.getDimSize(i
));
405 if (two
.isDynamicDim(i
)) {
406 join
.push_back(one
.getDimSize(i
));
409 if (one
.getDimSize(i
) != two
.getDimSize(i
))
411 join
.push_back(one
.getDimSize(i
));
413 return RankedTensorType::get(join
, one
.getElementType());
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
419 /// operation if doing so does not remove runtime constraints.
420 struct ChainedTensorCast
: public OpRewritePattern
<CastOp
> {
421 using OpRewritePattern
<CastOp
>::OpRewritePattern
;
423 LogicalResult
matchAndRewrite(CastOp tensorCast
,
424 PatternRewriter
&rewriter
) const final
{
425 auto tensorCastOperand
= tensorCast
.getOperand().getDefiningOp
<CastOp
>();
427 if (!tensorCastOperand
)
431 llvm::cast
<TensorType
>(tensorCastOperand
.getOperand().getType());
432 auto intermediateType
= llvm::cast
<TensorType
>(tensorCastOperand
.getType());
433 auto resultType
= llvm::cast
<TensorType
>(tensorCast
.getType());
435 // We can remove the intermediate cast if joining all three produces the
436 // same result as just joining the source and result shapes.
438 joinShapes(joinShapes(sourceType
, intermediateType
), resultType
);
440 // The join might not exist if the cast sequence would fail at runtime.
444 // The newJoin always exists if the above join exists, it might just contain
445 // less information. If so, we cannot drop the intermediate cast, as doing
446 // so would remove runtime checks.
447 auto newJoin
= joinShapes(sourceType
, resultType
);
448 if (firstJoin
!= newJoin
)
451 rewriter
.replaceOpWithNewOp
<CastOp
>(tensorCast
, resultType
,
452 tensorCastOperand
.getOperand());
457 /// Fold tensor.cast into tesor.extract_slice producer.
460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 /// tensor<128x512xf32> to tensor<?x512xf32>
462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 /// tensor<128x512xf32> to tensor<16x512xf32>
469 struct TensorCastExtractSlice
: public OpRewritePattern
<CastOp
> {
470 using OpRewritePattern
<CastOp
>::OpRewritePattern
;
472 LogicalResult
matchAndRewrite(CastOp tensorCast
,
473 PatternRewriter
&rewriter
) const final
{
474 auto extractOperand
=
475 tensorCast
.getOperand().getDefiningOp
<ExtractSliceOp
>();
477 // Cannot fold cast to unranked tensor.
478 auto rankedResultType
=
479 llvm::dyn_cast
<RankedTensorType
>(tensorCast
.getType());
480 if (!rankedResultType
)
483 if (!extractOperand
|| !canFoldIntoProducerOp(tensorCast
) ||
484 rankedResultType
.getShape() ==
485 llvm::cast
<RankedTensorType
>(tensorCast
.getSource().getType())
489 SmallVector
<OpFoldResult
, 4> sizes
= extractOperand
.getMixedSizes();
490 auto dimMask
= computeRankReductionMask(
491 extractOperand
.getStaticSizes(), extractOperand
.getType().getShape());
493 for (size_t i
= 0, e
= sizes
.size(); i
< e
; i
++) {
494 if (dimMask
&& dimMask
->count(i
))
496 int64_t dim
= rankedResultType
.getShape()[dimIndex
++];
497 if (ShapedType::isDynamic(dim
))
499 sizes
[i
] = rewriter
.getIndexAttr(dim
);
502 rewriter
.replaceOpWithNewOp
<ExtractSliceOp
>(
503 tensorCast
, rankedResultType
, extractOperand
.getSource(),
504 extractOperand
.getMixedOffsets(), sizes
,
505 extractOperand
.getMixedStrides());
512 void CastOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
513 MLIRContext
*context
) {
514 results
.add
<ChainedTensorCast
, TensorCastExtractSlice
>(context
);
517 //===----------------------------------------------------------------------===//
519 //===----------------------------------------------------------------------===//
521 RankedTensorType
ConcatOp::inferResultType(int64_t dim
, TypeRange inputTypes
) {
522 assert(!inputTypes
.empty() && "cannot concatenate 0 tensors");
524 llvm::to_vector
<4>(llvm::map_range(inputTypes
, [](Type type
) {
525 return llvm::cast
<RankedTensorType
>(type
);
527 int64_t concatRank
= tensorTypes
[0].getRank();
529 // The concatenation dim must be in the range [0, rank).
530 assert(dim
>= 0 && dim
< concatRank
&& "Invalid concatenation dim");
532 SmallVector
<int64_t> sizes(concatRank
);
533 for (int64_t i
= 0, e
= concatRank
; i
< e
; ++i
) {
536 SaturatedInteger size
;
537 for (auto tensorType
: tensorTypes
)
538 size
= *size
.desaturate(SaturatedInteger::wrap(tensorType
.getDimSize(i
)));
539 sizes
[i
] = size
.asInteger();
541 auto concatSize
= SaturatedInteger::wrap(0);
542 for (auto tensorType
: tensorTypes
)
544 concatSize
+ SaturatedInteger::wrap(tensorType
.getDimSize(dim
));
545 sizes
[dim
] = concatSize
.asInteger();
546 return RankedTensorType::get(sizes
, tensorTypes
[0].getElementType());
549 void ConcatOp::build(OpBuilder
&builder
, OperationState
&result
, int64_t dim
,
551 FailureOr
<RankedTensorType
> resultType
=
552 inferResultType(dim
, inputs
.getTypes());
553 assert(succeeded(resultType
) && "failed to infer concatenation result type");
554 build(builder
, result
, *resultType
, dim
, inputs
);
557 LogicalResult
ConcatOp::verify() {
558 if (getInputs().size() < 1)
559 return emitOpError("requires at least one input");
561 SmallVector
<RankedTensorType
> inputTypes
;
562 for (auto input
: getInputs())
563 inputTypes
.push_back(cast
<RankedTensorType
>(input
.getType()));
565 RankedTensorType resultType
= getResultType();
566 int64_t resultRank
= getRank();
567 if (llvm::any_of(inputTypes
, [resultRank
](RankedTensorType type
) {
568 return type
.getRank() != resultRank
;
570 return emitOpError("rank of concatenated inputs must match result rank");
572 Type resultElementType
= resultType
.getElementType();
573 if (llvm::any_of(inputTypes
, [&](RankedTensorType type
) {
574 return type
.getElementType() != resultElementType
;
576 return emitOpError("inputs and result element type must match");
578 int64_t dim
= getDim();
579 if (dim
>= resultRank
)
580 return emitOpError("concatenation dim must be less than the tensor rank");
582 SmallVector
<int64_t> sizes(resultRank
);
583 for (int64_t i
= 0, e
= resultRank
; i
< e
; ++i
) {
586 SaturatedInteger size
;
587 for (auto tensorType
: inputTypes
) {
588 FailureOr
<SaturatedInteger
> maybeSize
=
589 size
.desaturate(SaturatedInteger::wrap(tensorType
.getDimSize(i
)));
590 if (failed(maybeSize
))
591 return emitOpError("static concatenation size mismatch along ")
592 << "non-concatenated dimension " << i
;
595 sizes
[i
] = size
.asInteger();
597 auto concatSize
= SaturatedInteger::wrap(0);
598 for (auto tensorType
: inputTypes
)
600 concatSize
+ SaturatedInteger::wrap(tensorType
.getDimSize(dim
));
601 sizes
[dim
] = concatSize
.asInteger();
602 auto inferredResultType
=
603 RankedTensorType::get(sizes
, inputTypes
[0].getElementType());
605 for (auto [inferredSize
, actualSize
] :
606 llvm::zip_equal(inferredResultType
.getShape(), resultType
.getShape())) {
607 bool hasDynamic
= ShapedType::isDynamic(inferredSize
) ||
608 ShapedType::isDynamic(actualSize
);
609 if (!hasDynamic
&& inferredSize
!= actualSize
)
610 return emitOpError("result type ")
611 << resultType
<< "does not match inferred shape "
612 << inferredResultType
<< " static sizes";
618 FailureOr
<SmallVector
<Value
>> ConcatOp::decomposeOperation(OpBuilder
&builder
) {
619 size_t numInputs
= getInputs().size();
620 uint64_t concatDim
= getDim();
622 SmallVector
<SmallVector
<OpFoldResult
>> inputShapes
;
623 inputShapes
.reserve(numInputs
);
624 SmallVector
<OpFoldResult
> concatOffsets
;
625 concatOffsets
.reserve(numInputs
);
626 SmallVector
<OpFoldResult
> outputShape
;
629 builder
.getAffineSymbolExpr(0) + builder
.getAffineSymbolExpr(1);
630 OpFoldResult zero
= builder
.getIndexAttr(0);
631 Location loc
= getLoc();
632 for (auto [index
, input
] : llvm::enumerate(getInputs())) {
633 SmallVector
<OpFoldResult
> inputShape
=
634 tensor::getMixedSizes(builder
, input
.getLoc(), input
);
636 outputShape
= inputShape
;
637 concatOffsets
.push_back(zero
);
639 concatOffsets
.push_back(outputShape
[concatDim
]);
640 outputShape
[concatDim
] = affine::makeComposedFoldedAffineApply(
641 builder
, loc
, addExpr
,
642 {outputShape
[concatDim
], inputShape
[concatDim
]});
644 inputShapes
.emplace_back(std::move(inputShape
));
647 Value replacement
= builder
.create
<tensor::EmptyOp
>(
648 loc
, outputShape
, getType().getElementType());
650 int64_t rank
= getType().getRank();
651 OpFoldResult one
= builder
.getIndexAttr(1);
652 SmallVector
<OpFoldResult
> strides(rank
, one
);
653 SmallVector
<OpFoldResult
> offsets(rank
, zero
);
654 for (auto [index
, input
] : llvm::enumerate(getInputs())) {
655 offsets
[concatDim
] = concatOffsets
[index
];
656 auto insertSlice
= builder
.create
<tensor::InsertSliceOp
>(
657 loc
, input
, replacement
, offsets
, inputShapes
[index
], strides
);
658 replacement
= insertSlice
.getResult();
660 if (replacement
.getType() != getType()) {
661 replacement
= builder
.create
<tensor::CastOp
>(loc
, getType(), replacement
);
663 return SmallVector
<Value
>{replacement
};
667 ConcatOp::reifyResultShapes(OpBuilder
&builder
,
668 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
669 ValueRange inputs
= getInputs();
670 int64_t dim
= getDim();
671 RankedTensorType inferredResultType
= inferResultType(dim
, inputs
.getTypes());
673 Value init
= inputs
[0];
674 int64_t rank
= getType().getRank();
676 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(rank
));
678 // Pre-populate the result sizes with as much static information as possible
679 // from the given result type, as well as the inferred result type, otherwise
680 // use the dim sizes from the first input.
681 for (int64_t i
= 0; i
< rank
; ++i
) {
684 if (!getType().isDynamicDim(i
)) {
685 reifiedReturnShapes
[0][i
] = builder
.getIndexAttr(getType().getDimSize(i
));
686 } else if (!inferredResultType
.isDynamicDim(i
)) {
687 reifiedReturnShapes
[0][i
] = getValueOrCreateConstantIndexOp(
689 builder
.getIndexAttr(inferredResultType
.getDimSize(i
)));
691 reifiedReturnShapes
[0][i
] =
692 builder
.create
<tensor::DimOp
>(init
.getLoc(), init
, i
).getResult();
696 if (getType().isDynamicDim(dim
)) {
697 // Take the sum of the input sizes along the concatenated dim.
698 AffineExpr sum
= builder
.getAffineDimExpr(0);
699 SmallVector
<OpFoldResult
> sizes
= {
700 builder
.createOrFold
<tensor::DimOp
>(init
.getLoc(), init
, dim
)};
701 for (auto [idx
, input
] : llvm::enumerate(inputs
.drop_front())) {
702 sum
= sum
+ builder
.getAffineDimExpr(idx
+ 1);
704 builder
.createOrFold
<tensor::DimOp
>(input
.getLoc(), input
, dim
));
706 reifiedReturnShapes
[0][dim
] = getValueOrCreateConstantIndexOp(
708 affine::makeComposedFoldedAffineApply(builder
, getLoc(), sum
, sizes
));
710 // If the result shape is static along the concatenated dim, use the static
712 reifiedReturnShapes
[0][dim
] =
713 builder
.getIndexAttr(getType().getDimSize(dim
));
718 void ConcatOp::getAsmResultNames(
719 function_ref
<void(Value
, StringRef
)> setNameFn
) {
720 setNameFn(getResult(), "concat");
723 OpFoldResult
ConcatOp::fold(FoldAdaptor
) {
724 ValueRange inputs
= getInputs();
725 if (inputs
.size() == 1 && inputs
[0].getType() == getResultType())
731 /// Fold a concat op with a single input to a cast.
732 struct SingleInputConcatOp
: public OpRewritePattern
<ConcatOp
> {
733 using OpRewritePattern
<ConcatOp
>::OpRewritePattern
;
735 LogicalResult
matchAndRewrite(ConcatOp concatOp
,
736 PatternRewriter
&rewriter
) const override
{
737 if (concatOp
.getInputs().size() != 1)
739 rewriter
.replaceOpWithNewOp
<CastOp
>(concatOp
, concatOp
.getResultType(),
740 concatOp
.getInputs()[0]);
746 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
747 MLIRContext
*context
) {
748 results
.add
<SingleInputConcatOp
>(context
);
751 //===----------------------------------------------------------------------===//
753 //===----------------------------------------------------------------------===//
755 void DimOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
756 setNameFn(getResult(), "dim");
759 void DimOp::build(OpBuilder
&builder
, OperationState
&result
, Value source
,
761 auto loc
= result
.location
;
762 Value indexValue
= builder
.create
<arith::ConstantIndexOp
>(loc
, index
);
763 build(builder
, result
, source
, indexValue
);
766 std::optional
<int64_t> DimOp::getConstantIndex() {
767 return getConstantIntValue(getIndex());
770 Speculation::Speculatability
DimOp::getSpeculatability() {
771 auto constantIndex
= getConstantIndex();
773 return Speculation::NotSpeculatable
;
775 auto rankedSourceType
= dyn_cast
<RankedTensorType
>(getSource().getType());
776 if (!rankedSourceType
)
777 return Speculation::NotSpeculatable
;
779 if (rankedSourceType
.getRank() <= constantIndex
)
780 return Speculation::NotSpeculatable
;
782 return Speculation::Speculatable
;
785 OpFoldResult
DimOp::fold(FoldAdaptor adaptor
) {
786 // All forms of folding require a known index.
787 auto index
= llvm::dyn_cast_if_present
<IntegerAttr
>(adaptor
.getIndex());
791 // Folding for unranked types (UnrankedTensorType) is not supported.
792 auto tensorType
= llvm::dyn_cast
<RankedTensorType
>(getSource().getType());
796 // Out of bound indices produce undefined behavior but are still valid IR.
797 // Don't choke on them.
798 int64_t indexVal
= index
.getInt();
799 if (indexVal
< 0 || indexVal
>= tensorType
.getRank())
802 // Fold if the shape extent along the given index is known.
803 if (!tensorType
.isDynamicDim(index
.getInt())) {
804 Builder
builder(getContext());
805 return builder
.getIndexAttr(tensorType
.getShape()[index
.getInt()]);
808 Operation
*definingOp
= getSource().getDefiningOp();
810 // Fold dim to the operand of tensor.generate.
811 if (auto fromElements
= dyn_cast_or_null
<tensor::GenerateOp
>(definingOp
)) {
813 llvm::cast
<RankedTensorType
>(fromElements
.getResult().getType());
814 // The case where the type encodes the size of the dimension is handled
816 assert(ShapedType::isDynamic(resultType
.getShape()[index
.getInt()]));
818 // Find the operand of the fromElements that corresponds to this index.
819 auto dynExtents
= fromElements
.getDynamicExtents().begin();
820 for (auto dim
: resultType
.getShape().take_front(index
.getInt()))
821 if (ShapedType::isDynamic(dim
))
824 return Value
{*dynExtents
};
827 // The size at the given index is now known to be a dynamic size.
828 unsigned unsignedIndex
= index
.getValue().getZExtValue();
830 if (auto sliceOp
= dyn_cast_or_null
<tensor::ExtractSliceOp
>(definingOp
)) {
831 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
832 // `resolve-shaped-type-result-dims` pass.
833 if (sliceOp
.getType().getRank() == sliceOp
.getSourceType().getRank() &&
834 sliceOp
.isDynamicSize(unsignedIndex
)) {
835 return {sliceOp
.getDynamicSize(unsignedIndex
)};
840 if (succeeded(foldTensorCast(*this)))
847 /// Fold dim of a cast into the dim of the source of the tensor cast.
848 struct DimOfCastOp
: public OpRewritePattern
<DimOp
> {
849 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
851 LogicalResult
matchAndRewrite(DimOp dimOp
,
852 PatternRewriter
&rewriter
) const override
{
853 auto castOp
= dimOp
.getSource().getDefiningOp
<CastOp
>();
856 Value newSource
= castOp
.getOperand();
857 rewriter
.replaceOpWithNewOp
<DimOp
>(dimOp
, newSource
, dimOp
.getIndex());
862 /// Fold dim of a destination passing style op into the dim of the corresponding
864 struct DimOfDestStyleOp
: public OpRewritePattern
<DimOp
> {
865 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
867 LogicalResult
matchAndRewrite(DimOp dimOp
,
868 PatternRewriter
&rewriter
) const override
{
869 auto source
= dimOp
.getSource();
870 auto destOp
= source
.getDefiningOp
<DestinationStyleOpInterface
>();
874 auto resultIndex
= cast
<OpResult
>(source
).getResultNumber();
875 auto *initOperand
= destOp
.getDpsInitOperand(resultIndex
);
877 rewriter
.modifyOpInPlace(
878 dimOp
, [&]() { dimOp
.getSourceMutable().assign(initOperand
->get()); });
883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
885 struct DimOfReshapeOp
: public OpRewritePattern
<DimOp
> {
886 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
888 LogicalResult
matchAndRewrite(DimOp dim
,
889 PatternRewriter
&rewriter
) const override
{
890 auto reshape
= dim
.getSource().getDefiningOp
<ReshapeOp
>();
895 // Since tensors are immutable we don't need to worry about where to place
897 rewriter
.setInsertionPointAfter(dim
);
898 Location loc
= dim
.getLoc();
900 rewriter
.create
<ExtractOp
>(loc
, reshape
.getShape(), dim
.getIndex());
901 if (extract
.getType() != dim
.getType())
903 rewriter
.create
<arith::IndexCastOp
>(loc
, dim
.getType(), extract
);
904 rewriter
.replaceOp(dim
, extract
);
910 void DimOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
911 MLIRContext
*context
) {
912 results
.add
<DimOfCastOp
, DimOfDestStyleOp
, DimOfReshapeOp
>(context
);
915 //===----------------------------------------------------------------------===//
917 //===----------------------------------------------------------------------===//
919 void EmptyOp::build(OpBuilder
&builder
, OperationState
&result
,
920 ArrayRef
<int64_t> staticShape
, Type elementType
,
921 Attribute encoding
) {
922 assert(all_of(staticShape
,
923 [](int64_t sz
) { return !ShapedType::isDynamic(sz
); }) &&
924 "expected only static sizes");
925 build(builder
, result
, staticShape
, elementType
, ValueRange
{}, encoding
);
928 void EmptyOp::build(OpBuilder
&builder
, OperationState
&result
,
929 ArrayRef
<int64_t> staticShape
, Type elementType
,
930 ValueRange dynamicSizes
, Attribute encoding
) {
931 auto tensorType
= RankedTensorType::get(staticShape
, elementType
, encoding
);
932 build(builder
, result
, tensorType
, dynamicSizes
);
935 void EmptyOp::build(OpBuilder
&builder
, OperationState
&result
,
936 ArrayRef
<OpFoldResult
> sizes
, Type elementType
,
937 Attribute encoding
) {
938 SmallVector
<int64_t> staticShape
;
939 SmallVector
<Value
> dynamicSizes
;
940 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticShape
);
941 build(builder
, result
, staticShape
, elementType
, dynamicSizes
, encoding
);
944 LogicalResult
EmptyOp::verify() {
945 if (getType().getNumDynamicDims() != getDynamicSizes().size())
946 return emitOpError("incorrect number of dynamic sizes, has ")
947 << getDynamicSizes().size() << ", expected "
948 << getType().getNumDynamicDims();
953 EmptyOp::reifyResultShapes(OpBuilder
&builder
,
954 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
955 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(getType().getRank()));
957 for (int64_t i
= 0; i
< getType().getRank(); ++i
) {
958 if (getType().isDynamicDim(i
)) {
959 reifiedReturnShapes
[0][i
] = getDynamicSizes()[ctr
++];
961 reifiedReturnShapes
[0][i
] = builder
.getIndexAttr(getType().getDimSize(i
));
967 Value
EmptyOp::getDynamicSize(unsigned idx
) {
968 assert(getType().isDynamicDim(idx
) && "expected dynamic dim");
970 for (int64_t i
= 0; i
< static_cast<int64_t>(idx
); ++i
)
971 if (getType().isDynamicDim(i
))
973 return getDynamicSizes()[ctr
];
976 SmallVector
<OpFoldResult
> EmptyOp::getMixedSizes() {
977 SmallVector
<OpFoldResult
> result
;
979 OpBuilder
b(getContext());
980 for (int64_t i
= 0; i
< getType().getRank(); ++i
) {
981 if (getType().isDynamicDim(i
)) {
982 result
.push_back(getDynamicSizes()[ctr
++]);
984 result
.push_back(b
.getIndexAttr(getType().getShape()[i
]));
991 /// Change the type of the result of a `tensor.empty` by making the result
992 /// type statically sized along dimensions that in the original operation were
993 /// defined as dynamic, but the size was defined using a `constant` op. For
996 /// %c5 = arith.constant 5: index
997 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1001 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1002 struct ReplaceEmptyTensorStaticShapeDims
: OpRewritePattern
<EmptyOp
> {
1003 using OpRewritePattern
<EmptyOp
>::OpRewritePattern
;
1005 LogicalResult
matchAndRewrite(EmptyOp op
,
1006 PatternRewriter
&rewriter
) const override
{
1007 SmallVector
<Value
> foldedDynamicSizes
;
1008 RankedTensorType foldedTensorType
= foldDynamicToStaticDimSizes(
1009 op
.getType(), op
.getDynamicSizes(), foldedDynamicSizes
);
1011 // Stop here if no dynamic size was promoted to static.
1012 if (foldedTensorType
== op
.getType())
1015 auto newOp
= rewriter
.create
<EmptyOp
>(op
.getLoc(), foldedTensorType
,
1016 foldedDynamicSizes
);
1017 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(op
, op
.getType(), newOp
);
1022 struct FoldEmptyTensorWithDimOp
: public OpRewritePattern
<DimOp
> {
1023 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
1025 LogicalResult
matchAndRewrite(tensor::DimOp dimOp
,
1026 PatternRewriter
&rewriter
) const override
{
1027 std::optional
<int64_t> maybeConstantIndex
= dimOp
.getConstantIndex();
1028 auto emptyTensorOp
= dimOp
.getSource().getDefiningOp
<EmptyOp
>();
1029 if (!emptyTensorOp
|| !maybeConstantIndex
)
1031 auto emptyTensorType
= emptyTensorOp
.getType();
1032 if (*maybeConstantIndex
< 0 ||
1033 *maybeConstantIndex
>= emptyTensorType
.getRank() ||
1034 !emptyTensorType
.isDynamicDim(*maybeConstantIndex
))
1036 rewriter
.replaceOp(dimOp
,
1037 emptyTensorOp
.getDynamicSize(*maybeConstantIndex
));
1045 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1046 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1052 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1055 /// This assumes the input program is correct in terms of its shape. So it is
1056 /// safe to assume that `%d0` is in fact 4.
1057 struct FoldEmptyTensorWithCastOp
: public OpRewritePattern
<CastOp
> {
1058 using OpRewritePattern
<CastOp
>::OpRewritePattern
;
1060 LogicalResult
matchAndRewrite(CastOp castOp
,
1061 PatternRewriter
&rewriter
) const override
{
1062 if (!canFoldIntoProducerOp(castOp
))
1064 auto producer
= castOp
.getSource().getDefiningOp
<EmptyOp
>();
1069 llvm::cast
<RankedTensorType
>(castOp
->getResult(0).getType());
1070 ArrayRef
<int64_t> resultShape
= resultType
.getShape();
1071 SmallVector
<OpFoldResult
> currMixedSizes
= producer
.getMixedSizes();
1072 SmallVector
<OpFoldResult
> newMixedSizes
;
1073 newMixedSizes
.reserve(currMixedSizes
.size());
1074 assert(resultShape
.size() == currMixedSizes
.size() &&
1075 "mismatch in result shape and sizes of empty op");
1076 for (auto it
: llvm::zip(resultShape
, currMixedSizes
)) {
1077 int64_t newDim
= std::get
<0>(it
);
1078 OpFoldResult currDim
= std::get
<1>(it
);
1079 // Case 1: The empty tensor dim is static. Check that the tensor cast
1080 // result dim matches.
1081 if (auto attr
= llvm::dyn_cast_if_present
<Attribute
>(currDim
)) {
1082 if (ShapedType::isDynamic(newDim
) ||
1083 newDim
!= llvm::cast
<IntegerAttr
>(attr
).getInt()) {
1084 // Something is off, the cast result shape cannot be more dynamic
1085 // than the empty tensor result shape (enforced by
1086 // `canFoldIntoProducer`). Abort for now.
1087 return rewriter
.notifyMatchFailure(
1088 producer
, "mismatch in static value of shape of empty tensor "
1089 "result and cast result");
1091 newMixedSizes
.push_back(attr
);
1095 // Case 2 : The tensor cast shape is static, but empty tensor result
1096 // shape is dynamic.
1097 if (!ShapedType::isDynamic(newDim
)) {
1098 newMixedSizes
.push_back(rewriter
.getIndexAttr(newDim
));
1102 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1103 // shape is dynamic. Use the dynamic value from the empty tensor op.
1104 newMixedSizes
.push_back(currDim
);
1107 // TODO: Do not drop tensor encoding.
1108 rewriter
.replaceOpWithNewOp
<EmptyOp
>(castOp
, newMixedSizes
,
1109 resultType
.getElementType());
1116 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1117 MLIRContext
*context
) {
1118 results
.add
<FoldEmptyTensorWithCastOp
, FoldEmptyTensorWithDimOp
,
1119 ReplaceEmptyTensorStaticShapeDims
>(context
);
1122 /// Try to remove a tensor operation if it would only reshape a constant.
1123 /// Removes the op and replaces the constant with a new constant of the result
1124 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1125 /// splat value matches the value in the attribute.
1127 reshapeConstantSource(DenseElementsAttr source
, TensorType result
,
1128 std::optional
<Attribute
> cst
= std::nullopt
) {
1129 if (source
&& source
.isSplat() && result
.hasStaticShape() &&
1130 (!cst
.has_value() || source
.getSplatValue
<Attribute
>() == cst
.value()))
1131 return source
.resizeSplat(result
);
1136 //===----------------------------------------------------------------------===//
1138 //===----------------------------------------------------------------------===//
1142 /// Canonicalizes the pattern of the form
1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1150 struct ExtractFromTensorCast
: public OpRewritePattern
<tensor::ExtractOp
> {
1151 using OpRewritePattern
<tensor::ExtractOp
>::OpRewritePattern
;
1153 LogicalResult
matchAndRewrite(tensor::ExtractOp extract
,
1154 PatternRewriter
&rewriter
) const final
{
1155 auto tensorCast
= extract
.getTensor().getDefiningOp
<tensor::CastOp
>();
1158 if (!llvm::isa
<RankedTensorType
>(tensorCast
.getSource().getType()))
1160 rewriter
.replaceOpWithNewOp
<tensor::ExtractOp
>(
1161 extract
, tensorCast
.getSource(), extract
.getIndices());
1168 void ExtractOp::getAsmResultNames(
1169 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1170 setNameFn(getResult(), "extracted");
1173 LogicalResult
ExtractOp::verify() {
1174 // Verify the # indices match if we have a ranked type.
1175 auto tensorType
= llvm::cast
<RankedTensorType
>(getTensor().getType());
1176 if (tensorType
.getRank() != static_cast<int64_t>(getIndices().size()))
1177 return emitOpError("incorrect number of indices for extract_element");
1181 OpFoldResult
ExtractOp::fold(FoldAdaptor adaptor
) {
1182 if (Attribute tensor
= adaptor
.getTensor()) {
1183 // If this is a splat elements attribute, simply return the value.
1184 // All of the elements of a splat attribute are the same.
1185 if (auto splatTensor
= llvm::dyn_cast
<SplatElementsAttr
>(tensor
))
1186 return splatTensor
.getSplatValue
<Attribute
>();
1188 // If this is a dense resource elements attribute, return.
1189 if (isa
<DenseResourceElementsAttr
>(tensor
))
1193 // Collect the constant indices into the tensor.
1194 SmallVector
<uint64_t, 8> indices
;
1195 for (Attribute indice
: adaptor
.getIndices()) {
1196 if (!indice
|| !llvm::isa
<IntegerAttr
>(indice
))
1198 indices
.push_back(llvm::cast
<IntegerAttr
>(indice
).getInt());
1201 // Fold extract(from_elements(...)).
1202 if (auto fromElementsOp
= getTensor().getDefiningOp
<FromElementsOp
>()) {
1203 auto tensorType
= llvm::cast
<RankedTensorType
>(fromElementsOp
.getType());
1204 auto rank
= tensorType
.getRank();
1205 assert(static_cast<int64_t>(indices
.size()) == tensorType
.getRank() &&
1209 for (int i
= rank
- 1; i
>= 0; --i
) {
1210 flatIndex
+= indices
[i
] * stride
;
1211 stride
*= tensorType
.getDimSize(i
);
1213 // Prevent out of bounds accesses. This can happen in invalid code that
1214 // will never execute.
1215 if (static_cast<int>(fromElementsOp
.getElements().size()) <= flatIndex
||
1218 return fromElementsOp
.getElements()[flatIndex
];
1221 // If this is an elements attribute, query the value at the given indices.
1222 if (Attribute tensor
= adaptor
.getTensor()) {
1223 auto elementsAttr
= llvm::dyn_cast
<ElementsAttr
>(tensor
);
1224 if (elementsAttr
&& elementsAttr
.isValidIndex(indices
))
1225 return elementsAttr
.getValues
<Attribute
>()[indices
];
1231 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1232 MLIRContext
*context
) {
1233 results
.add
<ExtractFromTensorCast
>(context
);
1236 //===----------------------------------------------------------------------===//
1238 //===----------------------------------------------------------------------===//
1240 void FromElementsOp::getAsmResultNames(
1241 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1242 setNameFn(getResult(), "from_elements");
1245 void FromElementsOp::build(OpBuilder
&builder
, OperationState
&result
,
1246 ValueRange elements
) {
1247 assert(!elements
.empty() && "expected at least one element");
1248 Type resultType
= RankedTensorType::get(
1249 {static_cast<int64_t>(elements
.size())}, elements
.front().getType());
1250 build(builder
, result
, resultType
, elements
);
1253 OpFoldResult
FromElementsOp::fold(FoldAdaptor adaptor
) {
1254 if (!llvm::is_contained(adaptor
.getElements(), nullptr))
1255 return DenseElementsAttr::get(getType(), adaptor
.getElements());
1261 // Pushes the index_casts that occur before extractions to after the extract.
1262 // This minimizes type conversion in some cases and enables the extract
1263 // canonicalizer. This changes:
1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1268 // to the following:
1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1271 // %cast = arith.index_cast %extract : i32 to index
1273 // to just %element.
1275 // Consider expanding this to a template and handle all tensor cast
1277 struct ExtractElementFromIndexCast
1278 : public OpRewritePattern
<tensor::ExtractOp
> {
1279 using OpRewritePattern
<tensor::ExtractOp
>::OpRewritePattern
;
1281 LogicalResult
matchAndRewrite(tensor::ExtractOp extract
,
1282 PatternRewriter
&rewriter
) const final
{
1283 Location loc
= extract
.getLoc();
1284 auto indexCast
= extract
.getTensor().getDefiningOp
<arith::IndexCastOp
>();
1288 Type elementTy
= getElementTypeOrSelf(indexCast
.getIn());
1290 auto newExtract
= rewriter
.create
<tensor::ExtractOp
>(
1291 loc
, elementTy
, indexCast
.getIn(), extract
.getIndices());
1293 rewriter
.replaceOpWithNewOp
<arith::IndexCastOp
>(extract
, extract
.getType(),
1302 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1303 MLIRContext
*context
) {
1304 results
.add
<ExtractElementFromIndexCast
>(context
);
1307 //===----------------------------------------------------------------------===//
1309 //===----------------------------------------------------------------------===//
1311 void GatherOp::getAsmResultNames(
1312 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1313 setNameFn(getResult(), "gather");
1316 /// Return the inferred result type for a gatherOp where:
1317 /// - sourceType is the type of the source tensor gathered from
1318 /// - indicesType is the type of the indices used to gather
1319 /// - gatherDims are the dims along which the gather occurs.
1320 /// Return a full rank or ranked-reduced variant of the type depending on
1321 /// the value of rankReduced.
1323 /// The leading dimensions of the index tensor give the result tensor its
1324 /// leading dimensions.
1325 /// The trailing dimensions of the result tensor are obtained from the source
1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1327 /// rankedReduced is false), or skipping them (otherwise).
1328 RankedTensorType
GatherOp::inferResultType(RankedTensorType sourceType
,
1329 RankedTensorType indicesType
,
1330 ArrayRef
<int64_t> gatherDims
,
1332 SmallVector
<int64_t> resultShape(indicesType
.getShape().drop_back());
1333 resultShape
.reserve(resultShape
.size() + sourceType
.getRank());
1334 for (int64_t idx
: llvm::seq
<int64_t>(0, sourceType
.getRank())) {
1335 if (std::binary_search(gatherDims
.begin(), gatherDims
.end(), idx
)) {
1337 resultShape
.push_back(1);
1340 resultShape
.push_back(sourceType
.getDimSize(idx
));
1342 return RankedTensorType::Builder(sourceType
).setShape(resultShape
);
1345 static LogicalResult
1346 verifyGatherOrScatterDims(Operation
*op
, ArrayRef
<int64_t> dims
,
1347 ArrayRef
<int64_t> indices
, int64_t rank
,
1348 StringRef gatherOrScatter
, StringRef sourceOrDest
) {
1350 return op
->emitOpError(gatherOrScatter
) << "_dims must be non-empty";
1352 int64_t numGatherDims
= dims
.size();
1353 if (numGatherDims
> rank
)
1354 return op
->emitOpError(gatherOrScatter
)
1355 << "_dims overflow " << sourceOrDest
<< " rank";
1356 if (indices
.empty() || indices
.back() != numGatherDims
)
1357 return op
->emitOpError(gatherOrScatter
)
1358 << "_dims length must match the size of last dimension of indices";
1359 for (int64_t val
: dims
) {
1361 return op
->emitOpError(gatherOrScatter
)
1362 << "_dims value must be non-negative";
1364 return op
->emitOpError(gatherOrScatter
)
1365 << "_dims value must be smaller than " << sourceOrDest
<< " rank";
1367 for (int64_t i
= 1; i
< numGatherDims
; ++i
) {
1368 if (dims
[i
- 1] >= dims
[i
])
1369 return op
->emitOpError(gatherOrScatter
)
1370 << "_dims values must be strictly increasing";
1375 LogicalResult
GatherOp::verify() {
1376 int64_t sourceRank
= getSourceType().getRank();
1377 ArrayRef
<int64_t> gatherDims
= getGatherDims();
1378 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims
,
1379 getIndicesType().getShape(), sourceRank
,
1380 "gather", "source")))
1383 RankedTensorType expectedResultType
= GatherOp::inferResultType(
1384 getSourceType(), getIndicesType(), gatherDims
, /*rankReduced=*/false);
1385 RankedTensorType expectedRankReducedResultType
= GatherOp::inferResultType(
1386 getSourceType(), getIndicesType(), gatherDims
, /*rankReduced=*/true);
1387 if (getResultType() != expectedResultType
&&
1388 getResultType() != expectedRankReducedResultType
) {
1389 return emitOpError("result type "
1392 << expectedResultType
<< " or its rank-reduced variant "
1393 << expectedRankReducedResultType
<< " (got: " << getResultType()
1400 OpFoldResult
GatherOp::fold(FoldAdaptor adaptor
) {
1401 if (OpFoldResult reshapedSource
= reshapeConstantSource(
1402 llvm::dyn_cast_if_present
<DenseElementsAttr
>(adaptor
.getSource()),
1403 getResult().getType()))
1404 return reshapedSource
;
1408 //===----------------------------------------------------------------------===//
1410 //===----------------------------------------------------------------------===//
1412 void InsertOp::getAsmResultNames(
1413 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1414 setNameFn(getResult(), "inserted");
1417 LogicalResult
InsertOp::verify() {
1418 // Verify the # indices match if we have a ranked type.
1419 auto destType
= llvm::cast
<RankedTensorType
>(getDest().getType());
1420 if (destType
.getRank() != static_cast<int64_t>(getIndices().size()))
1421 return emitOpError("incorrect number of indices");
1425 OpFoldResult
InsertOp::fold(FoldAdaptor adaptor
) {
1426 Attribute scalar
= adaptor
.getScalar();
1427 Attribute dest
= adaptor
.getDest();
1429 if (auto splatDest
= llvm::dyn_cast
<SplatElementsAttr
>(dest
))
1430 if (scalar
== splatDest
.getSplatValue
<Attribute
>())
1435 //===----------------------------------------------------------------------===//
1437 //===----------------------------------------------------------------------===//
1439 void GenerateOp::getAsmResultNames(
1440 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1441 setNameFn(getResult(), "generated");
1444 LogicalResult
GenerateOp::reifyResultShapes(
1445 OpBuilder
&builder
, ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
1446 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(getType().getRank()));
1448 for (auto dim
: llvm::seq
<int64_t>(0, getType().getRank())) {
1449 if (getType().isDynamicDim(dim
)) {
1450 reifiedReturnShapes
[0][dim
] = getOperand(idx
++);
1452 reifiedReturnShapes
[0][dim
] =
1453 builder
.getIndexAttr(getType().getDimSize(dim
));
1459 LogicalResult
GenerateOp::verify() {
1460 // Ensure that the tensor type has as many dynamic dimensions as are
1461 // specified by the operands.
1462 RankedTensorType resultType
= llvm::cast
<RankedTensorType
>(getType());
1463 if (getNumOperands() != resultType
.getNumDynamicDims())
1464 return emitError("must have as many index operands as dynamic extents "
1465 "in the result type");
1469 LogicalResult
GenerateOp::verifyRegions() {
1470 RankedTensorType resultTy
= llvm::cast
<RankedTensorType
>(getType());
1471 // Ensure that region arguments span the index space.
1472 if (!llvm::all_of(getBody().getArgumentTypes(),
1473 [](Type ty
) { return ty
.isIndex(); }))
1474 return emitError("all body arguments must be index");
1475 if (getBody().getNumArguments() != resultTy
.getRank())
1476 return emitError("must have one body argument per input dimension");
1478 // Ensure that the region yields an element of the right type.
1479 auto yieldOp
= cast
<YieldOp
>(getBody().getBlocks().front().getTerminator());
1481 if (yieldOp
.getValue().getType() != resultTy
.getElementType())
1483 "body must be terminated with a `yield` operation of the tensor "
1489 void GenerateOp::build(
1490 OpBuilder
&b
, OperationState
&result
, Type resultTy
,
1491 ValueRange dynamicExtents
,
1492 function_ref
<void(OpBuilder
&, Location
, ValueRange
)> bodyBuilder
) {
1493 build(b
, result
, resultTy
, dynamicExtents
);
1495 // Build and populate body.
1496 OpBuilder::InsertionGuard
guard(b
);
1497 Region
*bodyRegion
= result
.regions
.front().get();
1498 auto rank
= llvm::cast
<RankedTensorType
>(resultTy
).getRank();
1499 SmallVector
<Type
, 2> argumentTypes(rank
, b
.getIndexType());
1500 SmallVector
<Location
, 2> argumentLocs(rank
, result
.location
);
1502 b
.createBlock(bodyRegion
, bodyRegion
->end(), argumentTypes
, argumentLocs
);
1503 bodyBuilder(b
, result
.location
, bodyBlock
->getArguments());
1508 /// Canonicalizes tensor.generate operations with a constant
1509 /// operand into the equivalent operation with the operand expressed in the
1510 /// result type, instead. We also insert a type cast to make sure that the
1511 /// resulting IR is still well-typed.
1512 struct StaticTensorGenerate
: public OpRewritePattern
<GenerateOp
> {
1513 using OpRewritePattern
<GenerateOp
>::OpRewritePattern
;
1515 LogicalResult
matchAndRewrite(GenerateOp generateOp
,
1516 PatternRewriter
&rewriter
) const final
{
1517 SmallVector
<Value
> foldedDynamicSizes
;
1518 RankedTensorType foldedTensorType
= foldDynamicToStaticDimSizes(
1519 generateOp
.getType(), generateOp
.getDynamicExtents(),
1520 foldedDynamicSizes
);
1522 // Stop here if no dynamic size was promoted to static.
1523 if (foldedTensorType
== generateOp
.getType())
1526 auto loc
= generateOp
.getLoc();
1528 rewriter
.create
<GenerateOp
>(loc
, foldedTensorType
, foldedDynamicSizes
);
1529 rewriter
.inlineRegionBefore(generateOp
.getBody(), newOp
.getBody(),
1530 newOp
.getBody().begin());
1531 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(generateOp
,
1532 generateOp
.getType(), newOp
);
1537 /// Canonicalizes the pattern of the form
1539 /// %tensor = tensor.generate %x {
1540 /// ^bb0(%arg0: index):
1542 /// yield %1 : index
1543 /// } : tensor<?xindex>
1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1547 /// tensor.generate operation has no side-effects.
1548 struct ExtractFromTensorGenerate
: public OpRewritePattern
<tensor::ExtractOp
> {
1549 using OpRewritePattern
<tensor::ExtractOp
>::OpRewritePattern
;
1551 LogicalResult
matchAndRewrite(tensor::ExtractOp extract
,
1552 PatternRewriter
&rewriter
) const final
{
1553 auto tensorFromElements
= extract
.getTensor().getDefiningOp
<GenerateOp
>();
1554 if (!tensorFromElements
|| !wouldOpBeTriviallyDead(tensorFromElements
))
1558 Block
*body
= &tensorFromElements
.getBody().front();
1559 mapping
.map(body
->getArguments(), extract
.getIndices());
1560 for (auto &op
: body
->without_terminator())
1561 rewriter
.clone(op
, mapping
);
1563 auto yield
= cast
<YieldOp
>(body
->getTerminator());
1565 rewriter
.replaceOp(extract
, mapping
.lookupOrDefault(yield
.getValue()));
1572 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
1573 MLIRContext
*context
) {
1574 // TODO: Move extract pattern to tensor::ExtractOp.
1575 results
.add
<ExtractFromTensorGenerate
, StaticTensorGenerate
>(context
);
1578 //===----------------------------------------------------------------------===//
1580 //===----------------------------------------------------------------------===//
1582 void RankOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
1583 setNameFn(getResult(), "rank");
1586 OpFoldResult
RankOp::fold(FoldAdaptor adaptor
) {
1587 // Constant fold rank when the rank of the operand is known.
1588 auto type
= getOperand().getType();
1589 auto shapedType
= llvm::dyn_cast
<ShapedType
>(type
);
1590 if (shapedType
&& shapedType
.hasRank())
1591 return IntegerAttr::get(IndexType::get(getContext()), shapedType
.getRank());
1592 return IntegerAttr();
1595 //===----------------------------------------------------------------------===//
1597 //===----------------------------------------------------------------------===//
1599 void ReshapeOp::getAsmResultNames(
1600 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1601 setNameFn(getResult(), "reshape");
1604 static int64_t getNumElements(ShapedType type
) {
1605 int64_t numElements
= 1;
1606 for (auto dim
: type
.getShape())
1611 LogicalResult
ReshapeOp::verify() {
1612 TensorType operandType
= llvm::cast
<TensorType
>(getSource().getType());
1613 TensorType resultType
= llvm::cast
<TensorType
>(getResult().getType());
1615 if (operandType
.getElementType() != resultType
.getElementType())
1616 return emitOpError("element types of source and destination tensor "
1617 "types should be the same");
1620 llvm::cast
<RankedTensorType
>(getShape().getType()).getDimSize(0);
1621 auto resultRankedType
= llvm::dyn_cast
<RankedTensorType
>(resultType
);
1622 auto operandRankedType
= llvm::dyn_cast
<RankedTensorType
>(operandType
);
1624 if (resultRankedType
) {
1625 if (operandRankedType
&& resultRankedType
.hasStaticShape() &&
1626 operandRankedType
.hasStaticShape()) {
1627 if (getNumElements(operandRankedType
) != getNumElements(resultRankedType
))
1628 return emitOpError("source and destination tensor should have the "
1629 "same number of elements");
1631 if (ShapedType::isDynamic(shapeSize
))
1632 return emitOpError("cannot use shape operand with dynamic length to "
1633 "reshape to statically-ranked tensor type");
1634 if (shapeSize
!= resultRankedType
.getRank())
1636 "length of shape operand differs from the result's tensor rank");
1641 OpFoldResult
ReshapeOp::fold(FoldAdaptor adaptor
) {
1642 if (OpFoldResult reshapedSource
= reshapeConstantSource(
1643 llvm::dyn_cast_if_present
<DenseElementsAttr
>(adaptor
.getSource()),
1644 getResult().getType()))
1645 return reshapedSource
;
1647 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1648 // producer's input instead as the original tensor to reshape. This could
1649 // render such producer dead code.
1650 if (auto reshapeOpProducer
= getSource().getDefiningOp
<ReshapeOp
>()) {
1651 getSourceMutable().assign(reshapeOpProducer
.getSource());
1655 auto source
= getSource();
1656 auto sourceTy
= dyn_cast
<RankedTensorType
>(source
.getType());
1657 auto resultTy
= dyn_cast
<RankedTensorType
>(getType());
1658 if (!sourceTy
|| !resultTy
|| sourceTy
!= resultTy
)
1661 // If the source and result are both 1D tensors and have the same type, the
1662 // reshape has no effect, even if the tensor is dynamically shaped.
1663 if (sourceTy
.getRank() == 1)
1666 if (auto fromElements
= getShape().getDefiningOp
<tensor::FromElementsOp
>()) {
1667 auto elements
= fromElements
.getElements();
1669 sourceTy
.getRank() == static_cast<int64_t>(elements
.size());
1670 for (int id
= 0, s
= elements
.size(); id
< s
&& dynamicNoop
; ++id
) {
1671 auto element
= elements
[id
];
1673 if (auto cst
= getConstantIntValue(element
)) {
1674 dynamicNoop
&= cst
.value() == sourceTy
.getDimSize(id
);
1678 if (auto dimOp
= element
.getDefiningOp
<tensor::DimOp
>()) {
1679 dynamicNoop
&= dimOp
.getSource() == source
;
1682 auto cst
= getConstantIntValue(dimOp
.getIndex());
1684 cst
.has_value() && cst
.value() == static_cast<int64_t>(id
);
1688 dynamicNoop
= false;
1699 //===----------------------------------------------------------------------===//
1700 // Reassociative reshape ops
1701 //===----------------------------------------------------------------------===//
1703 void CollapseShapeOp::getAsmResultNames(
1704 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1705 setNameFn(getResult(), "collapsed");
1708 void ExpandShapeOp::getAsmResultNames(
1709 function_ref
<void(Value
, StringRef
)> setNameFn
) {
1710 setNameFn(getResult(), "expanded");
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim
) {
1714 assert(resultDim
>= 0 && resultDim
< getResultType().getRank() &&
1715 "invalid resultDim");
1716 for (const auto &it
: llvm::enumerate(getReassociationIndices()))
1717 if (llvm::is_contained(it
.value(), resultDim
))
1719 llvm_unreachable("could not find reassociation group");
1722 FailureOr
<SmallVector
<OpFoldResult
>>
1723 ExpandShapeOp::inferOutputShape(OpBuilder
&b
, Location loc
,
1724 RankedTensorType expandedType
,
1725 ArrayRef
<ReassociationIndices
> reassociation
,
1726 ArrayRef
<OpFoldResult
> inputShape
) {
1727 std::optional
<SmallVector
<OpFoldResult
>> outputShape
=
1728 inferExpandShapeOutputShape(b
, loc
, expandedType
, reassociation
,
1732 return *outputShape
;
1735 void ExpandShapeOp::build(OpBuilder
&builder
, OperationState
&result
,
1736 Type resultType
, Value src
,
1737 ArrayRef
<ReassociationIndices
> reassociation
,
1738 ArrayRef
<OpFoldResult
> outputShape
) {
1739 auto [staticOutputShape
, dynamicOutputShape
] =
1740 decomposeMixedValues(SmallVector
<OpFoldResult
>(outputShape
));
1741 build(builder
, result
, cast
<RankedTensorType
>(resultType
), src
,
1742 getReassociationIndicesAttribute(builder
, reassociation
),
1743 dynamicOutputShape
, staticOutputShape
);
1746 void ExpandShapeOp::build(OpBuilder
&builder
, OperationState
&result
,
1747 Type resultType
, Value src
,
1748 ArrayRef
<ReassociationIndices
> reassociation
) {
1749 SmallVector
<OpFoldResult
> inputShape
=
1750 getMixedSizes(builder
, result
.location
, src
);
1751 auto tensorResultTy
= cast
<RankedTensorType
>(resultType
);
1752 FailureOr
<SmallVector
<OpFoldResult
>> outputShape
= inferOutputShape(
1753 builder
, result
.location
, tensorResultTy
, reassociation
, inputShape
);
1754 SmallVector
<OpFoldResult
> outputShapeOrEmpty
;
1755 if (succeeded(outputShape
)) {
1756 outputShapeOrEmpty
= *outputShape
;
1758 build(builder
, result
, tensorResultTy
, src
, reassociation
,
1759 outputShapeOrEmpty
);
1762 SmallVector
<AffineMap
, 4> CollapseShapeOp::getReassociationMaps() {
1763 return getSymbolLessAffineMaps(getReassociationExprs());
1765 SmallVector
<ReassociationExprs
, 4> CollapseShapeOp::getReassociationExprs() {
1766 return convertReassociationIndicesToExprs(getContext(),
1767 getReassociationIndices());
1770 SmallVector
<AffineMap
, 4> ExpandShapeOp::getReassociationMaps() {
1771 return getSymbolLessAffineMaps(getReassociationExprs());
1773 SmallVector
<ReassociationExprs
, 4> ExpandShapeOp::getReassociationExprs() {
1774 return convertReassociationIndicesToExprs(getContext(),
1775 getReassociationIndices());
1778 RankedTensorType
CollapseShapeOp::inferCollapsedType(
1779 RankedTensorType type
, SmallVector
<ReassociationIndices
> reassociation
) {
1780 return inferCollapsedType(
1781 type
, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1782 type
.getContext(), reassociation
)));
1785 /// Compute the RankedTensorType obtained by applying `reassociation` to
1788 CollapseShapeOp::inferCollapsedType(RankedTensorType type
,
1789 ArrayRef
<AffineMap
> reassociation
) {
1790 auto shape
= type
.getShape();
1791 SmallVector
<int64_t, 4> newShape
;
1792 newShape
.reserve(reassociation
.size());
1794 // Use the fact that reassociation is valid to simplify the logic: only use
1796 assert(isReassociationValid(reassociation
) && "invalid reassociation");
1797 unsigned currentDim
= 0;
1798 for (AffineMap m
: reassociation
) {
1799 unsigned dim
= m
.getNumResults();
1800 auto band
= shape
.slice(currentDim
, dim
);
1802 if (llvm::is_contained(band
, ShapedType::kDynamic
))
1803 size
= ShapedType::kDynamic
;
1805 for (unsigned d
= 0; d
< dim
; ++d
)
1806 size
*= shape
[currentDim
+ d
];
1807 newShape
.push_back(size
);
1811 return RankedTensorType::get(newShape
, type
.getElementType());
1814 void CollapseShapeOp::build(OpBuilder
&b
, OperationState
&result
, Value src
,
1815 ArrayRef
<ReassociationIndices
> reassociation
,
1816 ArrayRef
<NamedAttribute
> attrs
) {
1817 auto resultType
= inferCollapsedType(
1818 llvm::cast
<RankedTensorType
>(src
.getType()),
1819 getSymbolLessAffineMaps(
1820 convertReassociationIndicesToExprs(b
.getContext(), reassociation
)));
1821 result
.addAttribute(getReassociationAttrStrName(),
1822 getReassociationIndicesAttribute(b
, reassociation
));
1823 build(b
, result
, resultType
, src
, attrs
);
1826 template <typename TensorReshapeOp
, bool isExpansion
= std::is_same
<
1827 TensorReshapeOp
, ExpandShapeOp
>::value
>
1828 static LogicalResult
verifyTensorReshapeOp(TensorReshapeOp op
,
1829 RankedTensorType expandedType
,
1830 RankedTensorType collapsedType
) {
1832 verifyReshapeLikeTypes(op
, expandedType
, collapsedType
, isExpansion
)))
1835 auto maps
= op
.getReassociationMaps();
1836 RankedTensorType expectedType
=
1837 CollapseShapeOp::inferCollapsedType(expandedType
, maps
);
1838 if (!isSameTypeWithoutEncoding(collapsedType
, expectedType
))
1839 return op
.emitOpError("expected collapsed type to be ")
1840 << expectedType
<< ", but got " << collapsedType
;
1844 LogicalResult
ExpandShapeOp::verify() {
1845 auto srcType
= getSrcType();
1846 auto resultType
= getResultType();
1848 if ((int64_t)getStaticOutputShape().size() != resultType
.getRank())
1849 return emitOpError("expected number of static shape dims to be equal to "
1850 "the output rank (")
1851 << resultType
.getRank() << ") but found "
1852 << getStaticOutputShape().size() << " inputs instead";
1854 if ((int64_t)getOutputShape().size() !=
1855 llvm::count(getStaticOutputShape(), ShapedType::kDynamic
))
1856 return emitOpError("mismatch in dynamic dims in output_shape and "
1857 "static_output_shape: static_output_shape has ")
1858 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic
)
1859 << " dynamic dims while output_shape has " << getOutputShape().size()
1862 return verifyTensorReshapeOp(*this, resultType
, srcType
);
1865 LogicalResult
CollapseShapeOp::verify() {
1866 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1870 /// Reshape of a splat constant can be replaced with a constant of the result
1872 template <typename TensorReshapeOp
>
1873 struct FoldReshapeWithConstant
: OpRewritePattern
<TensorReshapeOp
> {
1874 using OpRewritePattern
<TensorReshapeOp
>::OpRewritePattern
;
1875 LogicalResult
matchAndRewrite(TensorReshapeOp reshapeOp
,
1876 PatternRewriter
&rewriter
) const override
{
1877 DenseElementsAttr attr
;
1878 if (!matchPattern(reshapeOp
.getSrc(), m_Constant(&attr
)))
1880 if (!attr
|| !attr
.isSplat())
1882 DenseElementsAttr newAttr
= DenseElementsAttr::getFromRawBuffer(
1883 reshapeOp
.getResultType(), attr
.getRawData());
1884 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(reshapeOp
, newAttr
);
1889 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1890 template <typename TensorReshapeOp
>
1891 class FoldReshapeWithSplat
: public OpRewritePattern
<TensorReshapeOp
> {
1893 using OpRewritePattern
<TensorReshapeOp
>::OpRewritePattern
;
1895 LogicalResult
matchAndRewrite(TensorReshapeOp reshapeOp
,
1896 PatternRewriter
&rewriter
) const override
{
1897 auto splatOp
= reshapeOp
.getSrc().template getDefiningOp
<tensor::SplatOp
>();
1898 if (!splatOp
|| !splatOp
.getAggregate().getType().hasStaticShape())
1901 rewriter
.replaceOpWithNewOp
<tensor::SplatOp
>(
1902 reshapeOp
, reshapeOp
.getResultType(), splatOp
.getInput());
1907 /// Reshape of a FromElements can be replaced with a FromElements of the
1909 template <typename TensorReshapeOp
>
1910 struct FoldReshapeWithFromElements
: OpRewritePattern
<TensorReshapeOp
> {
1911 using OpRewritePattern
<TensorReshapeOp
>::OpRewritePattern
;
1912 LogicalResult
matchAndRewrite(TensorReshapeOp reshapeOp
,
1913 PatternRewriter
&rewriter
) const override
{
1915 reshapeOp
.getSrc().template getDefiningOp
<FromElementsOp
>();
1919 auto shapedTy
= llvm::cast
<ShapedType
>(reshapeOp
.getType());
1921 if (!shapedTy
.hasStaticShape())
1924 rewriter
.replaceOpWithNewOp
<FromElementsOp
>(reshapeOp
, reshapeOp
.getType(),
1925 fromElements
.getElements());
1930 // Fold CastOp into CollapseShapeOp when adding static information.
1931 struct FoldCollapseOfCastOp
: public OpRewritePattern
<CollapseShapeOp
> {
1932 using OpRewritePattern
<CollapseShapeOp
>::OpRewritePattern
;
1934 LogicalResult
matchAndRewrite(CollapseShapeOp collapseShapeOp
,
1935 PatternRewriter
&rewriter
) const override
{
1936 auto castOp
= collapseShapeOp
.getSrc().getDefiningOp
<tensor::CastOp
>();
1937 if (!tensor::canFoldIntoConsumerOp(castOp
))
1940 RankedTensorType srcType
=
1941 llvm::cast
<RankedTensorType
>(castOp
.getSource().getType());
1942 RankedTensorType newResultType
= CollapseShapeOp::inferCollapsedType(
1943 srcType
, collapseShapeOp
.getReassociationMaps());
1945 if (newResultType
== collapseShapeOp
.getResultType()) {
1946 rewriter
.modifyOpInPlace(collapseShapeOp
, [&]() {
1947 collapseShapeOp
.getSrcMutable().assign(castOp
.getSource());
1950 auto newOp
= rewriter
.create
<CollapseShapeOp
>(
1951 collapseShapeOp
.getLoc(), newResultType
, castOp
.getSource(),
1952 collapseShapeOp
.getReassociation());
1953 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(
1954 collapseShapeOp
, collapseShapeOp
.getResultType(), newOp
);
1960 struct FoldDimOfExpandShape
: public OpRewritePattern
<DimOp
> {
1961 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
1963 LogicalResult
matchAndRewrite(DimOp dimOp
,
1964 PatternRewriter
&rewriter
) const override
{
1965 auto expandShapeOp
= dimOp
.getSource().getDefiningOp
<ExpandShapeOp
>();
1969 // Only constant dimension values are supported.
1970 std::optional
<int64_t> dim
= dimOp
.getConstantIndex();
1971 if (!dim
.has_value())
1974 // Skip static dims. These are folded to constant ops.
1975 RankedTensorType resultType
= expandShapeOp
.getResultType();
1976 if (!resultType
.isDynamicDim(*dim
))
1979 // Find reassociation group that contains this result dimension.
1980 int64_t srcDim
= expandShapeOp
.getCorrespondingSourceDim(*dim
);
1982 // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1983 // ExpandShapeOp would be ambiguous.)
1984 int64_t product
= 1;
1985 ReassociationIndices grp
= expandShapeOp
.getReassociationIndices()[srcDim
];
1986 for (int64_t d
: grp
) {
1988 assert(!resultType
.isDynamicDim(d
) && "expected static dim");
1989 product
*= resultType
.getDimSize(d
);
1993 // result dim size = src dim size / (product(other dims in reassoc group))
1995 rewriter
.create
<DimOp
>(dimOp
.getLoc(), expandShapeOp
.getSrc(), srcDim
);
1997 bindSymbols(dimOp
.getContext(), expr
);
1998 rewriter
.replaceOpWithNewOp
<affine::AffineApplyOp
>(
1999 dimOp
, expr
.floorDiv(product
), srcDimSz
);
2004 struct FoldDimOfCollapseShape
: public OpRewritePattern
<DimOp
> {
2005 using OpRewritePattern
<DimOp
>::OpRewritePattern
;
2007 LogicalResult
matchAndRewrite(DimOp dimOp
,
2008 PatternRewriter
&rewriter
) const override
{
2009 auto collapseShapeOp
= dimOp
.getSource().getDefiningOp
<CollapseShapeOp
>();
2010 if (!collapseShapeOp
)
2013 // Only constant dimension values are supported.
2014 std::optional
<int64_t> dim
= dimOp
.getConstantIndex();
2015 if (!dim
.has_value())
2018 // Skip static dims. These are folded to constant ops.
2019 RankedTensorType resultType
= collapseShapeOp
.getResultType();
2020 if (!resultType
.isDynamicDim(*dim
))
2023 // Get reassociation group of the result dimension.
2024 ReassociationIndices group
=
2025 collapseShapeOp
.getReassociationIndices()[*dim
];
2027 // result dim size = product(dims in reassoc group)
2028 SmallVector
<Value
> srcDimSizes
;
2029 SmallVector
<AffineExpr
> syms
;
2031 for (const auto &it
: llvm::enumerate(group
)) {
2032 srcDimSizes
.push_back(rewriter
.create
<DimOp
>(
2033 dimOp
.getLoc(), collapseShapeOp
.getSrc(), it
.value()));
2034 syms
.push_back(rewriter
.getAffineSymbolExpr(it
.index()));
2035 product
= product
? product
* syms
.back() : syms
.back();
2037 rewriter
.replaceOpWithNewOp
<affine::AffineApplyOp
>(dimOp
, product
,
2043 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2044 /// matching constant output_shape operands of the expand. This makes the
2045 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2046 /// propagated further.
2047 struct ConvertToStaticExpandShape
: public OpRewritePattern
<ExpandShapeOp
> {
2048 using OpRewritePattern
<ExpandShapeOp
>::OpRewritePattern
;
2050 LogicalResult
matchAndRewrite(ExpandShapeOp expandOp
,
2051 PatternRewriter
&rewriter
) const override
{
2052 auto castOp
= expandOp
.getSrc().getDefiningOp
<CastOp
>();
2053 if (!canFoldIntoConsumerOp(castOp
))
2056 ArrayRef
<int64_t> castSrcShape
= castOp
.getSource().getType().getShape();
2057 SmallVector
<ReassociationIndices
, 4> reassoc
=
2058 expandOp
.getReassociationIndices();
2060 SmallVector
<int64_t> newOutputShape(expandOp
.getResultType().getShape());
2061 SmallVector
<Value
> dynamicOutputShape
;
2062 auto outputIt
= expandOp
.getOutputShape().begin();
2064 for (const auto &[inputDim
, innerReassoc
] : llvm::enumerate(reassoc
)) {
2065 for (uint64_t outDim
: innerReassoc
) {
2066 if (!ShapedType::isDynamic(newOutputShape
[outDim
]))
2069 // If the cast's src type is dynamic, don't infer any of the
2070 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2071 // least one of the expanded dimensions to be dynamic if the input is
2073 Value val
= *outputIt
;
2075 if (ShapedType::isDynamic(castSrcShape
[inputDim
])) {
2076 dynamicOutputShape
.push_back(val
);
2081 if (matchPattern(val
, m_ConstantInt(&cst
))) {
2082 newOutputShape
[outDim
] = cst
.getSExtValue();
2084 dynamicOutputShape
.push_back(val
);
2089 // Couldn't match any values, nothing to change
2090 if (expandOp
.getOutputShape().size() == dynamicOutputShape
.size())
2093 // Calculate the input shape from the output
2094 SmallVector
<int64_t> newInputShape(expandOp
.getSrcType().getRank(), 1l);
2095 for (auto inDim
: llvm::seq
<int>(0, newInputShape
.size())) {
2096 for (auto outDim
: reassoc
[inDim
]) {
2097 auto ofr
= newOutputShape
[outDim
];
2098 if (ShapedType::isDynamic(ofr
)) {
2099 newInputShape
[inDim
] = ShapedType::kDynamic
;
2102 newInputShape
[inDim
] *= ofr
;
2106 SmallVector
<OpFoldResult
> outputOfr
=
2107 getMixedValues(newOutputShape
, dynamicOutputShape
, rewriter
);
2108 auto inputType
= RankedTensorType::get(
2109 newInputShape
, expandOp
.getSrcType().getElementType());
2110 auto outputType
= RankedTensorType::get(
2111 newOutputShape
, expandOp
.getSrcType().getElementType());
2112 auto inputCast
= rewriter
.create
<CastOp
>(expandOp
.getLoc(), inputType
,
2114 auto newExpand
= rewriter
.create
<ExpandShapeOp
>(
2115 expandOp
.getLoc(), outputType
, inputCast
.getResult(),
2116 expandOp
.getReassociationIndices(), outputOfr
);
2117 rewriter
.replaceOpWithNewOp
<CastOp
>(expandOp
, expandOp
.getType(),
2118 newExpand
.getResult());
2124 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2125 MLIRContext
*context
) {
2127 ComposeReassociativeReshapeOps
<ExpandShapeOp
, ReshapeOpKind::kExpand
>,
2128 ComposeExpandOfCollapseOp
<ExpandShapeOp
, CollapseShapeOp
>,
2129 ConvertToStaticExpandShape
, FoldReshapeWithConstant
<ExpandShapeOp
>,
2130 FoldReshapeWithSplat
<ExpandShapeOp
>,
2131 FoldReshapeWithFromElements
<ExpandShapeOp
>, FoldDimOfExpandShape
,
2132 FoldDimOfCollapseShape
>(context
);
2135 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2136 MLIRContext
*context
) {
2138 ComposeReassociativeReshapeOps
<CollapseShapeOp
, ReshapeOpKind::kCollapse
>,
2139 ComposeCollapseOfExpandOp
<CollapseShapeOp
, ExpandShapeOp
, CastOp
,
2140 tensor::DimOp
, RankedTensorType
>,
2141 FoldReshapeWithConstant
<CollapseShapeOp
>,
2142 FoldReshapeWithSplat
<CollapseShapeOp
>,
2143 FoldReshapeWithFromElements
<CollapseShapeOp
>, FoldCollapseOfCastOp
>(
2147 OpFoldResult
ExpandShapeOp::fold(FoldAdaptor adaptor
) {
2148 return foldReshapeOp
<ExpandShapeOp
, CollapseShapeOp
>(*this,
2149 adaptor
.getOperands());
2152 OpFoldResult
CollapseShapeOp::fold(FoldAdaptor adaptor
) {
2153 return foldReshapeOp
<CollapseShapeOp
, ExpandShapeOp
>(*this,
2154 adaptor
.getOperands());
2157 //===----------------------------------------------------------------------===//
2159 //===----------------------------------------------------------------------===//
2161 void ExtractSliceOp::getAsmResultNames(
2162 function_ref
<void(Value
, StringRef
)> setNameFn
) {
2163 setNameFn(getResult(), "extracted_slice");
2166 /// An extract_slice result type can be inferred, when it is not
2167 /// rank-reduced, from the source type and the static representation of
2168 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2169 RankedTensorType
ExtractSliceOp::inferResultType(
2170 RankedTensorType sourceTensorType
, ArrayRef
<int64_t> staticOffsets
,
2171 ArrayRef
<int64_t> staticSizes
, ArrayRef
<int64_t> staticStrides
) {
2172 // An extract_slice op may specify only a leading subset of offset/sizes/
2173 // strides in which case we complete with offset=0, sizes from memref type
2175 assert(static_cast<int64_t>(staticSizes
.size()) ==
2176 sourceTensorType
.getRank() &&
2177 "unexpected staticSizes not equal to rank of source");
2178 return RankedTensorType::get(staticSizes
, sourceTensorType
.getElementType(),
2179 sourceTensorType
.getEncoding());
2182 RankedTensorType
ExtractSliceOp::inferResultType(
2183 RankedTensorType sourceTensorType
, ArrayRef
<OpFoldResult
> offsets
,
2184 ArrayRef
<OpFoldResult
> sizes
, ArrayRef
<OpFoldResult
> strides
) {
2185 SmallVector
<int64_t> staticOffsets
, staticSizes
, staticStrides
;
2186 SmallVector
<Value
> dynamicOffsets
, dynamicSizes
, dynamicStrides
;
2187 dispatchIndexOpFoldResults(offsets
, dynamicOffsets
, staticOffsets
);
2188 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticSizes
);
2189 dispatchIndexOpFoldResults(strides
, dynamicStrides
, staticStrides
);
2190 return ExtractSliceOp::inferResultType(sourceTensorType
, staticOffsets
,
2191 staticSizes
, staticStrides
);
2194 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2195 /// number of sizes), drop as many size 1 as needed to produce an inferred
2196 /// type with the desired rank.
2198 /// Note that there may be multiple ways to compute this rank-reduced type:
2199 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2201 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2202 RankedTensorType
ExtractSliceOp::inferCanonicalRankReducedResultType(
2203 unsigned desiredResultRank
, RankedTensorType sourceRankedTensorType
,
2204 ArrayRef
<int64_t> offsets
, ArrayRef
<int64_t> sizes
,
2205 ArrayRef
<int64_t> strides
) {
2206 // Type inferred in the absence of rank-reducing behavior.
2207 auto inferredType
= llvm::cast
<RankedTensorType
>(
2208 inferResultType(sourceRankedTensorType
, offsets
, sizes
, strides
));
2209 int rankDiff
= inferredType
.getRank() - desiredResultRank
;
2211 auto shape
= inferredType
.getShape();
2212 llvm::SmallBitVector dimsToProject
=
2213 getPositionsOfShapeOne(rankDiff
, shape
);
2214 SmallVector
<int64_t> projectedShape
;
2215 // Best effort rank-reducing: drop 1s in order.
2216 for (unsigned pos
= 0, e
= shape
.size(); pos
< e
; ++pos
)
2217 if (!dimsToProject
.test(pos
))
2218 projectedShape
.push_back(shape
[pos
]);
2220 RankedTensorType::get(projectedShape
, inferredType
.getElementType());
2222 return inferredType
;
2225 RankedTensorType
ExtractSliceOp::inferCanonicalRankReducedResultType(
2226 unsigned desiredResultRank
, RankedTensorType sourceRankedTensorType
,
2227 ArrayRef
<OpFoldResult
> offsets
, ArrayRef
<OpFoldResult
> sizes
,
2228 ArrayRef
<OpFoldResult
> strides
) {
2229 SmallVector
<int64_t> staticOffsets
, staticSizes
, staticStrides
;
2230 SmallVector
<Value
> dynamicOffsets
, dynamicSizes
, dynamicStrides
;
2231 dispatchIndexOpFoldResults(offsets
, dynamicOffsets
, staticOffsets
);
2232 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticSizes
);
2233 dispatchIndexOpFoldResults(strides
, dynamicStrides
, staticStrides
);
2234 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2235 desiredResultRank
, sourceRankedTensorType
, staticOffsets
, staticSizes
,
2239 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2240 /// result type. If the type passed is nullptr, it is inferred.
2241 void ExtractSliceOp::build(OpBuilder
&b
, OperationState
&result
,
2242 RankedTensorType resultType
, Value source
,
2243 ArrayRef
<OpFoldResult
> offsets
,
2244 ArrayRef
<OpFoldResult
> sizes
,
2245 ArrayRef
<OpFoldResult
> strides
,
2246 ArrayRef
<NamedAttribute
> attrs
) {
2247 SmallVector
<int64_t> staticOffsets
, staticSizes
, staticStrides
;
2248 SmallVector
<Value
> dynamicOffsets
, dynamicSizes
, dynamicStrides
;
2249 dispatchIndexOpFoldResults(offsets
, dynamicOffsets
, staticOffsets
);
2250 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticSizes
);
2251 dispatchIndexOpFoldResults(strides
, dynamicStrides
, staticStrides
);
2252 auto sourceRankedTensorType
= llvm::cast
<RankedTensorType
>(source
.getType());
2253 // Structuring implementation this way avoids duplication between builders.
2255 resultType
= llvm::cast
<RankedTensorType
>(ExtractSliceOp::inferResultType(
2256 sourceRankedTensorType
, staticOffsets
, staticSizes
, staticStrides
));
2258 result
.addAttributes(attrs
);
2259 build(b
, result
, resultType
, source
, dynamicOffsets
, dynamicSizes
,
2260 dynamicStrides
, b
.getDenseI64ArrayAttr(staticOffsets
),
2261 b
.getDenseI64ArrayAttr(staticSizes
),
2262 b
.getDenseI64ArrayAttr(staticStrides
));
2265 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2267 void ExtractSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2268 ArrayRef
<OpFoldResult
> offsets
,
2269 ArrayRef
<OpFoldResult
> sizes
,
2270 ArrayRef
<OpFoldResult
> strides
,
2271 ArrayRef
<NamedAttribute
> attrs
) {
2272 build(b
, result
, RankedTensorType(), source
, offsets
, sizes
, strides
, attrs
);
2275 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2277 void ExtractSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2278 ArrayRef
<Range
> ranges
,
2279 ArrayRef
<NamedAttribute
> attrs
) {
2280 auto [offsets
, sizes
, strides
] = getOffsetsSizesAndStrides(ranges
);
2281 build(b
, result
, RankedTensorType(), source
, offsets
, sizes
, strides
, attrs
);
2284 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2285 /// the type passed is nullptr, it is inferred.
2286 void ExtractSliceOp::build(OpBuilder
&b
, OperationState
&result
,
2287 RankedTensorType resultType
, Value source
,
2288 ValueRange offsets
, ValueRange sizes
,
2289 ValueRange strides
, ArrayRef
<NamedAttribute
> attrs
) {
2290 SmallVector
<OpFoldResult
> offsetValues
= llvm::to_vector
<4>(
2291 llvm::map_range(offsets
, [](Value v
) -> OpFoldResult
{ return v
; }));
2292 SmallVector
<OpFoldResult
> sizeValues
= llvm::to_vector
<4>(
2293 llvm::map_range(sizes
, [](Value v
) -> OpFoldResult
{ return v
; }));
2294 SmallVector
<OpFoldResult
> strideValues
= llvm::to_vector
<4>(
2295 llvm::map_range(strides
, [](Value v
) -> OpFoldResult
{ return v
; }));
2296 build(b
, result
, resultType
, source
, offsetValues
, sizeValues
, strideValues
);
2299 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2300 void ExtractSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2301 ValueRange offsets
, ValueRange sizes
,
2302 ValueRange strides
, ArrayRef
<NamedAttribute
> attrs
) {
2303 build(b
, result
, RankedTensorType(), source
, offsets
, sizes
, strides
, attrs
);
2306 static LogicalResult
produceSliceErrorMsg(SliceVerificationResult result
,
2308 RankedTensorType expectedType
) {
2310 case SliceVerificationResult::Success
:
2312 case SliceVerificationResult::RankTooLarge
:
2313 return op
->emitError("expected rank to be smaller or equal to ")
2314 << "the other rank. ";
2315 case SliceVerificationResult::SizeMismatch
:
2316 return op
->emitError("expected type to be ")
2317 << expectedType
<< " or a rank-reduced version. (size mismatch) ";
2318 case SliceVerificationResult::ElemTypeMismatch
:
2319 return op
->emitError("expected element type to be ")
2320 << expectedType
.getElementType();
2322 llvm_unreachable("unexpected extract_slice op verification result");
2326 /// Verifier for ExtractSliceOp.
2327 LogicalResult
ExtractSliceOp::verify() {
2328 // Verify result type against inferred type.
2329 RankedTensorType expectedType
= ExtractSliceOp::inferResultType(
2330 getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2331 SliceVerificationResult result
= isRankReducedType(expectedType
, getType());
2332 return produceSliceErrorMsg(result
, *this, expectedType
);
2335 llvm::SmallBitVector
ExtractSliceOp::getDroppedDims() {
2336 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2340 ExtractSliceOp::rankReduceIfNeeded(OpBuilder
&b
, Location loc
, Value value
,
2341 ArrayRef
<int64_t> desiredShape
) {
2342 auto sourceTensorType
= llvm::dyn_cast
<RankedTensorType
>(value
.getType());
2343 assert(sourceTensorType
&& "not a ranked tensor type");
2344 auto sourceShape
= sourceTensorType
.getShape();
2345 if (sourceShape
.equals(desiredShape
))
2347 auto maybeRankReductionMask
=
2348 mlir::computeRankReductionMask(sourceShape
, desiredShape
);
2349 if (!maybeRankReductionMask
)
2351 return createCanonicalRankReducingExtractSliceOp(
2353 RankedTensorType::Builder(sourceTensorType
).setShape(desiredShape
));
2356 LogicalResult
ExtractSliceOp::reifyResultShapes(
2357 OpBuilder
&builder
, ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
2358 reifiedReturnShapes
.resize(1);
2359 reifiedReturnShapes
[0].reserve(getType().getRank());
2360 SmallVector
<OpFoldResult
> mixedSizes
= getMixedSizes();
2361 llvm::SmallBitVector droppedDims
= getDroppedDims();
2362 for (const auto &size
: enumerate(mixedSizes
)) {
2363 if (droppedDims
.test(size
.index()))
2365 reifiedReturnShapes
[0].push_back(size
.value());
2371 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2372 /// This essentially pushes memref_cast past its consuming slice when
2373 /// `canFoldIntoConsumerOp` is true.
2377 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2378 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2381 /// is rewritten into:
2383 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2384 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2386 class ExtractSliceOpCastFolder final
: public OpRewritePattern
<ExtractSliceOp
> {
2388 using OpRewritePattern
<ExtractSliceOp
>::OpRewritePattern
;
2390 LogicalResult
matchAndRewrite(ExtractSliceOp sliceOp
,
2391 PatternRewriter
&rewriter
) const override
{
2392 // Any constant operand, just return to let the constant folder kick in.
2393 if (llvm::any_of(sliceOp
.getOperands(), [](Value operand
) {
2394 return matchPattern(operand
, matchConstantIndex());
2398 auto castOp
= sliceOp
.getSource().getDefiningOp
<CastOp
>();
2402 if (!canFoldIntoConsumerOp(castOp
))
2405 // Create folded extract.
2406 Location loc
= sliceOp
.getLoc();
2407 Value newResult
= rewriter
.create
<ExtractSliceOp
>(
2408 loc
, sliceOp
.getType(), castOp
.getSource(), sliceOp
.getOffsets(),
2409 sliceOp
.getSizes(), sliceOp
.getStrides(), sliceOp
.getStaticOffsets(),
2410 sliceOp
.getStaticSizes(), sliceOp
.getStaticStrides());
2411 if (newResult
.getType() != sliceOp
.getType())
2412 newResult
= rewriter
.create
<CastOp
>(loc
, sliceOp
.getType(), newResult
);
2413 rewriter
.replaceOp(sliceOp
, newResult
);
2418 /// Slice elements from `values` into `outValues`. `counts` represents the
2419 /// numbers of elements to stride in the original values for each dimension.
2420 /// The output values can be used to construct a DenseElementsAttr.
2421 template <typename IterTy
, typename ElemTy
>
2422 static void sliceElements(IterTy values
, ArrayRef
<int64_t> counts
,
2423 ArrayRef
<int64_t> offsets
, ArrayRef
<int64_t> sizes
,
2424 ArrayRef
<int64_t> strides
,
2425 llvm::SmallVectorImpl
<ElemTy
> *outValues
) {
2426 assert(offsets
.size() == sizes
.size());
2427 assert(offsets
.size() == strides
.size());
2428 if (offsets
.empty())
2431 int64_t offset
= offsets
.front();
2432 int64_t size
= sizes
.front();
2433 int64_t stride
= strides
.front();
2434 if (offsets
.size() == 1) {
2435 for (int64_t i
= 0; i
< size
; ++i
, offset
+= stride
)
2436 outValues
->push_back(*(values
+ offset
));
2441 for (int64_t i
= 0; i
< size
; ++i
, offset
+= stride
) {
2442 auto begin
= values
+ offset
* counts
.front();
2443 sliceElements
<IterTy
, ElemTy
>(begin
, counts
.drop_front(),
2444 offsets
.drop_front(), sizes
.drop_front(),
2445 strides
.drop_front(), outValues
);
2449 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2450 /// folded operation might introduce more constant data; Users can control
2451 /// their heuristics by the control function.
2452 class ConstantOpExtractSliceFolder final
2453 : public OpRewritePattern
<ExtractSliceOp
> {
2455 using OpRewritePattern
<ExtractSliceOp
>::OpRewritePattern
;
2457 ConstantOpExtractSliceFolder(MLIRContext
*context
,
2458 ControlConstantExtractSliceFusionFn controlFn
)
2459 : OpRewritePattern
<ExtractSliceOp
>(context
),
2460 controlFn(std::move(controlFn
)) {}
2462 LogicalResult
matchAndRewrite(ExtractSliceOp op
,
2463 PatternRewriter
&rewriter
) const override
{
2464 DenseElementsAttr attr
;
2465 if (!matchPattern(op
.getSource(), m_Constant(&attr
)))
2468 // A constant splat is handled by fold().
2472 // Dynamic result shape is not supported.
2473 auto sourceType
= llvm::cast
<ShapedType
>(op
.getSource().getType());
2474 auto resultType
= llvm::cast
<ShapedType
>(op
.getResult().getType());
2475 if (!sourceType
.hasStaticShape() || !resultType
.hasStaticShape())
2478 // Customized control over the folding.
2482 int64_t count
= sourceType
.getNumElements();
2486 // Check if there are any dynamic parts, which are not supported.
2487 auto offsets
= op
.getStaticOffsets();
2488 if (llvm::is_contained(offsets
, ShapedType::kDynamic
))
2490 auto sizes
= op
.getStaticSizes();
2491 if (llvm::is_contained(sizes
, ShapedType::kDynamic
))
2493 auto strides
= op
.getStaticStrides();
2494 if (llvm::is_contained(strides
, ShapedType::kDynamic
))
2497 // Compute the stride for each dimension.
2498 SmallVector
<int64_t> counts
;
2499 ArrayRef
<int64_t> shape
= sourceType
.getShape();
2500 counts
.reserve(shape
.size());
2501 for (int64_t v
: shape
) {
2503 counts
.push_back(count
);
2506 // New attribute constructed by the sliced values.
2507 DenseElementsAttr newAttr
;
2509 if (auto elems
= llvm::dyn_cast
<DenseIntElementsAttr
>(attr
)) {
2510 SmallVector
<APInt
> outValues
;
2511 outValues
.reserve(sourceType
.getNumElements());
2512 sliceElements
<DenseElementsAttr::IntElementIterator
, APInt
>(
2513 elems
.begin(), counts
, offsets
, sizes
, strides
, &outValues
);
2514 newAttr
= DenseElementsAttr::get(resultType
, outValues
);
2515 } else if (auto elems
= llvm::dyn_cast
<DenseFPElementsAttr
>(attr
)) {
2516 SmallVector
<APFloat
> outValues
;
2517 outValues
.reserve(sourceType
.getNumElements());
2518 sliceElements
<DenseElementsAttr::FloatElementIterator
, APFloat
>(
2519 elems
.begin(), counts
, offsets
, sizes
, strides
, &outValues
);
2520 newAttr
= DenseElementsAttr::get(resultType
, outValues
);
2524 rewriter
.replaceOpWithNewOp
<arith::ConstantOp
>(op
, resultType
, newAttr
);
2532 /// This additionally controls whether the fold happens or not. Users can
2533 /// impose their heuristics in the function.
2534 ControlConstantExtractSliceFusionFn controlFn
;
2539 void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2540 RewritePatternSet
&patterns
,
2541 const ControlConstantExtractSliceFusionFn
&controlFn
) {
2542 patterns
.add
<ConstantOpExtractSliceFolder
>(patterns
.getContext(), controlFn
);
2545 /// Return the canonical type of the result of an extract_slice op.
2546 struct SliceReturnTypeCanonicalizer
{
2547 RankedTensorType
operator()(ExtractSliceOp op
,
2548 ArrayRef
<OpFoldResult
> mixedOffsets
,
2549 ArrayRef
<OpFoldResult
> mixedSizes
,
2550 ArrayRef
<OpFoldResult
> mixedStrides
) {
2551 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2552 op
.getType().getRank(), op
.getSourceType(), mixedOffsets
, mixedSizes
,
2557 /// A canonicalizer wrapper to replace ExtractSliceOps.
2558 struct SliceCanonicalizer
{
2559 void operator()(PatternRewriter
&rewriter
, ExtractSliceOp op
,
2560 ExtractSliceOp newOp
) {
2561 Value replacement
= newOp
.getResult();
2562 if (replacement
.getType() != op
.getType())
2563 replacement
= rewriter
.create
<tensor::CastOp
>(op
.getLoc(), op
.getType(),
2565 rewriter
.replaceOp(op
, replacement
);
2569 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2570 MLIRContext
*context
) {
2572 OpWithOffsetSizesAndStridesConstantArgumentFolder
<
2573 ExtractSliceOp
, SliceReturnTypeCanonicalizer
, SliceCanonicalizer
>,
2574 ExtractSliceOpCastFolder
>(context
);
2578 static LogicalResult
2579 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op
,
2580 ShapedType shapedType
) {
2581 OpBuilder
b(op
.getContext());
2582 for (OpFoldResult ofr
: op
.getMixedOffsets())
2583 if (getConstantIntValue(ofr
) != static_cast<int64_t>(0))
2585 // Rank-reducing noops only need to inspect the leading dimensions:
2586 // llvm::zip is appropriate.
2587 auto shape
= shapedType
.getShape();
2588 for (auto it
: llvm::zip(op
.getMixedSizes(), shape
))
2589 if (getConstantIntValue(std::get
<0>(it
)) != std::get
<1>(it
))
2591 for (OpFoldResult ofr
: op
.getMixedStrides())
2592 if (getConstantIntValue(ofr
) != static_cast<int64_t>(1))
2597 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2598 /// slice, we can return the InsertSliceOp's source directly.
2599 // TODO: This only checks the immediate producer; extend to go up the
2600 // insert/extract chain if the slices are disjoint.
2601 static Value
foldExtractAfterInsertSlice(ExtractSliceOp extractOp
) {
2602 auto insertOp
= extractOp
.getSource().getDefiningOp
<InsertSliceOp
>();
2604 auto isSame
= [](OpFoldResult a
, OpFoldResult b
) { return a
== b
; };
2605 if (insertOp
&& insertOp
.getSource().getType() == extractOp
.getType() &&
2606 insertOp
.isSameAs(extractOp
, isSame
))
2607 return insertOp
.getSource();
2612 OpFoldResult
ExtractSliceOp::fold(FoldAdaptor adaptor
) {
2613 if (OpFoldResult reshapedSource
= reshapeConstantSource(
2614 llvm::dyn_cast_if_present
<SplatElementsAttr
>(adaptor
.getSource()),
2615 getResult().getType()))
2616 return reshapedSource
;
2617 if (getSourceType() == getType() &&
2618 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2619 return this->getSource();
2620 if (Value slice
= foldExtractAfterInsertSlice(*this))
2623 return OpFoldResult();
2626 Value
mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2627 OpBuilder
&b
, Location loc
, Value tensor
, RankedTensorType targetType
) {
2628 auto rankedTensorType
= llvm::cast
<RankedTensorType
>(tensor
.getType());
2629 unsigned rank
= rankedTensorType
.getRank();
2630 SmallVector
<OpFoldResult
> offsets(rank
, b
.getIndexAttr(0));
2631 SmallVector
<OpFoldResult
> sizes
= getMixedSizes(b
, loc
, tensor
);
2632 SmallVector
<OpFoldResult
> strides(rank
, b
.getIndexAttr(1));
2633 return b
.createOrFold
<tensor::ExtractSliceOp
>(loc
, targetType
, tensor
,
2634 offsets
, sizes
, strides
);
2637 //===----------------------------------------------------------------------===//
2639 //===----------------------------------------------------------------------===//
2641 void InsertSliceOp::getAsmResultNames(
2642 function_ref
<void(Value
, StringRef
)> setNameFn
) {
2643 setNameFn(getResult(), "inserted_slice");
2646 // Build a InsertSliceOp with mixed static and dynamic entries.
2647 void InsertSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2648 Value dest
, ArrayRef
<OpFoldResult
> offsets
,
2649 ArrayRef
<OpFoldResult
> sizes
,
2650 ArrayRef
<OpFoldResult
> strides
,
2651 ArrayRef
<NamedAttribute
> attrs
) {
2652 SmallVector
<int64_t> staticOffsets
, staticSizes
, staticStrides
;
2653 SmallVector
<Value
> dynamicOffsets
, dynamicSizes
, dynamicStrides
;
2654 dispatchIndexOpFoldResults(offsets
, dynamicOffsets
, staticOffsets
);
2655 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticSizes
);
2656 dispatchIndexOpFoldResults(strides
, dynamicStrides
, staticStrides
);
2657 result
.addAttributes(attrs
);
2658 build(b
, result
, dest
.getType(), source
, dest
, dynamicOffsets
, dynamicSizes
,
2659 dynamicStrides
, b
.getDenseI64ArrayAttr(staticOffsets
),
2660 b
.getDenseI64ArrayAttr(staticSizes
),
2661 b
.getDenseI64ArrayAttr(staticStrides
));
2664 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2666 void InsertSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2667 Value dest
, ArrayRef
<Range
> ranges
,
2668 ArrayRef
<NamedAttribute
> attrs
) {
2669 auto [offsets
, sizes
, strides
] = getOffsetsSizesAndStrides(ranges
);
2670 build(b
, result
, source
, dest
, offsets
, sizes
, strides
, attrs
);
2673 // Build a InsertSliceOp with dynamic entries.
2674 void InsertSliceOp::build(OpBuilder
&b
, OperationState
&result
, Value source
,
2675 Value dest
, ValueRange offsets
, ValueRange sizes
,
2676 ValueRange strides
, ArrayRef
<NamedAttribute
> attrs
) {
2677 SmallVector
<OpFoldResult
> offsetValues
= llvm::to_vector
<4>(
2678 llvm::map_range(offsets
, [](Value v
) -> OpFoldResult
{ return v
; }));
2679 SmallVector
<OpFoldResult
> sizeValues
= llvm::to_vector
<4>(
2680 llvm::map_range(sizes
, [](Value v
) -> OpFoldResult
{ return v
; }));
2681 SmallVector
<OpFoldResult
> strideValues
= llvm::to_vector
<4>(
2682 llvm::map_range(strides
, [](Value v
) -> OpFoldResult
{ return v
; }));
2683 build(b
, result
, source
, dest
, offsetValues
, sizeValues
, strideValues
);
2686 /// Rank-reducing type verification for both InsertSliceOp and
2687 /// ParallelInsertSliceOp.
2688 static SliceVerificationResult
verifyInsertSliceOp(
2689 RankedTensorType srcType
, RankedTensorType dstType
,
2690 ArrayRef
<int64_t> staticOffsets
, ArrayRef
<int64_t> staticSizes
,
2691 ArrayRef
<int64_t> staticStrides
, RankedTensorType
*expectedType
= nullptr) {
2692 // insert_slice is the inverse of extract_slice, use the same type
2694 RankedTensorType expected
= ExtractSliceOp::inferResultType(
2695 dstType
, staticOffsets
, staticSizes
, staticStrides
);
2697 *expectedType
= expected
;
2698 return isRankReducedType(expected
, srcType
);
2701 /// Verifier for InsertSliceOp.
2702 LogicalResult
InsertSliceOp::verify() {
2703 RankedTensorType expectedType
;
2704 SliceVerificationResult result
=
2705 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2706 getStaticSizes(), getStaticStrides(), &expectedType
);
2707 return produceSliceErrorMsg(result
, *this, expectedType
);
2710 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2711 /// can mutate the second InsertSliceOp's destination to the first one's.
2716 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2717 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2723 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2726 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2727 static LogicalResult
foldInsertAfterInsertSlice(InsertSliceOp insertOp
) {
2728 auto prevInsertOp
= insertOp
.getDest().getDefiningOp
<InsertSliceOp
>();
2730 auto isSame
= [](OpFoldResult a
, OpFoldResult b
) { return a
== b
; };
2731 if (!prevInsertOp
||
2732 prevInsertOp
.getSource().getType() != insertOp
.getSource().getType() ||
2733 !prevInsertOp
.isSameAs(insertOp
, isSame
))
2736 insertOp
.getDestMutable().assign(prevInsertOp
.getDest());
2740 /// Folds round-trip extract/insert slice op pairs.
2743 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2744 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2746 /// can be folded into %val.
2747 static Value
foldInsertAfterExtractSlice(InsertSliceOp insertOp
) {
2748 auto extractOp
= insertOp
.getSource().getDefiningOp
<ExtractSliceOp
>();
2750 auto isSame
= [](OpFoldResult a
, OpFoldResult b
) { return a
== b
; };
2751 if (!extractOp
|| extractOp
.getSource() != insertOp
.getDest() ||
2752 !extractOp
.isSameAs(insertOp
, isSame
))
2755 return extractOp
.getSource();
2758 OpFoldResult
InsertSliceOp::fold(FoldAdaptor
) {
2759 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2760 getSourceType() == getType() &&
2761 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2762 return this->getSource();
2763 if (succeeded(foldInsertAfterInsertSlice(*this)))
2765 if (auto result
= foldInsertAfterExtractSlice(*this))
2767 if (llvm::any_of(getMixedSizes(),
2768 [](OpFoldResult ofr
) { return isConstantIntValue(ofr
, 0); }))
2770 return OpFoldResult();
2773 LogicalResult
InsertSliceOp::reifyResultShapes(
2774 OpBuilder
&builder
, ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
2775 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(getType().getRank()));
2776 reifiedReturnShapes
[0] = tensor::getMixedSizes(builder
, getLoc(), getDest());
2781 /// Pattern to rewrite a insert_slice op with constant arguments.
2783 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2784 template <typename InsertOpTy
>
2785 class InsertSliceOpConstantArgumentFolder final
2786 : public OpRewritePattern
<InsertOpTy
> {
2788 using OpRewritePattern
<InsertOpTy
>::OpRewritePattern
;
2790 LogicalResult
matchAndRewrite(InsertOpTy insertSliceOp
,
2791 PatternRewriter
&rewriter
) const override
{
2792 SmallVector
<OpFoldResult
> mixedOffsets(insertSliceOp
.getMixedOffsets());
2793 SmallVector
<OpFoldResult
> mixedSizes(insertSliceOp
.getMixedSizes());
2794 SmallVector
<OpFoldResult
> mixedStrides(insertSliceOp
.getMixedStrides());
2796 // No constant operands were folded, just return;
2797 if (failed(foldDynamicOffsetSizeList(mixedOffsets
)) &&
2798 failed(foldDynamicOffsetSizeList(mixedSizes
)) &&
2799 failed(foldDynamicStrideList(mixedStrides
)))
2802 // Create the new op in canonical form.
2803 auto sourceType
= ExtractSliceOp::inferCanonicalRankReducedResultType(
2804 insertSliceOp
.getSourceType().getRank(), insertSliceOp
.getDestType(),
2805 mixedOffsets
, mixedSizes
, mixedStrides
);
2806 Value toInsert
= insertSliceOp
.getSource();
2807 if (sourceType
!= insertSliceOp
.getSourceType()) {
2808 OpBuilder::InsertionGuard
g(rewriter
);
2809 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2810 // is that the insertion point is just before the ParallelCombiningOp in
2811 // the parallel case.
2812 if (std::is_same
<InsertOpTy
, ParallelInsertSliceOp
>::value
)
2813 rewriter
.setInsertionPoint(insertSliceOp
->getParentOp());
2814 toInsert
= rewriter
.create
<tensor::CastOp
>(insertSliceOp
.getLoc(),
2815 sourceType
, toInsert
);
2817 rewriter
.replaceOpWithNewOp
<InsertOpTy
>(
2818 insertSliceOp
, toInsert
, insertSliceOp
.getDest(), mixedOffsets
,
2819 mixedSizes
, mixedStrides
);
2824 /// Fold tensor_casts with insert_slice operations. If the source or
2825 /// destination tensor is a tensor_cast that removes static type information,
2826 /// the cast is folded into the insert_slice operation. E.g.:
2829 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2830 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2836 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2839 /// Note: When folding a cast on the destination tensor, the result of the
2840 /// insert_slice operation is casted to ensure that the type of the result did
2843 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2844 template <typename InsertOpTy
>
2845 struct InsertSliceOpCastFolder final
: public OpRewritePattern
<InsertOpTy
> {
2846 using OpRewritePattern
<InsertOpTy
>::OpRewritePattern
;
2848 LogicalResult
matchAndRewrite(InsertOpTy insertSliceOp
,
2849 PatternRewriter
&rewriter
) const override
{
2850 if (llvm::any_of(insertSliceOp
.getOperands(), [](Value operand
) {
2851 return matchPattern(operand
, matchConstantIndex());
2855 auto getSourceOfCastOp
= [](Value v
) -> std::optional
<Value
> {
2856 auto castOp
= v
.getDefiningOp
<tensor::CastOp
>();
2857 if (!castOp
|| !canFoldIntoConsumerOp(castOp
))
2858 return std::nullopt
;
2859 return castOp
.getSource();
2861 std::optional
<Value
> sourceCastSource
=
2862 getSourceOfCastOp(insertSliceOp
.getSource());
2863 std::optional
<Value
> destCastSource
=
2864 getSourceOfCastOp(insertSliceOp
.getDest());
2865 if (!sourceCastSource
&& !destCastSource
)
2869 (sourceCastSource
? *sourceCastSource
: insertSliceOp
.getSource());
2870 auto dst
= (destCastSource
? *destCastSource
: insertSliceOp
.getDest());
2871 auto srcType
= llvm::dyn_cast
<RankedTensorType
>(src
.getType());
2872 auto dstType
= llvm::dyn_cast
<RankedTensorType
>(dst
.getType());
2873 if (!srcType
|| !dstType
)
2876 // The tensor.cast source could have additional static information not seen
2877 // in the insert slice op static sizes, so we ignore dynamic dims when
2878 // computing the rank reduction mask.
2879 SmallVector
<int64_t> staticSizes(insertSliceOp
.getStaticSizes());
2880 auto rankReductionMask
= computeRankReductionMask(
2881 staticSizes
, srcType
.getShape(), /*matchDynamic=*/true);
2882 if (!rankReductionMask
.has_value())
2884 // Replace dimensions in the insert slice op with corresponding static dims
2885 // from the cast source type. If the insert slice sizes have static dims
2886 // that are not static in the tensor.cast source (i.e., when the cast op
2887 // casts a dynamic dim to static), the dim should not be replaced, and the
2888 // pattern will fail later in `verifyInsertSliceOp`.
2889 SmallVector
<OpFoldResult
> mixedSizes(insertSliceOp
.getMixedSizes());
2890 int64_t rankReducedIdx
= 0;
2891 for (auto [idx
, size
] : enumerate(staticSizes
)) {
2892 if (!rankReductionMask
.value().contains(idx
) &&
2893 !srcType
.isDynamicDim(rankReducedIdx
)) {
2894 mixedSizes
[idx
] = getAsIndexOpFoldResult(
2895 rewriter
.getContext(), srcType
.getDimSize(rankReducedIdx
));
2896 size
= srcType
.getDimSize(rankReducedIdx
++);
2899 if (verifyInsertSliceOp(srcType
, dstType
, insertSliceOp
.getStaticOffsets(),
2900 staticSizes
, insertSliceOp
.getStaticStrides()) !=
2901 SliceVerificationResult::Success
)
2904 Operation
*replacement
= rewriter
.create
<InsertOpTy
>(
2905 insertSliceOp
.getLoc(), src
, dst
, insertSliceOp
.getMixedOffsets(),
2906 mixedSizes
, insertSliceOp
.getMixedStrides());
2908 // In the parallel case there is no result and so nothing to cast.
2909 bool isParallelInsert
=
2910 std::is_same
<InsertOpTy
, ParallelInsertSliceOp
>::value
;
2911 if (!isParallelInsert
&& dst
.getType() != insertSliceOp
.getDestType()) {
2912 replacement
= rewriter
.create
<tensor::CastOp
>(insertSliceOp
.getLoc(),
2913 insertSliceOp
.getDestType(),
2914 replacement
->getResult(0));
2916 rewriter
.replaceOp(insertSliceOp
, replacement
->getResults());
2921 /// If additional static type information can be deduced from a insert_slice's
2922 /// size operands, insert an explicit cast of the op's source operand. This
2923 /// enables other canonicalization patterns that are matching for tensor_cast
2924 /// ops such as `ForOpTensorCastFolder` in SCF.
2929 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2930 /// : tensor<?x?xf32> into ...
2936 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2937 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2938 /// : tensor<64x64xf32> into ...
2941 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2942 template <typename InsertOpTy
>
2943 struct InsertSliceOpSourceCastInserter final
2944 : public OpRewritePattern
<InsertOpTy
> {
2945 using OpRewritePattern
<InsertOpTy
>::OpRewritePattern
;
2947 LogicalResult
matchAndRewrite(InsertOpTy insertSliceOp
,
2948 PatternRewriter
&rewriter
) const override
{
2949 RankedTensorType srcType
= insertSliceOp
.getSourceType();
2950 if (srcType
.getRank() != insertSliceOp
.getDestType().getRank())
2952 SmallVector
<int64_t> newSrcShape(srcType
.getShape());
2953 for (int64_t i
= 0; i
< srcType
.getRank(); ++i
) {
2954 if (std::optional
<int64_t> constInt
=
2955 getConstantIntValue(insertSliceOp
.getMixedSizes()[i
])) {
2956 // Bail on invalid IR.
2959 newSrcShape
[i
] = *constInt
;
2962 if (!hasValidSizesOffsets(newSrcShape
))
2965 RankedTensorType newSrcType
= RankedTensorType::get(
2966 newSrcShape
, srcType
.getElementType(), srcType
.getEncoding());
2967 if (srcType
== newSrcType
||
2968 !preservesStaticInformation(srcType
, newSrcType
) ||
2969 !tensor::CastOp::areCastCompatible(srcType
, newSrcType
))
2973 // 1) Different from srcType.
2974 // 2) "More static" than srcType.
2975 // 3) Cast-compatible with srcType.
2977 OpBuilder::InsertionGuard
g(rewriter
);
2978 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2979 // that the insertion point is just before the ParallelCombiningOp in the
2981 if (std::is_same
<InsertOpTy
, ParallelInsertSliceOp
>::value
)
2982 rewriter
.setInsertionPoint(insertSliceOp
->getParentOp());
2983 Value cast
= rewriter
.create
<tensor::CastOp
>(
2984 insertSliceOp
.getLoc(), newSrcType
, insertSliceOp
.getSource());
2985 rewriter
.replaceOpWithNewOp
<InsertOpTy
>(
2986 insertSliceOp
, cast
, insertSliceOp
.getDest(),
2987 insertSliceOp
.getMixedOffsets(), insertSliceOp
.getMixedSizes(),
2988 insertSliceOp
.getMixedStrides());
2994 llvm::SmallBitVector
InsertSliceOp::getDroppedDims() {
2995 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2998 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
2999 MLIRContext
*context
) {
3000 results
.add
<InsertSliceOpConstantArgumentFolder
<InsertSliceOp
>,
3001 InsertSliceOpCastFolder
<InsertSliceOp
>,
3002 InsertSliceOpSourceCastInserter
<InsertSliceOp
>>(context
);
3005 Value
mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder
&b
,
3009 auto rankedTensorType
= llvm::cast
<RankedTensorType
>(dest
.getType());
3010 unsigned rank
= rankedTensorType
.getRank();
3011 SmallVector
<OpFoldResult
> offsets(rank
, b
.getIndexAttr(0));
3012 SmallVector
<OpFoldResult
> sizes
= getMixedSizes(b
, loc
, dest
);
3013 SmallVector
<OpFoldResult
> strides(rank
, b
.getIndexAttr(1));
3014 return b
.createOrFold
<tensor::InsertSliceOp
>(loc
, tensor
, dest
, offsets
,
3018 //===----------------------------------------------------------------------===//
3020 //===----------------------------------------------------------------------===//
3022 void PadOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
3023 setNameFn(getResult(), "padded");
3026 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3027 // supports optional types.
3028 void printInferType(OpAsmPrinter
&printer
, Operation
*op
, Value optOperand
,
3029 Type typeToInfer
, Type typeToInferFrom
) {}
3032 parseInferType(OpAsmParser
&parser
,
3033 std::optional
<OpAsmParser::UnresolvedOperand
> optOperand
,
3034 Type
&typeToInfer
, Type typeToInferFrom
) {
3036 typeToInfer
= typeToInferFrom
;
3040 LogicalResult
PadOp::verify() {
3041 auto sourceType
= llvm::cast
<RankedTensorType
>(getSource().getType());
3042 auto resultType
= llvm::cast
<RankedTensorType
>(getResult().getType());
3044 PadOp::inferResultType(sourceType
, getStaticLow(), getStaticHigh());
3045 if (!expectedType
) {
3046 return emitError("failed to infer expectedType from sourceType ")
3047 << sourceType
<< ", specified resultType is " << resultType
;
3049 if (resultType
.getRank() != expectedType
.getRank()) {
3050 return emitError("specified type ")
3051 << resultType
<< " does not match the inferred type "
3054 for (int i
= 0, e
= sourceType
.getRank(); i
< e
; ++i
) {
3055 if (resultType
.getDimSize(i
) == expectedType
.getDimSize(i
))
3057 if (expectedType
.isDynamicDim(i
))
3059 return emitError("specified type ")
3060 << resultType
<< " does not match the inferred type "
3067 LogicalResult
PadOp::verifyRegions() {
3068 auto ®ion
= getRegion();
3069 unsigned rank
= llvm::cast
<RankedTensorType
>(getResult().getType()).getRank();
3070 Block
&block
= region
.front();
3071 if (block
.getNumArguments() != rank
)
3072 return emitError("expected the block to have ") << rank
<< " arguments";
3074 // Note: the number and type of yield values are checked in the YieldOp.
3075 for (const auto &en
: llvm::enumerate(block
.getArgumentTypes())) {
3076 if (!en
.value().isIndex())
3077 return emitOpError("expected block argument ")
3078 << (en
.index() + 1) << " to be an index";
3081 // Ensure that the region yields an element of the right type.
3082 auto yieldOp
= llvm::cast
<YieldOp
>(block
.getTerminator());
3083 if (yieldOp
.getValue().getType() !=
3084 llvm::cast
<ShapedType
>(getType()).getElementType())
3085 return emitOpError("expected yield type to match shape element type");
3090 RankedTensorType
PadOp::inferResultType(RankedTensorType sourceType
,
3091 ArrayRef
<int64_t> staticLow
,
3092 ArrayRef
<int64_t> staticHigh
,
3093 ArrayRef
<int64_t> resultShape
) {
3094 unsigned rank
= sourceType
.getRank();
3095 if (staticLow
.size() != rank
)
3096 return RankedTensorType();
3097 if (staticHigh
.size() != rank
)
3098 return RankedTensorType();
3099 if (!resultShape
.empty() && resultShape
.size() != rank
)
3100 return RankedTensorType();
3102 SmallVector
<int64_t, 4> inferredShape
;
3103 for (auto i
: llvm::seq
<unsigned>(0, rank
)) {
3104 if (sourceType
.isDynamicDim(i
) || staticLow
[i
] == ShapedType::kDynamic
||
3105 staticHigh
[i
] == ShapedType::kDynamic
) {
3106 inferredShape
.push_back(resultShape
.empty() ? ShapedType::kDynamic
3109 int64_t size
= sourceType
.getDimSize(i
) + staticLow
[i
] + staticHigh
[i
];
3110 assert((resultShape
.empty() || size
== resultShape
[i
] ||
3111 resultShape
[i
] == ShapedType::kDynamic
) &&
3112 "mismatch between inferred shape and result shape");
3113 inferredShape
.push_back(size
);
3117 return RankedTensorType::get(inferredShape
, sourceType
.getElementType());
3120 void PadOp::build(OpBuilder
&b
, OperationState
&result
, Type resultType
,
3121 Value source
, ArrayRef
<int64_t> staticLow
,
3122 ArrayRef
<int64_t> staticHigh
, ValueRange low
, ValueRange high
,
3123 bool nofold
, ArrayRef
<NamedAttribute
> attrs
) {
3124 auto sourceType
= llvm::cast
<RankedTensorType
>(source
.getType());
3126 resultType
= inferResultType(sourceType
, staticLow
, staticHigh
);
3127 result
.addAttributes(attrs
);
3128 build(b
, result
, resultType
, source
, low
, high
,
3129 b
.getDenseI64ArrayAttr(staticLow
), b
.getDenseI64ArrayAttr(staticHigh
),
3130 nofold
? b
.getUnitAttr() : UnitAttr());
3133 void PadOp::build(OpBuilder
&b
, OperationState
&result
, Type resultType
,
3134 Value source
, ValueRange low
, ValueRange high
, bool nofold
,
3135 ArrayRef
<NamedAttribute
> attrs
) {
3136 auto sourceType
= llvm::cast
<RankedTensorType
>(source
.getType());
3137 unsigned rank
= sourceType
.getRank();
3138 SmallVector
<int64_t, 4> staticVector(rank
, ShapedType::kDynamic
);
3139 build(b
, result
, resultType
, source
, staticVector
, staticVector
, low
, high
,
3143 void PadOp::build(OpBuilder
&b
, OperationState
&result
, Type resultType
,
3144 Value source
, ArrayRef
<OpFoldResult
> low
,
3145 ArrayRef
<OpFoldResult
> high
, bool nofold
,
3146 ArrayRef
<NamedAttribute
> attrs
) {
3147 auto sourceType
= llvm::cast
<RankedTensorType
>(source
.getType());
3148 SmallVector
<Value
, 4> dynamicLow
, dynamicHigh
;
3149 SmallVector
<int64_t, 4> staticLow
, staticHigh
;
3150 // staticLow and staticHigh have full information of the padding config.
3151 // This will grow staticLow and staticHigh with 1 value. If the config is
3152 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3154 dispatchIndexOpFoldResults(low
, dynamicLow
, staticLow
);
3155 dispatchIndexOpFoldResults(high
, dynamicHigh
, staticHigh
);
3157 resultType
= PadOp::inferResultType(sourceType
, staticLow
, staticHigh
);
3159 assert(llvm::isa
<RankedTensorType
>(resultType
));
3160 result
.addAttributes(attrs
);
3161 build(b
, result
, resultType
, source
, dynamicLow
, dynamicHigh
,
3162 b
.getDenseI64ArrayAttr(staticLow
), b
.getDenseI64ArrayAttr(staticHigh
),
3163 nofold
? b
.getUnitAttr() : UnitAttr());
3166 void PadOp::build(OpBuilder
&b
, OperationState
&result
, Type resultType
,
3167 Value source
, ArrayRef
<OpFoldResult
> low
,
3168 ArrayRef
<OpFoldResult
> high
, Value constantPadValue
,
3169 bool nofold
, ArrayRef
<NamedAttribute
> attrs
) {
3170 build(b
, result
, resultType
, source
, low
, high
, nofold
, attrs
);
3172 // Add a region and a block to yield the pad value.
3173 Region
*region
= result
.regions
[0].get();
3174 int sourceRank
= llvm::cast
<RankedTensorType
>(source
.getType()).getRank();
3175 SmallVector
<Type
> blockArgTypes(sourceRank
, b
.getIndexType());
3176 SmallVector
<Location
> blockArgLocs(sourceRank
, result
.location
);
3178 // `builder.createBlock` changes the insertion point within the block. Create
3179 // a guard to reset the insertion point of the builder after it is destroyed.
3180 OpBuilder::InsertionGuard
guard(b
);
3181 b
.createBlock(region
, region
->end(), blockArgTypes
, blockArgLocs
);
3182 b
.create
<tensor::YieldOp
>(result
.location
, constantPadValue
);
3185 llvm::SmallBitVector
PadOp::getPaddedDims() {
3186 llvm::SmallBitVector
paddedDims(getSourceType().getRank());
3187 auto extractPaddedDims
= [&](ArrayRef
<OpFoldResult
> paddingWidths
) {
3188 for (const auto &en
: enumerate(paddingWidths
))
3189 if (getConstantIntValue(en
.value()) != static_cast<int64_t>(0))
3190 paddedDims
.set(en
.index());
3192 extractPaddedDims(getMixedLowPad());
3193 extractPaddedDims(getMixedHighPad());
3198 // Folds tensor.pad when padding is static zeros and the attribute
3199 // doesn't request otherwise.
3200 struct FoldStaticZeroPadding
: public OpRewritePattern
<PadOp
> {
3201 using OpRewritePattern
<PadOp
>::OpRewritePattern
;
3203 LogicalResult
matchAndRewrite(PadOp padTensorOp
,
3204 PatternRewriter
&rewriter
) const override
{
3205 if (!padTensorOp
.hasZeroLowPad() || !padTensorOp
.hasZeroHighPad())
3207 if (padTensorOp
.getNofold())
3209 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(
3210 padTensorOp
, padTensorOp
.getResult().getType(),
3211 padTensorOp
.getSource());
3216 // Fold CastOp into PadOp when adding static information.
3217 struct FoldSourceTensorCast
: public OpRewritePattern
<PadOp
> {
3218 using OpRewritePattern
<PadOp
>::OpRewritePattern
;
3220 LogicalResult
matchAndRewrite(PadOp padTensorOp
,
3221 PatternRewriter
&rewriter
) const override
{
3222 auto castOp
= padTensorOp
.getSource().getDefiningOp
<tensor::CastOp
>();
3223 if (!tensor::canFoldIntoConsumerOp(castOp
))
3226 auto newResultType
= PadOp::inferResultType(
3227 llvm::cast
<RankedTensorType
>(castOp
.getSource().getType()),
3228 padTensorOp
.getStaticLow(), padTensorOp
.getStaticHigh(),
3229 padTensorOp
.getResultType().getShape());
3231 if (newResultType
== padTensorOp
.getResultType()) {
3232 rewriter
.modifyOpInPlace(padTensorOp
, [&]() {
3233 padTensorOp
.getSourceMutable().assign(castOp
.getSource());
3236 auto newOp
= rewriter
.create
<PadOp
>(
3237 padTensorOp
->getLoc(), newResultType
, padTensorOp
.getSource(),
3238 padTensorOp
.getStaticLow(), padTensorOp
.getStaticHigh(),
3239 padTensorOp
.getLow(), padTensorOp
.getHigh(), padTensorOp
.getNofold(),
3240 getPrunedAttributeList(padTensorOp
, PadOp::getAttributeNames()));
3242 padTensorOp
.getRegion().cloneInto(&newOp
.getRegion(), mapper
);
3244 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(
3245 padTensorOp
, padTensorOp
.getResultType(), newOp
);
3251 // Fold CastOp using the result of PadOp back into the latter if it adds
3252 // static information.
3253 struct FoldTargetTensorCast
: public OpRewritePattern
<PadOp
> {
3254 using OpRewritePattern
<PadOp
>::OpRewritePattern
;
3256 LogicalResult
matchAndRewrite(PadOp padTensorOp
,
3257 PatternRewriter
&rewriter
) const override
{
3258 if (!padTensorOp
.getResult().hasOneUse())
3261 dyn_cast
<tensor::CastOp
>(*padTensorOp
->getUsers().begin());
3264 if (!tensor::preservesStaticInformation(padTensorOp
.getResult().getType(),
3265 tensorCastOp
.getDest().getType()))
3268 auto replacementOp
= rewriter
.create
<PadOp
>(
3269 padTensorOp
.getLoc(), tensorCastOp
.getDest().getType(),
3270 padTensorOp
.getSource(), padTensorOp
.getStaticLow(),
3271 padTensorOp
.getStaticHigh(), padTensorOp
.getLow(),
3272 padTensorOp
.getHigh(), padTensorOp
.getNofold(),
3273 getPrunedAttributeList(padTensorOp
, PadOp::getAttributeNames()));
3274 replacementOp
.getRegion().takeBody(padTensorOp
.getRegion());
3276 rewriter
.replaceOp(padTensorOp
, replacementOp
.getResult());
3277 rewriter
.replaceOp(tensorCastOp
, replacementOp
.getResult());
3282 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3283 /// different dimensions. The pattern applies if the following preconditions
3285 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3286 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3287 /// 3) the tensor::PadOps perform only high-padding,
3288 /// 4) the tensor::PadOps have the same constant padding value,
3289 /// 5) the tensor::PadOps do not have common padding dimensions,
3290 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3291 /// zero-offset for every dimension.
3292 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3294 /// padded source dimensions.
3299 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3300 /// : tensor<64x64xf32> to tensor<?x64xf32>
3301 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3302 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3303 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3304 /// : tensor<8x64xf32> to tensor<8x?xf32>
3305 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3306 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3312 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3313 /// : tensor<64x64xf32> to tensor<?x?xf32>
3314 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3315 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3317 struct FoldOrthogonalPaddings
: public OpRewritePattern
<PadOp
> {
3318 using OpRewritePattern
<PadOp
>::OpRewritePattern
;
3320 LogicalResult
matchAndRewrite(PadOp padOp
,
3321 PatternRewriter
&rewriter
) const override
{
3322 auto innerSliceOp
= padOp
.getSource().getDefiningOp
<ExtractSliceOp
>();
3325 auto outerPadOp
= innerSliceOp
.getSource().getDefiningOp
<PadOp
>();
3326 if (!outerPadOp
|| outerPadOp
.getNofold())
3328 auto outerSliceOp
= outerPadOp
.getSource().getDefiningOp
<ExtractSliceOp
>();
3332 // 1) Fail if the chain is rank-reducing.
3333 int64_t rank
= padOp
.getSourceType().getRank();
3334 if (outerSliceOp
.getSourceType().getRank() != rank
) {
3335 return rewriter
.notifyMatchFailure(padOp
,
3336 "cannot fold rank-reducing chain");
3339 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3340 if (!innerSliceOp
.hasUnitStride() || !outerSliceOp
.hasUnitStride()) {
3341 return rewriter
.notifyMatchFailure(
3342 padOp
, "cannot fold non-unit stride ExtractSliceOps");
3345 // 3) Fail if the tensor::PadOps have non-zero low padding.
3346 if (!padOp
.hasZeroLowPad() || !outerPadOp
.hasZeroLowPad()) {
3347 return rewriter
.notifyMatchFailure(padOp
,
3348 "cannot fold PadOps with low padding");
3351 // 4) Fail if the tensor::PadOps padding values do not match.
3352 Attribute innerAttr
, outerAttr
;
3353 Value innerValue
= padOp
.getConstantPaddingValue();
3354 Value outerValue
= outerPadOp
.getConstantPaddingValue();
3355 if (!innerValue
|| !outerValue
||
3356 !matchPattern(innerValue
, m_Constant(&innerAttr
)) ||
3357 !matchPattern(outerValue
, m_Constant(&outerAttr
)) ||
3358 innerAttr
!= outerAttr
) {
3359 return rewriter
.notifyMatchFailure(
3360 padOp
, "cannot fold PadOps with different padding values");
3363 // 5) Fail if a dimension is padded by both tensor::PadOps.
3364 llvm::SmallBitVector innerDims
= padOp
.getPaddedDims();
3365 llvm::SmallBitVector outerDims
= outerPadOp
.getPaddedDims();
3366 if (innerDims
.anyCommon(outerDims
)) {
3367 return rewriter
.notifyMatchFailure(
3368 padOp
, "cannot fold PadOps with common padding dimensions");
3371 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3372 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3373 // for every dimension, and use the offset the other pair. Fail if no
3374 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3376 SmallVector
<OpFoldResult
> newOffsets(rank
, rewriter
.getIndexAttr(0));
3377 for (auto en
: enumerate(newOffsets
)) {
3378 OpFoldResult innerOffset
= innerSliceOp
.getMixedOffsets()[en
.index()];
3379 OpFoldResult outerOffset
= outerSliceOp
.getMixedOffsets()[en
.index()];
3380 if (!innerDims
.test(en
.index()) &&
3381 (getConstantIntValue(innerOffset
) == static_cast<int64_t>(0))) {
3382 en
.value() = outerOffset
;
3385 if (!outerDims
.test(en
.index()) &&
3386 (getConstantIntValue(outerOffset
) == static_cast<int64_t>(0))) {
3387 en
.value() = innerOffset
;
3390 return rewriter
.notifyMatchFailure(
3391 padOp
, "cannot find zero-offset and zero-padding pair");
3394 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3395 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3396 // outer tensor::PadOp and fail if the size of the inner
3397 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3398 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3399 SmallVector
<OpFoldResult
> newSizes
= innerSliceOp
.getMixedSizes();
3400 for (auto en
: enumerate(newSizes
)) {
3401 if (!outerDims
.test(en
.index()))
3403 OpFoldResult sliceSize
= innerSliceOp
.getMixedSizes()[en
.index()];
3404 int64_t sourceSize
= innerSliceOp
.getSourceType().getShape()[en
.index()];
3405 assert(!ShapedType::isDynamic(sourceSize
) &&
3406 "expected padded dimension to have a static size");
3407 if (getConstantIntValue(sliceSize
) != sourceSize
) {
3408 return rewriter
.notifyMatchFailure(
3409 padOp
, "cannot fold since the inner ExtractSliceOp size does not "
3410 "match the size of the outer padding");
3412 en
.value() = outerSliceOp
.getMixedSizes()[en
.index()];
3415 // Combine the high paddings of the two tensor::PadOps.
3416 SmallVector
<OpFoldResult
> newHighPad(rank
, rewriter
.getIndexAttr(0));
3417 for (auto en
: enumerate(newHighPad
)) {
3418 if (innerDims
.test(en
.index()))
3419 newHighPad
[en
.index()] = padOp
.getMixedHighPad()[en
.index()];
3420 if (outerDims
.test(en
.index()))
3421 newHighPad
[en
.index()] = outerPadOp
.getMixedHighPad()[en
.index()];
3424 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3425 // the two paddings in one step.
3426 auto newSliceOp
= rewriter
.create
<ExtractSliceOp
>(
3427 padOp
.getLoc(), outerSliceOp
.getSource(), newOffsets
, newSizes
,
3428 innerSliceOp
.getMixedStrides());
3429 auto newPadOp
= rewriter
.create
<PadOp
>(
3430 padOp
.getLoc(), padOp
.getResultType(), newSliceOp
.getResult(),
3431 padOp
.getMixedLowPad(), newHighPad
, padOp
.getNofold(),
3432 getPrunedAttributeList(padOp
, PadOp::getAttributeNames()));
3433 rewriter
.inlineRegionBefore(padOp
.getRegion(), newPadOp
.getRegion(),
3434 newPadOp
.getRegion().begin());
3435 rewriter
.replaceOp(padOp
, newPadOp
.getResult());
3440 struct FoldStaticPadding
: public OpRewritePattern
<PadOp
> {
3441 using OpRewritePattern
<PadOp
>::OpRewritePattern
;
3443 LogicalResult
matchAndRewrite(PadOp padTensorOp
,
3444 PatternRewriter
&rewriter
) const override
{
3445 Value input
= padTensorOp
.getSource();
3446 if (!llvm::isa
<RankedTensorType
>(input
.getType()))
3448 auto inputDims
= llvm::cast
<RankedTensorType
>(input
.getType()).getShape();
3449 auto inputRank
= inputDims
.size();
3451 auto oldResultType
=
3452 dyn_cast
<RankedTensorType
>(padTensorOp
.getResult().getType());
3456 auto outputDims
= oldResultType
.getShape();
3458 // Extract the static info from the high and low operands.
3459 SmallVector
<int64_t> constOperandsLow
;
3460 SmallVector
<Value
> newLows
;
3461 for (auto operand
: padTensorOp
.getLow()) {
3463 if (!matchPattern(operand
, m_ConstantInt(&intOp
))) {
3464 constOperandsLow
.push_back(ShapedType::kDynamic
);
3465 newLows
.push_back(operand
);
3468 constOperandsLow
.push_back(intOp
.getExtValue());
3470 SmallVector
<int64_t> constOperandsHigh
;
3471 SmallVector
<Value
> newHighs
;
3472 for (auto operand
: padTensorOp
.getHigh()) {
3474 if (!matchPattern(operand
, m_ConstantInt(&intOp
))) {
3475 constOperandsHigh
.push_back(ShapedType::kDynamic
);
3476 newHighs
.push_back(operand
);
3479 constOperandsHigh
.push_back(intOp
.getExtValue());
3482 SmallVector
<int64_t> constLow(padTensorOp
.getStaticLow());
3483 SmallVector
<int64_t> constHigh(padTensorOp
.getStaticHigh());
3485 // Verify the op is well-formed.
3486 if (inputDims
.size() != outputDims
.size() ||
3487 inputDims
.size() != constLow
.size() ||
3488 inputDims
.size() != constHigh
.size())
3493 for (size_t i
= 0; i
< inputRank
; i
++) {
3494 if (constLow
[i
] == ShapedType::kDynamic
)
3495 constLow
[i
] = constOperandsLow
[lowCount
++];
3496 if (constHigh
[i
] == ShapedType::kDynamic
)
3497 constHigh
[i
] = constOperandsHigh
[highCount
++];
3500 auto staticLow
= ArrayRef
<int64_t>(constLow
);
3501 auto staticHigh
= ArrayRef
<int64_t>(constHigh
);
3503 // Calculate the output sizes with the static information.
3504 SmallVector
<int64_t> newOutDims
;
3505 for (size_t i
= 0; i
< inputRank
; i
++) {
3506 if (outputDims
[i
] == ShapedType::kDynamic
) {
3507 newOutDims
.push_back(
3508 (staticLow
[i
] == ShapedType::kDynamic
||
3509 staticHigh
[i
] == ShapedType::kDynamic
||
3510 inputDims
[i
] == ShapedType::kDynamic
3511 ? ShapedType::kDynamic
3512 : inputDims
[i
] + staticLow
[i
] + staticHigh
[i
]));
3514 newOutDims
.push_back(outputDims
[i
]);
3518 if (SmallVector
<int64_t>(outputDims
) == newOutDims
||
3519 llvm::all_of(newOutDims
,
3520 [&](int64_t x
) { return x
== ShapedType::kDynamic
; }))
3523 // Rewrite the op using the new static type.
3524 auto newResultType
= RankedTensorType::get(
3525 newOutDims
, padTensorOp
.getType().getElementType());
3526 auto newOp
= rewriter
.create
<PadOp
>(
3527 padTensorOp
->getLoc(), newResultType
, input
, staticLow
, staticHigh
,
3528 newLows
, newHighs
, padTensorOp
.getNofold(),
3529 getPrunedAttributeList(padTensorOp
, PadOp::getAttributeNames()));
3532 padTensorOp
.getRegion().cloneInto(&newOp
.getRegion(), mapper
);
3533 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(padTensorOp
, oldResultType
,
3540 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3545 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3546 /// tensor.yield %val
3547 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3548 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3549 /// tensor.yield %val
3550 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3556 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3557 /// tensor.yield %val
3558 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3560 struct FoldConsecutiveConstantPadding
: public OpRewritePattern
<tensor::PadOp
> {
3561 using OpRewritePattern
<tensor::PadOp
>::OpRewritePattern
;
3563 LogicalResult
matchAndRewrite(tensor::PadOp padOp
,
3564 PatternRewriter
&rewriter
) const override
{
3565 if (padOp
.getNofold()) {
3566 return rewriter
.notifyMatchFailure(padOp
, "skipping unfoldable pad");
3569 auto producerPad
= padOp
.getSource().getDefiningOp
<tensor::PadOp
>();
3570 if (!producerPad
|| producerPad
.getNofold()) {
3571 return rewriter
.notifyMatchFailure(
3572 padOp
, "producer is not a foldable tensor.pad op");
3575 // Fail if the tensor::PadOps padding values do not match.
3576 Value consumerPadValue
= padOp
.getConstantPaddingValue();
3577 Value producerPadValue
= producerPad
.getConstantPaddingValue();
3578 if (!consumerPadValue
|| !producerPadValue
||
3579 consumerPadValue
!= producerPadValue
) {
3580 return rewriter
.notifyMatchFailure(
3582 "cannot fold PadOps with different or non-constant padding values");
3585 Location loc
= padOp
.getLoc();
3587 bindDims(rewriter
.getContext(), d0
, d1
);
3589 // Combine the low/high paddings of the two tensor::PadOps.
3590 auto addPaddings
= [&](ArrayRef
<OpFoldResult
> consumerPaddings
,
3591 ArrayRef
<OpFoldResult
> producerPaddings
) {
3592 SmallVector
<OpFoldResult
> sumPaddings
;
3593 for (auto [consumerIndex
, producerIndex
] :
3594 llvm::zip_equal(consumerPaddings
, producerPaddings
)) {
3595 sumPaddings
.push_back(affine::makeComposedFoldedAffineApply(
3596 rewriter
, loc
, d0
+ d1
, {consumerIndex
, producerIndex
}));
3601 SmallVector
<OpFoldResult
> newHighPad
=
3602 addPaddings(padOp
.getMixedHighPad(), producerPad
.getMixedHighPad());
3603 SmallVector
<OpFoldResult
> newLowPad
=
3604 addPaddings(padOp
.getMixedLowPad(), producerPad
.getMixedLowPad());
3606 auto newPadOp
= rewriter
.create
<tensor::PadOp
>(
3607 padOp
.getLoc(), padOp
.getResultType(), producerPad
.getSource(),
3608 newLowPad
, newHighPad
, padOp
.getNofold(),
3609 getPrunedAttributeList(padOp
, tensor::PadOp::getAttributeNames()));
3610 rewriter
.inlineRegionBefore(padOp
.getRegion(), newPadOp
.getRegion(),
3611 newPadOp
.getRegion().begin());
3612 rewriter
.replaceOp(padOp
, newPadOp
.getResult());
3619 void PadOp::getCanonicalizationPatterns(RewritePatternSet
&results
,
3620 MLIRContext
*context
) {
3621 results
.add
<FoldStaticZeroPadding
, FoldSourceTensorCast
, FoldTargetTensorCast
,
3622 FoldOrthogonalPaddings
, FoldStaticPadding
,
3623 FoldConsecutiveConstantPadding
>(context
);
3626 /// Return the padding value of the PadOp if it constant. In this context,
3627 /// "constant" means an actual constant or "defined outside of the block".
3629 /// Values are considered constant in three cases:
3630 /// - A ConstantLike value.
3631 /// - A basic block argument from a different block.
3632 /// - A value defined outside of the block.
3634 /// If the padding value is not constant, an empty Value is returned.
3635 Value
PadOp::getConstantPaddingValue() {
3636 auto yieldOp
= dyn_cast
<YieldOp
>(getRegion().front().getTerminator());
3639 Value padValue
= yieldOp
.getValue();
3640 // Check if yield value is a constant.
3641 if (matchPattern(padValue
, m_Constant()))
3643 // Check if yield value is defined inside the PadOp block.
3644 if (padValue
.getParentBlock() == &getRegion().front())
3646 // Else: Yield value defined outside of the PadOp block.
3650 OpFoldResult
PadOp::fold(FoldAdaptor
) {
3651 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3657 //===----------------------------------------------------------------------===//
3658 // ParallelInsertSliceOp
3659 //===----------------------------------------------------------------------===//
3661 OpResult
ParallelInsertSliceOp::getTiedOpResult() {
3662 ParallelCombiningOpInterface parallelCombiningParent
=
3663 getParallelCombiningParent();
3664 for (const auto &it
:
3665 llvm::enumerate(parallelCombiningParent
.getYieldingOps())) {
3666 Operation
&nextOp
= it
.value();
3667 if (&nextOp
== getOperation())
3668 return parallelCombiningParent
.getParentResult(it
.index());
3670 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3673 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3674 void ParallelInsertSliceOp::build(OpBuilder
&b
, OperationState
&result
,
3675 Value source
, Value dest
,
3676 ArrayRef
<OpFoldResult
> offsets
,
3677 ArrayRef
<OpFoldResult
> sizes
,
3678 ArrayRef
<OpFoldResult
> strides
,
3679 ArrayRef
<NamedAttribute
> attrs
) {
3680 SmallVector
<int64_t> staticOffsets
, staticSizes
, staticStrides
;
3681 SmallVector
<Value
> dynamicOffsets
, dynamicSizes
, dynamicStrides
;
3682 dispatchIndexOpFoldResults(offsets
, dynamicOffsets
, staticOffsets
);
3683 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticSizes
);
3684 dispatchIndexOpFoldResults(strides
, dynamicStrides
, staticStrides
);
3685 result
.addAttributes(attrs
);
3686 build(b
, result
, {}, source
, dest
, dynamicOffsets
, dynamicSizes
,
3687 dynamicStrides
, b
.getDenseI64ArrayAttr(staticOffsets
),
3688 b
.getDenseI64ArrayAttr(staticSizes
),
3689 b
.getDenseI64ArrayAttr(staticStrides
));
3692 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3693 /// packed into a Range vector.
3694 void ParallelInsertSliceOp::build(OpBuilder
&b
, OperationState
&result
,
3695 Value source
, Value dest
,
3696 ArrayRef
<Range
> ranges
,
3697 ArrayRef
<NamedAttribute
> attrs
) {
3698 auto [offsets
, sizes
, strides
] = getOffsetsSizesAndStrides(ranges
);
3699 build(b
, result
, source
, dest
, offsets
, sizes
, strides
, attrs
);
3702 // Build a ParallelInsertSliceOp with dynamic entries.
3703 void ParallelInsertSliceOp::build(OpBuilder
&b
, OperationState
&result
,
3704 Value source
, Value dest
, ValueRange offsets
,
3705 ValueRange sizes
, ValueRange strides
,
3706 ArrayRef
<NamedAttribute
> attrs
) {
3707 SmallVector
<OpFoldResult
> offsetValues
= llvm::to_vector
<4>(
3708 llvm::map_range(offsets
, [](Value v
) -> OpFoldResult
{ return v
; }));
3709 SmallVector
<OpFoldResult
> sizeValues
= llvm::to_vector
<4>(
3710 llvm::map_range(sizes
, [](Value v
) -> OpFoldResult
{ return v
; }));
3711 SmallVector
<OpFoldResult
> strideValues
= llvm::to_vector
<4>(
3712 llvm::map_range(strides
, [](Value v
) -> OpFoldResult
{ return v
; }));
3713 build(b
, result
, source
, dest
, offsetValues
, sizeValues
, strideValues
);
3716 LogicalResult
ParallelInsertSliceOp::verify() {
3717 if (!isa
<ParallelCombiningOpInterface
>(getOperation()->getParentOp()))
3718 return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3719 << *(getOperation()->getParentOp());
3721 RankedTensorType expectedType
;
3722 SliceVerificationResult result
=
3723 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3724 getStaticSizes(), getStaticStrides(), &expectedType
);
3725 return produceSliceErrorMsg(result
, *this, expectedType
);
3728 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3729 RewritePatternSet
&results
, MLIRContext
*context
) {
3730 results
.add
<InsertSliceOpConstantArgumentFolder
<ParallelInsertSliceOp
>,
3731 InsertSliceOpCastFolder
<ParallelInsertSliceOp
>,
3732 InsertSliceOpSourceCastInserter
<ParallelInsertSliceOp
>>(context
);
3735 llvm::SmallBitVector
ParallelInsertSliceOp::getDroppedDims() {
3736 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3739 //===----------------------------------------------------------------------===//
3741 //===----------------------------------------------------------------------===//
3743 void ScatterOp::getAsmResultNames(
3744 function_ref
<void(Value
, StringRef
)> setNameFn
) {
3745 setNameFn(getResult(), "scatter");
3748 LogicalResult
ScatterOp::verify() {
3749 int64_t destRank
= getDestType().getRank();
3750 ArrayRef
<int64_t> scatterDims
= getScatterDims();
3751 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims
,
3752 getIndicesType().getShape(), destRank
,
3753 "scatter", "dest")))
3757 return emitOpError("requires 'unique' attribute to be set");
3758 // TODO: we could also check statically that there are fewer leading index
3759 // tensor dims than the dest dims. If this is not the case, the unique
3760 // attribute cannot be true.
3762 // Use the GatherOp::inferResultType on the `dest` type and verify the
3763 // expected type matches the source type.
3764 RankedTensorType expectedSourceType
= GatherOp::inferResultType(
3765 getDestType(), getIndicesType(), scatterDims
, /*rankReduced=*/false);
3766 RankedTensorType expectedRankReducedSourceType
= GatherOp::inferResultType(
3767 getDestType(), getIndicesType(), scatterDims
, /*rankReduced=*/true);
3768 if (getSourceType() != expectedSourceType
&&
3769 getSourceType() != expectedRankReducedSourceType
) {
3770 return emitOpError("source type "
3773 << expectedSourceType
<< " or its rank-reduced variant "
3774 << expectedRankReducedSourceType
<< " (got: " << getSourceType()
3781 //===----------------------------------------------------------------------===//
3783 //===----------------------------------------------------------------------===//
3785 void SplatOp::build(OpBuilder
&builder
, OperationState
&result
, Value element
,
3786 Type aggregateType
, ValueRange dynamicSizes
) {
3787 build(builder
, result
, aggregateType
, element
, dynamicSizes
);
3790 void SplatOp::build(OpBuilder
&builder
, OperationState
&result
, Value element
,
3791 ArrayRef
<int64_t> staticShape
, ValueRange dynamicSizes
) {
3792 auto aggregateType
= RankedTensorType::get(staticShape
, element
.getType());
3793 build(builder
, result
, aggregateType
, element
, dynamicSizes
);
3796 void SplatOp::build(OpBuilder
&builder
, OperationState
&result
, Value element
,
3797 ArrayRef
<OpFoldResult
> sizes
) {
3798 SmallVector
<int64_t> staticShape
;
3799 SmallVector
<Value
> dynamicSizes
;
3800 dispatchIndexOpFoldResults(sizes
, dynamicSizes
, staticShape
);
3801 build(builder
, result
, element
, staticShape
, dynamicSizes
);
3804 void SplatOp::getAsmResultNames(
3805 function_ref
<void(Value
, StringRef
)> setNameFn
) {
3806 setNameFn(getResult(), "splat");
3809 LogicalResult
SplatOp::verify() {
3810 if (getType().getNumDynamicDims() != getDynamicSizes().size())
3811 return emitOpError("incorrect number of dynamic sizes, has ")
3812 << getDynamicSizes().size() << ", expected "
3813 << getType().getNumDynamicDims();
3818 SplatOp::reifyResultShapes(OpBuilder
&builder
,
3819 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
3820 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(getType().getRank()));
3822 for (int64_t i
= 0; i
< getType().getRank(); ++i
) {
3823 if (getType().isDynamicDim(i
)) {
3824 reifiedReturnShapes
[0][i
] = getDynamicSizes()[ctr
++];
3826 reifiedReturnShapes
[0][i
] = builder
.getIndexAttr(getType().getDimSize(i
));
3832 OpFoldResult
SplatOp::fold(FoldAdaptor adaptor
) {
3833 auto constOperand
= adaptor
.getInput();
3834 if (!isa_and_nonnull
<IntegerAttr
, FloatAttr
>(constOperand
))
3837 // Do not fold if the splat is not statically shaped
3838 if (!getType().hasStaticShape())
3841 // SplatElementsAttr::get treats single value for second arg as being a
3843 return SplatElementsAttr::get(getType(), {constOperand
});
3846 //===----------------------------------------------------------------------===//
3847 // PackOp/UnPackOp Common
3848 //===----------------------------------------------------------------------===//
3850 template <typename OpTy
>
3851 static LogicalResult
3852 reifyResultShapesImpl(OpTy op
, OpBuilder
&builder
,
3853 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
3854 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
3855 "applies to only pack or unpack operations");
3856 int64_t destRank
= op
.getDestRank();
3857 reifiedReturnShapes
.resize(1, SmallVector
<OpFoldResult
>(destRank
));
3858 reifiedReturnShapes
[0] =
3859 tensor::getMixedSizes(builder
, op
.getLoc(), op
.getDest());
3863 template <typename OpTy
>
3864 static DenseMap
<int64_t, OpFoldResult
> getDimAndTileMappingImpl(OpTy op
) {
3865 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
3866 "applies to only pack or unpack operations");
3867 DenseMap
<int64_t, OpFoldResult
> dimAndTileMapping
;
3868 ArrayRef
<int64_t> dimsToTile
= op
.getInnerDimsPos();
3869 SmallVector
<OpFoldResult
> tiles
= op
.getMixedTiles();
3870 assert(tiles
.size() == dimsToTile
.size() &&
3871 "tiles must match indices of dimension to block");
3872 // bind the dimension `i` with the tile factor.
3873 for (auto i
: llvm::seq
<int64_t>(0, dimsToTile
.size()))
3874 dimAndTileMapping
[dimsToTile
[i
]] = tiles
[i
];
3875 return dimAndTileMapping
;
3878 template <typename OpTy
>
3879 static SmallVector
<OpFoldResult
> getMixedTilesImpl(OpTy op
) {
3880 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
3881 "applies to only pack or unpack operations");
3882 Builder
builder(op
);
3883 SmallVector
<OpFoldResult
> mixedInnerTiles
;
3884 unsigned dynamicValIndex
= 0;
3885 for (int64_t staticTile
: op
.getStaticInnerTiles()) {
3886 if (!ShapedType::isDynamic(staticTile
))
3887 mixedInnerTiles
.push_back(builder
.getI64IntegerAttr(staticTile
));
3889 mixedInnerTiles
.push_back(op
.getInnerTiles()[dynamicValIndex
++]);
3891 return mixedInnerTiles
;
3894 template <typename OpTy
>
3895 static SmallVector
<int64_t> getStaticTilesImpl(OpTy op
) {
3896 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
3897 "applies to only pack or unpack operations");
3898 SmallVector
<Value
> dynamicTiles
;
3899 SmallVector
<int64_t> staticTiles
;
3900 dispatchIndexOpFoldResults(op
.getMixedTiles(), dynamicTiles
, staticTiles
);
3904 /// Returns true if `dimsPos` is invalid. It is invalid when:
3905 /// a) It contains duplicate.
3906 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3907 /// c) The number of elements in `dimsPos` is > than `rank`.
3908 static bool isInvalidPackingPosSpecification(ArrayRef
<int64_t> dimsPos
,
3910 size_t dimsPosSize
= dimsPos
.size();
3911 if (dimsPosSize
> rank
)
3913 DenseSet
<int64_t> uniqued
;
3914 for (int64_t dim
: dimsPos
)
3915 uniqued
.insert(dim
);
3916 if (dimsPosSize
!= uniqued
.size())
3918 return llvm::any_of(dimsPos
, [rank
](int64_t dimPos
) {
3919 return dimPos
< 0 || dimPos
>= static_cast<int64_t>(rank
);
3923 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3924 /// of the `limitShape`.
3925 static bool areAllInBound(ArrayRef
<int64_t> sourceShape
,
3926 ArrayRef
<int64_t> limitShape
) {
3928 sourceShape
.size() == limitShape
.size() &&
3929 "expected source shape rank, and limit of the shape to have same rank");
3930 return llvm::all_of(
3931 llvm::zip(sourceShape
, limitShape
), [](std::tuple
<int64_t, int64_t> it
) {
3932 int64_t sourceExtent
= std::get
<0>(it
);
3933 int64_t limit
= std::get
<1>(it
);
3934 return ShapedType::isDynamic(sourceExtent
) ||
3935 ShapedType::isDynamic(limit
) || sourceExtent
<= limit
;
3939 template <typename OpTy
>
3940 static LogicalResult
commonVerifierPackAndUnPackOp(OpTy packOrUnPack
) {
3941 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
3942 "applies to only pack or unpack operations");
3943 Operation
*op
= packOrUnPack
.getOperation();
3945 // Return true if we have a zero-value tile.
3946 auto hasZeros
= [&](ArrayRef
<OpFoldResult
> tiles
) {
3947 return llvm::any_of(
3948 tiles
, [](OpFoldResult tile
) { return isConstantIntValue(tile
, 0); });
3951 // Verify tiles. Do not allow zero tiles.
3952 SmallVector
<OpFoldResult
> mixedTiles
= packOrUnPack
.getMixedTiles();
3953 if (hasZeros(mixedTiles
))
3954 return op
->emitError("invalid zero tile factor");
3956 // Verify inner_dims_pos and outer_dims_perm.
3957 RankedTensorType unpackedType
= (std::is_same
<OpTy
, PackOp
>::value
)
3958 ? packOrUnPack
.getSourceType()
3959 : packOrUnPack
.getDestType();
3960 size_t unpackedRank
= unpackedType
.getRank();
3961 ArrayRef
<int64_t> innerDimsPos
= packOrUnPack
.getInnerDimsPos();
3962 ArrayRef
<int64_t> outerDimPerm
= packOrUnPack
.getOuterDimsPerm();
3963 if (isInvalidPackingPosSpecification(innerDimsPos
, unpackedRank
))
3964 return op
->emitError("invalid inner_dims_pos vector");
3965 if (isInvalidPackingPosSpecification(outerDimPerm
, unpackedRank
))
3966 return op
->emitError("invalid outer_dims_perm vector");
3967 if (!outerDimPerm
.empty() && outerDimPerm
.size() != unpackedRank
)
3968 return op
->emitError("outer_dims_perm must be a permutation or empty");
3970 // Tiling factors must be less than or equal to the input rank for pack (or
3971 // output rank for unpack), and must match the number of `inner_dims_pos`.
3972 if (mixedTiles
.size() > unpackedRank
) {
3973 return op
->emitError("tiling factors must be less than or equal to the "
3974 "input rank for pack or output rank for unpack");
3976 if (mixedTiles
.size() != innerDimsPos
.size()) {
3977 return op
->emitError(
3978 "tiling factors must equal the number of dimensions to tile");
3981 ShapedType packedType
= (std::is_same
<OpTy
, PackOp
>::value
)
3982 ? packOrUnPack
.getDestType()
3983 : packOrUnPack
.getSourceType();
3984 size_t packedRank
= packedType
.getRank();
3985 // Require output rank to match input rank + number of blocking factors.
3986 if (unpackedRank
+ mixedTiles
.size() != packedRank
) {
3987 return op
->emitError(
3988 "packed rank must equal unpacked rank + tiling factors");
3991 // Verify result shape is greater than the minimum expected
3992 // by the pack operation, and that the output shape
3993 // represents full tiles.
3994 RankedTensorType expectedPackedType
= PackOp::inferPackedType(
3995 unpackedType
, packOrUnPack
.getStaticTiles(), innerDimsPos
, outerDimPerm
);
3996 if (!areAllInBound(expectedPackedType
.getShape(), packedType
.getShape())) {
3997 return op
->emitError("the shape of output is not large enough to hold the "
3998 "packed data. Expected at least ")
3999 << expectedPackedType
<< ", got " << packedType
;
4002 llvm::zip(packedType
.getShape().take_back(mixedTiles
.size()),
4004 [](std::tuple
<int64_t, OpFoldResult
> it
) {
4005 int64_t shape
= std::get
<0>(it
);
4006 if (Attribute attr
=
4007 llvm::dyn_cast_if_present
<Attribute
>(std::get
<1>(it
))) {
4008 IntegerAttr intAttr
= dyn_cast_or_null
<IntegerAttr
>(attr
);
4009 int64_t staticTileSize
= intAttr
.getValue().getSExtValue();
4010 return shape
== staticTileSize
;
4012 return ShapedType::isDynamic(shape
);
4014 return op
->emitError("mismatch in inner tile sizes specified and shaped of "
4015 "tiled dimension in the packed type");
4021 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4022 /// various permutations to the op.
4023 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4024 // these. These may or may not become true foldings / canonicalizations
4025 // depending on how aggressive we want to be in automatically folding
4027 struct PackOrUnPackTransposeResult
{
4028 SmallVector
<int64_t> innerDimsPos
;
4029 SmallVector
<OpFoldResult
> innerTiles
;
4030 SmallVector
<int64_t> outerDimsPerm
;
4034 template <typename OpTy
>
4035 static PackOrUnPackTransposeResult
4036 commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp
,
4037 ArrayRef
<int64_t> innerPermutation
,
4038 ArrayRef
<int64_t> outerPermutation
) {
4039 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
4040 "applies to only pack or unpack operations");
4041 assert((!innerPermutation
.empty() || !outerPermutation
.empty()) &&
4042 "some permutation must be non-empty");
4043 PackOrUnPackTransposeResult metadata
;
4044 metadata
.innerDimsPos
=
4045 SmallVector
<int64_t>(packOrUnPackOp
.getInnerDimsPos());
4046 metadata
.innerTiles
=
4047 SmallVector
<OpFoldResult
>(packOrUnPackOp
.getMixedTiles());
4048 int64_t numOuterDims
= std::is_same
<OpTy
, PackOp
>::value
4049 ? packOrUnPackOp
.getSourceRank()
4050 : packOrUnPackOp
.getDestRank();
4051 metadata
.outerDimsPerm
=
4052 packOrUnPackOp
.getOuterDimsPerm().empty()
4053 ? llvm::to_vector(llvm::seq
<int64_t>(0, numOuterDims
))
4054 : SmallVector
<int64_t>(packOrUnPackOp
.getOuterDimsPerm());
4055 if (!innerPermutation
.empty()) {
4056 assert(innerPermutation
.size() == metadata
.innerDimsPos
.size() &&
4057 isPermutationVector(innerPermutation
) &&
4058 "invalid inner permutation");
4059 applyPermutationToVector(metadata
.innerDimsPos
, innerPermutation
);
4060 applyPermutationToVector(metadata
.innerTiles
, innerPermutation
);
4062 if (!outerPermutation
.empty()) {
4063 assert(outerPermutation
.size() == metadata
.outerDimsPerm
.size() &&
4064 isPermutationVector(outerPermutation
) &&
4065 "invalid outer permutation");
4066 applyPermutationToVector(metadata
.outerDimsPerm
, outerPermutation
);
4071 //===----------------------------------------------------------------------===//
4073 //===----------------------------------------------------------------------===//
4075 void PackOp::getAsmResultNames(function_ref
<void(Value
, StringRef
)> setNameFn
) {
4076 setNameFn(getResult(), "pack");
4079 void PackOp::build(OpBuilder
&builder
, OperationState
&state
, Value source
,
4080 Value dest
, ArrayRef
<int64_t> innerDimsPos
,
4081 ArrayRef
<OpFoldResult
> innerTiles
,
4082 std::optional
<Value
> paddingValue
,
4083 ArrayRef
<int64_t> outerDimsPerm
) {
4084 assert(innerDimsPos
.size() == innerTiles
.size() &&
4085 "number of tile sizes specified must match the specified number of "
4086 "original dimensions to be tiled");
4087 SmallVector
<int64_t> staticTileSizes
;
4088 SmallVector
<Value
> dynamicTileSizes
;
4089 dispatchIndexOpFoldResults(innerTiles
, dynamicTileSizes
, staticTileSizes
);
4090 build(builder
, state
, dest
.getType(), source
, dest
,
4091 paddingValue
? *paddingValue
: nullptr,
4092 outerDimsPerm
.empty() ? nullptr
4093 : builder
.getDenseI64ArrayAttr(outerDimsPerm
),
4094 builder
.getDenseI64ArrayAttr(innerDimsPos
), dynamicTileSizes
,
4095 builder
.getDenseI64ArrayAttr(staticTileSizes
));
4099 PackOp::reifyResultShapes(OpBuilder
&builder
,
4100 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
4101 return reifyResultShapesImpl(*this, builder
, reifiedReturnShapes
);
4104 DenseMap
<int64_t, OpFoldResult
> PackOp::getDimAndTileMapping() {
4105 return getDimAndTileMappingImpl(*this);
4108 SmallVector
<OpFoldResult
> PackOp::getMixedTiles() {
4109 return getMixedTilesImpl(*this);
4112 SmallVector
<int64_t> PackOp::getStaticTiles() {
4113 return getStaticTilesImpl(*this);
4116 ArrayRef
<int64_t> PackOp::getAllOuterDims() {
4117 ShapedType inputType
= getSourceType();
4118 int64_t inputRank
= inputType
.getRank();
4119 return getDestType().getShape().take_front(inputRank
);
4122 SmallVector
<int64_t> PackOp::getTiledOuterDims() {
4123 auto innerDimsPos
= getInnerDimsPos();
4124 auto packedShape
= getDestType().getShape();
4125 SmallVector
<int64_t> res
;
4127 for (auto index
: innerDimsPos
)
4128 res
.push_back(packedShape
[index
]);
4133 bool PackOp::requirePaddingValue(ArrayRef
<int64_t> inputShape
,
4134 ArrayRef
<int64_t> innerDimsPos
,
4135 ArrayRef
<int64_t> outputShape
,
4136 ArrayRef
<int64_t> outerDimsPerm
,
4137 ArrayRef
<OpFoldResult
> innerTiles
) {
4138 SmallVector
<int64_t> outputTileSizes(
4139 outputShape
.take_front(inputShape
.size()));
4140 if (!outerDimsPerm
.empty()) {
4141 assert(outerDimsPerm
.size() == outputTileSizes
.size() &&
4142 "expected output and outer_dims_perm to have same size");
4143 applyPermutationToVector(outputTileSizes
,
4144 invertPermutationVector(outerDimsPerm
));
4146 for (auto [pos
, tileSize
] : llvm::zip_equal(innerDimsPos
, innerTiles
)) {
4147 if (ShapedType::isDynamic(inputShape
[pos
]))
4149 std::optional
<int64_t> constantTile
= getConstantIntValue(tileSize
);
4151 if (!constantTile
) {
4152 if (!ShapedType::isDynamic(outputTileSizes
[pos
]) &&
4153 (inputShape
[pos
] % outputTileSizes
[pos
] != 0))
4155 } else if (inputShape
[pos
] % (*constantTile
) != 0) {
4162 LogicalResult
PackOp::verify() {
4163 if (failed(commonVerifierPackAndUnPackOp(*this)))
4166 // Verify padding value, and bail out if the tile does not divide the
4167 // dimension fully. In the case of dynamic tile factors or dimensions, having
4168 // a partial tile is undefined behavior.
4169 auto paddingValue
= getPaddingValue();
4171 paddingValue
.getType() != getSourceType().getElementType()) {
4172 return emitOpError("expected padding_value has ")
4173 << getSourceType().getElementType()
4174 << " but got: " << paddingValue
.getType();
4177 if (!paddingValue
&&
4178 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4179 getDestType().getShape(), getOuterDimsPerm(),
4182 "invalid tile factor or output size provided. Only full tiles are "
4183 "supported when padding_value is not set");
4188 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4189 /// Value's to kDynamic, even if they are arith.constant values.
4190 static SmallVector
<int64_t>
4191 asShapeWithAnyValueAsDynamic(ArrayRef
<OpFoldResult
> ofrs
) {
4192 SmallVector
<int64_t> result
;
4193 for (auto o
: ofrs
) {
4194 // Have to do this first, as getConstantIntValue special-cases constants.
4195 if (llvm::dyn_cast_if_present
<Value
>(o
))
4196 result
.push_back(ShapedType::kDynamic
);
4198 result
.push_back(getConstantIntValue(o
).value_or(ShapedType::kDynamic
));
4203 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4204 /// the packed type. Having a shared helper helps implement these two methods in
4205 /// a way that ensures that they agree on which dimensions are dynamic.
4206 static SmallVector
<int64_t> getPackOpResultTypeShape(
4207 ArrayRef
<int64_t> sourceShape
, ArrayRef
<int64_t> innerTileSizes
,
4208 ArrayRef
<int64_t> innerDimsPos
, ArrayRef
<int64_t> outerDimsPerm
) {
4209 SmallVector
<int64_t> resultShape
= llvm::to_vector(sourceShape
);
4210 for (auto tiledDim
: llvm::enumerate(llvm::to_vector(innerDimsPos
))) {
4211 if (ShapedType::isDynamic(resultShape
[tiledDim
.value()]))
4213 if (ShapedType::isDynamic(innerTileSizes
[tiledDim
.index()])) {
4214 resultShape
[tiledDim
.value()] = ShapedType::kDynamic
;
4217 resultShape
[tiledDim
.value()] = divideCeilSigned(
4218 resultShape
[tiledDim
.value()], innerTileSizes
[tiledDim
.index()]);
4221 // Swap tile loops if outer_dims_perm is available.
4222 if (!outerDimsPerm
.empty())
4223 applyPermutationToVector(resultShape
, outerDimsPerm
);
4225 // Append the inner tile dimensions.
4226 resultShape
.append(innerTileSizes
.begin(), innerTileSizes
.end());
4230 SmallVector
<OpFoldResult
> PackOp::getResultShape(
4231 OpBuilder
&builder
, Location loc
, ArrayRef
<OpFoldResult
> sourceDims
,
4232 ArrayRef
<OpFoldResult
> innerTileSizes
, ArrayRef
<int64_t> innerDimsPos
,
4233 ArrayRef
<int64_t> outerDimsPerm
) {
4234 SmallVector
<OpFoldResult
> resultDims
= llvm::to_vector(sourceDims
);
4237 bindSymbols(builder
.getContext(), s0
, s1
);
4238 AffineExpr ceilDivExpr
= s0
.ceilDiv(s1
);
4239 for (auto tiledDim
: llvm::enumerate(llvm::to_vector(innerDimsPos
))) {
4240 resultDims
[tiledDim
.value()] = affine::makeComposedFoldedAffineApply(
4241 builder
, loc
, ceilDivExpr
,
4242 {resultDims
[tiledDim
.value()], innerTileSizes
[tiledDim
.index()]});
4244 if (!outerDimsPerm
.empty())
4245 applyPermutationToVector(resultDims
, outerDimsPerm
);
4246 resultDims
.append(innerTileSizes
.begin(), innerTileSizes
.end());
4248 SmallVector
<int64_t> resultTypeShape
=
4249 getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims
),
4250 asShapeWithAnyValueAsDynamic(innerTileSizes
),
4251 innerDimsPos
, outerDimsPerm
);
4253 // Fix-up `resultDims` to ensure that they are Value's if and only if the
4254 // result type shape says it's a dynamic dim. This is needed as callers may
4255 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4256 // dynamic dims returned by that.
4257 for (unsigned i
= 0; i
< resultDims
.size(); ++i
) {
4258 if (!ShapedType::isDynamic(resultTypeShape
[i
]))
4261 getValueOrCreateConstantIndexOp(builder
, loc
, resultDims
[i
]);
4267 /// Get the expected packed type based on source type, tile factors, position of
4268 /// the inner tiles and permutation of the outer tiled loop.
4269 RankedTensorType
PackOp::inferPackedType(RankedTensorType sourceType
,
4270 ArrayRef
<int64_t> innerTileSizes
,
4271 ArrayRef
<int64_t> innerDimsPos
,
4272 ArrayRef
<int64_t> outerDimsPerm
) {
4273 SmallVector
<int64_t> resultShape
= getPackOpResultTypeShape(
4274 sourceType
.getShape(), innerTileSizes
, innerDimsPos
, outerDimsPerm
);
4275 return RankedTensorType::get(resultShape
, sourceType
.getElementType());
4278 Value
PackOp::createDestinationTensor(OpBuilder
&b
, Location loc
, Value source
,
4279 ArrayRef
<OpFoldResult
> innerTileSizes
,
4280 ArrayRef
<int64_t> innerDimsPos
,
4281 ArrayRef
<int64_t> outerDimsPerm
) {
4282 AffineExpr dim0
, dim1
;
4283 bindDims(b
.getContext(), dim0
, dim1
);
4284 auto ceilDiv
= [&](OpFoldResult v1
, OpFoldResult v2
) -> OpFoldResult
{
4285 return affine::makeComposedFoldedAffineApply(b
, loc
, dim0
.ceilDiv(dim1
),
4289 SmallVector
<OpFoldResult
> mixedSizes
;
4290 for (auto [index
, value
] : llvm::enumerate(
4291 llvm::cast
<RankedTensorType
>(source
.getType()).getShape())) {
4292 if (ShapedType::isDynamic(value
))
4293 mixedSizes
.push_back(b
.create
<DimOp
>(loc
, source
, index
).getResult());
4295 mixedSizes
.push_back(b
.getIndexAttr(value
));
4297 for (auto it
: llvm::zip(innerDimsPos
, innerTileSizes
)) {
4298 int64_t dimPos
= std::get
<0>(it
);
4299 OpFoldResult tileSize
= std::get
<1>(it
);
4300 mixedSizes
[dimPos
] = ceilDiv(mixedSizes
[dimPos
], tileSize
);
4302 if (!outerDimsPerm
.empty())
4303 applyPermutationToVector
<OpFoldResult
>(mixedSizes
, outerDimsPerm
);
4305 mixedSizes
.append(innerTileSizes
.begin(), innerTileSizes
.end());
4306 auto elemType
= llvm::cast
<ShapedType
>(source
.getType()).getElementType();
4307 return b
.create
<tensor::EmptyOp
>(loc
, mixedSizes
, elemType
);
4310 PackOp
PackOp::createTransposedClone(OpBuilder
&b
, Location loc
,
4311 ArrayRef
<int64_t> innerPermutation
,
4312 ArrayRef
<int64_t> outerPermutation
) {
4313 PackOrUnPackTransposeResult metadata
= commonPermutationOfPackAndUnPackOp(
4314 *this, innerPermutation
, outerPermutation
);
4315 Value transposedDest
=
4316 createDestinationTensor(b
, loc
, getSource(), metadata
.innerTiles
,
4317 metadata
.innerDimsPos
, metadata
.outerDimsPerm
);
4318 return b
.create
<PackOp
>(loc
, getSource(), transposedDest
,
4319 metadata
.innerDimsPos
, metadata
.innerTiles
,
4320 getPaddingValue(), metadata
.outerDimsPerm
);
4323 /// Returns true if the tiles and the tiled dims are constant.
4324 template <typename OpTy
>
4325 bool areTilesAndTiledDimsAllConstant(OpTy op
) {
4326 static_assert(llvm::is_one_of
<OpTy
, PackOp
, UnPackOp
>::value
,
4327 "applies to only pack or unpack operations");
4328 ShapedType packedType
= (std::is_same
<OpTy
, PackOp
>::value
)
4330 : op
.getSourceType();
4331 SmallVector
<OpFoldResult
> mixedTiles
= op
.getMixedTiles();
4332 for (auto [dimDest
, tile
] : llvm::zip(
4333 packedType
.getShape().take_back(mixedTiles
.size()), mixedTiles
)) {
4334 std::optional
<int64_t> constTileSize
= getConstantIntValue(tile
);
4335 if (!constTileSize
|| ShapedType::isDynamic(dimDest
))
4341 Speculation::Speculatability
PackOp::getSpeculatability() {
4342 if (getPaddingValue())
4343 return Speculation::Speculatable
;
4345 // The verifier rejects already operations if we can statically prove that the
4346 // sizes of the tiles do not divide perfectly the dimension; thus, check only
4347 // to have constant tiles and tiled inner dimensions.
4348 if (!areTilesAndTiledDimsAllConstant(*this))
4349 return Speculation::NotSpeculatable
;
4351 return Speculation::Speculatable
;
4354 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4355 // dimensions for pack and unpack.
4356 static bool hasSameInnerOuterAttribute(PackOp packOp
, UnPackOp unPackOp
) {
4357 if (packOp
.getInnerDimsPos() != unPackOp
.getInnerDimsPos())
4359 if (packOp
.getOuterDimsPerm() == unPackOp
.getOuterDimsPerm())
4361 // Outer dims permutation is optional.
4362 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4363 // identity permutation.
4364 return isIdentityPermutation(packOp
.getOuterDimsPerm()) &&
4365 isIdentityPermutation(unPackOp
.getOuterDimsPerm());
4368 // Return true if pack and unpack have the same tiles.
4369 // Same SSA values or same integer constants.
4370 static bool haveSameTiles(PackOp packOp
, UnPackOp unPackOp
) {
4371 auto packTiles
= packOp
.getMixedTiles();
4372 auto unPackTiles
= unPackOp
.getMixedTiles();
4373 if (packTiles
.size() != unPackTiles
.size())
4375 for (size_t i
= 0, e
= packTiles
.size(); i
< e
; i
++) {
4376 if (!isEqualConstantIntOrValue(packTiles
[i
], unPackTiles
[i
]))
4382 /// Returns true if the pack op does not need a padding value.
4383 static bool paddingIsNotNeeded(PackOp op
) {
4384 auto srcType
= op
.getSourceType();
4385 if (llvm::any_of(op
.getInnerDimsPos(),
4386 [&](int64_t pos
) { return srcType
.isDynamicDim(pos
); }))
4388 if (ShapedType::isDynamicShape(op
.getStaticInnerTiles()))
4390 return !PackOp::requirePaddingValue(
4391 srcType
.getShape(), op
.getInnerDimsPos(), op
.getDestType().getShape(),
4392 op
.getOuterDimsPerm(), op
.getMixedTiles());
4395 /// Returns true if the `srcShape` or `destShape` is different from the one in
4396 /// `packOp` and populates each with the inferred static shape.
4397 static bool inferStaticShape(PackOp packOp
, SmallVectorImpl
<int64_t> &srcShape
,
4398 SmallVectorImpl
<int64_t> &destShape
) {
4399 bool changeNeeded
= false;
4400 srcShape
.assign(packOp
.getSourceType().getShape().begin(),
4401 packOp
.getSourceType().getShape().end());
4402 destShape
.assign(packOp
.getDestType().getShape().begin(),
4403 packOp
.getDestType().getShape().end());
4404 llvm::SmallSetVector
<int64_t, 4> innerDims
;
4405 innerDims
.insert(packOp
.getInnerDimsPos().begin(),
4406 packOp
.getInnerDimsPos().end());
4407 SmallVector
<int64_t> inverseOuterDimsPerm
;
4408 if (!packOp
.getOuterDimsPerm().empty())
4409 inverseOuterDimsPerm
= invertPermutationVector(packOp
.getOuterDimsPerm());
4410 int srcRank
= packOp
.getSourceRank();
4411 for (auto i
: llvm::seq
<int64_t>(0, srcRank
)) {
4412 if (innerDims
.contains(i
))
4415 int64_t destPos
= i
;
4416 if (!inverseOuterDimsPerm
.empty())
4417 destPos
= inverseOuterDimsPerm
[srcPos
];
4418 if (ShapedType::isDynamic(srcShape
[srcPos
]) ==
4419 ShapedType::isDynamic(destShape
[destPos
])) {
4422 int64_t size
= srcShape
[srcPos
];
4423 if (ShapedType::isDynamic(size
))
4424 size
= destShape
[destPos
];
4425 srcShape
[srcPos
] = size
;
4426 destShape
[destPos
] = size
;
4427 changeNeeded
= true;
4429 return changeNeeded
;
4432 LogicalResult
PackOp::canonicalize(PackOp packOp
, PatternRewriter
&rewriter
) {
4433 // Fold an pack(unpack(x)) to x.
4434 if (auto unPackOp
= packOp
.getSource().getDefiningOp
<UnPackOp
>()) {
4435 if (unPackOp
.getSourceType() != packOp
.getDestType())
4437 if (packOp
.getPaddingValue() ||
4438 !hasSameInnerOuterAttribute(packOp
, unPackOp
) ||
4439 !haveSameTiles(packOp
, unPackOp
))
4441 rewriter
.replaceOp(packOp
, unPackOp
.getSource());
4445 // Fold optional PaddingValue operand away if padding is not needed.
4446 if (packOp
.getPaddingValue() && paddingIsNotNeeded(packOp
)) {
4447 rewriter
.startOpModification(packOp
);
4448 packOp
.getPaddingValueMutable().clear();
4449 rewriter
.finalizeOpModification(packOp
);
4453 // Insert tensor.cast ops if static shape inference is available..
4454 SmallVector
<int64_t> srcShape
, destShape
;
4455 if (inferStaticShape(packOp
, srcShape
, destShape
)) {
4456 Location loc
= packOp
.getLoc();
4457 Value source
= packOp
.getSource();
4458 if (srcShape
!= packOp
.getSourceType().getShape()) {
4459 auto newSrcType
= packOp
.getSourceType().clone(srcShape
);
4461 rewriter
.create
<tensor::CastOp
>(loc
, newSrcType
, packOp
.getSource());
4463 Value dest
= packOp
.getDest();
4464 RankedTensorType originalResultType
= packOp
.getDestType();
4465 bool needUpdateDestType
= (destShape
!= originalResultType
.getShape());
4466 if (needUpdateDestType
) {
4467 auto newDestType
= packOp
.getDestType().clone(destShape
);
4469 rewriter
.create
<tensor::CastOp
>(loc
, newDestType
, packOp
.getDest());
4471 rewriter
.modifyOpInPlace(packOp
, [&] {
4472 packOp
.getSourceMutable().assign(source
);
4473 packOp
.getDestMutable().assign(dest
);
4474 packOp
.getResult().setType(cast
<RankedTensorType
>(dest
.getType()));
4476 // Insert a cast if needed
4477 if (needUpdateDestType
) {
4478 rewriter
.setInsertionPointAfter(packOp
);
4480 rewriter
.create
<tensor::CastOp
>(loc
, originalResultType
, packOp
);
4481 rewriter
.replaceAllUsesExcept(packOp
, castOp
, castOp
);
4489 template <typename PackOrUnpackOp
>
4490 static bool isLikePadUnPad(PackOrUnpackOp packOp
,
4491 RankedTensorType packedTensorType
) {
4492 static_assert(std::is_same
<PackOrUnpackOp
, tensor::PackOp
>::value
||
4493 std::is_same
<PackOrUnpackOp
, tensor::UnPackOp
>::value
,
4494 "Function meant for pack/unpack");
4495 // This is a pad if packing only adds ones and we don't transpose dimensions.
4497 // Check that we are not transposing any dimensions.
4498 ArrayRef
<int64_t> innerDimsPos
= packOp
.getInnerDimsPos();
4499 int64_t numPackedDims
= innerDimsPos
.size();
4500 auto orderedDims
= llvm::to_vector
<4>(llvm::seq
<int64_t>(0, numPackedDims
));
4501 if (orderedDims
!= innerDimsPos
) {
4502 // Dimensions don't happen in order.
4506 ArrayRef
<int64_t> packedShape
= packedTensorType
.getShape();
4507 int64_t packedRank
= packedTensorType
.getRank();
4508 // At this point we know that we are taking numPackedDims outer
4509 // dimensions and pushing them all the way as the inner most dimensions.
4510 // What's left on the outer most dimensions is, in this order:
4511 // - the factor of the packed dimensions, then
4512 // - the untouched dimensions
4513 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4514 // if all the dimensions that bubble outerward are ones.
4515 // Therefore check that all the dimensions but the numPackedDims inner most
4517 return llvm::all_of(
4518 llvm::seq
<int64_t>(0, packedRank
- numPackedDims
),
4519 [&packedShape
](int64_t i
) { return packedShape
[i
] == 1; });
4522 bool PackOp::isLikePad() {
4523 auto packedTensorType
=
4524 llvm::cast
<RankedTensorType
>((*this)->getResultTypes().front());
4525 return isLikePadUnPad(*this, packedTensorType
);
4528 OpFoldResult
PackOp::fold(FoldAdaptor adaptor
) {
4529 std::optional
<Attribute
> paddingValue
;
4530 if (auto pad
= adaptor
.getPaddingValue())
4532 if (OpFoldResult reshapedSource
= reshapeConstantSource(
4533 llvm::dyn_cast_if_present
<DenseElementsAttr
>(adaptor
.getSource()),
4534 getDestType(), paddingValue
))
4535 return reshapedSource
;
4539 //===----------------------------------------------------------------------===//
4541 //===----------------------------------------------------------------------===//
4543 void UnPackOp::getAsmResultNames(
4544 function_ref
<void(Value
, StringRef
)> setNameFn
) {
4545 setNameFn(getResult(), "unpack");
4549 UnPackOp::reifyResultShapes(OpBuilder
&builder
,
4550 ReifiedRankedShapedTypeDims
&reifiedReturnShapes
) {
4551 return reifyResultShapesImpl(*this, builder
, reifiedReturnShapes
);
4554 DenseMap
<int64_t, OpFoldResult
> UnPackOp::getDimAndTileMapping() {
4555 return getDimAndTileMappingImpl(*this);
4558 SmallVector
<OpFoldResult
> UnPackOp::getMixedTiles() {
4559 return getMixedTilesImpl(*this);
4562 SmallVector
<int64_t> UnPackOp::getStaticTiles() {
4563 return getStaticTilesImpl(*this);
4566 ArrayRef
<int64_t> UnPackOp::getAllOuterDims() {
4567 ShapedType destType
= getDestType();
4568 int64_t destRank
= destType
.getRank();
4569 return getSourceType().getShape().take_front(destRank
);
4572 SmallVector
<int64_t> UnPackOp::getTiledOuterDims() {
4573 auto innerDimsPos
= getInnerDimsPos();
4574 auto packedShape
= getSourceType().getShape();
4575 SmallVector
<int64_t> res
;
4577 for (auto index
: innerDimsPos
)
4578 res
.push_back(packedShape
[index
]);
4583 LogicalResult
UnPackOp::verify() {
4584 return commonVerifierPackAndUnPackOp(*this);
4587 Speculation::Speculatability
UnPackOp::getSpeculatability() {
4588 // See PackOp::getSpeculatability.
4589 if (!areTilesAndTiledDimsAllConstant(*this))
4590 return Speculation::NotSpeculatable
;
4592 return Speculation::Speculatable
;
4595 void UnPackOp::build(OpBuilder
&builder
, OperationState
&state
, Value source
,
4596 Value dest
, ArrayRef
<int64_t> innerDimsPos
,
4597 ArrayRef
<OpFoldResult
> innerTiles
,
4598 ArrayRef
<int64_t> outerDimsPerm
) {
4599 assert(innerDimsPos
.size() == innerTiles
.size() &&
4600 "number of tile sizes specified must match the specified number of "
4601 "original dimensions to be tiled");
4602 SmallVector
<int64_t> staticTileSizes
;
4603 SmallVector
<Value
> dynamicTileSizes
;
4604 dispatchIndexOpFoldResults(innerTiles
, dynamicTileSizes
, staticTileSizes
);
4605 build(builder
, state
, dest
.getType(), source
, dest
,
4606 outerDimsPerm
.empty() ? nullptr
4607 : builder
.getDenseI64ArrayAttr(outerDimsPerm
),
4608 builder
.getDenseI64ArrayAttr(innerDimsPos
), dynamicTileSizes
,
4609 builder
.getDenseI64ArrayAttr(staticTileSizes
));
4612 Value
UnPackOp::createDestinationTensor(OpBuilder
&b
, Location loc
,
4614 ArrayRef
<OpFoldResult
> innerTileSizes
,
4615 ArrayRef
<int64_t> innerDimsPos
,
4616 ArrayRef
<int64_t> outerDimsPerm
) {
4617 AffineExpr sym0
, sym1
;
4618 bindSymbols(b
.getContext(), sym0
, sym1
);
4619 auto dimMul
= [&](OpFoldResult v1
, OpFoldResult v2
) -> OpFoldResult
{
4620 return affine::makeComposedFoldedAffineApply(b
, loc
, sym0
* sym1
, {v1
, v2
});
4623 SmallVector
<OpFoldResult
> mixedSizes
;
4624 auto srcType
= llvm::cast
<RankedTensorType
>(source
.getType());
4626 llvm::seq
<unsigned>(0, srcType
.getRank() - innerTileSizes
.size())) {
4627 if (srcType
.isDynamicDim(i
))
4628 mixedSizes
.push_back(b
.create
<DimOp
>(loc
, source
, i
).getResult());
4630 mixedSizes
.push_back(b
.getIndexAttr(srcType
.getDimSize(i
)));
4632 if (!outerDimsPerm
.empty()) {
4633 applyPermutationToVector
<OpFoldResult
>(
4634 mixedSizes
, invertPermutationVector(outerDimsPerm
));
4637 for (auto [dimPos
, tileSize
] : llvm::zip_equal(innerDimsPos
, innerTileSizes
))
4638 mixedSizes
[dimPos
] = dimMul(mixedSizes
[dimPos
], tileSize
);
4640 auto elemType
= srcType
.getElementType();
4641 return b
.create
<tensor::EmptyOp
>(loc
, mixedSizes
, elemType
);
4644 UnPackOp
UnPackOp::createTransposedClone(OpBuilder
&b
, Location loc
,
4645 Value transposedSource
,
4646 ArrayRef
<int64_t> innerPermutation
,
4647 ArrayRef
<int64_t> outerPermutation
) {
4648 PackOrUnPackTransposeResult metadata
= commonPermutationOfPackAndUnPackOp(
4649 *this, innerPermutation
, outerPermutation
);
4650 return b
.create
<UnPackOp
>(loc
, transposedSource
, getDest(),
4651 metadata
.innerDimsPos
, metadata
.innerTiles
,
4652 metadata
.outerDimsPerm
);
4655 /// Returns true if the `srcShape` or `destShape` is different from the one in
4656 /// `op` and populates each with the inferred static shape.
4657 static bool inferStaticShape(UnPackOp op
, SmallVectorImpl
<int64_t> &srcShape
,
4658 SmallVectorImpl
<int64_t> &destShape
) {
4659 bool changeNeeded
= false;
4660 srcShape
.assign(op
.getSourceType().getShape().begin(),
4661 op
.getSourceType().getShape().end());
4662 destShape
.assign(op
.getDestType().getShape().begin(),
4663 op
.getDestType().getShape().end());
4664 llvm::SmallSetVector
<int64_t, 4> innerDims
;
4665 innerDims
.insert(op
.getInnerDimsPos().begin(), op
.getInnerDimsPos().end());
4666 SmallVector
<int64_t> inverseOuterDimsPerm
;
4667 if (!op
.getOuterDimsPerm().empty())
4668 inverseOuterDimsPerm
= invertPermutationVector(op
.getOuterDimsPerm());
4669 int destRank
= op
.getDestRank();
4670 for (auto i
: llvm::seq
<int64_t>(0, destRank
)) {
4671 if (innerDims
.contains(i
))
4674 int64_t destPos
= i
;
4675 if (!inverseOuterDimsPerm
.empty())
4676 srcPos
= inverseOuterDimsPerm
[destPos
];
4677 if (ShapedType::isDynamic(srcShape
[srcPos
]) ==
4678 ShapedType::isDynamic(destShape
[destPos
])) {
4681 int64_t size
= srcShape
[srcPos
];
4682 if (ShapedType::isDynamic(size
))
4683 size
= destShape
[destPos
];
4684 srcShape
[srcPos
] = size
;
4685 destShape
[destPos
] = size
;
4686 changeNeeded
= true;
4688 return changeNeeded
;
4691 LogicalResult
UnPackOp::canonicalize(UnPackOp unPackOp
,
4692 PatternRewriter
&rewriter
) {
4693 /// unpack(pack(x)) -> x
4694 if (PackOp packOp
= unPackOp
.getSource().getDefiningOp
<tensor::PackOp
>()) {
4695 if (packOp
.getSourceType() != unPackOp
.getDestType())
4697 if (packOp
.getPaddingValue() ||
4698 !hasSameInnerOuterAttribute(packOp
, unPackOp
) ||
4699 !haveSameTiles(packOp
, unPackOp
))
4701 rewriter
.replaceOp(unPackOp
, packOp
.getSource());
4704 /// unpack(destinationStyleOp(x)) -> unpack(x)
4705 if (auto dstStyleOp
=
4706 unPackOp
.getDest().getDefiningOp
<DestinationStyleOpInterface
>()) {
4707 auto destValue
= cast
<OpResult
>(unPackOp
.getDest());
4708 Value newDest
= dstStyleOp
.getDpsInits()[destValue
.getResultNumber()];
4709 rewriter
.modifyOpInPlace(unPackOp
,
4710 [&]() { unPackOp
.setDpsInitOperand(0, newDest
); });
4714 // Insert tensor.cast ops if static shape inference is available..
4715 SmallVector
<int64_t> srcShape
, destShape
;
4716 if (inferStaticShape(unPackOp
, srcShape
, destShape
)) {
4717 Location loc
= unPackOp
.getLoc();
4718 Value source
= unPackOp
.getSource();
4719 if (srcShape
!= unPackOp
.getSourceType().getShape()) {
4720 auto newSrcType
= unPackOp
.getSourceType().clone(srcShape
);
4721 source
= rewriter
.create
<tensor::CastOp
>(loc
, newSrcType
,
4722 unPackOp
.getSource());
4724 Value dest
= unPackOp
.getDest();
4725 if (destShape
!= unPackOp
.getDestType().getShape()) {
4726 auto newDestType
= unPackOp
.getDestType().clone(destShape
);
4728 rewriter
.create
<tensor::CastOp
>(loc
, newDestType
, unPackOp
.getDest());
4730 Value newOp
= rewriter
.create
<tensor::UnPackOp
>(
4731 loc
, source
, dest
, unPackOp
.getInnerDimsPos(), unPackOp
.getMixedTiles(),
4732 unPackOp
.getOuterDimsPerm());
4733 rewriter
.replaceOpWithNewOp
<tensor::CastOp
>(
4734 unPackOp
, unPackOp
.getResult().getType(), newOp
);
4741 bool UnPackOp::isLikeUnPad() {
4742 RankedTensorType packedTensorType
= getSourceType();
4743 return isLikePadUnPad(*this, packedTensorType
);
4746 OpFoldResult
UnPackOp::fold(FoldAdaptor adaptor
) {
4747 if (OpFoldResult reshapedSource
= reshapeConstantSource(
4748 llvm::dyn_cast_if_present
<DenseElementsAttr
>(adaptor
.getSource()),
4749 getResult().getType()))
4750 return reshapedSource
;
4754 //===----------------------------------------------------------------------===//
4755 // Common Canonicalizers and Folders.
4756 //===----------------------------------------------------------------------===//
4757 bool foldTensorCastPrecondition(DestinationStyleOpInterface op
) {
4758 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4759 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4760 // might need special handling of attached regions.
4761 if (isa
<InsertSliceOp
>(op
.getOperation()) ||
4762 isa
<LoopLikeOpInterface
>(op
.getOperation()))
4765 // If no operand comes from a tensor::CastOp and can be folded then fail.
4766 bool hasTensorCastOperand
=
4767 llvm::any_of(op
->getOpOperands(), [&](OpOperand
&opOperand
) {
4768 if (llvm::isa
<BlockArgument
>(opOperand
.get()))
4770 auto castOp
= opOperand
.get().getDefiningOp
<tensor::CastOp
>();
4771 return castOp
&& canFoldIntoConsumerOp(castOp
);
4774 return hasTensorCastOperand
;
4777 static SmallVector
<Value
> getNewOperands(DestinationStyleOpInterface op
,
4778 SmallVector
<Type
> &newResTy
) {
4779 SmallVector
<Value
> newOperands
;
4780 newOperands
.reserve(op
->getNumOperands());
4782 // Assumes that the result has dpsInits followed by nonDpsInits.
4783 int64_t dpsInitIdx
= 0;
4784 for (OpOperand
&opOperand
: op
->getOpOperands()) {
4785 auto tensorCastOp
= opOperand
.get().getDefiningOp
<tensor::CastOp
>();
4786 bool fold
= canFoldIntoConsumerOp(tensorCastOp
);
4787 newOperands
.push_back(fold
? tensorCastOp
.getOperand() : opOperand
.get());
4788 if (op
.isDpsInit(&opOperand
) &&
4789 !llvm::isa
<MemRefType
>(newOperands
.back().getType()))
4790 newResTy
[dpsInitIdx
++] = newOperands
.back().getType();
4795 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4796 /// `tensor.cast` has source that is more static than the consuming op.
4800 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4801 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4807 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4809 struct FoldTensorCastPackOp
: public OpRewritePattern
<PackOp
> {
4810 using OpRewritePattern
<PackOp
>::OpRewritePattern
;
4812 LogicalResult
matchAndRewrite(PackOp op
,
4813 PatternRewriter
&rewriter
) const override
{
4814 if (!foldTensorCastPrecondition(op
))
4817 SmallVector
<Type
> newResultTypes(op
->getResultTypes());
4818 SmallVector
<Value
> newOperands
= getNewOperands(op
, newResultTypes
);
4820 // Get the updated mixed-tile-sizes attribute.
4821 SmallVector
<OpFoldResult
> newMixedTileSizes
;
4822 for (auto it
: llvm::zip(cast
<ShapedType
>(newResultTypes
[0])
4824 .take_back(op
.getMixedTiles().size()),
4825 op
.getMixedTiles())) {
4826 int64_t shape
= std::get
<0>(it
);
4827 if (shape
== ShapedType::kDynamic
) {
4828 newMixedTileSizes
.push_back(std::get
<1>(it
));
4832 if (Attribute attr
=
4833 llvm::dyn_cast_if_present
<Attribute
>(std::get
<1>(it
))) {
4834 // Already a constant
4835 newMixedTileSizes
.push_back(std::get
<1>(it
));
4837 int64_t tileSize
= getConstantIntValue(std::get
<1>(it
)).value();
4838 assert(tileSize
== shape
&& "tile size and dim size don't match!");
4840 newMixedTileSizes
.push_back(
4841 (rewriter
.getIntegerAttr(rewriter
.getIndexType(), shape
)));
4846 PackOp newOp
= rewriter
.create
<PackOp
>(
4847 op
.getLoc(), newOperands
[0], newOperands
[1], op
.getInnerDimsPos(),
4848 newMixedTileSizes
, op
.getPaddingValue(), op
.getOuterDimsPerm());
4849 newOp
->setDiscardableAttrs(op
->getDiscardableAttrDictionary());
4852 Value oldResult
= op
.getResult();
4853 Value newResult
= newOp
.getResult();
4854 Value replacement
= (newResult
.getType() != oldResult
.getType())
4855 ? rewriter
.create
<tensor::CastOp
>(
4856 op
->getLoc(), oldResult
.getType(), newResult
)
4859 rewriter
.replaceOp(op
, {replacement
});
4865 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4866 /// the `tensor.cast` has source that is more static than the consuming op.
4870 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4871 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4877 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4879 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4880 /// can add the pattern to their canonicalizers.
4881 struct FoldTensorCastProducerOp
4882 : public OpInterfaceRewritePattern
<DestinationStyleOpInterface
> {
4883 using OpInterfaceRewritePattern
<
4884 DestinationStyleOpInterface
>::OpInterfaceRewritePattern
;
4886 LogicalResult
matchAndRewrite(DestinationStyleOpInterface op
,
4887 PatternRewriter
&rewriter
) const override
{
4889 // Reject tensor::PackOp - there's dedicated pattern for that instead.
4890 if (!foldTensorCastPrecondition(op
) || dyn_cast
<tensor::PackOp
>(*op
))
4893 SmallVector
<Type
> newResultTypes(op
->getResultTypes());
4894 SmallVector
<Value
> newOperands
= getNewOperands(op
, newResultTypes
);
4897 auto newOp
= clone(rewriter
, op
, newResultTypes
, newOperands
);
4899 SmallVector
<Value
, 4> replacements
;
4900 replacements
.reserve(newOp
->getNumResults());
4901 for (auto [oldResult
, newResult
] :
4902 llvm::zip(op
->getResults(), newOp
->getResults())) {
4903 if (newResult
.getType() != oldResult
.getType()) {
4904 replacements
.push_back(rewriter
.create
<tensor::CastOp
>(
4905 op
->getLoc(), oldResult
.getType(), newResult
));
4907 replacements
.push_back(newResult
);
4910 rewriter
.replaceOp(op
, replacements
);
4916 //===----------------------------------------------------------------------===//
4918 //===----------------------------------------------------------------------===//
4920 void TensorDialect::getCanonicalizationPatterns(
4921 RewritePatternSet
&results
) const {
4922 results
.add
<FoldTensorCastPackOp
>(getContext());
4923 results
.add
<FoldTensorCastProducerOp
>(getContext());
4926 //===----------------------------------------------------------------------===//
4927 // TableGen'd op method definitions
4928 //===----------------------------------------------------------------------===//
4930 #define GET_OP_CLASSES
4931 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"