[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Rewrite / PatternApplicator.cpp
blobea43f8a147d479308a427cf4b737b6ac988b7696
1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "pattern-application"
20 using namespace mlir;
21 using namespace mlir::detail;
23 PatternApplicator::PatternApplicator(
24 const FrozenRewritePatternSet &frozenPatternList)
25 : frozenPatternList(frozenPatternList) {
26 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28 bytecode->initializeMutableState(*mutableByteCodeState);
31 PatternApplicator::~PatternApplicator() = default;
33 #ifndef NDEBUG
34 /// Log a message for a pattern that is impossible to match.
35 static void logImpossibleToMatch(const Pattern &pattern) {
36 llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37 << "' because it is impossible to match or cannot lead "
38 "to legal IR (by cost model)\n";
41 /// Log IR after pattern application.
42 static Operation *getDumpRootOp(Operation *op) {
43 Operation *isolatedParent =
44 op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
45 if (isolatedParent)
46 return isolatedParent;
47 return op;
49 static void logSucessfulPatternApplication(Operation *op) {
50 llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
51 op->dump();
52 llvm::dbgs() << "\n\n";
54 #endif
56 void PatternApplicator::applyCostModel(CostModel model) {
57 // Apply the cost model to the bytecode patterns first, and then the native
58 // patterns.
59 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
60 for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
61 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
64 // Copy over the patterns so that we can sort by benefit based on the cost
65 // model. Patterns that are already impossible to match are ignored.
66 patterns.clear();
67 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
68 for (const RewritePattern *pattern : it.second) {
69 if (pattern->getBenefit().isImpossibleToMatch())
70 LLVM_DEBUG(logImpossibleToMatch(*pattern));
71 else
72 patterns[it.first].push_back(pattern);
75 anyOpPatterns.clear();
76 for (const RewritePattern &pattern :
77 frozenPatternList.getMatchAnyOpNativePatterns()) {
78 if (pattern.getBenefit().isImpossibleToMatch())
79 LLVM_DEBUG(logImpossibleToMatch(pattern));
80 else
81 anyOpPatterns.push_back(&pattern);
84 // Sort the patterns using the provided cost model.
85 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
86 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
87 return benefits[lhs] > benefits[rhs];
89 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
90 // Special case for one pattern in the list, which is the most common case.
91 if (list.size() == 1) {
92 if (model(*list.front()).isImpossibleToMatch()) {
93 LLVM_DEBUG(logImpossibleToMatch(*list.front()));
94 list.clear();
96 return;
99 // Collect the dynamic benefits for the current pattern list.
100 benefits.clear();
101 for (const Pattern *pat : list)
102 benefits.try_emplace(pat, model(*pat));
104 // Sort patterns with highest benefit first, and remove those that are
105 // impossible to match.
106 std::stable_sort(list.begin(), list.end(), cmp);
107 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
108 LLVM_DEBUG(logImpossibleToMatch(*list.back()));
109 list.pop_back();
112 for (auto &it : patterns)
113 processPatternList(it.second);
114 processPatternList(anyOpPatterns);
117 void PatternApplicator::walkAllPatterns(
118 function_ref<void(const Pattern &)> walk) {
119 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
120 for (const auto &pattern : it.second)
121 walk(*pattern);
122 for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
123 walk(it);
124 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
125 for (const Pattern &it : bytecode->getPatterns())
126 walk(it);
130 LogicalResult PatternApplicator::matchAndRewrite(
131 Operation *op, PatternRewriter &rewriter,
132 function_ref<bool(const Pattern &)> canApply,
133 function_ref<void(const Pattern &)> onFailure,
134 function_ref<LogicalResult(const Pattern &)> onSuccess) {
135 // Before checking native patterns, first match against the bytecode. This
136 // won't automatically perform any rewrites so there is no need to worry about
137 // conflicts.
138 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
139 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
140 if (bytecode)
141 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
143 // Check to see if there are patterns matching this specific operation type.
144 MutableArrayRef<const RewritePattern *> opPatterns;
145 auto patternIt = patterns.find(op->getName());
146 if (patternIt != patterns.end())
147 opPatterns = patternIt->second;
149 // Process the patterns for that match the specific operation type, and any
150 // operation type in an interleaved fashion.
151 unsigned opIt = 0, opE = opPatterns.size();
152 unsigned anyIt = 0, anyE = anyOpPatterns.size();
153 unsigned pdlIt = 0, pdlE = pdlMatches.size();
154 LogicalResult result = failure();
155 do {
156 // Find the next pattern with the highest benefit.
157 const Pattern *bestPattern = nullptr;
158 unsigned *bestPatternIt = &opIt;
160 /// Operation specific patterns.
161 if (opIt < opE)
162 bestPattern = opPatterns[opIt];
163 /// Operation agnostic patterns.
164 if (anyIt < anyE &&
165 (!bestPattern ||
166 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
167 bestPatternIt = &anyIt;
168 bestPattern = anyOpPatterns[anyIt];
171 const PDLByteCode::MatchResult *pdlMatch = nullptr;
172 /// PDL patterns.
173 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
174 pdlMatches[pdlIt].benefit)) {
175 bestPatternIt = &pdlIt;
176 pdlMatch = &pdlMatches[pdlIt];
177 bestPattern = pdlMatch->pattern;
180 if (!bestPattern)
181 break;
183 // Update the pattern iterator on failure so that this pattern isn't
184 // attempted again.
185 ++(*bestPatternIt);
187 // Check that the pattern can be applied.
188 if (canApply && !canApply(*bestPattern))
189 continue;
191 // Try to match and rewrite this pattern. The patterns are sorted by
192 // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193 // match has already been performed, we just need to rewrite.
194 bool matched = false;
195 op->getContext()->executeAction<ApplyPatternAction>(
196 [&]() {
197 rewriter.setInsertionPoint(op);
198 #ifndef NDEBUG
199 // Operation `op` may be invalidated after applying the rewrite
200 // pattern.
201 Operation *dumpRootOp = getDumpRootOp(op);
202 #endif
203 if (pdlMatch) {
204 result =
205 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
206 } else {
207 LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208 << bestPattern->getDebugName() << "\"\n");
210 const auto *pattern =
211 static_cast<const RewritePattern *>(bestPattern);
212 result = pattern->matchAndRewrite(op, rewriter);
214 LLVM_DEBUG(llvm::dbgs()
215 << "\"" << bestPattern->getDebugName() << "\" result "
216 << succeeded(result) << "\n");
219 // Process the result of the pattern application.
220 if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
221 result = failure();
222 if (succeeded(result)) {
223 LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
224 matched = true;
225 return;
228 // Perform any necessary cleanups.
229 if (onFailure)
230 onFailure(*bestPattern);
232 {op}, *bestPattern);
233 if (matched)
234 break;
235 } while (true);
237 if (mutableByteCodeState)
238 mutableByteCodeState->cleanupAfterMatchAndRewrite();
239 return result;