[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Query / Matcher / RegistryManager.cpp
blob645db7109c2deb3a51628653bd5c8fb4546c5839
1 //===- RegistryManager.cpp - Matcher registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Registry map populated at static initialization time.
11 //===----------------------------------------------------------------------===//
13 #include "RegistryManager.h"
14 #include "mlir/Query/Matcher/Registry.h"
16 #include <set>
17 #include <utility>
19 namespace mlir::query::matcher {
20 namespace {
22 // This is needed because these matchers are defined as overloaded functions.
23 using IsConstantOp = detail::constant_op_matcher();
24 using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
25 using HasOpName = detail::NameOpMatcher(llvm::StringRef);
27 // Enum to string for autocomplete.
28 static std::string asArgString(ArgKind kind) {
29 switch (kind) {
30 case ArgKind::Matcher:
31 return "Matcher";
32 case ArgKind::String:
33 return "String";
35 llvm_unreachable("Unhandled ArgKind");
38 } // namespace
40 void Registry::registerMatcherDescriptor(
41 llvm::StringRef matcherName,
42 std::unique_ptr<internal::MatcherDescriptor> callback) {
43 assert(!constructorMap.contains(matcherName));
44 constructorMap[matcherName] = std::move(callback);
47 std::optional<MatcherCtor>
48 RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
49 const Registry &matcherRegistry) {
50 auto it = matcherRegistry.constructors().find(matcherName);
51 return it == matcherRegistry.constructors().end()
52 ? std::optional<MatcherCtor>()
53 : it->second.get();
56 std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
57 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
58 // Starting with the above seed of acceptable top-level matcher types, compute
59 // the acceptable type set for the argument indicated by each context element.
60 std::set<ArgKind> typeSet;
61 typeSet.insert(ArgKind::Matcher);
63 for (const auto &ctxEntry : context) {
64 MatcherCtor ctor = ctxEntry.first;
65 unsigned argNumber = ctxEntry.second;
66 std::vector<ArgKind> nextTypeSet;
68 if (argNumber < ctor->getNumArgs())
69 ctor->getArgKinds(argNumber, nextTypeSet);
71 typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
74 return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
77 std::vector<MatcherCompletion>
78 RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
79 const Registry &matcherRegistry) {
80 std::vector<MatcherCompletion> completions;
82 // Search the registry for acceptable matchers.
83 for (const auto &m : matcherRegistry.constructors()) {
84 const internal::MatcherDescriptor &matcher = *m.getValue();
85 llvm::StringRef name = m.getKey();
87 unsigned numArgs = matcher.getNumArgs();
88 std::vector<std::vector<ArgKind>> argKinds(numArgs);
90 for (const ArgKind &kind : acceptedTypes) {
91 if (kind != ArgKind::Matcher)
92 continue;
94 for (unsigned arg = 0; arg != numArgs; ++arg)
95 matcher.getArgKinds(arg, argKinds[arg]);
98 std::string decl;
99 llvm::raw_string_ostream os(decl);
101 std::string typedText = std::string(name);
102 os << "Matcher: " << name << "(";
104 for (const std::vector<ArgKind> &arg : argKinds) {
105 if (&arg != &argKinds[0])
106 os << ", ";
108 bool firstArgKind = true;
109 // Two steps. First all non-matchers, then matchers only.
110 for (const ArgKind &argKind : arg) {
111 if (!firstArgKind)
112 os << "|";
114 firstArgKind = false;
115 os << asArgString(argKind);
119 os << ")";
120 typedText += "(";
122 if (argKinds.empty())
123 typedText += ")";
124 else if (argKinds[0][0] == ArgKind::String)
125 typedText += "\"";
127 completions.emplace_back(typedText, decl);
130 return completions;
133 VariantMatcher RegistryManager::constructMatcher(
134 MatcherCtor ctor, internal::SourceRange nameRange,
135 llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136 internal::Diagnostics *error) {
137 VariantMatcher out = ctor->create(nameRange, args, error);
138 if (functionName.empty() || out.isNull())
139 return out;
141 if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142 result->setFunctionName(functionName);
143 return VariantMatcher::SingleMatcher(*result);
146 error->addError(nameRange, internal::ErrorType::RegistryNotBindable);
147 return {};
150 } // namespace mlir::query::matcher