1 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2 # See https://llvm.org/LICENSE.txt for license information.
3 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 # This file contains the utilities to process sparse tensor outputs.
7 from typing
import Tuple
13 @functools.lru_cache()
14 def _get_c_shared_lib(lib_name
: str) -> ctypes
.CDLL
:
15 """Loads and returns the requested C shared library.
18 lib_name: A string representing the C shared library.
24 OSError: If there is any problem in loading the shared library.
25 ValueError: If the shared library doesn't contain the needed routines.
27 # This raises OSError exception if there is any problem in loading the shared
29 c_lib
= ctypes
.CDLL(lib_name
)
32 c_lib
.convertToMLIRSparseTensor
.restype
= ctypes
.c_void_p
33 except Exception as e
:
34 raise ValueError("Missing function convertToMLIRSparseTensor from "
35 f
"the supporting C shared library: {e} ") from e
38 c_lib
.convertFromMLIRSparseTensor
.restype
= ctypes
.c_void_p
39 except Exception as e
:
40 raise ValueError("Missing function convertFromMLIRSparseTensor from "
41 f
"the C shared library: {e} ") from e
46 def sparse_tensor_to_coo_tensor(
48 sparse_tensor
: ctypes
.c_void_p
,
50 ) -> Tuple
[int, int, np
.ndarray
, np
.ndarray
, np
.ndarray
]:
51 """Converts an MLIR sparse tensor to a COO-flavored format tensor.
54 lib_name: A string for the supporting C shared library.
55 sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
56 dtype: The numpy data type for the tensor elements.
59 A tuple that contains the following values for the COO-flavored format
61 rank: An integer for the rank of the tensor.
62 nse: An interger for the number of non-zero values in the tensor.
63 shape: A 1D numpy array of integers, for the shape of the tensor.
64 values: A 1D numpy array, for the non-zero values in the tensor.
65 indices: A 2D numpy array of integers, representing the indices for the
66 non-zero values in the tensor.
69 OSError: If there is any problem in loading the shared library.
70 ValueError: If the shared library doesn't contain the needed routines.
72 c_lib
= _get_c_shared_lib(lib_name
)
74 rank
= ctypes
.c_ulonglong(0)
75 nse
= ctypes
.c_ulonglong(0)
76 shape
= ctypes
.POINTER(ctypes
.c_ulonglong
)()
77 values
= ctypes
.POINTER(np
.ctypeslib
.as_ctypes_type(dtype
))()
78 indices
= ctypes
.POINTER(ctypes
.c_ulonglong
)()
79 c_lib
.convertFromMLIRSparseTensor(sparse_tensor
, ctypes
.byref(rank
),
80 ctypes
.byref(nse
), ctypes
.byref(shape
),
81 ctypes
.byref(values
), ctypes
.byref(indices
))
83 # Convert the returned values to the corresponding numpy types.
84 shape
= np
.ctypeslib
.as_array(shape
, shape
=[rank
.value
])
85 values
= np
.ctypeslib
.as_array(values
, shape
=[nse
.value
])
86 indices
= np
.ctypeslib
.as_array(indices
, shape
=[nse
.value
, rank
.value
])
87 return rank
, nse
, shape
, values
, indices
90 def coo_tensor_to_sparse_tensor(lib_name
: str, np_shape
: np
.ndarray
,
91 np_values
: np
.ndarray
,
92 np_indices
: np
.ndarray
) -> int:
93 """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
96 lib_name: A string for the supporting C shared library.
97 np_shape: A 1D numpy array of integers, for the shape of the tensor.
98 np_values: A 1D numpy array, for the non-zero values in the tensor.
99 np_indices: A 2D numpy array of integers, representing the indices for the
100 non-zero values in the tensor.
103 An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
107 OSError: If there is any problem in loading the shared library.
108 ValueError: If the shared library doesn't contain the needed routines.
111 rank
= ctypes
.c_ulonglong(len(np_shape
))
112 nse
= ctypes
.c_ulonglong(len(np_values
))
113 shape
= np_shape
.ctypes
.data_as(ctypes
.POINTER(ctypes
.c_ulonglong
))
114 values
= np_values
.ctypes
.data_as(
115 ctypes
.POINTER(np
.ctypeslib
.as_ctypes_type(np_values
.dtype
)))
116 indices
= np_indices
.ctypes
.data_as(ctypes
.POINTER(ctypes
.c_ulonglong
))
118 c_lib
= _get_c_shared_lib(lib_name
)
119 ptr
= c_lib
.convertToMLIRSparseTensor(rank
, nse
, shape
, values
, indices
)
120 assert ptr
is not None, "Problem with calling convertToMLIRSparseTensor"