3 #import numpy # numpy provides Fast Arrays, for future optimization
6 from common_dmv
import *
7 SEALS
= [GOR
, RGOL
, SEAL
, NGOR
, NRGOL
] # overwriting here
10 if __name__
== "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP
,p_ATTACH
):
15 for (a
,h
,dir), p_ah
in p_ATTACH
.iteritems():
16 p_GO_AT
[a
,h
,dir, NON
] = p_ah
* (1-p_STOP
[h
, dir, NON
])
17 p_GO_AT
[a
,h
,dir, ADJ
] = p_ah
* (1-p_STOP
[h
, dir, ADJ
])
20 class CNF_DMV_Grammar(io
.Grammar
):
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
34 for r
in self
.all_rules():
35 str += "%s\n" % r
.__str
__(self
.numtag
)
39 return [ROOT
] + [(s_h
,h
)
40 for h
in self
.headnums()
43 def sent_rules(self
, sent_nums
):
44 sent_nums_stop
= sent_nums
+ [POS(STOP
)]
45 return [ r
for LHS
in self
.LHSs()
46 for r
in self
.arg_rules(LHS
, sent_nums
)
47 if POS(r
.L()) in sent_nums_stop
48 and POS(r
.R()) in sent_nums_stop
]
51 def mothersR(self
, w_node
, argnums
):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 if w_node
not in self
.__mothersR
:
54 self
.__mothersR
[w_node
] = [r
for LHS
in self
.LHSs()
55 for r
in self
.rules(LHS
)
57 argnums
.append(POS(STOP
))
58 return [r
for r
in self
.__mothersR
[w_node
]
59 if POS(r
.L()) in argnums
]
61 def mothersL(self
, w_node
, argnums
):
62 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
63 if w_node
not in self
.__mothersL
:
64 self
.__mothersL
[w_node
] = [r
for LHS
in self
.LHSs()
65 for r
in self
.rules(LHS
)
67 argnums
.append(POS(STOP
))
68 return [r
for r
in self
.__mothersL
[w_node
]
69 if POS(r
.R()) in argnums
]
73 def arg_rules(self
, LHS
, argnums
):
74 return [r
for r
in self
.rules(LHS
)
75 if (POS(r
.R()) in argnums
76 or POS(r
.L()) in argnums
)]
79 def make_all_rules(self
):
80 self
.new_rules([r
for LHS
in self
.LHSs()
81 for r
in self
._make
_rules
(LHS
, self
.headnums())])
83 def _make_rules(self
, LHS
, argnums
):
84 '''This is where the CNF grammar is defined. Also, s_dir_typ shows how
85 useful it'd be to split up the seals into direction and
89 return [CNF_DMV_Rule(LEFT
, LHS
, (SEAL
,h
), STOP
, self
.p_ROOT
[h
])
90 for h
in set(argnums
)]
93 return [] # only terminals from here on
94 s_dir_type
= { # seal of LHS
95 RGOL
: (RIGHT
, 'STOP'), NGOR
: (RIGHT
, 'ATTACH'),
96 SEAL
: (LEFT
, 'STOP'), NRGOL
: (LEFT
, 'ATTACH') }
97 dir_s_adj
= { # seal of h_daughter
98 RIGHT
: [(GOR
, True),(NGOR
, False)] ,
99 LEFT
: [(RGOL
,True),(NRGOL
,False)] }
100 dir,type = s_dir_type
[s_h
]
102 'ATTACH': [CNF_DMV_Rule(dir, LHS
, (s
, h
), (SEAL
,a
), self
.p_GO_AT
[a
,h
,dir,adj
])
103 for a
in set(argnums
) if (a
,h
,dir) in self
.p_ATTACH
104 for s
, adj
in dir_s_adj
[dir]] ,
105 'STOP': [CNF_DMV_Rule(dir, LHS
, (s
, h
), STOP
, self
.p_STOP
[h
,dir,adj
])
106 for s
, adj
in dir_s_adj
[dir]] }
110 def __init__(self
, numtag
, tagnum
, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
):
111 io
.Grammar
.__init
__(self
, numtag
, tagnum
, [], p_terminals
)
113 self
.p_ATTACH
= p_ATTACH
115 self
.p_GO_AT
= make_GO_AT(self
.p_STOP
, self
.p_ATTACH
)
116 self
.make_all_rules()
121 class CNF_DMV_Rule(io
.CNF_Rule
):
122 '''A single CNF rule in the PCFG, of the form
124 where LHS, L and R are 'nodes', eg. of the form (seals, head).
132 Different rule-types have different probabilities associated with
133 them, see formulas.pdf
136 return seals(self
.LHS())
139 return POS(self
.LHS())
141 def __init__(self
, dir, LHS
, h_daughter
, a_daughter
, prob
):
144 L
, R
= a_daughter
, h_daughter
146 L
, R
= h_daughter
, a_daughter
148 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
149 for b_h
in [LHS
, L
, R
]:
150 if seals(b_h
) not in SEALS
:
151 raise ValueError("seals must be in %s; was given: %s"
152 % (SEALS
, seals(b_h
)))
153 io
.CNF_Rule
.__init
__(self
, LHS
, L
, R
, prob
)
156 "'undefined' for ROOT"
157 if self
.__dir
== LEFT
:
158 return seals(self
.R()) == RGOL
160 return seals(self
.L()) == GOR
162 def __str__(self
, tag
=lambda x
:x
):
163 if self
.adj(): adj_str
= "adj"
164 else: adj_str
= "non_adj"
165 if self
.LHS() == ROOT
: adj_str
= ""
166 return "%s --> %s %s\t[%.2f] %s" % (node_str(self
.LHS(), tag
),
167 node_str(self
.L(), tag
),
168 node_str(self
.R(), tag
),
178 ###################################
179 # dmv-specific version of inner() #
180 ###################################
181 def inner(i
, j
, LHS
, g
, sent
, ichart
={}):
182 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
186 sent_nums
= g
.sent_nums(sent
)
190 "Tabs for debug output"
192 if (i
, j
, LHS
) in ichart
:
194 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart
[i
, j
, LHS
], i
, j
, node_str(LHS
))
195 return ichart
[i
, j
, LHS
]
197 # if seals(LHS) == RGOL then we have to STOP first
198 if i
== j
-1 and seals(LHS
) == GOR
:
199 if (LHS
, O(i
,j
)) in g
.p_terminals
:
200 prob
= g
.p_terminals
[LHS
, O(i
,j
)] # "b[LHS, O(s)]" in Lari&Young
204 print "%sLACKING TERMINAL:" % tab()
206 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob
, node_str(LHS
), O(i
,j
))
209 p
= 0.0 # "sum over j,k in a[LHS,j,k]"
210 for rule
in g
.arg_rules(LHS
, sent_nums
):
212 print "%ssumming rule %s i:%d j:%d" % (tab(),rule
,i
,j
)
215 # if it's a STOP rule, rewrite for the same xrange:
216 if (L
== STOP
) or (R
== STOP
):
218 pLR
= e(i
, j
, R
, n_t
+1)
220 pLR
= e(i
, j
, L
, n_t
+1)
223 print "%sp= %.4f (STOP)" % (tab(), p
)
225 elif j
> i
+1 and seals(LHS
) != GOR
:
226 # not a STOP, attachment rewrite:
227 for k
in xtween(i
, j
): # i<k<j
228 p_L
= e(i
, k
, L
, n_t
+1)
229 p_R
= e(k
, j
, R
, n_t
+1)
230 p
+= rule
.p() * p_L
* p_R
232 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p
,p_L
,p_R
,rule
.p())
233 ichart
[i
, j
, LHS
] = p
237 inner_prob
= e(i
,j
,LHS
, 0)
239 print debug_ichart(g
,sent
,ichart
)
241 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
244 def debug_ichart(g
,sent
,ichart
):
245 str = "---ICHART:---\n"
246 for (i
,j
,LHS
),v
in ichart
.iteritems():
247 if type(v
) == dict: # skip 'tree'
249 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS
,g
.numtag
),
250 sent
[i
], sent
[j
-1], v
)
251 str += "---ICHART:end---\n"
255 def inner_sent(g
, sent
, ichart
={}):
256 return sum([inner(0, len(sent
), ROOT
, g
, sent
, ichart
)])
259 #######################################
260 # cnf_dmv-specific version of outer() #
261 #######################################
262 def outer(i
,j
,w_node
, g
, sent
, ichart
={}, ochart
={}):
264 # or we could just look it up in ichart, assuming ichart to be done
265 return inner(i
, j
, LHS
, g
, sent
, ichart
)
267 sent_nums
= g
.sent_nums(sent
)
268 if POS(w_node
) not in sent_nums
[i
:j
]:
269 # sanity check, w must be able to dominate sent[i:j]
273 if (i
,j
,w_node
) in ochart
:
274 return ochart
[(i
, j
, w_node
)]
276 if i
== 0 and j
== len(sent
):
278 else: # ROOT may only be used on full sentence
279 return 0.0 # but we may have non-ROOTs over full sentence too
283 for rule
in g
.mothersL(w_node
, sent_nums
): # rule.L() == w_node
284 if 'OUTER' in DEBUG
: print "w_node:%s (L) ; %s"%(node_str(w_node
),rule
)
286 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
287 if 'OUTER' in DEBUG
: print p0
290 for k
in xgt(j
,sent
): # i<j<k
291 p0
= f(i
,k
, rule
.LHS() ) * rule
.p() * e(j
,k
, rule
.R() )
292 if 'OUTER' in DEBUG
: print p0
295 for rule
in g
.mothersR(w_node
, sent_nums
): # rule.R() == w_node
296 if 'OUTER' in DEBUG
: print "w_node:%s (R) ; %s"%(node_str(w_node
),rule
)
298 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
299 if 'OUTER' in DEBUG
: print p0
302 for k
in xlt(i
): # k<i<j
303 p0
= e(k
,i
, rule
.L() ) * rule
.p() * f(k
,j
, rule
.LHS() )
304 if 'OUTER' in DEBUG
: print p0
307 ochart
[i
,j
,w_node
] = p
312 # end outer(i,j,w_node, g,sent, ichart,ochart)
316 ##############################
317 # reestimation, todo: #
318 ##############################
319 def reest_zeros(rules
):
320 f
= { ('den',ROOT
) : 0.0 }
322 for nd
in ['num','den']:
323 f
[nd
, r
.LHS(), r
.L(), r
.R()] = 0.0
326 def reest_freq(g
, corpus
):
327 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
328 f
= reest_zeros(g
.all_rules())
332 p_sent
= None # 50 % speed increase on storing this locally
334 def c_g(i
,j
,LHS
,sent
):
337 return e_g(i
,j
,LHS
,sent
) * f_g(i
,j
,LHS
,sent
) / p_sent
339 def w1_g(i
,j
,rule
,sent
): # unary (stop) rules, LHS -> child_node
340 if rule
.L() == STOP
: child
= rule
.R()
341 elif rule
.R() == STOP
: child
= rule
.L()
342 else: raise ValueError, "expected a stop rule: %s"%(rule
,)
344 if p_sent
== 0.0: return 0.0
346 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
347 if p_out
== 0.0: return 0.0
349 return rule
.p() * e_g(i
,j
,child
,sent
) * p_out
/ p_sent
351 def w_g(i
,j
,rule
,sent
):
352 if p_sent
== 0.0 or i
+1 == j
: return 0.0
354 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
355 if p_out
== 0.0: return 0.0
358 for k
in xtween(i
,j
):
359 p
+= rule
.p() * e_g(i
,k
,rule
.L(),sent
) * e_g(k
,j
,rule
.R(),sent
) * p_out
362 def f_g(i
,j
,LHS
,sent
):
363 if (i
,j
,LHS
) in ochart
:
365 return ochart
[i
,j
,LHS
]
367 return outer(i
,j
,LHS
,g
,sent
,ichart
,ochart
)
369 def e_g(i
,j
,LHS
,sent
):
370 if (i
,j
,LHS
) in ichart
:
372 return ichart
[i
,j
,LHS
]
374 return inner(i
,j
,LHS
,g
,sent
,ichart
)
376 for s_num
,sent
in enumerate(corpus
):
377 if s_num
%5==0: print "s.num %d"%s_num
,
378 if 'REEST' in DEBUG
: print sent
381 # since we keep re-using p_sent, it seems better to have
382 # sentences as the outer loop; o/w we'd have to keep every chart
383 p_sent
= inner_sent(g
, sent
, ichart
)
385 sent_nums
= g
.sent_nums(sent
)
386 sent_rules
= g
.sent_rules(sent_nums
)
388 LHS
, L
, R
= r
.LHS(), r
.L(), r
.R()
389 if 'REEST' in DEBUG
: print r
391 f
['num',LHS
,L
,R
] += r
.p() * e_g(0, len(sent
), R
, sent
)
392 f
['den',ROOT
] += p_sent
393 continue # !!! o/w we add wrong values to it below
394 if L
== STOP
or R
== STOP
:
398 for i
in xlt(len(sent
)):
399 for j
in xgt(i
, sent
):
400 f
['num',LHS
,L
,R
] += w(i
,j
, r
, sent
)
401 f
['den',LHS
,L
,R
] += c_g(i
,j
, LHS
, sent
) # v_q
405 def reestimate(g
, corpus
):
406 f
= reest_freq(g
, corpus
)
407 print "applying f to rules"
408 for r
in g
.all_rules():
410 r
.prob
= f
['den',ROOT
]
412 r
.prob
= f
['den',r
.LHS(),r
.L(),r
.R()]
414 r
.prob
= f
['num',r
.LHS(),r
.L(),r
.R()] / r
.prob
418 ##############################
419 # Testing functions: #
420 ##############################
422 # make sure we use the same data:
423 from loc_h_dmv
import testcorpus
427 return cnf_harmonic
.initialize(testcorpus
)
429 def testreestimation():
430 from loc_h_dmv
import testcorpus
432 g
= reestimate(g
, testcorpus
[0:4])
435 def testgrammar_a(): # Non, Adj
436 _h_
= CNF_DMV_Rule((SEAL
,0), STOP
, ( RGOL
,0), 1.0, 1.0) # LSTOP
437 h_S
= CNF_DMV_Rule(( RGOL
,0),(GOR
,0), STOP
, 0.4, 0.3) # RSTOP
438 h_A
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,0),( RGOL
,0),0.2, 0.1) # Lattach
439 h_Aa
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,1),( RGOL
,0),0.4, 0.6) # Lattach to a
440 h
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,0), 1.0, 1.0) # Rattach
441 ha
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,1), 1.0, 1.0) # Rattach to a
442 rh
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,0), 0.9, 0.9) # ROOT
444 _a_
= CNF_DMV_Rule((SEAL
,1), STOP
, ( RGOL
,1), 1.0, 1.0) # LSTOP
445 a_S
= CNF_DMV_Rule(( RGOL
,1),(GOR
,1), STOP
, 0.4, 0.3) # RSTOP
446 a_A
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,1),( RGOL
,1),0.4, 0.6) # Lattach
447 a_Ah
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,0),( RGOL
,1),0.2, 0.1) # Lattach to h
448 a
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,1), 1.0, 1.0) # Rattach
449 ah
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,0), 1.0, 1.0) # Rattach to h
450 ra
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,1), 0.1, 0.1) # ROOT
452 p_rules
= [ h_Aa
, ha
, a_Ah
, ah
, ra
, _a_
, a_S
, a_A
, a
, rh
, _h_
, h_S
, h_A
, h
]
456 b
[(GOR
, 0), 'h'] = 1.0
457 b
[(GOR
, 1), 'a'] = 1.0
459 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
464 p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
= {},{},{},{}
466 p_STOP
[h
,LEFT
,NON
] = 1.0
467 p_STOP
[h
,LEFT
,ADJ
] = 1.0
468 p_STOP
[h
,RIGHT
,NON
] = 0.4
469 p_STOP
[h
,RIGHT
,ADJ
] = 0.3
470 p_ATTACH
[h
,h
,LEFT
] = 1.0 # not used
471 p_ATTACH
[h
,h
,RIGHT
] = 1.0 # not used
473 p_terminals
[(GOR
, 0), 'h'] = 1.0
475 g
= CNF_DMV_Grammar({h
:'h'}, {'h':h
}, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
)
477 g
.p_GO_AT
[h
,h
,LEFT
,NON
] = 0.6 # these probabilities are impossible
478 g
.p_GO_AT
[h
,h
,LEFT
,ADJ
] = 0.7 # so add them manually...
479 g
.p_GO_AT
[h
,h
,RIGHT
,NON
] = 1.0
480 g
.p_GO_AT
[h
,h
,RIGHT
,ADJ
] = 1.0
485 def testreestimation_h():
488 return reestimate(g
,['h h h'.split()])
490 def regression_tests():
491 test("0.1830", # = .120 + .063, since we have no loc_h
492 "%.4f" % inner(0, 2, (SEAL
,0), testgrammar_h(), 'h h'.split(), {}))
494 test("0.1842", # = .0498 + .1092 +.0252
495 "%.4f" % inner(0, 3, (SEAL
,0), testgrammar_h(), 'h h h'.split(), {}))
497 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
500 "%.2f" % outer(1, 3, ( RGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
502 "%.2f" % outer(1, 3, (NRGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
505 if __name__
== "__main__":
509 # profile.run('testreestimation()')
511 # DEBUG.add('reest_attach')
513 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
514 # reload(cnf_dmv)''').timeit(1)
516 if __name__
== "__main__":