1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bayes_classifier.h
7 * Brief: The bayes_classifier class - for classifying samples using network
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H
14 #define EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H
16 /*!\file bayes_classifier.h
17 * \brief bayes_classifier class
20 /* ///////////////////////////////////////////////////////////////////////
24 #include "network_validator.h"
26 /* ///////////////////////////////////////////////////////////////////////
29 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
30 # error basic_classifier_validator.h is not supported by the current compiler.
33 /* ///////////////////////////////////////////////////////////////////////
34 * ::extl::intelligence namespace
36 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
38 /*!brief bayes_classifier
40 * \param NetWork the network type
42 * \ingroup extl_group_intelligence
44 template<typename_param_k NetWork
>
45 class bayes_classifier
51 typedef bayes_classifier class_type
;
52 typedef NetWork network_type
;
53 typedef typename_type_k
network_type::size_type size_type
;
54 typedef typename_type_k
network_type::bool_type bool_type
;
55 typedef typename_type_k
network_type::index_type index_type
;
56 typedef typename_type_k
network_type::float_type float_type
;
57 typedef typename_type_k
network_type::sample_type sample_type
;
58 typedef typename_type_k
network_type::samples_type samples_type
;
59 typedef typename_type_k
network_type::rand_type rand_type
;
60 typedef typename_type_k hash_selector
<size_type
, float_type
>::hash_type hash_type
;
61 typedef typename_type_k
hash_type::iterator hash_iterator
;
62 typedef typename_type_k
hash_type::const_iterator hash_const_iterator
;
63 typedef typename_type_k hash_selector
<size_type
, hash_type
>::hash_type hashes_type
;
64 typedef typename_type_k
hashes_type::iterator hashes_iterator
;
65 typedef typename_type_k
hashes_type::const_iterator hashes_const_iterator
;
72 network_type
* m_network
;
73 /// trained samples number
74 hash_type m_samples_n
;
75 /// trained samples distribution
76 hashes_type m_samples_d
;
82 enum { en_input_size
= network_type::en_input_size
};
83 enum { en_output_size
= network_type::en_output_size
};
86 /// \name Constructors
89 explicit_k
bayes_classifier(network_type
* network
= NULL
)
95 bayes_classifier(class_type
const& rhs
)
96 : m_network(rhs
.m_network
)
97 , m_samples_n(rhs
.m_samples_n
)
98 , m_samples_d(rhs
.m_samples_d
)
106 network_type
& network() { EXTL_ASSERT(NULL
!= m_network
); return *m_network
; }
107 network_type
const& network() const { EXTL_ASSERT(NULL
!= m_network
); return *m_network
; }
109 size_type
classify_n() const { return m_samples_n
.size(); }
115 size_type
hamming(size_type v1
, size_type v2
)
118 for (size_type i
= 0; i
< (size_type
)en_output_size
; ++i
)
120 size_type v1_i
= (size_type(0x1) == ((v1
>> i
) & size_type(0x1)));
121 size_type v2_i
= (size_type(0x1) == ((v2
>> i
) & size_type(0x1)));
122 if (v1_i
!= v2_i
) count
++;
126 /// init classifier using trained samples
127 /// \note must be called after network training
128 void init(samples_type
& sps
)
132 for (hashes_iterator pd
= m_samples_d
.begin(); pd
!= m_samples_d
.end(); ++pd
) pd
->second().clear();
135 // generate trained samples distribution
137 for (i
= 0; i
< sps
.size(); i
++)
139 // gets trained sample output
140 network().run(sps
[i
]);
141 size_type dreal
= sps
[i
].dreal();
142 size_type doutput
= sps
[i
].doutput();
144 // stats the number of samples which belong to dreal
146 if (m_samples_n
.count(dreal
) == 0) m_samples_n
[dreal
] = 1;
147 else ++m_samples_n
[dreal
];
149 // stats the number of samples which belong to dreal and are considered to be doutput
151 if (m_samples_d
[dreal
].count(doutput
) == 0) m_samples_d
[dreal
][doutput
] = 1;
152 else ++m_samples_d
[dreal
][doutput
];
155 // calcuate samples distribution: p(doutput|dreal)
156 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
158 //printf("dreal:%d count:%f => ", p->first(), p->second());
159 for (hash_iterator pd
= m_samples_d
[p
->first()].begin(); pd
!= m_samples_d
[p
->first()].end(); ++pd
)
161 // p(doutput|dreal) = n(doutput|dreal) / n(dreal)
162 pd
->second() /= p
->second();
163 //printf("doutput:%d distribution:%f", pd->first(), pd->second());
169 void classify(sample_type
& sp
)
171 // gets sample output
173 size_type doutput
= sp
.doutput();
175 // cumulatives hamming distance
177 float_type hamming_total
= 0;
178 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
180 size_type dreal
= p
->first();
181 real_pros
[dreal
] = en_output_size
- hamming(doutput
, dreal
);
182 hamming_total
+= real_pros
[dreal
];
184 if (hamming_total
< 0.00001) hamming_total
= 0.00001;
186 // calcuate real pros
187 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
189 size_type dreal
= p
->first();
190 real_pros
[dreal
] /= hamming_total
;
193 // cumulatives p(doutput_i|dreal) * p(dreal)
195 float_type pre_total
= 0;
196 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
198 size_type dreal
= p
->first();
200 // p(doutput_i|dreal) * p(dreal)
201 if (m_samples_d
[dreal
].count(doutput
) == 0) pre_pros
[dreal
] = 0;
202 else pre_pros
[dreal
] = m_samples_d
[dreal
][doutput
] * real_pros
[dreal
];
204 pre_total
+= pre_pros
[dreal
];
207 size_type max_dreal
= 0;
208 float_type max_pro
= -1;
210 if (pre_total
> 0.00001)
212 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
214 // bayes_pro = p(doutput_i|dreal) * p(dreal) / total
215 size_type dreal
= p
->first();
216 pre_pros
[dreal
] /= pre_total
;
219 if (pre_pros
[dreal
] > max_pro
)
222 max_pro
= pre_pros
[dreal
];
228 for (hash_const_iterator p
= m_samples_n
.begin(); p
!= m_samples_n
.end(); ++p
)
230 // bayes_pro = p(doutput_i|dreal) * p(dreal) / total
231 size_type dreal
= p
->first();
233 if (real_pros
[dreal
] > max_pro
)
236 max_pro
= real_pros
[dreal
];
241 // reject to clasify if the max_pro is too small
242 //if (max_pro < 1 / float_type(classify_n())) max_dreal = -1;
244 // dreal with maximum posterior probability is considered to be doutput
245 sp
.doutput(max_dreal
);
254 /* ///////////////////////////////////////////////////////////////////////
255 * ::extl::intelligence namespace
257 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
259 /* //////////////////////////////////////////////////////////////////// */
260 #endif /* EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H */
261 /* //////////////////////////////////////////////////////////////////// */