Pachi Dochi 6.00
[pachi/pachi-r6144.git] / uct / slave.c
blobfa58c71d493ecf8ad34947a9ec275770441e165b
1 /* This is the slave specific part of the distributed engine.
2 * See introduction at top of distributed/distributed.c.
3 * The slave maintains a hash table of nodes received from the
4 * master. When receiving stats the hash table gives a pointer to the
5 * tree node to update. When sending stats we remember in the tree
6 * what was previously sent so that only the incremental part has to
7 * be sent. The incremental part is smaller and can be compressed.
8 * The compression is not yet done in this version. */
10 /* Similarly the master only sends stats increments.
11 * They include only contributions from other slaves. */
13 /* The keys for the hash table are coordinate paths from
14 * a root child to a given node. See distributed/distributed.h
15 * for the encoding of a path to a 64 bit integer. */
17 /* To allow the master to select the best move, slaves also send
18 * absolute playout counts for the best top level nodes (children
19 * of the root node), including contributions from other slaves. */
21 /* Pass me arguments like a=b,c=d,...
22 * Slave specific arguments (see uct.c for the other uct arguments
23 * and distributed.c for the port arguments) :
24 * slave required to indicate slave mode
25 * max_nodes=MAX_NODES default 80K
26 * stats_hbits=STATS_HBITS default 24. 2^stats_bits = hash table size
29 #include <assert.h>
30 #include <math.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
35 #define MAX_VERBOSE_LOGS 1000
36 #define DEBUG
38 #include "debug.h"
39 #include "board.h"
40 #include "gtp.h"
41 #include "move.h"
42 #include "timeinfo.h"
43 #include "uct/internal.h"
44 #include "uct/search.h"
45 #include "uct/slave.h"
46 #include "uct/tree.h"
49 /* UCT infrastructure for a distributed engine slave. */
51 /* For debugging only. */
52 static struct hash_counts h_counts;
53 static long parent_not_found = 0;
54 static long parent_leaf = 0;
55 static long node_not_found = 0;
57 /* Hash table entry mapping path to node. */
58 struct tree_hash {
59 path_t coord_path;
60 struct tree_node *node;
63 void *
64 uct_htable_alloc(int hbits)
66 return calloc2(1 << hbits, sizeof(struct tree_hash));
69 /* Clear the hash table. Used only when running as slave for the distributed engine. */
70 void uct_htable_reset(struct tree *t)
72 if (!t->htable) return;
73 double start = time_now();
74 memset(t->htable, 0, (1 << t->hbits) * sizeof(t->htable[0]));
75 if (DEBUGL(3))
76 fprintf(stderr, "tree occupied %ld %.1f%% inserts %ld collisions %ld/%ld %.1f%% clear %.3fms\n"
77 "parent_not_found %.1f%% parent_leaf %.1f%% node_not_found %.1f%%\n",
78 h_counts.occupied, h_counts.occupied * 100.0 / (1 << t->hbits),
79 h_counts.inserts, h_counts.collisions, h_counts.lookups,
80 h_counts.collisions * 100.0 / (h_counts.lookups + 1),
81 (time_now() - start)*1000,
82 parent_not_found * 100.0 / (h_counts.lookups + 1),
83 parent_leaf * 100.0 / (h_counts.lookups + 1),
84 node_not_found * 100.0 / (h_counts.lookups + 1));
85 if (DEBUG_MODE) h_counts.occupied = 0;
88 /* Find a node given its coord path from root. Insert it in the
89 * hash table if it is not already there.
90 * Return the tree node, or NULL if the node cannot be found.
91 * The tree is modified in background while this function is running.
92 * prev is only used to optimize the tree search, given that calls to
93 * tree_find_node are made with sorted coordinates (increasing levels
94 * and increasing coord within a level). */
95 static struct tree_node *
96 tree_find_node(struct tree *t, struct incr_stats *is, struct tree_node *prev)
98 assert(t && t->htable);
99 path_t path = is->coord_path;
100 /* pass and resign must never be inserted in the hash table. */
101 assert(path > 0);
103 int hash, parent_hash;
104 bool found;
105 find_hash(hash, t->htable, t->hbits, path, found, h_counts);
106 struct tree_hash *hnode = &t->htable[hash];
108 if (DEBUGVV(7))
109 fprintf(stderr,
110 "find_node %"PRIpath" %s found %d hash %d playouts %d node %p\n", path,
111 path2sstr(path, t->board), found, hash, is->incr.playouts, hnode->node);
113 if (found) return hnode->node;
115 /* The master sends parents before children so the parent should
116 * already be in the hash table. */
117 path_t parent_p = parent_path(path, t->board);
118 struct tree_node *parent;
119 if (parent_p) {
120 find_hash(parent_hash, t->htable, t->hbits,
121 parent_p, found, h_counts);
122 parent = t->htable[parent_hash].node;
123 } else {
124 parent = t->root;
126 struct tree_node *node = NULL;
127 if (parent) {
128 /* Search for the node in parent's children. */
129 coord_t leaf = leaf_coord(path, t->board);
130 node = (prev && prev->parent == parent ? prev->sibling : parent->children);
131 while (node && node->coord != leaf) node = node->sibling;
133 if (DEBUG_MODE) parent_leaf += !parent->is_expanded;
134 } else {
135 if (DEBUG_MODE) parent_not_found++;
136 if (DEBUGVV(7))
137 fprintf(stderr, "parent of %"PRIpath" %s not found\n",
138 path, path2sstr(path, t->board));
141 /* Insert the node in the hash table. */
142 hnode->node = node;
143 if (DEBUG_MODE) h_counts.inserts++, h_counts.occupied++;
144 if (DEBUGVV(7))
145 fprintf(stderr, "insert path %"PRIpath" %s hash %d playouts %d node %p\n",
146 path, path2sstr(path, t->board), hash, is->incr.playouts, node);
148 if (DEBUG_MODE && !node) node_not_found++;
150 hnode->coord_path = path;
151 return node;
155 /* Read and discard any binary arguments. The number of
156 * bytes to be skipped is given by @size in the command. */
157 static void
158 discard_bin_args(char *args)
160 char *s = strchr(args, '@');
161 int size = 0;
162 if (s) size = atoi(s+1);
163 while (size) {
164 char buf[64*1024];
165 int len = sizeof(buf);
166 if (len > size) len = size;
167 len = fread(buf, 1, len, stdin);
168 if (len <= 0) break;
169 size -= len;
173 enum parse_code
174 uct_notify(struct engine *e, struct board *b, int id, char *cmd, char *args, char **reply)
176 struct uct *u = e->data;
178 static bool board_resized = false;
179 if (is_gamestart(cmd)) {
180 board_resized = true;
181 uct_pondering_stop(u);
184 /* Force resending the whole command history if we are out of sync
185 * but do it only once, not if already getting the history. */
186 if ((move_number(id) != b->moves || !board_resized)
187 && !reply_disabled(id) && !is_reset(cmd)) {
188 static char buf[128];
189 snprintf(buf, sizeof(buf), "Out of sync, %d %s, move %d expected", id, cmd, b->moves);
190 if (UDEBUGL(0))
191 fprintf(stderr, "%s\n", buf);
192 discard_bin_args(args);
194 *reply = buf;
195 /* Let gtp_parse() complain about invalid commands. */
196 if (!gtp_is_valid(cmd) && !is_repeated(cmd)) return P_OK;
197 return P_DONE_ERROR;
199 return reply_disabled(id) ? P_NOREPLY : P_OK;
203 /* Read the move stats sent by the master, as a binary array of
204 * incr_stats structs. The stats come sorted by increasing coord path.
205 * To simplify the code, we assume that master and slave have the same
206 * architecture (store values identically).
207 * Keep this code in sync with distributed/distributed.c:select_best_move().
208 * Return true if ok, false if error. */
209 static bool
210 receive_stats(struct uct *u, int size)
212 if (size % sizeof(struct incr_stats)) return false;
213 int nodes = size / sizeof(struct incr_stats);
214 if (nodes > (1 << u->stats_hbits)) return false;
216 struct tree *t = u->t;
217 assert(nodes && t->htable);
218 struct tree_node *prev = NULL;
219 double start_time = time_now();
221 for (int n = 0; n < nodes; n++) {
222 struct incr_stats is;
223 if (fread(&is, sizeof(struct incr_stats), 1, stdin) != 1)
224 return false;
226 if (UDEBUGL(7))
227 fprintf(stderr, "read %5d/%d %6d %.3f %"PRIpath" %s\n", n, nodes,
228 is.incr.playouts, is.incr.value, is.coord_path,
229 path2sstr(is.coord_path, t->board));
231 struct tree_node *node = tree_find_node(t, &is, prev);
232 if (!node) continue;
234 /* node_total += others_incr */
235 stats_add_result(&node->u, is.incr.value, is.incr.playouts);
237 /* last_total += others_incr */
238 stats_add_result(&node->pu, is.incr.value, is.incr.playouts);
240 prev = node;
242 if (DEBUGVV(2))
243 fprintf(stderr, "read args for %d nodes in %.4fms\n", nodes,
244 (time_now() - start_time)*1000);
245 return true;
248 /* A tree traversal fills this array, then the nodes with most increments are sent. */
249 struct stats_candidate {
250 path_t coord_path;
251 int playout_incr;
252 struct tree_node *node;
255 /* We maintain counts per bucket to avoid sorting stats_queue.
256 * All nodes with n updates since last send go to bucket n.
257 * If we put all nodes above 1023 updates in the top bucket,
258 * we get at most 27 nodes in this bucket. So we can select
259 * exactly the best shared_nodes nodes if shared_nodes >= 27. */
260 #define MAX_BUCKETS 1024
261 static int bucket_count[MAX_BUCKETS];
263 /* Traverse the tree rooted at node, and append incremental stats
264 * for children to stats_queue. start_path is the coordinate path
265 * for the top node. Stats for a node are only appended if enough playouts
266 * have been made since the last send, and the level is not too deep.
267 * Return the updated stats count. */
268 static int
269 append_stats(struct stats_candidate *stats_queue, struct tree_node *node, int stats_count,
270 int max_count, path_t start_path, path_t max_path, int min_increment, struct board *b)
272 /* The children field is set only after all children are created
273 * so we can traverse the the tree while it is updated. */
274 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
276 if (is_pass(ni->coord)) continue;
277 if (ni->hints & TREE_HINT_INVALID) continue;
279 int incr = ni->u.playouts - ni->pu.playouts;
280 if (incr < min_increment) continue;
282 /* min_increment should be tuned to avoid overflow. */
283 if (stats_count >= max_count) {
284 if (DEBUGL(0))
285 fprintf(stderr, "*** stats overflow %d nodes\n", stats_count);
286 return stats_count;
288 path_t child_path = append_child(start_path, ni->coord, b);
289 stats_queue[stats_count].playout_incr = incr;
290 stats_queue[stats_count].coord_path = child_path;
291 stats_queue[stats_count++].node = ni;
293 if (incr >= MAX_BUCKETS) incr = MAX_BUCKETS - 1;
294 bucket_count[incr]++;
296 /* Do not recurse if level deep enough. */
297 if (child_path >= max_path) continue;
299 stats_count = append_stats(stats_queue, ni, stats_count, max_count,
300 child_path, max_path, min_increment, b);
302 return stats_count;
305 /* Used to sort by coord path the incremental stats to be sent. */
306 static int
307 coord_cmp(const void *p1, const void *p2)
309 path_t diff = ((struct incr_stats *)p1)->coord_path
310 - ((struct incr_stats *)p2)->coord_path;
311 return (int)(diff >> 32) | !!(int)diff;
314 /* Select from stats_queue at most shared_nodes candidates with
315 * biggest increments. Return a binary array sorted by coord path. */
316 static struct incr_stats *
317 select_best_stats(struct stats_candidate *stats_queue, int stats_count,
318 int shared_nodes, int *byte_size)
320 static struct incr_stats *out_stats = NULL;
321 if (!out_stats)
322 out_stats = malloc2(shared_nodes * sizeof(*out_stats));
324 /* Find the minimum increment to send. The bucket with minimum
325 * increment may be sent only partially. */
326 int out_count = 0;
327 int min_incr = MAX_BUCKETS;
328 do {
329 out_count += bucket_count[--min_incr];
330 } while (min_incr > 1 && out_count < shared_nodes);
332 /* Send all all increments > min_incr plus whatever we can at min_incr. */
333 int min_count = bucket_count[min_incr] - (out_count - shared_nodes);
334 struct incr_stats *os = out_stats;
335 out_count = 0;
336 for (int count = 0; count < stats_count; count++) {
337 int delta = stats_queue[count].playout_incr - min_incr;
338 if (delta < 0 || (delta == 0 && --min_count < 0)) continue;
340 struct tree_node *node = stats_queue[count].node;
341 os->incr = node->u;
342 stats_rm_result(&os->incr, node->pu.value, node->pu.playouts);
344 /* With virtual loss os->incr.playouts might be <= 0; we only
345 * send positive increments to other slaves so a virtual loss
346 * can be propagated to other machines (good). The undo of the
347 * virtual loss will be propagated later when node->u gets
348 * above node->pu. */
349 if (os->incr.playouts > 0) {
350 node->pu = node->u;
351 os->coord_path = stats_queue[count].coord_path;
352 assert(os->coord_path > 0);
353 os++;
354 out_count++;
356 assert (out_count <= shared_nodes);
358 *byte_size = (char *)os - (char *)out_stats;
360 /* Sort the increments by increasing coord path (required by master).
361 * Can be done in linear time with radix sort if qsort is too slow. */
362 qsort(out_stats, out_count, sizeof(*os), coord_cmp);
363 return out_stats;
366 /* Get incremental stats updates for the distributed engine.
367 * Return a binary array of incr_stats structs in coordinate order
368 * (increasing levels and increasing coordinates within a level).
369 * This function is called only by the main thread, but may be
370 * called while the tree is updated by the worker threads. Keep this
371 * code in sync with distributed/distributed.c:select_best_move(). */
372 static void *
373 report_incr_stats(struct uct *u, int *stats_size)
375 double start_time = time_now();
377 struct tree_node *root = u->t->root;
378 struct board *b = u->t->board;
380 /* The factor 3 below has experimentally been found to be
381 * sufficient. At worst if we fill stats_queue we will
382 * discard some stats updates but this is rare. */
383 int max_nodes = 3 * u->shared_nodes;
384 static struct stats_candidate *stats_queue = NULL;
385 if (!stats_queue) stats_queue = malloc2(max_nodes * sizeof(*stats_queue));
387 memset(bucket_count, 0, sizeof(bucket_count));
389 /* Try to fill the output buffer with the most important
390 * nodes (highest increments), while still traversing
391 * as little of the tree as possible. If we set min_increment
392 * too low we waste time. If we set it too high we can't
393 * fill the output buffer with the desired number of nodes.
394 * The best min_increment results in stats_count just above
395 * shared_nodes. However perfect tuning is not necessary:
396 * if we send too few nodes we just send shorter buffers
397 * more frequently. */
398 static int min_increment = 1;
399 static int stats_count = 0;
400 if (stats_count > 2 * u->shared_nodes) {
401 min_increment++;
402 } else if (stats_count < u->shared_nodes / 2 && min_increment > 1) {
403 min_increment--;
406 stats_count = append_stats(stats_queue, root, 0, max_nodes, 0,
407 max_parent_path(u, b), min_increment, b);
409 void *buf = select_best_stats(stats_queue, stats_count, u->shared_nodes, stats_size);
411 if (DEBUGVV(2))
412 fprintf(stderr,
413 "min_incr %d games %d stats_queue %d/%d sending %d/%d in %.3fms\n",
414 min_increment, root->u.playouts - root->pu.playouts, stats_count,
415 max_nodes, *stats_size / (int)sizeof(struct incr_stats), u->shared_nodes,
416 (time_now() - start_time)*1000);
417 root->pu = root->u;
418 return buf;
421 /* Get stats for the distributed engine. Return a buffer with one
422 * line "played_own root_playouts threads keep_looking @size", then
423 * a list of lines "coord playouts value" with absolute counts for
424 * children of the root node (including contributions from other
425 * slaves). The last line must not end with \n.
426 * If c is pass or resign, add this move with root->playouts weight.
427 * This function is called only by the main thread, but may be
428 * called while the tree is updated by the worker threads. Keep this
429 * code in sync with distributed/distributed.c:select_best_move(). */
430 static char *
431 report_stats(struct uct *u, struct board *b, coord_t c,
432 bool keep_looking, int bin_size)
434 static char reply[10240];
435 char *r = reply;
436 char *end = reply + sizeof(reply);
437 struct tree_node *root = u->t->root;
438 r += snprintf(r, end - r, "%d %d %d %d @%d", u->played_own, root->u.playouts,
439 u->threads, keep_looking, bin_size);
440 int min_playouts = root->u.playouts / 100;
442 /* Give a large weight to pass or resign, but still allow other moves. */
443 if (is_pass(c) || is_resign(c))
444 r += snprintf(r, end - r, "\n%s %d %.1f", coord2sstr(c, b),
445 root->u.playouts, 0.0);
447 /* We rely on the fact that root->children is set only
448 * after all children are created. */
449 for (struct tree_node *ni = root->children; ni; ni = ni->sibling) {
451 if (is_pass(ni->coord)) continue;
452 if (ni->u.playouts <= min_playouts || ni->hints & TREE_HINT_INVALID)
453 continue;
455 assert(ni->coord > 0 && ni->coord < board_size2(b));
456 char buf[4];
457 /* We return the values as stored in the tree, so from black's view. */
458 r += snprintf(r, end - r, "\n%s %d %.7f", coord2bstr(buf, ni->coord, b),
459 ni->u.playouts, ni->u.value);
461 return reply;
464 /* How long to wait in slave for initial stats to build up before
465 * replying to the genmoves command (in seconds) */
466 #define MIN_STATS_INTERVAL 0.05 /* 50ms */
468 /* genmoves is issued by the distributed engine master to all slaves, to:
469 * 1. Start a MCTS search if not running yet
470 * 2. Report current move statistics of the on-going search.
471 * The MCTS search is left running on the background when uct_genmoves()
472 * returns. It is stopped by receiving a play GTP command, triggering
473 * uct_pondering_stop(). */
474 /* genmoves gets in the args parameter
475 * "played_games nodes main_time byoyomi_time byoyomi_periods byoyomi_stones @size"
476 * and reads a binary array of coord, playouts, value to get stats of other slaves,
477 * except possibly for the first call at a given move number.
478 * See report_stats() for the description of the return value. */
479 char *
480 uct_genmoves(struct engine *e, struct board *b, struct time_info *ti, enum stone color,
481 char *args, bool pass_all_alive, void **stats_buf, int *stats_size)
483 struct uct *u = e->data;
484 assert(u->slave);
486 /* Prepare the state if the search is not already running.
487 * We must do this first since we tweak the state below
488 * based on instructions from the master. */
489 if (!thread_manager_running)
490 uct_genmove_setup(u, b, color);
492 /* Get playouts and time information from master. Keep this code
493 * in sync with distibuted/distributed.c:distributed_genmove(). */
494 if ((ti->dim == TD_WALLTIME
495 && sscanf(args, "%d %lf %lf %d %d", &u->played_all,
496 &ti->len.t.main_time, &ti->len.t.byoyomi_time,
497 &ti->len.t.byoyomi_periods, &ti->len.t.byoyomi_stones) != 5)
499 || (ti->dim == TD_GAMES && sscanf(args, "%d", &u->played_all) != 1)) {
500 return NULL;
503 static struct uct_search_state s;
504 if (!thread_manager_running) {
505 /* This is the first genmoves issue, start the MCTS
506 * now and let it run while we receive stats. */
507 memset(&s, 0, sizeof(s));
508 uct_search_start(u, b, color, u->t, ti, &s);
511 /* Read binary incremental stats if present, otherwise
512 * wait a bit to populate the statistics. */
513 int size = 0;
514 char *sizep = strchr(args, '@');
515 if (sizep) size = atoi(sizep+1);
516 if (!size) {
517 time_sleep(MIN_STATS_INTERVAL);
518 } else if (!receive_stats(u, size)) {
519 return NULL;
522 /* Check the state of the Monte Carlo Tree Search. */
524 int played_games = uct_search_games(&s);
525 uct_search_progress(u, b, color, u->t, ti, &s, played_games);
526 u->played_own = played_games - s.base_playouts;
528 bool keep_looking = !uct_search_check_stop(u, b, color, u->t, ti, &s, played_games);
529 coord_t best_coord;
530 uct_search_result(u, b, color, pass_all_alive, played_games, s.base_playouts, &best_coord);
532 *stats_buf = report_incr_stats(u, stats_size);
534 char *reply = report_stats(u, b, best_coord, keep_looking, *stats_size);
535 return reply;