1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
17 #include "NanobindUtils.h"
18 #include "mlir-c/AffineExpr.h"
19 #include "mlir-c/AffineMap.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir/Bindings/Python/Nanobind.h"
22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Twine.h"
29 namespace nb
= nanobind
;
31 using namespace mlir::python
;
33 using llvm::SmallVector
;
34 using llvm::StringRef
;
37 static const char kDumpDocstring
[] =
38 R
"(Dumps a debug representation of the object to stderr.)";
40 /// Attempts to populate `result` with the content of `list` casted to the
41 /// appropriate type (Python and C types are provided as template arguments).
42 /// Throws errors in case of failure, using "action" to describe what the caller
43 /// was attempting to do.
44 template <typename PyType
, typename CType
>
45 static void pyListToVector(const nb::list
&list
,
46 llvm::SmallVectorImpl
<CType
> &result
,
48 result
.reserve(nb::len(list
));
49 for (nb::handle item
: list
) {
51 result
.push_back(nb::cast
<PyType
>(item
));
52 } catch (nb::cast_error
&err
) {
53 std::string msg
= (llvm::Twine("Invalid expression when ") + action
+
54 " (" + err
.what() + ")")
56 throw std::runtime_error(msg
.c_str());
57 } catch (std::runtime_error
&err
) {
58 std::string msg
= (llvm::Twine("Invalid expression (None?) when ") +
59 action
+ " (" + err
.what() + ")")
61 throw std::runtime_error(msg
.c_str());
66 template <typename PermutationTy
>
67 static bool isPermutation(std::vector
<PermutationTy
> permutation
) {
68 llvm::SmallVector
<bool, 8> seen(permutation
.size(), false);
69 for (auto val
: permutation
) {
70 if (val
< permutation
.size()) {
83 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
84 /// and should be castable from it. Intermediate hierarchy classes can be
85 /// modeled by specifying BaseTy.
86 template <typename DerivedTy
, typename BaseTy
= PyAffineExpr
>
87 class PyConcreteAffineExpr
: public BaseTy
{
89 // Derived classes must define statics for:
90 // IsAFunctionTy isaFunction
91 // const char *pyClassName
92 // and redefine bindDerived.
93 using ClassTy
= nb::class_
<DerivedTy
, BaseTy
>;
94 using IsAFunctionTy
= bool (*)(MlirAffineExpr
);
96 PyConcreteAffineExpr() = default;
97 PyConcreteAffineExpr(PyMlirContextRef contextRef
, MlirAffineExpr affineExpr
)
98 : BaseTy(std::move(contextRef
), affineExpr
) {}
99 PyConcreteAffineExpr(PyAffineExpr
&orig
)
100 : PyConcreteAffineExpr(orig
.getContext(), castFrom(orig
)) {}
102 static MlirAffineExpr
castFrom(PyAffineExpr
&orig
) {
103 if (!DerivedTy::isaFunction(orig
)) {
104 auto origRepr
= nb::cast
<std::string
>(nb::repr(nb::cast(orig
)));
105 throw nb::value_error((Twine("Cannot cast affine expression to ") +
106 DerivedTy::pyClassName
+ " (from " + origRepr
+
114 static void bind(nb::module_
&m
) {
115 auto cls
= ClassTy(m
, DerivedTy::pyClassName
);
116 cls
.def(nb::init
<PyAffineExpr
&>(), nb::arg("expr"));
119 [](PyAffineExpr
&otherAffineExpr
) -> bool {
120 return DerivedTy::isaFunction(otherAffineExpr
);
123 DerivedTy::bindDerived(cls
);
126 /// Implemented by derived classes to add methods to the Python subclass.
127 static void bindDerived(ClassTy
&m
) {}
130 class PyAffineConstantExpr
: public PyConcreteAffineExpr
<PyAffineConstantExpr
> {
132 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAConstant
;
133 static constexpr const char *pyClassName
= "AffineConstantExpr";
134 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
136 static PyAffineConstantExpr
get(intptr_t value
,
137 DefaultingPyMlirContext context
) {
138 MlirAffineExpr affineExpr
=
139 mlirAffineConstantExprGet(context
->get(), static_cast<int64_t>(value
));
140 return PyAffineConstantExpr(context
->getRef(), affineExpr
);
143 static void bindDerived(ClassTy
&c
) {
144 c
.def_static("get", &PyAffineConstantExpr::get
, nb::arg("value"),
145 nb::arg("context").none() = nb::none());
146 c
.def_prop_ro("value", [](PyAffineConstantExpr
&self
) {
147 return mlirAffineConstantExprGetValue(self
);
152 class PyAffineDimExpr
: public PyConcreteAffineExpr
<PyAffineDimExpr
> {
154 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsADim
;
155 static constexpr const char *pyClassName
= "AffineDimExpr";
156 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
158 static PyAffineDimExpr
get(intptr_t pos
, DefaultingPyMlirContext context
) {
159 MlirAffineExpr affineExpr
= mlirAffineDimExprGet(context
->get(), pos
);
160 return PyAffineDimExpr(context
->getRef(), affineExpr
);
163 static void bindDerived(ClassTy
&c
) {
164 c
.def_static("get", &PyAffineDimExpr::get
, nb::arg("position"),
165 nb::arg("context").none() = nb::none());
166 c
.def_prop_ro("position", [](PyAffineDimExpr
&self
) {
167 return mlirAffineDimExprGetPosition(self
);
172 class PyAffineSymbolExpr
: public PyConcreteAffineExpr
<PyAffineSymbolExpr
> {
174 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsASymbol
;
175 static constexpr const char *pyClassName
= "AffineSymbolExpr";
176 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
178 static PyAffineSymbolExpr
get(intptr_t pos
, DefaultingPyMlirContext context
) {
179 MlirAffineExpr affineExpr
= mlirAffineSymbolExprGet(context
->get(), pos
);
180 return PyAffineSymbolExpr(context
->getRef(), affineExpr
);
183 static void bindDerived(ClassTy
&c
) {
184 c
.def_static("get", &PyAffineSymbolExpr::get
, nb::arg("position"),
185 nb::arg("context").none() = nb::none());
186 c
.def_prop_ro("position", [](PyAffineSymbolExpr
&self
) {
187 return mlirAffineSymbolExprGetPosition(self
);
192 class PyAffineBinaryExpr
: public PyConcreteAffineExpr
<PyAffineBinaryExpr
> {
194 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsABinary
;
195 static constexpr const char *pyClassName
= "AffineBinaryExpr";
196 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
199 MlirAffineExpr lhsExpr
= mlirAffineBinaryOpExprGetLHS(get());
200 return PyAffineExpr(getContext(), lhsExpr
);
204 MlirAffineExpr rhsExpr
= mlirAffineBinaryOpExprGetRHS(get());
205 return PyAffineExpr(getContext(), rhsExpr
);
208 static void bindDerived(ClassTy
&c
) {
209 c
.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs
);
210 c
.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs
);
214 class PyAffineAddExpr
215 : public PyConcreteAffineExpr
<PyAffineAddExpr
, PyAffineBinaryExpr
> {
217 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAAdd
;
218 static constexpr const char *pyClassName
= "AffineAddExpr";
219 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
221 static PyAffineAddExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
222 MlirAffineExpr expr
= mlirAffineAddExprGet(lhs
, rhs
);
223 return PyAffineAddExpr(lhs
.getContext(), expr
);
226 static PyAffineAddExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
227 MlirAffineExpr expr
= mlirAffineAddExprGet(
228 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
229 return PyAffineAddExpr(lhs
.getContext(), expr
);
232 static PyAffineAddExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
233 MlirAffineExpr expr
= mlirAffineAddExprGet(
234 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
235 return PyAffineAddExpr(rhs
.getContext(), expr
);
238 static void bindDerived(ClassTy
&c
) {
239 c
.def_static("get", &PyAffineAddExpr::get
);
243 class PyAffineMulExpr
244 : public PyConcreteAffineExpr
<PyAffineMulExpr
, PyAffineBinaryExpr
> {
246 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAMul
;
247 static constexpr const char *pyClassName
= "AffineMulExpr";
248 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
250 static PyAffineMulExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
251 MlirAffineExpr expr
= mlirAffineMulExprGet(lhs
, rhs
);
252 return PyAffineMulExpr(lhs
.getContext(), expr
);
255 static PyAffineMulExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
256 MlirAffineExpr expr
= mlirAffineMulExprGet(
257 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
258 return PyAffineMulExpr(lhs
.getContext(), expr
);
261 static PyAffineMulExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
262 MlirAffineExpr expr
= mlirAffineMulExprGet(
263 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
264 return PyAffineMulExpr(rhs
.getContext(), expr
);
267 static void bindDerived(ClassTy
&c
) {
268 c
.def_static("get", &PyAffineMulExpr::get
);
272 class PyAffineModExpr
273 : public PyConcreteAffineExpr
<PyAffineModExpr
, PyAffineBinaryExpr
> {
275 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAMod
;
276 static constexpr const char *pyClassName
= "AffineModExpr";
277 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
279 static PyAffineModExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
280 MlirAffineExpr expr
= mlirAffineModExprGet(lhs
, rhs
);
281 return PyAffineModExpr(lhs
.getContext(), expr
);
284 static PyAffineModExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
285 MlirAffineExpr expr
= mlirAffineModExprGet(
286 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
287 return PyAffineModExpr(lhs
.getContext(), expr
);
290 static PyAffineModExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
291 MlirAffineExpr expr
= mlirAffineModExprGet(
292 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
293 return PyAffineModExpr(rhs
.getContext(), expr
);
296 static void bindDerived(ClassTy
&c
) {
297 c
.def_static("get", &PyAffineModExpr::get
);
301 class PyAffineFloorDivExpr
302 : public PyConcreteAffineExpr
<PyAffineFloorDivExpr
, PyAffineBinaryExpr
> {
304 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsAFloorDiv
;
305 static constexpr const char *pyClassName
= "AffineFloorDivExpr";
306 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
308 static PyAffineFloorDivExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
309 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(lhs
, rhs
);
310 return PyAffineFloorDivExpr(lhs
.getContext(), expr
);
313 static PyAffineFloorDivExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
314 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(
315 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
316 return PyAffineFloorDivExpr(lhs
.getContext(), expr
);
319 static PyAffineFloorDivExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
320 MlirAffineExpr expr
= mlirAffineFloorDivExprGet(
321 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
322 return PyAffineFloorDivExpr(rhs
.getContext(), expr
);
325 static void bindDerived(ClassTy
&c
) {
326 c
.def_static("get", &PyAffineFloorDivExpr::get
);
330 class PyAffineCeilDivExpr
331 : public PyConcreteAffineExpr
<PyAffineCeilDivExpr
, PyAffineBinaryExpr
> {
333 static constexpr IsAFunctionTy isaFunction
= mlirAffineExprIsACeilDiv
;
334 static constexpr const char *pyClassName
= "AffineCeilDivExpr";
335 using PyConcreteAffineExpr::PyConcreteAffineExpr
;
337 static PyAffineCeilDivExpr
get(PyAffineExpr lhs
, const PyAffineExpr
&rhs
) {
338 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(lhs
, rhs
);
339 return PyAffineCeilDivExpr(lhs
.getContext(), expr
);
342 static PyAffineCeilDivExpr
getRHSConstant(PyAffineExpr lhs
, intptr_t rhs
) {
343 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(
344 lhs
, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs
), rhs
));
345 return PyAffineCeilDivExpr(lhs
.getContext(), expr
);
348 static PyAffineCeilDivExpr
getLHSConstant(intptr_t lhs
, PyAffineExpr rhs
) {
349 MlirAffineExpr expr
= mlirAffineCeilDivExprGet(
350 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs
), lhs
), rhs
);
351 return PyAffineCeilDivExpr(rhs
.getContext(), expr
);
354 static void bindDerived(ClassTy
&c
) {
355 c
.def_static("get", &PyAffineCeilDivExpr::get
);
361 bool PyAffineExpr::operator==(const PyAffineExpr
&other
) const {
362 return mlirAffineExprEqual(affineExpr
, other
.affineExpr
);
365 nb::object
PyAffineExpr::getCapsule() {
366 return nb::steal
<nb::object
>(mlirPythonAffineExprToCapsule(*this));
369 PyAffineExpr
PyAffineExpr::createFromCapsule(nb::object capsule
) {
370 MlirAffineExpr rawAffineExpr
= mlirPythonCapsuleToAffineExpr(capsule
.ptr());
371 if (mlirAffineExprIsNull(rawAffineExpr
))
372 throw nb::python_error();
374 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr
)),
378 //------------------------------------------------------------------------------
379 // PyAffineMap and utilities.
380 //------------------------------------------------------------------------------
383 /// A list of expressions contained in an affine map. Internally these are
384 /// stored as a consecutive array leading to inexpensive random access. Both
385 /// the map and the expression are owned by the context so we need not bother
386 /// with lifetime extension.
387 class PyAffineMapExprList
388 : public Sliceable
<PyAffineMapExprList
, PyAffineExpr
> {
390 static constexpr const char *pyClassName
= "AffineExprList";
392 PyAffineMapExprList(const PyAffineMap
&map
, intptr_t startIndex
= 0,
393 intptr_t length
= -1, intptr_t step
= 1)
394 : Sliceable(startIndex
,
395 length
== -1 ? mlirAffineMapGetNumResults(map
) : length
,
400 /// Give the parent CRTP class access to hook implementations below.
401 friend class Sliceable
<PyAffineMapExprList
, PyAffineExpr
>;
403 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap
); }
405 PyAffineExpr
getRawElement(intptr_t pos
) {
406 return PyAffineExpr(affineMap
.getContext(),
407 mlirAffineMapGetResult(affineMap
, pos
));
410 PyAffineMapExprList
slice(intptr_t startIndex
, intptr_t length
,
412 return PyAffineMapExprList(affineMap
, startIndex
, length
, step
);
415 PyAffineMap affineMap
;
419 bool PyAffineMap::operator==(const PyAffineMap
&other
) const {
420 return mlirAffineMapEqual(affineMap
, other
.affineMap
);
423 nb::object
PyAffineMap::getCapsule() {
424 return nb::steal
<nb::object
>(mlirPythonAffineMapToCapsule(*this));
427 PyAffineMap
PyAffineMap::createFromCapsule(nb::object capsule
) {
428 MlirAffineMap rawAffineMap
= mlirPythonCapsuleToAffineMap(capsule
.ptr());
429 if (mlirAffineMapIsNull(rawAffineMap
))
430 throw nb::python_error();
432 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap
)),
436 //------------------------------------------------------------------------------
437 // PyIntegerSet and utilities.
438 //------------------------------------------------------------------------------
441 class PyIntegerSetConstraint
{
443 PyIntegerSetConstraint(PyIntegerSet set
, intptr_t pos
)
444 : set(std::move(set
)), pos(pos
) {}
446 PyAffineExpr
getExpr() {
447 return PyAffineExpr(set
.getContext(),
448 mlirIntegerSetGetConstraint(set
, pos
));
451 bool isEq() { return mlirIntegerSetIsConstraintEq(set
, pos
); }
453 static void bind(nb::module_
&m
) {
454 nb::class_
<PyIntegerSetConstraint
>(m
, "IntegerSetConstraint")
455 .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr
)
456 .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq
);
464 class PyIntegerSetConstraintList
465 : public Sliceable
<PyIntegerSetConstraintList
, PyIntegerSetConstraint
> {
467 static constexpr const char *pyClassName
= "IntegerSetConstraintList";
469 PyIntegerSetConstraintList(const PyIntegerSet
&set
, intptr_t startIndex
= 0,
470 intptr_t length
= -1, intptr_t step
= 1)
471 : Sliceable(startIndex
,
472 length
== -1 ? mlirIntegerSetGetNumConstraints(set
) : length
,
477 /// Give the parent CRTP class access to hook implementations below.
478 friend class Sliceable
<PyIntegerSetConstraintList
, PyIntegerSetConstraint
>;
480 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set
); }
482 PyIntegerSetConstraint
getRawElement(intptr_t pos
) {
483 return PyIntegerSetConstraint(set
, pos
);
486 PyIntegerSetConstraintList
slice(intptr_t startIndex
, intptr_t length
,
488 return PyIntegerSetConstraintList(set
, startIndex
, length
, step
);
495 bool PyIntegerSet::operator==(const PyIntegerSet
&other
) const {
496 return mlirIntegerSetEqual(integerSet
, other
.integerSet
);
499 nb::object
PyIntegerSet::getCapsule() {
500 return nb::steal
<nb::object
>(mlirPythonIntegerSetToCapsule(*this));
503 PyIntegerSet
PyIntegerSet::createFromCapsule(nb::object capsule
) {
504 MlirIntegerSet rawIntegerSet
= mlirPythonCapsuleToIntegerSet(capsule
.ptr());
505 if (mlirIntegerSetIsNull(rawIntegerSet
))
506 throw nb::python_error();
508 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet
)),
512 void mlir::python::populateIRAffine(nb::module_
&m
) {
513 //----------------------------------------------------------------------------
514 // Mapping of PyAffineExpr and derived classes.
515 //----------------------------------------------------------------------------
516 nb::class_
<PyAffineExpr
>(m
, "AffineExpr")
517 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR
, &PyAffineExpr::getCapsule
)
518 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyAffineExpr::createFromCapsule
)
519 .def("__add__", &PyAffineAddExpr::get
)
520 .def("__add__", &PyAffineAddExpr::getRHSConstant
)
521 .def("__radd__", &PyAffineAddExpr::getRHSConstant
)
522 .def("__mul__", &PyAffineMulExpr::get
)
523 .def("__mul__", &PyAffineMulExpr::getRHSConstant
)
524 .def("__rmul__", &PyAffineMulExpr::getRHSConstant
)
525 .def("__mod__", &PyAffineModExpr::get
)
526 .def("__mod__", &PyAffineModExpr::getRHSConstant
)
528 [](PyAffineExpr
&self
, intptr_t other
) {
529 return PyAffineModExpr::get(
530 PyAffineConstantExpr::get(other
, *self
.getContext().get()),
534 [](PyAffineExpr
&self
, PyAffineExpr
&other
) {
536 PyAffineConstantExpr::get(-1, *self
.getContext().get());
537 return PyAffineAddExpr::get(self
,
538 PyAffineMulExpr::get(negOne
, other
));
541 [](PyAffineExpr
&self
, intptr_t other
) {
542 return PyAffineAddExpr::get(
544 PyAffineConstantExpr::get(-other
, *self
.getContext().get()));
547 [](PyAffineExpr
&self
, intptr_t other
) {
548 return PyAffineAddExpr::getLHSConstant(
549 other
, PyAffineMulExpr::getLHSConstant(-1, self
));
551 .def("__eq__", [](PyAffineExpr
&self
,
552 PyAffineExpr
&other
) { return self
== other
; })
554 [](PyAffineExpr
&self
, nb::object
&other
) { return false; })
556 [](PyAffineExpr
&self
) {
557 PyPrintAccumulator printAccum
;
558 mlirAffineExprPrint(self
, printAccum
.getCallback(),
559 printAccum
.getUserData());
560 return printAccum
.join();
563 [](PyAffineExpr
&self
) {
564 PyPrintAccumulator printAccum
;
565 printAccum
.parts
.append("AffineExpr(");
566 mlirAffineExprPrint(self
, printAccum
.getCallback(),
567 printAccum
.getUserData());
568 printAccum
.parts
.append(")");
569 return printAccum
.join();
572 [](PyAffineExpr
&self
) {
573 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
577 [](PyAffineExpr
&self
) { return self
.getContext().getObject(); })
579 [](PyAffineExpr
&self
, PyAffineMap
&other
) {
580 return PyAffineExpr(self
.getContext(),
581 mlirAffineExprCompose(self
, other
));
584 "get_add", &PyAffineAddExpr::get
,
585 "Gets an affine expression containing a sum of two expressions.")
586 .def_static("get_add", &PyAffineAddExpr::getLHSConstant
,
587 "Gets an affine expression containing a sum of a constant "
588 "and another expression.")
589 .def_static("get_add", &PyAffineAddExpr::getRHSConstant
,
590 "Gets an affine expression containing a sum of an expression "
593 "get_mul", &PyAffineMulExpr::get
,
594 "Gets an affine expression containing a product of two expressions.")
595 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant
,
596 "Gets an affine expression containing a product of a "
597 "constant and another expression.")
598 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant
,
599 "Gets an affine expression containing a product of an "
600 "expression and a constant.")
601 .def_static("get_mod", &PyAffineModExpr::get
,
602 "Gets an affine expression containing the modulo of dividing "
603 "one expression by another.")
604 .def_static("get_mod", &PyAffineModExpr::getLHSConstant
,
605 "Gets a semi-affine expression containing the modulo of "
606 "dividing a constant by an expression.")
607 .def_static("get_mod", &PyAffineModExpr::getRHSConstant
,
608 "Gets an affine expression containing the module of dividing"
609 "an expression by a constant.")
610 .def_static("get_floor_div", &PyAffineFloorDivExpr::get
,
611 "Gets an affine expression containing the rounded-down "
612 "result of dividing one expression by another.")
613 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant
,
614 "Gets a semi-affine expression containing the rounded-down "
615 "result of dividing a constant by an expression.")
616 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant
,
617 "Gets an affine expression containing the rounded-down "
618 "result of dividing an expression by a constant.")
619 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get
,
620 "Gets an affine expression containing the rounded-up result "
621 "of dividing one expression by another.")
622 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant
,
623 "Gets a semi-affine expression containing the rounded-up "
624 "result of dividing a constant by an expression.")
625 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant
,
626 "Gets an affine expression containing the rounded-up result "
627 "of dividing an expression by a constant.")
628 .def_static("get_constant", &PyAffineConstantExpr::get
, nb::arg("value"),
629 nb::arg("context").none() = nb::none(),
630 "Gets a constant affine expression with the given value.")
632 "get_dim", &PyAffineDimExpr::get
, nb::arg("position"),
633 nb::arg("context").none() = nb::none(),
634 "Gets an affine expression of a dimension at the given position.")
636 "get_symbol", &PyAffineSymbolExpr::get
, nb::arg("position"),
637 nb::arg("context").none() = nb::none(),
638 "Gets an affine expression of a symbol at the given position.")
640 "dump", [](PyAffineExpr
&self
) { mlirAffineExprDump(self
); },
642 PyAffineConstantExpr::bind(m
);
643 PyAffineDimExpr::bind(m
);
644 PyAffineSymbolExpr::bind(m
);
645 PyAffineBinaryExpr::bind(m
);
646 PyAffineAddExpr::bind(m
);
647 PyAffineMulExpr::bind(m
);
648 PyAffineModExpr::bind(m
);
649 PyAffineFloorDivExpr::bind(m
);
650 PyAffineCeilDivExpr::bind(m
);
652 //----------------------------------------------------------------------------
653 // Mapping of PyAffineMap.
654 //----------------------------------------------------------------------------
655 nb::class_
<PyAffineMap
>(m
, "AffineMap")
656 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR
, &PyAffineMap::getCapsule
)
657 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyAffineMap::createFromCapsule
)
659 [](PyAffineMap
&self
, PyAffineMap
&other
) { return self
== other
; })
660 .def("__eq__", [](PyAffineMap
&self
, nb::object
&other
) { return false; })
662 [](PyAffineMap
&self
) {
663 PyPrintAccumulator printAccum
;
664 mlirAffineMapPrint(self
, printAccum
.getCallback(),
665 printAccum
.getUserData());
666 return printAccum
.join();
669 [](PyAffineMap
&self
) {
670 PyPrintAccumulator printAccum
;
671 printAccum
.parts
.append("AffineMap(");
672 mlirAffineMapPrint(self
, printAccum
.getCallback(),
673 printAccum
.getUserData());
674 printAccum
.parts
.append(")");
675 return printAccum
.join();
678 [](PyAffineMap
&self
) {
679 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
681 .def_static("compress_unused_symbols",
682 [](nb::list affineMaps
, DefaultingPyMlirContext context
) {
683 SmallVector
<MlirAffineMap
> maps
;
684 pyListToVector
<PyAffineMap
, MlirAffineMap
>(
685 affineMaps
, maps
, "attempting to create an AffineMap");
686 std::vector
<MlirAffineMap
> compressed(affineMaps
.size());
687 auto populate
= [](void *result
, intptr_t idx
,
689 static_cast<MlirAffineMap
*>(result
)[idx
] = (m
);
691 mlirAffineMapCompressUnusedSymbols(
692 maps
.data(), maps
.size(), compressed
.data(), populate
);
693 std::vector
<PyAffineMap
> res
;
694 res
.reserve(compressed
.size());
695 for (auto m
: compressed
)
696 res
.emplace_back(context
->getRef(), m
);
701 [](PyAffineMap
&self
) { return self
.getContext().getObject(); },
702 "Context that owns the Affine Map")
704 "dump", [](PyAffineMap
&self
) { mlirAffineMapDump(self
); },
708 [](intptr_t dimCount
, intptr_t symbolCount
, nb::list exprs
,
709 DefaultingPyMlirContext context
) {
710 SmallVector
<MlirAffineExpr
> affineExprs
;
711 pyListToVector
<PyAffineExpr
, MlirAffineExpr
>(
712 exprs
, affineExprs
, "attempting to create an AffineMap");
714 mlirAffineMapGet(context
->get(), dimCount
, symbolCount
,
715 affineExprs
.size(), affineExprs
.data());
716 return PyAffineMap(context
->getRef(), map
);
718 nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
719 nb::arg("context").none() = nb::none(),
720 "Gets a map with the given expressions as results.")
723 [](intptr_t value
, DefaultingPyMlirContext context
) {
724 MlirAffineMap affineMap
=
725 mlirAffineMapConstantGet(context
->get(), value
);
726 return PyAffineMap(context
->getRef(), affineMap
);
728 nb::arg("value"), nb::arg("context").none() = nb::none(),
729 "Gets an affine map with a single constant result")
732 [](DefaultingPyMlirContext context
) {
733 MlirAffineMap affineMap
= mlirAffineMapEmptyGet(context
->get());
734 return PyAffineMap(context
->getRef(), affineMap
);
736 nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
739 [](intptr_t nDims
, DefaultingPyMlirContext context
) {
740 MlirAffineMap affineMap
=
741 mlirAffineMapMultiDimIdentityGet(context
->get(), nDims
);
742 return PyAffineMap(context
->getRef(), affineMap
);
744 nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
745 "Gets an identity map with the given number of dimensions.")
747 "get_minor_identity",
748 [](intptr_t nDims
, intptr_t nResults
,
749 DefaultingPyMlirContext context
) {
750 MlirAffineMap affineMap
=
751 mlirAffineMapMinorIdentityGet(context
->get(), nDims
, nResults
);
752 return PyAffineMap(context
->getRef(), affineMap
);
754 nb::arg("n_dims"), nb::arg("n_results"),
755 nb::arg("context").none() = nb::none(),
756 "Gets a minor identity map with the given number of dimensions and "
760 [](std::vector
<unsigned> permutation
,
761 DefaultingPyMlirContext context
) {
762 if (!isPermutation(permutation
))
763 throw std::runtime_error("Invalid permutation when attempting to "
764 "create an AffineMap");
765 MlirAffineMap affineMap
= mlirAffineMapPermutationGet(
766 context
->get(), permutation
.size(), permutation
.data());
767 return PyAffineMap(context
->getRef(), affineMap
);
769 nb::arg("permutation"), nb::arg("context").none() = nb::none(),
770 "Gets an affine map that permutes its inputs.")
773 [](PyAffineMap
&self
, std::vector
<intptr_t> &resultPos
) {
774 intptr_t numResults
= mlirAffineMapGetNumResults(self
);
775 for (intptr_t pos
: resultPos
) {
776 if (pos
< 0 || pos
>= numResults
)
777 throw nb::value_error("result position out of bounds");
779 MlirAffineMap affineMap
= mlirAffineMapGetSubMap(
780 self
, resultPos
.size(), resultPos
.data());
781 return PyAffineMap(self
.getContext(), affineMap
);
783 nb::arg("result_positions"))
786 [](PyAffineMap
&self
, intptr_t nResults
) {
787 if (nResults
>= mlirAffineMapGetNumResults(self
))
788 throw nb::value_error("number of results out of bounds");
789 MlirAffineMap affineMap
=
790 mlirAffineMapGetMajorSubMap(self
, nResults
);
791 return PyAffineMap(self
.getContext(), affineMap
);
793 nb::arg("n_results"))
796 [](PyAffineMap
&self
, intptr_t nResults
) {
797 if (nResults
>= mlirAffineMapGetNumResults(self
))
798 throw nb::value_error("number of results out of bounds");
799 MlirAffineMap affineMap
=
800 mlirAffineMapGetMinorSubMap(self
, nResults
);
801 return PyAffineMap(self
.getContext(), affineMap
);
803 nb::arg("n_results"))
806 [](PyAffineMap
&self
, PyAffineExpr
&expression
,
807 PyAffineExpr
&replacement
, intptr_t numResultDims
,
808 intptr_t numResultSyms
) {
809 MlirAffineMap affineMap
= mlirAffineMapReplace(
810 self
, expression
, replacement
, numResultDims
, numResultSyms
);
811 return PyAffineMap(self
.getContext(), affineMap
);
813 nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
814 nb::arg("n_result_syms"))
817 [](PyAffineMap
&self
) { return mlirAffineMapIsPermutation(self
); })
818 .def_prop_ro("is_projected_permutation",
819 [](PyAffineMap
&self
) {
820 return mlirAffineMapIsProjectedPermutation(self
);
824 [](PyAffineMap
&self
) { return mlirAffineMapGetNumDims(self
); })
827 [](PyAffineMap
&self
) { return mlirAffineMapGetNumInputs(self
); })
830 [](PyAffineMap
&self
) { return mlirAffineMapGetNumSymbols(self
); })
831 .def_prop_ro("results",
832 [](PyAffineMap
&self
) { return PyAffineMapExprList(self
); });
833 PyAffineMapExprList::bind(m
);
835 //----------------------------------------------------------------------------
836 // Mapping of PyIntegerSet.
837 //----------------------------------------------------------------------------
838 nb::class_
<PyIntegerSet
>(m
, "IntegerSet")
839 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR
, &PyIntegerSet::getCapsule
)
840 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyIntegerSet::createFromCapsule
)
841 .def("__eq__", [](PyIntegerSet
&self
,
842 PyIntegerSet
&other
) { return self
== other
; })
843 .def("__eq__", [](PyIntegerSet
&self
, nb::object other
) { return false; })
845 [](PyIntegerSet
&self
) {
846 PyPrintAccumulator printAccum
;
847 mlirIntegerSetPrint(self
, printAccum
.getCallback(),
848 printAccum
.getUserData());
849 return printAccum
.join();
852 [](PyIntegerSet
&self
) {
853 PyPrintAccumulator printAccum
;
854 printAccum
.parts
.append("IntegerSet(");
855 mlirIntegerSetPrint(self
, printAccum
.getCallback(),
856 printAccum
.getUserData());
857 printAccum
.parts
.append(")");
858 return printAccum
.join();
861 [](PyIntegerSet
&self
) {
862 return static_cast<size_t>(llvm::hash_value(self
.get().ptr
));
866 [](PyIntegerSet
&self
) { return self
.getContext().getObject(); })
868 "dump", [](PyIntegerSet
&self
) { mlirIntegerSetDump(self
); },
872 [](intptr_t numDims
, intptr_t numSymbols
, nb::list exprs
,
873 std::vector
<bool> eqFlags
, DefaultingPyMlirContext context
) {
874 if (exprs
.size() != eqFlags
.size())
875 throw nb::value_error(
876 "Expected the number of constraints to match "
877 "that of equality flags");
878 if (exprs
.size() == 0)
879 throw nb::value_error("Expected non-empty list of constraints");
881 // Copy over to a SmallVector because std::vector has a
882 // specialization for booleans that packs data and does not
883 // expose a `bool *`.
884 SmallVector
<bool, 8> flags(eqFlags
.begin(), eqFlags
.end());
886 SmallVector
<MlirAffineExpr
> affineExprs
;
887 pyListToVector
<PyAffineExpr
>(exprs
, affineExprs
,
888 "attempting to create an IntegerSet");
889 MlirIntegerSet set
= mlirIntegerSetGet(
890 context
->get(), numDims
, numSymbols
, exprs
.size(),
891 affineExprs
.data(), flags
.data());
892 return PyIntegerSet(context
->getRef(), set
);
894 nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
895 nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
898 [](intptr_t numDims
, intptr_t numSymbols
,
899 DefaultingPyMlirContext context
) {
901 mlirIntegerSetEmptyGet(context
->get(), numDims
, numSymbols
);
902 return PyIntegerSet(context
->getRef(), set
);
904 nb::arg("num_dims"), nb::arg("num_symbols"),
905 nb::arg("context").none() = nb::none())
908 [](PyIntegerSet
&self
, nb::list dimExprs
, nb::list symbolExprs
,
909 intptr_t numResultDims
, intptr_t numResultSymbols
) {
910 if (static_cast<intptr_t>(dimExprs
.size()) !=
911 mlirIntegerSetGetNumDims(self
))
912 throw nb::value_error(
913 "Expected the number of dimension replacement expressions "
914 "to match that of dimensions");
915 if (static_cast<intptr_t>(symbolExprs
.size()) !=
916 mlirIntegerSetGetNumSymbols(self
))
917 throw nb::value_error(
918 "Expected the number of symbol replacement expressions "
919 "to match that of symbols");
921 SmallVector
<MlirAffineExpr
> dimAffineExprs
, symbolAffineExprs
;
922 pyListToVector
<PyAffineExpr
>(
923 dimExprs
, dimAffineExprs
,
924 "attempting to create an IntegerSet by replacing dimensions");
925 pyListToVector
<PyAffineExpr
>(
926 symbolExprs
, symbolAffineExprs
,
927 "attempting to create an IntegerSet by replacing symbols");
928 MlirIntegerSet set
= mlirIntegerSetReplaceGet(
929 self
, dimAffineExprs
.data(), symbolAffineExprs
.data(),
930 numResultDims
, numResultSymbols
);
931 return PyIntegerSet(self
.getContext(), set
);
933 nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
934 nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
935 .def_prop_ro("is_canonical_empty",
936 [](PyIntegerSet
&self
) {
937 return mlirIntegerSetIsCanonicalEmpty(self
);
941 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumDims(self
); })
944 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumSymbols(self
); })
947 [](PyIntegerSet
&self
) { return mlirIntegerSetGetNumInputs(self
); })
948 .def_prop_ro("n_equalities",
949 [](PyIntegerSet
&self
) {
950 return mlirIntegerSetGetNumEqualities(self
);
952 .def_prop_ro("n_inequalities",
953 [](PyIntegerSet
&self
) {
954 return mlirIntegerSetGetNumInequalities(self
);
956 .def_prop_ro("constraints", [](PyIntegerSet
&self
) {
957 return PyIntegerSetConstraintList(self
);
959 PyIntegerSetConstraint::bind(m
);
960 PyIntegerSetConstraintList::bind(m
);