1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
24 static void addOperands(Operation
*op
, SetVector
<Value
> &operandSet
) {
27 TypeSwitch
<Operation
*, void>(op
)
28 .Case
<linalg::LinalgOp
>([&](linalg::LinalgOp linalgOp
) {
29 SmallVector
<Value
> inputOperands
= linalgOp
.getDpsInputs();
30 operandSet
.insert(inputOperands
.begin(), inputOperands
.end());
32 .Default([&](Operation
*operation
) {
33 operandSet
.insert(operation
->operand_begin(), operation
->operand_end());
37 template <int limit
= 3>
38 static bool setFusedOpOperandLimit(OpOperand
*fusedOperand
) {
39 Operation
*producer
= fusedOperand
->get().getDefiningOp();
43 Operation
*consumer
= fusedOperand
->getOwner();
44 SetVector
<Value
> fusedOpOperands
;
45 if (producer
->getNumResults() != 1)
47 addOperands(consumer
, fusedOpOperands
);
48 fusedOpOperands
.remove(producer
->getResult(0));
49 addOperands(producer
, fusedOpOperands
);
50 return fusedOpOperands
.size() <= limit
;
55 /// Pattern to test fusion of producer with consumer, even if producer has
57 struct TestMultiUseProducerFusion
: public OpRewritePattern
<linalg::GenericOp
> {
58 using OpRewritePattern
<linalg::GenericOp
>::OpRewritePattern
;
60 LogicalResult
matchAndRewrite(linalg::GenericOp genericOp
,
61 PatternRewriter
&rewriter
) const override
{
62 OpOperand
*fusableOperand
= nullptr;
63 for (OpOperand
&operand
: genericOp
->getOpOperands()) {
64 if (linalg::areElementwiseOpsFusable(&operand
)) {
65 fusableOperand
= &operand
;
69 if (!fusableOperand
) {
70 return rewriter
.notifyMatchFailure(genericOp
, "no fusable operand found");
72 std::optional
<linalg::ElementwiseOpFusionResult
> fusionResult
=
73 linalg::fuseElementwiseOps(rewriter
, fusableOperand
);
75 return rewriter
.notifyMatchFailure(genericOp
, "fusion failed");
76 for (auto [origValue
, replacement
] : fusionResult
->replacements
) {
77 rewriter
.replaceUsesWithIf(origValue
, replacement
, [&](OpOperand
&use
) {
78 return use
.getOwner() != genericOp
.getOperation();
81 rewriter
.eraseOp(genericOp
);
86 struct TestLinalgElementwiseFusion
87 : public PassWrapper
<TestLinalgElementwiseFusion
,
88 OperationPass
<func::FuncOp
>> {
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion
)
91 TestLinalgElementwiseFusion() = default;
92 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion
&pass
)
93 : PassWrapper(pass
) {}
94 void getDependentDialects(DialectRegistry
®istry
) const override
{
95 registry
.insert
<affine::AffineDialect
, linalg::LinalgDialect
,
96 memref::MemRefDialect
, tensor::TensorDialect
>();
98 StringRef
getArgument() const final
{
99 return "test-linalg-elementwise-fusion-patterns";
101 StringRef
getDescription() const final
{
102 return "Test Linalg element wise operation fusion patterns";
105 Option
<bool> fuseGenericOps
{
106 *this, "fuse-generic-ops",
107 llvm::cl::desc("Test fusion of generic operations."),
108 llvm::cl::init(false)};
110 Option
<bool> fuseGenericOpsControl
{
111 *this, "fuse-generic-ops-control",
113 "Test fusion of generic operations with a control function."),
114 llvm::cl::init(false)};
116 Option
<bool> fuseWithReshapeByExpansion
{
117 *this, "fuse-with-reshape-by-expansion",
119 "Test fusion of generic operations with reshape by expansion"),
120 llvm::cl::init(false)};
122 Option
<bool> controlFuseByExpansion
{
123 *this, "control-fusion-by-expansion",
125 "Test controlling fusion of reshape with generic op by expansion"),
126 llvm::cl::init(false)};
128 Option
<bool> fuseWithReshapeByCollapsing
{
129 *this, "fuse-with-reshape-by-collapsing",
130 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
131 "collapse the iteration space of the consumer"),
132 llvm::cl::init(false)};
134 Option
<bool> fuseWithReshapeByCollapsingWithControlFn
{
135 *this, "fuse-with-reshape-by-collapsing-control",
136 llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
137 "fusion patterns that "
138 "collapse the iteration space of the consumer"),
139 llvm::cl::init(false)};
141 Option
<bool> fuseMultiUseProducer
{
142 *this, "fuse-multiuse-producer",
143 llvm::cl::desc("Test fusion of producer ops with multiple uses"),
144 llvm::cl::init(false)};
146 ListOption
<int64_t> collapseDimensions
{
147 *this, "collapse-dimensions-control",
148 llvm::cl::desc("Test controlling dimension collapse pattern")};
150 void runOnOperation() override
{
151 MLIRContext
*context
= &this->getContext();
152 func::FuncOp funcOp
= this->getOperation();
154 if (fuseGenericOps
) {
155 RewritePatternSet
fusionPatterns(context
);
156 auto controlFn
= [](OpOperand
*operand
) { return true; };
157 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns
, controlFn
);
158 if (failed(applyPatternsGreedily(funcOp
.getBody(),
159 std::move(fusionPatterns
))))
160 return signalPassFailure();
164 if (fuseGenericOpsControl
) {
165 RewritePatternSet
fusionPatterns(context
);
166 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns
,
167 setFusedOpOperandLimit
<4>);
169 if (failed(applyPatternsGreedily(funcOp
.getBody(),
170 std::move(fusionPatterns
))))
171 return signalPassFailure();
175 if (fuseWithReshapeByExpansion
) {
176 RewritePatternSet
fusionPatterns(context
);
177 linalg::populateFoldReshapeOpsByExpansionPatterns(
178 fusionPatterns
, [](OpOperand
* /*fusedOperand*/) { return true; });
179 if (failed(applyPatternsGreedily(funcOp
.getBody(),
180 std::move(fusionPatterns
))))
181 return signalPassFailure();
185 if (controlFuseByExpansion
) {
186 RewritePatternSet
fusionPatterns(context
);
188 linalg::ControlFusionFn controlReshapeFusionFn
=
189 [](OpOperand
*fusedOperand
) {
190 Operation
*producer
= fusedOperand
->get().getDefiningOp();
194 if (auto collapseOp
= dyn_cast
<tensor::CollapseShapeOp
>(producer
)) {
195 if (!collapseOp
.getSrc().getDefiningOp
<linalg::LinalgOp
>()) {
200 Operation
*consumer
= fusedOperand
->getOwner();
201 if (auto expandOp
= dyn_cast
<tensor::ExpandShapeOp
>(consumer
)) {
202 if (expandOp
->hasOneUse()) {
203 OpOperand
&use
= *expandOp
->getUses().begin();
204 auto linalgOp
= dyn_cast
<linalg::LinalgOp
>(use
.getOwner());
205 if (linalgOp
&& linalgOp
.isDpsInit(&use
))
213 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns
,
214 controlReshapeFusionFn
);
215 if (failed(applyPatternsGreedily(funcOp
.getBody(),
216 std::move(fusionPatterns
))))
217 return signalPassFailure();
221 if (fuseWithReshapeByCollapsing
) {
222 RewritePatternSet
patterns(context
);
223 linalg::populateFoldReshapeOpsByCollapsingPatterns(
224 patterns
, [](OpOperand
* /*fusedOperand */) { return true; });
225 if (failed(applyPatternsGreedily(funcOp
.getBody(), std::move(patterns
))))
226 return signalPassFailure();
230 if (fuseWithReshapeByCollapsingWithControlFn
) {
231 RewritePatternSet
patterns(context
);
232 linalg::ControlFusionFn controlFn
= [](OpOperand
*fusedOperand
) -> bool {
233 Operation
*producer
= fusedOperand
->get().getDefiningOp();
234 if (isa
<tensor::ExpandShapeOp
>(producer
)) {
235 // Skip fusing the first operand.
236 return fusedOperand
->getOperandNumber();
240 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns
, controlFn
);
241 if (failed(applyPatternsGreedily(funcOp
.getBody(), std::move(patterns
))))
242 return signalPassFailure();
246 if (fuseMultiUseProducer
) {
247 RewritePatternSet
patterns(context
);
248 patterns
.insert
<TestMultiUseProducerFusion
>(context
);
249 if (failed(applyPatternsGreedily(funcOp
.getBody(), std::move(patterns
))))
250 return signalPassFailure();
254 if (!collapseDimensions
.empty()) {
255 SmallVector
<int64_t, 2> dims(collapseDimensions
.begin(),
256 collapseDimensions
.end());
257 linalg::GetCollapsableDimensionsFn collapseFn
=
258 [&dims
](linalg::LinalgOp op
) {
259 SmallVector
<ReassociationIndices
> reassociations
;
260 reassociations
.emplace_back(dims
);
261 return reassociations
;
263 RewritePatternSet
patterns(context
);
264 linalg::populateCollapseDimensions(patterns
, collapseFn
);
265 if (failed(applyPatternsGreedily(funcOp
.getBody(), std::move(patterns
))))
266 return signalPassFailure();
276 void registerTestLinalgElementwiseFusion() {
277 PassRegistration
<TestLinalgElementwiseFusion
>();