1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/Interfaces.h"
19 #include "mlir-c/Support.h"
20 #include "mlir/Bindings/Python/Nanobind.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
24 namespace nb
= nanobind
;
29 constexpr static const char *constructorDoc
=
30 R
"(Creates an interface from a given operation/opview object or from a
31 subclass of OpView. Raises ValueError if the operation does not implement the
34 constexpr static const char *operationDoc
=
35 R
"(Returns an Operation for which the interface was constructed.)";
37 constexpr static const char *opviewDoc
=
38 R
"(Returns an OpView subclass _instance_ for which the interface was
41 constexpr static const char *inferReturnTypesDoc
=
42 R
"(Given the arguments required to build an operation, attempts to infer
43 its return types. Raises ValueError on failure.)";
45 constexpr static const char *inferReturnTypeComponentsDoc
=
46 R
"(Given the arguments required to build an operation, attempts to infer
47 its return shaped type components. Raises ValueError on failure.)";
51 /// Takes in an optional ist of operands and converts them into a SmallVector
52 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
53 llvm::SmallVector
<MlirValue
> wrapOperands(std::optional
<nb::list
> operandList
) {
54 llvm::SmallVector
<MlirValue
> mlirOperands
;
56 if (!operandList
|| operandList
->size() == 0) {
60 // Note: as the list may contain other lists this may not be final size.
61 mlirOperands
.reserve(operandList
->size());
62 for (const auto &&it
: llvm::enumerate(*operandList
)) {
63 if (it
.value().is_none())
68 val
= nb::cast
<PyValue
*>(it
.value());
70 throw nb::cast_error();
71 mlirOperands
.push_back(val
->get());
73 } catch (nb::cast_error
&err
) {
74 // Intentionally unhandled to try sequence below first.
79 auto vals
= nb::cast
<nb::sequence
>(it
.value());
80 for (nb::handle v
: vals
) {
82 val
= nb::cast
<PyValue
*>(v
);
84 throw nb::cast_error();
85 mlirOperands
.push_back(val
->get());
86 } catch (nb::cast_error
&err
) {
87 throw nb::value_error(
88 (llvm::Twine("Operand ") + llvm::Twine(it
.index()) +
89 " must be a Value or Sequence of Values (" + err
.what() + ")")
95 } catch (nb::cast_error
&err
) {
96 throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it
.index()) +
97 " must be a Value or Sequence of Values (" +
103 throw nb::cast_error();
109 /// Takes in an optional vector of PyRegions and returns a SmallVector of
110 /// MlirRegion. Returns an empty SmallVector if the list is empty.
111 llvm::SmallVector
<MlirRegion
>
112 wrapRegions(std::optional
<std::vector
<PyRegion
>> regions
) {
113 llvm::SmallVector
<MlirRegion
> mlirRegions
;
116 mlirRegions
.reserve(regions
->size());
117 for (PyRegion
®ion
: *regions
) {
118 mlirRegions
.push_back(region
);
127 /// CRTP base class for Python classes representing MLIR Op interfaces.
128 /// Interface hierarchies are flat so no base class is expected here. The
129 /// derived class is expected to define the following static fields:
130 /// - `const char *pyClassName` - the name of the Python class to create;
131 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
132 /// of the interface.
133 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
134 /// interface-specific methods.
136 /// An interface class may be constructed from either an Operation/OpView object
137 /// or from a subclass of OpView. In the latter case, only the static interface
138 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
139 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
140 /// method to check whether the interface object was constructed from a class or
141 /// an operation/opview instance. The `getOpName` always succeeds and returns a
142 /// canonical name of the operation suitable for lookups.
143 template <typename ConcreteIface
>
144 class PyConcreteOpInterface
{
146 using ClassTy
= nb::class_
<ConcreteIface
>;
147 using GetTypeIDFunctionTy
= MlirTypeID (*)();
150 /// Constructs an interface instance from an object that is either an
151 /// operation or a subclass of OpView. In the latter case, only the static
152 /// methods of the interface are accessible to the caller.
153 PyConcreteOpInterface(nb::object object
, DefaultingPyMlirContext context
)
154 : obj(std::move(object
)) {
156 operation
= &nb::cast
<PyOperation
&>(obj
);
157 } catch (nb::cast_error
&) {
162 operation
= &nb::cast
<PyOpView
&>(obj
).getOperation();
163 } catch (nb::cast_error
&) {
167 if (operation
!= nullptr) {
168 if (!mlirOperationImplementsInterface(*operation
,
169 ConcreteIface::getInterfaceID())) {
170 std::string msg
= "the operation does not implement ";
171 throw nb::value_error((msg
+ ConcreteIface::pyClassName
).c_str());
174 MlirIdentifier identifier
= mlirOperationGetName(*operation
);
175 MlirStringRef stringRef
= mlirIdentifierStr(identifier
);
176 opName
= std::string(stringRef
.data
, stringRef
.length
);
179 opName
= nb::cast
<std::string
>(obj
.attr("OPERATION_NAME"));
180 } catch (nb::cast_error
&) {
181 throw nb::type_error(
182 "Op interface does not refer to an operation or OpView class");
185 if (!mlirOperationImplementsInterfaceStatic(
186 mlirStringRefCreate(opName
.data(), opName
.length()),
187 context
.resolve().get(), ConcreteIface::getInterfaceID())) {
188 std::string msg
= "the operation does not implement ";
189 throw nb::value_error((msg
+ ConcreteIface::pyClassName
).c_str());
194 /// Creates the Python bindings for this class in the given module.
195 static void bind(nb::module_
&m
) {
196 nb::class_
<ConcreteIface
> cls(m
, ConcreteIface::pyClassName
);
197 cls
.def(nb::init
<nb::object
, DefaultingPyMlirContext
>(), nb::arg("object"),
198 nb::arg("context").none() = nb::none(), constructorDoc
)
199 .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject
,
201 .def_prop_ro("opview", &PyConcreteOpInterface::getOpView
, opviewDoc
);
202 ConcreteIface::bindDerived(cls
);
205 /// Hook for derived classes to add class-specific bindings.
206 static void bindDerived(ClassTy
&cls
) {}
208 /// Returns `true` if this object was constructed from a subclass of OpView
209 /// rather than from an operation instance.
210 bool isStatic() { return operation
== nullptr; }
212 /// Returns the operation instance from which this object was constructed.
213 /// Throws a type error if this object was constructed from a subclass of
215 nb::object
getOperationObject() {
216 if (operation
== nullptr) {
217 throw nb::type_error("Cannot get an operation from a static interface");
220 return operation
->getRef().releaseObject();
223 /// Returns the opview of the operation instance from which this object was
224 /// constructed. Throws a type error if this object was constructed form a
225 /// subclass of OpView.
226 nb::object
getOpView() {
227 if (operation
== nullptr) {
228 throw nb::type_error("Cannot get an opview from a static interface");
231 return operation
->createOpView();
234 /// Returns the canonical name of the operation this interface is constructed
236 const std::string
&getOpName() { return opName
; }
239 PyOperation
*operation
= nullptr;
244 /// Python wrapper for InferTypeOpInterface. This interface has only static
246 class PyInferTypeOpInterface
247 : public PyConcreteOpInterface
<PyInferTypeOpInterface
> {
249 using PyConcreteOpInterface
<PyInferTypeOpInterface
>::PyConcreteOpInterface
;
251 constexpr static const char *pyClassName
= "InferTypeOpInterface";
252 constexpr static GetTypeIDFunctionTy getInterfaceID
=
253 &mlirInferTypeOpInterfaceTypeID
;
255 /// C-style user-data structure for type appending callback.
256 struct AppendResultsCallbackData
{
257 std::vector
<PyType
> &inferredTypes
;
258 PyMlirContext
&pyMlirContext
;
261 /// Appends the types provided as the two first arguments to the user-data
262 /// structure (expects AppendResultsCallbackData).
263 static void appendResultsCallback(intptr_t nTypes
, MlirType
*types
,
265 auto *data
= static_cast<AppendResultsCallbackData
*>(userData
);
266 data
->inferredTypes
.reserve(data
->inferredTypes
.size() + nTypes
);
267 for (intptr_t i
= 0; i
< nTypes
; ++i
) {
268 data
->inferredTypes
.emplace_back(data
->pyMlirContext
.getRef(), types
[i
]);
272 /// Given the arguments required to build an operation, attempts to infer its
273 /// return types. Throws value_error on failure.
275 inferReturnTypes(std::optional
<nb::list
> operandList
,
276 std::optional
<PyAttribute
> attributes
, void *properties
,
277 std::optional
<std::vector
<PyRegion
>> regions
,
278 DefaultingPyMlirContext context
,
279 DefaultingPyLocation location
) {
280 llvm::SmallVector
<MlirValue
> mlirOperands
=
281 wrapOperands(std::move(operandList
));
282 llvm::SmallVector
<MlirRegion
> mlirRegions
= wrapRegions(std::move(regions
));
284 std::vector
<PyType
> inferredTypes
;
285 PyMlirContext
&pyContext
= context
.resolve();
286 AppendResultsCallbackData data
{inferredTypes
, pyContext
};
287 MlirStringRef opNameRef
=
288 mlirStringRefCreate(getOpName().data(), getOpName().length());
289 MlirAttribute attributeDict
=
290 attributes
? attributes
->get() : mlirAttributeGetNull();
292 MlirLogicalResult result
= mlirInferTypeOpInterfaceInferReturnTypes(
293 opNameRef
, pyContext
.get(), location
.resolve(), mlirOperands
.size(),
294 mlirOperands
.data(), attributeDict
, properties
, mlirRegions
.size(),
295 mlirRegions
.data(), &appendResultsCallback
, &data
);
297 if (mlirLogicalResultIsFailure(result
)) {
298 throw nb::value_error("Failed to infer result types");
301 return inferredTypes
;
304 static void bindDerived(ClassTy
&cls
) {
305 cls
.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes
,
306 nb::arg("operands").none() = nb::none(),
307 nb::arg("attributes").none() = nb::none(),
308 nb::arg("properties").none() = nb::none(),
309 nb::arg("regions").none() = nb::none(),
310 nb::arg("context").none() = nb::none(),
311 nb::arg("loc").none() = nb::none(), inferReturnTypesDoc
);
315 /// Wrapper around an shaped type components.
316 class PyShapedTypeComponents
{
318 PyShapedTypeComponents(MlirType elementType
) : elementType(elementType
) {}
319 PyShapedTypeComponents(nb::list shape
, MlirType elementType
)
320 : shape(std::move(shape
)), elementType(elementType
), ranked(true) {}
321 PyShapedTypeComponents(nb::list shape
, MlirType elementType
,
322 MlirAttribute attribute
)
323 : shape(std::move(shape
)), elementType(elementType
), attribute(attribute
),
325 PyShapedTypeComponents(PyShapedTypeComponents
&) = delete;
326 PyShapedTypeComponents(PyShapedTypeComponents
&&other
) noexcept
327 : shape(other
.shape
), elementType(other
.elementType
),
328 attribute(other
.attribute
), ranked(other
.ranked
) {}
330 static void bind(nb::module_
&m
) {
331 nb::class_
<PyShapedTypeComponents
>(m
, "ShapedTypeComponents")
334 [](PyShapedTypeComponents
&self
) { return self
.elementType
; },
335 "Returns the element type of the shaped type components.")
338 [](PyType
&elementType
) {
339 return PyShapedTypeComponents(elementType
);
341 nb::arg("element_type"),
342 "Create an shaped type components object with only the element "
346 [](nb::list shape
, PyType
&elementType
) {
347 return PyShapedTypeComponents(std::move(shape
), elementType
);
349 nb::arg("shape"), nb::arg("element_type"),
350 "Create a ranked shaped type components object.")
353 [](nb::list shape
, PyType
&elementType
, PyAttribute
&attribute
) {
354 return PyShapedTypeComponents(std::move(shape
), elementType
,
357 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
358 "Create a ranked shaped type components object with attribute.")
361 [](PyShapedTypeComponents
&self
) -> bool { return self
.ranked
; },
362 "Returns whether the given shaped type component is ranked.")
365 [](PyShapedTypeComponents
&self
) -> nb::object
{
369 return nb::int_(self
.shape
.size());
371 "Returns the rank of the given ranked shaped type components. If "
372 "the shaped type components does not have a rank, None is "
376 [](PyShapedTypeComponents
&self
) -> nb::object
{
380 return nb::list(self
.shape
);
382 "Returns the shape of the ranked shaped type components as a list "
383 "of integers. Returns none if the shaped type component does not "
387 nb::object
getCapsule();
388 static PyShapedTypeComponents
createFromCapsule(nb::object capsule
);
392 MlirType elementType
;
393 MlirAttribute attribute
;
397 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
399 class PyInferShapedTypeOpInterface
400 : public PyConcreteOpInterface
<PyInferShapedTypeOpInterface
> {
402 using PyConcreteOpInterface
<
403 PyInferShapedTypeOpInterface
>::PyConcreteOpInterface
;
405 constexpr static const char *pyClassName
= "InferShapedTypeOpInterface";
406 constexpr static GetTypeIDFunctionTy getInterfaceID
=
407 &mlirInferShapedTypeOpInterfaceTypeID
;
409 /// C-style user-data structure for type appending callback.
410 struct AppendResultsCallbackData
{
411 std::vector
<PyShapedTypeComponents
> &inferredShapedTypeComponents
;
414 /// Appends the shaped type components provided as unpacked shape, element
415 /// type, attribute to the user-data.
416 static void appendResultsCallback(bool hasRank
, intptr_t rank
,
417 const int64_t *shape
, MlirType elementType
,
418 MlirAttribute attribute
, void *userData
) {
419 auto *data
= static_cast<AppendResultsCallbackData
*>(userData
);
421 data
->inferredShapedTypeComponents
.emplace_back(elementType
);
424 for (intptr_t i
= 0; i
< rank
; ++i
) {
425 shapeList
.append(shape
[i
]);
427 data
->inferredShapedTypeComponents
.emplace_back(shapeList
, elementType
,
432 /// Given the arguments required to build an operation, attempts to infer the
433 /// shaped type components. Throws value_error on failure.
434 std::vector
<PyShapedTypeComponents
> inferReturnTypeComponents(
435 std::optional
<nb::list
> operandList
,
436 std::optional
<PyAttribute
> attributes
, void *properties
,
437 std::optional
<std::vector
<PyRegion
>> regions
,
438 DefaultingPyMlirContext context
, DefaultingPyLocation location
) {
439 llvm::SmallVector
<MlirValue
> mlirOperands
=
440 wrapOperands(std::move(operandList
));
441 llvm::SmallVector
<MlirRegion
> mlirRegions
= wrapRegions(std::move(regions
));
443 std::vector
<PyShapedTypeComponents
> inferredShapedTypeComponents
;
444 PyMlirContext
&pyContext
= context
.resolve();
445 AppendResultsCallbackData data
{inferredShapedTypeComponents
};
446 MlirStringRef opNameRef
=
447 mlirStringRefCreate(getOpName().data(), getOpName().length());
448 MlirAttribute attributeDict
=
449 attributes
? attributes
->get() : mlirAttributeGetNull();
451 MlirLogicalResult result
= mlirInferShapedTypeOpInterfaceInferReturnTypes(
452 opNameRef
, pyContext
.get(), location
.resolve(), mlirOperands
.size(),
453 mlirOperands
.data(), attributeDict
, properties
, mlirRegions
.size(),
454 mlirRegions
.data(), &appendResultsCallback
, &data
);
456 if (mlirLogicalResultIsFailure(result
)) {
457 throw nb::value_error("Failed to infer result shape type components");
460 return inferredShapedTypeComponents
;
463 static void bindDerived(ClassTy
&cls
) {
464 cls
.def("inferReturnTypeComponents",
465 &PyInferShapedTypeOpInterface::inferReturnTypeComponents
,
466 nb::arg("operands").none() = nb::none(),
467 nb::arg("attributes").none() = nb::none(),
468 nb::arg("regions").none() = nb::none(),
469 nb::arg("properties").none() = nb::none(),
470 nb::arg("context").none() = nb::none(),
471 nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc
);
475 void populateIRInterfaces(nb::module_
&m
) {
476 PyInferTypeOpInterface::bind(m
);
477 PyShapedTypeComponents::bind(m
);
478 PyInferShapedTypeOpInterface::bind(m
);
481 } // namespace python