1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/DialectRegistry.h"
15 #include "mlir/IR/ExtensibleDialect.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/TypeID.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/SmallVectorExtras.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "llvm/Support/Regex.h"
29 #define DEBUG_TYPE "dialect"
32 using namespace detail
;
34 //===----------------------------------------------------------------------===//
36 //===----------------------------------------------------------------------===//
38 Dialect::Dialect(StringRef name
, MLIRContext
*context
, TypeID id
)
39 : name(name
), dialectID(id
), context(context
) {
40 assert(isValidNamespace(name
) && "invalid dialect namespace");
43 Dialect::~Dialect() = default;
45 /// Verify an attribute from this dialect on the argument at 'argIndex' for
46 /// the region at 'regionIndex' on the given operation. Returns failure if
47 /// the verification failed, success otherwise. This hook may optionally be
48 /// invoked from any operation containing a region.
49 LogicalResult
Dialect::verifyRegionArgAttribute(Operation
*, unsigned, unsigned,
54 /// Verify an attribute from this dialect on the result at 'resultIndex' for
55 /// the region at 'regionIndex' on the given operation. Returns failure if
56 /// the verification failed, success otherwise. This hook may optionally be
57 /// invoked from any operation containing a region.
58 LogicalResult
Dialect::verifyRegionResultAttribute(Operation
*, unsigned,
59 unsigned, NamedAttribute
) {
63 /// Parse an attribute registered to this dialect.
64 Attribute
Dialect::parseAttribute(DialectAsmParser
&parser
, Type type
) const {
65 parser
.emitError(parser
.getNameLoc())
66 << "dialect '" << getNamespace()
67 << "' provides no attribute parsing hook";
71 /// Parse a type registered to this dialect.
72 Type
Dialect::parseType(DialectAsmParser
&parser
) const {
73 // If this dialect allows unknown types, then represent this with OpaqueType.
74 if (allowsUnknownTypes()) {
75 StringAttr ns
= StringAttr::get(getContext(), getNamespace());
76 return OpaqueType::get(ns
, parser
.getFullSymbolSpec());
79 parser
.emitError(parser
.getNameLoc())
80 << "dialect '" << getNamespace() << "' provides no type parsing hook";
84 std::optional
<Dialect::ParseOpHook
>
85 Dialect::getParseOperationHook(StringRef opName
) const {
89 llvm::unique_function
<void(Operation
*, OpAsmPrinter
&printer
)>
90 Dialect::getOperationPrinter(Operation
*op
) const {
91 assert(op
->getDialect() == this &&
92 "Dialect hook invoked on non-dialect owned operation");
96 /// Utility function that returns if the given string is a valid dialect
98 bool Dialect::isValidNamespace(StringRef str
) {
99 llvm::Regex
dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
100 return dialectNameRegex
.match(str
);
103 /// Register a set of dialect interfaces with this dialect instance.
104 void Dialect::addInterface(std::unique_ptr
<DialectInterface
> interface
) {
105 // Handle the case where the models resolve a promised interface.
106 handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface
->getID());
108 auto it
= registeredInterfaces
.try_emplace(interface
->getID(),
109 std::move(interface
));
113 llvm::dbgs() << "[" DEBUG_TYPE
114 "] repeated interface registration for dialect "
120 //===----------------------------------------------------------------------===//
122 //===----------------------------------------------------------------------===//
124 DialectInterface::~DialectInterface() = default;
126 MLIRContext
*DialectInterface::getContext() const {
127 return dialect
->getContext();
130 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
131 MLIRContext
*ctx
, TypeID interfaceKind
, StringRef interfaceName
) {
132 for (auto *dialect
: ctx
->getLoadedDialects()) {
134 dialect
->handleUseOfUndefinedPromisedInterface(
135 dialect
->getTypeID(), interfaceKind
, interfaceName
);
137 if (auto *interface
= dialect
->getRegisteredInterface(interfaceKind
)) {
138 interfaces
.insert(interface
);
139 orderedInterfaces
.push_back(interface
);
144 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
146 /// Get the interface for the dialect of given operation, or null if one
147 /// is not registered.
148 const DialectInterface
*
149 DialectInterfaceCollectionBase::getInterfaceFor(Operation
*op
) const {
150 return getInterfaceFor(op
->getDialect());
153 //===----------------------------------------------------------------------===//
155 //===----------------------------------------------------------------------===//
157 DialectExtensionBase::~DialectExtensionBase() = default;
159 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
160 Dialect
&dialect
, TypeID interfaceRequestorID
, TypeID interfaceID
,
161 StringRef interfaceName
) {
162 dialect
.handleUseOfUndefinedPromisedInterface(interfaceRequestorID
,
163 interfaceID
, interfaceName
);
166 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
167 Dialect
&dialect
, TypeID interfaceRequestorID
, TypeID interfaceID
) {
168 dialect
.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID
,
172 bool dialect_extension_detail::hasPromisedInterface(Dialect
&dialect
,
173 TypeID interfaceRequestorID
,
174 TypeID interfaceID
) {
175 return dialect
.hasPromisedInterface(interfaceRequestorID
, interfaceID
);
178 //===----------------------------------------------------------------------===//
180 //===----------------------------------------------------------------------===//
183 template <typename Fn
>
184 void applyExtensionsFn(
186 const llvm::MapVector
<TypeID
, std::unique_ptr
<DialectExtensionBase
>>
188 // Note: Additional extensions may be added while applying an extension.
189 // The iterators will be invalidated if extensions are added so we'll keep
190 // a copy of the extensions for ourselves.
192 const auto extractExtension
=
193 [](const auto &entry
) -> DialectExtensionBase
* {
194 return entry
.second
.get();
197 auto startIt
= extensions
.begin(), endIt
= extensions
.end();
199 while (startIt
!= endIt
) {
200 count
+= endIt
- startIt
;
202 // Grab the subset of extensions we'll apply in this iteration.
204 llvm::map_to_vector(llvm::make_range(startIt
, endIt
), extractExtension
);
206 for (const auto *ext
: subset
)
207 applyExtension(*ext
);
209 // Book-keep for the next iteration.
210 startIt
= extensions
.begin() + count
;
211 endIt
= extensions
.end();
216 DialectRegistry::DialectRegistry() { insert
<BuiltinDialect
>(); }
218 DialectAllocatorFunctionRef
219 DialectRegistry::getDialectAllocator(StringRef name
) const {
220 auto it
= registry
.find(name
);
221 if (it
== registry
.end())
223 return it
->second
.second
;
226 void DialectRegistry::insert(TypeID typeID
, StringRef name
,
227 const DialectAllocatorFunction
&ctor
) {
228 auto inserted
= registry
.insert(
229 std::make_pair(std::string(name
), std::make_pair(typeID
, ctor
)));
230 if (!inserted
.second
&& inserted
.first
->second
.first
!= typeID
) {
231 llvm::report_fatal_error(
232 "Trying to register different dialects for the same namespace: " +
237 void DialectRegistry::insertDynamic(
238 StringRef name
, const DynamicDialectPopulationFunction
&ctor
) {
239 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
240 // dialect yet, since the TypeID of a dynamic dialect is defined at its
242 TypeID typeID
= TypeID::get
<void>();
244 // Create the dialect, and then call ctor, which allocates its components.
245 auto constructor
= [nameStr
= name
.str(), ctor
](MLIRContext
*ctx
) {
246 auto *dynDialect
= ctx
->getOrLoadDynamicDialect(
247 nameStr
, [ctx
, ctor
](DynamicDialect
*dialect
) { ctor(ctx
, dialect
); });
248 assert(dynDialect
&& "Dynamic dialect creation unexpectedly failed");
252 insert(typeID
, name
, constructor
);
255 void DialectRegistry::applyExtensions(Dialect
*dialect
) const {
256 MLIRContext
*ctx
= dialect
->getContext();
257 StringRef dialectName
= dialect
->getNamespace();
259 // Functor used to try to apply the given extension.
260 auto applyExtension
= [&](const DialectExtensionBase
&extension
) {
261 ArrayRef
<StringRef
> dialectNames
= extension
.getRequiredDialects();
262 // An empty set is equivalent to always invoke.
263 if (dialectNames
.empty()) {
264 extension
.apply(ctx
, dialect
);
268 // Handle the simple case of a single dialect name. In this case, the
269 // required dialect should be the current dialect.
270 if (dialectNames
.size() == 1) {
271 if (dialectNames
.front() == dialectName
)
272 extension
.apply(ctx
, dialect
);
276 // Otherwise, check to see if this extension requires this dialect.
277 const StringRef
*nameIt
= llvm::find(dialectNames
, dialectName
);
278 if (nameIt
== dialectNames
.end())
281 // If it does, ensure that all of the other required dialects have been
283 SmallVector
<Dialect
*> requiredDialects
;
284 requiredDialects
.reserve(dialectNames
.size());
285 for (auto it
= dialectNames
.begin(), e
= dialectNames
.end(); it
!= e
;
287 // The current dialect is known to be loaded.
289 requiredDialects
.push_back(dialect
);
292 // Otherwise, check if it is loaded.
293 Dialect
*loadedDialect
= ctx
->getLoadedDialect(*it
);
296 requiredDialects
.push_back(loadedDialect
);
298 extension
.apply(ctx
, requiredDialects
);
301 applyExtensionsFn(applyExtension
, extensions
);
304 void DialectRegistry::applyExtensions(MLIRContext
*ctx
) const {
305 // Functor used to try to apply the given extension.
306 auto applyExtension
= [&](const DialectExtensionBase
&extension
) {
307 ArrayRef
<StringRef
> dialectNames
= extension
.getRequiredDialects();
308 if (dialectNames
.empty()) {
309 auto loadedDialects
= ctx
->getLoadedDialects();
310 extension
.apply(ctx
, loadedDialects
);
314 // Check to see if all of the dialects for this extension are loaded.
315 SmallVector
<Dialect
*> requiredDialects
;
316 requiredDialects
.reserve(dialectNames
.size());
317 for (StringRef dialectName
: dialectNames
) {
318 Dialect
*loadedDialect
= ctx
->getLoadedDialect(dialectName
);
321 requiredDialects
.push_back(loadedDialect
);
323 extension
.apply(ctx
, requiredDialects
);
326 applyExtensionsFn(applyExtension
, extensions
);
329 bool DialectRegistry::isSubsetOf(const DialectRegistry
&rhs
) const {
330 // Check that all extension keys are present in 'rhs'.
331 const auto hasExtension
= [&](const auto &key
) {
332 return rhs
.extensions
.contains(key
);
334 if (!llvm::all_of(make_first_range(extensions
), hasExtension
))
337 // Check that the current dialects fully overlap with the dialects in 'rhs'.
339 registry
, [&](const auto &it
) { return rhs
.registry
.count(it
.first
); });