Import code from my Subversion repository
[black_box_cellml.git] / runtime / BlackBoxRuntime.cpp
bloba7f339c4db0b7d5a001ad26ed9bfa4e5e45048bd
1 #include "BlackBoxRuntime.hpp"
2 #include <mysql++/mysql++.h>
4 BlackBoxCache* gBlackBoxCache = NULL;
6 char*
7 BBArun_wchar_to_UTF8(const wchar_t *str)
9 uint32_t len = 0;
10 const wchar_t* p = str;
11 wchar_t c;
12 for (; (c = *p); p++)
14 if (c <= 0x7F)
15 len++;
16 else
17 #ifndef WCHAR_T_IS_32BIT
18 if (c <= 0x7FF)
19 #endif
20 len += 2;
21 #ifndef WCHAR_T_IS_32BIT
22 else if (c <= 0xFFFF)
23 len += 3;
24 else
25 len += 4;
26 #endif
28 char* newData = (char*)malloc(len + 1);
29 char* np = newData;
30 p = str;
32 while ((c = *p++))
34 if (c <= 0x7F)
35 *np++ = (char)c;
36 else if (c <= 0x7FF)
38 *np++ = (char)(0xC0 | ((c >> 6) & 0x1F));
39 *np++ = (char)(0x80 | (c & 0x3F));
41 #ifndef WCHAR_T_CONSTANT_WIDTH
42 else if ((c & 0xFC00) == 0xD800)
44 uint16_t u = ((c >> 6) & 0xF) + 1;
45 *np++ = 0xF0 | ((u >> 2) & 0x7);
46 *np++ = 0x80 | ((u << 4) & 0x30) | ((c >> 2) & 0xF);
47 wchar_t c2 = *p++;
48 *np++ = 0x80 | ((c << 4) & 0x30) | ((c2 >> 6) & 0xF);
49 *np++ = 0x80 | (c2 & 0x3F);
51 #endif
52 else
53 #if defined(WCHAR_T_CONSTANT_WIDTH) && !defined(WCHAR_T_IS_32BIT)
54 if (c <= 0xFFFF)
55 #endif
57 *np++ = (char)(0xD0 | ((c >> 12) & 0xF));
58 *np++ = (char)(0x80 | ((c >> 6) & 0x3F));
59 *np++ = (char)(0x80 | (c & 0x3F));
61 #if defined(WCHAR_T_CONSTANT_WIDTH) && !defined(WCHAR_T_IS_32BIT)
62 else
64 *np++ = (char)(0xF0 | ((c >> 18) & 0x7));
65 *np++ = (char)(0x80 | ((c >> 12) & 0x3F));
66 *np++ = (char)(0x80 | ((c >> 6) & 0x3F));
67 *np++ = (char)(0x80 | (c & 0x3F));
69 #endif
72 *np++ = 0;
73 return newData;
76 class SQLModelPersister
77 : public BlackBoxPersistBase
79 public:
80 SQLModelPersister()
81 : mConn(mysqlpp::use_exceptions), mReady(false)
83 const char* db = getenv("BLABOC_DB");
84 if (db == NULL)
86 printf("Missing BLABOC_DB. Not creating persister.\n");
87 return;
89 const char* host = getenv("BLABOC_HOST");
90 if (host == NULL)
92 printf("Missing BLABOC_HOST. Not creating persister.\n");
93 return;
95 const char* user = getenv("BLABOC_USER");
96 if (user == NULL)
98 printf("Missing BLABOC_USER. Not creating persister.\n");
99 return;
101 const char* pass = getenv("BLABOC_PASSWORD");
102 if (pass == NULL)
104 printf("Missing BLABOC_PASSWORD. Not creating persister.\n");
105 return;
110 mConn.connect(db, host, user, pass);
111 mReady = true;
113 catch (std::exception& e)
115 printf("Error connecting to database: %s. Not creating persister.\n",
116 e.what());
120 ~SQLModelPersister()
124 void
125 saveState(const char* aURL, const std::string& aData)
127 if (!mReady)
128 return;
130 // Figure out if we need to update or insert...
131 mysqlpp::Query query(mConn.query());
133 query << "SELECT COUNT(*) FROM black_box_state WHERE url=\""
134 << mysqlpp::escape << aURL << "\"";
138 mysqlpp::Result res = query.store();
139 unsigned int count = res.at(0)[(unsigned int)0];
140 query.reset();
141 if (count == 0)
143 query << "INSERT INTO black_box_state (url,data) VALUES (\""
144 << mysqlpp::escape << aURL << "\", \""
145 << mysqlpp::escape << aData << "\")";
146 query.execute();
148 else
150 query << "UPDATE black_box_state SET data=\""
151 << mysqlpp::escape << aData << "\" WHERE url=\""
152 << mysqlpp::escape << aURL << "\"";
153 query.execute();
156 catch (std::exception& e)
158 printf("Exception saving state: %s\n", e.what());
162 void
163 restoreState(const char* aURL, std::string& aData)
165 if (!mReady)
167 aData.assign("");
168 return;
171 mysqlpp::Query query(mConn.query());
172 query << "SELECT data FROM black_box_state WHERE url=\""
173 << mysqlpp::escape << aURL << "\"";
177 mysqlpp::Result res = query.store();
178 if (res.rows() == 0)
180 aData.assign("");
181 return;
183 mysqlpp::ColData col = res.at(0)["data"];
184 aData.assign(col.data(), col.size());
186 catch (std::exception& e)
188 printf("Exception restoring state: %s\n", e.what());
189 aData.assign("");
193 private:
194 mysqlpp::Connection mConn;
195 bool mReady;
198 MixedModelBase::MixedModelBase()
199 : mPersister(new SQLModelPersister()),
200 VARIABLES(NULL), INPUTS(NULL), OUTPUTS(NULL),
201 CONSTANTS(NULL), TMPINPUTS(NULL), TMPOUTPUTS(NULL),
202 blackBoxes(NULL), blackBoxNames(NULL),
203 BACKVARIABLES(NULL), BACKOUTPUTS(NULL)
207 MixedModelBase::~MixedModelBase()
209 size_t i;
211 if (BACKVARIABLES != NULL)
212 delete [] BACKVARIABLES;
214 if (BACKOUTPUTS != NULL)
215 delete [] BACKOUTPUTS;
217 if (blackBoxes != NULL)
219 std::string state;
221 if (gBlackBoxCache == NULL)
223 for (i = 0; i < nBlackBoxes; i++)
224 if (blackBoxes[i] != NULL)
226 blackBoxes[i]->Serialise(state);
227 if (state.size() != 0)
228 mPersister->saveState(blackBoxNames[i].c_str(), state);
229 delete blackBoxes[i];
232 delete [] blackBoxes;
235 delete mPersister;
237 if (blackBoxNames != NULL)
238 delete [] blackBoxNames;
241 bool
242 MixedModelBase::runConverged()
244 if (mConvergeSteps++ > 10)
246 printf("Warning: It took more than 10 steps for the black-boxes to "
247 "converge within tolerance. Quitting early.\n");
248 return true;
251 uint32_t i, l = countOtherVariables();
252 double RSS = 0, t;
254 #if 0
255 printf("Computing RSS with...\n");
256 for (i = 0; i < countOtherVariables(); i++)
257 printf(" * OTHERVARIABLES[%u] = %g\n", i, VARIABLES[i]);
258 for (i = 0; i < countOutputVariables(); i++)
259 printf(" * OUTPUTS[%u] = %g\n", i, OUTPUTS[i]);
260 printf("Listing complete.\n");
261 #endif
263 for (i = 0; i < l; i++)
265 t = VARIABLES[i] - BACKVARIABLES[i];
266 RSS += t * t;
268 l = countOutputVariables();
269 for (i = 0; i < l; i++)
271 t = OUTPUTS[i] - BACKOUTPUTS[i];
272 RSS += t * t;
275 #if 0
276 printf("Step %u: RSS = %g\n", mConvergeSteps, RSS);
277 #endif
279 if (RSS < THRESHOLD)
280 return true;
282 return false;
285 void
286 MixedModelBase::createBlackBox
288 uint32_t aIdx,
289 const wchar_t* URI,
290 uint32_t inputs,
291 uint32_t outputs
294 if (gBlackBoxCache)
296 BlackBoxModelBase* cachedModel = gBlackBoxCache->FindModelInCache(URI);
297 if (cachedModel != NULL)
299 blackBoxes[aIdx] = cachedModel;
301 char* utf8uri = BBArun_wchar_to_UTF8(URI);
302 blackBoxNames[aIdx] = utf8uri;
303 free(utf8uri);
305 return;
309 const wchar_t* offs = wcschr(URI, '#');
310 std::wstring name, id;
311 if (offs == NULL)
312 name = URI;
313 else
314 name = std::wstring(URI, offs - URI);
316 if (offs != NULL)
317 id = offs + 1;
319 std::map<std::wstring, BlackBoxModelFactory*>::iterator i
320 = blackBoxFactories().find(name);
322 if (i == blackBoxFactories().end())
324 printf("Warning: Cannot create a black box with URI base %S because "
325 "no factory could be found.\n",
326 name.c_str());
327 throw std::exception();
330 blackBoxes[aIdx] = (*i).second->createBlackBoxModel(id, inputs, outputs);
331 char* utf8uri = BBArun_wchar_to_UTF8(URI);
332 blackBoxNames[aIdx] = utf8uri;
333 free(utf8uri);
335 std::string data;
336 mPersister->restoreState(blackBoxNames[aIdx].c_str(), data);
337 if (data != "")
338 blackBoxes[aIdx]->Deserialise(data);
340 if (gBlackBoxCache != NULL)
341 gBlackBoxCache->CacheModel(URI, blackBoxes[aIdx]);
344 std::map<std::wstring,BlackBoxModelFactory*>* MixedModelBase::sBlackBoxFactories = NULL;
346 class StandardBlackBoxCache
347 : public BlackBoxCache
349 public:
350 StandardBlackBoxCache()
354 ~StandardBlackBoxCache()
356 DiscardCache();
359 BlackBoxModelBase*
360 FindModelInCache(const std::wstring& aURI)
362 std::map<std::wstring, BlackBoxModelBase*>::iterator i = mModels.find(aURI);
364 if (i == mModels.end())
365 return NULL;
367 return (*i).second;
370 void
371 CacheModel(const std::wstring& aURI, BlackBoxModelBase* aModel)
373 mModels.insert(std::pair<std::wstring, BlackBoxModelBase*>(aURI, aModel));
376 void
377 SaveAllCachedModels()
379 std::map<std::wstring, BlackBoxModelBase*>::iterator i;
380 std::string state;
382 SQLModelPersister smp;
384 for (i = mModels.begin(); i != mModels.end(); i++)
386 (*i).second->Serialise(state);
387 if (state.size() != 0)
389 char* utf8uri = BBArun_wchar_to_UTF8((*i).first.c_str());
390 smp.saveState(utf8uri, state);
391 free(utf8uri);
396 void
397 DiscardCache()
399 std::map<std::wstring, BlackBoxModelBase*>::iterator i;
400 for (i = mModels.begin(); i != mModels.end(); i++)
401 delete (*i).second;
402 mModels.clear();
404 private:
405 std::map<std::wstring, BlackBoxModelBase*> mModels;
408 void setupStandardBlackBoxCache()
410 gBlackBoxCache = new StandardBlackBoxCache();