Close input file after done reading
[qbe.git] / tools / callgen.ml
blobd53eabb988219950ae63794b7b945d5c55a5920a
1 (* abi fuzzer, generates two modules one calling
2 * the other in two possibly different languages
3 *)
5 type _ bty =
6 | Char: int bty
7 | Short: int bty
8 | Int: int bty
9 | Long: int bty
10 | Float: float bty
11 | Double: float bty
13 type _ sty =
14 | Field: 'a bty * 'b sty -> ('a * 'b) sty
15 | Empty: unit sty
17 type _ aty =
18 | Base: 'a bty -> 'a aty
19 | Struct: 'a sty -> 'a aty
21 type anyb = AB: _ bty -> anyb (* kinda boring... *)
22 type anys = AS: _ sty -> anys
23 type anya = AA: _ aty -> anya
24 type testb = TB: 'a bty * 'a -> testb
25 type testa = TA: 'a aty * 'a -> testa
28 let align a x =
29 let m = x mod a in
30 if m <> 0 then x + (a-m) else x
32 let btysize: type a. a bty -> int = function
33 | Char -> 1
34 | Short -> 2
35 | Int -> 4
36 | Long -> 8
37 | Float -> 4
38 | Double -> 8
40 let btyalign = btysize
42 let styempty: type a. a sty -> bool = function
43 | Field _ -> false
44 | Empty -> true
46 let stysize s =
47 let rec f: type a. int -> a sty -> int =
48 fun sz -> function
49 | Field (b, s) ->
50 let a = btyalign b in
51 f (align a sz + btysize b) s
52 | Empty -> sz in
53 f 0 s
55 let rec styalign: type a. a sty -> int = function
56 | Field (b, s) -> max (btyalign b) (styalign s)
57 | Empty -> 1
60 (* Generate types and test vectors. *)
61 module Gen = struct
62 module R = Random
64 let init = function
65 | None ->
66 let f = open_in "/dev/urandom" in
67 let seed =
68 Char.code (input_char f) lsl 16 +
69 Char.code (input_char f) lsl 8 +
70 Char.code (input_char f) in
71 close_in f;
72 R.init seed;
73 seed
74 | Some seed ->
75 R.init seed;
76 seed
78 let int sz =
79 let bound = 1 lsl (8 * min sz 3 - 1) in
80 let i = R.int bound in
81 if R.bool () then - i else i
83 let float () =
84 let f = R.float 1000. in
85 if R.bool () then -. f else f
87 let testv: type a. a aty -> a =
88 let tb: type a. a bty -> a = function (* eh, dry... *)
89 | Float -> float ()
90 | Double -> float ()
91 | Char -> int (btysize Char)
92 | Short -> int (btysize Short)
93 | Int -> int (btysize Int)
94 | Long -> int (btysize Long) in
95 let rec ts: type a. a sty -> a = function
96 | Field (b, s) -> (tb b, ts s)
97 | Empty -> () in
98 function
99 | Base b -> tb b
100 | Struct s -> ts s
102 let b () = (* uniform *)
103 match R.int 6 with
104 | 0 -> AB Char
105 | 1 -> AB Short
106 | 2 -> AB Int
107 | 3 -> AB Long
108 | 4 -> AB Float
109 | _ -> AB Double
111 let smax = 5 (* max elements in structs *)
112 let structp = 0.3 (* odds of having a struct type *)
113 let amax = 8 (* max function arguments *)
115 let s () =
116 let rec f n =
117 if n = 0 then AS Empty else
118 let AB bt = b () in
119 let AS st = f (n-1) in
120 AS (Field (bt, st)) in
121 f (1 + R.int (smax-1))
123 let a () =
124 if R.float 1.0 > structp then
125 let AB bt = b () in
126 AA (Base bt)
127 else
128 let AB bt = b () in
129 let AS st = s () in
130 AA (Struct (Field (bt, st)))
132 let test () =
133 let AA ty = a () in
134 let t = testv ty in
135 TA (ty, t)
137 let tests () =
138 let rec f n =
139 if n = 0 then [] else
140 test () :: f (n-1) in
141 f (R.int amax)
146 (* Code generation for C *)
147 module OutC = struct
148 open Printf
150 let ctypelong oc name =
151 let cb: type a. a bty -> unit = function
152 | Char -> fprintf oc "char"
153 | Short -> fprintf oc "short"
154 | Int -> fprintf oc "int"
155 | Long -> fprintf oc "long"
156 | Float -> fprintf oc "float"
157 | Double -> fprintf oc "double" in
158 let rec cs: type a. int -> a sty -> unit =
159 fun i -> function
160 | Field (b, s) ->
161 cb b;
162 fprintf oc " f%d; " i;
163 cs (i+1) s;
164 | Empty -> () in
165 function
166 | Base b ->
167 cb b;
168 | Struct s ->
169 fprintf oc "struct %s { " name;
170 cs 1 s;
171 fprintf oc "}";
174 let ctype: type a. out_channel -> string -> a aty -> unit =
175 fun oc name -> function
176 | Struct _ -> fprintf oc "struct %s" name
177 | t -> ctypelong oc "" t
179 let base: type a. out_channel -> a bty * a -> unit =
180 fun oc -> function
181 | Char, i -> fprintf oc "%d" i
182 | Short, i -> fprintf oc "%d" i
183 | Int, i -> fprintf oc "%d" i
184 | Long, i -> fprintf oc "%d" i
185 | Float, f -> fprintf oc "%ff" f
186 | Double, f -> fprintf oc "%f" f
188 let init oc name (TA (ty, t)) =
189 let inits s =
190 let rec f: type a. a sty * a -> unit = function
191 | Field (b, s), (tb, ts) ->
192 base oc (b, tb);
193 fprintf oc ", ";
194 f (s, ts)
195 | Empty, () -> () in
196 fprintf oc "{ ";
197 f s;
198 fprintf oc "}"; in
199 ctype oc name ty;
200 fprintf oc " %s = " name;
201 begin match (ty, t) with
202 | Base b, tb -> base oc (b, tb)
203 | Struct s, ts -> inits (s, ts)
204 end;
205 fprintf oc ";\n";
208 let extension = ".c"
210 let comment oc s =
211 fprintf oc "/* %s */\n" s
213 let prelude oc = List.iter (fprintf oc "%s\n")
214 [ "#include <stdio.h>"
215 ; "#include <stdlib.h>"
216 ; ""
217 ; "static void fail(char *chk)"
218 ; "{"
219 ; "\tfprintf(stderr, \"fail: checking %s\\n\", chk);"
220 ; "\tabort();"
221 ; "}"
222 ; ""
225 let typedef oc name = function
226 | TA (Struct ts, _) ->
227 ctypelong oc name (Struct ts);
228 fprintf oc ";\n";
229 | _ -> ()
231 let check oc name =
232 let chkbase: type a. string -> a bty * a -> unit =
233 fun name t ->
234 fprintf oc "\tif (%s != " name;
235 base oc t;
236 fprintf oc ")\n\t\tfail(%S);\n" name; in
237 function
238 | TA (Base b, tb) -> chkbase name (b, tb)
239 | TA (Struct s, ts) ->
240 let rec f: type a. int -> a sty * a -> unit =
241 fun i -> function
242 | Field (b, s), (tb, ts) ->
243 chkbase (Printf.sprintf "%s.f%d" name i) (b, tb);
244 f (i+1) (s, ts);
245 | Empty, () -> () in
246 f 1 (s, ts)
248 let argname i = "arg" ^ string_of_int (i+1)
250 let proto oc (TA (tret, _)) args =
251 ctype oc "ret" tret;
252 fprintf oc " f(";
253 let narg = List.length args in
254 List.iteri (fun i (TA (targ, _)) ->
255 ctype oc (argname i) targ;
256 fprintf oc " %s" (argname i);
257 if i <> narg-1 then
258 fprintf oc ", ";
259 ) args;
260 fprintf oc ")";
263 let caller oc ret args =
264 let narg = List.length args in
265 prelude oc;
266 typedef oc "ret" ret;
267 List.iteri (fun i arg ->
268 typedef oc (argname i) arg;
269 ) args;
270 proto oc ret args;
271 fprintf oc ";\n\nint main()\n{\n";
272 List.iteri (fun i arg ->
273 fprintf oc "\t";
274 init oc (argname i) arg;
275 ) args;
276 fprintf oc "\t";
277 let TA (tret, _) = ret in
278 ctype oc "ret" tret;
279 fprintf oc " ret;\n\n";
280 fprintf oc "\tret = f(";
281 List.iteri (fun i _ ->
282 fprintf oc "%s" (argname i);
283 if i <> narg-1 then
284 fprintf oc ", ";
285 ) args;
286 fprintf oc ");\n";
287 check oc "ret" ret;
288 fprintf oc "\n\treturn 0;\n}\n";
291 let callee oc ret args =
292 prelude oc;
293 typedef oc "ret" ret;
294 List.iteri (fun i arg ->
295 typedef oc (argname i) arg;
296 ) args;
297 fprintf oc "\n";
298 proto oc ret args;
299 fprintf oc "\n{\n\t";
300 init oc "ret" ret;
301 fprintf oc "\n";
302 List.iteri (fun i arg ->
303 check oc (argname i) arg;
304 ) args;
305 fprintf oc "\n\treturn ret;\n}\n";
310 (* Code generation for QBE *)
311 module OutIL = struct
312 open Printf
314 let comment oc s =
315 fprintf oc "# %s\n" s
317 let tmp, lbl =
318 let next = ref 0 in
319 (fun () -> incr next; "%t" ^ (string_of_int !next)),
320 (fun () -> incr next; "@l" ^ (string_of_int !next))
322 let bvalue: type a. a bty * a -> string = function
323 | Char, i -> sprintf "%d" i
324 | Short, i -> sprintf "%d" i
325 | Int, i -> sprintf "%d" i
326 | Long, i -> sprintf "%d" i
327 | Float, f -> sprintf "s_%f" f
328 | Double, f -> sprintf "d_%f" f
330 let btype: type a. a bty -> string = function
331 | Char -> "w"
332 | Short -> "w"
333 | Int -> "w"
334 | Long -> "l"
335 | Float -> "s"
336 | Double -> "d"
338 let extension = ".ssa"
340 let argname i = "arg" ^ string_of_int (i+1)
342 let siter oc base s g =
343 let rec f: type a. int -> int -> a sty * a -> unit =
344 fun id off -> function
345 | Field (b, s), (tb, ts) ->
346 let off = align (btyalign b) off in
347 let addr = tmp () in
348 fprintf oc "\t%s =l add %d, %s\n" addr off base;
349 g id addr (TB (b, tb));
350 f (id + 1) (off + btysize b) (s, ts);
351 | Empty, () -> () in
352 f 0 0 s
354 let bmemtype b =
355 if AB b = AB Char then "b" else
356 if AB b = AB Short then "h" else
357 btype b
359 let init oc = function
360 | TA (Base b, tb) -> bvalue (b, tb)
361 | TA (Struct s, ts) ->
362 let base = tmp () in
363 fprintf oc "\t%s =l alloc%d %d\n"
364 base (styalign s) (stysize s);
365 siter oc base (s, ts)
366 begin fun _ addr (TB (b, tb)) ->
367 fprintf oc "\tstore%s %s, %s\n"
368 (bmemtype b) (bvalue (b, tb)) addr;
369 end;
370 base
372 let check oc id name =
373 let bcheck = fun id name (b, tb) ->
374 let tcmp = tmp () in
375 let nxtl = lbl () in
376 fprintf oc "\t%s =w ceq%s %s, %s\n"
377 tcmp (btype b) name (bvalue (b, tb));
378 fprintf oc "\tstorew %d, %%failcode\n" id;
379 fprintf oc "\tjnz %s, %s, @fail\n" tcmp nxtl;
380 fprintf oc "%s\n" nxtl; in
381 function
382 | TA (Base Char, i) ->
383 let tval = tmp () in
384 fprintf oc "\t%s =w extsb %s\n" tval name;
385 bcheck id tval (Int, i)
386 | TA (Base Short, i) ->
387 let tval = tmp () in
388 fprintf oc "\t%s =w extsh %s\n" tval name;
389 bcheck id tval (Int, i)
390 | TA (Base b, tb) ->
391 bcheck id name (b, tb)
392 | TA (Struct s, ts) ->
393 siter oc name (s, ts)
394 begin fun id' addr (TB (b, tb)) ->
395 let tval = tmp () in
396 let lsuffix =
397 if AB b = AB Char then "sb" else
398 if AB b = AB Short then "sh" else
399 "" in
400 fprintf oc "\t%s =%s load%s %s\n"
401 tval (btype b) lsuffix addr;
402 bcheck (100*id + id'+1) tval (b, tb);
403 end;
406 let ttype name = function
407 | TA (Base b, _) -> btype b
408 | TA (Struct _, _) -> ":" ^ name
410 let typedef oc name =
411 let rec f: type a. a sty -> unit = function
412 | Field (b, s) ->
413 fprintf oc "%s" (bmemtype b);
414 if not (styempty s) then
415 fprintf oc ", ";
416 f s;
417 | Empty -> () in
418 function
419 | TA (Struct ts, _) ->
420 fprintf oc "type :%s = { " name;
421 f ts;
422 fprintf oc " }\n";
423 | _ -> ()
425 let postlude oc = List.iter (fprintf oc "%s\n")
426 [ "@fail"
427 ; "# failure code"
428 ; "\t%fcode =w loadw %failcode"
429 ; "\t%f0 =w call $printf(l $failstr, w %fcode)"
430 ; "\t%f1 =w call $abort()"
431 ; "\tret 0"
432 ; "}"
433 ; ""
434 ; "data $failstr = { b \"fail on check %d\\n\", b 0 }"
437 let caller oc ret args =
438 let narg = List.length args in
439 List.iteri (fun i arg ->
440 typedef oc (argname i) arg;
441 ) args;
442 typedef oc "ret" ret;
443 fprintf oc "\nexport function w $main() {\n";
444 fprintf oc "@start\n";
445 fprintf oc "\t%%failcode =l alloc4 4\n";
446 let targs = List.mapi (fun i arg ->
447 comment oc ("define argument " ^ (string_of_int (i+1)));
448 (ttype (argname i) arg, init oc arg)
449 ) args in
450 comment oc "call test function";
451 fprintf oc "\t%%ret =%s call $f(" (ttype "ret" ret);
452 List.iteri (fun i (ty, tmp) ->
453 fprintf oc "%s %s" ty tmp;
454 if i <> narg-1 then
455 fprintf oc ", ";
456 ) targs;
457 fprintf oc ")\n";
458 comment oc "check the return value";
459 check oc 0 "%ret" ret;
460 fprintf oc "\tret 0\n";
461 postlude oc;
464 let callee oc ret args =
465 let narg = List.length args in
466 List.iteri (fun i arg ->
467 typedef oc (argname i) arg;
468 ) args;
469 typedef oc "ret" ret;
470 fprintf oc "\nexport function %s $f(" (ttype "ret" ret);
471 List.iteri (fun i arg ->
472 let a = argname i in
473 fprintf oc "%s %%%s" (ttype a arg) a;
474 if i <> narg-1 then
475 fprintf oc ", ";
476 ) args;
477 fprintf oc ") {\n";
478 fprintf oc "@start\n";
479 fprintf oc "\t%%failcode =l alloc4 4\n";
480 List.iteri (fun i arg ->
481 comment oc ("checking argument " ^ (string_of_int (i+1)));
482 check oc (i+1) ("%" ^ argname i) arg;
483 ) args;
484 comment oc "define the return value";
485 let rettmp = init oc ret in
486 fprintf oc "\tret %s\n" rettmp;
487 postlude oc;
493 module type OUT = sig
494 val extension: string
495 val comment: out_channel -> string -> unit
496 val caller: out_channel -> testa -> testa list -> unit
497 val callee: out_channel -> testa -> testa list -> unit
500 let _ =
501 let usage code =
502 Printf.eprintf "usage: abi.ml [-s SEED] DIR {c,ssa} {c,ssa}\n";
503 exit code in
505 let outmod = function
506 | "c" -> (module OutC : OUT)
507 | "ssa" -> (module OutIL: OUT)
508 | _ -> usage 1 in
510 let seed, dir, mcaller, mcallee =
511 match Sys.argv with
512 | [| _; "-s"; seed; dir; caller; callee |] ->
513 let seed =
514 try Some (int_of_string seed) with
515 Failure _ -> usage 1 in
516 seed, dir, outmod caller, outmod callee
517 | [| _; dir; caller; callee |] ->
518 None, dir, outmod caller, outmod callee
519 | [| _; "-h" |] ->
520 usage 0
521 | _ ->
522 usage 1 in
524 let seed = Gen.init seed in
525 let tret = Gen.test () in
526 let targs = Gen.tests () in
527 let module OCaller = (val mcaller : OUT) in
528 let module OCallee = (val mcallee : OUT) in
529 let ocaller = open_out (dir ^ "/caller" ^ OCaller.extension) in
530 let ocallee = open_out (dir ^ "/callee" ^ OCallee.extension) in
531 OCaller.comment ocaller (Printf.sprintf "seed %d" seed);
532 OCallee.comment ocallee (Printf.sprintf "seed %d" seed);
533 OCaller.caller ocaller tret targs;
534 OCallee.callee ocallee tret targs;