Updated source code from upstream SVN
[svmtool++.git] / src / bin / SVMTeval.cc
blobdc02601676f118a30504d4a4ccb574685febd26b
1 /*
2 * Copyright (C) 2004 Jesus Gimenez, Lluis Marquez and Senen Moya
4 * This library is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU Lesser General Public
6 * License as published by the Free Software Foundation; either
7 * version 2.1 of the License, or (at your option) any later version.
8 *
9 * This library is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * Lesser General Public License for more details.
14 * You should have received a copy of the GNU Lesser General Public
15 * License along with this library; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include "list.h"
22 #include "hash.h"
23 #include "common.h"
24 #include "dict.h"
26 int verbose = TRUE;
28 /******************************************************************/
30 /*
31 * Esta estructura contiene el número de aciertos (hits),
32 * intentos (trials) y aciertos en caso de etiquetar con la
33 * etiqueta más frecuente (mft) para las palabras de una
34 * clase determinada (key) de palabras.
36 struct stat_t{
37 std::string key;
38 int hits;
39 int trials;
40 int mft;
43 /******************************************************************/
45 void printHelp(char *progname)
47 fprintf(stderr,"\nSVMTool++ v 1.1.6 -- SVMTeval\n\n");
48 fprintf(stderr,"Usage : %s [mode] <model> <gold> <pred>\n\n", progname);
49 fprintf(stderr,"\t- mode:\t0 - complete report (everything)\n");
50 fprintf(stderr,"\t\t1 - overall accuracy only [default]\n");
51 fprintf(stderr,"\t\t2 - accuracy of known vs unknown words\n");
52 fprintf(stderr,"\t\t3 - accuracy per level of ambiguity\n");
53 fprintf(stderr,"\t\t4 - accuracy per kind of ambiguity\n");
54 fprintf(stderr,"\t\t5 - accuracy per part-of-speech\n");
55 fprintf(stderr,"\t- model: model location (path + name)\n");
56 fprintf(stderr,"\t- gold: correct tagging file\n");
57 fprintf(stderr,"\t- pred: predicted tagging file\n\n");
58 fprintf(stderr,"Example : %s WSJTP WSJTP.IN WSJTP.OUT\n\n", progname);
61 /************************************************************************/
63 class eval
65 private:
66 char sModel[500];
67 char sGold[500];
68 char sPred[500];
69 FILE *gold;
70 FILE *pred;
71 dictionary *d;
72 hash_t<stat_t *> stat_Amb_Level;
73 hash_t<stat_t *> stat_Class_Amb;
74 hash_t<stat_t *> stat_POS;
75 int report_type;
76 int numAmbLevel;
77 int numAmbClass;
79 void printHashStats(hash_t<stat_t *> *tptr, int put_eol, const char *column_name);
80 void printKnownVsUnknown(int knownAmb, int knownUnamb, int unknown,int unkHits, int knownHitsAmb,int knownHitsUnamb);
81 void printTaggingSumary(int known,int unknown,int ambiguous,int well,int wellMFT);
82 void printOverallAccuracy(int total, int well, int wellMFT, float pAmb);
83 void printStatsByLevel(hash_t<stat_t *> *h);
84 void printStatsByAmbiguityClass(hash_t<stat_t *> *h);
85 void printStatsByPOS(hash_t<stat_t *> *h);
86 void addStatsToHash(hash_t< stat_t* >* h, const std::string& key, int is_hit, int is_mft);
87 void makeReport(dictionary *d,FILE *gold, FILE *pred);
89 public:
90 eval(char *model, char *goldName, char *predName);
91 void evalPutReportType(int report);
92 void evalRun();
95 /******************************************************************/
97 void eval::printHashStats(hash_t<stat_t *> *tptr, int put_eol, const char *column_name)
99 fprintf(stderr,"%s\tHITS\t\tTRIALS\t\tACCURACY\t\tMFT-ACCURACY\n",column_name);
100 fprintf(stderr,"* ------------------------------------------------------------------------- \n");
101 // hash_node_t *node, *last;
102 // int i;
104 char c='\0';
105 if (put_eol==TRUE) c='\n';
107 for (hash_t<stat_t *>::iterator it = tptr->begin(); it != tptr->end(); it++)
109 stat_t *s = (stat_t *)((*it).second);
110 fprintf(stderr,"%s%c\t%d\t/\t%d\t=\t%.4f %%\t\t%.4f %%\n",s->key.c_str(),c,s->hits,s->trials, 100*((float)s->hits/s->trials), 100*((float)s->mft/s->trials) );
114 /******************************************************************/
116 void eval::printKnownVsUnknown(int knownAmb, int knownUnamb, int unknown,int unkHits, int knownHitsAmb,int knownHitsUnamb)
118 int knownHits = knownHitsAmb + knownHitsUnamb;
119 int known = knownAmb + knownUnamb;
121 fprintf(stderr,"* ================= KNOWN vs UNKNOWN WORDS ================================\n");
122 fprintf(stderr,"\tHITS\t\tTRIALS\t\tACCURACY\n");
123 fprintf(stderr,"* -------------------------------------------------------------------------\n");
124 fprintf(stderr,"* ======= known ===========================================================\n");
125 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %%\n",knownHits,known,100*((float)knownHits/known));
126 fprintf(stderr,"-------- known unambiguous words ------------------------------------------\n");
127 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %%\n",knownHitsUnamb,knownUnamb,100*((float)knownHitsUnamb/knownUnamb));
128 fprintf(stderr,"-------- known ambiguous words --------------------------------------------\n");
129 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %%\n",knownHitsAmb,knownAmb,100*((float)knownHitsAmb/knownAmb));
130 fprintf(stderr,"* ======= unknown =========================================================\n");
131 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f %%\n",unkHits,unknown,100*((float)unkHits/unknown));
132 fprintf(stderr,"* =========================================================================\n");
135 /******************************************************************/
137 void eval::printTaggingSumary(int known,int unknown,int ambiguous,int /*well*/,int wellMFT)
139 fprintf(stderr,"* ================= TAGGING SUMMARY =======================================\n");
140 fprintf(stderr,"#WORDS\t\t = %d\n",known+unknown);
141 fprintf(stderr,"#KNOWN\t\t = %d\t/\t%d\t--> (%.4f %%)\n",known,known+unknown,100*((float)known/(known+unknown)));
142 fprintf(stderr,"#UNKNOWN\t = %d\t/\t%d\t--> (%.4f %%)\n",unknown,known+unknown,100*((float)unknown/(known+unknown)));
143 fprintf(stderr,"#AMBIGUOUS\t = %d\t/\t%d\t--> (%.4f %%)\n",ambiguous,known+unknown,100*((float)ambiguous/(known+unknown)));
144 fprintf(stderr,"#MFT baseline\t = %d\t/\t%d\t--> (%.4f %%)\n",wellMFT,known+unknown,100*((float)wellMFT/(known+unknown)));
147 /******************************************************************/
149 void eval::printOverallAccuracy(int total, int well, int wellMFT, float pAmb)
151 fprintf(stderr,"* ================= OVERALL ACCURACY ======================================\n");
152 fprintf(stderr,"\tHITS\t\tTRIALS\t\tACCURACY\tMFT-baseline\n");
153 fprintf(stderr,"* -------------------------------------------------------------------------\n");
154 fprintf(stderr,"\t%d\t/\t%d\t=\t%.4f\t\t%.4f%%\n",well,total,100*((float)well/total),100*((float)wellMFT/total));
155 fprintf(stderr,"* =========================================================================\n");
156 fprintf(stderr,"\tAmbiguity Average for Known words = %5f POS/word\n",pAmb);
157 fprintf(stderr,"* =========================================================================\n");
160 /******************************************************************/
162 void eval::printStatsByLevel(hash_t<stat_t *> *h)
164 fprintf(stderr,"* ================= ACCURACY PER LEVEL OF AMBIGUITY =======================\n");
165 fprintf(stderr,"#CLASSES = %d\n",numAmbLevel);
166 fprintf(stderr,"* =========================================================================\n");
167 printHashStats(h,FALSE,"LEVEL");
170 /******************************************************************/
172 void eval::printStatsByAmbiguityClass(hash_t<stat_t *> *h)
174 fprintf(stderr,"* ================= ACCURACY PER CLASS OF AMBIGUITY =======================\n");
175 fprintf(stderr,"#CLASSES = %d\n",numAmbClass);
176 fprintf(stderr,"* =========================================================================\n");
177 printHashStats(h,TRUE,"CLASS");
180 /******************************************************************/
182 void eval::printStatsByPOS(hash_t<stat_t *> *h)
184 fprintf(stderr,"* =================== ACCURACY PER PART-OF_SPEECH =========================\n");
185 printHashStats(h,FALSE,"POS");
188 /******************************************************************/
190 void eval::addStatsToHash(hash_t<stat_t *>* h, const std::string& key, int is_hit, int is_mft)
192 stat_t * s = h->hash_lookup(key);
193 if ((long)s!=HASH_FAIL)
195 s->key = key;
196 if (is_hit==TRUE) s->hits++;
197 if (is_mft==TRUE) s->mft++;
198 s->trials++;
200 else
202 if (report_type==3 || report_type==0) numAmbLevel++;
203 if (report_type==4 || report_type==0) numAmbClass++;
204 s = new stat_t;
205 s->key = key;
206 s->trials=1;
207 s->hits=0;
208 s->mft=0;
209 if (is_hit==TRUE) s->hits++;
210 if (is_mft==TRUE) s->mft++;
211 h->hash_insert(s->key,s);
215 /******************************************************************/
217 void eval::makeReport(dictionary *d,FILE *gold, FILE *pred)
219 std::string mft;
220 char wrd1[150],wrd2[150],pos1[5],pos2[5];
221 int totalWords=0, well = 0,known=0,unknown=0,wellMFT=0;
222 int ambiguous=0,unambiguous=0;
223 int unkHits=0,knownHitsAmb=0,knownHitsUnamb=0;
224 int ret1=0,ret2=0;
225 int is_mft=FALSE,is_hit=FALSE;
226 int contAmbiguities=0;
228 while (!feof(gold) && !feof(pred))
230 is_mft=FALSE;
231 is_hit=FALSE;
233 char gold_line[250] = "\n";
234 char pred_line[250] = "\n";
236 while ( !feof(gold) && ( strcmp(gold_line,"\n") == 0 || ( gold_line[0]=='#' && gold_line[1]=='#') ) )
237 fgets(gold_line,250,gold);
238 while ( !feof(pred) && ( strcmp(pred_line,"\n") == 0 || ( pred_line[0]=='#' && pred_line[1]=='#') ) )
239 fgets(pred_line,250,pred);
241 ret1 = sscanf (gold_line,"%s %s",wrd1,pos1);
242 ret2 = sscanf (pred_line,"%s %s",wrd2,pos2);
244 if ( ret1 >= 0 && ret2 >= 0 )
246 dataDict* w = d->getElement(wrd1);
247 int numMaybe;
248 if ((long)w!=HASH_FAIL) //Si es conocida
250 known++;
251 numMaybe = d->getElementNumMaybe(w);
253 if (numMaybe>1)
254 ambiguous++; //Si es ambigua
255 else unambiguous++; //Si no es ambigua
257 contAmbiguities += numMaybe;
259 mft = d->getMFT(w)->pos;
260 if (mft == pos1)
261 { is_mft = TRUE;
262 wellMFT++;
265 else unknown++;
267 if (strcmp(wrd1,wrd2)==0 && strcmp(pos1,pos2)==0)
269 well++;
270 is_hit=TRUE; //Es acierto
271 if ((long)w==HASH_FAIL) unkHits++; //Acierto para desconocidas
272 else if (numMaybe>1)
273 knownHitsAmb++; //Acierto para conocidas ambiguas
274 else knownHitsUnamb++; //Acierto para conocidas no ambiguas
277 if (report_type==3 || report_type==0)
279 //Acumulamos por nivel de ambigüedad
280 char level[4];
281 if ((long)w!=HASH_FAIL) sprintf(level,"%d",numMaybe);
282 else sprintf(level,"UNKOWN");
283 addStatsToHash(&stat_Amb_Level,level,is_hit,is_mft);
285 if (report_type==4 || report_type==0)
287 //Acumulamos por clase de ambigüedad
288 std::string ambClass = d->getAmbiguityClass(w);
289 addStatsToHash(&stat_Class_Amb,ambClass,is_hit,is_mft);
291 if (report_type==5 || report_type==0)
293 //Acumulamos por etiqueta
294 addStatsToHash(&stat_POS,pos2,is_hit,is_mft);
297 showProcessDone(totalWords,2000,FALSE,"words");
299 totalWords++;
304 showProcessDone(totalWords,2000,TRUE,"words");
305 printTaggingSumary(known,unknown,ambiguous,well,wellMFT);
306 if (report_type==2 || report_type==0) printKnownVsUnknown(ambiguous, unambiguous, unknown,unkHits,knownHitsAmb,knownHitsUnamb);
307 if (report_type==3 || report_type==0) printStatsByLevel(&stat_Amb_Level);
308 if (report_type==4 || report_type==0) printStatsByAmbiguityClass(&stat_Class_Amb);
309 if (report_type==5 || report_type==0) printStatsByPOS(&stat_POS);
311 float porcentageAmbiguedad = (float)contAmbiguities/ (float) known;
312 printOverallAccuracy(unknown+known,well,wellMFT,porcentageAmbiguedad);
316 /************************************************************************/
318 eval::eval(char *model, char *goldName, char *predName)
320 report_type = 1;
321 numAmbLevel = 0;
322 numAmbClass = 0;
324 stat_Amb_Level.hash_init(10);
325 stat_Class_Amb.hash_init(100);
326 stat_POS.hash_init(100);
328 strcpy(sModel, model);
329 strcpy(sGold, goldName);
330 strcpy(sPred, predName);
334 void eval::evalPutReportType(int report)
336 report_type = report;
339 void eval::evalRun()
341 fprintf(stderr,"* ========================= SVMTeval report ==============================\n");
342 fprintf(stderr,"* model = [%s]\n* testset = [%s]\n* predicted = [%s]\n",sModel,sGold,sPred);
343 fprintf(stderr,"* ========================================================================\n");
345 char name[200];
346 sprintf(name,"%s.DICT",sModel);
347 gold = openFile(sGold,"rt");
348 pred = openFile(sPred,"rt");
349 d = new dictionary(name);
351 makeReport(d,gold,pred);
353 delete d;
355 /************************************************************************/
357 int main(int argc, char *argv[])
359 int i = 0;
360 int report_type;
362 if (argc<4)
364 fprintf(stderr,"Waiting 3 or more parameters\n");
365 printHelp(argv[0]);
366 exit(0);
368 if (argc>4)
370 report_type = atoi(argv[1]);
371 i=1;
372 if (report_type>5 || report_type<0)
374 printHelp(argv[0]);
375 exit(0);
379 eval *e = new eval(argv[1+i], argv[2+i], argv[3+i]);
380 if (i == 1) { e->evalPutReportType(report_type); }
381 e->evalRun();