Added .gitignore
[vspell.git] / tests / wfst-train.cpp
blob9037cda2455f926881236d0dcdfacf24e8b7670e
1 // -*- tab-width: 2 -*-
2 #include "pfs.h"
3 #include "distance.h"
4 #include <string>
5 #include <fstream>
6 #include <cmath>
7 #include <cstdio>
8 #include <sstream>
9 #include <iostream>
10 #include "sentence.h"
11 #include "propername.h"
12 #include <boost/format.hpp>
14 using namespace std;
17 //NgramFractionalStats stats(sarch.get_dict(),2);
19 int main(int argc,char **argv)
21 if (argc < 3) {
22 fprintf(stderr,"Need at least 2 argument.\n");
23 return 0;
26 char *oldres = argv[1];
27 char *newres = argv[2];
28 bool nofuz = true;
29 bool nofuz2 = true;
30 const char *str;
32 dic_init();
34 cerr << "Loading... ";
35 //str = (boost::format("wordlist.%s") % oldres).str().c_str();
36 str = "wordlist";
37 warch.load(str);
38 str = (boost::format("ngram.%s") % oldres).str().c_str();
39 File f(str,"rt",0);
40 if (!f.error())
41 get_ngram().read(f);
42 else
43 cerr << "Ngram loading error..." << endl;
44 cerr << "done" << endl;
46 get_sarch().set_blocked(true);
48 //wfst.set_wordlist(get_root());
50 string s;
51 int i,ii,iii,n,nn,nnn,z;
52 int count = 0,ccount = 0;
53 NgramStats stats(get_sarch().get_dict(),3);
54 //NgramStats syllable_stats(get_sarch().get_dict(),2);
55 while (getline(cin,s)) {
56 ccount++;
57 //cerr << ">" << ccount << endl;
58 if (s.empty())
59 continue;
60 vector<string> ss;
61 sentences_split(s,ss);
62 for (z = 0;z < ss.size();z ++) {
63 count ++;
64 if (count % 1000 == 0)
65 cerr << count << endl;
66 Sentence st(ss[z]);
67 st.standardize();
68 st.tokenize();
69 if (!st.get_syllable_count())
70 continue;
71 //cerr << st << endl;
72 Lattice words;
73 set<WordEntry> wes;
74 WordStateFactories factories;
75 ExactWordStateFactory exact;
76 LowerWordStateFactory lower;
77 //FuzzyWordStateFactory fuzzy;
78 factories.push_back(&exact);
79 factories.push_back(&lower);
80 //factories.push_back(&fuzzy);
81 words.pre_construct(st,wes,factories);
82 mark_proper_name(st,wes);
83 words.post_construct(wes);
84 //cerr << words << endl;
85 Segmentation seg(words.we);
86 PFS wfst;
87 /* // pfs don't distinguish
88 if (nofuz2)
89 wfst.segment_best_no_fuzzy(words,seg);
90 else
91 wfst.segment_best(words,seg);
93 Path path;
94 WordDAG dag(&words);
95 wfst.search(dag,path);
96 seg.resize(path.size()-2);
97 copy(path.begin()+1,path.end()-1,seg.begin());
99 //seg.pretty_print(cout,st) << endl;
101 VocabIndex *vi;
102 n = path.size();
103 if (n > 3) {
104 vi = new VocabIndex[n+1];
105 vi[n] = Vocab_None;
106 for (i = 0;i < n;i ++) {
107 if (path[i] == dag.node_begin())
108 vi[i] = get_id(START_ID);
109 else if (path[i] == dag.node_end())
110 vi[i] = get_id(STOP_ID);
111 else {
112 vi[i] = ((WordEntry*)dag.node_info(path[i]))->node.node->get_id();
114 if (!sarch.in_dict(vi[i])) {
115 cerr << ">>" << ccount << " " << count << " " << vi[i] << endl;
116 vi[i] = get_id(UNK_ID);
120 //cerr << "<" << sarch[vi[i]] << "> ";
122 //cerr << endl;
123 //cerr << n << endl;
124 stats.countSentence(vi);
125 //cerr << "done" << endl;
126 delete[] vi;
130 const WordEntries &we = *words.we;
131 n = we.size();
132 for (i = 0;i < n;i ++) {
133 we[i].node.node->inc_b();
136 n = seg.size();
137 for (i = 0;i < n;i ++)
138 seg[i].node.node->inc_a();
141 sarch.clear_rest();
144 cerr << "Calculating... ";
145 //get_root()->get_next(unk_id)->get_b() = 0;
146 //get_root()->recalculate();
147 get_ngram().estimate(stats);
148 //wfst.enable_ngram(true);
150 cerr << "Saving... ";
151 //str = (boost::format("wordlist.wl.%s") % newres).str().c_str();
152 //get_root()->save(str);
154 str = (boost::format("ngram.%s") % newres).str().c_str();
155 File ff(str,"wt");
156 get_ngram().write(ff);
157 cerr << endl;
159 for (int i = 0;i < 50;i ++) {
160 ostringstream oss;
161 oss << "log." << i;
162 ofstream ofs(oss.str().c_str());
163 cerr << "Iteration " << i << "... ";
164 iterate(ofs,i);
165 cerr << "done" << endl;
168 return 0;