Merge pull request #50 from lemonsqueeze/can_countercap
[pachi.git] / patternprob.c
blob9a411cd05ab2519aa3d25cbb30c72a7f8225210b
1 #define DEBUG
2 #include <assert.h>
3 #include <ctype.h>
4 #include <stdio.h>
5 #include <stdlib.h>
7 #include "board.h"
8 #include "debug.h"
9 #include "pattern.h"
10 #include "patternsp.h"
11 #include "patternprob.h"
14 /* We try to avoid needlessly reloading probability dictionary
15 * since it may take rather long time. */
16 static struct pattern_pdict *cached_dict;
18 struct pattern_pdict *
19 pattern_pdict_init(char *filename, struct pattern_config *pc)
21 if (cached_dict) {
22 cached_dict->pc = pc;
23 return cached_dict;
26 if (!filename)
27 filename = "patterns.prob";
28 FILE *f = fopen(filename, "r");
29 if (!f) {
30 if (DEBUGL(1))
31 fprintf(stderr, "No pattern probtable, will not use learned patterns.\n");
32 return NULL;
35 struct pattern_pdict *dict = calloc2(1, sizeof(*dict));
36 dict->pc = pc;
37 dict->table = calloc2(pc->spat_dict->nspatials + 1, sizeof(*dict->table));
39 char *sphcachehit = calloc2(pc->spat_dict->nspatials, 1);
40 hash_t (*sphcache)[PTH__ROTATIONS] = malloc(pc->spat_dict->nspatials * sizeof(sphcache[0]));
42 int i = 0;
43 char sbuf[1024];
44 while (fgets(sbuf, sizeof(sbuf), f)) {
45 struct pattern_prob *pb = calloc2(1, sizeof(*pb));
46 int c, o;
48 char *buf = sbuf;
49 if (buf[0] == '#') continue;
50 while (isspace(*buf)) buf++;
51 while (!isspace(*buf)) buf++; // we recompute the probability
52 while (isspace(*buf)) buf++;
53 c = strtol(buf, &buf, 10);
54 while (isspace(*buf)) buf++;
55 o = strtol(buf, &buf, 10);
56 pb->prob = (floating_t) c / o;
57 while (isspace(*buf)) buf++;
58 str2pattern(buf, &pb->p);
60 uint32_t spi = pattern2spatial(dict, &pb->p);
61 pb->next = dict->table[spi];
62 dict->table[spi] = pb;
64 /* Some spatials may not have been loaded if they correspond
65 * to a radius larger than supported. */
66 if (pc->spat_dict->spatials[spi].dist > 0) {
67 /* We rehash spatials in the order of loaded patterns. This way
68 * we make sure that the most popular patterns will be hashed
69 * last and therefore take priority. */
70 if (!sphcachehit[spi]) {
71 sphcachehit[spi] = 1;
72 for (unsigned int r = 0; r < PTH__ROTATIONS; r++)
73 sphcache[spi][r] = spatial_hash(r, &pc->spat_dict->spatials[spi]);
75 for (unsigned int r = 0; r < PTH__ROTATIONS; r++)
76 spatial_dict_addh(pc->spat_dict, sphcache[spi][r], spi);
79 i++;
82 free(sphcache);
83 free(sphcachehit);
84 if (DEBUGL(3))
85 spatial_dict_hashstats(pc->spat_dict);
87 fclose(f);
88 if (DEBUGL(1))
89 fprintf(stderr, "Loaded %d pattern-probability pairs.\n", i);
90 cached_dict = dict;
91 return dict;
94 floating_t
95 pattern_rate_moves(struct pattern_setup *pat,
96 struct board *b, enum stone color,
97 struct pattern *pats, floating_t *probs)
99 double total = 0;
100 for (int f = 0; f < b->flen; f++) {
101 probs[f] = NAN;
103 struct move mo = { .coord = b->f[f], .color = color };
104 if (is_pass(mo.coord))
105 continue;
106 if (!board_is_valid_move(b, &mo))
107 continue;
109 pattern_match(&pat->pc, pat->ps, &pats[f], b, &mo);
110 floating_t prob = pattern_prob(pat->pd, &pats[f]);
111 if (!isnan(prob)) {
112 probs[f] = prob;
113 total += prob;
115 if (DEBUGL(5)) {
116 char buf[256]; pattern2str(buf, &pats[f]);
117 fprintf(stderr, "=> move %s pattern %s prob %.3f\n", coord2sstr(mo.coord, b), buf, prob);
120 return total;