1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
15 using namespace mlir::tensor
;
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20 : public OpRewritePattern
<ExpandShapeOp
> {
21 using OpRewritePattern
<ExpandShapeOp
>::OpRewritePattern
;
23 LogicalResult
matchAndRewrite(ExpandShapeOp expandShapeOp
,
24 PatternRewriter
&rewriter
) const override
{
25 RankedTensorType resultType
= expandShapeOp
.getResultType();
27 expandShapeOp
.getSrc().getDefiningOp
<ExtractSliceOp
>();
30 RankedTensorType srcType
= extractSliceOp
.getSourceType();
32 // Only cases where the ExpandShapeOp can be folded away entirely are
33 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34 // has no rank-reduction anymore are supported at the moment.
35 RankedTensorType nonReducingExtractType
= ExtractSliceOp::inferResultType(
36 srcType
, extractSliceOp
.getStaticOffsets(),
37 extractSliceOp
.getStaticSizes(), extractSliceOp
.getStaticStrides());
38 if (nonReducingExtractType
!= resultType
)
41 SmallVector
<OpFoldResult
> mixedOffsets
= extractSliceOp
.getMixedOffsets();
42 SmallVector
<OpFoldResult
> mixedSizes
= extractSliceOp
.getMixedSizes();
43 SmallVector
<OpFoldResult
> mixedStrides
= extractSliceOp
.getMixedStrides();
44 rewriter
.replaceOpWithNewOp
<tensor::ExtractSliceOp
>(
45 expandShapeOp
, extractSliceOp
.getSource(), mixedOffsets
, mixedSizes
,
51 /// Fold collapse_shape which only removes static dimensions of size `1`
52 /// into extract_slice.
53 struct FoldUnPaddingCollapseIntoExtract
54 : public OpRewritePattern
<tensor::CollapseShapeOp
> {
55 using OpRewritePattern
<tensor::CollapseShapeOp
>::OpRewritePattern
;
57 LogicalResult
matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp
,
58 PatternRewriter
&rewriter
) const override
{
60 collapseShapeOp
.getSrc().getDefiningOp
<tensor::ExtractSliceOp
>();
61 // Collapse cannot be folded away with multiple users of the extract slice
62 // and it is not necessarily beneficial to only convert the collapse into
63 // another extract slice.
64 if (!extractSliceOp
|| !extractSliceOp
->hasOneUse())
67 // Only fold away simple collapse where all removed dimensions have static
69 SliceVerificationResult res
= isRankReducedType(
70 collapseShapeOp
.getSrcType(), collapseShapeOp
.getResultType());
71 if (res
!= SliceVerificationResult::Success
)
72 return rewriter
.notifyMatchFailure(collapseShapeOp
,
73 "expected unpadding collapse");
75 Value unPaddedExtractSlice
= rewriter
.create
<tensor::ExtractSliceOp
>(
76 extractSliceOp
.getLoc(), collapseShapeOp
.getResultType(),
77 extractSliceOp
.getSource(), extractSliceOp
.getMixedOffsets(),
78 extractSliceOp
.getMixedSizes(), extractSliceOp
.getMixedStrides());
79 rewriter
.replaceOp(collapseShapeOp
, unPaddedExtractSlice
);
84 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
85 template <typename OpTy
>
86 struct FoldInsertOfRankReducingInsert
: public OpRewritePattern
<OpTy
> {
87 using OpRewritePattern
<OpTy
>::OpRewritePattern
;
89 LogicalResult
matchAndRewrite(OpTy insertSliceOp
,
90 PatternRewriter
&rewriter
) const override
{
91 auto collapseShapeOp
=
92 insertSliceOp
.getSource().template getDefiningOp
<CollapseShapeOp
>();
95 RankedTensorType srcType
= collapseShapeOp
.getSrcType();
97 // Only cases where the CollapseShapeOp can be folded away entirely are
98 // supported. Moreover, only simple cases where the resulting InsertSliceOp
99 // has no rank-reduction anymore are supported at the moment.
100 RankedTensorType nonReducingInsertType
=
101 RankedTensorType::get(insertSliceOp
.getStaticSizes(),
102 insertSliceOp
.getDestType().getElementType());
103 if (nonReducingInsertType
!= srcType
)
106 SmallVector
<OpFoldResult
> mixedOffsets
= insertSliceOp
.getMixedOffsets();
107 SmallVector
<OpFoldResult
> mixedSizes
= insertSliceOp
.getMixedSizes();
108 SmallVector
<OpFoldResult
> mixedStrides
= insertSliceOp
.getMixedStrides();
109 rewriter
.replaceOpWithNewOp
<OpTy
>(insertSliceOp
, collapseShapeOp
.getSrc(),
110 insertSliceOp
.getDest(), mixedOffsets
,
111 mixedSizes
, mixedStrides
);
116 /// Fold expand_shape which only adds static dimensions of size `1`
117 /// into insert_slice.
118 template <typename OpTy
>
119 struct FoldPaddingExpandIntoInsert
: public OpRewritePattern
<OpTy
> {
120 using OpRewritePattern
<OpTy
>::OpRewritePattern
;
122 LogicalResult
matchAndRewrite(OpTy insertSliceOp
,
123 PatternRewriter
&rewriter
) const override
{
124 auto expandShapeOp
= insertSliceOp
.getSource()
125 .template getDefiningOp
<tensor::ExpandShapeOp
>();
129 // Only fold away simple expansion where all added dimensions have static
131 SliceVerificationResult res
= isRankReducedType(
132 expandShapeOp
.getResultType(), expandShapeOp
.getSrcType());
133 if (res
!= SliceVerificationResult::Success
)
134 return rewriter
.notifyMatchFailure(insertSliceOp
,
135 "expected rank increasing expansion");
137 rewriter
.modifyOpInPlace(insertSliceOp
, [&]() {
138 insertSliceOp
.getSourceMutable().assign(expandShapeOp
.getSrc());
144 /// Pattern to bubble up a tensor.expand_shape op through a producer
145 /// tensor.collapse_shape op that has non intersecting reassociations.
146 struct BubbleUpExpandThroughParallelCollapse
147 : public OpRewritePattern
<tensor::ExpandShapeOp
> {
148 using OpRewritePattern
<tensor::ExpandShapeOp
>::OpRewritePattern
;
150 LogicalResult
matchAndRewrite(tensor::ExpandShapeOp expandOp
,
151 PatternRewriter
&rewriter
) const override
{
153 expandOp
.getSrc().getDefiningOp
<tensor::CollapseShapeOp
>();
156 auto expandReInds
= expandOp
.getReassociationIndices();
157 auto collapseReInds
= collapseOp
.getReassociationIndices();
159 // Reshapes are parallel to each other if none of the reassociation indices
160 // have greater than 1 index for both reshapes.
161 for (auto [expandReassociation
, collapseReassociation
] :
162 llvm::zip_equal(expandReInds
, collapseReInds
)) {
163 if (collapseReassociation
.size() != 1 && expandReassociation
.size() != 1)
167 // Compute new reassociation indices and expanded/collaped shapes.
168 SmallVector
<ReassociationIndices
> newExpandReInds
, newCollapseReInds
;
169 Location loc
= expandOp
->getLoc();
170 SmallVector
<OpFoldResult
> collapseSizes
=
171 tensor::getMixedSizes(rewriter
, loc
, collapseOp
.getSrc());
172 SmallVector
<OpFoldResult
> expandSizes(getMixedValues(
173 expandOp
.getStaticOutputShape(), expandOp
.getOutputShape(), rewriter
));
174 SmallVector
<OpFoldResult
> newExpandSizes
;
175 int64_t index
= 0, expandIndex
= 0, collapseIndex
= 0;
176 for (auto [idx
, collapseReassociation
] : llvm::enumerate(collapseReInds
)) {
177 if (collapseReassociation
.size() != 1) {
178 ReassociationIndices newCollapseReassociation
;
179 for (size_t i
= 0; i
< collapseReassociation
.size(); ++i
) {
180 newCollapseReassociation
.push_back(index
);
181 newExpandReInds
.push_back({index
++});
182 newExpandSizes
.push_back(collapseSizes
[collapseIndex
++]);
184 newCollapseReInds
.push_back(newCollapseReassociation
);
188 ReassociationIndices newExpandReassociation
;
189 auto expandReassociation
= expandReInds
[idx
];
190 for (size_t i
= 0; i
< expandReassociation
.size(); ++i
) {
191 newExpandReassociation
.push_back(index
);
192 newCollapseReInds
.push_back({index
++});
193 newExpandSizes
.push_back(expandSizes
[expandIndex
++]);
195 newExpandReInds
.push_back(newExpandReassociation
);
199 // Swap reshape order.
200 SmallVector
<Value
> dynamicSizes
;
201 SmallVector
<int64_t> staticSizes
;
202 dispatchIndexOpFoldResults(newExpandSizes
, dynamicSizes
, staticSizes
);
203 auto expandResultType
= expandOp
.getResultType().clone(staticSizes
);
204 auto newExpand
= rewriter
.create
<tensor::ExpandShapeOp
>(
205 loc
, expandResultType
, collapseOp
.getSrc(), newExpandReInds
,
207 rewriter
.replaceOpWithNewOp
<tensor::CollapseShapeOp
>(
208 expandOp
, newExpand
.getResult(), newCollapseReInds
);
215 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
216 RewritePatternSet
&patterns
) {
218 .add
<FoldExpandOfRankReducingExtract
, FoldUnPaddingCollapseIntoExtract
,
219 FoldInsertOfRankReducingInsert
<tensor::InsertSliceOp
>,
220 FoldInsertOfRankReducingInsert
<tensor::ParallelInsertSliceOp
>,
221 FoldPaddingExpandIntoInsert
<tensor::InsertSliceOp
>,
222 FoldPaddingExpandIntoInsert
<tensor::ParallelInsertSliceOp
>>(
223 patterns
.getContext());
226 void mlir::tensor::populateBubbleUpExpandShapePatterns(
227 RewritePatternSet
&patterns
) {
228 patterns
.add
<BubbleUpExpandThroughParallelCollapse
>(patterns
.getContext());