1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include <pybind11/cast.h>
12 #include <pybind11/detail/common.h>
13 #include <pybind11/pybind11.h>
14 #include <pybind11/pytypes.h>
21 #include "PybindUtils.h"
23 #include "mlir-c/AffineExpr.h"
24 #include "mlir-c/AffineMap.h"
25 #include "mlir-c/Bindings/Python/Interop.h"
26 #include "mlir-c/IntegerSet.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/Hashing.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
33 namespace py
= pybind11
;
35 using namespace mlir::python
;
37 using llvm::SmallVector
;
38 using llvm::StringRef
;
41 static const char kDumpDocstring
[] =
42 R
"(Dumps a debug representation of the object to stderr.)";
44 /// Attempts to populate `result` with the content of `list` casted to the
45 /// appropriate type (Python and C types are provided as template arguments).
46 /// Throws errors in case of failure, using "action" to describe what the caller
47 /// was attempting to do.
48 template <typename PyType
, typename CType
>
49 static void pyListToVector(const py::list
&list
,
50 llvm::SmallVectorImpl
<CType
> &result
,
52 result
.reserve(py::len(list
));
53 for (py::handle item
: list
) {
55 result
.push_back(item
.cast
<PyType
>());
56 } catch (py::cast_error
&err
) {
57 std::string msg
= (llvm::Twine("Invalid expression when ") + action
+
58 " (" + err
.what() + ")")
60 throw py::cast_error(msg
);
61 } catch (py::reference_cast_error
&err
) {
62 std::string msg
= (llvm::Twine("Invalid expression (None?) when ") +
63 action
+ " (" + err
.what() + ")")
65 throw py::cast_error(msg
);
70 template <typename PermutationTy
>
71 static bool isPermutation(std::vector
<PermutationTy
> permutation
) {
72 llvm::SmallVector
<bool, 8> seen(permutation
.size(), false);
73 for (auto val
: permutation
) {
74 if (val
< permutation
.size()) {
87 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
88 /// and should be castable from it. Intermediate hierarchy classes can be
89 /// modeled by specifying BaseTy.
90 template <typename DerivedTy
, typename BaseTy
= PyAffineExpr
>
91 class PyConcreteAffineExpr
: public BaseTy
{
93 // Derived classes must define statics for:
94 // IsAFunctionTy isaFunction
95 // const char *pyClassName
96 // and redefine bindDerived.
97 using ClassTy
= py::class_
<DerivedTy
, BaseTy
>;
98 using IsAFunctionTy
= bool (*)(MlirAffineExpr
);
100 PyConcreteAffineExpr() = default;
101 PyConcreteAffineExpr(PyMlirContextRef contextRef
, MlirAffineExpr affineExpr
)
102 : BaseTy(std::move(contextRef
), affineExpr
) {}
103 PyConcreteAffineExpr(PyAffineExpr
&orig
)
104 : PyConcreteAffineExpr(orig
.getContext(), castFrom(orig
)) {}
106 static MlirAffineExpr
castFrom(PyAffineExpr
&orig
) {
107 if (!DerivedTy::isaFunction(orig
)) {
108 auto origRepr
= py::repr(py::cast(orig
)).cast
<std::string
>();
109 throw py::value_error((Twine("Cannot cast affine expression to ") +
110 DerivedTy::pyClassName
+ " (from " + origRepr
+
117 static void bind(py::module
&m
) {
118 auto cls
= ClassTy(m
, DerivedTy::pyClassName
, py::module_local());
119 cls
.def(py::init
<PyAffineExpr
&>(), py::arg("expr"));
122 [](PyAffineExpr
&otherAffineExpr
) -> bool {
123 return DerivedTy::isaFunction(otherAffineExpr
);
126 DerivedTy::bindDerived(cls
);
129 /// Implemented by derived classes to add methods to the Python subclass.
130 static void bindDerived(ClassTy
&m
) {}
133 class PyAffineConstantExpr
: public PyConcreteAffineExpr
<PyAffineConstantExpr
> {
135 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAConstant
;
136 static constexpr const char *pyClassName
= "AffineConstantExpr";
137 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
139 static PyAffineConstantExpr
get(intptr_t value
,
140 DefaultingPyMlirContext context
) {
141 MlirAffineExpr affineExpr
=
142 mlirAffineConstantExprGet(context
->get(), static_cast<int64_t>(value
));
143 return PyAffineConstantExpr(context
->getRef(), affineExpr
);
146 static void bindDerived(ClassTy
&c
) {
147 c
.def_static("get", &PyAffineConstantExpr::get
, py::arg("value"),
148 py::arg("context") = py::none());
149 c
.def_property_readonly("value", [](PyAffineConstantExpr
&self
) {
150 return mlirAffineConstantExprGetValue(self
);
155 class PyAffineDimExpr
: public PyConcreteAffineExpr
<PyAffineDimExpr
> {
157 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsADim
;
158 static constexpr const char *pyClassName
= "AffineDimExpr";
159 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
161 static PyAffineDimExpr
get(intptr_t pos
, DefaultingPyMlirContext context
) {
162 MlirAffineExpr affineExpr
= mlirAffineDimExprGet(context
->get(), pos
);
163 return PyAffineDimExpr(context
->getRef(), affineExpr
);
166 static void bindDerived(ClassTy
&c
) {
167 c
.def_static("get", &PyAffineDimExpr::get
, py::arg("position"),
168 py::arg("context") = py::none());
169 c
.def_property_readonly("position", [](PyAffineDimExpr
&self
) {
170 return mlirAffineDimExprGetPosition(self
);
175 class PyAffineSymbolExpr
: public PyConcreteAffineExpr
<PyAffineSymbolExpr
> {
177 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsASymbol
;
178 static constexpr const char *pyClassName
= "AffineSymbolExpr";
179 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
181 static PyAffineSymbolExpr
get(intptr_t pos
, DefaultingPyMlirContext context
) {
182 MlirAffineExpr affineExpr
= mlirAffineSymbolExprGet(context
->get(), pos
);
183 return PyAffineSymbolExpr(context
->getRef(), affineExpr
);
186 static void bindDerived(ClassTy
&c
) {
187 c
.def_static("get", &PyAffineSymbolExpr::get
, py::arg("position"),
188 py::arg("context") = py::none());
189 c
.def_property_readonly("position", [](PyAffineSymbolExpr
&self
) {
190 return mlirAffineSymbolExprGetPosition(self
);
195 class PyAffineBinaryExpr
: public PyConcreteAffineExpr
<PyAffineBinaryExpr
> {
197 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsABinary
;
198 static constexpr const char *pyClassName
= "AffineBinaryExpr";
199 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
202 MlirAffineExpr lhsExpr
= mlirAffineBinaryOpExprGetLHS(get());
203 return PyAffineExpr(getContext(), lhsExpr
);
207 MlirAffineExpr rhsExpr
= mlirAffineBinaryOpExprGetRHS(get());
208 return PyAffineExpr(getContext(), rhsExpr
);
211 static void bindDerived(ClassTy
&c
) {
212 c
.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs
);
213 c
.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs
);
217 class PyAffineAddExpr
218 : public PyConcreteAffineExpr
<PyAffineAddExpr
, PyAffineBinaryExpr
> {
220 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAAdd
;
221 static constexpr const char *pyClassName
= "AffineAddExpr";
222 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
224 static PyAffineAddExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
225 MlirAffineExpr expr
= mlirAffineAddExprGet(lhs
, rhs
);
226 return PyAffineAddExpr(lhs
.getContext(), expr
);
229 static PyAffineAddExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
230 MlirAffineExpr expr
= mlirAffineAddExprGet(
231 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
232 return PyAffineAddExpr(lhs
.getContext(), expr
);
235 static PyAffineAddExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
236 MlirAffineExpr expr
= mlirAffineAddExprGet(
237 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
238 return PyAffineAddExpr(rhs
.getContext(), expr
);
241 static void bindDerived(ClassTy
&c
) {
242 c
.def_static("get", &PyAffineAddExpr::get
);
246 class PyAffineMulExpr
247 : public PyConcreteAffineExpr
<PyAffineMulExpr
, PyAffineBinaryExpr
> {
249 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAMul
;
250 static constexpr const char *pyClassName
= "AffineMulExpr";
251 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
253 static PyAffineMulExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
254 MlirAffineExpr expr
= mlirAffineMulExprGet(lhs
, rhs
);
255 return PyAffineMulExpr(lhs
.getContext(), expr
);
258 static PyAffineMulExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
259 MlirAffineExpr expr
= mlirAffineMulExprGet(
260 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
261 return PyAffineMulExpr(lhs
.getContext(), expr
);
264 static PyAffineMulExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
265 MlirAffineExpr expr
= mlirAffineMulExprGet(
266 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
267 return PyAffineMulExpr(rhs
.getContext(), expr
);
270 static void bindDerived(ClassTy
&c
) {
271 c
.def_static("get", &PyAffineMulExpr::get
);
275 class PyAffineModExpr
276 : public PyConcreteAffineExpr
<PyAffineModExpr
, PyAffineBinaryExpr
> {
278 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAMod
;
279 static constexpr const char *pyClassName
= "AffineModExpr";
280 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
282 static PyAffineModExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
283 MlirAffineExpr expr
= mlirAffineModExprGet(lhs
, rhs
);
284 return PyAffineModExpr(lhs
.getContext(), expr
);
287 static PyAffineModExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
288 MlirAffineExpr expr
= mlirAffineModExprGet(
289 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
290 return PyAffineModExpr(lhs
.getContext(), expr
);
293 static PyAffineModExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
294 MlirAffineExpr expr
= mlirAffineModExprGet(
295 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
296 return PyAffineModExpr(rhs
.getContext(), expr
);
299 static void bindDerived(ClassTy
&c
) {
300 c
.def_static("get", &PyAffineModExpr::get
);
304 class PyAffineFloorDivExpr
305 : public PyConcreteAffineExpr
<PyAffineFloorDivExpr
, PyAffineBinaryExpr
> {
307 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAFloorDiv
;
308 static constexpr const char *pyClassName
= "AffineFloorDivExpr";
309 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
311 static PyAffineFloorDivExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
312 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(lhs
, rhs
);
313 return PyAffineFloorDivExpr(lhs
.getContext(), expr
);
316 static PyAffineFloorDivExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
317 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(
318 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
319 return PyAffineFloorDivExpr(lhs
.getContext(), expr
);
322 static PyAffineFloorDivExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
323 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(
324 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
325 return PyAffineFloorDivExpr(rhs
.getContext(), expr
);
328 static void bindDerived(ClassTy
&c
) {
329 c
.def_static("get", &PyAffineFloorDivExpr::get
);
333 class PyAffineCeilDivExpr
334 : public PyConcreteAffineExpr
<PyAffineCeilDivExpr
, PyAffineBinaryExpr
> {
336 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsACeilDiv
;
337 static constexpr const char *pyClassName
= "AffineCeilDivExpr";
338 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
340 static PyAffineCeilDivExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
341 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(lhs
, rhs
);
342 return PyAffineCeilDivExpr(lhs
.getContext(), expr
);
345 static PyAffineCeilDivExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
346 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(
347 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
348 return PyAffineCeilDivExpr(lhs
.getContext(), expr
);
351 static PyAffineCeilDivExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
352 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(
353 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
354 return PyAffineCeilDivExpr(rhs
.getContext(), expr
);
357 static void bindDerived(ClassTy
&c
) {
358 c
.def_static("get", &PyAffineCeilDivExpr::get
);
364 bool PyAffineExpr::operator==(const PyAffineExpr
&other
) const {
365 return mlirAffineExprEqual(affineExpr
, other
.affineExpr
);
368 py::object
PyAffineExpr::getCapsule() {
369 return py::reinterpret_steal
<py::object
>(
370 mlirPythonAffineExprToCapsule(*this));
373 PyAffineExpr
PyAffineExpr::createFromCapsule(py::object capsule
) {
374 MlirAffineExpr rawAffineExpr
= mlirPythonCapsuleToAffineExpr(capsule
.ptr());
375 if (mlirAffineExprIsNull(rawAffineExpr
))
376 throw py::error_already_set();
378 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr
)),
382 //------------------------------------------------------------------------------
383 // PyAffineMap and utilities.
384 //------------------------------------------------------------------------------
387 /// A list of expressions contained in an affine map. Internally these are
388 /// stored as a consecutive array leading to inexpensive random access. Both
389 /// the map and the expression are owned by the context so we need not bother
390 /// with lifetime extension.
391 class PyAffineMapExprList
392 : public Sliceable
<PyAffineMapExprList
, PyAffineExpr
> {
394 static constexpr const char *pyClassName
= "AffineExprList";
396 PyAffineMapExprList(const PyAffineMap
&map
, intptr_t startIndex
= 0,
397 intptr_t length
= -1, intptr_t step
= 1)
398 : Sliceable(startIndex
,
399 length
== -1 ? mlirAffineMapGetNumResults(map
) : length
,
404 /// Give the parent CRTP class access to hook implementations below.
405 friend class Sliceable
<PyAffineMapExprList
, PyAffineExpr
>;
407 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap
); }
409 PyAffineExpr
getRawElement(intptr_t pos
) {
410 return PyAffineExpr(affineMap
.getContext(),
411 mlirAffineMapGetResult(affineMap
, pos
));
414 PyAffineMapExprList
slice(intptr_t startIndex
, intptr_t length
,
416 return PyAffineMapExprList(affineMap
, startIndex
, length
, step
);
419 PyAffineMap affineMap
;
423 bool PyAffineMap::operator==(const PyAffineMap
&other
) const {
424 return mlirAffineMapEqual(affineMap
, other
.affineMap
);
427 py::object
PyAffineMap::getCapsule() {
428 return py::reinterpret_steal
<py::object
>(mlirPythonAffineMapToCapsule(*this));
431 PyAffineMap
PyAffineMap::createFromCapsule(py::object capsule
) {
432 MlirAffineMap rawAffineMap
= mlirPythonCapsuleToAffineMap(capsule
.ptr());
433 if (mlirAffineMapIsNull(rawAffineMap
))
434 throw py::error_already_set();
436 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap
)),
440 //------------------------------------------------------------------------------
441 // PyIntegerSet and utilities.
442 //------------------------------------------------------------------------------
445 class PyIntegerSetConstraint
{
447 PyIntegerSetConstraint(PyIntegerSet set
, intptr_t pos
)
448 : set(std::move(set
)), pos(pos
) {}
450 PyAffineExpr
getExpr() {
451 return PyAffineExpr(set
.getContext(),
452 mlirIntegerSetGetConstraint(set
, pos
));
455 bool isEq() { return mlirIntegerSetIsConstraintEq(set
, pos
); }
457 static void bind(py::module
&m
) {
458 py::class_
<PyIntegerSetConstraint
>(m
, "IntegerSetConstraint",
460 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr
)
461 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq
);
469 class PyIntegerSetConstraintList
470 : public Sliceable
<PyIntegerSetConstraintList
, PyIntegerSetConstraint
> {
472 static constexpr const char *pyClassName
= "IntegerSetConstraintList";
474 PyIntegerSetConstraintList(const PyIntegerSet
&set
, intptr_t startIndex
= 0,
475 intptr_t length
= -1, intptr_t step
= 1)
476 : Sliceable(startIndex
,
477 length
== -1 ? mlirIntegerSetGetNumConstraints(set
) : length
,
482 /// Give the parent CRTP class access to hook implementations below.
483 friend class Sliceable
<PyIntegerSetConstraintList
, PyIntegerSetConstraint
>;
485 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set
); }
487 PyIntegerSetConstraint
getRawElement(intptr_t pos
) {
488 return PyIntegerSetConstraint(set
, pos
);
491 PyIntegerSetConstraintList
slice(intptr_t startIndex
, intptr_t length
,
493 return PyIntegerSetConstraintList(set
, startIndex
, length
, step
);
500 bool PyIntegerSet::operator==(const PyIntegerSet
&other
) const {
501 return mlirIntegerSetEqual(integerSet
, other
.integerSet
);
504 py::object
PyIntegerSet::getCapsule() {
505 return py::reinterpret_steal
<py::object
>(
506 mlirPythonIntegerSetToCapsule(*this));
509 PyIntegerSet
PyIntegerSet::createFromCapsule(py::object capsule
) {
510 MlirIntegerSet rawIntegerSet
= mlirPythonCapsuleToIntegerSet(capsule
.ptr());
511 if (mlirIntegerSetIsNull(rawIntegerSet
))
512 throw py::error_already_set();
514 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet
)),
518 void mlir::python::populateIRAffine(py::module
&m
) {
519 //----------------------------------------------------------------------------
520 // Mapping of PyAffineExpr and derived classes.
521 //----------------------------------------------------------------------------
522 py::class_
<PyAffineExpr
>(m
, "AffineExpr", py::module_local())
523 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR
,
524 &PyAffineExpr::getCapsule
)
525 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyAffineExpr::createFromCapsule
)
526 .def("__add__", &PyAffineAddExpr::get
)
527 .def("__add__", &PyAffineAddExpr::getRHSConstant
)
528 .def("__radd__", &PyAffineAddExpr::getRHSConstant
)
529 .def("__mul__", &PyAffineMulExpr::get
)
530 .def("__mul__", &PyAffineMulExpr::getRHSConstant
)
531 .def("__rmul__", &PyAffineMulExpr::getRHSConstant
)
532 .def("__mod__", &PyAffineModExpr::get
)
533 .def("__mod__", &PyAffineModExpr::getRHSConstant
)
535 [](PyAffineExpr
&self
, intptr_t other
) {
536 return PyAffineModExpr::get(
537 PyAffineConstantExpr::get(other
, *self
.getContext().get()),
541 [](PyAffineExpr
&self
, PyAffineExpr
&other
) {
543 PyAffineConstantExpr::get(-1, *self
.getContext().get());
544 return PyAffineAddExpr::get(self
,
545 PyAffineMulExpr::get(negOne
, other
));
548 [](PyAffineExpr
&self
, intptr_t other
) {
549 return PyAffineAddExpr::get(
551 PyAffineConstantExpr::get(-other
, *self
.getContext().get()));
554 [](PyAffineExpr
&self
, intptr_t other
) {
555 return PyAffineAddExpr::getLHSConstant(
556 other
, PyAffineMulExpr::getLHSConstant(-1, self
));
558 .def("__eq__", [](PyAffineExpr
&self
,
559 PyAffineExpr
&other
) { return self
== other
; })
561 [](PyAffineExpr
&self
, py::object
&other
) { return false; })
563 [](PyAffineExpr
&self
) {
564 PyPrintAccumulator printAccum
;
565 mlirAffineExprPrint(self
, printAccum
.getCallback(),
566 printAccum
.getUserData());
567 return printAccum
.join();
570 [](PyAffineExpr
&self
) {
571 PyPrintAccumulator printAccum
;
572 printAccum
.parts
.append("AffineExpr(");
573 mlirAffineExprPrint(self
, printAccum
.getCallback(),
574 printAccum
.getUserData());
575 printAccum
.parts
.append(")");
576 return printAccum
.join();
579 [](PyAffineExpr
&self
) {
580 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
582 .def_property_readonly(
584 [](PyAffineExpr
&self
) { return self
.getContext().getObject(); })
586 [](PyAffineExpr
&self
, PyAffineMap
&other
) {
587 return PyAffineExpr(self
.getContext(),
588 mlirAffineExprCompose(self
, other
));
591 "get_add", &PyAffineAddExpr::get
,
592 "Gets an affine expression containing a sum of two expressions.")
593 .def_static("get_add", &PyAffineAddExpr::getLHSConstant
,
594 "Gets an affine expression containing a sum of a constant "
595 "and another expression.")
596 .def_static("get_add", &PyAffineAddExpr::getRHSConstant
,
597 "Gets an affine expression containing a sum of an expression "
600 "get_mul", &PyAffineMulExpr::get
,
601 "Gets an affine expression containing a product of two expressions.")
602 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant
,
603 "Gets an affine expression containing a product of a "
604 "constant and another expression.")
605 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant
,
606 "Gets an affine expression containing a product of an "
607 "expression and a constant.")
608 .def_static("get_mod", &PyAffineModExpr::get
,
609 "Gets an affine expression containing the modulo of dividing "
610 "one expression by another.")
611 .def_static("get_mod", &PyAffineModExpr::getLHSConstant
,
612 "Gets a semi-affine expression containing the modulo of "
613 "dividing a constant by an expression.")
614 .def_static("get_mod", &PyAffineModExpr::getRHSConstant
,
615 "Gets an affine expression containing the module of dividing"
616 "an expression by a constant.")
617 .def_static("get_floor_div", &PyAffineFloorDivExpr::get
,
618 "Gets an affine expression containing the rounded-down "
619 "result of dividing one expression by another.")
620 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant
,
621 "Gets a semi-affine expression containing the rounded-down "
622 "result of dividing a constant by an expression.")
623 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant
,
624 "Gets an affine expression containing the rounded-down "
625 "result of dividing an expression by a constant.")
626 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get
,
627 "Gets an affine expression containing the rounded-up result "
628 "of dividing one expression by another.")
629 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant
,
630 "Gets a semi-affine expression containing the rounded-up "
631 "result of dividing a constant by an expression.")
632 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant
,
633 "Gets an affine expression containing the rounded-up result "
634 "of dividing an expression by a constant.")
635 .def_static("get_constant", &PyAffineConstantExpr::get
, py::arg("value"),
636 py::arg("context") = py::none(),
637 "Gets a constant affine expression with the given value.")
639 "get_dim", &PyAffineDimExpr::get
, py::arg("position"),
640 py::arg("context") = py::none(),
641 "Gets an affine expression of a dimension at the given position.")
643 "get_symbol", &PyAffineSymbolExpr::get
, py::arg("position"),
644 py::arg("context") = py::none(),
645 "Gets an affine expression of a symbol at the given position.")
647 "dump", [](PyAffineExpr
&self
) { mlirAffineExprDump(self
); },
649 PyAffineConstantExpr::bind(m
);
650 PyAffineDimExpr::bind(m
);
651 PyAffineSymbolExpr::bind(m
);
652 PyAffineBinaryExpr::bind(m
);
653 PyAffineAddExpr::bind(m
);
654 PyAffineMulExpr::bind(m
);
655 PyAffineModExpr::bind(m
);
656 PyAffineFloorDivExpr::bind(m
);
657 PyAffineCeilDivExpr::bind(m
);
659 //----------------------------------------------------------------------------
660 // Mapping of PyAffineMap.
661 //----------------------------------------------------------------------------
662 py::class_
<PyAffineMap
>(m
, "AffineMap", py::module_local())
663 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR
,
664 &PyAffineMap::getCapsule
)
665 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyAffineMap::createFromCapsule
)
667 [](PyAffineMap
&self
, PyAffineMap
&other
) { return self
== other
; })
668 .def("__eq__", [](PyAffineMap
&self
, py::object
&other
) { return false; })
670 [](PyAffineMap
&self
) {
671 PyPrintAccumulator printAccum
;
672 mlirAffineMapPrint(self
, printAccum
.getCallback(),
673 printAccum
.getUserData());
674 return printAccum
.join();
677 [](PyAffineMap
&self
) {
678 PyPrintAccumulator printAccum
;
679 printAccum
.parts
.append("AffineMap(");
680 mlirAffineMapPrint(self
, printAccum
.getCallback(),
681 printAccum
.getUserData());
682 printAccum
.parts
.append(")");
683 return printAccum
.join();
686 [](PyAffineMap
&self
) {
687 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
689 .def_static("compress_unused_symbols",
690 [](py::list affineMaps
, DefaultingPyMlirContext context
) {
691 SmallVector
<MlirAffineMap
> maps
;
692 pyListToVector
<PyAffineMap
, MlirAffineMap
>(
693 affineMaps
, maps
, "attempting to create an AffineMap");
694 std::vector
<MlirAffineMap
> compressed(affineMaps
.size());
695 auto populate
= [](void *result
, intptr_t idx
,
697 static_cast<MlirAffineMap
*>(result
)[idx
] = (m
);
699 mlirAffineMapCompressUnusedSymbols(
700 maps
.data(), maps
.size(), compressed
.data(), populate
);
701 std::vector
<PyAffineMap
> res
;
702 res
.reserve(compressed
.size());
703 for (auto m
: compressed
)
704 res
.emplace_back(context
->getRef(), m
);
707 .def_property_readonly(
709 [](PyAffineMap
&self
) { return self
.getContext().getObject(); },
710 "Context that owns the Affine Map")
712 "dump", [](PyAffineMap
&self
) { mlirAffineMapDump(self
); },
716 [](intptr_t dimCount
, intptr_t symbolCount
, py::list exprs
,
717 DefaultingPyMlirContext context
) {
718 SmallVector
<MlirAffineExpr
> affineExprs
;
719 pyListToVector
<PyAffineExpr
, MlirAffineExpr
>(
720 exprs
, affineExprs
, "attempting to create an AffineMap");
722 mlirAffineMapGet(context
->get(), dimCount
, symbolCount
,
723 affineExprs
.size(), affineExprs
.data());
724 return PyAffineMap(context
->getRef(), map
);
726 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
727 py::arg("context") = py::none(),
728 "Gets a map with the given expressions as results.")
731 [](intptr_t value
, DefaultingPyMlirContext context
) {
732 MlirAffineMap affineMap
=
733 mlirAffineMapConstantGet(context
->get(), value
);
734 return PyAffineMap(context
->getRef(), affineMap
);
736 py::arg("value"), py::arg("context") = py::none(),
737 "Gets an affine map with a single constant result")
740 [](DefaultingPyMlirContext context
) {
741 MlirAffineMap affineMap
= mlirAffineMapEmptyGet(context
->get());
742 return PyAffineMap(context
->getRef(), affineMap
);
744 py::arg("context") = py::none(), "Gets an empty affine map.")
747 [](intptr_t nDims
, DefaultingPyMlirContext context
) {
748 MlirAffineMap affineMap
=
749 mlirAffineMapMultiDimIdentityGet(context
->get(), nDims
);
750 return PyAffineMap(context
->getRef(), affineMap
);
752 py::arg("n_dims"), py::arg("context") = py::none(),
753 "Gets an identity map with the given number of dimensions.")
755 "get_minor_identity",
756 [](intptr_t nDims
, intptr_t nResults
,
757 DefaultingPyMlirContext context
) {
758 MlirAffineMap affineMap
=
759 mlirAffineMapMinorIdentityGet(context
->get(), nDims
, nResults
);
760 return PyAffineMap(context
->getRef(), affineMap
);
762 py::arg("n_dims"), py::arg("n_results"),
763 py::arg("context") = py::none(),
764 "Gets a minor identity map with the given number of dimensions and "
768 [](std::vector
<unsigned> permutation
,
769 DefaultingPyMlirContext context
) {
770 if (!isPermutation(permutation
))
771 throw py::cast_error("Invalid permutation when attempting to "
772 "create an AffineMap");
773 MlirAffineMap affineMap
= mlirAffineMapPermutationGet(
774 context
->get(), permutation
.size(), permutation
.data());
775 return PyAffineMap(context
->getRef(), affineMap
);
777 py::arg("permutation"), py::arg("context") = py::none(),
778 "Gets an affine map that permutes its inputs.")
781 [](PyAffineMap
&self
, std::vector
<intptr_t> &resultPos
) {
782 intptr_t numResults
= mlirAffineMapGetNumResults(self
);
783 for (intptr_t pos
: resultPos
) {
784 if (pos
< 0 || pos
>= numResults
)
785 throw py::value_error("result position out of bounds");
787 MlirAffineMap affineMap
= mlirAffineMapGetSubMap(
788 self
, resultPos
.size(), resultPos
.data());
789 return PyAffineMap(self
.getContext(), affineMap
);
791 py::arg("result_positions"))
794 [](PyAffineMap
&self
, intptr_t nResults
) {
795 if (nResults
>= mlirAffineMapGetNumResults(self
))
796 throw py::value_error("number of results out of bounds");
797 MlirAffineMap affineMap
=
798 mlirAffineMapGetMajorSubMap(self
, nResults
);
799 return PyAffineMap(self
.getContext(), affineMap
);
801 py::arg("n_results"))
804 [](PyAffineMap
&self
, intptr_t nResults
) {
805 if (nResults
>= mlirAffineMapGetNumResults(self
))
806 throw py::value_error("number of results out of bounds");
807 MlirAffineMap affineMap
=
808 mlirAffineMapGetMinorSubMap(self
, nResults
);
809 return PyAffineMap(self
.getContext(), affineMap
);
811 py::arg("n_results"))
814 [](PyAffineMap
&self
, PyAffineExpr
&expression
,
815 PyAffineExpr
&replacement
, intptr_t numResultDims
,
816 intptr_t numResultSyms
) {
817 MlirAffineMap affineMap
= mlirAffineMapReplace(
818 self
, expression
, replacement
, numResultDims
, numResultSyms
);
819 return PyAffineMap(self
.getContext(), affineMap
);
821 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
822 py::arg("n_result_syms"))
823 .def_property_readonly(
825 [](PyAffineMap
&self
) { return mlirAffineMapIsPermutation(self
); })
826 .def_property_readonly("is_projected_permutation",
827 [](PyAffineMap
&self
) {
828 return mlirAffineMapIsProjectedPermutation(self
);
830 .def_property_readonly(
832 [](PyAffineMap
&self
) { return mlirAffineMapGetNumDims(self
); })
833 .def_property_readonly(
835 [](PyAffineMap
&self
) { return mlirAffineMapGetNumInputs(self
); })
836 .def_property_readonly(
838 [](PyAffineMap
&self
) { return mlirAffineMapGetNumSymbols(self
); })
839 .def_property_readonly("results", [](PyAffineMap
&self
) {
840 return PyAffineMapExprList(self
);
842 PyAffineMapExprList::bind(m
);
844 //----------------------------------------------------------------------------
845 // Mapping of PyIntegerSet.
846 //----------------------------------------------------------------------------
847 py::class_
<PyIntegerSet
>(m
, "IntegerSet", py::module_local())
848 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR
,
849 &PyIntegerSet::getCapsule
)
850 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyIntegerSet::createFromCapsule
)
851 .def("__eq__", [](PyIntegerSet
&self
,
852 PyIntegerSet
&other
) { return self
== other
; })
853 .def("__eq__", [](PyIntegerSet
&self
, py::object other
) { return false; })
855 [](PyIntegerSet
&self
) {
856 PyPrintAccumulator printAccum
;
857 mlirIntegerSetPrint(self
, printAccum
.getCallback(),
858 printAccum
.getUserData());
859 return printAccum
.join();
862 [](PyIntegerSet
&self
) {
863 PyPrintAccumulator printAccum
;
864 printAccum
.parts
.append("IntegerSet(");
865 mlirIntegerSetPrint(self
, printAccum
.getCallback(),
866 printAccum
.getUserData());
867 printAccum
.parts
.append(")");
868 return printAccum
.join();
871 [](PyIntegerSet
&self
) {
872 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
874 .def_property_readonly(
876 [](PyIntegerSet
&self
) { return self
.getContext().getObject(); })
878 "dump", [](PyIntegerSet
&self
) { mlirIntegerSetDump(self
); },
882 [](intptr_t numDims
, intptr_t numSymbols
, py::list exprs
,
883 std::vector
<bool> eqFlags
, DefaultingPyMlirContext context
) {
884 if (exprs
.size() != eqFlags
.size())
885 throw py::value_error(
886 "Expected the number of constraints to match "
887 "that of equality flags");
889 throw py::value_error("Expected non-empty list of constraints");
891 // Copy over to a SmallVector because std::vector has a
892 // specialization for booleans that packs data and does not
893 // expose a `bool *`.
894 SmallVector
<bool, 8> flags(eqFlags
.begin(), eqFlags
.end());
896 SmallVector
<MlirAffineExpr
> affineExprs
;
897 pyListToVector
<PyAffineExpr
>(exprs
, affineExprs
,
898 "attempting to create an IntegerSet");
899 MlirIntegerSet set
= mlirIntegerSetGet(
900 context
->get(), numDims
, numSymbols
, exprs
.size(),
901 affineExprs
.data(), flags
.data());
902 return PyIntegerSet(context
->getRef(), set
);
904 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
905 py::arg("eq_flags"), py::arg("context") = py::none())
908 [](intptr_t numDims
, intptr_t numSymbols
,
909 DefaultingPyMlirContext context
) {
911 mlirIntegerSetEmptyGet(context
->get(), numDims
, numSymbols
);
912 return PyIntegerSet(context
->getRef(), set
);
914 py::arg("num_dims"), py::arg("num_symbols"),
915 py::arg("context") = py::none())
918 [](PyIntegerSet
&self
, py::list dimExprs
, py::list symbolExprs
,
919 intptr_t numResultDims
, intptr_t numResultSymbols
) {
920 if (static_cast<intptr_t>(dimExprs
.size()) !=
921 mlirIntegerSetGetNumDims(self
))
922 throw py::value_error(
923 "Expected the number of dimension replacement expressions "
924 "to match that of dimensions");
925 if (static_cast<intptr_t>(symbolExprs
.size()) !=
926 mlirIntegerSetGetNumSymbols(self
))
927 throw py::value_error(
928 "Expected the number of symbol replacement expressions "
929 "to match that of symbols");
931 SmallVector
<MlirAffineExpr
> dimAffineExprs
, symbolAffineExprs
;
932 pyListToVector
<PyAffineExpr
>(
933 dimExprs
, dimAffineExprs
,
934 "attempting to create an IntegerSet by replacing dimensions");
935 pyListToVector
<PyAffineExpr
>(
936 symbolExprs
, symbolAffineExprs
,
937 "attempting to create an IntegerSet by replacing symbols");
938 MlirIntegerSet set
= mlirIntegerSetReplaceGet(
939 self
, dimAffineExprs
.data(), symbolAffineExprs
.data(),
940 numResultDims
, numResultSymbols
);
941 return PyIntegerSet(self
.getContext(), set
);
943 py::arg("dim_exprs"), py::arg("symbol_exprs"),
944 py::arg("num_result_dims"), py::arg("num_result_symbols"))
945 .def_property_readonly("is_canonical_empty",
946 [](PyIntegerSet
&self
) {
947 return mlirIntegerSetIsCanonicalEmpty(self
);
949 .def_property_readonly(
951 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumDims(self
); })
952 .def_property_readonly(
954 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumSymbols(self
); })
955 .def_property_readonly(
957 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumInputs(self
); })
958 .def_property_readonly("n_equalities",
959 [](PyIntegerSet
&self
) {
960 return mlirIntegerSetGetNumEqualities(self
);
962 .def_property_readonly("n_inequalities",
963 [](PyIntegerSet
&self
) {
964 return mlirIntegerSetGetNumInequalities(self
);
966 .def_property_readonly("constraints", [](PyIntegerSet
&self
) {
967 return PyIntegerSetConstraintList(self
);
969 PyIntegerSetConstraint::bind(m
);
970 PyIntegerSetConstraintList::bind(m
);