[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / unittests / Transforms / Canonicalizer.cpp
blob4b94e0602b5097f62efcf6886ad197d5114a7c88
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Parser/Parser.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "gtest/gtest.h"
16 using namespace mlir;
18 namespace {
20 struct DisabledPattern : public RewritePattern {
21 DisabledPattern(MLIRContext *context)
22 : RewritePattern("test.foo", /*benefit=*/0, context,
23 /*generatedNamed=*/{}) {
24 setDebugName("DisabledPattern");
27 LogicalResult matchAndRewrite(Operation *op,
28 PatternRewriter &rewriter) const override {
29 if (op->getNumResults() != 1)
30 return failure();
31 rewriter.eraseOp(op);
32 return success();
36 struct EnabledPattern : public RewritePattern {
37 EnabledPattern(MLIRContext *context)
38 : RewritePattern("test.foo", /*benefit=*/0, context,
39 /*generatedNamed=*/{}) {
40 setDebugName("EnabledPattern");
43 LogicalResult matchAndRewrite(Operation *op,
44 PatternRewriter &rewriter) const override {
45 if (op->getNumResults() == 1)
46 return failure();
47 rewriter.eraseOp(op);
48 return success();
52 struct TestDialect : public Dialect {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
55 static StringRef getDialectNamespace() { return "test"; }
57 TestDialect(MLIRContext *context)
58 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
59 allowUnknownOperations();
62 void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63 results.add<DisabledPattern, EnabledPattern>(results.getContext());
67 TEST(CanonicalizerTest, TestDisablePatterns) {
68 MLIRContext context;
69 context.getOrLoadDialect<TestDialect>();
70 PassManager mgr(&context);
71 mgr.addPass(
72 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
74 const char *const code = R"mlir(
75 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76 %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77 )mlir";
79 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
80 ASSERT_TRUE(succeeded(mgr.run(*module)));
82 EXPECT_TRUE(module->lookupSymbol("B"));
83 EXPECT_FALSE(module->lookupSymbol("A"));
86 } // end anonymous namespace