1 //===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass for testing fusion of pad ops with its producer
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 struct TestPadFusionPass
24 : public PassWrapper
<TestPadFusionPass
, OperationPass
<>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPadFusionPass
)
27 void getDependentDialects(DialectRegistry
®istry
) const override
{
28 registry
.insert
<affine::AffineDialect
, linalg::LinalgDialect
,
29 tensor::TensorDialect
>();
32 StringRef
getArgument() const final
{ return "test-linalg-pad-fusion"; }
33 StringRef
getDescription() const final
{ return "Test PadOp fusion"; }
35 void runOnOperation() override
{
36 MLIRContext
*context
= &getContext();
37 RewritePatternSet
patterns(context
);
38 linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns
);
39 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns
))))
40 return signalPassFailure();
47 void registerTestPadFusion() { PassRegistration
<TestPadFusionPass
>(); }