[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Bindings / Python / IRAttributes.cpp
blob417c66b9165e3b4f4e80b1929b6eaa82138ecd1d
1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <optional>
10 #include <string_view>
11 #include <utility>
13 #include "IRModule.h"
15 #include "PybindUtils.h"
16 #include <pybind11/numpy.h>
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/raw_ostream.h"
21 #include "mlir-c/BuiltinAttributes.h"
22 #include "mlir-c/BuiltinTypes.h"
23 #include "mlir/Bindings/Python/PybindAdaptors.h"
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
29 using llvm::SmallVector;
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
35 static const char kDenseElementsAttrGetDocstring[] =
36 R"(Gets a DenseElementsAttr from a Python buffer or array.
38 When `type` is not provided, then some limited type inferencing is done based
39 on the buffer format. Support presently exists for 8/16/32/64 signed and
40 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41 types can also be converted back to a corresponding buffer.
43 For conversions outside of these types, a `type=` must be explicitly provided
44 and the buffer contents must be bit-castable to the MLIR internal
45 representation:
47 * Integer types (except for i1): the buffer must be byte aligned to the
48 next byte boundary.
49 * Floating point types: Must be bit-castable to the given floating point
50 size.
51 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
55 If a single element buffer is passed (or for i1, a single byte with value 0
56 or 255), then a splat will be created.
58 Args:
59 array: The array or buffer to convert.
60 signless: If inferring an appropriate MLIR type, use signless types for
61 integers (defaults True).
62 type: Skips inference of the MLIR element type and uses this instead. The
63 storage size must be consistent with the actual contents of the buffer.
64 shape: Overrides the shape of the buffer when constructing the MLIR
65 shaped type. This is needed when the physical and logical shape differ (as
66 for i1).
67 context: Explicit context, if not from context manager.
69 Returns:
70 DenseElementsAttr on success.
72 Raises:
73 ValueError: If the type of the buffer or array cannot be matched to an MLIR
74 type or if the buffer does not meet expectations.
75 )";
77 static const char kDenseElementsAttrGetFromListDocstring[] =
78 R"(Gets a DenseElementsAttr from a Python list of attributes.
80 Note that it can be expensive to construct attributes individually.
81 For a large number of elements, consider using a Python buffer or array instead.
83 Args:
84 attrs: A list of attributes.
85 type: The desired shape and type of the resulting DenseElementsAttr.
86 If not provided, the element type is determined based on the type
87 of the 0th attribute and the shape is `[len(attrs)]`.
88 context: Explicit context, if not from context manager.
90 Returns:
91 DenseElementsAttr on success.
93 Raises:
94 ValueError: If the type of the attributes does not match the type
95 specified by `shaped_type`.
96 )";
98 static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
99 R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
101 This function does minimal validation or massaging of the data, and it is
102 up to the caller to ensure that the buffer meets the characteristics
103 implied by the shape.
105 The backing buffer and any user objects will be retained for the lifetime
106 of the resource blob. This is typically bounded to the context but the
107 resource can have a shorter lifespan depending on how it is used in
108 subsequent processing.
110 Args:
111 buffer: The array or buffer to convert.
112 name: Name to provide to the resource (may be changed upon collision).
113 type: The explicit ShapedType to construct the attribute with.
114 context: Explicit context, if not from context manager.
116 Returns:
117 DenseResourceElementsAttr on success.
119 Raises:
120 ValueError: If the type of the buffer or array cannot be matched to an MLIR
121 type or if the buffer does not meet expectations.
124 namespace {
126 static MlirStringRef toMlirStringRef(const std::string &s) {
127 return mlirStringRefCreate(s.data(), s.size());
130 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
131 public:
132 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
133 static constexpr const char *pyClassName = "AffineMapAttr";
134 using PyConcreteAttribute::PyConcreteAttribute;
135 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
136 mlirAffineMapAttrGetTypeID;
138 static void bindDerived(ClassTy &c) {
139 c.def_static(
140 "get",
141 [](PyAffineMap &affineMap) {
142 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
143 return PyAffineMapAttribute(affineMap.getContext(), attr);
145 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
146 c.def_property_readonly("value", mlirAffineMapAttrGetValue,
147 "Returns the value of the AffineMap attribute");
151 class PyIntegerSetAttribute
152 : public PyConcreteAttribute<PyIntegerSetAttribute> {
153 public:
154 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
155 static constexpr const char *pyClassName = "IntegerSetAttr";
156 using PyConcreteAttribute::PyConcreteAttribute;
157 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
158 mlirIntegerSetAttrGetTypeID;
160 static void bindDerived(ClassTy &c) {
161 c.def_static(
162 "get",
163 [](PyIntegerSet &integerSet) {
164 MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
165 return PyIntegerSetAttribute(integerSet.getContext(), attr);
167 py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
171 template <typename T>
172 static T pyTryCast(py::handle object) {
173 try {
174 return object.cast<T>();
175 } catch (py::cast_error &err) {
176 std::string msg =
177 std::string(
178 "Invalid attribute when attempting to create an ArrayAttribute (") +
179 err.what() + ")";
180 throw py::cast_error(msg);
181 } catch (py::reference_cast_error &err) {
182 std::string msg = std::string("Invalid attribute (None?) when attempting "
183 "to create an ArrayAttribute (") +
184 err.what() + ")";
185 throw py::cast_error(msg);
189 /// A python-wrapped dense array attribute with an element type and a derived
190 /// implementation class.
191 template <typename EltTy, typename DerivedT>
192 class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
193 public:
194 using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
196 /// Iterator over the integer elements of a dense array.
197 class PyDenseArrayIterator {
198 public:
199 PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
201 /// Return a copy of the iterator.
202 PyDenseArrayIterator dunderIter() { return *this; }
204 /// Return the next element.
205 EltTy dunderNext() {
206 // Throw if the index has reached the end.
207 if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
208 throw py::stop_iteration();
209 return DerivedT::getElement(attr.get(), nextIndex++);
212 /// Bind the iterator class.
213 static void bind(py::module &m) {
214 py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
215 py::module_local())
216 .def("__iter__", &PyDenseArrayIterator::dunderIter)
217 .def("__next__", &PyDenseArrayIterator::dunderNext);
220 private:
221 /// The referenced dense array attribute.
222 PyAttribute attr;
223 /// The next index to read.
224 int nextIndex = 0;
227 /// Get the element at the given index.
228 EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
230 /// Bind the attribute class.
231 static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
232 // Bind the constructor.
233 c.def_static(
234 "get",
235 [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
236 return getAttribute(values, ctx->getRef());
238 py::arg("values"), py::arg("context") = py::none(),
239 "Gets a uniqued dense array attribute");
240 // Bind the array methods.
241 c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
242 if (i >= mlirDenseArrayGetNumElements(arr))
243 throw py::index_error("DenseArray index out of range");
244 return arr.getItem(i);
246 c.def("__len__", [](const DerivedT &arr) {
247 return mlirDenseArrayGetNumElements(arr);
249 c.def("__iter__",
250 [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
251 c.def("__add__", [](DerivedT &arr, const py::list &extras) {
252 std::vector<EltTy> values;
253 intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
254 values.reserve(numOldElements + py::len(extras));
255 for (intptr_t i = 0; i < numOldElements; ++i)
256 values.push_back(arr.getItem(i));
257 for (py::handle attr : extras)
258 values.push_back(pyTryCast<EltTy>(attr));
259 return getAttribute(values, arr.getContext());
263 private:
264 static DerivedT getAttribute(const std::vector<EltTy> &values,
265 PyMlirContextRef ctx) {
266 if constexpr (std::is_same_v<EltTy, bool>) {
267 std::vector<int> intValues(values.begin(), values.end());
268 MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
269 intValues.data());
270 return DerivedT(ctx, attr);
271 } else {
272 MlirAttribute attr =
273 DerivedT::getAttribute(ctx->get(), values.size(), values.data());
274 return DerivedT(ctx, attr);
279 /// Instantiate the python dense array classes.
280 struct PyDenseBoolArrayAttribute
281 : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
282 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
283 static constexpr auto getAttribute = mlirDenseBoolArrayGet;
284 static constexpr auto getElement = mlirDenseBoolArrayGetElement;
285 static constexpr const char *pyClassName = "DenseBoolArrayAttr";
286 static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
287 using PyDenseArrayAttribute::PyDenseArrayAttribute;
289 struct PyDenseI8ArrayAttribute
290 : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
291 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
292 static constexpr auto getAttribute = mlirDenseI8ArrayGet;
293 static constexpr auto getElement = mlirDenseI8ArrayGetElement;
294 static constexpr const char *pyClassName = "DenseI8ArrayAttr";
295 static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
296 using PyDenseArrayAttribute::PyDenseArrayAttribute;
298 struct PyDenseI16ArrayAttribute
299 : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
300 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
301 static constexpr auto getAttribute = mlirDenseI16ArrayGet;
302 static constexpr auto getElement = mlirDenseI16ArrayGetElement;
303 static constexpr const char *pyClassName = "DenseI16ArrayAttr";
304 static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
305 using PyDenseArrayAttribute::PyDenseArrayAttribute;
307 struct PyDenseI32ArrayAttribute
308 : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
309 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
310 static constexpr auto getAttribute = mlirDenseI32ArrayGet;
311 static constexpr auto getElement = mlirDenseI32ArrayGetElement;
312 static constexpr const char *pyClassName = "DenseI32ArrayAttr";
313 static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
314 using PyDenseArrayAttribute::PyDenseArrayAttribute;
316 struct PyDenseI64ArrayAttribute
317 : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
318 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
319 static constexpr auto getAttribute = mlirDenseI64ArrayGet;
320 static constexpr auto getElement = mlirDenseI64ArrayGetElement;
321 static constexpr const char *pyClassName = "DenseI64ArrayAttr";
322 static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
323 using PyDenseArrayAttribute::PyDenseArrayAttribute;
325 struct PyDenseF32ArrayAttribute
326 : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
327 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
328 static constexpr auto getAttribute = mlirDenseF32ArrayGet;
329 static constexpr auto getElement = mlirDenseF32ArrayGetElement;
330 static constexpr const char *pyClassName = "DenseF32ArrayAttr";
331 static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
332 using PyDenseArrayAttribute::PyDenseArrayAttribute;
334 struct PyDenseF64ArrayAttribute
335 : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
336 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
337 static constexpr auto getAttribute = mlirDenseF64ArrayGet;
338 static constexpr auto getElement = mlirDenseF64ArrayGetElement;
339 static constexpr const char *pyClassName = "DenseF64ArrayAttr";
340 static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
341 using PyDenseArrayAttribute::PyDenseArrayAttribute;
344 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
345 public:
346 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
347 static constexpr const char *pyClassName = "ArrayAttr";
348 using PyConcreteAttribute::PyConcreteAttribute;
349 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
350 mlirArrayAttrGetTypeID;
352 class PyArrayAttributeIterator {
353 public:
354 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
356 PyArrayAttributeIterator &dunderIter() { return *this; }
358 MlirAttribute dunderNext() {
359 // TODO: Throw is an inefficient way to stop iteration.
360 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
361 throw py::stop_iteration();
362 return mlirArrayAttrGetElement(attr.get(), nextIndex++);
365 static void bind(py::module &m) {
366 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
367 py::module_local())
368 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
369 .def("__next__", &PyArrayAttributeIterator::dunderNext);
372 private:
373 PyAttribute attr;
374 int nextIndex = 0;
377 MlirAttribute getItem(intptr_t i) {
378 return mlirArrayAttrGetElement(*this, i);
381 static void bindDerived(ClassTy &c) {
382 c.def_static(
383 "get",
384 [](py::list attributes, DefaultingPyMlirContext context) {
385 SmallVector<MlirAttribute> mlirAttributes;
386 mlirAttributes.reserve(py::len(attributes));
387 for (auto attribute : attributes) {
388 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
390 MlirAttribute attr = mlirArrayAttrGet(
391 context->get(), mlirAttributes.size(), mlirAttributes.data());
392 return PyArrayAttribute(context->getRef(), attr);
394 py::arg("attributes"), py::arg("context") = py::none(),
395 "Gets a uniqued Array attribute");
396 c.def("__getitem__",
397 [](PyArrayAttribute &arr, intptr_t i) {
398 if (i >= mlirArrayAttrGetNumElements(arr))
399 throw py::index_error("ArrayAttribute index out of range");
400 return arr.getItem(i);
402 .def("__len__",
403 [](const PyArrayAttribute &arr) {
404 return mlirArrayAttrGetNumElements(arr);
406 .def("__iter__", [](const PyArrayAttribute &arr) {
407 return PyArrayAttributeIterator(arr);
409 c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
410 std::vector<MlirAttribute> attributes;
411 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
412 attributes.reserve(numOldElements + py::len(extras));
413 for (intptr_t i = 0; i < numOldElements; ++i)
414 attributes.push_back(arr.getItem(i));
415 for (py::handle attr : extras)
416 attributes.push_back(pyTryCast<PyAttribute>(attr));
417 MlirAttribute arrayAttr = mlirArrayAttrGet(
418 arr.getContext()->get(), attributes.size(), attributes.data());
419 return PyArrayAttribute(arr.getContext(), arrayAttr);
424 /// Float Point Attribute subclass - FloatAttr.
425 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
426 public:
427 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
428 static constexpr const char *pyClassName = "FloatAttr";
429 using PyConcreteAttribute::PyConcreteAttribute;
430 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
431 mlirFloatAttrGetTypeID;
433 static void bindDerived(ClassTy &c) {
434 c.def_static(
435 "get",
436 [](PyType &type, double value, DefaultingPyLocation loc) {
437 PyMlirContext::ErrorCapture errors(loc->getContext());
438 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
439 if (mlirAttributeIsNull(attr))
440 throw MLIRError("Invalid attribute", errors.take());
441 return PyFloatAttribute(type.getContext(), attr);
443 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
444 "Gets an uniqued float point attribute associated to a type");
445 c.def_static(
446 "get_f32",
447 [](double value, DefaultingPyMlirContext context) {
448 MlirAttribute attr = mlirFloatAttrDoubleGet(
449 context->get(), mlirF32TypeGet(context->get()), value);
450 return PyFloatAttribute(context->getRef(), attr);
452 py::arg("value"), py::arg("context") = py::none(),
453 "Gets an uniqued float point attribute associated to a f32 type");
454 c.def_static(
455 "get_f64",
456 [](double value, DefaultingPyMlirContext context) {
457 MlirAttribute attr = mlirFloatAttrDoubleGet(
458 context->get(), mlirF64TypeGet(context->get()), value);
459 return PyFloatAttribute(context->getRef(), attr);
461 py::arg("value"), py::arg("context") = py::none(),
462 "Gets an uniqued float point attribute associated to a f64 type");
463 c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
464 "Returns the value of the float attribute");
465 c.def("__float__", mlirFloatAttrGetValueDouble,
466 "Converts the value of the float attribute to a Python float");
470 /// Integer Attribute subclass - IntegerAttr.
471 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
472 public:
473 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
474 static constexpr const char *pyClassName = "IntegerAttr";
475 using PyConcreteAttribute::PyConcreteAttribute;
477 static void bindDerived(ClassTy &c) {
478 c.def_static(
479 "get",
480 [](PyType &type, int64_t value) {
481 MlirAttribute attr = mlirIntegerAttrGet(type, value);
482 return PyIntegerAttribute(type.getContext(), attr);
484 py::arg("type"), py::arg("value"),
485 "Gets an uniqued integer attribute associated to a type");
486 c.def_property_readonly("value", toPyInt,
487 "Returns the value of the integer attribute");
488 c.def("__int__", toPyInt,
489 "Converts the value of the integer attribute to a Python int");
490 c.def_property_readonly_static("static_typeid",
491 [](py::object & /*class*/) -> MlirTypeID {
492 return mlirIntegerAttrGetTypeID();
496 private:
497 static py::int_ toPyInt(PyIntegerAttribute &self) {
498 MlirType type = mlirAttributeGetType(self);
499 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
500 return mlirIntegerAttrGetValueInt(self);
501 if (mlirIntegerTypeIsSigned(type))
502 return mlirIntegerAttrGetValueSInt(self);
503 return mlirIntegerAttrGetValueUInt(self);
507 /// Bool Attribute subclass - BoolAttr.
508 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
509 public:
510 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
511 static constexpr const char *pyClassName = "BoolAttr";
512 using PyConcreteAttribute::PyConcreteAttribute;
514 static void bindDerived(ClassTy &c) {
515 c.def_static(
516 "get",
517 [](bool value, DefaultingPyMlirContext context) {
518 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
519 return PyBoolAttribute(context->getRef(), attr);
521 py::arg("value"), py::arg("context") = py::none(),
522 "Gets an uniqued bool attribute");
523 c.def_property_readonly("value", mlirBoolAttrGetValue,
524 "Returns the value of the bool attribute");
525 c.def("__bool__", mlirBoolAttrGetValue,
526 "Converts the value of the bool attribute to a Python bool");
530 class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
531 public:
532 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
533 static constexpr const char *pyClassName = "SymbolRefAttr";
534 using PyConcreteAttribute::PyConcreteAttribute;
536 static MlirAttribute fromList(const std::vector<std::string> &symbols,
537 PyMlirContext &context) {
538 if (symbols.empty())
539 throw std::runtime_error("SymbolRefAttr must be composed of at least "
540 "one symbol.");
541 MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
542 SmallVector<MlirAttribute, 3> referenceAttrs;
543 for (size_t i = 1; i < symbols.size(); ++i) {
544 referenceAttrs.push_back(
545 mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
547 return mlirSymbolRefAttrGet(context.get(), rootSymbol,
548 referenceAttrs.size(), referenceAttrs.data());
551 static void bindDerived(ClassTy &c) {
552 c.def_static(
553 "get",
554 [](const std::vector<std::string> &symbols,
555 DefaultingPyMlirContext context) {
556 return PySymbolRefAttribute::fromList(symbols, context.resolve());
558 py::arg("symbols"), py::arg("context") = py::none(),
559 "Gets a uniqued SymbolRef attribute from a list of symbol names");
560 c.def_property_readonly(
561 "value",
562 [](PySymbolRefAttribute &self) {
563 std::vector<std::string> symbols = {
564 unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
565 for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
566 ++i)
567 symbols.push_back(
568 unwrap(mlirSymbolRefAttrGetRootReference(
569 mlirSymbolRefAttrGetNestedReference(self, i)))
570 .str());
571 return symbols;
573 "Returns the value of the SymbolRef attribute as a list[str]");
577 class PyFlatSymbolRefAttribute
578 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
579 public:
580 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
581 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
582 using PyConcreteAttribute::PyConcreteAttribute;
584 static void bindDerived(ClassTy &c) {
585 c.def_static(
586 "get",
587 [](std::string value, DefaultingPyMlirContext context) {
588 MlirAttribute attr =
589 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
590 return PyFlatSymbolRefAttribute(context->getRef(), attr);
592 py::arg("value"), py::arg("context") = py::none(),
593 "Gets a uniqued FlatSymbolRef attribute");
594 c.def_property_readonly(
595 "value",
596 [](PyFlatSymbolRefAttribute &self) {
597 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
598 return py::str(stringRef.data, stringRef.length);
600 "Returns the value of the FlatSymbolRef attribute as a string");
604 class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
605 public:
606 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
607 static constexpr const char *pyClassName = "OpaqueAttr";
608 using PyConcreteAttribute::PyConcreteAttribute;
609 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
610 mlirOpaqueAttrGetTypeID;
612 static void bindDerived(ClassTy &c) {
613 c.def_static(
614 "get",
615 [](std::string dialectNamespace, py::buffer buffer, PyType &type,
616 DefaultingPyMlirContext context) {
617 const py::buffer_info bufferInfo = buffer.request();
618 intptr_t bufferSize = bufferInfo.size;
619 MlirAttribute attr = mlirOpaqueAttrGet(
620 context->get(), toMlirStringRef(dialectNamespace), bufferSize,
621 static_cast<char *>(bufferInfo.ptr), type);
622 return PyOpaqueAttribute(context->getRef(), attr);
624 py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
625 py::arg("context") = py::none(), "Gets an Opaque attribute.");
626 c.def_property_readonly(
627 "dialect_namespace",
628 [](PyOpaqueAttribute &self) {
629 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
630 return py::str(stringRef.data, stringRef.length);
632 "Returns the dialect namespace for the Opaque attribute as a string");
633 c.def_property_readonly(
634 "data",
635 [](PyOpaqueAttribute &self) {
636 MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
637 return py::bytes(stringRef.data, stringRef.length);
639 "Returns the data for the Opaqued attributes as `bytes`");
643 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
644 public:
645 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
646 static constexpr const char *pyClassName = "StringAttr";
647 using PyConcreteAttribute::PyConcreteAttribute;
648 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
649 mlirStringAttrGetTypeID;
651 static void bindDerived(ClassTy &c) {
652 c.def_static(
653 "get",
654 [](std::string value, DefaultingPyMlirContext context) {
655 MlirAttribute attr =
656 mlirStringAttrGet(context->get(), toMlirStringRef(value));
657 return PyStringAttribute(context->getRef(), attr);
659 py::arg("value"), py::arg("context") = py::none(),
660 "Gets a uniqued string attribute");
661 c.def_static(
662 "get_typed",
663 [](PyType &type, std::string value) {
664 MlirAttribute attr =
665 mlirStringAttrTypedGet(type, toMlirStringRef(value));
666 return PyStringAttribute(type.getContext(), attr);
668 py::arg("type"), py::arg("value"),
669 "Gets a uniqued string attribute associated to a type");
670 c.def_property_readonly(
671 "value",
672 [](PyStringAttribute &self) {
673 MlirStringRef stringRef = mlirStringAttrGetValue(self);
674 return py::str(stringRef.data, stringRef.length);
676 "Returns the value of the string attribute");
677 c.def_property_readonly(
678 "value_bytes",
679 [](PyStringAttribute &self) {
680 MlirStringRef stringRef = mlirStringAttrGetValue(self);
681 return py::bytes(stringRef.data, stringRef.length);
683 "Returns the value of the string attribute as `bytes`");
687 // TODO: Support construction of string elements.
688 class PyDenseElementsAttribute
689 : public PyConcreteAttribute<PyDenseElementsAttribute> {
690 public:
691 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
692 static constexpr const char *pyClassName = "DenseElementsAttr";
693 using PyConcreteAttribute::PyConcreteAttribute;
695 static PyDenseElementsAttribute
696 getFromList(py::list attributes, std::optional<PyType> explicitType,
697 DefaultingPyMlirContext contextWrapper) {
699 const size_t numAttributes = py::len(attributes);
700 if (numAttributes == 0)
701 throw py::value_error("Attributes list must be non-empty.");
703 MlirType shapedType;
704 if (explicitType) {
705 if ((!mlirTypeIsAShaped(*explicitType) ||
706 !mlirShapedTypeHasStaticShape(*explicitType))) {
708 std::string message;
709 llvm::raw_string_ostream os(message);
710 os << "Expected a static ShapedType for the shaped_type parameter: "
711 << py::repr(py::cast(*explicitType));
712 throw py::value_error(message);
714 shapedType = *explicitType;
715 } else {
716 SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
717 shapedType = mlirRankedTensorTypeGet(
718 shape.size(), shape.data(),
719 mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
720 mlirAttributeGetNull());
723 SmallVector<MlirAttribute> mlirAttributes;
724 mlirAttributes.reserve(numAttributes);
725 for (const py::handle &attribute : attributes) {
726 MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
727 MlirType attrType = mlirAttributeGetType(mlirAttribute);
728 mlirAttributes.push_back(mlirAttribute);
730 if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
731 std::string message;
732 llvm::raw_string_ostream os(message);
733 os << "All attributes must be of the same type and match "
734 << "the type parameter: expected=" << py::repr(py::cast(shapedType))
735 << ", but got=" << py::repr(py::cast(attrType));
736 throw py::value_error(message);
740 MlirAttribute elements = mlirDenseElementsAttrGet(
741 shapedType, mlirAttributes.size(), mlirAttributes.data());
743 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
746 static PyDenseElementsAttribute
747 getFromBuffer(py::buffer array, bool signless,
748 std::optional<PyType> explicitType,
749 std::optional<std::vector<int64_t>> explicitShape,
750 DefaultingPyMlirContext contextWrapper) {
751 // Request a contiguous view. In exotic cases, this will cause a copy.
752 int flags = PyBUF_ND;
753 if (!explicitType) {
754 flags |= PyBUF_FORMAT;
756 Py_buffer view;
757 if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
758 throw py::error_already_set();
760 auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
762 MlirContext context = contextWrapper->get();
763 MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
764 explicitShape, context);
765 if (mlirAttributeIsNull(attr)) {
766 throw std::invalid_argument(
767 "DenseElementsAttr could not be constructed from the given buffer. "
768 "This may mean that the Python buffer layout does not match that "
769 "MLIR expected layout and is a bug.");
771 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
774 static PyDenseElementsAttribute getSplat(const PyType &shapedType,
775 PyAttribute &elementAttr) {
776 auto contextWrapper =
777 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
778 if (!mlirAttributeIsAInteger(elementAttr) &&
779 !mlirAttributeIsAFloat(elementAttr)) {
780 std::string message = "Illegal element type for DenseElementsAttr: ";
781 message.append(py::repr(py::cast(elementAttr)));
782 throw py::value_error(message);
784 if (!mlirTypeIsAShaped(shapedType) ||
785 !mlirShapedTypeHasStaticShape(shapedType)) {
786 std::string message =
787 "Expected a static ShapedType for the shaped_type parameter: ";
788 message.append(py::repr(py::cast(shapedType)));
789 throw py::value_error(message);
791 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
792 MlirType attrType = mlirAttributeGetType(elementAttr);
793 if (!mlirTypeEqual(shapedElementType, attrType)) {
794 std::string message =
795 "Shaped element type and attribute type must be equal: shaped=";
796 message.append(py::repr(py::cast(shapedType)));
797 message.append(", element=");
798 message.append(py::repr(py::cast(elementAttr)));
799 throw py::value_error(message);
802 MlirAttribute elements =
803 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
804 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
807 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
809 py::buffer_info accessBuffer() {
810 MlirType shapedType = mlirAttributeGetType(*this);
811 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
812 std::string format;
814 if (mlirTypeIsAF32(elementType)) {
815 // f32
816 return bufferInfo<float>(shapedType);
818 if (mlirTypeIsAF64(elementType)) {
819 // f64
820 return bufferInfo<double>(shapedType);
822 if (mlirTypeIsAF16(elementType)) {
823 // f16
824 return bufferInfo<uint16_t>(shapedType, "e");
826 if (mlirTypeIsAIndex(elementType)) {
827 // Same as IndexType::kInternalStorageBitWidth
828 return bufferInfo<int64_t>(shapedType);
830 if (mlirTypeIsAInteger(elementType) &&
831 mlirIntegerTypeGetWidth(elementType) == 32) {
832 if (mlirIntegerTypeIsSignless(elementType) ||
833 mlirIntegerTypeIsSigned(elementType)) {
834 // i32
835 return bufferInfo<int32_t>(shapedType);
837 if (mlirIntegerTypeIsUnsigned(elementType)) {
838 // unsigned i32
839 return bufferInfo<uint32_t>(shapedType);
841 } else if (mlirTypeIsAInteger(elementType) &&
842 mlirIntegerTypeGetWidth(elementType) == 64) {
843 if (mlirIntegerTypeIsSignless(elementType) ||
844 mlirIntegerTypeIsSigned(elementType)) {
845 // i64
846 return bufferInfo<int64_t>(shapedType);
848 if (mlirIntegerTypeIsUnsigned(elementType)) {
849 // unsigned i64
850 return bufferInfo<uint64_t>(shapedType);
852 } else if (mlirTypeIsAInteger(elementType) &&
853 mlirIntegerTypeGetWidth(elementType) == 8) {
854 if (mlirIntegerTypeIsSignless(elementType) ||
855 mlirIntegerTypeIsSigned(elementType)) {
856 // i8
857 return bufferInfo<int8_t>(shapedType);
859 if (mlirIntegerTypeIsUnsigned(elementType)) {
860 // unsigned i8
861 return bufferInfo<uint8_t>(shapedType);
863 } else if (mlirTypeIsAInteger(elementType) &&
864 mlirIntegerTypeGetWidth(elementType) == 16) {
865 if (mlirIntegerTypeIsSignless(elementType) ||
866 mlirIntegerTypeIsSigned(elementType)) {
867 // i16
868 return bufferInfo<int16_t>(shapedType);
870 if (mlirIntegerTypeIsUnsigned(elementType)) {
871 // unsigned i16
872 return bufferInfo<uint16_t>(shapedType);
874 } else if (mlirTypeIsAInteger(elementType) &&
875 mlirIntegerTypeGetWidth(elementType) == 1) {
876 // i1 / bool
877 // We can not send the buffer directly back to Python, because the i1
878 // values are bitpacked within MLIR. We call numpy's unpackbits function
879 // to convert the bytes.
880 return getBooleanBufferFromBitpackedAttribute();
883 // TODO: Currently crashes the program.
884 // Reported as https://github.com/pybind/pybind11/issues/3336
885 throw std::invalid_argument(
886 "unsupported data type for conversion to Python buffer");
889 static void bindDerived(ClassTy &c) {
890 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
891 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
892 py::arg("array"), py::arg("signless") = true,
893 py::arg("type") = py::none(), py::arg("shape") = py::none(),
894 py::arg("context") = py::none(),
895 kDenseElementsAttrGetDocstring)
896 .def_static("get", PyDenseElementsAttribute::getFromList,
897 py::arg("attrs"), py::arg("type") = py::none(),
898 py::arg("context") = py::none(),
899 kDenseElementsAttrGetFromListDocstring)
900 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
901 py::arg("shaped_type"), py::arg("element_attr"),
902 "Gets a DenseElementsAttr where all values are the same")
903 .def_property_readonly("is_splat",
904 [](PyDenseElementsAttribute &self) -> bool {
905 return mlirDenseElementsAttrIsSplat(self);
907 .def("get_splat_value",
908 [](PyDenseElementsAttribute &self) {
909 if (!mlirDenseElementsAttrIsSplat(self))
910 throw py::value_error(
911 "get_splat_value called on a non-splat attribute");
912 return mlirDenseElementsAttrGetSplatValue(self);
914 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
917 private:
918 static bool isUnsignedIntegerFormat(std::string_view format) {
919 if (format.empty())
920 return false;
921 char code = format[0];
922 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
923 code == 'Q';
926 static bool isSignedIntegerFormat(std::string_view format) {
927 if (format.empty())
928 return false;
929 char code = format[0];
930 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
931 code == 'q';
934 static MlirType
935 getShapedType(std::optional<MlirType> bulkLoadElementType,
936 std::optional<std::vector<int64_t>> explicitShape,
937 Py_buffer &view) {
938 SmallVector<int64_t> shape;
939 if (explicitShape) {
940 shape.append(explicitShape->begin(), explicitShape->end());
941 } else {
942 shape.append(view.shape, view.shape + view.ndim);
945 if (mlirTypeIsAShaped(*bulkLoadElementType)) {
946 if (explicitShape) {
947 throw std::invalid_argument("Shape can only be specified explicitly "
948 "when the type is not a shaped type.");
950 return *bulkLoadElementType;
951 } else {
952 MlirAttribute encodingAttr = mlirAttributeGetNull();
953 return mlirRankedTensorTypeGet(shape.size(), shape.data(),
954 *bulkLoadElementType, encodingAttr);
958 static MlirAttribute getAttributeFromBuffer(
959 Py_buffer &view, bool signless, std::optional<PyType> explicitType,
960 std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
961 // Detect format codes that are suitable for bulk loading. This includes
962 // all byte aligned integer and floating point types up to 8 bytes.
963 // Notably, this excludes exotics types which do not have a direct
964 // representation in the buffer protocol (i.e. complex, etc).
965 std::optional<MlirType> bulkLoadElementType;
966 if (explicitType) {
967 bulkLoadElementType = *explicitType;
968 } else {
969 std::string_view format(view.format);
970 if (format == "f") {
971 // f32
972 assert(view.itemsize == 4 && "mismatched array itemsize");
973 bulkLoadElementType = mlirF32TypeGet(context);
974 } else if (format == "d") {
975 // f64
976 assert(view.itemsize == 8 && "mismatched array itemsize");
977 bulkLoadElementType = mlirF64TypeGet(context);
978 } else if (format == "e") {
979 // f16
980 assert(view.itemsize == 2 && "mismatched array itemsize");
981 bulkLoadElementType = mlirF16TypeGet(context);
982 } else if (format == "?") {
983 // i1
984 // The i1 type needs to be bit-packed, so we will handle it seperately
985 return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
986 context);
987 } else if (isSignedIntegerFormat(format)) {
988 if (view.itemsize == 4) {
989 // i32
990 bulkLoadElementType = signless
991 ? mlirIntegerTypeGet(context, 32)
992 : mlirIntegerTypeSignedGet(context, 32);
993 } else if (view.itemsize == 8) {
994 // i64
995 bulkLoadElementType = signless
996 ? mlirIntegerTypeGet(context, 64)
997 : mlirIntegerTypeSignedGet(context, 64);
998 } else if (view.itemsize == 1) {
999 // i8
1000 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1001 : mlirIntegerTypeSignedGet(context, 8);
1002 } else if (view.itemsize == 2) {
1003 // i16
1004 bulkLoadElementType = signless
1005 ? mlirIntegerTypeGet(context, 16)
1006 : mlirIntegerTypeSignedGet(context, 16);
1008 } else if (isUnsignedIntegerFormat(format)) {
1009 if (view.itemsize == 4) {
1010 // unsigned i32
1011 bulkLoadElementType = signless
1012 ? mlirIntegerTypeGet(context, 32)
1013 : mlirIntegerTypeUnsignedGet(context, 32);
1014 } else if (view.itemsize == 8) {
1015 // unsigned i64
1016 bulkLoadElementType = signless
1017 ? mlirIntegerTypeGet(context, 64)
1018 : mlirIntegerTypeUnsignedGet(context, 64);
1019 } else if (view.itemsize == 1) {
1020 // i8
1021 bulkLoadElementType = signless
1022 ? mlirIntegerTypeGet(context, 8)
1023 : mlirIntegerTypeUnsignedGet(context, 8);
1024 } else if (view.itemsize == 2) {
1025 // i16
1026 bulkLoadElementType = signless
1027 ? mlirIntegerTypeGet(context, 16)
1028 : mlirIntegerTypeUnsignedGet(context, 16);
1031 if (!bulkLoadElementType) {
1032 throw std::invalid_argument(
1033 std::string("unimplemented array format conversion from format: ") +
1034 std::string(format));
1038 MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1039 return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1042 // There is a complication for boolean numpy arrays, as numpy represents them
1043 // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
1044 // per byte.
1045 static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1046 Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1047 MlirContext &context) {
1048 if (llvm::endianness::native != llvm::endianness::little) {
1049 // Given we have no good way of testing the behavior on big-endian systems
1050 // we will throw
1051 throw py::type_error("Constructing a bit-packed MLIR attribute is "
1052 "unsupported on big-endian systems");
1055 py::array_t<uint8_t> unpackedArray(view.len,
1056 static_cast<uint8_t *>(view.buf));
1058 py::module numpy = py::module::import("numpy");
1059 py::object packbitsFunc = numpy.attr("packbits");
1060 py::object packedBooleans =
1061 packbitsFunc(unpackedArray, "bitorder"_a = "little");
1062 py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
1064 MlirType bitpackedType =
1065 getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1066 assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1067 // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1068 // packedBooleans, hence the MlirAttribute will remain valid even when
1069 // packedBooleans get reclaimed by the end of the function.
1070 return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1071 pythonBuffer.ptr);
1074 // This does the opposite transformation of
1075 // `getBitpackedAttributeFromBooleanBuffer`
1076 py::buffer_info getBooleanBufferFromBitpackedAttribute() {
1077 if (llvm::endianness::native != llvm::endianness::little) {
1078 // Given we have no good way of testing the behavior on big-endian systems
1079 // we will throw
1080 throw py::type_error("Constructing a numpy array from a MLIR attribute "
1081 "is unsupported on big-endian systems");
1084 int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1085 int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
1086 uint8_t *bitpackedData = static_cast<uint8_t *>(
1087 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1088 py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
1090 py::module numpy = py::module::import("numpy");
1091 py::object unpackbitsFunc = numpy.attr("unpackbits");
1092 py::object equalFunc = numpy.attr("equal");
1093 py::object reshapeFunc = numpy.attr("reshape");
1094 py::array unpackedBooleans =
1095 unpackbitsFunc(packedArray, "bitorder"_a = "little");
1097 // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1098 // We need to:
1099 // 1. Slice away the padded bits
1100 // 2. Make the boolean array have the correct shape
1101 // 3. Convert the array to a boolean array
1102 unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
1103 unpackedBooleans = equalFunc(unpackedBooleans, 1);
1105 std::vector<intptr_t> shape;
1106 MlirType shapedType = mlirAttributeGetType(*this);
1107 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1108 for (intptr_t i = 0; i < rank; ++i) {
1109 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1111 unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1113 // Make sure the returned py::buffer_view claims ownership of the data in
1114 // `pythonBuffer` so it remains valid when Python reads it
1115 py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
1116 return pythonBuffer.request();
1119 template <typename Type>
1120 py::buffer_info bufferInfo(MlirType shapedType,
1121 const char *explicitFormat = nullptr) {
1122 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1123 // Prepare the data for the buffer_info.
1124 // Buffer is configured for read-only access below.
1125 Type *data = static_cast<Type *>(
1126 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1127 // Prepare the shape for the buffer_info.
1128 SmallVector<intptr_t, 4> shape;
1129 for (intptr_t i = 0; i < rank; ++i)
1130 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1131 // Prepare the strides for the buffer_info.
1132 SmallVector<intptr_t, 4> strides;
1133 if (mlirDenseElementsAttrIsSplat(*this)) {
1134 // Splats are special, only the single value is stored.
1135 strides.assign(rank, 0);
1136 } else {
1137 for (intptr_t i = 1; i < rank; ++i) {
1138 intptr_t strideFactor = 1;
1139 for (intptr_t j = i; j < rank; ++j)
1140 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1141 strides.push_back(sizeof(Type) * strideFactor);
1143 strides.push_back(sizeof(Type));
1145 std::string format;
1146 if (explicitFormat) {
1147 format = explicitFormat;
1148 } else {
1149 format = py::format_descriptor<Type>::format();
1151 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
1152 /*readonly=*/true);
1154 }; // namespace
1156 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1157 /// (and boolean) values. Supports element access.
1158 class PyDenseIntElementsAttribute
1159 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1160 PyDenseElementsAttribute> {
1161 public:
1162 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1163 static constexpr const char *pyClassName = "DenseIntElementsAttr";
1164 using PyConcreteAttribute::PyConcreteAttribute;
1166 /// Returns the element at the given linear position. Asserts if the index is
1167 /// out of range.
1168 py::int_ dunderGetItem(intptr_t pos) {
1169 if (pos < 0 || pos >= dunderLen()) {
1170 throw py::index_error("attempt to access out of bounds element");
1173 MlirType type = mlirAttributeGetType(*this);
1174 type = mlirShapedTypeGetElementType(type);
1175 assert(mlirTypeIsAInteger(type) &&
1176 "expected integer element type in dense int elements attribute");
1177 // Dispatch element extraction to an appropriate C function based on the
1178 // elemental type of the attribute. py::int_ is implicitly constructible
1179 // from any C++ integral type and handles bitwidth correctly.
1180 // TODO: consider caching the type properties in the constructor to avoid
1181 // querying them on each element access.
1182 unsigned width = mlirIntegerTypeGetWidth(type);
1183 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1184 if (isUnsigned) {
1185 if (width == 1) {
1186 return mlirDenseElementsAttrGetBoolValue(*this, pos);
1188 if (width == 8) {
1189 return mlirDenseElementsAttrGetUInt8Value(*this, pos);
1191 if (width == 16) {
1192 return mlirDenseElementsAttrGetUInt16Value(*this, pos);
1194 if (width == 32) {
1195 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1197 if (width == 64) {
1198 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1200 } else {
1201 if (width == 1) {
1202 return mlirDenseElementsAttrGetBoolValue(*this, pos);
1204 if (width == 8) {
1205 return mlirDenseElementsAttrGetInt8Value(*this, pos);
1207 if (width == 16) {
1208 return mlirDenseElementsAttrGetInt16Value(*this, pos);
1210 if (width == 32) {
1211 return mlirDenseElementsAttrGetInt32Value(*this, pos);
1213 if (width == 64) {
1214 return mlirDenseElementsAttrGetInt64Value(*this, pos);
1217 throw py::type_error("Unsupported integer type");
1220 static void bindDerived(ClassTy &c) {
1221 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1225 class PyDenseResourceElementsAttribute
1226 : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1227 public:
1228 static constexpr IsAFunctionTy isaFunction =
1229 mlirAttributeIsADenseResourceElements;
1230 static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1231 using PyConcreteAttribute::PyConcreteAttribute;
1233 static PyDenseResourceElementsAttribute
1234 getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1235 std::optional<size_t> alignment, bool isMutable,
1236 DefaultingPyMlirContext contextWrapper) {
1237 if (!mlirTypeIsAShaped(type)) {
1238 throw std::invalid_argument(
1239 "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1242 // Do not request any conversions as we must ensure to use caller
1243 // managed memory.
1244 int flags = PyBUF_STRIDES;
1245 std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1246 if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1247 throw py::error_already_set();
1250 // This scope releaser will only release if we haven't yet transferred
1251 // ownership.
1252 auto freeBuffer = llvm::make_scope_exit([&]() {
1253 if (view)
1254 PyBuffer_Release(view.get());
1257 if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1258 throw std::invalid_argument("Contiguous buffer is required.");
1261 // Infer alignment to be the stride of one element if not explicit.
1262 size_t inferredAlignment;
1263 if (alignment)
1264 inferredAlignment = *alignment;
1265 else
1266 inferredAlignment = view->strides[view->ndim - 1];
1268 // The userData is a Py_buffer* that the deleter owns.
1269 auto deleter = [](void *userData, const void *data, size_t size,
1270 size_t align) {
1271 Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1272 PyBuffer_Release(ownedView);
1273 delete ownedView;
1276 size_t rawBufferSize = view->len;
1277 MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1278 type, toMlirStringRef(name), view->buf, rawBufferSize,
1279 inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1280 if (mlirAttributeIsNull(attr)) {
1281 throw std::invalid_argument(
1282 "DenseResourceElementsAttr could not be constructed from the given "
1283 "buffer. "
1284 "This may mean that the Python buffer layout does not match that "
1285 "MLIR expected layout and is a bug.");
1287 view.release();
1288 return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1291 static void bindDerived(ClassTy &c) {
1292 c.def_static("get_from_buffer",
1293 PyDenseResourceElementsAttribute::getFromBuffer,
1294 py::arg("array"), py::arg("name"), py::arg("type"),
1295 py::arg("alignment") = py::none(),
1296 py::arg("is_mutable") = false, py::arg("context") = py::none(),
1297 kDenseResourceElementsAttrGetFromBufferDocstring);
1301 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1302 public:
1303 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1304 static constexpr const char *pyClassName = "DictAttr";
1305 using PyConcreteAttribute::PyConcreteAttribute;
1306 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1307 mlirDictionaryAttrGetTypeID;
1309 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1311 bool dunderContains(const std::string &name) {
1312 return !mlirAttributeIsNull(
1313 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
1316 static void bindDerived(ClassTy &c) {
1317 c.def("__contains__", &PyDictAttribute::dunderContains);
1318 c.def("__len__", &PyDictAttribute::dunderLen);
1319 c.def_static(
1320 "get",
1321 [](py::dict attributes, DefaultingPyMlirContext context) {
1322 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1323 mlirNamedAttributes.reserve(attributes.size());
1324 for (auto &it : attributes) {
1325 auto &mlirAttr = it.second.cast<PyAttribute &>();
1326 auto name = it.first.cast<std::string>();
1327 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1328 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1329 toMlirStringRef(name)),
1330 mlirAttr));
1332 MlirAttribute attr =
1333 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1334 mlirNamedAttributes.data());
1335 return PyDictAttribute(context->getRef(), attr);
1337 py::arg("value") = py::dict(), py::arg("context") = py::none(),
1338 "Gets an uniqued dict attribute");
1339 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1340 MlirAttribute attr =
1341 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1342 if (mlirAttributeIsNull(attr))
1343 throw py::key_error("attempt to access a non-existent attribute");
1344 return attr;
1346 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1347 if (index < 0 || index >= self.dunderLen()) {
1348 throw py::index_error("attempt to access out of bounds attribute");
1350 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1351 return PyNamedAttribute(
1352 namedAttr.attribute,
1353 std::string(mlirIdentifierStr(namedAttr.name).data));
1358 /// Refinement of PyDenseElementsAttribute for attributes containing
1359 /// floating-point values. Supports element access.
1360 class PyDenseFPElementsAttribute
1361 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1362 PyDenseElementsAttribute> {
1363 public:
1364 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1365 static constexpr const char *pyClassName = "DenseFPElementsAttr";
1366 using PyConcreteAttribute::PyConcreteAttribute;
1368 py::float_ dunderGetItem(intptr_t pos) {
1369 if (pos < 0 || pos >= dunderLen()) {
1370 throw py::index_error("attempt to access out of bounds element");
1373 MlirType type = mlirAttributeGetType(*this);
1374 type = mlirShapedTypeGetElementType(type);
1375 // Dispatch element extraction to an appropriate C function based on the
1376 // elemental type of the attribute. py::float_ is implicitly constructible
1377 // from float and double.
1378 // TODO: consider caching the type properties in the constructor to avoid
1379 // querying them on each element access.
1380 if (mlirTypeIsAF32(type)) {
1381 return mlirDenseElementsAttrGetFloatValue(*this, pos);
1383 if (mlirTypeIsAF64(type)) {
1384 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1386 throw py::type_error("Unsupported floating-point type");
1389 static void bindDerived(ClassTy &c) {
1390 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1394 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1395 public:
1396 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1397 static constexpr const char *pyClassName = "TypeAttr";
1398 using PyConcreteAttribute::PyConcreteAttribute;
1399 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1400 mlirTypeAttrGetTypeID;
1402 static void bindDerived(ClassTy &c) {
1403 c.def_static(
1404 "get",
1405 [](PyType value, DefaultingPyMlirContext context) {
1406 MlirAttribute attr = mlirTypeAttrGet(value.get());
1407 return PyTypeAttribute(context->getRef(), attr);
1409 py::arg("value"), py::arg("context") = py::none(),
1410 "Gets a uniqued Type attribute");
1411 c.def_property_readonly("value", [](PyTypeAttribute &self) {
1412 return mlirTypeAttrGetValue(self.get());
1417 /// Unit Attribute subclass. Unit attributes don't have values.
1418 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1419 public:
1420 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1421 static constexpr const char *pyClassName = "UnitAttr";
1422 using PyConcreteAttribute::PyConcreteAttribute;
1423 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1424 mlirUnitAttrGetTypeID;
1426 static void bindDerived(ClassTy &c) {
1427 c.def_static(
1428 "get",
1429 [](DefaultingPyMlirContext context) {
1430 return PyUnitAttribute(context->getRef(),
1431 mlirUnitAttrGet(context->get()));
1433 py::arg("context") = py::none(), "Create a Unit attribute.");
1437 /// Strided layout attribute subclass.
1438 class PyStridedLayoutAttribute
1439 : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1440 public:
1441 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1442 static constexpr const char *pyClassName = "StridedLayoutAttr";
1443 using PyConcreteAttribute::PyConcreteAttribute;
1444 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1445 mlirStridedLayoutAttrGetTypeID;
1447 static void bindDerived(ClassTy &c) {
1448 c.def_static(
1449 "get",
1450 [](int64_t offset, const std::vector<int64_t> strides,
1451 DefaultingPyMlirContext ctx) {
1452 MlirAttribute attr = mlirStridedLayoutAttrGet(
1453 ctx->get(), offset, strides.size(), strides.data());
1454 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1456 py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1457 "Gets a strided layout attribute.");
1458 c.def_static(
1459 "get_fully_dynamic",
1460 [](int64_t rank, DefaultingPyMlirContext ctx) {
1461 auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1462 std::vector<int64_t> strides(rank);
1463 std::fill(strides.begin(), strides.end(), dynamic);
1464 MlirAttribute attr = mlirStridedLayoutAttrGet(
1465 ctx->get(), dynamic, strides.size(), strides.data());
1466 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1468 py::arg("rank"), py::arg("context") = py::none(),
1469 "Gets a strided layout attribute with dynamic offset and strides of a "
1470 "given rank.");
1471 c.def_property_readonly(
1472 "offset",
1473 [](PyStridedLayoutAttribute &self) {
1474 return mlirStridedLayoutAttrGetOffset(self);
1476 "Returns the value of the float point attribute");
1477 c.def_property_readonly(
1478 "strides",
1479 [](PyStridedLayoutAttribute &self) {
1480 intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1481 std::vector<int64_t> strides(size);
1482 for (intptr_t i = 0; i < size; i++) {
1483 strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1485 return strides;
1487 "Returns the value of the float point attribute");
1491 py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1492 if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1493 return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
1494 if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1495 return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
1496 if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1497 return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
1498 if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1499 return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
1500 if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1501 return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
1502 if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1503 return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
1504 if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1505 return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
1506 std::string msg =
1507 std::string("Can't cast unknown element type DenseArrayAttr (") +
1508 std::string(py::repr(py::cast(pyAttribute))) + ")";
1509 throw py::cast_error(msg);
1512 py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1513 if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1514 return py::cast(PyDenseFPElementsAttribute(pyAttribute));
1515 if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1516 return py::cast(PyDenseIntElementsAttribute(pyAttribute));
1517 std::string msg =
1518 std::string(
1519 "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1520 std::string(py::repr(py::cast(pyAttribute))) + ")";
1521 throw py::cast_error(msg);
1524 py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1525 if (PyBoolAttribute::isaFunction(pyAttribute))
1526 return py::cast(PyBoolAttribute(pyAttribute));
1527 if (PyIntegerAttribute::isaFunction(pyAttribute))
1528 return py::cast(PyIntegerAttribute(pyAttribute));
1529 std::string msg =
1530 std::string("Can't cast unknown element type DenseArrayAttr (") +
1531 std::string(py::repr(py::cast(pyAttribute))) + ")";
1532 throw py::cast_error(msg);
1535 py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1536 if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1537 return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
1538 if (PySymbolRefAttribute::isaFunction(pyAttribute))
1539 return py::cast(PySymbolRefAttribute(pyAttribute));
1540 std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1541 std::string(py::repr(py::cast(pyAttribute))) + ")";
1542 throw py::cast_error(msg);
1545 } // namespace
1547 void mlir::python::populateIRAttributes(py::module &m) {
1548 PyAffineMapAttribute::bind(m);
1549 PyDenseBoolArrayAttribute::bind(m);
1550 PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1551 PyDenseI8ArrayAttribute::bind(m);
1552 PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1553 PyDenseI16ArrayAttribute::bind(m);
1554 PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1555 PyDenseI32ArrayAttribute::bind(m);
1556 PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1557 PyDenseI64ArrayAttribute::bind(m);
1558 PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1559 PyDenseF32ArrayAttribute::bind(m);
1560 PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1561 PyDenseF64ArrayAttribute::bind(m);
1562 PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1563 PyGlobals::get().registerTypeCaster(
1564 mlirDenseArrayAttrGetTypeID(),
1565 pybind11::cpp_function(denseArrayAttributeCaster));
1567 PyArrayAttribute::bind(m);
1568 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1569 PyBoolAttribute::bind(m);
1570 PyDenseElementsAttribute::bind(m);
1571 PyDenseFPElementsAttribute::bind(m);
1572 PyDenseIntElementsAttribute::bind(m);
1573 PyGlobals::get().registerTypeCaster(
1574 mlirDenseIntOrFPElementsAttrGetTypeID(),
1575 pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1576 PyDenseResourceElementsAttribute::bind(m);
1578 PyDictAttribute::bind(m);
1579 PySymbolRefAttribute::bind(m);
1580 PyGlobals::get().registerTypeCaster(
1581 mlirSymbolRefAttrGetTypeID(),
1582 pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
1584 PyFlatSymbolRefAttribute::bind(m);
1585 PyOpaqueAttribute::bind(m);
1586 PyFloatAttribute::bind(m);
1587 PyIntegerAttribute::bind(m);
1588 PyIntegerSetAttribute::bind(m);
1589 PyStringAttribute::bind(m);
1590 PyTypeAttribute::bind(m);
1591 PyGlobals::get().registerTypeCaster(
1592 mlirIntegerAttrGetTypeID(),
1593 pybind11::cpp_function(integerOrBoolAttributeCaster));
1594 PyUnitAttribute::bind(m);
1596 PyStridedLayoutAttribute::bind(m);