1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements logic for testing Linalg fusion patterns.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
23 using namespace mlir::linalg
;
25 static LogicalResult
fuseLinalgOpsGreedily(func::FuncOp f
) {
28 // Save original Linalg ops, we only want to make a pass over those.
29 SmallVector
<LinalgOp
, 8> linalgOps
;
30 f
.walk([&](LinalgOp op
) {
31 // TODO: support multi-results.
32 if (op
->getNumResults() <= 1)
33 linalgOps
.push_back(op
);
36 // Tile and Fuse for tensors inputs (TODO: all tensor operands).
38 for (LinalgOp linalgOp
: llvm::reverse(linalgOps
)) {
39 for (OpOperand
&opOperand
: linalgOp
->getOpOperands()) {
40 if (isa
<MemRefType
>(opOperand
.get().getType()))
42 if (isa
<RankedTensorType
>(opOperand
.get().getType())) {
43 // Tile and Fuse tensor input.
44 if (opOperand
.getOperandNumber() >= linalgOp
.getNumDpsInputs())
46 auto info
= fuseProducerOfTensor(b
, opOperand
);
49 auto *originalOp
= info
->originalProducer
.getOperation();
50 auto *originalOpInLinalgOpsVector
=
51 std::find(linalgOps
.begin(), linalgOps
.end(), originalOp
);
52 *originalOpInLinalgOpsVector
= info
->fusedProducer
;
53 // Don't mark for erasure in the tensor case, let DCE handle this.
59 return changed
? success() : failure();
63 struct TestLinalgGreedyFusion
64 : public PassWrapper
<TestLinalgGreedyFusion
, OperationPass
<func::FuncOp
>> {
65 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion
)
67 void getDependentDialects(DialectRegistry
®istry
) const override
{
68 registry
.insert
<affine::AffineDialect
, linalg::LinalgDialect
,
69 memref::MemRefDialect
, scf::SCFDialect
>();
71 StringRef
getArgument() const final
{ return "test-linalg-greedy-fusion"; }
72 StringRef
getDescription() const final
{
73 return "Test Linalg fusion by applying a greedy test transformation.";
75 void runOnOperation() override
{
76 MLIRContext
*context
= &getContext();
77 RewritePatternSet patterns
=
78 linalg::getLinalgTilingCanonicalizationPatterns(context
);
79 patterns
.add
<ExtractSliceOfPadTensorSwapPattern
>(context
);
80 scf::populateSCFForLoopCanonicalizationPatterns(patterns
);
81 FrozenRewritePatternSet
frozenPatterns(std::move(patterns
));
82 OpPassManager
pm(func::FuncOp::getOperationName());
83 pm
.addPass(createLoopInvariantCodeMotionPass());
84 pm
.addPass(createCanonicalizerPass());
85 pm
.addPass(createCSEPass());
87 (void)applyPatternsGreedily(getOperation(), frozenPatterns
);
88 if (failed(runPipeline(pm
, getOperation())))
89 this->signalPassFailure();
90 } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
97 void registerTestLinalgGreedyFusion() {
98 PassRegistration
<TestLinalgGreedyFusion
>();