4 #include <tr1/unordered_map>
6 #include <boost/functional/hash.hpp>
7 #include <boost/program_options.hpp>
8 #include <boost/program_options/variables_map.hpp>
9 #include <boost/lexical_cast.hpp>
11 #include "sparse_vector.h"
12 #include "sentence_pair.h"
21 using namespace std::tr1
;
22 namespace po
= boost::program_options
;
24 static const size_t MAX_LINE_LENGTH
= 100000;
25 WordID kBOS
, kEOS
, kDIVIDER
, kGAP
;
28 void InitCommandLine(int argc
, char** argv
, po::variables_map
* conf
) {
29 po::options_description
opts("Configuration options");
31 ("input,i", po::value
<string
>()->default_value("-"), "Input file")
32 ("default_category,d", po::value
<string
>(), "Default span type (use X for 'Hiero')")
33 ("loose", "Use loose phrase extraction heuristic for base phrases")
34 ("base_phrase,B", "Write base phrases")
35 ("base_phrase_spans", "Write base sentences and phrase spans")
36 ("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))")
37 ("combiner_size,c", po::value
<size_t>()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 0 to disable cache.")
38 ("silent", "Write nothing to stderr except errors")
39 ("phrase_context,C", "Write base phrase contexts")
40 ("phrase_context_size,S", po::value
<int>()->default_value(2), "Use this many words of context on left and write when writing base phrase contexts")
41 ("max_base_phrase_size,L", po::value
<int>()->default_value(10), "Maximum starting phrase size")
42 ("max_syms,l", po::value
<int>()->default_value(5), "Maximum number of symbols in final phrase size")
43 ("max_vars,v", po::value
<int>()->default_value(2), "Maximum number of nonterminal variables in final phrase size")
44 ("permit_adjacent_nonterminals,A", "Permit adjacent nonterminals in source side of rules")
45 ("no_required_aligned_terminal,n", "Do not require an aligned terminal")
46 ("help,h", "Print this help message and exit");
47 po::options_description
clo("Command line options");
48 po::options_description dcmdline_options
;
49 dcmdline_options
.add(opts
);
51 po::store(parse_command_line(argc
, argv
, dcmdline_options
), *conf
);
54 if (conf
->count("help") || conf
->count("input") == 0) {
55 cerr
<< "\nUsage: extractor [-options]\n";
56 cerr
<< dcmdline_options
<< endl
;
61 // TODO how to handle alignment information?
62 void WriteBasePhrases(const AnnotatedParallelSentence
& sentence
,
63 const vector
<ParallelSpan
>& phrases
) {
65 for (int it
= 0; it
< phrases
.size(); ++it
) {
66 const ParallelSpan
& phrase
= phrases
[it
];
69 for (int i
= phrase
.i1
; i
< phrase
.i2
; ++i
)
70 f
.push_back(sentence
.f
[i
]);
71 for (int j
= phrase
.j1
; j
< phrase
.j2
; ++j
)
72 e
.push_back(sentence
.e
[j
]);
73 cout
<< TD::GetString(f
) << " ||| " << TD::GetString(e
) << endl
;
77 void WriteBasePhraseSpans(const AnnotatedParallelSentence
& sentence
,
78 const vector
<ParallelSpan
>& phrases
) {
79 cout
<< TD::GetString(sentence
.f
) << " ||| " << TD::GetString(sentence
.e
) << " |||";
80 for (int it
= 0; it
< phrases
.size(); ++it
) {
81 const ParallelSpan
& phrase
= phrases
[it
];
82 cout
<< " " << phrase
.i1
<< "-" << phrase
.i2
83 << "-" << phrase
.j1
<< "-" << phrase
.j2
;
88 struct CountCombiner
{
89 CountCombiner(size_t csize
) : combiner_size(csize
) {}
91 if (!cache
.empty()) WriteAndClearCache();
94 void Count(const vector
<WordID
>& key
,
95 const vector
<WordID
>& val
,
97 const vector
<pair
<short,short> >& aligns
) {
98 if (combiner_size
> 0) {
99 RuleStatistics
& v
= cache
[key
][val
];
100 float newcount
= v
.counts
.add_value(count_type
, 1.0f
);
101 // hack for adding alignments
102 if (newcount
< 7.0f
&& aligns
.size() > v
.aligns
.size())
104 if (cache
.size() > combiner_size
) WriteAndClearCache();
106 cout
<< TD::GetString(key
) << '\t' << TD::GetString(val
) << " ||| ";
107 cout
<< RuleStatistics(count_type
, 1.0f
, aligns
) << endl
;
112 void WriteAndClearCache() {
113 for (unordered_map
<vector
<WordID
>, Vec2PhraseCount
, boost::hash
<vector
<WordID
> > >::iterator it
= cache
.begin();
114 it
!= cache
.end(); ++it
) {
115 cout
<< TD::GetString(it
->first
) << '\t';
116 const Vec2PhraseCount
& vals
= it
->second
;
117 bool needdiv
= false;
118 for (Vec2PhraseCount::const_iterator vi
= vals
.begin(); vi
!= vals
.end(); ++vi
) {
119 if (needdiv
) cout
<< " ||| "; else needdiv
= true;
120 cout
<< TD::GetString(vi
->first
) << " ||| " << vi
->second
;
127 const size_t combiner_size
;
128 typedef unordered_map
<vector
<WordID
>, RuleStatistics
, boost::hash
<vector
<WordID
> > > Vec2PhraseCount
;
129 unordered_map
<vector
<WordID
>, Vec2PhraseCount
, boost::hash
<vector
<WordID
> > > cache
;
132 // TODO optional source context
133 // output <k, v> : k = phrase "document" v = context "term"
134 void WritePhraseContexts(const AnnotatedParallelSentence
& sentence
,
135 const vector
<ParallelSpan
>& phrases
,
138 vector
<WordID
> context(ctx_size
* 2 + 1);
139 context
[ctx_size
] = kGAP
;
142 for (int it
= 0; it
< phrases
.size(); ++it
) {
143 const ParallelSpan
& phrase
= phrases
[it
];
145 // TODO, support src keys as well
146 key
.resize(phrase
.j2
- phrase
.j1
);
147 for (int j
= phrase
.j1
; j
< phrase
.j2
; ++j
)
148 key
[j
- phrase
.j1
] = sentence
.e
[j
];
150 for (int i
= 0; i
< ctx_size
; ++i
) {
151 int epos
= phrase
.j1
- 1 - i
;
152 const WordID left_ctx
= (epos
< 0) ? kBOS
: sentence
.e
[epos
];
153 context
[ctx_size
- i
- 1] = left_ctx
;
154 epos
= phrase
.j2
+ i
;
155 const WordID right_ctx
= (epos
>= sentence
.e_len
) ? kEOS
: sentence
.e
[epos
];
156 context
[ctx_size
+ i
+ 1] = right_ctx
;
158 o
->Count(key
, context
, kCOUNT
, vector
<pair
<short,short> >());
162 struct SimpleRuleWriter
: public Extract::RuleObserver
{
164 virtual void CountRuleImpl(WordID lhs
,
165 const vector
<WordID
>& rhs_f
,
166 const vector
<WordID
>& rhs_e
,
167 const vector
<pair
<short,short> >& fe_terminal_alignments
) {
168 cout
<< "[" << TD::Convert(-lhs
) << "] |||";
169 for (int i
= 0; i
< rhs_f
.size(); ++i
) {
170 if (rhs_f
[i
] < 0) cout
<< " [" << TD::Convert(-rhs_f
[i
]) << ']';
171 else cout
<< ' ' << TD::Convert(rhs_f
[i
]);
174 for (int i
= 0; i
< rhs_e
.size(); ++i
) {
175 if (rhs_e
[i
] <= 0) cout
<< " [" << (1-rhs_e
[i
]) << ']';
176 else cout
<< ' ' << TD::Convert(rhs_e
[i
]);
179 for (int i
= 0; i
< fe_terminal_alignments
.size(); ++i
) {
180 cout
<< ' ' << fe_terminal_alignments
[i
].first
<< '-' << fe_terminal_alignments
[i
].second
;
186 struct HadoopStreamingRuleObserver
: public Extract::RuleObserver
{
187 HadoopStreamingRuleObserver(CountCombiner
* cc
, bool bidir_flag
) :
189 kF(TD::Convert("F")),
190 kE(TD::Convert("E")),
191 kDIVIDER(TD::Convert("|||")),
195 kCFE(FD::Convert("CFE")) {
196 for (int i
=1; i
< 50; ++i
)
197 index2sym
[1-i
] = TD::Convert(kLB
+ boost::lexical_cast
<string
>(i
) + kRB
);
198 fmajor_key
.resize(10, kF
);
199 emajor_key
.resize(10, kE
);
201 fmajor_key
[2] = emajor_key
[2] = kDIVIDER
;
203 fmajor_key
[1] = kDIVIDER
;
207 virtual void CountRuleImpl(WordID lhs
,
208 const vector
<WordID
>& rhs_f
,
209 const vector
<WordID
>& rhs_e
,
210 const vector
<pair
<short,short> >& fe_terminal_alignments
) {
211 if (bidir
) { // extract rules in "both directions" E->F and F->E
212 fmajor_key
.resize(3 + rhs_f
.size());
213 emajor_key
.resize(3 + rhs_e
.size());
214 fmajor_val
.resize(rhs_e
.size());
215 emajor_val
.resize(rhs_f
.size());
216 emajor_key
[1] = fmajor_key
[1] = MapSym(lhs
);
218 for (int i
= 0; i
< rhs_f
.size(); ++i
) {
219 const WordID id
= rhs_f
[i
];
221 fmajor_key
[3 + i
] = MapSym(id
, nt
);
222 emajor_val
[i
] = MapSym(id
, nt
);
225 fmajor_key
[3 + i
] = id
;
229 for (int i
= 0; i
< rhs_e
.size(); ++i
) {
230 WordID id
= rhs_e
[i
];
232 fmajor_val
[i
] = index2sym
[id
];
233 emajor_key
[3 + i
] = index2sym
[id
];
236 emajor_key
[3 + i
] = id
;
239 combiner
.Count(fmajor_key
, fmajor_val
, kCFE
, fe_terminal_alignments
);
240 combiner
.Count(emajor_key
, emajor_val
, kCFE
, kEMPTY
);
241 } else { // extract rules only in F->E
242 fmajor_key
.resize(2 + rhs_f
.size());
243 fmajor_val
.resize(rhs_e
.size());
244 fmajor_key
[0] = MapSym(lhs
);
246 for (int i
= 0; i
< rhs_f
.size(); ++i
) {
247 const WordID id
= rhs_f
[i
];
249 fmajor_key
[2 + i
] = MapSym(id
, nt
++);
251 fmajor_key
[2 + i
] = id
;
253 for (int i
= 0; i
< rhs_e
.size(); ++i
) {
254 const WordID id
= rhs_e
[i
];
256 fmajor_val
[i
] = index2sym
[id
];
260 combiner
.Count(fmajor_key
, fmajor_val
, kCFE
, fe_terminal_alignments
);
265 WordID
MapSym(WordID sym
, int ind
= 0) {
266 WordID
& r
= cat2ind2sym
[sym
][ind
];
269 r
= TD::Convert(kLB
+ TD::Convert(-sym
) + kRB
);
271 r
= TD::Convert(kLB
+ TD::Convert(-sym
) + "," + boost::lexical_cast
<string
>(ind
) + kRB
);
277 const WordID kF
, kE
, kDIVIDER
;
278 const string kLB
, kRB
;
279 CountCombiner
& combiner
;
280 const vector
<pair
<short,short> > kEMPTY
;
282 map
<WordID
, map
<int, WordID
> > cat2ind2sym
;
283 map
<int, WordID
> index2sym
;
284 vector
<WordID
> emajor_key
, emajor_val
, fmajor_key
, fmajor_val
;
287 int main(int argc
, char** argv
) {
288 po::variables_map conf
;
289 InitCommandLine(argc
, argv
, &conf
);
290 kBOS
= TD::Convert("<s>");
291 kEOS
= TD::Convert("</s>");
292 kDIVIDER
= TD::Convert("|||");
293 kGAP
= TD::Convert("<PHRASE>");
294 kCOUNT
= FD::Convert("C");
296 WordID default_cat
= 0; // 0 means no default- extraction will
297 // fail if a phrase is extracted without a
299 if (conf
.count("default_category")) {
300 string sdefault_cat
= conf
["default_category"].as
<string
>();
301 default_cat
= -TD::Convert(sdefault_cat
);
302 cerr
<< "Default category: " << sdefault_cat
<< endl
;
304 cerr
<< "No default category (use --default_category if you want to set one)\n";
306 ReadFile
rf(conf
["input"].as
<string
>());
307 istream
& in
= *rf
.stream();
309 char buf
[MAX_LINE_LENGTH
];
310 AnnotatedParallelSentence sentence
;
311 vector
<ParallelSpan
> phrases
;
312 const int max_base_phrase_size
= conf
["max_base_phrase_size"].as
<int>();
313 const bool write_phrase_contexts
= conf
.count("phrase_context") > 0;
314 const bool write_base_phrases
= conf
.count("base_phrase") > 0;
315 const bool write_base_phrase_spans
= conf
.count("base_phrase_spans") > 0;
316 const bool loose_phrases
= conf
.count("loose") > 0;
317 const bool silent
= conf
.count("silent") > 0;
318 const int max_syms
= conf
["max_syms"].as
<int>();
319 const int max_vars
= conf
["max_vars"].as
<int>();
320 const int ctx_size
= conf
["phrase_context_size"].as
<int>();
321 const bool permit_adjacent_nonterminals
= conf
.count("permit_adjacent_nonterminals") > 0;
322 const bool require_aligned_terminal
= conf
.count("no_required_aligned_terminal") == 0;
324 CountCombiner
cc(conf
["combiner_size"].as
<size_t>());
325 HadoopStreamingRuleObserver
o(&cc
,
326 conf
.count("bidir") > 0);
327 //SimpleRuleWriter o;
330 in
.getline(buf
, MAX_LINE_LENGTH
);
331 if (buf
[0] == 0) continue;
333 if (line
% 200 == 0) cerr
<< '.';
334 if (line
% 8000 == 0) cerr
<< " [" << line
<< "]\n" << flush
;
336 sentence
.ParseInputLine(buf
);
338 Extract::ExtractBasePhrases(max_base_phrase_size
, sentence
, &phrases
);
340 Extract::LoosenPhraseBounds(sentence
, max_base_phrase_size
, &phrases
);
341 if (phrases
.empty()) {
342 cerr
<< "WARNING no phrases extracted line: " << line
<< endl
;
345 if (write_phrase_contexts
) {
346 WritePhraseContexts(sentence
, phrases
, ctx_size
, &cc
);
349 if (write_base_phrases
) {
350 WriteBasePhrases(sentence
, phrases
);
353 if (write_base_phrase_spans
) {
354 WriteBasePhraseSpans(sentence
, phrases
);
357 Extract::AnnotatePhrasesWithCategoryTypes(default_cat
, sentence
.span_types
, &phrases
);
358 Extract::ExtractConsistentRules(sentence
, phrases
, max_vars
, max_syms
, permit_adjacent_nonterminals
, require_aligned_terminal
, &o
);
360 if (!silent
) cerr
<< endl
;