create regress_simple for Python
[liba.git] / test / regress_linear.h
blobd5a81cdcdacfd694ac8dbe018216c69ebc615f51
1 #define MAIN(x) regress_linear##x
2 #include "test.h"
3 #include "a/regress_linear.h"
4 #include "a/math.h"
5 #include <time.h>
7 #define RAND_MAX_ 2147483647
8 static unsigned long rand_seed = 1;
9 static void srand_(unsigned long seed)
11 rand_seed = seed;
13 static long rand_(void)
15 rand_seed = (rand_seed * 1103515245 + 12345) % 2147483648;
16 return a_cast_s(long, rand_seed);
19 typedef struct config
21 a_float delta;
22 a_float lrmax;
23 a_float lrmin;
24 a_size lrtim;
25 a_size epoch;
26 a_size batch;
27 } config;
29 static void main_1(int m, a_float a, a_float b, a_size n, config const *cfg)
31 a_float *x = a_new(a_float, A_NULL, n);
32 a_float *y = a_new(a_float, A_NULL, n);
33 a_float *e = a_new(a_float, A_NULL, n);
34 long x_n = a_cast_s(long, n) * 10;
35 long y_n = a_cast_s(long, n) * 2;
37 for (a_size i = 0; i < n; ++i)
39 x[i] = a_cast_s(a_float, rand_() % x_n);
40 y[i] = a * x[i] + b + a_cast_s(a_float, rand_() % y_n) - a_cast_s(a_float, n);
43 a_float coef[] = {1};
45 a_regress_linear ctx;
46 a_regress_linear_init(&ctx, coef, 1, 1);
47 a_regress_linear_zero(&ctx);
49 switch (m)
51 default:
52 case 's':
54 a_float lrcur = 0;
55 a_float const lramp = (cfg->lrmax - cfg->lrmin) / 2;
56 a_float const lrper = A_FLOAT_PI / a_float_c(cfg->lrtim);
57 a_regress_linear_err1(&ctx, n, x, y, e);
58 a_float r = a_float_sum2(e, n);
59 for (a_size i = 0; i < cfg->epoch; ++i)
61 a_float alpha = cfg->lrmin + lramp * (a_float_cos(lrcur) + 1);
62 a_regress_linear_sgd1(&ctx, n, x, y, alpha);
63 a_regress_linear_err1(&ctx, n, x, y, e);
64 a_float s = a_float_sum2(e, n);
65 if (A_ABS_(r, s) < cfg->delta)
67 break;
69 lrcur += lrper;
70 r = s;
72 break;
74 case 'b':
76 a_float lrcur = 0;
77 a_float const lramp = (cfg->lrmax - cfg->lrmin) / 2;
78 a_float const lrper = A_FLOAT_PI / a_float_c(cfg->lrtim);
79 a_regress_linear_err1(&ctx, n, x, y, e);
80 a_float r = a_float_sum2(e, n);
81 for (a_size i = 0; i < cfg->epoch; ++i)
83 a_float alpha = cfg->lrmin + lramp * (a_float_cos(lrcur) + 1);
84 a_regress_linear_bgd1(&ctx, n, x, e, alpha);
85 a_regress_linear_err1(&ctx, n, x, y, e);
86 a_float s = a_float_sum2(e, n);
87 if (A_ABS_(r, s) < cfg->delta)
89 break;
91 lrcur += lrper;
92 r = s;
94 break;
96 case 'm':
97 a_regress_linear_mgd1(&ctx, n, x, y, e, cfg->delta, cfg->lrmax, cfg->lrmin, cfg->lrtim, cfg->epoch, cfg->batch);
100 for (unsigned int i = 0; i < n; ++i)
102 a_float u = a_cast_s(a_float, i * 10);
103 a_float v = a_regress_linear_eval(&ctx, &u);
104 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,")
105 A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f\n"),
106 u, v, x[i], y[i], e[i]);
109 a_regress_linear_zero(&ctx);
111 a_die(x);
112 a_die(y);
113 a_die(e);
116 static void main_2(int m, a_float a, a_float b, a_float c, a_size n, config const *cfg)
118 a_float *x = a_new(a_float, A_NULL, n * 2);
119 a_float *y = a_new(a_float, A_NULL, n);
120 a_float *e = a_new(a_float, A_NULL, n);
121 long x_n = a_cast_s(long, n) * 10;
122 long y_n = a_cast_s(long, n) * 2;
124 for (a_size i = 0; i < n; ++i)
126 x[i * 2 + 0] = a_cast_s(a_float, rand_() % x_n);
127 x[i * 2 + 1] = a_cast_s(a_float, rand_() % x_n);
128 y[i] = a * x[i * 2 + 0] + b * x[i * 2 + 1] + c +
129 a_cast_s(a_float, rand_() % y_n) - a_cast_s(a_float, n);
132 a_float coef[2] = {1, 1};
134 a_regress_linear ctx;
135 a_regress_linear_init(&ctx, coef, 2, 1);
136 a_regress_linear_zero(&ctx);
138 switch (m)
140 default:
141 case 's':
143 a_float lrcur = 0;
144 a_float const lramp = (cfg->lrmax - cfg->lrmin) / 2;
145 a_float const lrper = A_FLOAT_PI / a_float_c(cfg->lrtim);
146 a_regress_linear_err1(&ctx, n, x, y, e);
147 a_float r = a_float_sum2(e, n);
148 for (a_size i = 0; i < cfg->epoch; ++i)
150 a_float alpha = cfg->lrmin + lramp * (a_float_cos(lrcur) + 1);
151 a_regress_linear_sgd1(&ctx, n, x, y, alpha);
152 a_regress_linear_err1(&ctx, n, x, y, e);
153 a_float s = a_float_sum2(e, n);
154 if (A_ABS_(r, s) < cfg->delta)
156 break;
158 lrcur += lrper;
159 r = s;
161 break;
163 case 'b':
165 a_float lrcur = 0;
166 a_float const lramp = (cfg->lrmax - cfg->lrmin) / 2;
167 a_float const lrper = A_FLOAT_PI / a_float_c(cfg->lrtim);
168 a_regress_linear_err1(&ctx, n, x, y, e);
169 a_float r = a_float_sum2(e, n);
170 for (a_size i = 0; i < cfg->epoch; ++i)
172 a_float alpha = cfg->lrmin + lramp * (a_float_cos(lrcur) + 1);
173 a_regress_linear_bgd1(&ctx, n, x, e, alpha);
174 a_regress_linear_err1(&ctx, n, x, y, e);
175 a_float s = a_float_sum2(e, n);
176 if (A_ABS_(r, s) < cfg->delta)
178 break;
180 lrcur += lrper;
181 r = s;
183 break;
185 case 'm':
186 a_regress_linear_mgd1(&ctx, n, x, y, e, cfg->delta, cfg->lrmax, cfg->lrmin, cfg->lrtim, cfg->epoch, cfg->batch);
189 for (unsigned int i = 0; i < n; ++i)
191 a_float u[2];
192 u[0] = a_cast_s(a_float, i * 10);
193 for (unsigned int j = 0; j < n; ++j)
195 u[1] = a_cast_s(a_float, j * 10);
196 a_float v = a_regress_linear_eval(&ctx, u);
197 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f"), u[0], u[1], v);
198 debug("%c", i ? '\n' : ',');
199 if (i == 0)
201 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,")
202 A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f\n"),
203 x[j * 2 + 0], x[j * 2 + 1], y[j], e[j]);
206 debug("\n");
209 a_regress_linear_zero(&ctx);
211 a_die(x);
212 a_die(y);
213 a_die(e);
216 int main(int argc, char *argv[]) // NOLINT(misc-definitions-in-headers)
218 srand_(a_cast_s(a_ulong, time(A_NULL)));
219 main_init(argc, argv, 1);
221 a_float a = A_FLOAT_C(0.7);
222 a_float b = A_FLOAT_C(1.4);
223 a_float c = 12;
224 a_size n = 100;
225 char m = 'm';
226 int d = 1;
228 config cfg;
229 cfg.delta = 1;
230 cfg.lrmax = A_FLOAT_C(5e-7);
231 cfg.lrmin = A_FLOAT_C(5e-9);
232 cfg.lrtim = 100;
233 cfg.epoch = 1000;
234 cfg.batch = 16;
236 if (argc > 1)
238 char const *s = strstr(argv[1], "regress_linear_");
239 if (s) { sscanf(s, "regress_linear_%i%c", &d, &m); } // NOLINT
240 else
242 debug("regress_linear_1bgd.csv\n");
243 debug("regress_linear_1sgd.csv\n");
244 debug("regress_linear_1mgd.csv\n");
245 debug("regress_linear_2bgd.csv\n");
246 debug("regress_linear_2sgd.csv\n");
247 debug("regress_linear_2mgd.csv\n");
248 return 0;
252 char *endptr;
253 if (d == 1)
255 if (argc > 2) { a = strtonum(argv[2], &endptr); }
256 if (argc > 3) { c = strtonum(argv[3], &endptr); }
257 if (argc > 4) { n = strtoul(argv[4], &endptr, 0); }
258 main_1(m, a, c, n, &cfg);
260 if (d == 2)
262 if (argc > 2) { a = strtonum(argv[2], &endptr); }
263 if (argc > 3) { b = strtonum(argv[3], &endptr); }
264 if (argc > 4) { c = strtonum(argv[4], &endptr); }
265 if (argc > 5) { n = strtoul(argv[5], &endptr, 0); }
266 main_2(m, a, b, c, n, &cfg);
269 #if defined(__cplusplus) && (__cplusplus > 201100L)
270 A_BUILD_ASSERT(std::is_pod<a_regress_linear>::value);
271 #endif /* __cplusplus */
273 return 0;