[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Bindings / Python / IRAffine.cpp
blobb138e131e851eaeefd8350f36b2fcd4cc95276cc
1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <cstddef>
10 #include <cstdint>
11 #include <pybind11/cast.h>
12 #include <pybind11/detail/common.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
15 #include <string>
16 #include <utility>
17 #include <vector>
19 #include "IRModule.h"
21 #include "PybindUtils.h"
23 #include "mlir-c/AffineExpr.h"
24 #include "mlir-c/AffineMap.h"
25 #include "mlir-c/Bindings/Python/Interop.h"
26 #include "mlir-c/IntegerSet.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/Hashing.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
33 namespace py = pybind11;
34 using namespace mlir;
35 using namespace mlir::python;
37 using llvm::SmallVector;
38 using llvm::StringRef;
39 using llvm::Twine;
41 static const char kDumpDocstring[] =
42 R"(Dumps a debug representation of the object to stderr.)";
44 /// Attempts to populate `result` with the content of `list` casted to the
45 /// appropriate type (Python and C types are provided as template arguments).
46 /// Throws errors in case of failure, using "action" to describe what the caller
47 /// was attempting to do.
48 template <typename PyType, typename CType>
49 static void pyListToVector(const py::list &list,
50 llvm::SmallVectorImpl<CType> &result,
51 StringRef action) {
52 result.reserve(py::len(list));
53 for (py::handle item : list) {
54 try {
55 result.push_back(item.cast<PyType>());
56 } catch (py::cast_error &err) {
57 std::string msg = (llvm::Twine("Invalid expression when ") + action +
58 " (" + err.what() + ")")
59 .str();
60 throw py::cast_error(msg);
61 } catch (py::reference_cast_error &err) {
62 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
63 action + " (" + err.what() + ")")
64 .str();
65 throw py::cast_error(msg);
70 template <typename PermutationTy>
71 static bool isPermutation(std::vector<PermutationTy> permutation) {
72 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
73 for (auto val : permutation) {
74 if (val < permutation.size()) {
75 if (seen[val])
76 return false;
77 seen[val] = true;
78 continue;
80 return false;
82 return true;
85 namespace {
87 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
88 /// and should be castable from it. Intermediate hierarchy classes can be
89 /// modeled by specifying BaseTy.
90 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
91 class PyConcreteAffineExpr : public BaseTy {
92 public:
93 // Derived classes must define statics for:
94 // IsAFunctionTy isaFunction
95 // const char *pyClassName
96 // and redefine bindDerived.
97 using ClassTy = py::class_<DerivedTy, BaseTy>;
98 using IsAFunctionTy = bool (*)(MlirAffineExpr);
100 PyConcreteAffineExpr() = default;
101 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
102 : BaseTy(std::move(contextRef), affineExpr) {}
103 PyConcreteAffineExpr(PyAffineExpr &orig)
104 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
106 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
107 if (!DerivedTy::isaFunction(orig)) {
108 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
109 throw py::value_error((Twine("Cannot cast affine expression to ") +
110 DerivedTy::pyClassName + " (from " + origRepr +
111 ")")
112 .str());
114 return orig;
117 static void bind(py::module &m) {
118 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
119 cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
120 cls.def_static(
121 "isinstance",
122 [](PyAffineExpr &otherAffineExpr) -> bool {
123 return DerivedTy::isaFunction(otherAffineExpr);
125 py::arg("other"));
126 DerivedTy::bindDerived(cls);
129 /// Implemented by derived classes to add methods to the Python subclass.
130 static void bindDerived(ClassTy &m) {}
133 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
134 public:
135 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
136 static constexpr const char *pyClassName = "AffineConstantExpr";
137 using PyConcreteAffineExpr::PyConcreteAffineExpr;
139 static PyAffineConstantExpr get(intptr_t value,
140 DefaultingPyMlirContext context) {
141 MlirAffineExpr affineExpr =
142 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
143 return PyAffineConstantExpr(context->getRef(), affineExpr);
146 static void bindDerived(ClassTy &c) {
147 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
148 py::arg("context") = py::none());
149 c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
150 return mlirAffineConstantExprGetValue(self);
155 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
156 public:
157 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
158 static constexpr const char *pyClassName = "AffineDimExpr";
159 using PyConcreteAffineExpr::PyConcreteAffineExpr;
161 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
162 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
163 return PyAffineDimExpr(context->getRef(), affineExpr);
166 static void bindDerived(ClassTy &c) {
167 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
168 py::arg("context") = py::none());
169 c.def_property_readonly("position", [](PyAffineDimExpr &self) {
170 return mlirAffineDimExprGetPosition(self);
175 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
176 public:
177 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
178 static constexpr const char *pyClassName = "AffineSymbolExpr";
179 using PyConcreteAffineExpr::PyConcreteAffineExpr;
181 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
182 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
183 return PyAffineSymbolExpr(context->getRef(), affineExpr);
186 static void bindDerived(ClassTy &c) {
187 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
188 py::arg("context") = py::none());
189 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
190 return mlirAffineSymbolExprGetPosition(self);
195 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
196 public:
197 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
198 static constexpr const char *pyClassName = "AffineBinaryExpr";
199 using PyConcreteAffineExpr::PyConcreteAffineExpr;
201 PyAffineExpr lhs() {
202 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
203 return PyAffineExpr(getContext(), lhsExpr);
206 PyAffineExpr rhs() {
207 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
208 return PyAffineExpr(getContext(), rhsExpr);
211 static void bindDerived(ClassTy &c) {
212 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
213 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
217 class PyAffineAddExpr
218 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
219 public:
220 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
221 static constexpr const char *pyClassName = "AffineAddExpr";
222 using PyConcreteAffineExpr::PyConcreteAffineExpr;
224 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
225 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
226 return PyAffineAddExpr(lhs.getContext(), expr);
229 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
230 MlirAffineExpr expr = mlirAffineAddExprGet(
231 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
232 return PyAffineAddExpr(lhs.getContext(), expr);
235 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
236 MlirAffineExpr expr = mlirAffineAddExprGet(
237 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
238 return PyAffineAddExpr(rhs.getContext(), expr);
241 static void bindDerived(ClassTy &c) {
242 c.def_static("get", &PyAffineAddExpr::get);
246 class PyAffineMulExpr
247 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
248 public:
249 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
250 static constexpr const char *pyClassName = "AffineMulExpr";
251 using PyConcreteAffineExpr::PyConcreteAffineExpr;
253 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
254 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
255 return PyAffineMulExpr(lhs.getContext(), expr);
258 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
259 MlirAffineExpr expr = mlirAffineMulExprGet(
260 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
261 return PyAffineMulExpr(lhs.getContext(), expr);
264 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
265 MlirAffineExpr expr = mlirAffineMulExprGet(
266 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
267 return PyAffineMulExpr(rhs.getContext(), expr);
270 static void bindDerived(ClassTy &c) {
271 c.def_static("get", &PyAffineMulExpr::get);
275 class PyAffineModExpr
276 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
277 public:
278 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
279 static constexpr const char *pyClassName = "AffineModExpr";
280 using PyConcreteAffineExpr::PyConcreteAffineExpr;
282 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
283 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
284 return PyAffineModExpr(lhs.getContext(), expr);
287 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
288 MlirAffineExpr expr = mlirAffineModExprGet(
289 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
290 return PyAffineModExpr(lhs.getContext(), expr);
293 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
294 MlirAffineExpr expr = mlirAffineModExprGet(
295 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
296 return PyAffineModExpr(rhs.getContext(), expr);
299 static void bindDerived(ClassTy &c) {
300 c.def_static("get", &PyAffineModExpr::get);
304 class PyAffineFloorDivExpr
305 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
306 public:
307 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
308 static constexpr const char *pyClassName = "AffineFloorDivExpr";
309 using PyConcreteAffineExpr::PyConcreteAffineExpr;
311 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
312 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
313 return PyAffineFloorDivExpr(lhs.getContext(), expr);
316 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
317 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
318 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
319 return PyAffineFloorDivExpr(lhs.getContext(), expr);
322 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
323 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
324 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
325 return PyAffineFloorDivExpr(rhs.getContext(), expr);
328 static void bindDerived(ClassTy &c) {
329 c.def_static("get", &PyAffineFloorDivExpr::get);
333 class PyAffineCeilDivExpr
334 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
335 public:
336 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
337 static constexpr const char *pyClassName = "AffineCeilDivExpr";
338 using PyConcreteAffineExpr::PyConcreteAffineExpr;
340 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
341 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
342 return PyAffineCeilDivExpr(lhs.getContext(), expr);
345 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
346 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
347 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
348 return PyAffineCeilDivExpr(lhs.getContext(), expr);
351 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
352 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
353 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
354 return PyAffineCeilDivExpr(rhs.getContext(), expr);
357 static void bindDerived(ClassTy &c) {
358 c.def_static("get", &PyAffineCeilDivExpr::get);
362 } // namespace
364 bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
365 return mlirAffineExprEqual(affineExpr, other.affineExpr);
368 py::object PyAffineExpr::getCapsule() {
369 return py::reinterpret_steal<py::object>(
370 mlirPythonAffineExprToCapsule(*this));
373 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
374 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
375 if (mlirAffineExprIsNull(rawAffineExpr))
376 throw py::error_already_set();
377 return PyAffineExpr(
378 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
379 rawAffineExpr);
382 //------------------------------------------------------------------------------
383 // PyAffineMap and utilities.
384 //------------------------------------------------------------------------------
385 namespace {
387 /// A list of expressions contained in an affine map. Internally these are
388 /// stored as a consecutive array leading to inexpensive random access. Both
389 /// the map and the expression are owned by the context so we need not bother
390 /// with lifetime extension.
391 class PyAffineMapExprList
392 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
393 public:
394 static constexpr const char *pyClassName = "AffineExprList";
396 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
397 intptr_t length = -1, intptr_t step = 1)
398 : Sliceable(startIndex,
399 length == -1 ? mlirAffineMapGetNumResults(map) : length,
400 step),
401 affineMap(map) {}
403 private:
404 /// Give the parent CRTP class access to hook implementations below.
405 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
407 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
409 PyAffineExpr getRawElement(intptr_t pos) {
410 return PyAffineExpr(affineMap.getContext(),
411 mlirAffineMapGetResult(affineMap, pos));
414 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
415 intptr_t step) {
416 return PyAffineMapExprList(affineMap, startIndex, length, step);
419 PyAffineMap affineMap;
421 } // namespace
423 bool PyAffineMap::operator==(const PyAffineMap &other) const {
424 return mlirAffineMapEqual(affineMap, other.affineMap);
427 py::object PyAffineMap::getCapsule() {
428 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
431 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
432 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
433 if (mlirAffineMapIsNull(rawAffineMap))
434 throw py::error_already_set();
435 return PyAffineMap(
436 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
437 rawAffineMap);
440 //------------------------------------------------------------------------------
441 // PyIntegerSet and utilities.
442 //------------------------------------------------------------------------------
443 namespace {
445 class PyIntegerSetConstraint {
446 public:
447 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
448 : set(std::move(set)), pos(pos) {}
450 PyAffineExpr getExpr() {
451 return PyAffineExpr(set.getContext(),
452 mlirIntegerSetGetConstraint(set, pos));
455 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
457 static void bind(py::module &m) {
458 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
459 py::module_local())
460 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
461 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
464 private:
465 PyIntegerSet set;
466 intptr_t pos;
469 class PyIntegerSetConstraintList
470 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
471 public:
472 static constexpr const char *pyClassName = "IntegerSetConstraintList";
474 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
475 intptr_t length = -1, intptr_t step = 1)
476 : Sliceable(startIndex,
477 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
478 step),
479 set(set) {}
481 private:
482 /// Give the parent CRTP class access to hook implementations below.
483 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
485 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
487 PyIntegerSetConstraint getRawElement(intptr_t pos) {
488 return PyIntegerSetConstraint(set, pos);
491 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
492 intptr_t step) {
493 return PyIntegerSetConstraintList(set, startIndex, length, step);
496 PyIntegerSet set;
498 } // namespace
500 bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
501 return mlirIntegerSetEqual(integerSet, other.integerSet);
504 py::object PyIntegerSet::getCapsule() {
505 return py::reinterpret_steal<py::object>(
506 mlirPythonIntegerSetToCapsule(*this));
509 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
510 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
511 if (mlirIntegerSetIsNull(rawIntegerSet))
512 throw py::error_already_set();
513 return PyIntegerSet(
514 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
515 rawIntegerSet);
518 void mlir::python::populateIRAffine(py::module &m) {
519 //----------------------------------------------------------------------------
520 // Mapping of PyAffineExpr and derived classes.
521 //----------------------------------------------------------------------------
522 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
523 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524 &PyAffineExpr::getCapsule)
525 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
526 .def("__add__", &PyAffineAddExpr::get)
527 .def("__add__", &PyAffineAddExpr::getRHSConstant)
528 .def("__radd__", &PyAffineAddExpr::getRHSConstant)
529 .def("__mul__", &PyAffineMulExpr::get)
530 .def("__mul__", &PyAffineMulExpr::getRHSConstant)
531 .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
532 .def("__mod__", &PyAffineModExpr::get)
533 .def("__mod__", &PyAffineModExpr::getRHSConstant)
534 .def("__rmod__",
535 [](PyAffineExpr &self, intptr_t other) {
536 return PyAffineModExpr::get(
537 PyAffineConstantExpr::get(other, *self.getContext().get()),
538 self);
540 .def("__sub__",
541 [](PyAffineExpr &self, PyAffineExpr &other) {
542 auto negOne =
543 PyAffineConstantExpr::get(-1, *self.getContext().get());
544 return PyAffineAddExpr::get(self,
545 PyAffineMulExpr::get(negOne, other));
547 .def("__sub__",
548 [](PyAffineExpr &self, intptr_t other) {
549 return PyAffineAddExpr::get(
550 self,
551 PyAffineConstantExpr::get(-other, *self.getContext().get()));
553 .def("__rsub__",
554 [](PyAffineExpr &self, intptr_t other) {
555 return PyAffineAddExpr::getLHSConstant(
556 other, PyAffineMulExpr::getLHSConstant(-1, self));
558 .def("__eq__", [](PyAffineExpr &self,
559 PyAffineExpr &other) { return self == other; })
560 .def("__eq__",
561 [](PyAffineExpr &self, py::object &other) { return false; })
562 .def("__str__",
563 [](PyAffineExpr &self) {
564 PyPrintAccumulator printAccum;
565 mlirAffineExprPrint(self, printAccum.getCallback(),
566 printAccum.getUserData());
567 return printAccum.join();
569 .def("__repr__",
570 [](PyAffineExpr &self) {
571 PyPrintAccumulator printAccum;
572 printAccum.parts.append("AffineExpr(");
573 mlirAffineExprPrint(self, printAccum.getCallback(),
574 printAccum.getUserData());
575 printAccum.parts.append(")");
576 return printAccum.join();
578 .def("__hash__",
579 [](PyAffineExpr &self) {
580 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
582 .def_property_readonly(
583 "context",
584 [](PyAffineExpr &self) { return self.getContext().getObject(); })
585 .def("compose",
586 [](PyAffineExpr &self, PyAffineMap &other) {
587 return PyAffineExpr(self.getContext(),
588 mlirAffineExprCompose(self, other));
590 .def_static(
591 "get_add", &PyAffineAddExpr::get,
592 "Gets an affine expression containing a sum of two expressions.")
593 .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
594 "Gets an affine expression containing a sum of a constant "
595 "and another expression.")
596 .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
597 "Gets an affine expression containing a sum of an expression "
598 "and a constant.")
599 .def_static(
600 "get_mul", &PyAffineMulExpr::get,
601 "Gets an affine expression containing a product of two expressions.")
602 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
603 "Gets an affine expression containing a product of a "
604 "constant and another expression.")
605 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
606 "Gets an affine expression containing a product of an "
607 "expression and a constant.")
608 .def_static("get_mod", &PyAffineModExpr::get,
609 "Gets an affine expression containing the modulo of dividing "
610 "one expression by another.")
611 .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
612 "Gets a semi-affine expression containing the modulo of "
613 "dividing a constant by an expression.")
614 .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
615 "Gets an affine expression containing the module of dividing"
616 "an expression by a constant.")
617 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
618 "Gets an affine expression containing the rounded-down "
619 "result of dividing one expression by another.")
620 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
621 "Gets a semi-affine expression containing the rounded-down "
622 "result of dividing a constant by an expression.")
623 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
624 "Gets an affine expression containing the rounded-down "
625 "result of dividing an expression by a constant.")
626 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
627 "Gets an affine expression containing the rounded-up result "
628 "of dividing one expression by another.")
629 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
630 "Gets a semi-affine expression containing the rounded-up "
631 "result of dividing a constant by an expression.")
632 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
633 "Gets an affine expression containing the rounded-up result "
634 "of dividing an expression by a constant.")
635 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
636 py::arg("context") = py::none(),
637 "Gets a constant affine expression with the given value.")
638 .def_static(
639 "get_dim", &PyAffineDimExpr::get, py::arg("position"),
640 py::arg("context") = py::none(),
641 "Gets an affine expression of a dimension at the given position.")
642 .def_static(
643 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
644 py::arg("context") = py::none(),
645 "Gets an affine expression of a symbol at the given position.")
646 .def(
647 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
648 kDumpDocstring);
649 PyAffineConstantExpr::bind(m);
650 PyAffineDimExpr::bind(m);
651 PyAffineSymbolExpr::bind(m);
652 PyAffineBinaryExpr::bind(m);
653 PyAffineAddExpr::bind(m);
654 PyAffineMulExpr::bind(m);
655 PyAffineModExpr::bind(m);
656 PyAffineFloorDivExpr::bind(m);
657 PyAffineCeilDivExpr::bind(m);
659 //----------------------------------------------------------------------------
660 // Mapping of PyAffineMap.
661 //----------------------------------------------------------------------------
662 py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
663 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
664 &PyAffineMap::getCapsule)
665 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
666 .def("__eq__",
667 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
668 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
669 .def("__str__",
670 [](PyAffineMap &self) {
671 PyPrintAccumulator printAccum;
672 mlirAffineMapPrint(self, printAccum.getCallback(),
673 printAccum.getUserData());
674 return printAccum.join();
676 .def("__repr__",
677 [](PyAffineMap &self) {
678 PyPrintAccumulator printAccum;
679 printAccum.parts.append("AffineMap(");
680 mlirAffineMapPrint(self, printAccum.getCallback(),
681 printAccum.getUserData());
682 printAccum.parts.append(")");
683 return printAccum.join();
685 .def("__hash__",
686 [](PyAffineMap &self) {
687 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
689 .def_static("compress_unused_symbols",
690 [](py::list affineMaps, DefaultingPyMlirContext context) {
691 SmallVector<MlirAffineMap> maps;
692 pyListToVector<PyAffineMap, MlirAffineMap>(
693 affineMaps, maps, "attempting to create an AffineMap");
694 std::vector<MlirAffineMap> compressed(affineMaps.size());
695 auto populate = [](void *result, intptr_t idx,
696 MlirAffineMap m) {
697 static_cast<MlirAffineMap *>(result)[idx] = (m);
699 mlirAffineMapCompressUnusedSymbols(
700 maps.data(), maps.size(), compressed.data(), populate);
701 std::vector<PyAffineMap> res;
702 res.reserve(compressed.size());
703 for (auto m : compressed)
704 res.emplace_back(context->getRef(), m);
705 return res;
707 .def_property_readonly(
708 "context",
709 [](PyAffineMap &self) { return self.getContext().getObject(); },
710 "Context that owns the Affine Map")
711 .def(
712 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
713 kDumpDocstring)
714 .def_static(
715 "get",
716 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
717 DefaultingPyMlirContext context) {
718 SmallVector<MlirAffineExpr> affineExprs;
719 pyListToVector<PyAffineExpr, MlirAffineExpr>(
720 exprs, affineExprs, "attempting to create an AffineMap");
721 MlirAffineMap map =
722 mlirAffineMapGet(context->get(), dimCount, symbolCount,
723 affineExprs.size(), affineExprs.data());
724 return PyAffineMap(context->getRef(), map);
726 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
727 py::arg("context") = py::none(),
728 "Gets a map with the given expressions as results.")
729 .def_static(
730 "get_constant",
731 [](intptr_t value, DefaultingPyMlirContext context) {
732 MlirAffineMap affineMap =
733 mlirAffineMapConstantGet(context->get(), value);
734 return PyAffineMap(context->getRef(), affineMap);
736 py::arg("value"), py::arg("context") = py::none(),
737 "Gets an affine map with a single constant result")
738 .def_static(
739 "get_empty",
740 [](DefaultingPyMlirContext context) {
741 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
742 return PyAffineMap(context->getRef(), affineMap);
744 py::arg("context") = py::none(), "Gets an empty affine map.")
745 .def_static(
746 "get_identity",
747 [](intptr_t nDims, DefaultingPyMlirContext context) {
748 MlirAffineMap affineMap =
749 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
750 return PyAffineMap(context->getRef(), affineMap);
752 py::arg("n_dims"), py::arg("context") = py::none(),
753 "Gets an identity map with the given number of dimensions.")
754 .def_static(
755 "get_minor_identity",
756 [](intptr_t nDims, intptr_t nResults,
757 DefaultingPyMlirContext context) {
758 MlirAffineMap affineMap =
759 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
760 return PyAffineMap(context->getRef(), affineMap);
762 py::arg("n_dims"), py::arg("n_results"),
763 py::arg("context") = py::none(),
764 "Gets a minor identity map with the given number of dimensions and "
765 "results.")
766 .def_static(
767 "get_permutation",
768 [](std::vector<unsigned> permutation,
769 DefaultingPyMlirContext context) {
770 if (!isPermutation(permutation))
771 throw py::cast_error("Invalid permutation when attempting to "
772 "create an AffineMap");
773 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
774 context->get(), permutation.size(), permutation.data());
775 return PyAffineMap(context->getRef(), affineMap);
777 py::arg("permutation"), py::arg("context") = py::none(),
778 "Gets an affine map that permutes its inputs.")
779 .def(
780 "get_submap",
781 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
782 intptr_t numResults = mlirAffineMapGetNumResults(self);
783 for (intptr_t pos : resultPos) {
784 if (pos < 0 || pos >= numResults)
785 throw py::value_error("result position out of bounds");
787 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
788 self, resultPos.size(), resultPos.data());
789 return PyAffineMap(self.getContext(), affineMap);
791 py::arg("result_positions"))
792 .def(
793 "get_major_submap",
794 [](PyAffineMap &self, intptr_t nResults) {
795 if (nResults >= mlirAffineMapGetNumResults(self))
796 throw py::value_error("number of results out of bounds");
797 MlirAffineMap affineMap =
798 mlirAffineMapGetMajorSubMap(self, nResults);
799 return PyAffineMap(self.getContext(), affineMap);
801 py::arg("n_results"))
802 .def(
803 "get_minor_submap",
804 [](PyAffineMap &self, intptr_t nResults) {
805 if (nResults >= mlirAffineMapGetNumResults(self))
806 throw py::value_error("number of results out of bounds");
807 MlirAffineMap affineMap =
808 mlirAffineMapGetMinorSubMap(self, nResults);
809 return PyAffineMap(self.getContext(), affineMap);
811 py::arg("n_results"))
812 .def(
813 "replace",
814 [](PyAffineMap &self, PyAffineExpr &expression,
815 PyAffineExpr &replacement, intptr_t numResultDims,
816 intptr_t numResultSyms) {
817 MlirAffineMap affineMap = mlirAffineMapReplace(
818 self, expression, replacement, numResultDims, numResultSyms);
819 return PyAffineMap(self.getContext(), affineMap);
821 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
822 py::arg("n_result_syms"))
823 .def_property_readonly(
824 "is_permutation",
825 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
826 .def_property_readonly("is_projected_permutation",
827 [](PyAffineMap &self) {
828 return mlirAffineMapIsProjectedPermutation(self);
830 .def_property_readonly(
831 "n_dims",
832 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
833 .def_property_readonly(
834 "n_inputs",
835 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
836 .def_property_readonly(
837 "n_symbols",
838 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
839 .def_property_readonly("results", [](PyAffineMap &self) {
840 return PyAffineMapExprList(self);
842 PyAffineMapExprList::bind(m);
844 //----------------------------------------------------------------------------
845 // Mapping of PyIntegerSet.
846 //----------------------------------------------------------------------------
847 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
848 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
849 &PyIntegerSet::getCapsule)
850 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
851 .def("__eq__", [](PyIntegerSet &self,
852 PyIntegerSet &other) { return self == other; })
853 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
854 .def("__str__",
855 [](PyIntegerSet &self) {
856 PyPrintAccumulator printAccum;
857 mlirIntegerSetPrint(self, printAccum.getCallback(),
858 printAccum.getUserData());
859 return printAccum.join();
861 .def("__repr__",
862 [](PyIntegerSet &self) {
863 PyPrintAccumulator printAccum;
864 printAccum.parts.append("IntegerSet(");
865 mlirIntegerSetPrint(self, printAccum.getCallback(),
866 printAccum.getUserData());
867 printAccum.parts.append(")");
868 return printAccum.join();
870 .def("__hash__",
871 [](PyIntegerSet &self) {
872 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
874 .def_property_readonly(
875 "context",
876 [](PyIntegerSet &self) { return self.getContext().getObject(); })
877 .def(
878 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
879 kDumpDocstring)
880 .def_static(
881 "get",
882 [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
883 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
884 if (exprs.size() != eqFlags.size())
885 throw py::value_error(
886 "Expected the number of constraints to match "
887 "that of equality flags");
888 if (exprs.empty())
889 throw py::value_error("Expected non-empty list of constraints");
891 // Copy over to a SmallVector because std::vector has a
892 // specialization for booleans that packs data and does not
893 // expose a `bool *`.
894 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
896 SmallVector<MlirAffineExpr> affineExprs;
897 pyListToVector<PyAffineExpr>(exprs, affineExprs,
898 "attempting to create an IntegerSet");
899 MlirIntegerSet set = mlirIntegerSetGet(
900 context->get(), numDims, numSymbols, exprs.size(),
901 affineExprs.data(), flags.data());
902 return PyIntegerSet(context->getRef(), set);
904 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
905 py::arg("eq_flags"), py::arg("context") = py::none())
906 .def_static(
907 "get_empty",
908 [](intptr_t numDims, intptr_t numSymbols,
909 DefaultingPyMlirContext context) {
910 MlirIntegerSet set =
911 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
912 return PyIntegerSet(context->getRef(), set);
914 py::arg("num_dims"), py::arg("num_symbols"),
915 py::arg("context") = py::none())
916 .def(
917 "get_replaced",
918 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
919 intptr_t numResultDims, intptr_t numResultSymbols) {
920 if (static_cast<intptr_t>(dimExprs.size()) !=
921 mlirIntegerSetGetNumDims(self))
922 throw py::value_error(
923 "Expected the number of dimension replacement expressions "
924 "to match that of dimensions");
925 if (static_cast<intptr_t>(symbolExprs.size()) !=
926 mlirIntegerSetGetNumSymbols(self))
927 throw py::value_error(
928 "Expected the number of symbol replacement expressions "
929 "to match that of symbols");
931 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
932 pyListToVector<PyAffineExpr>(
933 dimExprs, dimAffineExprs,
934 "attempting to create an IntegerSet by replacing dimensions");
935 pyListToVector<PyAffineExpr>(
936 symbolExprs, symbolAffineExprs,
937 "attempting to create an IntegerSet by replacing symbols");
938 MlirIntegerSet set = mlirIntegerSetReplaceGet(
939 self, dimAffineExprs.data(), symbolAffineExprs.data(),
940 numResultDims, numResultSymbols);
941 return PyIntegerSet(self.getContext(), set);
943 py::arg("dim_exprs"), py::arg("symbol_exprs"),
944 py::arg("num_result_dims"), py::arg("num_result_symbols"))
945 .def_property_readonly("is_canonical_empty",
946 [](PyIntegerSet &self) {
947 return mlirIntegerSetIsCanonicalEmpty(self);
949 .def_property_readonly(
950 "n_dims",
951 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
952 .def_property_readonly(
953 "n_symbols",
954 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
955 .def_property_readonly(
956 "n_inputs",
957 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
958 .def_property_readonly("n_equalities",
959 [](PyIntegerSet &self) {
960 return mlirIntegerSetGetNumEqualities(self);
962 .def_property_readonly("n_inequalities",
963 [](PyIntegerSet &self) {
964 return mlirIntegerSetGetNumInequalities(self);
966 .def_property_readonly("constraints", [](PyIntegerSet &self) {
967 return PyIntegerSetConstraintList(self);
969 PyIntegerSetConstraint::bind(m);
970 PyIntegerSetConstraintList::bind(m);