verify: support phis
[ajla.git] / stdlib / compiler / optimize / verify.ajla
blob39387fe1eae2b9b4544fd99442e42d75329b4bf4
1 {*
2  * Copyright (C) 2025 Mikulas Patocka
3  *
4  * This file is part of Ajla.
5  *
6  * Ajla is free software: you can redistribute it and/or modify it under the
7  * terms of the GNU General Public License as published by the Free Software
8  * Foundation, either version 3 of the License, or (at your option) any later
9  * version.
10  *
11  * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13  * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * Ajla. If not, see <https://www.gnu.org/licenses/>.
17  *}
19 private unit compiler.optimize.verify;
21 uses compiler.optimize.defs;
23 fn verify_function(ctx : context) : unit_type;
25 implementation
27 uses exception;
28 uses charset;
29 uses z3;
30 uses compiler.common.blob;
32 fn allocate_variable(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context, bytes, bytes);
34 fn get_z3_name(ctx : context, v : int) : (bytes, bytes)
36         var n := "var_" + map(ntos_base(v, 16), ascii_locase);
37         var m := "valid_" + map(ntos_base(v, 16), ascii_locase);
38         if len_greater_than(ctx.variables[v].name, 0) then [
39                 n += "_" + ctx.variables[v].name;
40                 m += "_" + ctx.variables[v].name;
41         ]
42         return n, m;
45 fn get_z3_bb_name(ctx : context, bgi : int) : (context, bytes)
47         var str := "bb_" + map(ntos_base(bgi, 16), ascii_locase);
48         ctx.blocks[bgi].verifier_name := str;
49         return ctx, str;
52 fn get_z3_type(ctx : context, v : int) : bytes
54         if ctx.variables[v].type_index = T_AlwaysFlatOption then
55                 return "Bool";
56         if ctx.variables[v].type_index <= T_Integer, ctx.variables[v].type_index >= T_Integer128 then
57                 return "Int";
58         if ctx.variables[v].type_index <= T_Real16, ctx.variables[v].type_index >= T_Real128 then
59                 return "Real";
60         return "";
63 fn gen_assert(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int, deps : list(bytes), ops : bytes) : z3_world
65         var str := "(assert (= " + ctx.variables[v].verifier_valid + " (and";
66         for i := 0 to len(deps) do
67                 str += " " + deps[i];
68         str += " (= " + ctx.variables[v].verifier_name + " " + ops + "))))";
69         z3_eval_smtlib2_string_noret(str);
72 fn assert_instruction(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context)
74         var ins := ctx.instrs[ctx.variables[v].defining_instr];
75         if ins.opcode = P_BinaryOp or ins.opcode = P_BinaryConstOp then [
76                 var t := ctx.variables[ins.params[3]].type_index;
77                 var is_bool := t = T_AlwaysFlatOption;
78                 var is_int := t <= T_Integer and t >= T_Integer128;
79                 var is_real := t <= T_Real16 and t >= T_Real128;
80                 var op_z3 : bytes;
81                 var op := ins.params[0];
82                 if op = Bin_Add                                 then op_z3 := "+";
83                 else if op = Bin_Subtract                       then op_z3 := "-";
84                 else if op = Bin_Multiply                       then op_z3 := "*";
85                 else if op = Bin_Divide_Int                     then op_z3 := "div";
86                 else if op = Bin_Divide_Real                    then op_z3 := "/";
87                 else if op = Bin_Modulo                         then op_z3 := "rem";
88                 else if op = Bin_And, is_bool                   then op_z3 := "and";
89                 else if op = Bin_Or, is_bool                    then op_z3 := "or";
90                 else if op = Bin_Xor, is_bool                   then op_z3 := "xor";
91                 else if op = Bin_Equal                          then op_z3 := "=";
92                 else if op = Bin_NotEqual                       then op_z3 := "distinct";
93                 else if op = Bin_Less, is_int or is_real        then op_z3 := "<";
94                 else if op = Bin_LessEqual, is_int or is_real   then op_z3 := "<=";
95                 else if op = Bin_Greater, is_int or is_real     then op_z3 := ">";
96                 else if op = Bin_GreaterEqual, is_int or is_real then op_z3 := ">=";
97                 else if op = Bin_LessEqual, is_bool             then op_z3 := "=>";
98                 else return ctx;
99                 var var1 var1v var2 var2v : bytes;
100                 //eval debug("P_BinaryOp: " + ntos(ins.params[0]) + ", " + ntos(ins.params[1]) + ", " + ntos(ins.params[2]) + ", " + ntos(ins.params[3]) + ", " + ntos(ins.params[4]) + ", " + ntos(ins.params[5]));
101                 ctx, var1, var1v := allocate_variable(ctx, ins.params[3]);
102                 if len(var1) = 0 then
103                         return ctx;
104                 if ins.opcode = P_BinaryOp then [
105                         ctx, var2, var2v := allocate_variable(ctx, ins.params[5]);
106                         if len(var2) = 0 then
107                                 return ctx;
108                         gen_assert(ctx, v, [var1v, var2v], "(" + op_z3 + " " + var1 + " " + var2 + ")");
109                 ] else [
110                         if is_bool then
111                                 var2 := select(ins.params[4] <> 0, "false", "true");
112                         else
113                                 var2 := ntos(ins.params[4]);
114                         gen_assert(ctx, v, [var1v], "(" + op_z3 + " " + var1 + " " + var2 + ")");
115                 ]
116                 return ctx;
117         ]
118         if ins.opcode = P_UnaryOp then [
119                 var t := ctx.variables[v].type_index;
120                 var is_bool := t = T_AlwaysFlatOption;
121                 var is_int := t <= T_Integer and t >= T_Integer128;
122                 var is_real := t <= T_Real16 and t >= T_Real128;
123                 var op_z3 : bytes;
124                 var op := ins.params[0];
125                 if op = Un_Not, is_bool                         then op_z3 := "not";
126                 else if op = Un_Neg, is_int or is_real          then op_z3 := "-";
127                 else return ctx;
128                 var var1 var1v : bytes;
129                 ctx, var1, var1v := allocate_variable(ctx, ins.params[3]);
130                 if len(var1) = 0 then
131                         return ctx;
132                 gen_assert(ctx, v, [var1v], "(" + op_z3 + " " + var1 + ")");
133                 return ctx;
134         ]
135         if ins.opcode = P_Copy then [
136                 var var1 var1v : bytes;
137                 ctx, var1, var1v := allocate_variable(ctx, ins.params[2]);
138                 if len(var1) = 0 then
139                         return ctx;
140                 gen_assert(ctx, v, [var1v], var1);
141                 return ctx;
142         ]
143         if ins.opcode = P_Load_Const then [
144                 var cnst : bytes;
145                 var l := blob_to_int(ins.params[1 ..]);
146                 var t := ctx.variables[v].type_index;
147                 if t = T_AlwaysFlatOption then
148                         cnst := select(l <> 0, "false", "true");
149                 else if t <= T_Integer and t >= T_Integer128 then
150                         cnst := ntos(l);
151                 else
152                         return ctx;
153                 gen_assert(ctx, v, empty(bytes), cnst);
154                 return ctx;
155         ]
156         if ins.opcode = P_Return_Vars then [
157                 var new_v : int;
158                 for i := 0 to len(ins.params) do [
159                         if ins.params[i] = v then [
160                                 new_v := ctx.return_ins.params[1 + 2 * i];
161                                 goto found_new_v;
162                         ]
163                 ]
164                 abort internal("P_Return_Vars parameter not found");
165 found_new_v:
166                 var var1 var1v : bytes;
167                 ctx, var1, var1v := allocate_variable(ctx, new_v);
168                 if len(var1) = 0 then
169                         return ctx;
170                 gen_assert(ctx, v, [var1v], var1);
171                 return ctx;
172         ]
173         //eval debug("opcode: " + ntos(ins.opcode) + " (" + ctx.variables[v].verifier_name + ")");
174         return ctx;
177 fn allocate_variable(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context, bytes, bytes)
179         if not len_greater_than(ctx.variables[v].verifier_name, 0) then [
180                 var t := get_z3_type(ctx, v);
181                 if not len_greater_than(t, 0) then
182                         return ctx, "", "";
183                 var n, m := get_z3_name(ctx, v);
184                 ctx.variables[v].verifier_name := n;
185                 ctx.variables[v].verifier_valid := m;
186                 z3_eval_smtlib2_string_noret("(declare-const " + n + " " + t + ")");
187                 z3_eval_smtlib2_string_noret("(declare-const " + m + " Bool)");
188                 ctx := assert_instruction(ctx, v);
189         ]
190         return ctx, ctx.variables[v].verifier_name, ctx.variables[v].verifier_valid;
193 fn verify_function(ctx : context) : unit_type
195         var b : bytes;
196         implicit var z3w := z3_mk_world;
197         implicit var z3ctx := z3_mk_context();
198         //eval debug("verify function " + ctx.name);
200         for bgi := 0 to len(ctx.blocks) do [
201                 if not ctx.blocks[bgi].active then
202                         continue;
203                 for ili := 0 to len(ctx.blocks[bgi].instrs) do [
204                         var igi := ctx.blocks[bgi].instrs[ili];
205                         var ins := ctx.instrs[igi];
206                         if ins.opcode = P_Return then [
207                                 ctx.return_ins := ins;
208                         ]
209                 ]
210                 var str : bytes;
211                 ctx, str := get_z3_bb_name(ctx, bgi);
212                 z3_eval_smtlib2_string_noret("(declare-const " + str + " Bool)");
213         ]
215         for bgi := 0 to len(ctx.blocks) do [
216                 if not ctx.blocks[bgi].active then
217                         continue;
218                 var assumes := "";
219                 var claims := "";
220                 var guards := "";
221                 var all_phis := "";
222                 for i := 0 to len(ctx.blocks[bgi].pred_list) do [
223                         var p := ctx.blocks[bgi].pred_list[i];
224                         if len_greater_than(ctx.blocks[p].instrs, 0) then [
225                                 var ins := ctx.instrs[ctx.blocks[p].instrs[len(ctx.blocks[p].instrs) - 1]];
226                                 if ins.opcode = P_Jmp_False then [
227                                         var pos := ctx.blocks[bgi].pred_position[i];
228                                         var var1 var1v : bytes;
229                                         ctx, var1, var1v := allocate_variable(ctx, ins.params[0]);
230                                         if pos = 0 then
231                                                 guards += " (and " + var1v + " " + var1 + ")";
232                                         else if pos = 1 then
233                                                 guards += " (and " + var1v + " (not " + var1 + "))";
234                                         else
235                                                 guards += " false";
236                                 ]
237                         ]
238                         var phis := "";
239                         for ili := 0 to len(ctx.blocks[bgi].instrs) do [
240                                 var igi := ctx.blocks[bgi].instrs[ili];
241                                 var ins := ctx.instrs[igi];
242                                 if ins.opcode <> P_Phi then
243                                         break;
244                                 var phi1 phi1v var1 var1v : bytes;
245                                 ctx, phi1, phi1v := allocate_variable(ctx, ins.params[0]);
246                                 ctx, var1, var1v := allocate_variable(ctx, ins.params[i + 1]);
247                                 phis += " (= " + phi1v + " (and " + var1v + " (= " + phi1 + " " + var1 + ")))";
248                         ]
249                         if len_greater_than(phis, 0) then [
250                                 all_phis += " (and" + phis + ")";
251                         ]
252                 ]
253                 if len_greater_than(all_phis, 0) then
254                         assumes += " (or" + all_phis + ")";
255                 if len_greater_than(guards, 0) then
256                         assumes += " (or" + guards + ")";
257                 for ili := 0 to len(ctx.blocks[bgi].instrs) do [
258                         var igi := ctx.blocks[bgi].instrs[ili];
259                         var ins := ctx.instrs[igi];
260                         if ins.opcode = P_Assume then [
261                                 var var1 var1v : bytes;
262                                 ctx, var1, var1v := allocate_variable(ctx, ins.params[0]);
263                                 if len_greater_than(var1, 0) then
264                                         assumes += " (=> " + var1v + " " + var1 + ")";
265                         ] else if ins.opcode = P_Claim then [
266                                 var var1 var1v : bytes;
267                                 ctx, var1, var1v := allocate_variable(ctx, ins.params[0]);
268                                 if len_greater_than(var1, 0) then [
269                                         claims += " (=> " + var1v + " " + var1 + ")";
270                                 ]
271                         ]
272                 ]
273                 for i := 0 to len(ctx.blocks[bgi].post_list) do [
274                         var p := ctx.blocks[bgi].post_list[i];
275                         claims += " " + ctx.blocks[p].verifier_name;
276                 ]
277                 if len(assumes) = 0 then
278                         assumes := " true";
279                 if len(claims) = 0 then
280                         claims := " true";
281                 z3_eval_smtlib2_string_noret("(assert (= " + ctx.blocks[bgi].verifier_name + " (=> (and" + assumes + ") (and" + claims + "))))");
282         ]
284         z3_eval_smtlib2_string_noret("(assert (not " + ctx.blocks[0].verifier_name + "))");
285         b := z3_eval_smtlib2_string("(check-sat)");
286         if list_begins_with(b, "unsat") then
287                 return unit_value;
288         if list_begins_with(b, "sat") then [
289                 b := z3_eval_smtlib2_string("(get-model)");
290                 b := "Verification of function " + ctx.name + " failed:" + nl + b;
291                 return exception_make_str(unit_type, ec_async, error_compiler_error, 0, false, b);
292         ]
293         b := "Verification of function " + ctx.name + " inconclusive";
294         return exception_make_str(unit_type, ec_async, error_compiler_error, 0, false, b);