test
[ws10smt.git] / vest / mbr_kbest.cc
blob5d70b4e2dcccbbe0ada5f061bbfc16b74f4952f3
1 #include <iostream>
2 #include <vector>
4 #include <boost/program_options.hpp>
6 #include "prob.h"
7 #include "tdict.h"
8 #include "scorer.h"
9 #include "filelib.h"
10 #include "stringlib.h"
12 using namespace std;
14 namespace po = boost::program_options;
16 void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
17 po::options_description opts("Configuration options");
18 opts.add_options()
19 ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)")
20 ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function")
21 ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")
22 ("output_list,L", "Show reranked list as output")
23 ("help,h", "Help");
24 po::options_description dcmdline_options;
25 dcmdline_options.add(opts);
26 po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
27 bool flag = false;
28 if (flag || conf->count("help")) {
29 cerr << dcmdline_options << endl;
30 exit(1);
34 struct LossComparer {
35 bool operator()(const pair<vector<WordID>, double>& a, const pair<vector<WordID>, double>& b) const {
36 return a.second < b.second;
40 bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
41 static string cache_id;
42 static pair<vector<WordID>, prob_t> cache_pair;
43 list->clear();
44 string cur_id;
45 if (cache_pair.first.size() > 0) {
46 list->push_back(cache_pair);
47 cur_id = cache_id;
48 cache_pair.first.clear();
50 string line;
51 string tstr;
52 while(*in) {
53 getline(*in, line);
54 if (line.empty()) continue;
55 size_t p1 = line.find(" ||| ");
56 if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
57 size_t p2 = line.find(" ||| ", p1 + 4);
58 if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
59 size_t p3 = line.rfind(" ||| ");
60 cache_id = line.substr(0, p1);
61 tstr = line.substr(p1 + 5, p2 - p1 - 5);
62 double val = strtod(line.substr(p3 + 5).c_str(), NULL);
63 TD::ConvertSentence(tstr, &cache_pair.first);
64 cache_pair.second.logeq(val);
65 if (cur_id.empty()) cur_id = cache_id;
66 if (cur_id == cache_id) {
67 list->push_back(cache_pair);
68 *sent_id = cur_id;
69 cache_pair.first.clear();
70 } else { break; }
72 return !list->empty();
75 int main(int argc, char** argv) {
76 po::variables_map conf;
77 InitCommandLine(argc, argv, &conf);
78 const string metric = conf["loss_function"].as<string>();
79 const bool output_list = conf.count("output_list") > 0;
80 const string file = conf["input"].as<string>();
81 const double mbr_scale = conf["scale"].as<double>();
82 cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl;
84 ScoreType type = ScoreTypeFromString(metric);
85 vector<pair<vector<WordID>, prob_t> > list;
86 ReadFile rf(file);
87 string sent_id;
88 while(ReadKBestList(rf.stream(), &sent_id, &list)) {
89 vector<prob_t> joints(list.size());
90 const prob_t max_score = pow(list.front().second, mbr_scale);
91 prob_t marginal = prob_t::Zero();
92 for (int i = 0 ; i < list.size(); ++i) {
93 const prob_t joint = pow(list[i].second, mbr_scale) / max_score;
94 joints[i] = joint;
95 // cerr << "list[" << i << "] joint=" << log(joint) << endl;
96 marginal += joint;
98 int mbr_idx = -1;
99 vector<double> mbr_scores(output_list ? list.size() : 0);
100 double mbr_loss = numeric_limits<double>::max();
101 for (int i = 0 ; i < list.size(); ++i) {
102 vector<vector<WordID> > refs(1, list[i].first);
103 //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
104 SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs);
105 double wl_acc = 0;
106 for (int j = 0; j < list.size(); ++j) {
107 if (i != j) {
108 Score* s = scorer->ScoreCandidate(list[j].first);
109 double loss = 1.0 - s->ComputeScore();
110 if (type == TER || type == AER) loss = 1.0 - loss;
111 delete s;
112 double weighted_loss = loss * (joints[j] / marginal);
113 wl_acc += weighted_loss;
114 if ((!output_list) && wl_acc > mbr_loss) break;
117 if (output_list) mbr_scores[i] = wl_acc;
118 if (wl_acc < mbr_loss) {
119 mbr_loss = wl_acc;
120 mbr_idx = i;
122 delete scorer;
124 // cerr << "ML translation: " << TD::GetString(list[0].first) << endl;
125 cerr << "MBR Best idx: " << mbr_idx << endl;
126 if (output_list) {
127 for (int i = 0; i < list.size(); ++i)
128 list[i].second.logeq(mbr_scores[i]);
129 sort(list.begin(), list.end(), LossComparer());
130 for (int i = 0; i < list.size(); ++i)
131 cout << sent_id << " ||| "
132 << TD::GetString(list[i].first) << " ||| "
133 << log(list[i].second) << endl;
134 } else {
135 cout << TD::GetString(list[mbr_idx].first) << endl;
138 return 0;