test
[ws10smt.git] / extools / mr_stripe_rule_reduce.cc
blobeaf1b6d77a6ead1fce50d88f36fa65b5c54a069c
1 #include <iostream>
2 #include <vector>
3 #include <utility>
4 #include <cstdlib>
5 #include <tr1/unordered_map>
7 #include <boost/functional/hash.hpp>
8 #include <boost/program_options.hpp>
9 #include <boost/program_options/variables_map.hpp>
11 #include "tdict.h"
12 #include "sentence_pair.h"
13 #include "fdict.h"
14 #include "extract.h"
16 using namespace std;
17 using namespace std::tr1;
18 namespace po = boost::program_options;
20 static const size_t MAX_LINE_LENGTH = 64000000;
22 bool use_hadoop_counters = false;
24 namespace {
25 inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
27 inline void SkipWhitespace(const char* buf, int* ptr) {
28 while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }
31 void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
32 po::options_description opts("Configuration options");
33 opts.add_options()
34 ("phrase_marginals,p", "Compute phrase marginals")
35 ("use_hadoop_counters,C", "Enable this if running inside Hadoop")
36 ("bidir,b", "Rules are tagged as being F->E or E->F, invert E rules in output")
37 ("help,h", "Print this help message and exit");
38 po::options_description clo("Command line options");
39 po::options_description dcmdline_options;
40 dcmdline_options.add(opts);
42 po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
43 po::notify(*conf);
45 if (conf->count("help")) {
46 cerr << "\nUsage: mr_stripe_rule_reduce [-options]\n";
47 cerr << dcmdline_options << endl;
48 exit(1);
52 typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
54 void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {
55 for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) {
56 RuleStatistics& dest = (*self)[it->first];
57 dest += it->second;
58 // TODO - do something smarter about alignments?
59 if (dest.aligns.empty() && !it->second.aligns.empty())
60 dest.aligns = it->second.aligns;
64 int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {
65 static const WordID kDIV = TD::Convert("|||");
66 int ptr = sstart;
67 while(ptr < end) {
68 while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; }
69 int start = ptr;
70 while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; }
71 if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; }
72 const WordID w = TD::Convert(string(buf, start, ptr - start));
73 if (w == kDIV) return ptr;
74 p->push_back(w);
76 return ptr;
79 void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) {
80 static const WordID kDIV = TD::Convert("|||");
81 counts->clear();
82 int ptr = 0;
83 while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; }
84 if (buf[ptr] != '\t') {
85 cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl;
86 exit(1);
88 cur_key->clear();
89 // key is: "[X] ||| word word word"
90 int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key);
91 if (buf[tmpp] != '\t') {
92 cur_key->push_back(kDIV);
93 ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key);
95 ++ptr;
96 int start = ptr;
97 int end = ptr;
98 int state = 0; // 0=reading label, 1=reading count
99 vector<WordID> name;
100 while(buf[ptr] != 0) {
101 while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; }
102 if (buf[ptr] == '|') {
103 ++ptr;
104 if (buf[ptr] == '|') {
105 ++ptr;
106 if (buf[ptr] == '|') {
107 ++ptr;
108 end = ptr - 3;
109 while (end > start && IsWhitespace(buf[end-1])) { --end; }
110 if (start == end) {
111 cerr << "Got empty token!\n LINE=" << buf << endl;
112 exit(1);
114 switch (state) {
115 case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
116 case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
117 default: cerr << "Can't happen\n"; abort();
119 SkipWhitespace(buf, &ptr);
120 start = ptr;
125 end=ptr;
126 while (end > start && IsWhitespace(buf[end-1])) { --end; }
127 if (end > start) {
128 switch (state) {
129 case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
130 case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
131 default: cerr << "Can't happen\n"; abort();
136 void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) {
137 cout << TD::GetString(key) << '\t';
138 bool needdiv = false;
139 for (ID2RuleStatistics::const_iterator it = val.begin(); it != val.end(); ++it) {
140 if (needdiv) cout << " ||| "; else needdiv = true;
141 cout << TD::GetString(it->first) << " ||| " << it->second;
143 cout << endl;
144 if (use_hadoop_counters) cerr << "reporter:counter:UserCounters,RuleCount," << val.size() << endl;
147 void DoPhraseMarginals(const vector<WordID>& key, const bool bidir, ID2RuleStatistics* val) {
148 static const WordID kF = TD::Convert("F");
149 static const WordID kE = TD::Convert("E");
150 static const int kCF = FD::Convert("CF");
151 static const int kCE = FD::Convert("CE");
152 static const int kCFE = FD::Convert("CFE");
153 assert(key.size() > 0);
154 int cur_marginal_id = kCF;
155 if (bidir) {
156 if (key[0] != kF && key[0] != kE) {
157 cerr << "DoPhraseMarginals expects keys to have the from 'F|E [NT] word word word'\n";
158 cerr << " but got: " << TD::GetString(key) << endl;
159 exit(1);
161 if (key[0] == kE) cur_marginal_id = kCE;
163 double tot = 0;
164 for (ID2RuleStatistics::iterator it = val->begin(); it != val->end(); ++it)
165 tot += it->second.counts.value(kCFE);
166 for (ID2RuleStatistics::iterator it = val->begin(); it != val->end(); ++it) {
167 it->second.counts.set_value(cur_marginal_id, tot);
169 // prevent double counting of the joint
170 if (cur_marginal_id == kCE) it->second.counts.clear_value(kCFE);
174 void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val) {
175 static const WordID kE = TD::Convert("E");
176 static const WordID kDIV = TD::Convert("|||");
177 vector<WordID> new_key(key.size() - 1);
178 for (int i = 1; i < key.size(); ++i)
179 new_key[i - 1] = key[i];
180 const bool do_invert = (key[0] == kE);
181 if (!do_invert) {
182 WriteKeyValue(new_key, val);
183 } else {
184 ID2RuleStatistics inv;
185 assert(new_key.size() > 2);
186 vector<WordID> tk(new_key.size() - 2);
187 for (int i = 0; i < tk.size(); ++i)
188 tk[i] = new_key[2 + i];
189 RuleStatistics& inv_stats = inv[tk];
190 for (ID2RuleStatistics::const_iterator it = val.begin(); it != val.end(); ++it) {
191 inv_stats.counts = it->second.counts;
192 vector<WordID> ekey(2 + it->first.size());
193 ekey[0] = key[1];
194 ekey[1] = kDIV;
195 for (int i = 0; i < it->first.size(); ++i)
196 ekey[2+i] = it->first[i];
197 WriteKeyValue(ekey, inv);
202 int main(int argc, char** argv) {
203 po::variables_map conf;
204 InitCommandLine(argc, argv, &conf);
206 char* buf = new char[MAX_LINE_LENGTH];
207 ID2RuleStatistics acc, cur_counts;
208 vector<WordID> key, cur_key;
209 int line = 0;
210 use_hadoop_counters = conf.count("use_hadoop_counters") > 0;
211 const bool phrase_marginals = conf.count("phrase_marginals") > 0;
212 const bool bidir = conf.count("bidir") > 0;
213 while(cin) {
214 ++line;
215 cin.getline(buf, MAX_LINE_LENGTH);
216 if (buf[0] == 0) continue;
217 ParseLine(buf, &cur_key, &cur_counts);
218 if (cur_key != key) {
219 if (key.size() > 0) {
220 if (phrase_marginals)
221 DoPhraseMarginals(key, bidir, &acc);
222 if (bidir)
223 WriteWithInversions(key, acc);
224 else
225 WriteKeyValue(key, acc);
226 acc.clear();
228 key = cur_key;
230 PlusEquals(cur_counts, &acc);
232 if (key.size() > 0) {
233 if (phrase_marginals)
234 DoPhraseMarginals(key, bidir, &acc);
235 if (bidir)
236 WriteWithInversions(key, acc);
237 else
238 WriteKeyValue(key, acc);
240 return 0;