[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / IR / Dialect.cpp
blob965386681f270942a865f3d7d05734c75e299dd5
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/ExtensibleDialect.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Operation.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/Regex.h"
23 #define DEBUG_TYPE "dialect"
25 using namespace mlir;
26 using namespace detail;
28 //===----------------------------------------------------------------------===//
29 // Dialect
30 //===----------------------------------------------------------------------===//
32 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
33 : name(name), dialectID(id), context(context) {
34 assert(isValidNamespace(name) && "invalid dialect namespace");
37 Dialect::~Dialect() = default;
39 /// Verify an attribute from this dialect on the argument at 'argIndex' for
40 /// the region at 'regionIndex' on the given operation. Returns failure if
41 /// the verification failed, success otherwise. This hook may optionally be
42 /// invoked from any operation containing a region.
43 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
44 NamedAttribute) {
45 return success();
48 /// Verify an attribute from this dialect on the result at 'resultIndex' for
49 /// the region at 'regionIndex' on the given operation. Returns failure if
50 /// the verification failed, success otherwise. This hook may optionally be
51 /// invoked from any operation containing a region.
52 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
53 unsigned, NamedAttribute) {
54 return success();
57 /// Parse an attribute registered to this dialect.
58 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
59 parser.emitError(parser.getNameLoc())
60 << "dialect '" << getNamespace()
61 << "' provides no attribute parsing hook";
62 return Attribute();
65 /// Parse a type registered to this dialect.
66 Type Dialect::parseType(DialectAsmParser &parser) const {
67 // If this dialect allows unknown types, then represent this with OpaqueType.
68 if (allowsUnknownTypes()) {
69 StringAttr ns = StringAttr::get(getContext(), getNamespace());
70 return OpaqueType::get(ns, parser.getFullSymbolSpec());
73 parser.emitError(parser.getNameLoc())
74 << "dialect '" << getNamespace() << "' provides no type parsing hook";
75 return Type();
78 std::optional<Dialect::ParseOpHook>
79 Dialect::getParseOperationHook(StringRef opName) const {
80 return std::nullopt;
83 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
84 Dialect::getOperationPrinter(Operation *op) const {
85 assert(op->getDialect() == this &&
86 "Dialect hook invoked on non-dialect owned operation");
87 return nullptr;
90 /// Utility function that returns if the given string is a valid dialect
91 /// namespace
92 bool Dialect::isValidNamespace(StringRef str) {
93 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
94 return dialectNameRegex.match(str);
97 /// Register a set of dialect interfaces with this dialect instance.
98 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
99 // Handle the case where the models resolve a promised interface.
100 handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
102 auto it = registeredInterfaces.try_emplace(interface->getID(),
103 std::move(interface));
104 (void)it;
105 LLVM_DEBUG({
106 if (!it.second) {
107 llvm::dbgs() << "[" DEBUG_TYPE
108 "] repeated interface registration for dialect "
109 << getNamespace();
114 //===----------------------------------------------------------------------===//
115 // Dialect Interface
116 //===----------------------------------------------------------------------===//
118 DialectInterface::~DialectInterface() = default;
120 MLIRContext *DialectInterface::getContext() const {
121 return dialect->getContext();
124 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126 for (auto *dialect : ctx->getLoadedDialects()) {
127 #ifndef NDEBUG
128 dialect->handleUseOfUndefinedPromisedInterface(
129 dialect->getTypeID(), interfaceKind, interfaceName);
130 #endif
131 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
132 interfaces.insert(interface);
133 orderedInterfaces.push_back(interface);
138 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
140 /// Get the interface for the dialect of given operation, or null if one
141 /// is not registered.
142 const DialectInterface *
143 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
144 return getInterfaceFor(op->getDialect());
147 //===----------------------------------------------------------------------===//
148 // DialectExtension
149 //===----------------------------------------------------------------------===//
151 DialectExtensionBase::~DialectExtensionBase() = default;
153 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
154 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
155 StringRef interfaceName) {
156 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
157 interfaceID, interfaceName);
160 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
161 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
162 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
163 interfaceID);
166 bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
167 TypeID interfaceRequestorID,
168 TypeID interfaceID) {
169 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
172 //===----------------------------------------------------------------------===//
173 // DialectRegistry
174 //===----------------------------------------------------------------------===//
176 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
178 DialectAllocatorFunctionRef
179 DialectRegistry::getDialectAllocator(StringRef name) const {
180 auto it = registry.find(name.str());
181 if (it == registry.end())
182 return nullptr;
183 return it->second.second;
186 void DialectRegistry::insert(TypeID typeID, StringRef name,
187 const DialectAllocatorFunction &ctor) {
188 auto inserted = registry.insert(
189 std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
190 if (!inserted.second && inserted.first->second.first != typeID) {
191 llvm::report_fatal_error(
192 "Trying to register different dialects for the same namespace: " +
193 name);
197 void DialectRegistry::insertDynamic(
198 StringRef name, const DynamicDialectPopulationFunction &ctor) {
199 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
200 // dialect yet, since the TypeID of a dynamic dialect is defined at its
201 // construction.
202 TypeID typeID = TypeID::get<void>();
204 // Create the dialect, and then call ctor, which allocates its components.
205 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
206 auto *dynDialect = ctx->getOrLoadDynamicDialect(
207 nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
208 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
209 return dynDialect;
212 insert(typeID, name, constructor);
215 void DialectRegistry::applyExtensions(Dialect *dialect) const {
216 MLIRContext *ctx = dialect->getContext();
217 StringRef dialectName = dialect->getNamespace();
219 // Functor used to try to apply the given extension.
220 auto applyExtension = [&](const DialectExtensionBase &extension) {
221 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
222 // An empty set is equivalent to always invoke.
223 if (dialectNames.empty()) {
224 extension.apply(ctx, dialect);
225 return;
228 // Handle the simple case of a single dialect name. In this case, the
229 // required dialect should be the current dialect.
230 if (dialectNames.size() == 1) {
231 if (dialectNames.front() == dialectName)
232 extension.apply(ctx, dialect);
233 return;
236 // Otherwise, check to see if this extension requires this dialect.
237 const StringRef *nameIt = llvm::find(dialectNames, dialectName);
238 if (nameIt == dialectNames.end())
239 return;
241 // If it does, ensure that all of the other required dialects have been
242 // loaded.
243 SmallVector<Dialect *> requiredDialects;
244 requiredDialects.reserve(dialectNames.size());
245 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
246 ++it) {
247 // The current dialect is known to be loaded.
248 if (it == nameIt) {
249 requiredDialects.push_back(dialect);
250 continue;
252 // Otherwise, check if it is loaded.
253 Dialect *loadedDialect = ctx->getLoadedDialect(*it);
254 if (!loadedDialect)
255 return;
256 requiredDialects.push_back(loadedDialect);
258 extension.apply(ctx, requiredDialects);
261 // Note: Additional extensions may be added while applying an extension.
262 for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
263 applyExtension(*extensions[i]);
266 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
267 // Functor used to try to apply the given extension.
268 auto applyExtension = [&](const DialectExtensionBase &extension) {
269 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
270 if (dialectNames.empty()) {
271 auto loadedDialects = ctx->getLoadedDialects();
272 extension.apply(ctx, loadedDialects);
273 return;
276 // Check to see if all of the dialects for this extension are loaded.
277 SmallVector<Dialect *> requiredDialects;
278 requiredDialects.reserve(dialectNames.size());
279 for (StringRef dialectName : dialectNames) {
280 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
281 if (!loadedDialect)
282 return;
283 requiredDialects.push_back(loadedDialect);
285 extension.apply(ctx, requiredDialects);
288 // Note: Additional extensions may be added while applying an extension.
289 for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
290 applyExtension(*extensions[i]);
293 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
294 // Treat any extensions conservatively.
295 if (!extensions.empty())
296 return false;
297 // Check that the current dialects fully overlap with the dialects in 'rhs'.
298 return llvm::all_of(
299 registry, [&](const auto &it) { return rhs.registry.count(it.first); });