return unit_type from verify_function, so that it can modify the context
[ajla.git] / stdlib / compiler / optimize / verify.ajla
blob0f1d1032951a346f0dd229ef81a7027e7db1ee11
1 {*
2  * Copyright (C) 2025 Mikulas Patocka
3  *
4  * This file is part of Ajla.
5  *
6  * Ajla is free software: you can redistribute it and/or modify it under the
7  * terms of the GNU General Public License as published by the Free Software
8  * Foundation, either version 3 of the License, or (at your option) any later
9  * version.
10  *
11  * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13  * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * Ajla. If not, see <https://www.gnu.org/licenses/>.
17  *}
19 private unit compiler.optimize.verify;
21 uses compiler.optimize.defs;
23 fn verify_function(ctx : context) : unit_type;
25 implementation
27 uses exception;
28 uses z3;
29 uses compiler.common.blob;
31 fn allocate_variable(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context, bytes);
33 fn get_z3_name(ctx : context, v : int) : bytes
35         var n := "var_" + ntos_base(v, 16);
36         if len_greater_than(ctx.variables[v].name, 0) then
37                 n += "_" + ctx.variables[v].name;
38         return n;
41 fn get_z3_type(ctx : context, v : int) : bytes
43         if ctx.variables[v].type_index = T_AlwaysFlatOption then
44                 return "Bool";
45         if ctx.variables[v].type_index <= T_Integer, ctx.variables[v].type_index >= T_Integer128 then
46                 return "Int";
47         if ctx.variables[v].type_index <= T_Real16, ctx.variables[v].type_index >= T_Real128 then
48                 return "Real";
49         return "";
52 fn assert_instruction(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context)
54         var ins := ctx.instrs[ctx.variables[v].defining_instr];
55         if ins.opcode = P_BinaryOp or ins.opcode = P_BinaryConstOp then [
56                 var t := ctx.variables[ins.params[3]].type_index;
57                 var is_bool := t = T_AlwaysFlatOption;
58                 var is_int := t <= T_Integer and t >= T_Integer128;
59                 var is_real := t <= T_Real16 and t >= T_Real128;
60                 var op_z3 : bytes;
61                 var op := ins.params[0];
62                 if op = Bin_Add                                 then op_z3 := "+";
63                 else if op = Bin_Subtract                       then op_z3 := "-";
64                 else if op = Bin_Multiply                       then op_z3 := "*";
65                 else if op = Bin_Divide_Int                     then op_z3 := "div";
66                 else if op = Bin_Divide_Real                    then op_z3 := "/";
67                 else if op = Bin_Modulo                         then op_z3 := "rem";
68                 else if op = Bin_And, is_bool                   then op_z3 := "and";
69                 else if op = Bin_Or, is_bool                    then op_z3 := "or";
70                 else if op = Bin_Xor, is_bool                   then op_z3 := "xor";
71                 else if op = Bin_Equal                          then op_z3 := "=";
72                 else if op = Bin_NotEqual                       then op_z3 := "distinct";
73                 else if op = Bin_Less, is_int or is_real        then op_z3 := "<";
74                 else if op = Bin_LessEqual, is_int or is_real   then op_z3 := "<=";
75                 else if op = Bin_Greater, is_int or is_real     then op_z3 := ">";
76                 else if op = Bin_GreaterEqual, is_int or is_real then op_z3 := ">=";
77                 else if op = Bin_LessEqual, is_bool             then op_z3 := "=>";
78                 else return ctx;
79                 var var1 var2 : bytes;
80                 //eval debug("P_BinaryOp: " + ntos(ins.params[0]) + ", " + ntos(ins.params[1]) + ", " + ntos(ins.params[2]) + ", " + ntos(ins.params[3]) + ", " + ntos(ins.params[4]) + ", " + ntos(ins.params[5]));
81                 ctx, var1 := allocate_variable(ctx, ins.params[3]);
82                 if ins.opcode = P_BinaryOp then [
83                         ctx, var2 := allocate_variable(ctx, ins.params[5]);
84                 ] else [
85                         if is_bool then
86                                 var2 := select(ins.params[4] <> 0, "false", "true");
87                         else
88                                 var2 := ntos(ins.params[4]);
89                 ]
90                 if len_greater_than(var1, 0), len_greater_than(var2, 0) then [
91                         z3_eval_smtlib2_string_noret("(assert (= " + ctx.variables[v].verifier_name + " (" + op_z3 + " " + var1 + " " + var2 + ")))");
92                 ]
93                 return ctx;
94         ]
95         if ins.opcode = P_UnaryOp then [
96                 var t := ctx.variables[v].type_index;
97                 var is_bool := t = T_AlwaysFlatOption;
98                 var is_int := t <= T_Integer and t >= T_Integer128;
99                 var is_real := t <= T_Real16 and t >= T_Real128;
100                 var op_z3 : bytes;
101                 var op := ins.params[0];
102                 if op = Un_Not, is_bool                         then op_z3 := "not";
103                 else if op = Un_Neg, is_int or is_real          then op_z3 := "-";
104                 else return ctx;
105                 var var1 : bytes;
106                 ctx, var1 := allocate_variable(ctx, ins.params[3]);
107                 if len_greater_than(var1, 0) then [
108                         z3_eval_smtlib2_string_noret("(assert (= " + ctx.variables[v].verifier_name + " (" + op_z3 + " " + var1 + ")))");
109                 ]
110                 return ctx;
111         ]
112         if ins.opcode = P_Copy then [
113                 var new_var : bytes;
114                 ctx, new_var := allocate_variable(ctx, ins.params[2]);
115                 if len_greater_than(new_var, 0) then
116                         z3_eval_smtlib2_string_noret("(assert (= " + ctx.variables[v].verifier_name + " " + new_var + "))");
117                 return ctx;
118         ]
119         if ins.opcode = P_Load_Const then [
120                 var cnst : bytes;
121                 var l := blob_to_int(ins.params[1 ..]);
122                 var t := ctx.variables[v].type_index;
123                 if t = T_AlwaysFlatOption then
124                         cnst := select(l <> 0, "false", "true");
125                 else if t <= T_Integer and t >= T_Integer128 then
126                         cnst := ntos(l);
127                 else
128                         return ctx;
129                 z3_eval_smtlib2_string_noret("(assert (= " + ctx.variables[v].verifier_name + " " + cnst + "))");
130                 return ctx;
131         ]
132         if ins.opcode = P_Return_Vars then [
133                 var new_v : int;
134                 for i := 0 to len(ins.params) do [
135                         if ins.params[i] = v then [
136                                 new_v := ctx.return_ins.params[1 + 2 * i];
137                                 goto found_new_v;
138                         ]
139                 ]
140                 abort internal("P_Return_Vars parameter not found");
141 found_new_v:
142                 var new_var : bytes;
143                 ctx, new_var := allocate_variable(ctx, new_v);
144                 if len_greater_than(new_var, 0) then
145                         z3_eval_smtlib2_string_noret("(assert (= " + ctx.variables[v].verifier_name + " " + new_var + "))");
146                 return ctx;
147         ]
148         //eval debug("opcode: " + ntos(ins.opcode) + " (" + ctx.variables[v].verifier_name + ")");
149         return ctx;
152 fn allocate_variable(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, v : int) : (z3_world, context, bytes)
154         if not len_greater_than(ctx.variables[v].verifier_name, 0) then [
155                 var t := get_z3_type(ctx, v);
156                 if not len_greater_than(t, 0) then
157                         return ctx, "";
158                 var n := get_z3_name(ctx, v);
159                 ctx.variables[v].verifier_name := n;
160                 z3_eval_smtlib2_string_noret("(declare-const " + n + " " + t + ")");
161                 ctx := assert_instruction(ctx, v);
162         ]
163         return ctx, ctx.variables[v].verifier_name;
166 fn get_cond_guards(implicit z3w : z3_world, implicit z3ctx : z3_context, ctx : context, bgi : int) : (z3_world, context, bytes)
168         var dom_set := ctx.blocks[bgi].dom;
169         var all_preds := "";
170         while dom_set <> 0 do [
171                 var dom : int := bsr dom_set;
172                 dom_set btr= dom;
173                 var bb := ctx.blocks[dom];
174                 if not bb.active then
175                         abort internal("dominator is not active");
176                 if len(bb.pred_list) = 0 then
177                         goto next_dom;
178                 var pred_string := "";
179                 for i := 0 to len(bb.pred_list) do [
180                         if bb.pred_position[i] = 2 then
181                                 goto next_dom;
182                         var bb_pred := ctx.blocks[bb.pred_list[i]];
183                         if len(bb_pred.instrs) = 0 then
184                                 goto next_dom;
185                         var lins := ctx.instrs[bb_pred.instrs[len(bb_pred.instrs) - 1]];
186                         if lins.opcode <> P_Jmp_False then
187                                 goto next_dom;
188                         var vgic := lins.params[0];
189                         var vgic_name : bytes;
190                         ctx, vgic_name := allocate_variable(ctx, vgic);
191                         if not len_greater_than(vgic_name, 0) then
192                                 goto next_dom;
193                         if len_greater_than(pred_string, 0) then [
194                                 pred_string +<= ' ';
195                         ] else [
196                                 pred_string := "(or ";
197                         ]
198                         if bb.pred_position[i] = 0 then [
199                                 pred_string += vgic_name;
200                         ] else [
201                                 pred_string += "(not " + vgic_name + ")";
202                         ]
203                 ]
204                 if not len_greater_than(pred_string, 0) then
205                         goto next_dom;
206                 pred_string +<= ')';
207                 //eval debug("pred_string: " + pred_string);
208                 if len_greater_than(all_preds, 0) then [
209                         all_preds +<= ' ';
210                 ] else [
211                         all_preds := "(and ";
212                 ]
213                 all_preds += pred_string;
214 next_dom:
215         ]
216         if not len_greater_than(all_preds, 0) then
217                 return ctx, " ";
218         all_preds +<= ')';
219         //eval debug("all_preds: " + all_preds);
220         return ctx, all_preds;
223 fn verify_function(ctx : context) : unit_type
225         var claims := "";
226         var b : bytes;
227         implicit var z3w := z3_mk_world;
228         implicit var z3ctx := z3_mk_context();
229         //eval debug("verify function " + ctx.name);
231         for bgi := 0 to len(ctx.blocks) do [
232                 if not ctx.blocks[bgi].active then
233                         continue;
234                 for ili := 0 to len(ctx.blocks[bgi].instrs) do [
235                         var igi := ctx.blocks[bgi].instrs[ili];
236                         var ins := ctx.instrs[igi];
237                         if ins.opcode = P_Return then [
238                                 ctx.return_ins := ins;
239                                 goto found_ret;
240                         ]
241                 ]
242         ]
243 found_ret:
245         for bgi := 0 to len(ctx.blocks) do [
246                 if not ctx.blocks[bgi].active then
247                         continue;
248                 var cond_guards := "";
249                 for ili := 0 to len(ctx.blocks[bgi].instrs) do [
250                         var igi := ctx.blocks[bgi].instrs[ili];
251                         var ins := ctx.instrs[igi];
252                         if ins.opcode = P_Assume then [
253                                 //eval debug("assume");
254                                 var str : bytes;
255                                 ctx, str := allocate_variable(ctx, ins.params[0]);
256                                 if len_greater_than(str, 0) then
257                                         z3_eval_smtlib2_string_noret("(assert " + str + ")");
258                         ] else if ins.opcode = P_Claim then [
259                                 //eval debug("claim");
260                                 if not len_greater_than(cond_guards, 0) then
261                                         ctx, cond_guards := get_cond_guards(ctx, bgi);
262                                 var str : bytes;
263                                 ctx, str := allocate_variable(ctx, ins.params[0]);
264                                 if len_greater_than(str, 0) then [
265                                         if cond_guards[0] <> ' ' then
266                                                 claims += " (=> " + cond_guards + " " + str + ")";
267                                         else
268                                                 claims += " " + str;
269                                 ]
270                         ]
271                 ]
272         ]
274         z3_eval_smtlib2_string_noret("(assert (not (and true" + claims + ")))");
275         b := z3_eval_smtlib2_string("(check-sat)");
276         if list_begins_with(b, "unsat") then
277                 return unit_value;
278         if list_begins_with(b, "sat") then [
279                 b := z3_eval_smtlib2_string("(get-model)");
280                 b := "Verification of function " + ctx.name + " failed:" + nl + b;
281                 return exception_make_str(unit_type, ec_async, error_compiler_error, 0, false, b);
282         ]
283         b := "Verification of function " + ctx.name + " inconclusive";
284         return exception_make_str(unit_type, ec_async, error_compiler_error, 0, false, b);