support tan and cot
[fpmath-consensus.git] / impl-myrddin / impl-myrddin.myr
blob38d1c6aef65132c6dfd364e0f3a467076aac84da
1 use std
3 use math
5 type Fn_flt__flt = struct
6         f32 : (x : flt32 -> flt32)
7         f64 : (x : flt64 -> flt64)
8 ;;
10 type Fn_flt_flt__flt = struct
11         f32 : (x : flt32, y : flt32 -> flt32)
12         f64 : (x : flt64, y : flt64 -> flt64)
15 type Fn_flt_flt_flt__flt = struct
16         f32 : (x : flt32, y : flt32, z : flt32 -> flt32)
17         f64 : (x : flt64, y : flt64, z : flt64 -> flt64)
20 type fn_desc = struct
21         name : byte[:]
22         f : union
23                 `Flt__flt Fn_flt__flt
24                 `Flt_flt__flt Fn_flt_flt__flt
25                 `Flt_flt_flt__flt Fn_flt_flt_flt__flt
26         ;;
29 type flt_prec = union
30         `Single
31         `Double
34 var available_fns : fn_desc[:] = [][:]
36 generic id : (a : @a -> @a) = {x; -> x}
38 const main = {args : byte[:][:]
39         available_fns = [
40                 [.name = "id",    .f = `Flt__flt         [ .f32 = id,         .f64 = id]],
41                 [.name = "ceil",  .f = `Flt__flt         [ .f32 = math.ceil,  .f64 = math.ceil]],
42                 [.name = "cos",   .f = `Flt__flt         [ .f32 = math.cos,   .f64 = math.cos]],
43                 [.name = "cot",   .f = `Flt__flt         [ .f32 = math.cot,   .f64 = math.cot]],
44                 [.name = "exp",   .f = `Flt__flt         [ .f32 = math.exp,   .f64 = math.exp]],
45                 [.name = "expm1", .f = `Flt__flt         [ .f32 = math.expm1, .f64 = math.expm1]],
46                 [.name = "floor", .f = `Flt__flt         [ .f32 = math.floor, .f64 = math.floor]],
47                 [.name = "fma",   .f = `Flt_flt_flt__flt [ .f32 = math.fma,   .f64 = math.fma]],
48                 [.name = "log",   .f = `Flt__flt         [ .f32 = math.log,   .f64 = math.log]],
49                 [.name = "log1p", .f = `Flt__flt         [ .f32 = math.log1p, .f64 = math.log1p]],
50                 [.name = "powr",  .f = `Flt_flt__flt     [ .f32 = math.powr,  .f64 = math.powr]],
51                 [.name = "sqrt",  .f = `Flt__flt         [ .f32 = math.sqrt,  .f64 = math.sqrt]],
52                 [.name = "sin",   .f = `Flt__flt         [ .f32 = math.sin,   .f64 = math.sin]],
53                 [.name = "tan",   .f = `Flt__flt         [ .f32 = math.tan,   .f64 = math.tan]],
54                 [.name = "trunc", .f = `Flt__flt         [ .f32 = math.trunc, .f64 = math.trunc]],
55         ][:]
57         var p : flt_prec = `Single
58         var f : fn_desc = available_fns[0]
59         var n : std.size = 0
61         (p, f, n) = read_args(args)
63         io_loop(p, f, n)
66 const read_args = {args : byte[:][:]
67         var p : flt_prec = `Single
68         var n : std.size = 0
69         var fname : byte[:] = ""
70         var fn : fn_desc = available_fns[0]
71         var cmd = std.optparse(args, &[
72                 .argdesc = "",
73                 .opts = [
74                         [.opt = 's', .desc = "use single precision (default)"],
75                         [.opt = 'd', .desc = "use double precision"],
76                         [.opt = 'n', .arg = "N", .desc = "read/write ‘N’ entries at a time"],
77                         [.opt = 'f', .arg = "func", .desc = "use function ‘f’"],
78                 ][:]
79         ])
81         for opt : cmd.opts
82                 match opt
83                 | ('s', _): p = `Single
84                 | ('d', _): p = `Double
85                 | ('n', ns):
86                         match std.intparse(ns)
87                         | `std.Some np: n = np
88                         | `std.None:
89                                 std.put("impl-myrddin: unparsable number “{}”\n", ns)
90                                 std.exit(1)
91                         ;;
92                 | ('f', fs): fname = fs
93                 | _ : std.die("impl-myrddin: impossible\n")
94                 ;;
95         ;;
97         var good_fn : bool = false
98         for f : available_fns
99                 if std.eq(f.name, fname)
100                         fn = f
101                         good_fn = true
102                         break
103                 ;;
104         ;;
106         if !good_fn
107                 std.put("impl-myrddin: unknown function “{}”\n", fname)
108                 std.exit(1)
109         ;;
111         if n <= 0
112                 std.put("impl-myrddin: positive number of entries required\n")
113                 std.exit(1)
114         ;;
116         -> (p, fn, n)
120 const io_loop = {p : flt_prec, fn : fn_desc, n : std.size
121         var input_sz : std.size = 0
122         var output_sz : std.size = 0
123         var in_buf : byte[:] = [][:]
124         var out_buf : byte[:] = [][:]
125         var w = prec_width(p)
127         (input_sz, output_sz) = io_widths(p, fn)
129         if (((input_sz * n) / input_sz) != n) || (((output_sz * n) / output_sz) != n)
130                 std.put("impl-myrddin: overflow in i/o buffer size\n")
131                 std.exit(1)
132         ;;
134         in_buf = std.slalloc(input_sz * n)
135         out_buf = std.slalloc(output_sz * n)
137         while true
138                 match std.readall(0, in_buf)
139                 | `std.Ok _:
140                 | `std.Err e:
141                         std.put("impl-myrddin: std.readall(): {}\n", e)
142                         std.exit(1)
143                 ;;
145                 for var j = 0; j < n; ++j
146                         var ib : byte[:] = in_buf[j * input_sz:(j + 1) * input_sz]
147                         var ob : byte[:] = out_buf[j * output_sz:(j + 1) * output_sz]
148                         match (p, fn.f)
149                         | (`Single, `Flt__flt f):
150                                 var x : flt32 = std.flt32frombits(std.getle32(ib))
151                                 std.putle32(ob, std.flt32bits(f.f32(x)))
152                         | (`Double, `Flt__flt f):
153                                 var x : flt64 = std.flt64frombits(std.getle64(ib))
154                                 std.putle64(ob, std.flt64bits(f.f64(x)))
155                         | (`Single, `Flt_flt__flt f):
156                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
157                                 var x2 : flt32 = std.flt32frombits(std.getle32(ib[4: 8]))
158                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2)))
159                         | (`Double, `Flt_flt__flt f):
160                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
161                                 var x2 : flt64 = std.flt64frombits(std.getle64(ib[ 8:16]))
162                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2)))
163                         | (`Single, `Flt_flt_flt__flt f):
164                                 var x1 : flt32 = std.flt32frombits(std.getle32(ib[0: 4]))
165                                 var x2 : flt32 = std.flt32frombits(std.getle32(ib[4: 8]))
166                                 var x3 : flt32 = std.flt32frombits(std.getle32(ib[8:12]))
167                                 std.putle32(ob, std.flt32bits(f.f32(x1, x2, x3)))
168                         | (`Double, `Flt_flt_flt__flt f):
169                                 var x1 : flt64 = std.flt64frombits(std.getle64(ib[ 0: 8]))
170                                 var x2 : flt64 = std.flt64frombits(std.getle64(ib[ 8:16]))
171                                 var x3 : flt64 = std.flt64frombits(std.getle64(ib[16:24]))
172                                 std.putle64(ob, std.flt64bits(f.f64(x1, x2, x3)))
173                         ;;
174                 ;;
176                 match std.writeall(1, out_buf)
177                 | `std.Ok _:
178                 | `std.Err (_, e):
179                         std.put("impl-myrddin: std.writeall(): {}\n", e)
180                         std.exit(1)
181                 ;;
182         ;;
185 const prec_width = {p : flt_prec
186         match p
187         | `Single: -> 4
188         | `Double: -> 8
189         ;;
192 const io_widths = {p : flt_prec, fn : fn_desc
193         var w : std.size = prec_width(p)
195         match fn.f
196         | `Flt__flt _ : -> (w, w)
197         | `Flt_flt__flt _ : -> (2*w, w)
198         | `Flt_flt_flt__flt _ : -> (3*w, w)
199         ;;