test
[ws10smt.git] / extools / extractor.cc
bloba3791d2a17e9f3a78f58135830437001a741f2a7
1 #include <iostream>
2 #include <vector>
3 #include <utility>
4 #include <tr1/unordered_map>
6 #include <boost/functional/hash.hpp>
7 #include <boost/program_options.hpp>
8 #include <boost/program_options/variables_map.hpp>
9 #include <boost/lexical_cast.hpp>
11 #include "sparse_vector.h"
12 #include "sentence_pair.h"
13 #include "extract.h"
14 #include "tdict.h"
15 #include "fdict.h"
16 #include "wordid.h"
17 #include "array2d.h"
18 #include "filelib.h"
20 using namespace std;
21 using namespace std::tr1;
22 namespace po = boost::program_options;
24 static const size_t MAX_LINE_LENGTH = 100000;
25 WordID kBOS, kEOS, kDIVIDER, kGAP;
26 int kCOUNT;
28 void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
29 po::options_description opts("Configuration options");
30 opts.add_options()
31 ("input,i", po::value<string>()->default_value("-"), "Input file")
32 ("default_category,d", po::value<string>(), "Default span type (use X for 'Hiero')")
33 ("loose", "Use loose phrase extraction heuristic for base phrases")
34 ("base_phrase,B", "Write base phrases")
35 ("base_phrase_spans", "Write base sentences and phrase spans")
36 ("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))")
37 ("combiner_size,c", po::value<size_t>()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 0 to disable cache.")
38 ("silent", "Write nothing to stderr except errors")
39 ("phrase_context,C", "Write base phrase contexts")
40 ("phrase_context_size,S", po::value<int>()->default_value(2), "Use this many words of context on left and write when writing base phrase contexts")
41 ("max_base_phrase_size,L", po::value<int>()->default_value(10), "Maximum starting phrase size")
42 ("max_syms,l", po::value<int>()->default_value(5), "Maximum number of symbols in final phrase size")
43 ("max_vars,v", po::value<int>()->default_value(2), "Maximum number of nonterminal variables in final phrase size")
44 ("permit_adjacent_nonterminals,A", "Permit adjacent nonterminals in source side of rules")
45 ("no_required_aligned_terminal,n", "Do not require an aligned terminal")
46 ("help,h", "Print this help message and exit");
47 po::options_description clo("Command line options");
48 po::options_description dcmdline_options;
49 dcmdline_options.add(opts);
51 po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
52 po::notify(*conf);
54 if (conf->count("help") || conf->count("input") == 0) {
55 cerr << "\nUsage: extractor [-options]\n";
56 cerr << dcmdline_options << endl;
57 exit(1);
61 // TODO how to handle alignment information?
62 void WriteBasePhrases(const AnnotatedParallelSentence& sentence,
63 const vector<ParallelSpan>& phrases) {
64 vector<WordID> e,f;
65 for (int it = 0; it < phrases.size(); ++it) {
66 const ParallelSpan& phrase = phrases[it];
67 e.clear();
68 f.clear();
69 for (int i = phrase.i1; i < phrase.i2; ++i)
70 f.push_back(sentence.f[i]);
71 for (int j = phrase.j1; j < phrase.j2; ++j)
72 e.push_back(sentence.e[j]);
73 cout << TD::GetString(f) << " ||| " << TD::GetString(e) << endl;
77 void WriteBasePhraseSpans(const AnnotatedParallelSentence& sentence,
78 const vector<ParallelSpan>& phrases) {
79 cout << TD::GetString(sentence.f) << " ||| " << TD::GetString(sentence.e) << " |||";
80 for (int it = 0; it < phrases.size(); ++it) {
81 const ParallelSpan& phrase = phrases[it];
82 cout << " " << phrase.i1 << "-" << phrase.i2
83 << "-" << phrase.j1 << "-" << phrase.j2;
85 cout << endl;
88 struct CountCombiner {
89 CountCombiner(size_t csize) : combiner_size(csize) {}
90 ~CountCombiner() {
91 if (!cache.empty()) WriteAndClearCache();
94 void Count(const vector<WordID>& key,
95 const vector<WordID>& val,
96 const int count_type,
97 const vector<pair<short,short> >& aligns) {
98 if (combiner_size > 0) {
99 RuleStatistics& v = cache[key][val];
100 float newcount = v.counts.add_value(count_type, 1.0f);
101 // hack for adding alignments
102 if (newcount < 7.0f && aligns.size() > v.aligns.size())
103 v.aligns = aligns;
104 if (cache.size() > combiner_size) WriteAndClearCache();
105 } else {
106 cout << TD::GetString(key) << '\t' << TD::GetString(val) << " ||| ";
107 cout << RuleStatistics(count_type, 1.0f, aligns) << endl;
111 private:
112 void WriteAndClearCache() {
113 for (unordered_map<vector<WordID>, Vec2PhraseCount, boost::hash<vector<WordID> > >::iterator it = cache.begin();
114 it != cache.end(); ++it) {
115 cout << TD::GetString(it->first) << '\t';
116 const Vec2PhraseCount& vals = it->second;
117 bool needdiv = false;
118 for (Vec2PhraseCount::const_iterator vi = vals.begin(); vi != vals.end(); ++vi) {
119 if (needdiv) cout << " ||| "; else needdiv = true;
120 cout << TD::GetString(vi->first) << " ||| " << vi->second;
122 cout << endl;
124 cache.clear();
127 const size_t combiner_size;
128 typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > Vec2PhraseCount;
129 unordered_map<vector<WordID>, Vec2PhraseCount, boost::hash<vector<WordID> > > cache;
132 // TODO optional source context
133 // output <k, v> : k = phrase "document" v = context "term"
134 void WritePhraseContexts(const AnnotatedParallelSentence& sentence,
135 const vector<ParallelSpan>& phrases,
136 const int ctx_size,
137 CountCombiner* o) {
138 vector<WordID> context(ctx_size * 2 + 1);
139 context[ctx_size] = kGAP;
140 vector<WordID> key;
141 key.reserve(100);
142 for (int it = 0; it < phrases.size(); ++it) {
143 const ParallelSpan& phrase = phrases[it];
145 // TODO, support src keys as well
146 key.resize(phrase.j2 - phrase.j1);
147 for (int j = phrase.j1; j < phrase.j2; ++j)
148 key[j - phrase.j1] = sentence.e[j];
150 for (int i = 0; i < ctx_size; ++i) {
151 int epos = phrase.j1 - 1 - i;
152 const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos];
153 context[ctx_size - i - 1] = left_ctx;
154 epos = phrase.j2 + i;
155 const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos];
156 context[ctx_size + i + 1] = right_ctx;
158 o->Count(key, context, kCOUNT, vector<pair<short,short> >());
162 struct SimpleRuleWriter : public Extract::RuleObserver {
163 protected:
164 virtual void CountRuleImpl(WordID lhs,
165 const vector<WordID>& rhs_f,
166 const vector<WordID>& rhs_e,
167 const vector<pair<short,short> >& fe_terminal_alignments) {
168 cout << "[" << TD::Convert(-lhs) << "] |||";
169 for (int i = 0; i < rhs_f.size(); ++i) {
170 if (rhs_f[i] < 0) cout << " [" << TD::Convert(-rhs_f[i]) << ']';
171 else cout << ' ' << TD::Convert(rhs_f[i]);
173 cout << " |||";
174 for (int i = 0; i < rhs_e.size(); ++i) {
175 if (rhs_e[i] <= 0) cout << " [" << (1-rhs_e[i]) << ']';
176 else cout << ' ' << TD::Convert(rhs_e[i]);
178 cout << " |||";
179 for (int i = 0; i < fe_terminal_alignments.size(); ++i) {
180 cout << ' ' << fe_terminal_alignments[i].first << '-' << fe_terminal_alignments[i].second;
182 cout << endl;
186 struct HadoopStreamingRuleObserver : public Extract::RuleObserver {
187 HadoopStreamingRuleObserver(CountCombiner* cc, bool bidir_flag) :
188 bidir(bidir_flag),
189 kF(TD::Convert("F")),
190 kE(TD::Convert("E")),
191 kDIVIDER(TD::Convert("|||")),
192 kLB("["), kRB("]"),
193 combiner(*cc),
194 kEMPTY(),
195 kCFE(FD::Convert("CFE")) {
196 for (int i=1; i < 50; ++i)
197 index2sym[1-i] = TD::Convert(kLB + boost::lexical_cast<string>(i) + kRB);
198 fmajor_key.resize(10, kF);
199 emajor_key.resize(10, kE);
200 if (bidir)
201 fmajor_key[2] = emajor_key[2] = kDIVIDER;
202 else
203 fmajor_key[1] = kDIVIDER;
206 protected:
207 virtual void CountRuleImpl(WordID lhs,
208 const vector<WordID>& rhs_f,
209 const vector<WordID>& rhs_e,
210 const vector<pair<short,short> >& fe_terminal_alignments) {
211 if (bidir) { // extract rules in "both directions" E->F and F->E
212 fmajor_key.resize(3 + rhs_f.size());
213 emajor_key.resize(3 + rhs_e.size());
214 fmajor_val.resize(rhs_e.size());
215 emajor_val.resize(rhs_f.size());
216 emajor_key[1] = fmajor_key[1] = MapSym(lhs);
217 int nt = 1;
218 for (int i = 0; i < rhs_f.size(); ++i) {
219 const WordID id = rhs_f[i];
220 if (id < 0) {
221 fmajor_key[3 + i] = MapSym(id, nt);
222 emajor_val[i] = MapSym(id, nt);
223 ++nt;
224 } else {
225 fmajor_key[3 + i] = id;
226 emajor_val[i] = id;
229 for (int i = 0; i < rhs_e.size(); ++i) {
230 WordID id = rhs_e[i];
231 if (id <= 0) {
232 fmajor_val[i] = index2sym[id];
233 emajor_key[3 + i] = index2sym[id];
234 } else {
235 fmajor_val[i] = id;
236 emajor_key[3 + i] = id;
239 combiner.Count(fmajor_key, fmajor_val, kCFE, fe_terminal_alignments);
240 combiner.Count(emajor_key, emajor_val, kCFE, kEMPTY);
241 } else { // extract rules only in F->E
242 fmajor_key.resize(2 + rhs_f.size());
243 fmajor_val.resize(rhs_e.size());
244 fmajor_key[0] = MapSym(lhs);
245 int nt = 1;
246 for (int i = 0; i < rhs_f.size(); ++i) {
247 const WordID id = rhs_f[i];
248 if (id < 0)
249 fmajor_key[2 + i] = MapSym(id, nt++);
250 else
251 fmajor_key[2 + i] = id;
253 for (int i = 0; i < rhs_e.size(); ++i) {
254 const WordID id = rhs_e[i];
255 if (id <= 0)
256 fmajor_val[i] = index2sym[id];
257 else
258 fmajor_val[i] = id;
260 combiner.Count(fmajor_key, fmajor_val, kCFE, fe_terminal_alignments);
264 private:
265 WordID MapSym(WordID sym, int ind = 0) {
266 WordID& r = cat2ind2sym[sym][ind];
267 if (!r) {
268 if (ind == 0)
269 r = TD::Convert(kLB + TD::Convert(-sym) + kRB);
270 else
271 r = TD::Convert(kLB + TD::Convert(-sym) + "," + boost::lexical_cast<string>(ind) + kRB);
273 return r;
276 const bool bidir;
277 const WordID kF, kE, kDIVIDER;
278 const string kLB, kRB;
279 CountCombiner& combiner;
280 const vector<pair<short,short> > kEMPTY;
281 const int kCFE;
282 map<WordID, map<int, WordID> > cat2ind2sym;
283 map<int, WordID> index2sym;
284 vector<WordID> emajor_key, emajor_val, fmajor_key, fmajor_val;
287 int main(int argc, char** argv) {
288 po::variables_map conf;
289 InitCommandLine(argc, argv, &conf);
290 kBOS = TD::Convert("<s>");
291 kEOS = TD::Convert("</s>");
292 kDIVIDER = TD::Convert("|||");
293 kGAP = TD::Convert("<PHRASE>");
294 kCOUNT = FD::Convert("C");
296 WordID default_cat = 0; // 0 means no default- extraction will
297 // fail if a phrase is extracted without a
298 // category
299 if (conf.count("default_category")) {
300 string sdefault_cat = conf["default_category"].as<string>();
301 default_cat = -TD::Convert(sdefault_cat);
302 cerr << "Default category: " << sdefault_cat << endl;
303 } else {
304 cerr << "No default category (use --default_category if you want to set one)\n";
306 ReadFile rf(conf["input"].as<string>());
307 istream& in = *rf.stream();
309 char buf[MAX_LINE_LENGTH];
310 AnnotatedParallelSentence sentence;
311 vector<ParallelSpan> phrases;
312 const int max_base_phrase_size = conf["max_base_phrase_size"].as<int>();
313 const bool write_phrase_contexts = conf.count("phrase_context") > 0;
314 const bool write_base_phrases = conf.count("base_phrase") > 0;
315 const bool write_base_phrase_spans = conf.count("base_phrase_spans") > 0;
316 const bool loose_phrases = conf.count("loose") > 0;
317 const bool silent = conf.count("silent") > 0;
318 const int max_syms = conf["max_syms"].as<int>();
319 const int max_vars = conf["max_vars"].as<int>();
320 const int ctx_size = conf["phrase_context_size"].as<int>();
321 const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0;
322 const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0;
323 int line = 0;
324 CountCombiner cc(conf["combiner_size"].as<size_t>());
325 HadoopStreamingRuleObserver o(&cc,
326 conf.count("bidir") > 0);
327 //SimpleRuleWriter o;
328 while(in) {
329 ++line;
330 in.getline(buf, MAX_LINE_LENGTH);
331 if (buf[0] == 0) continue;
332 if (!silent) {
333 if (line % 200 == 0) cerr << '.';
334 if (line % 8000 == 0) cerr << " [" << line << "]\n" << flush;
336 sentence.ParseInputLine(buf);
337 phrases.clear();
338 Extract::ExtractBasePhrases(max_base_phrase_size, sentence, &phrases);
339 if (loose_phrases)
340 Extract::LoosenPhraseBounds(sentence, max_base_phrase_size, &phrases);
341 if (phrases.empty()) {
342 cerr << "WARNING no phrases extracted line: " << line << endl;
343 continue;
345 if (write_phrase_contexts) {
346 WritePhraseContexts(sentence, phrases, ctx_size, &cc);
347 continue;
349 if (write_base_phrases) {
350 WriteBasePhrases(sentence, phrases);
351 continue;
353 if (write_base_phrase_spans) {
354 WriteBasePhraseSpans(sentence, phrases);
355 continue;
357 Extract::AnnotatePhrasesWithCategoryTypes(default_cat, sentence.span_types, &phrases);
358 Extract::ExtractConsistentRules(sentence, phrases, max_vars, max_syms, permit_adjacent_nonterminals, require_aligned_terminal, &o);
360 if (!silent) cerr << endl;
361 return 0;