[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / ExecutionEngine / SparseTensor / File.cpp
blobc49ec0998fb444f2e9f238ce146021c550a30508
1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements reading and writing sparse tensor files.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/SparseTensor/File.h"
15 #include <cctype>
16 #include <cstring>
18 using namespace mlir::sparse_tensor;
20 /// Opens the file for reading.
21 void SparseTensorReader::openFile() {
22 if (file)
23 MLIR_SPARSETENSOR_FATAL("Already opened file %s\n", filename);
24 file = fopen(filename, "r");
25 if (!file)
26 MLIR_SPARSETENSOR_FATAL("Cannot find file %s\n", filename);
29 /// Closes the file.
30 void SparseTensorReader::closeFile() {
31 if (file) {
32 fclose(file);
33 file = nullptr;
37 /// Attempts to read a line from the file.
38 void SparseTensorReader::readLine() {
39 if (!fgets(line, kColWidth, file))
40 MLIR_SPARSETENSOR_FATAL("Cannot read next line of %s\n", filename);
43 /// Reads and parses the file's header.
44 void SparseTensorReader::readHeader() {
45 assert(file && "Attempt to readHeader() before openFile()");
46 if (strstr(filename, ".mtx"))
47 readMMEHeader();
48 else if (strstr(filename, ".tns"))
49 readExtFROSTTHeader();
50 else
51 MLIR_SPARSETENSOR_FATAL("Unknown format %s\n", filename);
52 assert(isValid() && "Failed to read the header");
55 /// Asserts the shape subsumes the actual dimension sizes. Is only
56 /// valid after parsing the header.
57 void SparseTensorReader::assertMatchesShape(uint64_t rank,
58 const uint64_t *shape) const {
59 assert(rank == getRank() && "Rank mismatch");
60 for (uint64_t r = 0; r < rank; ++r)
61 assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
62 "Dimension size mismatch");
65 bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
66 switch (valueKind_) {
67 case ValueKind::kInvalid:
68 assert(false && "Must readHeader() before calling canReadAs()");
69 return false; // In case assertions are disabled.
70 case ValueKind::kPattern:
71 return true;
72 case ValueKind::kInteger:
73 // When the file is specified to store integer values, we still
74 // allow implicitly converting those to floating primary-types.
75 return isRealPrimaryType(valTy);
76 case ValueKind::kReal:
77 // When the file is specified to store real/floating values, then
78 // we disallow implicit conversion to integer primary-types.
79 return isFloatingPrimaryType(valTy);
80 case ValueKind::kComplex:
81 // When the file is specified to store complex values, then we
82 // require a complex primary-type.
83 return isComplexPrimaryType(valTy);
84 case ValueKind::kUndefined:
85 // The "extended" FROSTT format doesn't specify a ValueKind.
86 // So we allow implicitly converting the stored values to both
87 // integer and floating primary-types.
88 return isRealPrimaryType(valTy);
90 MLIR_SPARSETENSOR_FATAL("Unknown ValueKind: %d\n",
91 static_cast<uint8_t>(valueKind_));
94 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
95 static inline void toLower(char *token) {
96 for (char *c = token; *c; ++c)
97 *c = tolower(*c);
100 /// Idiomatic name for checking string equality.
101 static inline bool streq(const char *lhs, const char *rhs) {
102 return strcmp(lhs, rhs) == 0;
105 /// Idiomatic name for checking string inequality.
106 static inline bool strne(const char *lhs, const char *rhs) {
107 return strcmp(lhs, rhs); // aka `!= 0`
110 /// Read the MME header of a general sparse matrix of type real.
111 void SparseTensorReader::readMMEHeader() {
112 char header[64];
113 char object[64];
114 char format[64];
115 char field[64];
116 char symmetry[64];
117 // Read header line.
118 if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
119 symmetry) != 5)
120 MLIR_SPARSETENSOR_FATAL("Corrupt header in %s\n", filename);
121 // Convert all to lowercase up front (to avoid accidental redundancy).
122 toLower(header);
123 toLower(object);
124 toLower(format);
125 toLower(field);
126 toLower(symmetry);
127 // Process `field`, which specify pattern or the data type of the values.
128 if (streq(field, "pattern"))
129 valueKind_ = ValueKind::kPattern;
130 else if (streq(field, "real"))
131 valueKind_ = ValueKind::kReal;
132 else if (streq(field, "integer"))
133 valueKind_ = ValueKind::kInteger;
134 else if (streq(field, "complex"))
135 valueKind_ = ValueKind::kComplex;
136 else
137 MLIR_SPARSETENSOR_FATAL("Unexpected header field value in %s\n", filename);
138 // Set properties.
139 isSymmetric_ = streq(symmetry, "symmetric");
140 // Make sure this is a general sparse matrix.
141 if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
142 strne(format, "coordinate") ||
143 (strne(symmetry, "general") && !isSymmetric_))
144 MLIR_SPARSETENSOR_FATAL("Cannot find a general sparse matrix in %s\n",
145 filename);
146 // Skip comments.
147 while (true) {
148 readLine();
149 if (line[0] != '%')
150 break;
152 // Next line contains M N NNZ.
153 idata[0] = 2; // rank
154 if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
155 idata + 1) != 3)
156 MLIR_SPARSETENSOR_FATAL("Cannot find size in %s\n", filename);
159 /// Read the "extended" FROSTT header. Although not part of the documented
160 /// format, we assume that the file starts with optional comments followed
161 /// by two lines that define the rank, the number of nonzeros, and the
162 /// dimensions sizes (one per rank) of the sparse tensor.
163 void SparseTensorReader::readExtFROSTTHeader() {
164 // Skip comments.
165 while (true) {
166 readLine();
167 if (line[0] != '#')
168 break;
170 // Next line contains RANK and NNZ.
171 if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
172 MLIR_SPARSETENSOR_FATAL("Cannot find metadata in %s\n", filename);
173 // Followed by a line with the dimension sizes (one per rank).
174 for (uint64_t r = 0; r < idata[0]; ++r)
175 if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
176 MLIR_SPARSETENSOR_FATAL("Cannot find dimension size %s\n", filename);
177 readLine(); // end of line
178 // The FROSTT format does not define the data type of the nonzero elements.
179 valueKind_ = ValueKind::kUndefined;