1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements reading and writing sparse tensor files.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/SparseTensor/File.h"
18 using namespace mlir::sparse_tensor
;
20 /// Opens the file for reading.
21 void SparseTensorReader::openFile() {
23 fprintf(stderr
, "Already opened file %s\n", filename
);
26 file
= fopen(filename
, "r");
28 fprintf(stderr
, "Cannot find file %s\n", filename
);
34 void SparseTensorReader::closeFile() {
41 /// Attempts to read a line from the file.
42 void SparseTensorReader::readLine() {
43 if (!fgets(line
, kColWidth
, file
)) {
44 fprintf(stderr
, "Cannot read next line of %s\n", filename
);
49 /// Reads and parses the file's header.
50 void SparseTensorReader::readHeader() {
51 assert(file
&& "Attempt to readHeader() before openFile()");
52 if (strstr(filename
, ".mtx")) {
54 } else if (strstr(filename
, ".tns")) {
55 readExtFROSTTHeader();
57 fprintf(stderr
, "Unknown format %s\n", filename
);
60 assert(isValid() && "Failed to read the header");
63 /// Asserts the shape subsumes the actual dimension sizes. Is only
64 /// valid after parsing the header.
65 void SparseTensorReader::assertMatchesShape(uint64_t rank
,
66 const uint64_t *shape
) const {
67 assert(rank
== getRank() && "Rank mismatch");
68 for (uint64_t r
= 0; r
< rank
; r
++)
69 assert((shape
[r
] == 0 || shape
[r
] == idata
[2 + r
]) &&
70 "Dimension size mismatch");
73 bool SparseTensorReader::canReadAs(PrimaryType valTy
) const {
75 case ValueKind::kInvalid
:
76 assert(false && "Must readHeader() before calling canReadAs()");
77 return false; // In case assertions are disabled.
78 case ValueKind::kPattern
:
80 case ValueKind::kInteger
:
81 // When the file is specified to store integer values, we still
82 // allow implicitly converting those to floating primary-types.
83 return isRealPrimaryType(valTy
);
84 case ValueKind::kReal
:
85 // When the file is specified to store real/floating values, then
86 // we disallow implicit conversion to integer primary-types.
87 return isFloatingPrimaryType(valTy
);
88 case ValueKind::kComplex
:
89 // When the file is specified to store complex values, then we
90 // require a complex primary-type.
91 return isComplexPrimaryType(valTy
);
92 case ValueKind::kUndefined
:
93 // The "extended" FROSTT format doesn't specify a ValueKind.
94 // So we allow implicitly converting the stored values to both
95 // integer and floating primary-types.
96 return isRealPrimaryType(valTy
);
98 fprintf(stderr
, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_
));
102 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
103 static inline void toLower(char *token
) {
104 for (char *c
= token
; *c
; c
++)
108 /// Idiomatic name for checking string equality.
109 static inline bool streq(const char *lhs
, const char *rhs
) {
110 return strcmp(lhs
, rhs
) == 0;
113 /// Idiomatic name for checking string inequality.
114 static inline bool strne(const char *lhs
, const char *rhs
) {
115 return strcmp(lhs
, rhs
); // aka `!= 0`
118 /// Read the MME header of a general sparse matrix of type real.
119 void SparseTensorReader::readMMEHeader() {
126 if (fscanf(file
, "%63s %63s %63s %63s %63s\n", header
, object
, format
, field
,
128 fprintf(stderr
, "Corrupt header in %s\n", filename
);
131 // Convert all to lowercase up front (to avoid accidental redundancy).
137 // Process `field`, which specify pattern or the data type of the values.
138 if (streq(field
, "pattern")) {
139 valueKind_
= ValueKind::kPattern
;
140 } else if (streq(field
, "real")) {
141 valueKind_
= ValueKind::kReal
;
142 } else if (streq(field
, "integer")) {
143 valueKind_
= ValueKind::kInteger
;
144 } else if (streq(field
, "complex")) {
145 valueKind_
= ValueKind::kComplex
;
147 fprintf(stderr
, "Unexpected header field value in %s\n", filename
);
151 isSymmetric_
= streq(symmetry
, "symmetric");
152 // Make sure this is a general sparse matrix.
153 if (strne(header
, "%%matrixmarket") || strne(object
, "matrix") ||
154 strne(format
, "coordinate") ||
155 (strne(symmetry
, "general") && !isSymmetric_
)) {
156 fprintf(stderr
, "Cannot find a general sparse matrix in %s\n", filename
);
165 // Next line contains M N NNZ.
166 idata
[0] = 2; // rank
167 if (sscanf(line
, "%" PRIu64
"%" PRIu64
"%" PRIu64
"\n", idata
+ 2, idata
+ 3,
169 fprintf(stderr
, "Cannot find size in %s\n", filename
);
174 /// Read the "extended" FROSTT header. Although not part of the documented
175 /// format, we assume that the file starts with optional comments followed
176 /// by two lines that define the rank, the number of nonzeros, and the
177 /// dimensions sizes (one per rank) of the sparse tensor.
178 void SparseTensorReader::readExtFROSTTHeader() {
185 // Next line contains RANK and NNZ.
186 if (sscanf(line
, "%" PRIu64
"%" PRIu64
"\n", idata
, idata
+ 1) != 2) {
187 fprintf(stderr
, "Cannot find metadata in %s\n", filename
);
190 // Followed by a line with the dimension sizes (one per rank).
191 for (uint64_t r
= 0; r
< idata
[0]; r
++) {
192 if (fscanf(file
, "%" PRIu64
, idata
+ 2 + r
) != 1) {
193 fprintf(stderr
, "Cannot find dimension size %s\n", filename
);
197 readLine(); // end of line
198 // The FROSTT format does not define the data type of the nonzero elements.
199 valueKind_
= ValueKind::kUndefined
;