before trying new pchoose reestimation
[dmvccm.git] / src / io.py~20080525~
blobfeb5fa032da9bc2a461be68c93d21f02ea705a04
1 ##################################################################
2 #                            Changes:                            #
3 ##################################################################
4 # 2008-05-24, KBU:
5 # - the chart variable in inner() is treated by Python like a
6 #   "global", documented this in the function
7 # - 
9 # import numpy # numpy provides Fast Arrays, for future optimization
10 import pprint
12 DEBUG = 1
13 def debug(string):
14     "Easily turn on/off inline debug printouts with the global"
15     if DEBUG:
16         print string
18 class CNF_Rule():
19     '''A single CNF rule in the PCFG, of the form 
20     head -> L R
21     where these are just integers
22     (where do we save the connection between number and symbol?
23     symbols being 'vbd' etc.)'''
24     def __str__(self):
25         return "%s -> %s %s [%.2f]" % (self.head, self.L, self.R, self.prob)
26     def __init__(self, head, L, R, prob):
27         self.head = head
28         self.R = R
29         self.L = L
30         self.prob = prob
32 class Grammar():
33     '''The PCFG used in the I/O-algorithm.
35     Todo: as of now, this allows duplicate rules... should we check
36     for this?  (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
37     inner probabilities of 2.)'''
38     def rules(self, head):
39         return [rule for rule in self.p_rules if rule.head == head]
40     
41     def __init__(self, p_rules, p_terminals):
42         '''p_rules and p_terminals should be arrays, where p_terminals are of
43         the form [preterminal, terminal], and p_rules are CNF_Rule's.'''        
44         self.p_rules = p_rules # could check for summing to 1 (+/- epsilon)
45         self.p_terminals = p_terminals
49 def inner(s, t, i, g, sent, chart = {}):
50     ''' Give the inner probability of having the node i cover whatever's
51     between s and t in sentence sent, using grammar g.
53     For DMV, i is a pair (bar, h), but this function ought to be
54     agnostic about that.
56     e() is an internal function, so the variable chart (a dictionary)
57     is available to all calls of e().
59     Also, importantly, on subsequent calls to inner, if no value for
60     chart is given, the last value is used, so if after calling >>>
61     inner(s,t,i,g,sent) chart is {foo:bar, fie:foe}, then within the
62     next such call to inner, chart is still {foo:bar, fie:foe}, even
63     though the default value is {}. So this is a sort of Python
64     "global".
65     '''
66     
67     def O(s):
68         return sent[s]
69     
70     def e(s,t,i):
71         if (s, t, i) in chart:
72             return chart[(s, t, i)]
73         else:
74             debug( "trying from %d to %d with %s" % (s,t,i) )
75             if s == t:
76                 if (i, O(s)) in g.p_terminals:
77                     prob = g.p_terminals[i, O(s)] # b[i, O(s)]
78                 else:
79                     prob = 0.0 # todo: is this the right way to deal with lacking rules?
80                     debug( "\tterminal: %s -> %s : %.1f" % (i, O(s), prob) )
81                 return prob
82             else:
83                 if (s,t,i) not in chart:
84                     chart[(s,t,i)] = 0.0
85                 for rule in g.rules(i): # summing over j,k in a[i,j,k]
86                     debug( "\tsumming rule %s" % rule ) 
87                     L = rule.L
88                     R = rule.R
89                     for r in range(s, t): # summing over r = s to r = t-1
90                         chart[(s, t, i)] += rule.prob * e(s, r, L) * e(r + 1, t, R)
91                 debug( "\tchart[(%d,%d,%s)] = %.2f" % (s,t,i, chart[(s,t,i)]) )
92                 return chart[(s, t, i)]
93     # end of e-function
94     
95     h = e(s,t,i)
96     if DEBUG:
97         print "---CHART:---"
98         for k,v in chart.iteritems():
99             print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
100         print "---CHART:end---"
101     return h
110 if __name__ == "__main__":
111     print "IO-module tests:"
112     b = {}
113     s   = CNF_Rule(0,1,2, 1.0) # s->np vp
114     np  = CNF_Rule(1,3,4, 0.3) # np->n p
115     b[1, 'n'] = 0.7 # np->'n'
116     b[3, 'n'] = 1.0 # n->'n'
117     b[4, 'p'] = 1.0 # p->'p'
118     vp  = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
119     vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
120     b[5, 'v'] = 1.0 # v->'v'
121     
122     g = Grammar([s,np,vp,vp2], b)
123     
124     print "The rules:"
125     for i in range(0,5):
126         for r in g.rules(i):
127             print r
128     print ""
130     test1 = inner(0,0, 1, g, ['n'], {})
131     if test1 != 0.7:
132         print "should be 0.70 : %.2f" % test1
133         print ""
134     
135     test2 = inner(0,2, 2, g, ['v','n','p'], {})
136     print "should be 0.?? : %.2f" % test2
138     
139 ##################################################################
140 #            just junk from here on down:                        #
141 ##################################################################
143 def io():
144     return "todo"
146 #    a = initialize(tagset)
147 #    b = initialize_t(tagset)
149 ## now I should be able to say a[X->X Y]
151 #     for i in a:
152 #         for j,k in a:
153 #         a[i,j,k] = sum sum sum w_q(s,t,i,j,k) / sum sum sum v_q(s,t,i)
154 #     for m in b:
155 #         b[i,m] = sum sum v_q(t,t,i) / sum sum sum v_q(s,t,i)
159 def virahanka3(n, lookup={0:[""], 1:["S"]}):
160     '''An example of a top-down recursive dynamic function; perhaps
161 inner(s,t,i) could have inner_chart as an argument also? (Or would we miss out
162 on certain values of inner_chart then?)
164 Check whatever is faster:
165 >>> from timeit import Timer
166 >>> Timer("all_inner(sent, grammar)", "sent = 'nn vbd det nn foo bar baz' grammar = ...").timeit()
168     if n not in lookup:
169         s = ["S" + prosody for prosody in virahanka3(n-1)]
170         l = ["L" + prosody for prosody in virahanka3(n-2)]
171         lookup[n] = s + l
172     return lookup[n]