4 #include <boost/program_options.hpp>
10 #include "stringlib.h"
14 namespace po
= boost::program_options
;
16 void InitCommandLine(int argc
, char** argv
, po::variables_map
* conf
) {
17 po::options_description
opts("Configuration options");
19 ("scale,a",po::value
<double>()->default_value(1.0), "Posterior scaling factor (alpha)")
20 ("loss_function,l",po::value
<string
>()->default_value("bleu"), "Loss function")
21 ("input,i",po::value
<string
>()->default_value("-"), "File to read k-best lists from")
22 ("output_list,L", "Show reranked list as output")
24 po::options_description dcmdline_options
;
25 dcmdline_options
.add(opts
);
26 po::store(parse_command_line(argc
, argv
, dcmdline_options
), *conf
);
28 if (flag
|| conf
->count("help")) {
29 cerr
<< dcmdline_options
<< endl
;
35 bool operator()(const pair
<vector
<WordID
>, double>& a
, const pair
<vector
<WordID
>, double>& b
) const {
36 return a
.second
< b
.second
;
40 bool ReadKBestList(istream
* in
, string
* sent_id
, vector
<pair
<vector
<WordID
>, prob_t
> >* list
) {
41 static string cache_id
;
42 static pair
<vector
<WordID
>, prob_t
> cache_pair
;
45 if (cache_pair
.first
.size() > 0) {
46 list
->push_back(cache_pair
);
48 cache_pair
.first
.clear();
54 if (line
.empty()) continue;
55 size_t p1
= line
.find(" ||| ");
56 if (p1
== string::npos
) { cerr
<< "Bad format: " << line
<< endl
; abort(); }
57 size_t p2
= line
.find(" ||| ", p1
+ 4);
58 if (p2
== string::npos
) { cerr
<< "Bad format: " << line
<< endl
; abort(); }
59 size_t p3
= line
.rfind(" ||| ");
60 cache_id
= line
.substr(0, p1
);
61 tstr
= line
.substr(p1
+ 5, p2
- p1
- 5);
62 double val
= strtod(line
.substr(p3
+ 5).c_str(), NULL
);
63 TD::ConvertSentence(tstr
, &cache_pair
.first
);
64 cache_pair
.second
.logeq(val
);
65 if (cur_id
.empty()) cur_id
= cache_id
;
66 if (cur_id
== cache_id
) {
67 list
->push_back(cache_pair
);
69 cache_pair
.first
.clear();
72 return !list
->empty();
75 int main(int argc
, char** argv
) {
76 po::variables_map conf
;
77 InitCommandLine(argc
, argv
, &conf
);
78 const string metric
= conf
["loss_function"].as
<string
>();
79 const bool output_list
= conf
.count("output_list") > 0;
80 const string file
= conf
["input"].as
<string
>();
81 const double mbr_scale
= conf
["scale"].as
<double>();
82 cerr
<< "Posterior scaling factor (alpha) = " << mbr_scale
<< endl
;
84 ScoreType type
= ScoreTypeFromString(metric
);
85 vector
<pair
<vector
<WordID
>, prob_t
> > list
;
88 while(ReadKBestList(rf
.stream(), &sent_id
, &list
)) {
89 vector
<prob_t
> joints(list
.size());
90 const prob_t max_score
= pow(list
.front().second
, mbr_scale
);
91 prob_t marginal
= prob_t::Zero();
92 for (int i
= 0 ; i
< list
.size(); ++i
) {
93 const prob_t joint
= pow(list
[i
].second
, mbr_scale
) / max_score
;
95 // cerr << "list[" << i << "] joint=" << log(joint) << endl;
99 vector
<double> mbr_scores(output_list
? list
.size() : 0);
100 double mbr_loss
= numeric_limits
<double>::max();
101 for (int i
= 0 ; i
< list
.size(); ++i
) {
102 vector
<vector
<WordID
> > refs(1, list
[i
].first
);
103 //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
104 SentenceScorer
* scorer
= SentenceScorer::CreateSentenceScorer(type
, refs
);
106 for (int j
= 0; j
< list
.size(); ++j
) {
108 Score
* s
= scorer
->ScoreCandidate(list
[j
].first
);
109 double loss
= 1.0 - s
->ComputeScore();
110 if (type
== TER
|| type
== AER
) loss
= 1.0 - loss
;
112 double weighted_loss
= loss
* (joints
[j
] / marginal
);
113 wl_acc
+= weighted_loss
;
114 if ((!output_list
) && wl_acc
> mbr_loss
) break;
117 if (output_list
) mbr_scores
[i
] = wl_acc
;
118 if (wl_acc
< mbr_loss
) {
124 // cerr << "ML translation: " << TD::GetString(list[0].first) << endl;
125 cerr
<< "MBR Best idx: " << mbr_idx
<< endl
;
127 for (int i
= 0; i
< list
.size(); ++i
)
128 list
[i
].second
.logeq(mbr_scores
[i
]);
129 sort(list
.begin(), list
.end(), LossComparer());
130 for (int i
= 0; i
< list
.size(); ++i
)
131 cout
<< sent_id
<< " ||| "
132 << TD::GetString(list
[i
].first
) << " ||| "
133 << log(list
[i
].second
) << endl
;
135 cout
<< TD::GetString(list
[mbr_idx
].first
) << endl
;