test
[ws10smt.git] / extools / extract.h
blobf87aa6cb96c1d8ca557fbf4ab18e0324b9c65e73
1 #ifndef _EXTRACT_H_
2 #define _EXTRACT_H_
4 #include <iostream>
5 #include <utility>
6 #include <vector>
7 #include "array2d.h"
8 #include "wordid.h"
9 #include "sparse_vector.h"
11 struct AnnotatedParallelSentence;
13 // usually represents a consistent phrase, which may
14 // be annotated with a type (cat)
15 // inside the rule extractor, this class is also used to represent a word
16 // in a partial rule.
17 struct ParallelSpan {
18 // i1 = i of f side
19 // i2 = j of f side
20 // j1 = i of e side
21 // j2 = j of e side
22 short i1,i2,j1,j2;
23 // cat is set by AnnotatePhrasesWithCategoryTypes, otherwise it's 0
24 WordID cat; // category type of span (also overloaded by RuleItem class
25 // to be a word ID)
26 ParallelSpan() : i1(-1), i2(-1), j1(-1), j2(-1), cat() {}
27 // used by Rule class to represent a terminal symbol:
28 explicit ParallelSpan(WordID w) : i1(-1), i2(-1), j1(-1), j2(-1), cat(w) {}
29 ParallelSpan(int pi1, int pi2, int pj1, int pj2) : i1(pi1), i2(pi2), j1(pj1), j2(pj2), cat() {}
30 ParallelSpan(int pi1, int pi2, int pj1, int pj2, WordID c) : i1(pi1), i2(pi2), j1(pj1), j2(pj2), cat(c) {}
32 // ParallelSpan is used in the Rule class where it is
33 // overloaded to also represent terminal symbols
34 inline bool IsVariable() const { return i1 != -1; }
37 // rule extraction logic lives here. this has no data, it's just got
38 // static member functions.
39 struct Extract {
40 // RuleObserver's CountRule is called for each rule extracted
41 // implement CountRuleImpl to do things like count the rules,
42 // write them to a file, etc.
43 struct RuleObserver {
44 RuleObserver() : count() {}
45 virtual void CountRule(WordID lhs,
46 const std::vector<WordID>& rhs_f,
47 const std::vector<WordID>& rhs_e,
48 const std::vector<std::pair<short, short> >& fe_terminal_alignments) {
49 ++count;
50 CountRuleImpl(lhs, rhs_f, rhs_e, fe_terminal_alignments);
52 virtual ~RuleObserver();
54 protected:
55 virtual void CountRuleImpl(WordID lhs,
56 const std::vector<WordID>& rhs_f,
57 const std::vector<WordID>& rhs_e,
58 const std::vector<std::pair<short, short> >& fe_terminal_alignments) = 0;
59 private:
60 int count;
63 // given a set of "tight" phrases and the aligned sentence they were
64 // extracted from, "loosen" them
65 static void LoosenPhraseBounds(const AnnotatedParallelSentence& sentence,
66 const int max_base_phrase_size,
67 std::vector<ParallelSpan>* phrases);
69 // extract all consistent phrase pairs, up to size max_base_phrase_size
70 // (on the source side). these phrases will be "tight".
71 static void ExtractBasePhrases(const int max_base_phrase_size,
72 const AnnotatedParallelSentence& sentence,
73 std::vector<ParallelSpan>* phrases);
75 // this uses the TARGET span (i,j) to annotate phrases, will copy
76 // phrases if there is more than one annotation.
77 // TODO: support source annotation
78 static void AnnotatePhrasesWithCategoryTypes(const WordID default_cat,
79 const Array2D<std::vector<WordID> >& types,
80 std::vector<ParallelSpan>* phrases);
82 // use the Chiang (2007) extraction logic to extract consistent subphrases
83 // observer->CountRule is called once for each rule extracted
84 static void ExtractConsistentRules(const AnnotatedParallelSentence& sentence,
85 const std::vector<ParallelSpan>& phrases,
86 const int max_vars,
87 const int max_syms,
88 const bool permit_adjacent_nonterminals,
89 const bool require_aligned_terminal,
90 RuleObserver* observer);
93 // represents statistics / information about a rule pair
94 struct RuleStatistics {
95 SparseVector<float> counts;
96 std::vector<std::pair<short,short> > aligns;
97 RuleStatistics() {}
98 RuleStatistics(int name, float val, const std::vector<std::pair<short,short> >& al) :
99 aligns(al) {
100 counts.set_value(name, val);
102 void ParseRuleStatistics(const char* buf, int start, int end);
103 RuleStatistics& operator+=(const RuleStatistics& rhs) {
104 counts += rhs.counts;
105 return *this;
108 std::ostream& operator<<(std::ostream& os, const RuleStatistics& s);
110 #endif