Import code from my Subversion repository
[black_box_cellml.git] / runtime / BuildTrainingDatabase.cpp
blobacfa88ec6eb3d42a9e069746d4a12578ac101973
1 #include "TrainingDatabase.hpp"
2 #include <mysql++/mysql++.h>
3 #include <sys/errno.h>
5 class DBBException
6 : public std::exception
8 public:
9 DBBException(const std::string& aWhat)
10 : mWhat(aWhat)
14 ~DBBException() throw()
18 const char*
19 what() const throw()
21 return mWhat.c_str();
24 private:
25 std::string mWhat;
28 class DatabaseBuilder
30 public:
31 DatabaseBuilder(const char* sql_host, const char* sql_user,
32 const char* sql_password, const char* sql_database,
33 const char* db_filename)
34 : mFile(NULL), mConn(mysqlpp::use_exceptions)
36 mFile = fopen(db_filename, "w");
37 if (mFile == NULL)
38 throw DBBException("Cannot open the database input file.");
40 mConn.connect(sql_database, sql_host, sql_user, sql_password);
43 ~DatabaseBuilder()
45 if (mFile != NULL)
46 fclose(mFile);
49 void
50 BuildDatabase()
52 mysqlpp::Query q(mConn.query());
53 q << "SELECT id FROM variants ORDER BY id";
54 mysqlpp::Result r(q.store());
56 struct TrainingDatabaseHeader tdh;
58 tdh.first_variant_id = r.at(0)["id"];
59 tdh.nvariants = r.size();
61 q.reset();
62 q << "SELECT v.dirtype, COUNT(*) FROM training_data AS td, variables AS v "
63 "WHERE v.id = td.variable_id AND td.variant_id="
64 << tdh.first_variant_id << " GROUP BY v.dirtype";
65 mysqlpp::Result r2(q.store());
67 tdh.ninput_signals = 0;
68 tdh.noutput_signals = 0;
69 uint32_t i, l = r2.size();
71 for (i = 0; i < l; i++)
73 uint32_t type = r2.at(i).at(0);
74 if (type == 0)
75 tdh.ninput_signals = r2.at(i).at(1);
76 else
77 tdh.noutput_signals = r2.at(i).at(1);
80 if (fwrite(&tdh, sizeof(tdh), 1, mFile) != 1)
81 throw DBBException(strerror(errno));
83 l = r2.size();
84 uint32_t mExpectVariant = tdh.first_variant_id;
85 double inputs[tdh.ninput_signals + 1], outputs[tdh.noutput_signals];
86 double * iotab[] = {inputs, outputs};
87 uint32_t iolims[] = {tdh.ninput_signals, tdh.noutput_signals};
89 l = r.size();
90 for (i = 0; i < l; i++)
92 uint32_t variant = r.at(i).at(0);
94 if (variant != mExpectVariant++)
95 throw DBBException("Expected consecutive variant IDs.");
97 printf("Building row for variant %u\n", variant);
99 q.reset();
100 q << "SELECT td.value, td.array_index, v.dirtype FROM training_data AS "
101 "td, variables AS v WHERE td.variable_id = v.id AND td.variant_id = "
102 << variant;
103 mysqlpp::ResUse ru(q.use());
107 while (true)
109 mysqlpp::Row row(ru.fetch_row());
110 uint32_t io = row["dirtype"];
111 uint32_t idx = row["array_index"];
113 if (io > 1)
115 throw DBBException("dirtype field invalid.");
117 else if (idx >= iolims[io])
119 throw DBBException("array_index field invalid.");
122 iotab[io][idx] = row["value"];
125 catch (mysqlpp::EndOfResults& eor)
129 if (fwrite(inputs, tdh.ninput_signals * sizeof(double), 1, mFile) < 0)
131 perror("fwrite");
132 throw DBBException("Error writing the input signals.");
135 if (fwrite(outputs, tdh.noutput_signals * sizeof(double), 1, mFile) < 0)
137 perror("fwrite");
138 throw DBBException("Error writing the output signals.");
143 private:
144 FILE* mFile;
145 mysqlpp::Connection mConn;
149 main(int argc, char** argv)
151 if (argc < 6)
153 printf("Usage: BuildTrainingDatabase sql_host sql_user sql_password sql_database db_filename\n");
154 return 0;
159 DatabaseBuilder dbb(argv[1], argv[2], argv[3], argv[4], argv[5]);
160 dbb.BuildDatabase();
162 catch (std::exception& e)
164 printf("Error: %s\n", e.what());