1 //===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "gtest/gtest.h"
12 #include "../../test/lib/Dialect/Test/TestDialect.h"
17 struct AnOpRewritePattern
: OpRewritePattern
<test::OpA
> {
18 AnOpRewritePattern(MLIRContext
*context
)
19 : OpRewritePattern(context
, /*benefit=*/1,
20 /*generatedNames=*/{test::OpB::getOperationName()}) {}
22 TEST(OpRewritePatternTest
, GetGeneratedNames
) {
24 AnOpRewritePattern
pattern(&context
);
25 ArrayRef
<OperationName
> ops
= pattern
.getGeneratedOps();
27 ASSERT_EQ(ops
.size(), 1u);
28 ASSERT_EQ(ops
.front().getStringRef(), test::OpB::getOperationName());
30 } // end anonymous namespace
33 LogicalResult
anOpRewritePatternFunc(test::OpA op
, PatternRewriter
&rewriter
) {
36 TEST(AnOpRewritePatternTest
, PatternFuncAttributes
) {
38 RewritePatternSet
patterns(&context
);
40 patterns
.add(anOpRewritePatternFunc
, /*benefit=*/3,
41 /*generatedNames=*/{test::OpB::getOperationName()});
42 ASSERT_EQ(patterns
.getNativePatterns().size(), 1U);
43 auto &pattern
= patterns
.getNativePatterns().front();
44 ASSERT_EQ(pattern
->getBenefit(), 3);
45 ASSERT_EQ(pattern
->getGeneratedOps().size(), 1U);
46 ASSERT_EQ(pattern
->getGeneratedOps().front().getStringRef(),
47 test::OpB::getOperationName());
49 } // end anonymous namespace