2 * Copyright (C) 2024, 2025 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
19 private unit compiler.optimize.inline;
21 uses compiler.optimize.defs;
23 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context;
27 uses compiler.optimize.utils;
30 {fn dump_graph(ctx : context) : bool
32 for i := 0 to len(ctx.blocks) do [
33 var str := empty(byte);
34 var block := ctx.blocks[i];
35 str += "******** block " + ntos(i);
36 if not block.active then [
43 for j := 0 to len(block.post_list) do
44 str += " " + ntos(block.post_list[j]);
47 for j := 0 to len(block.pred_list) do
48 str += " " + ntos(block.pred_list[j]);
50 str += "pred_position:";
51 for j := 0 to len(block.pred_position) do
52 str += " " + ntos(block.pred_position[j]);
60 fn do_inline(ctx : context, bgi : int, ili : int, call_mode : pcode_t, get_inline : fn(function, bool) : list(pcode_t)) : (context, bool)
62 //eval dump_graph(ctx);
63 var block := ctx.blocks[bgi];
64 var igi := block.instrs[ili];
65 var ins := ctx.instrs[igi];
66 var l, f := function_load(ins.params[3 .. ]);
68 var call_params := ins.params[l .. l + ins.params[2] * 2];
69 var call_result := ins.params[ l + ins.params[2] * 2 ..
70 l + ins.params[2] * 2 + ins.params[1]];
72 //eval debug("considering inline " + ntos(f.path_idx) + " " + f.un + " " + ntos(f.fn_idx) + " in block " + ntos(bgi));
73 //var new_pc := load_optimized_pcode(f.path_idx, f.un, f.fn_idx);
74 var new_pc := get_inline~save(f, true);
76 var ft := new_pc[0] and Fn_Mask;
78 if ft <> Fn_Function and ft <> Fn_Record and ft <> Fn_Option then
79 abort internal("unsupported function type " + ntos(new_pc[0]));
81 if ft <> Fn_Function and call_mode <> Call_Mode_Flat then
84 if ft = Fn_Option then [
85 var did_something := false;
86 if new_pc[0] bt bsf Fn_IsFlatOption, ctx.variables[call_result[0]].is_option_type <> true then [
87 ctx.variables[call_result[0]].is_option_type := true;
88 did_something := true;
90 if new_pc[0] bt bsf Fn_AlwaysFlatOption, ctx.variables[call_result[0]].always_flat_option <> true then [
91 ctx.variables[call_result[0]].always_flat_option := true;
92 did_something := true;
94 return ctx, did_something;
96 //if ft <> Fn_Function then
99 var new_ctx := load_function_context(new_pc);
101 //eval debug("inlining " + new_ctx.name + " into " + ctx.name);
103 //eval dump_basic_blocks(new_ctx, true);
105 //eval dump_graph(new_ctx);
108 for i := 0 to ins.params[2] do [
109 var cp := call_params[i * 2 + 1];
111 if cp >= 0 then rt := ctx.variables[cp].runtime_type; else rt := -1;
112 eval debug("caller arg " + ntos(i) + ": " + ntos(cp) + ", rt " + ntos(rt));
113 var instr_args := new_ctx.instrs[new_ctx.blocks[0].instrs[0]];
114 var arg := instr_args.params[i];
115 eval debug("callee arg " + ntos(i) + ": " + ntos(arg) + ", rt " + ntos(new_ctx.variables[arg].runtime_type));
119 var tail_block := new_basic_block;
120 var tail_block_bgi := len(ctx.blocks);
121 tail_block.post_list := block.post_list;
122 for i := 0 to len(tail_block.post_list) do [
123 var post_bgi := tail_block.post_list[i];
124 var post_block := ctx.blocks[post_bgi];
125 for j := 0 to len(post_block.pred_list) do [
126 if post_block.pred_list[j] = bgi then
127 ctx.blocks[post_bgi].pred_list[j] := tail_block_bgi;
130 tail_block.instrs := block.instrs[ ili + 1 .. ];
131 for i := 0 to len(tail_block.instrs) do [
132 var vgi := tail_block.instrs[i];
133 ctx.instrs[vgi].bb := tail_block_bgi;
135 ctx.blocks +<= tail_block;
136 ctx.blocks[bgi].instrs := ctx.blocks[bgi].instrs[ .. ili ];
137 ctx.blocks[bgi].post_list := list(int).[ len(ctx.blocks) ];
139 for i := 0 to len(new_ctx.local_types) do [
140 var lt := new_ctx.local_types[i];
141 if lt is flat_rec then [
142 lt.flat_rec.non_flat_record += len(ctx.local_types);
143 for j := 0 to len(lt.flat_rec.flat_types) do [
144 if lt.flat_rec.flat_types[j] >= 0 then
145 lt.flat_rec.flat_types[j] += len(ctx.local_types);
147 new_ctx.local_types[i] := lt;
148 ] else if lt is flat_array then [
149 if lt.flat_array.flat_type >= 0 then
150 lt.flat_array.flat_type += len(ctx.local_types);
151 new_ctx.local_types[i] := lt;
155 for i := 0 to len(new_ctx.instrs) do [
156 if new_ctx.instrs[i].opcode = P_BinaryOp then
157 new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
158 if new_ctx.instrs[i].opcode = P_BinaryConstOp then
159 new_ctx.instrs[i].params[2] and= not Flag_Fused_Bin_Jmp;
160 if new_ctx.instrs[i].opcode = P_Claim then
161 new_ctx.instrs[i].opcode := P_Assume;
162 var s := new_ctx.instrs[i].read_set or new_ctx.instrs[i].write_set;
164 var p : int := bsr s;
166 new_ctx.instrs[i].params[p] += len(ctx.variables);
168 new_ctx.instrs[i].bb += len(ctx.blocks);
170 var ls := new_ctx.instrs[i].lt_set;
174 new_ctx.instrs[i].params[p] += len(ctx.local_types);
178 for i := 0 to len(new_ctx.variables) do [
179 if new_ctx.variables[i].type_index >= 0 then
180 new_ctx.variables[i].type_index += len(ctx.variables);
181 if new_ctx.variables[i].runtime_type >= 0 then
182 new_ctx.variables[i].runtime_type += len(ctx.local_types);
185 for i := 0 to len(new_ctx.blocks) do [
186 var new_block := new_ctx.blocks[i];
187 if not new_block.active then
191 var first_instr_igi := new_block.instrs[0];
192 var first_instr := new_ctx.instrs[first_instr_igi];
193 if first_instr.opcode <> P_Args then
194 abort internal("the first instruction is not Args");
195 new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[1 .. ];
196 if len(first_instr.params) * 2 <> len(call_params) then
197 abort internal("call mismatch when inlining " + new_ctx.name + " into " + ctx.name);
198 for j := 0 to len(first_instr.params) do [
199 var cp := call_params[j * 2 + 1];
200 var arg := first_instr.params[j];
201 //eval debug("args types " + ntos(j) + ": " + ntos(cp_rt) + " " + ntos(arg_rt));
203 var copy_in := create_instr(P_Copy, list(pcode_t).[ arg, 0, cp ], bgi);
204 var copy_igi := len(ctx.instrs);
205 ctx.blocks[bgi].instrs +<= copy_igi;
206 ctx.instrs +<= copy_in;
208 new_ctx.blocks[i].pred_list := list(int).[ bgi ];
209 new_ctx.blocks[i].pred_position := list(int).[ 0 ];
211 for j := 0 to len(new_block.pred_list) do
212 new_ctx.blocks[i].pred_list[j] += len(ctx.blocks);
215 new_block := new_ctx.blocks[i];
217 if len_greater_than(int, new_block.instrs, 0) then [
218 var last_instr_igi := new_block.instrs[len(new_block.instrs) - 1];
219 var last_instr := new_ctx.instrs[last_instr_igi];
220 if last_instr.opcode = P_Return then [
221 new_ctx.blocks[i].instrs := new_ctx.blocks[i].instrs[ .. len(new_block.instrs) - 1];
222 for j := 0 to len(last_instr.params) shr 1 do [
223 var rp := last_instr.params[j * 2 + 1];
224 var cr := call_result[j];
225 //eval debug("ret types " + ntos(j) + ": " + ntos(rp_rt) + " " + ntos(cr_rt));
227 var copy_in := create_instr(P_Copy, list(pcode_t).[ call_result[j], 0, rp ], len(ctx.blocks) + i);
228 var copy_igi := len(new_ctx.instrs);
229 new_ctx.blocks[i].instrs +<= copy_igi;
230 new_ctx.instrs +<= copy_in;
232 new_ctx.blocks[i].post_list := list(int).[ tail_block_bgi ];
233 ctx.blocks[tail_block_bgi].pred_list +<= len(ctx.blocks) + i;
234 ctx.blocks[tail_block_bgi].pred_position +<= 0;
239 for j := 0 to len(new_ctx.blocks[i].post_list) do
240 new_ctx.blocks[i].post_list[j] += len(ctx.blocks);
243 var instrs := empty(int);
244 for j := 0 to len(new_ctx.blocks[i].instrs) do [
245 var ins := new_ctx.instrs[new_ctx.blocks[i].instrs[j]];
246 if ins.opcode <> P_Line_Info then [
247 instrs +<= new_ctx.blocks[i].instrs[j];
250 new_ctx.blocks[i].instrs := instrs;
252 for j := 0 to len(new_ctx.blocks[i].instrs) do
253 new_ctx.blocks[i].instrs[j] += len(ctx.instrs);
256 ctx.local_types += new_ctx.local_types;
257 ctx.instrs += new_ctx.instrs;
258 ctx.blocks += new_ctx.blocks;
259 ctx.variables += new_ctx.variables;
261 //eval dump_graph(ctx);
266 fn inline_functions(ctx : context, get_inline : fn(function, bool) : list(pcode_t)) : context
269 var did_something := false;
270 for bgi := 0 to len(ctx.blocks) do [
271 var block := ctx.blocks[bgi];
272 if not block.active then
274 var l := len(block.instrs);
275 for ili := 0 to l do [
276 var igi := ctx.blocks[bgi].instrs[ili];
277 var ins := ctx.instrs[igi];
278 if ins.opcode = P_Call then [
279 //xeval debug("considering inline of " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ": " + ntos(ins.params[0]));
280 var call_mode := ins.params[0];
281 if call_mode = Call_Mode_Inline or call_mode = Call_Mode_Flat then [
284 ctx, progress := do_inline(ctx, bgi, ili, call_mode, get_inline);
286 did_something := true;
287 ctx.should_retry := true;
290 ] else if call_mode = Call_Mode_Unspecified then [
291 var pc := function_pcode(ins.params[3 .. ]);
292 if pc[0] bt bsf Fn_AutoInline, not len_greater_than(pc, 1024) then [
293 //xeval debug("auto-inlining " + ctx.name + " -> " + function_name(ins.params[3 .. ]) + ", " + ntos(len(pc)));
296 //xeval function_name(ins.params[3 .. ]);
301 if did_something then