some late comments :)
[gostyle.git] / game_to_vec.py
blobb77b5814b4ee44b0293d0e1e6cd368c5e9b09db9
1 import logging
2 import subprocess
3 from subprocess import PIPE
5 import os
6 import sys
7 from os import remove
8 from os.path import abspath, exists
9 import math
11 import itertools
12 from itertools import chain
13 import shutil
14 import re
15 import numpy
17 import utils
18 from utils import misc, godb_models, godb_session
19 from utils.db_cache import declare_pure_function, cache_result
20 from utils.misc import first_true_pred, partial_right, partial
21 from utils.utils import head
22 from utils.colors import *
23 from utils.godb_models import ProcessingError
26 from result_file import ResultFile, get_output_resultfile, get_output_resultpair
27 from config import OUTPUT_DIR
28 import pachi
29 from pachi import scan_raw_patterns, generate_spatial_dictionary
31 """
32 Contains code for conversion of a game (or list of them) into a vector,
33 using pachi.
34 Moreover it allows to form .tab files to be used by the Orange datamining framework
35 """
38 pat_file_regexp = '^\s*(\d+)\s*(.+)$'
40 def _make_interval_annotations(l, varname):
41 """
42 >>> _make_interval_annotations([10,11,12,13], 'X')
43 ['X <= 10', 'X == 11', 'X == 12', 'X == 13', 'X > 13']
44 >>> _make_interval_annotations([22], 'X')
45 ['X <= 22', 'X > 22']
46 >>> _make_interval_annotations([-1, 20], 'X')
47 ['X <= -1', '-1 < X <= 20', 'X > 20']
49 """
50 if not all( misc.is_int(x) for x in l):
51 raise ValueError("Interval boundaries must be a number.")
52 if not l:
53 return [ "any " + varname ]
55 prev = None
56 annots = []
57 for point in l + ['LAST']:
58 s = varname
59 # if the interval size is 1 specify the interval precisely
60 if point != 'LAST' and prev == point - 1:
61 s = "%s == %d" % (s, point)
62 else:
63 # if not first, add left boundary
64 if prev != None:
65 # when we do not have right boundary as well
66 if point == "LAST":
67 # nicer like this I guess
68 s = "%s > %d" % (s, prev)
69 else:
70 s = "%d < %s" % (prev, s)
71 # if not last, add right boundary
72 if point != 'LAST':
73 s = "%s <= %d" % (s, point)
75 annots.append(s)
76 prev = point
78 return annots
80 ## game -> BlackWhite( vector_black, vector_white )
81 class BWBdistVectorGenerator:
82 def __init__(self, by_line=[2,3,4], by_moves=[26,76]):
83 self.by_line = by_line
84 self.by_moves = by_moves
86 if any( x%2 for x in by_moves ):
87 logging.warn("BWDistVectorGenerator called with odd number of moves"
88 "specifying the hist size => this means that the players"
89 "wont have the same number of moves in the buckets!!")
91 # nice annotations
92 line_annots = _make_interval_annotations(by_line, 'bdist')
93 move_annots = _make_interval_annotations(by_moves, 'move')
94 #line_annots = [ 'bdist <= %d'%line for line in by_line ] + [ 'bdist other']
95 # move_annots = [ 'move <= %d'%move for move in by_moves ] + [ 'move other']
97 self.annotations = [ "(bdist histogram: %s, %s)"%(m,b) for m,b in itertools.product(move_annots, line_annots) ]
98 self.types = [ "continuous" ] * len(self.annotations)
100 def leq_fac(val):
101 return lambda x : x <= val
103 # predicates giving bucket coordinate
104 self.line_preds = [ leq_fac(line) for line in by_line ] + [ lambda line : True ]
105 self.move_preds = [ leq_fac(movenum) for movenum in by_moves ] + [ lambda movenum : True ]
107 def __repr__(self):
108 return 'BWBdistVectorGenerator(by_line=%s, by_moves=%s)'%(repr(self.by_line),
109 repr(self.by_moves))
111 def __call__(self, game):
113 For a game, creates histograms of moves distance from border.
114 The histograms' granularity is specified by @by_line and @by_moves parameters.
116 The @by_moves makes different histogram for each game phase, e.g.:
117 by_moves=[] makes one histogram for whole game
118 by_moves=[50] makes two histograms, one for first 50 moves (including)
119 second for the rest
120 by_moves=[26, 76] makes three histograms,
121 first 26 moves (X <=26) ~ opening
122 first 76 moves (26 < X <= 76) ~ middle game
123 rest of the game (76 < X) ~ end game
124 NOTE: of the by moves number should be even, so that we count the same
125 number of moves for every player.
127 The @by_line specifies granularity of each histogram, that is
128 by_line = [3] each hist has 2 buckets, one counts moves on first three
129 lines, second for the rest
131 by_line = [3, 4, 5] four buckets/histogram, X <= 3, X = 4, X = 5, X > 5
133 # scan game, ignore spatials
134 col_pat = pachi.scan_raw_patterns(game, patargs='xspat=0', skip_empty=False)
136 buckets = {}
137 for color in PLAYER_COLORS:
138 # histogram
139 buckets[color] = numpy.zeros(len(self.line_preds) * len(self.move_preds))
141 for movenum, (color, pat) in enumerate(col_pat):
142 try:
143 bdist = pat.first_payload('border')
144 except StopIteration:
145 continue
147 # X and Y coordinates
148 line_bucket = first_true_pred(self.line_preds, bdist + 1) # line = bdist + 1
149 move_bucket = first_true_pred(self.move_preds, movenum + 1) # movenum is counted from zero
151 # histogram[color][X][Y] += 1
152 xy = line_bucket + move_bucket * len(self.line_preds)
153 buckets[color][xy] += 1
155 #print movenum, color, bdist, " \t",
156 #print line_bucket, move_bucket,
157 #print color, xy
159 return BlackWhite(buckets[PLAYER_COLOR_BLACK], buckets[PLAYER_COLOR_WHITE])
161 ## game -> BlackWhite( vector_black, vector_white )
162 class BWLocalSeqVectorGenerator:
163 def __init__(self, local_threshold=5):
164 self.local_threshold = local_threshold
165 self.annotations = [ '(local seq < %d: sente)'%local_threshold,
166 '(local seq < %d: gote)'%local_threshold,
167 '(local seq < %d: sente - gote)'%local_threshold, ]
168 self.types = [ "continuous" ] * len(self.annotations)
170 def __repr__(self):
171 return 'BWLocalSeqVectorGenerator(local_threshold=%s)'%(repr(self.local_threshold))
173 def __call__(self, game):
174 """self.local_threshold gives threshold specifiing what is considered to be a local
175 sequence, moves closer (or equal) than self.local_threshold in gridcular matrix
176 to each other are considered local."""
177 # scan game, ignore spatials
178 col_pat = pachi.scan_raw_patterns(game, patargs='xspat=0', skip_empty=False)
180 SENTE_COOR = 0
181 GOTE_COOR = 1
182 DIFF_COOR = 2
184 count = {PLAYER_COLOR_BLACK : numpy.zeros(3),
185 PLAYER_COLOR_WHITE : numpy.zeros(3)}
187 last_local = False
188 seq_start = None
189 for movenum, (color, pat) in enumerate(col_pat):
190 if not pat.has_feature('cont'):
191 local = False
192 else:
193 local = pat.first_payload('cont') <= self.local_threshold
195 # if the sequence just started
196 if local and not last_local:
197 # this color had to reply locally, so it was the other guy that
198 # started the sequence
199 seq_start = the_other_color(color)
201 # if the sequence just ended
202 if not local and last_local:
203 # the player who started the sequence gets to continue elsewhere
204 if color == seq_start:
205 count[seq_start][ SENTE_COOR ] += 1
206 # if he does not <=> he lost tempo with the sequence
207 else:
208 count[seq_start][ GOTE_COOR ] += 1
210 last_local = local
212 for color in PLAYER_COLORS:
213 cnt = count[color]
214 cnt[DIFF_COOR] = cnt[SENTE_COOR] - cnt[GOTE_COOR]
216 return BlackWhite(count[PLAYER_COLOR_BLACK], count[PLAYER_COLOR_WHITE])
218 ## game -> BlackWhite( vector_black, vector_white )
219 class BWCaptureVectorGenerator:
220 def __init__(self, by_moves=[26,76], offset=6, payload_size=4):
221 """The params @offset and @payload size have to be the constants from pachi/pattern.h,
222 corresponding to:
223 offset = PF_CAPTURE_COUNTSTONES
224 payload_size = CAPTURE_COUNTSTONES_PAYLOAD_SIZE
226 self.offset = offset
227 self.payload_size = payload_size
228 self.by_moves = by_moves
230 if any( x%2 for x in by_moves ):
231 logging.warn("BWCaptureVectorGenerator called with odd number of moves"
232 "specifying the hist size => this means that the players"
233 "wont have the same number of moves in the buckets!!")
235 # nice annotations
236 capture_annots = [ 'captured', 'lost', 'difference' ]
237 move_annots = _make_interval_annotations(by_moves, 'move')
239 self.annotations = [ "(capture histogram: %s, %s)"%(m,b) for m,b in itertools.product(move_annots, capture_annots) ]
240 self.types = [ "continuous" ] * len(self.annotations)
242 def leq_fac(val):
243 return lambda x : x <= val
245 # predicates giving bucket coordinate
246 self.move_preds = [ leq_fac(move) for move in by_moves ] + [ lambda movenum : True ]
248 def __repr__(self):
249 args = map(repr, [self.by_moves, self.offset, self.payload_size])
250 return 'BWCaptureVectorGenerator(by_moves=%s, offset=%s, payload_size=%s)'% tuple(args)
252 def __call__(self, game):
253 # scan game, ignore spatials
254 col_pat = pachi.scan_raw_patterns(game, patargs='xspat=0', skip_empty=False)
256 buckets = {}
257 for color in PLAYER_COLORS:
258 buckets[color] = numpy.zeros(len(self.move_preds))
260 for movenum, (color, pat) in enumerate(col_pat):
261 if pat.has_feature('capture'):
262 captured = pat.first_payload('capture') >> self.offset
263 captured = (2 ** self.payload_size - 1 ) & captured
265 move_bucket = first_true_pred(self.move_preds, movenum + 1) # counted from zero
266 buckets[color][move_bucket] += captured
268 ret = {}
269 for color in PLAYER_COLORS:
270 ret[color] = numpy.zeros(3 * len(self.move_preds))
272 for mp in xrange(len(self.move_preds)):
273 for color in PLAYER_COLORS:
274 # I captured
275 ret[color][3 * mp] = buckets[color][mp]
276 # I lost
277 ret[color][3 * mp + 1] = buckets[the_other_color(color)][mp]
278 # diff
279 ret[color][3 * mp + 2] = ret[color][3 * mp] - ret[color][3 * mp + 1]
282 return BlackWhite(ret[PLAYER_COLOR_BLACK], ret[PLAYER_COLOR_WHITE])
284 ## game -> BlackWhite( vector_black, vector_white )
285 class BWWinStatVectorGenerator:
286 def __init__(self):
287 self.annotations = [
288 '(wins by points)',
289 '(wins by resign)',
290 '(wp - wr)',
291 '(lost by points)',
292 '(lost by resign)',
293 '(lp - lr)'
295 self.types = [ "continuous" ] * len(self.annotations)
297 def __repr__(self):
298 return 'BWWinStatVectorGenerator2()'
300 def __call__(self, game):
301 """"""
302 result = str(game.sgf_header.get('RE', '0'))
304 if result.lower() in ['0', 'jigo', 'draw']:
305 raise ProcessingError(repr(self) + " Jigo")
307 match = re.match(r'^([BW])\+(.*)$', result)
308 if not match:
309 raise ProcessingError(repr(self) + ' Could not find result sgf tag.')
311 player, val = match.group(1), match.group(2)
312 if ( val.lower().startswith('f') or # forfeit
313 val.lower().startswith('t') ): # time
314 raise ProcessingError(repr(self) + ' Forfeit, time.')
316 loses = [0, 0, 0]
317 # by resign
318 if val.lower().startswith('r'):
319 wins = [0, 1, -1]
320 else:
321 # by points
322 try:
323 points = float(val)
324 except ValueError:
325 raise ProcessingError(repr(self) + ' Points not float.')
326 wins = [1, 0, 1]
328 if player == 'B':
329 black = numpy.array( wins + loses )
330 white = numpy.array( loses + wins )
331 else:
332 white = numpy.array( wins + loses )
333 black = numpy.array( loses + wins )
335 return BlackWhite(black, white)
339 ## game -> BlackWhite( vector_black, vector_white )
340 class BWWinPointsStatVectorGenerator:
341 def __init__(self):
342 self.annotations = [
343 '(wins #points)',
344 '(loses #points)',
346 self.types = [ "continuous" ] * len(self.annotations)
348 def __repr__(self):
349 return 'BWWinPointsStatVectorGenerator2()'
351 def __call__(self, game):
352 """"""
353 result = str(game.sgf_header.get('RE', '0'))
355 if result.lower() in ['0', 'jigo', 'draw']:
356 raise ProcessingError(repr(self) + " Jigo")
358 match = re.match(r'^([BW])\+(.*)$', result)
359 if not match:
360 raise ProcessingError(repr(self) + ' Could not find result sgf tag.')
362 player, val = match.group(1), match.group(2)
363 if ( val.lower().startswith('f') or # forfeit
364 val.lower().startswith('t') or # time
365 val.lower().startswith('r') # resign
367 raise ProcessingError(repr(self) + ' Forfeit, time, resign.')
369 try:
370 points = float(val)
371 except ValueError:
372 raise ProcessingError(repr(self) + ' Points not float.')
374 # if black wins
375 black = numpy.array( [points, 0] )
376 white = numpy.array( [0, points] )
377 # if white wins
378 if player == 'W':
379 black, white = white, black
381 return BlackWhite(black, white)
383 # - for black - transform_rawpatfile -
384 # /
385 # game -> raw_patternscan_game --
386 # \
387 # - for white ----- || -----
389 #@cache_result
390 @declare_pure_function
391 def raw_patternscan_game(game, spatial_dict, patargs=''):
392 assert spatial_dict.exists(warn=True)
393 ret = get_output_resultpair(suffix='.raw.pat')
395 with open(ret.black.filename, mode='w') as fb:
396 with open(ret.white.filename, mode='w') as fw:
397 for color, pat in scan_raw_patterns(game, spatial_dict, patargs=patargs):
398 fd = fb if color == PLAYER_COLOR_BLACK else fw
399 # write output for the desired player
400 fd.write("%s\n"%pat)
401 #logging.debug(gtp + ":" + pat)
403 #logging.info("Generated Raw Patternfiles for game %s, %s"%(game, ret))
404 return ret
406 #@cache_result
407 @declare_pure_function
408 def transform_rawpatfile(rawpat_file, ignore=set(), transform={}, ignore_empty=True):
409 """Transforms raw pattern file line by line, by ignoring certain features (and their payloads)
410 @ignore and transforming payloads with @transform. If @ignore_empty is specified,
411 empty patterns are ignored.
413 transform_rawpatfile(file, ignore=set('s', 'cont'), transform={'border':lambda x: x - 1})
414 (s:20)
415 (s:10 border:5 cont:10)
416 (s:20 cont:1)
417 (capture:18)
419 will produce
420 (border:4)
421 (capture:18)
424 ret = get_output_resultfile('.raw.pat')
425 with open(ret.filename, mode='w') as fout:
426 with open(rawpat_file.filename, mode='r') as fin:
427 for line in fin:
428 pat = pachi.Pattern(line).reduce(lambda feat, _: not feat in ignore)
429 fpairs = []
430 for f, p in pat:
431 p = transform.get(f, lambda x:x)(p)
432 fpairs.append((f, p))
434 if ignore_empty and not fpairs:
435 continue
437 fout.write( "%s\n"%pachi.Pattern(fpairs=fpairs) )
438 return ret
440 #@cache_result
441 @declare_pure_function
442 def summarize_rawpat_file(rawpat_file):
443 """Transforms raw pattern file into summarized one:
444 (s:20)
445 (s:10 border:5)
446 (s:20)
447 (s:40)
448 (s:20)
449 ========>
450 3 (s:20)
451 1 (s:10 border:5)
452 1 (s:40)
454 result_file = get_output_resultfile('.pat')
456 script="cat %s | sort | uniq -c | sort -rn > %s "%(rawpat_file.filename, result_file.filename)
458 p = subprocess.Popen(script, shell=True, stderr=PIPE)
459 _, stderr = p.communicate()
460 if stderr:
461 logging.warn("subprocess summarize stderr:\n%s"%(stderr,))
462 if p.returncode:
463 raise RuntimeError("Child sumarize failed, exitcode %d."%(p.returncode,))
465 return result_file
467 class SummarizeMerger(godb_models.Merger):
468 """Used to sum Summarized Pattern files:
469 patfile_1:
470 3 (s:20)
471 1 (s:10 border:5)
472 1 (s:40)
474 patfile_2:
475 3 (s:20)
476 2 (s:15)
477 1 (s:10 border:5)
479 m = SummarizeMerger()
480 m.add(patfile_1)
481 m.add(patfile_2)
482 patres = m.finish()
484 Now, patres is:
485 6 (s:20)
486 2 (s:15)
487 2 (s:10 border:5)
488 1 (s:40)
490 def __init__(self):
491 self.reset()
493 def start(self, bw_gen):
494 self.reset()
496 def reset(self):
497 self.cd = {}
499 def add(self, pat_file, color):
500 with open(pat_file.filename) as fin:
501 for line in fin:
502 match = re.match(pat_file_regexp, line)
503 if not match:
504 raise IOError("Wrong file format: " + pat_file)
505 count, pattern = int(match.group(1)), match.group(2)
506 self.cd[pattern] = self.cd.get(pattern, 0) + count
508 def finish(self):
509 result_file = get_output_resultfile('.pat')
510 with open(result_file.filename, 'w') as fout:
511 firstlen = None
512 for pattern, count in sorted(self.cd.iteritems(), key=lambda kv : - kv[1]):
513 if firstlen == None:
514 # get number of decimal places, so that the file is nicely formatted
515 firstlen = 1 + int(math.log10(count))
517 # prefix the count with 2 spaces, see pat_file_regexp for format
518 s = "%" + str(2 + firstlen) + "d %s\n"
519 fout.write(s%(count, pattern))
521 self.reset()
522 return result_file
525 class VectorSumMerger(godb_models.Merger):
526 def __init__(self):
527 self.reset()
529 def start(self, bw_gen):
530 assert all( tp == 'continuous' for tp in bw_gen.types )
531 self.sofar = numpy.zeros(len(bw_gen.types))
533 def reset(self):
534 self.sofar = None
536 def add(self, vector, color=None):
537 if self.sofar == None:
538 self.sofar = numpy.zeros(vector.shape)
539 self.sofar += vector
541 def finish(self):
542 if self.sofar == None:
543 self.sofar = numpy.zeros(0)
544 ret = self.sofar
545 self.reset()
546 return ret
548 class VectorArithmeticMeanMerger(godb_models.Merger):
549 def __init__(self):
550 self.reset()
552 def start(self, bw_gen):
553 self.reset()
554 self.summ.start(bw_gen)
556 def reset(self):
557 self.count = 0
558 self.summ = VectorSumMerger()
560 def add(self, vector, color=None):
561 self.count += 1
562 self.summ.add(vector)
564 def finish(self):
565 if not self.count:
566 ret = self.summ.finish()
567 else:
568 ret = self.summ.finish() / self.count
570 self.reset()
571 return ret
573 # so that the fc has nice repr
574 @declare_pure_function
575 def identity(obj):
576 return obj
578 @declare_pure_function
579 def linear_rescale(vec, a=-1, b=1):
580 """Linearly rescales elements in vector so that:
581 min(vec) gets mapped to a
582 max(vec) gets mapped to b
583 the intermediate values get remapped linearly between
585 assert a <= b
586 MIN, MAX = vec.min(), vec.max()
587 if MIN == MAX:
588 # return average value of the set
589 return (float(a + b) / 2) * numpy.ones(vec.shape)
590 return a + (vec - MIN) * ( float(b - a) / (MAX - MIN) )
592 @declare_pure_function
593 def natural_rescale(vec):
594 return vec / numpy.sum(vec)
596 @declare_pure_function
597 def log_rescale(vec, a=-1, b=1):
598 return linear_rescale(numpy.log(1 + vec), a, b)
600 class VectorApply(godb_models.Merger):
601 def __init__(self, merger,
602 add_fc=identity,
603 finish_fc=identity ):
604 self.merger = merger
605 self.finish_fc = finish_fc
606 self.add_fc = add_fc
608 def start(self, bw_gen):
609 self.merger.start(bw_gen)
611 def add(self, vector, color=None):
612 self.merger.add(self.add_fc(vector), color)
614 def finish(self):
615 return self.finish_fc( self.merger.finish() )
617 def __repr__(self):
618 return "VectorApply(%s, add_fc=%s, finish_fc=%s)" % (repr(self.merger),
619 repr(self.add_fc),
620 repr(self.finish_fc))
622 class PatternVectorMaker:
623 def __init__(self, all_pat, n):
624 self.all_pat = all_pat
625 self.n = n
627 self.annotations = []
628 self.pat2order = {}
630 with open(self.all_pat.filename, 'r') as fin:
631 # take first n patterns
632 for num, line in enumerate(fin):
633 if num >= self.n:
634 break
635 match = re.match(pat_file_regexp, line)
636 if not match:
637 raise IOError("Wrong file format: " + self.all_pat)
638 pattern = match.group(2)
639 self.pat2order[pattern] = num
640 self.annotations.append(pattern)
642 self.types = [ "continuous" ] * len(self.annotations)
644 if len(self.pat2order) < self.n:
645 raise ValueError("Input file all_pat '%s' does not have enough lines."%(self.all_pat))
647 def __repr__(self):
648 return "PatternVectorMaker(all_pat=%s, n=%d)"%(self.all_pat, self.n)
650 def __call__(self, sum_patfile):
651 vector = numpy.zeros(self.n)
652 added = 0
653 with open(sum_patfile.filename, 'r') as fin:
654 for line in fin:
655 match = re.match(pat_file_regexp, line)
656 if not match:
657 raise IOError("Wrong file format: " + str(sum_patfile))
659 index = self.pat2order.get(match.group(2), None)
660 if index != None:
661 vector[index] += int(match.group(1))
662 added += 1
664 # no need to walk through the whole files, the patterns (match.group(2))
665 # are unique since the patfile is summarized
666 if added >= self.n:
667 break
669 return vector
671 ## game -> BlackWhite( vector_black, vector_white )
672 class BWPatternVectorGenerator:
673 def __init__(self, bw_game_summarize, pattern_vector_maker):
674 self.pattern_vector_maker = pattern_vector_maker
675 self.bw_game_summarize = bw_game_summarize
677 self.annotations = pattern_vector_maker.annotations
678 self.types = pattern_vector_maker.types
680 def __repr__(self):
681 return "BWPatternVectorGenerator(bw_game_summarize=%s, pattern_vector_maker=%s)"%(
682 repr(self.bw_game_summarize), repr(self.pattern_vector_maker))
684 def __call__(self, game):
685 bw = self.bw_game_summarize(game)
686 return bw.map_both(self.pattern_vector_maker)
688 #@cache_result
689 @declare_pure_function
690 def process_game(game, init, pathway):
691 bw = init(game)
692 return bw.map_pathway(pathway)
694 @cache_result
695 @declare_pure_function
696 def process_one_side_list(osl, merger, bw_processor):
697 return osl.for_one_side_list( merger, bw_processor)
699 ## Process One Side List
700 class OSLVectorGenerator:
702 Maps one side lists to vectors, using different game vector generators (e.g. BWPatternVectorGenerator), e.g:
703 OSLVectorGenerator([(vg1, m1), (vg2, m2)])
705 game1 m1.add(vg1(game1)) m2.add(vg2(game1))
706 game2 m1.add(vg1(game2)) m2.add(vg2(game2))
707 . | |
708 . | |
709 . | |
710 game666 m1.add(vg1(game666)) m2.add(vg2(game666))
711 m1.finish() m2.finish()
712 = [1,2,3,4,5] = [6,7,8,9,10]
713 vg1.annotations vg2.annotations
714 = [f1, ..., f5] =[f6, ..., f10]
715 ----------------------------------------------
716 result = [ 1,2,3,4,5,6,7,8,9,10 ]
717 annotations = [ f1, ..., f10 ]
719 def __init__(self, gen_n_merge, annotate_featurewise=True):
720 self.gen_n_merge = gen_n_merge
721 self.annotate_featurewise = annotate_featurewise
722 self.functions = []
723 self.annotations = []
724 self.types = []
726 for num, (game_vg, merger) in enumerate(gen_n_merge):
727 self.functions.append(
728 # this function maps one_side_list to a vector
729 # where vectors from a game in the osl are merged using the merger
730 partial_right(process_one_side_list, merger, game_vg ))
732 anns = game_vg.annotations
733 if annotate_featurewise:
734 anns = [ 'f%d%s' % (num, an) for an in anns ]
736 self.annotations.extend(anns)
737 self.types.extend(game_vg.types)
739 def __repr__(self):
740 return "OSLVectorGenerator(gen_n_merge=%s, annotate_featurewise=%s)"%(repr(self.gen_n_merge),
741 repr(self.annotate_featurewise) )
743 def __call__(self, osl):
744 # stack vectors from different generators together
745 return numpy.hstack( [ f(osl) for f in self.functions ] )
747 def make_all_pat(osl, bw_summarize_pathway):
748 return process_one_side_list(osl, SummarizeMerger(), bw_summarize_pathway)
750 @cache_result
751 @declare_pure_function
752 def osl_vector_gen_cached(osl_gen, osl):
753 """Just to emulate caching for osl_gen.__call__ method.
754 this is a bit ugly, since this should really be handled by the caching itself to allow for
755 decorating class methods."""
756 return osl_gen(osl)
758 @declare_pure_function
759 def minus(a,b):
760 return a-b
762 @cache_result
763 @declare_pure_function
764 def make_tab_file(datamap, vg_osl, osl_name_as_meta=True, osl_size_as_meta=True, image_name_as_meta=True):
765 """As specified in http://orange.biolab.si/doc/reference/Orange.data.formats/
766 If image_name_as_meta or osl_name_as_meta parameters are present, the names of the
767 respective objects are added as meta columns.
769 tab_file = get_output_resultfile('.tab')
771 def tab_denoted(fout, l):
772 """Writes tab-denoted elements of list @l to output stream @fout"""
773 strings = map(str, l)
774 for el in strings:
775 if '\t' in el:
776 raise RuntimeError("Elements of tab-denoted list must not contain tabs.")
777 fout.write('\t'.join(strings) + '\n')
779 def get_meta(osl_m, osl_size_m, image_m):
780 return list( itertools.compress((osl_m, osl_size_m, image_m),
781 (osl_name_as_meta, osl_size_as_meta, image_name_as_meta)))
783 with open(tab_file.filename, 'w') as fout:
784 # annotations - column names
785 tab_denoted(fout, chain( vg_osl.annotations,
786 datamap.image_annotations,
787 get_meta('OSL name', 'OSL size', 'Image name')))
789 # column data types
790 tab_denoted(fout, chain( vg_osl.types,
791 datamap.image_types,
792 get_meta('string', 'continuous', 'string')))
794 # column info type: empty (normal columns) / class (main class attribute) / multiclass / meta
795 tab_denoted(fout, chain( # attributes are no class
796 [''] * len(vg_osl.types),
797 # for the first class attribute if present
798 [ 'class' ] * len(datamap.image_types[:1]),
799 # for the following class attributes if present
800 [ 'meta' ] * len(datamap.image_types[1:]),
801 #[ 'multiclass' ] * len(datamap.image_types[1:]),
802 # meta information if requested
803 get_meta('meta', 'meta', 'meta')))
805 # the data itself
806 for num, (osl, image) in enumerate(datamap):
807 logging.info('Tab file %d%% (%d / %d)'%(100* (num+1) / len(datamap), num+1, len(datamap)))
809 tab_denoted(fout, chain( # the osl
810 osl_vector_gen_cached(vg_osl, osl),
811 # the image
812 map(float, image.data),
813 # the meta data
814 get_meta(osl.name, float(len(osl)), image.name)))
816 return tab_file
821 ## Playground:
826 if __name__ == '__main__':
827 def main():
828 import logging
829 from logging import handlers
830 if False:
831 logger = logging.getLogger()
832 logger.setLevel(logging.INFO)
833 ch = handlers.WatchedFileHandler('LOG', mode='w')
834 logger.addHandler(ch)
836 from utils.godb_models import Game, GameList, OneSideList, PLAYER_COLOR_BLACK, PLAYER_COLOR_WHITE
837 from utils.godb_session import godb_session_maker
838 from utils import db_cache
840 def test1():
841 ## import'n'init
842 s = godb_session_maker(filename=':memory:')
844 ## Prepare data
846 gl = GameList("pokus")
847 s.godb_scan_dir_as_gamelist('./TEST_FILES/games', gl)
848 s.add(gl)
850 # add all the games into the all.pat file
851 osl = OneSideList("all.pat")
852 osl.batch_add(gl.games, PLAYER_COLOR_BLACK)
853 osl.batch_add(gl.games, PLAYER_COLOR_WHITE)
854 s.add(osl)
855 s.commit()
857 ## Prepare the pattern vector game processing pathway
858 ## game -> BlackWhite( vector_black, vector_white )
860 spatial_dict = generate_spatial_dictionary(gl, spatmin=2)
862 # the pathway: game -> bw rawpat files -> bw transformed rawpat files -> bw summarized pat files
863 bw_game_summarize = partial_right(process_game,
864 partial_right(raw_patternscan_game, spatial_dict),
865 [ partial_right(transform_rawpatfile,
866 #transform={ 'border':partial_right(minus, 1) },
867 ignore=['border', 'cont']),
868 summarize_rawpat_file
870 all_pat = make_all_pat(osl, bw_game_summarize)
872 vg_pat = BWPatternVectorGenerator( bw_game_summarize,
873 PatternVectorMaker(all_pat, 100) )
874 vg_local = BWLocalSeqVectorGenerator()
875 vg_bdist = BWBdistVectorGenerator()
877 ## Process One game
881 print vg_pat(game)
882 print vg_local(game)
883 print vg_bdist(game)
886 ## Process One Side List
888 gen_n_merge = [ (vg_pat, VectorApply(VectorSumMerger(), finish_fc=linear_rescale)),
889 (vg_local, VectorArithmeticMeanMerger()),
890 (vg_bdist, VectorArithmeticMeanMerger())]
892 vg_osl = OSLVectorGenerator(gen_n_merge)
894 generate = partial( osl_vector_gen_cached, vg_osl)
896 # not cached
897 #vec, annotations = vg_osl(osl), vg_osl.annotations
899 # cached
901 ## now the pathway is ready, we can process whatewer OSL we
902 # feel up to, osl in the following is just an example
903 vec, annotations = generate(osl), vg_osl.annotations
905 for i in xrange(len(annotations)):
906 print vec[i], '\t\t', annotations[i]
909 def test_rescale():
910 import numpy
911 from pylab import figure, scatter, subplot, show
913 vec = numpy.random.random( size=10)
914 print vec
915 print linear_rescale(vec, a=-20, b=20)
917 vec = numpy.array([ 452915., 288357., 271245., 111039., 84811., 74074.,
918 58663., 62257., 55296., 46359., 51022., 41049.,
919 31297., 35259., 34467., 30918., 29869., 36875.,
920 29592., 28075., 25823., 27479., 26343., 26964.,
921 24093., 24724., 23135., 22266., 21725., 21769.,
922 20130., 21625., 20200., 20619., 19741., 19049.,
923 17434., 20167., 19830., 16458., 16513., 21720.,
924 20933., 20216., 18414., 17442., 12046., 16186.,
925 16732., 16142., 15126., 15332., 15435., 12925.,
926 14072., 16321., 11391., 14884., 13147., 15162.,
927 14247., 15578., 11826., 12009., 11533., 12349.,
928 12219., 12590., 10581., 14550., 10699., 12384.,
929 11795., 10769., 12617., 12576., 12281., 11311.,
930 12479., 11327., 11398., 11814., 11050., 10248.,
931 10506., 11541., 12401., 9580., 11201., 10704.,
932 9766., 10402., 9422., 12888., 9473., 9536.,
933 10933., 10844., 11005., 8112., 0.])
935 figure(1)
936 subplot(321)
937 scatter(range(len(vec)), vec, marker='x', c='r')
938 subplot(322)
939 scatter(range(len(vec)), linear_rescale(vec), marker='x', c='g')
940 subplot(323)
941 scatter(range(len(vec)), numpy.log(1 + vec), marker='x', c='b')
942 subplot(324)
943 scatter(range(len(vec)), log_rescale(vec), marker='x', c='y')
944 subplot(325)
945 scatter(range(len(vec)), vec / sum(vec), marker='x', c='b')
946 show()
948 def test_bdist_hist():
949 s = godb_session_maker(filename=':memory:')#, echo=True)
950 game = s.godb_sgf_to_game('./TEST_FILES/test_bdist2.sgf')
952 bdg = BWBdistVectorGenerator(by_line=[2, 3, 4], by_moves=[4, 6])
953 bw = bdg(game)
954 assert len(bdg.annotations) == len(bw[0]) == len(bw[1])
956 print "Interval \t\tBlack\tWhite"
957 print "-" * 40
958 for ann, b, w in zip( bdg.annotations, bw[0], bw[1] ):
959 print "%s\t\t"%(ann), int(b), "\t", int(w)
961 def test_win_stat():
962 s = godb_session_maker(filename=':memory:')#, echo=True)
963 #gl = s.godb_add_dir_as_gamelist('./files/')
965 game = s.godb_sgf_to_game('../data/go_teaching_ladder/reviews/5443-breakfast-m711-A2.sgf')
967 bdg = BWWinStatVectorGenerator()
968 #bdg = BWWinPointsStatVectorGenerator()
969 bw = bdg(game)
970 #continue
972 assert len(bdg.annotations) == len(bw[0]) == len(bw[1])
974 print "Interval \t\tBlack\tWhite"
975 print "-" * 40
976 for ann, b, w in zip( bdg.annotations, bw[0], bw[1] ):
977 print "%30s\t\t" % (ann), b, "\t", w
983 def header(text):
984 print "=" * 10, "\n"+text +"\n", "=" * 10
986 header("PROCESSING PATHWAY TEST")
987 test1()
989 return
991 header("RESCALE TEST")
992 test_rescale()
993 header("BDIST HIST TEST")
994 test_bdist_hist()
995 #test_capture_hist()
996 header("WINSTAT TEST")
997 test_win_stat()
999 main()