1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "PybindUtils.h"
13 #include "mlir-c/Bindings/Python/Interop.h"
14 #include "mlir-c/Support.h"
19 namespace py
= pybind11
;
21 using namespace mlir::python
;
23 // -----------------------------------------------------------------------------
25 // -----------------------------------------------------------------------------
27 PyGlobals
*PyGlobals::instance
= nullptr;
29 PyGlobals::PyGlobals() {
30 assert(!instance
&& "PyGlobals already constructed");
32 // The default search path include {mlir.}dialects, where {mlir.} is the
33 // package prefix configured at compile time.
34 dialectSearchPrefixes
.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
37 PyGlobals::~PyGlobals() { instance
= nullptr; }
39 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace
) {
40 if (loadedDialectModules
.contains(dialectNamespace
))
42 // Since re-entrancy is possible, make a copy of the search prefixes.
43 std::vector
<std::string
> localSearchPrefixes
= dialectSearchPrefixes
;
44 py::object loaded
= py::none();
45 for (std::string moduleName
: localSearchPrefixes
) {
46 moduleName
.push_back('.');
47 moduleName
.append(dialectNamespace
.data(), dialectNamespace
.size());
50 loaded
= py::module::import(moduleName
.c_str());
51 } catch (py::error_already_set
&e
) {
52 if (e
.matches(PyExc_ModuleNotFoundError
)) {
62 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
63 // may have occurred, which may do anything.
64 loadedDialectModules
.insert(dialectNamespace
);
68 void PyGlobals::registerAttributeBuilder(const std::string
&attributeKind
,
69 py::function pyFunc
, bool replace
) {
70 py::object
&found
= attributeBuilderMap
[attributeKind
];
71 if (found
&& !replace
) {
72 throw std::runtime_error((llvm::Twine("Attribute builder for '") +
74 "' is already registered with func: " +
75 py::str(found
).operator std::string())
78 found
= std::move(pyFunc
);
81 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID
,
82 pybind11::function typeCaster
,
84 pybind11::object
&found
= typeCasterMap
[mlirTypeID
];
85 if (found
&& !replace
)
86 throw std::runtime_error("Type caster is already registered with caster: " +
87 py::str(found
).operator std::string());
88 found
= std::move(typeCaster
);
91 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID
,
92 pybind11::function valueCaster
,
94 pybind11::object
&found
= valueCasterMap
[mlirTypeID
];
95 if (found
&& !replace
)
96 throw std::runtime_error("Value caster is already registered: " +
97 py::repr(found
).cast
<std::string
>());
98 found
= std::move(valueCaster
);
101 void PyGlobals::registerDialectImpl(const std::string
&dialectNamespace
,
102 py::object pyClass
) {
103 py::object
&found
= dialectClassMap
[dialectNamespace
];
105 throw std::runtime_error((llvm::Twine("Dialect namespace '") +
106 dialectNamespace
+ "' is already registered.")
109 found
= std::move(pyClass
);
112 void PyGlobals::registerOperationImpl(const std::string
&operationName
,
113 py::object pyClass
, bool replace
) {
114 py::object
&found
= operationClassMap
[operationName
];
115 if (found
&& !replace
) {
116 throw std::runtime_error((llvm::Twine("Operation '") + operationName
+
117 "' is already registered.")
120 found
= std::move(pyClass
);
123 std::optional
<py::function
>
124 PyGlobals::lookupAttributeBuilder(const std::string
&attributeKind
) {
125 const auto foundIt
= attributeBuilderMap
.find(attributeKind
);
126 if (foundIt
!= attributeBuilderMap
.end()) {
127 assert(foundIt
->second
&& "attribute builder is defined");
128 return foundIt
->second
;
133 std::optional
<py::function
> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID
,
134 MlirDialect dialect
) {
135 // Try to load dialect module.
136 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect
)));
137 const auto foundIt
= typeCasterMap
.find(mlirTypeID
);
138 if (foundIt
!= typeCasterMap
.end()) {
139 assert(foundIt
->second
&& "type caster is defined");
140 return foundIt
->second
;
145 std::optional
<py::function
> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID
,
146 MlirDialect dialect
) {
147 // Try to load dialect module.
148 (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect
)));
149 const auto foundIt
= valueCasterMap
.find(mlirTypeID
);
150 if (foundIt
!= valueCasterMap
.end()) {
151 assert(foundIt
->second
&& "value caster is defined");
152 return foundIt
->second
;
157 std::optional
<py::object
>
158 PyGlobals::lookupDialectClass(const std::string
&dialectNamespace
) {
159 // Make sure dialect module is loaded.
160 if (!loadDialectModule(dialectNamespace
))
162 const auto foundIt
= dialectClassMap
.find(dialectNamespace
);
163 if (foundIt
!= dialectClassMap
.end()) {
164 assert(foundIt
->second
&& "dialect class is defined");
165 return foundIt
->second
;
167 // Not found and loading did not yield a registration.
171 std::optional
<pybind11::object
>
172 PyGlobals::lookupOperationClass(llvm::StringRef operationName
) {
173 // Make sure dialect module is loaded.
174 auto split
= operationName
.split('.');
175 llvm::StringRef dialectNamespace
= split
.first
;
176 if (!loadDialectModule(dialectNamespace
))
179 auto foundIt
= operationClassMap
.find(operationName
);
180 if (foundIt
!= operationClassMap
.end()) {
181 assert(foundIt
->second
&& "OpView is defined");
182 return foundIt
->second
;
184 // Not found and loading did not yield a registration.