1 //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a light-weight runtime support library for
10 // manipulating sparse tensors from MLIR. More specifically, it provides
11 // C-API wrappers so that MLIR-generated code can call into the C++ runtime
12 // support library. The functionality provided in this library is meant
13 // to simplify benchmarking, testing, and debugging of MLIR code operating
14 // on sparse tensors. However, the provided functionality is **not**
15 // part of core MLIR itself.
17 // The following memory-resident sparse storage schemes are supported:
19 // (a) A coordinate scheme for temporarily storing and lexicographically
20 // sorting a sparse tensor by coordinate (SparseTensorCOO).
22 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
23 // per-dimension sparse/dense annnotations together with a dimension
24 // ordering used by MLIR compiler-generated code (SparseTensorStorage).
26 // The following external formats are supported:
28 // (1) Matrix Market Exchange (MME): *.mtx
29 // https://math.nist.gov/MatrixMarket/formats.html
31 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
32 // http://frostt.io/tensors/file-formats.html
34 // Two public APIs are supported:
36 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
37 // tensors. These methods should be used exclusively by MLIR
38 // compiler-generated code.
40 // (II) Methods that accept C-style data structures to interact with sparse
41 // tensors. These methods can be used by any external runtime that wants
42 // to interact with MLIR compiler-generated code.
44 // In both cases (I) and (II), the SparseTensorStorage format is externally
45 // only visible as an opaque pointer.
47 //===----------------------------------------------------------------------===//
49 #include "mlir/ExecutionEngine/SparseTensorRuntime.h"
51 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
53 #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
54 #include "mlir/ExecutionEngine/SparseTensor/COO.h"
55 #include "mlir/ExecutionEngine/SparseTensor/File.h"
56 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
61 using namespace mlir::sparse_tensor
;
63 //===----------------------------------------------------------------------===//
65 // Utilities for manipulating `StridedMemRefType`.
67 //===----------------------------------------------------------------------===//
71 #define ASSERT_NO_STRIDE(MEMREF) \
73 assert((MEMREF) && "Memref is nullptr"); \
74 assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
77 #define MEMREF_GET_USIZE(MEMREF) \
78 detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
80 #define ASSERT_USIZE_EQ(MEMREF, SZ) \
81 assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \
82 "Memref size mismatch")
84 #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
86 /// Initializes the memref with the provided size and data pointer. This
87 /// is designed for functions which want to "return" a memref that aliases
88 /// into memory owned by some other object (e.g., `SparseTensorStorage`),
89 /// without doing any actual copying. (The "return" is in scarequotes
90 /// because the `_mlir_ciface_` calling convention migrates any returned
91 /// memrefs into an out-parameter passed before all the other function
93 template <typename DataSizeT
, typename T
>
94 static inline void aliasIntoMemref(DataSizeT size
, T
*data
,
95 StridedMemRefType
<T
, 1> &ref
) {
96 ref
.basePtr
= ref
.data
= data
;
98 using MemrefSizeT
= std::remove_reference_t
<decltype(ref
.sizes
[0])>;
99 ref
.sizes
[0] = detail::checkOverflowCast
<MemrefSizeT
>(size
);
103 } // anonymous namespace
107 //===----------------------------------------------------------------------===//
109 // Public functions which operate on MLIR buffers (memrefs) to interact
110 // with sparse tensors (which are only visible as opaque pointers externally).
112 //===----------------------------------------------------------------------===//
114 #define CASE(p, c, v, P, C, V) \
115 if (posTp == (p) && crdTp == (c) && valTp == (v)) { \
117 case Action::kEmpty: { \
118 return SparseTensorStorage<P, C, V>::newEmpty( \
119 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \
121 case Action::kFromReader: { \
122 assert(ptr && "Received nullptr for SparseTensorReader object"); \
123 SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
124 return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
125 lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
127 case Action::kPack: { \
128 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
129 intptr_t *buffers = static_cast<intptr_t *>(ptr); \
130 return SparseTensorStorage<P, C, V>::newFromBuffers( \
131 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
134 case Action::kSortCOOInPlace: { \
135 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
136 auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
137 tensor.sortInPlace(); \
141 fprintf(stderr, "unknown action %d\n", static_cast<uint32_t>(action)); \
145 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
147 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
148 // can safely rewrite kIndex to kU64. We make this assertion to guarantee
149 // that this file cannot get out of sync with its header.
150 static_assert(std::is_same
<index_type
, uint64_t>::value
,
151 "Expected index_type == uint64_t");
153 // The Swiss-army-knife for sparse tensor creation.
154 void *_mlir_ciface_newSparseTensor( // NOLINT
155 StridedMemRefType
<index_type
, 1> *dimSizesRef
,
156 StridedMemRefType
<index_type
, 1> *lvlSizesRef
,
157 StridedMemRefType
<LevelType
, 1> *lvlTypesRef
,
158 StridedMemRefType
<index_type
, 1> *dim2lvlRef
,
159 StridedMemRefType
<index_type
, 1> *lvl2dimRef
, OverheadType posTp
,
160 OverheadType crdTp
, PrimaryType valTp
, Action action
, void *ptr
) {
161 ASSERT_NO_STRIDE(dimSizesRef
);
162 ASSERT_NO_STRIDE(lvlSizesRef
);
163 ASSERT_NO_STRIDE(lvlTypesRef
);
164 ASSERT_NO_STRIDE(dim2lvlRef
);
165 ASSERT_NO_STRIDE(lvl2dimRef
);
166 const uint64_t dimRank
= MEMREF_GET_USIZE(dimSizesRef
);
167 const uint64_t lvlRank
= MEMREF_GET_USIZE(lvlSizesRef
);
168 ASSERT_USIZE_EQ(lvlTypesRef
, lvlRank
);
169 ASSERT_USIZE_EQ(dim2lvlRef
, lvlRank
);
170 ASSERT_USIZE_EQ(lvl2dimRef
, dimRank
);
171 const index_type
*dimSizes
= MEMREF_GET_PAYLOAD(dimSizesRef
);
172 const index_type
*lvlSizes
= MEMREF_GET_PAYLOAD(lvlSizesRef
);
173 const LevelType
*lvlTypes
= MEMREF_GET_PAYLOAD(lvlTypesRef
);
174 const index_type
*dim2lvl
= MEMREF_GET_PAYLOAD(dim2lvlRef
);
175 const index_type
*lvl2dim
= MEMREF_GET_PAYLOAD(lvl2dimRef
);
177 // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
178 // This is safe because of the static_assert above.
179 if (posTp
== OverheadType::kIndex
)
180 posTp
= OverheadType::kU64
;
181 if (crdTp
== OverheadType::kIndex
)
182 crdTp
= OverheadType::kU64
;
184 // Double matrices with all combinations of overhead storage.
185 CASE(OverheadType::kU64
, OverheadType::kU64
, PrimaryType::kF64
, uint64_t,
187 CASE(OverheadType::kU64
, OverheadType::kU32
, PrimaryType::kF64
, uint64_t,
189 CASE(OverheadType::kU64
, OverheadType::kU16
, PrimaryType::kF64
, uint64_t,
191 CASE(OverheadType::kU64
, OverheadType::kU8
, PrimaryType::kF64
, uint64_t,
193 CASE(OverheadType::kU32
, OverheadType::kU64
, PrimaryType::kF64
, uint32_t,
195 CASE(OverheadType::kU32
, OverheadType::kU32
, PrimaryType::kF64
, uint32_t,
197 CASE(OverheadType::kU32
, OverheadType::kU16
, PrimaryType::kF64
, uint32_t,
199 CASE(OverheadType::kU32
, OverheadType::kU8
, PrimaryType::kF64
, uint32_t,
201 CASE(OverheadType::kU16
, OverheadType::kU64
, PrimaryType::kF64
, uint16_t,
203 CASE(OverheadType::kU16
, OverheadType::kU32
, PrimaryType::kF64
, uint16_t,
205 CASE(OverheadType::kU16
, OverheadType::kU16
, PrimaryType::kF64
, uint16_t,
207 CASE(OverheadType::kU16
, OverheadType::kU8
, PrimaryType::kF64
, uint16_t,
209 CASE(OverheadType::kU8
, OverheadType::kU64
, PrimaryType::kF64
, uint8_t,
211 CASE(OverheadType::kU8
, OverheadType::kU32
, PrimaryType::kF64
, uint8_t,
213 CASE(OverheadType::kU8
, OverheadType::kU16
, PrimaryType::kF64
, uint8_t,
215 CASE(OverheadType::kU8
, OverheadType::kU8
, PrimaryType::kF64
, uint8_t,
218 // Float matrices with all combinations of overhead storage.
219 CASE(OverheadType::kU64
, OverheadType::kU64
, PrimaryType::kF32
, uint64_t,
221 CASE(OverheadType::kU64
, OverheadType::kU32
, PrimaryType::kF32
, uint64_t,
223 CASE(OverheadType::kU64
, OverheadType::kU16
, PrimaryType::kF32
, uint64_t,
225 CASE(OverheadType::kU64
, OverheadType::kU8
, PrimaryType::kF32
, uint64_t,
227 CASE(OverheadType::kU32
, OverheadType::kU64
, PrimaryType::kF32
, uint32_t,
229 CASE(OverheadType::kU32
, OverheadType::kU32
, PrimaryType::kF32
, uint32_t,
231 CASE(OverheadType::kU32
, OverheadType::kU16
, PrimaryType::kF32
, uint32_t,
233 CASE(OverheadType::kU32
, OverheadType::kU8
, PrimaryType::kF32
, uint32_t,
235 CASE(OverheadType::kU16
, OverheadType::kU64
, PrimaryType::kF32
, uint16_t,
237 CASE(OverheadType::kU16
, OverheadType::kU32
, PrimaryType::kF32
, uint16_t,
239 CASE(OverheadType::kU16
, OverheadType::kU16
, PrimaryType::kF32
, uint16_t,
241 CASE(OverheadType::kU16
, OverheadType::kU8
, PrimaryType::kF32
, uint16_t,
243 CASE(OverheadType::kU8
, OverheadType::kU64
, PrimaryType::kF32
, uint8_t,
245 CASE(OverheadType::kU8
, OverheadType::kU32
, PrimaryType::kF32
, uint8_t,
247 CASE(OverheadType::kU8
, OverheadType::kU16
, PrimaryType::kF32
, uint8_t,
249 CASE(OverheadType::kU8
, OverheadType::kU8
, PrimaryType::kF32
, uint8_t,
252 // Two-byte floats with both overheads of the same type.
253 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kF16
, uint64_t, f16
);
254 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kBF16
, uint64_t, bf16
);
255 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kF16
, uint32_t, f16
);
256 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kBF16
, uint32_t, bf16
);
257 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kF16
, uint16_t, f16
);
258 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kBF16
, uint16_t, bf16
);
259 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kF16
, uint8_t, f16
);
260 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kBF16
, uint8_t, bf16
);
262 // Integral matrices with both overheads of the same type.
263 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kI64
, uint64_t, int64_t);
264 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kI32
, uint64_t, int32_t);
265 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kI16
, uint64_t, int16_t);
266 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kI8
, uint64_t, int8_t);
267 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kI64
, uint32_t, int64_t);
268 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kI32
, uint32_t, int32_t);
269 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kI16
, uint32_t, int16_t);
270 CASE_SECSAME(OverheadType::kU32
, PrimaryType::kI8
, uint32_t, int8_t);
271 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kI64
, uint16_t, int64_t);
272 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kI32
, uint16_t, int32_t);
273 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kI16
, uint16_t, int16_t);
274 CASE_SECSAME(OverheadType::kU16
, PrimaryType::kI8
, uint16_t, int8_t);
275 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kI64
, uint8_t, int64_t);
276 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kI32
, uint8_t, int32_t);
277 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kI16
, uint8_t, int16_t);
278 CASE_SECSAME(OverheadType::kU8
, PrimaryType::kI8
, uint8_t, int8_t);
280 // Complex matrices with wide overhead.
281 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kC64
, uint64_t, complex64
);
282 CASE_SECSAME(OverheadType::kU64
, PrimaryType::kC32
, uint64_t, complex32
);
284 // Unsupported case (add above if needed).
285 fprintf(stderr
, "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
286 static_cast<int>(posTp
), static_cast<int>(crdTp
),
287 static_cast<int>(valTp
));
293 #define IMPL_SPARSEVALUES(VNAME, V) \
294 void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
296 assert(ref &&tensor); \
298 static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
300 aliasIntoMemref(v->size(), v->data(), *ref); \
302 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES
)
303 #undef IMPL_SPARSEVALUES
305 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
306 void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
308 assert(ref &&tensor); \
309 std::vector<TYPE> *v; \
310 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl); \
312 aliasIntoMemref(v->size(), v->data(), *ref); \
315 #define IMPL_SPARSEPOSITIONS(PNAME, P) \
316 IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
317 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS
)
318 #undef IMPL_SPARSEPOSITIONS
320 #define IMPL_SPARSECOORDINATES(CNAME, C) \
321 IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
322 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES
)
323 #undef IMPL_SPARSECOORDINATES
325 #define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \
326 IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
327 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER
)
328 #undef IMPL_SPARSECOORDINATESBUFFER
330 #undef IMPL_GETOVERHEAD
332 #define IMPL_LEXINSERT(VNAME, V) \
333 void _mlir_ciface_lexInsert##VNAME( \
334 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
335 StridedMemRefType<V, 0> *vref) { \
337 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
338 ASSERT_NO_STRIDE(lvlCoordsRef); \
339 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
341 V *value = MEMREF_GET_PAYLOAD(vref); \
342 tensor.lexInsert(lvlCoords, *value); \
344 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT
)
345 #undef IMPL_LEXINSERT
347 #define IMPL_EXPINSERT(VNAME, V) \
348 void _mlir_ciface_expInsert##VNAME( \
349 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
350 StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
351 StridedMemRefType<index_type, 1> *aref, index_type count) { \
353 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
354 ASSERT_NO_STRIDE(lvlCoordsRef); \
355 ASSERT_NO_STRIDE(vref); \
356 ASSERT_NO_STRIDE(fref); \
357 ASSERT_NO_STRIDE(aref); \
358 ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \
359 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
360 V *values = MEMREF_GET_PAYLOAD(vref); \
361 bool *filled = MEMREF_GET_PAYLOAD(fref); \
362 index_type *added = MEMREF_GET_PAYLOAD(aref); \
363 uint64_t expsz = vref->sizes[0]; \
364 tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \
366 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT
)
367 #undef IMPL_EXPINSERT
369 void *_mlir_ciface_createCheckedSparseTensorReader(
370 char *filename
, StridedMemRefType
<index_type
, 1> *dimShapeRef
,
372 ASSERT_NO_STRIDE(dimShapeRef
);
373 const uint64_t dimRank
= MEMREF_GET_USIZE(dimShapeRef
);
374 const index_type
*dimShape
= MEMREF_GET_PAYLOAD(dimShapeRef
);
375 auto *reader
= SparseTensorReader::create(filename
, dimRank
, dimShape
, valTp
);
376 return static_cast<void *>(reader
);
379 void _mlir_ciface_getSparseTensorReaderDimSizes(
380 StridedMemRefType
<index_type
, 1> *out
, void *p
) {
382 SparseTensorReader
&reader
= *static_cast<SparseTensorReader
*>(p
);
383 auto *dimSizes
= const_cast<uint64_t *>(reader
.getDimSizes());
384 aliasIntoMemref(reader
.getRank(), dimSizes
, *out
);
387 #define IMPL_GETNEXT(VNAME, V, CNAME, C) \
388 bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \
389 void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
390 StridedMemRefType<index_type, 1> *lvl2dimRef, \
391 StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \
393 auto &reader = *static_cast<SparseTensorReader *>(p); \
394 ASSERT_NO_STRIDE(dim2lvlRef); \
395 ASSERT_NO_STRIDE(lvl2dimRef); \
396 ASSERT_NO_STRIDE(cref); \
397 ASSERT_NO_STRIDE(vref); \
398 const uint64_t dimRank = reader.getRank(); \
399 const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef); \
400 const uint64_t cSize = MEMREF_GET_USIZE(cref); \
401 const uint64_t vSize = MEMREF_GET_USIZE(vref); \
402 ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \
403 assert(cSize >= lvlRank * reader.getNSE()); \
404 assert(vSize >= reader.getNSE()); \
408 index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
409 index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \
410 C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \
411 V *values = MEMREF_GET_PAYLOAD(vref); \
412 return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \
413 lvlCoordinates, values); \
415 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT
)
418 void _mlir_ciface_outSparseTensorWriterMetaData(
419 void *p
, index_type dimRank
, index_type nse
,
420 StridedMemRefType
<index_type
, 1> *dimSizesRef
) {
422 ASSERT_NO_STRIDE(dimSizesRef
);
423 assert(dimRank
!= 0);
424 index_type
*dimSizes
= MEMREF_GET_PAYLOAD(dimSizesRef
);
425 std::ostream
&file
= *static_cast<std::ostream
*>(p
);
426 file
<< dimRank
<< " " << nse
<< '\n';
427 for (index_type d
= 0; d
< dimRank
- 1; d
++)
428 file
<< dimSizes
[d
] << " ";
429 file
<< dimSizes
[dimRank
- 1] << '\n';
432 #define IMPL_OUTNEXT(VNAME, V) \
433 void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
434 void *p, index_type dimRank, \
435 StridedMemRefType<index_type, 1> *dimCoordsRef, \
436 StridedMemRefType<V, 0> *vref) { \
438 ASSERT_NO_STRIDE(dimCoordsRef); \
439 const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
440 std::ostream &file = *static_cast<std::ostream *>(p); \
441 for (index_type d = 0; d < dimRank; d++) \
442 file << (dimCoords[d] + 1) << " "; \
443 V *value = MEMREF_GET_PAYLOAD(vref); \
444 file << *value << '\n'; \
446 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT
)
449 //===----------------------------------------------------------------------===//
451 // Public functions which accept only C-style data structures to interact
452 // with sparse tensors (which are only visible as opaque pointers externally).
454 //===----------------------------------------------------------------------===//
456 index_type
sparseLvlSize(void *tensor
, index_type l
) {
457 return static_cast<SparseTensorStorageBase
*>(tensor
)->getLvlSize(l
);
460 index_type
sparseDimSize(void *tensor
, index_type d
) {
461 return static_cast<SparseTensorStorageBase
*>(tensor
)->getDimSize(d
);
464 void endLexInsert(void *tensor
) {
465 return static_cast<SparseTensorStorageBase
*>(tensor
)->endLexInsert();
468 void delSparseTensor(void *tensor
) {
469 delete static_cast<SparseTensorStorageBase
*>(tensor
);
472 char *getTensorFilename(index_type id
) {
473 constexpr size_t bufSize
= 80;
475 snprintf(var
, bufSize
, "TENSOR%" PRIu64
, id
);
476 char *env
= getenv(var
);
478 fprintf(stderr
, "Environment variable %s is not set\n", var
);
484 index_type
getSparseTensorReaderNSE(void *p
) {
485 return static_cast<SparseTensorReader
*>(p
)->getNSE();
488 void delSparseTensorReader(void *p
) {
489 delete static_cast<SparseTensorReader
*>(p
);
492 void *createSparseTensorWriter(char *filename
) {
494 (filename
[0] == 0) ? &std::cout
: new std::ofstream(filename
);
495 *file
<< "# extended FROSTT format\n";
496 return static_cast<void *>(file
);
499 void delSparseTensorWriter(void *p
) {
500 std::ostream
*file
= static_cast<std::ostream
*>(p
);
502 assert(file
->good());
503 if (file
!= &std::cout
)
509 #undef MEMREF_GET_PAYLOAD
510 #undef ASSERT_USIZE_EQ
511 #undef MEMREF_GET_USIZE
512 #undef ASSERT_NO_STRIDE
514 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS