remove \r
[extl.git] / extl / intelligence / ann / bagging_networks.h
blobc046d09342125f4bc36afc6622f58ced0621c77e
1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bagging_networks.h
4 * Created: 08.12.27
5 * Updated: 08.12.27
7 * Brief: The bagging_networks class
9 * [<Home>]
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H
14 #define EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H
16 /*!\file bagging_networks.h
17 * \brief bagging_networks class
20 /* ///////////////////////////////////////////////////////////////////////
21 * Compatibility
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error bagging_networks.h is not supported by the current compiler.
25 #endif
27 /* ///////////////////////////////////////////////////////////////////////
28 * Includes
30 #include "prefix.h"
31 /* ///////////////////////////////////////////////////////////////////////
32 * ::extl::intelligence namespace
34 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
36 /*!brief bagging_networks
38 * \param NetWork the network type
40 * \ingroup extl_group_intelligence
42 template<typename_param_k NetWork>
43 class bagging_networks
46 /// \name Public Types
47 /// @{
48 public:
49 typedef bagging_networks class_type;
50 typedef NetWork network_type;
51 typedef typename_type_k network_type::size_type size_type;
52 typedef typename_type_k network_type::bool_type bool_type;
53 typedef typename_type_k network_type::index_type index_type;
54 typedef typename_type_k network_type::float_type float_type;
55 typedef typename_type_k network_type::sample_type sample_type;
56 typedef typename_type_k network_type::samples_type samples_type;
57 typedef typename_type_k network_type::rand_type rand_type;
58 typedef typename_type_k buffer_selector<network_type*>::buffer_type networks_type;
59 typedef typename_type_k buffer_selector<float_type>::buffer_type float_buffer_type;
60 /// @}
62 /// \name Members
63 /// @{
64 private:
65 networks_type m_networks;
66 float_buffer_type m_foutputs;
67 /// @}
69 /// \name Constructors
70 /// @{
71 public:
72 explicit_k bagging_networks ( network_type* network_0 = NULL
73 , network_type* network_1 = NULL
74 , network_type* network_2 = NULL
75 , network_type* network_3 = NULL
76 , network_type* network_4 = NULL
77 , network_type* network_5 = NULL
78 , network_type* network_6 = NULL
79 , network_type* network_7 = NULL
80 , network_type* network_8 = NULL
81 , network_type* network_9 = NULL
83 : m_networks()
84 , m_foutputs()
86 if (NULL != network_0) m_networks.push_back(network_0);
87 if (NULL != network_1) m_networks.push_back(network_1);
88 if (NULL != network_2) m_networks.push_back(network_2);
89 if (NULL != network_3) m_networks.push_back(network_3);
90 if (NULL != network_4) m_networks.push_back(network_4);
91 if (NULL != network_5) m_networks.push_back(network_5);
92 if (NULL != network_6) m_networks.push_back(network_6);
93 if (NULL != network_7) m_networks.push_back(network_7);
94 if (NULL != network_8) m_networks.push_back(network_8);
95 if (NULL != network_9) m_networks.push_back(network_9);
97 bagging_networks(class_type const& rhs)
98 : m_networks(rhs.m_networks)
99 , m_foutputs(rhs.m_foutputs)
102 /// @}
104 /// \name Accessors
105 /// @{
106 public:
107 networks_type& networks() { return m_networks; }
108 networks_type const& networks() const { return m_networks; }
110 float_type foutput(index_type i) const;
111 bool_type boutput(index_type i) const;
113 /// returns the random generator
114 rand_type& rand() { return rand_type(); }
115 /// returns the const layers
116 rand_type const& rand() const { return rand_type(); }
117 /// @}
119 /// \name Operators
120 /// @{
121 public:
122 void swap(class_type& rhs);
123 /// @}
125 /// \name Methods
126 /// @{
127 public:
128 /// train samples and return mse
129 void train(samples_type& sps, size_type train_n = 1);
130 /// run the given sample
131 void run(sample_type& sp);
132 /// run the given samples
133 //void run(samples_type& sps);
134 /// @}
136 /// \name helpers
137 /// @{
138 private:
139 float_buffer_type& foutputs() { return m_foutputs; }
140 float_buffer_type const& foutputs() const { return m_foutputs; }
141 /// @}
143 /* ///////////////////////////////////////////////////////////////////////
144 * Implemention
146 template<typename_param_k NetWork>
147 inline void bagging_networks<NetWork>::swap(class_type& rhs)
149 m_networks.swap(rhs.m_networks);
150 m_foutputs.swap(rhs.m_foutputs);
152 template<typename_param_k NetWork>
153 inline void bagging_networks<NetWork>::train(samples_type& sps, size_type train_n)
155 index_type i, j;
156 size_type sps_n = sps.size();
157 size_type networks_n = networks().size();
158 for (i = 0; i < networks_n; ++i)
160 EXTL_ASSERT(NULL != networks()[i]);
161 // select sps_n samples randomly
162 samples_type tsps;
163 for (j = 0; j < sps_n; ++j)
165 index_type sps_i = networks()[i]->rand().generate(0, sps_n);
166 tsps.push_back(sps[sps_i]);
169 // train
170 networks()[i]->train(tsps, train_n);
174 template<typename_param_k NetWork>
175 inline void bagging_networks<NetWork>::run(sample_type& sp)
177 index_type i, j;
179 // initialize foutputs
180 size_type output_n = sp.output_size();
181 foutputs().resize(output_n);
182 for (i = 0; i < output_n; ++i)
183 foutputs()[i] = 0;
185 size_type networks_n = networks().size();
186 for (i = 0; i < networks_n; ++i)
188 EXTL_ASSERT(NULL != networks()[i]);
189 networks()[i]->run(sp);
191 // accumulate outputs
192 for (j = 0; j < output_n; ++j)
193 foutputs()[j] += static_cast<float_type>(sp.get_boutput(j));
196 // calculate the average of the outputs
197 for (j = 0; j < output_n; ++j)
199 foutputs()[j] /= networks_n;
200 sp.set_foutput(j, foutputs()[j]);
203 /*template<typename_param_k NetWork>
204 inline void bagging_networks<NetWork>::run(samples_type& sps)
206 size_type sps_n = sps.size();
207 for (index_type i = 0; i < sps_n; ++i)
208 run(sps[i]);
211 template<typename_param_k NetWork>
212 inline typename_type_ret_k bagging_networks<NetWork>::
213 float_type bagging_networks<NetWork>::foutput(index_type i) const
215 EXTL_ASSERT(i < m_foutputs.size());
216 return m_foutputs[i];
218 template<typename_param_k NetWork>
219 inline typename_type_ret_k bagging_networks<NetWork>::
220 bool_type bagging_networks<NetWork>::boutput(index_type i) const
222 EXTL_ASSERT(i < m_foutputs.size());
223 return xtl_round45(m_foutputs[i] == 1);
225 /* ///////////////////////////////////////////////////////////////////////
226 * ::extl::intelligence namespace
228 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
230 /* //////////////////////////////////////////////////////////////////// */
231 #endif /* EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H */
232 /* //////////////////////////////////////////////////////////////////// */