terrible bug in PenaltyDAG and Penalty2DAG.
[vspell.git] / tests / sc-train.cpp
blobdb043af13707f7e48547318dddeca29e62454618
1 // -*- tab-width: 2 -*-
2 #include "distance.h"
3 #include <string>
4 #include <fstream>
5 #include <cmath>
6 #include <cstdio>
7 #include <sstream>
8 #include <iostream>
9 #include "sentence.h"
10 #include "softcount.h"
11 #include "propername.h"
12 #include <boost/format.hpp>
14 using namespace std;
16 void estimate(Ngram &ngram,NgramFractionalStats &stats);
18 int main(int argc,char **argv)
20 if (argc < 3) {
21 fprintf(stderr,"Need at least 2 argument.\n");
22 return 0;
25 char *oldres = argv[1];
26 char *newres = argv[2];
27 bool nofuz = true;
28 bool nofuz2 = true;
29 bool trigram = true;
30 const char *str;
32 dic_init();
34 cerr << "Loading... ";
35 str = "wordlist";
36 warch.load(str);
37 str = (boost::format("ngram.%s") % oldres).str().c_str();
38 File f(str,"rt",0);
39 if (!f.error())
40 get_ngram().read(f);
41 else
42 cerr << "Ngram loading error..." << endl;
43 cerr << "done" << endl;
45 get_sarch().set_blocked(true);
47 string s;
48 int i,ii,iii,n,nn,nnn,z;
49 int count = 0;
50 NgramStats stats(get_sarch().get_dict(),3);
51 while (getline(cin,s)) {
52 count ++;
53 if (count % 200 == 0)
54 cerr << count << endl;
55 if (s.empty())
56 continue;
57 vector<string> ss;
58 sentences_split(s,ss);
59 for (z = 0;z < ss.size();z ++) {
60 Sentence st(ss[z]);
61 st.standardize();
62 st.tokenize();
63 if (!st.get_syllable_count())
64 continue;
65 //cout << ">>" << count << endl;
66 Lattice words;
67 set<WordEntry> wes;
68 WordStateFactories factories;
69 ExactWordStateFactory exact;
70 LowerWordStateFactory lower;
71 //FuzzyWordStateFactory fuzzy;
72 factories.push_back(&exact);
73 factories.push_back(&lower);
74 //factories.push_back(&fuzzy);
75 words.pre_construct(st,wes,factories);
76 mark_proper_name(st,wes);
77 words.post_construct(wes);
78 //cerr << words << endl;
79 WordDAG dagw(&words);
80 DAG *dag = &dagw;
81 WordDAG2 *dagw2;
82 if (trigram) {
83 dagw2 = new WordDAG2(&dagw);
84 dag = dagw2;
86 SoftCounter sc;
87 //sc.count(words,stats);
88 sc.count(*dag,stats);
89 if (trigram)
90 delete (WordDAG2*)dag;
94 cerr << "Dumping...";
95 File fff("dump","wt");
96 stats.write(fff);
97 fff.close();
99 cerr << "Calculating... ";
100 //estimate(get_ngram(),stats);
101 get_ngram().estimate(stats);
102 //wfst.enable_ngram(true);
104 cerr << "Saving... ";
105 str = (boost::format("ngram.%s") % newres).str().c_str();
106 File ff(str,"wt");
107 get_ngram().write(ff);
108 cerr << endl;
110 for (int i = 0;i < 50;i ++) {
111 ostringstream oss;
112 oss << "log." << i;
113 ofstream ofs(oss.str().c_str());
114 cerr << "Iteration " << i << "... ";
115 iterate(ofs,i);
116 cerr << "done" << endl;
119 return 0;
122 void estimate(Ngram &ngram,NgramFractionalStats &stats)
125 * If no discount method was specified we do the default, standard
126 * thing. Good Turing discounting with the specified min and max counts
127 * for all orders.
129 unsigned order = get_ngram().setorder(0);
130 Discount *discounts[order];
131 unsigned i;
132 Boolean error = false;
134 for (i = 1; !error & i <= order; i++) {
135 discounts[i-1] = new WittenBell();
137 * Transfer the LMStats's debug level to the newly
138 * created discount objects
140 discounts[i-1]->debugme(stats.debuglevel());
142 if (!discounts[i-1]->estimate(stats, i)) {
143 std::cerr << "failed to estimate GT discount for order " << i + 1
144 << std::endl;
145 error = true;
149 if (!error) {
150 error = !get_ngram().estimate((NgramCounts<FloatCount>&)stats, discounts);
153 for (i = 1; i <= order; i++) {
154 delete discounts[i-1];