1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include <pybind11/cast.h>
12 #include <pybind11/detail/common.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
20 #include "mlir-c/BuiltinAttributes.h"
21 #include "mlir-c/IR.h"
22 #include "mlir-c/Interfaces.h"
23 #include "mlir-c/Support.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
27 namespace py
= pybind11
;
32 constexpr static const char *constructorDoc
=
33 R
"(Creates an interface from a given operation/opview object or from a
34 subclass of OpView. Raises ValueError if the operation does not implement the
37 constexpr static const char *operationDoc
=
38 R
"(Returns an Operation for which the interface was constructed.)";
40 constexpr static const char *opviewDoc
=
41 R
"(Returns an OpView subclass _instance_ for which the interface was
44 constexpr static const char *inferReturnTypesDoc
=
45 R
"(Given the arguments required to build an operation, attempts to infer
46 its return types. Raises ValueError on failure.)";
48 constexpr static const char *inferReturnTypeComponentsDoc
=
49 R
"(Given the arguments required to build an operation, attempts to infer
50 its return shaped type components. Raises ValueError on failure.)";
54 /// Takes in an optional ist of operands and converts them into a SmallVector
55 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
56 llvm::SmallVector
<MlirValue
> wrapOperands(std::optional
<py::list
> operandList
) {
57 llvm::SmallVector
<MlirValue
> mlirOperands
;
59 if (!operandList
|| operandList
->empty()) {
63 // Note: as the list may contain other lists this may not be final size.
64 mlirOperands
.reserve(operandList
->size());
65 for (const auto &&it
: llvm::enumerate(*operandList
)) {
66 if (it
.value().is_none())
71 val
= py::cast
<PyValue
*>(it
.value());
73 throw py::cast_error();
74 mlirOperands
.push_back(val
->get());
76 } catch (py::cast_error
&err
) {
77 // Intentionally unhandled to try sequence below first.
82 auto vals
= py::cast
<py::sequence
>(it
.value());
83 for (py::object v
: vals
) {
85 val
= py::cast
<PyValue
*>(v
);
87 throw py::cast_error();
88 mlirOperands
.push_back(val
->get());
89 } catch (py::cast_error
&err
) {
90 throw py::value_error(
91 (llvm::Twine("Operand ") + llvm::Twine(it
.index()) +
92 " must be a Value or Sequence of Values (" + err
.what() + ")")
97 } catch (py::cast_error
&err
) {
98 throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it
.index()) +
99 " must be a Value or Sequence of Values (" +
104 throw py::cast_error();
110 /// Takes in an optional vector of PyRegions and returns a SmallVector of
111 /// MlirRegion. Returns an empty SmallVector if the list is empty.
112 llvm::SmallVector
<MlirRegion
>
113 wrapRegions(std::optional
<std::vector
<PyRegion
>> regions
) {
114 llvm::SmallVector
<MlirRegion
> mlirRegions
;
117 mlirRegions
.reserve(regions
->size());
118 for (PyRegion
®ion
: *regions
) {
119 mlirRegions
.push_back(region
);
128 /// CRTP base class for Python classes representing MLIR Op interfaces.
129 /// Interface hierarchies are flat so no base class is expected here. The
130 /// derived class is expected to define the following static fields:
131 /// - `const char *pyClassName` - the name of the Python class to create;
132 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
133 /// of the interface.
134 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
135 /// interface-specific methods.
137 /// An interface class may be constructed from either an Operation/OpView object
138 /// or from a subclass of OpView. In the latter case, only the static interface
139 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
140 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
141 /// method to check whether the interface object was constructed from a class or
142 /// an operation/opview instance. The `getOpName` always succeeds and returns a
143 /// canonical name of the operation suitable for lookups.
144 template <typename ConcreteIface
>
145 class PyConcreteOpInterface
{
147 using ClassTy
= py::class_
<ConcreteIface
>;
148 using GetTypeIDFunctionTy
= MlirTypeID (*)();
151 /// Constructs an interface instance from an object that is either an
152 /// operation or a subclass of OpView. In the latter case, only the static
153 /// methods of the interface are accessible to the caller.
154 PyConcreteOpInterface(py::object object
, DefaultingPyMlirContext context
)
155 : obj(std::move(object
)) {
157 operation
= &py::cast
<PyOperation
&>(obj
);
158 } catch (py::cast_error
&) {
163 operation
= &py::cast
<PyOpView
&>(obj
).getOperation();
164 } catch (py::cast_error
&) {
168 if (operation
!= nullptr) {
169 if (!mlirOperationImplementsInterface(*operation
,
170 ConcreteIface::getInterfaceID())) {
171 std::string msg
= "the operation does not implement ";
172 throw py::value_error(msg
+ ConcreteIface::pyClassName
);
175 MlirIdentifier identifier
= mlirOperationGetName(*operation
);
176 MlirStringRef stringRef
= mlirIdentifierStr(identifier
);
177 opName
= std::string(stringRef
.data
, stringRef
.length
);
180 opName
= obj
.attr("OPERATION_NAME").template cast
<std::string
>();
181 } catch (py::cast_error
&) {
182 throw py::type_error(
183 "Op interface does not refer to an operation or OpView class");
186 if (!mlirOperationImplementsInterfaceStatic(
187 mlirStringRefCreate(opName
.data(), opName
.length()),
188 context
.resolve().get(), ConcreteIface::getInterfaceID())) {
189 std::string msg
= "the operation does not implement ";
190 throw py::value_error(msg
+ ConcreteIface::pyClassName
);
195 /// Creates the Python bindings for this class in the given module.
196 static void bind(py::module
&m
) {
197 py::class_
<ConcreteIface
> cls(m
, ConcreteIface::pyClassName
,
199 cls
.def(py::init
<py::object
, DefaultingPyMlirContext
>(), py::arg("object"),
200 py::arg("context") = py::none(), constructorDoc
)
201 .def_property_readonly("operation",
202 &PyConcreteOpInterface::getOperationObject
,
204 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView
,
206 ConcreteIface::bindDerived(cls
);
209 /// Hook for derived classes to add class-specific bindings.
210 static void bindDerived(ClassTy
&cls
) {}
212 /// Returns `true` if this object was constructed from a subclass of OpView
213 /// rather than from an operation instance.
214 bool isStatic() { return operation
== nullptr; }
216 /// Returns the operation instance from which this object was constructed.
217 /// Throws a type error if this object was constructed from a subclass of
219 py::object
getOperationObject() {
220 if (operation
== nullptr) {
221 throw py::type_error("Cannot get an operation from a static interface");
224 return operation
->getRef().releaseObject();
227 /// Returns the opview of the operation instance from which this object was
228 /// constructed. Throws a type error if this object was constructed form a
229 /// subclass of OpView.
230 py::object
getOpView() {
231 if (operation
== nullptr) {
232 throw py::type_error("Cannot get an opview from a static interface");
235 return operation
->createOpView();
238 /// Returns the canonical name of the operation this interface is constructed
240 const std::string
&getOpName() { return opName
; }
243 PyOperation
*operation
= nullptr;
248 /// Python wrapper for InferTypeOpInterface. This interface has only static
250 class PyInferTypeOpInterface
251 : public PyConcreteOpInterface
<PyInferTypeOpInterface
> {
253 using PyConcreteOpInterface
<PyInferTypeOpInterface
>::PyConcreteOpInterface
;
255 constexpr static const char *pyClassName
= "InferTypeOpInterface";
256 constexpr static GetTypeIDFunctionTy getInterfaceID
=
257 &mlirInferTypeOpInterfaceTypeID
;
259 /// C-style user-data structure for type appending callback.
260 struct AppendResultsCallbackData
{
261 std::vector
<PyType
> &inferredTypes
;
262 PyMlirContext
&pyMlirContext
;
265 /// Appends the types provided as the two first arguments to the user-data
266 /// structure (expects AppendResultsCallbackData).
267 static void appendResultsCallback(intptr_t nTypes
, MlirType
*types
,
269 auto *data
= static_cast<AppendResultsCallbackData
*>(userData
);
270 data
->inferredTypes
.reserve(data
->inferredTypes
.size() + nTypes
);
271 for (intptr_t i
= 0; i
< nTypes
; ++i
) {
272 data
->inferredTypes
.emplace_back(data
->pyMlirContext
.getRef(), types
[i
]);
276 /// Given the arguments required to build an operation, attempts to infer its
277 /// return types. Throws value_error on failure.
279 inferReturnTypes(std::optional
<py::list
> operandList
,
280 std::optional
<PyAttribute
> attributes
, void *properties
,
281 std::optional
<std::vector
<PyRegion
>> regions
,
282 DefaultingPyMlirContext context
,
283 DefaultingPyLocation location
) {
284 llvm::SmallVector
<MlirValue
> mlirOperands
=
285 wrapOperands(std::move(operandList
));
286 llvm::SmallVector
<MlirRegion
> mlirRegions
= wrapRegions(std::move(regions
));
288 std::vector
<PyType
> inferredTypes
;
289 PyMlirContext
&pyContext
= context
.resolve();
290 AppendResultsCallbackData data
{inferredTypes
, pyContext
};
291 MlirStringRef opNameRef
=
292 mlirStringRefCreate(getOpName().data(), getOpName().length());
293 MlirAttribute attributeDict
=
294 attributes
? attributes
->get() : mlirAttributeGetNull();
296 MlirLogicalResult result
= mlirInferTypeOpInterfaceInferReturnTypes(
297 opNameRef
, pyContext
.get(), location
.resolve(), mlirOperands
.size(),
298 mlirOperands
.data(), attributeDict
, properties
, mlirRegions
.size(),
299 mlirRegions
.data(), &appendResultsCallback
, &data
);
301 if (mlirLogicalResultIsFailure(result
)) {
302 throw py::value_error("Failed to infer result types");
305 return inferredTypes
;
308 static void bindDerived(ClassTy
&cls
) {
309 cls
.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes
,
310 py::arg("operands") = py::none(),
311 py::arg("attributes") = py::none(),
312 py::arg("properties") = py::none(), py::arg("regions") = py::none(),
313 py::arg("context") = py::none(), py::arg("loc") = py::none(),
314 inferReturnTypesDoc
);
318 /// Wrapper around an shaped type components.
319 class PyShapedTypeComponents
{
321 PyShapedTypeComponents(MlirType elementType
) : elementType(elementType
) {}
322 PyShapedTypeComponents(py::list shape
, MlirType elementType
)
323 : shape(std::move(shape
)), elementType(elementType
), ranked(true) {}
324 PyShapedTypeComponents(py::list shape
, MlirType elementType
,
325 MlirAttribute attribute
)
326 : shape(std::move(shape
)), elementType(elementType
), attribute(attribute
),
328 PyShapedTypeComponents(PyShapedTypeComponents
&) = delete;
329 PyShapedTypeComponents(PyShapedTypeComponents
&&other
) noexcept
330 : shape(other
.shape
), elementType(other
.elementType
),
331 attribute(other
.attribute
), ranked(other
.ranked
) {}
333 static void bind(py::module
&m
) {
334 py::class_
<PyShapedTypeComponents
>(m
, "ShapedTypeComponents",
336 .def_property_readonly(
338 [](PyShapedTypeComponents
&self
) { return self
.elementType
; },
339 "Returns the element type of the shaped type components.")
342 [](PyType
&elementType
) {
343 return PyShapedTypeComponents(elementType
);
345 py::arg("element_type"),
346 "Create an shaped type components object with only the element "
350 [](py::list shape
, PyType
&elementType
) {
351 return PyShapedTypeComponents(std::move(shape
), elementType
);
353 py::arg("shape"), py::arg("element_type"),
354 "Create a ranked shaped type components object.")
357 [](py::list shape
, PyType
&elementType
, PyAttribute
&attribute
) {
358 return PyShapedTypeComponents(std::move(shape
), elementType
,
361 py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
362 "Create a ranked shaped type components object with attribute.")
363 .def_property_readonly(
365 [](PyShapedTypeComponents
&self
) -> bool { return self
.ranked
; },
366 "Returns whether the given shaped type component is ranked.")
367 .def_property_readonly(
369 [](PyShapedTypeComponents
&self
) -> py::object
{
373 return py::int_(self
.shape
.size());
375 "Returns the rank of the given ranked shaped type components. If "
376 "the shaped type components does not have a rank, None is "
378 .def_property_readonly(
380 [](PyShapedTypeComponents
&self
) -> py::object
{
384 return py::list(self
.shape
);
386 "Returns the shape of the ranked shaped type components as a list "
387 "of integers. Returns none if the shaped type component does not "
391 pybind11::object
getCapsule();
392 static PyShapedTypeComponents
createFromCapsule(pybind11::object capsule
);
396 MlirType elementType
;
397 MlirAttribute attribute
;
401 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
403 class PyInferShapedTypeOpInterface
404 : public PyConcreteOpInterface
<PyInferShapedTypeOpInterface
> {
406 using PyConcreteOpInterface
<
407 PyInferShapedTypeOpInterface
>::PyConcreteOpInterface
;
409 constexpr static const char *pyClassName
= "InferShapedTypeOpInterface";
410 constexpr static GetTypeIDFunctionTy getInterfaceID
=
411 &mlirInferShapedTypeOpInterfaceTypeID
;
413 /// C-style user-data structure for type appending callback.
414 struct AppendResultsCallbackData
{
415 std::vector
<PyShapedTypeComponents
> &inferredShapedTypeComponents
;
418 /// Appends the shaped type components provided as unpacked shape, element
419 /// type, attribute to the user-data.
420 static void appendResultsCallback(bool hasRank
, intptr_t rank
,
421 const int64_t *shape
, MlirType elementType
,
422 MlirAttribute attribute
, void *userData
) {
423 auto *data
= static_cast<AppendResultsCallbackData
*>(userData
);
425 data
->inferredShapedTypeComponents
.emplace_back(elementType
);
428 for (intptr_t i
= 0; i
< rank
; ++i
) {
429 shapeList
.append(shape
[i
]);
431 data
->inferredShapedTypeComponents
.emplace_back(shapeList
, elementType
,
436 /// Given the arguments required to build an operation, attempts to infer the
437 /// shaped type components. Throws value_error on failure.
438 std::vector
<PyShapedTypeComponents
> inferReturnTypeComponents(
439 std::optional
<py::list
> operandList
,
440 std::optional
<PyAttribute
> attributes
, void *properties
,
441 std::optional
<std::vector
<PyRegion
>> regions
,
442 DefaultingPyMlirContext context
, DefaultingPyLocation location
) {
443 llvm::SmallVector
<MlirValue
> mlirOperands
=
444 wrapOperands(std::move(operandList
));
445 llvm::SmallVector
<MlirRegion
> mlirRegions
= wrapRegions(std::move(regions
));
447 std::vector
<PyShapedTypeComponents
> inferredShapedTypeComponents
;
448 PyMlirContext
&pyContext
= context
.resolve();
449 AppendResultsCallbackData data
{inferredShapedTypeComponents
};
450 MlirStringRef opNameRef
=
451 mlirStringRefCreate(getOpName().data(), getOpName().length());
452 MlirAttribute attributeDict
=
453 attributes
? attributes
->get() : mlirAttributeGetNull();
455 MlirLogicalResult result
= mlirInferShapedTypeOpInterfaceInferReturnTypes(
456 opNameRef
, pyContext
.get(), location
.resolve(), mlirOperands
.size(),
457 mlirOperands
.data(), attributeDict
, properties
, mlirRegions
.size(),
458 mlirRegions
.data(), &appendResultsCallback
, &data
);
460 if (mlirLogicalResultIsFailure(result
)) {
461 throw py::value_error("Failed to infer result shape type components");
464 return inferredShapedTypeComponents
;
467 static void bindDerived(ClassTy
&cls
) {
468 cls
.def("inferReturnTypeComponents",
469 &PyInferShapedTypeOpInterface::inferReturnTypeComponents
,
470 py::arg("operands") = py::none(),
471 py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
472 py::arg("properties") = py::none(), py::arg("context") = py::none(),
473 py::arg("loc") = py::none(), inferReturnTypeComponentsDoc
);
477 void populateIRInterfaces(py::module
&m
) {
478 PyInferTypeOpInterface::bind(m
);
479 PyShapedTypeComponents::bind(m
);
480 PyInferShapedTypeOpInterface::bind(m
);
483 } // namespace python