support log and log1p
[fpmath-consensus.git] / checker / checker.myr
blobc959d43d694bac3bca01a14f73303d543c786914
1 use std
3 use bio
4 use iter
5 use fileutil
6 use sys
8 type impl_prog = struct
9         name : byte[:]
10         pid : std.pid
11         stdin : std.fd
12         stdout : std.fd
13         alive : bool
15         output_bits : byte[:]
18 var rng : std.rng#
20 /* Flt is ``whatever precision we're testing''. */
21 type fp_type = union
22         `Flt
25 type fn_desc = struct
26         name : byte[:]
27         inputs : fp_type[:]
28         outputs : fp_type[:]
31 type flt_prec = union
32         `Single
33         `Double
36 type exactness = union
37         `Exact
38         `Inexact
41 /* (name, number-of-flt-args, constant-extra-bytes) */
42 var available_fns : fn_desc[:] = [][:]
44 const nop = {;}
46 const main = {args : byte[:][:]
47         available_fns = [
48                 [.name = "id",    .inputs = [`Flt][:], .outputs = [`Flt][:]],
49                 [.name = "ceil",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
50                 [.name = "cos",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
51                 [.name = "exp",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
52                 [.name = "expm1", .inputs = [`Flt][:], .outputs = [`Flt][:]],
53                 [.name = "floor", .inputs = [`Flt][:], .outputs = [`Flt][:]],
54                 [.name = "fma",   .inputs = [`Flt, `Flt, `Flt][:], .outputs = [`Flt][:]],
55                 [.name = "log",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
56                 [.name = "log1p", .inputs = [`Flt][:], .outputs = [`Flt][:]],
57                 [.name = "sin",   .inputs = [`Flt][:], .outputs = [`Flt][:]],
58                 [.name = "sqrt",  .inputs = [`Flt][:], .outputs = [`Flt][:]],
59                 [.name = "trunc", .inputs = [`Flt][:], .outputs = [`Flt][:]],
60         ][:]
62         var old
63         var sa = [
64                 .handler = (nop : byte#),
65                 .flags = sys.Saresethand,
66         ]
67         sys.sigaction(sys.Sigpipe, &sa, &old)
69         var fn : fn_desc = available_fns[0]
70         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
72         var precision : flt_prec = `Single
73         var exactness : exactness = `Exact
74         var impls : impl_prog#[:]
75         rng = std.mksrng((std.now() : uint32))
77         (precision, exactness, fn, next_bits_fn) = read_args(args)
79         var input_sz = args_width(fn.inputs, precision)
80         var num_inputs = (1 << 18)
81         var buf_sz = num_inputs * input_sz
83         impls = start_impls([prec_arg(precision), "-f", fn.name, "-n", std.fmt("{}", num_inputs)][:])
85         io_loop(impls, precision, exactness, num_inputs, fn, next_bits_fn)
86         std.put("\n")
89 const prec_width = {p
90         match p
91         | `Single: -> 4
92         | `Double: -> 8
93         ;;
96 const prec_arg = {p
97         match p
98         | `Single: -> "-s"
99         | `Double: -> "-d"
100         ;;
103 const args_width = {ts : fp_type[:], p : flt_prec
104         var w : std.size = 0
106         for t : ts
107                 match t
108                 | `Flt: w += prec_width(p)
109                 ;;
110         ;;
112         -> w
115 const read_args = {args : byte[:][:]
116         var exactness : exactness = `Exact
117         var precision : flt_prec = `Single
118         var fn_name : byte[:] = "UNSPECIFIED"
119         var fn : fn_desc = available_fns[0]
120         var next_bits_fn : (b : byte[:], n : std.size -> bool) = next_rand
122         var cmd = std.optparse(args, &[
123                 .argdesc = "",
124                 .opts = [
125                         [.opt = 's', .desc = "use single precision (default)"],
126                         [.opt = 'd', .desc = "use double precision"],
127                         [.opt = 'l', .desc = "list available functions"],
128                         [.opt = 'f', .arg = "f", .desc = "test function ‘f’"],
129                         [.opt = 'r', .arg = "s", .desc = "choose inputs randomly, with seed ‘s’"],
130                         [.opt = 'e', .desc = "exhaust input space"],
131                         [.opt = 'i', .desc = "allow inexact results (off by 1 bit)"],
132                 ][:]
133         ])
135         for opt : cmd.opts
136                 match opt
137                 | ('s', _):
138                         precision = `Single
139                 | ('d', _):
140                         precision = `Double
141                 | ('i', _):
142                         exactness = `Inexact
143                 | ('l', _):
144                         list_functions()
145                         std.exit(0)
146                 | ('f', arg): fn_name = arg
147                 | ('r', arg):
148                         next_bits_fn = next_rand
149                         match std.intparse(arg)
150                         | `std.Some n:
151                                 rng = std.mksrng(n)
152                         | `std.None:
153                                 std.put("cannot parse seed “{}”\n", arg)
154                                 std.exit(1)
155                         ;;
156                 | ('e', _): next_bits_fn = next_exhaust
157                 | _: std.die("impossible\n")
158                 ;;
159         ;;
161         var good_fn : bool = false
162         for f : available_fns
163                 if std.eq(f.name, fn_name)
164                         fn = f
165                         good_fn = true
166                         break
167                 ;;
168         ;;
170         if !good_fn
171                 std.put("unknown function “{}”\n", fn_name)
172                 std.exit(1)
173         ;;
175         -> (precision, exactness, fn, next_bits_fn)
178 const io_loop = {impls : impl_prog#[:], p : flt_prec, x : exactness, num_inputs : std.size, fn : fn_desc, next_bits_fn : (b : byte[:], n : std.size -> bool)
179         var input_sz = args_width(fn.inputs, p)
180         var output_sz = args_width(fn.outputs, p)
181         var draw_line : bool = false
182         var n = 0
183         var bits : byte[:] = std.slalloc(num_inputs * input_sz)
184         var last_bits : byte[:] = std.slalloc(num_inputs * input_sz)
185         std.slfill(bits, 0)
186         std.slfill(last_bits, 0)
187         for i : impls
188                 i.output_bits = std.slalloc(num_inputs * output_sz)
189                 std.slfill(i.output_bits, 0)
190         ;;
192         /* Now, loop perhaps infinitely with the comparisons */
193 :again
194         draw_line = false
196         /* Send question */
197         for i : impls
198                 match std.writeall(i.stdin, bits)
199                 | (_, `std.Some e):
200                         std.put("CRASH: {w=20} [{w=6}] failed to receive data\n", i.name, i.pid)
201                         i.alive = false
202                         draw_line = true
203                 | _:
204                 ;;
205         ;;
207         /* Reap zombies */
208         for var j = 0; j < impls.len; ++j
209                 if impls[j].alive
210                         continue
211                 ;;
212                 impls[j] = impls[impls.len - 1]
213                 std.slgrow(&impls, impls.len - 1)
214         ;;
216         /* Gather consensus on last time's answers */
217         consensus(last_bits, num_inputs, fn, p, x, impls)
218         if n % 100 == 0
219                 if n % 8000 == 0
220                         std.put("\x1b[1G\x1b[0K")
221                 ;;
222                 std.put(".")
223         ;;
224         n++
226         std.slcp(last_bits, bits)
228         /* Receive new answers */
229         for i : impls
230                 if !i.alive
231                         continue
232                 ;;
234                 match std.readall(i.stdout, i.output_bits)
235                 | `std.Ok _:
236                 | `std.Err e:
237                         std.put("CRASH: {w=20} [{w=6}] failed to send data\n", i.name, i.pid)
238                         i.alive = false
239                         draw_line = true
240                 |_:
241                 ;;
242         ;;
244         if impls.len < 2
245                 std.put("Less than 2 implementations left. Consensus impossible.\n")
246                 std.put("----------\n")
247                 std.exit(1)
248         ;;
250         if draw_line
251                 std.put("----------\n")
252         ;;
255         /* Onward */
256         if next_bits_fn(bits, num_inputs)
257                 goto again
258         ;;
261 const next_rand = { b : byte[:], n : std.size
262         std.rngrandbytes(rng, b)
263         -> true
266 const next_exhaust = { b : byte[:], n : std.size
267         var one_arg : byte[:] = std.slalloc(b.len / n)
268         std.slcp(one_arg, b[b.len - one_arg.len:])
269         var finished : bool = false
271         /* n is the number of total argument groups */
272         for var i = 0; i < n; ++i
273                 /* Increment this particular argument */
274                 var j = one_arg.len - 1
275                 while j >= 0
276                         one_arg[j]++
277                         if (one_arg[j] != 0)
278                                 break
279                         ;;
280                         j--
281                 ;;
283                 finished = finished || j < 0
284                 var z = one_arg.len * i
285                 std.slcp(b[z:z + one_arg.len], one_arg)
286         ;;
288         -> !finished
291 const list_functions = {
292         std.put("Available functions:\n")
293         std.put("--------------------\n")
294         for f : available_fns
295                 std.put("  {}\n", f)
296         ;;
299 const start_impls = {opts : byte[:][:]
300         var cmd : byte[:][:] = std.slalloc(opts.len + 1)
301         var nice_name : byte[:] = [][:]
302         var started_impls : impl_prog#[:] = [][:]
303         var survived_impls : impl_prog#[:] = [][:]
305         for var j = 0; j < opts.len; ++j
306                 cmd[j + 1] = opts[j]
307         ;;
309         /* Start everything */
310         for f  : fileutil.bywalk(".")
311                 match std.strrfind(f, "/")
312                 | `std.Some j: nice_name = std.sldup(f[j+1:])
313                 | `std.None: nice_name = std.sldup(f)
314                 ;;
316                 if nice_name.len < 5 || !std.eq(nice_name[:5], "impl-") || \
317                         std.eq(nice_name[nice_name.len - 2:nice_name.len - 1], ".")
318                         std.slfree(nice_name)
319                         continue
320                 ;;
322                 cmd[0] = f
323                 match std.spork(cmd)
324                 | `std.Ok (p, fi, fo) :
325                         std.slpush(&started_impls, std.mk([
326                                 .name = nice_name,
327                                 .pid = p,
328                                 .stdin = fi,
329                                 .stdout = fo,
330                                 .alive = true,
331                         ]))
332                 | `std.Err e: std.slfree(nice_name)
333                 ;;
334         ;;
336         /* Give them a bit of time to die */
337         std.usleep(500_000)
339         /* Reap the zombies */
340         var z
341         var l
342         var WNOHANG = 1 /* HACK */
343         while ((z = sys.waitpid(-1, &l, WNOHANG)) > 0)
344                 match sys.waitstatus(l)
345                 | `sys.Waitexit _:
346                 | `sys.Waitsig _:
347                 | `sys.Waitfail _:
348                 | `sys.Waitstop _: continue
349                 ;;
350                 for i : started_impls
351                         if i.pid == (z : std.pid)
352                                 i.alive = false
353                         ;;
354                 ;;
355         ;;
357         /* What remains? */
358         for i : started_impls
359                 if !i.alive
360                         continue
361                 ;;
363                 std.slpush(&survived_impls, i)
364         ;;
366         match survived_impls.len
367         | 0:
368                 std.put("No implementations found. Try running from fpmath-consensus root dir.\n")
369                 std.exit(1)
370         | 1:
371                 std.put("Only one implementation found. Comparisons will be impossible.\n")
372                 std.exit(1)
373         | _:
374         ;;
376         std.put("Executing:\n")
377         std.put("----------\n")
378         for i : survived_impls
379                 std.put("  [{w=6}] {w=20}", i.pid, i.name)
380                 for o : opts
381                         std.put(" {}", o)
382                 ;;
383                 std.put("\n")
384         ;;
385         std.put("----------\n")
387         std.slfree(started_impls)
389         -> survived_impls
392 const consensus = { input : byte[:], num_inputs : std.size, fn : fn_desc, p : flt_prec, x : exactness, impls : impl_prog#[:]
393         var all_agree : bool = true
394         var flt_sz : std.size = prec_width(p)
395         var inputs_sz = args_width(fn.inputs, p)
396         var outputs_sz = args_width(fn.outputs, p)
398         for var z = 0; z < num_inputs; ++z
399                 all_agree = true
401                 /* The input is possibly multiple entries */
402                 var i_start = z * inputs_sz
403                 var i_end = (z + 1) * inputs_sz
405                 /* The output might also be strange */
406                 var a_start = z * outputs_sz
407                 var a_end = (z + 1) * outputs_sz
409                 for var j = 0; j + 1 < impls.len; ++j
410                         match x
411                         | `Exact:
412                                 if std.sleq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end])
413                                         continue
414                                 ;;
416                                 /* The memory patterns don't agree. But perhaps this is due to NaN? */
417                                 if detailed_eq(impls[j].output_bits[a_start:a_end], impls[j+1].output_bits[a_start:a_end], fn.outputs, p, false)
418                                         continue
419                                 ;;
420                         | `Inexact:
421                                 var good : bool = true
423                                 for var k = j + 1; k < impls.len; ++k
424                                         good = good && detailed_eq(impls[j].output_bits[a_start:a_end], impls[k].output_bits[a_start:a_end], fn.outputs, p, true)
425                                 ;;
427                                 if good
428                                         continue
429                                 ;;
430                         ;;
432                         all_agree = false
433                         break
434                 ;;
436                 if all_agree
437                         continue
438                 ;;
440                 
441                 std.put("For input: ")
442                 extract_and_describe(input[i_start:i_end], fn.inputs, p)
443                 std.put("\n")
444                 for i : impls
445                         std.put("  [{w=6}] {w=20}: ", i.pid, i.name)
446                         extract_and_describe(i.output_bits[a_start:a_end], fn.outputs, p)
447                 ;;
448                 std.put("----------\n")
449         ;;
452 const extract_and_describe = {bits : byte[:], ts : fp_type[:], p : flt_prec
453         var w : std.size = prec_width(p)
455         if ts.len > 1
456                 std.put("\n")
457         ;;
459         for t : ts
460                 if ts.len > 1
461                         std.put("    ")
462                 ;;
464                 match t
465                 | `Flt:
466                         match p
467                         | `Single:
468                                 var u = std.getle32(bits[:w])
469                                 std.put("0x{w=8,p=0,x} ({})\n", u, std.flt32frombits(u))
470                         | `Double:
471                                 var u = std.getle64(bits[:w])
472                                 std.put("0x{w=16,p=0,x} ({})\n", u, std.flt64frombits(u))
473                         ;;
475                         bits = bits[w:]
476                 ;;
477         ;;
480 const detailed_eq = {a : byte[:], b : byte[:], ts : fp_type[:], p : flt_prec, o : bool
481         var w : std.size = prec_width(p)
483         for t : ts
484                 match t
485                 | `Flt:
486                         match p
487                         | `Single:
488                                 var u1 = std.getle32(a[:w])
489                                 var u2 = std.getle32(b[:w])
490                                 var f1 = std.flt32frombits(u1)
491                                 var f2 = std.flt32frombits(u2)
492                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
493                                         if !(o && (u1 + 1 == u2  || u2 + 1 == u1))
494                                                 -> false
495                                         ;;
496                                 ;;
497                         | `Double:
498                                 var u1 = std.getle64(a[:w])
499                                 var u2 = std.getle64(b[:w])
500                                 var f1 = std.flt64frombits(u1)
501                                 var f2 = std.flt64frombits(u2)
502                                 if !((f1 == f2) || (std.isnan(f1) && std.isnan(f2)))
503                                         if !(o && (u1 + 1 == u2  || u2 + 1 == u1))
504                                                 -> false
505                                         ;;
506                                 ;;
507                         ;;
509                         a = a[w:]
510                         b = b[w:]
511                 ;;
512         ;;
514         -> true