Import code from my Subversion repository
[black_box_cellml.git] / tests / BlackBoxesForVariants.cpp
blob0315c9c1af56ae48a2e42a1f36b4e2e1780ec5e7
1 #include "Utilities.hxx"
2 #include "IfaceBlaBoC.hxx"
3 #include "BlaBoCBootstrap.hpp"
4 #include "CellMLBootstrap.hpp"
5 #include <string>
6 #include <iostream>
7 #include <fstream>
8 #include <mysql++/mysql++.h>
9 #include <mysql++/transaction.h>
11 class BlackBoxesForVariants
13 public:
14 BlackBoxesForVariants()
15 : mInputIndex(0), mOutputIndex(0)
17 mBBS = already_AddRefd<iface::blaboc_api::BlaBoCBootstrap>
18 (createBlaBoCBootstrap());
19 mCBS = already_AddRefd<iface::cellml_api::CellMLBootstrap>
20 (CreateCellMLBootstrap());
21 mML = already_AddRefd<iface::cellml_api::DOMModelLoader>
22 (mCBS->modelLoader());
25 void
26 SetConnection
28 mysqlpp::Connection& aConn
31 mConn = &aConn;
32 mysqlpp::Query q(mConn->query());
33 q << "SELECT id, component, variable, dirtype FROM variables";
34 try
36 mVariables = q.store();
38 catch (std::exception& e)
40 printf("Error getting variables: %s\n", e.what());
44 void
45 DoVariant
47 uint32_t aVariantId,
48 const char* aVariantName
51 std::string inputFn(aVariantName);
52 std::string outputFn(aVariantName);
54 inputFn += ".cellml";
55 outputFn += ".obj";
57 mOutputIndex = 0;
58 mInputIndex = 0;
60 printf("Generating %s from %s\n", outputFn.c_str(), inputFn.c_str());
62 char* doingWhat;
64 try
66 doingWhat = "Loading the model";
67 uint32_t l = inputFn.length() + 1;
68 wchar_t data[l];
69 mbstowcs(data, inputFn.c_str(), l);
70 data[l - 1] = 0;
71 RETURN_INTO_OBJREF(m, iface::cellml_api::Model, mML->loadFromURL(data));
72 doingWhat = "Creating BlaBoC";
73 RETURN_INTO_OBJREF(bb, iface::blaboc_api::BlaBoC, mBBS->createBlaBoC(m));
75 mysqlpp::Transaction trans(*mConn);
77 mysqlpp::Query q(mConn->query());
78 q << "UPDATE training_data SET array_index=%0q "
79 "WHERE variable_id=%1q "
80 "AND variant_id=%2q;";
81 q.parse();
83 // Now we need to assign inputs and outputs...
84 uint32_t i = 0;
85 l = mVariables.rows();
86 for (; i < l; i++)
88 mysqlpp::Row r = mVariables.at(i);
89 const char* comp8 = r["component"];
90 uint32_t comp_l = strlen(comp8);
91 wchar_t comp[comp_l + 1];
92 mbstowcs(comp, comp8, comp_l + 1);
94 const char* var8 = r["variable"];
95 uint32_t var_l = strlen(var8);
96 wchar_t var[var_l + 1];
97 mbstowcs(var, var8, var_l + 1);
99 RETURN_INTO_OBJREF(mc, iface::cellml_api::CellMLComponentSet,
100 m->modelComponents());
101 RETURN_INTO_OBJREF(c, iface::cellml_api::CellMLComponent,
102 mc->getComponent(comp));
104 // Ignore it if it is knocked out in this model...
105 if (c == NULL)
106 continue;
108 RETURN_INTO_OBJREF(vs, iface::cellml_api::CellMLVariableSet,
109 c->variables());
110 RETURN_INTO_OBJREF(v, iface::cellml_api::CellMLVariable,
111 vs->getVariable(var));
112 if (v == NULL)
113 continue;
115 uint32_t idx;
116 int dt = r["dirtype"];
117 if (dt == 0)
119 idx = mInputIndex++;
120 bb->variableIsInput(v);
122 else
124 idx = mOutputIndex++;
125 bb->variableIsOutput(v);
130 uint32_t id = r["id"];
131 q.execute(mysqlpp::SQLString(idx), mysqlpp::SQLString(id),
132 mysqlpp::SQLString(aVariantId));
134 catch (std::exception& e)
136 printf("Error updating training array index: %s\n", e.what());
140 trans.commit();
142 doingWhat = "Computing the BlaBoC code";
146 bb->computeBlaBoC();
148 catch (iface::blaboc_api::BlaBoCException& bbe)
152 RETURN_INTO_WSTRING(em, bb->errorMessage());
153 if (em != L"")
155 printf("Error message: %S\n", em.c_str());
157 else
159 RETURN_INTO_WSTRING(ci, bb->classImplementation());
160 uint32_t len;
162 char* bb = mBBS->assembleBlaBoC(ci.c_str(), &len);
163 std::string code(bb, len);
164 free(bb);
166 std::ofstream obj(outputFn.c_str());
167 obj << code;
170 catch (...)
172 printf("Failure: %s.\n", doingWhat);
173 return;
177 private:
178 ObjRef<iface::blaboc_api::BlaBoCBootstrap> mBBS;
179 ObjRef<iface::cellml_api::CellMLBootstrap> mCBS;
180 ObjRef<iface::cellml_api::DOMModelLoader> mML;
181 mysqlpp::Connection* mConn;
182 mysqlpp::Result mVariables;
183 uint32_t mInputIndex, mOutputIndex;
187 main(int argc, char** argv)
189 if (argc < 5)
191 printf("Usage: BlackBoxesForVariants db_host db_user db_password db_database\n");
192 return 1;
195 mysqlpp::Connection conn(mysqlpp::use_exceptions);
196 mysqlpp::Connection conn1(mysqlpp::use_exceptions);
200 conn.connect(argv[4], argv[1], argv[2], argv[3]);
201 conn1.connect(argv[4], argv[1], argv[2], argv[3]);
203 catch (std::exception& e)
205 printf("Error connecting to database: %s\n", e.what());
206 return 2;
209 mysqlpp::Query query(conn.query());
211 BlackBoxesForVariants bbfv;
215 bbfv.SetConnection(conn1);
217 mysqlpp::ResUse r = query.store("SELECT id, name FROM variants");
221 while (true)
223 mysqlpp::Row row = r.fetch_row();
224 bbfv.DoVariant(row["id"], row["name"]);
227 catch (mysqlpp::EndOfResults& eor)
231 catch (std::exception& e)
233 printf("Error processing data: %s\n", e.what());
234 return 2;