support tan and cot
[fpmath-consensus.git] / checker / checker.myr
blobd7c600b54c229c99ff9421b39752cf645ecdddc0
1 use std
3 use bio
4 use iter
5 use fileutil
6 use sys
8 type impl_prog = struct
9         name : byte[:]
10         pid : std.pid
11         stdin : std.fd
12         stdout : std.fd
13         alive : bool
15         output_bits : byte[:]
18 var rng : std.rng#
20 /* Flt is ``whatever precision we're testing''. */
21 type fp_type = union
22         `Flt
25 type fn_desc = struct
26         name : byte[:]
27         inputs : fp_type[:]
28         outputs : fp_type[:]
31 type flt_prec = union
32         `Single
33         `Double
36 type exactness = union
37         `Exact
38         `Inexact uint
41 /* (name, number-of-flt-args, constant-extra-bytes) */
42 var available_fns : fn_desc[:] = [][:]
44 const nop = {;}
46 const main = {args : byte[:][:]
47         available_fns = [
48                 [.name = "id",    .inputs = [`Flt][:], .outputs = [`Flt][:]],
49                 [.name = "ceil",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
50                 [.name = "cos",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
51                 [.name = "cot",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
52                 [.name = "exp",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
53                 [.name = "expm1", .inputs = [`Flt][:], .outputs = [`Flt][:]],
54                 [.name = "floor", .inputs = [`Flt][:], .outputs = [`Flt][:]],
55                 [.name = "fma",   .inputs = [`Flt, `Flt, `Flt][:], .outputs = [`Flt][:]],
56                 [.name = "log",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
57                 [.name = "log1p", .inputs = [`Flt][:], .outputs = [`Flt][:]],
58                 [.name = "powr",  .inputs = [`Flt, `Flt][:], .outputs = [`Flt][:]],
59                 [.name = "sin",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
60                 [.name = "sincos",.inputs = [`Flt][:], .outputs = [`Flt, `Flt][:]],
61                 [.name = "sqrt",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
62                 [.name = "tan",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
63                 [.name = "trunc", .inputs = [`Flt][:], .outputs = [`Flt][:]],
64         ][:]
66         var old
67         var sa = [
68                 .handler = (nop : byte#),
69                 .flags = sys.Saresethand,
70         ]
71         sys.sigaction(sys.Sigpipe, &sa, &old)
73         var fn : fn_desc = available_fns[0]
74         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
76         var precision : flt_prec = `Single
77         var exactness : exactness = `Exact
78         var impls : impl_prog#[:]
79         rng = std.mksrng((std.now() : uint32))
81         (precision, exactness, fn, next_bits_fn) = read_args(args)
83         var input_sz = args_width(fn.inputs, precision)
84         var num_inputs = (1 << 18)
85         var buf_sz = num_inputs * input_sz
87         impls = start_impls([prec_arg(precision), "-f", fn.name, "-n", std.fmt("{}", num_inputs)][:])
89         io_loop(impls, precision, exactness, num_inputs, fn, next_bits_fn)
90         std.put("\n")
93 const prec_width = {p
94         match p
95         | `Single: -> 4
96         | `Double: -> 8
97         ;;
100 const prec_arg = {p
101         match p
102         | `Single: -> "-s"
103         | `Double: -> "-d"
104         ;;
107 const args_width = {ts : fp_type[:], p : flt_prec
108         var w : std.size = 0
110         for t : ts
111                 match t
112                 | `Flt: w += prec_width(p)
113                 ;;
114         ;;
116         -> w
119 const read_args = {args : byte[:][:]
120         var exactness : exactness = `Exact
121         var precision : flt_prec = `Single
122         var fn_name : byte[:] = "UNSPECIFIED"
123         var fn : fn_desc = available_fns[0]
124         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
126         var cmd = std.optparse(args, &[
127                 .argdesc = "",
128                 .opts = [
129                         [.opt = 's', .desc = "use single precision (default)"],
130                         [.opt = 'd', .desc = "use double precision"],
131                         [.opt = 'l', .desc = "list available functions"],
132                         [.opt = 'f', .arg = "f", .desc = "test function ‘f’"],
133                         [.opt = 'r', .arg = "s", .desc = "choose inputs randomly, with seed ‘s’"],
134                         [.opt = 'e', .desc = "exhaust input space"],
135                         [.opt = 'i', .arg = "i", .desc = "allow inexact matches (by ‘i’, bitwise)"],
136                 ][:]
137         ])
139         for opt : cmd.opts
140                 match opt
141                 | ('s', _):
142                         precision = `Single
143                 | ('d', _):
144                         precision = `Double
145                 | ('i', arg):
146                         match std.intparse(arg)
147                         | `std.Some n:
148                                 if n >= 0
149                                         exactness = `Inexact (n : uint)
150                                 else
151                                         std.put("unacceptable inexactness “{}”\n", arg)
152                                 ;;
153                         | `std.None:
154                                 std.put("cannot parse inexactness “{}”\n", arg)
155                                 std.exit(1)
156                         ;;
157                 | ('l', _):
158                         list_functions()
159                         std.exit(0)
160                 | ('f', arg): fn_name = arg
161                 | ('r', arg):
162                         next_bits_fn = next_rand
163                         match std.intparse(arg)
164                         | `std.Some n:
165                                 rng = std.mksrng(n)
166                         | `std.None:
167                                 std.put("cannot parse seed “{}”\n", arg)
168                                 std.exit(1)
169                         ;;
170                 | ('e', _): next_bits_fn = next_exhaust
171                 | _: std.die("impossible\n")
172                 ;;
173         ;;
175         var good_fn : bool = false
176         for f : available_fns
177                 if std.eq(f.name, fn_name)
178                         fn = f
179                         good_fn = true
180                         break
181                 ;;
182         ;;
184         if !good_fn
185                 std.put("unknown function “{}”\n", fn_name)
186                 std.exit(1)
187         ;;
189         -> (precision, exactness, fn, next_bits_fn)
192 const io_loop = {impls : impl_prog#[:], p : flt_prec, x : exactness, num_inputs : std.size, fn : fn_desc, next_bits_fn : (b : byte[:], n : std.size -> bool)
193         var input_sz = args_width(fn.inputs, p)
194         var output_sz = args_width(fn.outputs, p)
195         var draw_line : bool = false
196         var n = 0
197         var bits : byte[:] = std.slalloc(num_inputs * input_sz)
198         var last_bits : byte[:] = std.slalloc(num_inputs * input_sz)
199         std.slfill(bits, 0)
200         std.slfill(last_bits, 0)
201         for i : impls
202                 i.output_bits = std.slalloc(num_inputs * output_sz)
203                 std.slfill(i.output_bits, 0)
204         ;;
206         /* Now, loop perhaps infinitely with the comparisons */
207 :again
208         draw_line = false
210         /* Send question */
211         for i : impls
212                 match std.writeall(i.stdin, bits)
213                 | `std.Ok _:
214                 | `std.Err (_, e):
215                         std.put("CRASH: {w=20} [{w=6}] failed to receive data\n", i.name, i.pid)
216                         i.alive = false
217                         draw_line = true
218                 ;;
219         ;;
221         /* Reap zombies */
222         for var j = 0; j < impls.len; ++j
223                 if impls[j].alive
224                         continue
225                 ;;
226                 impls[j] = impls[impls.len - 1]
227                 std.slgrow(&impls, impls.len - 1)
228         ;;
230         /* Gather consensus on last time's answers */
231         consensus(last_bits, num_inputs, fn, p, x, impls)
232         if n % 100 == 0
233                 if n % 8000 == 0
234                         std.put("\x1b[1G\x1b[0K")
235                 ;;
236                 std.put(".")
237         ;;
238         n++
240         std.slcp(last_bits, bits)
242         /* Receive new answers */
243         for i : impls
244                 if !i.alive
245                         continue
246                 ;;
248                 match std.readall(i.stdout, i.output_bits)
249                 | `std.Ok _:
250                 | `std.Err e:
251                         std.put("CRASH: {w=20} [{w=6}] failed to send data\n", i.name, i.pid)
252                         i.alive = false
253                         draw_line = true
254                 |_:
255                 ;;
256         ;;
258         if impls.len < 2
259                 std.put("Less than 2 implementations left. Consensus impossible.\n")
260                 std.put("----------\n")
261                 std.exit(1)
262         ;;
264         if draw_line
265                 std.put("----------\n")
266         ;;
269         /* Onward */
270         if next_bits_fn(bits, num_inputs)
271                 goto again
272         ;;
275 const next_rand = { b : byte[:], n : std.size
276         std.rngrandbytes(rng, b)
277         -> true
280 const next_exhaust = { b : byte[:], n : std.size
281         var one_arg : byte[:] = std.slalloc(b.len / n)
282         std.slcp(one_arg, b[b.len - one_arg.len:])
283         var finished : bool = false
285         /* n is the number of total argument groups */
286         for var i = 0; i < n; ++i
287                 /* Increment this particular argument */
288                 var j = one_arg.len - 1
289                 while j >= 0
290                         one_arg[j]++
291                         if (one_arg[j] != 0)
292                                 break
293                         ;;
294                         j--
295                 ;;
297                 finished = finished || j < 0
298                 var z = one_arg.len * i
299                 std.slcp(b[z:z + one_arg.len], one_arg)
300         ;;
302         -> !finished
305 const list_functions = {
306         std.put("Available functions:\n")
307         std.put("--------------------\n")
308         for f : available_fns
309                 std.put("  {}\n", f)
310         ;;
313 const start_impls = {opts : byte[:][:]
314         var cmd : byte[:][:] = std.slalloc(opts.len + 1)
315         var nice_name : byte[:] = [][:]
316         var started_impls : impl_prog#[:] = [][:]
317         var survived_impls : impl_prog#[:] = [][:]
319         for var j = 0; j < opts.len; ++j
320                 cmd[j + 1] = opts[j]
321         ;;
323         /* Start everything */
324         for f  : fileutil.bywalk(".")
325                 match std.strrfind(f, "/")
326                 | `std.Some j: nice_name = std.sldup(f[j+1:])
327                 | `std.None: nice_name = std.sldup(f)
328                 ;;
330                 if nice_name.len < 5 || !std.eq(nice_name[:5], "impl-") || \
331                         std.eq(nice_name[nice_name.len - 2:nice_name.len - 1], ".")
332                         std.slfree(nice_name)
333                         continue
334                 ;;
336                 cmd[0] = f
337                 match std.spork(cmd)
338                 | `std.Ok (p, fi, fo) :
339                         std.slpush(&started_impls, std.mk([
340                                 .name = nice_name,
341                                 .pid = p,
342                                 .stdin = fi,
343                                 .stdout = fo,
344                                 .alive = true,
345                         ]))
346                 | `std.Err e: std.slfree(nice_name)
347                 ;;
348         ;;
350         /* Give them a bit of time to die */
351         std.usleep(500_000)
353         /* Reap the zombies */
354         var z
355         var l
356         var WNOHANG = 1 /* HACK */
357         while ((z = sys.waitpid(-1, &l, WNOHANG)) > 0)
358                 match sys.waitstatus(l)
359                 | `sys.Waitexit _:
360                 | `sys.Waitsig _:
361                 | `sys.Waitfail _:
362                 | `sys.Waitstop _: continue
363                 ;;
364                 for i : started_impls
365                         if i.pid == (z : std.pid)
366                                 i.alive = false
367                         ;;
368                 ;;
369         ;;
371         /* What remains? */
372         for i : started_impls
373                 if !i.alive
374                         continue
375                 ;;
377                 std.slpush(&survived_impls, i)
378         ;;
380         match survived_impls.len
381         | 0:
382                 std.put("No implementations found. Try running from fpmath-consensus root dir.\n")
383                 std.exit(1)
384         | 1:
385                 std.put("Only one implementation found. Comparisons will be impossible.\n")
386                 std.exit(1)
387         | _:
388         ;;
390         std.put("Executing:\n")
391         std.put("----------\n")
392         for i : survived_impls
393                 std.put("  [{w=6}] {w=20}", i.pid, i.name)
394                 for o : opts
395                         std.put(" {}", o)
396                 ;;
397                 std.put("\n")
398         ;;
399         std.put("----------\n")
401         std.slfree(started_impls)
403         -> survived_impls
406 const consensus = { input : byte[:], num_inputs : std.size, fn : fn_desc, p : flt_prec, x : exactness, impls : impl_prog#[:]
407         var all_agree : bool = true
408         var flt_sz : std.size = prec_width(p)
409         var inputs_sz = args_width(fn.inputs, p)
410         var outputs_sz = args_width(fn.outputs, p)
412         for var z = 0; z < num_inputs; ++z
413                 all_agree = true
415                 /* The input is possibly multiple entries */
416                 var i_start = z * inputs_sz
417                 var i_end = (z + 1) * inputs_sz
419                 /* The output might also be strange */
420                 var a_start = z * outputs_sz
421                 var a_end = (z + 1) * outputs_sz
423                 for var j = 0; j + 1 < impls.len; ++j
424                         match x
425                         | `Exact:
426                                 if std.sleq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end])
427                                         continue
428                                 ;;
430                                 /* The memory patterns don't agree. But perhaps this is due to NaN? */
431                                 if detailed_eq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end], fn.outputs, p, 0)
432                                         continue
433                                 ;;
434                         | `Inexact i:
435                                 var good : bool = true
437                                 for var k = j + 1; k < impls.len; ++k
438                                         good = good && detailed_eq(impls[j].output_bits[a_start:a_end], impls[k].output_bits[a_start:a_end], fn.outputs, p, i)
439                                 ;;
441                                 if good
442                                         continue
443                                 ;;
444                         ;;
446                         all_agree = false
447                         break
448                 ;;
450                 if all_agree
451                         continue
452                 ;;
454                 
455                 std.put("For input: ")
456                 extract_and_describe(input[i_start:i_end], fn.inputs, p)
457                 std.put("\n")
458                 for i : impls
459                         std.put("  [{w=6}] {w=20}: ", i.pid, i.name)
460                         extract_and_describe(i.output_bits[a_start:a_end], fn.outputs, p)
461                 ;;
462                 std.put("----------\n")
463         ;;
466 const extract_and_describe = {bits : byte[:], ts : fp_type[:], p : flt_prec
467         var w : std.size = prec_width(p)
469         if ts.len > 1
470                 std.put("\n")
471         ;;
473         for t : ts
474                 if ts.len > 1
475                         std.put("    ")
476                 ;;
478                 match t
479                 | `Flt:
480                         match p
481                         | `Single:
482                                 var u = std.getle32(bits[:w])
483                                 std.put("0x{w=8,p=0,x} ({})\n", u, std.flt32frombits(u))
484                         | `Double:
485                                 var u = std.getle64(bits[:w])
486                                 std.put("0x{w=16,p=0,x} ({})\n", u, std.flt64frombits(u))
487                         ;;
489                         bits = bits[w:]
490                 ;;
491         ;;
494 const detailed_eq = {a : byte[:], b : byte[:], ts : fp_type[:], p : flt_prec, i : uint
495         var w : std.size = prec_width(p)
497         for t : ts
498                 match t
499                 | `Flt:
500                         match p
501                         | `Single:
502                                 var u1 = std.getle32(a[:w])
503                                 var u2 = std.getle32(b[:w])
504                                 var f1 = std.flt32frombits(u1)
505                                 var f2 = std.flt32frombits(u2)
506                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
507                                         if i == 0 || ((u1 - u2 > (i : uint32)) && (u2 - u1 > (i : uint32)))
508                                                 -> false
509                                         ;;
510                                 ;;
511                         | `Double:
512                                 var u1 = std.getle64(a[:w])
513                                 var u2 = std.getle64(b[:w])
514                                 var f1 = std.flt64frombits(u1)
515                                 var f2 = std.flt64frombits(u2)
516                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
517                                         if i == 0 || ((u1 - u2 > (i : uint64)) && (u2 - u1 > (i : uint64)))
518                                                 -> false
519                                         ;;
520                                 ;;
521                         ;;
523                         a = a[w:]
524                         b = b[w:]
525                 ;;
526         ;;
528         -> true