1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bp_network_impl.h
7 * Brief: The bp_network_impl class
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H
14 #define EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H
16 /*!\file bp_network_impl.h
17 * \brief bp_network_impl class
20 /* ///////////////////////////////////////////////////////////////////////
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error bp_network_impl.h is not supported by the current compiler.
27 /* ///////////////////////////////////////////////////////////////////////
30 #include "../prefix.h"
31 #include "basic_network_base.h"
32 #include "../basic_network_validator.h"
33 /* ///////////////////////////////////////////////////////////////////////
34 * ::extl::intelligence::detail namespace
36 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
37 EXTL_DETAIL_BEGIN_NAMESPACE
39 /*!brief bp_network_impl
41 * \param Dev The derived type
42 * \param InN the input demension
43 * \param OutN the output demension
44 * \param Nt The network traits type
47 * note: node id: 0 1 2 3 4 5
59 * \ingroup extl_group_intelligence
61 template< typename_param_k Dev
67 : public basic_network_base
<Dev
, InN
, OutN
, Nt
>
72 typedef basic_network_base
<Dev
, InN
, OutN
, Nt
> base_type
;
78 typedef bp_network_impl class_type
;
79 typedef Dev derived_type
;
80 typedef typename_type_k
base_type::size_type size_type
;
81 typedef typename_type_k
base_type::bool_type bool_type
;
82 typedef typename_type_k
base_type::index_type index_type
;
83 typedef typename_type_k
base_type::network_traits_type network_traits_type
;
84 typedef typename_type_k
base_type::rand_type rand_type
;
85 typedef typename_type_k
base_type::node_type node_type
;
86 typedef typename_type_k
base_type::weight_type weight_type
;
87 typedef typename_type_k
base_type::float_type float_type
;
88 typedef typename_type_k
base_type::net_type net_type
;
89 typedef typename_type_k
base_type::net_in_adjnode_iterator net_in_adjnode_iterator
;
90 typedef typename_type_k
base_type::net_out_adjnode_iterator net_out_adjnode_iterator
;
91 typedef typename_type_k
base_type::layers_type layers_type
;
92 typedef typename_type_k
network_traits_type::afunc_type afunc_type
;
93 typedef typename_type_k
network_traits_type::sample_type sample_type
;
94 typedef typename_type_k buffer_selector
<sample_type
>::large_buffer_type samples_type
;
95 typedef typename_type_k buffer_selector
<float_type
>::buffer_type floats_type
;
101 /// the learning rate, range: [0, 1]
103 /// the activation function
108 bp_network_impl(class_type
const& rhs
);
109 class_type
& operator =(class_type
& rhs
);
111 /// \name Constructors
114 explicit_k
bp_network_impl ( layers_type
const& layers
117 , afunc_type
const& afunc
118 , rand_type
const& rand
120 : base_type(layers
, hr
, rand
)
125 explicit_k
bp_network_impl(derived_type
const& rhs
)
127 , m_lrate(static_cast<class_type
const&>(rhs
).m_lrate
)
128 , m_afunc(static_cast<class_type
const&>(rhs
).m_afunc
)
137 /// gets the activation function
138 afunc_type
& afunc() { return m_afunc
; }
139 /// gets the activation function
140 afunc_type
const& afunc() const { return m_afunc
; }
142 /// gets the learning rate, range: [0, 1]
143 float_type
lrate() const { return m_lrate
; }
144 /// sets the learning rate, range: [0, 1]
145 void lrate(float_type lr
) { m_lrate
= lr
; }
151 void swap(derived_type
& rhs
);
152 derived_type
& operator =(derived_type
const& rhs
);
158 /// train the single sample
159 void train(sample_type
& sp
);
160 /// train the multi-samples
161 void train(samples_type
& sps
, size_type train_n
= 1);
162 /// train the multi-samples with weights
163 void train(samples_type
& sps
, floats_type
const& sps_ws
, size_type train_n
= 1);
164 /// run the given sample
165 void run(sample_type
& sp
);
166 /// run the given samples
167 //void run(samples_type& sps);
173 derived_type
& derive() { return static_cast<derived_type
&>(*this); }
174 derived_type
const& derive() const { return static_cast<derived_type
const&>(*this); }
180 /// forward calculation: input sample & calcucate output
181 void forward(sample_type
& sp
);
182 /// backward calculation: modify weight
187 /* ///////////////////////////////////////////////////////////////////////
190 // Template declaration
191 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
192 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
195 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL \
196 template< typename_param_k Dev \
199 , typename_param_k Nt \
202 // Class qualification
203 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
204 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
207 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL bp_network_impl<Dev, InN, OutN, Nt>
209 // Class qualification
210 #ifdef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
211 # undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
214 #define EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL(ret_type) \
215 typename_type_ret_k EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::ret_type \
216 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
218 /* ///////////////////////////////////////////////////////////////////////
222 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
223 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::swap(derived_type
& rhs
)
225 std_swap(m_afunc
, static_cast<class_type
&>(rhs
).m_afunc
);
226 std_swap(m_lrate
, static_cast<class_type
&>(rhs
).m_lrate
);
228 static_cast<base_type
&>(*this).swap(rhs
);
231 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
232 inline EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL(derived_type
&)::operator =(derived_type
const& rhs
)
234 derived_type(rhs
).swap(derive());
238 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
239 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::forward(sample_type
& sp
)
241 EXTL_ASSERT(derive().is_valid());
245 size_type input_n
= sp
.input_size();
246 EXTL_ASSERT(this->layers().inodes_size() == input_n
+ 1); // + a hreshold
247 for (i
= 0; i
< input_n
; ++i
)
249 // note: this->net().at(0) is hreshold
250 this->net().at(i
+ 1).output(sp
.get_finput(i
));
253 // calculate hide-layer & output-layer output
254 size_type ihs_n
= this->layers().nodes_size() - this->layers().onodes_size();
255 index_type start_n
= this->layers().inodes_size();
256 size_type nodes_n
= this->layers().nodes_size();
257 EXTL_ASSERT(start_n
<= nodes_n
);
258 for (i
= start_n
; i
< nodes_n
; ++i
)
260 // weighted_sum = sum(weight_k * val_k)
261 float_type weighted_sum
= 0;
262 net_in_adjnode_iterator pe
= this->net().in_adjnode_end(i
);
263 for (net_in_adjnode_iterator pi
= this->net().in_adjnode_begin(i
); pi
!= pe
; ++pi
)
265 weighted_sum
+= this->net().weight(*pi
, i
).value() * this->net().at(*pi
).output();
268 node_type
& node
= this->net().at(i
);
269 derive().do_activate_output(node
, weighted_sum
);
271 // calculate the error of real_output & the output-layer node
272 if (node
.in_olayer())
274 // calculate output error
275 node
.error(sp
.get_freal(i
- ihs_n
) - node
.output());
279 EXTL_ASSERT(derive().is_valid());
282 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
283 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::backward()
285 EXTL_ASSERT(derive().is_valid());
287 // modify output-layer weight
288 index_type start_n
= this->layers().nodes_size() - 1;
289 index_type end_n
= this->layers().inodes_size();
290 for (index_type i
= start_n
+ 1; i
>= end_n
+ 1; --i
) // prevent generate bug for index_type is unsigned type
292 index_type cur
= i
- 1;
293 node_type
& node
= this->net().at(cur
);
295 // prepare modify weight
296 derive().do_prepare_modify_weight(node
);
299 net_in_adjnode_iterator pe
= this->net().in_adjnode_end(cur
);
300 for (net_in_adjnode_iterator pi
= this->net().in_adjnode_begin(cur
); pi
!= pe
; ++pi
)
302 derive().do_modify_weight(this->net().weight(*pi
, cur
), node
, this->net().at(*pi
));
308 EXTL_ASSERT(derive().is_valid());
310 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
311 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(sample_type
& sp
)
313 EXTL_ASSERT(derive().is_valid());
315 // forward calculation: input sample & calcucate output
316 derive().forward(sp
);
317 /// backward calculation: modify weight
320 EXTL_ASSERT(derive().is_valid());
323 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
324 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(samples_type
& sps
, size_type train_n
)
326 EXTL_ASSERT(derive().is_valid());
329 size_type sps_n
= sps
.size();
330 for (size_type train_i
= 0; train_i
< train_n
; ++train_i
)
331 for (index_type i
= 0; i
< sps_n
; ++i
)
333 #else // for optimization: adaptive learning rate
334 typedef basic_network_validator
<derived_type
> validator_type
;
335 validator_type validator
;
337 size_type sps_n
= sps
.size();
338 float_type old_mse
= -1;
339 for (size_type train_i
= 0; train_i
< train_n
; ++train_i
)
341 for (index_type i
= 0; i
< sps_n
; ++i
)
344 // modify learning rate
345 validator
.validate(derive(), sps
);
346 if (old_mse
< 0) old_mse
= validator
.mse();
347 else if (validator
.mse() < old_mse
) m_lrate
*= 1.06;
348 else if (validator
.mse() > (1.06 * old_mse
)) m_lrate
*= 0.8;
349 old_mse
= validator
.mse();
352 EXTL_ASSERT(derive().is_valid());
355 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
356 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::train(samples_type
& sps
, floats_type
const& sps_ws
, size_type train_n
)
358 EXTL_ASSERT(derive().is_valid());
360 size_type sps_n
= sps
.size();
361 for (size_type train_i
= 0; train_i
< train_n
; ++train_i
)
363 for (index_type i
= 0; i
< sps_n
; ++i
)
366 float_type p
= derive().rand().fgenerate(0, 1);
369 for (j
= 0; (sum
<= p
) && (j
< sps_n
); ++j
)
371 j
= j
> 0? j
- 1 : 0;
377 #else // for optimization: adaptive learning rate
379 typedef basic_network_validator
<derived_type
> validator_type
;
380 validator_type validator
;
382 size_type sps_n
= sps
.size();
383 float_type old_mse
= -1;
384 for (size_type train_i
= 0; train_i
< train_n
; ++train_i
)
386 for (index_type i
= 0; i
< sps_n
; ++i
)
389 float_type p
= derive().rand().fgenerate(0, 1);
392 for (j
= 0; (sum
<= p
) && (j
< sps_n
); ++j
)
394 j
= j
> 0? j
- 1 : 0;
400 // modify learning rate
401 validator
.validate(derive(), sps
);
402 if (old_mse
< 0) old_mse
= validator
.mse();
403 else if (validator
.mse() < old_mse
) m_lrate
*= 1.06;
404 else if (validator
.mse() > (1.06 * old_mse
)) m_lrate
*= 0.8;
405 old_mse
= validator
.mse();
408 EXTL_ASSERT(derive().is_valid());
411 EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
412 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::run(sample_type
& sp
)
414 EXTL_ASSERT(derive().is_valid());
416 // forward calculation: input sample & calcucate output
417 derive().forward(sp
);
420 index_type start_n
= this->layers().nodes_size() - this->layers().onodes_size();
421 index_type end_n
= this->layers().nodes_size();
422 EXTL_ASSERT(start_n
<= end_n
);
423 EXTL_ASSERT(sp
.output_size() == this->layers().onodes_size());
424 for (index_type i
= start_n
; i
< end_n
; ++i
)
426 // the sample type is the bit array type
427 sp
.set_foutput(i
- start_n
, this->net().at(i
).output());
429 EXTL_ASSERT(derive().is_valid());
431 /*EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
432 inline void EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL::run(samples_type& sps)
434 EXTL_ASSERT(derive().is_valid());
436 size_type sps_n = sps.size();
437 for (index_type i = 0; i < sps_n; ++i)
440 EXTL_ASSERT(derive().is_valid());
443 /* //////////////////////////////////////////////////////////////////// */
444 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_TEMPLATE_DECL
445 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_QUAL
446 #undef EXTL_INTELLI_ANN_DETAIL_BP_NETWORK_IMPL_RET_QUAL
449 /* ///////////////////////////////////////////////////////////////////////
450 * ::extl::intelligence::detail namespace
452 EXTL_DETAIL_END_NAMESPACE
453 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
455 /* //////////////////////////////////////////////////////////////////// */
456 #endif /* EXTL_INTELLIGENCE_ANN_DETAIL_BP_NETWORK_IMPL_H */
457 /* //////////////////////////////////////////////////////////////////// */