2 * Copyright (C) 2024, 2025 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
19 private unit compiler.optimize.utils;
21 uses compiler.optimize.defs;
23 fn function_pcode(params : list(pcode_t)) : list(pcode_t);
24 fn function_name(params : list(pcode_t)) : bytes;
25 fn function_extract_nested(pc : list(pcode_t), fn_idx : list(int)) : list(pcode_t);
26 fn function_specifier_length(params : list(pcode_t)) : int;
27 fn function_load(params : list(pcode_t)) : (int, function);
28 fn function_store(f : function) : list(pcode_t);
30 fn create_instr(opcode : pcode_t, params : list(pcode_t), bgi : int) : instruction;
31 fn load_function_context(pc : list(pcode_t)) : context;
32 fn dump_basic_blocks(ctx : context, dump_it : bool) : list(pcode_t);
37 uses compiler.common.blob;
38 uses compiler.common.evaluate;
39 uses compiler.parser.util;
41 fn function_specifier_length(params : list(pcode_t)) : int
43 var bl := blob_length(params[1 .. ]);
44 return 1 + bl + 1 + params[1 + bl];
47 fn function_load(params : list(pcode_t)) : (int, function)
49 var bl := blob_length(params[1 .. ]);
51 path_idx : params[0] shr 1,
52 program : (params[0] and 1) <> 0,
53 un : blob_load(params[1 .. ]),
54 fn_idx : fill(0, params[1 + bl]),
56 for i := 0 to params[1 + bl] do
57 r.fn_idx[i] := params[1 + bl + 1 + i];
58 return function_specifier_length(params), r;
61 fn function_store(f : function) : list(pcode_t)
63 var pc := list(pcode_t).[ (f.path_idx shl 1) + select(f.program, 0, 1) ];
64 pc += blob_store(f.un);
65 pc += list(pcode_t).[ len(f.fn_idx) ];
66 for i := 0 to len(f.fn_idx) do
71 fn function_pcode(params : list(pcode_t)) : list(pcode_t)
73 var l, f := function_load(params);
74 var fi := list(pcode_t).[ len(f.fn_idx) ];
75 for i := 0 to len(f.fn_idx) do
77 return load_optimized_pcode(f.path_idx, f.un, f.program, fi, false);
80 fn function_name(params : list(pcode_t)) : bytes
82 var pc := function_pcode(params);
83 return blob_load(pc[9 .. ]);
86 fn function_extract_nested(pc : list(pcode_t), fn_idx : list(int)) : list(pcode_t)
88 fn_idx := fn_idx[1 .. ];
89 for idx in list_consumer(fn_idx) do [
91 abort internal("function_extract_nested: too high function index: " + ntos(idx) + " >= " + ntos(pc[2]));
92 var ptr := 9 + blob_length(pc[9 .. ]);
97 pc := pc[ptr + 1 .. ptr + 1 + pc[ptr]];
103 fn decode_structured_params(offs : int, params : list(pcode_t)) : (int, param_set, param_set)
105 var ps : param_set := 0;
106 var ls : param_set := 0;
107 for i := 0 to params[0] do [
108 var scode := params[offs];
109 if scode = Structured_Record then [
112 ] else if scode = Structured_Option then [
114 ] else if scode = Structured_Array then [
119 abort internal("invalid structured type");
125 fn create_instr(opcode : pcode_t, params : list(pcode_t), bgi : int) : instruction
128 var ins := instruction.[
140 if opcode = P_BinaryOp then [
141 ins.read_set := 0 bts 3 bts 5;
142 ins.free_set := ins.read_set;
143 ins.write_set := 0 bts 1;
145 ] else if opcode = P_BinaryConstOp then [
146 ins.read_set := 0 bts 3;
147 ins.free_set := ins.read_set;
148 ins.write_set := 0 bts 1;
150 ] else if opcode = P_UnaryOp then [
151 ins.read_set := 0 bts 3;
152 ins.free_set := ins.read_set;
153 ins.write_set := 0 bts 1;
155 ] else if opcode = P_Copy then [
156 ins.read_set := 0 bts 2;
157 ins.free_set := ins.read_set;
158 ins.write_set := 0 bts 0;
160 ] else if opcode = P_Copy_Type_Cast then [
161 ins.read_set := 0 bts 2;
162 ins.free_set := ins.read_set;
163 ins.write_set := 0 bts 0;
165 ] else if opcode = P_Free then [
166 ins.read_set := 0 bts 0;
168 ] else if opcode = P_Eval then [
169 ins.read_set := 0 bts 0;
171 ] else if opcode = P_Keep then [
172 ins.read_set := 0 bts 0;
174 ] else if opcode = P_Fn then [
175 ins.write_set := 0 bts 0;
176 xlen := 3 + params[1] + params[2];
177 for i := 3 to xlen do [
182 ] else if opcode = P_Load_Local_Type then [
183 ins.write_set := 0 bts 0;
185 ] else if opcode = P_Load_Fn then [
186 var l := function_specifier_length(params[3 .. ]);
187 for i := 0 to params[1] do [
188 ins.read_set bts= 3 + l + i * 2 + 1;
189 ins.free_set bts= 3 + l + i * 2 + 1;
191 ins.write_set := 0 bts 0;
192 xlen := 3 + l + params[1] * 2;
193 ] else if opcode = P_Curry then [
194 ins.read_set := 0 bts 3;
195 ins.free_set := 0 bts 3;
196 for i := 0 to params[1] do [
197 ins.read_set bts= 4 + i * 2 + 1;
198 ins.free_set bts= 4 + i * 2 + 1;
200 ins.write_set := 0 bts 0;
201 xlen := 4 + params[1] * 2;
202 ] else if opcode = P_Call then [
203 var l := function_specifier_length(params[3 .. ] );
204 for i := 0 to params[2] do [
205 ins.read_set bts= 3 + l + i * 2 + 1;
206 ins.free_set bts= 3 + l + i * 2 + 1;
208 for i := 0 to params[1] do
209 ins.write_set bts= 3 + l + params[2] * 2 + i;
210 xlen := 3 + l + params[2] * 2 + params[1];
211 ] else if opcode = P_Call_Indirect then [
212 ins.read_set := 0 bts 4;
213 ins.free_set := 0 bts 4;
214 for i := 0 to params[2] do [
215 ins.read_set bts= 5 + i * 2 + 1;
216 ins.free_set bts= 5 + i * 2 + 1;
218 for i := 0 to params[1] do
219 ins.write_set bts= 5 + params[2] * 2 + i;
220 xlen := 5 + params[2] * 2 + params[1];
221 ] else if opcode = P_Load_Const then [
222 ins.write_set := 0 bts 0;
223 xlen := 1 + blob_length(params[1 .. ]);
224 ] else if opcode = P_Structured_Write then [
225 ins.read_set := 0 bts 3 bts 5;
226 ins.free_set := 0 bts 3 bts 5;
227 ins.write_set := 0 bts 1;
228 var pmask lmask : param_set;
229 xlen, pmask, lmask := decode_structured_params(6, params);
230 ins.read_set or= pmask;
232 ins.conflict_1 := 0 bts 5 or pmask;
233 ins.conflict_2 := 0 bts 1;
234 ] else if opcode = P_Record_Type or opcode = P_Option_Type then [
235 for i := 0 to params[1] do
236 ins.read_set bts= 2 + i;
237 ins.write_set := 0 bts 0;
238 var l := function_specifier_length(params[2 + params[1] .. ]);
239 xlen := 2 + params[1] + l;
240 ] else if opcode = P_Record_Create then [
241 ins.write_set := 0 bts 0;
242 for i := 0 to params[1] do [
243 ins.read_set bts= 2 + i * 2 + 1;
244 ins.free_set bts= 2 + i * 2 + 1;
246 xlen := 2 + params[1] * 2;
247 ] else if opcode = P_Record_Load_Slot then [
248 ins.read_set := 0 bts 1;
249 ins.write_set := 0 bts 0;
251 ] else if opcode = P_Record_Load then [
252 ins.read_set := 0 bts 2;
253 ins.write_set := 0 bts 0;
256 ] else if opcode = P_Option_Create then [
257 ins.read_set := 0 bts 3;
258 ins.free_set := 0 bts 3;
259 ins.write_set := 0 bts 0;
261 ] else if opcode = P_Option_Load then [
262 ins.read_set := 0 bts 2;
263 ins.write_set := 0 bts 0;
266 ] else if opcode = P_Option_Test then [
267 ins.read_set := 0 bts 1;
268 ins.write_set := 0 bts 0;
270 ] else if opcode = P_Option_Ord then [
271 ins.read_set := 0 bts 1;
272 ins.write_set := 0 bts 0;
274 ] else if opcode = P_Array_Flexible then [
275 ins.read_set := 0 bts 1;
276 ins.write_set := 0 bts 0;
278 ] else if opcode = P_Array_Fixed then [
279 ins.read_set := 0 bts 1 bts 2;
280 ins.write_set := 0 bts 0;
282 ] else if opcode = P_Array_Create then [
283 ins.read_set := 0 bts 3;
284 ins.write_set := 0 bts 0;
285 for i := 0 to params[2] do [
286 ins.read_set bts= 4 + i * 2 + 1;
287 ins.free_set bts= 4 + i * 2 + 1;
289 ins.lt_set := 0 bts 1;
290 xlen := 4 + params[2] * 2;
291 ] else if opcode = P_Array_Fill then [
292 ins.read_set := 0 bts 3 bts 4;
293 ins.free_set := 0 bts 3;
294 ins.write_set := 0 bts 0;
295 ins.lt_set := 0 bts 1;
297 ] else if opcode = P_Array_String then [
298 ins.write_set := 0 bts 0;
299 xlen := 1 + blob_length(params[1 .. ]);
300 ] else if opcode = P_Array_Unicode then [
301 ins.write_set := 0 bts 0;
303 ] else if opcode = P_Array_Load then [
304 ins.read_set := 0 bts 2 bts 3;
305 ins.write_set := 0 bts 0;
308 ] else if opcode = P_Array_Len then [
309 ins.read_set := 0 bts 1;
310 ins.write_set := 0 bts 0;
312 ] else if opcode = P_Array_Len_Greater_Than then [
313 ins.read_set := 0 bts 1 bts 2;
314 ins.write_set := 0 bts 0;
316 ] else if opcode = P_Array_Sub then [
317 ins.read_set := 0 bts 2 bts 3 bts 4;
318 ins.free_set := 0 bts 2;
319 ins.write_set := 0 bts 0;
321 ] else if opcode = P_Array_Skip then [
322 ins.read_set := 0 bts 2 bts 3;
323 ins.free_set := 0 bts 2;
324 ins.write_set := 0 bts 0;
326 ] else if opcode = P_Array_Append then [
327 ins.read_set := 0 bts 2 bts 4;
328 ins.free_set := 0 bts 2 bts 4;
329 ins.write_set := 0 bts 0;
331 ] else if opcode = P_Array_Append_One then [
332 ins.read_set := 0 bts 2 bts 4;
333 ins.free_set := 0 bts 2 bts 4;
334 ins.write_set := 0 bts 0;
336 ] else if opcode = P_Array_Flatten then [
337 ins.read_set := 0 bts 2;
338 ins.free_set := 0 bts 2;
339 ins.write_set := 0 bts 0;
341 ] else if opcode = P_Jmp then [
343 ] else if opcode = P_Jmp_False then [
344 ins.read_set := 0 bts 0;
346 ] else if opcode = P_Label then [
348 ] else if opcode = P_IO then [
349 for i := 0 to params[1] do [
350 ins.write_set bts= 4 + i;
352 for i := params[1] to params[1] + params[2] do [
353 ins.read_set bts= 4 + i;
355 ins.conflict_1 := ins.read_set;
356 ins.conflict_2 := ins.write_set;
357 xlen := 4 + params[1] + params[2] + params[3];
358 ] else if opcode = P_Args then [
359 for i := 0 to len(params) do
360 ins.write_set bts= i;
362 ] else if opcode = P_Return_Vars then [
363 for i := 0 to len(params) do
364 ins.write_set bts= i;
366 ] else if opcode = P_Return then [
368 while i < len(params) do [
374 ] else if opcode = P_Assume then [
375 ins.read_set := 0 bts 0;
377 ] else if opcode = P_Claim then [
378 ins.read_set := 0 bts 0;
380 ] else if opcode = P_Invariant then [
381 ins.read_set := 0 bts 0;
383 ] else if opcode = P_Checkpoint then [
385 ] else if opcode = P_Line_Info then [
386 if params[0] < 0 then
387 abort internal("P_Line_Info: negative line info");
389 ] else if opcode = P_Phi then [
390 ins.write_set := 0 bts 0;
391 for i := 1 to len(params) do
395 abort internal("invalid opcode");
398 if xlen <> len(params) then
399 abort internal("length mismatch on opcode " + ntos(opcode) + ": " + ntos(xlen) + " <> " + ntos(len(params)));
401 var rs := ins.read_set;
403 var s : int := bsr rs;
405 if ins.params[s] < 0 then [
407 ins.conflict_1 btr= s;
408 ins.conflict_2 btr= s;
411 var ls := ins.lt_set;
413 var s : int := bsr ls;
415 if ins.params[s] < 0 then [
424 fn set_arrow(ctx : context, src : int, dst : int) : context
426 {var sblk := ctx.blocks[src];
427 for i := 0 to len(sblk.instrs) do [
428 eval debug("pcode: " + ntos(sblk.instrs[i].opcode));
430 ctx.blocks[dst].pred_position +<= len(ctx.blocks[src].post_list);
431 ctx.blocks[dst].pred_list +<= src;
432 ctx.blocks[src].post_list +<= dst;
433 //eval debug("arrow from " + ntos(src) + " to " + ntos(dst) + " total " + ntos(len(ctx.blocks)));
437 fn load_function_context(pc : list(pcode_t)) : context
440 local_types : empty(local_type),
441 instrs : empty(instruction),
442 blocks : empty(basic_block),
444 variables : exception_make(list(variable), ec_sync, error_record_field_not_initialized, 0, false),
445 label_to_block : exception_make(list(int), ec_sync, error_record_field_not_initialized, 0, false),
446 var_map : exception_make(list(int), ec_sync, error_record_field_not_initialized, 0, false),
447 cm : exception_make(conflict_map, ec_sync, error_record_field_not_initialized, 0, false),
448 should_retry : exception_make(bool, ec_sync, error_record_field_not_initialized, 0, false),
450 name : blob_load(pc[9 .. ]),
453 var ptr := 9 + blob_length(pc[9 .. ]);
455 for i := 0 to pc[2] do [
459 for i := 0 to pc[3] do [
463 if ft = Local_Type_Record then [
464 var n, f := function_load(pc[ptr .. ]);
466 lt := local_type.rec.(f);
467 ] else if ft = Local_Type_Flat_Record then [
468 var non_flat_rec := pc[ptr];
469 var n_entries := pc[ptr + 1];
470 lt := local_type.flat_rec.(local_type_flat_record.[ non_flat_record : non_flat_rec, flat_types : empty(int) ]);
472 for j := 0 to n_entries do [
473 lt.flat_rec.flat_types +<= pc[ptr];
476 ] else if ft = Local_Type_Flat_Array then [
477 lt := local_type.flat_array.(local_type_flat_array.[ flat_type : pc[ptr], number_of_elements : pc[ptr + 1] ]);
480 abort internal("unknown local type " + ntos(ft));
482 ctx.local_types +<= lt;
485 var n_variables := pc[4];
486 ctx.variables := fill(new_variable, n_variables);
488 ctx.label_to_block := fill(-1, pc[8]);
490 for i := 0 to n_variables do [
491 ctx.variables[i].type_index := pc[ptr];
492 ctx.variables[i].runtime_type := pc[ptr + 1];
493 ctx.variables[i].local_type := -1;
494 ctx.variables[i].color := pc[ptr + 2];
495 ctx.variables[i].must_be_flat := pc[ptr + 3] bt bsf VarFlag_Must_Be_Flat;
496 ctx.variables[i].must_be_data := pc[ptr + 3] bt bsf VarFlag_Must_Be_Data;
497 ctx.variables[i].is_option_type := false;
499 ctx.variables[i].name := blob_load(pc[ptr .. ]);
500 ptr += blob_length(pc[ptr .. ]);
502 if ctx.variables[i].runtime_type < T_Undetermined then
503 abort internal("load_function_context: invalid runtime type: " + ctx.name + ", " + ntos(i) + "(" + ctx.variables[i].name + "): " + ntos(ctx.variables[i].runtime_type));
506 var b := new_basic_block;
508 while ptr < len(pc) do [
509 var instr_len := pc[ptr + 1] + 2;
510 var ins := create_instr(pc[ptr], pc[ptr + 2 .. ptr + instr_len], len(ctx.blocks));
512 var free_set : param_set := ins.free_set;
513 while free_set <> 0 do [
514 var arg : int := bsr free_set;
516 ins.params[arg - 1] and= not Flag_Free_Argument;
519 if ins.opcode = P_Jmp or ins.opcode = P_Jmp_False or ins.opcode = P_Return then [
520 b.instrs +<= len(ctx.instrs);
523 b := new_basic_block;
524 ] else if ins.opcode = P_Label then [
525 if len_greater_than(int, b.instrs, 0) then [
527 b := new_basic_block;
530 if ctx.label_to_block[ins.params[0]] >= 0 then
531 abort internal("load_function_context: label already defined");
532 ctx.label_to_block[ins.params[0]] := len(ctx.blocks);
533 b.instrs +<= len(ctx.instrs);
535 ] else if ins.opcode = P_Free then [
536 ] else if ins.opcode = P_Checkpoint then [
538 b.instrs +<= len(ctx.instrs);
544 if ptr > len(pc) then
545 abort internal("load_function_context: " + ctx.name + ": pcode doesn't match");
546 if len(b.instrs) > 0 then
547 abort internal("load_function_context: " + ctx.name + ": the last basic block is not finished");
549 for i := 0 to len(ctx.blocks) do [
551 var block := ctx.blocks[i];
552 if len(block.instrs) = 0 then [
553 ctx := set_arrow(ctx, i, i + 1);
556 var first := ctx.instrs[ block.instrs[0] ];
557 if first.opcode = P_Label then [
558 ctx.blocks[i].instrs := block.instrs[1 .. ];
561 var last := ctx.instrs[ block.instrs[ len(block.instrs) - 1 ] ];
563 if last.opcode = P_Jmp then [
564 var target := last.params[0];
565 ctx := set_arrow(ctx, i, ctx.label_to_block[target]);
566 ctx.blocks[i].instrs := block.instrs[ .. len(block.instrs) - 1];
567 ] else if last.opcode = P_Jmp_False then [
568 var target1 := last.params[1];
569 var target2 := last.params[2];
570 ctx := set_arrow(ctx, i, i + 1);
571 ctx := set_arrow(ctx, i, ctx.label_to_block[target1]);
572 ctx := set_arrow(ctx, i, ctx.label_to_block[target2]);
573 ] else if last.opcode <> P_Return then [
574 ctx := set_arrow(ctx, i, i + 1);
581 fn dump_basic_blocks(ctx : context, dump_it : bool) : list(pcode_t)
583 //dump_it := ctx.name = "main" or ctx.name = "fact";
585 eval debug("-----------------------------------------------------------------");
586 eval debug("dump_basic_blocks: " + ctx.name);
588 var rpc := empty(pcode_t);
589 var worklist : node_set := 1;
590 var done : node_set := 0;
591 while worklist <> 0 do [
592 var bgi : int := bsr worklist;
595 eval debug("process block from worklist " + ntos(bgi));
598 goto process_block_no_label;
604 eval debug("generate label " + ntos(bgi - 1));
605 process_block_no_label:
608 var block := ctx.blocks[bgi];
609 for ili := 0 to len(block.instrs) do [
610 var ins := ctx.instrs[block.instrs[ili]];
612 rpc +<= len(ins.params);
615 var instr_name := (pcode_name(ins.opcode) + " ")[ .. 20];
616 var msg := "instr: " + instr_name + " ";
617 if ins.opcode = P_Jmp_False then [
618 ins.params[1] := block.post_list[1] - 1;
619 ins.params[2] := block.post_list[2] - 1;
621 for i := 0 to len(ins.params) do
622 msg += " " + ntos(ins.params[i]);
624 var read_elided := true;
625 var write_elided := true;
626 var read_set := ins.read_set;
627 if ins.opcode = P_Free then [
630 while read_set <> 0 do [
631 var x : int := bsf read_set;
633 var v := ins.params[x];
634 var va := ctx.variables[v];
635 msg += " r(" + ntos(v) + "," + ntos(va.type_index) + "," + ntos(va.runtime_type) + ":" + ntos(va.color) + ")";
636 if v < 0 or ctx.variables[v].color = -1 then
639 read_elided := false;
641 var write_set := ins.write_set;
642 var something_written := false;
643 while write_set <> 0 do [
644 var x : int := bsf write_set;
646 var v := ins.params[x];
647 var va := ctx.variables[v];
648 msg += " w(" + ntos(v) + "," + ntos(va.type_index) + "," + ntos(va.runtime_type) + ":" + ntos(va.color) + ")";
649 if ctx.variables[v].color = -1 then
652 write_elided := false;
653 something_written := true;
655 if ins.opcode = P_Call or ins.opcode = P_Load_Fn then [
656 msg += " " + function_name(ins.params[3 .. ]);
658 if read_elided and write_elided then
659 msg := bytes.[27] + "[31m" + msg +< 27 + "[0m";
660 else if write_elided and something_written then
661 msg := bytes.[27] + "[35m" + msg +< 27 + "[0m";
663 msg := bytes.[27] + "[32m" + msg +< 27 + "[0m";
670 if len(block.post_list) = 0 then
672 for i := 1 to len(block.post_list) do [
673 var post_idx := block.post_list[i];
674 if not done bt post_idx then
675 worklist bts= post_idx;
676 if rpc[len(rpc) - 5] <> P_Jmp_False then
677 abort internal("a block with multiple outputs doesn't end with P_Jmp_False");
678 rpc[len(rpc) - 3 + i] := post_idx - 1;
680 var next_bgi := block.post_list[0];
681 if done bt next_bgi then [
684 rpc +<= next_bgi - 1;
686 eval debug("generating jump to label " + ntos(next_bgi - 1));
691 eval debug("process following block " + ntos(bgi));
692 if len(ctx.blocks[bgi].pred_list) <= 1 then
693 goto process_block_no_label;