IR: de-duplicate two CmpInst routines (NFC) (#116866)
[llvm-project.git] / mlir / lib / Transforms / Canonicalizer.cpp
blobd50019bd6aee5558bd2009a7c7b350dfdeddcd5d
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 namespace mlir {
20 #define GEN_PASS_DEF_CANONICALIZER
21 #include "mlir/Transforms/Passes.h.inc"
22 } // namespace mlir
24 using namespace mlir;
26 namespace {
27 /// Canonicalize operations in nested regions.
28 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
29 Canonicalizer() = default;
30 Canonicalizer(const GreedyRewriteConfig &config,
31 ArrayRef<std::string> disabledPatterns,
32 ArrayRef<std::string> enabledPatterns)
33 : config(config) {
34 this->topDownProcessingEnabled = config.useTopDownTraversal;
35 this->enableRegionSimplification = config.enableRegionSimplification;
36 this->maxIterations = config.maxIterations;
37 this->maxNumRewrites = config.maxNumRewrites;
38 this->disabledPatterns = disabledPatterns;
39 this->enabledPatterns = enabledPatterns;
42 /// Initialize the canonicalizer by building the set of patterns used during
43 /// execution.
44 LogicalResult initialize(MLIRContext *context) override {
45 // Set the config from possible pass options set in the meantime.
46 config.useTopDownTraversal = topDownProcessingEnabled;
47 config.enableRegionSimplification = enableRegionSimplification;
48 config.maxIterations = maxIterations;
49 config.maxNumRewrites = maxNumRewrites;
51 RewritePatternSet owningPatterns(context);
52 for (auto *dialect : context->getLoadedDialects())
53 dialect->getCanonicalizationPatterns(owningPatterns);
54 for (RegisteredOperationName op : context->getRegisteredOperations())
55 op.getCanonicalizationPatterns(owningPatterns, context);
57 patterns = std::make_shared<FrozenRewritePatternSet>(
58 std::move(owningPatterns), disabledPatterns, enabledPatterns);
59 return success();
61 void runOnOperation() override {
62 LogicalResult converged =
63 applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
64 // Canonicalization is best-effort. Non-convergence is not a pass failure.
65 if (testConvergence && failed(converged))
66 signalPassFailure();
68 GreedyRewriteConfig config;
69 std::shared_ptr<const FrozenRewritePatternSet> patterns;
71 } // namespace
73 /// Create a Canonicalizer pass.
74 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
75 return std::make_unique<Canonicalizer>();
78 /// Creates an instance of the Canonicalizer pass with the specified config.
79 std::unique_ptr<Pass>
80 mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
81 ArrayRef<std::string> disabledPatterns,
82 ArrayRef<std::string> enabledPatterns) {
83 return std::make_unique<Canonicalizer>(config, disabledPatterns,
84 enabledPatterns);