1 //===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains method definitions for `SparseTensorStorageBase`.
10 // In particular we want to ensure that the default implementations of
11 // the "partial method specialization" trick aren't inline (since there's
14 //===----------------------------------------------------------------------===//
16 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
18 using namespace mlir::sparse_tensor
;
20 SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
21 uint64_t dimRank
, const uint64_t *dimSizes
, uint64_t lvlRank
,
22 const uint64_t *lvlSizes
, const DimLevelType
*lvlTypes
,
23 const uint64_t *dim2lvl
, const uint64_t *lvl2dim
)
24 : dimSizes(dimSizes
, dimSizes
+ dimRank
),
25 lvlSizes(lvlSizes
, lvlSizes
+ lvlRank
),
26 lvlTypes(lvlTypes
, lvlTypes
+ lvlRank
),
27 dim2lvlVec(dim2lvl
, dim2lvl
+ lvlRank
),
28 lvl2dimVec(lvl2dim
, lvl2dim
+ dimRank
),
29 map(dimRank
, lvlRank
, dim2lvlVec
.data(), lvl2dimVec
.data()) {
30 assert(dimSizes
&& lvlSizes
&& lvlTypes
&& dim2lvl
&& lvl2dim
);
31 // Validate dim-indexed parameters.
32 assert(dimRank
> 0 && "Trivial shape is unsupported");
33 for (uint64_t d
= 0; d
< dimRank
; ++d
)
34 assert(dimSizes
[d
] > 0 && "Dimension size zero has trivial storage");
35 // Validate lvl-indexed parameters.
36 assert(lvlRank
> 0 && "Trivial shape is unsupported");
37 for (uint64_t l
= 0; l
< lvlRank
; ++l
) {
38 assert(lvlSizes
[l
] > 0 && "Level size zero has trivial storage");
39 assert(isDenseLvl(l
) || isCompressedLvl(l
) || isLooseCompressedLvl(l
) ||
40 isSingletonLvl(l
) || is2OutOf4Lvl(l
));
44 // Helper macro for wrong "partial method specialization" errors.
45 #define FATAL_PIV(NAME) \
46 MLIR_SPARSETENSOR_FATAL("<P,I,V> type mismatch for: " #NAME);
48 #define IMPL_GETPOSITIONS(PNAME, P) \
49 void SparseTensorStorageBase::getPositions(std::vector<P> **, uint64_t) { \
50 FATAL_PIV("getPositions" #PNAME); \
52 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS
)
53 #undef IMPL_GETPOSITIONS
55 #define IMPL_GETCOORDINATES(CNAME, C) \
56 void SparseTensorStorageBase::getCoordinates(std::vector<C> **, uint64_t) { \
57 FATAL_PIV("getCoordinates" #CNAME); \
59 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES
)
60 #undef IMPL_GETCOORDINATES
62 #define IMPL_GETVALUES(VNAME, V) \
63 void SparseTensorStorageBase::getValues(std::vector<V> **) { \
64 FATAL_PIV("getValues" #VNAME); \
66 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES
)
69 #define IMPL_FORWARDINGINSERT(VNAME, V) \
70 void SparseTensorStorageBase::forwardingInsert(const uint64_t *, V) { \
71 FATAL_PIV("forwardingInsert" #VNAME); \
73 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT
)
74 #undef IMPL_FORWARDINGINSERT
76 #define IMPL_LEXINSERT(VNAME, V) \
77 void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \
78 FATAL_PIV("lexInsert" #VNAME); \
80 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT
)
83 #define IMPL_EXPINSERT(VNAME, V) \
84 void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \
85 uint64_t, uint64_t) { \
86 FATAL_PIV("expInsert" #VNAME); \
88 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT
)