Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"
[llvm-project.git] / mlir / unittests / Rewrite / PatternBenefit.cpp
blob65ea4ee6683db8d0156a6b86857f7ba4b324a00b
1 //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/OwningOpRef.h"
10 #include "mlir/IR/PatternMatch.h"
11 #include "mlir/Rewrite/PatternApplicator.h"
12 #include "gtest/gtest.h"
14 using namespace mlir;
16 namespace {
17 TEST(PatternBenefitTest, BenefitOrder) {
18 // There was a bug which caused low-benefit op-specific patterns to never be
19 // called in presence of high-benefit op-agnostic pattern
21 MLIRContext context;
23 OpBuilder builder(&context);
24 OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
26 struct Pattern1 : public OpRewritePattern<ModuleOp> {
27 Pattern1(mlir::MLIRContext *context, bool *called)
28 : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
30 llvm::LogicalResult
31 matchAndRewrite(ModuleOp /*op*/,
32 mlir::PatternRewriter & /*rewriter*/) const override {
33 *called = true;
34 return failure();
37 private:
38 bool *called;
41 struct Pattern2 : public RewritePattern {
42 Pattern2(MLIRContext *context, bool *called)
43 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
44 called(called) {}
46 llvm::LogicalResult
47 matchAndRewrite(Operation * /*op*/,
48 mlir::PatternRewriter & /*rewriter*/) const override {
49 *called = true;
50 return failure();
53 private:
54 bool *called;
57 RewritePatternSet patterns(&context);
59 bool called1 = false;
60 bool called2 = false;
62 patterns.add<Pattern1>(&context, &called1);
63 patterns.add<Pattern2>(&context, &called2);
65 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
66 PatternApplicator pa(frozenPatterns);
67 pa.applyDefaultCostModel();
69 class MyPatternRewriter : public PatternRewriter {
70 public:
71 MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
74 MyPatternRewriter rewriter(&context);
75 (void)pa.matchAndRewrite(*module, rewriter);
77 EXPECT_TRUE(called1);
78 EXPECT_TRUE(called2);
80 } // namespace