1 #include "TrainingDatabase.hpp"
4 #include "Utilities.hxx"
5 #include <mysql++/mysql++.h>
8 : public std::exception
11 CompAvgException(const std::string
& aWhat
)
16 ~CompAvgException() throw()
20 const char* what() const throw()
32 AverageComputer(mysqlpp::Connection
& aConn
, const char* aFilename
)
33 : mConn(aConn
), mAnswers(NULL
), mSum(NULL
), mTemp(NULL
)
35 mFile
= open(aFilename
, O_RDONLY
| O_LARGEFILE
);
37 throw CompAvgException("Can't open training file.");
39 if (read(mFile
, &mTdh
, sizeof(TrainingDatabaseHeader
)) != sizeof(TrainingDatabaseHeader
))
40 throw CompAvgException("Can't read the training file.");
42 mJump
= mTdh
.ninput_signals
* sizeof(double);
44 mAnswers
= new double[mTdh
.noutput_signals
];
45 mSum
= new double[mTdh
.noutput_signals
];
46 mTemp
= new double[mTdh
.noutput_signals
];
62 computeRMSForAllVariants()
64 uint32_t v
= mTdh
.first_variant_id
;
65 uint32_t lv
= v
+ mTdh
.nvariants
;
69 double RMS
= computeRMSFromAverageExcluding(v
);
70 mysqlpp::Query
q(mConn
.query());
71 q
<< "UPDATE crossval_rss SET average_rss='" << RMS
72 << "' WHERE leftout_variant_id='" << v
<< "'";
78 computeRMSFromAverageExcluding(uint32_t aExclude
)
80 if (aExclude
< mTdh
.first_variant_id
)
81 throw CompAvgException("Excluded variant out of range.");
83 aExclude
-= mTdh
.first_variant_id
;
84 if (aExclude
>= mTdh
.nvariants
)
85 throw CompAvgException("Excluded variant out of range.");
87 lseek64(mFile
, sizeof(mTdh
), SEEK_SET
);
88 memset(mSum
, 0, sizeof(double) * mTdh
.noutput_signals
);
91 for (; i
< mTdh
.nvariants
; i
++)
95 lseek64(mFile
, mJump
, SEEK_CUR
);
97 double * into
= (aExclude
== i
) ? mAnswers
: mTemp
;
99 if (read(mFile
, into
, sizeof(double) * mTdh
.noutput_signals
) !=
100 sizeof(double) * mTdh
.noutput_signals
)
101 throw CompAvgException("Couldn't load signals.");
106 for (j
= 0; j
< mTdh
.noutput_signals
; j
++)
111 double factor
= 1.0 / (mTdh
.nvariants
- 1);
113 // RSS isn't robust against a few outliers, so use residual median of
115 std::vector
<double> residualsSq(mTdh
.noutput_signals
);
116 for (i
= 0; i
< mTdh
.noutput_signals
; i
++)
118 double t
= mAnswers
[i
] - factor
* mSum
[i
];
119 residualsSq
[i
] = t
* t
;
121 std::sort(residualsSq
.begin(), residualsSq
.end());
125 if ((mTdh
.noutput_signals
) & 1 == 1)
127 rms
= residualsSq
[mTdh
.noutput_signals
/ 2];
131 uint32_t ofs
= mTdh
.noutput_signals
/ 2;
132 rms
= (residualsSq
[ofs
- 1] + residualsSq
[ofs
]) / 2;
140 TrainingDatabaseHeader mTdh
;
141 mysqlpp::Connection
& mConn
;
142 double* mAnswers
, * mSum
, * mTemp
;
147 main(int argc
, char** argv
)
151 printf("Usage: ComputeAverageRMS db_host db_user db_password db_database trainingfile\n");
152 // printf("Set variant=-1 if RMS is sum for all variants.\n");
158 mysqlpp::Connection
conn(mysqlpp::use_exceptions
);
159 conn
.connect(argv
[4], argv
[1], argv
[2], argv
[3]);
161 AverageComputer
ac(conn
, argv
[5]);
162 ac
.computeRMSForAllVariants();
164 catch (std::exception
& e
)
166 printf("Error: %s\n", e
.what());