1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements reading and writing sparse tensor files.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/SparseTensor/File.h"
18 using namespace mlir::sparse_tensor
;
20 /// Opens the file for reading.
21 void SparseTensorReader::openFile() {
23 MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename
);
24 file
= fopen(filename
, "r");
26 MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename
);
30 void SparseTensorReader::closeFile() {
37 /// Attempts to read a line from the file.
38 void SparseTensorReader::readLine() {
39 if (!fgets(line
, kColWidth
, file
))
40 MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename
);
43 /// Reads and parses the file's header.
44 void SparseTensorReader::readHeader() {
45 assert(file
&& "Attempt to readHeader() before openFile()");
46 if (strstr(filename
, ".mtx"))
48 else if (strstr(filename
, ".tns"))
49 readExtFROSTTHeader();
51 MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename
);
52 assert(isValid() && "Failed to read the header");
55 /// Asserts the shape subsumes the actual dimension sizes. Is only
56 /// valid after parsing the header.
57 void SparseTensorReader::assertMatchesShape(uint64_t rank
,
58 const uint64_t *shape
) const {
59 assert(rank
== getRank() && "Rank mismatch");
60 for (uint64_t r
= 0; r
< rank
; ++r
)
61 assert((shape
[r
] == 0 || shape
[r
] == idata
[2 + r
]) &&
62 "Dimension size mismatch");
65 bool SparseTensorReader::canReadAs(PrimaryType valTy
) const {
67 case ValueKind::kInvalid
:
68 assert(false && "Must readHeader() before calling canReadAs()");
69 return false; // In case assertions are disabled.
70 case ValueKind::kPattern
:
72 case ValueKind::kInteger
:
73 // When the file is specified to store integer values, we still
74 // allow implicitly converting those to floating primary-types.
75 return isRealPrimaryType(valTy
);
76 case ValueKind::kReal
:
77 // When the file is specified to store real/floating values, then
78 // we disallow implicit conversion to integer primary-types.
79 return isFloatingPrimaryType(valTy
);
80 case ValueKind::kComplex
:
81 // When the file is specified to store complex values, then we
82 // require a complex primary-type.
83 return isComplexPrimaryType(valTy
);
84 case ValueKind::kUndefined
:
85 // The "extended" FROSTT format doesn't specify a ValueKind.
86 // So we allow implicitly converting the stored values to both
87 // integer and floating primary-types.
88 return isRealPrimaryType(valTy
);
90 MLIR_SPARSETENSOR_FATAL("Unknown ValueKind: %d\n",
91 static_cast<uint8_t>(valueKind_
));
94 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
95 static inline void toLower(char *token
) {
96 for (char *c
= token
; *c
; ++c
)
100 /// Idiomatic name for checking string equality.
101 static inline bool streq(const char *lhs
, const char *rhs
) {
102 return strcmp(lhs
, rhs
) == 0;
105 /// Idiomatic name for checking string inequality.
106 static inline bool strne(const char *lhs
, const char *rhs
) {
107 return strcmp(lhs
, rhs
); // aka `!= 0`
110 /// Read the MME header of a general sparse matrix of type real.
111 void SparseTensorReader::readMMEHeader() {
118 if (fscanf(file
, "%63s %63s %63s %63s %63s\n", header
, object
, format
, field
,
120 MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename
);
121 // Convert all to lowercase up front (to avoid accidental redundancy).
127 // Process `field`, which specify pattern or the data type of the values.
128 if (streq(field
, "pattern"))
129 valueKind_
= ValueKind::kPattern
;
130 else if (streq(field
, "real"))
131 valueKind_
= ValueKind::kReal
;
132 else if (streq(field
, "integer"))
133 valueKind_
= ValueKind::kInteger
;
134 else if (streq(field
, "complex"))
135 valueKind_
= ValueKind::kComplex
;
137 MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename
);
139 isSymmetric_
= streq(symmetry
, "symmetric");
140 // Make sure this is a general sparse matrix.
141 if (strne(header
, "%%matrixmarket") || strne(object
, "matrix") ||
142 strne(format
, "coordinate") ||
143 (strne(symmetry
, "general") && !isSymmetric_
))
144 MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n",
152 // Next line contains M N NNZ.
153 idata
[0] = 2; // rank
154 if (sscanf(line
, "%" PRIu64
"%" PRIu64
"%" PRIu64
"\n", idata
+ 2, idata
+ 3,
156 MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename
);
159 /// Read the "extended" FROSTT header. Although not part of the documented
160 /// format, we assume that the file starts with optional comments followed
161 /// by two lines that define the rank, the number of nonzeros, and the
162 /// dimensions sizes (one per rank) of the sparse tensor.
163 void SparseTensorReader::readExtFROSTTHeader() {
170 // Next line contains RANK and NNZ.
171 if (sscanf(line
, "%" PRIu64
"%" PRIu64
"\n", idata
, idata
+ 1) != 2)
172 MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename
);
173 // Followed by a line with the dimension sizes (one per rank).
174 for (uint64_t r
= 0; r
< idata
[0]; ++r
)
175 if (fscanf(file
, "%" PRIu64
, idata
+ 2 + r
) != 1)
176 MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename
);
177 readLine(); // end of line
178 // The FROSTT format does not define the data type of the nonzero elements.
179 valueKind_
= ValueKind::kUndefined
;