support tan and cot
[fpmath-consensus.git] / impl-mpfr / impl-mpfr.c
blob42d9c852ece45517440b7e8edcc6afb679ec13c0
1 #include <errno.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
6 #include <unistd.h>
8 #include <mpfr.h>
10 /* Whether we're looking at 32-bit or 64-bit floats */
11 typedef enum { P_SINGLE, P_DOUBLE } precision;
13 /* What type of arguments we expect, and what we'll give back. */
14 typedef enum {
15 /* */
16 A_UNKNOWN,
17 A__FLT__FLT,
18 A__FLT_FLT_RND__FLT,
19 A__FLT_FLT_FLT_RND__FLT,
20 A__FLT_RND__FLT,
21 } argtype;
23 /* Types of functions we could call */
24 typedef int (*f__flt_flt_flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_srcptr,
25 mpfr_srcptr, mpfr_rnd_t);
26 typedef int (*f__flt__flt)(mpfr_ptr, mpfr_srcptr);
27 typedef int (*f__flt_flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_srcptr,
28 mpfr_rnd_t);
29 typedef int (*f__flt_rnd__flt)(mpfr_ptr, mpfr_srcptr, mpfr_rnd_t);
31 /* Wrapper around a function pointer */
32 typedef struct {
33 /* */
34 precision p;
35 argtype a;
37 union {
38 /* */
39 f__flt__flt flt__flt;
40 f__flt_flt_rnd__flt flt_flt_rnd__flt;
41 f__flt_flt_flt_rnd__flt flt_flt_flt_rnd__flt;
42 f__flt_rnd__flt flt_rnd__flt;
43 } f;
44 } action;
46 void usage(void)
48 fprintf(stderr,
49 "usage: impl-mpfr [-s|-d] -f <function_name> -n <num_inputs>\n");
50 _exit(1);
53 void determine_function(const char *f, action *a)
55 if (!strcmp(f, "zzzzzzzzz")) {
56 a->a = A_UNKNOWN;
57 } else if (!strcmp(f, "id")) {
58 a->a = A__FLT_RND__FLT;
59 a->f.flt_rnd__flt = mpfr_set;
60 } else if (!strcmp(f, "ceil")) {
61 a->a = A__FLT__FLT;
62 a->f.flt__flt = mpfr_ceil;
63 } else if (!strcmp(f, "cos")) {
64 a->a = A__FLT_RND__FLT;
65 a->f.flt_rnd__flt = mpfr_cos;
66 } else if (!strcmp(f, "cot")) {
67 a->a = A__FLT_RND__FLT;
68 a->f.flt_rnd__flt = mpfr_cot;
69 } else if (!strcmp(f, "floor")) {
70 a->a = A__FLT__FLT;
71 a->f.flt__flt = mpfr_floor;
72 } else if (!strcmp(f, "exp")) {
73 a->a = A__FLT_RND__FLT;
74 a->f.flt_rnd__flt = mpfr_exp;
75 } else if (!strcmp(f, "expm1")) {
76 a->a = A__FLT_RND__FLT;
77 a->f.flt_rnd__flt = mpfr_expm1;
78 } else if (!strcmp(f, "fma")) {
79 a->a = A__FLT_FLT_FLT_RND__FLT;
80 a->f.flt_flt_flt_rnd__flt = mpfr_fma;
81 } else if (!strcmp(f, "log")) {
82 a->a = A__FLT_RND__FLT;
83 a->f.flt_rnd__flt = mpfr_log;
84 } else if (!strcmp(f, "log1p")) {
85 a->a = A__FLT_RND__FLT;
86 a->f.flt_rnd__flt = mpfr_log1p;
87 } else if (!strcmp(f, "powr")) {
88 a->a = A__FLT_FLT_RND__FLT;
89 a->f.flt_flt_rnd__flt = mpfr_pow;
90 } else if (!strcmp(f, "sin")) {
91 a->a = A__FLT_RND__FLT;
92 a->f.flt_rnd__flt = mpfr_sin;
93 } else if (!strcmp(f, "sqrt")) {
94 a->a = A__FLT_RND__FLT;
95 a->f.flt_rnd__flt = mpfr_sqrt;
96 } else if (!strcmp(f, "tan")) {
97 a->a = A__FLT_RND__FLT;
98 a->f.flt_rnd__flt = mpfr_tan;
99 } else if (!strcmp(f, "trunc")) {
100 a->a = A__FLT__FLT;
101 a->f.flt__flt = mpfr_trunc;
102 } else {
103 fprintf(stderr, "impl-mpfr: unknown function \"%s\"\n", f);
104 _exit(1);
108 void read_buf(char *b, ssize_t len)
110 ssize_t r;
111 ssize_t total = 0;
113 while (total < len) {
114 r = read(0, (b + total), (len - total));
116 if (!r) {
117 /* EOF */
118 _exit(0);
119 } else if (r == -1) {
120 perror("impl-mpfr: read");
121 _exit(1);
122 } else {
123 total += r;
128 void write_buf(const char *b, ssize_t len)
130 ssize_t r;
131 ssize_t total = 0;
133 while (total < len) {
134 r = write(1, (b + total), (len - total));
136 if (r == -1) {
137 perror("impl-mpfr: write");
138 _exit(1);
139 } else {
140 total += r;
145 size_t input_width(argtype a, precision p)
147 size_t w = (p == P_SINGLE) ? 4 : 8;
149 switch (a) {
150 case A_UNKNOWN:
151 break;
152 case A__FLT__FLT:
154 return 1 * w;
155 case A__FLT_FLT_RND__FLT:
157 return 2 * w;
158 case A__FLT_FLT_FLT_RND__FLT:
160 return 3 * w;
161 case A__FLT_RND__FLT:
163 return 1 * w;
166 return (size_t) -1;
169 size_t output_width(argtype a, precision p)
171 size_t w = (p == P_SINGLE) ? 4 : 8;
173 switch (a) {
174 case A_UNKNOWN:
175 break;
176 case A__FLT__FLT:
178 return 1 * w;
179 case A__FLT_FLT_RND__FLT:
181 return 1 * w;
182 case A__FLT_FLT_FLT_RND__FLT:
184 return 1 * w;
185 case A__FLT_RND__FLT:
187 return 1 * w;
190 return (size_t) -1;
193 void io_loop(action a, size_t n)
195 char *in_buf = 0;
196 char *out_buf = 0;
197 size_t in_sz = input_width(a.a, a.p);
198 size_t out_sz = output_width(a.a, a.p);
199 mpfr_t x1;
200 mpfr_t x2;
201 mpfr_t x3;
202 mpfr_t y;
204 if ((in_sz * n) / n != in_sz) {
205 fprintf(stderr, "impl-libc: input length overflow\n");
206 _exit(1);
209 if ((out_sz * n) / n != out_sz) {
210 fprintf(stderr, "impl-libc: output length overflow\n");
211 _exit(1);
214 if (!(in_buf = malloc(in_sz * n))) {
215 perror("impl-libc: malloc");
216 _exit(1);
219 if (!(out_buf = malloc(out_sz * n))) {
220 perror("impl-libc: malloc");
221 _exit(1);
224 /* I'm pretty sure 53 precision would be enough */
225 mpfr_init2(x1, 75);
226 mpfr_init2(x2, 75);
227 mpfr_init2(x3, 75);
228 mpfr_init2(y, 75);
230 while (1) {
231 read_buf(in_buf, in_sz * n);
233 switch (a.a) {
234 case A_UNKNOWN:
235 fprintf(stderr, "impl-libc: impossible\n");
236 _exit(1);
237 break;
238 case A__FLT__FLT:
240 switch (a.p) {
241 case P_SINGLE:
243 for (size_t j = 0; j < n; ++j) {
244 float *xf1 = (float *) (in_buf +
245 (in_sz * j));
246 float *yf = (float *) (out_buf +
247 (out_sz * j));
249 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
250 a.f.flt__flt(y, x1);
251 *yf = mpfr_get_flt(y, MPFR_RNDN);
254 break;
255 case P_DOUBLE:
257 for (size_t j = 0; j < n; ++j) {
258 double *xf1 = (double *) (in_buf +
259 (in_sz * j));
260 double *yf = (double *) (out_buf +
261 (out_sz * j));
263 mpfr_set_d(x1, *xf1, MPFR_RNDN);
264 a.f.flt__flt(y, x1);
265 *yf = mpfr_get_d(y, MPFR_RNDN);
268 break;
271 break;
272 case A__FLT_FLT_RND__FLT:
274 switch (a.p) {
275 case P_SINGLE:
277 for (size_t j = 0; j < n; ++j) {
278 float *xf1 = (float *) (in_buf +
279 (in_sz * j));
280 float *xf2 = (float *) (in_buf +
281 (in_sz * j) +
283 float *yf = (float *) (out_buf +
284 (out_sz * j));
286 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
287 mpfr_set_flt(x2, *xf2, MPFR_RNDN);
288 a.f.flt_flt_rnd__flt(y, x1, x2,
289 MPFR_RNDN);
290 *yf = mpfr_get_flt(y, MPFR_RNDN);
293 break;
294 case P_DOUBLE:
296 for (size_t j = 0; j < n; ++j) {
297 double *xf1 = (double *) (in_buf +
298 (in_sz * j));
299 double *xf2 = (double *) (in_buf +
300 (in_sz * j) +
302 double *yf = (double *) (out_buf +
303 (out_sz * j));
305 mpfr_set_d(x1, *xf1, MPFR_RNDN);
306 mpfr_set_d(x2, *xf2, MPFR_RNDN);
307 a.f.flt_flt_rnd__flt(y, x1, x2,
308 MPFR_RNDN);
309 *yf = mpfr_get_d(y, MPFR_RNDN);
312 break;
315 break;
316 case A__FLT_FLT_FLT_RND__FLT:
318 switch (a.p) {
319 case P_SINGLE:
321 for (size_t j = 0; j < n; ++j) {
322 float *xf1 = (float *) (in_buf +
323 (in_sz * j));
324 float *xf2 = (float *) (in_buf +
325 (in_sz * j) +
327 float *xf3 = (float *) (in_buf +
328 (in_sz * j) +
330 float *yf = (float *) (out_buf +
331 (out_sz * j));
333 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
334 mpfr_set_flt(x2, *xf2, MPFR_RNDN);
335 mpfr_set_flt(x3, *xf3, MPFR_RNDN);
336 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
337 MPFR_RNDN);
338 *yf = mpfr_get_flt(y, MPFR_RNDN);
341 break;
342 case P_DOUBLE:
344 for (size_t j = 0; j < n; ++j) {
345 double *xf1 = (double *) (in_buf +
346 (in_sz * j));
347 double *xf2 = (double *) (in_buf +
348 (in_sz * j) +
350 double *xf3 = (double *) (in_buf +
351 (in_sz * j) +
352 16);
353 double *yf = (double *) (out_buf +
354 (out_sz * j));
356 mpfr_set_d(x1, *xf1, MPFR_RNDN);
357 mpfr_set_d(x2, *xf2, MPFR_RNDN);
358 mpfr_set_d(x3, *xf3, MPFR_RNDN);
359 a.f.flt_flt_flt_rnd__flt(y, x1, x2, x3,
360 MPFR_RNDN);
361 *yf = mpfr_get_d(y, MPFR_RNDN);
364 break;
367 break;
368 case A__FLT_RND__FLT:
370 switch (a.p) {
371 case P_SINGLE:
373 for (size_t j = 0; j < n; ++j) {
374 float *xf1 = (float *) (in_buf +
375 (in_sz * j));
376 float *yf = (float *) (out_buf +
377 (out_sz * j));
379 mpfr_set_flt(x1, *xf1, MPFR_RNDN);
380 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
381 *yf = mpfr_get_flt(y, MPFR_RNDN);
384 break;
385 case P_DOUBLE:
387 for (size_t j = 0; j < n; ++j) {
388 double *xf1 = (double *) (in_buf +
389 (in_sz * j));
390 double *yf = (double *) (out_buf +
391 (out_sz * j));
393 mpfr_set_d(x1, *xf1, MPFR_RNDN);
394 a.f.flt_rnd__flt(y, x1, MPFR_RNDN);
395 *yf = mpfr_get_d(y, MPFR_RNDN);
398 break;
401 break;
404 write_buf(out_buf, out_sz * n);
408 int main(int argc, char **argv)
410 int c = 0;
411 action a = { .p = P_SINGLE };
412 long long n = 0;
414 while ((c = getopt(argc, argv, "sdf:n:")) != -1) {
415 switch (c) {
416 case 's':
417 a.p = P_SINGLE;
418 break;
419 case 'd':
420 a.p = P_DOUBLE;
421 break;
422 case 'f':
423 determine_function(optarg, &a);
424 break;
425 case 'n':
426 errno = 0;
427 n = strtoll(optarg, 0, 0);
429 if (errno) {
430 perror("impl-libc: unparsable");
432 return 1;
435 break;
436 default:
437 usage();
438 break;
442 if (a.a == A_UNKNOWN) {
443 usage();
446 io_loop(a, n);