[OptTable] Fix typo VALUE => VALUES (NFCI) (#121523)
[llvm-project.git] / mlir / lib / IR / MLIRContext.cpp
blobb9e745fdf4a13ee6071c6964a395d4a84a39af1b
1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/Action.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/ExtensibleDialect.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseSet.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/Twine.h"
34 #include "llvm/Support/Allocator.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Compiler.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/Mutex.h"
39 #include "llvm/Support/RWMutex.h"
40 #include "llvm/Support/ThreadPool.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include <memory>
43 #include <optional>
45 #define DEBUG_TYPE "mlircontext"
47 using namespace mlir;
48 using namespace mlir::detail;
50 //===----------------------------------------------------------------------===//
51 // MLIRContext CommandLine Options
52 //===----------------------------------------------------------------------===//
54 namespace {
55 /// This struct contains command line options that can be used to initialize
56 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
57 /// for global command line options.
58 struct MLIRContextOptions {
59 llvm::cl::opt<bool> disableThreading{
60 "mlir-disable-threading",
61 llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
62 "further call to MLIRContext::enableMultiThreading()")};
64 llvm::cl::opt<bool> printOpOnDiagnostic{
65 "mlir-print-op-on-diagnostic",
66 llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
67 "the operation as an attached note"),
68 llvm::cl::init(true)};
70 llvm::cl::opt<bool> printStackTraceOnDiagnostic{
71 "mlir-print-stacktrace-on-diagnostic",
72 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
73 "as an attached note")};
75 } // namespace
77 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
79 static bool isThreadingGloballyDisabled() {
80 #if LLVM_ENABLE_THREADS != 0
81 return clOptions.isConstructed() && clOptions->disableThreading;
82 #else
83 return true;
84 #endif
87 /// Register a set of useful command-line options that can be used to configure
88 /// various flags within the MLIRContext. These flags are used when constructing
89 /// an MLIR context for initialization.
90 void mlir::registerMLIRContextCLOptions() {
91 // Make sure that the options struct has been initialized.
92 *clOptions;
95 //===----------------------------------------------------------------------===//
96 // Locking Utilities
97 //===----------------------------------------------------------------------===//
99 namespace {
100 /// Utility writer lock that takes a runtime flag that specifies if we really
101 /// need to lock.
102 struct ScopedWriterLock {
103 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
104 : mutex(shouldLock ? &mutexParam : nullptr) {
105 if (mutex)
106 mutex->lock();
108 ~ScopedWriterLock() {
109 if (mutex)
110 mutex->unlock();
112 llvm::sys::SmartRWMutex<true> *mutex;
114 } // namespace
116 //===----------------------------------------------------------------------===//
117 // MLIRContextImpl
118 //===----------------------------------------------------------------------===//
120 namespace mlir {
121 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
122 /// This class is completely private to this file, so everything is public.
123 class MLIRContextImpl {
124 public:
125 //===--------------------------------------------------------------------===//
126 // Debugging
127 //===--------------------------------------------------------------------===//
129 /// An action handler for handling actions that are dispatched through this
130 /// context.
131 std::function<void(function_ref<void()>, const tracing::Action &)>
132 actionHandler;
134 //===--------------------------------------------------------------------===//
135 // Diagnostics
136 //===--------------------------------------------------------------------===//
137 DiagnosticEngine diagEngine;
139 //===--------------------------------------------------------------------===//
140 // Options
141 //===--------------------------------------------------------------------===//
143 /// In most cases, creating operation in unregistered dialect is not desired
144 /// and indicate a misconfiguration of the compiler. This option enables to
145 /// detect such use cases
146 bool allowUnregisteredDialects = false;
148 /// Enable support for multi-threading within MLIR.
149 bool threadingIsEnabled = true;
151 /// Track if we are currently executing in a threaded execution environment
152 /// (like the pass-manager): this is only a debugging feature to help reducing
153 /// the chances of data races one some context APIs.
154 #ifndef NDEBUG
155 std::atomic<int> multiThreadedExecutionContext{0};
156 #endif
158 /// If the operation should be attached to diagnostics printed via the
159 /// Operation::emit methods.
160 bool printOpOnDiagnostic = true;
162 /// If the current stack trace should be attached when emitting diagnostics.
163 bool printStackTraceOnDiagnostic = false;
165 //===--------------------------------------------------------------------===//
166 // Other
167 //===--------------------------------------------------------------------===//
169 /// This points to the ThreadPool used when processing MLIR tasks in parallel.
170 /// It can't be nullptr when multi-threading is enabled. Otherwise if
171 /// multi-threading is disabled, and the threadpool wasn't externally provided
172 /// using `setThreadPool`, this will be nullptr.
173 llvm::ThreadPoolInterface *threadPool = nullptr;
175 /// In case where the thread pool is owned by the context, this ensures
176 /// destruction with the context.
177 std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
179 /// An allocator used for AbstractAttribute and AbstractType objects.
180 llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
182 /// This is a mapping from operation name to the operation info describing it.
183 llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
185 /// A vector of operation info specifically for registered operations.
186 llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
187 llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
189 /// This is a sorted container of registered operations for a deterministic
190 /// and efficient `getRegisteredOperations` implementation.
191 SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
193 /// This is a list of dialects that are created referring to this context.
194 /// The MLIRContext owns the objects. These need to be declared after the
195 /// registered operations to ensure correct destruction order.
196 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
197 DialectRegistry dialectsRegistry;
199 /// A mutex used when accessing operation information.
200 llvm::sys::SmartRWMutex<true> operationInfoMutex;
202 //===--------------------------------------------------------------------===//
203 // Affine uniquing
204 //===--------------------------------------------------------------------===//
206 // Affine expression, map and integer set uniquing.
207 StorageUniquer affineUniquer;
209 //===--------------------------------------------------------------------===//
210 // Type uniquing
211 //===--------------------------------------------------------------------===//
213 DenseMap<TypeID, AbstractType *> registeredTypes;
214 StorageUniquer typeUniquer;
216 /// This is a mapping from type name to the abstract type describing it.
217 /// It is used by `AbstractType::lookup` to get an `AbstractType` from a name.
218 /// As this map needs to be populated before `StringAttr` is loaded, we
219 /// cannot use `StringAttr` as the key. The context does not take ownership
220 /// of the key, so the `StringRef` must outlive the context.
221 llvm::DenseMap<StringRef, AbstractType *> nameToType;
223 /// Cached Type Instances.
224 Float4E2M1FNType f4E2M1FNTy;
225 Float6E2M3FNType f6E2M3FNTy;
226 Float6E3M2FNType f6E3M2FNTy;
227 Float8E5M2Type f8E5M2Ty;
228 Float8E4M3Type f8E4M3Ty;
229 Float8E4M3FNType f8E4M3FNTy;
230 Float8E5M2FNUZType f8E5M2FNUZTy;
231 Float8E4M3FNUZType f8E4M3FNUZTy;
232 Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
233 Float8E3M4Type f8E3M4Ty;
234 Float8E8M0FNUType f8E8M0FNUTy;
235 BFloat16Type bf16Ty;
236 Float16Type f16Ty;
237 FloatTF32Type tf32Ty;
238 Float32Type f32Ty;
239 Float64Type f64Ty;
240 Float80Type f80Ty;
241 Float128Type f128Ty;
242 IndexType indexTy;
243 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
244 NoneType noneType;
246 //===--------------------------------------------------------------------===//
247 // Attribute uniquing
248 //===--------------------------------------------------------------------===//
250 DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
251 StorageUniquer attributeUniquer;
253 /// This is a mapping from attribute name to the abstract attribute describing
254 /// it. It is used by `AbstractType::lookup` to get an `AbstractType` from a
255 /// name.
256 /// As this map needs to be populated before `StringAttr` is loaded, we
257 /// cannot use `StringAttr` as the key. The context does not take ownership
258 /// of the key, so the `StringRef` must outlive the context.
259 llvm::DenseMap<StringRef, AbstractAttribute *> nameToAttribute;
261 /// Cached Attribute Instances.
262 BoolAttr falseAttr, trueAttr;
263 UnitAttr unitAttr;
264 UnknownLoc unknownLocAttr;
265 DictionaryAttr emptyDictionaryAttr;
266 StringAttr emptyStringAttr;
268 /// Map of string attributes that may reference a dialect, that are awaiting
269 /// that dialect to be loaded.
270 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
271 DenseMap<StringRef, SmallVector<StringAttrStorage *>>
272 dialectReferencingStrAttrs;
274 /// A distinct attribute allocator that allocates every time since the
275 /// address of the distinct attribute storage serves as unique identifier. The
276 /// allocator is thread safe and frees the allocated storage after its
277 /// destruction.
278 DistinctAttributeAllocator distinctAttributeAllocator;
280 public:
281 MLIRContextImpl(bool threadingIsEnabled)
282 : threadingIsEnabled(threadingIsEnabled) {
283 if (threadingIsEnabled) {
284 ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
285 threadPool = ownedThreadPool.get();
288 ~MLIRContextImpl() {
289 for (auto typeMapping : registeredTypes)
290 typeMapping.second->~AbstractType();
291 for (auto attrMapping : registeredAttributes)
292 attrMapping.second->~AbstractAttribute();
295 } // namespace mlir
297 MLIRContext::MLIRContext(Threading setting)
298 : MLIRContext(DialectRegistry(), setting) {}
300 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
301 : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
302 !isThreadingGloballyDisabled())) {
303 // Initialize values based on the command line flags if they were provided.
304 if (clOptions.isConstructed()) {
305 printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
306 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
309 // Pre-populate the registry.
310 registry.appendTo(impl->dialectsRegistry);
312 // Ensure the builtin dialect is always pre-loaded.
313 getOrLoadDialect<BuiltinDialect>();
315 // Initialize several common attributes and types to avoid the need to lock
316 // the context when accessing them.
318 //// Types.
319 /// Floating-point Types.
320 impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
321 impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
322 impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
323 impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
324 impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
325 impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
326 impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
327 impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
328 impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
329 impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
330 impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
331 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
332 impl->f16Ty = TypeUniquer::get<Float16Type>(this);
333 impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
334 impl->f32Ty = TypeUniquer::get<Float32Type>(this);
335 impl->f64Ty = TypeUniquer::get<Float64Type>(this);
336 impl->f80Ty = TypeUniquer::get<Float80Type>(this);
337 impl->f128Ty = TypeUniquer::get<Float128Type>(this);
338 /// Index Type.
339 impl->indexTy = TypeUniquer::get<IndexType>(this);
340 /// Integer Types.
341 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
342 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
343 impl->int16Ty =
344 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
345 impl->int32Ty =
346 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
347 impl->int64Ty =
348 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
349 impl->int128Ty =
350 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
351 /// None Type.
352 impl->noneType = TypeUniquer::get<NoneType>(this);
354 //// Attributes.
355 //// Note: These must be registered after the types as they may generate one
356 //// of the above types internally.
357 /// Unknown Location Attribute.
358 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
359 /// Bool Attributes.
360 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
361 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
362 /// Unit Attribute.
363 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
364 /// The empty dictionary attribute.
365 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
366 /// The empty string attribute.
367 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
369 // Register the affine storage objects with the uniquer.
370 impl->affineUniquer
371 .registerParametricStorageType<AffineBinaryOpExprStorage>();
372 impl->affineUniquer
373 .registerParametricStorageType<AffineConstantExprStorage>();
374 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
375 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
376 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
379 MLIRContext::~MLIRContext() = default;
381 /// Copy the specified array of elements into memory managed by the provided
382 /// bump pointer allocator. This assumes the elements are all PODs.
383 template <typename T>
384 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
385 ArrayRef<T> elements) {
386 auto result = allocator.Allocate<T>(elements.size());
387 std::uninitialized_copy(elements.begin(), elements.end(), result);
388 return ArrayRef<T>(result, elements.size());
391 //===----------------------------------------------------------------------===//
392 // Action Handling
393 //===----------------------------------------------------------------------===//
395 void MLIRContext::registerActionHandler(HandlerTy handler) {
396 getImpl().actionHandler = std::move(handler);
399 /// Dispatch the provided action to the handler if any, or just execute it.
400 void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
401 const tracing::Action &action) {
402 assert(getImpl().actionHandler);
403 getImpl().actionHandler(actionFn, action);
406 bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
408 //===----------------------------------------------------------------------===//
409 // Diagnostic Handlers
410 //===----------------------------------------------------------------------===//
412 /// Returns the diagnostic engine for this context.
413 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
415 //===----------------------------------------------------------------------===//
416 // Dialect and Operation Registration
417 //===----------------------------------------------------------------------===//
419 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
420 if (registry.isSubsetOf(impl->dialectsRegistry))
421 return;
423 assert(impl->multiThreadedExecutionContext == 0 &&
424 "appending to the MLIRContext dialect registry while in a "
425 "multi-threaded execution context");
426 registry.appendTo(impl->dialectsRegistry);
428 // For the already loaded dialects, apply any possible extensions immediately.
429 registry.applyExtensions(this);
432 const DialectRegistry &MLIRContext::getDialectRegistry() {
433 return impl->dialectsRegistry;
436 /// Return information about all registered IR dialects.
437 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
438 std::vector<Dialect *> result;
439 result.reserve(impl->loadedDialects.size());
440 for (auto &dialect : impl->loadedDialects)
441 result.push_back(dialect.second.get());
442 llvm::array_pod_sort(result.begin(), result.end(),
443 [](Dialect *const *lhs, Dialect *const *rhs) -> int {
444 return (*lhs)->getNamespace() < (*rhs)->getNamespace();
446 return result;
448 std::vector<StringRef> MLIRContext::getAvailableDialects() {
449 std::vector<StringRef> result;
450 for (auto dialect : impl->dialectsRegistry.getDialectNames())
451 result.push_back(dialect);
452 return result;
455 /// Get a registered IR dialect with the given namespace. If none is found,
456 /// then return nullptr.
457 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
458 // Dialects are sorted by name, so we can use binary search for lookup.
459 auto it = impl->loadedDialects.find(name);
460 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
463 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
464 Dialect *dialect = getLoadedDialect(name);
465 if (dialect)
466 return dialect;
467 DialectAllocatorFunctionRef allocator =
468 impl->dialectsRegistry.getDialectAllocator(name);
469 return allocator ? allocator(this) : nullptr;
472 /// Get a dialect for the provided namespace and TypeID: abort the program if a
473 /// dialect exist for this namespace with different TypeID. Returns a pointer to
474 /// the dialect owned by the context.
475 Dialect *
476 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
477 function_ref<std::unique_ptr<Dialect>()> ctor) {
478 auto &impl = getImpl();
479 // Get the correct insertion position sorted by namespace.
480 auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);
482 if (dialectIt.second) {
483 LLVM_DEBUG(llvm::dbgs()
484 << "Load new dialect in Context " << dialectNamespace << "\n");
485 #ifndef NDEBUG
486 if (impl.multiThreadedExecutionContext != 0)
487 llvm::report_fatal_error(
488 "Loading a dialect (" + dialectNamespace +
489 ") while in a multi-threaded execution context (maybe "
490 "the PassManager): this can indicate a "
491 "missing `dependentDialects` in a pass for example.");
492 #endif // NDEBUG
493 // loadedDialects entry is initialized to nullptr, indicating that the
494 // dialect is currently being loaded. Re-lookup the address in
495 // loadedDialects because the table might have been rehashed by recursive
496 // dialect loading in ctor().
497 std::unique_ptr<Dialect> &dialectOwned =
498 impl.loadedDialects[dialectNamespace] = ctor();
499 Dialect *dialect = dialectOwned.get();
500 assert(dialect && "dialect ctor failed");
502 // Refresh all the identifiers dialect field, this catches cases where a
503 // dialect may be loaded after identifier prefixed with this dialect name
504 // were already created.
505 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
506 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
507 for (StringAttrStorage *storage : stringAttrsIt->second)
508 storage->referencedDialect = dialect;
509 impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
512 // Apply any extensions to this newly loaded dialect.
513 impl.dialectsRegistry.applyExtensions(dialect);
514 return dialect;
517 #ifndef NDEBUG
518 if (dialectIt.first->second == nullptr)
519 llvm::report_fatal_error(
520 "Loading (and getting) a dialect (" + dialectNamespace +
521 ") while the same dialect is still loading: use loadDialect instead "
522 "of getOrLoadDialect.");
523 #endif // NDEBUG
525 // Abort if dialect with namespace has already been registered.
526 std::unique_ptr<Dialect> &dialect = dialectIt.first->second;
527 if (dialect->getTypeID() != dialectID)
528 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
529 "' has already been registered");
531 return dialect.get();
534 bool MLIRContext::isDialectLoading(StringRef dialectNamespace) {
535 auto it = getImpl().loadedDialects.find(dialectNamespace);
536 // nullptr indicates that the dialect is currently being loaded.
537 return it != getImpl().loadedDialects.end() && it->second == nullptr;
540 DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
541 StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
542 auto &impl = getImpl();
543 // Get the correct insertion position sorted by namespace.
544 auto dialectIt = impl.loadedDialects.find(dialectNamespace);
546 if (dialectIt != impl.loadedDialects.end()) {
547 if (auto *dynDialect = dyn_cast<DynamicDialect>(dialectIt->second.get()))
548 return dynDialect;
549 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
550 "' has already been registered");
553 LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
554 << dialectNamespace << "\n");
555 #ifndef NDEBUG
556 if (impl.multiThreadedExecutionContext != 0)
557 llvm::report_fatal_error(
558 "Loading a dynamic dialect (" + dialectNamespace +
559 ") while in a multi-threaded execution context (maybe "
560 "the PassManager): this can indicate a "
561 "missing `dependentDialects` in a pass for example.");
562 #endif
564 auto name = StringAttr::get(this, dialectNamespace);
565 auto *dialect = new DynamicDialect(name, this);
566 (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() {
567 ctor(dialect);
568 return std::unique_ptr<DynamicDialect>(dialect);
570 // This is the same result as `getOrLoadDialect` (if it didn't failed),
571 // since it has the same TypeID, and TypeIDs are unique.
572 return dialect;
575 void MLIRContext::loadAllAvailableDialects() {
576 for (StringRef name : getAvailableDialects())
577 getOrLoadDialect(name);
580 llvm::hash_code MLIRContext::getRegistryHash() {
581 llvm::hash_code hash(0);
582 // Factor in number of loaded dialects, attributes, operations, types.
583 hash = llvm::hash_combine(hash, impl->loadedDialects.size());
584 hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
585 hash = llvm::hash_combine(hash, impl->registeredOperations.size());
586 hash = llvm::hash_combine(hash, impl->registeredTypes.size());
587 return hash;
590 bool MLIRContext::allowsUnregisteredDialects() {
591 return impl->allowUnregisteredDialects;
594 void MLIRContext::allowUnregisteredDialects(bool allowing) {
595 assert(impl->multiThreadedExecutionContext == 0 &&
596 "changing MLIRContext `allow-unregistered-dialects` configuration "
597 "while in a multi-threaded execution context");
598 impl->allowUnregisteredDialects = allowing;
601 /// Return true if multi-threading is enabled by the context.
602 bool MLIRContext::isMultithreadingEnabled() {
603 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
606 /// Set the flag specifying if multi-threading is disabled by the context.
607 void MLIRContext::disableMultithreading(bool disable) {
608 // This API can be overridden by the global debugging flag
609 // --mlir-disable-threading
610 if (isThreadingGloballyDisabled())
611 return;
612 assert(impl->multiThreadedExecutionContext == 0 &&
613 "changing MLIRContext `disable-threading` configuration while "
614 "in a multi-threaded execution context");
616 impl->threadingIsEnabled = !disable;
618 // Update the threading mode for each of the uniquers.
619 impl->affineUniquer.disableMultithreading(disable);
620 impl->attributeUniquer.disableMultithreading(disable);
621 impl->typeUniquer.disableMultithreading(disable);
623 // Destroy thread pool (stop all threads) if it is no longer needed, or create
624 // a new one if multithreading was re-enabled.
625 if (disable) {
626 // If the thread pool is owned, explicitly set it to nullptr to avoid
627 // keeping a dangling pointer around. If the thread pool is externally
628 // owned, we don't do anything.
629 if (impl->ownedThreadPool) {
630 assert(impl->threadPool);
631 impl->threadPool = nullptr;
632 impl->ownedThreadPool.reset();
634 } else if (!impl->threadPool) {
635 // The thread pool isn't externally provided.
636 assert(!impl->ownedThreadPool);
637 impl->ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
638 impl->threadPool = impl->ownedThreadPool.get();
642 void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
643 assert(!isMultithreadingEnabled() &&
644 "expected multi-threading to be disabled when setting a ThreadPool");
645 impl->threadPool = &pool;
646 impl->ownedThreadPool.reset();
647 enableMultithreading();
650 unsigned MLIRContext::getNumThreads() {
651 if (isMultithreadingEnabled()) {
652 assert(impl->threadPool &&
653 "multi-threading is enabled but threadpool not set");
654 return impl->threadPool->getMaxConcurrency();
656 // No multithreading or active thread pool. Return 1 thread.
657 return 1;
660 llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
661 assert(isMultithreadingEnabled() &&
662 "expected multi-threading to be enabled within the context");
663 assert(impl->threadPool &&
664 "multi-threading is enabled but threadpool not set");
665 return *impl->threadPool;
668 void MLIRContext::enterMultiThreadedExecution() {
669 #ifndef NDEBUG
670 ++impl->multiThreadedExecutionContext;
671 #endif
673 void MLIRContext::exitMultiThreadedExecution() {
674 #ifndef NDEBUG
675 --impl->multiThreadedExecutionContext;
676 #endif
679 /// Return true if we should attach the operation to diagnostics emitted via
680 /// Operation::emit.
681 bool MLIRContext::shouldPrintOpOnDiagnostic() {
682 return impl->printOpOnDiagnostic;
685 /// Set the flag specifying if we should attach the operation to diagnostics
686 /// emitted via Operation::emit.
687 void MLIRContext::printOpOnDiagnostic(bool enable) {
688 assert(impl->multiThreadedExecutionContext == 0 &&
689 "changing MLIRContext `print-op-on-diagnostic` configuration while in "
690 "a multi-threaded execution context");
691 impl->printOpOnDiagnostic = enable;
694 /// Return true if we should attach the current stacktrace to diagnostics when
695 /// emitted.
696 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
697 return impl->printStackTraceOnDiagnostic;
700 /// Set the flag specifying if we should attach the current stacktrace when
701 /// emitting diagnostics.
702 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
703 assert(impl->multiThreadedExecutionContext == 0 &&
704 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
705 "while in a multi-threaded execution context");
706 impl->printStackTraceOnDiagnostic = enable;
709 /// Return information about all registered operations.
710 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711 return impl->sortedRegisteredOperations;
714 /// Return information for registered operations by dialect.
715 ArrayRef<RegisteredOperationName>
716 MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
717 auto lowerBound =
718 std::lower_bound(impl->sortedRegisteredOperations.begin(),
719 impl->sortedRegisteredOperations.end(), dialectName,
720 [](auto &lhs, auto &rhs) {
721 return lhs.getDialect().getNamespace().compare(rhs);
724 if (lowerBound == impl->sortedRegisteredOperations.end() ||
725 lowerBound->getDialect().getNamespace() != dialectName)
726 return ArrayRef<RegisteredOperationName>();
728 auto upperBound =
729 std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
730 dialectName, [](auto &lhs, auto &rhs) {
731 return lhs.compare(rhs.getDialect().getNamespace());
734 size_t count = std::distance(lowerBound, upperBound);
735 return ArrayRef(&*lowerBound, count);
738 bool MLIRContext::isOperationRegistered(StringRef name) {
739 return RegisteredOperationName::lookup(name, this).has_value();
742 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
743 auto &impl = context->getImpl();
744 assert(impl.multiThreadedExecutionContext == 0 &&
745 "Registering a new type kind while in a multi-threaded execution "
746 "context");
747 auto *newInfo =
748 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
749 AbstractType(std::move(typeInfo));
750 if (!impl.registeredTypes.insert({typeID, newInfo}).second)
751 llvm::report_fatal_error("Dialect Type already registered.");
752 if (!impl.nameToType.insert({newInfo->getName(), newInfo}).second)
753 llvm::report_fatal_error("Dialect Type with name " + newInfo->getName() +
754 " is already registered.");
757 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
758 auto &impl = context->getImpl();
759 assert(impl.multiThreadedExecutionContext == 0 &&
760 "Registering a new attribute kind while in a multi-threaded execution "
761 "context");
762 auto *newInfo =
763 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
764 AbstractAttribute(std::move(attrInfo));
765 if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
766 llvm::report_fatal_error("Dialect Attribute already registered.");
767 if (!impl.nameToAttribute.insert({newInfo->getName(), newInfo}).second)
768 llvm::report_fatal_error("Dialect Attribute with name " +
769 newInfo->getName() + " is already registered.");
772 //===----------------------------------------------------------------------===//
773 // AbstractAttribute
774 //===----------------------------------------------------------------------===//
776 /// Get the dialect that registered the attribute with the provided typeid.
777 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
778 MLIRContext *context) {
779 const AbstractAttribute *abstract = lookupMutable(typeID, context);
780 if (!abstract)
781 llvm::report_fatal_error("Trying to create an Attribute that was not "
782 "registered in this MLIRContext.");
783 return *abstract;
786 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
787 MLIRContext *context) {
788 auto &impl = context->getImpl();
789 return impl.registeredAttributes.lookup(typeID);
792 std::optional<std::reference_wrapper<const AbstractAttribute>>
793 AbstractAttribute::lookup(StringRef name, MLIRContext *context) {
794 MLIRContextImpl &impl = context->getImpl();
795 const AbstractAttribute *type = impl.nameToAttribute.lookup(name);
797 if (!type)
798 return std::nullopt;
799 return {*type};
802 //===----------------------------------------------------------------------===//
803 // OperationName
804 //===----------------------------------------------------------------------===//
806 OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
807 detail::InterfaceMap interfaceMap)
808 : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID,
809 std::move(interfaceMap)) {}
811 OperationName::OperationName(StringRef name, MLIRContext *context) {
812 MLIRContextImpl &ctxImpl = context->getImpl();
814 // Check for an existing name in read-only mode.
815 bool isMultithreadingEnabled = context->isMultithreadingEnabled();
816 if (isMultithreadingEnabled) {
817 // Check the registered info map first. In the overwhelmingly common case,
818 // the entry will be in here and it also removes the need to acquire any
819 // locks.
820 auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
821 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
822 impl = registeredIt->second.impl;
823 return;
826 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
827 auto it = ctxImpl.operations.find(name);
828 if (it != ctxImpl.operations.end()) {
829 impl = it->second.get();
830 return;
834 // Acquire a writer-lock so that we can safely create the new instance.
835 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
837 auto it = ctxImpl.operations.insert({name, nullptr});
838 if (it.second) {
839 auto nameAttr = StringAttr::get(context, name);
840 it.first->second = std::make_unique<UnregisteredOpModel>(
841 nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
842 detail::InterfaceMap());
844 impl = it.first->second.get();
847 StringRef OperationName::getDialectNamespace() const {
848 if (Dialect *dialect = getDialect())
849 return dialect->getNamespace();
850 return getStringRef().split('.').first;
853 LogicalResult
854 OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
855 SmallVectorImpl<OpFoldResult> &) {
856 return failure();
858 void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
859 RewritePatternSet &, MLIRContext *) {}
860 bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
862 OperationName::ParseAssemblyFn
863 OperationName::UnregisteredOpModel::getParseAssemblyFn() {
864 llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
866 void OperationName::UnregisteredOpModel::populateDefaultAttrs(
867 const OperationName &, NamedAttrList &) {}
868 void OperationName::UnregisteredOpModel::printAssembly(
869 Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
870 p.printGenericOp(op);
872 LogicalResult
873 OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
874 return success();
876 LogicalResult
877 OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
878 return success();
881 std::optional<Attribute>
882 OperationName::UnregisteredOpModel::getInherentAttr(Operation *op,
883 StringRef name) {
884 auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
885 if (!dict)
886 return std::nullopt;
887 if (Attribute attr = dict.get(name))
888 return attr;
889 return std::nullopt;
891 void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op,
892 StringAttr name,
893 Attribute value) {
894 auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
895 assert(dict);
896 NamedAttrList attrs(dict);
897 attrs.set(name, value);
898 *op->getPropertiesStorage().as<Attribute *>() =
899 attrs.getDictionary(op->getContext());
901 void OperationName::UnregisteredOpModel::populateInherentAttrs(
902 Operation *op, NamedAttrList &attrs) {}
903 LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs(
904 OperationName opName, NamedAttrList &attributes,
905 function_ref<InFlightDiagnostic()> emitError) {
906 return success();
908 int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
909 return sizeof(Attribute);
911 void OperationName::UnregisteredOpModel::initProperties(
912 OperationName opName, OpaqueProperties storage, OpaqueProperties init) {
913 new (storage.as<Attribute *>()) Attribute();
915 void OperationName::UnregisteredOpModel::deleteProperties(
916 OpaqueProperties prop) {
917 prop.as<Attribute *>()->~Attribute();
919 void OperationName::UnregisteredOpModel::populateDefaultProperties(
920 OperationName opName, OpaqueProperties properties) {}
921 LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr(
922 OperationName opName, OpaqueProperties properties, Attribute attr,
923 function_ref<InFlightDiagnostic()> emitError) {
924 *properties.as<Attribute *>() = attr;
925 return success();
927 Attribute
928 OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) {
929 return *op->getPropertiesStorage().as<Attribute *>();
931 void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs,
932 OpaqueProperties rhs) {
933 *lhs.as<Attribute *>() = *rhs.as<Attribute *>();
935 bool OperationName::UnregisteredOpModel::compareProperties(
936 OpaqueProperties lhs, OpaqueProperties rhs) {
937 return *lhs.as<Attribute *>() == *rhs.as<Attribute *>();
939 llvm::hash_code
940 OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
941 return llvm::hash_combine(*prop.as<Attribute *>());
944 //===----------------------------------------------------------------------===//
945 // RegisteredOperationName
946 //===----------------------------------------------------------------------===//
948 std::optional<RegisteredOperationName>
949 RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
950 auto &impl = ctx->getImpl();
951 auto it = impl.registeredOperations.find(typeID);
952 if (it != impl.registeredOperations.end())
953 return it->second;
954 return std::nullopt;
957 std::optional<RegisteredOperationName>
958 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
959 auto &impl = ctx->getImpl();
960 auto it = impl.registeredOperationsByName.find(name);
961 if (it != impl.registeredOperationsByName.end())
962 return it->getValue();
963 return std::nullopt;
966 void RegisteredOperationName::insert(
967 std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
968 ArrayRef<StringRef> attrNames) {
969 RegisteredOperationName::Impl *impl = ownedImpl.get();
970 MLIRContext *ctx = impl->getDialect()->getContext();
971 auto &ctxImpl = ctx->getImpl();
972 assert(ctxImpl.multiThreadedExecutionContext == 0 &&
973 "registering a new operation kind while in a multi-threaded execution "
974 "context");
976 // Register the attribute names of this operation.
977 MutableArrayRef<StringAttr> cachedAttrNames;
978 if (!attrNames.empty()) {
979 cachedAttrNames = MutableArrayRef<StringAttr>(
980 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
981 attrNames.size()),
982 attrNames.size());
983 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
984 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
985 impl->attributeNames = cachedAttrNames;
987 StringRef name = impl->getName().strref();
988 // Insert the operation info if it doesn't exist yet.
989 ctxImpl.operations[name] = std::move(ownedImpl);
991 // Update the registered info for this operation.
992 auto emplaced = ctxImpl.registeredOperations.try_emplace(
993 impl->getTypeID(), RegisteredOperationName(impl));
994 assert(emplaced.second && "operation name registration must be successful");
995 auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
996 name, RegisteredOperationName(impl));
997 (void)emplacedByName;
998 assert(emplacedByName.second &&
999 "operation name registration must be successful");
1001 // Add emplaced operation name to the sorted operations container.
1002 RegisteredOperationName &value = emplaced.first->second;
1003 ctxImpl.sortedRegisteredOperations.insert(
1004 llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
1005 [](auto &lhs, auto &rhs) {
1006 return lhs.getIdentifier().compare(
1007 rhs.getIdentifier());
1009 value);
1012 //===----------------------------------------------------------------------===//
1013 // AbstractType
1014 //===----------------------------------------------------------------------===//
1016 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
1017 const AbstractType *type = lookupMutable(typeID, context);
1018 if (!type)
1019 llvm::report_fatal_error(
1020 "Trying to create a Type that was not registered in this MLIRContext.");
1021 return *type;
1024 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
1025 auto &impl = context->getImpl();
1026 return impl.registeredTypes.lookup(typeID);
1029 std::optional<std::reference_wrapper<const AbstractType>>
1030 AbstractType::lookup(StringRef name, MLIRContext *context) {
1031 MLIRContextImpl &impl = context->getImpl();
1032 const AbstractType *type = impl.nameToType.lookup(name);
1034 if (!type)
1035 return std::nullopt;
1036 return {*type};
1039 //===----------------------------------------------------------------------===//
1040 // Type uniquing
1041 //===----------------------------------------------------------------------===//
1043 /// Returns the storage uniquer used for constructing type storage instances.
1044 /// This should not be used directly.
1045 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
1047 Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
1048 return context->getImpl().f4E2M1FNTy;
1050 Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
1051 return context->getImpl().f6E2M3FNTy;
1053 Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
1054 return context->getImpl().f6E3M2FNTy;
1056 Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
1057 return context->getImpl().f8E5M2Ty;
1059 Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
1060 return context->getImpl().f8E4M3Ty;
1062 Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
1063 return context->getImpl().f8E4M3FNTy;
1065 Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
1066 return context->getImpl().f8E5M2FNUZTy;
1068 Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
1069 return context->getImpl().f8E4M3FNUZTy;
1071 Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
1072 return context->getImpl().f8E4M3B11FNUZTy;
1074 Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
1075 return context->getImpl().f8E3M4Ty;
1077 Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
1078 return context->getImpl().f8E8M0FNUTy;
1080 BFloat16Type BFloat16Type::get(MLIRContext *context) {
1081 return context->getImpl().bf16Ty;
1083 Float16Type Float16Type::get(MLIRContext *context) {
1084 return context->getImpl().f16Ty;
1086 FloatTF32Type FloatTF32Type::get(MLIRContext *context) {
1087 return context->getImpl().tf32Ty;
1089 Float32Type Float32Type::get(MLIRContext *context) {
1090 return context->getImpl().f32Ty;
1092 Float64Type Float64Type::get(MLIRContext *context) {
1093 return context->getImpl().f64Ty;
1095 Float80Type Float80Type::get(MLIRContext *context) {
1096 return context->getImpl().f80Ty;
1098 Float128Type Float128Type::get(MLIRContext *context) {
1099 return context->getImpl().f128Ty;
1102 /// Get an instance of the IndexType.
1103 IndexType IndexType::get(MLIRContext *context) {
1104 return context->getImpl().indexTy;
1107 /// Return an existing integer type instance if one is cached within the
1108 /// context.
1109 static IntegerType
1110 getCachedIntegerType(unsigned width,
1111 IntegerType::SignednessSemantics signedness,
1112 MLIRContext *context) {
1113 if (signedness != IntegerType::Signless)
1114 return IntegerType();
1116 switch (width) {
1117 case 1:
1118 return context->getImpl().int1Ty;
1119 case 8:
1120 return context->getImpl().int8Ty;
1121 case 16:
1122 return context->getImpl().int16Ty;
1123 case 32:
1124 return context->getImpl().int32Ty;
1125 case 64:
1126 return context->getImpl().int64Ty;
1127 case 128:
1128 return context->getImpl().int128Ty;
1129 default:
1130 return IntegerType();
1134 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
1135 IntegerType::SignednessSemantics signedness) {
1136 if (auto cached = getCachedIntegerType(width, signedness, context))
1137 return cached;
1138 return Base::get(context, width, signedness);
1141 IntegerType
1142 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1143 MLIRContext *context, unsigned width,
1144 SignednessSemantics signedness) {
1145 if (auto cached = getCachedIntegerType(width, signedness, context))
1146 return cached;
1147 return Base::getChecked(emitError, context, width, signedness);
1150 /// Get an instance of the NoneType.
1151 NoneType NoneType::get(MLIRContext *context) {
1152 if (NoneType cachedInst = context->getImpl().noneType)
1153 return cachedInst;
1154 // Note: May happen when initializing the singleton attributes of the builtin
1155 // dialect.
1156 return Base::get(context);
1159 //===----------------------------------------------------------------------===//
1160 // Attribute uniquing
1161 //===----------------------------------------------------------------------===//
1163 /// Returns the storage uniquer used for constructing attribute storage
1164 /// instances. This should not be used directly.
1165 StorageUniquer &MLIRContext::getAttributeUniquer() {
1166 return getImpl().attributeUniquer;
1169 /// Initialize the given attribute storage instance.
1170 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
1171 MLIRContext *ctx,
1172 TypeID attrID) {
1173 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
1176 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
1177 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
1180 UnitAttr UnitAttr::get(MLIRContext *context) {
1181 return context->getImpl().unitAttr;
1184 UnknownLoc UnknownLoc::get(MLIRContext *context) {
1185 return context->getImpl().unknownLocAttr;
1188 DistinctAttrStorage *
1189 detail::DistinctAttributeUniquer::allocateStorage(MLIRContext *context,
1190 Attribute referencedAttr) {
1191 return context->getImpl().distinctAttributeAllocator.allocate(referencedAttr);
1194 /// Return empty dictionary.
1195 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
1196 return context->getImpl().emptyDictionaryAttr;
1199 void StringAttrStorage::initialize(MLIRContext *context) {
1200 // Check for a dialect namespace prefix, if there isn't one we don't need to
1201 // do any additional initialization.
1202 auto dialectNamePair = value.split('.');
1203 if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
1204 return;
1206 // If one exists, we check to see if this dialect is loaded. If it is, we set
1207 // the dialect now, if it isn't we record this storage for initialization
1208 // later if the dialect ever gets loaded.
1209 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
1210 return;
1212 MLIRContextImpl &impl = context->getImpl();
1213 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
1214 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
1217 /// Return an empty string.
1218 StringAttr StringAttr::get(MLIRContext *context) {
1219 return context->getImpl().emptyStringAttr;
1222 //===----------------------------------------------------------------------===//
1223 // AffineMap uniquing
1224 //===----------------------------------------------------------------------===//
1226 StorageUniquer &MLIRContext::getAffineUniquer() {
1227 return getImpl().affineUniquer;
1230 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
1231 ArrayRef<AffineExpr> results,
1232 MLIRContext *context) {
1233 auto &impl = context->getImpl();
1234 auto *storage = impl.affineUniquer.get<AffineMapStorage>(
1235 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount,
1236 symbolCount, results);
1237 return AffineMap(storage);
1240 /// Check whether the arguments passed to the AffineMap::get() are consistent.
1241 /// This method checks whether the highest index of dimensional identifier
1242 /// present in result expressions is less than `dimCount` and the highest index
1243 /// of symbolic identifier present in result expressions is less than
1244 /// `symbolCount`.
1245 LLVM_ATTRIBUTE_UNUSED static bool
1246 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
1247 ArrayRef<AffineExpr> results) {
1248 int64_t maxDimPosition = -1;
1249 int64_t maxSymbolPosition = -1;
1250 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
1251 maxSymbolPosition);
1252 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
1253 LLVM_DEBUG(
1254 llvm::dbgs()
1255 << "maximum dimensional identifier position in result expression must "
1256 "be less than `dimCount` and maximum symbolic identifier position "
1257 "in result expression must be less than `symbolCount`\n");
1258 return false;
1260 return true;
1263 AffineMap AffineMap::get(MLIRContext *context) {
1264 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
1267 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1268 MLIRContext *context) {
1269 return getImpl(dimCount, symbolCount, /*results=*/{}, context);
1272 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1273 AffineExpr result) {
1274 assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
1275 return getImpl(dimCount, symbolCount, {result}, result.getContext());
1278 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1279 ArrayRef<AffineExpr> results, MLIRContext *context) {
1280 assert(willBeValidAffineMap(dimCount, symbolCount, results));
1281 return getImpl(dimCount, symbolCount, results, context);
1284 //===----------------------------------------------------------------------===//
1285 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1286 // Unlike AffineMap's, these are uniqued only if they are small.
1287 //===----------------------------------------------------------------------===//
1289 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1290 ArrayRef<AffineExpr> constraints,
1291 ArrayRef<bool> eqFlags) {
1292 // The number of constraints can't be zero.
1293 assert(!constraints.empty());
1294 assert(constraints.size() == eqFlags.size());
1296 auto &impl = constraints[0].getContext()->getImpl();
1297 auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
1298 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags);
1299 return IntegerSet(storage);
1302 //===----------------------------------------------------------------------===//
1303 // StorageUniquerSupport
1304 //===----------------------------------------------------------------------===//
1306 /// Utility method to generate a callback that can be used to generate a
1307 /// diagnostic when checking the construction invariants of a storage object.
1308 /// This is defined out-of-line to avoid the need to include Location.h.
1309 llvm::unique_function<InFlightDiagnostic()>
1310 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1311 return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1313 llvm::unique_function<InFlightDiagnostic()>
1314 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1315 return [=] { return emitError(loc); };