1 #include "TrainingDatabase.hpp"
2 #include <mysql++/mysql++.h>
6 : public std::exception
9 DBBException(const std::string
& aWhat
)
14 ~DBBException() throw()
31 DatabaseBuilder(const char* sql_host
, const char* sql_user
,
32 const char* sql_password
, const char* sql_database
,
33 const char* db_filename
)
34 : mFile(NULL
), mConn(mysqlpp::use_exceptions
)
36 mFile
= fopen(db_filename
, "w");
38 throw DBBException("Cannot open the database input file.");
40 mConn
.connect(sql_database
, sql_host
, sql_user
, sql_password
);
52 mysqlpp::Query
q(mConn
.query());
53 q
<< "SELECT id FROM variants ORDER BY id";
54 mysqlpp::Result
r(q
.store());
56 struct TrainingDatabaseHeader tdh
;
58 tdh
.first_variant_id
= r
.at(0)["id"];
59 tdh
.nvariants
= r
.size();
62 q
<< "SELECT v.dirtype, COUNT(*) FROM training_data AS td, variables AS v "
63 "WHERE v.id = td.variable_id AND td.variant_id="
64 << tdh
.first_variant_id
<< " GROUP BY v.dirtype";
65 mysqlpp::Result
r2(q
.store());
67 tdh
.ninput_signals
= 0;
68 tdh
.noutput_signals
= 0;
69 uint32_t i
, l
= r2
.size();
71 for (i
= 0; i
< l
; i
++)
73 uint32_t type
= r2
.at(i
).at(0);
75 tdh
.ninput_signals
= r2
.at(i
).at(1);
77 tdh
.noutput_signals
= r2
.at(i
).at(1);
80 if (fwrite(&tdh
, sizeof(tdh
), 1, mFile
) != 1)
81 throw DBBException(strerror(errno
));
84 uint32_t mExpectVariant
= tdh
.first_variant_id
;
85 double inputs
[tdh
.ninput_signals
+ 1], outputs
[tdh
.noutput_signals
];
86 double * iotab
[] = {inputs
, outputs
};
87 uint32_t iolims
[] = {tdh
.ninput_signals
, tdh
.noutput_signals
};
90 for (i
= 0; i
< l
; i
++)
92 uint32_t variant
= r
.at(i
).at(0);
94 if (variant
!= mExpectVariant
++)
95 throw DBBException("Expected consecutive variant IDs.");
97 printf("Building row for variant %u\n", variant
);
100 q
<< "SELECT td.value, td.array_index, v.dirtype FROM training_data AS "
101 "td, variables AS v WHERE td.variable_id = v.id AND td.variant_id = "
103 mysqlpp::ResUse
ru(q
.use());
109 mysqlpp::Row
row(ru
.fetch_row());
110 uint32_t io
= row
["dirtype"];
111 uint32_t idx
= row
["array_index"];
115 throw DBBException("dirtype field invalid.");
117 else if (idx
>= iolims
[io
])
119 throw DBBException("array_index field invalid.");
122 iotab
[io
][idx
] = row
["value"];
125 catch (mysqlpp::EndOfResults
& eor
)
129 if (fwrite(inputs
, tdh
.ninput_signals
* sizeof(double), 1, mFile
) < 0)
132 throw DBBException("Error writing the input signals.");
135 if (fwrite(outputs
, tdh
.noutput_signals
* sizeof(double), 1, mFile
) < 0)
138 throw DBBException("Error writing the output signals.");
145 mysqlpp::Connection mConn
;
149 main(int argc
, char** argv
)
153 printf("Usage: BuildTrainingDatabase sql_host sql_user sql_password sql_database db_filename\n");
159 DatabaseBuilder
dbb(argv
[1], argv
[2], argv
[3], argv
[4], argv
[5]);
162 catch (std::exception
& e
)
164 printf("Error: %s\n", e
.what());