[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / ExecutionEngine / SparseTensorRuntime.cpp
blobf84fdd3964c14f13b8a4f138d609f264b427ed0a
1 //===- SparseTensorRuntime.cpp - SparseTensor runtime support lib ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library for
10 // manipulating sparse tensors from MLIR. More specifically, it provides
11 // C-API wrappers so that MLIR-generated code can call into the C++ runtime
12 // support library. The functionality provided in this library is meant
13 // to simplify benchmarking, testing, and debugging of MLIR code operating
14 // on sparse tensors. However, the provided functionality is **not**
15 // part of core MLIR itself.
17 // The following memory-resident sparse storage schemes are supported:
19 // (a) A coordinate scheme for temporarily storing and lexicographically
20 // sorting a sparse tensor by coordinate (SparseTensorCOO).
22 // (b) A "one-size-fits-all" sparse tensor storage scheme defined by
23 // per-dimension sparse/dense annnotations together with a dimension
24 // ordering used by MLIR compiler-generated code (SparseTensorStorage).
26 // The following external formats are supported:
28 // (1) Matrix Market Exchange (MME): *.mtx
29 // https://math.nist.gov/MatrixMarket/formats.html
31 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
32 // http://frostt.io/tensors/file-formats.html
34 // Two public APIs are supported:
36 // (I) Methods operating on MLIR buffers (memrefs) to interact with sparse
37 // tensors. These methods should be used exclusively by MLIR
38 // compiler-generated code.
40 // (II) Methods that accept C-style data structures to interact with sparse
41 // tensors. These methods can be used by any external runtime that wants
42 // to interact with MLIR compiler-generated code.
44 // In both cases (I) and (II), the SparseTensorStorage format is externally
45 // only visible as an opaque pointer.
47 //===----------------------------------------------------------------------===//
49 #include "mlir/ExecutionEngine/SparseTensorRuntime.h"
51 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
53 #include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
54 #include "mlir/ExecutionEngine/SparseTensor/COO.h"
55 #include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
56 #include "mlir/ExecutionEngine/SparseTensor/File.h"
57 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
59 #include <cstring>
60 #include <numeric>
62 using namespace mlir::sparse_tensor;
64 //===----------------------------------------------------------------------===//
66 // Utilities for manipulating `StridedMemRefType`.
68 //===----------------------------------------------------------------------===//
70 namespace {
72 #define ASSERT_NO_STRIDE(MEMREF) \
73 do { \
74 assert((MEMREF) && "Memref is nullptr"); \
75 assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
76 } while (false)
78 #define MEMREF_GET_USIZE(MEMREF) \
79 detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
81 #define ASSERT_USIZE_EQ(MEMREF, SZ) \
82 assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \
83 "Memref size mismatch")
85 #define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
87 /// Initializes the memref with the provided size and data pointer. This
88 /// is designed for functions which want to "return" a memref that aliases
89 /// into memory owned by some other object (e.g., `SparseTensorStorage`),
90 /// without doing any actual copying. (The "return" is in scarequotes
91 /// because the `_mlir_ciface_` calling convention migrates any returned
92 /// memrefs into an out-parameter passed before all the other function
93 /// parameters.)
94 template <typename DataSizeT, typename T>
95 static inline void aliasIntoMemref(DataSizeT size, T *data,
96 StridedMemRefType<T, 1> &ref) {
97 ref.basePtr = ref.data = data;
98 ref.offset = 0;
99 using MemrefSizeT = std::remove_reference_t<decltype(ref.sizes[0])>;
100 ref.sizes[0] = detail::checkOverflowCast<MemrefSizeT>(size);
101 ref.strides[0] = 1;
104 } // anonymous namespace
106 extern "C" {
108 //===----------------------------------------------------------------------===//
110 // Public functions which operate on MLIR buffers (memrefs) to interact
111 // with sparse tensors (which are only visible as opaque pointers externally).
113 //===----------------------------------------------------------------------===//
115 #define CASE(p, c, v, P, C, V) \
116 if (posTp == (p) && crdTp == (c) && valTp == (v)) { \
117 switch (action) { \
118 case Action::kEmpty: { \
119 return SparseTensorStorage<P, C, V>::newEmpty( \
120 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
121 false); \
123 case Action::kEmptyForward: { \
124 return SparseTensorStorage<P, C, V>::newEmpty( \
125 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
126 true); \
128 case Action::kFromCOO: { \
129 assert(ptr && "Received nullptr for SparseTensorCOO object"); \
130 auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr); \
131 return SparseTensorStorage<P, C, V>::newFromCOO( \
132 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
133 coo); \
135 case Action::kFromReader: { \
136 assert(ptr && "Received nullptr for SparseTensorReader object"); \
137 SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
138 return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
139 lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
141 case Action::kToCOO: { \
142 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
143 auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
144 return tensor.toCOO(); \
146 case Action::kPack: { \
147 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
148 intptr_t *buffers = static_cast<intptr_t *>(ptr); \
149 return SparseTensorStorage<P, C, V>::packFromLvlBuffers( \
150 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
151 dimRank, buffers); \
153 case Action::kSortCOOInPlace: { \
154 assert(ptr && "Received nullptr for SparseTensorStorage object"); \
155 auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
156 tensor.sortInPlace(); \
157 return ptr; \
160 MLIR_SPARSETENSOR_FATAL("unknown action: %d\n", \
161 static_cast<uint32_t>(action)); \
164 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
166 // Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
167 // can safely rewrite kIndex to kU64. We make this assertion to guarantee
168 // that this file cannot get out of sync with its header.
169 static_assert(std::is_same<index_type, uint64_t>::value,
170 "Expected index_type == uint64_t");
172 // The Swiss-army-knife for sparse tensor creation.
173 void *_mlir_ciface_newSparseTensor( // NOLINT
174 StridedMemRefType<index_type, 1> *dimSizesRef,
175 StridedMemRefType<index_type, 1> *lvlSizesRef,
176 StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
177 StridedMemRefType<index_type, 1> *dim2lvlRef,
178 StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
179 OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
180 ASSERT_NO_STRIDE(dimSizesRef);
181 ASSERT_NO_STRIDE(lvlSizesRef);
182 ASSERT_NO_STRIDE(lvlTypesRef);
183 ASSERT_NO_STRIDE(dim2lvlRef);
184 ASSERT_NO_STRIDE(lvl2dimRef);
185 const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
186 const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
187 ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
188 ASSERT_USIZE_EQ(dim2lvlRef, lvlRank);
189 ASSERT_USIZE_EQ(lvl2dimRef, dimRank);
190 const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
191 const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
192 const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
193 const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
194 const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
196 // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
197 // This is safe because of the static_assert above.
198 if (posTp == OverheadType::kIndex)
199 posTp = OverheadType::kU64;
200 if (crdTp == OverheadType::kIndex)
201 crdTp = OverheadType::kU64;
203 // Double matrices with all combinations of overhead storage.
204 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
205 uint64_t, double);
206 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
207 uint32_t, double);
208 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
209 uint16_t, double);
210 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
211 uint8_t, double);
212 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
213 uint64_t, double);
214 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
215 uint32_t, double);
216 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
217 uint16_t, double);
218 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
219 uint8_t, double);
220 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
221 uint64_t, double);
222 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
223 uint32_t, double);
224 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
225 uint16_t, double);
226 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
227 uint8_t, double);
228 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
229 uint64_t, double);
230 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
231 uint32_t, double);
232 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
233 uint16_t, double);
234 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
235 uint8_t, double);
237 // Float matrices with all combinations of overhead storage.
238 CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
239 uint64_t, float);
240 CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
241 uint32_t, float);
242 CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
243 uint16_t, float);
244 CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
245 uint8_t, float);
246 CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
247 uint64_t, float);
248 CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
249 uint32_t, float);
250 CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
251 uint16_t, float);
252 CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
253 uint8_t, float);
254 CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
255 uint64_t, float);
256 CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
257 uint32_t, float);
258 CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
259 uint16_t, float);
260 CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
261 uint8_t, float);
262 CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
263 uint64_t, float);
264 CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
265 uint32_t, float);
266 CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
267 uint16_t, float);
268 CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
269 uint8_t, float);
271 // Two-byte floats with both overheads of the same type.
272 CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
273 CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
274 CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
275 CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
276 CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
277 CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
278 CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
279 CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
281 // Integral matrices with both overheads of the same type.
282 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
283 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
284 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
285 CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
286 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
287 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
288 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
289 CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
290 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
291 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
292 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
293 CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
294 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
295 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
296 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
297 CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
299 // Complex matrices with wide overhead.
300 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
301 CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
303 // Unsupported case (add above if needed).
304 MLIR_SPARSETENSOR_FATAL(
305 "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
306 static_cast<int>(posTp), static_cast<int>(crdTp),
307 static_cast<int>(valTp));
309 #undef CASE
310 #undef CASE_SECSAME
312 #define IMPL_SPARSEVALUES(VNAME, V) \
313 void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
314 void *tensor) { \
315 assert(ref &&tensor); \
316 std::vector<V> *v; \
317 static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
318 assert(v); \
319 aliasIntoMemref(v->size(), v->data(), *ref); \
321 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
322 #undef IMPL_SPARSEVALUES
324 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
325 void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
326 index_type lvl) { \
327 assert(ref &&tensor); \
328 std::vector<TYPE> *v; \
329 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, lvl); \
330 assert(v); \
331 aliasIntoMemref(v->size(), v->data(), *ref); \
333 #define IMPL_SPARSEPOSITIONS(PNAME, P) \
334 IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
335 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
336 #undef IMPL_SPARSEPOSITIONS
338 #define IMPL_SPARSECOORDINATES(CNAME, C) \
339 IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
340 MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
341 #undef IMPL_SPARSECOORDINATES
342 #undef IMPL_GETOVERHEAD
344 #define IMPL_FORWARDINGINSERT(VNAME, V) \
345 void _mlir_ciface_forwardingInsert##VNAME( \
346 void *t, StridedMemRefType<V, 0> *vref, \
347 StridedMemRefType<index_type, 1> *dimCoordsRef) { \
348 assert(t &&vref); \
349 ASSERT_NO_STRIDE(dimCoordsRef); \
350 const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
351 assert(dimCoords); \
352 const V *value = MEMREF_GET_PAYLOAD(vref); \
353 static_cast<SparseTensorStorageBase *>(t)->forwardingInsert(dimCoords, \
354 *value); \
356 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
357 #undef IMPL_FORWARDINGINSERT
359 #define IMPL_LEXINSERT(VNAME, V) \
360 void _mlir_ciface_lexInsert##VNAME( \
361 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
362 StridedMemRefType<V, 0> *vref) { \
363 assert(t &&vref); \
364 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
365 ASSERT_NO_STRIDE(lvlCoordsRef); \
366 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
367 assert(lvlCoords); \
368 V *value = MEMREF_GET_PAYLOAD(vref); \
369 tensor.lexInsert(lvlCoords, *value); \
371 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
372 #undef IMPL_LEXINSERT
374 #define IMPL_EXPINSERT(VNAME, V) \
375 void _mlir_ciface_expInsert##VNAME( \
376 void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
377 StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
378 StridedMemRefType<index_type, 1> *aref, index_type count) { \
379 assert(t); \
380 auto &tensor = *static_cast<SparseTensorStorageBase *>(t); \
381 ASSERT_NO_STRIDE(lvlCoordsRef); \
382 ASSERT_NO_STRIDE(vref); \
383 ASSERT_NO_STRIDE(fref); \
384 ASSERT_NO_STRIDE(aref); \
385 ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \
386 index_type *lvlCoords = MEMREF_GET_PAYLOAD(lvlCoordsRef); \
387 V *values = MEMREF_GET_PAYLOAD(vref); \
388 bool *filled = MEMREF_GET_PAYLOAD(fref); \
389 index_type *added = MEMREF_GET_PAYLOAD(aref); \
390 uint64_t expsz = vref->sizes[0]; \
391 tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \
393 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
394 #undef IMPL_EXPINSERT
396 void *_mlir_ciface_createCheckedSparseTensorReader(
397 char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
398 PrimaryType valTp) {
399 ASSERT_NO_STRIDE(dimShapeRef);
400 const uint64_t dimRank = MEMREF_GET_USIZE(dimShapeRef);
401 const index_type *dimShape = MEMREF_GET_PAYLOAD(dimShapeRef);
402 auto *reader = SparseTensorReader::create(filename, dimRank, dimShape, valTp);
403 return static_cast<void *>(reader);
406 void _mlir_ciface_getSparseTensorReaderDimSizes(
407 StridedMemRefType<index_type, 1> *out, void *p) {
408 assert(out && p);
409 SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
410 auto *dimSizes = const_cast<uint64_t *>(reader.getDimSizes());
411 aliasIntoMemref(reader.getRank(), dimSizes, *out);
414 #define IMPL_GETNEXT(VNAME, V, CNAME, C) \
415 bool _mlir_ciface_getSparseTensorReaderReadToBuffers##CNAME##VNAME( \
416 void *p, StridedMemRefType<index_type, 1> *dim2lvlRef, \
417 StridedMemRefType<index_type, 1> *lvl2dimRef, \
418 StridedMemRefType<C, 1> *cref, StridedMemRefType<V, 1> *vref) { \
419 assert(p); \
420 auto &reader = *static_cast<SparseTensorReader *>(p); \
421 ASSERT_NO_STRIDE(dim2lvlRef); \
422 ASSERT_NO_STRIDE(lvl2dimRef); \
423 ASSERT_NO_STRIDE(cref); \
424 ASSERT_NO_STRIDE(vref); \
425 const uint64_t dimRank = reader.getRank(); \
426 const uint64_t lvlRank = MEMREF_GET_USIZE(dim2lvlRef); \
427 const uint64_t cSize = MEMREF_GET_USIZE(cref); \
428 const uint64_t vSize = MEMREF_GET_USIZE(vref); \
429 ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \
430 assert(cSize >= lvlRank * vSize); \
431 assert(vSize >= reader.getNSE() && "Not enough space in buffers"); \
432 (void)dimRank; \
433 (void)cSize; \
434 (void)vSize; \
435 index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
436 index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef); \
437 C *lvlCoordinates = MEMREF_GET_PAYLOAD(cref); \
438 V *values = MEMREF_GET_PAYLOAD(vref); \
439 return reader.readToBuffers<C, V>(lvlRank, dim2lvl, lvl2dim, \
440 lvlCoordinates, values); \
442 MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
443 #undef IMPL_GETNEXT
445 void _mlir_ciface_outSparseTensorWriterMetaData(
446 void *p, index_type dimRank, index_type nse,
447 StridedMemRefType<index_type, 1> *dimSizesRef) {
448 assert(p);
449 ASSERT_NO_STRIDE(dimSizesRef);
450 assert(dimRank != 0);
451 index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
452 std::ostream &file = *static_cast<std::ostream *>(p);
453 file << dimRank << " " << nse << std::endl;
454 for (index_type d = 0; d < dimRank - 1; d++)
455 file << dimSizes[d] << " ";
456 file << dimSizes[dimRank - 1] << std::endl;
459 #define IMPL_OUTNEXT(VNAME, V) \
460 void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
461 void *p, index_type dimRank, \
462 StridedMemRefType<index_type, 1> *dimCoordsRef, \
463 StridedMemRefType<V, 0> *vref) { \
464 assert(p &&vref); \
465 ASSERT_NO_STRIDE(dimCoordsRef); \
466 const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
467 std::ostream &file = *static_cast<std::ostream *>(p); \
468 for (index_type d = 0; d < dimRank; d++) \
469 file << (dimCoords[d] + 1) << " "; \
470 V *value = MEMREF_GET_PAYLOAD(vref); \
471 file << *value << std::endl; \
473 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
474 #undef IMPL_OUTNEXT
476 //===----------------------------------------------------------------------===//
478 // Public functions which accept only C-style data structures to interact
479 // with sparse tensors (which are only visible as opaque pointers externally).
481 //===----------------------------------------------------------------------===//
483 index_type sparseLvlSize(void *tensor, index_type l) {
484 return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
487 index_type sparseDimSize(void *tensor, index_type d) {
488 return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
491 void endForwardingInsert(void *tensor) {
492 return static_cast<SparseTensorStorageBase *>(tensor)->endForwardingInsert();
495 void endLexInsert(void *tensor) {
496 return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
499 void delSparseTensor(void *tensor) {
500 delete static_cast<SparseTensorStorageBase *>(tensor);
503 #define IMPL_DELCOO(VNAME, V) \
504 void delSparseTensorCOO##VNAME(void *coo) { \
505 delete static_cast<SparseTensorCOO<V> *>(coo); \
507 MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO)
508 #undef IMPL_DELCOO
510 char *getTensorFilename(index_type id) {
511 constexpr size_t BUF_SIZE = 80;
512 char var[BUF_SIZE];
513 snprintf(var, BUF_SIZE, "TENSOR%" PRIu64, id);
514 char *env = getenv(var);
515 if (!env)
516 MLIR_SPARSETENSOR_FATAL("Environment variable %s is not set\n", var);
517 return env;
520 index_type getSparseTensorReaderNSE(void *p) {
521 return static_cast<SparseTensorReader *>(p)->getNSE();
524 void delSparseTensorReader(void *p) {
525 delete static_cast<SparseTensorReader *>(p);
528 void *createSparseTensorWriter(char *filename) {
529 std::ostream *file =
530 (filename[0] == 0) ? &std::cout : new std::ofstream(filename);
531 *file << "# extended FROSTT format\n";
532 return static_cast<void *>(file);
535 void delSparseTensorWriter(void *p) {
536 std::ostream *file = static_cast<std::ostream *>(p);
537 file->flush();
538 assert(file->good());
539 if (file != &std::cout)
540 delete file;
543 } // extern "C"
545 #undef MEMREF_GET_PAYLOAD
546 #undef ASSERT_USIZE_EQ
547 #undef MEMREF_GET_USIZE
548 #undef ASSERT_NO_STRIDE
550 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS