6 from data_about_players
import Data
8 class KNNOutputVectorGenerator(VectorGenerator
):
9 """ k-NearestNeighbour output vector generator."""
10 def __init__(self
, ref_dict
, k
=5, weight_param
=0.8, dist_mult
=10):
12 ref_dict is a dictionary of refence input/output vectors.
13 e.g. ref_dict= { (1.0,2.0):(9.0,16.0,21.0)
15 self
.ref_dict
= ref_dict
17 self
.weigth_param
= weight_param
18 self
.dist_mult
= dist_mult
19 def __call__(self
, player_vector
):
21 for ref_vec
in self
.ref_dict
.keys():
22 distance
.append((self
.distance(ref_vec
, player_vector
), ref_vec
))
26 # print "%2.3f"%(float(p),),
28 ref_output_vecs
= [ self
.ref_dict
[b
] for a
,b
in distance
[:self
.k
] ]
29 coefs
= [ self
.weight_fc(a
) for a
,b
in distance
[:self
.k
] ]
31 return linear_combination(ref_output_vecs
, coefs
)
32 def weight_fc(self
, distance
):
33 return self
.weigth_param
** (distance
)
34 def distance(self
, vec1
, vec2
):
35 if len(vec1
) != len(vec2
):
36 raise RuntimeError("Dimensions of vectors mismatch.")
37 ### the 10* multiplicative constant is empirically determined for correct scaling
38 return self
.dist_mult
* sqrt(sum([ (float(a
) - float(b
))**2 for a
,b
in zip(vec1
,vec2
)]))
41 if __name__
== '__main__':
42 main_pat_filename
= Data
.main_pat_filename
43 filename_play_other
= 'knn_other.data'
44 filename_play_ref
= 'knn_ref.data'
45 filename_play_ref_orig
= 'knn_ref_orig.data'
48 player_vector
= Data
.questionare_total
49 players_ignore
= [ "Yi Ch'ang-ho 2004-", "Yi Ch'ang-ho"] #,"Takao Shinji","Hane Naoki","Kobayashi Koichi" ]
51 players_all
= [ p
for p
in Data
.players_all
if p
not in players_ignore
]
52 players_ref
= [ p
for p
in player_vector
if p
not in players_ignore
]
53 players_other
= [ x
for x
in players_all
if x
not in players_ref
]
55 ### Object creating input vector when called
56 print >>sys
.stderr
, "Creating input vector generator from main pat file:", main_pat_filename
57 i
= InputVectorGenerator(main_pat_filename
, num_features
)
59 # Create list of input vectors
60 input_vectors_ref
= []
61 for name
in players_ref
:
62 input_vectors_ref
+= [i(Data
.pat_files_folder
+ name
)]
63 input_vectors_other
= []
64 for name
in players_other
:
65 input_vectors_other
+= [i(Data
.pat_files_folder
+ name
)]
67 if len(input_vectors_ref
) == 0:
68 print >>sys
.stderr
, "No reference vectors."
70 if len(input_vectors_other
) == 0:
71 print >>sys
.stderr
, "No vectors to process."
75 # Change this to False, if you do not want to use PCA
78 # Create PCA object, trained on input_vectors
79 print >>sys
.stderr
, "Running PCA."
80 pca
= PCA(input_vectors_ref
+ input_vectors_other
, reduce=True)
81 # Perform a PCA on input vectors
82 input_vectors_ref
= pca
.process_list_of_vectors(input_vectors_ref
)
83 input_vectors_other
= pca
.process_list_of_vectors(input_vectors_other
)
84 # Creates a Composed object that first generates an input vector
85 # and then performs a PCA analysis on it.
88 ### Object creating output vector when called;
90 for name
, input_vector
in zip(players_ref
, input_vectors_ref
):
91 ref_dict
[tuple(input_vector
)] = player_vector
[name
]
93 oknn
= KNNOutputVectorGenerator(ref_dict
, k
=k
)
95 # Create list of output vectors using weighted kNN algorithm approximating output_vector
96 output_vectors_other
= [ oknn(input_vector
) for input_vector
in input_vectors_other
]
97 output_vectors_ref
= [ oknn(input_vector
) for input_vector
in input_vectors_ref
]
99 def print_me(names
, vecs
, where
):
100 if len(names
) != len(vecs
):
101 raise RuntimeError("Dimensions of vectors mismatch.")
103 print >>sys
.stderr
, "Saving output_vectors to file:", where
105 for i
in xrange(len(names
)):
106 name_to_print
= '_'.join(names
[i
].split())
107 print_vector([name_to_print
] + list(vecs
[i
]), f
)
110 print_me(players_ref
, [player_vector
[name
] for name
in players_ref
], filename_play_ref_orig
)
111 print_me(players_ref
, output_vectors_ref
, filename_play_ref
)
112 print_me(players_other
, output_vectors_other
, filename_play_other
)
114 print >> sys
.stderr
, "\nNow plot that in Gnuplot by:"
115 #print >> sys.stderr, 'set xrange[0:%d] ; set yrange[0:%d]'%(size,size)
116 print >> sys
.stderr
, 'set xtics 1 ; set ytics 1'
117 print >> sys
.stderr
, 'set grid ; set size square'
118 print >> sys
.stderr
, 'plot "%s" using 2:3:1 with labels font "arial,11" point lt 10 pt 4 left, "%s" using 2:3:1 with labels font "arial,11" point lt 12 pt 4 left'%(filename_play_other
, filename_play_ref
)