remove \r
[extl.git] / extl / intelligence / ann / bayes_classifier.h
blob4af4fb7c13ad09dbe1d6845315590b7f6f87895e
1 /* ///////////////////////////////////////////////////////////////////////
2 * File: bayes_classifier.h
4 * Created: 09.06.14
5 * Updated: 09.06.14
7 * Brief: The bayes_classifier class - for classifying samples using network
9 * [<Home>]
10 * Copyright (c) 2008-2020, Waruqi All rights reserved.
11 * //////////////////////////////////////////////////////////////////// */
13 #ifndef EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H
14 #define EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H
16 /*!\file bayes_classifier.h
17 * \brief bayes_classifier class
20 /* ///////////////////////////////////////////////////////////////////////
21 * Includes
23 #include "prefix.h"
24 #include "network_validator.h"
26 /* ///////////////////////////////////////////////////////////////////////
27 * Compatibility
29 #if !defined(EXTL_INTELLIGENCE_ANN_NETWORK_SUPPORT)
30 # error basic_classifier_validator.h is not supported by the current compiler.
31 #endif
33 /* ///////////////////////////////////////////////////////////////////////
34 * ::extl::intelligence namespace
36 EXTL_INTELLIGENCE_BEGIN_WHOLE_NAMESPACE
38 /*!brief bayes_classifier
40 * \param NetWork the network type
42 * \ingroup extl_group_intelligence
44 template<typename_param_k NetWork>
45 class bayes_classifier
48 /// \name Types
49 /// @{
50 public:
51 typedef bayes_classifier class_type;
52 typedef NetWork network_type;
53 typedef typename_type_k network_type::size_type size_type;
54 typedef typename_type_k network_type::bool_type bool_type;
55 typedef typename_type_k network_type::index_type index_type;
56 typedef typename_type_k network_type::float_type float_type;
57 typedef typename_type_k network_type::sample_type sample_type;
58 typedef typename_type_k network_type::samples_type samples_type;
59 typedef typename_type_k network_type::rand_type rand_type;
60 typedef typename_type_k hash_selector<size_type, float_type>::hash_type hash_type;
61 typedef typename_type_k hash_type::iterator hash_iterator;
62 typedef typename_type_k hash_type::const_iterator hash_const_iterator;
63 typedef typename_type_k hash_selector<size_type, hash_type>::hash_type hashes_type;
64 typedef typename_type_k hashes_type::iterator hashes_iterator;
65 typedef typename_type_k hashes_type::const_iterator hashes_const_iterator;
66 /// @}
68 /// \name Members
69 /// @{
70 private:
71 /// the network
72 network_type* m_network;
73 /// trained samples number
74 hash_type m_samples_n;
75 /// trained samples distribution
76 hashes_type m_samples_d;
77 /// @}
79 /// \name Constants
80 /// @{
81 public:
82 enum { en_input_size = network_type::en_input_size };
83 enum { en_output_size = network_type::en_output_size };
84 /// @}
86 /// \name Constructors
87 /// @{
88 public:
89 explicit_k bayes_classifier(network_type* network = NULL)
90 : m_network(network)
91 , m_samples_n()
92 , m_samples_d()
95 bayes_classifier(class_type const& rhs)
96 : m_network(rhs.m_network)
97 , m_samples_n(rhs.m_samples_n)
98 , m_samples_d(rhs.m_samples_d)
101 /// @}
103 /// \name Accessors
104 /// @{
105 public:
106 network_type& network() { EXTL_ASSERT(NULL != m_network); return *m_network; }
107 network_type const& network() const { EXTL_ASSERT(NULL != m_network); return *m_network; }
109 size_type classify_n() const { return m_samples_n.size(); }
110 /// @}
112 /// \name Methods
113 /// @{
114 public:
115 size_type hamming(size_type v1, size_type v2)
117 size_type count = 0;
118 for (size_type i = 0; i < (size_type)en_output_size; ++i)
120 size_type v1_i = (size_type(0x1) == ((v1 >> i) & size_type(0x1)));
121 size_type v2_i = (size_type(0x1) == ((v2 >> i) & size_type(0x1)));
122 if (v1_i != v2_i) count++;
124 return count;
126 /// init classifier using trained samples
127 /// \note must be called after network training
128 void init(samples_type& sps)
130 // clear data
131 m_samples_n.clear();
132 for (hashes_iterator pd = m_samples_d.begin(); pd != m_samples_d.end(); ++pd) pd->second().clear();
133 m_samples_d.clear();
135 // generate trained samples distribution
136 size_type i = 0;
137 for (i = 0; i < sps.size(); i++)
139 // gets trained sample output
140 network().run(sps[i]);
141 size_type dreal = sps[i].dreal();
142 size_type doutput = sps[i].doutput();
144 // stats the number of samples which belong to dreal
145 // n(dreal)
146 if (m_samples_n.count(dreal) == 0) m_samples_n[dreal] = 1;
147 else ++m_samples_n[dreal];
149 // stats the number of samples which belong to dreal and are considered to be doutput
150 // n(doutput|dreal)
151 if (m_samples_d[dreal].count(doutput) == 0) m_samples_d[dreal][doutput] = 1;
152 else ++m_samples_d[dreal][doutput];
155 // calcuate samples distribution: p(doutput|dreal)
156 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
158 //printf("dreal:%d count:%f => ", p->first(), p->second());
159 for (hash_iterator pd = m_samples_d[p->first()].begin(); pd != m_samples_d[p->first()].end(); ++pd)
161 // p(doutput|dreal) = n(doutput|dreal) / n(dreal)
162 pd->second() /= p->second();
163 //printf("doutput:%d distribution:%f", pd->first(), pd->second());
165 //printf("\n");
169 void classify(sample_type& sp)
171 // gets sample output
172 network().run(sp);
173 size_type doutput = sp.doutput();
175 // cumulatives hamming distance
176 hash_type real_pros;
177 float_type hamming_total = 0;
178 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
180 size_type dreal = p->first();
181 real_pros[dreal] = en_output_size - hamming(doutput, dreal);
182 hamming_total += real_pros[dreal];
184 if (hamming_total < 0.00001) hamming_total = 0.00001;
186 // calcuate real pros
187 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
189 size_type dreal = p->first();
190 real_pros[dreal] /= hamming_total;
193 // cumulatives p(doutput_i|dreal) * p(dreal)
194 hash_type pre_pros;
195 float_type pre_total = 0;
196 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
198 size_type dreal = p->first();
200 // p(doutput_i|dreal) * p(dreal)
201 if (m_samples_d[dreal].count(doutput) == 0) pre_pros[dreal] = 0;
202 else pre_pros[dreal] = m_samples_d[dreal][doutput] * real_pros[dreal];
204 pre_total += pre_pros[dreal];
207 size_type max_dreal = 0;
208 float_type max_pro = -1;
210 if (pre_total > 0.00001)
212 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
214 // bayes_pro = p(doutput_i|dreal) * p(dreal) / total
215 size_type dreal = p->first();
216 pre_pros[dreal] /= pre_total;
218 // maximum bayes_pro
219 if (pre_pros[dreal] > max_pro)
221 max_dreal = dreal;
222 max_pro = pre_pros[dreal];
226 else
228 for (hash_const_iterator p = m_samples_n.begin(); p != m_samples_n.end(); ++p)
230 // bayes_pro = p(doutput_i|dreal) * p(dreal) / total
231 size_type dreal = p->first();
232 // maximum bayes_pro
233 if (real_pros[dreal] > max_pro)
235 max_dreal = dreal;
236 max_pro = real_pros[dreal];
241 // reject to clasify if the max_pro is too small
242 //if (max_pro < 1 / float_type(classify_n())) max_dreal = -1;
244 // dreal with maximum posterior probability is considered to be doutput
245 sp.doutput(max_dreal);
247 /// @}
249 /// \name Attributes
250 /// @{
251 public:
254 /* ///////////////////////////////////////////////////////////////////////
255 * ::extl::intelligence namespace
257 EXTL_INTELLIGENCE_END_WHOLE_NAMESPACE
259 /* //////////////////////////////////////////////////////////////////// */
260 #endif /* EXTL_INTELLIGENCE_ANN_BAYES_CLASSIFIER_H */
261 /* //////////////////////////////////////////////////////////////////// */