support fma
[fpmath-consensus.git] / impl-libc / impl-libc.c
blob2bde25d24231c25c499fad86997279d53f08d9cf
1 #include <errno.h>
2 #include <math.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
7 #include <unistd.h>
9 /* Whether we're looking at 32-bit or 64-bit floats */
10 typedef enum { P_SINGLE, P_DOUBLE } precision;
12 /* What type of arguments we expect, and what we'll give back. */
13 typedef enum {
14 /* */
15 A_UNKNOWN,
16 A__FLT__FLT,
17 A__FLT_FLT_FLT__FLT,
18 } argtype;
20 /* Types of functions we could call */
21 typedef float (*f__f32__f32)(float);
22 typedef float (*f__f32_f32_f32__f32)(float, float, float);
23 typedef double (*f__f64__f64)(double);
24 typedef double (*f__f64_f64_f64__f64)(double, double, double);
26 /* Wrapper around a function pointer */
27 typedef struct {
28 /* */
29 precision p;
30 argtype a;
32 union {
33 /* */
34 f__f32__f32 f32__f32;
35 f__f32_f32_f32__f32 f32_f32_f32__f32;
36 } f32;
38 union {
39 /* */
40 f__f64__f64 f64__f64;
41 f__f64_f64_f64__f64 f64_f64_f64__f64;
42 } f64;
44 } action;
46 void usage(void)
48 fprintf(stderr,
49 "usage: impl-libc [-s|-d] -f <function_name> -n <num_inputs>\n");
50 _exit(1);
53 float idf(float f)
55 return f;
58 double idd(double d)
60 return d;
63 void determine_function(const char *f, action *a)
65 if (!strcmp(f, "zzzzzz")) {
66 a->a = A_UNKNOWN;
67 } else if (!strcmp(f, "id")) {
68 a->a = A__FLT__FLT;
69 a->f32.f32__f32 = idf;
70 a->f64.f64__f64 = idd;
71 } else if (!strcmp(f, "ceil")) {
72 a->a = A__FLT__FLT;
73 a->f32.f32__f32 = ceilf;
74 a->f64.f64__f64 = ceil;
75 } else if (!strcmp(f, "cos")) {
76 a->a = A__FLT__FLT;
77 a->f32.f32__f32 = cosf;
78 a->f64.f64__f64 = cos;
79 } else if (!strcmp(f, "floor")) {
80 a->a = A__FLT__FLT;
81 a->f32.f32__f32 = floorf;
82 a->f64.f64__f64 = floor;
83 } else if (!strcmp(f, "sin")) {
84 a->a = A__FLT__FLT;
85 a->f32.f32__f32 = sinf;
86 a->f64.f64__f64 = sin;
87 } else if (!strcmp(f, "trunc")) {
88 a->a = A__FLT__FLT;
89 a->f32.f32__f32 = truncf;
90 a->f64.f64__f64 = trunc;
91 } else if (!strcmp(f, "fma")) {
92 a->a = A__FLT_FLT_FLT__FLT;
93 a->f32.f32_f32_f32__f32 = fmaf;
94 a->f64.f64_f64_f64__f64 = fma;
95 } else {
96 fprintf(stderr, "impl-libc: unknown function \"%s\"\n", f);
97 _exit(1);
101 void read_buf(char *b, ssize_t len)
103 ssize_t r;
104 ssize_t total = 0;
106 while (total < len) {
107 r = read(0, (b + total), (len - total));
109 if (!r) {
110 /* EOF */
111 _exit(0);
112 } else if (r == -1) {
113 perror("impl-libc: read");
114 _exit(1);
115 } else {
116 total += r;
121 void write_buf(const char *b, ssize_t len)
123 ssize_t r;
124 ssize_t total = 0;
126 while (total < len) {
127 r = write(1, (b + total), (len - total));
129 if (r == -1) {
130 perror("impl-libc: write");
131 _exit(1);
132 } else {
133 total += r;
138 size_t input_width(argtype a, precision p)
140 size_t w = (p == P_SINGLE) ? 4 : 8;
142 switch (a) {
143 case A_UNKNOWN:
144 break;
145 case A__FLT__FLT:
147 return 1 * w;
148 case A__FLT_FLT_FLT__FLT:
150 return 3 * w;
153 return (size_t) -1;
156 size_t output_width(argtype a, precision p)
158 size_t w = (p == P_SINGLE) ? 4 : 8;
160 switch (a) {
161 case A_UNKNOWN:
162 break;
163 case A__FLT__FLT:
165 return 1 * w;
166 case A__FLT_FLT_FLT__FLT:
168 return 1 * w;
171 return (size_t) -1;
174 void io_loop(action a, size_t n)
176 char *in_buf = 0;
177 char *out_buf = 0;
178 size_t in_sz = input_width(a.a, a.p);
179 size_t out_sz = output_width(a.a, a.p);
181 if ((in_sz * n) / n != in_sz) {
182 fprintf(stderr, "impl-libc: input length overflow\n");
183 _exit(1);
186 if ((out_sz * n) / n != out_sz) {
187 fprintf(stderr, "impl-libc: output length overflow\n");
188 _exit(1);
191 if (!(in_buf = malloc(in_sz * n))) {
192 perror("impl-libc: malloc");
193 _exit(1);
196 if (!(out_buf = malloc(out_sz * n))) {
197 perror("impl-libc: malloc");
198 _exit(1);
201 while (1) {
202 read_buf(in_buf, in_sz * n);
204 switch (a.a) {
205 case A_UNKNOWN:
206 fprintf(stderr, "impl-libc: impossible\n");
207 _exit(1);
208 break;
209 case A__FLT__FLT:
211 switch (a.p) {
212 case P_SINGLE:
214 for (size_t j = 0; j < n; ++j) {
215 float *x = (float *) (in_buf + (in_sz *
216 j));
217 float *y = (float *) (out_buf +
218 (out_sz * j));
220 *y = a.f32.f32__f32(*x);
223 break;
224 case P_DOUBLE:
226 for (size_t j = 0; j < n; ++j) {
227 double *x = (double *) (in_buf +
228 (in_sz * j));
229 double *y = (double *) (out_buf +
230 (out_sz * j));
232 *y = a.f64.f64__f64(*x);
235 break;
238 break;
239 case A__FLT_FLT_FLT__FLT:
241 switch (a.p) {
242 case P_SINGLE:
244 for (size_t j = 0; j < n; ++j) {
245 float *x1 = (float *) (in_buf + (in_sz *
246 j));
247 float *x2 = (float *) (in_buf + (in_sz *
249 4));
250 float *x3 = (float *) (in_buf + (in_sz *
252 8));
253 float *y = (float *) (out_buf +
254 (out_sz * j));
256 *y = a.f32.f32_f32_f32__f32(*x1, *x2,
257 *x3);
260 break;
261 case P_DOUBLE:
263 for (size_t j = 0; j < n; ++j) {
264 double *x1 = (double *) (in_buf +
265 (in_sz * j));
266 double *x2 = (double *) (in_buf +
267 (in_sz * j) +
269 double *x3 = (double *) (in_buf +
270 (in_sz * j) +
271 16);
272 double *y = (double *) (out_buf +
273 (out_sz * j));
275 *y = a.f64.f64_f64_f64__f64(*x1, *x2,
276 *x3);
279 break;
282 break;
285 write_buf(out_buf, out_sz * n);
289 int main(int argc, char **argv)
291 int c = 0;
292 action a = { .p = P_SINGLE };
293 long long n = 0;
295 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
296 switch (c) {
297 case 's':
298 a.p = P_SINGLE;
299 break;
300 case 'd':
301 a.p = P_DOUBLE;
302 break;
303 case 'f':
304 determine_function(optarg, &a);
305 break;
306 case 'n':
307 errno = 0;
308 n = strtoll(optarg, 0, 0);
310 if (errno) {
311 perror("impl-libc: unparsable");
313 return 1;
316 break;
317 default:
318 usage();
319 break;
323 if (a.a == A_UNKNOWN) {
324 usage();
327 io_loop(a, n);