1 /* ///////////////////////////////////////////////////////////////////////
2 * File: kfold_cross_validator.h
7 * Brief: The kfold_cross_validator class
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H
14 #define EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H
16 /*!\file kfold_cross_validator.h
17 * \brief kfold_cross_validator class
20 /* ///////////////////////////////////////////////////////////////////////
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error kfold_cross_validator.h is not supported by the current compiler.
27 /* ///////////////////////////////////////////////////////////////////////
32 /* ///////////////////////////////////////////////////////////////////////
33 * ::extl::intelligence namespace
35 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
38 /*!brief kfold_cross_validator
40 * \param NetWork the networks type
41 * \param NetVdtr the network validator type
42 * \param K k-fold validation
44 * \ingroup extl_group_intelligence
46 template< typename_param_k NetWork
47 , typename_param_k NetVdtr
50 class kfold_cross_validator
55 typedef kfold_cross_validator class_type
;
56 typedef NetWork network_type
;
57 typedef typename_type_k
network_type::size_type size_type
;
58 typedef typename_type_k
network_type::bool_type bool_type
;
59 typedef typename_type_k
network_type::index_type index_type
;
60 typedef typename_type_k
network_type::float_type float_type
;
61 typedef typename_type_k
network_type::sample_type sample_type
;
62 typedef typename_type_k
network_type::samples_type samples_type
;
63 typedef NetVdtr validator_type
;
69 /// the mean square error, range: [0, 1]
71 /// the error rate, range: [0, 1]
74 validator_type m_validator
;
80 enum { en_kfold
= K
};
83 /// \name Constructors
86 kfold_cross_validator()
96 /// returns the mean square error, range: [0, 1]
97 float_type
mse() const { return m_mse
; }
99 /// returns the error rate, range: [0, 1]
100 float_type
erate() const { return m_erate
; }
106 /// \note sps must have real outputs
107 void validate(network_type
const& network
, samples_type
& sps
, size_type train_n
= 100)
109 // initialize mse & error rate
113 size_type sps_n
= sps
.size();
114 EXTL_ASSERT(size_type(en_kfold
) <= sps_n
);
117 size_type part_n
= sps_n
/ en_kfold
;
118 for (index_type k
= 0; k
< en_kfold
; ++k
)
120 // calculates validated range
121 index_type b
= part_n
* k
;
122 index_type e
= part_n
* (k
+ 1);
123 if (k
== en_kfold
- 1) e
= sps_n
;
125 samples_type vsps
; // validated samples
126 samples_type tsps
; // trained samples
127 for (index_type i
= 0; i
< sps_n
; ++i
)
130 if (i
>= b
&& i
< e
) vsps
.push_back(sps
[i
]);
131 else tsps
.push_back(sps
[i
]);
135 network_type
net(network
);
137 net
.train(tsps
, train_n
);
138 m_validator
.validate(net
, vsps
);
141 m_mse
+= m_validator
.mse();
142 m_erate
+= m_validator
.erate();
152 /* ///////////////////////////////////////////////////////////////////////
153 * ::extl::intelligence namespace
155 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
157 /* //////////////////////////////////////////////////////////////////// */
158 #endif /* EXTL_INTELLIGENCE_ANN_KFOLD_CROSS_VALIDATOR_H */
159 /* //////////////////////////////////////////////////////////////////// */