limit the size of inlined functions to 1024 instructions
[ajla.git] / newlib / compiler / optimize / inline.ajla
blob4977b325d519b1e4334aef3682c8f777396e4e83
1 {*
2  * Copyright (C) 2024 Mikulas Patocka
3  *
4  * This file is part of Ajla.
5  *
6  * Ajla is free software: you can redistribute it and/or modify it under the
7  * terms of the GNU General Public License as published by the Free Software
8  * Foundation, either version 3 of the License, or (at your option) any later
9  * version.
10  *
11  * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13  * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * Ajla. If not, see <https://www.gnu.org/licenses/>.
17  *}
19 private unit compiler.optimize.inline;
21 uses compiler.optimize.defs;
23 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context;
25 implementation
27 uses compiler.optimize.utils;
30 {fn dump_graph(ctx : context) : bool
32         for i := 0 to len(ctx.blocks) do [
33                 var str := empty(byte);
34                 var block := ctx.blocks[i];
35                 str += "******** block " + ntos(i);
36                 if not block.active then [
37                         str += " - inactive";
38                         eval debug(str);
39                         goto next;
40                 ]
41                 str += nl;
42                 str += "post_list:";
43                 for j := 0 to len(block.post_list) do
44                         str += " " + ntos(block.post_list[j]);
45                 str += nl;
46                 str += "pred_list:";
47                 for j := 0 to len(block.pred_list) do
48                         str += " " + ntos(block.pred_list[j]);
49                 str += nl;
50                 str += "pred_position:";
51                 for j := 0 to len(block.pred_position) do
52                         str += " " + ntos(block.pred_position[j]);
53                 eval debug(str);
54 next:
55         ]
56         return true;
60 fn do_inline(ctx : context, bgi : int, ili : int, call_mode : pcode_t, get_inline : fn(function, bool) : list(pcode_t)) : (context, bool)
62         //eval dump_graph(ctx);
63         var block := ctx.blocks[bgi];
64         var igi := block.instrs[ili];
65         var ins := ctx.instrs[igi];
66         var l, f := function_load(ins.params[3 .. ]);
67         l := 3 + l;
68         var call_params := ins.params[l .. l + ins.params[2] * 2];
69         var call_result := ins.params[    l + ins.params[2] * 2 ..
70                                           l + ins.params[2] * 2 + ins.params[1]];
72         //eval debug("considering inline " + ntos(f.path_idx) + " " + f.un + " " + ntos(f.fn_idx) + " in block " + ntos(bgi));
73         //var new_pc := load_optimized_pcode(f.path_idx, f.un, f.fn_idx);
74         var new_pc := get_inline~save(f, true);
76         var ft := new_pc[0] and Fn_Mask;
78         if ft <> Fn_Function and ft <> Fn_Record and ft <> Fn_Option then
79                 abort internal("unsupported function type " + ntos(new_pc[0]));
81         if ft <> Fn_Function and call_mode <> Call_Mode_Flat then
82                 return ctx, false;
84         if ft = Fn_Option then [
85                 if ctx.variables[call_result[0]].is_option_type <> true then [
86                         ctx.variables[call_result[0]].is_option_type := true;
87                         return ctx, true;
88                 ]
89                 return ctx, false;
90         ]
91         //if ft <> Fn_Function then
92         //      return ctx, false;
94         var new_ctx := load_function_context(new_pc);
96         //eval debug("inlining " + new_ctx.name + " into " + ctx.name);
98         //eval dump_basic_blocks(new_ctx, true);
100         //eval dump_graph(new_ctx);
102         {
103         for i := 0 to ins.params[2] do [
104                 var cp := call_params[i * 2 + 1];
105                 var rt : int;
106                 if cp >= 0 then rt := ctx.variables[cp].runtime_type; else rt := -1;
107                 eval debug("caller arg " + ntos(i) + ": " + ntos(cp) + ", rt " + ntos(rt));
108                 var instr_args := new_ctx.instrs[new_ctx.blocks[0].instrs[0]];
109                 var arg := instr_args.params[i];
110                 eval debug("callee arg " + ntos(i) + ": " + ntos(arg) + ", rt " + ntos(new_ctx.variables[arg].runtime_type));
111         ]
112         }
114         var tail_block := new_basic_block;
115         var tail_block_bgi := len(ctx.blocks);
116         tail_block.post_list := block.post_list;
117         for i := 0 to len(tail_block.post_list) do [
118                 var post_bgi := tail_block.post_list[i];
119                 var post_block := ctx.blocks[post_bgi];
120                 for j := 0 to len(post_block.pred_list) do [
121                         if post_block.pred_list[j] = bgi then
122                                 ctx.blocks[post_bgi].pred_list[j] := tail_block_bgi;
123                 ]
124         ]
125         tail_block.instrs := block.instrs[ ili + 1 .. ];
126         for i := 0 to len(tail_block.instrs) do [
127                 var vgi := tail_block.instrs[i];
128                 ctx.instrs[vgi].bb := tail_block_bgi;
129         ]
130         ctx.blocks +<= tail_block;
131         ctx.blocks[bgi].instrs := ctx.blocks[bgi].instrs[ .. ili ];
132         ctx.blocks[bgi].post_list := list(int).[ len(ctx.blocks) ];
134         for i := 0 to len(new_ctx.local_types) do [
135                 var lt := new_ctx.local_types[i];
136                 if lt is flat_rec then [
137                         lt.flat_rec.non_flat_record += len(ctx.local_types);
138                         for j := 0 to len(lt.flat_rec.flat_types) do [
139                                 if lt.flat_rec.flat_types[j] >= 0 then
140                                         lt.flat_rec.flat_types[j] += len(ctx.local_types);
141                         ]
142                         new_ctx.local_types[i] := lt;
143                 ] else if lt is flat_array then [
144                         if lt.flat_array.flat_type >= 0 then
145                                 lt.flat_array.flat_type += len(ctx.local_types);
146                         new_ctx.local_types[i] := lt;
147                 ]
148         ]
150         for i := 0 to len(new_ctx.instrs) do [
151                 if new_ctx.instrs[i].opcode = P_BinaryOp then
152                         new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
153                 if new_ctx.instrs[i].opcode = P_BinaryConstOp then
154                         new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
155                 var s := new_ctx.instrs[i].read_set or new_ctx.instrs[i].write_set;
156                 while s <> 0 do [
157                         var p : int := bsr s;
158                         s btr= p;
159                         new_ctx.instrs[i].params[p] += len(ctx.variables);
160                 ]
161                 new_ctx.instrs[i].bb += len(ctx.blocks);
163                 var ls := new_ctx.instrs[i].lt_set;
164                 while ls <> 0 do [
165                         var p := bsr ls;
166                         ls btr= p;
167                         new_ctx.instrs[i].params[p] += len(ctx.local_types);
168                 ]
169         ]
171         for i := 0 to len(new_ctx.variables) do [
172                 if new_ctx.variables[i].type_index >= 0 then
173                         new_ctx.variables[i].type_index += len(ctx.variables);
174                 if new_ctx.variables[i].runtime_type >= 0 then
175                         new_ctx.variables[i].runtime_type += len(ctx.local_types);
176         ]
178         for i := 0 to len(new_ctx.blocks) do [
179                 var new_block := new_ctx.blocks[i];
180                 if not new_block.active then
181                         continue;
183                 if i = 0 then [
184                         var first_instr_igi := new_block.instrs[0];
185                         var first_instr := new_ctx.instrs[first_instr_igi];
186                         if first_instr.opcode <> P_Args then
187                                 abort internal("the first instruction is not Args");
188                         new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[1 .. ];
189                         if len(first_instr.params) * 2 <> len(call_params) then
190                                 abort internal("call mismatch when inlining " + new_ctx.name + " into " + ctx.name);
191                         for j := 0 to len(first_instr.params) do [
192                                 var cp := call_params[j * 2 + 1];
193                                 var arg := first_instr.params[j];
194                                 //eval debug("args types " + ntos(j) + ": " + ntos(cp_rt) + " " + ntos(arg_rt));
196                                 var copy_in := create_instr(P_Copy, list(pcode_t).[ arg, 0, cp ], bgi);
197                                 var copy_igi := len(ctx.instrs);
198                                 ctx.blocks[bgi].instrs +<= copy_igi;
199                                 ctx.instrs +<= copy_in;
200                         ]
201                         new_ctx.blocks[i].pred_list := list(int).[ bgi ];
202                         new_ctx.blocks[i].pred_position := list(int).[ 0 ];
203                 ] else [
204                         for j := 0 to len(new_block.pred_list) do
205                                 new_ctx.blocks[i].pred_list[j] += len(ctx.blocks);
206                 ]
208                 new_block := new_ctx.blocks[i];
210                 if len_greater_than(int, new_block.instrs, 0) then [
211                         var last_instr_igi := new_block.instrs[len(new_block.instrs) - 1];
212                         var last_instr := new_ctx.instrs[last_instr_igi];
213                         if last_instr.opcode = P_Return then [
214                                 new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[ .. len(new_block.instrs) - 1];
215                                 for j := 0 to len(last_instr.params) shr 1 do [
216                                         var rp := last_instr.params[j * 2 + 1];
217                                         var cr := call_result[j];
218                                         //eval debug("ret types " + ntos(j) + ": " + ntos(rp_rt) + " " + ntos(cr_rt));
220                                         var copy_in := create_instr(P_Copy, list(pcode_t).[ call_result[j], 0, rp ], len(ctx.blocks) + i);
221                                         var copy_igi := len(new_ctx.instrs);
222                                         new_ctx.blocks[i].instrs +<= copy_igi;
223                                         new_ctx.instrs +<= copy_in;
224                                 ]
225                                 new_ctx.blocks[i].post_list := list(int).[ tail_block_bgi ];
226                                 ctx.blocks[tail_block_bgi].pred_list +<= len(ctx.blocks) + i;
227                                 ctx.blocks[tail_block_bgi].pred_position +<= 0;
228                                 goto done_post_list;
229                         ]
230                 ]
232                 for j := 0 to len(new_ctx.blocks[i].post_list) do
233                         new_ctx.blocks[i].post_list[j] += len(ctx.blocks);
234 done_post_list:
236                 var instrs := empty(int);
237                 for j := 0 to len(new_ctx.blocks[i].instrs) do [
238                         var ins := new_ctx.instrs[new_ctx.blocks[i].instrs[j]];
239                         if ins.opcode <> P_Line_Info then [
240                                 instrs +<= new_ctx.blocks[i].instrs[j];
241                         ]
242                 ]
243                 new_ctx.blocks[i].instrs := instrs;
245                 for j := 0 to len(new_ctx.blocks[i].instrs) do
246                         new_ctx.blocks[i].instrs[j] += len(ctx.instrs);
247         ]
249         ctx.local_types += new_ctx.local_types;
250         ctx.instrs += new_ctx.instrs;
251         ctx.blocks += new_ctx.blocks;
252         ctx.variables += new_ctx.variables;
254         //eval dump_graph(ctx);
256         return ctx, true;
259 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context
261 again:
262         var did_something := false;
263         for bgi := 0 to len(ctx.blocks) do [
264                 var block := ctx.blocks[bgi];
265                 if not block.active then
266                         continue;
267                 var l := len(block.instrs);
268                 for ili := 0 to l do [
269                         var igi := ctx.blocks[bgi].instrs[ili];
270                         var ins := ctx.instrs[igi];
271                         if ins.opcode = P_Call then [
272                                 //xeval debug("considering inline of " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ": " + ntos(ins.params[0]));
273                                 var call_mode := ins.params[0];
274                                 if call_mode = Call_Mode_Inline or call_mode = Call_Mode_Type or call_mode = Call_Mode_Flat then [
275 do_inline:
276                                         var progress : bool;
277                                         ctx, progress := do_inline(ctx, bgi, ili, call_mode, get_inline);
278                                         if progress then [
279                                                 did_something := true;
280                                                 ctx.should_retry := true;
281                                                 break;
282                                         ]
283                                 ] else if call_mode = Call_Mode_Unspecified then [
284                                         var pc := function_pcode(ins.params[3 .. ]);
285                                         if pc[8] <= 1, not len_greater_than(pc, 1024) then [
286                                                 //xeval debug("auto-inlining " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ", " + ntos(len(pc)));
287                                                 goto do_inline;
288                                         ]
289                                         //xeval function_name(ins.params[3 .. ]);
290                                 ]
291                         ]
292                 ]
293         ]
294         if did_something then goto again;
296         return ctx;