Shhh.. dont know what i did
[vspell.git] / wfst-train.cpp
blob46768faf5169df5bbde52302152e3f0054b5034f
1 // -*- coding: viscii -*-
2 #include "wfst.h"
3 #include "distance.h"
4 #include <string>
5 #include <fstream>
6 #include <cmath>
7 #include <cstdio>
8 #include <sstream>
9 #include <srilm/NgramStats.h>
11 void iterate(ostream &os,int level);
13 using namespace std;
15 WFST wfst;
16 vector<Sentence> sentences;
18 int main()
20 Dictionary::initialize();
22 cerr << "Loading... ";
23 Dictionary::get_root()->load("wordlist.wl");
24 cerr << "done" << endl;
26 wfst.set_wordlist(Dictionary::get_root());
27 ifstream ifs("corpus2");
28 if (!ifs.is_open()) {
29 cerr << "Can not open corpus\n";
30 return 0;
33 string s;
34 while (getline(ifs,s)) {
35 if (!s.empty()) {
36 sentences.push_back(Sentence(s));
37 Sentence &st = sentences.back();
38 st.standardize();
39 st.tokenize();
43 for (int i = 0;i < 50;i ++) {
44 ostringstream oss;
45 oss << "log." << i;
46 ofstream ofs(oss.str().c_str());
47 cerr << "Iteration " << i << "... ";
48 iterate(ofs,i);
49 cerr << "done" << endl;
52 return 0;
55 void print_all_words(const Words &words);
56 void iterate(ostream &os,int level)
58 int ist,nr_sentences = sentences.size();
59 NgramStats stats(Dictionary::sarch.get_dict(),2);
60 for (ist = 0;ist < nr_sentences;ist ++) {
61 Sentence &st = sentences[ist];
63 Segmentation seg;
64 Words words;
65 wfst.get_all_words(st,words);
66 //print_all_words(words);
67 wfst.segment_best(st,words,seg);
68 seg.print(os,st);
71 #ifdef TRAINING
72 int i,ii,iii,n,nn,nnn;
73 n = seg.items.size();
74 VocabIndex *vi = new VocabIndex[n+1];
75 vi[n] = Vocab_None;
76 for (i = 0;i < n;i ++)
77 vi[i] = seg.items[i].state->get_id();
78 stats.countSentence(vi);
80 n = words.size();
81 for (i = 0;i < n;i ++) {
82 nn = words[i].size();
83 for (ii = 0;ii < nn;ii ++) {
84 nnn = words[i][ii].fuzzy_match.size();
85 for (iii = 0;iii < nnn;iii ++) {
86 words[i][ii].fuzzy_match[iii].node->inc_b();
91 n = seg.items.size();
92 for (i = 0;i < n;i ++)
93 seg.items[i].state->inc_a();
94 #endif
97 cerr << "Calculating... ";
98 Dictionary::get_root()->get_next(Dictionary::unk_id)->get_b() = 0;
99 Dictionary::get_root()->recalculate();
100 Dictionary::ngram.estimate(stats);
101 //wfst.enable_ngram(true);
103 cerr << "Saving... ";
104 ostringstream oss;
105 oss << "wordlist.wl." << level;
106 Dictionary::get_root()->save(oss.str().c_str());
108 ostringstream oss1;
109 oss1 << "ngram." << level;
110 File f(oss1.str().c_str(),"wt");
111 Dictionary::ngram.write(f);