1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
11 #include "mlir/Interfaces/SideEffectInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
18 // Include the PDL rewrite support.
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
21 #include "mlir/Dialect/PDL/IR/PDLOps.h"
24 convertPDLToPDLInterp(ModuleOp pdlModule
,
25 DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
) {
26 // Skip the conversion if the module doesn't contain pdl.
27 if (pdlModule
.getOps
<pdl::PatternOp
>().empty())
30 // Simplify the provided PDL module. Note that we can't use the canonicalizer
31 // here because it would create a cyclic dependency.
32 auto simplifyFn
= [](Operation
*op
) {
33 // TODO: Add folding here if ever necessary.
34 if (isOpTriviallyDead(op
))
37 pdlModule
.getBody()->walk(simplifyFn
);
39 /// Lower the PDL pattern module to the interpreter dialect.
40 PassManager
pdlPipeline(pdlModule
->getName());
42 // We don't want to incur the hit of running the verifier when in release
44 pdlPipeline
.enableVerifier(false);
46 pdlPipeline
.addPass(createPDLToPDLInterpPass(configMap
));
47 if (failed(pdlPipeline
.run(pdlModule
)))
50 // Simplify again after running the lowering pipeline.
51 pdlModule
.getBody()->walk(simplifyFn
);
54 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
56 //===----------------------------------------------------------------------===//
57 // FrozenRewritePatternSet
58 //===----------------------------------------------------------------------===//
60 FrozenRewritePatternSet::FrozenRewritePatternSet()
61 : impl(std::make_shared
<Impl
>()) {}
63 FrozenRewritePatternSet::FrozenRewritePatternSet(
64 RewritePatternSet
&&patterns
, ArrayRef
<std::string
> disabledPatternLabels
,
65 ArrayRef
<std::string
> enabledPatternLabels
)
66 : impl(std::make_shared
<Impl
>()) {
67 DenseSet
<StringRef
> disabledPatterns
, enabledPatterns
;
68 disabledPatterns
.insert(disabledPatternLabels
.begin(),
69 disabledPatternLabels
.end());
70 enabledPatterns
.insert(enabledPatternLabels
.begin(),
71 enabledPatternLabels
.end());
73 // Functor used to walk all of the operations registered in the context. This
74 // is useful for patterns that get applied to multiple operations, such as
75 // interface and trait based patterns.
76 std::vector
<RegisteredOperationName
> opInfos
;
78 [&](std::unique_ptr
<RewritePattern
> &pattern
,
79 function_ref
<bool(RegisteredOperationName
)> callbackFn
) {
81 opInfos
= pattern
->getContext()->getRegisteredOperations();
82 for (RegisteredOperationName info
: opInfos
)
84 impl
->nativeOpSpecificPatternMap
[info
].push_back(pattern
.get());
85 impl
->nativeOpSpecificPatternList
.push_back(std::move(pattern
));
88 for (std::unique_ptr
<RewritePattern
> &pat
: patterns
.getNativePatterns()) {
89 // Don't add patterns that haven't been enabled by the user.
90 if (!enabledPatterns
.empty()) {
91 auto isEnabledFn
= [&](StringRef label
) {
92 return enabledPatterns
.count(label
);
94 if (!isEnabledFn(pat
->getDebugName()) &&
95 llvm::none_of(pat
->getDebugLabels(), isEnabledFn
))
98 // Don't add patterns that have been disabled by the user.
99 if (!disabledPatterns
.empty()) {
100 auto isDisabledFn
= [&](StringRef label
) {
101 return disabledPatterns
.count(label
);
103 if (isDisabledFn(pat
->getDebugName()) ||
104 llvm::any_of(pat
->getDebugLabels(), isDisabledFn
))
108 if (std::optional
<OperationName
> rootName
= pat
->getRootKind()) {
109 impl
->nativeOpSpecificPatternMap
[*rootName
].push_back(pat
.get());
110 impl
->nativeOpSpecificPatternList
.push_back(std::move(pat
));
113 if (std::optional
<TypeID
> interfaceID
= pat
->getRootInterfaceID()) {
114 addToOpsWhen(pat
, [&](RegisteredOperationName info
) {
115 return info
.hasInterface(*interfaceID
);
119 if (std::optional
<TypeID
> traitID
= pat
->getRootTraitID()) {
120 addToOpsWhen(pat
, [&](RegisteredOperationName info
) {
121 return info
.hasTrait(*traitID
);
125 impl
->nativeAnyOpPatterns
.push_back(std::move(pat
));
128 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
129 // Generate the bytecode for the PDL patterns if any were provided.
130 PDLPatternModule
&pdlPatterns
= patterns
.getPDLPatterns();
131 ModuleOp pdlModule
= pdlPatterns
.getModule();
134 DenseMap
<Operation
*, PDLPatternConfigSet
*> configMap
=
135 pdlPatterns
.takeConfigMap();
136 if (failed(convertPDLToPDLInterp(pdlModule
, configMap
)))
137 llvm::report_fatal_error(
138 "failed to lower PDL pattern module to the PDL Interpreter");
140 // Generate the pdl bytecode.
141 impl
->pdlByteCode
= std::make_unique
<detail::PDLByteCode
>(
142 pdlModule
, pdlPatterns
.takeConfigs(), configMap
,
143 pdlPatterns
.takeConstraintFunctions(),
144 pdlPatterns
.takeRewriteFunctions());
145 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
148 FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;