remove \r
[extl.git] / extl / intelligence / ann / kfold_cross_validator.h
blob9e516511c86ef92232ba0e62e5182bfabaf7d97e
1 /* ///////////////////////////////////////////////////////////////////////
2 * File: kfold_cross_validator.h
4 * Created: 09.04.05
5 * Updated: 09.04.05
7 * Brief: The kfold_cross_validator class
9 * [<Home>]
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H
14 #define EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H
16 /*!\file kfold_cross_validator.h
17 * \brief kfold_cross_validator class
20 /* ///////////////////////////////////////////////////////////////////////
21 * Compatibility
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error kfold_cross_validator.h is not supported by the current compiler.
25 #endif
27 /* ///////////////////////////////////////////////////////////////////////
28 * Includes
30 #include "prefix.h"
32 /* ///////////////////////////////////////////////////////////////////////
33 * ::extl::intelligence namespace
35 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
38 /*!brief kfold_cross_validator
40 * \param NetWork the networks type
41 * \param NetVdtr the network validator type
42 * \param K k-fold validation
44 * \ingroup extl_group_intelligence
46 template< typename_param_k NetWork
47 , typename_param_k NetVdtr
48 , e_size_t K = 10
50 class kfold_cross_validator
52 /// \name Types
53 /// @{
54 public:
55 typedef kfold_cross_validator class_type;
56 typedef NetWork network_type;
57 typedef typename_type_k network_type::size_type size_type;
58 typedef typename_type_k network_type::bool_type bool_type;
59 typedef typename_type_k network_type::index_type index_type;
60 typedef typename_type_k network_type::float_type float_type;
61 typedef typename_type_k network_type::sample_type sample_type;
62 typedef typename_type_k network_type::samples_type samples_type;
63 typedef NetVdtr validator_type;
64 /// @}
66 /// \name Members
67 /// @{
68 private:
69 /// the mean square error, range: [0, 1]
70 float_type m_mse;
71 /// the error rate, range: [0, 1]
72 float_type m_erate;
73 /// network validator
74 validator_type m_validator;
75 /// @}
77 /// \name Constants
78 /// @{
79 public:
80 enum { en_kfold = K };
81 /// @}
83 /// \name Constructors
84 /// @{
85 public:
86 kfold_cross_validator()
87 : m_mse(0)
88 , m_erate(0)
91 /// @}
93 /// \name Attributes
94 /// @{
95 public:
96 /// returns the mean square error, range: [0, 1]
97 float_type mse() const { return m_mse; }
99 /// returns the error rate, range: [0, 1]
100 float_type erate() const { return m_erate; }
101 /// @}
103 /// \name Methods
104 /// @{
105 public:
106 /// \note sps must have real outputs
107 void validate(network_type const& network, samples_type& sps, size_type train_n = 100)
109 // initialize mse & error rate
110 m_mse = 0;
111 m_erate = 0;
113 size_type sps_n = sps.size();
114 EXTL_ASSERT(size_type(en_kfold) <= sps_n);
116 // validate samples
117 size_type part_n = sps_n / en_kfold;
118 for (index_type k = 0; k < en_kfold; ++k)
120 // calculates validated range
121 index_type b = part_n * k;
122 index_type e = part_n * (k + 1);
123 if (k == en_kfold - 1) e = sps_n;
125 samples_type vsps; // validated samples
126 samples_type tsps; // trained samples
127 for (index_type i = 0; i < sps_n; ++i)
129 // split samples
130 if (i >= b && i < e) vsps.push_back(sps[i]);
131 else tsps.push_back(sps[i]);
134 // train network
135 network_type net(network);
136 net.init();
137 net.train(tsps, train_n);
138 m_validator.validate(net, vsps);
140 // validate network
141 m_mse += m_validator.mse();
142 m_erate += m_validator.erate();
145 m_mse /= en_kfold;
146 m_erate /= en_kfold;
148 /// @}
152 /* ///////////////////////////////////////////////////////////////////////
153 * ::extl::intelligence namespace
155 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
157 /* //////////////////////////////////////////////////////////////////// */
158 #endif /* EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H */
159 /* //////////////////////////////////////////////////////////////////// */