1 //===- Globals.h - MLIR Python extension globals --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
10 #define MLIR_BINDINGS_PYTHON_GLOBALS_H
12 #include "PybindUtils.h"
14 #include "mlir-c/IR.h"
15 #include "mlir/CAPI/Support.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/ADT/StringSet.h"
27 /// Globals that are always accessible once the extension has been initialized.
33 /// Most code should get the globals via this static accessor.
34 static PyGlobals
&get() {
35 assert(instance
&& "PyGlobals is null");
39 /// Get and set the list of parent modules to search for dialect
40 /// implementation classes.
41 std::vector
<std::string
> &getDialectSearchPrefixes() {
42 return dialectSearchPrefixes
;
44 void setDialectSearchPrefixes(std::vector
<std::string
> newValues
) {
45 dialectSearchPrefixes
.swap(newValues
);
48 /// Loads a python module corresponding to the given dialect namespace.
49 /// No-ops if the module has already been loaded or is not found. Raises
50 /// an error on any evaluation issues.
51 /// Note that this returns void because it is expected that the module
52 /// contains calls to decorators and helpers that register the salient
53 /// entities. Returns true if dialect is successfully loaded.
54 bool loadDialectModule(llvm::StringRef dialectNamespace
);
56 /// Adds a user-friendly Attribute builder.
57 /// Raises an exception if the mapping already exists and replace == false.
58 /// This is intended to be called by implementation code.
59 void registerAttributeBuilder(const std::string
&attributeKind
,
60 pybind11::function pyFunc
,
61 bool replace
= false);
63 /// Adds a user-friendly type caster. Raises an exception if the mapping
64 /// already exists and replace == false. This is intended to be called by
65 /// implementation code.
66 void registerTypeCaster(MlirTypeID mlirTypeID
, pybind11::function typeCaster
,
67 bool replace
= false);
69 /// Adds a user-friendly value caster. Raises an exception if the mapping
70 /// already exists and replace == false. This is intended to be called by
71 /// implementation code.
72 void registerValueCaster(MlirTypeID mlirTypeID
,
73 pybind11::function valueCaster
,
74 bool replace
= false);
76 /// Adds a concrete implementation dialect class.
77 /// Raises an exception if the mapping already exists.
78 /// This is intended to be called by implementation code.
79 void registerDialectImpl(const std::string
&dialectNamespace
,
80 pybind11::object pyClass
);
82 /// Adds a concrete implementation operation class.
83 /// Raises an exception if the mapping already exists and replace == false.
84 /// This is intended to be called by implementation code.
85 void registerOperationImpl(const std::string
&operationName
,
86 pybind11::object pyClass
, bool replace
= false);
88 /// Returns the custom Attribute builder for Attribute kind.
89 std::optional
<pybind11::function
>
90 lookupAttributeBuilder(const std::string
&attributeKind
);
92 /// Returns the custom type caster for MlirTypeID mlirTypeID.
93 std::optional
<pybind11::function
> lookupTypeCaster(MlirTypeID mlirTypeID
,
96 /// Returns the custom value caster for MlirTypeID mlirTypeID.
97 std::optional
<pybind11::function
> lookupValueCaster(MlirTypeID mlirTypeID
,
100 /// Looks up a registered dialect class by namespace. Note that this may
101 /// trigger loading of the defining module and can arbitrarily re-enter.
102 std::optional
<pybind11::object
>
103 lookupDialectClass(const std::string
&dialectNamespace
);
105 /// Looks up a registered operation class (deriving from OpView) by operation
106 /// name. Note that this may trigger a load of the dialect, which can
107 /// arbitrarily re-enter.
108 std::optional
<pybind11::object
>
109 lookupOperationClass(llvm::StringRef operationName
);
112 static PyGlobals
*instance
;
113 /// Module name prefixes to search under for dialect implementation modules.
114 std::vector
<std::string
> dialectSearchPrefixes
;
115 /// Map of dialect namespace to external dialect class object.
116 llvm::StringMap
<pybind11::object
> dialectClassMap
;
117 /// Map of full operation name to external operation class object.
118 llvm::StringMap
<pybind11::object
> operationClassMap
;
119 /// Map of attribute ODS name to custom builder.
120 llvm::StringMap
<pybind11::object
> attributeBuilderMap
;
121 /// Map of MlirTypeID to custom type caster.
122 llvm::DenseMap
<MlirTypeID
, pybind11::object
> typeCasterMap
;
123 /// Map of MlirTypeID to custom value caster.
124 llvm::DenseMap
<MlirTypeID
, pybind11::object
> valueCasterMap
;
125 /// Set of dialect namespaces that we have attempted to import implementation
127 llvm::StringSet
<> loadedDialectModules
;
130 } // namespace python
133 #endif // MLIR_BINDINGS_PYTHON_GLOBALS_H