1 //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains method definitions for `SparseTensorStorageBase`.
10 // In particular we want to ensure that the default implementations of
11 // the "partial method specialization" trick aren't inline (since there's
14 //===----------------------------------------------------------------------===//
16 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
18 using namespace mlir::sparse_tensor
;
20 static inline bool isAllDense(uint64_t lvlRank
, const LevelType
*lvlTypes
) {
21 for (uint64_t l
= 0; l
< lvlRank
; l
++)
22 if (!isDenseLT(lvlTypes
[l
]))
27 SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
28 uint64_t dimRank
, const uint64_t *dimSizes
, uint64_t lvlRank
,
29 const uint64_t *lvlSizes
, const LevelType
*lvlTypes
,
30 const uint64_t *dim2lvl
, const uint64_t *lvl2dim
)
31 : dimSizes(dimSizes
, dimSizes
+ dimRank
),
32 lvlSizes(lvlSizes
, lvlSizes
+ lvlRank
),
33 lvlTypes(lvlTypes
, lvlTypes
+ lvlRank
),
34 dim2lvlVec(dim2lvl
, dim2lvl
+ lvlRank
),
35 lvl2dimVec(lvl2dim
, lvl2dim
+ dimRank
),
36 map(dimRank
, lvlRank
, dim2lvlVec
.data(), lvl2dimVec
.data()),
37 allDense(isAllDense(lvlRank
, lvlTypes
)) {
38 assert(dimSizes
&& lvlSizes
&& lvlTypes
&& dim2lvl
&& lvl2dim
);
39 // Validate dim-indexed parameters.
40 assert(dimRank
> 0 && "Trivial shape is unsupported");
41 for (uint64_t d
= 0; d
< dimRank
; d
++)
42 assert(dimSizes
[d
] > 0 && "Dimension size zero has trivial storage");
43 // Validate lvl-indexed parameters.
44 assert(lvlRank
> 0 && "Trivial shape is unsupported");
45 for (uint64_t l
= 0; l
< lvlRank
; l
++) {
46 assert(lvlSizes
[l
] > 0 && "Level size zero has trivial storage");
47 assert(isDenseLvl(l
) || isCompressedLvl(l
) || isLooseCompressedLvl(l
) ||
48 isSingletonLvl(l
) || isNOutOfMLvl(l
));
52 // Helper macro for wrong "partial method specialization" errors.
53 #define FATAL_PIV(NAME) \
54 fprintf(stderr, "<P,I,V> type mismatch for: " #NAME); \
57 #define IMPL_GETPOSITIONS(PNAME, P) \
58 void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) { \
59 FATAL_PIV("getPositions" #PNAME); \
61 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS
)
62 #undef IMPL_GETPOSITIONS
64 #define IMPL_GETCOORDINATES(CNAME, C) \
65 void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) { \
66 FATAL_PIV("getCoordinates" #CNAME); \
68 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES
)
69 #undef IMPL_GETCOORDINATES
71 #define IMPL_GETCOORDINATESBUFFER(CNAME, C) \
72 void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \
74 FATAL_PIV("getCoordinatesBuffer" #CNAME); \
76 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER
)
77 #undef IMPL_GETCOORDINATESBUFFER
79 #define IMPL_GETVALUES(VNAME, V) \
80 void SparseTensorStorageBase::getValues(std::vector<V> **) { \
81 FATAL_PIV("getValues" #VNAME); \
83 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES
)
86 #define IMPL_LEXINSERT(VNAME, V) \
87 void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \
88 FATAL_PIV("lexInsert" #VNAME); \
90 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT
)
93 #define IMPL_EXPINSERT(VNAME, V) \
94 void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
95 uint64_t, uint64_t) { \
96 FATAL_PIV("expInsert" #VNAME); \
98 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT
)