[MLIR][NVVM] Add support for griddepcontrol Ops (#124603)
[llvm-project.git] / mlir / lib / Rewrite / FrozenRewritePatternSet.cpp
blob17fe02df9f66cda6f43db1f9f390cbd7263ed6ea
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Interfaces/SideEffectInterfaces.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include <optional>
16 using namespace mlir;
18 // Include the PDL rewrite support.
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
21 #include "mlir/Dialect/PDL/IR/PDLOps.h"
23 static LogicalResult
24 convertPDLToPDLInterp(ModuleOp pdlModule,
25 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
26 // Skip the conversion if the module doesn't contain pdl.
27 if (pdlModule.getOps<pdl::PatternOp>().empty())
28 return success();
30 // Simplify the provided PDL module. Note that we can't use the canonicalizer
31 // here because it would create a cyclic dependency.
32 auto simplifyFn = [](Operation *op) {
33 // TODO: Add folding here if ever necessary.
34 if (isOpTriviallyDead(op))
35 op->erase();
37 pdlModule.getBody()->walk(simplifyFn);
39 /// Lower the PDL pattern module to the interpreter dialect.
40 PassManager pdlPipeline(pdlModule->getName());
41 #ifdef NDEBUG
42 // We don't want to incur the hit of running the verifier when in release
43 // mode.
44 pdlPipeline.enableVerifier(false);
45 #endif
46 pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
47 if (failed(pdlPipeline.run(pdlModule)))
48 return failure();
50 // Simplify again after running the lowering pipeline.
51 pdlModule.getBody()->walk(simplifyFn);
52 return success();
54 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
56 //===----------------------------------------------------------------------===//
57 // FrozenRewritePatternSet
58 //===----------------------------------------------------------------------===//
60 FrozenRewritePatternSet::FrozenRewritePatternSet()
61 : impl(std::make_shared<Impl>()) {}
63 FrozenRewritePatternSet::FrozenRewritePatternSet(
64 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
65 ArrayRef<std::string> enabledPatternLabels)
66 : impl(std::make_shared<Impl>()) {
67 DenseSet<StringRef> disabledPatterns, enabledPatterns;
68 disabledPatterns.insert(disabledPatternLabels.begin(),
69 disabledPatternLabels.end());
70 enabledPatterns.insert(enabledPatternLabels.begin(),
71 enabledPatternLabels.end());
73 // Functor used to walk all of the operations registered in the context. This
74 // is useful for patterns that get applied to multiple operations, such as
75 // interface and trait based patterns.
76 std::vector<RegisteredOperationName> opInfos;
77 auto addToOpsWhen =
78 [&](std::unique_ptr<RewritePattern> &pattern,
79 function_ref<bool(RegisteredOperationName)> callbackFn) {
80 if (opInfos.empty())
81 opInfos = pattern->getContext()->getRegisteredOperations();
82 for (RegisteredOperationName info : opInfos)
83 if (callbackFn(info))
84 impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
85 impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
88 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
89 // Don't add patterns that haven't been enabled by the user.
90 if (!enabledPatterns.empty()) {
91 auto isEnabledFn = [&](StringRef label) {
92 return enabledPatterns.count(label);
94 if (!isEnabledFn(pat->getDebugName()) &&
95 llvm::none_of(pat->getDebugLabels(), isEnabledFn))
96 continue;
98 // Don't add patterns that have been disabled by the user.
99 if (!disabledPatterns.empty()) {
100 auto isDisabledFn = [&](StringRef label) {
101 return disabledPatterns.count(label);
103 if (isDisabledFn(pat->getDebugName()) ||
104 llvm::any_of(pat->getDebugLabels(), isDisabledFn))
105 continue;
108 if (std::optional<OperationName> rootName = pat->getRootKind()) {
109 impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
110 impl->nativeOpSpecificPatternList.push_back(std::move(pat));
111 continue;
113 if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
114 addToOpsWhen(pat, [&](RegisteredOperationName info) {
115 return info.hasInterface(*interfaceID);
117 continue;
119 if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
120 addToOpsWhen(pat, [&](RegisteredOperationName info) {
121 return info.hasTrait(*traitID);
123 continue;
125 impl->nativeAnyOpPatterns.push_back(std::move(pat));
128 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
129 // Generate the bytecode for the PDL patterns if any were provided.
130 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
131 ModuleOp pdlModule = pdlPatterns.getModule();
132 if (!pdlModule)
133 return;
134 DenseMap<Operation *, PDLPatternConfigSet *> configMap =
135 pdlPatterns.takeConfigMap();
136 if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
137 llvm::report_fatal_error(
138 "failed to lower PDL pattern module to the PDL Interpreter");
140 // Generate the pdl bytecode.
141 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
142 pdlModule, pdlPatterns.takeConfigs(), configMap,
143 pdlPatterns.takeConstraintFunctions(),
144 pdlPatterns.takeRewriteFunctions());
145 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
148 FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;