1 //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements mlir::walkAndApplyPatterns.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/OperationSupport.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/IR/Visitors.h"
20 #include "mlir/Rewrite/PatternApplicator.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
24 #define DEBUG_TYPE "walk-rewriter"
29 struct WalkAndApplyPatternsAction final
30 : tracing::ActionImpl
<WalkAndApplyPatternsAction
> {
31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction
)
32 using ActionImpl::ActionImpl
;
33 static constexpr StringLiteral tag
= "walk-and-apply-patterns";
34 void print(raw_ostream
&os
) const override
{ os
<< tag
; }
37 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38 // Forwarding listener to guard against unsupported erasures of non-descendant
39 // ops/blocks. Because we use walk-based pattern application, erasing the
40 // op/block from the *next* iteration (e.g., a user of the visited op) is not
41 // valid. Note that this is only used with expensive pattern API checks.
42 struct ErasedOpsListener final
: RewriterBase::ForwardingListener
{
43 using RewriterBase::ForwardingListener::ForwardingListener
;
45 void notifyOperationErased(Operation
*op
) override
{
47 ForwardingListener::notifyOperationErased(op
);
50 void notifyBlockErased(Block
*block
) override
{
51 checkErasure(block
->getParentOp());
52 ForwardingListener::notifyBlockErased(block
);
55 void checkErasure(Operation
*op
) const {
56 Operation
*ancestorOp
= op
;
57 while (ancestorOp
&& ancestorOp
!= visitedOp
)
58 ancestorOp
= ancestorOp
->getParentOp();
60 if (ancestorOp
!= visitedOp
)
61 llvm::report_fatal_error(
62 "unsupported erasure in WalkPatternRewriter; "
63 "erasure is only supported for matched ops and their descendants");
66 Operation
*visitedOp
= nullptr;
68 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
71 void walkAndApplyPatterns(Operation
*op
,
72 const FrozenRewritePatternSet
&patterns
,
73 RewriterBase::Listener
*listener
) {
74 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75 if (failed(verify(op
)))
76 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
79 MLIRContext
*ctx
= op
->getContext();
80 PatternRewriter
rewriter(ctx
);
81 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82 ErasedOpsListener
erasedListener(listener
);
83 rewriter
.setListener(&erasedListener
);
85 rewriter
.setListener(listener
);
86 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
88 PatternApplicator
applicator(patterns
);
89 applicator
.applyDefaultCostModel();
91 ctx
->executeAction
<WalkAndApplyPatternsAction
>(
93 for (Region
®ion
: op
->getRegions()) {
94 region
.walk([&](Operation
*visitedOp
) {
95 LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp
->print(
96 llvm::dbgs(), OpPrintingFlags().skipRegions());
97 llvm::dbgs() << "\n";);
98 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99 erasedListener
.visitedOp
= visitedOp
;
100 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101 if (succeeded(applicator
.matchAndRewrite(visitedOp
, rewriter
))) {
102 LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
109 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110 if (failed(verify(op
)))
111 llvm::report_fatal_error(
112 "walk pattern rewriter result IR failed to verify");
113 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS