Change Words to Lattice
[vspell.git] / libvspell / wfst.cpp
blobbefb16bd0c2a5cfa7cc4286d0d0b0070bb73ce76
1 #include "config.h" // -*- tab-width: 2 -*-
2 #include "wfst.h"
3 #include <iterator>
4 #include <algorithm>
5 #include <sstream>
6 #include <iostream>
7 #include <fstream>
8 #include <boost/format.hpp>
9 #include <set>
10 #include "posgen.h"
11 #include <signal.h>
12 #include <unistd.h>
14 using namespace std;
16 unsigned int ngram_length = 2; // for lazy men ;)
18 /**
19 Don't know how to name this class.
20 It is used to generate a "to hop" ;)
24 class ZZZ
26 private:
27 std::vector<uint> limit,iter;
28 int nr_limit,cur;
30 public:
31 void init(const std::vector<uint> &limit);
32 bool step(std::vector<uint> &pos);
33 void done();
38 void Generator::init(const Sentence &_st)
40 nr_misspelled = 3; // REMEMBER THIS NUMBER
41 misspelled_pos.resize(nr_misspelled);
42 run = true;
44 st = &(Sentence&)_st;
46 len = st->get_syllable_count();
48 // initialize
49 for (i = 0;i < nr_misspelled;i ++)
50 misspelled_pos[i] = i;
52 i = 0;
55 void Generator::done()
59 /**
60 Generate every possible 3-misspelled-positions.
61 Then call WFST::generate_misspelled_words.
64 bool Generator::step(vector<uint> &_pos,uint &_len)
66 while (run) {
68 // move to the next counter
69 if (i+1 < nr_misspelled && misspelled_pos[i] < len) {
70 i ++;
72 // do something here with misspelled_pos[i]
73 _pos = misspelled_pos;
74 _len = i;
75 return true;
78 // the last counter is not full
79 if (misspelled_pos[i] < len) {
80 // do something here with misspelled_pos[nr_misspelled]
81 _pos = misspelled_pos;
82 _len = nr_misspelled;
83 misspelled_pos[i] ++;
84 return true;
87 // the last counter is full, step back
88 while (i >= 0 && misspelled_pos[i] == len) i --;
89 if (i < 0)
90 run = false;
91 else {
92 misspelled_pos[i] ++;
93 for (ii = i+1;ii < nr_misspelled;ii ++)
94 misspelled_pos[ii] = misspelled_pos[ii-1]+1;
97 return false;
100 typedef vector<Segmentations> Segmentation2;
103 Generate new Lattice info based on misspelled position info.
104 Call WFST::get_sections() to split into sections
105 then call WFST::segment_all1() to do real segmentation.
106 \param pos contains possibly misspelled position.
107 \param len specify the actual length pos. Don't use pos.size()
110 void WFST::generate_misspelled_words(const vector<uint> &pos,int len,Segmentation &final_seg)
112 const Lattice &words = *p_words;
113 Lattice w;
115 w.based_on(words);
117 // 2. Compute the score, jump to 1. if score is too low (pruning 1)
119 // create new (more compact) Lattice structure
120 int i,n = words.get_word_count();
121 for (i = 0;i < len;i ++) {
122 const WordEntryRefs &fmap = words.get_fuzzy_map(pos[i]);
123 int ii,nn = fmap.size();
124 for (ii = 0;ii < nn;++ii)
125 w.add(*fmap[ii]);
128 //cerr << w << endl;
130 // 4. Create sections
131 Sections sects;
132 sects.construct(words);
134 // 5. Call create_base_segmentation
135 //Segmentation base_seg(words.we);
136 //create_base_segmentation(words,sects,base_seg);
139 // 6. Get the best segmentation of each section,
140 // then merge to one big segment.
141 n = sects.size();
143 uint ii,nn;
145 i = ii = 0;
146 nn = words.get_word_count();
148 final_seg.clear();
149 while (ii < nn)
150 if (i < n && sects[i].start == ii) {
151 Segmentation seg;
152 sects[i].segment_best(words,seg);
153 copy(seg.begin(),
154 seg.end(),
155 back_insert_iterator< Segmentation >(final_seg));
156 ii += sects[i].len;
157 i ++;
158 } else {
159 // only word(i,*,0) exists
160 final_seg.push_back(words.get_we(ii)[0]->id);
161 ii += words.get_we(ii)[0]->len;
166 Create the best segmentation for a sentence.
167 \param words store Lattice info
168 \param seps return the best segmentation
171 void WFST::segment_best(const Lattice &words,Segmentation &seps)
173 //int i,ii,n,nn;
175 p_words = &words;
177 // in test mode, generate all positions where misspelled can appear,
178 // then create a new Lattice for them, re get_sections,
179 // create_base_segmentation and call segment_all1 for each sections.
181 // 1. Generate mispelled positions (pruning 0 - GA)
182 // 2. Compute the score, jump to 1. if score is too low (pruning 1)
183 // 3. Make a new Lattice based on the original Lattice
184 // 4. Call get_sections
185 // 5. Call create_base_segmentation
186 // 6. Call segment_all1 for each sections.
187 // 6.1. Recompute the score after each section processed. (pruning 2)
189 // 1. Bai toan hoan vi, tinh chap C(k,n) with modification. C(1,n)+C(2,n)+...+C(k,n)
190 Generator gen;
192 gen.init(*words.st);
193 vector<uint> pos;
194 uint len;
195 seps.prob = 100;
196 while (gen.step(pos,len)) {
197 Segmentation seg;
198 //cerr << "POS :";
199 //for (int i = 0;i < len;i ++) cerr << pos[i];
200 //cerr << endl;
201 generate_misspelled_words(pos,len,seg);
202 if (seg.prob < seps.prob)
203 seps = seg;
205 gen.done();
210 Create the best segmentation for a sentence. The biggest difference between
211 segment_best and segment_best_no_fuzzy is that segment_best_no_fuzzy don't
212 use Generator. It assumes there is no misspelled position at all.
213 \param words store Lattice info
214 \param seps return the best segmentation
217 void WFST::segment_best_no_fuzzy(const Lattice &words,Segmentation &seps)
219 p_words = &words;
221 vector<uint> pos;
222 generate_misspelled_words(pos,0,seps);
225 // WFST (Weighted Finite State Transducer) implementation
226 // TUNE: think about genetic/greedy. Vietnamese is almost one-syllable words..
227 // we find where is likely a start of a word, then try to construct word
228 // and check if others are still valid words.
230 // the last item is always the incompleted item. We will try to complete
231 // a word from the item. If we reach the end of sentence, we will remove it
232 // from segs
234 // obsolete
235 void WFST::segment_all(const Sentence &sent,vector<Segmentation> &result)
237 Lattice words;
238 words.construct(sent);
239 // segment_all1(sent,words,0,sent.get_syllable_count(),result);a
241 int nn = words.size();
242 for (i = 0;i < nn;i ++) {
243 int nnn = words[i].size();
244 cerr << "From " << i << endl;
245 for (int ii = 0;ii < nnn;ii ++) {
246 int nnnn = words[i][ii].fuzzy_match.size();
247 cerr << "Len " << ii << endl;
248 for (int iii = 0;iii < nnnn;iii ++) {
249 cerr << words[i][ii].fuzzy_match[iii].distance << " ";
250 cerr << words[i][ii].fuzzy_match[iii].node->get_prob() << endl;
259 * Segmentation comparator
262 class SegmentationComparator
264 public:
265 bool operator() (const Segmentation &seg1,const Segmentation &seg2) {
266 return seg1.prob > seg2.prob;
275 void Segmentor::init(const Lattice &words,
276 int from,
277 int to)
279 nr_syllables = to+1;
281 segs.clear();
282 segs.reserve(100);
284 Trace trace(words.we);
285 trace.next_syllable = from;
286 segs.push_back(trace); // empty seg
288 _words = &words;
291 bool Segmentor::step(Segmentation &result)
293 const Lattice &words = *_words;
294 // const Sentence &sent = *_words->st;
295 while (!segs.empty()) {
296 // get one
297 Trace trace = segs.back();
298 segs.pop_back();
300 Segmentation seg = trace.s;
301 int next_syllable = trace.next_syllable;
303 // segmentation completed. throw it away
304 if (next_syllable >= nr_syllables)
305 continue;
307 //WordEntries::iterator i_entry;
308 WordEntryRefs &wes = words.get_we(next_syllable);
309 int ii,nn = wes.size();
310 for (ii = 0;ii < nn;ii ++) {
311 WordEntryRef i_entry = wes[ii];
313 // New segmentation for longer incomplete word
314 Trace newtrace(words.we);
315 newtrace = trace;
316 newtrace.s.push_back(i_entry->id);
317 newtrace.s.prob += i_entry->node.node->get_prob();
319 /*-it's better to compute ngram outside this function
320 if (ngram_enabled) {
321 if (newtrace.s.size() >= ngram_length) {
322 VocabIndex *vi = new VocabIndex[ngram_length];
323 vi[0] = newtrace.s.items[len-1].node(sent).node->get_id();
324 vi[1] = Vocab_None;
325 newtrace.s.prob += -ngram.wordProb(newtrace.s.items[len-2].node(sent).node->get_id(),vi);
326 delete[] vi;
331 newtrace.next_syllable += i_entry->len;
332 if (newtrace.next_syllable == nr_syllables) {
333 //cout << count << " " << newtrace.s << endl;
334 result = newtrace.s;
335 return true;
336 //result.push_back(newtrace.s);
337 //push_heap(result.begin(),result.end(),SegmentationComparator());
338 //count ++;
339 } else {
340 segs.push_back(newtrace);
343 } // end while
344 return false;
347 void Segmentor::done()
351 void ZZZ::init(const vector<uint> &_limit)
353 limit = _limit;
354 nr_limit = limit.size();
355 iter.resize(nr_limit);
356 cur = 0;
359 void ZZZ::done()
364 Generate every possible 3-misspelled-positions.
365 Then call WFST::generate_misspelled_words.
368 bool ZZZ::step(vector<uint> &_pos)
370 while (cur >= 0) {
372 if (cur == nr_limit-1) {
373 _pos = iter;
374 while (cur >= 0 && iter[cur] == limit[cur]-1)
375 cur --;
376 if (cur >= 0)
377 iter[cur]++;
378 return true;
381 cur ++;
382 if (cur < nr_limit)
383 iter[cur] = 0;
385 return false;