[OpenMP] Adjust 'printf' handling in the OpenMP runtime (#123670)
[llvm-project.git] / mlir / lib / Transforms / Utils / GreedyPatternRewriteDriver.cpp
blob969c560c99ab7cb755d3b55cbc6f6e66d8a4c37c
1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsGreedily.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 #include "mlir/Config/mlir-config.h"
16 #include "mlir/IR/Action.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Rewrite/PatternApplicator.h"
21 #include "mlir/Transforms/FoldUtils.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/BitVector.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
31 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32 #include <random>
33 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
35 using namespace mlir;
37 #define DEBUG_TYPE "greedy-rewriter"
39 namespace {
41 //===----------------------------------------------------------------------===//
42 // Debugging Infrastructure
43 //===----------------------------------------------------------------------===//
45 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46 /// A helper struct that performs various "expensive checks" to detect broken
47 /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48 /// broken if:
49 /// * IR does not verify after pattern application / folding.
50 /// * Pattern returns "failure" but the IR has changed.
51 /// * Pattern returns "success" but the IR has not changed.
52 ///
53 /// This struct stores finger prints of ops to determine whether the IR has
54 /// changed or not.
55 struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
59 /// Compute finger prints of the given op and its nested ops.
60 void computeFingerPrints(Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
63 topLevel->walk([&](Operation *op) {
64 fingerprints.try_emplace(op, op, /*includeNested=*/false);
65 });
68 /// Clear all finger prints.
69 void clear() {
70 topLevel = nullptr;
71 topLevelFingerPrint.reset();
72 fingerprints.clear();
75 void notifyRewriteSuccess() {
76 if (!topLevel)
77 return;
79 // Make sure that the IR still verifies.
80 if (failed(verify(topLevel)))
81 llvm::report_fatal_error("IR failed to verify after pattern application");
83 // Pattern application success => IR must have changed.
84 OperationFingerPrint afterFingerPrint(topLevel);
85 if (*topLevelFingerPrint == afterFingerPrint) {
86 // Note: Run "mlir-opt -debug" to see which pattern is broken.
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
90 for (const auto &it : fingerprints) {
91 // Skip top-level op, its finger print is never invalidated.
92 if (it.first == topLevel)
93 continue;
94 // Note: Finger print computation may crash when an op was erased
95 // without notifying the rewriter. (Run with ASAN to see where the op was
96 // erased; the op was probably erased directly, bypassing the rewriter
97 // API.) Finger print computation does may not crash if a new op was
98 // created at the same memory location. (But then the finger print should
99 // have changed.)
100 if (it.second !=
101 OperationFingerPrint(it.first, /*includeNested=*/false)) {
102 // Note: Run "mlir-opt -debug" to see which pattern is broken.
103 llvm::report_fatal_error("operation finger print changed");
108 void notifyRewriteFailure() {
109 if (!topLevel)
110 return;
112 // Pattern application failure => IR must not have changed.
113 OperationFingerPrint afterFingerPrint(topLevel);
114 if (*topLevelFingerPrint != afterFingerPrint) {
115 // Note: Run "mlir-opt -debug" to see which pattern is broken.
116 llvm::report_fatal_error("pattern returned failure but IR did change");
120 void notifyFoldingSuccess() {
121 if (!topLevel)
122 return;
124 // Make sure that the IR still verifies.
125 if (failed(verify(topLevel)))
126 llvm::report_fatal_error("IR failed to verify after folding");
129 protected:
130 /// Invalidate the finger print of the given op, i.e., remove it from the map.
131 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
133 void notifyBlockErased(Block *block) override {
134 RewriterBase::ForwardingListener::notifyBlockErased(block);
136 // The block structure (number of blocks, types of block arguments, etc.)
137 // is part of the fingerprint of the parent op.
138 // TODO: The parent op fingerprint should also be invalidated when modifying
139 // the block arguments of a block, but we do not have a
140 // `notifyBlockModified` callback yet.
141 invalidateFingerPrint(block->getParentOp());
144 void notifyOperationInserted(Operation *op,
145 OpBuilder::InsertPoint previous) override {
146 RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
147 invalidateFingerPrint(op->getParentOp());
150 void notifyOperationModified(Operation *op) override {
151 RewriterBase::ForwardingListener::notifyOperationModified(op);
152 invalidateFingerPrint(op);
155 void notifyOperationErased(Operation *op) override {
156 RewriterBase::ForwardingListener::notifyOperationErased(op);
157 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
160 /// Operation finger prints to detect invalid pattern API usage. IR is checked
161 /// against these finger prints after pattern application to detect cases
162 /// where IR was modified directly, bypassing the rewriter API.
163 DenseMap<Operation *, OperationFingerPrint> fingerprints;
165 /// Top-level operation of the current greedy rewrite.
166 Operation *topLevel = nullptr;
168 /// Finger print of the top-level operation.
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
171 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
173 #ifndef NDEBUG
174 static Operation *getDumpRootOp(Operation *op) {
175 // Dump the parent op so that materialized constants are visible. If the op
176 // is a top-level op, dump it directly.
177 if (Operation *parentOp = op->getParentOp())
178 return parentOp;
179 return op;
181 static void logSuccessfulFolding(Operation *op) {
182 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183 op->dump();
184 llvm::dbgs() << "\n\n";
186 #endif // NDEBUG
188 //===----------------------------------------------------------------------===//
189 // Worklist
190 //===----------------------------------------------------------------------===//
192 /// A LIFO worklist of operations with efficient removal and set semantics.
194 /// This class maintains a vector of operations and a mapping of operations to
195 /// positions in the vector, so that operations can be removed efficiently at
196 /// random. When an operation is removed, it is replaced with nullptr. Such
197 /// nullptr are skipped when pop'ing elements.
198 class Worklist {
199 public:
200 Worklist();
202 /// Clear the worklist.
203 void clear();
205 /// Return whether the worklist is empty.
206 bool empty() const;
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation *op);
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
214 Operation *pop();
216 /// Remove an operation from the worklist.
217 void remove(Operation *op);
219 /// Reverse the worklist.
220 void reverse();
222 protected:
223 /// The worklist of operations.
224 std::vector<Operation *> list;
226 /// A mapping of operations to positions in `list`.
227 DenseMap<Operation *, unsigned> map;
230 Worklist::Worklist() { list.reserve(64); }
232 void Worklist::clear() {
233 list.clear();
234 map.clear();
237 bool Worklist::empty() const {
238 // Skip all nullptr.
239 return !llvm::any_of(list,
240 [](Operation *op) { return static_cast<bool>(op); });
243 void Worklist::push(Operation *op) {
244 assert(op && "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map.insert({op, list.size()}).second)
247 return;
248 list.push_back(op);
251 Operation *Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
254 while (!list.back())
255 list.pop_back();
256 Operation *op = list.back();
257 list.pop_back();
258 map.erase(op);
259 // Cleanup: Remove all trailing nullptr.
260 while (!list.empty() && !list.back())
261 list.pop_back();
262 return op;
265 void Worklist::remove(Operation *op) {
266 assert(op && "cannot remove nullptr from worklist");
267 auto it = map.find(op);
268 if (it != map.end()) {
269 assert(list[it->second] == op && "malformed worklist data structure");
270 list[it->second] = nullptr;
271 map.erase(it);
275 void Worklist::reverse() {
276 std::reverse(list.begin(), list.end());
277 for (size_t i = 0, e = list.size(); i != e; ++i)
278 map[list[i]] = i;
281 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282 /// A worklist that pops elements at a random position. This worklist is for
283 /// testing/debugging purposes only. It can be used to ensure that lowering
284 /// pipelines work correctly regardless of the order in which ops are processed
285 /// by the GreedyPatternRewriteDriver.
286 class RandomizedWorklist : public Worklist {
287 public:
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
292 /// Pop a random non-empty op from the worklist.
293 Operation *pop() {
294 Operation *op = nullptr;
295 do {
296 assert(!list.empty() && "cannot pop from empty worklist");
297 int64_t pos = generator() % list.size();
298 op = list[pos];
299 list.erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
301 map[list[i]] = i;
302 map.erase(op);
303 } while (!op);
304 return op;
307 private:
308 std::minstd_rand0 generator;
310 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
312 //===----------------------------------------------------------------------===//
313 // GreedyPatternRewriteDriver
314 //===----------------------------------------------------------------------===//
316 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317 /// applies the locally optimal patterns.
319 /// This abstract class manages the worklist and contains helper methods for
320 /// rewriting ops on the worklist. Derived classes specify how ops are added
321 /// to the worklist in the beginning.
322 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323 protected:
324 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
325 const FrozenRewritePatternSet &patterns,
326 const GreedyRewriteConfig &config);
328 /// Add the given operation to the worklist.
329 void addSingleOpToWorklist(Operation *op);
331 /// Add the given operation and its ancestors to the worklist.
332 void addToWorklist(Operation *op);
334 /// Notify the driver that the specified operation may have been modified
335 /// in-place. The operation is added to the worklist.
336 void notifyOperationModified(Operation *op) override;
338 /// Notify the driver that the specified operation was inserted. Update the
339 /// worklist as needed: The operation is enqueued depending on scope and
340 /// strict mode.
341 void notifyOperationInserted(Operation *op,
342 OpBuilder::InsertPoint previous) override;
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
346 /// worklist.
347 void notifyOperationErased(Operation *op) override;
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
357 /// The pattern rewriter that is used for making IR modifications and is
358 /// passed to rewrite patterns.
359 PatternRewriter rewriter;
361 /// The worklist for this transformation keeps track of the operations that
362 /// need to be (re)visited.
363 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
365 #else
366 Worklist worklist;
367 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
369 /// Configuration information for how to simplify.
370 const GreedyRewriteConfig config;
372 /// The list of ops we are restricting our rewrites to. These include the
373 /// supplied set of ops as well as new ops created while rewriting those ops
374 /// depending on `strictMode`. This set is not maintained when
375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
378 private:
379 /// Look over the provided operands for any defining operations that should
380 /// be re-added to the worklist. This function should be called when an
381 /// operation is modified or removed, as it may trigger further
382 /// simplifications.
383 void addOperandsToWorklist(Operation *op);
385 /// Notify the driver that the given block was inserted.
386 void notifyBlockInserted(Block *block, Region *previous,
387 Region::iterator previousIt) override;
389 /// Notify the driver that the given block is about to be removed.
390 void notifyBlockErased(Block *block) override;
392 /// For debugging only: Notify the driver of a pattern match failure.
393 void
394 notifyMatchFailure(Location loc,
395 function_ref<void(Diagnostic &)> reasonCallback) override;
397 #ifndef NDEBUG
398 /// A logger used to emit information during the application process.
399 llvm::ScopedPrinter logger{llvm::dbgs()};
400 #endif
402 /// The low-level pattern applicator.
403 PatternApplicator matcher;
405 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406 ExpensiveChecks expensiveChecks;
407 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
409 } // namespace
411 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
413 const GreedyRewriteConfig &config)
414 : rewriter(ctx), config(config), matcher(patterns)
415 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416 // clang-format off
417 , expensiveChecks(
418 /*driver=*/this,
419 /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
420 // clang-format on
421 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
423 // Apply a simple cost model based solely on pattern benefit.
424 matcher.applyDefaultCostModel();
426 // Set up listener.
427 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428 // Send IR notifications to the debug handler. This handler will then forward
429 // all notifications to this GreedyPatternRewriteDriver.
430 rewriter.setListener(&expensiveChecks);
431 #else
432 rewriter.setListener(this);
433 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
436 bool GreedyPatternRewriteDriver::processWorklist() {
437 #ifndef NDEBUG
438 const char *logLineComment =
439 "//===-------------------------------------------===//\n";
441 /// A utility function to log a process result for the given reason.
442 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
443 logger.unindent();
444 logger.startLine() << "} -> " << result;
445 if (!msg.isTriviallyEmpty())
446 logger.getOStream() << " : " << msg;
447 logger.getOStream() << "\n";
449 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
450 logResult(result, msg);
451 logger.startLine() << logLineComment;
453 #endif
455 bool changed = false;
456 int64_t numRewrites = 0;
457 while (!worklist.empty() &&
458 (numRewrites < config.maxNumRewrites ||
459 config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
460 auto *op = worklist.pop();
462 LLVM_DEBUG({
463 logger.getOStream() << "\n";
464 logger.startLine() << logLineComment;
465 logger.startLine() << "Processing operation : '" << op->getName() << "'("
466 << op << ") {\n";
467 logger.indent();
469 // If the operation has no regions, just print it here.
470 if (op->getNumRegions() == 0) {
471 op->print(
472 logger.startLine(),
473 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474 logger.getOStream() << "\n\n";
478 // If the operation is trivially dead - remove it.
479 if (isOpTriviallyDead(op)) {
480 rewriter.eraseOp(op);
481 changed = true;
483 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
484 continue;
487 // Try to fold this op. Do not fold constant ops. That would lead to an
488 // infinite folding loop, as every constant op would be folded to an
489 // Attribute and then immediately be rematerialized as a constant op, which
490 // is then put on the worklist.
491 if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
492 SmallVector<OpFoldResult> foldResults;
493 if (succeeded(op->fold(foldResults))) {
494 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495 #ifndef NDEBUG
496 Operation *dumpRootOp = getDumpRootOp(op);
497 #endif // NDEBUG
498 if (foldResults.empty()) {
499 // Op was modified in-place.
500 notifyOperationModified(op);
501 changed = true;
502 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504 expensiveChecks.notifyFoldingSuccess();
505 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506 continue;
509 // Op results can be replaced with `foldResults`.
510 assert(foldResults.size() == op->getNumResults() &&
511 "folder produced incorrect number of results");
512 OpBuilder::InsertionGuard g(rewriter);
513 rewriter.setInsertionPoint(op);
514 SmallVector<Value> replacements;
515 bool materializationSucceeded = true;
516 for (auto [ofr, resultType] :
517 llvm::zip_equal(foldResults, op->getResultTypes())) {
518 if (auto value = ofr.dyn_cast<Value>()) {
519 assert(value.getType() == resultType &&
520 "folder produced value of incorrect type");
521 replacements.push_back(value);
522 continue;
524 // Materialize Attributes as SSA values.
525 Operation *constOp = op->getDialect()->materializeConstant(
526 rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
528 if (!constOp) {
529 // If materialization fails, cleanup any operations generated for
530 // the previous results.
531 llvm::SmallDenseSet<Operation *> replacementOps;
532 for (Value replacement : replacements) {
533 assert(replacement.use_empty() &&
534 "folder reused existing op for one result but constant "
535 "materialization failed for another result");
536 replacementOps.insert(replacement.getDefiningOp());
538 for (Operation *op : replacementOps) {
539 rewriter.eraseOp(op);
542 materializationSucceeded = false;
543 break;
546 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547 "materializeConstant produced op that is not a ConstantLike");
548 assert(constOp->getResultTypes()[0] == resultType &&
549 "materializeConstant produced incorrect result type");
550 replacements.push_back(constOp->getResult(0));
553 if (materializationSucceeded) {
554 rewriter.replaceOp(op, replacements);
555 changed = true;
556 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558 expensiveChecks.notifyFoldingSuccess();
559 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560 continue;
565 // Try to match one of the patterns. The rewriter is automatically
566 // notified of any necessary changes, so there is nothing else to do
567 // here.
568 auto canApplyCallback = [&](const Pattern &pattern) {
569 LLVM_DEBUG({
570 logger.getOStream() << "\n";
571 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
572 << op->getName() << " -> (";
573 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574 logger.getOStream() << ")' {\n";
575 logger.indent();
577 if (config.listener)
578 config.listener->notifyPatternBegin(pattern, op);
579 return true;
581 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
582 auto onFailureCallback = [&](const Pattern &pattern) {
583 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
584 if (config.listener)
585 config.listener->notifyPatternEnd(pattern, failure());
587 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
588 auto onSuccessCallback = [&](const Pattern &pattern) {
589 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
590 if (config.listener)
591 config.listener->notifyPatternEnd(pattern, success());
592 return success();
594 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
596 #ifdef NDEBUG
597 // Optimization: PatternApplicator callbacks are not needed when running in
598 // optimized mode and without a listener.
599 if (!config.listener) {
600 canApply = nullptr;
601 onFailure = nullptr;
602 onSuccess = nullptr;
604 #endif // NDEBUG
606 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
607 if (config.scope) {
608 expensiveChecks.computeFingerPrints(config.scope->getParentOp());
610 auto clearFingerprints =
611 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
614 LogicalResult matchResult =
615 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
617 if (succeeded(matchResult)) {
618 LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
619 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620 expensiveChecks.notifyRewriteSuccess();
621 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622 changed = true;
623 ++numRewrites;
624 } else {
625 LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
626 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627 expensiveChecks.notifyRewriteFailure();
628 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
632 return changed;
635 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636 assert(op && "expected valid op");
637 // Gather potential ancestors while looking for a "scope" parent region.
638 SmallVector<Operation *, 8> ancestors;
639 Region *region = nullptr;
640 do {
641 ancestors.push_back(op);
642 region = op->getParentRegion();
643 if (config.scope == region) {
644 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645 for (Operation *op : ancestors)
646 addSingleOpToWorklist(op);
647 return;
649 if (region == nullptr)
650 return;
651 } while ((op = region->getParentOp()));
654 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
655 if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
656 strictModeFilteredOps.contains(op))
657 worklist.push(op);
660 void GreedyPatternRewriteDriver::notifyBlockInserted(
661 Block *block, Region *previous, Region::iterator previousIt) {
662 if (config.listener)
663 config.listener->notifyBlockInserted(block, previous, previousIt);
666 void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
667 if (config.listener)
668 config.listener->notifyBlockErased(block);
671 void GreedyPatternRewriteDriver::notifyOperationInserted(
672 Operation *op, OpBuilder::InsertPoint previous) {
673 LLVM_DEBUG({
674 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
675 << ")\n";
677 if (config.listener)
678 config.listener->notifyOperationInserted(op, previous);
679 if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
680 strictModeFilteredOps.insert(op);
681 addToWorklist(op);
684 void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
685 LLVM_DEBUG({
686 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
687 << ")\n";
689 if (config.listener)
690 config.listener->notifyOperationModified(op);
691 addToWorklist(op);
694 void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695 for (Value operand : op->getOperands()) {
696 // If this operand currently has at most 2 users, add its defining op to the
697 // worklist. Indeed, after the op is deleted, then the operand will have at
698 // most 1 user left. If it has 0 users left, it can be deleted too,
699 // and if it has 1 user left, there may be further canonicalization
700 // opportunities.
701 if (!operand)
702 continue;
704 auto *defOp = operand.getDefiningOp();
705 if (!defOp)
706 continue;
708 Operation *otherUser = nullptr;
709 bool hasMoreThanTwoUses = false;
710 for (auto user : operand.getUsers()) {
711 if (user == op || user == otherUser)
712 continue;
713 if (!otherUser) {
714 otherUser = user;
715 continue;
717 hasMoreThanTwoUses = true;
718 break;
720 if (hasMoreThanTwoUses)
721 continue;
723 addToWorklist(defOp);
727 void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
728 LLVM_DEBUG({
729 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
730 << ")\n";
733 #ifndef NDEBUG
734 // Only ops that are within the configured scope are added to the worklist of
735 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736 // the part of the IR that is taken into account for the "expensive checks".
737 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738 // region, as that would break the worklist handling and the expensive checks.
739 if (config.scope && config.scope->getParentOp() == op)
740 llvm_unreachable(
741 "scope region must not be erased during greedy pattern rewrite");
742 #endif // NDEBUG
744 if (config.listener)
745 config.listener->notifyOperationErased(op);
747 addOperandsToWorklist(op);
748 worklist.remove(op);
750 if (config.strictMode != GreedyRewriteStrictness::AnyOp)
751 strictModeFilteredOps.erase(op);
754 void GreedyPatternRewriteDriver::notifyOperationReplaced(
755 Operation *op, ValueRange replacement) {
756 LLVM_DEBUG({
757 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
758 << ")\n";
760 if (config.listener)
761 config.listener->notifyOperationReplaced(op, replacement);
764 void GreedyPatternRewriteDriver::notifyMatchFailure(
765 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
766 LLVM_DEBUG({
767 Diagnostic diag(loc, DiagnosticSeverity::Remark);
768 reasonCallback(diag);
769 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
771 if (config.listener)
772 config.listener->notifyMatchFailure(loc, reasonCallback);
775 //===----------------------------------------------------------------------===//
776 // RegionPatternRewriteDriver
777 //===----------------------------------------------------------------------===//
779 namespace {
780 /// This driver simplfies all ops in a region.
781 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
782 public:
783 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
784 const FrozenRewritePatternSet &patterns,
785 const GreedyRewriteConfig &config,
786 Region &regions);
788 /// Simplify ops inside `region` and simplify the region itself. Return
789 /// success if the transformation converged.
790 LogicalResult simplify(bool *changed) &&;
792 private:
793 /// The region that is simplified.
794 Region &region;
796 } // namespace
798 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
799 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
800 const GreedyRewriteConfig &config, Region &region)
801 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
802 // Populate strict mode ops.
803 if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
804 region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
808 namespace {
809 class GreedyPatternRewriteIteration
810 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
811 public:
812 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
813 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
814 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815 iteration(iteration) {}
816 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
817 void print(raw_ostream &os) const override {
818 os << "GreedyPatternRewriteIteration(" << iteration << ")";
821 private:
822 int64_t iteration = 0;
824 } // namespace
826 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
827 bool continueRewrites = false;
828 int64_t iteration = 0;
829 MLIRContext *ctx = rewriter.getContext();
830 do {
831 // Check if the iteration limit was reached.
832 if (++iteration > config.maxIterations &&
833 config.maxIterations != GreedyRewriteConfig::kNoLimit)
834 break;
836 // New iteration: start with an empty worklist.
837 worklist.clear();
839 // `OperationFolder` CSE's constant ops (and may move them into parents
840 // regions to enable more aggressive CSE'ing).
841 OperationFolder folder(ctx, this);
842 auto insertKnownConstant = [&](Operation *op) {
843 // Check for existing constants when populating the worklist. This avoids
844 // accidentally reversing the constant order during processing.
845 Attribute constValue;
846 if (matchPattern(op, m_Constant(&constValue)))
847 if (!folder.insertKnownConstant(op, constValue))
848 return true;
849 return false;
852 if (!config.useTopDownTraversal) {
853 // Add operations to the worklist in postorder.
854 region.walk([&](Operation *op) {
855 if (!config.cseConstants || !insertKnownConstant(op))
856 addToWorklist(op);
858 } else {
859 // Add all nested operations to the worklist in preorder.
860 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
861 if (!config.cseConstants || !insertKnownConstant(op)) {
862 addToWorklist(op);
863 return WalkResult::advance();
865 return WalkResult::skip();
868 // Reverse the list so our pop-back loop processes them in-order.
869 worklist.reverse();
872 ctx->executeAction<GreedyPatternRewriteIteration>(
873 [&] {
874 continueRewrites = processWorklist();
876 // After applying patterns, make sure that the CFG of each of the
877 // regions is kept up to date.
878 if (config.enableRegionSimplification !=
879 GreedySimplifyRegionLevel::Disabled) {
880 continueRewrites |= succeeded(simplifyRegions(
881 rewriter, region,
882 /*mergeBlocks=*/config.enableRegionSimplification ==
883 GreedySimplifyRegionLevel::Aggressive));
886 {&region}, iteration);
887 } while (continueRewrites);
889 if (changed)
890 *changed = iteration > 1;
892 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
893 return success(!continueRewrites);
896 LogicalResult
897 mlir::applyPatternsGreedily(Region &region,
898 const FrozenRewritePatternSet &patterns,
899 GreedyRewriteConfig config, bool *changed) {
900 // The top-level operation must be known to be isolated from above to
901 // prevent performing canonicalizations on operations defined at or above
902 // the region containing 'op'.
903 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
904 "patterns can only be applied to operations IsolatedFromAbove");
906 // Set scope if not specified.
907 if (!config.scope)
908 config.scope = &region;
910 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
911 if (failed(verify(config.scope->getParentOp())))
912 llvm::report_fatal_error(
913 "greedy pattern rewriter input IR failed to verify");
914 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
916 // Start the pattern driver.
917 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
918 region);
919 LogicalResult converged = std::move(driver).simplify(changed);
920 LLVM_DEBUG(if (failed(converged)) {
921 llvm::dbgs() << "The pattern rewrite did not converge after scanning "
922 << config.maxIterations << " times\n";
924 return converged;
927 //===----------------------------------------------------------------------===//
928 // MultiOpPatternRewriteDriver
929 //===----------------------------------------------------------------------===//
931 namespace {
932 /// This driver simplfies a list of ops.
933 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
934 public:
935 explicit MultiOpPatternRewriteDriver(
936 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
937 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
938 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
940 /// Simplify `ops`. Return `success` if the transformation converged.
941 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
943 private:
944 void notifyOperationErased(Operation *op) override {
945 GreedyPatternRewriteDriver::notifyOperationErased(op);
946 if (survivingOps)
947 survivingOps->erase(op);
950 /// An optional set of ops that survived the rewrite. This set is populated
951 /// at the beginning of `simplifyLocally` with the inititally provided list
952 /// of ops.
953 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
955 } // namespace
957 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
958 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
959 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
960 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
961 : GreedyPatternRewriteDriver(ctx, patterns, config),
962 survivingOps(survivingOps) {
963 if (config.strictMode != GreedyRewriteStrictness::AnyOp)
964 strictModeFilteredOps.insert(ops.begin(), ops.end());
966 if (survivingOps) {
967 survivingOps->clear();
968 survivingOps->insert(ops.begin(), ops.end());
972 LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
973 bool *changed) && {
974 // Populate the initial worklist.
975 for (Operation *op : ops)
976 addSingleOpToWorklist(op);
978 // Process ops on the worklist.
979 bool result = processWorklist();
980 if (changed)
981 *changed = result;
983 return success(worklist.empty());
986 /// Find the region that is the closest common ancestor of all given ops.
988 /// Note: This function returns `nullptr` if there is a top-level op among the
989 /// given list of ops.
990 static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
991 assert(!ops.empty() && "expected at least one op");
992 // Fast path in case there is only one op.
993 if (ops.size() == 1)
994 return ops.front()->getParentRegion();
996 Region *region = ops.front()->getParentRegion();
997 ops = ops.drop_front();
998 int sz = ops.size();
999 llvm::BitVector remainingOps(sz, true);
1000 while (region) {
1001 int pos = -1;
1002 // Iterate over all remaining ops.
1003 while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
1004 // Is this op contained in `region`?
1005 if (region->findAncestorOpInRegion(*ops[pos]))
1006 remainingOps.reset(pos);
1008 if (remainingOps.none())
1009 break;
1010 region = region->getParentRegion();
1012 return region;
1015 LogicalResult mlir::applyOpPatternsGreedily(
1016 ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1017 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1018 if (ops.empty()) {
1019 if (changed)
1020 *changed = false;
1021 if (allErased)
1022 *allErased = true;
1023 return success();
1026 // Determine scope of rewrite.
1027 if (!config.scope) {
1028 // Compute scope if none was provided. The scope will remain `nullptr` if
1029 // there is a top-level op among `ops`.
1030 config.scope = findCommonAncestor(ops);
1031 } else {
1032 // If a scope was provided, make sure that all ops are in scope.
1033 #ifndef NDEBUG
1034 bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1035 return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
1037 assert(allOpsInScope && "ops must be within the specified scope");
1038 #endif // NDEBUG
1041 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1042 if (config.scope && failed(verify(config.scope->getParentOp())))
1043 llvm::report_fatal_error(
1044 "greedy pattern rewriter input IR failed to verify");
1045 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1047 // Start the pattern driver.
1048 llvm::SmallDenseSet<Operation *, 4> surviving;
1049 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1050 config, ops,
1051 allErased ? &surviving : nullptr);
1052 LogicalResult converged = std::move(driver).simplify(ops, changed);
1053 if (allErased)
1054 *allErased = surviving.empty();
1055 LLVM_DEBUG(if (failed(converged)) {
1056 llvm::dbgs() << "The pattern rewrite did not converge after "
1057 << config.maxNumRewrites << " rewrites";
1059 return converged;