softcount: Count as long double instead of double
[vspell.git] / utils / softcount.cpp
blob622daa46a04048993f40a4d89da2c7cad4a195e0
1 #include "softcount.h" // -*- tab-width: 2 -*-
2 #include <deque>
3 #include <boost/format.hpp>
5 /**
6 Soft counter.
7 Based on the article "Discovering Chinese Words from Unsegmented Text".
8 The idea is that we count all possible words in a sentence with a fraction count
9 instead of just count the words of the best segmentation in the sentence (count=1)
12 using namespace std;
14 /**
15 No short description.
17 ostream& SoftCounter::count_lattice(const Lattice &w, ostream &os, bool first_count)
19 vector<long double> Sleft,Sright;
20 vector<vector<uint> > prev;
21 //boost::shared_ptr<WordEntries> p_we = w.we;
22 //const WordEntries &we = *p_we;
23 int i,n = w.get_word_count(),v,vv;
24 long double sum = 0;
25 VocabIndex vi[3];
26 vector<uint> ends;
27 long double fc;
29 Sleft.resize(w.we->size());
30 Sright.resize(w.we->size());
31 prev.resize(w.we->size());
33 vi[1] = 0;
35 // first pass: Sleft
36 for (i = 0;i < n;i ++) {
37 const WordEntryRefs &wers = w.get_we(i);
38 int ii,nn = wers.size();
39 long double add;
40 for (ii = 0;ii < nn;ii ++) {
41 // wers[ii] is the first node (W).
42 v = wers[ii]->id;
44 if (i == 0) {
45 vi[0] = get_id(START_ID);
46 Sleft[v] = first_count ? 1 : -(long double)get_ngram().wordProb((*w.we)[v].node.node->get_id(),vi);
47 //cerr << "Sleft(" << vi[0] << "," << v << ") = " << Sleft[v] << endl;
50 int next = wers[ii]->pos+wers[ii]->len;
51 if (next < n) {
52 const WordEntryRefs &wers2 = w.get_we(next);
53 int iii,nnn = wers2.size();
54 for (iii = 0;iii < nnn;iii ++) {
55 // wers2[iii] is the second node (W').
56 vv = wers2[iii]->id;
57 vi[0] = (*w.we)[v].node.node->get_id();
58 add = first_count ? 1 : Sleft[v]*(-(long double)get_ngram().wordProb((*w.we)[vv].node.node->get_id(),vi));
59 Sleft[vv] += add;
61 //cerr << "Sleft(" << vi[0] << "," << vv << ") = " << Sleft[vv] << endl;
63 // init prev references for Sright phase
64 prev[vv].push_back(v);
66 } else {
67 vi[0] = (*w.we)[v].node.node->get_id();
68 Sright[v] = first_count ? 1 : -(long double)get_ngram().wordProb(get_id(STOP_ID),vi);
69 //cerr << "Sright(" << vi[0] << "," << v << ") = " << Sright[v] << endl;
70 sum += Sleft[v];
71 //cerr << "Sum " << sum << endl;
72 ends.push_back(v);
77 // second pass: Sright
78 Sright[w.we->size()-1] = 1;
79 for (i = n-1;i >= 0;i --) {
80 const WordEntryRefs &wers = w.get_we(i);
81 int ii,nn = wers.size();
82 long double add;
83 for (ii = 0;ii < nn;ii ++) {
84 // wers[ii] is the first node (W).
85 v = wers[ii]->id;
87 int iii,nnn = prev[v].size();
88 for (iii = 0;iii < nnn;iii ++) {
89 // vv is the second node (W').
90 vv = prev[v][iii];
91 vi[0] = (*w.we)[vv].node.node->get_id();
92 add = first_count ? 1 : Sright[v]*(-(long double)get_ngram().wordProb((*w.we)[v].node.node->get_id(),vi));
93 Sright[vv] += add;
95 //cerr << "Sright(" << vi[0] << "," << vv << ") = " << Sright[vv] << endl;
97 // collect fractional counts
98 fc = Sleft[vv]*add/sum; // P(v/vv)
99 vi[0] = (*w.we)[vv].node.node->get_id();
100 vi[1] = (*w.we)[v].node.node->get_id();
101 vi[2] = 0;
102 //*stats.insertCount(vi) += fc;
103 os << boost::format("%s %s %f\n") % get_ngram()[vi[0]] % get_ngram()[vi[1]] << fc;
104 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
105 vi[1] = 0;
110 // collect fc of ends
111 // we can use Sright with no problems becase there is only one edge
112 // from ends[i] to the end.
113 n = ends.size();
114 vi[2] = 0;
115 vi[1] = get_id(STOP_ID);
116 for (i = 0;i < n;i ++) {
117 fc = Sleft[ends[i]]*Sright[ends[i]]/sum;
118 vi[0] = (*w.we)[ends[i]].node.node->get_id();
119 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
120 //*stats.insertCount(vi) += fc;
121 os << boost::format("%s %s %f\n") % get_ngram()[vi[0]] % get_ngram()[vi[1]] << fc;
124 vi[0] = get_id(START_ID);
125 const WordEntryRefs &wers = w.get_we(0);
126 n = wers.size();
127 for (i = 0;i < n;i ++) {
128 fc = Sleft[wers[i]->id]*Sright[wers[i]->id]/sum;
129 vi[1] = (*w.we)[wers[i]->id].node.node->get_id();
130 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
131 //*stats.insertCount(vi) += fc;
132 os << boost::format("%s %s %f\n") % get_ngram()[vi[0]] % get_ngram()[vi[1]] << fc;
134 return os;
138 ostream& SoftCounter::count_dag(const DAG &dag,ostream &os,int id, bool first_count)
140 int n = dag.node_count();
141 vector<long double> Sleft(n),Sright(n);
142 vector<set<uint> > prev(n);
143 int i,v,vv;
144 long double add;
146 //cerr << "Nodes: " << n << endl;
148 vector<bool> mark(n);
149 deque<uint> traces;
150 Sleft[dag.node_begin()] = 1;
151 traces.push_back(dag.node_begin());
152 // first pass: Sleft
153 while (!traces.empty()) {
154 v = traces.front();
155 traces.pop_front();
156 if (mark[v])
157 continue;
158 else
159 mark[v] = true;
160 std::vector<uint> nexts;
161 dag.get_next(v,nexts);
162 n = nexts.size();
164 for (i = 0;i < n;i ++) {
165 vv = nexts[i];
166 add = Sleft[v]*(first_count ? 1 : LogPtoProb(-dag.edge_value(v,vv)));
167 if (add == 0.0 || add == -0.0)
168 cerr << boost::format("WARNING: %d: Sleft addition for %d is zero (Sleft[%d] = %Lg, prob=%g)") % id % vv % v % Sleft[v] % LogPtoProb(-dag.edge_value(v,vv)) << endl;
169 Sleft[vv] += add;
171 traces.push_back(vv);
173 // init prev references for Sright phase
174 prev[vv].insert(v);
178 //cerr << "Sleft done" << endl;
180 long double fc;
181 // second pass: Sright
182 long double sum = Sleft[dag.node_end()];
183 if (sum == 0.0 || sum == -0.0) {
184 cerr << boost::format("WARNING: %d: Sum is zero") % id << endl;
185 // Can do nothing more because sum is zero
186 return os;
188 Sright[dag.node_end()] = 1;
189 traces.clear();
190 traces.push_back(dag.node_end()); // the last v above
191 while (!traces.empty()) {
192 vv = traces.front();
193 traces.pop_front();
194 if (!mark[vv])
195 continue;
196 else
197 mark[vv] = false;
198 //cerr << vv << " ";
199 set<uint>::iterator iter;
200 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
201 v = *iter;
202 traces.push_back(v);
204 add = Sright[vv]*(first_count ? 1 : LogPtoProb(-dag.edge_value(v,vv)));
205 if (add == 0.0 || add == -0.0)
206 cerr << boost::format("WARNING: %d: Sright addition for %d is zero") % id % v << endl;
207 Sright[v] += add;
209 // collect fractional counts
210 fc = Sleft[v]*add/sum; // P(vv/v)
211 VocabIndex vi[10];
212 VocabIndex vvv;
213 if (dag.fill_vi(v,vv,vvv,vi,9)) {
214 uint t,jn,j;
215 for (jn = 0;vi[jn] != 0 && jn < 9; jn ++);
216 if (jn < 9 && vi[jn] == 0) {
217 for (j = 0;j < jn/2;j ++) {
218 t = vi[j];
219 vi[j] = vi[jn-j-1];
220 vi[jn-j-1] = t;
222 vi[jn] = vvv;
223 vi[jn+1] = 0;
224 //stats.countSentence(vi,/*LogPtoProb(--fc)*/fc);
225 //*stats.insertCount(vi) += fc;
226 // FIXME
227 for (int i_vi = 0; vi[i_vi] != 0; i_vi ++)
228 cout << get_ngram()[vi[i_vi]] << " ";
229 cout << fc;
230 cout << endl;
235 //cerr << "Sright done" << endl;
236 return os;
239 void SoftCounter::record2(const DAG &dag,FILE *fp,int id)
241 int n = dag.node_count();
242 vector<set<uint> > prev(n);
243 int i,v,vv;
245 vector<bool> mark(n);
246 deque<uint> traces;
247 traces.push_back(dag.node_begin());
249 fprintf(fp,"%d %d %d\n",n,dag.node_begin(),dag.node_end());
251 while (!traces.empty()) {
252 v = traces.front();
253 traces.pop_front();
254 if (mark[v])
255 continue;
256 else
257 mark[v] = true;
258 std::vector<uint> nexts;
259 dag.get_next(v,nexts);
260 n = nexts.size();
262 for (i = 0;i < n;i ++) {
263 vv = nexts[i];
265 VocabIndex edge_vi[2],edge_v;
266 dag.fill_vi(v,vv,edge_v,edge_vi,2);
267 fprintf(fp,"L %d %d %s %s\n",v,vv,get_ngram()[edge_vi[0]],get_ngram()[edge_v]);
269 traces.push_back(vv);
271 // init prev references for Sright phase
272 prev[vv].insert(v);
276 traces.clear();
277 traces.push_back(dag.node_end()); // the last v above
278 while (!traces.empty()) {
279 vv = traces.front();
280 traces.pop_front();
281 if (!mark[vv])
282 continue;
283 else
284 mark[vv] = false;
286 set<uint>::iterator iter;
287 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
288 v = *iter;
289 traces.push_back(v);
291 VocabIndex edge_vi[2],edge_v;
292 dag.fill_vi(v,vv,edge_v,edge_vi,2);
293 fprintf(fp,"R %d %d %s %s\n",v,vv,get_ngram()[edge_vi[0]],get_ngram()[edge_v]);
296 fprintf(fp,"E 0 0 none none\n");
299 int SoftCounter::replay2(FILE *fp_in,FILE *fp_out, int id,bool first_count)
301 int n, node_begin, node_end;
303 if (fscanf(fp_in,"%d %d %d\n",&n,&node_begin,&node_end) != 3) {
304 fprintf(stderr,"Error: %d: Could not read dag count\n",id);
305 return -2;
308 vector<long double> Sleft(n),Sright(n);
309 int i,v,vv;
310 long double add;
311 char type[2]; // should not be longer than one
312 char str1[100],str2[100];
313 int right_mode = 0;
314 long double sum;
316 //cerr << "Nodes: " << n << endl;
318 Sleft[node_begin] = 1;
320 while (fscanf(fp_in,"%s %d %d %s %s\n",type,&v,&vv,str1,str2) == 5) {
321 if (type[0] == 'E')
322 return 0;
323 //fprintf(stderr,"Got %s %d %d %s %s\n",type,v,vv,str1,str2);
324 VocabIndex edge_vi[2],edge_v;
325 edge_v = get_ngram()[str2];
326 edge_vi[0] = get_ngram()[str1];
327 edge_vi[1] = 0;
328 if (type[0] == 'L') {
329 add = Sleft[v]*(first_count ? 1 : (long double)LogPtoProb(get_ngram().wordProb(edge_v,edge_vi)));
330 if (add == 0.0 || add == -0.0)
331 fprintf(stderr,"WARNING: %d: Sleft addition for %d is zero (Sleft[%d] = %Lg, prob=%g %s %s)\n",id,vv,v,Sleft[v],LogPtoProb(get_ngram().wordProb(edge_v,edge_vi)),str1,str2);
332 Sleft[vv] += add;
334 else {
335 if (!right_mode) {
336 right_mode = 1;
337 sum = Sleft[node_end];
338 if (sum == 0.0 || sum == -0.0) {
339 fprintf(stderr,"WARNING: %d: Sum is zero\n",id);
341 Sright[node_end] = 1;
343 if (sum == 0.0 || sum == -0.0)
344 continue;
345 add = Sright[vv]*(first_count ? 1 : (long double)LogPtoProb(get_ngram().wordProb(edge_v,edge_vi)));
346 if (add == 0.0 || add == -0.0)
347 fprintf(stderr,"WARNING: %d: Sright addition for %d is zero (%s %s)\n",id,v,str1,str2);
348 Sright[v] += add;
350 // collect fractional counts
351 long double fc = Sleft[v]*add/sum; // P(vv/v)
352 fprintf(fp_out,"%s %s %Lg\n",str1,str2,fc);
355 return 0;
359 A work-around because NgramFractionalStats is still buggy.
363 ostream& SoftCounter::count_dag_fixed(const DAG &dag,ostream &os,bool first_count)
365 vector<long double> Sleft,Sright;
366 vector<set<uint> > prev;
367 int i,n,v,vv;
368 long double add;
370 n = dag.node_count();
371 Sleft.resize(n);
372 Sright.resize(n);
373 prev.resize(n);
374 //cerr << "Nodes: " << n << endl;
376 // topo sort
377 vector<uint> traces;
378 traces.push_back(dag.node_begin());
379 int itrace = 0;
380 while (itrace < traces.size()) {
381 v = traces[itrace++];
382 std::vector<uint> nexts;
383 dag.get_next(v,nexts);
384 n = nexts.size();
386 for (i = 0;i < n;i ++) {
387 vector<uint>::iterator iter = find(traces.begin(),traces.end(),nexts[i]);
388 if (iter != traces.end()) {
389 traces.erase(iter);
390 if (iter - traces.begin() <= itrace-1)
391 itrace --;
393 traces.push_back(nexts[i]);
397 // first pass: Sleft
398 uint ntrace = traces.size();
399 Sleft[dag.node_begin()] = 1; // log(1) = 0
400 for (itrace = 0;itrace < ntrace;itrace ++) {
401 v = traces[itrace];
402 //cout << " " << v << ":" << Sleft[v];
403 std::vector<uint> nexts;
404 dag.get_next(v,nexts);
405 n = nexts.size();
407 for (i = 0;i < n;i ++) {
408 vv = nexts[i];
409 add = Sleft[v]*dag.edge_value(v,vv);
410 //cout << "-" << dag.edge_value(v,vv) << "-" << Sleft[vv] << ">" << vv;
411 Sleft[vv] += add;
413 // init prev references for Sright phase
414 prev[vv].insert(v);
418 //cerr << "Sleft done" << endl;
420 unsigned int fc;
421 // second pass: Sright
422 long double sum = Sleft[dag.node_end()];
423 if (sum == 0)
424 return os;
425 //cout << "Sum " << sum << endl;
427 traces.clear();
428 traces.push_back(dag.node_end());
429 itrace = 0;
430 while (itrace < traces.size()) {
431 vv = traces[itrace++];
433 set<uint>::iterator iter;
434 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
435 vector<uint>::iterator iiter = find(traces.begin(),traces.end(),*iter);
436 if (iiter != traces.end()) {
437 traces.erase(iiter);
438 if (iiter - traces.begin() <= itrace-1)
439 itrace --;
441 traces.push_back(*iter);
445 ntrace = traces.size();
446 Sright[dag.node_end()] = 1; // log(1) = 0
447 for (itrace = 0;itrace < ntrace;itrace ++) {
448 vv = traces[itrace];
449 //cout << " " << vv << ":" << Sright[vv];
450 set<uint>::iterator iter;
451 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
452 v = *iter;
453 add = Sright[vv]*dag.edge_value(v,vv);
454 //cout << "-" << dag.edge_value(v,vv)<< "-" << Sright[v] << ">" << v;
455 Sright[v] += add;
457 // collect fractional counts
458 fc = 100-(unsigned int)((Sleft[v]*add)*100.0/sum); // P(vv/v)
459 //cout << Sleft[v] << "+" << dag.edge_value(v,vv) << Sright[vv]<< "=("<< (Sleft[v]+add)<< ")"<<((Sleft[v]+add)/sum) <<"_" << fc << endl;
460 cerr << v << " " << vv << " " << fc << endl;
461 if (fc != 0) {
462 VocabIndex vi[10];
463 VocabIndex vvv;
464 if (dag.fill_vi(v,vv,vvv,vi,9)) {
465 uint t,jn,j;
466 for (jn = 0;vi[jn] != 0 && jn < 9; jn ++);
467 if (jn < 9 && vi[jn] == 0) {
468 for (j = 0;j < jn/2;j ++) {
469 t = vi[j];
470 vi[j] = vi[jn-j-1];
471 vi[jn-j-1] = t;
473 vi[jn] = vvv;
474 vi[jn+1] = 0;
475 //stats.countSentence(vi,(unsigned int)(fc*10.0));
476 //*stats.insertCount(vi) += fc;
477 // FIXME
483 cerr << endl;
484 //long double sum2 = Sright[dag.node_begin()];
485 //cout << sum2 << " " << (traces[ntrace-1] == dag.node_begin()) << " " << (sum2 == sum ? "Ok" : "Failed") << endl;
486 //cerr << "Sright done" << endl;
487 return os;