[mlir] Update Ch-2.md (#121379)
[llvm-project.git] / mlir / test / lib / Dialect / Linalg / TestLinalgRankReduceContractionOps.cpp
blob750ba6b5d9872662f0968690d17a4d7374fe0756
1 //===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing rank reduing patterns for named
10 // contraction ops with unit dims.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 using namespace mlir;
23 namespace {
25 struct TestLinalgRankReduceContractionOps
26 : public PassWrapper<TestLinalgRankReduceContractionOps,
27 OperationPass<func::FuncOp>> {
28 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
29 TestLinalgRankReduceContractionOps)
31 TestLinalgRankReduceContractionOps() = default;
32 TestLinalgRankReduceContractionOps(
33 const TestLinalgRankReduceContractionOps &pass)
34 : PassWrapper(pass) {}
35 void getDependentDialects(DialectRegistry &registry) const override {
36 registry.insert<affine::AffineDialect, linalg::LinalgDialect,
37 memref::MemRefDialect, tensor::TensorDialect>();
39 StringRef getArgument() const final {
40 return "test-linalg-rank-reduce-contraction-ops";
42 StringRef getDescription() const final {
43 return "Test Linalg rank reduce contraction ops with unit dims";
46 void runOnOperation() override {
47 MLIRContext *context = &this->getContext();
48 func::FuncOp funcOp = this->getOperation();
50 RewritePatternSet patterns(context);
51 linalg::populateContractionOpRankReducingPatterns(patterns);
52 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
53 return signalPassFailure();
54 return;
58 } // namespace
60 namespace mlir {
61 namespace test {
62 void registerTestLinalgRankReduceContractionOps() {
63 PassRegistration<TestLinalgRankReduceContractionOps>();
65 } // namespace test
66 } // namespace mlir