[gn build] Port 69b8cf4f0621
[llvm-project.git] / mlir / lib / Dialect / Tensor / Transforms / ReshapePatterns.cpp
blob5edd7a02bc42b1c2c71cd09c0c2b19120e23b360
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
14 using namespace mlir;
15 using namespace mlir::tensor;
17 namespace {
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20 : public OpRewritePattern<ExpandShapeOp> {
21 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
23 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24 PatternRewriter &rewriter) const override {
25 RankedTensorType resultType = expandShapeOp.getResultType();
26 auto extractSliceOp =
27 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28 if (!extractSliceOp)
29 return failure();
30 RankedTensorType srcType = extractSliceOp.getSourceType();
32 // Only cases where the ExpandShapeOp can be folded away entirely are
33 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34 // has no rank-reduction anymore are supported at the moment.
35 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36 srcType, extractSliceOp.getStaticOffsets(),
37 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38 if (nonReducingExtractType != resultType)
39 return failure();
41 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46 mixedStrides);
47 return success();
51 /// Fold collapse_shape which only removes static dimensions of size `1`
52 /// into extract_slice.
53 struct FoldUnPaddingCollapseIntoExtract
54 : public OpRewritePattern<tensor::CollapseShapeOp> {
55 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
57 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
58 PatternRewriter &rewriter) const override {
59 auto extractSliceOp =
60 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
61 // Collapse cannot be folded away with multiple users of the extract slice
62 // and it is not necessarily beneficial to only convert the collapse into
63 // another extract slice.
64 if (!extractSliceOp || !extractSliceOp->hasOneUse())
65 return failure();
67 // Only fold away simple collapse where all removed dimensions have static
68 // size `1`.
69 SliceVerificationResult res = isRankReducedType(
70 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
71 if (res != SliceVerificationResult::Success)
72 return rewriter.notifyMatchFailure(collapseShapeOp,
73 "expected unpadding collapse");
75 Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
76 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
80 return success();
84 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
85 template <typename OpTy>
86 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
87 using OpRewritePattern<OpTy>::OpRewritePattern;
89 LogicalResult matchAndRewrite(OpTy insertSliceOp,
90 PatternRewriter &rewriter) const override {
91 auto collapseShapeOp =
92 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
93 if (!collapseShapeOp)
94 return failure();
95 RankedTensorType srcType = collapseShapeOp.getSrcType();
97 // Only cases where the CollapseShapeOp can be folded away entirely are
98 // supported. Moreover, only simple cases where the resulting InsertSliceOp
99 // has no rank-reduction anymore are supported at the moment.
100 RankedTensorType nonReducingInsertType =
101 RankedTensorType::get(insertSliceOp.getStaticSizes(),
102 insertSliceOp.getDestType().getElementType());
103 if (nonReducingInsertType != srcType)
104 return failure();
106 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
107 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
108 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
109 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
110 insertSliceOp.getDest(), mixedOffsets,
111 mixedSizes, mixedStrides);
112 return success();
116 /// Fold expand_shape which only adds static dimensions of size `1`
117 /// into insert_slice.
118 template <typename OpTy>
119 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
120 using OpRewritePattern<OpTy>::OpRewritePattern;
122 LogicalResult matchAndRewrite(OpTy insertSliceOp,
123 PatternRewriter &rewriter) const override {
124 auto expandShapeOp = insertSliceOp.getSource()
125 .template getDefiningOp<tensor::ExpandShapeOp>();
126 if (!expandShapeOp)
127 return failure();
129 // Only fold away simple expansion where all added dimensions have static
130 // size `1`.
131 SliceVerificationResult res = isRankReducedType(
132 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
133 if (res != SliceVerificationResult::Success)
134 return rewriter.notifyMatchFailure(insertSliceOp,
135 "expected rank increasing expansion");
137 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
138 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
140 return success();
144 /// Pattern to bubble up a tensor.expand_shape op through a producer
145 /// tensor.collapse_shape op that has non intersecting reassociations.
146 struct BubbleUpExpandThroughParallelCollapse
147 : public OpRewritePattern<tensor::ExpandShapeOp> {
148 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
150 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
151 PatternRewriter &rewriter) const override {
152 auto collapseOp =
153 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
154 if (!collapseOp)
155 return failure();
156 auto expandReInds = expandOp.getReassociationIndices();
157 auto collapseReInds = collapseOp.getReassociationIndices();
159 // Reshapes are parallel to each other if none of the reassociation indices
160 // have greater than 1 index for both reshapes.
161 for (auto [expandReassociation, collapseReassociation] :
162 llvm::zip_equal(expandReInds, collapseReInds)) {
163 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
164 return failure();
167 // Compute new reassociation indices and expanded/collaped shapes.
168 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
169 Location loc = expandOp->getLoc();
170 SmallVector<OpFoldResult> collapseSizes =
171 tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
172 SmallVector<OpFoldResult> expandSizes(getMixedValues(
173 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
174 SmallVector<OpFoldResult> newExpandSizes;
175 int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176 for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
177 if (collapseReassociation.size() != 1) {
178 ReassociationIndices newCollapseReassociation;
179 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
180 newCollapseReassociation.push_back(index);
181 newExpandReInds.push_back({index++});
182 newExpandSizes.push_back(collapseSizes[collapseIndex++]);
184 newCollapseReInds.push_back(newCollapseReassociation);
185 expandIndex++;
186 continue;
188 ReassociationIndices newExpandReassociation;
189 auto expandReassociation = expandReInds[idx];
190 for (size_t i = 0; i < expandReassociation.size(); ++i) {
191 newExpandReassociation.push_back(index);
192 newCollapseReInds.push_back({index++});
193 newExpandSizes.push_back(expandSizes[expandIndex++]);
195 newExpandReInds.push_back(newExpandReassociation);
196 collapseIndex++;
199 // Swap reshape order.
200 SmallVector<Value> dynamicSizes;
201 SmallVector<int64_t> staticSizes;
202 dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
203 auto expandResultType = expandOp.getResultType().clone(staticSizes);
204 auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
205 loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
206 newExpandSizes);
207 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
208 expandOp, newExpand.getResult(), newCollapseReInds);
209 return success();
213 } // namespace
215 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
216 RewritePatternSet &patterns) {
217 patterns
218 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
219 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
220 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
221 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
222 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
223 patterns.getContext());
226 void mlir::tensor::populateBubbleUpExpandShapePatterns(
227 RewritePatternSet &patterns) {
228 patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());