added ccm.py
[dmvccm.git] / src / loc_h_dmv.py.new
blobb1d44e92d73030593ea7d09b9fcb5d4021e0defa
1 # loc_h_dmv.py
2
3 # dmv reestimation and inside-outside probabilities using loc_h
5 # for now, if 3 < loc_h < 4 then I use loc_h=3.5; although this is far
6 # from ideal, the code reads a lot easier -- and it's faster than
7 # using a pair for loc_h
9 #import numpy # numpy provides Fast Arrays, for future optimization
10 import io
11 from common_dmv import *
13 if __name__ == "__main__":
14     print "loc_h_dmv module tests:"
16 class DMV_Grammar(io.Grammar):
17     '''The DMV-PCFG.
19     Public members:
20     p_STOP, p_ROOT, p_CHOOSE, p_terminals
21     These are changed in the Maximation step, then used to set the
22     new probabilities of each DMV_Rule.
24     Todo: make p_terminals private? (But it has to be changable in
25     maximation step due to the short-cutting rules... could of course
26     make a DMV_Grammar function to update the short-cut rules...)
28     __p_rules is private, but we can still say stuff like:
29     for r in g.all_rules():
30         r.probN = newProbN
31     
32     What other representations do we need? (P_STOP formula uses
33     deps_D(h,l/r) at least)'''
34     def __str__(self):
35         str = ""
36         for r in self.all_rules():
37              str += "%s\n" % r.__str__(self.numtag)
38         return str
40     def h_rules(self, h):
41         return [r for r in self.all_rules() if r.head() == h]
42     
43     def mothersL(self, Node, sent_nums, loc_N):
44         # todo: speed-test with and without sent_nums/loc_N cut-off
45         return [r for r in self.all_rules() if r.L() == Node
46                 and (head(r.R()) in sent_nums[loc_N+1:] or r.R() == STOP)]
47     
48     def mothersR(self, Node, sent_nums, loc_N):
49         return [r for r in self.all_rules() if r.R() == Node
50                 and (head(r.L()) in sent_nums[:loc_N] or r.L() == STOP)]
52     def rules(self, LHS):
53         return [r for r in self.all_rules() if r.LHS() == LHS]
54     
55     def sent_rules(self, LHS, sent_nums):
56         '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
57         seems. Could use some more space and cache some of this
58         somehow perhaps?'''
59         # We don't want to rule out STOPs!
60         nums = sent_nums + [ head(STOP) ]
61         return [r for r in self.all_rules() if r.LHS() == LHS
62                 and head(r.L()) in nums and head(r.R()) in nums]
63     
64     def deps_L(self, head): # todo: do I use this at all?
65         # todo test, probably this list comprehension doesn't work 
66         return [a for r in self.all_rules() if r.head() == head and a == r.L()]
67     
68     def deps_R(self, head):
69         # todo test, probably this list comprehension doesn't work 
70         return [a for r in self.all_rules() if r.head() == head and a == r.R()]
71     
72     def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
73         io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
74         self.p_STOP = p_STOP
75         self.p_CHOOSE = p_CHOOSE
76         self.p_ROOT = p_ROOT
77         self.head_nums = [k for k in numtag.iterkeys()]
78         
80 class DMV_Rule(io.CNF_Rule):
81     '''A single CNF rule in the PCFG, of the form 
82     LHS -> L R
83     where LHS, L and R are 'nodes', eg. of the form (seals, head).
84     
85     Public members:
86     probN, probA
87     
88     Private members:
89     __L, __R, __LHS
90     
91     Different rule-types have different probabilities associated with
92     them:
94     _h_ -> STOP  h_     P( STOP|h,L,    adj)
95     _h_ -> STOP  h_     P( STOP|h,L,non_adj)
96      h_ ->  h  STOP     P( STOP|h,R,    adj)
97      h_ ->  h  STOP     P( STOP|h,R,non_adj)
98      h_ -> _a_   h_     P(-STOP|h,L,    adj) * P(a|h,L)
99      h_ -> _a_   h_     P(-STOP|h,L,non_adj) * P(a|h,L)
100      h  ->  h   _a_     P(-STOP|h,R,    adj) * P(a|h,R)
101      h  ->  h   _a_     P(-STOP|h,R,non_adj) * P(a|h,R) 
102     '''
103     def p(self, middle, loc_h):
104         "middle is eg. k when rewriting for i<k<j (inside probabilities)."
105         if middle-1 < loc_h < middle+1:
106             return self.probA
107         else:
108             return self.probN
110     def seals(self):
111         return seals(self.LHS())
112     
113     def head(self):
114         return head(self.LHS())
115     
116     def __init__(self, LHS, L, R, probN, probA):
117         for b_h in [LHS, L, R]:
118             if seals(b_h) not in SEALS:
119                 raise ValueError("seals must be in %s; was given: %s"
120                                  % (SEALS, seals(b_h)))
121         io.CNF_Rule.__init__(self, LHS, L, R, probN)
122         self.probA = probA # adjacent
123         self.probN = probN # non_adj
124         
125     def __str__(self, tag=lambda x:x):
126         return "%s-->%s %s\t[N %.2f] [A %.2f]" % (bar_str(self.LHS(), tag),
127                                                   bar_str(self.L(), tag),
128                                                   bar_str(self.R(), tag),
129                                                   self.probN,
130                                                   self.probA)
131     
133     
138 ###################################
139 # dmv-specific version of inner() #
140 ###################################
141 def locs(h, sent, s=0, t=None, remove=None):
142     '''Return the locations of h in sent, or some fragment of sent (in the
143     latter case we make sure to offset the locations correctly so that
144     for any x in the returned list, sent[x]==h).
146     t is inclusive, to match the way indices work with inner()
147     (although python list-splicing has "exclusive" end indices)'''
148     if t == None:
149         t = len(sent)-1
150     return [i+s for i,w in enumerate(sent[s:t+1])
151             if w == h and not (i+s) == remove]
154 def inner(i, j, LHS, loc_h, g, sent, ichart={}):
155     ''' A rewrite of io.inner(), to take adjacency into accord.
157     The ichart is now of this form:
158     ichart[i,j,LHS, loc_h]
159     where i and j are between-word positions.
160     
161     loc_h gives adjacency (along with r and location of other child
162     for attachment rules), and is needed in P_STOP reestimation.
163     
164     Todo: if possible, refactor (move dmv-specific stuff back into
165     dmv, so this is "general" enough to be in io.py)
166     '''
167     
168     def O(i,j):
169         return sent[i]
170     
171     sent_nums = g.sent_nums(sent)
172     
173     def e(i,j,LHS, loc_h, n_t):
174         def tab():
175             "Tabs for debug output"
176             return "\t"*n_t
177         
178         if (i, j, LHS, loc_h) in ichart:
179             if 'INNER' in io.DEBUG:
180                 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%s" % (tab(),ichart[s, t, LHS, loc_h], s, t,
181                                                                        bar_str(LHS), loc_h)
182             return ichart[i, j, LHS, loc_h]
183         else: # Either terminal rewrites:
184             if i+1 == j and (seals(LHS) == GOR or seals(LHS) == GOL):
185                 if not i < loc_h < j:
186                     if 'INNER' in io.DEBUG:
187                         print "%s*= 0.0 (wrong loc_h)" % tab()
188                     return 0.0
189                 elif (LHS, O(i,j)) in g.p_terminals:
190                     # this assumes g.p_terminals has encoded P_ORDER. Todo.
191                     prob = g.p_terminals[LHS, O(i,j)] 
192                 else:
193                     prob = 0.0 
194                     if 'INNER' in io.DEBUG:
195                         print "%sLACKING TERMINAL:" % tab()
196                 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
197                 if 'INNER' in io.DEBUG:
198                     print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, bar_str(LHS), O(i,j), loc_h) 
199                 return prob
200             else: # Or not a terminal yet:
201                 p = 0.0 
202                 for rule in g.sent_rules(LHS, sent_nums): 
203                     if 'INNER' in io.DEBUG:
204                         print "%ssumming rule %s i:%d j:%d loc:%d" % (tab(),rule,i,j,loc_h) 
205                     L = rule.L()
206                     R = rule.R()
207                     # todo: speed-test, and check if it works with left-first
208 #                     if loc_h == j and rule.LHS() == L:
209 #                         continue 
210 #                     if loc_h == i and rule.LHS() == R:
211 #                         continue 
212                     # All STOP rules we rewrite for the same xrange,
213                     # independently of seals:
214                     if L == STOP or R == STOP:
215                         if L == STOP:
216                             p += rule.p(i,loc_h) * e(i, j, R, loc_h, n_t+1)
217                         elif R == STOP:
218                             p += rule.p(j,loc_h) * e(i, j, L, loc_h, n_t+1)
219                         if 'INNER' in io.DEBUG:
220                             print "%sp= %.4f (STOP)" % (tab(), p) 
221                             
222                     elif j > i: # Not a STOP => an attachment rewrite:
223                         rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
224                         for k in xrange(i+1, j):
225                             p_h = rp_ATTACH(r, loc_h, s=s)
226                             if rule.LHS() == L: 
227                                 locs_L = [loc_h]
228                                 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
229                             elif rule.LHS() == R: 
230                                 locs_L = locs(head(L), sent_nums,  s,  r, loc_h)
231                                 locs_R = [loc_h]
232                             for loc_L in locs_L:
233                                 pL = e(s, r, L, loc_L, n_t+1)
234                                 if pL > 0.0: 
235                                     for loc_R in locs_R:
236                                         pR = e(r+1, t, R, loc_R, n_t+1)
237                                         p += pL * p_h * pR
238                             if 'INNER' in io.DEBUG:
239                                 print "%sp= %.4f (ATTACH)" % (tab(), p) 
240                 ichart[s, t, LHS, loc_h] = p
241                 return p
242     # end of e-function
243             
244     inner_prob = e(i,j,LHS,loc_h, 0)
245     if 'INNER' in io.DEBUG:
246         print debug_ichart(g,sent,ichart)
247     return inner_prob
248 # end of dmv.inner(i, j, LHS, loc_h, g, sent, ichart={})
251 def debug_ichart(g,sent,ichart):
252     str = "---ICHART:---\n"
253     for (s,t,LHS,loc_h),v in ichart.iteritems():
254         if type(v) == dict: # skip 'tree'
255             continue
256         str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (bar_str(LHS,g.numtag),
257                                                               sent[s], s, sent[s], t, loc_h, v)
258     str += "---ICHART:end---\n"
259     return str
262 def inner_sent(g, sent, ichart={}):
263     return sum([inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
264                 for loc_h in xrange(len(sent))])
267 ###################################
268 # dmv-specific version of outer() #
269 ###################################
270 def outer(s,t,Node,loc_N, g, sent, ichart={}, ochart={}):
271     ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
272     '''
273     def e(s,t,LHS,loc_h):
274         # or we could just look it up in ichart, assuming ichart to be done
275         return inner(s, t, LHS, loc_h, g, sent, ichart)
276     
277     T = len(sent)-1
278     sent_nums = g.sent_nums(sent)
279     
280     def f(s,t,Node,loc_N):
281         if (s,t,Node,loc_N) in ochart:
282             return ochart[(s, t, Node,loc_N)]
283         if Node == ROOT:
284             if s == 0 and t == T:
285                 return 1.0
286             else: # ROOT may only be used on full sentence
287                 return 0.0 # but we may have non-ROOTs over full sentence too
288         p = 0.0
289         
290         for mom in g.mothersL(Node, sent_nums, loc_N): # mom.L() == Node
291             R = mom.R()
292             mLHS = mom.LHS()
293             if R == STOP:
294                 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
295             else:
296                 if seals(mLHS) == RGOL: # left attachment, head(mLHS) == head(R)
297                     for r in xrange(t+1,T+1): # t+1 to lasT 
298                         for loc_m in locs(head(mLHS),sent_nums,t+1,r):
299                             p_m = mom.p(t+1 == loc_m)
300                             p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
301                 elif seals(mLHS) == GOR: # right attachment, head(mLHS) == head(Node)
302                     loc_m = loc_N
303                     p_m = mom.p( t  == loc_m)
304                     for r in xrange(t+1,T+1): # t+1 to lasT 
305                         for loc_R in locs(head(R),sent_nums,t+1,r):
306                             p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
307         
308         for mom in g.mothersR(Node, sent_nums, loc_N): # mom.R() == Node
309             L = mom.L()
310             mLHS = mom.LHS()
311             if L == STOP:
312                 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
313             else:
314                 if seals(mLHS) == RGOL: # left attachment, head(mLHS) == head(Node)
315                     loc_m = loc_N
316                     p_m = mom.p( s  == loc_m)
317                     for r in xrange(0,s): # first to s-1 
318                         for loc_L in locs(head(L),sent_nums,r,s-1):
319                             p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
320                 elif seals(mLHS) == GOR: # right attachment, head(mLHS) == head(L)
321                     for r in xrange(0,s): # first to s-1
322                         for loc_m in locs(head(mLHS),sent_nums,r,s-1): 
323                             p_m = mom.p(s-1 == loc_m)
324                             p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
325         ochart[s,t,Node,loc_N] = p
326         return p
328     
329     return f(s,t,Node,loc_N)
330 # end outer(s,t,Node,loc_N, g,sent, ichart,ochart)
334 ##############################
335 #      reestimation, todo:   #
336 ##############################
337 ## using local version instead
338 # def c(s,t,LHS,loc_h,g,sent,ichart={},ochart={}):
339 #     # assuming P_sent = P(D(ROOT)) = inner(sent). todo: check K&M about this
340 #     p_sent = inner_sent(g, sent, ichart)
341 #     p_in = inner(s,t,LHS,loc_h,g,sent,ichart) 
342 #     p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
343 #     if p_sent > 0.0:
344 #         return p_in * p_out / p_sent
345 #     else:
346 #         return p_sent
348 def reest_zeros(h_nums):
349     # todo: p_ROOT? ... p_terminals?
350     f = {}
351     for h in h_nums:
352         for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
353             for nd in ['num','den']:
354                 f[stop,nd,h] = 0.0
355         for choice in ['RCHOOSE', 'LCHOOSE']:
356             f[choice,'den',h] = 0.0
357     return f
359 def reest_freq(g, corpus):
360     ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
361     f = reest_zeros(g.head_nums)
362     ichart = {}
363     ochart = {}
364     
365     p_sent = None # 50 % speed increase on storing this locally
366     def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
367         if (s,t,LHS,loc_h) in ichart:
368             p_in = ichart[s,t,LHS,loc_h]
369         else:
370             p_in = inner(s,t,LHS,loc_h,g,sent,ichart) 
371         if (s,t,LHS,loc_h) in ochart:
372             p_out = ochart[s,t,LHS,loc_h]
373         else:
374             p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
376         if p_sent > 0.0:
377             return p_in * p_out / p_sent
378         else:
379             return p_sent
381     def w_g(s,t,a,loc_a,LHS,loc_h,sent):
382         "Todo: should sum through all r in between s and t in sent(_nums)"
383         h = head(LHS)
384         b_h = seals(LHS)
385         if b_h == GOR:
386             return e_L * e_R * f_g(s,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
387         if b_h == RGOL:
388             return e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
390     def f_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
391         if (s,t,LHS,loc_h) in ochart:
392             return ochart[s,t,LHS,loc_h]
393         else:
394             return outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
396     def e_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
397         if (s,t,LHS,loc_h) in ichart:
398             return ichart[s,t,LHS,loc_h]
399         else:
400             return inner(s,t,LHS,loc_h,g,sent,ichart) 
401         
402     def p_g(r,LHS,L,R,loc_h,sent):
403         rules = [rule for rule in g.sent_rules(LHS, sent)
404                  if rule.L() == L and rule.R() == R]
405         rule = rules[0]
406         if len(rules) > 1:
407             raise Exception("Several rules matching a[i,j,k]")
408         return rule.p_ATTACH(r,loc_h)
410     for sent in corpus:
411         if 'reest' in io.DEBUG:
412             print sent
413         ichart = {}
414         ochart = {}
415         p_sent = inner_sent(g, sent, ichart)
417         sent_nums = g.sent_nums(sent)
418         # todo: use sum([ichart[s, t...] etc? but can we then
419         # keep den and num separate within _one_ sum()-call?
420         for loc_h,h in enumerate(sent_nums):
421             for t in xrange(loc_h, len(sent)):
422                 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
423                     # left non-adjacent stop:
424                     f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
425                     f['LNSTOP','den',h] += c_g(s, t, (RGOL,h), loc_h,sent)
426                 # left adjacent stop:
427                 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
428                 f['LASTOP','den',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
429             for t in xrange(loc_h+1, len(sent)):
430                 # right non-adjacent stop:
431                 f['RNSTOP','num',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
432                 f['RNSTOP','den',h] += c_g(loc_h, t, (GOR, h), loc_h,sent)
433             # right adjacent stop:
434             f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGOL,h), loc_h,sent)
435             f['RASTOP','den',h] += c_g(loc_h, loc_h, (GOR, h), loc_h,sent)
437             # right attachment:  TODO: try with p*e*e*f instead of c, for numerator
438             if 'reest_attach' in io.DEBUG:
439                 print "Rattach %s: for t in %s"%(g.numtag(h),sent[loc_h+1:len(sent)])
440             for t in xrange(loc_h+1, len(sent)): 
441                 cM = c_g(loc_h,t,(GOR, h), loc_h, sent) # v_q in L&Y 
442                 f['RCHOOSE','den',h] += cM
443                 if 'reest_attach' in io.DEBUG:
444                     print "\tc_g( %d , %d, %s, %s, sent)=%.4f"%(loc_h,t,g.numtag(h),loc_h,cM)
445                 args = {} # for summing w_q's in L&Y, without 1/P_q
446                 for r in xrange(loc_h+1, t+1): # loc_h < r <= t 
447                     e_L = e_g(loc_h, r-1, (GOR, h), loc_h, sent)
448                     if 'reest_attach' in io.DEBUG:
449                         print "\t\te_g( %d , %d, %s, %d, sent)=%.4f"%(loc_h,r-1,g.numtag(h),loc_h,e_L)
450                     for i,a in enumerate(sent_nums[r:t+1]):
451                         loc_a = i+r
452                         e_R = e_g(r, t, (SEAL, a), loc_a, sent)
453                         if a not in args:
454                             args[a] = 0.0
455                         args[a] += e_L * e_R * f_g(loc_h,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
456                     for a,sum_a in args.iteritems():
457                         f['RCHOOSE','num',h,a] = sum_a / p_sent
458                         
460             # left attachment:
461             if 'reest_attach' in io.DEBUG:
462                 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
463             for s in xrange(0, loc_h):
464                 if 'reest_attach' in io.DEBUG:
465                     print "\tfor t in %s"%sent[loc_h:len(sent)]
466                 for t in xrange(loc_h, len(sent)):
467                     c_M = c_g(s,t,(RGOL, h), loc_h, sent) # v_q in L&Y 
468                     f['LCHOOSE','den',h] += c_M
469                     if 'reest_attach' in io.DEBUG:
470                         print "\t\tc_g( %d , %d, %s_, %s, sent)=%.4f"%(s,t,g.numtag(h),loc_h,c_M)
471                     if 'reest_attach' in io.DEBUG:
472                         print "\t\tfor r in %s"%(sent[s:loc_h])
473                     args = {} # for summing w_q's in L&Y, without 1/P_q
474                     for r in xrange(s, loc_h): # s <= r < loc_h <= t
475                         e_R = e_g(r+1, t, (RGOL, h), loc_h, sent)
476                         if 'reest_attach' in io.DEBUG:
477                             print "\t\te_g( %d , %d, %s_, %d, sent)=%.4f"%(r+1,t,g.numtag(h),loc_h,e_R)
478                         for i,a in enumerate(sent_nums[s:r+1]):
479                             loc_a = i+s
480                             e_L = e_g( s , r, (SEAL, a), loc_a, sent)
481                             if a not in args:
482                                 args[a] = 0.0
483                             args[a] += e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
484                     for a,sum_a in args.iteritems():
485                         f['LCHOOSE', 'num',h,a] = sum_a / p_sent 
486     return f
488 def reestimate(g, corpus):
489     ""
490     f = reest_freq(g, corpus)
491     # we want to go through only non-ROOT left-STOPs.. 
492     for r in g.all_rules():
493         reest_rule(r,f, g)
494     return f
497 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
498     "remove 0-prob rules? todo"
499     h = r.head()
500     if r.LHS() == ROOT:
501         return None # not sure what todo yet here
502     if r.L() == STOP or head(r.R()) == h:
503         dir = 'L'
504     elif r.R() == STOP or head(r.L()) == h:
505         dir = 'R'
506     else:
507         raise Exception("Odd rule in reestimation.")
509     p_stopN = f[dir+'NSTOP','den',h]
510     if p_stopN > 0.0:
511         p_stopN = f[dir+'NSTOP','num',h] / p_stopN
513     p_stopA = f[dir+'ASTOP','den',h]
514     if p_stopA > 0.0:
515         p_stopA = f[dir+'ASTOP','num',h] / p_stopA
517     if r.L() == STOP or r.R() == STOP: # stop rules
518         if 'reest' in io.DEBUG:
519             print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN) 
520             print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA) 
521         r.probN = p_stopN
522         r.probA = p_stopA
524     else: # attachment rules
525         pchoose = f[dir+'CHOOSE','den',h]
526         if pchoose > 0.0:
527             if head(r.R()) == h: # left attachment
528                 a = head(r.L())
529             elif head(r.L()) == h: # right attachment
530                 a = head(r.R())
531             pchoose = f[dir+'CHOOSE','num',h,a] / pchoose 
532             r.probN = (1-p_stopN) * pchoose
533             r.probA = (1-p_stopA) * pchoose
534             if 'reest' in io.DEBUG:
535                 print "p(%d=%s|%d=%s,%s): %.4f,\tprobN: %.4f, probA: %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose,r.probN,r.probA) 
543 ##############################
544 #     testing functions:     #
545 ##############################
547 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
548                                   'det nn vbd',      'det nn vbd c pp', 
549                                   'det nn vbd',      'det vbd vbd c pp', 
550                                   'det nn vbd',      'det nn vbd c vbd', 
551                                   'det nn vbd',      'det nn vbd c vbd', 
552                                   'det nn vbd',      'det nn vbd c vbd', 
553                                   'det nn vbd',      'det nn vbd c pp', 
554                                   'det nn vbd pp',   'det nn vbd', ]]
556 def testgrammar():
557     import loc_h_harmonic
558     reload(loc_h_harmonic)
559     return loc_h_harmonic.initialize(testcorpus)
561 def testreestimation():
562     io.DEBUG.add('reest')
563     g = testgrammar()
564     f = reestimate(g, testcorpus)
565     f_stops = {('LNSTOP', 'den', 3): 12.212773236178391, ('RASTOP', 'den', 2): 4.0, ('RNSTOP', 'num', 4): 2.5553487221351365, ('LNSTOP', 'den', 2): 1.274904052793207, ('LASTOP', 'num', 1): 14.999999999999995, ('RASTOP', 'den', 3): 15.0, ('LASTOP', 'num', 4): 16.65701084787457, ('LASTOP', 'num', 0): 4.1600647714443468, ('LNSTOP', 'den', 4): 6.0170669155897105, ('LASTOP', 'num', 3): 2.7872267638216113, ('LASTOP', 'num', 2): 2.9723139990470515, ('LASTOP', 'den', 2): 4.0, ('RNSTOP', 'den', 3): 12.945787931730905, ('LASTOP', 'den', 3): 14.999999999999996, ('RNSTOP', 'den', 2): 0.0, ('LASTOP', 'den', 0): 8.0, ('RASTOP', 'num', 4): 19.44465127786486, ('RNSTOP', 'den', 1): 3.1966410324085777, ('LASTOP', 'den', 1): 14.999999999999995, ('RASTOP', 'num', 3): 4.1061665495365558, ('RNSTOP', 'den', 0): 4.8282499043902476, ('LNSTOP', 'num', 4): 5.3429891521254289, ('RASTOP', 'num', 2): 4.0, ('LASTOP', 'den', 4): 22.0, ('RASTOP', 'num', 1): 12.400273895299103, ('LNSTOP', 'num', 2): 1.0276860009529487, ('RASTOP', 'num', 0): 3.1717500956097533, ('LNSTOP', 'num', 3): 12.212773236178391, ('RASTOP', 'den', 4): 22.0, ('RNSTOP', 'den', 4): 2.8705211946979836, ('LNSTOP', 'num', 0): 3.8399352285556518, ('LNSTOP', 'num', 1): 0.0, ('RNSTOP', 'num', 0): 4.8282499043902476, ('RNSTOP', 'num', 1): 2.5997261047008959, ('LNSTOP', 'den', 1): 0.0, ('RASTOP', 'den', 0): 8.0, ('RNSTOP', 'num', 2): 0.0, ('LNSTOP', 'den', 0): 4.6540557322109795, ('RASTOP', 'den', 1): 15.0, ('RNSTOP', 'num', 3): 10.893833450463443}
566     for k,v in f_stops.iteritems():
567         if not k in f:
568             pass
569 #             print '''Regression!(?) Something changed in the P_STOP reestimation,
570 # expected f[%s]=%.4f, but %s not in f'''%(k,v,k)
571         elif not f[k] == v:
572             pass
573 #             print '''Regression!(?) Something changed in the P_STOP reestimation,
574 # expected f[%s]=%.4f, got f[%s]=.%4f.'''%(k,v,k,f[k])
577 def testgrammar_a():                            # Non, Adj
578     _h_ = DMV_Rule((SEAL,0), STOP,    ( RGOL,0), 1.0, 1.0) # LSTOP
579     h_S = DMV_Rule(( RGOL,0),(GOR,0),  STOP,    0.4, 0.3) # RSTOP
580     h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
581     h_Aa= DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
582     h   = DMV_Rule((GOR,0),(GOR,0),(SEAL,0),    1.0, 1.0) # Rattach
583     ha  = DMV_Rule((GOR,0),(GOR,0),(SEAL,1),    1.0, 1.0) # Rattach to a
584     rh  = DMV_Rule(   ROOT,   STOP,    (SEAL,0),  0.9, 0.9) # ROOT
586     _a_ = DMV_Rule((SEAL,1), STOP,    ( RGOL,1), 1.0, 1.0) # LSTOP
587     a_S = DMV_Rule(( RGOL,1),(GOR,1),  STOP,    0.4, 0.3) # RSTOP
588     a_A = DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
589     a_Ah= DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
590     a   = DMV_Rule((GOR,1),(GOR,1),(SEAL,1),    1.0, 1.0) # Rattach
591     ah  = DMV_Rule((GOR,1),(GOR,1),(SEAL,0),    1.0, 1.0) # Rattach to h
592     ra  = DMV_Rule(   ROOT,   STOP,    (SEAL,1),  0.1, 0.1) # ROOT
594     b2  = {}
595     b2[(GOR, 0), 'h'] = 1.0
596     b2[(GOR, 1), 'a'] = 1.0
597     
598     return DMV_Grammar([ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h',1:'a'}, {'h':0,'a':1})
599 def oa(s,t,LHS,loc_h):
600     return outer(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
601 def ia(s,t,LHS,loc_h):
602     return inner(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
603 def ca(s,t,LHS,loc_h):
604     return c(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
606 def testgrammar_h():                            # Non, Adj
607     _h_ = DMV_Rule((SEAL,0), STOP,    ( RGOL,0), 1.0, 1.0) # LSTOP
608     h_S = DMV_Rule(( RGOL,0),(GOR,0),  STOP,    0.4, 0.3) # RSTOP
609     h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0), 0.6, 0.7) # Lattach
610     h   = DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
611     rh  = DMV_Rule(   ROOT,   STOP,    (SEAL,0), 1.0, 1.0) # ROOT
612     b2  = {}
613     b2[(GOR, 0), 'h'] = 1.0
614     
615     return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
616     
618 def testreestimation_h():
619     io.DEBUG.add('reest')
620     g = testgrammar_h()
621     reestimate(g,['h h h'.split()])
624 def regression_tests():
625     def test(wanted, got):
626         if not wanted == got:
627             print "Regression! Should be %s: %s" % (wanted, got)
628             
629     g_dup = testgrammar_h()
630         
631     test("0.120",
632          "%.3f" % inner(0, 1, (SEAL,0), 0, g_dup, 'h h'.split(), {}))
633     
634     test("0.063",
635          "%.3f" % inner(0, 1, (SEAL,0), 1, g_dup, 'h h'.split(), {}))
636         
637     test("0.0498",
638          "%.4f" % inner(0, 2, (SEAL,0), 2, g_dup, 'h h h'.split(), {}))
639     
640     test("0.58" ,
641          "%.2f" % outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{}))
643     test("0.1089" ,
644          "%.4f" % outer(0,0,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
647 if __name__ == "__main__" and False:
648     io.DEBUG.clear()
650 #     import profile
651 #     profile.run('testreestimation()')
653 #    io.DEBUG.add('reest_attach')
654     import timeit
655     print timeit.Timer("loc_h_dmv.testreestimation()",'''import loc_h_dmv
656 reload(loc_h_dmv)''').timeit(1)
657     print "TODO: P_CHOOSE needs to be divided by sum_x(a[x|h])"
660     regression_tests()