1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "pattern-application"
21 using namespace mlir::detail
;
23 PatternApplicator::PatternApplicator(
24 const FrozenRewritePatternSet
&frozenPatternList
)
25 : frozenPatternList(frozenPatternList
) {
26 if (const PDLByteCode
*bytecode
= frozenPatternList
.getPDLByteCode()) {
27 mutableByteCodeState
= std::make_unique
<PDLByteCodeMutableState
>();
28 bytecode
->initializeMutableState(*mutableByteCodeState
);
31 PatternApplicator::~PatternApplicator() = default;
34 /// Log a message for a pattern that is impossible to match.
35 static void logImpossibleToMatch(const Pattern
&pattern
) {
36 llvm::dbgs() << "Ignoring pattern '" << pattern
.getRootKind()
37 << "' because it is impossible to match or cannot lead "
38 "to legal IR (by cost model)\n";
41 /// Log IR after pattern application.
42 static Operation
*getDumpRootOp(Operation
*op
) {
43 Operation
*isolatedParent
=
44 op
->getParentWithTrait
<mlir::OpTrait::IsIsolatedFromAbove
>();
46 return isolatedParent
;
49 static void logSucessfulPatternApplication(Operation
*op
) {
50 llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
52 llvm::dbgs() << "\n\n";
56 void PatternApplicator::applyCostModel(CostModel model
) {
57 // Apply the cost model to the bytecode patterns first, and then the native
59 if (const PDLByteCode
*bytecode
= frozenPatternList
.getPDLByteCode()) {
60 for (const auto &it
: llvm::enumerate(bytecode
->getPatterns()))
61 mutableByteCodeState
->updatePatternBenefit(it
.index(), model(it
.value()));
64 // Copy over the patterns so that we can sort by benefit based on the cost
65 // model. Patterns that are already impossible to match are ignored.
67 for (const auto &it
: frozenPatternList
.getOpSpecificNativePatterns()) {
68 for (const RewritePattern
*pattern
: it
.second
) {
69 if (pattern
->getBenefit().isImpossibleToMatch())
70 LLVM_DEBUG(logImpossibleToMatch(*pattern
));
72 patterns
[it
.first
].push_back(pattern
);
75 anyOpPatterns
.clear();
76 for (const RewritePattern
&pattern
:
77 frozenPatternList
.getMatchAnyOpNativePatterns()) {
78 if (pattern
.getBenefit().isImpossibleToMatch())
79 LLVM_DEBUG(logImpossibleToMatch(pattern
));
81 anyOpPatterns
.push_back(&pattern
);
84 // Sort the patterns using the provided cost model.
85 llvm::SmallDenseMap
<const Pattern
*, PatternBenefit
> benefits
;
86 auto cmp
= [&benefits
](const Pattern
*lhs
, const Pattern
*rhs
) {
87 return benefits
[lhs
] > benefits
[rhs
];
89 auto processPatternList
= [&](SmallVectorImpl
<const RewritePattern
*> &list
) {
90 // Special case for one pattern in the list, which is the most common case.
91 if (list
.size() == 1) {
92 if (model(*list
.front()).isImpossibleToMatch()) {
93 LLVM_DEBUG(logImpossibleToMatch(*list
.front()));
99 // Collect the dynamic benefits for the current pattern list.
101 for (const Pattern
*pat
: list
)
102 benefits
.try_emplace(pat
, model(*pat
));
104 // Sort patterns with highest benefit first, and remove those that are
105 // impossible to match.
106 std::stable_sort(list
.begin(), list
.end(), cmp
);
107 while (!list
.empty() && benefits
[list
.back()].isImpossibleToMatch()) {
108 LLVM_DEBUG(logImpossibleToMatch(*list
.back()));
112 for (auto &it
: patterns
)
113 processPatternList(it
.second
);
114 processPatternList(anyOpPatterns
);
117 void PatternApplicator::walkAllPatterns(
118 function_ref
<void(const Pattern
&)> walk
) {
119 for (const auto &it
: frozenPatternList
.getOpSpecificNativePatterns())
120 for (const auto &pattern
: it
.second
)
122 for (const Pattern
&it
: frozenPatternList
.getMatchAnyOpNativePatterns())
124 if (const PDLByteCode
*bytecode
= frozenPatternList
.getPDLByteCode()) {
125 for (const Pattern
&it
: bytecode
->getPatterns())
130 LogicalResult
PatternApplicator::matchAndRewrite(
131 Operation
*op
, PatternRewriter
&rewriter
,
132 function_ref
<bool(const Pattern
&)> canApply
,
133 function_ref
<void(const Pattern
&)> onFailure
,
134 function_ref
<LogicalResult(const Pattern
&)> onSuccess
) {
135 // Before checking native patterns, first match against the bytecode. This
136 // won't automatically perform any rewrites so there is no need to worry about
138 SmallVector
<PDLByteCode::MatchResult
, 4> pdlMatches
;
139 const PDLByteCode
*bytecode
= frozenPatternList
.getPDLByteCode();
141 bytecode
->match(op
, rewriter
, pdlMatches
, *mutableByteCodeState
);
143 // Check to see if there are patterns matching this specific operation type.
144 MutableArrayRef
<const RewritePattern
*> opPatterns
;
145 auto patternIt
= patterns
.find(op
->getName());
146 if (patternIt
!= patterns
.end())
147 opPatterns
= patternIt
->second
;
149 // Process the patterns for that match the specific operation type, and any
150 // operation type in an interleaved fashion.
151 unsigned opIt
= 0, opE
= opPatterns
.size();
152 unsigned anyIt
= 0, anyE
= anyOpPatterns
.size();
153 unsigned pdlIt
= 0, pdlE
= pdlMatches
.size();
154 LogicalResult result
= failure();
156 // Find the next pattern with the highest benefit.
157 const Pattern
*bestPattern
= nullptr;
158 unsigned *bestPatternIt
= &opIt
;
160 /// Operation specific patterns.
162 bestPattern
= opPatterns
[opIt
];
163 /// Operation agnostic patterns.
166 bestPattern
->getBenefit() < anyOpPatterns
[anyIt
]->getBenefit())) {
167 bestPatternIt
= &anyIt
;
168 bestPattern
= anyOpPatterns
[anyIt
];
171 const PDLByteCode::MatchResult
*pdlMatch
= nullptr;
173 if (pdlIt
< pdlE
&& (!bestPattern
|| bestPattern
->getBenefit() <
174 pdlMatches
[pdlIt
].benefit
)) {
175 bestPatternIt
= &pdlIt
;
176 pdlMatch
= &pdlMatches
[pdlIt
];
177 bestPattern
= pdlMatch
->pattern
;
183 // Update the pattern iterator on failure so that this pattern isn't
187 // Check that the pattern can be applied.
188 if (canApply
&& !canApply(*bestPattern
))
191 // Try to match and rewrite this pattern. The patterns are sorted by
192 // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193 // match has already been performed, we just need to rewrite.
194 bool matched
= false;
195 op
->getContext()->executeAction
<ApplyPatternAction
>(
197 rewriter
.setInsertionPoint(op
);
199 // Operation `op` may be invalidated after applying the rewrite
201 Operation
*dumpRootOp
= getDumpRootOp(op
);
205 bytecode
->rewrite(rewriter
, *pdlMatch
, *mutableByteCodeState
);
207 LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208 << bestPattern
->getDebugName() << "\"\n");
210 const auto *pattern
=
211 static_cast<const RewritePattern
*>(bestPattern
);
212 result
= pattern
->matchAndRewrite(op
, rewriter
);
214 LLVM_DEBUG(llvm::dbgs()
215 << "\"" << bestPattern
->getDebugName() << "\" result "
216 << succeeded(result
) << "\n");
219 // Process the result of the pattern application.
220 if (succeeded(result
) && onSuccess
&& failed(onSuccess(*bestPattern
)))
222 if (succeeded(result
)) {
223 LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp
));
228 // Perform any necessary cleanups.
230 onFailure(*bestPattern
);
237 if (mutableByteCodeState
)
238 mutableByteCodeState
->cleanupAfterMatchAndRewrite();