Linux multi-monitor fullscreen support
[ryzomcore.git] / ryzom / tools / assoc_mem / tree.cpp
blob544af78b901ff6a6c3d3ae8bb73130735e0ffcf7
1 // Ryzom - MMORPG Framework <http://dev.ryzom.com/projects/ryzom/>
2 // Copyright (C) 2010 Winch Gate Property Limited
3 //
4 // This source file has been modified by the following contributors:
5 // Copyright (C) 2010 Robert TIMM (rti) <mail@rtti.de>
6 //
7 // This program is free software: you can redistribute it and/or modify
8 // it under the terms of the GNU Affero General Public License as
9 // published by the Free Software Foundation, either version 3 of the
10 // License, or (at your option) any later version.
12 // This program is distributed in the hope that it will be useful,
13 // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 // GNU Affero General Public License for more details.
17 // You should have received a copy of the GNU Affero General Public License
18 // along with this program. If not, see <http://www.gnu.org/licenses/>.
20 #include <valarray>
21 #include <map>
22 #include <stdio.h>
23 #include <iostream>
24 #include <algorithm>
26 #include "tree.h"
27 #include "cond_node.h"
28 #include "result_node.h"
30 CTree::CTree()
32 _RootNode = NULL;
35 CTree::~CTree()
37 if ( _RootNode != NULL )
38 delete _RootNode;
41 void CTree::setKey(int key)
43 _Key = key;
46 int CTree::getKey()
48 return _Key;
52 bool CTree::getOutput(CRecord *input)
54 if ( _RootNode != NULL )
55 return _RootNode->propagRecord( input );
56 else
57 return false;
60 int CTree::getNbRecords(std::vector<CRecord *> &records,int key, IValue *value) //
62 int nb = 0;
63 std::vector<CRecord *>::iterator it_r = records.begin();
64 while ( it_r != records.end() )
66 if ( *((**it_r)[key]) == value )
67 nb++;
68 it_r++;
70 return nb;
73 double CTree::log2(double val) const
75 return (log(val) / log(2.0));
78 double CTree::entropy(double a, double b) const
80 double p1;
81 double p2;
83 if ( a > 0 )
84 p1 = a * log2(a);
85 else
86 p1 = 0;
88 if ( b > 0 )
89 p2 = b * log2(b);
90 else
91 p2 = 0;
93 return ( p1 + p2 ) * -1;
96 double CTree::entropy(std::vector<double> &p) const
98 double result = 0;
99 std::vector<double>::iterator it_p = p.begin();
100 while ( it_p != p.end() )
102 double val = *it_p;
104 if ( val > 0 )
105 result = result + val * log2( val );
107 it_p++;
109 return result * -1;
112 double CTree::gain(std::vector<CRecord *> &records, int attrib, CField *field)
114 int nb_values = (int)field->getPossibleValues().size();
115 int nb_records = (int)records.size();
117 CValue<bool> bool_true(true);
119 double nb_key_true = getNbRecords(records, _Key, &bool_true );
120 double nb_key_false = nb_records - nb_key_true;
122 double entropy_records = entropy( nb_key_true / nb_records, nb_key_false / nb_records );
124 double gain = entropy_records;
126 int i;
127 for ( i = 0; i < nb_values; i++ )
129 IValue *val = field->getPossibleValues()[i];
131 int nb_records_val, nb_records_notval;
132 splitRecords(records, attrib, val, nb_records_val, nb_records_notval );
134 int nb_records_val_key, nb_records_val_notkey;
135 splitRecords(records, attrib, val, true, nb_records_val_key, nb_records_val_notkey );
137 double entropy_val = entropy( ((double)nb_records_val_key) / ((double)nb_records_val), ((double)nb_records_val_notkey) / ((double)nb_records_val) );
139 gain = gain - ( ( (double)nb_records_val ) / ( (double) nb_records ) * entropy_val );
141 return gain;
144 std::vector<std::pair<double,int> > CTree::getSortedFields( std::vector<int> &attributes, std::vector<CRecord *> &records, std::vector<CField *> &fields )
146 std::vector<std::pair<double,int> > attribs;
148 if ( ! records.empty() )
150 std::vector<int>::iterator it_a = attributes.begin();
151 while ( it_a != attributes.end() )
153 if ( (*it_a) != _Key )
154 attribs.push_back( std::pair<double,int>( gain(records, (*it_a), fields[*it_a] ) , (*it_a) ) );
155 it_a++;
159 // Sorts the records by gain
160 std::sort(attribs.begin(), attribs.end(), greater() );
162 std::vector<std::pair<double,int> >::iterator it_f = attribs.begin();
164 std::cout << "Attributes(gain) :" << std::endl;
165 while ( it_f != attribs.end() )
167 std::cout << " " << fields[ (*it_f).second ]->getName() << " (" << (*it_f).first << ") " << std::endl;
168 it_f++;
170 std::cout << std::endl;
172 return attribs;
175 // Looks for the attrib with the most gain
176 int CTree::getBestAttrib( std::vector<int> &attributes, std::vector<CRecord *> &records, std::vector<CField *> &fields )
178 double tmp_gain;
179 double max_gain = 0;
180 int best_attrib = -1;
182 std::cout << "Attributes(gain) :" << std::endl;
183 if ( ! records.empty() )
185 std::vector<int>::iterator it_a = attributes.begin();
186 while ( it_a != attributes.end() )
188 if ( (*it_a) != _Key )
191 tmp_gain = gain( records, *it_a, fields[ *it_a ] );
192 std::cout << " " << fields[ *it_a ]->getName() << " (" << tmp_gain << ") " << std::endl;
193 if ( tmp_gain >= max_gain )
195 max_gain = tmp_gain;
196 best_attrib = *it_a;
199 it_a++;
202 return best_attrib;
206 void CTree::rebuild(std::vector<CRecord *> &records, std::vector<CField *> &fields)
208 std::vector<int> left_fields;
209 CRecord *first = *records.begin();
211 for (int i = 0; i < first->size(); i++ )
212 if ( i != _Key )
213 left_fields.push_back( i );
215 _RootNode = ID3( left_fields, records, fields );
218 float CTree::findNumKeyValue(std::vector<CRecord *> &records, int key)
220 float sum_true = 0;
221 float nb_true = 0;
222 float sum_false = 0;
223 float nb_false = 0;
225 std::vector<CRecord *>::iterator it_r = records.begin();
226 while ( it_r != records.end() )
228 bool result = ((CValue<bool> *)(**it_r)[ _Key ])->getValue();
229 if ( result == true )
231 sum_true += ((CValue<int> *)(**it_r)[ key ])->getValue();
232 nb_true ++;
234 else
236 sum_false += ((CValue<int> *)(**it_r)[ key ])->getValue();
237 nb_false ++;
239 it_r++;
242 return ( sum_true / nb_true + sum_false / nb_false ) / 2;
245 std::string CTree::getDebugString(std::vector<CRecord *> &records, std::vector<CField *> &fields)
247 std::string output;
248 output += "CTree KEY = ";
249 output += fields[ _Key ]->getName();
250 return output;
253 INode *CTree::ID3(std::vector<int> &attributes, std::vector<CRecord *> &records, std::vector<CField *> &fields)
255 if ( records.empty() )
257 return new CResultNode( false );
259 else
261 // If there is no attribute left and the records don't have the same key value,
262 // returns a result node with the most frequent key value
263 if ( attributes.empty() )
265 int nb_key_true;
266 int nb_key_false;
267 splitRecords( records, _Key , nb_key_true, nb_key_false );
269 if ( nb_key_true > nb_key_false )
270 return new CResultNode( true );
271 else
272 return new CResultNode( false );
276 // Tests if all records have the same key value, if so returns a result node with this key value.
277 int nb_records = (int)records.size();
278 int nb_key_true;
279 int nb_key_false;
281 splitRecords( records, _Key , nb_key_true, nb_key_false );
283 if ( nb_key_true == nb_records )
284 return new CResultNode( true );
286 if ( nb_key_false == nb_records )
287 return new CResultNode( false );
290 // Gets the attribute with the most gain for the current record set,
291 // and recursively builds the subnodes corresponding to each
292 // possible value for this attribute.
293 int best_gain_attrib = getBestAttrib( attributes, records, fields );
295 std::vector< std::vector<CRecord *> > sorted_records;
296 splitRecords( records, best_gain_attrib, fields, sorted_records ); // classifies the records depending on the value of the best gain attribute
298 std::vector<int> new_attribs;
299 for ( int i = 0; i < (int) attributes.size(); i++ ) // Creates a new attributes list from the current attributes list with the best gain attribute removed
300 if ( attributes[i] != best_gain_attrib )
301 new_attribs.push_back( attributes[i] );
303 ICondNode *root_node = fields[best_gain_attrib]->createNode(_Key, best_gain_attrib, records);
305 std::vector< std::vector<CRecord *> >::iterator it = sorted_records.begin(); // Constructs subnodes recursively
306 while ( it != sorted_records.end() )
308 root_node->addNode( ID3( new_attribs, *it, fields ) );
309 it++;
311 return root_node;
315 std::vector<CRecord *> CTree::getRecords(std::vector<CRecord *> &records, int attrib, bool value)
317 std::vector<CRecord *> result;
318 std::vector<CRecord *>::iterator it_r = records.begin();
319 while ( it_r != records.end() )
321 if ( ((CValue<bool> *)(**it_r)[attrib])->getValue() == value )
322 result.push_back( *it_r );
323 it_r++;
325 return result;
328 void CTree::splitRecords(std::vector<CRecord *> &records, int attrib, int &true_records, int &false_records) //
330 true_records = 0;
331 false_records = 0;
333 std::vector<CRecord *>::iterator it_r = records.begin();
334 while ( it_r != records.end() )
336 if ( ((CValue<bool> *)(**it_r)[attrib])->getValue() == true )
337 true_records++;
338 else
339 false_records++;
340 it_r++;
344 void CTree::splitRecords(std::vector<CRecord *> &records, int attrib, IValue *val, int &true_records, int &false_records) //
346 true_records = 0;
347 false_records = 0;
349 std::vector<CRecord *>::iterator it_r = records.begin();
350 while ( it_r != records.end() )
352 const IValue *left_val = (**it_r)[ attrib ];
353 if ( ( *left_val ) == val )
354 true_records++;
355 else
356 false_records++;
357 it_r++;
361 // count records with a certain value for an attrib and true or false for the key attrib
362 void CTree::splitRecords(std::vector<CRecord *> &records, int attrib, IValue *val, bool key, int &true_records, int &false_records) //
364 true_records = 0;
365 false_records = 0;
367 std::vector<CRecord *>::iterator it_r = records.begin();
368 while ( it_r != records.end() )
370 if ( (* ( (**it_r)[attrib] ) ) == val )
372 if ( ( (CValue<bool> *) (**it_r)[ _Key ] )->getValue() == key )
373 true_records++;
374 else
375 false_records++;
377 it_r++;
382 // Sorts records according to the possibles values for an attribute.
383 void CTree::splitRecords( std::vector<CRecord *> &records, int attrib, std::vector<CField *> &fields, std::vector< std::vector<CRecord *> > &result) //
385 if ( result.size() < fields[attrib]->getPossibleValues().size() )
387 int nb_missing = (int)(fields[attrib]->getPossibleValues().size() - result.size());
388 for (int i = 0; i <= nb_missing; i++ )
390 result.push_back( std::vector<CRecord *>() );
394 std::vector<CRecord *>::iterator it_r = records.begin();
395 while ( it_r != records.end() )
397 std::vector<IValue *>::const_iterator it_vp = fields[attrib]->getPossibleValues().begin();
398 std::vector< std::vector<CRecord *> >::iterator it = result.begin();
399 int id_val = 0;
400 while ( it_vp != fields[attrib]->getPossibleValues().end() )
402 const IValue *left_value = (**it_r)[attrib];
403 IValue *right_value = *it_vp;
404 if ( (*left_value) == right_value )
405 (*it).push_back( *it_r );
406 it_vp++;
407 it++;
409 it_r++;