handle large hfas correctly on arm64
[qbe.git] / tools / mgen / fuzz.ml
blob0821286477bb5aded88fed7f8584c71cc136a786
1 (* fuzz the tables and matchers generated *)
2 open Match
4 module Buffer: sig
5 type 'a t
6 val create: ?capacity:int -> unit -> 'a t
7 val reset: 'a t -> unit
8 val size: 'a t -> int
9 val get: 'a t -> int -> 'a
10 val set: 'a t -> int -> 'a -> unit
11 val push: 'a t -> 'a -> unit
12 end = struct
13 type 'a t =
14 { mutable size: int
15 ; mutable data: 'a array }
16 let mk_array n = Array.make n (Obj.magic 0)
17 let create ?(capacity = 10) () =
18 if capacity < 0 then invalid_arg "Buffer.make";
19 {size = 0; data = mk_array capacity}
20 let reset b = b.size <- 0
21 let size b = b.size
22 let get b n =
23 if n >= size b then invalid_arg "Buffer.get";
24 b.data.(n)
25 let set b n x =
26 if n >= size b then invalid_arg "Buffer.set";
27 b.data.(n) <- x
28 let push b x =
29 let cap = Array.length b.data in
30 if size b = cap then begin
31 let data = mk_array (2 * cap + 1) in
32 Array.blit b.data 0 data 0 cap;
33 b.data <- data
34 end;
35 let sz = size b in
36 b.size <- sz + 1;
37 set b sz x
38 end
40 let binop_state n op s1 s2 =
41 let key = K (op, s1, s2) in
42 try StateMap.find key n.statemap
43 with Not_found -> atom_state n Tmp
45 type id = int
46 type term_data =
47 | Binop of op * id * id
48 | Leaf of atomic_pattern
49 type term =
50 { id: id
51 ; data: term_data
52 ; state: p state }
54 let pp_term fmt (ta, id) =
55 let fpf x = Format.fprintf fmt x in
56 let rec pp _fmt id =
57 match ta.(id).data with
58 | Leaf (Con c) -> fpf "%Ld" c
59 | Leaf AnyCon -> fpf "$%d" id
60 | Leaf Tmp -> fpf "%%%d" id
61 | Binop (op, id1, id2) ->
62 fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
63 (show_op op) id ta.(id).state.id
64 pp id1 pp id2
65 in pp fmt id
67 (* A term pool is a deduplicated set of term
68 * that maintains nodes numbering using the
69 * statemap passed at creation time *)
70 module TermPool = struct
71 type t =
72 { terms: term Buffer.t
73 ; hcons: (term_data, id) Hashtbl.t
74 ; numbr: numberer }
76 let create numbr =
77 { terms = Buffer.create ()
78 ; hcons = Hashtbl.create 100
79 ; numbr }
80 let reset tp =
81 Buffer.reset tp.terms;
82 Hashtbl.clear tp.hcons
84 let size tp = Buffer.size tp.terms
85 let term tp id = Buffer.get tp.terms id
87 let mk_leaf tp atm =
88 let data = Leaf atm in
89 match Hashtbl.find tp.hcons data with
90 | id -> term tp id
91 | exception Not_found ->
92 let id = Buffer.size tp.terms in
93 let state = atom_state tp.numbr atm in
94 Buffer.push tp.terms {id; data; state};
95 Hashtbl.add tp.hcons data id;
96 term tp id
97 let mk_binop tp op t1 t2 =
98 let data = Binop (op, t1.id, t2.id) in
99 match Hashtbl.find tp.hcons data with
100 | id -> term tp id
101 | exception Not_found ->
102 let id = Buffer.size tp.terms in
103 let state =
104 binop_state tp.numbr op t1.state t2.state
106 Buffer.push tp.terms {id; data; state};
107 Hashtbl.add tp.hcons data id;
108 term tp id
110 let rec add_pattern tp = function
111 | Bnr (op, p1, p2) ->
112 let t1 = add_pattern tp p1 in
113 let t2 = add_pattern tp p2 in
114 mk_binop tp op t1 t2
115 | Atm atm -> mk_leaf tp atm
116 | Var (_, atm) -> add_pattern tp (Atm atm)
118 let explode_term tp id =
119 let rec aux tms n id =
120 let t = term tp id in
121 match t.data with
122 | Leaf _ -> (n, {t with id = n} :: tms)
123 | Binop (op, id1, id2) ->
124 let n1, tms = aux tms n id1 in
125 let n = n1 + 1 in
126 let n2, tms = aux tms n id2 in
127 let n = n2 + 1 in
128 (n, { t with data = Binop (op, n1, n2)
129 ; id = n } :: tms)
131 let n, tms = aux [] 0 id in
132 Array.of_list (List.rev tms), n
135 module R = Random
137 (* uniform pick in a list *)
138 let list_pick l =
139 let rec aux n l x =
140 match l with
141 | [] -> x
142 | y :: l ->
143 if R.int (n + 1) = 0 then
144 aux (n + 1) l y
145 else
146 aux (n + 1) l x
148 match l with
149 | [] -> invalid_arg "list_pick"
150 | x :: l -> aux 1 l x
152 let term_pick ~numbr =
153 let ops =
154 if numbr.ops = [] then
155 numbr.ops <-
156 (StateMap.fold (fun k _ ops ->
157 match k with
158 | K (op, _, _) -> op :: ops)
159 numbr.statemap [] |> setify);
160 numbr.ops
162 let rec gen depth =
163 (* exponential probability for leaves to
164 * avoid skewing towards shallow terms *)
165 let atm_prob = 0.75 ** float_of_int depth in
166 if R.float 1.0 <= atm_prob || ops = [] then
167 let atom, st = list_pick numbr.atoms in
168 (st, Atm atom)
169 else
170 let op = list_pick ops in
171 let s1, t1 = gen (depth - 1) in
172 let s2, t2 = gen (depth - 1) in
173 ( binop_state numbr op s1 s2
174 , Bnr (op, t1, t2) )
175 in fun ~depth -> gen depth
177 exception FuzzError
179 let rec pattern_depth = function
180 | Bnr (_, p1, p2) ->
181 1 + max (pattern_depth p1) (pattern_depth p2)
182 | Atm _ -> 0
183 | Var (_, atm) -> pattern_depth (Atm atm)
185 let ( %% ) a b =
186 1e2 *. float_of_int a /. float_of_int b
188 let progress ?(width = 50) msg pct =
189 Format.eprintf "\x1b[2K\r%!";
190 let progress_bar fmt =
191 let n =
192 let fwidth = float_of_int width in
193 1 + int_of_float (pct *. fwidth /. 1e2)
195 Format.fprintf fmt " %s%s %.0f%%@?"
196 (String.concat "" (List.init n (fun _ -> "▒")))
197 (String.make (max 0 (width - n)) '-')
200 Format.kfprintf progress_bar
201 Format.err_formatter msg
203 let fuzz_numberer rules numbr =
204 (* pick twice the max pattern depth so we
205 * have a chance to find non-trivial numbers
206 * for the atomic patterns in the rules *)
207 let depth =
208 List.fold_left (fun depth r ->
209 max depth (pattern_depth r.pattern))
210 0 rules * 2
212 (* fuzz until the term pool we are constructing
213 * is no longer growing fast enough; or we just
214 * went through sufficiently many iterations *)
215 let max_iter = 1_000_000 in
216 let low_insert_rate = 1e-2 in
217 let tp = TermPool.create numbr in
218 let rec loop new_stats i =
219 let (_, _, insert_rate) = new_stats in
220 if insert_rate <= low_insert_rate then () else
221 if i >= max_iter then () else
222 (* periodically update stats *)
223 let new_stats =
224 let (num, cnt, rate) = new_stats in
225 if num land 1023 = 0 then
226 let rate =
227 0.5 *. (rate +. float_of_int cnt /. 1023.)
229 progress " insert_rate=%.1f%%"
230 (i %% max_iter) (rate *. 1e2);
231 (num + 1, 0, rate)
232 else new_stats
234 (* create a term and check that its number is
235 * accurate wrt the rules *)
236 let st, term = term_pick ~numbr ~depth in
237 let state_matched =
238 List.filter_map (fun cu ->
239 match cu with
240 | Top ("$" | "%") -> None
241 | Top name -> Some name
242 | _ -> None)
243 st.point |> setify
245 let rule_matched =
246 List.filter_map (fun r ->
247 if pattern_match r.pattern term then
248 Some r.name
249 else None)
250 rules |> setify
252 if state_matched <> rule_matched then begin
253 let open Format in
254 let pp_str_list =
255 let pp_sep fmt () = fprintf fmt ",@ " in
256 pp_print_list ~pp_sep pp_print_string
258 eprintf "@.@[<v2>fuzz error for %s"
259 (show_pattern term);
260 eprintf "@ @[state matched: %a@]"
261 pp_str_list state_matched;
262 eprintf "@ @[rule matched: %a@]"
263 pp_str_list rule_matched;
264 eprintf "@]@.";
265 raise FuzzError;
266 end;
267 if state_matched = [] then
268 loop new_stats (i + 1)
269 else
270 (* add to the term pool *)
271 let old_size = TermPool.size tp in
272 let _ = TermPool.add_pattern tp term in
273 let new_stats =
274 let (num, cnt, rate) = new_stats in
275 if TermPool.size tp <> old_size then
276 (num + 1, cnt + 1, rate)
277 else
278 (num + 1, cnt, rate)
280 loop new_stats (i + 1)
282 loop (1, 0, 1.0) 0;
283 Format.eprintf
284 "@.@[ generated %.3fMiB of test terms@]@."
285 (float_of_int (Obj.reachable_words (Obj.repr tp))
286 /. 128. /. 1024.);
289 let rec run_matcher stk m (ta, id as t) =
290 let state id = ta.(id).state.id in
291 match m.Action.node with
292 | Action.Switch cases ->
293 let m =
294 try List.assoc (state id) cases
295 with Not_found -> failwith "no switch case"
297 run_matcher stk m t
298 | Action.Push (sym, m) ->
299 let l, r =
300 match ta.(id).data with
301 | Leaf _ -> failwith "push on leaf"
302 | Binop (_, l, r) -> (l, r)
304 if sym && state l > state r
305 then run_matcher (l :: stk) m (ta, r)
306 else run_matcher (r :: stk) m (ta, l)
307 | Action.Pop m -> begin
308 match stk with
309 | id :: stk -> run_matcher stk m (ta, id)
310 | [] -> failwith "pop on empty stack"
312 | Action.Set (v, m) ->
313 (v, id) :: run_matcher stk m t
314 | Action.Stop -> []
316 let rec term_match p (ta, id) =
317 let (|>>) x f =
318 match x with None -> None | Some x -> f x
320 let atom_match a =
321 match ta.(id).data with
322 | Leaf a' -> pattern_match (Atm a) (Atm a')
323 | Binop _ -> pattern_match (Atm a) (Atm Tmp)
325 match p with
326 | Var (v, a) when atom_match a ->
327 Some [(v, id)]
328 | Atm a when atom_match a -> Some []
329 | (Atm _ | Var _) -> None
330 | Bnr (op, pl, pr) -> begin
331 match ta.(id).data with
332 | Binop (op', idl, idr) when op' = op ->
333 term_match pl (ta, idl) |>> fun l1 ->
334 term_match pr (ta, idr) |>> fun l2 ->
335 Some (l1 @ l2)
336 | _ -> None
339 let test_matchers tp numbr rules =
340 let {statemap = sm; states = sa; _} = numbr in
341 let total = ref 0 in
342 let matchers =
343 let htbl = Hashtbl.create (Array.length sa) in
344 List.map (fun r -> (r.name, r.pattern)) rules |>
345 group_by_fst |>
346 List.iter (fun (r, ps) ->
347 total := !total + List.length ps;
348 let pm = (ps, lr_matcher sm sa rules r) in
349 sa |> Array.iter (fun s ->
350 if List.mem (Top r) s.point then
351 Hashtbl.add htbl s.id pm));
352 htbl
354 let seen = Hashtbl.create !total in
355 for id = 0 to TermPool.size tp - 1 do
356 if id land 1023 = 0 ||
357 id = TermPool.size tp - 1 then begin
358 progress
359 " coverage=%.1f%%"
360 (id %% TermPool.size tp)
361 (Hashtbl.length seen %% !total)
362 end;
363 let t = TermPool.explode_term tp id in
364 Hashtbl.find_all matchers
365 (TermPool.term tp id).state.id |>
366 List.iter (fun (ps, m) ->
367 let norm = List.fast_sort compare in
368 let ok =
369 match norm (run_matcher [] m t) with
370 | asn -> `Match (List.exists (fun p ->
371 match term_match p t with
372 | None -> false
373 | Some asn' ->
374 if asn = norm asn' then begin
375 Hashtbl.replace seen p ();
376 true
377 end else false) ps)
378 | exception e -> `RunFailure e
380 if ok <> `Match true then begin
381 let open Format in
382 let pp_asn fmt asn =
383 fprintf fmt "@[<h>";
384 pp_print_list
385 ~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
386 (fun fmt (v, d) ->
387 fprintf fmt "@[%s%d@]" v d)
388 fmt asn;
389 fprintf fmt "@]"
391 eprintf "@.@[<v2>matcher error for";
392 eprintf "@ @[%a@]" pp_term t;
393 begin match ok with
394 | `RunFailure e ->
395 eprintf "@ @[exception: %s@]"
396 (Printexc.to_string e)
397 | `Match (* false *) _ ->
398 let asn = run_matcher [] m t in
399 eprintf "@ @[assignment: %a@]"
400 pp_asn asn;
401 eprintf "@ @[<v2>could not match";
402 List.iter (fun p ->
403 eprintf "@ + @[%s@]"
404 (show_pattern p)) ps;
405 eprintf "@]"
406 end;
407 eprintf "@]@.";
408 raise FuzzError
409 end)
410 done;
411 Format.eprintf "@."