1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/OwningOpRef.h"
10 #include "mlir/IR/PatternMatch.h"
11 #include "mlir/Rewrite/PatternApplicator.h"
12 #include "gtest/gtest.h"
17 TEST(PatternBenefitTest
, BenefitOrder
) {
18 // There was a bug which caused low-benefit op-specific patterns to never be
19 // called in presence of high-benefit op-agnostic pattern
23 OpBuilder
builder(&context
);
24 OwningOpRef
<ModuleOp
> module
= ModuleOp::create(builder
.getUnknownLoc());
26 struct Pattern1
: public OpRewritePattern
<ModuleOp
> {
27 Pattern1(mlir::MLIRContext
*context
, bool *called
)
28 : OpRewritePattern
<ModuleOp
>(context
, /*benefit*/ 1), called(called
) {}
31 matchAndRewrite(ModuleOp
/*op*/,
32 mlir::PatternRewriter
& /*rewriter*/) const override
{
41 struct Pattern2
: public RewritePattern
{
42 Pattern2(MLIRContext
*context
, bool *called
)
43 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context
),
47 matchAndRewrite(Operation
* /*op*/,
48 mlir::PatternRewriter
& /*rewriter*/) const override
{
57 RewritePatternSet
patterns(&context
);
62 patterns
.add
<Pattern1
>(&context
, &called1
);
63 patterns
.add
<Pattern2
>(&context
, &called2
);
65 FrozenRewritePatternSet
frozenPatterns(std::move(patterns
));
66 PatternApplicator
pa(frozenPatterns
);
67 pa
.applyDefaultCostModel();
69 class MyPatternRewriter
: public PatternRewriter
{
71 MyPatternRewriter(MLIRContext
*ctx
) : PatternRewriter(ctx
) {}
74 MyPatternRewriter
rewriter(&context
);
75 (void)pa
.matchAndRewrite(*module
, rewriter
);