1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Linalg/IR/Linalg.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12 #include "mlir/Dialect/Utils/IndexingUtils.h"
13 #include "mlir/IR/PatternMatch.h"
19 /// Returns the number of shape sizes that is either dynamic or greater than 1.
20 static int64_t getNumGtOneDims(ArrayRef
<int64_t> shape
) {
21 return llvm::count_if(
22 shape
, [](int64_t v
) { return ShapedType::isDynamic(v
) || v
> 1; });
25 /// Returns success() if there is only 1 dimension size in non-packed domain
26 /// being greater than 1 and packing only happens on the dimension.
27 /// Note: this method should only be used by pack/unpack to reshape conversion.
28 /// It assumes that non-unit inner tile size must be used by the non-unit
30 static LogicalResult
isPackOn1D(RewriterBase
&rewriter
, Operation
*op
,
31 ArrayRef
<int64_t> srcShape
,
32 ArrayRef
<int64_t> innerPackTileSize
) {
33 if (getNumGtOneDims(srcShape
) > 1) {
34 return rewriter
.notifyMatchFailure(
35 op
, "expects non-packed domain to have at most one non-unit dims");
37 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38 // will faill on getting reassociation maps.
39 if (getNumGtOneDims(innerPackTileSize
) > 1) {
40 return rewriter
.notifyMatchFailure(
41 op
, "expects at most one non-unit inner tiles");
46 // If the `linalgOp` represents a transpose, return the permutation vector for
47 // the transpose. Otherwise, return failure.
48 static FailureOr
<SmallVector
<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp
) {
50 if (auto transposeOp
= dyn_cast
<linalg::TransposeOp
>(linalgOp
.getOperation()))
51 return SmallVector
<int64_t>(transposeOp
.getPermutation());
52 if (linalgOp
.getNumParallelLoops() != linalgOp
.getNumLoops())
55 if (linalgOp
.getNumDpsInputs() != 1 || linalgOp
.getNumDpsInits() != 1)
57 auto mapRange
= linalgOp
.getIndexingMapsArray();
58 if (!mapRange
.front().isPermutation() || !mapRange
.back().isPermutation() ||
59 mapRange
.front() == mapRange
.back()) {
62 if (!llvm::hasSingleElement(linalgOp
.getBlock()->getOperations()))
64 AffineMap outMap
= mapRange
.back();
65 AffineMap inMap
= mapRange
.front();
66 // To get the permutation, look at each output index and find which
67 // dimension in the input we're reading from for that index.
68 return llvm::map_to_vector(outMap
.getResults(),
69 [&](AffineExpr expr
) -> int64_t {
70 return *inMap
.getResultPosition(expr
);
74 /// Packing one-dimensional tensor can be expressed as an expand shape op.
75 struct SimplifyPackToExpandShape
: public OpRewritePattern
<PackOp
> {
76 using OpRewritePattern
<PackOp
>::OpRewritePattern
;
79 insertExpand(RewriterBase
&rewriter
, Location loc
, Value operand
,
81 ArrayRef
<ReassociationIndices
> reassociation
) const {
82 if (operand
.getType() == newOperandType
)
85 .create
<tensor::ExpandShapeOp
>(loc
, newOperandType
, operand
,
90 /// Returns success() if it is only packing on the innermost dimension.
91 LogicalResult
isPackOnInnerMostDim(RewriterBase
&rewriter
,
92 PackOp packOp
) const {
93 auto outerDimsPerm
= packOp
.getOuterDimsPerm();
94 if (!outerDimsPerm
.empty() && !isIdentityPermutation(outerDimsPerm
)) {
95 return rewriter
.notifyMatchFailure(
97 "expects outer_dims_perm is empty or an identity permutation");
100 int64_t srcRank
= packOp
.getSourceRank();
101 ArrayRef
<int64_t> dimsPos
= packOp
.getInnerDimsPos();
102 if (dimsPos
.size() != 1 || (dimsPos
[0] + 1 != srcRank
)) {
103 return rewriter
.notifyMatchFailure(
104 packOp
, "expects packing at the innermost dimension");
109 LogicalResult
matchAndRewrite(PackOp packOp
,
110 PatternRewriter
&rewriter
) const override
{
111 if (packOp
.getPaddingValue())
112 return rewriter
.notifyMatchFailure(packOp
, "expects no padding value");
114 RankedTensorType sourceType
= packOp
.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter
, packOp
)) &&
116 failed(isPackOn1D(rewriter
, packOp
, sourceType
.getShape(),
117 packOp
.getStaticTiles())) &&
118 !packOp
.isLikePad()) {
122 RankedTensorType destType
= packOp
.getDestType();
124 getReassociationIndicesForReshape(sourceType
, destType
);
127 FailureOr
<Value
> expanded
=
128 insertExpand(rewriter
, packOp
.getLoc(), packOp
.getSource(), destType
,
130 if (failed(expanded
)) {
131 return rewriter
.notifyMatchFailure(
132 packOp
, "unable to expand source of tensor.pack");
134 rewriter
.replaceOp(packOp
, *expanded
);
139 struct SimplifyUnPackToCollapseShape
: public OpRewritePattern
<UnPackOp
> {
140 using OpRewritePattern
<UnPackOp
>::OpRewritePattern
;
142 Value
insertCollapse(RewriterBase
&rewriter
, Location loc
, Value operand
,
143 Type newOperandType
, ArrayAttr reassociation
) const {
144 if (operand
.getType() == newOperandType
)
146 return rewriter
.create
<tensor::CollapseShapeOp
>(loc
, newOperandType
,
147 operand
, reassociation
);
150 /// Returns success() if it is unpacking on the innermost dimension.
151 LogicalResult
isUnpackOnInnerMostDim(RewriterBase
&rewriter
,
152 UnPackOp unpackOp
) const {
153 auto outerDimsPerm
= unpackOp
.getOuterDimsPerm();
154 if (!outerDimsPerm
.empty() && !isIdentityPermutation(outerDimsPerm
)) {
155 return rewriter
.notifyMatchFailure(
157 "expects outer_dims_perm is empty or an identity permutation");
160 RankedTensorType sourceType
= unpackOp
.getSourceType();
161 RankedTensorType destType
= unpackOp
.getDestType();
162 if (!sourceType
.hasStaticShape() || !destType
.hasStaticShape())
163 return rewriter
.notifyMatchFailure(unpackOp
, "expects static shapes");
165 ArrayRef
<int64_t> dimsPos
= unpackOp
.getInnerDimsPos();
166 if (dimsPos
.size() != 1 || (dimsPos
[0] + 1 != destType
.getRank())) {
167 return rewriter
.notifyMatchFailure(
168 unpackOp
, "expects unpacking on the innermost dimension");
174 LogicalResult
matchAndRewrite(UnPackOp unpackOp
,
175 PatternRewriter
&rewriter
) const override
{
176 RankedTensorType destType
= unpackOp
.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter
, unpackOp
)) &&
178 failed(isPackOn1D(rewriter
, unpackOp
, destType
.getShape(),
179 unpackOp
.getStaticTiles())) &&
180 !unpackOp
.isLikeUnPad()) {
184 RankedTensorType sourceType
= unpackOp
.getSourceType();
186 getReassociationIndicesForReshape(sourceType
, destType
);
189 Value collapsed
= insertCollapse(
190 rewriter
, unpackOp
.getLoc(), unpackOp
.getSource(), destType
,
191 getReassociationIndicesAttribute(rewriter
, *reassociation
));
192 rewriter
.replaceOp(unpackOp
, collapsed
);
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp
: public OpRewritePattern
<PackOp
> {
200 using OpRewritePattern
<PackOp
>::OpRewritePattern
;
202 LogicalResult
matchAndRewrite(PackOp packOp
,
203 PatternRewriter
&rewriter
) const override
{
204 auto padOp
= packOp
.getSource().getDefiningOp
<PadOp
>();
206 if (!padOp
|| padOp
.getNofold() || !padOp
.hasZeroLowPad())
209 Value constantPaddingValue
= padOp
.getConstantPaddingValue();
210 if (!constantPaddingValue
)
213 if (auto paddingValue
= packOp
.getPaddingValue())
214 if (!isEqualConstantIntOrValue(paddingValue
, constantPaddingValue
))
217 rewriter
.replaceOpWithNewOp
<PackOp
>(
218 packOp
, padOp
.getSource(), packOp
.getDest(), packOp
.getInnerDimsPos(),
219 packOp
.getMixedTiles(), constantPaddingValue
,
220 packOp
.getOuterDimsPerm());
225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226 /// has extract_slice semantics.
227 struct FoldUnpackWithExtractSliceOp
: public OpRewritePattern
<ExtractSliceOp
> {
228 using OpRewritePattern
<ExtractSliceOp
>::OpRewritePattern
;
230 LogicalResult
matchAndRewrite(ExtractSliceOp sliceOp
,
231 PatternRewriter
&rewriter
) const override
{
232 auto unpackOp
= sliceOp
.getSource().getDefiningOp
<UnPackOp
>();
236 if (sliceOp
.getResultType().getRank() != unpackOp
.getDestType().getRank()) {
237 return rewriter
.notifyMatchFailure(
238 sliceOp
, "rank-reduced folding is not supported");
241 // Check all offsets are zeros, and all strides are ones.
242 if (!areAllConstantIntValue(sliceOp
.getMixedOffsets(), 0) ||
243 !areAllConstantIntValue(sliceOp
.getMixedStrides(), 1)) {
244 return rewriter
.notifyMatchFailure(
245 sliceOp
, "expects offsets to be 0s and strides to be 1s");
248 // Create a new empty output tensor.
249 Type elementType
= unpackOp
.getDestType().getElementType();
250 Value output
= rewriter
.create
<EmptyOp
>(
251 sliceOp
.getLoc(), sliceOp
.getMixedSizes(), elementType
);
252 rewriter
.replaceOpWithNewOp
<UnPackOp
>(
253 sliceOp
, unpackOp
.getSource(), output
, unpackOp
.getInnerDimsPos(),
254 unpackOp
.getMixedTiles(), unpackOp
.getOuterDimsPerm());
259 // Applies 'permutation' on 'inVec' and stores the result in resVec.
260 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
261 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
262 // greater than the rank specified. If it's so then return false.
263 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
264 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
265 // not allowed since `3` exceeds the value of the rank in the given range.
266 static bool checkAndPermute(ArrayRef
<int64_t> permutation
,
267 ArrayRef
<int64_t> inVec
,
268 SmallVectorImpl
<int64_t> &resVec
, int64_t rank
) {
270 for (unsigned int i
= 0; i
< rank
; ++i
) {
271 int64_t remappedPosition
= permutation
[i
];
272 if (remappedPosition
>= rank
)
275 remappedPosition
= inVec
[remappedPosition
];
276 resVec
.push_back(remappedPosition
);
282 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
284 struct FoldProducerPackWithConsumerLinalgTransposeOp
285 : public OpInterfaceRewritePattern
<linalg::LinalgOp
> {
286 using OpInterfaceRewritePattern
<linalg::LinalgOp
>::OpInterfaceRewritePattern
;
288 LogicalResult
matchAndRewrite(linalg::LinalgOp linalgOp
,
289 PatternRewriter
&rewriter
) const override
{
290 auto packOp
= linalgOp
->getOperand(0).getDefiningOp
<PackOp
>();
295 FailureOr
<SmallVector
<int64_t>> maybePerm
=
296 getTransposeOpPermutation(linalgOp
);
297 if (failed(maybePerm
))
300 auto innerDimsPos
= packOp
.getInnerDimsPos();
301 auto mixedInnerTiles
= packOp
.getMixedTiles();
302 auto outerDimsPerm
= packOp
.getOuterDimsPerm();
303 auto transposePerm
= maybePerm
.value();
304 SmallVector
<int64_t> newOuterDimsPermVec
;
305 SmallVector
<int64_t> newInnerDimsPosVec
;
306 SmallVector
<OpFoldResult
> newMixedInnerTilesVec
;
307 int64_t srcRank
= packOp
.getSourceRank();
309 if (!checkAndPermute(transposePerm
, outerDimsPerm
, newOuterDimsPermVec
,
311 return rewriter
.notifyMatchFailure(
313 "Cannot fold in tensor.pack if a tile dimension was transposed "
314 "with a non-tile dimension in linalg.transpose.");
316 // Process transpose operation for tiled inner dimensions
317 for (unsigned int i
= srcRank
; i
< transposePerm
.size(); ++i
) {
318 int64_t remappedPosition
= transposePerm
[i
] - srcRank
;
319 newMixedInnerTilesVec
.push_back(mixedInnerTiles
[remappedPosition
]);
320 newInnerDimsPosVec
.push_back(innerDimsPos
[remappedPosition
]);
323 Value output
= packOp
.createDestinationTensor(
324 rewriter
, linalgOp
.getLoc(), packOp
.getSource(), newMixedInnerTilesVec
,
325 newInnerDimsPosVec
, newOuterDimsPermVec
);
327 rewriter
.replaceOpWithNewOp
<PackOp
>(
328 linalgOp
, packOp
.getSource(), output
, newInnerDimsPosVec
,
329 newMixedInnerTilesVec
, packOp
.getPaddingValue(), newOuterDimsPermVec
);
335 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337 struct FoldConsumerPackWithProducerLinalgTransposeOp
338 : public OpRewritePattern
<PackOp
> {
339 using OpRewritePattern
<PackOp
>::OpRewritePattern
;
341 LogicalResult
matchAndRewrite(PackOp packOp
,
342 PatternRewriter
&rewriter
) const override
{
343 auto linalgOp
= packOp
.getSource().getDefiningOp
<linalg::LinalgOp
>();
347 FailureOr
<SmallVector
<int64_t>> maybePerm
=
348 getTransposeOpPermutation(linalgOp
);
349 if (failed(maybePerm
))
352 auto transposePermutation
= maybePerm
.value();
353 auto outerDimsPerm
= packOp
.getOuterDimsPerm();
354 auto innerDimsPos
= packOp
.getInnerDimsPos();
355 SmallVector
<int64_t> newInnerDimsPosVec
;
356 SmallVector
<int64_t> newOuterDimsPermVec
=
357 llvm::to_vector(transposePermutation
);
359 if (!outerDimsPerm
.empty())
360 applyPermutationToVector(newOuterDimsPermVec
, outerDimsPerm
);
362 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
363 // permutation rank won't necessarily be equal in all cases.
364 for (auto dim
: innerDimsPos
)
365 newInnerDimsPosVec
.push_back(transposePermutation
[dim
]);
367 Value output
= packOp
.createDestinationTensor(
368 rewriter
, packOp
.getLoc(), linalgOp
->getOperand(0),
369 packOp
.getMixedTiles(), newInnerDimsPosVec
, newOuterDimsPermVec
);
371 rewriter
.replaceOpWithNewOp
<PackOp
>(
372 packOp
, linalgOp
->getOperand(0), output
, newInnerDimsPosVec
,
373 packOp
.getMixedTiles(), packOp
.getPaddingValue(), newOuterDimsPermVec
);
379 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
380 /// transpose semantics.
381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
382 : public OpInterfaceRewritePattern
<linalg::LinalgOp
> {
383 using OpInterfaceRewritePattern
<linalg::LinalgOp
>::OpInterfaceRewritePattern
;
385 LogicalResult
matchAndRewrite(linalg::LinalgOp linalgOp
,
386 PatternRewriter
&rewriter
) const override
{
387 auto unPackOp
= linalgOp
->getOperand(0).getDefiningOp
<UnPackOp
>();
392 FailureOr
<SmallVector
<int64_t>> maybePerm
=
393 getTransposeOpPermutation(linalgOp
);
394 if (failed(maybePerm
))
397 auto outerDimsPerm
= unPackOp
.getOuterDimsPerm();
398 auto innerDimsPos
= unPackOp
.getInnerDimsPos();
399 SmallVector
<int64_t> newInnerDimsPosVec
;
400 SmallVector
<int64_t> newOuterDimsPermVec
=
401 invertPermutationVector(maybePerm
.value());
403 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404 // permutation rank won't necessarily be equal in all cases.
405 for (auto dim
: innerDimsPos
)
406 newInnerDimsPosVec
.push_back(newOuterDimsPermVec
[dim
]);
408 if (!outerDimsPerm
.empty())
409 applyPermutationToVector(newOuterDimsPermVec
, outerDimsPerm
);
411 // Reuse the destination of the transpose op.
412 rewriter
.replaceOpWithNewOp
<UnPackOp
>(
413 linalgOp
, unPackOp
.getSource(), linalgOp
.getDpsInits()[0],
414 newInnerDimsPosVec
, unPackOp
.getMixedTiles(), newOuterDimsPermVec
);
420 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
421 /// transpose semantics.
422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423 : public OpRewritePattern
<UnPackOp
> {
424 using OpRewritePattern
<UnPackOp
>::OpRewritePattern
;
426 LogicalResult
matchAndRewrite(UnPackOp unPackOp
,
427 PatternRewriter
&rewriter
) const override
{
428 auto linalgOp
= unPackOp
.getSource().getDefiningOp
<linalg::LinalgOp
>();
432 FailureOr
<SmallVector
<int64_t>> maybePerm
=
433 getTransposeOpPermutation(linalgOp
);
434 if (failed(maybePerm
))
437 SmallVector
<SmallVector
<OpFoldResult
>> unpackOpResultDims
;
438 if (failed(reifyResultShapes(rewriter
, unPackOp
, unpackOpResultDims
))) {
442 SmallVector
<int64_t> inverseTransposePerm
=
443 invertPermutationVector(maybePerm
.value());
444 auto outerDimsPerm
= unPackOp
.getOuterDimsPerm();
445 auto innerDimsPos
= unPackOp
.getInnerDimsPos();
446 int64_t destRank
= unPackOp
.getSourceRank() - innerDimsPos
.size();
447 auto mixedInnerTilesVec
= unPackOp
.getMixedTiles();
448 SmallVector
<int64_t> newOuterDimsPermVec
;
449 SmallVector
<int64_t> newInnerDimsPosVec
;
450 SmallVector
<OpFoldResult
> newMixedInnerTilesVec
;
451 if (!checkAndPermute(inverseTransposePerm
, outerDimsPerm
,
452 newOuterDimsPermVec
, destRank
))
453 return rewriter
.notifyMatchFailure(
455 "Cannot fold in tensor.unpack if a tile dimension was transposed "
456 "with a non-tile dimension in linalg.transpose.");
458 // Process transpose operation for tiled inner dimensions
459 for (unsigned int i
= destRank
; i
< inverseTransposePerm
.size(); ++i
) {
460 int64_t remappedPosition
= inverseTransposePerm
[i
] - destRank
;
461 newMixedInnerTilesVec
.push_back(mixedInnerTilesVec
[remappedPosition
]);
462 newInnerDimsPosVec
.push_back(innerDimsPos
[remappedPosition
]);
466 cast
<ShapedType
>(unPackOp
->getResultTypes()[0]).getElementType();
467 Value output
= rewriter
.create
<tensor::EmptyOp
>(
468 unPackOp
->getLoc(), unpackOpResultDims
[0], elemType
);
470 rewriter
.replaceOpWithNewOp
<UnPackOp
>(
471 unPackOp
, linalgOp
->getOperand(0), output
, newInnerDimsPosVec
,
472 newMixedInnerTilesVec
, newOuterDimsPermVec
);
479 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet
&patterns
) {
480 patterns
.insert
<FoldUnpackWithExtractSliceOp
, FoldPadWithPackOp
,
481 FoldProducerPackWithConsumerLinalgTransposeOp
,
482 FoldConsumerPackWithProducerLinalgTransposeOp
,
483 FoldConsumerUnPackWithProducerLinalgTransposeOp
,
484 FoldProducerUnPackWithConsumerLinalgTransposeOp
>(
485 patterns
.getContext());
488 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet
&patterns
) {
489 patterns
.add
<SimplifyPackToExpandShape
, SimplifyUnPackToCollapseShape
>(
490 patterns
.getContext());
493 } // namespace tensor