support sin, cos; update mpfr Makefile for v4
[fpmath-consensus.git] / checker / checker.myr
blob572571bc4eb146b45917f332c799713e3b47de43
1 use std
3 use bio
4 use iter
5 use fileutil
6 use sys
8 type impl_prog = struct
9         name : byte[:]
10         pid : std.pid
11         stdin : std.fd
12         stdout : std.fd
13         alive : bool
15         output_bits : byte[:]
18 var rng : std.rng#
20 /* Flt is ``whatever precision we're testing''. */
21 type fp_type = union
22         `Flt
25 type fn_desc = struct
26         name : byte[:]
27         inputs : fp_type[:]
28         outputs : fp_type[:]
31 type flt_prec = union
32         `Single
33         `Double
36 type exactness = union
37         `Exact
38         `Inexact uint
41 /* (name, number-of-flt-args, constant-extra-bytes) */
42 var available_fns : fn_desc[:] = [][:]
44 const nop = {;}
46 const main = {args : byte[:][:]
47         available_fns = [
48                 [.name = "id",    .inputs = [`Flt][:], .outputs = [`Flt][:]],
49                 [.name = "ceil",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
50                 [.name = "cos",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
51                 [.name = "exp",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
52                 [.name = "expm1", .inputs = [`Flt][:], .outputs = [`Flt][:]],
53                 [.name = "floor", .inputs = [`Flt][:], .outputs = [`Flt][:]],
54                 [.name = "fma",   .inputs = [`Flt, `Flt, `Flt][:], .outputs = [`Flt][:]],
55                 [.name = "log",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
56                 [.name = "log1p", .inputs = [`Flt][:], .outputs = [`Flt][:]],
57                 [.name = "powr",  .inputs = [`Flt, `Flt][:], .outputs = [`Flt][:]],
58                 [.name = "sin",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
59                 [.name = "sincos",.inputs = [`Flt][:], .outputs = [`Flt, `Flt][:]],
60                 [.name = "sqrt",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
61                 [.name = "trunc", .inputs = [`Flt][:], .outputs = [`Flt][:]],
62         ][:]
64         var old
65         var sa = [
66                 .handler = (nop : byte#),
67                 .flags = sys.Saresethand,
68         ]
69         sys.sigaction(sys.Sigpipe, &sa, &old)
71         var fn : fn_desc = available_fns[0]
72         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
74         var precision : flt_prec = `Single
75         var exactness : exactness = `Exact
76         var impls : impl_prog#[:]
77         rng = std.mksrng((std.now() : uint32))
79         (precision, exactness, fn, next_bits_fn) = read_args(args)
81         var input_sz = args_width(fn.inputs, precision)
82         var num_inputs = (1 << 18)
83         var buf_sz = num_inputs * input_sz
85         impls = start_impls([prec_arg(precision), "-f", fn.name, "-n", std.fmt("{}", num_inputs)][:])
87         io_loop(impls, precision, exactness, num_inputs, fn, next_bits_fn)
88         std.put("\n")
91 const prec_width = {p
92         match p
93         | `Single: -> 4
94         | `Double: -> 8
95         ;;
98 const prec_arg = {p
99         match p
100         | `Single: -> "-s"
101         | `Double: -> "-d"
102         ;;
105 const args_width = {ts : fp_type[:], p : flt_prec
106         var w : std.size = 0
108         for t : ts
109                 match t
110                 | `Flt: w += prec_width(p)
111                 ;;
112         ;;
114         -> w
117 const read_args = {args : byte[:][:]
118         var exactness : exactness = `Exact
119         var precision : flt_prec = `Single
120         var fn_name : byte[:] = "UNSPECIFIED"
121         var fn : fn_desc = available_fns[0]
122         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
124         var cmd = std.optparse(args, &[
125                 .argdesc = "",
126                 .opts = [
127                         [.opt = 's', .desc = "use single precision (default)"],
128                         [.opt = 'd', .desc = "use double precision"],
129                         [.opt = 'l', .desc = "list available functions"],
130                         [.opt = 'f', .arg = "f", .desc = "test function ‘f’"],
131                         [.opt = 'r', .arg = "s", .desc = "choose inputs randomly, with seed ‘s’"],
132                         [.opt = 'e', .desc = "exhaust input space"],
133                         [.opt = 'i', .arg = "i", .desc = "allow inexact matches (by ‘i’, bitwise)"],
134                 ][:]
135         ])
137         for opt : cmd.opts
138                 match opt
139                 | ('s', _):
140                         precision = `Single
141                 | ('d', _):
142                         precision = `Double
143                 | ('i', arg):
144                         match std.intparse(arg)
145                         | `std.Some n:
146                                 if n >= 0
147                                         exactness = `Inexact (n : uint)
148                                 else
149                                         std.put("unacceptable inexactness “{}”\n", arg)
150                                 ;;
151                         | `std.None:
152                                 std.put("cannot parse inexactness “{}”\n", arg)
153                                 std.exit(1)
154                         ;;
155                 | ('l', _):
156                         list_functions()
157                         std.exit(0)
158                 | ('f', arg): fn_name = arg
159                 | ('r', arg):
160                         next_bits_fn = next_rand
161                         match std.intparse(arg)
162                         | `std.Some n:
163                                 rng = std.mksrng(n)
164                         | `std.None:
165                                 std.put("cannot parse seed “{}”\n", arg)
166                                 std.exit(1)
167                         ;;
168                 | ('e', _): next_bits_fn = next_exhaust
169                 | _: std.die("impossible\n")
170                 ;;
171         ;;
173         var good_fn : bool = false
174         for f : available_fns
175                 if std.eq(f.name, fn_name)
176                         fn = f
177                         good_fn = true
178                         break
179                 ;;
180         ;;
182         if !good_fn
183                 std.put("unknown function “{}”\n", fn_name)
184                 std.exit(1)
185         ;;
187         -> (precision, exactness, fn, next_bits_fn)
190 const io_loop = {impls : impl_prog#[:], p : flt_prec, x : exactness, num_inputs : std.size, fn : fn_desc, next_bits_fn : (b : byte[:], n : std.size -> bool)
191         var input_sz = args_width(fn.inputs, p)
192         var output_sz = args_width(fn.outputs, p)
193         var draw_line : bool = false
194         var n = 0
195         var bits : byte[:] = std.slalloc(num_inputs * input_sz)
196         var last_bits : byte[:] = std.slalloc(num_inputs * input_sz)
197         std.slfill(bits, 0)
198         std.slfill(last_bits, 0)
199         for i : impls
200                 i.output_bits = std.slalloc(num_inputs * output_sz)
201                 std.slfill(i.output_bits, 0)
202         ;;
204         /* Now, loop perhaps infinitely with the comparisons */
205 :again
206         draw_line = false
208         /* Send question */
209         for i : impls
210                 match std.writeall(i.stdin, bits)
211                 | `std.Ok _:
212                 | `std.Err (_, e):
213                         std.put("CRASH: {w=20} [{w=6}] failed to receive data\n", i.name, i.pid)
214                         i.alive = false
215                         draw_line = true
216                 ;;
217         ;;
219         /* Reap zombies */
220         for var j = 0; j < impls.len; ++j
221                 if impls[j].alive
222                         continue
223                 ;;
224                 impls[j] = impls[impls.len - 1]
225                 std.slgrow(&impls, impls.len - 1)
226         ;;
228         /* Gather consensus on last time's answers */
229         consensus(last_bits, num_inputs, fn, p, x, impls)
230         if n % 100 == 0
231                 if n % 8000 == 0
232                         std.put("\x1b[1G\x1b[0K")
233                 ;;
234                 std.put(".")
235         ;;
236         n++
238         std.slcp(last_bits, bits)
240         /* Receive new answers */
241         for i : impls
242                 if !i.alive
243                         continue
244                 ;;
246                 match std.readall(i.stdout, i.output_bits)
247                 | `std.Ok _:
248                 | `std.Err e:
249                         std.put("CRASH: {w=20} [{w=6}] failed to send data\n", i.name, i.pid)
250                         i.alive = false
251                         draw_line = true
252                 |_:
253                 ;;
254         ;;
256         if impls.len < 2
257                 std.put("Less than 2 implementations left. Consensus impossible.\n")
258                 std.put("----------\n")
259                 std.exit(1)
260         ;;
262         if draw_line
263                 std.put("----------\n")
264         ;;
267         /* Onward */
268         if next_bits_fn(bits, num_inputs)
269                 goto again
270         ;;
273 const next_rand = { b : byte[:], n : std.size
274         std.rngrandbytes(rng, b)
275         -> true
278 const next_exhaust = { b : byte[:], n : std.size
279         var one_arg : byte[:] = std.slalloc(b.len / n)
280         std.slcp(one_arg, b[b.len - one_arg.len:])
281         var finished : bool = false
283         /* n is the number of total argument groups */
284         for var i = 0; i < n; ++i
285                 /* Increment this particular argument */
286                 var j = one_arg.len - 1
287                 while j >= 0
288                         one_arg[j]++
289                         if (one_arg[j] != 0)
290                                 break
291                         ;;
292                         j--
293                 ;;
295                 finished = finished || j < 0
296                 var z = one_arg.len * i
297                 std.slcp(b[z:z + one_arg.len], one_arg)
298         ;;
300         -> !finished
303 const list_functions = {
304         std.put("Available functions:\n")
305         std.put("--------------------\n")
306         for f : available_fns
307                 std.put("  {}\n", f)
308         ;;
311 const start_impls = {opts : byte[:][:]
312         var cmd : byte[:][:] = std.slalloc(opts.len + 1)
313         var nice_name : byte[:] = [][:]
314         var started_impls : impl_prog#[:] = [][:]
315         var survived_impls : impl_prog#[:] = [][:]
317         for var j = 0; j < opts.len; ++j
318                 cmd[j + 1] = opts[j]
319         ;;
321         /* Start everything */
322         for f  : fileutil.bywalk(".")
323                 match std.strrfind(f, "/")
324                 | `std.Some j: nice_name = std.sldup(f[j+1:])
325                 | `std.None: nice_name = std.sldup(f)
326                 ;;
328                 if nice_name.len < 5 || !std.eq(nice_name[:5], "impl-") || \
329                         std.eq(nice_name[nice_name.len - 2:nice_name.len - 1], ".")
330                         std.slfree(nice_name)
331                         continue
332                 ;;
334                 cmd[0] = f
335                 match std.spork(cmd)
336                 | `std.Ok (p, fi, fo) :
337                         std.slpush(&started_impls, std.mk([
338                                 .name = nice_name,
339                                 .pid = p,
340                                 .stdin = fi,
341                                 .stdout = fo,
342                                 .alive = true,
343                         ]))
344                 | `std.Err e: std.slfree(nice_name)
345                 ;;
346         ;;
348         /* Give them a bit of time to die */
349         std.usleep(500_000)
351         /* Reap the zombies */
352         var z
353         var l
354         var WNOHANG = 1 /* HACK */
355         while ((z = sys.waitpid(-1, &l, WNOHANG)) > 0)
356                 match sys.waitstatus(l)
357                 | `sys.Waitexit _:
358                 | `sys.Waitsig _:
359                 | `sys.Waitfail _:
360                 | `sys.Waitstop _: continue
361                 ;;
362                 for i : started_impls
363                         if i.pid == (z : std.pid)
364                                 i.alive = false
365                         ;;
366                 ;;
367         ;;
369         /* What remains? */
370         for i : started_impls
371                 if !i.alive
372                         continue
373                 ;;
375                 std.slpush(&survived_impls, i)
376         ;;
378         match survived_impls.len
379         | 0:
380                 std.put("No implementations found. Try running from fpmath-consensus root dir.\n")
381                 std.exit(1)
382         | 1:
383                 std.put("Only one implementation found. Comparisons will be impossible.\n")
384                 std.exit(1)
385         | _:
386         ;;
388         std.put("Executing:\n")
389         std.put("----------\n")
390         for i : survived_impls
391                 std.put("  [{w=6}] {w=20}", i.pid, i.name)
392                 for o : opts
393                         std.put(" {}", o)
394                 ;;
395                 std.put("\n")
396         ;;
397         std.put("----------\n")
399         std.slfree(started_impls)
401         -> survived_impls
404 const consensus = { input : byte[:], num_inputs : std.size, fn : fn_desc, p : flt_prec, x : exactness, impls : impl_prog#[:]
405         var all_agree : bool = true
406         var flt_sz : std.size = prec_width(p)
407         var inputs_sz = args_width(fn.inputs, p)
408         var outputs_sz = args_width(fn.outputs, p)
410         for var z = 0; z < num_inputs; ++z
411                 all_agree = true
413                 /* The input is possibly multiple entries */
414                 var i_start = z * inputs_sz
415                 var i_end = (z + 1) * inputs_sz
417                 /* The output might also be strange */
418                 var a_start = z * outputs_sz
419                 var a_end = (z + 1) * outputs_sz
421                 for var j = 0; j + 1 < impls.len; ++j
422                         match x
423                         | `Exact:
424                                 if std.sleq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end])
425                                         continue
426                                 ;;
428                                 /* The memory patterns don't agree. But perhaps this is due to NaN? */
429                                 if detailed_eq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end], fn.outputs, p, 0)
430                                         continue
431                                 ;;
432                         | `Inexact i:
433                                 var good : bool = true
435                                 for var k = j + 1; k < impls.len; ++k
436                                         good = good && detailed_eq(impls[j].output_bits[a_start:a_end], impls[k].output_bits[a_start:a_end], fn.outputs, p, i)
437                                 ;;
439                                 if good
440                                         continue
441                                 ;;
442                         ;;
444                         all_agree = false
445                         break
446                 ;;
448                 if all_agree
449                         continue
450                 ;;
452                 
453                 std.put("For input: ")
454                 extract_and_describe(input[i_start:i_end], fn.inputs, p)
455                 std.put("\n")
456                 for i : impls
457                         std.put("  [{w=6}] {w=20}: ", i.pid, i.name)
458                         extract_and_describe(i.output_bits[a_start:a_end], fn.outputs, p)
459                 ;;
460                 std.put("----------\n")
461         ;;
464 const extract_and_describe = {bits : byte[:], ts : fp_type[:], p : flt_prec
465         var w : std.size = prec_width(p)
467         if ts.len > 1
468                 std.put("\n")
469         ;;
471         for t : ts
472                 if ts.len > 1
473                         std.put("    ")
474                 ;;
476                 match t
477                 | `Flt:
478                         match p
479                         | `Single:
480                                 var u = std.getle32(bits[:w])
481                                 std.put("0x{w=8,p=0,x} ({})\n", u, std.flt32frombits(u))
482                         | `Double:
483                                 var u = std.getle64(bits[:w])
484                                 std.put("0x{w=16,p=0,x} ({})\n", u, std.flt64frombits(u))
485                         ;;
487                         bits = bits[w:]
488                 ;;
489         ;;
492 const detailed_eq = {a : byte[:], b : byte[:], ts : fp_type[:], p : flt_prec, i : uint
493         var w : std.size = prec_width(p)
495         for t : ts
496                 match t
497                 | `Flt:
498                         match p
499                         | `Single:
500                                 var u1 = std.getle32(a[:w])
501                                 var u2 = std.getle32(b[:w])
502                                 var f1 = std.flt32frombits(u1)
503                                 var f2 = std.flt32frombits(u2)
504                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
505                                         if i == 0 || ((u1 - u2 > (i : uint32)) && (u2 - u1 > (i : uint32)))
506                                                 -> false
507                                         ;;
508                                 ;;
509                         | `Double:
510                                 var u1 = std.getle64(a[:w])
511                                 var u2 = std.getle64(b[:w])
512                                 var f1 = std.flt64frombits(u1)
513                                 var f2 = std.flt64frombits(u2)
514                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
515                                         if i == 0 || ((u1 - u2 > (i : uint64)) && (u2 - u1 > (i : uint64)))
516                                                 -> false
517                                         ;;
518                                 ;;
519                         ;;
521                         a = a[w:]
522                         b = b[w:]
523                 ;;
524         ;;
526         -> true