1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/Action.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributes.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/ExtensibleDialect.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/Location.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseSet.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/Twine.h"
34 #include "llvm/Support/Allocator.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Compiler.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/Mutex.h"
39 #include "llvm/Support/RWMutex.h"
40 #include "llvm/Support/ThreadPool.h"
41 #include "llvm/Support/raw_ostream.h"
45 #define DEBUG_TYPE "mlircontext"
48 using namespace mlir::detail
;
50 //===----------------------------------------------------------------------===//
51 // MLIRContext CommandLine Options
52 //===----------------------------------------------------------------------===//
55 /// This struct contains command line options that can be used to initialize
56 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
57 /// for global command line options.
58 struct MLIRContextOptions
{
59 llvm::cl::opt
<bool> disableThreading
{
60 "mlir-disable-threading",
61 llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
62 "further call to MLIRContext::enableMultiThreading()")};
64 llvm::cl::opt
<bool> printOpOnDiagnostic
{
65 "mlir-print-op-on-diagnostic",
66 llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
67 "the operation as an attached note"),
68 llvm::cl::init(true)};
70 llvm::cl::opt
<bool> printStackTraceOnDiagnostic
{
71 "mlir-print-stacktrace-on-diagnostic",
72 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
73 "as an attached note")};
77 static llvm::ManagedStatic
<MLIRContextOptions
> clOptions
;
79 static bool isThreadingGloballyDisabled() {
80 #if LLVM_ENABLE_THREADS != 0
81 return clOptions
.isConstructed() && clOptions
->disableThreading
;
87 /// Register a set of useful command-line options that can be used to configure
88 /// various flags within the MLIRContext. These flags are used when constructing
89 /// an MLIR context for initialization.
90 void mlir::registerMLIRContextCLOptions() {
91 // Make sure that the options struct has been initialized.
95 //===----------------------------------------------------------------------===//
97 //===----------------------------------------------------------------------===//
100 /// Utility writer lock that takes a runtime flag that specifies if we really
102 struct ScopedWriterLock
{
103 ScopedWriterLock(llvm::sys::SmartRWMutex
<true> &mutexParam
, bool shouldLock
)
104 : mutex(shouldLock
? &mutexParam
: nullptr) {
108 ~ScopedWriterLock() {
112 llvm::sys::SmartRWMutex
<true> *mutex
;
116 //===----------------------------------------------------------------------===//
118 //===----------------------------------------------------------------------===//
121 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
122 /// This class is completely private to this file, so everything is public.
123 class MLIRContextImpl
{
125 //===--------------------------------------------------------------------===//
127 //===--------------------------------------------------------------------===//
129 /// An action handler for handling actions that are dispatched through this
131 std::function
<void(function_ref
<void()>, const tracing::Action
&)>
134 //===--------------------------------------------------------------------===//
136 //===--------------------------------------------------------------------===//
137 DiagnosticEngine diagEngine
;
139 //===--------------------------------------------------------------------===//
141 //===--------------------------------------------------------------------===//
143 /// In most cases, creating operation in unregistered dialect is not desired
144 /// and indicate a misconfiguration of the compiler. This option enables to
145 /// detect such use cases
146 bool allowUnregisteredDialects
= false;
148 /// Enable support for multi-threading within MLIR.
149 bool threadingIsEnabled
= true;
151 /// Track if we are currently executing in a threaded execution environment
152 /// (like the pass-manager): this is only a debugging feature to help reducing
153 /// the chances of data races one some context APIs.
155 std::atomic
<int> multiThreadedExecutionContext
{0};
158 /// If the operation should be attached to diagnostics printed via the
159 /// Operation::emit methods.
160 bool printOpOnDiagnostic
= true;
162 /// If the current stack trace should be attached when emitting diagnostics.
163 bool printStackTraceOnDiagnostic
= false;
165 //===--------------------------------------------------------------------===//
167 //===--------------------------------------------------------------------===//
169 /// This points to the ThreadPool used when processing MLIR tasks in parallel.
170 /// It can't be nullptr when multi-threading is enabled. Otherwise if
171 /// multi-threading is disabled, and the threadpool wasn't externally provided
172 /// using `setThreadPool`, this will be nullptr.
173 llvm::ThreadPool
*threadPool
= nullptr;
175 /// In case where the thread pool is owned by the context, this ensures
176 /// destruction with the context.
177 std::unique_ptr
<llvm::ThreadPool
> ownedThreadPool
;
179 /// An allocator used for AbstractAttribute and AbstractType objects.
180 llvm::BumpPtrAllocator abstractDialectSymbolAllocator
;
182 /// This is a mapping from operation name to the operation info describing it.
183 llvm::StringMap
<std::unique_ptr
<OperationName::Impl
>> operations
;
185 /// A vector of operation info specifically for registered operations.
186 llvm::StringMap
<RegisteredOperationName
> registeredOperations
;
188 /// This is a sorted container of registered operations for a deterministic
189 /// and efficient `getRegisteredOperations` implementation.
190 SmallVector
<RegisteredOperationName
, 0> sortedRegisteredOperations
;
192 /// This is a list of dialects that are created referring to this context.
193 /// The MLIRContext owns the objects. These need to be declared after the
194 /// registered operations to ensure correct destruction order.
195 DenseMap
<StringRef
, std::unique_ptr
<Dialect
>> loadedDialects
;
196 DialectRegistry dialectsRegistry
;
198 /// A mutex used when accessing operation information.
199 llvm::sys::SmartRWMutex
<true> operationInfoMutex
;
201 //===--------------------------------------------------------------------===//
203 //===--------------------------------------------------------------------===//
205 // Affine expression, map and integer set uniquing.
206 StorageUniquer affineUniquer
;
208 //===--------------------------------------------------------------------===//
210 //===--------------------------------------------------------------------===//
212 DenseMap
<TypeID
, AbstractType
*> registeredTypes
;
213 StorageUniquer typeUniquer
;
215 /// Cached Type Instances.
216 Float8E5M2Type f8E5M2Ty
;
217 Float8E4M3FNType f8E4M3FNTy
;
218 Float8E5M2FNUZType f8E5M2FNUZTy
;
219 Float8E4M3FNUZType f8E4M3FNUZTy
;
220 Float8E4M3B11FNUZType f8E4M3B11FNUZTy
;
223 FloatTF32Type tf32Ty
;
229 IntegerType int1Ty
, int8Ty
, int16Ty
, int32Ty
, int64Ty
, int128Ty
;
232 //===--------------------------------------------------------------------===//
233 // Attribute uniquing
234 //===--------------------------------------------------------------------===//
236 DenseMap
<TypeID
, AbstractAttribute
*> registeredAttributes
;
237 StorageUniquer attributeUniquer
;
239 /// Cached Attribute Instances.
240 BoolAttr falseAttr
, trueAttr
;
242 UnknownLoc unknownLocAttr
;
243 DictionaryAttr emptyDictionaryAttr
;
244 StringAttr emptyStringAttr
;
246 /// Map of string attributes that may reference a dialect, that are awaiting
247 /// that dialect to be loaded.
248 llvm::sys::SmartMutex
<true> dialectRefStrAttrMutex
;
249 DenseMap
<StringRef
, SmallVector
<StringAttrStorage
*>>
250 dialectReferencingStrAttrs
;
252 /// A distinct attribute allocator that allocates every time since the
253 /// address of the distinct attribute storage serves as unique identifier. The
254 /// allocator is thread safe and frees the allocated storage after its
256 DistinctAttributeAllocator distinctAttributeAllocator
;
259 MLIRContextImpl(bool threadingIsEnabled
)
260 : threadingIsEnabled(threadingIsEnabled
) {
261 if (threadingIsEnabled
) {
262 ownedThreadPool
= std::make_unique
<llvm::ThreadPool
>();
263 threadPool
= ownedThreadPool
.get();
267 for (auto typeMapping
: registeredTypes
)
268 typeMapping
.second
->~AbstractType();
269 for (auto attrMapping
: registeredAttributes
)
270 attrMapping
.second
->~AbstractAttribute();
275 MLIRContext::MLIRContext(Threading setting
)
276 : MLIRContext(DialectRegistry(), setting
) {}
278 MLIRContext::MLIRContext(const DialectRegistry
®istry
, Threading setting
)
279 : impl(new MLIRContextImpl(setting
== Threading::ENABLED
&&
280 !isThreadingGloballyDisabled())) {
281 // Initialize values based on the command line flags if they were provided.
282 if (clOptions
.isConstructed()) {
283 printOpOnDiagnostic(clOptions
->printOpOnDiagnostic
);
284 printStackTraceOnDiagnostic(clOptions
->printStackTraceOnDiagnostic
);
287 // Pre-populate the registry.
288 registry
.appendTo(impl
->dialectsRegistry
);
290 // Ensure the builtin dialect is always pre-loaded.
291 getOrLoadDialect
<BuiltinDialect
>();
293 // Initialize several common attributes and types to avoid the need to lock
294 // the context when accessing them.
297 /// Floating-point Types.
298 impl
->f8E5M2Ty
= TypeUniquer::get
<Float8E5M2Type
>(this);
299 impl
->f8E4M3FNTy
= TypeUniquer::get
<Float8E4M3FNType
>(this);
300 impl
->f8E5M2FNUZTy
= TypeUniquer::get
<Float8E5M2FNUZType
>(this);
301 impl
->f8E4M3FNUZTy
= TypeUniquer::get
<Float8E4M3FNUZType
>(this);
302 impl
->f8E4M3B11FNUZTy
= TypeUniquer::get
<Float8E4M3B11FNUZType
>(this);
303 impl
->bf16Ty
= TypeUniquer::get
<BFloat16Type
>(this);
304 impl
->f16Ty
= TypeUniquer::get
<Float16Type
>(this);
305 impl
->tf32Ty
= TypeUniquer::get
<FloatTF32Type
>(this);
306 impl
->f32Ty
= TypeUniquer::get
<Float32Type
>(this);
307 impl
->f64Ty
= TypeUniquer::get
<Float64Type
>(this);
308 impl
->f80Ty
= TypeUniquer::get
<Float80Type
>(this);
309 impl
->f128Ty
= TypeUniquer::get
<Float128Type
>(this);
311 impl
->indexTy
= TypeUniquer::get
<IndexType
>(this);
313 impl
->int1Ty
= TypeUniquer::get
<IntegerType
>(this, 1, IntegerType::Signless
);
314 impl
->int8Ty
= TypeUniquer::get
<IntegerType
>(this, 8, IntegerType::Signless
);
316 TypeUniquer::get
<IntegerType
>(this, 16, IntegerType::Signless
);
318 TypeUniquer::get
<IntegerType
>(this, 32, IntegerType::Signless
);
320 TypeUniquer::get
<IntegerType
>(this, 64, IntegerType::Signless
);
322 TypeUniquer::get
<IntegerType
>(this, 128, IntegerType::Signless
);
324 impl
->noneType
= TypeUniquer::get
<NoneType
>(this);
327 //// Note: These must be registered after the types as they may generate one
328 //// of the above types internally.
329 /// Unknown Location Attribute.
330 impl
->unknownLocAttr
= AttributeUniquer::get
<UnknownLoc
>(this);
332 impl
->falseAttr
= IntegerAttr::getBoolAttrUnchecked(impl
->int1Ty
, false);
333 impl
->trueAttr
= IntegerAttr::getBoolAttrUnchecked(impl
->int1Ty
, true);
335 impl
->unitAttr
= AttributeUniquer::get
<UnitAttr
>(this);
336 /// The empty dictionary attribute.
337 impl
->emptyDictionaryAttr
= DictionaryAttr::getEmptyUnchecked(this);
338 /// The empty string attribute.
339 impl
->emptyStringAttr
= StringAttr::getEmptyStringAttrUnchecked(this);
341 // Register the affine storage objects with the uniquer.
343 .registerParametricStorageType
<AffineBinaryOpExprStorage
>();
345 .registerParametricStorageType
<AffineConstantExprStorage
>();
346 impl
->affineUniquer
.registerParametricStorageType
<AffineDimExprStorage
>();
347 impl
->affineUniquer
.registerParametricStorageType
<AffineMapStorage
>();
348 impl
->affineUniquer
.registerParametricStorageType
<IntegerSetStorage
>();
351 MLIRContext::~MLIRContext() = default;
353 /// Copy the specified array of elements into memory managed by the provided
354 /// bump pointer allocator. This assumes the elements are all PODs.
355 template <typename T
>
356 static ArrayRef
<T
> copyArrayRefInto(llvm::BumpPtrAllocator
&allocator
,
357 ArrayRef
<T
> elements
) {
358 auto result
= allocator
.Allocate
<T
>(elements
.size());
359 std::uninitialized_copy(elements
.begin(), elements
.end(), result
);
360 return ArrayRef
<T
>(result
, elements
.size());
363 //===----------------------------------------------------------------------===//
365 //===----------------------------------------------------------------------===//
367 void MLIRContext::registerActionHandler(HandlerTy handler
) {
368 getImpl().actionHandler
= std::move(handler
);
371 /// Dispatch the provided action to the handler if any, or just execute it.
372 void MLIRContext::executeActionInternal(function_ref
<void()> actionFn
,
373 const tracing::Action
&action
) {
374 assert(getImpl().actionHandler
);
375 getImpl().actionHandler(actionFn
, action
);
378 bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler
; }
380 //===----------------------------------------------------------------------===//
381 // Diagnostic Handlers
382 //===----------------------------------------------------------------------===//
384 /// Returns the diagnostic engine for this context.
385 DiagnosticEngine
&MLIRContext::getDiagEngine() { return getImpl().diagEngine
; }
387 //===----------------------------------------------------------------------===//
388 // Dialect and Operation Registration
389 //===----------------------------------------------------------------------===//
391 void MLIRContext::appendDialectRegistry(const DialectRegistry
®istry
) {
392 if (registry
.isSubsetOf(impl
->dialectsRegistry
))
395 assert(impl
->multiThreadedExecutionContext
== 0 &&
396 "appending to the MLIRContext dialect registry while in a "
397 "multi-threaded execution context");
398 registry
.appendTo(impl
->dialectsRegistry
);
400 // For the already loaded dialects, apply any possible extensions immediately.
401 registry
.applyExtensions(this);
404 const DialectRegistry
&MLIRContext::getDialectRegistry() {
405 return impl
->dialectsRegistry
;
408 /// Return information about all registered IR dialects.
409 std::vector
<Dialect
*> MLIRContext::getLoadedDialects() {
410 std::vector
<Dialect
*> result
;
411 result
.reserve(impl
->loadedDialects
.size());
412 for (auto &dialect
: impl
->loadedDialects
)
413 result
.push_back(dialect
.second
.get());
414 llvm::array_pod_sort(result
.begin(), result
.end(),
415 [](Dialect
*const *lhs
, Dialect
*const *rhs
) -> int {
416 return (*lhs
)->getNamespace() < (*rhs
)->getNamespace();
420 std::vector
<StringRef
> MLIRContext::getAvailableDialects() {
421 std::vector
<StringRef
> result
;
422 for (auto dialect
: impl
->dialectsRegistry
.getDialectNames())
423 result
.push_back(dialect
);
427 /// Get a registered IR dialect with the given namespace. If none is found,
428 /// then return nullptr.
429 Dialect
*MLIRContext::getLoadedDialect(StringRef name
) {
430 // Dialects are sorted by name, so we can use binary search for lookup.
431 auto it
= impl
->loadedDialects
.find(name
);
432 return (it
!= impl
->loadedDialects
.end()) ? it
->second
.get() : nullptr;
435 Dialect
*MLIRContext::getOrLoadDialect(StringRef name
) {
436 Dialect
*dialect
= getLoadedDialect(name
);
439 DialectAllocatorFunctionRef allocator
=
440 impl
->dialectsRegistry
.getDialectAllocator(name
);
441 return allocator
? allocator(this) : nullptr;
444 /// Get a dialect for the provided namespace and TypeID: abort the program if a
445 /// dialect exist for this namespace with different TypeID. Returns a pointer to
446 /// the dialect owned by the context.
448 MLIRContext::getOrLoadDialect(StringRef dialectNamespace
, TypeID dialectID
,
449 function_ref
<std::unique_ptr
<Dialect
>()> ctor
) {
450 auto &impl
= getImpl();
451 // Get the correct insertion position sorted by namespace.
452 auto dialectIt
= impl
.loadedDialects
.try_emplace(dialectNamespace
, nullptr);
454 if (dialectIt
.second
) {
455 LLVM_DEBUG(llvm::dbgs()
456 << "Load new dialect in Context " << dialectNamespace
<< "\n");
458 if (impl
.multiThreadedExecutionContext
!= 0)
459 llvm::report_fatal_error(
460 "Loading a dialect (" + dialectNamespace
+
461 ") while in a multi-threaded execution context (maybe "
462 "the PassManager): this can indicate a "
463 "missing `dependentDialects` in a pass for example.");
465 // loadedDialects entry is initialized to nullptr, indicating that the
466 // dialect is currently being loaded. Re-lookup the address in
467 // loadedDialects because the table might have been rehashed by recursive
468 // dialect loading in ctor().
469 std::unique_ptr
<Dialect
> &dialectOwned
=
470 impl
.loadedDialects
[dialectNamespace
] = ctor();
471 Dialect
*dialect
= dialectOwned
.get();
472 assert(dialect
&& "dialect ctor failed");
474 // Refresh all the identifiers dialect field, this catches cases where a
475 // dialect may be loaded after identifier prefixed with this dialect name
476 // were already created.
477 auto stringAttrsIt
= impl
.dialectReferencingStrAttrs
.find(dialectNamespace
);
478 if (stringAttrsIt
!= impl
.dialectReferencingStrAttrs
.end()) {
479 for (StringAttrStorage
*storage
: stringAttrsIt
->second
)
480 storage
->referencedDialect
= dialect
;
481 impl
.dialectReferencingStrAttrs
.erase(stringAttrsIt
);
484 // Apply any extensions to this newly loaded dialect.
485 impl
.dialectsRegistry
.applyExtensions(dialect
);
490 if (dialectIt
.first
->second
== nullptr)
491 llvm::report_fatal_error(
492 "Loading (and getting) a dialect (" + dialectNamespace
+
493 ") while the same dialect is still loading: use loadDialect instead "
494 "of getOrLoadDialect.");
497 // Abort if dialect with namespace has already been registered.
498 std::unique_ptr
<Dialect
> &dialect
= dialectIt
.first
->second
;
499 if (dialect
->getTypeID() != dialectID
)
500 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace
+
501 "' has already been registered");
503 return dialect
.get();
506 bool MLIRContext::isDialectLoading(StringRef dialectNamespace
) {
507 auto it
= getImpl().loadedDialects
.find(dialectNamespace
);
508 // nullptr indicates that the dialect is currently being loaded.
509 return it
!= getImpl().loadedDialects
.end() && it
->second
== nullptr;
512 DynamicDialect
*MLIRContext::getOrLoadDynamicDialect(
513 StringRef dialectNamespace
, function_ref
<void(DynamicDialect
*)> ctor
) {
514 auto &impl
= getImpl();
515 // Get the correct insertion position sorted by namespace.
516 auto dialectIt
= impl
.loadedDialects
.find(dialectNamespace
);
518 if (dialectIt
!= impl
.loadedDialects
.end()) {
519 if (auto *dynDialect
= dyn_cast
<DynamicDialect
>(dialectIt
->second
.get()))
521 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace
+
522 "' has already been registered");
525 LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
526 << dialectNamespace
<< "\n");
528 if (impl
.multiThreadedExecutionContext
!= 0)
529 llvm::report_fatal_error(
530 "Loading a dynamic dialect (" + dialectNamespace
+
531 ") while in a multi-threaded execution context (maybe "
532 "the PassManager): this can indicate a "
533 "missing `dependentDialects` in a pass for example.");
536 auto name
= StringAttr::get(this, dialectNamespace
);
537 auto *dialect
= new DynamicDialect(name
, this);
538 (void)getOrLoadDialect(name
, dialect
->getTypeID(), [dialect
, ctor
]() {
540 return std::unique_ptr
<DynamicDialect
>(dialect
);
542 // This is the same result as `getOrLoadDialect` (if it didn't failed),
543 // since it has the same TypeID, and TypeIDs are unique.
547 void MLIRContext::loadAllAvailableDialects() {
548 for (StringRef name
: getAvailableDialects())
549 getOrLoadDialect(name
);
552 llvm::hash_code
MLIRContext::getRegistryHash() {
553 llvm::hash_code
hash(0);
554 // Factor in number of loaded dialects, attributes, operations, types.
555 hash
= llvm::hash_combine(hash
, impl
->loadedDialects
.size());
556 hash
= llvm::hash_combine(hash
, impl
->registeredAttributes
.size());
557 hash
= llvm::hash_combine(hash
, impl
->registeredOperations
.size());
558 hash
= llvm::hash_combine(hash
, impl
->registeredTypes
.size());
562 bool MLIRContext::allowsUnregisteredDialects() {
563 return impl
->allowUnregisteredDialects
;
566 void MLIRContext::allowUnregisteredDialects(bool allowing
) {
567 assert(impl
->multiThreadedExecutionContext
== 0 &&
568 "changing MLIRContext `allow-unregistered-dialects` configuration "
569 "while in a multi-threaded execution context");
570 impl
->allowUnregisteredDialects
= allowing
;
573 /// Return true if multi-threading is enabled by the context.
574 bool MLIRContext::isMultithreadingEnabled() {
575 return impl
->threadingIsEnabled
&& llvm::llvm_is_multithreaded();
578 /// Set the flag specifying if multi-threading is disabled by the context.
579 void MLIRContext::disableMultithreading(bool disable
) {
580 // This API can be overridden by the global debugging flag
581 // --mlir-disable-threading
582 if (isThreadingGloballyDisabled())
584 assert(impl
->multiThreadedExecutionContext
== 0 &&
585 "changing MLIRContext `disable-threading` configuration while "
586 "in a multi-threaded execution context");
588 impl
->threadingIsEnabled
= !disable
;
590 // Update the threading mode for each of the uniquers.
591 impl
->affineUniquer
.disableMultithreading(disable
);
592 impl
->attributeUniquer
.disableMultithreading(disable
);
593 impl
->typeUniquer
.disableMultithreading(disable
);
595 // Destroy thread pool (stop all threads) if it is no longer needed, or create
596 // a new one if multithreading was re-enabled.
598 // If the thread pool is owned, explicitly set it to nullptr to avoid
599 // keeping a dangling pointer around. If the thread pool is externally
600 // owned, we don't do anything.
601 if (impl
->ownedThreadPool
) {
602 assert(impl
->threadPool
);
603 impl
->threadPool
= nullptr;
604 impl
->ownedThreadPool
.reset();
606 } else if (!impl
->threadPool
) {
607 // The thread pool isn't externally provided.
608 assert(!impl
->ownedThreadPool
);
609 impl
->ownedThreadPool
= std::make_unique
<llvm::ThreadPool
>();
610 impl
->threadPool
= impl
->ownedThreadPool
.get();
614 void MLIRContext::setThreadPool(llvm::ThreadPool
&pool
) {
615 assert(!isMultithreadingEnabled() &&
616 "expected multi-threading to be disabled when setting a ThreadPool");
617 impl
->threadPool
= &pool
;
618 impl
->ownedThreadPool
.reset();
619 enableMultithreading();
622 unsigned MLIRContext::getNumThreads() {
623 if (isMultithreadingEnabled()) {
624 assert(impl
->threadPool
&&
625 "multi-threading is enabled but threadpool not set");
626 return impl
->threadPool
->getThreadCount();
628 // No multithreading or active thread pool. Return 1 thread.
632 llvm::ThreadPool
&MLIRContext::getThreadPool() {
633 assert(isMultithreadingEnabled() &&
634 "expected multi-threading to be enabled within the context");
635 assert(impl
->threadPool
&&
636 "multi-threading is enabled but threadpool not set");
637 return *impl
->threadPool
;
640 void MLIRContext::enterMultiThreadedExecution() {
642 ++impl
->multiThreadedExecutionContext
;
645 void MLIRContext::exitMultiThreadedExecution() {
647 --impl
->multiThreadedExecutionContext
;
651 /// Return true if we should attach the operation to diagnostics emitted via
653 bool MLIRContext::shouldPrintOpOnDiagnostic() {
654 return impl
->printOpOnDiagnostic
;
657 /// Set the flag specifying if we should attach the operation to diagnostics
658 /// emitted via Operation::emit.
659 void MLIRContext::printOpOnDiagnostic(bool enable
) {
660 assert(impl
->multiThreadedExecutionContext
== 0 &&
661 "changing MLIRContext `print-op-on-diagnostic` configuration while in "
662 "a multi-threaded execution context");
663 impl
->printOpOnDiagnostic
= enable
;
666 /// Return true if we should attach the current stacktrace to diagnostics when
668 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
669 return impl
->printStackTraceOnDiagnostic
;
672 /// Set the flag specifying if we should attach the current stacktrace when
673 /// emitting diagnostics.
674 void MLIRContext::printStackTraceOnDiagnostic(bool enable
) {
675 assert(impl
->multiThreadedExecutionContext
== 0 &&
676 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
677 "while in a multi-threaded execution context");
678 impl
->printStackTraceOnDiagnostic
= enable
;
681 /// Return information about all registered operations.
682 ArrayRef
<RegisteredOperationName
> MLIRContext::getRegisteredOperations() {
683 return impl
->sortedRegisteredOperations
;
686 bool MLIRContext::isOperationRegistered(StringRef name
) {
687 return RegisteredOperationName::lookup(name
, this).has_value();
690 void Dialect::addType(TypeID typeID
, AbstractType
&&typeInfo
) {
691 auto &impl
= context
->getImpl();
692 assert(impl
.multiThreadedExecutionContext
== 0 &&
693 "Registering a new type kind while in a multi-threaded execution "
696 new (impl
.abstractDialectSymbolAllocator
.Allocate
<AbstractType
>())
697 AbstractType(std::move(typeInfo
));
698 if (!impl
.registeredTypes
.insert({typeID
, newInfo
}).second
)
699 llvm::report_fatal_error("Dialect Type already registered.");
702 void Dialect::addAttribute(TypeID typeID
, AbstractAttribute
&&attrInfo
) {
703 auto &impl
= context
->getImpl();
704 assert(impl
.multiThreadedExecutionContext
== 0 &&
705 "Registering a new attribute kind while in a multi-threaded execution "
708 new (impl
.abstractDialectSymbolAllocator
.Allocate
<AbstractAttribute
>())
709 AbstractAttribute(std::move(attrInfo
));
710 if (!impl
.registeredAttributes
.insert({typeID
, newInfo
}).second
)
711 llvm::report_fatal_error("Dialect Attribute already registered.");
714 //===----------------------------------------------------------------------===//
716 //===----------------------------------------------------------------------===//
718 /// Get the dialect that registered the attribute with the provided typeid.
719 const AbstractAttribute
&AbstractAttribute::lookup(TypeID typeID
,
720 MLIRContext
*context
) {
721 const AbstractAttribute
*abstract
= lookupMutable(typeID
, context
);
723 llvm::report_fatal_error("Trying to create an Attribute that was not "
724 "registered in this MLIRContext.");
728 AbstractAttribute
*AbstractAttribute::lookupMutable(TypeID typeID
,
729 MLIRContext
*context
) {
730 auto &impl
= context
->getImpl();
731 return impl
.registeredAttributes
.lookup(typeID
);
734 //===----------------------------------------------------------------------===//
736 //===----------------------------------------------------------------------===//
738 OperationName::Impl::Impl(StringRef name
, Dialect
*dialect
, TypeID typeID
,
739 detail::InterfaceMap interfaceMap
)
740 : Impl(StringAttr::get(dialect
->getContext(), name
), dialect
, typeID
,
741 std::move(interfaceMap
)) {}
743 OperationName::OperationName(StringRef name
, MLIRContext
*context
) {
744 MLIRContextImpl
&ctxImpl
= context
->getImpl();
746 // Check for an existing name in read-only mode.
747 bool isMultithreadingEnabled
= context
->isMultithreadingEnabled();
748 if (isMultithreadingEnabled
) {
749 // Check the registered info map first. In the overwhelmingly common case,
750 // the entry will be in here and it also removes the need to acquire any
752 auto registeredIt
= ctxImpl
.registeredOperations
.find(name
);
753 if (LLVM_LIKELY(registeredIt
!= ctxImpl
.registeredOperations
.end())) {
754 impl
= registeredIt
->second
.impl
;
758 llvm::sys::SmartScopedReader
<true> contextLock(ctxImpl
.operationInfoMutex
);
759 auto it
= ctxImpl
.operations
.find(name
);
760 if (it
!= ctxImpl
.operations
.end()) {
761 impl
= it
->second
.get();
766 // Acquire a writer-lock so that we can safely create the new instance.
767 ScopedWriterLock
lock(ctxImpl
.operationInfoMutex
, isMultithreadingEnabled
);
769 auto it
= ctxImpl
.operations
.insert({name
, nullptr});
771 auto nameAttr
= StringAttr::get(context
, name
);
772 it
.first
->second
= std::make_unique
<UnregisteredOpModel
>(
773 nameAttr
, nameAttr
.getReferencedDialect(), TypeID::get
<void>(),
774 detail::InterfaceMap());
776 impl
= it
.first
->second
.get();
779 StringRef
OperationName::getDialectNamespace() const {
780 if (Dialect
*dialect
= getDialect())
781 return dialect
->getNamespace();
782 return getStringRef().split('.').first
;
786 OperationName::UnregisteredOpModel::foldHook(Operation
*, ArrayRef
<Attribute
>,
787 SmallVectorImpl
<OpFoldResult
> &) {
790 void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
791 RewritePatternSet
&, MLIRContext
*) {}
792 bool OperationName::UnregisteredOpModel::hasTrait(TypeID
) { return false; }
794 OperationName::ParseAssemblyFn
795 OperationName::UnregisteredOpModel::getParseAssemblyFn() {
796 llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
798 void OperationName::UnregisteredOpModel::populateDefaultAttrs(
799 const OperationName
&, NamedAttrList
&) {}
800 void OperationName::UnregisteredOpModel::printAssembly(
801 Operation
*op
, OpAsmPrinter
&p
, StringRef defaultDialect
) {
802 p
.printGenericOp(op
);
805 OperationName::UnregisteredOpModel::verifyInvariants(Operation
*) {
809 OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation
*) {
813 std::optional
<Attribute
>
814 OperationName::UnregisteredOpModel::getInherentAttr(Operation
*op
,
816 auto dict
= dyn_cast_or_null
<DictionaryAttr
>(getPropertiesAsAttr(op
));
819 if (Attribute attr
= dict
.get(name
))
823 void OperationName::UnregisteredOpModel::setInherentAttr(Operation
*op
,
826 auto dict
= dyn_cast_or_null
<DictionaryAttr
>(getPropertiesAsAttr(op
));
828 NamedAttrList
attrs(dict
);
829 attrs
.set(name
, value
);
830 *op
->getPropertiesStorage().as
<Attribute
*>() =
831 attrs
.getDictionary(op
->getContext());
833 void OperationName::UnregisteredOpModel::populateInherentAttrs(
834 Operation
*op
, NamedAttrList
&attrs
) {}
835 LogicalResult
OperationName::UnregisteredOpModel::verifyInherentAttrs(
836 OperationName opName
, NamedAttrList
&attributes
,
837 function_ref
<InFlightDiagnostic()> emitError
) {
840 int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
841 return sizeof(Attribute
);
843 void OperationName::UnregisteredOpModel::initProperties(
844 OperationName opName
, OpaqueProperties storage
, OpaqueProperties init
) {
845 new (storage
.as
<Attribute
*>()) Attribute();
847 void OperationName::UnregisteredOpModel::deleteProperties(
848 OpaqueProperties prop
) {
849 prop
.as
<Attribute
*>()->~Attribute();
851 void OperationName::UnregisteredOpModel::populateDefaultProperties(
852 OperationName opName
, OpaqueProperties properties
) {}
853 LogicalResult
OperationName::UnregisteredOpModel::setPropertiesFromAttr(
854 OperationName opName
, OpaqueProperties properties
, Attribute attr
,
855 function_ref
<InFlightDiagnostic()> emitError
) {
856 *properties
.as
<Attribute
*>() = attr
;
860 OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation
*op
) {
861 return *op
->getPropertiesStorage().as
<Attribute
*>();
863 void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs
,
864 OpaqueProperties rhs
) {
865 *lhs
.as
<Attribute
*>() = *rhs
.as
<Attribute
*>();
867 bool OperationName::UnregisteredOpModel::compareProperties(OpaqueProperties lhs
,
868 OpaqueProperties rhs
) {
869 return *lhs
.as
<Attribute
*>() == *rhs
.as
<Attribute
*>();
872 OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop
) {
873 return llvm::hash_combine(*prop
.as
<Attribute
*>());
876 //===----------------------------------------------------------------------===//
877 // RegisteredOperationName
878 //===----------------------------------------------------------------------===//
880 std::optional
<RegisteredOperationName
>
881 RegisteredOperationName::lookup(StringRef name
, MLIRContext
*ctx
) {
882 auto &impl
= ctx
->getImpl();
883 auto it
= impl
.registeredOperations
.find(name
);
884 if (it
!= impl
.registeredOperations
.end())
885 return it
->getValue();
889 void RegisteredOperationName::insert(
890 std::unique_ptr
<RegisteredOperationName::Impl
> ownedImpl
,
891 ArrayRef
<StringRef
> attrNames
) {
892 RegisteredOperationName::Impl
*impl
= ownedImpl
.get();
893 MLIRContext
*ctx
= impl
->getDialect()->getContext();
894 auto &ctxImpl
= ctx
->getImpl();
895 assert(ctxImpl
.multiThreadedExecutionContext
== 0 &&
896 "registering a new operation kind while in a multi-threaded execution "
899 // Register the attribute names of this operation.
900 MutableArrayRef
<StringAttr
> cachedAttrNames
;
901 if (!attrNames
.empty()) {
902 cachedAttrNames
= MutableArrayRef
<StringAttr
>(
903 ctxImpl
.abstractDialectSymbolAllocator
.Allocate
<StringAttr
>(
906 for (unsigned i
: llvm::seq
<unsigned>(0, attrNames
.size()))
907 new (&cachedAttrNames
[i
]) StringAttr(StringAttr::get(ctx
, attrNames
[i
]));
908 impl
->attributeNames
= cachedAttrNames
;
910 StringRef name
= impl
->getName().strref();
911 // Insert the operation info if it doesn't exist yet.
912 auto it
= ctxImpl
.operations
.insert({name
, nullptr});
913 it
.first
->second
= std::move(ownedImpl
);
915 // Update the registered info for this operation.
916 auto emplaced
= ctxImpl
.registeredOperations
.try_emplace(
917 name
, RegisteredOperationName(impl
));
918 assert(emplaced
.second
&& "operation name registration must be successful");
920 // Add emplaced operation name to the sorted operations container.
921 RegisteredOperationName
&value
= emplaced
.first
->getValue();
922 ctxImpl
.sortedRegisteredOperations
.insert(
923 llvm::upper_bound(ctxImpl
.sortedRegisteredOperations
, value
,
924 [](auto &lhs
, auto &rhs
) {
925 return lhs
.getIdentifier().compare(
926 rhs
.getIdentifier());
931 //===----------------------------------------------------------------------===//
933 //===----------------------------------------------------------------------===//
935 const AbstractType
&AbstractType::lookup(TypeID typeID
, MLIRContext
*context
) {
936 const AbstractType
*type
= lookupMutable(typeID
, context
);
938 llvm::report_fatal_error(
939 "Trying to create a Type that was not registered in this MLIRContext.");
943 AbstractType
*AbstractType::lookupMutable(TypeID typeID
, MLIRContext
*context
) {
944 auto &impl
= context
->getImpl();
945 return impl
.registeredTypes
.lookup(typeID
);
948 //===----------------------------------------------------------------------===//
950 //===----------------------------------------------------------------------===//
952 /// Returns the storage uniquer used for constructing type storage instances.
953 /// This should not be used directly.
954 StorageUniquer
&MLIRContext::getTypeUniquer() { return getImpl().typeUniquer
; }
956 Float8E5M2Type
Float8E5M2Type::get(MLIRContext
*context
) {
957 return context
->getImpl().f8E5M2Ty
;
959 Float8E4M3FNType
Float8E4M3FNType::get(MLIRContext
*context
) {
960 return context
->getImpl().f8E4M3FNTy
;
962 Float8E5M2FNUZType
Float8E5M2FNUZType::get(MLIRContext
*context
) {
963 return context
->getImpl().f8E5M2FNUZTy
;
965 Float8E4M3FNUZType
Float8E4M3FNUZType::get(MLIRContext
*context
) {
966 return context
->getImpl().f8E4M3FNUZTy
;
968 Float8E4M3B11FNUZType
Float8E4M3B11FNUZType::get(MLIRContext
*context
) {
969 return context
->getImpl().f8E4M3B11FNUZTy
;
971 BFloat16Type
BFloat16Type::get(MLIRContext
*context
) {
972 return context
->getImpl().bf16Ty
;
974 Float16Type
Float16Type::get(MLIRContext
*context
) {
975 return context
->getImpl().f16Ty
;
977 FloatTF32Type
FloatTF32Type::get(MLIRContext
*context
) {
978 return context
->getImpl().tf32Ty
;
980 Float32Type
Float32Type::get(MLIRContext
*context
) {
981 return context
->getImpl().f32Ty
;
983 Float64Type
Float64Type::get(MLIRContext
*context
) {
984 return context
->getImpl().f64Ty
;
986 Float80Type
Float80Type::get(MLIRContext
*context
) {
987 return context
->getImpl().f80Ty
;
989 Float128Type
Float128Type::get(MLIRContext
*context
) {
990 return context
->getImpl().f128Ty
;
993 /// Get an instance of the IndexType.
994 IndexType
IndexType::get(MLIRContext
*context
) {
995 return context
->getImpl().indexTy
;
998 /// Return an existing integer type instance if one is cached within the
1001 getCachedIntegerType(unsigned width
,
1002 IntegerType::SignednessSemantics signedness
,
1003 MLIRContext
*context
) {
1004 if (signedness
!= IntegerType::Signless
)
1005 return IntegerType();
1009 return context
->getImpl().int1Ty
;
1011 return context
->getImpl().int8Ty
;
1013 return context
->getImpl().int16Ty
;
1015 return context
->getImpl().int32Ty
;
1017 return context
->getImpl().int64Ty
;
1019 return context
->getImpl().int128Ty
;
1021 return IntegerType();
1025 IntegerType
IntegerType::get(MLIRContext
*context
, unsigned width
,
1026 IntegerType::SignednessSemantics signedness
) {
1027 if (auto cached
= getCachedIntegerType(width
, signedness
, context
))
1029 return Base::get(context
, width
, signedness
);
1033 IntegerType::getChecked(function_ref
<InFlightDiagnostic()> emitError
,
1034 MLIRContext
*context
, unsigned width
,
1035 SignednessSemantics signedness
) {
1036 if (auto cached
= getCachedIntegerType(width
, signedness
, context
))
1038 return Base::getChecked(emitError
, context
, width
, signedness
);
1041 /// Get an instance of the NoneType.
1042 NoneType
NoneType::get(MLIRContext
*context
) {
1043 if (NoneType cachedInst
= context
->getImpl().noneType
)
1045 // Note: May happen when initializing the singleton attributes of the builtin
1047 return Base::get(context
);
1050 //===----------------------------------------------------------------------===//
1051 // Attribute uniquing
1052 //===----------------------------------------------------------------------===//
1054 /// Returns the storage uniquer used for constructing attribute storage
1055 /// instances. This should not be used directly.
1056 StorageUniquer
&MLIRContext::getAttributeUniquer() {
1057 return getImpl().attributeUniquer
;
1060 /// Initialize the given attribute storage instance.
1061 void AttributeUniquer::initializeAttributeStorage(AttributeStorage
*storage
,
1064 storage
->initializeAbstractAttribute(AbstractAttribute::lookup(attrID
, ctx
));
1067 BoolAttr
BoolAttr::get(MLIRContext
*context
, bool value
) {
1068 return value
? context
->getImpl().trueAttr
: context
->getImpl().falseAttr
;
1071 UnitAttr
UnitAttr::get(MLIRContext
*context
) {
1072 return context
->getImpl().unitAttr
;
1075 UnknownLoc
UnknownLoc::get(MLIRContext
*context
) {
1076 return context
->getImpl().unknownLocAttr
;
1079 DistinctAttrStorage
*
1080 detail::DistinctAttributeUniquer::allocateStorage(MLIRContext
*context
,
1081 Attribute referencedAttr
) {
1082 return context
->getImpl().distinctAttributeAllocator
.allocate(referencedAttr
);
1085 /// Return empty dictionary.
1086 DictionaryAttr
DictionaryAttr::getEmpty(MLIRContext
*context
) {
1087 return context
->getImpl().emptyDictionaryAttr
;
1090 void StringAttrStorage::initialize(MLIRContext
*context
) {
1091 // Check for a dialect namespace prefix, if there isn't one we don't need to
1092 // do any additional initialization.
1093 auto dialectNamePair
= value
.split('.');
1094 if (dialectNamePair
.first
.empty() || dialectNamePair
.second
.empty())
1097 // If one exists, we check to see if this dialect is loaded. If it is, we set
1098 // the dialect now, if it isn't we record this storage for initialization
1099 // later if the dialect ever gets loaded.
1100 if ((referencedDialect
= context
->getLoadedDialect(dialectNamePair
.first
)))
1103 MLIRContextImpl
&impl
= context
->getImpl();
1104 llvm::sys::SmartScopedLock
<true> lock(impl
.dialectRefStrAttrMutex
);
1105 impl
.dialectReferencingStrAttrs
[dialectNamePair
.first
].push_back(this);
1108 /// Return an empty string.
1109 StringAttr
StringAttr::get(MLIRContext
*context
) {
1110 return context
->getImpl().emptyStringAttr
;
1113 //===----------------------------------------------------------------------===//
1114 // AffineMap uniquing
1115 //===----------------------------------------------------------------------===//
1117 StorageUniquer
&MLIRContext::getAffineUniquer() {
1118 return getImpl().affineUniquer
;
1121 AffineMap
AffineMap::getImpl(unsigned dimCount
, unsigned symbolCount
,
1122 ArrayRef
<AffineExpr
> results
,
1123 MLIRContext
*context
) {
1124 auto &impl
= context
->getImpl();
1125 auto *storage
= impl
.affineUniquer
.get
<AffineMapStorage
>(
1126 [&](AffineMapStorage
*storage
) { storage
->context
= context
; }, dimCount
,
1127 symbolCount
, results
);
1128 return AffineMap(storage
);
1131 /// Check whether the arguments passed to the AffineMap::get() are consistent.
1132 /// This method checks whether the highest index of dimensional identifier
1133 /// present in result expressions is less than `dimCount` and the highest index
1134 /// of symbolic identifier present in result expressions is less than
1136 LLVM_ATTRIBUTE_UNUSED
static bool
1137 willBeValidAffineMap(unsigned dimCount
, unsigned symbolCount
,
1138 ArrayRef
<AffineExpr
> results
) {
1139 int64_t maxDimPosition
= -1;
1140 int64_t maxSymbolPosition
= -1;
1141 getMaxDimAndSymbol(ArrayRef
<ArrayRef
<AffineExpr
>>(results
), maxDimPosition
,
1143 if ((maxDimPosition
>= dimCount
) || (maxSymbolPosition
>= symbolCount
)) {
1146 << "maximum dimensional identifier position in result expression must "
1147 "be less than `dimCount` and maximum symbolic identifier position "
1148 "in result expression must be less than `symbolCount`\n");
1154 AffineMap
AffineMap::get(MLIRContext
*context
) {
1155 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context
);
1158 AffineMap
AffineMap::get(unsigned dimCount
, unsigned symbolCount
,
1159 MLIRContext
*context
) {
1160 return getImpl(dimCount
, symbolCount
, /*results=*/{}, context
);
1163 AffineMap
AffineMap::get(unsigned dimCount
, unsigned symbolCount
,
1164 AffineExpr result
) {
1165 assert(willBeValidAffineMap(dimCount
, symbolCount
, {result
}));
1166 return getImpl(dimCount
, symbolCount
, {result
}, result
.getContext());
1169 AffineMap
AffineMap::get(unsigned dimCount
, unsigned symbolCount
,
1170 ArrayRef
<AffineExpr
> results
, MLIRContext
*context
) {
1171 assert(willBeValidAffineMap(dimCount
, symbolCount
, results
));
1172 return getImpl(dimCount
, symbolCount
, results
, context
);
1175 //===----------------------------------------------------------------------===//
1176 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1177 // Unlike AffineMap's, these are uniqued only if they are small.
1178 //===----------------------------------------------------------------------===//
1180 IntegerSet
IntegerSet::get(unsigned dimCount
, unsigned symbolCount
,
1181 ArrayRef
<AffineExpr
> constraints
,
1182 ArrayRef
<bool> eqFlags
) {
1183 // The number of constraints can't be zero.
1184 assert(!constraints
.empty());
1185 assert(constraints
.size() == eqFlags
.size());
1187 auto &impl
= constraints
[0].getContext()->getImpl();
1188 auto *storage
= impl
.affineUniquer
.get
<IntegerSetStorage
>(
1189 [](IntegerSetStorage
*) {}, dimCount
, symbolCount
, constraints
, eqFlags
);
1190 return IntegerSet(storage
);
1193 //===----------------------------------------------------------------------===//
1194 // StorageUniquerSupport
1195 //===----------------------------------------------------------------------===//
1197 /// Utility method to generate a callback that can be used to generate a
1198 /// diagnostic when checking the construction invariants of a storage object.
1199 /// This is defined out-of-line to avoid the need to include Location.h.
1200 llvm::unique_function
<InFlightDiagnostic()>
1201 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext
*ctx
) {
1202 return [ctx
] { return emitError(UnknownLoc::get(ctx
)); };
1204 llvm::unique_function
<InFlightDiagnostic()>
1205 mlir::detail::getDefaultDiagnosticEmitFn(const Location
&loc
) {
1206 return [=] { return emitError(loc
); };