1 #include "TrainingDatabase.hpp"
4 #include "Utilities.hxx"
5 #include <mysql++/mysql++.h>
8 : public std::exception
11 CompAvgException(const std::string
& aWhat
)
16 ~CompAvgException() throw()
20 const char* what() const throw()
32 AverageComputer(mysqlpp::Connection
& aConn
, const char* aFilename
)
33 : mConn(aConn
), mAnswers(NULL
), mSum(NULL
), mTemp(NULL
)
35 mFile
= open(aFilename
, O_RDONLY
| O_LARGEFILE
);
37 throw CompAvgException("Can't open training file.");
39 if (read(mFile
, &mTdh
, sizeof(TrainingDatabaseHeader
)) != sizeof(TrainingDatabaseHeader
))
40 throw CompAvgException("Can't read the training file.");
42 mJump
= mTdh
.ninput_signals
* sizeof(double);
44 mAnswers
= new double[mTdh
.noutput_signals
];
45 mSum
= new double[mTdh
.noutput_signals
];
46 mTemp
= new double[mTdh
.noutput_signals
];
62 computeRSSForAllVariants()
64 uint32_t v
= mTdh
.first_variant_id
;
65 uint32_t lv
= v
+ mTdh
.nvariants
;
69 double RSS
= computeRSSFromAverageExcluding(v
);
70 mysqlpp::Query
q(mConn
.query());
71 q
<< "UPDATE crossval_rss SET average_rss='" << RSS
72 << "' WHERE leftout_variant_id='" << v
<< "'";
78 computeRSSFromAverageExcluding(uint32_t aExclude
)
80 if (aExclude
< mTdh
.first_variant_id
)
81 throw CompAvgException("Excluded variant out of range.");
83 aExclude
-= mTdh
.first_variant_id
;
84 if (aExclude
>= mTdh
.nvariants
)
85 throw CompAvgException("Excluded variant out of range.");
87 lseek64(mFile
, sizeof(mTdh
), SEEK_SET
);
88 memset(mSum
, 0, sizeof(double) * mTdh
.noutput_signals
);
91 for (; i
< mTdh
.nvariants
; i
++)
95 lseek64(mFile
, mJump
, SEEK_CUR
);
97 double * into
= (aExclude
== i
) ? mAnswers
: mTemp
;
99 if (read(mFile
, into
, sizeof(double) * mTdh
.noutput_signals
) !=
100 sizeof(double) * mTdh
.noutput_signals
)
101 throw CompAvgException("Couldn't load signals.");
106 for (j
= 0; j
< mTdh
.noutput_signals
; j
++)
111 double factor
= 1.0 / (mTdh
.nvariants
- 1);
113 for (i
= 0; i
< mTdh
.noutput_signals
; i
++)
115 double t
= mAnswers
[i
] - factor
* mSum
[i
];
124 TrainingDatabaseHeader mTdh
;
125 mysqlpp::Connection
& mConn
;
126 double* mAnswers
, * mSum
, * mTemp
;
131 main(int argc
, char** argv
)
135 printf("Usage: ComputeAverageRSS db_host db_user db_password db_database trainingfile\n");
136 // printf("Set variant=-1 if RSS is sum for all variants.\n");
142 mysqlpp::Connection
conn(mysqlpp::use_exceptions
);
143 conn
.connect(argv
[4], argv
[1], argv
[2], argv
[3]);
145 AverageComputer
ac(conn
, argv
[5]);
146 ac
.computeRSSForAllVariants();
148 catch (std::exception
& e
)
150 printf("Error: %s\n", e
.what());