verify: support phis
[ajla.git] / stdlib / compiler / optimize / inline.ajla
blob6264079cb4092911dd5c2f412007c33ff1bcad06
1 {*
2  * Copyright (C) 2024, 2025 Mikulas Patocka
3  *
4  * This file is part of Ajla.
5  *
6  * Ajla is free software: you can redistribute it and/or modify it under the
7  * terms of the GNU General Public License as published by the Free Software
8  * Foundation, either version 3 of the License, or (at your option) any later
9  * version.
10  *
11  * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13  * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * Ajla. If not, see <https://www.gnu.org/licenses/>.
17  *}
19 private unit compiler.optimize.inline;
21 uses compiler.optimize.defs;
23 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context;
25 implementation
27 uses compiler.optimize.utils;
30 {fn dump_graph(ctx : context) : bool
32         for i := 0 to len(ctx.blocks) do [
33                 var str := empty(byte);
34                 var block := ctx.blocks[i];
35                 str += "******** block " + ntos(i);
36                 if not block.active then [
37                         str += " - inactive";
38                         eval debug(str);
39                         goto next;
40                 ]
41                 str += nl;
42                 str += "post_list:";
43                 for j := 0 to len(block.post_list) do
44                         str += " " + ntos(block.post_list[j]);
45                 str += nl;
46                 str += "pred_list:";
47                 for j := 0 to len(block.pred_list) do
48                         str += " " + ntos(block.pred_list[j]);
49                 str += nl;
50                 str += "pred_position:";
51                 for j := 0 to len(block.pred_position) do
52                         str += " " + ntos(block.pred_position[j]);
53                 eval debug(str);
54 next:
55         ]
56         return true;
60 fn do_inline(ctx : context, bgi : int, ili : int, call_mode : pcode_t, get_inline : fn(function, bool) : list(pcode_t)) : (context, bool)
62         //eval dump_graph(ctx);
63         var block := ctx.blocks[bgi];
64         var igi := block.instrs[ili];
65         var ins := ctx.instrs[igi];
66         var l, f := function_load(ins.params[3 .. ]);
67         l := 3 + l;
68         var call_params := ins.params[l .. l + ins.params[2] * 2];
69         var call_result := ins.params[    l + ins.params[2] * 2 ..
70                                           l + ins.params[2] * 2 + ins.params[1]];
72         //eval debug("considering inline " + ntos(f.path_idx) + " " + f.un + " " + ntos(f.fn_idx) + " in block " + ntos(bgi));
73         //var new_pc := load_optimized_pcode(f.path_idx, f.un, f.fn_idx);
74         var new_pc := get_inline~save(f, true);
76         var ft := new_pc[0] and Fn_Mask;
78         if ft <> Fn_Function and ft <> Fn_Record and ft <> Fn_Option then
79                 abort internal("unsupported function type " + ntos(new_pc[0]));
81         if ft <> Fn_Function and call_mode <> Call_Mode_Flat then
82                 return ctx, false;
84         if ft = Fn_Option then [
85                 var did_something := false;
86                 if new_pc[0] bt bsf Fn_IsFlatOption, ctx.variables[call_result[0]].is_option_type <> true then [
87                         ctx.variables[call_result[0]].is_option_type := true;
88                         did_something := true;
89                 ]
90                 if new_pc[0] bt bsf Fn_AlwaysFlatOption, ctx.variables[call_result[0]].always_flat_option <> true then [
91                         ctx.variables[call_result[0]].always_flat_option := true;
92                         did_something := true;
93                 ]
94                 return ctx, did_something;
95         ]
96         //if ft <> Fn_Function then
97         //      return ctx, false;
99         var new_ctx := load_function_context(new_pc);
101         //eval debug("inlining " + new_ctx.name + " into " + ctx.name);
103         //eval dump_basic_blocks(new_ctx, true);
105         //eval dump_graph(new_ctx);
107         {
108         for i := 0 to ins.params[2] do [
109                 var cp := call_params[i * 2 + 1];
110                 var rt : int;
111                 if cp >= 0 then rt := ctx.variables[cp].runtime_type; else rt := -1;
112                 eval debug("caller arg " + ntos(i) + ": " + ntos(cp) + ", rt " + ntos(rt));
113                 var instr_args := new_ctx.instrs[new_ctx.blocks[0].instrs[0]];
114                 var arg := instr_args.params[i];
115                 eval debug("callee arg " + ntos(i) + ": " + ntos(arg) + ", rt " + ntos(new_ctx.variables[arg].runtime_type));
116         ]
117         }
119         var tail_block := new_basic_block;
120         var tail_block_bgi := len(ctx.blocks);
121         tail_block.post_list := block.post_list;
122         for i := 0 to len(tail_block.post_list) do [
123                 var post_bgi := tail_block.post_list[i];
124                 var post_block := ctx.blocks[post_bgi];
125                 for j := 0 to len(post_block.pred_list) do [
126                         if post_block.pred_list[j] = bgi then
127                                 ctx.blocks[post_bgi].pred_list[j] := tail_block_bgi;
128                 ]
129         ]
130         tail_block.instrs := block.instrs[ ili + 1 .. ];
131         for i := 0 to len(tail_block.instrs) do [
132                 var vgi := tail_block.instrs[i];
133                 ctx.instrs[vgi].bb := tail_block_bgi;
134         ]
135         ctx.blocks +<= tail_block;
136         ctx.blocks[bgi].instrs := ctx.blocks[bgi].instrs[ .. ili ];
137         ctx.blocks[bgi].post_list := list(int).[ len(ctx.blocks) ];
139         for i := 0 to len(new_ctx.local_types) do [
140                 var lt := new_ctx.local_types[i];
141                 if lt is flat_rec then [
142                         lt.flat_rec.non_flat_record += len(ctx.local_types);
143                         for j := 0 to len(lt.flat_rec.flat_types) do [
144                                 if lt.flat_rec.flat_types[j] >= 0 then
145                                         lt.flat_rec.flat_types[j] += len(ctx.local_types);
146                         ]
147                         new_ctx.local_types[i] := lt;
148                 ] else if lt is flat_array then [
149                         if lt.flat_array.flat_type >= 0 then
150                                 lt.flat_array.flat_type += len(ctx.local_types);
151                         new_ctx.local_types[i] := lt;
152                 ]
153         ]
155         for i := 0 to len(new_ctx.instrs) do [
156                 if new_ctx.instrs[i].opcode = P_BinaryOp then
157                         new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
158                 if new_ctx.instrs[i].opcode = P_BinaryConstOp then
159                         new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
160                 if new_ctx.instrs[i].opcode = P_Claim then
161                         new_ctx.instrs[i].opcode := P_Assume;
162                 var s := new_ctx.instrs[i].read_set or new_ctx.instrs[i].write_set;
163                 while s <> 0 do [
164                         var p : int := bsr s;
165                         s btr= p;
166                         new_ctx.instrs[i].params[p] += len(ctx.variables);
167                 ]
168                 new_ctx.instrs[i].bb += len(ctx.blocks);
170                 var ls := new_ctx.instrs[i].lt_set;
171                 while ls <> 0 do [
172                         var p := bsr ls;
173                         ls btr= p;
174                         new_ctx.instrs[i].params[p] += len(ctx.local_types);
175                 ]
176         ]
178         for i := 0 to len(new_ctx.variables) do [
179                 if new_ctx.variables[i].type_index >= 0 then
180                         new_ctx.variables[i].type_index += len(ctx.variables);
181                 if new_ctx.variables[i].runtime_type >= 0 then
182                         new_ctx.variables[i].runtime_type += len(ctx.local_types);
183         ]
185         for i := 0 to len(new_ctx.blocks) do [
186                 var new_block := new_ctx.blocks[i];
187                 if not new_block.active then
188                         continue;
190                 if i = 0 then [
191                         var first_instr_igi := new_block.instrs[0];
192                         var first_instr := new_ctx.instrs[first_instr_igi];
193                         if first_instr.opcode <> P_Args then
194                                 abort internal("the first instruction is not Args");
195                         new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[1 .. ];
196                         if len(first_instr.params) * 2 <> len(call_params) then
197                                 abort internal("call mismatch when inlining " + new_ctx.name + " into " + ctx.name);
198                         for j := 0 to len(first_instr.params) do [
199                                 var cp := call_params[j * 2 + 1];
200                                 var arg := first_instr.params[j];
201                                 //eval debug("args types " + ntos(j) + ": " + ntos(cp_rt) + " " + ntos(arg_rt));
203                                 var copy_in := create_instr(P_Copy, list(pcode_t).[ arg, 0, cp ], bgi);
204                                 var copy_igi := len(ctx.instrs);
205                                 ctx.blocks[bgi].instrs +<= copy_igi;
206                                 ctx.instrs +<= copy_in;
207                         ]
208                         new_ctx.blocks[i].pred_list := list(int).[ bgi ];
209                         new_ctx.blocks[i].pred_position := list(int).[ 0 ];
210                 ] else [
211                         for j := 0 to len(new_block.pred_list) do
212                                 new_ctx.blocks[i].pred_list[j] += len(ctx.blocks);
213                 ]
215                 new_block := new_ctx.blocks[i];
217                 if len_greater_than(int, new_block.instrs, 0) then [
218                         var last_instr_igi := new_block.instrs[len(new_block.instrs) - 1];
219                         var last_instr := new_ctx.instrs[last_instr_igi];
220                         if last_instr.opcode = P_Return then [
221                                 new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[ .. len(new_block.instrs) - 1];
222                                 for j := 0 to len(last_instr.params) shr 1 do [
223                                         var rp := last_instr.params[j * 2 + 1];
224                                         var cr := call_result[j];
225                                         //eval debug("ret types " + ntos(j) + ": " + ntos(rp_rt) + " " + ntos(cr_rt));
227                                         var copy_in := create_instr(P_Copy, list(pcode_t).[ call_result[j], 0, rp ], len(ctx.blocks) + i);
228                                         var copy_igi := len(new_ctx.instrs);
229                                         new_ctx.blocks[i].instrs +<= copy_igi;
230                                         new_ctx.instrs +<= copy_in;
231                                 ]
232                                 new_ctx.blocks[i].post_list := list(int).[ tail_block_bgi ];
233                                 ctx.blocks[tail_block_bgi].pred_list +<= len(ctx.blocks) + i;
234                                 ctx.blocks[tail_block_bgi].pred_position +<= 0;
235                                 goto done_post_list;
236                         ]
237                 ]
239                 for j := 0 to len(new_ctx.blocks[i].post_list) do
240                         new_ctx.blocks[i].post_list[j] += len(ctx.blocks);
241 done_post_list:
243                 var instrs := empty(int);
244                 for j := 0 to len(new_ctx.blocks[i].instrs) do [
245                         var ins := new_ctx.instrs[new_ctx.blocks[i].instrs[j]];
246                         if ins.opcode <> P_Line_Info then [
247                                 instrs +<= new_ctx.blocks[i].instrs[j];
248                         ]
249                 ]
250                 new_ctx.blocks[i].instrs := instrs;
252                 for j := 0 to len(new_ctx.blocks[i].instrs) do
253                         new_ctx.blocks[i].instrs[j] += len(ctx.instrs);
254         ]
256         ctx.local_types += new_ctx.local_types;
257         ctx.instrs += new_ctx.instrs;
258         ctx.blocks += new_ctx.blocks;
259         ctx.variables += new_ctx.variables;
261         //eval dump_graph(ctx);
263         return ctx, true;
266 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context
268 again:
269         var did_something := false;
270         for bgi := 0 to len(ctx.blocks) do [
271                 var block := ctx.blocks[bgi];
272                 if not block.active then
273                         continue;
274                 var l := len(block.instrs);
275                 for ili := 0 to l do [
276                         var igi := ctx.blocks[bgi].instrs[ili];
277                         var ins := ctx.instrs[igi];
278                         if ins.opcode = P_Call then [
279                                 //xeval debug("considering inline of " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ": " + ntos(ins.params[0]));
280                                 var call_mode := ins.params[0];
281                                 if call_mode = Call_Mode_Inline or call_mode = Call_Mode_Flat then [
282 do_inline:
283                                         var progress : bool;
284                                         ctx, progress := do_inline(ctx, bgi, ili, call_mode, get_inline);
285                                         if progress then [
286                                                 did_something := true;
287                                                 ctx.should_retry := true;
288                                                 break;
289                                         ]
290                                 ] else if call_mode = Call_Mode_Unspecified then [
291                                         var pc := function_pcode(ins.params[3 .. ]);
292                                         if pc[0] bt bsf Fn_AutoInline, not len_greater_than(pc, 1024) then [
293                                                 //xeval debug("auto-inlining " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ", " + ntos(len(pc)));
294                                                 goto do_inline;
295                                         ]
296                                         //xeval function_name(ins.params[3 .. ]);
297                                 ]
298                         ]
299                 ]
300         ]
301         if did_something then
302                 goto again;
304         return ctx;