1 //===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "gtest/gtest.h"
12 #include "../../test/lib/Dialect/Test/TestDialect.h"
13 #include "../../test/lib/Dialect/Test/TestOps.h"
18 struct AnOpRewritePattern
: OpRewritePattern
<test::OpA
> {
19 AnOpRewritePattern(MLIRContext
*context
)
20 : OpRewritePattern(context
, /*benefit=*/1,
21 /*generatedNames=*/{test::OpB::getOperationName()}) {}
23 TEST(OpRewritePatternTest
, GetGeneratedNames
) {
25 AnOpRewritePattern
pattern(&context
);
26 ArrayRef
<OperationName
> ops
= pattern
.getGeneratedOps();
28 ASSERT_EQ(ops
.size(), 1u);
29 ASSERT_EQ(ops
.front().getStringRef(), test::OpB::getOperationName());
31 } // end anonymous namespace
34 LogicalResult
anOpRewritePatternFunc(test::OpA op
, PatternRewriter
&rewriter
) {
37 TEST(AnOpRewritePatternTest
, PatternFuncAttributes
) {
39 RewritePatternSet
patterns(&context
);
41 patterns
.add(anOpRewritePatternFunc
, /*benefit=*/3,
42 /*generatedNames=*/{test::OpB::getOperationName()});
43 ASSERT_EQ(patterns
.getNativePatterns().size(), 1U);
44 auto &pattern
= patterns
.getNativePatterns().front();
45 ASSERT_EQ(pattern
->getBenefit(), 3);
46 ASSERT_EQ(pattern
->getGeneratedOps().size(), 1U);
47 ASSERT_EQ(pattern
->getGeneratedOps().front().getStringRef(),
48 test::OpB::getOperationName());
50 } // end anonymous namespace