1 #include "softcount.h" // -*- tab-width: 2 -*-
9 void SoftCounter::count(const Lattice
&w
,NgramFractionalStats
&stats
)
11 vector
<float> Sleft
,Sright
;
12 vector
<vector
<uint
> > prev
;
13 //boost::shared_ptr<WordEntries> p_we = w.we;
14 //const WordEntries &we = *p_we;
15 int i
,n
= w
.get_word_count(),v
,vv
;
21 Sleft
.resize(w
.we
->size());
22 Sright
.resize(w
.we
->size());
23 prev
.resize(w
.we
->size());
28 for (i
= 0;i
< n
;i
++) {
29 const WordEntryRefs
&wers
= w
.get_we(i
);
30 int ii
,nn
= wers
.size();
32 for (ii
= 0;ii
< nn
;ii
++) {
33 // wers[ii] is the first node (W).
37 vi
[0] = get_id(START_ID
);
38 Sleft
[v
] = -get_ngram().wordProb((*w
.we
)[v
].node
.node
->get_id(),vi
);
39 //cerr << "Sleft(" << vi[0] << "," << v << ") = " << Sleft[v] << endl;
42 int next
= wers
[ii
]->pos
+wers
[ii
]->len
;
44 const WordEntryRefs
&wers2
= w
.get_we(next
);
45 int iii
,nnn
= wers2
.size();
46 for (iii
= 0;iii
< nnn
;iii
++) {
47 // wers2[iii] is the second node (W').
49 vi
[0] = (*w
.we
)[v
].node
.node
->get_id();
50 add
= Sleft
[v
]*(-get_ngram().wordProb((*w
.we
)[vv
].node
.node
->get_id(),vi
));
53 //cerr << "Sleft(" << vi[0] << "," << vv << ") = " << Sleft[vv] << endl;
55 // init prev references for Sright phase
56 prev
[vv
].push_back(v
);
59 vi
[0] = (*w
.we
)[v
].node
.node
->get_id();
60 Sright
[v
] = -get_ngram().wordProb(get_id(STOP_ID
),vi
);
61 //cerr << "Sright(" << vi[0] << "," << v << ") = " << Sright[v] << endl;
63 //cerr << "Sum " << sum << endl;
69 // second pass: Sright
70 Sright
[w
.we
->size()-1] = 1;
71 for (i
= n
-1;i
>= 0;i
--) {
72 const WordEntryRefs
&wers
= w
.get_we(i
);
73 int ii
,nn
= wers
.size();
75 for (ii
= 0;ii
< nn
;ii
++) {
76 // wers[ii] is the first node (W).
79 int iii
,nnn
= prev
[v
].size();
80 for (iii
= 0;iii
< nnn
;iii
++) {
81 // vv is the second node (W').
83 vi
[0] = (*w
.we
)[vv
].node
.node
->get_id();
84 add
= Sright
[v
]*(-get_ngram().wordProb((*w
.we
)[v
].node
.node
->get_id(),vi
));
87 //cerr << "Sright(" << vi[0] << "," << vv << ") = " << Sright[vv] << endl;
89 // collect fractional counts
90 fc
= Sleft
[vv
]*add
/sum
; // P(v/vv)
91 vi
[0] = (*w
.we
)[vv
].node
.node
->get_id();
92 vi
[1] = (*w
.we
)[v
].node
.node
->get_id();
94 *stats
.insertCount(vi
) += fc
;
95 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
101 // collect fc of ends
102 // we can use Sright with no problems becase there is only one edge
103 // from ends[i] to the end.
106 vi
[1] = get_id(STOP_ID
);
107 for (i
= 0;i
< n
;i
++) {
108 fc
= Sleft
[ends
[i
]]*Sright
[ends
[i
]]/sum
;
109 vi
[0] = (*w
.we
)[ends
[i
]].node
.node
->get_id();
110 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
111 *stats
.insertCount(vi
) += fc
;
114 vi
[0] = get_id(START_ID
);
115 const WordEntryRefs
&wers
= w
.get_we(0);
117 for (i
= 0;i
< n
;i
++) {
118 fc
= Sleft
[wers
[i
]->id
]*Sright
[wers
[i
]->id
]/sum
;
119 vi
[1] = (*w
.we
)[wers
[i
]->id
].node
.node
->get_id();
120 //cerr << "Gram " << vi[0] << "," << vi[1] << "+=" << fc << endl;
121 *stats
.insertCount(vi
) += fc
;
126 void SoftCounter::count(const DAG
&dag
,NgramFractionalStats
&stats
)
128 vector
<float> Sleft
,Sright
;
129 vector
<set
<uint
> > prev
;
133 n
= dag
.node_count();
137 //cerr << "Nodes: " << n << endl;
142 Sleft
[dag
.node_begin()] = 1;
143 traces
.push_back(dag
.node_begin());
145 while (!traces
.empty()) {
152 std::vector
<uint
> nexts
;
153 dag
.get_next(v
,nexts
);
156 for (i
= 0;i
< n
;i
++) {
158 add
= Sleft
[v
]*LogPtoProb(-dag
.edge_value(v
,vv
));
161 traces
.push_back(vv
);
163 // init prev references for Sright phase
168 //cerr << "Sleft done" << endl;
171 // second pass: Sright
172 float sum
= Sleft
[dag
.node_end()];
173 Sright
[dag
.node_end()] = 1;
175 traces
.push_back(dag
.node_end()); // the last v above
176 while (!traces
.empty()) {
184 set
<uint
>::iterator iter
;
185 for (iter
= prev
[vv
].begin();iter
!= prev
[vv
].end(); ++iter
) {
189 add
= Sright
[vv
]*LogPtoProb(-dag
.edge_value(v
,vv
));
192 // collect fractional counts
193 fc
= Sleft
[v
]*add
/sum
; // P(vv/v)
196 if (dag
.fill_vi(v
,vv
,vvv
,vi
,9)) {
198 for (jn
= 0;vi
[jn
] != Vocab_None
&& jn
< 9; jn
++);
199 if (jn
< 9 && vi
[jn
] == Vocab_None
) {
200 for (j
= 0;j
< jn
/2;j
++) {
206 vi
[jn
+1] = Vocab_None
;
207 //stats.countSentence(vi,/*LogPtoProb(--fc)*/fc);
208 *stats
.insertCount(vi
) += fc
;
213 //cerr << "Sright done" << endl;
217 A work-around because NgramFractionalStats is still buggy.
221 void SoftCounter::count(const DAG
&dag
,NgramStats
&stats
)
223 vector
<double> Sleft
,Sright
;
224 vector
<set
<uint
> > prev
;
228 n
= dag
.node_count();
232 //cerr << "Nodes: " << n << endl;
236 traces
.push_back(dag
.node_begin());
238 while (itrace
< traces
.size()) {
239 v
= traces
[itrace
++];
240 std::vector
<uint
> nexts
;
241 dag
.get_next(v
,nexts
);
244 for (i
= 0;i
< n
;i
++) {
245 vector
<uint
>::iterator iter
= find(traces
.begin(),traces
.end(),nexts
[i
]);
246 if (iter
!= traces
.end()) {
248 if (iter
- traces
.begin() <= itrace
-1)
251 traces
.push_back(nexts
[i
]);
256 uint ntrace
= traces
.size();
257 Sleft
[dag
.node_begin()] = 1; // log(1) = 0
258 for (itrace
= 0;itrace
< ntrace
;itrace
++) {
260 //cout << " " << v << ":" << Sleft[v];
261 std::vector
<uint
> nexts
;
262 dag
.get_next(v
,nexts
);
265 for (i
= 0;i
< n
;i
++) {
267 add
= Sleft
[v
]*dag
.edge_value(v
,vv
);
268 //cout << "-" << dag.edge_value(v,vv) << "-" << Sleft[vv] << ">" << vv;
271 // init prev references for Sright phase
276 //cerr << "Sleft done" << endl;
279 // second pass: Sright
280 double sum
= Sleft
[dag
.node_end()];
283 //cout << "Sum " << sum << endl;
286 traces
.push_back(dag
.node_end());
288 while (itrace
< traces
.size()) {
289 vv
= traces
[itrace
++];
291 set
<uint
>::iterator iter
;
292 for (iter
= prev
[vv
].begin();iter
!= prev
[vv
].end(); ++iter
) {
293 vector
<uint
>::iterator iiter
= find(traces
.begin(),traces
.end(),*iter
);
294 if (iiter
!= traces
.end()) {
296 if (iiter
- traces
.begin() <= itrace
-1)
299 traces
.push_back(*iter
);
303 ntrace
= traces
.size();
304 Sright
[dag
.node_end()] = 1; // log(1) = 0
305 for (itrace
= 0;itrace
< ntrace
;itrace
++) {
307 //cout << " " << vv << ":" << Sright[vv];
308 set
<uint
>::iterator iter
;
309 for (iter
= prev
[vv
].begin();iter
!= prev
[vv
].end(); ++iter
) {
311 add
= Sright
[vv
]*dag
.edge_value(v
,vv
);
312 //cout << "-" << dag.edge_value(v,vv)<< "-" << Sright[v] << ">" << v;
315 // collect fractional counts
316 fc
= 100-(unsigned int)((Sleft
[v
]*add
)*100.0/sum
); // P(vv/v)
317 //cout << Sleft[v] << "+" << dag.edge_value(v,vv) << Sright[vv]<< "=("<< (Sleft[v]+add)<< ")"<<((Sleft[v]+add)/sum) <<"_" << fc << endl;
318 cout
<< v
<< " " << vv
<< " " << fc
<< endl
;
322 if (dag
.fill_vi(v
,vv
,vvv
,vi
,9)) {
324 for (jn
= 0;vi
[jn
] != Vocab_None
&& jn
< 9; jn
++);
325 if (jn
< 9 && vi
[jn
] == Vocab_None
) {
326 for (j
= 0;j
< jn
/2;j
++) {
332 vi
[jn
+1] = Vocab_None
;
333 //stats.countSentence(vi,(unsigned int)(fc*10.0));
334 *stats
.insertCount(vi
) += fc
;
341 //double sum2 = Sright[dag.node_begin()];
342 //cout << sum2 << " " << (traces[ntrace-1] == dag.node_begin()) << " " << (sum2 == sum ? "Ok" : "Failed") << endl;
343 //cerr << "Sright done" << endl;