change e,y,x,n to n,x,y,e for regress
[liba.git] / test / regress_linear.h
blob6c2405e7bd6f0575f58cba7bff2d7f6c8a34d2b4
1 #define MAIN(x) regress_linear##x
2 #include "test.h"
3 #include "a/regress_linear.h"
4 #include <time.h>
6 #define RAND_MAX_ 2147483647
7 static unsigned long rand_seed = 1;
8 static void srand_(unsigned long seed)
10 rand_seed = seed;
12 static long rand_(void)
14 rand_seed = (rand_seed * 1103515245 + 12345) % 2147483648;
15 return a_cast_s(long, rand_seed);
18 static void main_1(int m, a_float a, a_float b, a_size n, a_float alpha, a_float threshold)
20 a_float *x = a_new(a_float, A_NULL, n);
21 a_float *y = a_new(a_float, A_NULL, n);
22 a_float *e = a_new(a_float, A_NULL, n);
23 long x_n = a_cast_s(long, n) * 10;
24 long y_n = a_cast_s(long, n) * 2;
26 for (a_size i = 0; i < n; ++i)
28 x[i] = a_cast_s(a_float, rand_() % x_n);
29 y[i] = a * x[i] + b + a_cast_s(a_float, rand_() % y_n) - a_cast_s(a_float, n);
32 a_float coef[] = {1};
34 a_regress_linear ctx;
35 a_regress_linear_init(&ctx, coef, 1, 1);
36 a_regress_linear_zero(&ctx);
38 switch (m)
40 default:
41 case 's':
43 a_regress_linear_err1(&ctx, n, x, y, e);
44 a_float r = a_float_sum2(e, n);
45 for (a_size i = 0; i < 100; ++i)
47 a_regress_linear_bgd1(&ctx, n, x, e, alpha);
48 a_regress_linear_err1(&ctx, n, x, y, e);
49 a_float s = a_float_sum2(e, n);
50 if (A_ABS_(r, s) < threshold)
52 break;
54 r = s;
56 break;
58 case 'b':
60 a_regress_linear_err1(&ctx, n, x, y, e);
61 a_float r = a_float_sum2(e, n);
62 for (a_size i = 0; i < 100; ++i)
64 a_regress_linear_sgd1(&ctx, n, x, y, alpha);
65 a_regress_linear_err1(&ctx, n, x, y, e);
66 a_float s = a_float_sum2(e, n);
67 if (A_ABS_(r, s) < threshold)
69 break;
71 r = s;
73 break;
75 case 'm':
76 a_regress_linear_mgd1(&ctx, n, x, y, e, alpha, threshold, 100, 16);
79 for (unsigned int i = 0; i < n; ++i)
81 a_float u = a_cast_s(a_float, i * 10);
82 a_float v = a_regress_linear_eval(&ctx, &u);
83 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,")
84 A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f\n"),
85 u, v, x[i], y[i], e[i]);
88 a_regress_linear_zero(&ctx);
90 a_die(x);
91 a_die(y);
92 a_die(e);
95 static void main_2(int m, a_float a, a_float b, a_float c, a_size n, a_float alpha, a_float threshold)
97 a_float *x = a_new(a_float, A_NULL, n * 2);
98 a_float *y = a_new(a_float, A_NULL, n);
99 a_float *e = a_new(a_float, A_NULL, n);
100 long x_n = a_cast_s(long, n) * 10;
101 long y_n = a_cast_s(long, n) * 2;
103 for (a_size i = 0; i < n; ++i)
105 x[i * 2 + 0] = a_cast_s(a_float, rand_() % x_n);
106 x[i * 2 + 1] = a_cast_s(a_float, rand_() % x_n);
107 y[i] = a * x[i * 2 + 0] + b * x[i * 2 + 1] + c +
108 a_cast_s(a_float, rand_() % y_n) - a_cast_s(a_float, n);
111 a_float coef[2] = {1, 1};
113 a_regress_linear ctx;
114 a_regress_linear_init(&ctx, coef, 2, 1);
115 a_regress_linear_zero(&ctx);
117 switch (m)
119 default:
120 case 's':
122 a_regress_linear_err1(&ctx, n, x, y, e);
123 a_float r = a_float_sum2(e, n);
124 for (a_size i = 0; i < 100; ++i)
126 a_regress_linear_bgd1(&ctx, n, x, e, alpha);
127 a_regress_linear_err1(&ctx, n, x, y, e);
128 a_float s = a_float_sum2(e, n);
129 if (A_ABS_(r, s) < threshold)
131 break;
133 r = s;
135 break;
137 case 'b':
139 a_regress_linear_err1(&ctx, n, x, y, e);
140 a_float r = a_float_sum2(e, n);
141 for (a_size i = 0; i < 100; ++i)
143 a_regress_linear_sgd1(&ctx, n, x, e, alpha);
144 a_regress_linear_err1(&ctx, n, x, y, e);
145 a_float s = a_float_sum2(e, n);
146 if (A_ABS_(r, s) < threshold)
148 break;
150 r = s;
152 break;
154 case 'm':
155 a_regress_linear_mgd1(&ctx, n, x, y, e, alpha, threshold, 100, 16);
158 for (unsigned int i = 0; i < n; ++i)
160 a_float u[2];
161 u[0] = a_cast_s(a_float, i * 10);
162 for (unsigned int j = 0; j < n; ++j)
164 u[1] = a_cast_s(a_float, j * 10);
165 a_float v = a_regress_linear_eval(&ctx, u);
166 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f"), u[0], u[1], v);
167 debug("%c", i ? '\n' : ',');
168 if (i == 0)
170 debug(A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f,")
171 A_FLOAT_PRI("+.1", "f,") A_FLOAT_PRI("+.1", "f\n"),
172 x[j * 2 + 0], x[j * 2 + 1], y[j], e[j]);
175 debug("\n");
178 a_regress_linear_zero(&ctx);
180 a_die(x);
181 a_die(y);
182 a_die(e);
185 int main(int argc, char *argv[]) // NOLINT(misc-definitions-in-headers)
187 srand_(a_cast_s(a_ulong, time(A_NULL)));
188 main_init(argc, argv, 1);
190 a_float threshold = A_FLOAT_C(1.0);
191 a_float alpha = A_FLOAT_C(5e-8);
192 a_float a = A_FLOAT_C(0.7);
193 a_float b = A_FLOAT_C(1.4);
194 a_float c = 12;
195 a_size n = 100;
196 char m = 'm';
197 int d = 1;
199 if (argc > 1)
201 char const *s = strstr(argv[1], "regress_linear_");
202 if (s) { sscanf(s, "regress_linear_%i%c", &d, &m); } // NOLINT
203 else
205 debug("regress_linear_1bgd.csv\n");
206 debug("regress_linear_1sgd.csv\n");
207 debug("regress_linear_1mgd.csv\n");
208 debug("regress_linear_2bgd.csv\n");
209 debug("regress_linear_2sgd.csv\n");
210 debug("regress_linear_2mgd.csv\n");
211 return 0;
215 if (m == 's') { alpha = A_FLOAT_C(2e-8); }
217 char *endptr;
218 if (d == 1)
220 if (argc > 2) { a = strtonum(argv[2], &endptr); }
221 if (argc > 3) { c = strtonum(argv[3], &endptr); }
222 if (argc > 4) { n = strtoul(argv[4], &endptr, 0); }
223 if (argc > 5) { alpha = strtonum(argv[5], &endptr); }
224 if (argc > 6) { threshold = strtonum(argv[6], &endptr); }
225 main_1(m, a, c, n, alpha, threshold);
227 if (d == 2)
229 if (argc > 2) { a = strtonum(argv[2], &endptr); }
230 if (argc > 3) { b = strtonum(argv[3], &endptr); }
231 if (argc > 4) { c = strtonum(argv[4], &endptr); }
232 if (argc > 5) { n = strtoul(argv[5], &endptr, 0); }
233 if (argc > 6) { alpha = strtonum(argv[6], &endptr); }
234 if (argc > 7) { threshold = strtonum(argv[7], &endptr); }
235 main_2(m, a, b, c, n, alpha, threshold);
238 #if defined(__cplusplus) && (__cplusplus > 201100L)
239 A_BUILD_ASSERT(std::is_pod<a_regress_linear>::value);
240 #endif /* __cplusplus */
242 return 0;