1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements mlir::applyPatternsGreedily.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Rewrite/PatternApplicator.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
37 #define DEBUG_TYPE "greedy-rewriter"
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
53 /// This struct stores finger prints of ops to determine whether the IR has
55 struct ExpensiveChecks
: public RewriterBase::ForwardingListener
{
56 ExpensiveChecks(RewriterBase::Listener
*driver
, Operation
*topLevel
)
57 : RewriterBase::ForwardingListener(driver
), topLevel(topLevel
) {}
59 /// Compute finger prints of the given op and its nested ops.
60 void computeFingerPrints(Operation
*topLevel
) {
61 this->topLevel
= topLevel
;
62 this->topLevelFingerPrint
.emplace(topLevel
);
63 topLevel
->walk([&](Operation
*op
) {
64 fingerprints
.try_emplace(op
, op
, /*includeNested=*/false);
68 /// Clear all finger prints.
71 topLevelFingerPrint
.reset();
75 void notifyRewriteSuccess() {
79 // Make sure that the IR still verifies.
80 if (failed(verify(topLevel
)))
81 llvm::report_fatal_error("IR failed to verify after pattern application");
83 // Pattern application success => IR must have changed.
84 OperationFingerPrint
afterFingerPrint(topLevel
);
85 if (*topLevelFingerPrint
== afterFingerPrint
) {
86 // Note: Run "mlir-opt -debug" to see which pattern is broken.
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
90 for (const auto &it
: fingerprints
) {
91 // Skip top-level op, its finger print is never invalidated.
92 if (it
.first
== topLevel
)
94 // Note: Finger print computation may crash when an op was erased
95 // without notifying the rewriter. (Run with ASAN to see where the op was
96 // erased; the op was probably erased directly, bypassing the rewriter
97 // API.) Finger print computation does may not crash if a new op was
98 // created at the same memory location. (But then the finger print should
101 OperationFingerPrint(it
.first
, /*includeNested=*/false)) {
102 // Note: Run "mlir-opt -debug" to see which pattern is broken.
103 llvm::report_fatal_error("operation finger print changed");
108 void notifyRewriteFailure() {
112 // Pattern application failure => IR must not have changed.
113 OperationFingerPrint
afterFingerPrint(topLevel
);
114 if (*topLevelFingerPrint
!= afterFingerPrint
) {
115 // Note: Run "mlir-opt -debug" to see which pattern is broken.
116 llvm::report_fatal_error("pattern returned failure but IR did change");
120 void notifyFoldingSuccess() {
124 // Make sure that the IR still verifies.
125 if (failed(verify(topLevel
)))
126 llvm::report_fatal_error("IR failed to verify after folding");
130 /// Invalidate the finger print of the given op, i.e., remove it from the map.
131 void invalidateFingerPrint(Operation
*op
) { fingerprints
.erase(op
); }
133 void notifyBlockErased(Block
*block
) override
{
134 RewriterBase::ForwardingListener::notifyBlockErased(block
);
136 // The block structure (number of blocks, types of block arguments, etc.)
137 // is part of the fingerprint of the parent op.
138 // TODO: The parent op fingerprint should also be invalidated when modifying
139 // the block arguments of a block, but we do not have a
140 // `notifyBlockModified` callback yet.
141 invalidateFingerPrint(block
->getParentOp());
144 void notifyOperationInserted(Operation
*op
,
145 OpBuilder::InsertPoint previous
) override
{
146 RewriterBase::ForwardingListener::notifyOperationInserted(op
, previous
);
147 invalidateFingerPrint(op
->getParentOp());
150 void notifyOperationModified(Operation
*op
) override
{
151 RewriterBase::ForwardingListener::notifyOperationModified(op
);
152 invalidateFingerPrint(op
);
155 void notifyOperationErased(Operation
*op
) override
{
156 RewriterBase::ForwardingListener::notifyOperationErased(op
);
157 op
->walk([this](Operation
*op
) { invalidateFingerPrint(op
); });
160 /// Operation finger prints to detect invalid pattern API usage. IR is checked
161 /// against these finger prints after pattern application to detect cases
162 /// where IR was modified directly, bypassing the rewriter API.
163 DenseMap
<Operation
*, OperationFingerPrint
> fingerprints
;
165 /// Top-level operation of the current greedy rewrite.
166 Operation
*topLevel
= nullptr;
168 /// Finger print of the top-level operation.
169 std::optional
<OperationFingerPrint
> topLevelFingerPrint
;
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
174 static Operation
*getDumpRootOp(Operation
*op
) {
175 // Dump the parent op so that materialized constants are visible. If the op
176 // is a top-level op, dump it directly.
177 if (Operation
*parentOp
= op
->getParentOp())
181 static void logSuccessfulFolding(Operation
*op
) {
182 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
184 llvm::dbgs() << "\n\n";
188 //===----------------------------------------------------------------------===//
190 //===----------------------------------------------------------------------===//
192 /// A LIFO worklist of operations with efficient removal and set semantics.
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
202 /// Clear the worklist.
205 /// Return whether the worklist is empty.
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation
*op
);
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
216 /// Remove an operation from the worklist.
217 void remove(Operation
*op
);
219 /// Reverse the worklist.
223 /// The worklist of operations.
224 std::vector
<Operation
*> list
;
226 /// A mapping of operations to positions in `list`.
227 DenseMap
<Operation
*, unsigned> map
;
230 Worklist::Worklist() { list
.reserve(64); }
232 void Worklist::clear() {
237 bool Worklist::empty() const {
239 return !llvm::any_of(list
,
240 [](Operation
*op
) { return static_cast<bool>(op
); });
243 void Worklist::push(Operation
*op
) {
244 assert(op
&& "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map
.insert({op
, list
.size()}).second
)
251 Operation
*Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
256 Operation
*op
= list
.back();
259 // Cleanup: Remove all trailing nullptr.
260 while (!list
.empty() && !list
.back())
265 void Worklist::remove(Operation
*op
) {
266 assert(op
&& "cannot remove nullptr from worklist");
267 auto it
= map
.find(op
);
268 if (it
!= map
.end()) {
269 assert(list
[it
->second
] == op
&& "malformed worklist data structure");
270 list
[it
->second
] = nullptr;
275 void Worklist::reverse() {
276 std::reverse(list
.begin(), list
.end());
277 for (size_t i
= 0, e
= list
.size(); i
!= e
; ++i
)
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist
: public Worklist
{
288 RandomizedWorklist() : Worklist() {
289 generator
.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
);
292 /// Pop a random non-empty op from the worklist.
294 Operation
*op
= nullptr;
296 assert(!list
.empty() && "cannot pop from empty worklist");
297 int64_t pos
= generator() % list
.size();
299 list
.erase(list
.begin() + pos
);
300 for (int64_t i
= pos
, e
= list
.size(); i
< e
; ++i
)
308 std::minstd_rand0 generator
;
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver
: public RewriterBase::Listener
{
324 explicit GreedyPatternRewriteDriver(MLIRContext
*ctx
,
325 const FrozenRewritePatternSet
&patterns
,
326 const GreedyRewriteConfig
&config
);
328 /// Add the given operation to the worklist.
329 void addSingleOpToWorklist(Operation
*op
);
331 /// Add the given operation and its ancestors to the worklist.
332 void addToWorklist(Operation
*op
);
334 /// Notify the driver that the specified operation may have been modified
335 /// in-place. The operation is added to the worklist.
336 void notifyOperationModified(Operation
*op
) override
;
338 /// Notify the driver that the specified operation was inserted. Update the
339 /// worklist as needed: The operation is enqueued depending on scope and
341 void notifyOperationInserted(Operation
*op
,
342 OpBuilder::InsertPoint previous
) override
;
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
347 void notifyOperationErased(Operation
*op
) override
;
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation
*op
, ValueRange replacement
) override
;
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
357 /// The pattern rewriter that is used for making IR modifications and is
358 /// passed to rewrite patterns.
359 PatternRewriter rewriter
;
361 /// The worklist for this transformation keeps track of the operations that
362 /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist
;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
369 /// Configuration information for how to simplify.
370 const GreedyRewriteConfig config
;
372 /// The list of ops we are restricting our rewrites to. These include the
373 /// supplied set of ops as well as new ops created while rewriting those ops
374 /// depending on `strictMode`. This set is not maintained when
375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376 llvm::SmallDenseSet
<Operation
*, 4> strictModeFilteredOps
;
379 /// Look over the provided operands for any defining operations that should
380 /// be re-added to the worklist. This function should be called when an
381 /// operation is modified or removed, as it may trigger further
383 void addOperandsToWorklist(Operation
*op
);
385 /// Notify the driver that the given block was inserted.
386 void notifyBlockInserted(Block
*block
, Region
*previous
,
387 Region::iterator previousIt
) override
;
389 /// Notify the driver that the given block is about to be removed.
390 void notifyBlockErased(Block
*block
) override
;
392 /// For debugging only: Notify the driver of a pattern match failure.
394 notifyMatchFailure(Location loc
,
395 function_ref
<void(Diagnostic
&)> reasonCallback
) override
;
398 /// A logger used to emit information during the application process.
399 llvm::ScopedPrinter logger
{llvm::dbgs()};
402 /// The low-level pattern applicator.
403 PatternApplicator matcher
;
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406 ExpensiveChecks expensiveChecks
;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412 MLIRContext
*ctx
, const FrozenRewritePatternSet
&patterns
,
413 const GreedyRewriteConfig
&config
)
414 : rewriter(ctx
), config(config
), matcher(patterns
)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
419 /*topLevel=*/config
.scope
? config
.scope
->getParentOp() : nullptr)
421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
423 // Apply a simple cost model based solely on pattern benefit.
424 matcher
.applyDefaultCostModel();
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428 // Send IR notifications to the debug handler. This handler will then forward
429 // all notifications to this GreedyPatternRewriteDriver.
430 rewriter
.setListener(&expensiveChecks
);
432 rewriter
.setListener(this);
433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
436 bool GreedyPatternRewriteDriver::processWorklist() {
438 const char *logLineComment
=
439 "//===-------------------------------------------===//\n";
441 /// A utility function to log a process result for the given reason.
442 auto logResult
= [&](StringRef result
, const llvm::Twine
&msg
= {}) {
444 logger
.startLine() << "} -> " << result
;
445 if (!msg
.isTriviallyEmpty())
446 logger
.getOStream() << " : " << msg
;
447 logger
.getOStream() << "\n";
449 auto logResultWithLine
= [&](StringRef result
, const llvm::Twine
&msg
= {}) {
450 logResult(result
, msg
);
451 logger
.startLine() << logLineComment
;
455 bool changed
= false;
456 int64_t numRewrites
= 0;
457 while (!worklist
.empty() &&
458 (numRewrites
< config
.maxNumRewrites
||
459 config
.maxNumRewrites
== GreedyRewriteConfig::kNoLimit
)) {
460 auto *op
= worklist
.pop();
463 logger
.getOStream() << "\n";
464 logger
.startLine() << logLineComment
;
465 logger
.startLine() << "Processing operation : '" << op
->getName() << "'("
469 // If the operation has no regions, just print it here.
470 if (op
->getNumRegions() == 0) {
473 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474 logger
.getOStream() << "\n\n";
478 // If the operation is trivially dead - remove it.
479 if (isOpTriviallyDead(op
)) {
480 rewriter
.eraseOp(op
);
483 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
487 // Try to fold this op. Do not fold constant ops. That would lead to an
488 // infinite folding loop, as every constant op would be folded to an
489 // Attribute and then immediately be rematerialized as a constant op, which
490 // is then put on the worklist.
491 if (config
.fold
&& !op
->hasTrait
<OpTrait::ConstantLike
>()) {
492 SmallVector
<OpFoldResult
> foldResults
;
493 if (succeeded(op
->fold(foldResults
))) {
494 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
496 Operation
*dumpRootOp
= getDumpRootOp(op
);
498 if (foldResults
.empty()) {
499 // Op was modified in-place.
500 notifyOperationModified(op
);
502 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp
));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504 expensiveChecks
.notifyFoldingSuccess();
505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
509 // Op results can be replaced with `foldResults`.
510 assert(foldResults
.size() == op
->getNumResults() &&
511 "folder produced incorrect number of results");
512 OpBuilder::InsertionGuard
g(rewriter
);
513 rewriter
.setInsertionPoint(op
);
514 SmallVector
<Value
> replacements
;
515 bool materializationSucceeded
= true;
516 for (auto [ofr
, resultType
] :
517 llvm::zip_equal(foldResults
, op
->getResultTypes())) {
518 if (auto value
= ofr
.dyn_cast
<Value
>()) {
519 assert(value
.getType() == resultType
&&
520 "folder produced value of incorrect type");
521 replacements
.push_back(value
);
524 // Materialize Attributes as SSA values.
525 Operation
*constOp
= op
->getDialect()->materializeConstant(
526 rewriter
, cast
<Attribute
>(ofr
), resultType
, op
->getLoc());
529 // If materialization fails, cleanup any operations generated for
530 // the previous results.
531 llvm::SmallDenseSet
<Operation
*> replacementOps
;
532 for (Value replacement
: replacements
) {
533 assert(replacement
.use_empty() &&
534 "folder reused existing op for one result but constant "
535 "materialization failed for another result");
536 replacementOps
.insert(replacement
.getDefiningOp());
538 for (Operation
*op
: replacementOps
) {
539 rewriter
.eraseOp(op
);
542 materializationSucceeded
= false;
546 assert(constOp
->hasTrait
<OpTrait::ConstantLike
>() &&
547 "materializeConstant produced op that is not a ConstantLike");
548 assert(constOp
->getResultTypes()[0] == resultType
&&
549 "materializeConstant produced incorrect result type");
550 replacements
.push_back(constOp
->getResult(0));
553 if (materializationSucceeded
) {
554 rewriter
.replaceOp(op
, replacements
);
556 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp
));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558 expensiveChecks
.notifyFoldingSuccess();
559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
565 // Try to match one of the patterns. The rewriter is automatically
566 // notified of any necessary changes, so there is nothing else to do
568 auto canApplyCallback
= [&](const Pattern
&pattern
) {
570 logger
.getOStream() << "\n";
571 logger
.startLine() << "* Pattern " << pattern
.getDebugName() << " : '"
572 << op
->getName() << " -> (";
573 llvm::interleaveComma(pattern
.getGeneratedOps(), logger
.getOStream());
574 logger
.getOStream() << ")' {\n";
578 config
.listener
->notifyPatternBegin(pattern
, op
);
581 function_ref
<bool(const Pattern
&)> canApply
= canApplyCallback
;
582 auto onFailureCallback
= [&](const Pattern
&pattern
) {
583 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
585 config
.listener
->notifyPatternEnd(pattern
, failure());
587 function_ref
<void(const Pattern
&)> onFailure
= onFailureCallback
;
588 auto onSuccessCallback
= [&](const Pattern
&pattern
) {
589 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
591 config
.listener
->notifyPatternEnd(pattern
, success());
594 function_ref
<LogicalResult(const Pattern
&)> onSuccess
= onSuccessCallback
;
597 // Optimization: PatternApplicator callbacks are not needed when running in
598 // optimized mode and without a listener.
599 if (!config
.listener
) {
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
608 expensiveChecks
.computeFingerPrints(config
.scope
->getParentOp());
610 auto clearFingerprints
=
611 llvm::make_scope_exit([&]() { expensiveChecks
.clear(); });
612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
614 LogicalResult matchResult
=
615 matcher
.matchAndRewrite(op
, rewriter
, canApply
, onFailure
, onSuccess
);
617 if (succeeded(matchResult
)) {
618 LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620 expensiveChecks
.notifyRewriteSuccess();
621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625 LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627 expensiveChecks
.notifyRewriteFailure();
628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
635 void GreedyPatternRewriteDriver::addToWorklist(Operation
*op
) {
636 assert(op
&& "expected valid op");
637 // Gather potential ancestors while looking for a "scope" parent region.
638 SmallVector
<Operation
*, 8> ancestors
;
639 Region
*region
= nullptr;
641 ancestors
.push_back(op
);
642 region
= op
->getParentRegion();
643 if (config
.scope
== region
) {
644 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645 for (Operation
*op
: ancestors
)
646 addSingleOpToWorklist(op
);
649 if (region
== nullptr)
651 } while ((op
= region
->getParentOp()));
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation
*op
) {
655 if (config
.strictMode
== GreedyRewriteStrictness::AnyOp
||
656 strictModeFilteredOps
.contains(op
))
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
661 Block
*block
, Region
*previous
, Region::iterator previousIt
) {
663 config
.listener
->notifyBlockInserted(block
, previous
, previousIt
);
666 void GreedyPatternRewriteDriver::notifyBlockErased(Block
*block
) {
668 config
.listener
->notifyBlockErased(block
);
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
672 Operation
*op
, OpBuilder::InsertPoint previous
) {
674 logger
.startLine() << "** Insert : '" << op
->getName() << "'(" << op
678 config
.listener
->notifyOperationInserted(op
, previous
);
679 if (config
.strictMode
== GreedyRewriteStrictness::ExistingAndNewOps
)
680 strictModeFilteredOps
.insert(op
);
684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation
*op
) {
686 logger
.startLine() << "** Modified: '" << op
->getName() << "'(" << op
690 config
.listener
->notifyOperationModified(op
);
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation
*op
) {
695 for (Value operand
: op
->getOperands()) {
696 // If this operand currently has at most 2 users, add its defining op to the
697 // worklist. Indeed, after the op is deleted, then the operand will have at
698 // most 1 user left. If it has 0 users left, it can be deleted too,
699 // and if it has 1 user left, there may be further canonicalization
704 auto *defOp
= operand
.getDefiningOp();
708 Operation
*otherUser
= nullptr;
709 bool hasMoreThanTwoUses
= false;
710 for (auto user
: operand
.getUsers()) {
711 if (user
== op
|| user
== otherUser
)
717 hasMoreThanTwoUses
= true;
720 if (hasMoreThanTwoUses
)
723 addToWorklist(defOp
);
727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation
*op
) {
729 logger
.startLine() << "** Erase : '" << op
->getName() << "'(" << op
734 // Only ops that are within the configured scope are added to the worklist of
735 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736 // the part of the IR that is taken into account for the "expensive checks".
737 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738 // region, as that would break the worklist handling and the expensive checks.
739 if (config
.scope
&& config
.scope
->getParentOp() == op
)
741 "scope region must not be erased during greedy pattern rewrite");
745 config
.listener
->notifyOperationErased(op
);
747 addOperandsToWorklist(op
);
750 if (config
.strictMode
!= GreedyRewriteStrictness::AnyOp
)
751 strictModeFilteredOps
.erase(op
);
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
755 Operation
*op
, ValueRange replacement
) {
757 logger
.startLine() << "** Replace : '" << op
->getName() << "'(" << op
761 config
.listener
->notifyOperationReplaced(op
, replacement
);
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
765 Location loc
, function_ref
<void(Diagnostic
&)> reasonCallback
) {
767 Diagnostic
diag(loc
, DiagnosticSeverity::Remark
);
768 reasonCallback(diag
);
769 logger
.startLine() << "** Match Failure : " << diag
.str() << "\n";
772 config
.listener
->notifyMatchFailure(loc
, reasonCallback
);
775 //===----------------------------------------------------------------------===//
776 // RegionPatternRewriteDriver
777 //===----------------------------------------------------------------------===//
780 /// This driver simplfies all ops in a region.
781 class RegionPatternRewriteDriver
: public GreedyPatternRewriteDriver
{
783 explicit RegionPatternRewriteDriver(MLIRContext
*ctx
,
784 const FrozenRewritePatternSet
&patterns
,
785 const GreedyRewriteConfig
&config
,
788 /// Simplify ops inside `region` and simplify the region itself. Return
789 /// success if the transformation converged.
790 LogicalResult
simplify(bool *changed
) &&;
793 /// The region that is simplified.
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
799 MLIRContext
*ctx
, const FrozenRewritePatternSet
&patterns
,
800 const GreedyRewriteConfig
&config
, Region
®ion
)
801 : GreedyPatternRewriteDriver(ctx
, patterns
, config
), region(region
) {
802 // Populate strict mode ops.
803 if (config
.strictMode
!= GreedyRewriteStrictness::AnyOp
) {
804 region
.walk([&](Operation
*op
) { strictModeFilteredOps
.insert(op
); });
809 class GreedyPatternRewriteIteration
810 : public tracing::ActionImpl
<GreedyPatternRewriteIteration
> {
812 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration
)
813 GreedyPatternRewriteIteration(ArrayRef
<IRUnit
> units
, int64_t iteration
)
814 : tracing::ActionImpl
<GreedyPatternRewriteIteration
>(units
),
815 iteration(iteration
) {}
816 static constexpr StringLiteral tag
= "GreedyPatternRewriteIteration";
817 void print(raw_ostream
&os
) const override
{
818 os
<< "GreedyPatternRewriteIteration(" << iteration
<< ")";
822 int64_t iteration
= 0;
826 LogicalResult
RegionPatternRewriteDriver::simplify(bool *changed
) && {
827 bool continueRewrites
= false;
828 int64_t iteration
= 0;
829 MLIRContext
*ctx
= rewriter
.getContext();
831 // Check if the iteration limit was reached.
832 if (++iteration
> config
.maxIterations
&&
833 config
.maxIterations
!= GreedyRewriteConfig::kNoLimit
)
836 // New iteration: start with an empty worklist.
839 // `OperationFolder` CSE's constant ops (and may move them into parents
840 // regions to enable more aggressive CSE'ing).
841 OperationFolder
folder(ctx
, this);
842 auto insertKnownConstant
= [&](Operation
*op
) {
843 // Check for existing constants when populating the worklist. This avoids
844 // accidentally reversing the constant order during processing.
845 Attribute constValue
;
846 if (matchPattern(op
, m_Constant(&constValue
)))
847 if (!folder
.insertKnownConstant(op
, constValue
))
852 if (!config
.useTopDownTraversal
) {
853 // Add operations to the worklist in postorder.
854 region
.walk([&](Operation
*op
) {
855 if (!config
.cseConstants
|| !insertKnownConstant(op
))
859 // Add all nested operations to the worklist in preorder.
860 region
.walk
<WalkOrder::PreOrder
>([&](Operation
*op
) {
861 if (!config
.cseConstants
|| !insertKnownConstant(op
)) {
863 return WalkResult::advance();
865 return WalkResult::skip();
868 // Reverse the list so our pop-back loop processes them in-order.
872 ctx
->executeAction
<GreedyPatternRewriteIteration
>(
874 continueRewrites
= processWorklist();
876 // After applying patterns, make sure that the CFG of each of the
877 // regions is kept up to date.
878 if (config
.enableRegionSimplification
!=
879 GreedySimplifyRegionLevel::Disabled
) {
880 continueRewrites
|= succeeded(simplifyRegions(
882 /*mergeBlocks=*/config
.enableRegionSimplification
==
883 GreedySimplifyRegionLevel::Aggressive
));
886 {®ion
}, iteration
);
887 } while (continueRewrites
);
890 *changed
= iteration
> 1;
892 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
893 return success(!continueRewrites
);
897 mlir::applyPatternsGreedily(Region
®ion
,
898 const FrozenRewritePatternSet
&patterns
,
899 GreedyRewriteConfig config
, bool *changed
) {
900 // The top-level operation must be known to be isolated from above to
901 // prevent performing canonicalizations on operations defined at or above
902 // the region containing 'op'.
903 assert(region
.getParentOp()->hasTrait
<OpTrait::IsIsolatedFromAbove
>() &&
904 "patterns can only be applied to operations IsolatedFromAbove");
906 // Set scope if not specified.
908 config
.scope
= ®ion
;
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
911 if (failed(verify(config
.scope
->getParentOp())))
912 llvm::report_fatal_error(
913 "greedy pattern rewriter input IR failed to verify");
914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
916 // Start the pattern driver.
917 RegionPatternRewriteDriver
driver(region
.getContext(), patterns
, config
,
919 LogicalResult converged
= std::move(driver
).simplify(changed
);
920 LLVM_DEBUG(if (failed(converged
)) {
921 llvm::dbgs() << "The pattern rewrite did not converge after scanning "
922 << config
.maxIterations
<< " times\n";
927 //===----------------------------------------------------------------------===//
928 // MultiOpPatternRewriteDriver
929 //===----------------------------------------------------------------------===//
932 /// This driver simplfies a list of ops.
933 class MultiOpPatternRewriteDriver
: public GreedyPatternRewriteDriver
{
935 explicit MultiOpPatternRewriteDriver(
936 MLIRContext
*ctx
, const FrozenRewritePatternSet
&patterns
,
937 const GreedyRewriteConfig
&config
, ArrayRef
<Operation
*> ops
,
938 llvm::SmallDenseSet
<Operation
*, 4> *survivingOps
= nullptr);
940 /// Simplify `ops`. Return `success` if the transformation converged.
941 LogicalResult
simplify(ArrayRef
<Operation
*> ops
, bool *changed
= nullptr) &&;
944 void notifyOperationErased(Operation
*op
) override
{
945 GreedyPatternRewriteDriver::notifyOperationErased(op
);
947 survivingOps
->erase(op
);
950 /// An optional set of ops that survived the rewrite. This set is populated
951 /// at the beginning of `simplifyLocally` with the inititally provided list
953 llvm::SmallDenseSet
<Operation
*, 4> *const survivingOps
= nullptr;
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
958 MLIRContext
*ctx
, const FrozenRewritePatternSet
&patterns
,
959 const GreedyRewriteConfig
&config
, ArrayRef
<Operation
*> ops
,
960 llvm::SmallDenseSet
<Operation
*, 4> *survivingOps
)
961 : GreedyPatternRewriteDriver(ctx
, patterns
, config
),
962 survivingOps(survivingOps
) {
963 if (config
.strictMode
!= GreedyRewriteStrictness::AnyOp
)
964 strictModeFilteredOps
.insert(ops
.begin(), ops
.end());
967 survivingOps
->clear();
968 survivingOps
->insert(ops
.begin(), ops
.end());
972 LogicalResult
MultiOpPatternRewriteDriver::simplify(ArrayRef
<Operation
*> ops
,
974 // Populate the initial worklist.
975 for (Operation
*op
: ops
)
976 addSingleOpToWorklist(op
);
978 // Process ops on the worklist.
979 bool result
= processWorklist();
983 return success(worklist
.empty());
986 /// Find the region that is the closest common ancestor of all given ops.
988 /// Note: This function returns `nullptr` if there is a top-level op among the
989 /// given list of ops.
990 static Region
*findCommonAncestor(ArrayRef
<Operation
*> ops
) {
991 assert(!ops
.empty() && "expected at least one op");
992 // Fast path in case there is only one op.
994 return ops
.front()->getParentRegion();
996 Region
*region
= ops
.front()->getParentRegion();
997 ops
= ops
.drop_front();
999 llvm::BitVector
remainingOps(sz
, true);
1002 // Iterate over all remaining ops.
1003 while ((pos
= remainingOps
.find_first_in(pos
+ 1, sz
)) != -1) {
1004 // Is this op contained in `region`?
1005 if (region
->findAncestorOpInRegion(*ops
[pos
]))
1006 remainingOps
.reset(pos
);
1008 if (remainingOps
.none())
1010 region
= region
->getParentRegion();
1015 LogicalResult
mlir::applyOpPatternsGreedily(
1016 ArrayRef
<Operation
*> ops
, const FrozenRewritePatternSet
&patterns
,
1017 GreedyRewriteConfig config
, bool *changed
, bool *allErased
) {
1026 // Determine scope of rewrite.
1027 if (!config
.scope
) {
1028 // Compute scope if none was provided. The scope will remain `nullptr` if
1029 // there is a top-level op among `ops`.
1030 config
.scope
= findCommonAncestor(ops
);
1032 // If a scope was provided, make sure that all ops are in scope.
1034 bool allOpsInScope
= llvm::all_of(ops
, [&](Operation
*op
) {
1035 return static_cast<bool>(config
.scope
->findAncestorOpInRegion(*op
));
1037 assert(allOpsInScope
&& "ops must be within the specified scope");
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1042 if (config
.scope
&& failed(verify(config
.scope
->getParentOp())))
1043 llvm::report_fatal_error(
1044 "greedy pattern rewriter input IR failed to verify");
1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1047 // Start the pattern driver.
1048 llvm::SmallDenseSet
<Operation
*, 4> surviving
;
1049 MultiOpPatternRewriteDriver
driver(ops
.front()->getContext(), patterns
,
1051 allErased
? &surviving
: nullptr);
1052 LogicalResult converged
= std::move(driver
).simplify(ops
, changed
);
1054 *allErased
= surviving
.empty();
1055 LLVM_DEBUG(if (failed(converged
)) {
1056 llvm::dbgs() << "The pattern rewrite did not converge after "
1057 << config
.maxNumRewrites
<< " rewrites";