support sqrt
[fpmath-consensus.git] / impl-libc / impl-libc.c
blob9a1dc9d8b2d1d0935f67831a7ffaf86a2dab1ca9
1 #include <errno.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include <unistd.h>
9 /* Whether we're looking at 32-bit or 64-bit floats */
10 typedef enum { P_SINGLE, P_DOUBLE } precision;
12 /* What type of arguments we expect, and what we'll give back. */
13 typedef enum {
14 /* */
15 A_UNKNOWN,
16 A__FLT__FLT,
17 A__FLT_FLT_FLT__FLT,
18 } argtype;
20 /* Types of functions we could call */
21 typedef float (*f__f32__f32)(float);
22 typedef float (*f__f32_f32_f32__f32)(float, float, float);
23 typedef double (*f__f64__f64)(double);
24 typedef double (*f__f64_f64_f64__f64)(double, double, double);
26 /* Wrapper around a function pointer */
27 typedef struct {
28 /* */
29 precision p;
30 argtype a;
32 union {
33 /* */
34 f__f32__f32 f32__f32;
35 f__f32_f32_f32__f32 f32_f32_f32__f32;
36 } f32;
38 union {
39 /* */
40 f__f64__f64 f64__f64;
41 f__f64_f64_f64__f64 f64_f64_f64__f64;
42 } f64;
44 } action;
46 void usage(void)
48 fprintf(stderr,
49 "usage: impl-libc [-s|-d] -f <function_name> -n <num_inputs>\n");
50 _exit(1);
53 float idf(float f)
55 return f;
58 double idd(double d)
60 return d;
63 void determine_function(const char *f, action *a)
65 if (!strcmp(f, "zzzzzz")) {
66 a->a = A_UNKNOWN;
67 } else if (!strcmp(f, "id")) {
68 a->a = A__FLT__FLT;
69 a->f32.f32__f32 = idf;
70 a->f64.f64__f64 = idd;
71 } else if (!strcmp(f, "ceil")) {
72 a->a = A__FLT__FLT;
73 a->f32.f32__f32 = ceilf;
74 a->f64.f64__f64 = ceil;
75 } else if (!strcmp(f, "cos")) {
76 a->a = A__FLT__FLT;
77 a->f32.f32__f32 = cosf;
78 a->f64.f64__f64 = cos;
79 } else if (!strcmp(f, "floor")) {
80 a->a = A__FLT__FLT;
81 a->f32.f32__f32 = floorf;
82 a->f64.f64__f64 = floor;
83 } else if (!strcmp(f, "sin")) {
84 a->a = A__FLT__FLT;
85 a->f32.f32__f32 = sinf;
86 a->f64.f64__f64 = sin;
87 } else if (!strcmp(f, "trunc")) {
88 a->a = A__FLT__FLT;
89 a->f32.f32__f32 = truncf;
90 a->f64.f64__f64 = trunc;
91 } else if (!strcmp(f, "fma")) {
92 a->a = A__FLT_FLT_FLT__FLT;
93 a->f32.f32_f32_f32__f32 = fmaf;
94 a->f64.f64_f64_f64__f64 = fma;
95 } else if (!strcmp(f, "sqrt")) {
96 a->a = A__FLT__FLT;
97 a->f32.f32__f32 = sqrtf;
98 a->f64.f64__f64 = sqrt;
99 } else {
100 fprintf(stderr, "impl-libc: unknown function \"%s\"\n", f);
101 _exit(1);
105 void read_buf(char *b, ssize_t len)
107 ssize_t r;
108 ssize_t total = 0;
110 while (total < len) {
111 r = read(0, (b + total), (len - total));
113 if (!r) {
114 /* EOF */
115 _exit(0);
116 } else if (r == -1) {
117 perror("impl-libc: read");
118 _exit(1);
119 } else {
120 total += r;
125 void write_buf(const char *b, ssize_t len)
127 ssize_t r;
128 ssize_t total = 0;
130 while (total < len) {
131 r = write(1, (b + total), (len - total));
133 if (r == -1) {
134 perror("impl-libc: write");
135 _exit(1);
136 } else {
137 total += r;
142 size_t input_width(argtype a, precision p)
144 size_t w = (p == P_SINGLE) ? 4 : 8;
146 switch (a) {
147 case A_UNKNOWN:
148 break;
149 case A__FLT__FLT:
151 return 1 * w;
152 case A__FLT_FLT_FLT__FLT:
154 return 3 * w;
157 return (size_t) -1;
160 size_t output_width(argtype a, precision p)
162 size_t w = (p == P_SINGLE) ? 4 : 8;
164 switch (a) {
165 case A_UNKNOWN:
166 break;
167 case A__FLT__FLT:
169 return 1 * w;
170 case A__FLT_FLT_FLT__FLT:
172 return 1 * w;
175 return (size_t) -1;
178 void io_loop(action a, size_t n)
180 char *in_buf = 0;
181 char *out_buf = 0;
182 size_t in_sz = input_width(a.a, a.p);
183 size_t out_sz = output_width(a.a, a.p);
185 if ((in_sz * n) / n != in_sz) {
186 fprintf(stderr, "impl-libc: input length overflow\n");
187 _exit(1);
190 if ((out_sz * n) / n != out_sz) {
191 fprintf(stderr, "impl-libc: output length overflow\n");
192 _exit(1);
195 if (!(in_buf = malloc(in_sz * n))) {
196 perror("impl-libc: malloc");
197 _exit(1);
200 if (!(out_buf = malloc(out_sz * n))) {
201 perror("impl-libc: malloc");
202 _exit(1);
205 while (1) {
206 read_buf(in_buf, in_sz * n);
208 switch (a.a) {
209 case A_UNKNOWN:
210 fprintf(stderr, "impl-libc: impossible\n");
211 _exit(1);
212 break;
213 case A__FLT__FLT:
215 switch (a.p) {
216 case P_SINGLE:
218 for (size_t j = 0; j < n; ++j) {
219 float *x = (float *) (in_buf + (in_sz *
220 j));
221 float *y = (float *) (out_buf +
222 (out_sz * j));
224 *y = a.f32.f32__f32(*x);
227 break;
228 case P_DOUBLE:
230 for (size_t j = 0; j < n; ++j) {
231 double *x = (double *) (in_buf +
232 (in_sz * j));
233 double *y = (double *) (out_buf +
234 (out_sz * j));
236 *y = a.f64.f64__f64(*x);
239 break;
242 break;
243 case A__FLT_FLT_FLT__FLT:
245 switch (a.p) {
246 case P_SINGLE:
248 for (size_t j = 0; j < n; ++j) {
249 float *x1 = (float *) (in_buf + (in_sz *
250 j));
251 float *x2 = (float *) (in_buf + (in_sz *
253 4));
254 float *x3 = (float *) (in_buf + (in_sz *
256 8));
257 float *y = (float *) (out_buf +
258 (out_sz * j));
260 *y = a.f32.f32_f32_f32__f32(*x1, *x2,
261 *x3);
264 break;
265 case P_DOUBLE:
267 for (size_t j = 0; j < n; ++j) {
268 double *x1 = (double *) (in_buf +
269 (in_sz * j));
270 double *x2 = (double *) (in_buf +
271 (in_sz * j) +
273 double *x3 = (double *) (in_buf +
274 (in_sz * j) +
275 16);
276 double *y = (double *) (out_buf +
277 (out_sz * j));
279 *y = a.f64.f64_f64_f64__f64(*x1, *x2,
280 *x3);
283 break;
286 break;
289 write_buf(out_buf, out_sz * n);
293 int main(int argc, char **argv)
295 int c = 0;
296 action a = { .p = P_SINGLE };
297 long long n = 0;
299 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
300 switch (c) {
301 case 's':
302 a.p = P_SINGLE;
303 break;
304 case 'd':
305 a.p = P_DOUBLE;
306 break;
307 case 'f':
308 determine_function(optarg, &a);
309 break;
310 case 'n':
311 errno = 0;
312 n = strtoll(optarg, 0, 0);
314 if (errno) {
315 perror("impl-libc: unparsable");
317 return 1;
320 break;
321 default:
322 usage();
323 break;
327 if (a.a == A_UNKNOWN) {
328 usage();
331 io_loop(a, n);