1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Parser/Parser.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "gtest/gtest.h"
20 struct DisabledPattern
: public RewritePattern
{
21 DisabledPattern(MLIRContext
*context
)
22 : RewritePattern("test.foo", /*benefit=*/0, context
,
23 /*generatedNamed=*/{}) {
24 setDebugName("DisabledPattern");
27 LogicalResult
matchAndRewrite(Operation
*op
,
28 PatternRewriter
&rewriter
) const override
{
29 if (op
->getNumResults() != 1)
36 struct EnabledPattern
: public RewritePattern
{
37 EnabledPattern(MLIRContext
*context
)
38 : RewritePattern("test.foo", /*benefit=*/0, context
,
39 /*generatedNamed=*/{}) {
40 setDebugName("EnabledPattern");
43 LogicalResult
matchAndRewrite(Operation
*op
,
44 PatternRewriter
&rewriter
) const override
{
45 if (op
->getNumResults() == 1)
52 struct TestDialect
: public Dialect
{
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect
)
55 static StringRef
getDialectNamespace() { return "test"; }
57 TestDialect(MLIRContext
*context
)
58 : Dialect(getDialectNamespace(), context
, TypeID::get
<TestDialect
>()) {
59 allowUnknownOperations();
62 void getCanonicalizationPatterns(RewritePatternSet
&results
) const override
{
63 results
.add
<DisabledPattern
, EnabledPattern
>(results
.getContext());
67 TEST(CanonicalizerTest
, TestDisablePatterns
) {
69 context
.getOrLoadDialect
<TestDialect
>();
70 PassManager
mgr(&context
);
72 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
74 const char *const code
= R
"mlir(
75 %0:2 = "test
.foo
"() {sym_name = "A
"} : () -> (i32, i32)
76 %1 = "test
.foo
"() {sym_name = "B
"} : () -> (f32)
79 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(code
, &context
);
80 ASSERT_TRUE(succeeded(mgr
.run(*module
)));
82 EXPECT_TRUE(module
->lookupSymbol("B"));
83 EXPECT_FALSE(module
->lookupSymbol("A"));
86 } // end anonymous namespace