1 //===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/AffineMap.h"
10 #include "mlir-c/Dialect/SparseTensor.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
14 #include <pybind11/cast.h>
15 #include <pybind11/detail/common.h>
16 #include <pybind11/pybind11.h>
17 #include <pybind11/pytypes.h>
20 namespace py
= pybind11
;
23 using namespace mlir::python::adaptors
;
25 static void populateDialectSparseTensorSubmodule(const py::module
&m
) {
26 py::enum_
<MlirSparseTensorLevelFormat
>(m
, "LevelFormat", py::module_local())
27 .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE
)
28 .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M
)
29 .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED
)
30 .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON
)
31 .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED
);
33 py::enum_
<MlirSparseTensorLevelPropertyNondefault
>(m
, "LevelProperty",
35 .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED
)
36 .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE
)
37 .value("soa", MLIR_SPARSE_PROPERTY_SOA
);
39 mlir_attribute_subclass(m
, "EncodingAttr",
40 mlirAttributeIsASparseTensorEncodingAttr
)
43 [](py::object cls
, std::vector
<MlirSparseTensorLevelType
> lvlTypes
,
44 std::optional
<MlirAffineMap
> dimToLvl
,
45 std::optional
<MlirAffineMap
> lvlToDim
, int posWidth
, int crdWidth
,
46 std::optional
<MlirAttribute
> explicitVal
,
47 std::optional
<MlirAttribute
> implicitVal
, MlirContext context
) {
48 return cls(mlirSparseTensorEncodingAttrGet(
49 context
, lvlTypes
.size(), lvlTypes
.data(),
50 dimToLvl
? *dimToLvl
: MlirAffineMap
{nullptr},
51 lvlToDim
? *lvlToDim
: MlirAffineMap
{nullptr}, posWidth
,
52 crdWidth
, explicitVal
? *explicitVal
: MlirAttribute
{nullptr},
53 implicitVal
? *implicitVal
: MlirAttribute
{nullptr}));
55 py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
56 py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
57 py::arg("explicit_val") = py::none(),
58 py::arg("implicit_val") = py::none(), py::arg("context") = py::none(),
59 "Gets a sparse_tensor.encoding from parameters.")
62 [](py::object cls
, MlirSparseTensorLevelFormat lvlFmt
,
63 const std::vector
<MlirSparseTensorLevelPropertyNondefault
>
65 unsigned n
, unsigned m
) {
66 return mlirSparseTensorEncodingAttrBuildLvlType(
67 lvlFmt
, properties
.data(), properties
.size(), n
, m
);
69 py::arg("cls"), py::arg("lvl_fmt"),
70 py::arg("properties") =
71 std::vector
<MlirSparseTensorLevelPropertyNondefault
>(),
72 py::arg("n") = 0, py::arg("m") = 0,
73 "Builds a sparse_tensor.encoding.level_type from parameters.")
74 .def_property_readonly(
76 [](MlirAttribute self
) {
77 const int lvlRank
= mlirSparseTensorEncodingGetLvlRank(self
);
78 std::vector
<MlirSparseTensorLevelType
> ret
;
80 for (int l
= 0; l
< lvlRank
; ++l
)
81 ret
.push_back(mlirSparseTensorEncodingAttrGetLvlType(self
, l
));
84 .def_property_readonly(
86 [](MlirAttribute self
) -> std::optional
<MlirAffineMap
> {
87 MlirAffineMap ret
= mlirSparseTensorEncodingAttrGetDimToLvl(self
);
88 if (mlirAffineMapIsNull(ret
))
92 .def_property_readonly(
94 [](MlirAttribute self
) -> std::optional
<MlirAffineMap
> {
95 MlirAffineMap ret
= mlirSparseTensorEncodingAttrGetLvlToDim(self
);
96 if (mlirAffineMapIsNull(ret
))
100 .def_property_readonly("pos_width",
101 mlirSparseTensorEncodingAttrGetPosWidth
)
102 .def_property_readonly("crd_width",
103 mlirSparseTensorEncodingAttrGetCrdWidth
)
104 .def_property_readonly(
106 [](MlirAttribute self
) -> std::optional
<MlirAttribute
> {
108 mlirSparseTensorEncodingAttrGetExplicitVal(self
);
109 if (mlirAttributeIsNull(ret
))
113 .def_property_readonly(
115 [](MlirAttribute self
) -> std::optional
<MlirAttribute
> {
117 mlirSparseTensorEncodingAttrGetImplicitVal(self
);
118 if (mlirAttributeIsNull(ret
))
122 .def_property_readonly(
124 [](MlirAttribute self
) -> unsigned {
125 const int lvlRank
= mlirSparseTensorEncodingGetLvlRank(self
);
126 return mlirSparseTensorEncodingAttrGetStructuredN(
127 mlirSparseTensorEncodingAttrGetLvlType(self
, lvlRank
- 1));
129 .def_property_readonly(
131 [](MlirAttribute self
) -> unsigned {
132 const int lvlRank
= mlirSparseTensorEncodingGetLvlRank(self
);
133 return mlirSparseTensorEncodingAttrGetStructuredM(
134 mlirSparseTensorEncodingAttrGetLvlType(self
, lvlRank
- 1));
136 .def_property_readonly("lvl_formats_enum", [](MlirAttribute self
) {
137 const int lvlRank
= mlirSparseTensorEncodingGetLvlRank(self
);
138 std::vector
<MlirSparseTensorLevelFormat
> ret
;
139 ret
.reserve(lvlRank
);
140 for (int l
= 0; l
< lvlRank
; l
++)
141 ret
.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self
, l
));
146 PYBIND11_MODULE(_mlirDialectsSparseTensor
, m
) {
147 m
.doc() = "MLIR SparseTensor dialect.";
148 populateDialectSparseTensorSubmodule(m
);