1 //===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass for testing rank reduing patterns for named
10 // contraction ops with unit dims.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 struct TestLinalgRankReduceContractionOps
26 : public PassWrapper
<TestLinalgRankReduceContractionOps
,
27 OperationPass
<func::FuncOp
>> {
28 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
29 TestLinalgRankReduceContractionOps
)
31 TestLinalgRankReduceContractionOps() = default;
32 TestLinalgRankReduceContractionOps(
33 const TestLinalgRankReduceContractionOps
&pass
)
34 : PassWrapper(pass
) {}
35 void getDependentDialects(DialectRegistry
®istry
) const override
{
36 registry
.insert
<affine::AffineDialect
, linalg::LinalgDialect
,
37 memref::MemRefDialect
, tensor::TensorDialect
>();
39 StringRef
getArgument() const final
{
40 return "test-linalg-rank-reduce-contraction-ops";
42 StringRef
getDescription() const final
{
43 return "Test Linalg rank reduce contraction ops with unit dims";
46 void runOnOperation() override
{
47 MLIRContext
*context
= &this->getContext();
48 func::FuncOp funcOp
= this->getOperation();
50 RewritePatternSet
patterns(context
);
51 linalg::populateContractionOpRankReducingPatterns(patterns
);
52 if (failed(applyPatternsGreedily(funcOp
.getBody(), std::move(patterns
))))
53 return signalPassFailure();
62 void registerTestLinalgRankReduceContractionOps() {
63 PassRegistration
<TestLinalgRankReduceContractionOps
>();