[MLIR][LLVM] Fold extract of extract (#125980)
[llvm-project.git] / mlir / lib / Bindings / Python / IRInterfaces.cpp
blob9e1fedaab52352ae5e1f59c12974ebceb25f4151
1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <cstdint>
10 #include <optional>
11 #include <string>
12 #include <utility>
13 #include <vector>
15 #include "IRModule.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/Interfaces.h"
19 #include "mlir-c/Support.h"
20 #include "mlir/Bindings/Python/Nanobind.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
24 namespace nb = nanobind;
26 namespace mlir {
27 namespace python {
29 constexpr static const char *constructorDoc =
30 R"(Creates an interface from a given operation/opview object or from a
31 subclass of OpView. Raises ValueError if the operation does not implement the
32 interface.)";
34 constexpr static const char *operationDoc =
35 R"(Returns an Operation for which the interface was constructed.)";
37 constexpr static const char *opviewDoc =
38 R"(Returns an OpView subclass _instance_ for which the interface was
39 constructed)";
41 constexpr static const char *inferReturnTypesDoc =
42 R"(Given the arguments required to build an operation, attempts to infer
43 its return types. Raises ValueError on failure.)";
45 constexpr static const char *inferReturnTypeComponentsDoc =
46 R"(Given the arguments required to build an operation, attempts to infer
47 its return shaped type components. Raises ValueError on failure.)";
49 namespace {
51 /// Takes in an optional ist of operands and converts them into a SmallVector
52 /// of MlirVlaues. Returns an empty SmallVector if the list is empty.
53 llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
54 llvm::SmallVector<MlirValue> mlirOperands;
56 if (!operandList || operandList->size() == 0) {
57 return mlirOperands;
60 // Note: as the list may contain other lists this may not be final size.
61 mlirOperands.reserve(operandList->size());
62 for (const auto &&it : llvm::enumerate(*operandList)) {
63 if (it.value().is_none())
64 continue;
66 PyValue *val;
67 try {
68 val = nb::cast<PyValue *>(it.value());
69 if (!val)
70 throw nb::cast_error();
71 mlirOperands.push_back(val->get());
72 continue;
73 } catch (nb::cast_error &err) {
74 // Intentionally unhandled to try sequence below first.
75 (void)err;
78 try {
79 auto vals = nb::cast<nb::sequence>(it.value());
80 for (nb::handle v : vals) {
81 try {
82 val = nb::cast<PyValue *>(v);
83 if (!val)
84 throw nb::cast_error();
85 mlirOperands.push_back(val->get());
86 } catch (nb::cast_error &err) {
87 throw nb::value_error(
88 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
89 " must be a Value or Sequence of Values (" + err.what() + ")")
90 .str()
91 .c_str());
94 continue;
95 } catch (nb::cast_error &err) {
96 throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
97 " must be a Value or Sequence of Values (" +
98 err.what() + ")")
99 .str()
100 .c_str());
103 throw nb::cast_error();
106 return mlirOperands;
109 /// Takes in an optional vector of PyRegions and returns a SmallVector of
110 /// MlirRegion. Returns an empty SmallVector if the list is empty.
111 llvm::SmallVector<MlirRegion>
112 wrapRegions(std::optional<std::vector<PyRegion>> regions) {
113 llvm::SmallVector<MlirRegion> mlirRegions;
115 if (regions) {
116 mlirRegions.reserve(regions->size());
117 for (PyRegion &region : *regions) {
118 mlirRegions.push_back(region);
122 return mlirRegions;
125 } // namespace
127 /// CRTP base class for Python classes representing MLIR Op interfaces.
128 /// Interface hierarchies are flat so no base class is expected here. The
129 /// derived class is expected to define the following static fields:
130 /// - `const char *pyClassName` - the name of the Python class to create;
131 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
132 /// of the interface.
133 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
134 /// interface-specific methods.
136 /// An interface class may be constructed from either an Operation/OpView object
137 /// or from a subclass of OpView. In the latter case, only the static interface
138 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
139 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
140 /// method to check whether the interface object was constructed from a class or
141 /// an operation/opview instance. The `getOpName` always succeeds and returns a
142 /// canonical name of the operation suitable for lookups.
143 template <typename ConcreteIface>
144 class PyConcreteOpInterface {
145 protected:
146 using ClassTy = nb::class_<ConcreteIface>;
147 using GetTypeIDFunctionTy = MlirTypeID (*)();
149 public:
150 /// Constructs an interface instance from an object that is either an
151 /// operation or a subclass of OpView. In the latter case, only the static
152 /// methods of the interface are accessible to the caller.
153 PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
154 : obj(std::move(object)) {
155 try {
156 operation = &nb::cast<PyOperation &>(obj);
157 } catch (nb::cast_error &) {
158 // Do nothing.
161 try {
162 operation = &nb::cast<PyOpView &>(obj).getOperation();
163 } catch (nb::cast_error &) {
164 // Do nothing.
167 if (operation != nullptr) {
168 if (!mlirOperationImplementsInterface(*operation,
169 ConcreteIface::getInterfaceID())) {
170 std::string msg = "the operation does not implement ";
171 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
174 MlirIdentifier identifier = mlirOperationGetName(*operation);
175 MlirStringRef stringRef = mlirIdentifierStr(identifier);
176 opName = std::string(stringRef.data, stringRef.length);
177 } else {
178 try {
179 opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
180 } catch (nb::cast_error &) {
181 throw nb::type_error(
182 "Op interface does not refer to an operation or OpView class");
185 if (!mlirOperationImplementsInterfaceStatic(
186 mlirStringRefCreate(opName.data(), opName.length()),
187 context.resolve().get(), ConcreteIface::getInterfaceID())) {
188 std::string msg = "the operation does not implement ";
189 throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
194 /// Creates the Python bindings for this class in the given module.
195 static void bind(nb::module_ &m) {
196 nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
197 cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
198 nb::arg("context").none() = nb::none(), constructorDoc)
199 .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
200 operationDoc)
201 .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
202 ConcreteIface::bindDerived(cls);
205 /// Hook for derived classes to add class-specific bindings.
206 static void bindDerived(ClassTy &cls) {}
208 /// Returns `true` if this object was constructed from a subclass of OpView
209 /// rather than from an operation instance.
210 bool isStatic() { return operation == nullptr; }
212 /// Returns the operation instance from which this object was constructed.
213 /// Throws a type error if this object was constructed from a subclass of
214 /// OpView.
215 nb::object getOperationObject() {
216 if (operation == nullptr) {
217 throw nb::type_error("Cannot get an operation from a static interface");
220 return operation->getRef().releaseObject();
223 /// Returns the opview of the operation instance from which this object was
224 /// constructed. Throws a type error if this object was constructed form a
225 /// subclass of OpView.
226 nb::object getOpView() {
227 if (operation == nullptr) {
228 throw nb::type_error("Cannot get an opview from a static interface");
231 return operation->createOpView();
234 /// Returns the canonical name of the operation this interface is constructed
235 /// from.
236 const std::string &getOpName() { return opName; }
238 private:
239 PyOperation *operation = nullptr;
240 std::string opName;
241 nb::object obj;
244 /// Python wrapper for InferTypeOpInterface. This interface has only static
245 /// methods.
246 class PyInferTypeOpInterface
247 : public PyConcreteOpInterface<PyInferTypeOpInterface> {
248 public:
249 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
251 constexpr static const char *pyClassName = "InferTypeOpInterface";
252 constexpr static GetTypeIDFunctionTy getInterfaceID =
253 &mlirInferTypeOpInterfaceTypeID;
255 /// C-style user-data structure for type appending callback.
256 struct AppendResultsCallbackData {
257 std::vector<PyType> &inferredTypes;
258 PyMlirContext &pyMlirContext;
261 /// Appends the types provided as the two first arguments to the user-data
262 /// structure (expects AppendResultsCallbackData).
263 static void appendResultsCallback(intptr_t nTypes, MlirType *types,
264 void *userData) {
265 auto *data = static_cast<AppendResultsCallbackData *>(userData);
266 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
267 for (intptr_t i = 0; i < nTypes; ++i) {
268 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
272 /// Given the arguments required to build an operation, attempts to infer its
273 /// return types. Throws value_error on failure.
274 std::vector<PyType>
275 inferReturnTypes(std::optional<nb::list> operandList,
276 std::optional<PyAttribute> attributes, void *properties,
277 std::optional<std::vector<PyRegion>> regions,
278 DefaultingPyMlirContext context,
279 DefaultingPyLocation location) {
280 llvm::SmallVector<MlirValue> mlirOperands =
281 wrapOperands(std::move(operandList));
282 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
284 std::vector<PyType> inferredTypes;
285 PyMlirContext &pyContext = context.resolve();
286 AppendResultsCallbackData data{inferredTypes, pyContext};
287 MlirStringRef opNameRef =
288 mlirStringRefCreate(getOpName().data(), getOpName().length());
289 MlirAttribute attributeDict =
290 attributes ? attributes->get() : mlirAttributeGetNull();
292 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
293 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
294 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
295 mlirRegions.data(), &appendResultsCallback, &data);
297 if (mlirLogicalResultIsFailure(result)) {
298 throw nb::value_error("Failed to infer result types");
301 return inferredTypes;
304 static void bindDerived(ClassTy &cls) {
305 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
306 nb::arg("operands").none() = nb::none(),
307 nb::arg("attributes").none() = nb::none(),
308 nb::arg("properties").none() = nb::none(),
309 nb::arg("regions").none() = nb::none(),
310 nb::arg("context").none() = nb::none(),
311 nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
315 /// Wrapper around an shaped type components.
316 class PyShapedTypeComponents {
317 public:
318 PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
319 PyShapedTypeComponents(nb::list shape, MlirType elementType)
320 : shape(std::move(shape)), elementType(elementType), ranked(true) {}
321 PyShapedTypeComponents(nb::list shape, MlirType elementType,
322 MlirAttribute attribute)
323 : shape(std::move(shape)), elementType(elementType), attribute(attribute),
324 ranked(true) {}
325 PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
326 PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
327 : shape(other.shape), elementType(other.elementType),
328 attribute(other.attribute), ranked(other.ranked) {}
330 static void bind(nb::module_ &m) {
331 nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
332 .def_prop_ro(
333 "element_type",
334 [](PyShapedTypeComponents &self) { return self.elementType; },
335 "Returns the element type of the shaped type components.")
336 .def_static(
337 "get",
338 [](PyType &elementType) {
339 return PyShapedTypeComponents(elementType);
341 nb::arg("element_type"),
342 "Create an shaped type components object with only the element "
343 "type.")
344 .def_static(
345 "get",
346 [](nb::list shape, PyType &elementType) {
347 return PyShapedTypeComponents(std::move(shape), elementType);
349 nb::arg("shape"), nb::arg("element_type"),
350 "Create a ranked shaped type components object.")
351 .def_static(
352 "get",
353 [](nb::list shape, PyType &elementType, PyAttribute &attribute) {
354 return PyShapedTypeComponents(std::move(shape), elementType,
355 attribute);
357 nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
358 "Create a ranked shaped type components object with attribute.")
359 .def_prop_ro(
360 "has_rank",
361 [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
362 "Returns whether the given shaped type component is ranked.")
363 .def_prop_ro(
364 "rank",
365 [](PyShapedTypeComponents &self) -> nb::object {
366 if (!self.ranked) {
367 return nb::none();
369 return nb::int_(self.shape.size());
371 "Returns the rank of the given ranked shaped type components. If "
372 "the shaped type components does not have a rank, None is "
373 "returned.")
374 .def_prop_ro(
375 "shape",
376 [](PyShapedTypeComponents &self) -> nb::object {
377 if (!self.ranked) {
378 return nb::none();
380 return nb::list(self.shape);
382 "Returns the shape of the ranked shaped type components as a list "
383 "of integers. Returns none if the shaped type component does not "
384 "have a rank.");
387 nb::object getCapsule();
388 static PyShapedTypeComponents createFromCapsule(nb::object capsule);
390 private:
391 nb::list shape;
392 MlirType elementType;
393 MlirAttribute attribute;
394 bool ranked{false};
397 /// Python wrapper for InferShapedTypeOpInterface. This interface has only
398 /// static methods.
399 class PyInferShapedTypeOpInterface
400 : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
401 public:
402 using PyConcreteOpInterface<
403 PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
405 constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
406 constexpr static GetTypeIDFunctionTy getInterfaceID =
407 &mlirInferShapedTypeOpInterfaceTypeID;
409 /// C-style user-data structure for type appending callback.
410 struct AppendResultsCallbackData {
411 std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
414 /// Appends the shaped type components provided as unpacked shape, element
415 /// type, attribute to the user-data.
416 static void appendResultsCallback(bool hasRank, intptr_t rank,
417 const int64_t *shape, MlirType elementType,
418 MlirAttribute attribute, void *userData) {
419 auto *data = static_cast<AppendResultsCallbackData *>(userData);
420 if (!hasRank) {
421 data->inferredShapedTypeComponents.emplace_back(elementType);
422 } else {
423 nb::list shapeList;
424 for (intptr_t i = 0; i < rank; ++i) {
425 shapeList.append(shape[i]);
427 data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
428 attribute);
432 /// Given the arguments required to build an operation, attempts to infer the
433 /// shaped type components. Throws value_error on failure.
434 std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
435 std::optional<nb::list> operandList,
436 std::optional<PyAttribute> attributes, void *properties,
437 std::optional<std::vector<PyRegion>> regions,
438 DefaultingPyMlirContext context, DefaultingPyLocation location) {
439 llvm::SmallVector<MlirValue> mlirOperands =
440 wrapOperands(std::move(operandList));
441 llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
443 std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
444 PyMlirContext &pyContext = context.resolve();
445 AppendResultsCallbackData data{inferredShapedTypeComponents};
446 MlirStringRef opNameRef =
447 mlirStringRefCreate(getOpName().data(), getOpName().length());
448 MlirAttribute attributeDict =
449 attributes ? attributes->get() : mlirAttributeGetNull();
451 MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
452 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
453 mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
454 mlirRegions.data(), &appendResultsCallback, &data);
456 if (mlirLogicalResultIsFailure(result)) {
457 throw nb::value_error("Failed to infer result shape type components");
460 return inferredShapedTypeComponents;
463 static void bindDerived(ClassTy &cls) {
464 cls.def("inferReturnTypeComponents",
465 &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
466 nb::arg("operands").none() = nb::none(),
467 nb::arg("attributes").none() = nb::none(),
468 nb::arg("regions").none() = nb::none(),
469 nb::arg("properties").none() = nb::none(),
470 nb::arg("context").none() = nb::none(),
471 nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
475 void populateIRInterfaces(nb::module_ &m) {
476 PyInferTypeOpInterface::bind(m);
477 PyShapedTypeComponents::bind(m);
478 PyInferShapedTypeOpInterface::bind(m);
481 } // namespace python
482 } // namespace mlir