[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / Tensor / Transforms / PackAndUnpackPatterns.cpp
blob3566714c6529e3cbcd07aa1f2d193ad9bef6d117
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Linalg/IR/Linalg.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12 #include "mlir/Dialect/Utils/IndexingUtils.h"
13 #include "mlir/IR/PatternMatch.h"
15 namespace mlir {
16 namespace tensor {
17 namespace {
19 /// Returns the number of shape sizes that is either dynamic or greater than 1.
20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21 return llvm::count_if(
22 shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
25 /// Returns success() if there is only 1 dimension size in non-packed domain
26 /// being greater than 1 and packing only happens on the dimension.
27 /// Note: this method should only be used by pack/unpack to reshape conversion.
28 /// It assumes that non-unit inner tile size must be used by the non-unit
29 /// dimension.
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31 ArrayRef<int64_t> srcShape,
32 ArrayRef<int64_t> innerPackTileSize) {
33 if (getNumGtOneDims(srcShape) > 1) {
34 return rewriter.notifyMatchFailure(
35 op, "expects non-packed domain to have at most one non-unit dims");
37 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38 // will faill on getting reassociation maps.
39 if (getNumGtOneDims(innerPackTileSize) > 1) {
40 return rewriter.notifyMatchFailure(
41 op, "expects at most one non-unit inner tiles");
43 return success();
46 // If the `linalgOp` represents a transpose, return the permutation vector for
47 // the transpose. Otherwise, return failure.
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51 return SmallVector<int64_t>(transposeOp.getPermutation());
52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53 return failure();
55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56 return failure();
57 auto mapRange = linalgOp.getIndexingMapsArray();
58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59 mapRange.front() == mapRange.back()) {
60 return failure();
62 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
63 return failure();
64 AffineMap outMap = mapRange.back();
65 AffineMap inMap = mapRange.front();
66 // To get the permutation, look at each output index and find which
67 // dimension in the input we're reading from for that index.
68 return llvm::map_to_vector(outMap.getResults(),
69 [&](AffineExpr expr) -> int64_t {
70 return *inMap.getResultPosition(expr);
71 });
74 /// Packing one-dimensional tensor can be expressed as an expand shape op.
75 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
76 using OpRewritePattern<PackOp>::OpRewritePattern;
78 FailureOr<Value>
79 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80 Type newOperandType,
81 ArrayRef<ReassociationIndices> reassociation) const {
82 if (operand.getType() == newOperandType)
83 return operand;
84 return rewriter
85 .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
86 reassociation)
87 .getResult();
90 /// Returns success() if it is only packing on the innermost dimension.
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp) const {
93 auto outerDimsPerm = packOp.getOuterDimsPerm();
94 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
95 return rewriter.notifyMatchFailure(
96 packOp,
97 "expects outer_dims_perm is empty or an identity permutation");
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp, "expects packing at the innermost dimension");
106 return success();
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter) const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
119 return failure();
122 RankedTensorType destType = packOp.getDestType();
123 auto reassociation =
124 getReassociationIndicesForReshape(sourceType, destType);
125 if (!reassociation)
126 return failure();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129 *reassociation);
130 if (failed(expanded)) {
131 return rewriter.notifyMatchFailure(
132 packOp, "unable to expand source of tensor.pack");
134 rewriter.replaceOp(packOp, *expanded);
135 return success();
139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
140 using OpRewritePattern<UnPackOp>::OpRewritePattern;
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation) const {
144 if (operand.getType() == newOperandType)
145 return operand;
146 return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147 operand, reassociation);
150 /// Returns success() if it is unpacking on the innermost dimension.
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp) const {
153 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
155 return rewriter.notifyMatchFailure(
156 unpackOp,
157 "expects outer_dims_perm is empty or an identity permutation");
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp, "expects unpacking on the innermost dimension");
171 return success();
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter) const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
181 return failure();
184 RankedTensorType sourceType = unpackOp.getSourceType();
185 auto reassociation =
186 getReassociationIndicesForReshape(sourceType, destType);
187 if (!reassociation)
188 return failure();
189 Value collapsed = insertCollapse(
190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191 getReassociationIndicesAttribute(rewriter, *reassociation));
192 rewriter.replaceOp(unpackOp, collapsed);
193 return success();
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200 using OpRewritePattern<PackOp>::OpRewritePattern;
202 LogicalResult matchAndRewrite(PackOp packOp,
203 PatternRewriter &rewriter) const override {
204 auto padOp = packOp.getSource().getDefiningOp<PadOp>();
206 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207 return failure();
209 Value constantPaddingValue = padOp.getConstantPaddingValue();
210 if (!constantPaddingValue)
211 return failure();
213 if (auto paddingValue = packOp.getPaddingValue())
214 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
215 return failure();
217 rewriter.replaceOpWithNewOp<PackOp>(
218 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219 packOp.getMixedTiles(), constantPaddingValue,
220 packOp.getOuterDimsPerm());
221 return success();
225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226 /// has extract_slice semantics.
227 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
228 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
230 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
231 PatternRewriter &rewriter) const override {
232 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
233 if (!unpackOp)
234 return failure();
236 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
237 return rewriter.notifyMatchFailure(
238 sliceOp, "rank-reduced folding is not supported");
241 // Check all offsets are zeros, and all strides are ones.
242 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
243 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
244 return rewriter.notifyMatchFailure(
245 sliceOp, "expects offsets to be 0s and strides to be 1s");
248 // Create a new empty output tensor.
249 Type elementType = unpackOp.getDestType().getElementType();
250 Value output = rewriter.create<EmptyOp>(
251 sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
252 rewriter.replaceOpWithNewOp<UnPackOp>(
253 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
254 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
255 return success();
259 // Applies 'permutation' on 'inVec' and stores the result in resVec.
260 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
261 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
262 // greater than the rank specified. If it's so then return false.
263 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
264 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
265 // not allowed since `3` exceeds the value of the rank in the given range.
266 static bool checkAndPermute(ArrayRef<int64_t> permutation,
267 ArrayRef<int64_t> inVec,
268 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
270 for (unsigned int i = 0; i < rank; ++i) {
271 int64_t remappedPosition = permutation[i];
272 if (remappedPosition >= rank)
273 return false;
274 if (!inVec.empty())
275 remappedPosition = inVec[remappedPosition];
276 resVec.push_back(remappedPosition);
279 return true;
282 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
283 /// semantics.
284 struct FoldProducerPackWithConsumerLinalgTransposeOp
285 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
286 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
288 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
289 PatternRewriter &rewriter) const override {
290 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
292 if (!packOp)
293 return failure();
295 FailureOr<SmallVector<int64_t>> maybePerm =
296 getTransposeOpPermutation(linalgOp);
297 if (failed(maybePerm))
298 return failure();
300 auto innerDimsPos = packOp.getInnerDimsPos();
301 auto mixedInnerTiles = packOp.getMixedTiles();
302 auto outerDimsPerm = packOp.getOuterDimsPerm();
303 auto transposePerm = maybePerm.value();
304 SmallVector<int64_t> newOuterDimsPermVec;
305 SmallVector<int64_t> newInnerDimsPosVec;
306 SmallVector<OpFoldResult> newMixedInnerTilesVec;
307 int64_t srcRank = packOp.getSourceRank();
309 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
310 srcRank))
311 return rewriter.notifyMatchFailure(
312 linalgOp,
313 "Cannot fold in tensor.pack if a tile dimension was transposed "
314 "with a non-tile dimension in linalg.transpose.");
316 // Process transpose operation for tiled inner dimensions
317 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
318 int64_t remappedPosition = transposePerm[i] - srcRank;
319 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
320 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
323 Value output = packOp.createDestinationTensor(
324 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
325 newInnerDimsPosVec, newOuterDimsPermVec);
327 rewriter.replaceOpWithNewOp<PackOp>(
328 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
329 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
331 return success();
335 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
336 /// semantics.
337 struct FoldConsumerPackWithProducerLinalgTransposeOp
338 : public OpRewritePattern<PackOp> {
339 using OpRewritePattern<PackOp>::OpRewritePattern;
341 LogicalResult matchAndRewrite(PackOp packOp,
342 PatternRewriter &rewriter) const override {
343 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
344 if (!linalgOp)
345 return failure();
347 FailureOr<SmallVector<int64_t>> maybePerm =
348 getTransposeOpPermutation(linalgOp);
349 if (failed(maybePerm))
350 return failure();
352 auto transposePermutation = maybePerm.value();
353 auto outerDimsPerm = packOp.getOuterDimsPerm();
354 auto innerDimsPos = packOp.getInnerDimsPos();
355 SmallVector<int64_t> newInnerDimsPosVec;
356 SmallVector<int64_t> newOuterDimsPermVec =
357 llvm::to_vector(transposePermutation);
359 if (!outerDimsPerm.empty())
360 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
362 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
363 // permutation rank won't necessarily be equal in all cases.
364 for (auto dim : innerDimsPos)
365 newInnerDimsPosVec.push_back(transposePermutation[dim]);
367 Value output = packOp.createDestinationTensor(
368 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
369 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
371 rewriter.replaceOpWithNewOp<PackOp>(
372 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
373 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
375 return success();
379 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
380 /// transpose semantics.
381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
382 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
383 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
385 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
386 PatternRewriter &rewriter) const override {
387 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
389 if (!unPackOp)
390 return failure();
392 FailureOr<SmallVector<int64_t>> maybePerm =
393 getTransposeOpPermutation(linalgOp);
394 if (failed(maybePerm))
395 return failure();
397 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
398 auto innerDimsPos = unPackOp.getInnerDimsPos();
399 SmallVector<int64_t> newInnerDimsPosVec;
400 SmallVector<int64_t> newOuterDimsPermVec =
401 invertPermutationVector(maybePerm.value());
403 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404 // permutation rank won't necessarily be equal in all cases.
405 for (auto dim : innerDimsPos)
406 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
408 if (!outerDimsPerm.empty())
409 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
411 // Reuse the destination of the transpose op.
412 rewriter.replaceOpWithNewOp<UnPackOp>(
413 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
414 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
416 return success();
420 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
421 /// transpose semantics.
422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423 : public OpRewritePattern<UnPackOp> {
424 using OpRewritePattern<UnPackOp>::OpRewritePattern;
426 LogicalResult matchAndRewrite(UnPackOp unPackOp,
427 PatternRewriter &rewriter) const override {
428 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
429 if (!linalgOp)
430 return failure();
432 FailureOr<SmallVector<int64_t>> maybePerm =
433 getTransposeOpPermutation(linalgOp);
434 if (failed(maybePerm))
435 return failure();
437 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
438 if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
439 return failure();
442 SmallVector<int64_t> inverseTransposePerm =
443 invertPermutationVector(maybePerm.value());
444 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445 auto innerDimsPos = unPackOp.getInnerDimsPos();
446 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448 SmallVector<int64_t> newOuterDimsPermVec;
449 SmallVector<int64_t> newInnerDimsPosVec;
450 SmallVector<OpFoldResult> newMixedInnerTilesVec;
451 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
452 newOuterDimsPermVec, destRank))
453 return rewriter.notifyMatchFailure(
454 unPackOp,
455 "Cannot fold in tensor.unpack if a tile dimension was transposed "
456 "with a non-tile dimension in linalg.transpose.");
458 // Process transpose operation for tiled inner dimensions
459 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
460 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
461 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
462 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
465 auto elemType =
466 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
467 Value output = rewriter.create<tensor::EmptyOp>(
468 unPackOp->getLoc(), unpackOpResultDims[0], elemType);
470 rewriter.replaceOpWithNewOp<UnPackOp>(
471 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472 newMixedInnerTilesVec, newOuterDimsPermVec);
474 return success();
477 } // namespace
479 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
480 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481 FoldProducerPackWithConsumerLinalgTransposeOp,
482 FoldConsumerPackWithProducerLinalgTransposeOp,
483 FoldConsumerUnPackWithProducerLinalgTransposeOp,
484 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
485 patterns.getContext());
488 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
489 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
490 patterns.getContext());
493 } // namespace tensor
494 } // namespace mlir