1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bagging_networks.h
7 * Brief: The bagging_networks class
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H
14 #define EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H
16 /*!\file bagging_networks.h
17 * \brief bagging_networks class
20 /* ///////////////////////////////////////////////////////////////////////
23 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
24 # error bagging_networks.h is not supported by the current compiler.
27 /* ///////////////////////////////////////////////////////////////////////
31 /* ///////////////////////////////////////////////////////////////////////
32 * ::extl::intelligence namespace
34 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
36 /*!brief bagging_networks
38 * \param NetWork the network type
40 * \ingroup extl_group_intelligence
42 template<typename_param_k NetWork
>
43 class bagging_networks
46 /// \name Public Types
49 typedef bagging_networks class_type
;
50 typedef NetWork network_type
;
51 typedef typename_type_k
network_type::size_type size_type
;
52 typedef typename_type_k
network_type::bool_type bool_type
;
53 typedef typename_type_k
network_type::index_type index_type
;
54 typedef typename_type_k
network_type::float_type float_type
;
55 typedef typename_type_k
network_type::sample_type sample_type
;
56 typedef typename_type_k
network_type::samples_type samples_type
;
57 typedef typename_type_k
network_type::rand_type rand_type
;
58 typedef typename_type_k buffer_selector
<network_type
*>::buffer_type networks_type
;
59 typedef typename_type_k buffer_selector
<float_type
>::buffer_type float_buffer_type
;
65 networks_type m_networks
;
66 float_buffer_type m_foutputs
;
69 /// \name Constructors
72 explicit_k
bagging_networks ( network_type
* network_0
= NULL
73 , network_type
* network_1
= NULL
74 , network_type
* network_2
= NULL
75 , network_type
* network_3
= NULL
76 , network_type
* network_4
= NULL
77 , network_type
* network_5
= NULL
78 , network_type
* network_6
= NULL
79 , network_type
* network_7
= NULL
80 , network_type
* network_8
= NULL
81 , network_type
* network_9
= NULL
86 if (NULL
!= network_0
) m_networks
.push_back(network_0
);
87 if (NULL
!= network_1
) m_networks
.push_back(network_1
);
88 if (NULL
!= network_2
) m_networks
.push_back(network_2
);
89 if (NULL
!= network_3
) m_networks
.push_back(network_3
);
90 if (NULL
!= network_4
) m_networks
.push_back(network_4
);
91 if (NULL
!= network_5
) m_networks
.push_back(network_5
);
92 if (NULL
!= network_6
) m_networks
.push_back(network_6
);
93 if (NULL
!= network_7
) m_networks
.push_back(network_7
);
94 if (NULL
!= network_8
) m_networks
.push_back(network_8
);
95 if (NULL
!= network_9
) m_networks
.push_back(network_9
);
97 bagging_networks(class_type
const& rhs
)
98 : m_networks(rhs
.m_networks
)
99 , m_foutputs(rhs
.m_foutputs
)
107 networks_type
& networks() { return m_networks
; }
108 networks_type
const& networks() const { return m_networks
; }
110 float_type
foutput(index_type i
) const;
111 bool_type
boutput(index_type i
) const;
113 /// returns the random generator
114 rand_type
& rand() { return rand_type(); }
115 /// returns the const layers
116 rand_type
const& rand() const { return rand_type(); }
122 void swap(class_type
& rhs
);
128 /// train samples and return mse
129 void train(samples_type
& sps
, size_type train_n
= 1);
130 /// run the given sample
131 void run(sample_type
& sp
);
132 /// run the given samples
133 //void run(samples_type& sps);
139 float_buffer_type
& foutputs() { return m_foutputs
; }
140 float_buffer_type
const& foutputs() const { return m_foutputs
; }
143 /* ///////////////////////////////////////////////////////////////////////
146 template<typename_param_k NetWork
>
147 inline void bagging_networks
<NetWork
>::swap(class_type
& rhs
)
149 m_networks
.swap(rhs
.m_networks
);
150 m_foutputs
.swap(rhs
.m_foutputs
);
152 template<typename_param_k NetWork
>
153 inline void bagging_networks
<NetWork
>::train(samples_type
& sps
, size_type train_n
)
156 size_type sps_n
= sps
.size();
157 size_type networks_n
= networks().size();
158 for (i
= 0; i
< networks_n
; ++i
)
160 EXTL_ASSERT(NULL
!= networks()[i
]);
161 // select sps_n samples randomly
163 for (j
= 0; j
< sps_n
; ++j
)
165 index_type sps_i
= networks()[i
]->rand().generate(0, sps_n
);
166 tsps
.push_back(sps
[sps_i
]);
170 networks()[i
]->train(tsps
, train_n
);
174 template<typename_param_k NetWork
>
175 inline void bagging_networks
<NetWork
>::run(sample_type
& sp
)
179 // initialize foutputs
180 size_type output_n
= sp
.output_size();
181 foutputs().resize(output_n
);
182 for (i
= 0; i
< output_n
; ++i
)
185 size_type networks_n
= networks().size();
186 for (i
= 0; i
< networks_n
; ++i
)
188 EXTL_ASSERT(NULL
!= networks()[i
]);
189 networks()[i
]->run(sp
);
191 // accumulate outputs
192 for (j
= 0; j
< output_n
; ++j
)
193 foutputs()[j
] += static_cast<float_type
>(sp
.get_boutput(j
));
196 // calculate the average of the outputs
197 for (j
= 0; j
< output_n
; ++j
)
199 foutputs()[j
] /= networks_n
;
200 sp
.set_foutput(j
, foutputs()[j
]);
203 /*template<typename_param_k NetWork>
204 inline void bagging_networks<NetWork>::run(samples_type& sps)
206 size_type sps_n = sps.size();
207 for (index_type i = 0; i < sps_n; ++i)
211 template<typename_param_k NetWork
>
212 inline typename_type_ret_k bagging_networks
<NetWork
>::
213 float_type bagging_networks
<NetWork
>::foutput(index_type i
) const
215 EXTL_ASSERT(i
< m_foutputs
.size());
216 return m_foutputs
[i
];
218 template<typename_param_k NetWork
>
219 inline typename_type_ret_k bagging_networks
<NetWork
>::
220 bool_type bagging_networks
<NetWork
>::boutput(index_type i
) const
222 EXTL_ASSERT(i
< m_foutputs
.size());
223 return xtl_round45(m_foutputs
[i
] == 1);
225 /* ///////////////////////////////////////////////////////////////////////
226 * ::extl::intelligence namespace
228 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
230 /* //////////////////////////////////////////////////////////////////// */
231 #endif /* EXTL_INTELLIGENCE_ANN_BAGGING_NETWORKS_H */
232 /* //////////////////////////////////////////////////////////////////// */