1 // Ryzom - MMORPG Framework <http://dev.ryzom.com/projects/ryzom/>
2 // Copyright (C) 2010 Winch Gate Property Limited
4 // This source file has been modified by the following contributors:
5 // Copyright (C) 2010 Robert TIMM (rti) <mail@rtti.de>
7 // This program is free software: you can redistribute it and/or modify
8 // it under the terms of the GNU Affero General Public License as
9 // published by the Free Software Foundation, either version 3 of the
10 // License, or (at your option) any later version.
12 // This program is distributed in the hope that it will be useful,
13 // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 // GNU Affero General Public License for more details.
17 // You should have received a copy of the GNU Affero General Public License
18 // along with this program. If not, see <http://www.gnu.org/licenses/>.
27 #include "cond_node.h"
28 #include "result_node.h"
37 if ( _RootNode
!= NULL
)
41 void CTree::setKey(int key
)
52 bool CTree::getOutput(CRecord
*input
)
54 if ( _RootNode
!= NULL
)
55 return _RootNode
->propagRecord( input
);
60 int CTree::getNbRecords(std::vector
<CRecord
*> &records
,int key
, IValue
*value
) //
63 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
64 while ( it_r
!= records
.end() )
66 if ( *((**it_r
)[key
]) == value
)
73 double CTree::log2(double val
) const
75 return (log(val
) / log(2.0));
78 double CTree::entropy(double a
, double b
) const
93 return ( p1
+ p2
) * -1;
96 double CTree::entropy(std::vector
<double> &p
) const
99 std::vector
<double>::iterator it_p
= p
.begin();
100 while ( it_p
!= p
.end() )
105 result
= result
+ val
* log2( val
);
112 double CTree::gain(std::vector
<CRecord
*> &records
, int attrib
, CField
*field
)
114 int nb_values
= (int)field
->getPossibleValues().size();
115 int nb_records
= (int)records
.size();
117 CValue
<bool> bool_true(true);
119 double nb_key_true
= getNbRecords(records
, _Key
, &bool_true
);
120 double nb_key_false
= nb_records
- nb_key_true
;
122 double entropy_records
= entropy( nb_key_true
/ nb_records
, nb_key_false
/ nb_records
);
124 double gain
= entropy_records
;
127 for ( i
= 0; i
< nb_values
; i
++ )
129 IValue
*val
= field
->getPossibleValues()[i
];
131 int nb_records_val
, nb_records_notval
;
132 splitRecords(records
, attrib
, val
, nb_records_val
, nb_records_notval
);
134 int nb_records_val_key
, nb_records_val_notkey
;
135 splitRecords(records
, attrib
, val
, true, nb_records_val_key
, nb_records_val_notkey
);
137 double entropy_val
= entropy( ((double)nb_records_val_key
) / ((double)nb_records_val
), ((double)nb_records_val_notkey
) / ((double)nb_records_val
) );
139 gain
= gain
- ( ( (double)nb_records_val
) / ( (double) nb_records
) * entropy_val
);
144 std::vector
<std::pair
<double,int> > CTree::getSortedFields( std::vector
<int> &attributes
, std::vector
<CRecord
*> &records
, std::vector
<CField
*> &fields
)
146 std::vector
<std::pair
<double,int> > attribs
;
148 if ( ! records
.empty() )
150 std::vector
<int>::iterator it_a
= attributes
.begin();
151 while ( it_a
!= attributes
.end() )
153 if ( (*it_a
) != _Key
)
154 attribs
.push_back( std::pair
<double,int>( gain(records
, (*it_a
), fields
[*it_a
] ) , (*it_a
) ) );
159 // Sorts the records by gain
160 std::sort(attribs
.begin(), attribs
.end(), greater() );
162 std::vector
<std::pair
<double,int> >::iterator it_f
= attribs
.begin();
164 std::cout
<< "Attributes(gain) :" << std::endl
;
165 while ( it_f
!= attribs
.end() )
167 std::cout
<< " " << fields
[ (*it_f
).second
]->getName() << " (" << (*it_f
).first
<< ") " << std::endl
;
170 std::cout
<< std::endl
;
175 // Looks for the attrib with the most gain
176 int CTree::getBestAttrib( std::vector
<int> &attributes
, std::vector
<CRecord
*> &records
, std::vector
<CField
*> &fields
)
180 int best_attrib
= -1;
182 std::cout
<< "Attributes(gain) :" << std::endl
;
183 if ( ! records
.empty() )
185 std::vector
<int>::iterator it_a
= attributes
.begin();
186 while ( it_a
!= attributes
.end() )
188 if ( (*it_a
) != _Key
)
191 tmp_gain
= gain( records
, *it_a
, fields
[ *it_a
] );
192 std::cout
<< " " << fields
[ *it_a
]->getName() << " (" << tmp_gain
<< ") " << std::endl
;
193 if ( tmp_gain
>= max_gain
)
206 void CTree::rebuild(std::vector
<CRecord
*> &records
, std::vector
<CField
*> &fields
)
208 std::vector
<int> left_fields
;
209 CRecord
*first
= *records
.begin();
211 for (int i
= 0; i
< first
->size(); i
++ )
213 left_fields
.push_back( i
);
215 _RootNode
= ID3( left_fields
, records
, fields
);
218 float CTree::findNumKeyValue(std::vector
<CRecord
*> &records
, int key
)
225 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
226 while ( it_r
!= records
.end() )
228 bool result
= ((CValue
<bool> *)(**it_r
)[ _Key
])->getValue();
229 if ( result
== true )
231 sum_true
+= ((CValue
<int> *)(**it_r
)[ key
])->getValue();
236 sum_false
+= ((CValue
<int> *)(**it_r
)[ key
])->getValue();
242 return ( sum_true
/ nb_true
+ sum_false
/ nb_false
) / 2;
245 std::string
CTree::getDebugString(std::vector
<CRecord
*> &records
, std::vector
<CField
*> &fields
)
248 output
+= "CTree KEY = ";
249 output
+= fields
[ _Key
]->getName();
253 INode
*CTree::ID3(std::vector
<int> &attributes
, std::vector
<CRecord
*> &records
, std::vector
<CField
*> &fields
)
255 if ( records
.empty() )
257 return new CResultNode( false );
261 // If there is no attribute left and the records don't have the same key value,
262 // returns a result node with the most frequent key value
263 if ( attributes
.empty() )
267 splitRecords( records
, _Key
, nb_key_true
, nb_key_false
);
269 if ( nb_key_true
> nb_key_false
)
270 return new CResultNode( true );
272 return new CResultNode( false );
276 // Tests if all records have the same key value, if so returns a result node with this key value.
277 int nb_records
= (int)records
.size();
281 splitRecords( records
, _Key
, nb_key_true
, nb_key_false
);
283 if ( nb_key_true
== nb_records
)
284 return new CResultNode( true );
286 if ( nb_key_false
== nb_records
)
287 return new CResultNode( false );
290 // Gets the attribute with the most gain for the current record set,
291 // and recursively builds the subnodes corresponding to each
292 // possible value for this attribute.
293 int best_gain_attrib
= getBestAttrib( attributes
, records
, fields
);
295 std::vector
< std::vector
<CRecord
*> > sorted_records
;
296 splitRecords( records
, best_gain_attrib
, fields
, sorted_records
); // classifies the records depending on the value of the best gain attribute
298 std::vector
<int> new_attribs
;
299 for ( int i
= 0; i
< (int) attributes
.size(); i
++ ) // Creates a new attributes list from the current attributes list with the best gain attribute removed
300 if ( attributes
[i
] != best_gain_attrib
)
301 new_attribs
.push_back( attributes
[i
] );
303 ICondNode
*root_node
= fields
[best_gain_attrib
]->createNode(_Key
, best_gain_attrib
, records
);
305 std::vector
< std::vector
<CRecord
*> >::iterator it
= sorted_records
.begin(); // Constructs subnodes recursively
306 while ( it
!= sorted_records
.end() )
308 root_node
->addNode( ID3( new_attribs
, *it
, fields
) );
315 std::vector
<CRecord
*> CTree::getRecords(std::vector
<CRecord
*> &records
, int attrib
, bool value
)
317 std::vector
<CRecord
*> result
;
318 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
319 while ( it_r
!= records
.end() )
321 if ( ((CValue
<bool> *)(**it_r
)[attrib
])->getValue() == value
)
322 result
.push_back( *it_r
);
328 void CTree::splitRecords(std::vector
<CRecord
*> &records
, int attrib
, int &true_records
, int &false_records
) //
333 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
334 while ( it_r
!= records
.end() )
336 if ( ((CValue
<bool> *)(**it_r
)[attrib
])->getValue() == true )
344 void CTree::splitRecords(std::vector
<CRecord
*> &records
, int attrib
, IValue
*val
, int &true_records
, int &false_records
) //
349 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
350 while ( it_r
!= records
.end() )
352 const IValue
*left_val
= (**it_r
)[ attrib
];
353 if ( ( *left_val
) == val
)
361 // count records with a certain value for an attrib and true or false for the key attrib
362 void CTree::splitRecords(std::vector
<CRecord
*> &records
, int attrib
, IValue
*val
, bool key
, int &true_records
, int &false_records
) //
367 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
368 while ( it_r
!= records
.end() )
370 if ( (* ( (**it_r
)[attrib
] ) ) == val
)
372 if ( ( (CValue
<bool> *) (**it_r
)[ _Key
] )->getValue() == key
)
382 // Sorts records according to the possibles values for an attribute.
383 void CTree::splitRecords( std::vector
<CRecord
*> &records
, int attrib
, std::vector
<CField
*> &fields
, std::vector
< std::vector
<CRecord
*> > &result
) //
385 if ( result
.size() < fields
[attrib
]->getPossibleValues().size() )
387 int nb_missing
= (int)(fields
[attrib
]->getPossibleValues().size() - result
.size());
388 for (int i
= 0; i
<= nb_missing
; i
++ )
390 result
.push_back( std::vector
<CRecord
*>() );
394 std::vector
<CRecord
*>::iterator it_r
= records
.begin();
395 while ( it_r
!= records
.end() )
397 std::vector
<IValue
*>::const_iterator it_vp
= fields
[attrib
]->getPossibleValues().begin();
398 std::vector
< std::vector
<CRecord
*> >::iterator it
= result
.begin();
400 while ( it_vp
!= fields
[attrib
]->getPossibleValues().end() )
402 const IValue
*left_value
= (**it_r
)[attrib
];
403 IValue
*right_value
= *it_vp
;
404 if ( (*left_value
) == right_value
)
405 (*it
).push_back( *it_r
);