some files I forgot to add
[dmvccm.git] / src / io-between.py
blob3170845304d775bf0e78bf8a53f793b7284de3a4
1 # io-between.py, trying out sentence locations between words (i<k<j)
3 DEBUG = set(['TODO'])
5 # some of dmv-module bleeding in here... todo: prettier (in inner())
6 NOBAR = 0
7 STOP = (NOBAR, -2)
9 def debug(string, level='TODO'):
10 '''Easily turn on/off inline debug printouts with this global. There's
11 a lot of cluttering debug statements here, todo: clean up'''
12 if level in DEBUG:
13 print string
16 class Grammar():
17 '''The PCFG used in the I/O-algorithm.
19 Public members:
20 p_terminals
22 Todo: as of now, this allows duplicate rules... should we check
23 for this? (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
24 inner probabilities of 2.)'''
25 def all_rules(self):
26 return self.__p_rules
28 def rules(self, LHS):
29 return [rule for rule in self.all_rules() if rule.LHS() == LHS]
31 def nums(self):
32 return self.__numtag
34 def sent_nums(self, sent):
35 return [self.tagnum(tag) for tag in sent]
37 def numtag(self, num):
38 return self.__numtag[num]
40 def tagnum(self, tag):
41 return self.__tagnum[tag]
43 def __init__(self, p_rules, p_terminals, numtag, tagnum):
44 '''rules and p_terminals should be arrays, where p_terminals are of
45 the form [preterminal, terminal], and rules are CNF_Rule's.'''
46 self.__p_rules = p_rules # todo: could check for summing to 1 (+/- epsilon)
47 self.__numtag = numtag
48 self.__tagnum = tagnum
49 self.p_terminals = p_terminals
54 class CNF_Rule():
55 '''A single CNF rule in the PCFG, of the form
56 LHS -> L R
57 where these are just integers
58 (where do we save the connection between number and symbol?
59 symbols being 'vbd' etc.)'''
60 def __eq__(self, other):
61 return self.LHS() == other.LHS() and self.R() == other.R() and self.L() == other.L()
62 def __ne__(self, other):
63 return self.LHS() != other.LHS() or self.R() != other.R() or self.L() != other.L()
64 def __str__(self):
65 return "%s -> %s %s [%.2f]" % (self.LHS(), self.L(), self.R(), self.prob)
66 def __init__(self, LHS, L, R, prob):
67 self.__LHS = LHS
68 self.__R = R
69 self.__L = L
70 self.prob = prob
71 def p(self, *arg):
72 "Return a probability, doesn't care about attachment..."
73 return self.prob
74 def LHS(self):
75 return self.__LHS
76 def L(self):
77 return self.__L
78 def R(self):
79 return self.__R
81 def inner(i, j, LHS, g, sent, chart):
82 ''' Give the inner probability of having the node LHS cover whatever's
83 between s and t in sentence sent, using grammar g.
85 Returns a pair of the inner probability and the chart
87 For DMV, LHS is a pair (bar, h), but this function ought to be
88 agnostic about that.
90 e() is an internal function, so the variable chart (a dictionary)
91 is available to all calls of e().
93 Since terminal probabilities are just simple lookups, they are not
94 put in the chart (although we could put them in there later to
95 optimize)
96 '''
98 def O(i,j):
99 return sent[i]
101 def e(i,j,LHS):
102 '''Chart has lists of probability and whether or not we've attached
103 yet to L and R, each entry is a list [p, Rattach, Lattach], where if
104 Rattach==True then the rule has a right-attachment or there is one
105 lower in the tree (meaning we're no longer adjacent).'''
106 if (i, j, LHS) in chart:
107 return chart[i, j, LHS]
108 else:
109 debug( "trying from %d to %d with %s" % (i,j,LHS) , "IO")
110 if i+1 == j:
111 if (LHS, O(i,j)) in g.p_terminals:
112 prob = g.p_terminals[LHS, O(i,j)] # b[LHS, O(s)] in L&Y
113 else:
114 prob = 0.0
115 print "\t LACKING TERMINAL:%s -> %s : %.1f" % (LHS, O(i,j), prob)
116 debug( "\t terminal: %s -> %s : %.1f" % (LHS, O(i,j), prob) ,"IO")
117 # terminals have no attachment
118 return prob
119 else:
120 if (i,j,LHS) not in chart:
121 # by default, not attachment yet
122 chart[i,j,LHS] = 0.0
123 for rule in g.rules(LHS): # summing over rules headed by LHS, "a[i,j,k]"
124 debug( "\tsumming rule %s" % rule , "IO")
125 L = rule.L()
126 R = rule.R()
127 for k in range(i+1, j): # i<k<j
128 p_L = e(i, k, L)
129 p_R = e(k, j, R)
130 p = rule.p()
131 chart[i, j, LHS] += p * p_L * p_R
132 debug( "\tchart[%d,%d,%s] = %.2f" % (i,j,LHS, chart[i,j,LHS]) ,"IO")
133 return chart[i, j, LHS]
134 # end of e-function
136 inner_prob = e(i,j,LHS)
137 if 'IO' in DEBUG:
138 print "---CHART:---"
139 for k,v in chart.iteritems():
140 print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
141 print "---CHART:end---"
142 return [inner_prob, chart]
151 if __name__ == "__main__":
152 print "IO-module tests:"
153 b = {}
154 s = CNF_Rule(0,1,2, 1.0) # s->np vp
155 np = CNF_Rule(1,3,4, 0.3) # np->n p
156 b[1, 'n'] = 0.7 # np->'n'
157 b[3, 'n'] = 1.0 # n->'n'
158 b[4, 'p'] = 1.0 # p->'p'
159 vp = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
160 vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
161 b[5, 'v'] = 1.0 # v->'v'
163 g = Grammar([s,np,vp,vp2], b, {0:'s',1:'np',2:'vp',3:'n',4:'p',5:'v'},
164 {'s':0,'np':1,'vp':2,'n':3,'p':4,'v':5})
166 # print "The rules:"
167 # for i in range(0,5):
168 # for r in g.rules(i):
169 # print r
170 # print ""
172 test1 = inner(0,1, 1, g, ['n'], {})
173 if test1[0] != 0.7:
174 print "should be 0.70 : %.3f" % test1[0]
175 print ""
177 test2 = inner(0,3, 2, g, ['v','n','p'], test1[1])
178 if test2[0] != 0.0930:
179 print "should be 0.0930 : %.4f" % test2[0]
180 test2 = inner(0,3, 2, g, ['v','n','p'], test2[1])
181 if test2[0] != 0.0930:
182 print "should be 0.0930 : %.4f" % test2[0]