5 #include <tr1/unordered_map>
7 #include <boost/functional/hash.hpp>
8 #include <boost/program_options.hpp>
9 #include <boost/program_options/variables_map.hpp>
12 #include "sentence_pair.h"
17 using namespace std::tr1
;
18 namespace po
= boost::program_options
;
20 static const size_t MAX_LINE_LENGTH
= 64000000;
22 bool use_hadoop_counters
= false;
25 inline bool IsWhitespace(char c
) { return c
== ' ' || c
== '\t'; }
27 inline void SkipWhitespace(const char* buf
, int* ptr
) {
28 while (buf
[*ptr
] && IsWhitespace(buf
[*ptr
])) { ++(*ptr
); }
31 void InitCommandLine(int argc
, char** argv
, po::variables_map
* conf
) {
32 po::options_description
opts("Configuration options");
34 ("phrase_marginals,p", "Compute phrase marginals")
35 ("use_hadoop_counters,C", "Enable this if running inside Hadoop")
36 ("bidir,b", "Rules are tagged as being F->E or E->F, invert E rules in output")
37 ("help,h", "Print this help message and exit");
38 po::options_description
clo("Command line options");
39 po::options_description dcmdline_options
;
40 dcmdline_options
.add(opts
);
42 po::store(parse_command_line(argc
, argv
, dcmdline_options
), *conf
);
45 if (conf
->count("help")) {
46 cerr
<< "\nUsage: mr_stripe_rule_reduce [-options]\n";
47 cerr
<< dcmdline_options
<< endl
;
52 typedef unordered_map
<vector
<WordID
>, RuleStatistics
, boost::hash
<vector
<WordID
> > > ID2RuleStatistics
;
54 void PlusEquals(const ID2RuleStatistics
& v
, ID2RuleStatistics
* self
) {
55 for (ID2RuleStatistics::const_iterator it
= v
.begin(); it
!= v
.end(); ++it
) {
56 RuleStatistics
& dest
= (*self
)[it
->first
];
58 // TODO - do something smarter about alignments?
59 if (dest
.aligns
.empty() && !it
->second
.aligns
.empty())
60 dest
.aligns
= it
->second
.aligns
;
64 int ReadPhraseUntilDividerOrEnd(const char* buf
, const int sstart
, const int end
, vector
<WordID
>* p
) {
65 static const WordID kDIV
= TD::Convert("|||");
68 while(ptr
< end
&& IsWhitespace(buf
[ptr
])) { ++ptr
; }
70 while(ptr
< end
&& !IsWhitespace(buf
[ptr
])) { ++ptr
; }
71 if (ptr
== start
) {cerr
<< "Warning! empty token.\n"; return ptr
; }
72 const WordID w
= TD::Convert(string(buf
, start
, ptr
- start
));
73 if (w
== kDIV
) return ptr
;
79 void ParseLine(const char* buf
, vector
<WordID
>* cur_key
, ID2RuleStatistics
* counts
) {
80 static const WordID kDIV
= TD::Convert("|||");
83 while(buf
[ptr
] != 0 && buf
[ptr
] != '\t') { ++ptr
; }
84 if (buf
[ptr
] != '\t') {
85 cerr
<< "Missing tab separator between key and value!\n INPUT=" << buf
<< endl
;
89 // key is: "[X] ||| word word word"
90 int tmpp
= ReadPhraseUntilDividerOrEnd(buf
, 0, ptr
, cur_key
);
91 if (buf
[tmpp
] != '\t') {
92 cur_key
->push_back(kDIV
);
93 ReadPhraseUntilDividerOrEnd(buf
, tmpp
, ptr
, cur_key
);
98 int state
= 0; // 0=reading label, 1=reading count
100 while(buf
[ptr
] != 0) {
101 while(buf
[ptr
] != 0 && buf
[ptr
] != '|') { ++ptr
; }
102 if (buf
[ptr
] == '|') {
104 if (buf
[ptr
] == '|') {
106 if (buf
[ptr
] == '|') {
109 while (end
> start
&& IsWhitespace(buf
[end
-1])) { --end
; }
111 cerr
<< "Got empty token!\n LINE=" << buf
<< endl
;
115 case 0: ++state
; name
.clear(); ReadPhraseUntilDividerOrEnd(buf
, start
, end
, &name
); break;
116 case 1: --state
; (*counts
)[name
].ParseRuleStatistics(buf
, start
, end
); break;
117 default: cerr
<< "Can't happen\n"; abort();
119 SkipWhitespace(buf
, &ptr
);
126 while (end
> start
&& IsWhitespace(buf
[end
-1])) { --end
; }
129 case 0: ++state
; name
.clear(); ReadPhraseUntilDividerOrEnd(buf
, start
, end
, &name
); break;
130 case 1: --state
; (*counts
)[name
].ParseRuleStatistics(buf
, start
, end
); break;
131 default: cerr
<< "Can't happen\n"; abort();
136 void WriteKeyValue(const vector
<WordID
>& key
, const ID2RuleStatistics
& val
) {
137 cout
<< TD::GetString(key
) << '\t';
138 bool needdiv
= false;
139 for (ID2RuleStatistics::const_iterator it
= val
.begin(); it
!= val
.end(); ++it
) {
140 if (needdiv
) cout
<< " ||| "; else needdiv
= true;
141 cout
<< TD::GetString(it
->first
) << " ||| " << it
->second
;
144 if (use_hadoop_counters
) cerr
<< "reporter:counter:UserCounters,RuleCount," << val
.size() << endl
;
147 void DoPhraseMarginals(const vector
<WordID
>& key
, const bool bidir
, ID2RuleStatistics
* val
) {
148 static const WordID kF
= TD::Convert("F");
149 static const WordID kE
= TD::Convert("E");
150 static const int kCF
= FD::Convert("CF");
151 static const int kCE
= FD::Convert("CE");
152 static const int kCFE
= FD::Convert("CFE");
153 assert(key
.size() > 0);
154 int cur_marginal_id
= kCF
;
156 if (key
[0] != kF
&& key
[0] != kE
) {
157 cerr
<< "DoPhraseMarginals expects keys to have the from 'F|E [NT] word word word'\n";
158 cerr
<< " but got: " << TD::GetString(key
) << endl
;
161 if (key
[0] == kE
) cur_marginal_id
= kCE
;
164 for (ID2RuleStatistics::iterator it
= val
->begin(); it
!= val
->end(); ++it
)
165 tot
+= it
->second
.counts
.value(kCFE
);
166 for (ID2RuleStatistics::iterator it
= val
->begin(); it
!= val
->end(); ++it
) {
167 it
->second
.counts
.set_value(cur_marginal_id
, tot
);
169 // prevent double counting of the joint
170 if (cur_marginal_id
== kCE
) it
->second
.counts
.clear_value(kCFE
);
174 void WriteWithInversions(const vector
<WordID
>& key
, const ID2RuleStatistics
& val
) {
175 static const WordID kE
= TD::Convert("E");
176 static const WordID kDIV
= TD::Convert("|||");
177 vector
<WordID
> new_key(key
.size() - 1);
178 for (int i
= 1; i
< key
.size(); ++i
)
179 new_key
[i
- 1] = key
[i
];
180 const bool do_invert
= (key
[0] == kE
);
182 WriteKeyValue(new_key
, val
);
184 ID2RuleStatistics inv
;
185 assert(new_key
.size() > 2);
186 vector
<WordID
> tk(new_key
.size() - 2);
187 for (int i
= 0; i
< tk
.size(); ++i
)
188 tk
[i
] = new_key
[2 + i
];
189 RuleStatistics
& inv_stats
= inv
[tk
];
190 for (ID2RuleStatistics::const_iterator it
= val
.begin(); it
!= val
.end(); ++it
) {
191 inv_stats
.counts
= it
->second
.counts
;
192 vector
<WordID
> ekey(2 + it
->first
.size());
195 for (int i
= 0; i
< it
->first
.size(); ++i
)
196 ekey
[2+i
] = it
->first
[i
];
197 WriteKeyValue(ekey
, inv
);
202 int main(int argc
, char** argv
) {
203 po::variables_map conf
;
204 InitCommandLine(argc
, argv
, &conf
);
206 char* buf
= new char[MAX_LINE_LENGTH
];
207 ID2RuleStatistics acc
, cur_counts
;
208 vector
<WordID
> key
, cur_key
;
210 use_hadoop_counters
= conf
.count("use_hadoop_counters") > 0;
211 const bool phrase_marginals
= conf
.count("phrase_marginals") > 0;
212 const bool bidir
= conf
.count("bidir") > 0;
215 cin
.getline(buf
, MAX_LINE_LENGTH
);
216 if (buf
[0] == 0) continue;
217 ParseLine(buf
, &cur_key
, &cur_counts
);
218 if (cur_key
!= key
) {
219 if (key
.size() > 0) {
220 if (phrase_marginals
)
221 DoPhraseMarginals(key
, bidir
, &acc
);
223 WriteWithInversions(key
, acc
);
225 WriteKeyValue(key
, acc
);
230 PlusEquals(cur_counts
, &acc
);
232 if (key
.size() > 0) {
233 if (phrase_marginals
)
234 DoPhraseMarginals(key
, bidir
, &acc
);
236 WriteWithInversions(key
, acc
);
238 WriteKeyValue(key
, acc
);