1 //===- RegistryManager.cpp - Matcher registry -----------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Registry map populated at static initialization time.
11 //===----------------------------------------------------------------------===//
13 #include "RegistryManager.h"
14 #include "mlir/Query/Matcher/Registry.h"
19 namespace mlir::query::matcher
{
22 // This is needed because these matchers are defined as overloaded functions.
23 using IsConstantOp
= detail::constant_op_matcher();
24 using HasOpAttrName
= detail::AttrOpMatcher(llvm::StringRef
);
25 using HasOpName
= detail::NameOpMatcher(llvm::StringRef
);
27 // Enum to string for autocomplete.
28 static std::string
asArgString(ArgKind kind
) {
30 case ArgKind::Matcher
:
35 llvm_unreachable("Unhandled ArgKind");
40 void Registry::registerMatcherDescriptor(
41 llvm::StringRef matcherName
,
42 std::unique_ptr
<internal::MatcherDescriptor
> callback
) {
43 assert(!constructorMap
.contains(matcherName
));
44 constructorMap
[matcherName
] = std::move(callback
);
47 std::optional
<MatcherCtor
>
48 RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName
,
49 const Registry
&matcherRegistry
) {
50 auto it
= matcherRegistry
.constructors().find(matcherName
);
51 return it
== matcherRegistry
.constructors().end()
52 ? std::optional
<MatcherCtor
>()
56 std::vector
<ArgKind
> RegistryManager::getAcceptedCompletionTypes(
57 llvm::ArrayRef
<std::pair
<MatcherCtor
, unsigned>> context
) {
58 // Starting with the above seed of acceptable top-level matcher types, compute
59 // the acceptable type set for the argument indicated by each context element.
60 std::set
<ArgKind
> typeSet
;
61 typeSet
.insert(ArgKind::Matcher
);
63 for (const auto &ctxEntry
: context
) {
64 MatcherCtor ctor
= ctxEntry
.first
;
65 unsigned argNumber
= ctxEntry
.second
;
66 std::vector
<ArgKind
> nextTypeSet
;
68 if (argNumber
< ctor
->getNumArgs())
69 ctor
->getArgKinds(argNumber
, nextTypeSet
);
71 typeSet
.insert(nextTypeSet
.begin(), nextTypeSet
.end());
74 return std::vector
<ArgKind
>(typeSet
.begin(), typeSet
.end());
77 std::vector
<MatcherCompletion
>
78 RegistryManager::getMatcherCompletions(llvm::ArrayRef
<ArgKind
> acceptedTypes
,
79 const Registry
&matcherRegistry
) {
80 std::vector
<MatcherCompletion
> completions
;
82 // Search the registry for acceptable matchers.
83 for (const auto &m
: matcherRegistry
.constructors()) {
84 const internal::MatcherDescriptor
&matcher
= *m
.getValue();
85 llvm::StringRef name
= m
.getKey();
87 unsigned numArgs
= matcher
.getNumArgs();
88 std::vector
<std::vector
<ArgKind
>> argKinds(numArgs
);
90 for (const ArgKind
&kind
: acceptedTypes
) {
91 if (kind
!= ArgKind::Matcher
)
94 for (unsigned arg
= 0; arg
!= numArgs
; ++arg
)
95 matcher
.getArgKinds(arg
, argKinds
[arg
]);
99 llvm::raw_string_ostream
os(decl
);
101 std::string typedText
= std::string(name
);
102 os
<< "Matcher: " << name
<< "(";
104 for (const std::vector
<ArgKind
> &arg
: argKinds
) {
105 if (&arg
!= &argKinds
[0])
108 bool firstArgKind
= true;
109 // Two steps. First all non-matchers, then matchers only.
110 for (const ArgKind
&argKind
: arg
) {
114 firstArgKind
= false;
115 os
<< asArgString(argKind
);
122 if (argKinds
.empty())
124 else if (argKinds
[0][0] == ArgKind::String
)
127 completions
.emplace_back(typedText
, decl
);
133 VariantMatcher
RegistryManager::constructMatcher(
134 MatcherCtor ctor
, internal::SourceRange nameRange
,
135 llvm::StringRef functionName
, llvm::ArrayRef
<ParserValue
> args
,
136 internal::Diagnostics
*error
) {
137 VariantMatcher out
= ctor
->create(nameRange
, args
, error
);
138 if (functionName
.empty() || out
.isNull())
141 if (std::optional
<DynMatcher
> result
= out
.getDynMatcher()) {
142 result
->setFunctionName(functionName
);
143 return VariantMatcher::SingleMatcher(*result
);
146 error
->addError(nameRange
, internal::ErrorType::RegistryNotBindable
);
150 } // namespace mlir::query::matcher