ease the proof of coincidence count
[why3.git] / plugins / strategies / forward_propagation.ml
blob6e1441f950017f2f8ddc54ad5a4c60361e6f20f7
1 (********************************************************************)
2 (* *)
3 (* The Why3 Verification Platform / The Why3 Development Team *)
4 (* Copyright 2010-2022 -- Inria - CNRS - Paris-Saclay University *)
5 (* *)
6 (* This software is distributed under the terms of the GNU Lesser *)
7 (* General Public License version 2.1, with the special exception *)
8 (* on linking described in file LICENSE. *)
9 (* *)
10 (********************************************************************)
12 open Why3
13 open Strategy
14 open Term
15 open Decl
16 open Ty
17 open Theory
18 open Ident
19 open Task
20 open Format
22 type ufloat_symbols = {
23 ufloat_type : tysymbol;
24 to_real : lsymbol;
25 uadd : lsymbol;
26 usub : lsymbol;
27 umul : lsymbol;
28 udiv : lsymbol;
29 uminus : lsymbol;
30 udiv_exact : lsymbol;
31 uadd_infix : lsymbol;
32 usub_infix : lsymbol;
33 umul_infix : lsymbol;
34 udiv_infix : lsymbol;
35 uminus_prefix : lsymbol;
36 udiv_exact_infix : lsymbol;
37 eps : term;
38 eta : term;
41 type symbols = {
42 add : lsymbol;
43 sub : lsymbol;
44 mul : lsymbol;
45 _div : lsymbol;
46 minus : lsymbol;
47 add_infix : lsymbol;
48 sub_infix : lsymbol;
49 mul_infix : lsymbol;
50 div_infix : lsymbol;
51 minus_infix : lsymbol;
52 lt : lsymbol;
53 lt_infix : lsymbol;
54 le : lsymbol;
55 le_infix : lsymbol;
56 abs : lsymbol;
57 exp : lsymbol;
58 log : lsymbol;
59 log2 : lsymbol;
60 log10 : lsymbol;
61 sin : lsymbol;
62 cos : lsymbol;
63 usingle_symbols : ufloat_symbols;
64 udouble_symbols : ufloat_symbols;
65 ident_printer : ident_printer;
66 tv_printer : ident_printer;
69 let symbols = ref None
70 let ( !! ) s = Option.get !s
72 (* A term 't' having a forward error 'e' means "|t - e.exact| <= e.rel *
73 e.factor + e.cst", where |exact_t| <= t' *)
74 type forward_error = {
75 exact : term;
76 factor : term;
77 rel : term;
78 cst : term;
81 (* This type corresponds to the numeric info we have on a real/float term *)
82 type term_info = {
83 error : forward_error option;
84 computed_error : (forward_error * term * proof_tree) option;
86 * "Some (op, [x; y])" means that the term "t" is the result of the FP operation
87 * "op" on "x" and "y"
89 ieee_op : (lsymbol * term list) option;
92 type info = {
93 terms_info : term_info Mterm.t;
94 (* fns_info is used for the sum propagation lemma *)
95 fns_info : forward_error Mls.t;
96 (* ls_defs is used to store logic definitions *)
97 ls_defs : term Mls.t;
100 let default_strat () = Sdo_nothing
102 let zero =
103 t_const
104 (Constant.ConstReal
105 (Number.real_literal ~radix:10 ~neg:false ~int:"0" ~frac:"0" ~exp:None))
106 ty_real
108 let one =
109 t_const
110 (Constant.ConstReal
111 (Number.real_literal ~radix:10 ~neg:false ~int:"1" ~frac:"0" ~exp:None))
112 ty_real
114 let is_zero t = t_equal zero t
115 let is_one t = t_equal one t
117 let abs t =
118 match t.t_node with
119 (* Don't add an abs symbol on top of another *)
120 | Tapp (ls, [ _ ]) when ls_equal !!symbols.abs ls -> t
121 | _ -> fs_app !!symbols.abs [ t ] ty_real
123 let is_ineq_ls ls =
124 let symbols = !!symbols in
125 ls_equal ls symbols.lt || ls_equal ls symbols.le
126 || ls_equal ls symbols.lt_infix
127 || ls_equal ls symbols.le_infix
129 let is_add_ls ls = ls_equal ls !!symbols.add || ls_equal ls !!symbols.add_infix
130 let is_sub_ls ls = ls_equal ls !!symbols.sub || ls_equal ls !!symbols.sub_infix
131 let is_mul_ls ls = ls_equal ls !!symbols.mul || ls_equal ls !!symbols.mul_infix
133 let is_minus_ls ls =
134 ls_equal ls !!symbols.minus || ls_equal ls !!symbols.minus_infix
136 let is_abs_ls ls = ls_equal ls !!symbols.abs
138 let is_to_real_ls ls =
139 ls_equal ls !!symbols.usingle_symbols.to_real
140 || ls_equal ls !!symbols.udouble_symbols.to_real
142 let is_uadd_ls ls =
143 ls_equal ls !!symbols.usingle_symbols.uadd
144 || ls_equal ls !!symbols.usingle_symbols.uadd_infix
145 || ls_equal ls !!symbols.udouble_symbols.uadd
146 || ls_equal ls !!symbols.udouble_symbols.uadd_infix
148 let is_usub_ls ls =
149 ls_equal ls !!symbols.usingle_symbols.usub
150 || ls_equal ls !!symbols.usingle_symbols.usub_infix
151 || ls_equal ls !!symbols.udouble_symbols.usub
152 || ls_equal ls !!symbols.udouble_symbols.usub_infix
154 let is_umul_ls ls =
155 ls_equal ls !!symbols.usingle_symbols.umul
156 || ls_equal ls !!symbols.usingle_symbols.umul_infix
157 || ls_equal ls !!symbols.udouble_symbols.umul
158 || ls_equal ls !!symbols.udouble_symbols.umul_infix
160 let is_udiv_ls ls =
161 ls_equal ls !!symbols.usingle_symbols.udiv
162 || ls_equal ls !!symbols.usingle_symbols.udiv_infix
163 || ls_equal ls !!symbols.udouble_symbols.udiv
164 || ls_equal ls !!symbols.udouble_symbols.udiv_infix
166 let is_uminus_ls ls =
167 ls_equal ls !!symbols.usingle_symbols.uminus
168 || ls_equal ls !!symbols.usingle_symbols.uminus_prefix
169 || ls_equal ls !!symbols.udouble_symbols.uminus
170 || ls_equal ls !!symbols.udouble_symbols.uminus_prefix
172 let is_uexact_div_ls ls =
173 ls_equal ls !!symbols.usingle_symbols.udiv_exact
174 || ls_equal ls !!symbols.usingle_symbols.udiv_exact_infix
175 || ls_equal ls !!symbols.udouble_symbols.udiv_exact
176 || ls_equal ls !!symbols.udouble_symbols.udiv_exact_infix
178 let is_uop ls =
179 is_uadd_ls ls || is_usub_ls ls || is_umul_ls ls || is_udiv_ls ls
180 || is_uminus_ls ls || is_uexact_div_ls ls
182 let minus t = fs_app !!symbols.minus_infix [ t ] ty_real
184 let minus_simp t =
185 match t.t_node with
186 | Tapp (ls, [ t' ]) when is_minus_ls ls -> t'
187 | _ -> minus t
189 let add t1 t2 = fs_app !!symbols.add_infix [ t1; t2 ] ty_real
191 let add_simp t1 t2 =
192 if is_zero t1 then
194 else if is_zero t2 then
196 else
197 add t1 t2
199 let sub t1 t2 = fs_app !!symbols.sub_infix [ t1; t2 ] ty_real
201 let sub_simp t1 t2 =
202 if is_zero t1 then
203 minus_simp t2
204 else if is_zero t2 then
206 else
207 sub t1 t2
209 let mul t1 t2 = fs_app !!symbols.mul_infix [ t1; t2 ] ty_real
210 let div t1 t2 = fs_app !!symbols.div_infix [ t1; t2 ] ty_real
212 let mul_simp t1 t2 =
213 if is_zero t1 || is_zero t2 then
214 zero
215 else if is_one t1 then
217 else if is_one t2 then
219 else
220 match (t1.t_node, t2.t_node) with
221 | Tapp (ls1, [ t1 ]), Tapp (ls2, [ t2 ]) when is_abs_ls ls1 && is_abs_ls ls2
223 abs (mul t1 t2)
224 | _ -> mul t1 t2
226 let div_simp t1 t2 =
227 if is_zero t1 then
228 zero
229 else if is_one t2 then
231 else
232 div t1 t2
234 let ( +. ) x y = add x y
235 let ( -. ) x y = sub x y
236 let ( *. ) x y = mul x y
237 let ( /. ) x y = div x y
238 let ( ++. ) x y = add_simp x y
239 let ( --. ) x y = sub_simp x y
240 let ( **. ) x y = mul_simp x y
241 let ( //. ) x y = div_simp x y
242 let ( <=. ) x y = ps_app !!symbols.le_infix [ x; y ]
244 let is_ty_float ty =
245 match ty.ty_node with
246 | Tyapp (v, []) ->
248 ts_equal v !!symbols.usingle_symbols.ufloat_type
249 || ts_equal v !!symbols.udouble_symbols.ufloat_type
250 then
251 true
252 else
253 false
254 | _ -> false
256 let eps ieee_type =
257 if ts_equal ieee_type !!symbols.usingle_symbols.ufloat_type then
258 !!symbols.usingle_symbols.eps
259 else if ts_equal ieee_type !!symbols.udouble_symbols.ufloat_type then
260 !!symbols.udouble_symbols.eps
261 else
262 failwith (asprintf "Unsupported type %a" Pretty.print_ts ieee_type)
264 let eta ieee_type =
265 if ts_equal ieee_type !!symbols.usingle_symbols.ufloat_type then
266 !!symbols.usingle_symbols.eta
267 else if ts_equal ieee_type !!symbols.udouble_symbols.ufloat_type then
268 !!symbols.udouble_symbols.eta
269 else
270 failwith (asprintf "Unsupported type %a" Pretty.print_ts ieee_type)
272 let to_real ieee_type t =
273 let to_real =
274 if ts_equal ieee_type !!symbols.usingle_symbols.ufloat_type then
275 !!symbols.usingle_symbols.to_real
276 else if ts_equal ieee_type !!symbols.udouble_symbols.ufloat_type then
277 !!symbols.udouble_symbols.to_real
278 else
279 failwith (asprintf "Unsupported type %a" Pretty.print_ts ieee_type)
281 fs_app to_real [ t ] ty_real
283 let get_info info t =
284 try Mterm.find t info with
285 | Not_found -> { error = None; computed_error = None; ieee_op = None }
287 let add_fw_error info t error =
288 let t =
289 match t.t_node with
290 | Tapp (ls, [ t ]) when is_to_real_ls ls -> t
291 | _ -> t
293 let t_info = get_info info t in
294 let t_info = { t_info with error = Some error } in
295 Mterm.add t t_info info
297 let add_computed_fw_error info t ((fe, _, _) as error) =
298 let t =
299 match t.t_node with
300 | Tapp (ls, [ t ]) when is_to_real_ls ls -> t
301 | _ -> t
303 let t_info = get_info info t in
304 let t_info = { t_info with computed_error = Some error; error = Some fe } in
305 Mterm.add t t_info info
307 let add_ieee_op info ls t args =
308 let t_info = get_info info t in
309 let t_info = { t_info with ieee_op = Some (ls, args) } in
310 Mterm.add t t_info info
312 let get_ts t =
313 match t.t_ty with
314 | None -> assert false
315 | Some ty -> (
316 match ty.ty_node with
317 | Tyvar _ -> assert false
318 | Tyapp (ts, []) -> ts
319 | _ -> assert false)
321 (* Return the float terms inside `t`. Note that if `t` is an application of type
322 float we don't return the floats that it potentially contains *)
323 (* TODO: Don't put duplicates in list !!! *)
324 let rec get_floats t =
325 match t.t_ty with
326 | Some ty when is_ty_float ty -> [ t ]
327 | _ -> (
328 match t.t_node with
329 | Tapp (_, tl) -> List.fold_left (fun l t -> l @ get_floats t) [] tl
330 | _ -> [])
332 let string_of_ufloat_type ts =
333 if ts_equal ts !!symbols.usingle_symbols.ufloat_type then
334 "single"
335 else if ts_equal ts !!symbols.udouble_symbols.ufloat_type then
336 "double"
337 else
338 failwith (asprintf "Unsupported type %a" Pretty.print_ts ts)
340 let string_of_ufloat_type_and_op ts uop =
341 let ty_str = string_of_ufloat_type ts in
342 let uop_str =
343 if is_uadd_ls uop then
344 "uadd"
345 else if is_usub_ls uop then
346 "usub"
347 else if is_umul_ls uop then
348 "umul"
349 else if is_uexact_div_ls uop then
350 "umul"
351 else
352 failwith (asprintf "Unsupported uop '%a'" Pretty.print_ls uop)
354 uop_str ^ "_" ^ ty_str
356 let term_to_str t =
357 let module P =
358 (val Pretty.create !!symbols.ident_printer !!symbols.tv_printer
359 !!symbols.ident_printer
360 (Ident.create_ident_printer []))
362 Format.asprintf "%a" P.print_term t
364 (* Error on an IEEE op when propagation is needed (eg. we have a forward error
365 on t1 and/or on t2) *)
366 let combine_uop_errors info uop t1 e1 t2 e2 r strat_for_t1 strat_for_t2 =
367 let ts = get_ts r in
368 let eps = eps ts in
369 let eta = eta ts in
370 let to_real = to_real ts in
371 let rel_err, rel_err_simp =
372 if is_uadd_ls uop || is_usub_ls uop then
373 (* Relative error for addition and sustraction *)
374 (e1.rel +. e2.rel +. eps, e1.rel ++. e2.rel ++. eps)
375 else
376 (* Relative error for multiplication *)
377 ( eps +. ((e1.rel +. e2.rel +. (e1.rel *. e2.rel)) *. (one +. eps)),
378 eps ++. ((e1.rel ++. e2.rel ++. (e1.rel **. e2.rel)) **. (one ++. eps))
381 let cst_err, cst_err_simp =
382 if is_uadd_ls uop || is_usub_ls uop then
383 (* Constant error for addition and sustraction *)
384 ( ((one +. eps +. e2.rel) *. e1.cst) +. ((one +. eps +. e1.rel) *. e2.cst),
385 ((one ++. eps ++. e2.rel) **. e1.cst)
386 ++. ((one ++. eps ++. e1.rel) **. e2.cst) )
387 else
388 (* Constant error for multiplication *)
389 ( (((e2.cst +. (e2.cst *. e1.rel)) *. e1.factor)
390 +. ((e1.cst +. (e1.cst *. e2.rel)) *. e2.factor)
391 +. (e1.cst *. e2.cst))
392 *. (one +. eps)
393 +. eta,
394 (((one ++. eps) **. (e2.cst ++. (e2.cst **. e1.rel))) **. e1.factor)
395 ++. (((one ++. eps) **. (e1.cst ++. (e1.cst **. e2.rel))) **. e2.factor)
396 ++. ((one ++. eps) **. e1.cst **. e2.cst)
397 ++. eta )
399 let total_err =
400 if is_uadd_ls uop || is_usub_ls uop then
401 (rel_err *. (e1.factor +. e2.factor)) +. cst_err
402 else
403 (rel_err *. (e1.factor *. e2.factor)) +. cst_err
405 let total_err_simp =
406 if is_uadd_ls uop || is_usub_ls uop then
407 (rel_err **. (e1.factor ++. e2.factor)) ++. cst_err
408 else
409 (rel_err **. e1.factor **. e2.factor) ++. cst_err
411 let exact, exact_simp =
412 if is_uadd_ls uop then
413 (e1.exact +. e2.exact, e1.factor ++. e2.factor)
414 else if is_usub_ls uop then
415 (e1.exact -. e2.exact, e1.factor ++. e2.factor)
416 else
417 (e1.exact *. e2.exact, e1.factor **. e2.factor)
419 let str = string_of_ufloat_type_and_op ts uop in
420 let strat =
421 Sapply_trans
422 ( "apply",
424 str ^ "_error_propagation";
425 "with";
426 sprintf "%s,%s" (term_to_str t1) (term_to_str t2);
429 strat_for_t1;
430 strat_for_t2;
431 default_strat ();
432 default_strat ();
433 default_strat ();
434 default_strat ();
435 default_strat ();
436 default_strat ();
437 default_strat ();
440 let f = abs (to_real r -. exact) <=. total_err in
441 let fw_err =
442 { exact; rel = rel_err_simp; factor = exact_simp; cst = cst_err_simp }
444 let f, strat =
445 if t_equal total_err total_err_simp then
446 (f, strat)
447 else
448 let f_simp = abs (to_real r -. exact) <=. total_err_simp in
449 ( f_simp,
450 Sapply_trans ("assert", [ term_to_str f ], [ strat; default_strat () ])
453 let info = add_computed_fw_error info r (fw_err, f, strat) in
454 (info, f, strat)
456 (* Error on a IEEE op when no propagation is needed (eg. we don't have a forward
457 error on t1 nor t2) *)
458 let basic_uop_error info uop r t1 t2 =
459 let ts = get_ts r in
460 let eps = eps ts in
461 let eta = eta ts in
462 let to_real = to_real ts in
463 let exact =
464 if is_uadd_ls uop then
465 to_real t1 +. to_real t2
466 else if is_usub_ls uop then
467 to_real t1 -. to_real t2
468 else
469 to_real t1 *. to_real t2
471 let cst_err =
472 if is_umul_ls uop then
474 else
475 zero
477 let total_err = (eps *. abs exact) ++. cst_err in
478 let info =
479 add_fw_error info r { exact; rel = eps; factor = abs exact; cst = cst_err }
481 let f = abs (to_real r -. exact) <=. total_err in
482 (info, f, default_strat ())
484 (* Generates the formula and the strat for the forward error of `r = uop t1
485 t2` *)
486 let use_ieee_thms info uop r t1 t2 strat_for_t1 strat_for_t2 =
487 let t1_info = get_info info t1 in
488 let t2_info = get_info info t2 in
489 match (t1_info.error, t2_info.error) with
490 (* No propagation needed, we use the basic lemma *)
491 | None, None -> basic_uop_error info uop r t1 t2
492 (* We have an error on at least one of t1 and t2, use the propagation lemma *)
493 | _ ->
494 let to_real = to_real (get_ts r) in
495 let get_err_or_default t t_info =
496 match t_info.error with
497 | None ->
498 { exact = to_real t; rel = zero; factor = abs (to_real t); cst = zero }
499 | Some e -> e
501 let e1 = get_err_or_default t1 t1_info in
502 let e2 = get_err_or_default t2 t2_info in
503 combine_uop_errors info uop t1 e1 t2 e2 r strat_for_t1 strat_for_t2
505 (* Returns None if the function is unsupported by the strategy. Otherwise,
506 returns its symbol as well as its arguments. *)
507 let get_known_fn_and_args _t x =
508 match x.t_node with
509 | Tapp (ls, args)
510 when ls_equal ls !!symbols.log || ls_equal ls !!symbols.exp
511 || ls_equal ls !!symbols.log2
512 || ls_equal ls !!symbols.log10
513 || ls_equal ls !!symbols.sin || ls_equal ls !!symbols.cos ->
514 let args =
515 List.map
516 (fun arg ->
517 match arg.t_node with
518 | Tapp (ls, [ t ]) when is_to_real_ls ls -> t
519 | _ -> arg)
520 args
522 Some (ls, args)
523 | _ -> None
525 (* Returns the forward error formula associated with the propagation lemma of
526 the function `exact_fn`. `app_approx` *)
527 let get_fn_errs info exact_fn app_approx arg_approx =
528 let e_arg = Option.get (get_info info.terms_info arg_approx).error in
529 let e_app = Option.get (get_info info.terms_info app_approx).error in
530 if ls_equal exact_fn !!symbols.exp then
531 let cst_err = e_app.cst in
532 let a =
533 fs_app exact_fn [ (e_arg.rel *. e_arg.factor) +. e_arg.cst ] ty_real
535 let a_simp =
536 fs_app exact_fn [ (e_arg.rel **. e_arg.factor) ++. e_arg.cst ] ty_real
538 let rel_err = e_app.rel +. ((a -. one) *. (one +. e_app.rel)) in
539 let rel_err_simp =
540 e_app.rel ++. ((a_simp --. one) **. (one ++. e_app.rel))
542 ( rel_err,
543 rel_err_simp,
544 cst_err,
545 cst_err,
546 fs_app exact_fn [ e_arg.exact ] ty_real,
548 else if
549 ls_equal exact_fn !!symbols.log
550 || ls_equal exact_fn !!symbols.log2
551 || ls_equal exact_fn !!symbols.log10
552 then
553 let a =
554 fs_app exact_fn
555 [ one -. (((e_arg.rel *. e_arg.factor) +. e_arg.cst) /. e_arg.exact) ]
556 ty_real
558 let a_simp =
560 t_equal e_arg.exact e_arg.factor
561 || t_equal (abs e_arg.exact) e_arg.factor
562 then
563 fs_app exact_fn
564 [ one --. (e_arg.rel ++. (e_arg.cst //. e_arg.exact)) ]
565 ty_real
566 else
567 fs_app exact_fn
570 --. (((e_arg.rel **. e_arg.factor) ++. e_arg.cst) //. e_arg.exact);
572 ty_real
574 let cst_err = (minus a *. (one +. e_app.rel)) +. e_app.cst in
575 let cst_err_simp =
576 (minus_simp a_simp **. (one ++. e_app.rel)) ++. e_app.cst
578 let rel_err = e_app.rel in
579 ( rel_err,
580 rel_err,
581 cst_err,
582 cst_err_simp,
583 abs (fs_app exact_fn [ e_arg.exact ] ty_real),
585 else if ls_equal exact_fn !!symbols.sin || ls_equal exact_fn !!symbols.cos
586 then
587 let cst_err =
588 (((e_arg.rel *. e_arg.factor) +. e_arg.cst) *. (one +. e_app.rel))
589 +. e_app.cst
591 let cst_err_simp =
592 (((e_arg.rel **. e_arg.factor) ++. e_arg.cst) **. (one ++. e_app.rel))
593 ++. e_app.cst
595 let rel_err = e_app.rel in
596 ( rel_err,
597 rel_err,
598 cst_err,
599 cst_err_simp,
600 abs (fs_app exact_fn [ e_arg.exact ] ty_real),
602 else
603 assert false
605 (* Returns the forward error formula and the strat associated with the
606 application of the propagation lemma for `fn`. The argument `strat`
607 corresponds to the strat that is used to prove the error on `arg_approx` (the
608 argument of the function) *)
609 let apply_fn_thm info fn app_approx arg_approx strat =
610 let e_arg = Option.get (get_info info.terms_info arg_approx).error in
611 let to_real = to_real (get_ts app_approx) in
612 let exact = fs_app fn [ e_arg.exact ] ty_real in
613 let fn_str = fn.ls_name.id_string in
614 let ty_str = string_of_ufloat_type (get_ts app_approx) in
615 let rel_err, rel_err_simp, cst_err, cst_err_simp, app', nb =
616 get_fn_errs info fn app_approx arg_approx
618 let strat =
619 Sapply_trans
620 ( "apply",
622 sprintf "%s_%s_error_propagation" fn_str ty_str;
623 "with";
624 sprintf "%s" (term_to_str arg_approx);
626 [ strat ] @ List.init nb (fun _ -> Sdo_nothing) )
628 let total_err = (rel_err *. app') +. cst_err in
629 let total_err_simp = (app' **. rel_err_simp) ++. cst_err_simp in
630 let left = abs (to_real app_approx -. exact) in
631 let f, strat =
632 if t_equal total_err total_err_simp then
633 (left <=. total_err, strat)
634 else
635 ( left <=. total_err_simp,
636 Sapply_trans
637 ( "assert",
638 [ term_to_str (left <=. total_err) ],
639 [ strat; default_strat () ] ) )
641 let term_info =
642 add_computed_fw_error info.terms_info app_approx
643 ( { exact; rel = rel_err_simp; factor = app'; cst = cst_err_simp },
645 strat )
647 (term_info, Some f, strat)
649 let use_known_thm info app_approx fn args strats =
651 (* Nothing to do if none of the args have a forward error *)
653 (List.exists
654 (fun arg ->
655 match (get_info info.terms_info arg).error with
656 | None -> false
657 | Some _ -> true)
658 args)
659 then
660 (info.terms_info, None, List.hd strats)
661 else if
662 ls_equal fn !!symbols.exp || ls_equal fn !!symbols.log
663 || ls_equal fn !!symbols.log2
664 || ls_equal fn !!symbols.log10
665 || ls_equal fn !!symbols.sin || ls_equal fn !!symbols.cos
666 then
667 apply_fn_thm info fn app_approx (List.hd args) (List.hd strats)
668 else
669 failwith (asprintf "Unsupported fn symbol '%a'" Pretty.print_ls fn)
671 (* Recursively unfold the definition of `t` if it has one. Stops when we find an
672 error that we can use *)
673 let update_term_info info t =
674 let t_info = get_info info.terms_info t in
675 let rec recurse t' =
676 let t'_info = get_info info.terms_info t' in
677 match t'_info.ieee_op with
678 | None -> (
679 match t'_info.error with
680 | None -> (
681 match t'.t_node with
682 | Tapp (ls, []) ->
683 if Mls.mem ls info.ls_defs then
684 recurse (Mls.find ls info.ls_defs)
685 else
686 t_info
687 | Tapp (ls, args) when is_uop ls ->
688 { t_info with ieee_op = Some (ls, args) }
689 | _ -> t_info)
690 | Some _ as error -> { t_info with error })
691 | Some _ as ieee_op -> { t_info with ieee_op }
693 recurse t
696 * Generate error formulas recursively for a term `t` using propagation lemmas.
697 * This is recursive because if `t` is an approximation of a term `u` which
698 * itself is an approximation of a term `v`, we first compute a formula for the
699 * approximation of `v` by `u` and we combine it with the formula we already
700 * have of the approximation of `u` by `t` to get a formula relating `t` to `v`.
702 let rec get_error_fmlas info t =
703 let t_info = update_term_info info t in
704 let get_strat f s =
705 match f with
706 | Some f ->
707 Sapply_trans ("assert", [ term_to_str f ], [ s; default_strat () ])
708 | None -> s
710 match t_info.computed_error with
711 | Some (_, f, strat) -> (info.terms_info, Some f, strat)
712 | None -> (
713 match t_info.ieee_op with
714 (* `t` is the result of the IEEE minus operation *)
715 | Some (ieee_op, [ x ]) when is_uminus_ls ieee_op -> (
716 let terms_info, fmla, strat_for_x = get_error_fmlas info x in
717 let strat = get_strat fmla strat_for_x in
718 let ts = get_ts t in
719 let to_real = to_real ts in
720 let x_info = get_info terms_info x in
721 match x_info.error with
722 (* No propagation needed *)
723 | None ->
724 (terms_info, None, strat)
725 (* The error doesn't change since the float minus operation is exact *)
726 | Some { exact; rel; factor; cst } ->
727 let exact = minus exact in
728 let f = abs (to_real t -. exact) <=. (rel **. factor) ++. cst in
729 let terms_info =
730 add_computed_fw_error terms_info t
731 ({ exact; rel; factor; cst }, f, strat)
733 (terms_info, Some f, strat))
734 (* `t` is the result of an IEEE addition/sustraction/multiplication *)
735 | Some (ieee_op, [ t1; t2 ])
736 when is_uadd_ls ieee_op || is_usub_ls ieee_op || is_umul_ls ieee_op ->
737 (* Get error formulas on subterms `t1` and `t2` *)
738 let terms_info, fmla1, strat_for_t1 = get_error_fmlas info t1 in
739 let terms_info, fmla2, strat_for_t2 =
740 get_error_fmlas { info with terms_info } t2
742 let strat_for_t1 = get_strat fmla1 strat_for_t1 in
743 let strat_for_t2 = get_strat fmla2 strat_for_t2 in
744 let terms_info, f, strats =
745 use_ieee_thms terms_info ieee_op t t1 t2 strat_for_t1 strat_for_t2
747 (terms_info, Some f, strats)
748 (* `t` is the result of an other IEEE operation *)
749 | Some (ieee_op, [ t1; t2 ]) when is_uexact_div_ls ieee_op ->
750 let ts = get_ts t2 in
751 let eta = eta ts in
752 let to_real = to_real ts in
753 let terms_info, fmla1, strat_for_t1 = get_error_fmlas info t1 in
754 let strat_for_t1 = get_strat fmla1 strat_for_t1 in
755 let t1_info = get_info terms_info t1 in
756 let e1 = Option.get t1_info.error in
757 let fe =
759 exact = e1.exact //. to_real t2;
760 rel = e1.rel;
761 factor = e1.factor //. abs (to_real t2);
762 cst = (e1.cst //. abs (to_real t2)) ++. eta;
765 let err =
766 (e1.rel *. (e1.factor /. abs (to_real t2)))
767 +. ((e1.cst /. abs (to_real t2)) +. eta)
769 let err_simp =
770 (e1.rel **. (e1.factor //. abs (to_real t2)))
771 ++. ((e1.cst //. abs (to_real t2)) ++. eta)
773 let left = abs (to_real t -. (e1.exact //. to_real t2)) in
774 let s = string_of_ufloat_type ts in
775 let strat =
776 Sapply_trans
777 ( "apply",
779 "udiv_exact_" ^ s ^ "_error_propagation";
780 "with";
781 sprintf "%s" (term_to_str t1);
784 strat_for_t1;
785 default_strat ();
786 default_strat ();
787 default_strat ();
788 default_strat ();
789 default_strat ();
790 default_strat ();
793 let s =
794 Sapply_trans
795 ("assert", [ term_to_str (left <=. err) ], [ strat; default_strat () ])
797 let f, strat =
798 if t_equal err err_simp then
799 (left <=. err, s)
800 else
801 ( left <=. err_simp,
802 Sapply_trans
803 ("assert", [ term_to_str (left <=. err) ], [ s; default_strat () ])
806 let term_info = add_computed_fw_error terms_info t (fe, f, strat) in
807 (term_info, Some f, strat)
808 | Some _ -> (info.terms_info, None, default_strat ())
809 | None -> (
810 match t_info.error with
811 (* `t` has a forward error, we look if it is the result of the
812 approximation of a known function, in which case we use the function's
813 propagation lemma to compute an error bound *)
814 | Some e -> (
815 match get_known_fn_and_args t e.exact with
816 | Some (fn, args) ->
817 (* First we compute the potential forward errors of the function
818 args *)
819 let info, strats =
820 List.fold_left
821 (fun (info, l) t ->
822 let terms_info, f, s = get_error_fmlas info t in
823 let s =
824 match f with
825 | None -> s
826 | Some f ->
827 Sapply_trans
828 ("assert", [ term_to_str f ], [ s; default_strat () ])
830 ({ info with terms_info }, s :: l))
831 (info, []) args
833 use_known_thm info t fn args (List.rev strats)
834 | None -> (info.terms_info, None, default_strat ()))
835 | None -> (info.terms_info, None, default_strat ())))
837 let parse_error is_match exact t =
838 (* If it we don't find a "relative" error we have an absolute error of `t` *)
839 let e_default = { exact; rel = zero; factor = abs exact; cst = t } in
840 let rec parse t =
841 if is_match t then
842 ({ exact; rel = one; factor = t; cst = zero }, true)
843 else
844 match t.t_node with
845 | Tapp (ls, [ t1; t2 ]) when is_add_ls ls ->
846 let e1, _ = parse t1 in
847 if is_zero e1.rel then
848 let e2, _ = parse t2 in
849 if is_zero e2.rel then
850 (e_default, false)
851 else
852 (* FIXME: we should combine the factor of t1 and factor of t2 *)
853 ({ e2 with cst = e2.cst ++. t1 }, false)
854 else
855 ({ e1 with cst = e1.cst ++. t2 }, false)
856 | Tapp (ls, [ t1; t2 ]) when is_sub_ls ls ->
857 (* FIXME : should be the same as for addition *)
858 let e1, _ = parse t1 in
859 if is_zero e1.rel then
860 (e_default, false)
861 else
862 ({ e1 with cst = e1.cst --. t2 }, false)
863 | Tapp (ls, [ t1; t2 ]) when is_mul_ls ls ->
864 let e1, is_factor = parse t1 in
865 if is_zero e1.rel then
866 let e2, is_factor = parse t2 in
867 if is_zero e2.rel then
868 (e_default, false)
869 else if is_factor then
870 ({ e2 with rel = e2.rel **. t1 }, true)
871 else
872 ({ e2 with cst = e2.cst **. t1 }, false)
873 else if is_factor then
874 ({ e1 with rel = e1.rel **. t2 }, true)
875 else
876 ({ e1 with cst = e1.cst **. t2 }, false)
877 | _ -> (e_default, false)
879 fst (parse t)
881 (* Parse `|f i - exact_f i| <= C`.
882 * We try to decompose `C` to see if it has the form `A (f' i) + B` where
883 * |f i| <= f' i for i in a given range.
884 * Used for sum error propagation.
886 let parse_fn_error i exact c =
887 let rec is_match t =
888 let extract_fn_and_args fn l =
889 if ls_equal fn fs_func_app then
890 match (List.hd l).t_node with
891 | Tapp (fn, []) -> (fn, List.tl l)
892 | _ -> (fn, [])
893 else
894 (fn, l)
896 match t.t_node with
897 | Tapp (ls, [ t' ]) when is_abs_ls ls && is_match t' -> true
898 | Tapp (fn', l) -> (
899 let _, args = extract_fn_and_args fn' l in
900 match args with
901 | [ i' ] when t_equal i i' -> true
902 | _ -> false)
903 | _ -> false
905 parse_error is_match exact c
907 (* Parse `|x_approx - x| <= C`.
908 * We try to decompose `C` to see if it has the form `Ax' + B` where |x| <= x' *)
909 let parse_error x c =
910 let _is_sum, _f, _i, _j =
911 match x.t_node with
912 | _ -> (false, None, None, None)
914 let is_match t =
915 t_equal t x
917 match t.t_node with
918 | Tapp (ls, [ t' ]) when is_abs_ls ls -> (
919 if t_equal t' x then
920 true
921 else
922 match t'.t_node with
923 | _ -> false)
924 | _ -> false
926 parse_error is_match x c
928 let is_var_equal t vs =
929 match t.t_node with
930 | Tvar vs' when vs_equal vs vs' -> true
931 | _ -> false
933 (* Looks for |to_real x - exact_x| <= C. *)
934 let parse_ineq t =
935 match t.t_node with
936 | Tapp (ls, [ t; c ]) when is_ineq_ls ls -> (
937 match t.t_node with
938 | Tapp (ls, [ t ]) when is_abs_ls ls -> (
939 match t.t_node with
940 | Tapp (ls, [ x; exact_x ]) when is_sub_ls ls -> (
941 match x.t_node with
942 | Tapp (ls, [ x ]) when is_to_real_ls ls -> Some (x, exact_x, c)
943 | _ -> None)
944 | _ -> None)
945 | _ -> None)
946 | _ -> None
948 (* Looks for `forall i. P -> |to_real (fn i) - exact_fn i| <= A(fn' i) + B` (or
949 the same formula without the hypothesis), with `i` an integer. We also match
950 on a potential hypothesis P because usually there is a hypothesis on the
951 bounds of i. A forward error on a function of type (int -> real) is can be
952 used for sum error propagation *)
953 let collect_in_quant info q =
954 let vs, _, t = t_open_quant q in
955 match vs with
956 | [ i ] when ty_equal i.vs_ty ty_int -> (
957 let t =
958 match t.t_node with
959 | Tbinop (Timplies, _, t) -> t
960 | _ -> t
962 match parse_ineq t with
963 | Some (x, exact_x, c) -> (
964 match (x.t_node, exact_x.t_node) with
965 | Tapp (fn, l), Tapp (exact_fn, l') -> (
966 let extract_fn_and_args fn l =
967 if ls_equal fn fs_func_app then
968 match (List.hd l).t_node with
969 | Tapp (fn, []) -> (fn, List.tl l)
970 | _ -> (fn, [])
971 else
972 (fn, l)
974 let fn, args = extract_fn_and_args fn l in
975 let exact_fn, args' = extract_fn_and_args exact_fn l' in
976 match (args, args') with
977 | [ i' ], [ i'' ] when is_var_equal i' i && is_var_equal i'' i -> (
978 let e = parse_fn_error i' (t_app_infer exact_fn []) c in
979 match e.factor.t_node with
980 | Tapp (fn'', args) ->
981 let fn' = fst (extract_fn_and_args fn'' args) in
982 let fns_info =
983 Mls.add fn { e with factor = t_app_infer fn' [] } info.fns_info
985 { info with fns_info }
986 | _ -> info)
987 | _ -> info)
988 | _ -> info)
989 | _ -> info)
990 | _ -> info
992 let rec collect info f =
993 match f.t_node with
994 | Tbinop (Tand, f1, f2) ->
995 let info = collect info f1 in
996 collect info f2
997 | Tapp (ls, [ _; _ ]) when is_ineq_ls ls -> (
998 (* term of the form x op y where op is an inequality "<" or "<=" over reals
999 FIXME: we should handle ">" and ">=" as well *)
1000 match parse_ineq f with
1001 | Some (x, exact_x, c) ->
1002 let error_fmla = parse_error exact_x c in
1003 let terms_info = add_fw_error info.terms_info x error_fmla in
1004 { info with terms_info }
1005 | _ -> info)
1006 (* `r = uop args` or `uop args = r` *)
1007 | Tapp (ls, [ t1; t2 ]) when ls_equal ls ps_equ -> (
1008 match t1.t_node with
1009 | Tapp (ls, args) when is_uop ls ->
1010 let terms_info = add_ieee_op info.terms_info ls t2 args in
1011 { info with terms_info }
1012 | _ -> (
1013 match t2.t_node with
1014 | Tapp (ls, args) when is_uop ls ->
1015 let terms_info = add_ieee_op info.terms_info ls t1 args in
1016 { info with terms_info }
1017 | _ -> info))
1018 | Tquant (Tforall, tq) ->
1019 (* Collect potential error on function (for sum propagation lemma) *)
1020 collect_in_quant info tq
1021 | _ -> info
1024 * We look for relevant axioms in the task, and we add corresponding
1025 * inequalities to `info`.
1026 * The formulas we look for have one of the following structures :
1027 * - `|to_real x - exact_x| <= A` : fw error on `x`
1028 * - `r = t1 uop t2` : `r` is the result of an IEEE operation on `t1` and `t2`
1029 * - `forall (i:int). P -> |to_real (f i) - exact_f i| <= A` : fw error on `f`
1030 * - `forall (i:int). |to_real (f i) - exact_f i| <= A` : fw error on `f`
1032 * We also look for definitions of the form `r = t1 uop t2`
1034 let collect_info info d =
1035 match d.d_node with
1036 | Dprop (kind, _pr, f) when kind = Paxiom || kind = Plemma -> collect info f
1037 | Dlogic defs ->
1038 let ls_defs =
1039 List.fold_left
1040 (fun info (ls, ls_def) ->
1041 match ls.ls_value with
1042 | Some ty when is_ty_float ty ->
1043 let _vsl, t = open_ls_defn ls_def in
1044 Mls.add ls t info
1045 | _ -> info)
1046 info.ls_defs defs
1048 { info with ls_defs }
1049 | _ -> info
1051 let init_symbols env printer =
1052 let real = Env.read_theory env [ "real" ] "Real" in
1053 let lt = ns_find_ls real.th_export [ Ident.op_infix "<" ] in
1054 let le = ns_find_ls real.th_export [ Ident.op_infix "<=" ] in
1055 let real_infix = Env.read_theory env [ "real" ] "RealInfix" in
1056 let lt_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "<." ] in
1057 let le_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "<=." ] in
1058 let add = ns_find_ls real.th_export [ Ident.op_infix "+" ] in
1059 let sub = ns_find_ls real.th_export [ Ident.op_infix "-" ] in
1060 let mul = ns_find_ls real.th_export [ Ident.op_infix "*" ] in
1061 let _div = ns_find_ls real.th_export [ Ident.op_infix "/" ] in
1062 let minus = ns_find_ls real.th_export [ Ident.op_prefix "-" ] in
1063 let add_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "+." ] in
1064 let sub_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "-." ] in
1065 let mul_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "*." ] in
1066 let div_infix = ns_find_ls real_infix.th_export [ Ident.op_infix "/." ] in
1067 let minus_infix = ns_find_ls real_infix.th_export [ Ident.op_prefix "-." ] in
1068 let real_abs = Env.read_theory env [ "real" ] "Abs" in
1069 let abs = ns_find_ls real_abs.th_export [ "abs" ] in
1070 let exp_log_th = Env.read_theory env [ "real" ] "ExpLog" in
1071 let exp = ns_find_ls exp_log_th.th_export [ "exp" ] in
1072 let log = ns_find_ls exp_log_th.th_export [ "log" ] in
1073 let log2 = ns_find_ls exp_log_th.th_export [ "log2" ] in
1074 let log10 = ns_find_ls exp_log_th.th_export [ "log10" ] in
1075 let trigo_th = Env.read_theory env [ "real" ] "Trigonometry" in
1076 let sin = ns_find_ls trigo_th.th_export [ "sin" ] in
1077 let cos = ns_find_ls trigo_th.th_export [ "cos" ] in
1078 let usingle = Env.read_theory env [ "ufloat" ] "USingle" in
1079 let udouble = Env.read_theory env [ "ufloat" ] "UDouble" in
1080 let usingle_lemmas = Env.read_theory env [ "ufloat" ] "USingleLemmas" in
1081 let udouble_lemmas = Env.read_theory env [ "ufloat" ] "UDoubleLemmas" in
1082 let mk_ufloat_symbols th _th_lemmas ty =
1083 let f th s =
1084 try ns_find_ls th.th_export [ s ] with
1085 | Not_found -> failwith (Format.sprintf "Symbol %s not found" s)
1088 ufloat_type = ns_find_ts th.th_export [ ty ];
1089 to_real = f th "to_real";
1090 uadd = f th "uadd";
1091 usub = f th "usub";
1092 umul = f th "umul";
1093 udiv = f th "udiv";
1094 uminus = f th "uminus";
1095 udiv_exact = f th "udiv_exact";
1096 uadd_infix = f th (Ident.op_infix "++.");
1097 usub_infix = f th (Ident.op_infix "--.");
1098 umul_infix = f th (Ident.op_infix "**.");
1099 udiv_infix = f th (Ident.op_infix "//.");
1100 uminus_prefix = f th (Ident.op_prefix "--.");
1101 udiv_exact_infix = f th (Ident.op_infix "///.");
1102 eps = fs_app (f th "eps") [] ty_real;
1103 eta = fs_app (f th "eta") [] ty_real;
1106 let usingle_symbols = mk_ufloat_symbols usingle usingle_lemmas "usingle" in
1107 let udouble_symbols = mk_ufloat_symbols udouble udouble_lemmas "udouble" in
1108 symbols :=
1109 Some
1111 add;
1112 sub;
1113 mul;
1114 _div;
1115 minus;
1116 add_infix;
1117 sub_infix;
1118 mul_infix;
1119 div_infix;
1120 minus_infix;
1122 lt_infix;
1124 le_infix;
1125 abs;
1126 exp;
1127 log;
1128 log2;
1129 log10;
1130 sin;
1131 cos;
1132 usingle_symbols;
1133 udouble_symbols;
1134 ident_printer = printer.Trans.printer;
1135 tv_printer = printer.Trans.aprinter;
1138 (*** letify **)
1140 let letify f =
1141 let rec compute_subterms (acc : int Mterm.t) f : int Mterm.t =
1142 match f.t_node with
1143 | Ttrue
1144 | Tfalse
1145 | Tconst _
1146 | Tvar _
1147 | Tapp (_, []) ->
1149 | _ -> (
1150 match f.t_ty with
1151 | None -> t_fold compute_subterms acc f
1152 | Some _ -> (
1154 let n = Mterm.find f acc in
1155 Mterm.add f (n + 1) acc
1156 with
1157 | Not_found ->
1158 let acc = t_fold compute_subterms acc f in
1159 Mterm.add f 1 acc))
1161 let m = compute_subterms Mterm.empty f in
1162 let m =
1163 Mterm.fold
1164 (fun t n acc ->
1165 if n <= 1 then
1167 else
1168 let id = Ident.id_fresh ?loc:t.t_loc "t" in
1169 let vs = create_vsymbol id (Option.get t.t_ty) in
1170 Mterm.add t vs acc)
1171 m Mterm.empty
1173 let rec letify_rec f =
1175 let vs = Mterm.find f m in
1176 t_var vs
1177 with
1178 | Not_found -> t_map letify_rec f
1180 let letified_f = letify_rec f in
1181 let l = Mterm.bindings m in
1182 let l =
1183 List.sort (fun (t1, _) (t2, _) -> compare (term_size t2) (term_size t1)) l
1185 List.fold_left
1186 (fun acc (t, vs) ->
1187 let t = t_map letify_rec t in
1188 t_let_close vs t acc)
1189 letified_f l
1191 (*** complete strategy *)
1193 let fw_propagation args env naming_table lang task =
1194 (* let naming_table = Args_wrapper.build_naming_tables task in *)
1195 init_symbols env naming_table;
1196 let printer = Args_wrapper.build_naming_tables task in
1197 (* Update the printer at each call, but not the symbols *)
1198 symbols :=
1199 Some
1201 !!symbols with
1202 ident_printer = printer.Trans.printer;
1203 tv_printer = printer.Trans.aprinter;
1205 (* We start by collecting infos from the hypotheses of the task *)
1206 let info =
1207 List.fold_left collect_info
1208 { terms_info = Mterm.empty; fns_info = Mls.empty; ls_defs = Mls.empty }
1209 (task_decls task)
1211 let floats =
1212 match args with
1213 (* If no argument is given, then we perform forward error propagation for
1214 every ufloat term of the goal. *)
1215 | [] ->
1216 let goal = task_goal_fmla task in
1217 get_floats goal
1218 | [ floats ] ->
1219 Args_wrapper.parse_and_type_list ~lang ~as_fmla:false floats naming_table
1220 | _ ->
1221 raise
1222 (Args_wrapper.Arg_error
1223 "this strategy expects an optional comma-separated list of terms as \
1224 argument")
1226 (* For each float `x`, we try to compute a formula of the form `|x - exact_x|
1227 <= A x' + B` where `exact_x` is the real value which is approximated by the
1228 float `x` and `|exact_x| <= x'`. For this, forward error propagation is
1229 performed using propagation lemmas of the ufloat stdlib with the data
1230 contained in `info`. For each new formula created, a proof tree is
1231 generated with the necessary steps to prove it. *)
1232 let f, strats =
1233 List.fold_left
1234 (fun (f, l) t ->
1235 match get_error_fmlas info t with
1236 | _, None, _ -> (f, l)
1237 | _, Some f', s -> (t_and_simp f f', s :: l))
1238 (t_true, []) floats
1240 let f' = letify f in
1241 if List.length strats = 0 then
1242 (* Nothing to do *)
1243 default_strat ()
1244 else if List.length strats = 1 then
1245 (* We only have an assertion on one float so no need to use split *)
1246 let f_strat =
1247 Sapply_trans ("assert", [ term_to_str f ], strats @ [ default_strat () ])
1249 Sapply_trans ("assert", [ term_to_str f' ], [ f_strat; default_strat () ])
1250 else
1251 (* We assert a conjunction of formulas, one for each float in the goal for
1252 which we can use propagation lemmas. We have one strat for each of this
1253 goal so we split our assertions and prove each subgoal with the
1254 corresponding strat *)
1255 let s = Sapply_trans ("split_vc", [], List.rev strats) in
1256 let f_strat =
1257 Sapply_trans ("assert", [ term_to_str f ], [ s; default_strat () ])
1259 Sapply_trans ("assert", [ term_to_str f' ], [ f_strat; default_strat () ])
1261 let () =
1262 register_strat_with_args "forward_propagation" fw_propagation
1263 ~desc:"Compute@ forward@ error@ of@ float@ computations."