Minor changes
[matilda.git] / src / data_set.c
blob7c7a7e481b81f3b82801af14fccd991e04e2629a
1 /*
2 Data set collection manipulation functions.
4 A data set file is defined as 4 bytes (unsigned int) indicating the
5 number of training cases, followed by the elements of struct
6 training_example type.
8 The examples are stored unique and invariant of flips and rotations, when
9 loaded via the data_set_load function they are fliped and rotated to increase
10 the data set size.
13 #include "config.h"
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include <assert.h>
18 #include <unistd.h>
19 #include <string.h>
21 #include "alloc.h"
22 #include "data_set.h"
23 #include "engine.h"
24 #include "flog.h"
25 #include "randg.h"
26 #include "types.h"
28 static u32 data_set_size;
29 static training_example ** data_set = NULL;
32 Shuffle all first num entries.
33 Fisher–Yates shuffle.
35 void data_set_shuffle(
36 u32 num
37 ) {
38 assert(num > 1 && num <= data_set_size);
40 u32 i;
41 for (i = num - 1; i > 0; --i) {
42 u32 j = rand_u32(i + 1);
43 training_example * tmp = data_set[i];
44 data_set[i] = data_set[j];
45 data_set[j] = tmp;
50 Shuffle all data set.
51 Fisher–Yates shuffle.
53 void data_set_shuffle_all() {
54 data_set_shuffle(data_set_size);
58 Read a data set and shuffle it.
59 RETURNS table set size (number of cases)
61 u32 data_set_load() {
62 return data_set_load2(UINT32_MAX);
66 Read a data set, with a maximum size, and shuffles it.
67 RETURNS table set size (number of cases)
69 u32 data_set_load2(
70 u32 max
71 ) {
72 assert(data_set == NULL);
74 char * filename = alloc();
75 snprintf(filename, MAX_PAGE_SIZ, "%s%dx%d.ds", data_folder(), BOARD_SIZ, BOARD_SIZ);
76 FILE * fp = fopen(filename, "rb");
77 release(filename);
79 if (fp == NULL) {
80 flog_crit("dset", "could not open file for reading\n");
83 u32 ds_elems;
84 size_t r = fread(&ds_elems, sizeof(u32), 1, fp);
86 if (r != 1) {
87 flog_crit("dset", "communication failure\n");
90 assert(ds_elems > 0);
92 ds_elems = MIN(ds_elems, max);
94 data_set = malloc(sizeof(training_example *) * ds_elems * 8);
95 if (data_set == NULL) {
96 flog_crit("dset", "system out of memory\n");
99 u32 insert = 0;
100 u32 i;
101 for (i = 0; i < ds_elems; ++i) {
102 data_set[insert] = malloc(sizeof(training_example));
104 if (data_set[insert] == NULL) {
105 flog_crit("dset", "system out of memory (1)\n");
108 r = fread(data_set[insert], sizeof(training_example), 1, fp);
109 assert(r == 1);
110 u32 base_insert = insert;
111 ++insert;
114 Generate more (0-7) cases from reduced ones
116 board tmp;
117 for (d8 r = 2; r < 9; ++r) {
118 memcpy(&tmp.p, &data_set[base_insert]->p, TOTAL_BOARD_SIZ);
119 tmp.last_played = tmp.last_eaten = NONE;
120 reduce_fixed(&tmp, r);
122 bool repeated = false;
123 for (u32 j = base_insert; j < insert; ++j) {
124 if (memcmp(&tmp.p, &data_set[j]->p, TOTAL_BOARD_SIZ) == 0) {
125 repeated = true;
126 break;
130 if (repeated) {
131 continue;
134 data_set[insert] = malloc(sizeof(training_example));
135 if (data_set[insert] == NULL) {
136 flog_crit("dset", "system out of memory (2)\n");
139 memcpy(&data_set[insert]->p, &tmp.p, TOTAL_BOARD_SIZ);
141 data_set[insert]->m = data_set[base_insert]->m;
142 data_set[insert]->m = reduce_move(data_set[insert]->m, r);
143 ++insert;
146 data_set_size = insert;
149 fclose(fp);
151 data_set_shuffle_all();
153 char * s = alloc();
154 snprintf(s, MAX_PAGE_SIZ, "Data set loaded with %u examples, yielding %u examples\n", ds_elems, data_set_size);
155 flog_info("dset", s);
156 release(s);
157 return data_set_size;
161 Get a specific data set element by position.
163 training_example * data_set_get(
164 u32 pos
166 return data_set[pos];