remove \r
[extl.git] / extl / intelligence / ann / detail / bp_network_impl.h
blob43a8f3f00611c25ea797f2c4edef2d5bb837f2d4
1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bp_network_impl.h
4 * Created: 08.12.17
5 * Updated: 08.12.17
7 * Brief: The bp_network_impl class
9 * [<Home>]
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H
14 #define EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H
16 /*!\file bp_network_impl.h
17 * \brief bp_network_impl class
20 /* ///////////////////////////////////////////////////////////////////////
21 * Compatibility
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error bp_network_impl.h is not supported by the current compiler.
25 #endif
27 /* ///////////////////////////////////////////////////////////////////////
28 * Includes
30 #include "../prefix.h"
31 #include "basic_network_base.h"
32 #include "../basic_network_validator.h"
33 /* ///////////////////////////////////////////////////////////////////////
34 * ::extl::intelligence::detail namespace
36 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
37 EXTL_DETAIL_BEGIN_NAMESPACE
39 /*!brief bp_network_impl
41 * \param Dev The derived type
42 * \param InN the input demension
43 * \param OutN the output demension
44 * \param Nt The network traits type
46 * <pre>
47 * note: node id: 0 1 2 3 4 5
48 * [0]: hreshold
50 * input hide0 output
51 * [0]--------------\
52 * \ /[4]
53 * [1]---->[3]---->|
54 * / \[5]
55 * [2]---
57 * </pre>
59 * \ingroup extl_group_intelligence
61 template< typename_param_k Dev
62 , e_size_t InN
63 , e_size_t OutN
64 , typename_param_k Nt
66 class bp_network_impl
67 : public basic_network_base<Dev, InN, OutN, Nt>
69 /// \name Types
70 /// @{
71 protected:
72 typedef basic_network_base<Dev, InN, OutN, Nt> base_type;
73 /// @}
75 /// \name Types
76 /// @{
77 public:
78 typedef bp_network_impl class_type;
79 typedef Dev derived_type;
80 typedef typename_type_k base_type::size_type size_type;
81 typedef typename_type_k base_type::bool_type bool_type;
82 typedef typename_type_k base_type::index_type index_type;
83 typedef typename_type_k base_type::network_traits_type network_traits_type;
84 typedef typename_type_k base_type::rand_type rand_type;
85 typedef typename_type_k base_type::node_type node_type;
86 typedef typename_type_k base_type::weight_type weight_type;
87 typedef typename_type_k base_type::float_type float_type;
88 typedef typename_type_k base_type::net_type net_type;
89 typedef typename_type_k base_type::net_in_adjnode_iterator net_in_adjnode_iterator;
90 typedef typename_type_k base_type::net_out_adjnode_iterator net_out_adjnode_iterator;
91 typedef typename_type_k base_type::layers_type layers_type;
92 typedef typename_type_k network_traits_type::afunc_type afunc_type;
93 typedef typename_type_k network_traits_type::sample_type sample_type;
94 typedef typename_type_k buffer_selector<sample_type>::large_buffer_type samples_type;
95 typedef typename_type_k buffer_selector<float_type>::buffer_type floats_type;
96 /// @}
98 /// \name Members
99 /// @{
100 private:
101 /// the learning rate, range: [0, 1]
102 float_type m_lrate;
103 /// the activation function
104 afunc_type m_afunc;
105 /// @}
107 private:
108 bp_network_impl(class_type const& rhs);
109 class_type& operator =(class_type& rhs);
111 /// \name Constructors
112 /// @{
113 public:
114 explicit_k bp_network_impl ( layers_type const& layers
115 , float_type lrate
116 , float_type hr
117 , afunc_type const& afunc
118 , rand_type const& rand
120 : base_type(layers, hr, rand)
121 , m_lrate(lrate)
122 , m_afunc(afunc)
125 explicit_k bp_network_impl(derived_type const& rhs)
126 : base_type(rhs)
127 , m_lrate(static_cast<class_type const&>(rhs).m_lrate)
128 , m_afunc(static_cast<class_type const&>(rhs).m_afunc)
132 /// @}
134 /// \name Attributes
135 /// @{
136 public:
137 /// gets the activation function
138 afunc_type& afunc() { return m_afunc; }
139 /// gets the activation function
140 afunc_type const& afunc() const { return m_afunc; }
142 /// gets the learning rate, range: [0, 1]
143 float_type lrate() const { return m_lrate; }
144 /// sets the learning rate, range: [0, 1]
145 void lrate(float_type lr) { m_lrate = lr; }
146 /// @}
148 /// \name Mutators
149 /// @{
150 public:
151 void swap(derived_type& rhs);
152 derived_type& operator =(derived_type const& rhs);
153 /// @}
155 /// \name Methods
156 /// @{
157 public:
158 /// train the single sample
159 void train(sample_type& sp);
160 /// train the multi-samples
161 void train(samples_type& sps, size_type train_n = 1);
162 /// train the multi-samples with weights
163 void train(samples_type& sps, floats_type const& sps_ws, size_type train_n = 1);
164 /// run the given sample
165 void run(sample_type& sp);
166 /// run the given samples
167 //void run(samples_type& sps);
168 /// @}
170 /// \name Others
171 /// @{
172 protected:
173 derived_type& derive() { return static_cast<derived_type&>(*this); }
174 derived_type const& derive() const { return static_cast<derived_type const&>(*this); }
175 /// @}
177 /// \name Helpers
178 /// @{
179 protected:
180 /// forward calculation: input sample & calcucate output
181 void forward(sample_type& sp);
182 /// backward calculation: modify weight
183 void backward();
184 /// @}
187 /* ///////////////////////////////////////////////////////////////////////
188 * Macro
190 // Template declaration
191 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
192 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
193 #endif
195 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL \
196 template< typename_param_k Dev \
197 , e_size_t InN \
198 , e_size_t OutN \
199 , typename_param_k Nt \
202 // Class qualification
203 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
204 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
205 #endif
207 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL bp_network_impl<Dev, InN, OutN, Nt>
209 // Class qualification
210 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
211 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
212 #endif
214 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL(ret_type) \
215 typename_type_ret_k EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::ret_type \
216 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
218 /* ///////////////////////////////////////////////////////////////////////
219 * Implementation
222 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
223 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::swap(derived_type& rhs)
225 std_swap(m_afunc, static_cast<class_type&>(rhs).m_afunc);
226 std_swap(m_lrate, static_cast<class_type&>(rhs).m_lrate);
228 static_cast<base_type&>(*this).swap(rhs);
231 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
232 inline EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL(derived_type&)::operator =(derived_type const& rhs)
234 derived_type(rhs).swap(derive());
235 return derive();
238 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
239 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::forward(sample_type& sp)
241 EXTL_ASSERT(derive().is_valid());
243 // input sample
244 index_type i;
245 size_type input_n = sp.input_size();
246 EXTL_ASSERT(this->layers().inodes_size() == input_n + 1); // + a hreshold
247 for (i = 0; i < input_n; ++i)
249 // note: this->net().at(0) is hreshold
250 this->net().at(i + 1).output(sp.get_finput(i));
253 // calculate hide-layer & output-layer output
254 size_type ihs_n = this->layers().nodes_size() - this->layers().onodes_size();
255 index_type start_n = this->layers().inodes_size();
256 size_type nodes_n = this->layers().nodes_size();
257 EXTL_ASSERT(start_n <= nodes_n);
258 for (i = start_n; i < nodes_n; ++i)
260 // weighted_sum = sum(weight_k * val_k)
261 float_type weighted_sum = 0;
262 net_in_adjnode_iterator pe = this->net().in_adjnode_end(i);
263 for (net_in_adjnode_iterator pi = this->net().in_adjnode_begin(i); pi != pe; ++pi)
265 weighted_sum += this->net().weight(*pi, i).value() * this->net().at(*pi).output();
267 // activate output
268 node_type& node = this->net().at(i);
269 derive().do_activate_output(node, weighted_sum);
271 // calculate the error of real_output & the output-layer node
272 if (node.in_olayer())
274 // calculate output error
275 node.error(sp.get_freal(i - ihs_n) - node.output());
279 EXTL_ASSERT(derive().is_valid());
282 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
283 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::backward()
285 EXTL_ASSERT(derive().is_valid());
287 // modify output-layer weight
288 index_type start_n = this->layers().nodes_size() - 1;
289 index_type end_n = this->layers().inodes_size();
290 for (index_type i = start_n + 1; i >= end_n + 1; --i) // prevent generate bug for index_type is unsigned type
292 index_type cur = i - 1;
293 node_type& node = this->net().at(cur);
295 // prepare modify weight
296 derive().do_prepare_modify_weight(node);
298 // modify weight
299 net_in_adjnode_iterator pe = this->net().in_adjnode_end(cur);
300 for (net_in_adjnode_iterator pi = this->net().in_adjnode_begin(cur); pi != pe; ++pi)
302 derive().do_modify_weight(this->net().weight(*pi, cur), node, this->net().at(*pi));
306 //EXTL_TRACEA("\n");
308 EXTL_ASSERT(derive().is_valid());
310 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
311 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(sample_type& sp)
313 EXTL_ASSERT(derive().is_valid());
315 // forward calculation: input sample & calcucate output
316 derive().forward(sp);
317 /// backward calculation: modify weight
318 derive().backward();
320 EXTL_ASSERT(derive().is_valid());
323 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
324 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(samples_type& sps, size_type train_n)
326 EXTL_ASSERT(derive().is_valid());
328 #if 0
329 size_type sps_n = sps.size();
330 for (size_type train_i = 0; train_i < train_n; ++train_i)
331 for (index_type i = 0; i < sps_n; ++i)
332 train(sps[i]);
333 #else // for optimization: adaptive learning rate
334 typedef basic_network_validator<derived_type> validator_type;
335 validator_type validator;
337 size_type sps_n = sps.size();
338 float_type old_mse = -1;
339 for (size_type train_i = 0; train_i < train_n; ++train_i)
341 for (index_type i = 0; i < sps_n; ++i)
342 train(sps[i]);
344 // modify learning rate
345 validator.validate(derive(), sps);
346 if (old_mse < 0) old_mse = validator.mse();
347 else if (validator.mse() < old_mse) m_lrate *= 1.06;
348 else if (validator.mse() > (1.06 * old_mse)) m_lrate *= 0.8;
349 old_mse = validator.mse();
351 #endif
352 EXTL_ASSERT(derive().is_valid());
355 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
356 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(samples_type& sps, floats_type const& sps_ws, size_type train_n)
358 EXTL_ASSERT(derive().is_valid());
359 #if 0
360 size_type sps_n = sps.size();
361 for (size_type train_i = 0; train_i < train_n; ++train_i)
363 for (index_type i = 0; i < sps_n; ++i)
365 // select sample
366 float_type p = derive().rand().fgenerate(0, 1);
367 float_type sum = 0;
368 index_type j;
369 for (j = 0; (sum <= p) && (j < sps_n); ++j)
370 sum += sps_ws[j];
371 j = j > 0? j - 1 : 0;
373 // train sample
374 train(sps[j]);
377 #else // for optimization: adaptive learning rate
379 typedef basic_network_validator<derived_type> validator_type;
380 validator_type validator;
382 size_type sps_n = sps.size();
383 float_type old_mse = -1;
384 for (size_type train_i = 0; train_i < train_n; ++train_i)
386 for (index_type i = 0; i < sps_n; ++i)
388 // select sample
389 float_type p = derive().rand().fgenerate(0, 1);
390 float_type sum = 0;
391 index_type j;
392 for (j = 0; (sum <= p) && (j < sps_n); ++j)
393 sum += sps_ws[j];
394 j = j > 0? j - 1 : 0;
396 // train sample
397 train(sps[j]);
400 // modify learning rate
401 validator.validate(derive(), sps);
402 if (old_mse < 0) old_mse = validator.mse();
403 else if (validator.mse() < old_mse) m_lrate *= 1.06;
404 else if (validator.mse() > (1.06 * old_mse)) m_lrate *= 0.8;
405 old_mse = validator.mse();
407 #endif
408 EXTL_ASSERT(derive().is_valid());
411 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
412 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::run(sample_type& sp)
414 EXTL_ASSERT(derive().is_valid());
416 // forward calculation: input sample & calcucate output
417 derive().forward(sp);
419 // get output
420 index_type start_n = this->layers().nodes_size() - this->layers().onodes_size();
421 index_type end_n = this->layers().nodes_size();
422 EXTL_ASSERT(start_n <= end_n);
423 EXTL_ASSERT(sp.output_size() == this->layers().onodes_size());
424 for (index_type i = start_n; i < end_n; ++i)
426 // the sample type is the bit array type
427 sp.set_foutput(i - start_n, this->net().at(i).output());
429 EXTL_ASSERT(derive().is_valid());
431 /*EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
432 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::run(samples_type& sps)
434 EXTL_ASSERT(derive().is_valid());
436 size_type sps_n = sps.size();
437 for (index_type i = 0; i < sps_n; ++i)
438 run(sps[i]);
440 EXTL_ASSERT(derive().is_valid());
443 /* //////////////////////////////////////////////////////////////////// */
444 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
445 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
446 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
449 /* ///////////////////////////////////////////////////////////////////////
450 * ::extl::intelligence::detail namespace
452 EXTL_DETAIL_END_NAMESPACE
453 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
455 /* //////////////////////////////////////////////////////////////////// */
456 #endif /* EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H */
457 /* //////////////////////////////////////////////////////////////////// */