Minor changes
[matilda.git] / src / pat3.c
blob441bd29661e98e3196312b918afa8c892cea4d71
1 /*
2 Functions that support the use of small 3x3 patterns hand crafted by the authors
3 of GNU Go, MoGo and others over the years.
5 The life of these patterns is as follow:
6 * On startup a pat3 file is loaded with a number of 3x3 patterns suggesting
7 play at the center intersection. The pattern is flipped and rotated and stored
8 in a hash table for both players (with the color inverted for white). They are
9 stored in their 16-bit value form.
11 * In MCTS each candidate position can be transposed to a 3x3 array, which fixed
12 out of bounds codification, fliped and rotated (but the color remains the same)
13 and searched for in the appropriate hash table.
16 #include "config.h"
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <assert.h>
23 #include "alloc.h"
24 #include "board.h"
25 #include "engine.h"
26 #include "file_io.h"
27 #include "flog.h"
28 #include "hash_table.h"
29 #include "matrix.h"
30 #include "pat3.h"
31 #include "stringm.h"
32 #include "types.h"
35 static u16 b_pattern_table[65536];
36 static u16 w_pattern_table[65536];
37 static bool pat3_table_inited = false;
39 static hash_table * weights_table = NULL;
40 static u32 weights_found = 0;
41 static u32 weights_not_found = 0;
44 typedef struct __pat3_{
45 u16 value;
46 u16 weight;
47 } pat3;
50 Convert some final symbols; non-final symbols are left as-is here
52 static void clean_symbols(
53 u8 p[static 3][3]
54 ) {
55 char * s;
57 for (u8 i = 0; i < 3; ++i) {
58 for (u8 j = 0; j < 3; ++j) {
59 switch (p[i][j]) {
60 case SYMBOL_EMPTY:
61 p[i][j] = EMPTY;
62 break;
63 case SYMBOL_BORDER:
64 p[i][j] = ILLEGAL;
65 break;
66 case SYMBOL_OWN_STONE:
67 p[i][j] = BLACK_STONE;
68 break;
69 case SYMBOL_OPT_STONE:
70 p[i][j] = WHITE_STONE;
71 break;
72 case SYMBOL_OWN_OR_EMPTY:
73 case SYMBOL_OPT_OR_EMPTY:
74 case SYMBOL_STONE_OR_EMPTY:
75 break;
76 default:
77 s = alloc();
78 snprintf(s, MAX_PAGE_SIZ, "pattern file format error; unknown symbol: '%c', %u\n", p[i][j], p[i][j]);
79 flog_crit("pat3", s);
80 release(s);
86 static void pat3_insert(
87 u16 value,
88 u16 value_inv,
89 u16 weight
90 ) {
91 /* patterns from blacks perspective */
92 b_pattern_table[value] = weight;
94 /* patterns from whites perspective */
95 w_pattern_table[value_inv] = weight;
99 Lookup of pattern value for the specified player.
100 RETURNS pattern weight or 0 if not found
102 u16 pat3_find(
103 u16 value,
104 bool is_black
106 return is_black ? b_pattern_table[value] : w_pattern_table[value];
109 static void flip(
110 const u8 src[static 3][3],
111 u8 dst[static 3][3]
113 for (u8 i = 0; i < 3; ++i) {
114 for (u8 j = 0; j < 3; ++j) {
115 dst[i][j] = src[2 - i][j];
120 static void rotate(
121 const u8 src[static 3][3],
122 u8 dst[static 3][3],
123 u8 rotations
125 u8 i;
126 u8 j;
127 switch (rotations) {
128 case 1:
129 for (i = 0; i < 3; ++i) {
130 for (j = 0; j < 3; ++j) {
131 dst[i][j] = src[2 - j][i];
134 break;
135 case 2:
136 for (i = 0; i < 3; ++i) {
137 for (j = 0; j < 3; ++j) {
138 dst[i][j] = src[2 - i][2 - j];
141 break;
142 case 3:
143 for (i = 0; i < 3; ++i) {
144 for (j = 0; j < 3; ++j) {
145 dst[i][j] = src[j][2 - i];
151 static void reduce_pattern(
152 u8 v[static 3][3],
153 u8 method
155 u8 r[3][3];
156 u8 f[3][3];
157 switch (method) {
158 case ROTATE90:
159 rotate((const u8 (*)[3])v, r, 1);
160 break;
161 case ROTATE180:
162 rotate((const u8 (*)[3])v, r, 2);
163 break;
164 case ROTATE270:
165 rotate((const u8 (*)[3])v, r, 3);
166 break;
167 case ROTFLIP0:
168 flip((const u8 (*)[3])v, r);
169 break;
170 case ROTFLIP90:
171 rotate((const u8 (*)[3])v, f, 1);
172 flip((const u8 (*)[3])f, r);
173 break;
174 case ROTFLIP180:
175 rotate((const u8 (*)[3])v, f, 2);
176 flip((const u8 (*)[3])f, r);
177 break;
178 case ROTFLIP270:
179 rotate((const u8 (*)[3])v, f, 3);
180 flip((const u8 (*)[3])f, r);
181 break;
182 default: /* NOREDUCE */
183 return;
186 memcpy(v, r, 3 * 3);
190 Rotate and flip the pattern to its unique representative.
191 Avoid using, is not optimized.
193 void pat3_reduce_auto(
194 u8 v[static 3][3]
196 u8 b[3][3];
198 for (u8 reduction = ROTATE90; reduction <= ROTFLIP270; ++reduction) {
199 memcpy(b, v, 3 * 3);
200 reduce_pattern(b, ROTATE90);
202 if (memcmp(b, v, 3 * 3) < 0) {
203 memcpy(v, b, 3 * 3);
209 Transposes part of an input matrix board into a 3x3 matrix pattern codified,
210 with board safety.
212 void pat3_transpose(
213 u8 dst[static 3][3],
214 const u8 p[static TOTAL_BOARD_SIZ],
215 move m
217 assert(is_board_move(m));
218 assert(p[m] == EMPTY);
220 d8 x;
221 d8 y;
222 move_to_coord(m, (u8 *)&x, (u8 *)&y);
224 d8 i;
225 d8 j;
226 d8 ki;
227 d8 kj;
229 for (j = y - 1, kj = 0; j <= y + 1; ++j, ++kj) {
230 for (i = x - 1, ki = 0; i <= x + 1; ++i, ++ki) {
231 if (i >= 0 && j >= 0 && i < BOARD_SIZ && j < BOARD_SIZ) {
232 move n = coord_to_move(i, j);
233 dst[ki][kj] = p[n];
234 } else {
235 dst[ki][kj] = ILLEGAL; /* edge of the board */
242 Codifies the pattern in a 16 bit unsigned value.
244 u16 pat3_to_string(
245 const u8 p[static 3][3]
247 u16 ret = p[0][0] & 3;
248 ret = (ret << 2) + (p[0][1] & 3);
249 ret = (ret << 2) + (p[0][2] & 3);
251 ret = (ret << 2) + (p[1][0] & 3);
252 assert(p[1][1] == EMPTY);
253 ret = (ret << 2) + (p[1][2] & 3);
255 ret = (ret << 2) + (p[2][0] & 3);
256 ret = (ret << 2) + (p[2][1] & 3);
257 ret = (ret << 2) + (p[2][2] & 3);
258 return ret;
262 Decodes a 16-bit value into a 3x3 pattern, with empty center.
264 void string_to_pat3(
265 u8 dst[static 3][3],
266 u16 src
268 dst[2][2] = src & 3;
269 src >>= 2;
270 dst[2][1] = src & 3;
271 src >>= 2;
272 dst[2][0] = src & 3;
273 src >>= 2;
274 dst[1][2] = src & 3;
275 dst[1][1] = EMPTY;
276 src >>= 2;
277 dst[1][0] = src & 3;
278 src >>= 2;
279 dst[0][2] = src & 3;
280 src >>= 2;
281 dst[0][1] = src & 3;
282 src >>= 2;
283 dst[0][0] = src & 3;
286 static u8 _count_stones(
287 const u8 p[static 3][3]
289 u8 ret = 0;
291 for (u8 i = 0; i < 3; ++i) {
292 for (u8 j = 0; j < 3; ++j) {
293 if (p[i][j] == WHITE_STONE || p[i][j] == BLACK_STONE) {
294 ++ret;
299 return ret;
303 Invert stone colors.
305 void pat3_invert(
306 u8 p[static 3][3]
308 for (u8 i = 0; i < 3; ++i) {
309 for (u8 j = 0; j < 3; ++j) {
310 if (p[i][j] == BLACK_STONE) {
311 p[i][j] = WHITE_STONE;
312 } else if (p[i][j] == WHITE_STONE) {
313 p[i][j] = BLACK_STONE;
319 static void multiply_and_store(
320 const u8 pat[static 3][3]
322 u8 p[3][3];
323 u16 weight = 1;
325 if (weights_table != NULL) {
326 /* Discover weight */
327 memcpy(p, pat, 3 * 3);
328 pat3_reduce_auto(p);
329 u16 pattern = pat3_to_string((const u8 (*)[3])p);
330 pat3 tmp;
331 tmp.value = pattern;
332 pat3 * tmp2 = (pat3 *)hash_table_find(weights_table, &tmp);
334 if (tmp2 == NULL) {
335 weight = (65535 / WEIGHT_SCALE);
336 weights_not_found++;
337 } else {
338 weight = tmp2->weight;
339 weights_found++;
343 u8 p_inv[3][3];
344 for (u8 r = 1; r < 9; ++r) {
345 memcpy(p, pat, 3 * 3);
346 reduce_pattern(p, r);
347 u16 value = pat3_to_string((const u8 (*)[3])p);
349 if (pat3_find(value, true) == 0) {
350 memcpy(p_inv, p, 3 * 3);
351 pat3_invert(p_inv);
352 u16 value_inv = pat3_to_string((const u8 (*)[3])p);
353 assert(pat3_find(value_inv, false) == 0);
354 pat3_insert(value, value_inv, weight);
360 The original pattern is expanded into all possible forms, then rotated/flip and
361 saved if unique in the hash table under the correct patterns group (patterns
362 generated from same original p.). For uniqueness the attributes are also taken
363 into consideration.
365 RETURNS total number of new unique patterns
367 static void expand_pattern(
368 const u8 pat[static 3][3]
370 u8 p[3][3];
371 memcpy(p, pat, 3 * 3);
373 for (u8 i = 0; i < 3; ++i) {
374 for (u8 j = 0; j < 3; ++j) {
375 switch (p[i][j]) {
376 case SYMBOL_OWN_OR_EMPTY:
377 p[i][j] = BLACK_STONE;
378 expand_pattern((const u8 (*)[3])p);
379 p[i][j] = EMPTY;
380 expand_pattern((const u8 (*)[3])p);
381 return;
382 case SYMBOL_OPT_OR_EMPTY:
383 p[i][j] = WHITE_STONE;
384 expand_pattern((const u8 (*)[3])p);
385 p[i][j] = EMPTY;
386 expand_pattern((const u8 (*)[3])p);
387 return;
388 case SYMBOL_STONE_OR_EMPTY:
389 p[i][j] = BLACK_STONE;
390 expand_pattern((const u8 (*)[3])p);
391 p[i][j] = WHITE_STONE;
392 expand_pattern((const u8 (*)[3])p);
393 p[i][j] = EMPTY;
394 expand_pattern((const u8 (*)[3])p);
395 return;
400 if (_count_stones((const u8 (*)[3])p) < 2) {
401 flog_crit("pat3", "failed to open and expand patterns because the expansion would create patterns with a single stone or less");
404 /* invert color, rotate and flip to generate equivalent pattern
405 configurations */
406 multiply_and_store((const u8 (*)[3])p);
410 static u32 read_pat3_file(
411 const char * restrict filename,
412 char * restrict buffer
414 d32 chars_read = read_ascii_file(buffer, MAX_FILE_SIZ, filename);
415 if (chars_read < 0) {
416 flog_crit("pat3", "couldn't open file for reading");
419 u8 pat[3][3];
420 u8 pat_pos = 0;
421 u32 pats_read = 0;
423 char * line;
424 char * init_str = buffer;
425 char * save_ptr;
426 while ((line = strtok_r(init_str, "\r\n", &save_ptr)) != NULL) {
427 init_str = NULL;
429 line_cut_before(line, '#');
431 line = trim(line);
432 if (line == NULL) {
433 continue;
436 u16 len = strlen(line);
437 if (len == 0) {
438 continue;
441 if (len == 3) {
442 pat[pat_pos][0] = line[0];
443 pat[pat_pos][1] = line[1];
444 pat[pat_pos][2] = line[2];
445 ++pat_pos;
447 if (pat_pos == 3) {
448 clean_symbols(pat);
450 /* generate all combinations and store in hash tables */
451 expand_pattern((const u8 (*)[3])pat);
452 pats_read += 1;
454 pat_pos = 0;
459 return pats_read;
462 static u32 pat3_hash_function(
463 void * a
465 pat3 * b = (pat3 *)a;
467 return b->value;
470 static int pat3_compare_function(
471 const void * restrict a,
472 const void * restrict b
474 pat3 * f1 = (pat3 *)a;
475 pat3 * f2 = (pat3 *)b;
477 return ((d32)(f2->value)) - ((d32)(f1->value));
480 static u32 read_patern_weights(
481 char * buffer
483 weights_table = hash_table_create(1543, sizeof(pat3), pat3_hash_function, pat3_compare_function);
485 char * line;
486 char * init_str = buffer;
487 char * save_ptr;
488 while ((line = strtok_r(init_str, "\r\n", &save_ptr)) != NULL) {
489 init_str = NULL;
491 line_cut_before(line, '#');
493 line = trim(line);
494 if (line == NULL) {
495 continue;
498 u16 len = strlen(line);
499 if (len == 0) {
500 continue;
503 char * save_ptr2;
504 char * word1 = strtok_r(line, " ", &save_ptr2);
505 if (word1 == NULL) {
506 continue;
509 char * word2 = strtok_r(NULL, " ", &save_ptr2);
510 if (word2 == NULL) {
511 continue;
514 long int tmp1 = strtol(word1, NULL, 16);
515 long int tmp2 = strtol(word2, NULL, 10);
516 if (tmp1 > 65535 || tmp2 > 65535) {
517 continue;
520 u16 pattern = (u16)tmp1;
521 /* Weight scaling for totals to fit in 16 bit and not having 0s */
522 tmp2 = (tmp2 / WEIGHT_SCALE) + 1;
523 u16 weight = (u16)tmp2;
525 pat3 * p = malloc(sizeof(pat3));
526 p->value = pattern;
527 p->weight = weight;
529 if (!hash_table_exists(weights_table, p)) {
530 hash_table_insert_unique(weights_table, p);
534 return weights_table->elements;
538 Reads a .pat3 patterns file and expands all patterns into all possible and
539 patternable configurations.
541 void pat3_init() {
542 if (pat3_table_inited) {
543 return;
546 pat3_table_inited = true;
548 char * file_buf = malloc(MAX_FILE_SIZ);
549 if (file_buf == NULL) {
550 flog_crit("pat3", "system out of memory");
553 char * buf = alloc();
555 if (USE_PATTERN_WEIGHTS) {
557 Read pattern weights file
559 char * filename = alloc();
560 snprintf(filename, MAX_PAGE_SIZ, "%s%ux%u.weights", data_folder(), BOARD_SIZ, BOARD_SIZ);
562 d32 chars_read = read_ascii_file(file_buf, MAX_FILE_SIZ, filename);
563 if (chars_read < 0) {
564 char * s = alloc();
565 snprintf(s, MAX_PAGE_SIZ, "could not read %s", filename);
566 flog_warn("pat3", s);
567 release(s);
568 } else {
569 u32 weights = read_patern_weights(file_buf);
571 snprintf(buf, MAX_PAGE_SIZ, "read %s (%u weights)", filename, weights);
572 flog_info("pat3", buf);
575 release(filename);
579 Discover .pat3 files
581 char * pat3_filenames[128];
582 u32 files_found = recurse_find_files(data_folder(), ".pat3", pat3_filenames, 128);
584 snprintf(buf, MAX_PAGE_SIZ, "found %u 3x3 pattern files", files_found);
585 flog_info("pat3", buf);
587 for (u32 i = 0; i < files_found; ++i) {
588 u32 patterns_found = read_pat3_file(pat3_filenames[i], file_buf);
590 snprintf(buf, MAX_PAGE_SIZ, "read %s (%u patterns)", pat3_filenames[i], patterns_found);
591 flog_info("pat3", buf);
593 free(pat3_filenames[i]);
596 free(file_buf);
598 if (USE_PATTERN_WEIGHTS && weights_table != NULL) {
599 snprintf(buf, MAX_PAGE_SIZ, "%u/%u expanded patterns weighted", weights_found, weights_found + weights_not_found);
600 flog_info("pat3", buf);
602 hash_table_destroy(weights_table, true);
603 weights_table = NULL;
606 release(buf);