1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/Passes.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #define GEN_PASS_DEF_CANONICALIZER
21 #include "mlir/Transforms/Passes.h.inc"
27 /// Canonicalize operations in nested regions.
28 struct Canonicalizer
: public impl::CanonicalizerBase
<Canonicalizer
> {
29 Canonicalizer() = default;
30 Canonicalizer(const GreedyRewriteConfig
&config
,
31 ArrayRef
<std::string
> disabledPatterns
,
32 ArrayRef
<std::string
> enabledPatterns
)
34 this->topDownProcessingEnabled
= config
.useTopDownTraversal
;
35 this->enableRegionSimplification
= config
.enableRegionSimplification
;
36 this->maxIterations
= config
.maxIterations
;
37 this->maxNumRewrites
= config
.maxNumRewrites
;
38 this->disabledPatterns
= disabledPatterns
;
39 this->enabledPatterns
= enabledPatterns
;
42 /// Initialize the canonicalizer by building the set of patterns used during
44 LogicalResult
initialize(MLIRContext
*context
) override
{
45 // Set the config from possible pass options set in the meantime.
46 config
.useTopDownTraversal
= topDownProcessingEnabled
;
47 config
.enableRegionSimplification
= enableRegionSimplification
;
48 config
.maxIterations
= maxIterations
;
49 config
.maxNumRewrites
= maxNumRewrites
;
51 RewritePatternSet
owningPatterns(context
);
52 for (auto *dialect
: context
->getLoadedDialects())
53 dialect
->getCanonicalizationPatterns(owningPatterns
);
54 for (RegisteredOperationName op
: context
->getRegisteredOperations())
55 op
.getCanonicalizationPatterns(owningPatterns
, context
);
57 patterns
= std::make_shared
<FrozenRewritePatternSet
>(
58 std::move(owningPatterns
), disabledPatterns
, enabledPatterns
);
61 void runOnOperation() override
{
62 LogicalResult converged
=
63 applyPatternsAndFoldGreedily(getOperation(), *patterns
, config
);
64 // Canonicalization is best-effort. Non-convergence is not a pass failure.
65 if (testConvergence
&& failed(converged
))
68 GreedyRewriteConfig config
;
69 std::shared_ptr
<const FrozenRewritePatternSet
> patterns
;
73 /// Create a Canonicalizer pass.
74 std::unique_ptr
<Pass
> mlir::createCanonicalizerPass() {
75 return std::make_unique
<Canonicalizer
>();
78 /// Creates an instance of the Canonicalizer pass with the specified config.
80 mlir::createCanonicalizerPass(const GreedyRewriteConfig
&config
,
81 ArrayRef
<std::string
> disabledPatterns
,
82 ArrayRef
<std::string
> enabledPatterns
) {
83 return std::make_unique
<Canonicalizer
>(config
, disabledPatterns
,