Import code from my Subversion repository
[black_box_cellml.git] / runtime / ComputeAverageRSS.cpp
blob8f3fdb631a8a7d0cd59ee575f9a8c134702d20a0
1 #include "TrainingDatabase.hpp"
2 #include <string>
3 #include <fcntl.h>
4 #include "Utilities.hxx"
5 #include <mysql++/mysql++.h>
7 class CompAvgException
8 : public std::exception
10 public:
11 CompAvgException(const std::string& aWhat)
12 : mWhat(aWhat)
16 ~CompAvgException() throw()
20 const char* what() const throw()
22 return mWhat.c_str();
25 private:
26 std::string mWhat;
29 class AverageComputer
31 public:
32 AverageComputer(mysqlpp::Connection& aConn, const char* aFilename)
33 : mConn(aConn), mAnswers(NULL), mSum(NULL), mTemp(NULL)
35 mFile = open(aFilename, O_RDONLY | O_LARGEFILE);
36 if (mFile == -1)
37 throw CompAvgException("Can't open training file.");
39 if (read(mFile, &mTdh, sizeof(TrainingDatabaseHeader)) != sizeof(TrainingDatabaseHeader))
40 throw CompAvgException("Can't read the training file.");
42 mJump = mTdh.ninput_signals * sizeof(double);
44 mAnswers = new double[mTdh.noutput_signals];
45 mSum = new double[mTdh.noutput_signals];
46 mTemp = new double[mTdh.noutput_signals];
49 ~AverageComputer()
51 if (mFile != -1)
52 close(mFile);
53 if (mAnswers)
54 delete [] mAnswers;
55 if (mTemp)
56 delete [] mTemp;
57 if (mSum)
58 delete [] mSum;
61 void
62 computeRSSForAllVariants()
64 uint32_t v = mTdh.first_variant_id;
65 uint32_t lv = v + mTdh.nvariants;
67 for (; v < lv; v++)
69 double RSS = computeRSSFromAverageExcluding(v);
70 mysqlpp::Query q(mConn.query());
71 q << "UPDATE crossval_rss SET average_rss='" << RSS
72 << "' WHERE leftout_variant_id='" << v << "'";
73 q.execute();
77 double
78 computeRSSFromAverageExcluding(uint32_t aExclude)
80 if (aExclude < mTdh.first_variant_id)
81 throw CompAvgException("Excluded variant out of range.");
83 aExclude -= mTdh.first_variant_id;
84 if (aExclude >= mTdh.nvariants)
85 throw CompAvgException("Excluded variant out of range.");
87 lseek64(mFile, sizeof(mTdh), SEEK_SET);
88 memset(mSum, 0, sizeof(double) * mTdh.noutput_signals);
90 uint32_t i(0);
91 for (; i < mTdh.nvariants; i++)
93 // Skip the inputs...
94 if (mJump != 0)
95 lseek64(mFile, mJump, SEEK_CUR);
97 double * into = (aExclude == i) ? mAnswers : mTemp;
99 if (read(mFile, into, sizeof(double) * mTdh.noutput_signals) !=
100 sizeof(double) * mTdh.noutput_signals)
101 throw CompAvgException("Couldn't load signals.");
103 if (into == mTemp)
105 uint32_t j;
106 for (j = 0; j < mTdh.noutput_signals; j++)
107 mSum[j] += mTemp[j];
111 double factor = 1.0 / (mTdh.nvariants - 1);
112 double RSS = 0.0;
113 for (i = 0; i < mTdh.noutput_signals; i++)
115 double t = mAnswers[i] - factor * mSum[i];
116 RSS += t * t;
119 return RSS;
122 private:
123 int mFile;
124 TrainingDatabaseHeader mTdh;
125 mysqlpp::Connection& mConn;
126 double* mAnswers, * mSum, * mTemp;
127 off64_t mJump;
131 main(int argc, char** argv)
133 if (argc < 6)
135 printf("Usage: ComputeAverageRSS db_host db_user db_password db_database trainingfile\n");
136 // printf("Set variant=-1 if RSS is sum for all variants.\n");
137 return 1;
142 mysqlpp::Connection conn(mysqlpp::use_exceptions);
143 conn.connect(argv[4], argv[1], argv[2], argv[3]);
145 AverageComputer ac(conn, argv[5]);
146 ac.computeRSSForAllVariants();
148 catch (std::exception& e)
150 printf("Error: %s\n", e.what());