Prevent crashes when get node_begin from an empty DAG
[vspell.git] / utils / softcount.cpp
blob2938ec91ba2570d4bc750cae6e511d2a36ae5890
1 #include "softcount.h" // -*- tab-width: 2 -*-
2 #include <deque>
4 using namespace std;
6 /**
7 No short description.
8 */
9 void SoftCounter::count(const Lattice &w,NgramFractionalStats &stats)
11 vector<float> Sleft,Sright;
12 vector<vector<uint> > prev;
13 //boost::shared_ptr<WordEntries> p_we = w.we;
14 //const WordEntries &we = *p_we;
15 int i,n = w.get_word_count(),v,vv;
16 float sum = 0;
17 VocabIndex vi[3];
18 vector<uint> ends;
19 float fc;
21 Sleft.resize(w.we->size());
22 Sright.resize(w.we->size());
23 prev.resize(w.we->size());
25 vi[1] = Vocab_None;
27 // first pass: Sleft
28 for (i = 0;i < n;i ++) {
29 const WordEntryRefs &wers = w.get_we(i);
30 int ii,nn = wers.size();
31 float add;
32 for (ii = 0;ii < nn;ii ++) {
33 // wers[ii] is the first node (W).
34 v = wers[ii]->id;
36 if (i == 0) {
37 vi[0] = get_id(START_ID);
38 Sleft[v] = -get_ngram().wordProb((*w.we)[v].node.node->get_id(),vi);
39 //cerr << "Sleft(" << vi[0] << "," << v << ") = " << Sleft[v] << endl;
42 int next = wers[ii]->pos+wers[ii]->len;
43 if (next < n) {
44 const WordEntryRefs &wers2 = w.get_we(next);
45 int iii,nnn = wers2.size();
46 for (iii = 0;iii < nnn;iii ++) {
47 // wers2[iii] is the second node (W').
48 vv = wers2[iii]->id;
49 vi[0] = (*w.we)[v].node.node->get_id();
50 add = Sleft[v]*(-get_ngram().wordProb((*w.we)[vv].node.node->get_id(),vi));
51 Sleft[vv] += add;
53 //cerr << "Sleft(" << vi[0] << "," << vv << ") = " << Sleft[vv] << endl;
55 // init prev references for Sright phase
56 prev[vv].push_back(v);
58 } else {
59 vi[0] = (*w.we)[v].node.node->get_id();
60 Sright[v] = -get_ngram().wordProb(get_id(STOP_ID),vi);
61 //cerr << "Sright(" << vi[0] << "," << v << ") = " << Sright[v] << endl;
62 sum += Sleft[v];
63 //cerr << "Sum " << sum << endl;
64 ends.push_back(v);
69 // second pass: Sright
70 Sright[w.we->size()-1] = 1;
71 for (i = n-1;i >= 0;i --) {
72 const WordEntryRefs &wers = w.get_we(i);
73 int ii,nn = wers.size();
74 float add;
75 for (ii = 0;ii < nn;ii ++) {
76 // wers[ii] is the first node (W).
77 v = wers[ii]->id;
79 int iii,nnn = prev[v].size();
80 for (iii = 0;iii < nnn;iii ++) {
81 // vv is the second node (W').
82 vv = prev[v][iii];
83 vi[0] = (*w.we)[vv].node.node->get_id();
84 add = Sright[v]*(-get_ngram().wordProb((*w.we)[v].node.node->get_id(),vi));
85 Sright[vv] += add;
87 //cerr << "Sright(" << vi[0] << "," << vv << ") = " << Sright[vv] << endl;
89 // collect fractional counts
90 fc = Sleft[vv]*add/sum; // P(v/vv)
91 vi[0] = (*w.we)[vv].node.node->get_id();
92 vi[1] = (*w.we)[v].node.node->get_id();
93 vi[2] = Vocab_None;
94 *stats.insertCount(vi) += fc;
95 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
96 vi[1] = Vocab_None;
101 // collect fc of ends
102 // we can use Sright with no problems becase there is only one edge
103 // from ends[i] to the end.
104 n = ends.size();
105 vi[2] = Vocab_None;
106 vi[1] = get_id(STOP_ID);
107 for (i = 0;i < n;i ++) {
108 fc = Sleft[ends[i]]*Sright[ends[i]]/sum;
109 vi[0] = (*w.we)[ends[i]].node.node->get_id();
110 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
111 *stats.insertCount(vi) += fc;
114 vi[0] = get_id(START_ID);
115 const WordEntryRefs &wers = w.get_we(0);
116 n = wers.size();
117 for (i = 0;i < n;i ++) {
118 fc = Sleft[wers[i]->id]*Sright[wers[i]->id]/sum;
119 vi[1] = (*w.we)[wers[i]->id].node.node->get_id();
120 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
121 *stats.insertCount(vi) += fc;
126 void SoftCounter::count(const DAG &dag,NgramFractionalStats &stats)
128 vector<float> Sleft,Sright;
129 vector<set<uint> > prev;
130 int i,n,v,vv;
131 float add;
133 n = dag.node_count();
134 Sleft.resize(n);
135 Sright.resize(n);
136 prev.resize(n);
137 //cerr << "Nodes: " << n << endl;
139 vector<bool> mark;
140 mark.resize(n);
141 deque<uint> traces;
142 Sleft[dag.node_begin()] = 1;
143 traces.push_back(dag.node_begin());
144 // first pass: Sleft
145 while (!traces.empty()) {
146 v = traces.front();
147 traces.pop_front();
148 if (mark[v])
149 continue;
150 else
151 mark[v] = true;
152 std::vector<uint> nexts;
153 dag.get_next(v,nexts);
154 n = nexts.size();
156 for (i = 0;i < n;i ++) {
157 vv = nexts[i];
158 add = Sleft[v]*LogPtoProb(-dag.edge_value(v,vv));
159 Sleft[vv] += add;
161 traces.push_back(vv);
163 // init prev references for Sright phase
164 prev[vv].insert(v);
168 //cerr << "Sleft done" << endl;
170 float fc;
171 // second pass: Sright
172 float sum = Sleft[dag.node_end()];
173 Sright[dag.node_end()] = 1;
174 traces.clear();
175 traces.push_back(dag.node_end()); // the last v above
176 while (!traces.empty()) {
177 vv = traces.front();
178 traces.pop_front();
179 if (!mark[vv])
180 continue;
181 else
182 mark[vv] = false;
183 //cerr << vv << " ";
184 set<uint>::iterator iter;
185 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
186 v = *iter;
187 traces.push_back(v);
189 add = Sright[vv]*LogPtoProb(-dag.edge_value(v,vv));
190 Sright[v] += add;
192 // collect fractional counts
193 fc = Sleft[v]*add/sum; // P(vv/v)
194 VocabIndex vi[10];
195 VocabIndex vvv;
196 if (dag.fill_vi(v,vv,vvv,vi,9)) {
197 uint t,jn,j;
198 for (jn = 0;vi[jn] != Vocab_None && jn < 9; jn ++);
199 if (jn < 9 && vi[jn] == Vocab_None) {
200 for (j = 0;j < jn/2;j ++) {
201 t = vi[j];
202 vi[j] = vi[jn-j-1];
203 vi[jn-j-1] = t;
205 vi[jn] = vvv;
206 vi[jn+1] = Vocab_None;
207 //stats.countSentence(vi,/*LogPtoProb(--fc)*/fc);
208 *stats.insertCount(vi) += fc;
213 //cerr << "Sright done" << endl;
217 A work-around because NgramFractionalStats is still buggy.
221 void SoftCounter::count(const DAG &dag,NgramStats &stats)
223 vector<double> Sleft,Sright;
224 vector<set<uint> > prev;
225 int i,n,v,vv;
226 double add;
228 n = dag.node_count();
229 Sleft.resize(n);
230 Sright.resize(n);
231 prev.resize(n);
232 //cerr << "Nodes: " << n << endl;
234 // topo sort
235 vector<uint> traces;
236 traces.push_back(dag.node_begin());
237 int itrace = 0;
238 while (itrace < traces.size()) {
239 v = traces[itrace++];
240 std::vector<uint> nexts;
241 dag.get_next(v,nexts);
242 n = nexts.size();
244 for (i = 0;i < n;i ++) {
245 vector<uint>::iterator iter = find(traces.begin(),traces.end(),nexts[i]);
246 if (iter != traces.end()) {
247 traces.erase(iter);
248 if (iter - traces.begin() <= itrace-1)
249 itrace --;
251 traces.push_back(nexts[i]);
255 // first pass: Sleft
256 uint ntrace = traces.size();
257 Sleft[dag.node_begin()] = 1; // log(1) = 0
258 for (itrace = 0;itrace < ntrace;itrace ++) {
259 v = traces[itrace];
260 //cout << " " << v << ":" << Sleft[v];
261 std::vector<uint> nexts;
262 dag.get_next(v,nexts);
263 n = nexts.size();
265 for (i = 0;i < n;i ++) {
266 vv = nexts[i];
267 add = Sleft[v]*dag.edge_value(v,vv);
268 //cout << "-" << dag.edge_value(v,vv) << "-" << Sleft[vv] << ">" << vv;
269 Sleft[vv] += add;
271 // init prev references for Sright phase
272 prev[vv].insert(v);
276 //cerr << "Sleft done" << endl;
278 unsigned int fc;
279 // second pass: Sright
280 double sum = Sleft[dag.node_end()];
281 if (sum == 0)
282 return;
283 //cout << "Sum " << sum << endl;
285 traces.clear();
286 traces.push_back(dag.node_end());
287 itrace = 0;
288 while (itrace < traces.size()) {
289 vv = traces[itrace++];
291 set<uint>::iterator iter;
292 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
293 vector<uint>::iterator iiter = find(traces.begin(),traces.end(),*iter);
294 if (iiter != traces.end()) {
295 traces.erase(iiter);
296 if (iiter - traces.begin() <= itrace-1)
297 itrace --;
299 traces.push_back(*iter);
303 ntrace = traces.size();
304 Sright[dag.node_end()] = 1; // log(1) = 0
305 for (itrace = 0;itrace < ntrace;itrace ++) {
306 vv = traces[itrace];
307 //cout << " " << vv << ":" << Sright[vv];
308 set<uint>::iterator iter;
309 for (iter = prev[vv].begin();iter != prev[vv].end(); ++iter) {
310 v = *iter;
311 add = Sright[vv]*dag.edge_value(v,vv);
312 //cout << "-" << dag.edge_value(v,vv)<< "-" << Sright[v] << ">" << v;
313 Sright[v] += add;
315 // collect fractional counts
316 fc = 100-(unsigned int)((Sleft[v]*add)*100.0/sum); // P(vv/v)
317 //cout << Sleft[v] << "+" << dag.edge_value(v,vv) << Sright[vv]<< "=("<< (Sleft[v]+add)<< ")"<<((Sleft[v]+add)/sum) <<"_" << fc << endl;
318 cout << v << " " << vv << " " << fc << endl;
319 if (fc != 0) {
320 VocabIndex vi[10];
321 VocabIndex vvv;
322 if (dag.fill_vi(v,vv,vvv,vi,9)) {
323 uint t,jn,j;
324 for (jn = 0;vi[jn] != Vocab_None && jn < 9; jn ++);
325 if (jn < 9 && vi[jn] == Vocab_None) {
326 for (j = 0;j < jn/2;j ++) {
327 t = vi[j];
328 vi[j] = vi[jn-j-1];
329 vi[jn-j-1] = t;
331 vi[jn] = vvv;
332 vi[jn+1] = Vocab_None;
333 //stats.countSentence(vi,(unsigned int)(fc*10.0));
334 *stats.insertCount(vi) += fc;
340 cout << endl;
341 //double sum2 = Sright[dag.node_begin()];
342 //cout << sum2 << " " << (traces[ntrace-1] == dag.node_begin()) << " " << (sum2 == sum ? "Ok" : "Failed") << endl;
343 //cerr << "Sright done" << endl;