Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Linalg / TestLinalgElementwiseFusion.cpp
blob7f68f4aec3a10c3f946da279796f9758e72c8349
1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
22 using namespace mlir;
24 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
25 if (!op)
26 return;
27 TypeSwitch<Operation *, void>(op)
28 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
29 SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
30 operandSet.insert(inputOperands.begin(), inputOperands.end());
32 .Default([&](Operation *operation) {
33 operandSet.insert(operation->operand_begin(), operation->operand_end());
34 });
37 template <int limit = 3>
38 static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
39 Operation *producer = fusedOperand->get().getDefiningOp();
40 if (!producer)
41 return false;
43 Operation *consumer = fusedOperand->getOwner();
44 SetVector<Value> fusedOpOperands;
45 if (producer->getNumResults() != 1)
46 return false;
47 addOperands(consumer, fusedOpOperands);
48 fusedOpOperands.remove(producer->getResult(0));
49 addOperands(producer, fusedOpOperands);
50 return fusedOpOperands.size() <= limit;
53 namespace {
55 /// Pattern to test fusion of producer with consumer, even if producer has
56 /// multiple uses.
57 struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
58 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
60 LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
61 PatternRewriter &rewriter) const override {
62 OpOperand *fusableOperand = nullptr;
63 for (OpOperand &operand : genericOp->getOpOperands()) {
64 if (linalg::areElementwiseOpsFusable(&operand)) {
65 fusableOperand = &operand;
66 break;
69 if (!fusableOperand) {
70 return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
72 std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
73 linalg::fuseElementwiseOps(rewriter, fusableOperand);
74 if (!fusionResult)
75 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
76 for (auto [origValue, replacement] : fusionResult->replacements) {
77 rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
78 return use.getOwner() != genericOp.getOperation();
79 });
81 rewriter.eraseOp(genericOp);
82 return success();
86 struct TestLinalgElementwiseFusion
87 : public PassWrapper<TestLinalgElementwiseFusion,
88 OperationPass<func::FuncOp>> {
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
91 TestLinalgElementwiseFusion() = default;
92 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
93 : PassWrapper(pass) {}
94 void getDependentDialects(DialectRegistry &registry) const override {
95 registry.insert<affine::AffineDialect, linalg::LinalgDialect,
96 memref::MemRefDialect, tensor::TensorDialect>();
98 StringRef getArgument() const final {
99 return "test-linalg-elementwise-fusion-patterns";
101 StringRef getDescription() const final {
102 return "Test Linalg element wise operation fusion patterns";
105 Option<bool> fuseGenericOps{
106 *this, "fuse-generic-ops",
107 llvm::cl::desc("Test fusion of generic operations."),
108 llvm::cl::init(false)};
110 Option<bool> fuseGenericOpsControl{
111 *this, "fuse-generic-ops-control",
112 llvm::cl::desc(
113 "Test fusion of generic operations with a control function."),
114 llvm::cl::init(false)};
116 Option<bool> fuseWithReshapeByExpansion{
117 *this, "fuse-with-reshape-by-expansion",
118 llvm::cl::desc(
119 "Test fusion of generic operations with reshape by expansion"),
120 llvm::cl::init(false)};
122 Option<bool> controlFuseByExpansion{
123 *this, "control-fusion-by-expansion",
124 llvm::cl::desc(
125 "Test controlling fusion of reshape with generic op by expansion"),
126 llvm::cl::init(false)};
128 Option<bool> fuseWithReshapeByCollapsing{
129 *this, "fuse-with-reshape-by-collapsing",
130 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
131 "collapse the iteration space of the consumer"),
132 llvm::cl::init(false)};
134 Option<bool> fuseWithReshapeByCollapsingWithControlFn{
135 *this, "fuse-with-reshape-by-collapsing-control",
136 llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
137 "fusion patterns that "
138 "collapse the iteration space of the consumer"),
139 llvm::cl::init(false)};
141 Option<bool> fuseMultiUseProducer{
142 *this, "fuse-multiuse-producer",
143 llvm::cl::desc("Test fusion of producer ops with multiple uses"),
144 llvm::cl::init(false)};
146 ListOption<int64_t> collapseDimensions{
147 *this, "collapse-dimensions-control",
148 llvm::cl::desc("Test controlling dimension collapse pattern")};
150 void runOnOperation() override {
151 MLIRContext *context = &this->getContext();
152 func::FuncOp funcOp = this->getOperation();
154 if (fuseGenericOps) {
155 RewritePatternSet fusionPatterns(context);
156 auto controlFn = [](OpOperand *operand) { return true; };
157 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
158 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
159 std::move(fusionPatterns))))
160 return signalPassFailure();
161 return;
164 if (fuseGenericOpsControl) {
165 RewritePatternSet fusionPatterns(context);
166 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
167 setFusedOpOperandLimit<4>);
169 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
170 std::move(fusionPatterns))))
171 return signalPassFailure();
172 return;
175 if (fuseWithReshapeByExpansion) {
176 RewritePatternSet fusionPatterns(context);
177 linalg::populateFoldReshapeOpsByExpansionPatterns(
178 fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
179 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
180 std::move(fusionPatterns))))
181 return signalPassFailure();
182 return;
185 if (controlFuseByExpansion) {
186 RewritePatternSet fusionPatterns(context);
188 linalg::ControlFusionFn controlReshapeFusionFn =
189 [](OpOperand *fusedOperand) {
190 Operation *producer = fusedOperand->get().getDefiningOp();
191 if (!producer)
192 return false;
194 if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
195 if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
196 return false;
200 Operation *consumer = fusedOperand->getOwner();
201 if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
202 if (expandOp->hasOneUse()) {
203 OpOperand &use = *expandOp->getUses().begin();
204 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
205 if (linalgOp && linalgOp.isDpsInit(&use))
206 return true;
208 return false;
210 return true;
213 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
214 controlReshapeFusionFn);
215 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
216 std::move(fusionPatterns))))
217 return signalPassFailure();
218 return;
221 if (fuseWithReshapeByCollapsing) {
222 RewritePatternSet patterns(context);
223 linalg::populateFoldReshapeOpsByCollapsingPatterns(
224 patterns, [](OpOperand * /*fusedOperand */) { return true; });
225 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
226 std::move(patterns))))
227 return signalPassFailure();
228 return;
231 if (fuseWithReshapeByCollapsingWithControlFn) {
232 RewritePatternSet patterns(context);
233 linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
234 Operation *producer = fusedOperand->get().getDefiningOp();
235 if (isa<tensor::ExpandShapeOp>(producer)) {
236 // Skip fusing the first operand.
237 return fusedOperand->getOperandNumber();
239 return true;
241 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
242 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
243 std::move(patterns))))
244 return signalPassFailure();
245 return;
248 if (fuseMultiUseProducer) {
249 RewritePatternSet patterns(context);
250 patterns.insert<TestMultiUseProducerFusion>(context);
251 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
252 std::move(patterns))))
253 return signalPassFailure();
254 return;
257 if (!collapseDimensions.empty()) {
258 SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
259 collapseDimensions.end());
260 linalg::GetCollapsableDimensionsFn collapseFn =
261 [&dims](linalg::LinalgOp op) {
262 SmallVector<ReassociationIndices> reassociations;
263 reassociations.emplace_back(dims);
264 return reassociations;
266 RewritePatternSet patterns(context);
267 linalg::populateCollapseDimensions(patterns, collapseFn);
268 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
269 std::move(patterns))))
270 return signalPassFailure();
271 return;
276 } // namespace
278 namespace mlir {
279 namespace test {
280 void registerTestLinalgElementwiseFusion() {
281 PassRegistration<TestLinalgElementwiseFusion>();
283 } // namespace test
284 } // namespace mlir