1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements logic for testing Tensor transformations.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Transform/IR/TransformOps.h"
20 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 struct TestTensorTransforms
28 : public PassWrapper
<TestTensorTransforms
, OperationPass
<>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms
)
31 TestTensorTransforms() = default;
32 TestTensorTransforms(const TestTensorTransforms
&pass
) : PassWrapper(pass
) {}
34 void getDependentDialects(DialectRegistry
®istry
) const override
{
35 registry
.insert
<arith::ArithDialect
, scf::SCFDialect
, linalg::LinalgDialect
,
36 transform::TransformDialect
>();
39 StringRef
getArgument() const final
{
40 return "test-tensor-transform-patterns";
42 StringRef
getDescription() const final
{
43 return "Test Tensor transformation patterns by applying them greedily.";
46 void runOnOperation() override
;
48 Option
<bool> testFoldConstantExtractSlice
{
49 *this, "test-fold-constant-extract-slice",
50 llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51 llvm::cl::init(false)};
53 Option
<bool> testFoldConsecutiveInsertExtractSlice
{
54 *this, "test-fold-consecutive-insert-extract-slice",
56 "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
57 llvm::cl::init(false)};
59 Option
<bool> testRewriteExtractSliceWithTiledCollapseShape
{
60 *this, "test-rewrite-extract-slice-from-collapse-shape",
61 llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
63 llvm::cl::init(false)};
65 Option
<bool> testDropRedundantInsertSliceRankExpansion
{
66 *this, "test-drop-redundant-insert-slice-rank-expansion",
67 llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
68 llvm::cl::init(false)};
70 Option
<bool> testReassociativeReshapeFolding
{
71 *this, "test-reassociative-reshape-folding",
72 llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
73 llvm::cl::init(false)};
75 Option
<bool> testBubbleUpExpandShapePatterns
{
76 *this, "test-expand-shape-bubbling",
77 llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
78 llvm::cl::init(false)};
80 Option
<bool> testFoldIntoPackAndUnpack
{
81 *this, "test-fold-into-pack-and-unpack",
82 llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
83 llvm::cl::init(false)};
85 Option
<bool> useForeach
{
88 "Use the scf.forall operation when generating loop nests for "
89 "the extract_slice of collapse_shape pattern"),
90 llvm::cl::init(false)};
92 Option
<bool> testSimplifyPackUnpackPatterns
{
93 *this, "test-simplify-pack-unpack-patterns",
94 llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
95 llvm::cl::init(false)};
97 Option
<bool> testTrackingListener
{
98 *this, "test-tracking-listener",
99 llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
100 llvm::cl::init(false)};
104 static void applyReassociativeReshapeFoldingPatterns(Operation
*rootOp
) {
105 RewritePatternSet
patterns(rootOp
->getContext());
106 tensor::populateReassociativeReshapeFoldingPatterns(patterns
);
107 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
110 static void applyBubbleUpExpandShapePatterns(Operation
*rootOp
) {
111 RewritePatternSet
patterns(rootOp
->getContext());
112 tensor::populateBubbleUpExpandShapePatterns(patterns
);
113 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
116 static void applyFoldIntoPackAndUnpackPatterns(Operation
*rootOp
) {
117 RewritePatternSet
patterns(rootOp
->getContext());
118 tensor::populateFoldIntoPackAndUnpackPatterns(patterns
);
119 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
122 static void applyFoldConstantExtractSlicePatterns(Operation
*rootOp
) {
123 RewritePatternSet
patterns(rootOp
->getContext());
124 tensor::ControlConstantExtractSliceFusionFn controlFn
=
125 [](tensor::ExtractSliceOp op
) {
126 if (!op
.getSource().hasOneUse())
129 auto resultType
= cast
<ShapedType
>(op
.getResult().getType());
130 constexpr int64_t kConstantFoldingMaxNumElements
= 1024;
131 return resultType
.getNumElements() <= kConstantFoldingMaxNumElements
;
134 tensor::populateFoldConstantExtractSlicePatterns(patterns
, controlFn
);
135 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
138 static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation
*rootOp
) {
139 RewritePatternSet
patterns(rootOp
->getContext());
140 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns
);
141 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
145 applyDropRedundantInsertSliceRankExpansionPatterns(Operation
*rootOp
) {
146 RewritePatternSet
patterns(rootOp
->getContext());
147 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns
);
148 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
151 static void applySimplifyPackUnpackPatterns(Operation
*rootOp
) {
152 RewritePatternSet
patterns(rootOp
->getContext());
153 tensor::populateSimplifyPackAndUnpackPatterns(patterns
);
154 (void)applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
158 /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
159 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
160 /// stitches together the desired tile from slices of the source of the collapse
162 struct RewriteExtractSliceFromCollapseShapeBase
163 : public OpRewritePattern
<tensor::ExtractSliceOp
> {
164 RewriteExtractSliceFromCollapseShapeBase(MLIRContext
*context
)
165 : mlir::OpRewritePattern
<tensor::ExtractSliceOp
>(context
) {}
167 /// Emit a loop or gather operation that uses `helper` to take each point in
168 /// the parallel iteration space bounds, extract a slice from the source
169 /// tensor and insert it into `dest`. For examples, see below for `scf.for`
170 /// and `scf.foreach`.
171 virtual LogicalResult
172 emitReplacement(tensor::ExtractSliceOp op
, Value dest
,
173 tensor::ExtractSliceFromCollapseHelper
&helper
,
174 PatternRewriter
&rewriter
) const = 0;
176 LogicalResult
matchAndRewrite(tensor::ExtractSliceOp op
,
177 PatternRewriter
&rewriter
) const override
{
178 auto collapseOp
= op
.getSource().getDefiningOp
<tensor::CollapseShapeOp
>();
180 return rewriter
.notifyMatchFailure(
181 op
, "producer is not a tensor.collapse_shape op");
183 // Try to simplify the collapse shape using a rank-reducing slice, if
185 FailureOr
<Operation
*> simplifiedCollapseShapeResult
=
186 tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp
,
188 if (succeeded(simplifiedCollapseShapeResult
)) {
190 dyn_cast
<tensor::CollapseShapeOp
>(*simplifiedCollapseShapeResult
);
191 // The collapse shape op might have been simplified away, so we can just
195 collapseOp
= newCollapseOp
;
198 // Materialize the output shape values of the slice operation.
199 ReifiedRankedShapedTypeDims reifiedShapes
;
200 if (failed(reifyResultShapes(rewriter
, op
, reifiedShapes
)))
201 return rewriter
.notifyMatchFailure(op
, "failed to reify result shapes");
203 // Create the destination tensor using the above values.
204 Type elementType
= op
.getSourceType().getElementType();
205 SmallVector
<OpFoldResult
> outputShape
= reifiedShapes
[0];
206 Value dest
= rewriter
.create
<tensor::EmptyOp
>(op
->getLoc(), outputShape
,
209 // Calculate the parameters for the tile loop nest.
210 FailureOr
<tensor::ExtractSliceFromCollapseHelper
> params
=
211 tensor::ExtractSliceFromCollapseHelper::create(rewriter
, collapseOp
,
214 return rewriter
.notifyMatchFailure(
215 op
, "could not calculate tiling parameters");
216 return emitReplacement(op
, dest
, *params
, rewriter
);
220 struct RewriteExtractSliceFromCollapseShapeUsingScfFor
221 : public RewriteExtractSliceFromCollapseShapeBase
{
222 RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext
*context
)
223 : RewriteExtractSliceFromCollapseShapeBase(context
) {}
224 LogicalResult
emitReplacement(tensor::ExtractSliceOp op
, Value dest
,
225 tensor::ExtractSliceFromCollapseHelper
&helper
,
226 PatternRewriter
&rewriter
) const override
{
227 Location loc
= op
.getLoc();
228 const unsigned numTiledDims
= helper
.getIterationSpaceSizes().size();
229 auto zero
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
230 auto one
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
231 SmallVector
<Value
> lbs(numTiledDims
, zero
);
232 SmallVector
<Value
> steps(numTiledDims
, one
);
234 scf::LoopNest nest
= scf::buildLoopNest(
235 rewriter
, loc
, lbs
, helper
.getIterationSpaceSizes(), steps
, dest
,
236 [&](OpBuilder
&nestedBuilder
, Location loc
, ValueRange outputIvs
,
237 ValueRange iterArgs
) -> scf::ValueVector
{
238 auto [tile
, insertParams
] =
239 helper
.emitLoopNestBody(nestedBuilder
, loc
, outputIvs
);
241 // Insert the slice into the destination.
242 return {nestedBuilder
.create
<tensor::InsertSliceOp
>(
243 loc
, tile
, iterArgs
[0], insertParams
)};
245 rewriter
.replaceOp(op
, nest
.results
);
251 struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
252 : public RewriteExtractSliceFromCollapseShapeBase
{
253 RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext
*context
)
254 : RewriteExtractSliceFromCollapseShapeBase(context
) {}
255 LogicalResult
emitReplacement(tensor::ExtractSliceOp op
, Value dest
,
256 tensor::ExtractSliceFromCollapseHelper
&helper
,
257 PatternRewriter
&rewriter
) const override
{
258 Location loc
= op
.getLoc();
259 auto forallOp
= rewriter
.create
<scf::ForallOp
>(
260 loc
, /*numThreads=*/getAsOpFoldResult(helper
.getIterationSpaceSizes()),
262 /*mapping=*/std::nullopt
,
263 [&](OpBuilder
&nestedBuilder
, Location loc
, ValueRange regionArgs
) {
264 unsigned numThreadIdRegionArgs
=
265 helper
.getIterationSpaceSizes().size();
266 unsigned numOutputRegionArgs
=
267 regionArgs
.size() - numThreadIdRegionArgs
;
268 ValueRange outputIvs
= regionArgs
.take_front(numThreadIdRegionArgs
);
269 ValueRange outputArgs
= regionArgs
.take_back(numOutputRegionArgs
);
270 assert(outputArgs
.size() == 1 &&
271 "there should only be one output region argument");
272 auto [tile
, insertParams
] =
273 helper
.emitLoopNestBody(nestedBuilder
, loc
, outputIvs
);
274 // Insert the slice into the destination.
275 auto term
= nestedBuilder
.create
<scf::InParallelOp
>(loc
);
276 nestedBuilder
.setInsertionPointToStart(term
.getBody());
277 nestedBuilder
.create
<tensor::ParallelInsertSliceOp
>(
278 loc
, tile
, outputArgs
[0], insertParams
);
280 rewriter
.replaceOp(op
, forallOp
->getResult(0));
287 applyRewriteExtractFromCollapseShapePatterns(Operation
*rootOp
,
289 RewritePatternSet
patterns(rootOp
->getContext());
291 patterns
.add
<RewriteExtractSliceFromCollapseShapeUsingScfForeach
>(
292 rootOp
->getContext());
294 patterns
.add
<RewriteExtractSliceFromCollapseShapeUsingScfFor
>(
295 rootOp
->getContext());
296 return applyPatternsAndFoldGreedily(rootOp
, std::move(patterns
));
300 class DummyTrackingListener
: public transform::TrackingListener
{
302 using transform::TrackingListener::TrackingListener
;
304 // Expose `findReplacementOp` as a public function, so that it can be tested.
305 Operation
*getReplacementOp(Operation
*op
, ValueRange newValues
) const {
306 Operation
*replacementOp
;
307 if (!findReplacementOp(replacementOp
, op
, newValues
).succeeded())
309 return replacementOp
;
314 static LogicalResult
testTrackingListenerReplacements(Operation
*rootOp
) {
316 Operation
*replaced
= nullptr;
317 WalkResult status
= rootOp
->walk([&](Operation
*op
) {
318 if (op
->hasAttr("replaced")) {
320 op
->emitError("only one 'replaced' op is allowed per test case");
321 replaced
->emitRemark("other 'replaced' op");
322 return WalkResult::interrupt();
326 return WalkResult::advance();
328 if (status
.wasInterrupted())
331 rootOp
->emitError("could not find 'replaced' op");
335 // Find replacements.
336 SmallVector
<Value
> replacements(replaced
->getNumResults(), Value());
337 status
= rootOp
->walk([&](Operation
*op
) {
338 for (int64_t i
= 0; i
< replaced
->getNumResults(); ++i
) {
339 if (auto attr
= op
->getAttrOfType
<IntegerAttr
>("replacement_" +
340 std::to_string(i
))) {
341 if (replacements
[i
]) {
342 op
->emitError("only one 'replacement_" + std::to_string(i
) +
343 "' is allowed per test case");
344 replacements
[i
].getDefiningOp()->emitRemark("other 'replacement_" +
345 std::to_string(i
) + "'");
346 return WalkResult::interrupt();
348 replacements
[i
] = op
->getResult(attr
.getInt());
351 return WalkResult::advance();
353 if (status
.wasInterrupted())
356 if (!llvm::all_of(replacements
,
357 [](Value v
) { return static_cast<bool>(v
); })) {
358 replaced
->emitError("insufficient replacement values");
362 // Find the replacement op (if any) and emit a remark/error.
363 transform::TransformState transformState
=
364 transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
365 /*payloadRoot=*/nullptr);
366 MLIRContext
*context
= rootOp
->getContext();
367 OpBuilder
builder(context
);
368 OwningOpRef
<transform::NamedSequenceOp
> transformOp
=
369 builder
.create
<transform::NamedSequenceOp
>(
371 /*sym_name=*/"test_sequence",
373 TypeAttr::get(FunctionType::get(context
, TypeRange
{}, TypeRange
{})),
374 /*sym_visibility*/ StringAttr::get(context
, "public"),
375 /*arg_attrs=*/ArrayAttr::get(context
, ArrayRef
<Attribute
>()),
376 /*res_attrs=*/ArrayAttr::get(context
, ArrayRef
<Attribute
>()));
377 DummyTrackingListener
listener(transformState
, transformOp
.get());
378 Operation
*replacement
= listener
.getReplacementOp(replaced
, replacements
);
380 replaced
->emitError("listener could not find replacement op");
384 replacement
->emitRemark("replacement found");
388 void TestTensorTransforms::runOnOperation() {
389 Operation
*rootOp
= getOperation();
390 if (testSimplifyPackUnpackPatterns
)
391 applySimplifyPackUnpackPatterns(rootOp
);
392 if (testFoldConstantExtractSlice
)
393 applyFoldConstantExtractSlicePatterns(rootOp
);
394 if (testFoldConsecutiveInsertExtractSlice
)
395 applyFoldConsecutiveInsertExtractSlicePatterns(rootOp
);
396 if (testDropRedundantInsertSliceRankExpansion
)
397 applyDropRedundantInsertSliceRankExpansionPatterns(rootOp
);
398 if (testReassociativeReshapeFolding
)
399 applyReassociativeReshapeFoldingPatterns(rootOp
);
400 if (testBubbleUpExpandShapePatterns
)
401 applyBubbleUpExpandShapePatterns(rootOp
);
402 if (testFoldIntoPackAndUnpack
)
403 applyFoldIntoPackAndUnpackPatterns(rootOp
);
404 if (testRewriteExtractSliceWithTiledCollapseShape
) {
406 applyRewriteExtractFromCollapseShapePatterns(rootOp
, useForeach
)))
407 return signalPassFailure();
409 if (testTrackingListener
)
410 if (failed(testTrackingListenerReplacements(rootOp
)))
411 return signalPassFailure();
416 void registerTestTensorTransforms() {
417 PassRegistration
<TestTensorTransforms
>();