handle large hfas correctly on arm64
[qbe.git] / tools / mgen / cgen.ml
blob297265cc3086e07c2d5d9c6af5b1ce74c54c9b98
1 open Match
3 type options =
4 { pfx: string
5 ; static: bool
6 ; oc: out_channel }
8 type side = L | R
10 type id_pred =
11 | InBitSet of Int64.t
12 | Ge of int
13 | Eq of int
15 and id_test =
16 | Pred of (side * id_pred)
17 | And of id_test * id_test
19 type case_code =
20 | Table of ((int * int) * int) list
21 | IfThen of
22 { test: id_test
23 ; cif: case_code
24 ; cthen: case_code option }
25 | Return of int
27 type case =
28 { swap: bool
29 ; code: case_code }
31 let cgen_case tmp nstates map =
32 let cgen_test ids =
33 match ids with
34 | [id] -> Eq id
35 | _ ->
36 let min_id =
37 List.fold_left min max_int ids in
38 if List.length ids = nstates - min_id
39 then Ge min_id
40 else begin
41 assert (nstates <= 64);
42 InBitSet
43 (List.fold_left (fun bs id ->
44 Int64.logor bs
45 (Int64.shift_left 1L id))
46 0L ids)
47 end
49 let symmetric =
50 let inverse ((l, r), x) = ((r, l), x) in
51 setify map = setify (List.map inverse map) in
52 let map =
53 let ordered ((l, r), _) = r <= l in
54 if symmetric then
55 List.filter ordered map
56 else map
58 let exception BailToTable in
59 try
60 let st =
61 match setify (List.map snd map) with
62 | [st] -> st
63 | _ -> raise BailToTable
65 (* the operation considered can only
66 * generate a single state *)
67 let pairs = List.map fst map in
68 let ls, rs = List.split pairs in
69 let ls = setify ls and rs = setify rs in
70 if List.length ls > 1 && List.length rs > 1 then
71 raise BailToTable;
72 { swap = symmetric
73 ; code =
74 let pl = Pred (L, cgen_test ls)
75 and pr = Pred (R, cgen_test rs) in
76 IfThen
77 { test = And (pl, pr)
78 ; cif = Return st
79 ; cthen = Some (Return tmp) } }
80 with BailToTable ->
81 { swap = symmetric
82 ; code = Table map }
84 let show_op (_cls, op) =
85 "O" ^ show_op_base op
87 let indent oc i =
88 Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
90 let emit_swap oc i =
91 let pf m = Printf.fprintf oc m in
92 let pfi n m = indent oc n; pf m in
93 pfi i "if (l < r)\n";
94 pfi (i+1) "t = l, l = r, r = t;\n"
96 let gen_tables oc tmp pfx nstates (op, c) =
97 let i = 1 in
98 let pf m = Printf.fprintf oc m in
99 let pfi n m = indent oc n; pf m in
100 let ntables = ref 0 in
101 (* we must follow the order in which
102 * we visit code in emit_case, or
103 * else ntables goes out of sync *)
104 let base = pfx ^ show_op op in
105 let swap = c.swap in
106 let rec gen c =
107 match c with
108 | Table map ->
109 let name =
110 if !ntables = 0 then base else
111 base ^ string_of_int !ntables
113 assert (nstates <= 256);
114 if swap then
115 let n = nstates * (nstates + 1) / 2 in
116 pfi i "static uchar %stbl[%d] = {\n" name n
117 else
118 pfi i "static uchar %stbl[%d][%d] = {\n"
119 name nstates nstates;
120 for l = 0 to nstates - 1 do
121 pfi (i+1) "";
122 for r = 0 to nstates - 1 do
123 if not swap || r <= l then
124 begin
125 pf "%d"
126 (try List.assoc (l,r) map
127 with Not_found -> tmp);
128 pf ",";
130 done;
131 pf "\n";
132 done;
133 pfi i "};\n"
134 | IfThen {cif; cthen} ->
135 gen cif;
136 Option.iter gen cthen
137 | Return _ -> ()
139 gen c.code
141 let emit_case oc pfx no_swap (op, c) =
142 let fpf = Printf.fprintf in
143 let pf m = fpf oc m in
144 let pfi n m = indent oc n; pf m in
145 let rec side oc = function
146 | L -> fpf oc "l"
147 | R -> fpf oc "r"
149 let pred oc (s, pred) =
150 match pred with
151 | InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
152 | Eq id -> fpf oc "%a == %d" side s id
153 | Ge id -> fpf oc "%d <= %a" id side s
155 let base = pfx ^ show_op op in
156 let swap = c.swap in
157 let ntables = ref 0 in
158 let rec code i c =
159 match c with
160 | Return id -> pfi i "return %d;\n" id
161 | Table map ->
162 let name =
163 if !ntables = 0 then base else
164 base ^ string_of_int !ntables
166 incr ntables;
167 if swap then
168 pfi i "return %stbl[(l + l*l)/2 + r];\n" name
169 else pfi i "return %stbl[l][r];\n" name
170 | IfThen ({test = And (And (t1, t2), t3)} as r) ->
171 code i @@ IfThen
172 {r with test = And (t1, And (t2, t3))}
173 | IfThen {test = And (Pred p, t); cif; cthen} ->
174 pfi i "if (%a)\n" pred p;
175 code i (IfThen {test = t; cif; cthen})
176 | IfThen {test = Pred p; cif; cthen} ->
177 pfi i "if (%a) {\n" pred p;
178 code (i+1) cif;
179 pfi i "}\n";
180 Option.iter (code i) cthen
182 pfi 1 "case %s:\n" (show_op op);
183 if not no_swap && c.swap then
184 emit_swap oc 2;
185 code 2 c.code
187 let emit_list
188 ?(limit=60) ?(cut_before_sep=false)
189 ~col ~indent:i ~sep ~f oc l =
190 let sl = String.length sep in
191 let rstripped_sep, rssl =
192 if sep.[sl - 1] = ' ' then
193 String.sub sep 0 (sl - 1), sl - 1
194 else sep, sl
196 let lstripped_sep, lssl =
197 if sep.[0] = ' ' then
198 String.sub sep 1 (sl - 1), sl - 1
199 else sep, sl
201 let rec line col acc = function
202 | [] -> (List.rev acc, [])
203 | s :: l ->
204 let col = col + sl + String.length s in
205 let no_space =
206 if cut_before_sep || l = [] then
207 col > limit
208 else
209 col + rssl > limit
211 if no_space then
212 (List.rev acc, s :: l)
213 else
214 line col (s :: acc) l
216 let rec go col l =
217 if l = [] then () else
218 let ll, l = line col [] l in
219 Printf.fprintf oc "%s" (String.concat sep ll);
220 if l <> [] && cut_before_sep then begin
221 Printf.fprintf oc "\n";
222 indent oc i;
223 Printf.fprintf oc "%s" lstripped_sep;
224 go (8*i + lssl) l
225 end else if l <> [] then begin
226 Printf.fprintf oc "%s\n" rstripped_sep;
227 indent oc i;
228 go (8*i) l
229 end else ()
231 go col (List.map f l)
233 let emit_numberer opts n =
234 let pf m = Printf.fprintf opts.oc m in
235 let tmp = (atom_state n Tmp).id in
236 let con = (atom_state n AnyCon).id in
237 let nst = Array.length n.states in
238 let cases =
239 StateMap.by_ops n.statemap |>
240 List.map (fun (op, map) ->
241 (op, cgen_case tmp nst map))
243 let all_swap =
244 List.for_all (fun (_, c) -> c.swap) cases in
245 (* opn() *)
246 if opts.static then pf "static ";
247 pf "int\n";
248 pf "%sopn(int op, int l, int r)\n" opts.pfx;
249 pf "{\n";
250 cases |> List.iter
251 (gen_tables opts.oc tmp opts.pfx nst);
252 if List.exists (fun (_, c) -> c.swap) cases then
253 pf "\tint t;\n\n";
254 if all_swap then emit_swap opts.oc 1;
255 pf "\tswitch (op) {\n";
256 cases |> List.iter
257 (emit_case opts.oc opts.pfx all_swap);
258 pf "\tdefault:\n";
259 pf "\t\treturn %d;\n" tmp;
260 pf "\t}\n";
261 pf "}\n\n";
262 (* refn() *)
263 if opts.static then pf "static ";
264 pf "int\n";
265 pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
266 pf "{\n";
267 let cons =
268 List.filter_map (function
269 | (Con c, s) -> Some (c, s.id)
270 | _ -> None)
271 n.atoms
273 if cons <> [] then
274 pf "\tint64_t n;\n\n";
275 pf "\tswitch (rtype(r)) {\n";
276 pf "\tcase RTmp:\n";
277 if tmp <> 0 then begin
278 assert
279 (List.exists (fun (_, s) ->
280 s.id = 0
281 ) n.atoms &&
282 (* no temp should ever get state 0 *)
283 List.for_all (fun (a, s) ->
284 s.id <> 0 ||
285 match a with
286 | AnyCon | Con _ -> true
287 | _ -> false
288 ) n.atoms);
289 pf "\t\tif (!tn[r.val].n)\n";
290 pf "\t\t\ttn[r.val].n = %d;\n" tmp;
291 end;
292 pf "\t\treturn tn[r.val].n;\n";
293 pf "\tcase RCon:\n";
294 if cons <> [] then begin
295 pf "\t\tif (con[r.val].type != CBits)\n";
296 pf "\t\t\treturn %d;\n" con;
297 pf "\t\tn = con[r.val].bits.i;\n";
298 cons |> inverse |> group_by_fst
299 |> List.iter (fun (id, cs) ->
300 pf "\t\tif (";
301 emit_list ~cut_before_sep:true
302 ~col:20 ~indent:2 ~sep:" || "
303 ~f:(fun c -> "n == " ^ Int64.to_string c)
304 opts.oc cs;
305 pf ")\n";
306 pf "\t\t\treturn %d;\n" id
308 end;
309 pf "\t\treturn %d;\n" con;
310 pf "\tdefault:\n";
311 pf "\t\treturn INT_MIN;\n";
312 pf "\t}\n";
313 pf "}\n\n";
314 (* match[]: patterns per state *)
315 if opts.static then pf "static ";
316 pf "bits %smatch[%d] = {\n" opts.pfx nst;
317 n.states |> Array.iteri (fun sn s ->
318 let tops =
319 List.filter_map (function
320 | Top ("$" | "%") -> None
321 | Top r -> Some ("BIT(P" ^ r ^ ")")
322 | _ -> None) s.point |> setify
324 if tops <> [] then
325 pf "\t[%d] = %s,\n"
326 sn (String.concat " | " tops);
328 pf "};\n\n"
330 let var_id vars f =
331 List.mapi (fun i x -> (x, i)) vars |>
332 List.assoc f
334 let compile_action vars act =
335 let pcs = Hashtbl.create 100 in
336 let rec gen pc (act: Action.t) =
338 [10 + Hashtbl.find pcs act.id]
339 with Not_found ->
340 let code =
341 match act.node with
342 | Action.Stop ->
344 | Action.Push (sym, k) ->
345 let c = if sym then 1 else 2 in
346 [c] @ gen (pc + 1) k
347 | Action.Set (v, {node = Action.Pop k; _})
348 | Action.Set (v, ({node = Action.Stop; _} as k)) ->
349 let v = var_id vars v in
350 [3; v] @ gen (pc + 2) k
351 | Action.Set _ ->
352 (* for now, only atomic patterns can be
353 * tied to a variable, so Set must be
354 * followed by either Pop or Stop *)
355 assert false
356 | Action.Pop k ->
357 [4] @ gen (pc + 1) k
358 | Action.Switch cases ->
359 let cases =
360 inverse cases |> group_by_fst |>
361 List.sort (fun (_, cs1) (_, cs2) ->
362 let n1 = List.length cs1
363 and n2 = List.length cs2 in
364 compare n2 n1)
366 (* the last case is the one with
367 * the max number of entries *)
368 let cases = List.rev (List.tl cases)
369 and last = fst (List.hd cases) in
370 let ncases =
371 List.fold_left (fun n (_, cs) ->
372 List.length cs + n)
373 0 cases
375 let body_off = 2 + 2 * ncases + 1 in
376 let pc, tbl, body =
377 List.fold_left
378 (fun (pc, tbl, body) (a, cs) ->
379 let ofs = body_off + List.length body in
380 let case = gen pc a in
381 let pc = pc + List.length case in
382 let body = body @ case in
383 let tbl =
384 List.fold_left (fun tbl c ->
385 tbl @ [c; ofs]
386 ) tbl cs
388 (pc, tbl, body))
389 (pc + body_off, [], [])
390 cases
392 let ofs = body_off + List.length body in
393 let tbl = tbl @ [ofs] in
394 assert (2 + List.length tbl = body_off);
395 [5; ncases] @ tbl @ body @ gen pc last
397 if act.node <> Action.Stop then
398 Hashtbl.replace pcs act.id pc;
399 code
401 gen 0 act
403 let emit_matchers opts ms =
404 let pf m = Printf.fprintf opts.oc m in
405 if opts.static then pf "static ";
406 pf "uchar *%smatcher[] = {\n" opts.pfx;
407 List.iter (fun (vars, pname, m) ->
408 pf "\t[P%s] = (uchar[]){\n" pname;
409 pf "\t\t";
410 let bytes = compile_action vars m in
411 emit_list
412 ~col:16 ~indent:2 ~sep:","
413 ~f:string_of_int opts.oc bytes;
414 pf "\n";
415 pf "\t},\n")
417 pf "};\n\n"
419 let emit_c opts n =
420 emit_numberer opts n