Import code from my Subversion repository
[black_box_cellml.git] / runtime / ComputeAverageRMS.cpp
blob2e528b3293d40a2e1c49a7736d668ec9bceba62d
1 #include "TrainingDatabase.hpp"
2 #include <string>
3 #include <fcntl.h>
4 #include "Utilities.hxx"
5 #include <mysql++/mysql++.h>
7 class CompAvgException
8 : public std::exception
10 public:
11 CompAvgException(const std::string& aWhat)
12 : mWhat(aWhat)
16 ~CompAvgException() throw()
20 const char* what() const throw()
22 return mWhat.c_str();
25 private:
26 std::string mWhat;
29 class AverageComputer
31 public:
32 AverageComputer(mysqlpp::Connection& aConn, const char* aFilename)
33 : mConn(aConn), mAnswers(NULL), mSum(NULL), mTemp(NULL)
35 mFile = open(aFilename, O_RDONLY | O_LARGEFILE);
36 if (mFile == -1)
37 throw CompAvgException("Can't open training file.");
39 if (read(mFile, &mTdh, sizeof(TrainingDatabaseHeader)) != sizeof(TrainingDatabaseHeader))
40 throw CompAvgException("Can't read the training file.");
42 mJump = mTdh.ninput_signals * sizeof(double);
44 mAnswers = new double[mTdh.noutput_signals];
45 mSum = new double[mTdh.noutput_signals];
46 mTemp = new double[mTdh.noutput_signals];
49 ~AverageComputer()
51 if (mFile != -1)
52 close(mFile);
53 if (mAnswers)
54 delete [] mAnswers;
55 if (mTemp)
56 delete [] mTemp;
57 if (mSum)
58 delete [] mSum;
61 void
62 computeRMSForAllVariants()
64 uint32_t v = mTdh.first_variant_id;
65 uint32_t lv = v + mTdh.nvariants;
67 for (; v < lv; v++)
69 double RMS = computeRMSFromAverageExcluding(v);
70 mysqlpp::Query q(mConn.query());
71 q << "UPDATE crossval_rss SET average_rss='" << RMS
72 << "' WHERE leftout_variant_id='" << v << "'";
73 q.execute();
77 double
78 computeRMSFromAverageExcluding(uint32_t aExclude)
80 if (aExclude < mTdh.first_variant_id)
81 throw CompAvgException("Excluded variant out of range.");
83 aExclude -= mTdh.first_variant_id;
84 if (aExclude >= mTdh.nvariants)
85 throw CompAvgException("Excluded variant out of range.");
87 lseek64(mFile, sizeof(mTdh), SEEK_SET);
88 memset(mSum, 0, sizeof(double) * mTdh.noutput_signals);
90 uint32_t i(0);
91 for (; i < mTdh.nvariants; i++)
93 // Skip the inputs...
94 if (mJump != 0)
95 lseek64(mFile, mJump, SEEK_CUR);
97 double * into = (aExclude == i) ? mAnswers : mTemp;
99 if (read(mFile, into, sizeof(double) * mTdh.noutput_signals) !=
100 sizeof(double) * mTdh.noutput_signals)
101 throw CompAvgException("Couldn't load signals.");
103 if (into == mTemp)
105 uint32_t j;
106 for (j = 0; j < mTdh.noutput_signals; j++)
107 mSum[j] += mTemp[j];
111 double factor = 1.0 / (mTdh.nvariants - 1);
113 // RSS isn't robust against a few outliers, so use residual median of
114 // squares instead.
115 std::vector<double> residualsSq(mTdh.noutput_signals);
116 for (i = 0; i < mTdh.noutput_signals; i++)
118 double t = mAnswers[i] - factor * mSum[i];
119 residualsSq[i] = t * t;
121 std::sort(residualsSq.begin(), residualsSq.end());
123 double rms;
125 if ((mTdh.noutput_signals) & 1 == 1)
127 rms = residualsSq[mTdh.noutput_signals / 2];
129 else
131 uint32_t ofs = mTdh.noutput_signals / 2;
132 rms = (residualsSq[ofs - 1] + residualsSq[ofs]) / 2;
135 return rms;
138 private:
139 int mFile;
140 TrainingDatabaseHeader mTdh;
141 mysqlpp::Connection& mConn;
142 double* mAnswers, * mSum, * mTemp;
143 off64_t mJump;
147 main(int argc, char** argv)
149 if (argc < 6)
151 printf("Usage: ComputeAverageRMS db_host db_user db_password db_database trainingfile\n");
152 // printf("Set variant=-1 if RMS is sum for all variants.\n");
153 return 1;
158 mysqlpp::Connection conn(mysqlpp::use_exceptions);
159 conn.connect(argv[4], argv[1], argv[2], argv[3]);
161 AverageComputer ac(conn, argv[5]);
162 ac.computeRMSForAllVariants();
164 catch (std::exception& e)
166 printf("Error: %s\n", e.what());