More code cleanup
[svmtool++.git] / src / SVMTeval.cc
blobfc915a1adb9cd04c4bb953341ff2091407c037c1
1 /*
2 * Copyright (C) 2004 Jesus Gimenez, Lluis Marquez and Senen Moya
4 * This library is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU Lesser General Public
6 * License as published by the Free Software Foundation; either
7 * version 2.1 of the License, or (at your option) any later version.
9 * This library is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * Lesser General Public License for more details.
14 * You should have received a copy of the GNU Lesser General Public
15 * License along with this library; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include "list.h"
22 #include "hash.h"
23 #include "common.h"
24 #include "dict.h"
26 extern int verbose_svmtool;
28 /******************************************************************/
30 /*
31 * Esta estructura contiene el nmero de aciertos (hits),
32 * intentos (trials) y aciertos en caso de etiquetar con la
33 * etiqueta ms frecuente (mft) para las palabras de una
34 * clase determinada (key) de palabras.
36 struct stat_t
38 char key[100];
39 int hits;
40 int trials;
41 int mft;
44 /******************************************************************/
46 void printHelp()
48 fprintf(stderr,"\nSVMTool++ v 1.1.2 -- SVMTeval\n\n");
49 fprintf(stderr,"Usage : SVMTeval.perl [mode] <model> <gold> <pred>\n\n");
50 fprintf(stderr,"\t- mode:\t0 - complete report (everything)\n");
51 fprintf(stderr,"\t\t1 - overall accuracy only [default]\n");
52 fprintf(stderr,"\t\t2 - accuracy of known vs unknown words\n");
53 fprintf(stderr,"\t\t3 - accuracy per level of ambiguity\n");
54 fprintf(stderr,"\t\t4 - accuracy per kind of ambiguity\n");
55 fprintf(stderr,"\t\t5 - accuracy per part-of-speech\n");
56 fprintf(stderr,"\t- model: model location (path + name)\n");
57 fprintf(stderr,"\t- gold: correct tagging file\n");
58 fprintf(stderr,"\t- pred: predicted tagging file\n\n");
59 fprintf(stderr,"Example : SVMTeval WSJTP WSJTP.IN WSJTP.OUT\n\n");
63 /************************************************************************/
65 class eval
67 private:
68 char sModel[500];
69 char sGold[500];
70 char sPred[500];
71 FILE *gold;
72 FILE *pred;
73 dictionary *d;
74 hash_t stat_Amb_Level;
75 hash_t stat_Class_Amb;
76 hash_t stat_POS;
77 int report_type;
78 int numAmbLevel;
79 int numAmbClass;
81 void printHashStats(hash_t *tptr, int put_eol, const char *column_name);
82 void printKnownVsUnknown(int knownAmb, int knownUnamb, int unknown,int unkHits, int knownHitsAmb,int knownHitsUnamb);
83 void printTaggingSumary(int known,int unknown,int ambiguous,int well,int wellMFT);
84 void printOverallAccuracy(int total, int well, int wellMFT, float pAmb);
85 void printStatsByLevel(hash_t *h);
86 void printStatsByAmbiguityClass(hash_t *h);
87 void printStatsByPOS(hash_t *h);
88 void addStatsToHash(hash_t *h,char *key,int is_hit, int is_mft);
89 void makeReport(dictionary *d,FILE *gold, FILE *pred);
91 public:
92 eval(char *model, char *goldName, char *predName);
93 void evalPutReportType(int report);
94 void evalRun();
97 /******************************************************************/
99 void eval::printHashStats(hash_t *tptr, int put_eol, const char *column_name)
101 fprintf(stderr,"%s\tHITS\t\tTRIALS\t\tACCURACY\t\tMFT-ACCURACY\n",column_name);
102 fprintf(stderr,"* ------------------------------------------------------------------------- \n");
103 hash_node_t *node, *last;
104 int i;
106 char c='\0';
107 if (put_eol==TRUE) c='\n';
109 for (i=0; i<tptr->size; i++)
111 node = tptr->bucket[i];
112 while (node != NULL)
114 last = node;
115 node = node->next;
116 stat_t *s = (stat_t *)last->data;
117 fprintf(stderr,"%s%c\t%d\t/\t%d\t=\t%.4f %\t\t%.4f %\n",s->key,c,s->hits,s->trials, 100*((float)s->hits/s->trials), 100*((float)s->mft/s->trials) );
124 /******************************************************************/
126 void eval::printKnownVsUnknown(int knownAmb, int knownUnamb, int unknown,int unkHits, int knownHitsAmb,int knownHitsUnamb)
128 int knownHits = knownHitsAmb + knownHitsUnamb;
129 int known = knownAmb + knownUnamb;
131 fprintf(stderr,"* ================= KNOWN vs UNKNOWN WORDS ================================\n");
132 fprintf(stderr,"\tHITS\t\tTRIALS\t\tACCURACY\n");
133 fprintf(stderr,"* -------------------------------------------------------------------------\n");
134 fprintf(stderr,"* ======= known ===========================================================\n");
135 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %\n",knownHits,known,100*((float)knownHits/known));
136 fprintf(stderr,"-------- known unambiguous words ------------------------------------------\n");
137 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %\n",knownHitsUnamb,knownUnamb,100*((float)knownHitsUnamb/knownUnamb));
138 fprintf(stderr,"-------- known ambiguous words --------------------------------------------\n");
139 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %\n",knownHitsAmb,knownAmb,100*((float)knownHitsAmb/knownAmb));
140 fprintf(stderr,"* ======= unknown =========================================================\n");
141 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %\n",unkHits,unknown,100*((float)unkHits/unknown));
142 fprintf(stderr,"* =========================================================================\n");
146 /******************************************************************/
148 void eval::printTaggingSumary(int known,int unknown,int ambiguous,int well,int wellMFT)
150 fprintf(stderr,"* ================= TAGGING SUMMARY =======================================\n");
151 fprintf(stderr,"#WORDS\t\t = %d\n",known+unknown);
152 fprintf(stderr,"#KNOWN\t\t = %d\t/\t%d\t--> (%.4f %)\n",known,known+unknown,100*((float)known/(known+unknown)));
153 fprintf(stderr,"#UNKNOWN\t = %d\t/\t%d\t--> (%.4f %)\n",unknown,known+unknown,100*((float)unknown/(known+unknown)));
154 fprintf(stderr,"#AMBIGUOUS\t = %d\t/\t%d\t--> (%.4f %)\n",ambiguous,known+unknown,100*((float)ambiguous/(known+unknown)));
155 fprintf(stderr,"#MFT baseline\t = %d\t/\t%d\t--> (%.4f %)\n",wellMFT,known+unknown,100*((float)wellMFT/(known+unknown)));
159 /******************************************************************/
161 void eval::printOverallAccuracy(int total, int well, int wellMFT, float pAmb)
163 fprintf(stderr,"* ================= OVERALL ACCURACY ======================================\n");
164 fprintf(stderr,"\tHITS\t\tTRIALS\t\tACCURACY\tMFT-baseline\n");
165 fprintf(stderr,"* -------------------------------------------------------------------------\n");
166 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f\t\t%.4f%\n",well,total,100*((float)well/total),100*((float)wellMFT/total));
167 fprintf(stderr,"* =========================================================================\n");
168 fprintf(stderr,"\tAmbiguity Average for Known words = %.5f POS/word\n",pAmb);
169 fprintf(stderr,"* =========================================================================\n");
173 /******************************************************************/
175 void eval::printStatsByLevel(hash_t *h)
177 fprintf(stderr,"* ================= ACCURACY PER LEVEL OF AMBIGUITY =======================\n");
178 fprintf(stderr,"#CLASSES = %d\n",numAmbLevel);
179 fprintf(stderr,"* =========================================================================\n");
180 printHashStats(h,FALSE,"LEVEL");
184 /******************************************************************/
186 void eval::printStatsByAmbiguityClass(hash_t *h)
188 fprintf(stderr,"* ================= ACCURACY PER CLASS OF AMBIGUITY =======================\n");
189 fprintf(stderr,"#CLASSES = %d\n",numAmbClass);
190 fprintf(stderr,"* =========================================================================\n");
191 printHashStats(h,TRUE,"CLASS");
195 /******************************************************************/
197 void eval::printStatsByPOS(hash_t *h)
199 fprintf(stderr,"* =================== ACCURACY PER PART-OF_SPEECH =========================\n");
200 printHashStats(h,FALSE,"POS");
204 /******************************************************************/
206 void eval::addStatsToHash(hash_t *h,char *key,int is_hit, int is_mft)
208 uintptr_t p = hash_lookup(h,key);
209 if (p!=HASH_FAIL)
211 stat_t *s = (stat_t *)p;
212 strcpy(s->key,key);
213 if (is_hit==TRUE) s->hits++;
214 if (is_mft==TRUE) s->mft++;
215 s->trials++;
217 else
219 if (report_type==3 || report_type==0) numAmbLevel++;
220 if (report_type==4 || report_type==0) numAmbClass++;
221 stat_t *s = new stat_t;
222 strcpy(s->key,key);
223 s->trials=1;
224 s->hits=0;
225 s->mft=0;
226 if (is_hit==TRUE) s->hits++;
227 if (is_mft==TRUE) s->mft++;
228 hash_insert(h,s->key,(uintptr_t) s);
233 /******************************************************************/
235 void eval::makeReport(dictionary *d,FILE *gold, FILE *pred)
237 char *mft,wrd1[150],wrd2[150],pos1[5],pos2[5];
238 int totalWords=0, well = 0,known=0,unknown=0,wellMFT=0;
239 int ambiguous=0,unambiguous=0;
240 int unkHits=0,knownHitsAmb=0,knownHitsUnamb=0;
241 int ret1=0,ret2=0;
242 int is_mft=FALSE,is_hit=FALSE;
243 int contAmbiguities=0;
245 while (!feof(gold) && !feof(pred))
247 is_mft=FALSE;
248 is_hit=FALSE;
250 char gold_line[250] = "\n";
251 char pred_line[250] = "\n";
253 while ( !feof(gold) && ( strcmp(gold_line,"\n") == 0 || ( gold_line[0]=='#' && gold_line[1]=='#') ) )
254 fgets(gold_line,250,gold);
255 while ( !feof(pred) && ( strcmp(pred_line,"\n") == 0 || ( pred_line[0]=='#' && pred_line[1]=='#') ) )
256 fgets(pred_line,250,pred);
258 ret1 = sscanf (gold_line,"%s %s",wrd1,pos1);
259 ret2 = sscanf (pred_line,"%s %s",wrd2,pos2);
261 if ( ret1 >= 0 && ret2 >= 0 )
263 int w = d->getElement(wrd1);
264 int numMaybe;
265 if (w!=HASH_FAIL) //Si es conocida
267 known++;
268 numMaybe = d->getElementNumMaybe(w);
270 if (numMaybe>1)
271 ambiguous++; //Si es ambigua
272 //Si no es ambigua
273 else unambiguous++;
275 contAmbiguities += numMaybe;
277 mft = d->getMFT(w);
278 if (strcmp(mft,pos1)==0)
280 is_mft = TRUE;
281 wellMFT++;
283 delete mft;
285 else unknown++;
287 if (strcmp(wrd1,wrd2)==0 && strcmp(pos1,pos2)==0)
289 well++;
290 is_hit=TRUE; //Es acierto
291 //Acierto para desconocidas
292 if (w==HASH_FAIL) unkHits++;
293 else if (numMaybe>1)
294 //Acierto para conocidas ambiguas
295 knownHitsAmb++;
296 //Acierto para conocidas no ambiguas
297 else knownHitsUnamb++;
300 if (report_type==3 || report_type==0)
302 //Acumulamos por nivel de ambigedad
303 char level[4];
304 if (w!=HASH_FAIL) sprintf(level,"%d",numMaybe);
305 else sprintf(level,"UNKOWN");
306 addStatsToHash(&stat_Amb_Level,level,is_hit,is_mft);
308 if (report_type==4 || report_type==0)
310 //Acumulamos por clase de ambigedad
311 char *ambClass = d->getAmbiguityClass(w);
312 addStatsToHash(&stat_Class_Amb,ambClass,is_hit,is_mft);
313 delete ambClass;
315 if (report_type==5 || report_type==0)
317 //Acumulamos por etiqueta
318 addStatsToHash(&stat_POS,pos2,is_hit,is_mft);
321 showProcessDone(totalWords,2000,FALSE,"words");
323 totalWords++;
328 showProcessDone(totalWords,2000,TRUE,"words");
329 printTaggingSumary(known,unknown,ambiguous,well,wellMFT);
330 if (report_type==2 || report_type==0) printKnownVsUnknown(ambiguous, unambiguous, unknown,unkHits,knownHitsAmb,knownHitsUnamb);
331 if (report_type==3 || report_type==0) printStatsByLevel(&stat_Amb_Level);
332 if (report_type==4 || report_type==0) printStatsByAmbiguityClass(&stat_Class_Amb);
333 if (report_type==5 || report_type==0) printStatsByPOS(&stat_POS);
335 float porcentageAmbiguedad = (float)contAmbiguities/ (float) known;
336 printOverallAccuracy(unknown+known,well,wellMFT,porcentageAmbiguedad);
340 /************************************************************************/
342 eval::eval(char *model, char *goldName, char *predName)
344 report_type = 1;
345 numAmbLevel = 0;
346 numAmbClass = 0;
348 hash_init(&stat_Amb_Level,10);
349 hash_init(&stat_Class_Amb,100);
350 hash_init(&stat_POS,100);
352 strcpy(sModel, model);
353 strcpy(sGold, goldName);
354 strcpy(sPred, predName);
358 void eval::evalPutReportType(int report)
360 report_type = report;
364 void eval::evalRun()
366 fprintf(stderr,"* ========================= SVMTeval report ==============================\n");
367 fprintf(stderr,"* model = [%s]\n* testset = [%s]\n* predicted = [%s]\n",sModel,sGold,sPred);
368 fprintf(stderr,"* ========================================================================\n");
370 char name[200];
371 sprintf(name,"%s.DICT",sModel);
372 gold = openFile(sGold,"rt");
373 pred = openFile(sPred,"rt");
374 d = new dictionary(name);
376 makeReport(d,gold,pred);
378 delete d;
382 /************************************************************************/
384 int main(int argc, char *argv[])
386 int i = 0;
387 int report_type;
389 verbose_svmtool = TRUE;
391 if (argc<4)
393 fprintf(stderr,"Waiting 3 or more parameters\n");
394 printHelp();
395 exit(0);
397 if (argc>4)
399 report_type = atoi(argv[1]);
400 i=1;
401 if (report_type>5 || report_type<0)
403 printHelp();
404 exit(0);
408 eval *e = new eval(argv[1+i], argv[2+i], argv[3+i]);
409 if (i == 1) { e->evalPutReportType(report_type); }
410 e->evalRun();