More notes on the possible min/max method.
[pachi/pachi-r6144.git] / uct / slave.c
blob50da248751dfe067885ec3c3070a58fe1c67d667
1 /* This is the slave specific part of the distributed engine.
2 * See introduction at top of distributed/distributed.c.
3 * The slave maintains a hash table of nodes received from the
4 * master. When receiving stats the hash table gives a pointer to the
5 * tree node to update. When sending stats we remember in the tree
6 * what was previously sent so that only the incremental part has to
7 * be sent. The incremental part is smaller and can be compressed.
8 * The compression is not yet done in this version. */
10 /* Similarly the master only sends stats increments.
11 * They include only contributions from other slaves. */
13 /* The keys for the hash table are coordinate paths from
14 * a root child to a given node. See distributed/distributed.h
15 * for the encoding of a path to a 64 bit integer. */
17 /* To allow the master to select the best move, slaves also send
18 * absolute playout counts for the best top level nodes (children
19 * of the root node), including contributions from other slaves. */
21 /* Pass me arguments like a=b,c=d,...
22 * Slave specific arguments (see uct.c for the other uct arguments
23 * and distributed.c for the port arguments) :
24 * slave required to indicate slave mode
25 * max_nodes=MAX_NODES default 80K
26 * stats_hbits=STATS_HBITS default 24. 2^stats_bits = hash table size
29 #include <assert.h>
30 #include <math.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
35 #define MAX_VERBOSE_LOGS 1000
36 #define DEBUG
38 #include "debug.h"
39 #include "board.h"
40 #include "fbook.h"
41 #include "gtp.h"
42 #include "move.h"
43 #include "timeinfo.h"
44 #include "uct/internal.h"
45 #include "uct/search.h"
46 #include "uct/slave.h"
47 #include "uct/tree.h"
50 /* UCT infrastructure for a distributed engine slave. */
52 /* For debugging only. */
53 static struct hash_counts h_counts;
54 static long parent_not_found = 0;
55 static long parent_leaf = 0;
56 static long node_not_found = 0;
58 /* Hash table entry mapping path to node. */
59 struct tree_hash {
60 path_t coord_path;
61 struct tree_node *node;
64 void *
65 uct_htable_alloc(int hbits)
67 return calloc2(1 << hbits, sizeof(struct tree_hash));
70 /* Clear the hash table. Used only when running as slave for the distributed engine. */
71 void uct_htable_reset(struct tree *t)
73 if (!t->htable) return;
74 double start = time_now();
75 memset(t->htable, 0, (1 << t->hbits) * sizeof(t->htable[0]));
76 if (DEBUGL(3))
77 fprintf(stderr, "tree occupied %ld %.1f%% inserts %ld collisions %ld/%ld %.1f%% clear %.3fms\n"
78 "parent_not_found %.1f%% parent_leaf %.1f%% node_not_found %.1f%%\n",
79 h_counts.occupied, h_counts.occupied * 100.0 / (1 << t->hbits),
80 h_counts.inserts, h_counts.collisions, h_counts.lookups,
81 h_counts.collisions * 100.0 / (h_counts.lookups + 1),
82 (time_now() - start)*1000,
83 parent_not_found * 100.0 / (h_counts.lookups + 1),
84 parent_leaf * 100.0 / (h_counts.lookups + 1),
85 node_not_found * 100.0 / (h_counts.lookups + 1));
86 if (DEBUG_MODE) h_counts.occupied = 0;
89 /* Find a node given its coord path from root. Insert it in the
90 * hash table if it is not already there.
91 * Return the tree node, or NULL if the node cannot be found.
92 * The tree is modified in background while this function is running.
93 * prev is only used to optimize the tree search, given that calls to
94 * tree_find_node are made with sorted coordinates (increasing levels
95 * and increasing coord within a level). */
96 static struct tree_node *
97 tree_find_node(struct tree *t, struct incr_stats *is, struct tree_node *prev)
99 assert(t && t->htable);
100 path_t path = is->coord_path;
101 /* pass and resign must never be inserted in the hash table. */
102 assert(path > 0);
104 int hash, parent_hash;
105 bool found;
106 find_hash(hash, t->htable, t->hbits, path, found, h_counts);
107 struct tree_hash *hnode = &t->htable[hash];
109 if (DEBUGVV(7))
110 fprintf(stderr,
111 "find_node %"PRIpath" %s found %d hash %d playouts %d node %p\n", path,
112 path2sstr(path, t->board), found, hash, is->incr.playouts, hnode->node);
114 if (found) return hnode->node;
116 /* The master sends parents before children so the parent should
117 * already be in the hash table. */
118 path_t parent_p = parent_path(path, t->board);
119 struct tree_node *parent;
120 if (parent_p) {
121 find_hash(parent_hash, t->htable, t->hbits,
122 parent_p, found, h_counts);
123 parent = t->htable[parent_hash].node;
124 } else {
125 parent = t->root;
127 struct tree_node *node = NULL;
128 if (parent) {
129 /* Search for the node in parent's children. */
130 coord_t leaf = leaf_coord(path, t->board);
131 node = (prev && prev->parent == parent ? prev->sibling : parent->children);
132 while (node && node->coord != leaf) node = node->sibling;
134 if (DEBUG_MODE) parent_leaf += !parent->is_expanded;
135 } else {
136 if (DEBUG_MODE) parent_not_found++;
137 if (DEBUGVV(7))
138 fprintf(stderr, "parent of %"PRIpath" %s not found\n",
139 path, path2sstr(path, t->board));
142 /* Insert the node in the hash table. */
143 hnode->node = node;
144 if (DEBUG_MODE) h_counts.inserts++, h_counts.occupied++;
145 if (DEBUGVV(7))
146 fprintf(stderr, "insert path %"PRIpath" %s hash %d playouts %d node %p\n",
147 path, path2sstr(path, t->board), hash, is->incr.playouts, node);
149 if (DEBUG_MODE && !node) node_not_found++;
151 hnode->coord_path = path;
152 return node;
156 /* Read and discard any binary arguments. The number of
157 * bytes to be skipped is given by @size in the command. */
158 static void
159 discard_bin_args(char *args)
161 char *s = strchr(args, '@');
162 int size = 0;
163 if (s) size = atoi(s+1);
164 while (size) {
165 char buf[64*1024];
166 int len = sizeof(buf);
167 if (len > size) len = size;
168 len = fread(buf, 1, len, stdin);
169 if (len <= 0) break;
170 size -= len;
174 enum parse_code
175 uct_notify(struct engine *e, struct board *b, int id, char *cmd, char *args, char **reply)
177 struct uct *u = e->data;
179 static bool board_resized = false;
180 if (is_gamestart(cmd)) {
181 board_resized = true;
182 uct_pondering_stop(u);
185 /* Force resending the whole command history if we are out of sync
186 * but do it only once, not if already getting the history. */
187 if ((move_number(id) != b->moves || !board_resized)
188 && !reply_disabled(id) && !is_reset(cmd)) {
189 static char buf[128];
190 snprintf(buf, sizeof(buf), "Out of sync, %d %s, move %d expected", id, cmd, b->moves);
191 if (UDEBUGL(0))
192 fprintf(stderr, "%s\n", buf);
193 discard_bin_args(args);
195 *reply = buf;
196 /* Let gtp_parse() complain about invalid commands. */
197 if (!gtp_is_valid(cmd) && !is_repeated(cmd)) return P_OK;
198 return P_DONE_ERROR;
200 return reply_disabled(id) ? P_NOREPLY : P_OK;
204 /* Read the move stats sent by the master, as a binary array of
205 * incr_stats structs. The stats come sorted by increasing coord path.
206 * To simplify the code, we assume that master and slave have the same
207 * architecture (store values identically).
208 * Keep this code in sync with distributed/merge.c:output_stats()
209 * Return true if ok, false if error. */
210 static bool
211 receive_stats(struct uct *u, int size)
213 if (size % sizeof(struct incr_stats)) return false;
214 int nodes = size / sizeof(struct incr_stats);
215 if (nodes > (1 << u->stats_hbits)) return false;
217 struct tree *t = u->t;
218 assert(nodes && t->htable);
219 struct tree_node *prev = NULL;
220 double start_time = time_now();
222 for (int n = 0; n < nodes; n++) {
223 struct incr_stats is;
224 if (fread(&is, sizeof(struct incr_stats), 1, stdin) != 1)
225 return false;
227 if (UDEBUGL(7))
228 fprintf(stderr, "read %5d/%d %6d %.3f %"PRIpath" %s\n", n, nodes,
229 is.incr.playouts, is.incr.value, is.coord_path,
230 path2sstr(is.coord_path, t->board));
232 struct tree_node *node = tree_find_node(t, &is, prev);
233 if (!node) continue;
235 /* node_total += others_incr */
236 stats_add_result(&node->u, is.incr.value, is.incr.playouts);
238 /* last_total += others_incr */
239 stats_add_result(&node->pu, is.incr.value, is.incr.playouts);
241 prev = node;
243 if (DEBUGVV(2))
244 fprintf(stderr, "read args for %d nodes in %.4fms\n", nodes,
245 (time_now() - start_time)*1000);
246 return true;
249 /* A tree traversal fills this array, then the nodes with most increments are sent. */
250 struct stats_candidate {
251 path_t coord_path;
252 int playout_incr;
253 struct tree_node *node;
256 /* We maintain counts per bucket to avoid sorting stats_queue.
257 * All nodes with n updates since last send go to bucket n.
258 * If we put all nodes above 1023 updates in the top bucket,
259 * we get at most 27 nodes in this bucket. So we can select
260 * exactly the best shared_nodes nodes if shared_nodes >= 27. */
261 #define MAX_BUCKETS 1024
262 static int bucket_count[MAX_BUCKETS];
264 /* Traverse the tree rooted at node, and append incremental stats
265 * for children to stats_queue. start_path is the coordinate path
266 * for the top node. Stats for a node are only appended if enough playouts
267 * have been made since the last send, and the level is not too deep.
268 * Return the updated stats count. */
269 static int
270 append_stats(struct stats_candidate *stats_queue, struct tree_node *node, int stats_count,
271 int max_count, path_t start_path, path_t max_path, int min_increment, struct board *b)
273 /* The children field is set only after all children are created
274 * so we can traverse the the tree while it is updated. */
275 for (struct tree_node *ni = node->children; ni; ni = ni->sibling) {
277 if (is_pass(ni->coord)) continue;
278 if (ni->hints & TREE_HINT_INVALID) continue;
280 int incr = ni->u.playouts - ni->pu.playouts;
281 if (incr < min_increment) continue;
283 /* min_increment should be tuned to avoid overflow. */
284 if (stats_count >= max_count) {
285 if (DEBUGL(0))
286 fprintf(stderr, "*** stats overflow %d nodes\n", stats_count);
287 return stats_count;
289 path_t child_path = append_child(start_path, ni->coord, b);
290 stats_queue[stats_count].playout_incr = incr;
291 stats_queue[stats_count].coord_path = child_path;
292 stats_queue[stats_count++].node = ni;
294 if (incr >= MAX_BUCKETS) incr = MAX_BUCKETS - 1;
295 bucket_count[incr]++;
297 /* Do not recurse if level deep enough. */
298 if (child_path >= max_path) continue;
300 stats_count = append_stats(stats_queue, ni, stats_count, max_count,
301 child_path, max_path, min_increment, b);
303 return stats_count;
306 /* Used to sort by coord path the incremental stats to be sent. */
307 static int
308 coord_cmp(const void *p1, const void *p2)
310 path_t diff = ((struct incr_stats *)p1)->coord_path
311 - ((struct incr_stats *)p2)->coord_path;
312 return (int)(diff >> 32) | !!(int)diff;
315 /* Select from stats_queue at most shared_nodes candidates with
316 * biggest increments. Return a binary array sorted by coord path. */
317 static struct incr_stats *
318 select_best_stats(struct stats_candidate *stats_queue, int stats_count,
319 int shared_nodes, int *byte_size)
321 static struct incr_stats *out_stats = NULL;
322 if (!out_stats)
323 out_stats = malloc2(shared_nodes * sizeof(*out_stats));
325 /* Find the minimum increment to send. The bucket with minimum
326 * increment may be sent only partially. */
327 int out_count = 0;
328 int min_incr = MAX_BUCKETS;
329 do {
330 out_count += bucket_count[--min_incr];
331 } while (min_incr > 1 && out_count < shared_nodes);
333 /* Send all all increments > min_incr plus whatever we can at min_incr. */
334 int min_count = bucket_count[min_incr] - (out_count - shared_nodes);
335 struct incr_stats *os = out_stats;
336 out_count = 0;
337 for (int count = 0; count < stats_count; count++) {
338 int delta = stats_queue[count].playout_incr - min_incr;
339 if (delta < 0 || (delta == 0 && --min_count < 0)) continue;
341 struct tree_node *node = stats_queue[count].node;
342 os->incr = node->u;
343 stats_rm_result(&os->incr, node->pu.value, node->pu.playouts);
345 /* With virtual loss os->incr.playouts might be <= 0; we only
346 * send positive increments to other slaves so a virtual loss
347 * can be propagated to other machines (good). The undo of the
348 * virtual loss will be propagated later when node->u gets
349 * above node->pu. */
350 if (os->incr.playouts > 0) {
351 node->pu = node->u;
352 os->coord_path = stats_queue[count].coord_path;
353 assert(os->coord_path > 0);
354 os++;
355 out_count++;
357 assert (out_count <= shared_nodes);
359 *byte_size = (char *)os - (char *)out_stats;
361 /* Sort the increments by increasing coord path (required by master).
362 * Can be done in linear time with radix sort if qsort is too slow. */
363 qsort(out_stats, out_count, sizeof(*os), coord_cmp);
364 return out_stats;
367 /* Get incremental stats updates for the distributed engine.
368 * Return a binary array of incr_stats structs in coordinate order
369 * (increasing levels and increasing coordinates within a level).
370 * This function is called only by the main thread, but may be
371 * called while the tree is updated by the worker threads. Keep this
372 * code in sync with distributed/merge.c:merge_new_stats(). */
373 static void *
374 report_incr_stats(struct uct *u, int *stats_size)
376 double start_time = time_now();
378 struct tree_node *root = u->t->root;
379 struct board *b = u->t->board;
381 /* The factor 3 below has experimentally been found to be
382 * sufficient. At worst if we fill stats_queue we will
383 * discard some stats updates but this is rare. */
384 int max_nodes = 3 * u->shared_nodes;
385 static struct stats_candidate *stats_queue = NULL;
386 if (!stats_queue) stats_queue = malloc2(max_nodes * sizeof(*stats_queue));
388 memset(bucket_count, 0, sizeof(bucket_count));
390 /* Try to fill the output buffer with the most important
391 * nodes (highest increments), while still traversing
392 * as little of the tree as possible. If we set min_increment
393 * too low we waste time. If we set it too high we can't
394 * fill the output buffer with the desired number of nodes.
395 * The best min_increment results in stats_count just above
396 * shared_nodes. However perfect tuning is not necessary:
397 * if we send too few nodes we just send shorter buffers
398 * more frequently. */
399 static int min_increment = 1;
400 static int stats_count = 0;
401 if (stats_count > 2 * u->shared_nodes) {
402 min_increment++;
403 } else if (stats_count < u->shared_nodes / 2 && min_increment > 1) {
404 min_increment--;
407 stats_count = append_stats(stats_queue, root, 0, max_nodes, 0,
408 max_parent_path(u, b), min_increment, b);
410 void *buf = select_best_stats(stats_queue, stats_count, u->shared_nodes, stats_size);
412 if (DEBUGVV(2))
413 fprintf(stderr,
414 "min_incr %d games %d stats_queue %d/%d sending %d/%d in %.3fms\n",
415 min_increment, root->u.playouts - root->pu.playouts, stats_count,
416 max_nodes, *stats_size / (int)sizeof(struct incr_stats), u->shared_nodes,
417 (time_now() - start_time)*1000);
418 root->pu = root->u;
419 return buf;
422 /* Get stats for the distributed engine. Return a buffer with one
423 * line "played_own root_playouts threads keep_looking @size", then
424 * a list of lines "coord playouts value" with absolute counts for
425 * children of the root node (including contributions from other
426 * slaves). The last line must not end with \n.
427 * If c is pass or resign, add this move with a large weight.
428 * This function is called only by the main thread, but may be
429 * called while the tree is updated by the worker threads. Keep this
430 * code in sync with distributed/distributed.c:select_best_move(). */
431 static char *
432 report_stats(struct uct *u, struct board *b, coord_t c,
433 bool keep_looking, int bin_size)
435 static char reply[10240];
436 char *r = reply;
437 char *end = reply + sizeof(reply);
438 struct tree_node *root = u->t->root;
439 r += snprintf(r, end - r, "%d %d %d %d @%d", u->played_own, root->u.playouts,
440 u->threads, keep_looking, bin_size);
441 int min_playouts = root->u.playouts / 100;
442 int max_playouts = 1;
444 /* We rely on the fact that root->children is set only
445 * after all children are created. */
446 for (struct tree_node *ni = root->children; ni; ni = ni->sibling) {
448 if (is_pass(ni->coord)) continue;
449 if (ni->u.playouts > max_playouts)
450 max_playouts = ni->u.playouts;
451 if (ni->u.playouts <= min_playouts || ni->hints & TREE_HINT_INVALID)
452 continue;
454 assert(ni->coord > 0 && ni->coord < board_size2(b));
455 char buf[4];
456 /* We return the values as stored in the tree, so from black's view. */
457 r += snprintf(r, end - r, "\n%s %d %.16f", coord2bstr(buf, ni->coord, b),
458 ni->u.playouts, ni->u.value);
460 /* Give a large but not infinite weight to pass or resign, to avoid forcing
461 * resign if other slaves don't like it. */
462 if (is_pass(c) || is_resign(c)) {
463 double resign_value = u->t->root_color == S_WHITE ? 0.0 : 1.0;
464 double c_value = is_resign(c) ? resign_value : 1.0 - resign_value;
465 r += snprintf(r, end - r, "\n%s %d %.1f", coord2sstr(c, b),
466 2 * max_playouts, c_value);
468 return reply;
471 /* genmoves is issued by the distributed engine master to all slaves, to:
472 * 1. Start a MCTS search if not running yet
473 * 2. Report current move statistics of the on-going search.
474 * The MCTS search is left running on the background when uct_genmoves()
475 * returns. It is stopped by receiving a play GTP command, triggering
476 * uct_pondering_stop(). */
477 /* genmoves gets in the args parameter
478 * "played_games nodes main_time byoyomi_time byoyomi_periods byoyomi_stones @size"
479 * and reads a binary array of coord, playouts, value to get stats of other slaves,
480 * except possibly for the first call at a given move number.
481 * See report_stats() for the description of the return value. */
482 char *
483 uct_genmoves(struct engine *e, struct board *b, struct time_info *ti, enum stone color,
484 char *args, bool pass_all_alive, void **stats_buf, int *stats_size)
486 struct uct *u = e->data;
487 assert(u->slave);
489 /* Prepare the state if the search is not already running.
490 * We must do this first since we tweak the state below
491 * based on instructions from the master. */
492 if (!thread_manager_running)
493 uct_genmove_setup(u, b, color);
495 /* Get playouts and time information from master. Keep this code
496 * in sync with distibuted/distributed.c:distributed_genmove(). */
497 if ((ti->dim == TD_WALLTIME
498 && sscanf(args, "%d %lf %lf %d %d", &u->played_all,
499 &ti->len.t.main_time, &ti->len.t.byoyomi_time,
500 &ti->len.t.byoyomi_periods, &ti->len.t.byoyomi_stones) != 5)
502 || (ti->dim == TD_GAMES && sscanf(args, "%d", &u->played_all) != 1)) {
503 return NULL;
506 static struct uct_search_state s;
507 if (!thread_manager_running) {
508 /* This is the first genmoves issue, start the MCTS
509 * now and let it run while we receive stats. */
510 memset(&s, 0, sizeof(s));
511 uct_search_start(u, b, color, u->t, ti, &s);
514 /* Read binary incremental stats if present, otherwise
515 * wait a bit to populate the statistics. */
516 int size = 0;
517 char *sizep = strchr(args, '@');
518 if (sizep) size = atoi(sizep+1);
519 if (!size) {
520 time_sleep(u->stats_delay);
521 } else if (!receive_stats(u, size)) {
522 return NULL;
525 /* Check the state of the Monte Carlo Tree Search. */
527 int played_games = uct_search_games(&s);
528 uct_search_progress(u, b, color, u->t, ti, &s, played_games);
529 u->played_own = played_games - s.base_playouts;
531 *stats_size = 0;
532 bool keep_looking = false;
533 coord_t best_coord = pass;
534 if (b->fbook)
535 best_coord = fbook_check(b);
536 if (best_coord == pass) {
537 keep_looking = !uct_search_check_stop(u, b, color, u->t, ti, &s, played_games);
538 uct_search_result(u, b, color, pass_all_alive, played_games, s.base_playouts, &best_coord);
540 if (u->shared_levels) {
541 *stats_buf = report_incr_stats(u, stats_size);
544 char *reply = report_stats(u, b, best_coord, keep_looking, *stats_size);
545 return reply;