1 (********************************************************************)
3 (* The Why3 Verification Platform / The Why3 Development Team *)
4 (* Copyright 2010-2022 -- Inria - CNRS - Paris-Saclay University *)
6 (* This software is distributed under the terms of the GNU Lesser *)
7 (* General Public License version 2.1, with the special exception *)
8 (* on linking described in file LICENSE. *)
10 (********************************************************************)
22 type ufloat_symbols
= {
23 ufloat_type
: tysymbol
;
35 uminus_prefix
: lsymbol
;
36 udiv_exact_infix
: lsymbol
;
51 minus_infix
: lsymbol
;
63 usingle_symbols
: ufloat_symbols
;
64 udouble_symbols
: ufloat_symbols
;
65 ident_printer
: ident_printer
;
66 tv_printer
: ident_printer
;
69 let symbols = ref None
70 let ( !! ) s
= Option.get
!s
72 (* A term 't' having a forward error 'e' means "|t - e.exact| <= e.rel *
73 e.factor + e.cst", where |exact_t| <= t' *)
74 type forward_error
= {
81 (* This type corresponds to the numeric info we have on a real/float term *)
83 error
: forward_error
option;
84 computed_error
: (forward_error
* term
* proof_tree
) option;
86 * "Some (op, [x; y])" means that the term "t" is the result of the FP operation
89 ieee_op
: (lsymbol
* term list
) option;
93 terms_info
: term_info
Mterm.t
;
94 (* fns_info is used for the sum propagation lemma *)
95 fns_info
: forward_error
Mls.t
;
96 (* ls_defs is used to store logic definitions *)
100 let default_strat () = Sdo_nothing
105 (Number.real_literal ~radix
:10 ~neg
:false ~
int:"0" ~frac
:"0" ~exp
:None
))
111 (Number.real_literal ~radix
:10 ~neg
:false ~
int:"1" ~frac
:"0" ~exp
:None
))
114 let is_zero t
= t_equal
zero t
115 let is_one t
= t_equal
one t
119 (* Don't add an abs symbol on top of another *)
120 | Tapp
(ls
, [ _
]) when ls_equal
!!symbols.abs ls
-> t
121 | _
-> fs_app
!!symbols.abs [ t
] ty_real
124 let symbols = !!symbols in
125 ls_equal ls
symbols.lt
|| ls_equal ls
symbols.le
126 || ls_equal ls
symbols.lt_infix
127 || ls_equal ls
symbols.le_infix
129 let is_add_ls ls
= ls_equal ls
!!symbols.add
|| ls_equal ls
!!symbols.add_infix
130 let is_sub_ls ls
= ls_equal ls
!!symbols.sub
|| ls_equal ls
!!symbols.sub_infix
131 let is_mul_ls ls
= ls_equal ls
!!symbols.mul
|| ls_equal ls
!!symbols.mul_infix
134 ls_equal ls
!!symbols.minus
|| ls_equal ls
!!symbols.minus_infix
136 let is_abs_ls ls
= ls_equal ls
!!symbols.abs
138 let is_to_real_ls ls
=
139 ls_equal ls
!!symbols.usingle_symbols
.to_real
140 || ls_equal ls
!!symbols.udouble_symbols
.to_real
143 ls_equal ls
!!symbols.usingle_symbols
.uadd
144 || ls_equal ls
!!symbols.usingle_symbols
.uadd_infix
145 || ls_equal ls
!!symbols.udouble_symbols
.uadd
146 || ls_equal ls
!!symbols.udouble_symbols
.uadd_infix
149 ls_equal ls
!!symbols.usingle_symbols
.usub
150 || ls_equal ls
!!symbols.usingle_symbols
.usub_infix
151 || ls_equal ls
!!symbols.udouble_symbols
.usub
152 || ls_equal ls
!!symbols.udouble_symbols
.usub_infix
155 ls_equal ls
!!symbols.usingle_symbols
.umul
156 || ls_equal ls
!!symbols.usingle_symbols
.umul_infix
157 || ls_equal ls
!!symbols.udouble_symbols
.umul
158 || ls_equal ls
!!symbols.udouble_symbols
.umul_infix
161 ls_equal ls
!!symbols.usingle_symbols
.udiv
162 || ls_equal ls
!!symbols.usingle_symbols
.udiv_infix
163 || ls_equal ls
!!symbols.udouble_symbols
.udiv
164 || ls_equal ls
!!symbols.udouble_symbols
.udiv_infix
166 let is_uminus_ls ls
=
167 ls_equal ls
!!symbols.usingle_symbols
.uminus
168 || ls_equal ls
!!symbols.usingle_symbols
.uminus_prefix
169 || ls_equal ls
!!symbols.udouble_symbols
.uminus
170 || ls_equal ls
!!symbols.udouble_symbols
.uminus_prefix
172 let is_uexact_div_ls ls
=
173 ls_equal ls
!!symbols.usingle_symbols
.udiv_exact
174 || ls_equal ls
!!symbols.usingle_symbols
.udiv_exact_infix
175 || ls_equal ls
!!symbols.udouble_symbols
.udiv_exact
176 || ls_equal ls
!!symbols.udouble_symbols
.udiv_exact_infix
179 is_uadd_ls ls
|| is_usub_ls ls
|| is_umul_ls ls
|| is_udiv_ls ls
180 || is_uminus_ls ls
|| is_uexact_div_ls ls
182 let minus t
= fs_app
!!symbols.minus_infix
[ t
] ty_real
186 | Tapp
(ls
, [ t'
]) when is_minus_ls ls
-> t'
189 let add t1 t2
= fs_app
!!symbols.add_infix
[ t1
; t2
] ty_real
194 else if is_zero t2
then
199 let sub t1 t2
= fs_app
!!symbols.sub_infix
[ t1
; t2
] ty_real
204 else if is_zero t2
then
209 let mul t1 t2
= fs_app
!!symbols.mul_infix
[ t1
; t2
] ty_real
210 let div t1 t2
= fs_app
!!symbols.div_infix
[ t1
; t2
] ty_real
213 if is_zero t1
|| is_zero t2
then
215 else if is_one t1
then
217 else if is_one t2
then
220 match (t1
.t_node
, t2
.t_node
) with
221 | Tapp
(ls1
, [ t1
]), Tapp
(ls2
, [ t2
]) when is_abs_ls ls1
&& is_abs_ls ls2
229 else if is_one t2
then
234 let ( +. ) x y
= add x y
235 let ( -. ) x y
= sub x y
236 let ( *. ) x y
= mul x y
237 let ( /. ) x y
= div x y
238 let ( ++. ) x y
= add_simp x y
239 let ( --. ) x y
= sub_simp x y
240 let ( **. ) x y
= mul_simp x y
241 let ( //. ) x y
= div_simp x y
242 let ( <=. ) x y
= ps_app
!!symbols.le_infix
[ x
; y
]
245 match ty
.ty_node
with
248 ts_equal v
!!symbols.usingle_symbols
.ufloat_type
249 || ts_equal v
!!symbols.udouble_symbols
.ufloat_type
257 if ts_equal ieee_type
!!symbols.usingle_symbols
.ufloat_type
then
258 !!symbols.usingle_symbols
.eps
259 else if ts_equal ieee_type
!!symbols.udouble_symbols
.ufloat_type
then
260 !!symbols.udouble_symbols
.eps
262 failwith
(asprintf
"Unsupported type %a" Pretty.print_ts ieee_type
)
265 if ts_equal ieee_type
!!symbols.usingle_symbols
.ufloat_type
then
266 !!symbols.usingle_symbols
.eta
267 else if ts_equal ieee_type
!!symbols.udouble_symbols
.ufloat_type
then
268 !!symbols.udouble_symbols
.eta
270 failwith
(asprintf
"Unsupported type %a" Pretty.print_ts ieee_type
)
272 let to_real ieee_type t
=
274 if ts_equal ieee_type
!!symbols.usingle_symbols
.ufloat_type
then
275 !!symbols.usingle_symbols
.to_real
276 else if ts_equal ieee_type
!!symbols.udouble_symbols
.ufloat_type
then
277 !!symbols.udouble_symbols
.to_real
279 failwith
(asprintf
"Unsupported type %a" Pretty.print_ts ieee_type
)
281 fs_app
to_real [ t
] ty_real
283 let get_info info t
=
284 try Mterm.find t info
with
285 | Not_found
-> { error
= None
; computed_error
= None
; ieee_op
= None
}
287 let add_fw_error info t error
=
290 | Tapp
(ls
, [ t ]) when is_to_real_ls ls
-> t
293 let t_info = get_info info
t in
294 let t_info = { t_info with error
= Some error
} in
295 Mterm.add t t_info info
297 let add_computed_fw_error info
t ((fe
, _
, _
) as error
) =
300 | Tapp
(ls
, [ t ]) when is_to_real_ls ls
-> t
303 let t_info = get_info info
t in
304 let t_info = { t_info with computed_error
= Some error
; error
= Some fe
} in
305 Mterm.add t t_info info
307 let add_ieee_op info ls
t args
=
308 let t_info = get_info info
t in
309 let t_info = { t_info with ieee_op
= Some
(ls
, args
) } in
310 Mterm.add t t_info info
314 | None
-> assert false
316 match ty
.ty_node
with
317 | Tyvar _
-> assert false
318 | Tyapp
(ts
, []) -> ts
321 (* Return the float terms inside `t`. Note that if `t` is an application of type
322 float we don't return the floats that it potentially contains *)
323 (* TODO: Don't put duplicates in list !!! *)
324 let rec get_floats t =
326 | Some ty
when is_ty_float ty
-> [ t ]
329 | Tapp
(_
, tl
) -> List.fold_left
(fun l
t -> l
@ get_floats t) [] tl
332 let string_of_ufloat_type ts
=
333 if ts_equal ts
!!symbols.usingle_symbols
.ufloat_type
then
335 else if ts_equal ts
!!symbols.udouble_symbols
.ufloat_type
then
338 failwith
(asprintf
"Unsupported type %a" Pretty.print_ts ts
)
340 let string_of_ufloat_type_and_op ts uop
=
341 let ty_str = string_of_ufloat_type ts
in
343 if is_uadd_ls uop
then
345 else if is_usub_ls uop
then
347 else if is_umul_ls uop
then
349 else if is_uexact_div_ls uop
then
352 failwith
(asprintf
"Unsupported uop '%a'" Pretty.print_ls uop
)
354 uop_str ^
"_" ^
ty_str
358 (val Pretty.create
!!symbols.ident_printer
!!symbols.tv_printer
359 !!symbols.ident_printer
360 (Ident.create_ident_printer
[]))
362 Format.asprintf
"%a" P.print_term
t
364 (* Error on an IEEE op when propagation is needed (eg. we have a forward error
365 on t1 and/or on t2) *)
366 let combine_uop_errors info uop t1 e1 t2 e2 r strat_for_t1 strat_for_t2
=
370 let to_real = to_real ts in
371 let rel_err, rel_err_simp
=
372 if is_uadd_ls uop
|| is_usub_ls uop
then
373 (* Relative error for addition and sustraction *)
374 (e1
.rel
+. e2
.rel
+. eps, e1
.rel
++. e2
.rel
++. eps)
376 (* Relative error for multiplication *)
377 ( eps +. ((e1
.rel
+. e2
.rel
+. (e1
.rel
*. e2
.rel
)) *. (one +. eps)),
378 eps ++. ((e1
.rel
++. e2
.rel
++. (e1
.rel
**. e2
.rel
)) **. (one ++. eps))
381 let cst_err, cst_err_simp
=
382 if is_uadd_ls uop
|| is_usub_ls uop
then
383 (* Constant error for addition and sustraction *)
384 ( ((one +. eps +. e2
.rel
) *. e1
.cst
) +. ((one +. eps +. e1
.rel
) *. e2
.cst
),
385 ((one ++. eps ++. e2
.rel
) **. e1
.cst
)
386 ++. ((one ++. eps ++. e1
.rel
) **. e2
.cst
) )
388 (* Constant error for multiplication *)
389 ( (((e2
.cst
+. (e2
.cst
*. e1
.rel
)) *. e1
.factor
)
390 +. ((e1
.cst
+. (e1
.cst
*. e2
.rel
)) *. e2
.factor
)
391 +. (e1
.cst
*. e2
.cst
))
394 (((one ++. eps) **. (e2
.cst
++. (e2
.cst
**. e1
.rel
))) **. e1
.factor
)
395 ++. (((one ++. eps) **. (e1
.cst
++. (e1
.cst
**. e2
.rel
))) **. e2
.factor
)
396 ++. ((one ++. eps) **. e1
.cst
**. e2
.cst
)
400 if is_uadd_ls uop
|| is_usub_ls uop
then
401 (rel_err *. (e1
.factor
+. e2
.factor
)) +. cst_err
403 (rel_err *. (e1
.factor
*. e2
.factor
)) +. cst_err
406 if is_uadd_ls uop
|| is_usub_ls uop
then
407 (rel_err **. (e1
.factor
++. e2
.factor
)) ++. cst_err
409 (rel_err **. e1
.factor
**. e2
.factor
) ++. cst_err
411 let exact, exact_simp
=
412 if is_uadd_ls uop
then
413 (e1
.exact +. e2
.exact, e1
.factor
++. e2
.factor
)
414 else if is_usub_ls uop
then
415 (e1
.exact -. e2
.exact, e1
.factor
++. e2
.factor
)
417 (e1
.exact *. e2
.exact, e1
.factor
**. e2
.factor
)
419 let str = string_of_ufloat_type_and_op ts uop
in
424 str ^
"_error_propagation";
426 sprintf
"%s,%s" (term_to_str t1
) (term_to_str t2
);
440 let f = abs (to_real r
-. exact) <=. total_err in
442 { exact; rel
= rel_err_simp
; factor
= exact_simp
; cst
= cst_err_simp
}
445 if t_equal
total_err total_err_simp then
448 let f_simp = abs (to_real r
-. exact) <=. total_err_simp in
450 Sapply_trans
("assert", [ term_to_str f ], [ strat; default_strat () ])
453 let info = add_computed_fw_error info r
(fw_err, f, strat) in
456 (* Error on a IEEE op when no propagation is needed (eg. we don't have a forward
457 error on t1 nor t2) *)
458 let basic_uop_error info uop r t1 t2
=
462 let to_real = to_real ts in
464 if is_uadd_ls uop
then
465 to_real t1
+. to_real t2
466 else if is_usub_ls uop
then
467 to_real t1
-. to_real t2
469 to_real t1
*. to_real t2
472 if is_umul_ls uop
then
477 let total_err = (eps *. abs exact) ++. cst_err in
479 add_fw_error info r
{ exact; rel
= eps; factor
= abs exact; cst
= cst_err }
481 let f = abs (to_real r
-. exact) <=. total_err in
482 (info, f, default_strat ())
484 (* Generates the formula and the strat for the forward error of `r = uop t1
486 let use_ieee_thms info uop r t1 t2 strat_for_t1 strat_for_t2
=
487 let t1_info = get_info info t1
in
488 let t2_info = get_info info t2
in
489 match (t1_info.error
, t2_info.error
) with
490 (* No propagation needed, we use the basic lemma *)
491 | None
, None
-> basic_uop_error info uop r t1 t2
492 (* We have an error on at least one of t1 and t2, use the propagation lemma *)
494 let to_real = to_real (get_ts r
) in
495 let get_err_or_default t t_info =
496 match t_info.error
with
498 { exact = to_real t; rel
= zero; factor
= abs (to_real t); cst
= zero }
501 let e1 = get_err_or_default t1
t1_info in
502 let e2 = get_err_or_default t2
t2_info in
503 combine_uop_errors info uop t1
e1 t2
e2 r strat_for_t1 strat_for_t2
505 (* Returns None if the function is unsupported by the strategy. Otherwise,
506 returns its symbol as well as its arguments. *)
507 let get_known_fn_and_args _t x
=
510 when ls_equal ls
!!symbols.log
|| ls_equal ls
!!symbols.exp
511 || ls_equal ls
!!symbols.log2
512 || ls_equal ls
!!symbols.log10
513 || ls_equal ls
!!symbols.sin
|| ls_equal ls
!!symbols.cos
->
517 match arg
.t_node
with
518 | Tapp
(ls
, [ t ]) when is_to_real_ls ls
-> t
525 (* Returns the forward error formula associated with the propagation lemma of
526 the function `exact_fn`. `app_approx` *)
527 let get_fn_errs info exact_fn app_approx arg_approx
=
528 let e_arg = Option.get
(get_info info.terms_info arg_approx
).error
in
529 let e_app = Option.get
(get_info info.terms_info app_approx
).error
in
530 if ls_equal exact_fn
!!symbols.exp
then
531 let cst_err = e_app.cst
in
533 fs_app exact_fn
[ (e_arg.rel
*. e_arg.factor
) +. e_arg.cst
] ty_real
536 fs_app exact_fn
[ (e_arg.rel
**. e_arg.factor
) ++. e_arg.cst
] ty_real
538 let rel_err = e_app.rel
+. ((a -. one) *. (one +. e_app.rel
)) in
540 e_app.rel
++. ((a_simp --. one) **. (one ++. e_app.rel
))
546 fs_app exact_fn
[ e_arg.exact ] ty_real
,
549 ls_equal exact_fn
!!symbols.log
550 || ls_equal exact_fn
!!symbols.log2
551 || ls_equal exact_fn
!!symbols.log10
555 [ one -. (((e_arg.rel
*. e_arg.factor
) +. e_arg.cst
) /. e_arg.exact) ]
560 t_equal
e_arg.exact e_arg.factor
561 || t_equal
(abs e_arg.exact) e_arg.factor
564 [ one --. (e_arg.rel
++. (e_arg.cst
//. e_arg.exact)) ]
570 --. (((e_arg.rel
**. e_arg.factor
) ++. e_arg.cst
) //. e_arg.exact);
574 let cst_err = (minus a *. (one +. e_app.rel
)) +. e_app.cst
in
576 (minus_simp a_simp **. (one ++. e_app.rel
)) ++. e_app.cst
578 let rel_err = e_app.rel
in
583 abs (fs_app exact_fn
[ e_arg.exact ] ty_real
),
585 else if ls_equal exact_fn
!!symbols.sin
|| ls_equal exact_fn
!!symbols.cos
588 (((e_arg.rel
*. e_arg.factor
) +. e_arg.cst
) *. (one +. e_app.rel
))
592 (((e_arg.rel
**. e_arg.factor
) ++. e_arg.cst
) **. (one ++. e_app.rel
))
595 let rel_err = e_app.rel
in
600 abs (fs_app exact_fn
[ e_arg.exact ] ty_real
),
605 (* Returns the forward error formula and the strat associated with the
606 application of the propagation lemma for `fn`. The argument `strat`
607 corresponds to the strat that is used to prove the error on `arg_approx` (the
608 argument of the function) *)
609 let apply_fn_thm info fn app_approx arg_approx
strat =
610 let e_arg = Option.get
(get_info info.terms_info arg_approx
).error
in
611 let to_real = to_real (get_ts app_approx
) in
612 let exact = fs_app fn
[ e_arg.exact ] ty_real
in
613 let fn_str = fn
.ls_name
.id_string
in
614 let ty_str = string_of_ufloat_type (get_ts app_approx
) in
615 let rel_err, rel_err_simp, cst_err, cst_err_simp, app'
, nb
=
616 get_fn_errs info fn app_approx arg_approx
622 sprintf
"%s_%s_error_propagation" fn_str ty_str;
624 sprintf
"%s" (term_to_str arg_approx
);
626 [ strat ] @ List.init nb
(fun _
-> Sdo_nothing
) )
628 let total_err = (rel_err *. app'
) +. cst_err in
629 let total_err_simp = (app'
**. rel_err_simp) ++. cst_err_simp in
630 let left = abs (to_real app_approx
-. exact) in
632 if t_equal
total_err total_err_simp then
633 (left <=. total_err, strat)
635 ( left <=. total_err_simp,
638 [ term_to_str (left <=. total_err) ],
639 [ strat; default_strat () ] ) )
642 add_computed_fw_error info.terms_info app_approx
643 ( { exact; rel
= rel_err_simp; factor
= app'
; cst
= cst_err_simp },
647 (term_info, Some
f, strat)
649 let use_known_thm info app_approx fn
args strats
=
651 (* Nothing to do if none of the args have a forward error *)
655 match (get_info info.terms_info arg
).error
with
660 (info.terms_info
, None
, List.hd strats
)
662 ls_equal fn
!!symbols.exp
|| ls_equal fn
!!symbols.log
663 || ls_equal fn
!!symbols.log2
664 || ls_equal fn
!!symbols.log10
665 || ls_equal fn
!!symbols.sin
|| ls_equal fn
!!symbols.cos
667 apply_fn_thm info fn app_approx
(List.hd
args) (List.hd strats
)
669 failwith
(asprintf
"Unsupported fn symbol '%a'" Pretty.print_ls fn
)
671 (* Recursively unfold the definition of `t` if it has one. Stops when we find an
672 error that we can use *)
673 let update_term_info info t =
674 let t_info = get_info info.terms_info
t in
676 let t'_info
= get_info info.terms_info
t'
in
677 match t'_info
.ieee_op
with
679 match t'_info
.error
with
683 if Mls.mem ls
info.ls_defs
then
684 recurse (Mls.find ls
info.ls_defs
)
687 | Tapp
(ls
, args) when is_uop ls
->
688 { t_info with ieee_op
= Some
(ls
, args) }
690 | Some _
as error
-> { t_info with error
})
691 | Some _
as ieee_op
-> { t_info with ieee_op
}
696 * Generate error formulas recursively for a term `t` using propagation lemmas.
697 * This is recursive because if `t` is an approximation of a term `u` which
698 * itself is an approximation of a term `v`, we first compute a formula for the
699 * approximation of `v` by `u` and we combine it with the formula we already
700 * have of the approximation of `u` by `t` to get a formula relating `t` to `v`.
702 let rec get_error_fmlas info t =
703 let t_info = update_term_info info t in
707 Sapply_trans
("assert", [ term_to_str f ], [ s
; default_strat () ])
710 match t_info.computed_error
with
711 | Some
(_
, f, strat) -> (info.terms_info
, Some
f, strat)
713 match t_info.ieee_op
with
714 (* `t` is the result of the IEEE minus operation *)
715 | Some
(ieee_op
, [ x
]) when is_uminus_ls ieee_op
-> (
716 let terms_info, fmla
, strat_for_x
= get_error_fmlas info x
in
717 let strat = get_strat fmla strat_for_x
in
719 let to_real = to_real ts in
720 let x_info = get_info terms_info x
in
721 match x_info.error
with
722 (* No propagation needed *)
724 (terms_info, None
, strat)
725 (* The error doesn't change since the float minus operation is exact *)
726 | Some
{ exact; rel
; factor
; cst
} ->
727 let exact = minus exact in
728 let f = abs (to_real t -. exact) <=. (rel
**. factor
) ++. cst
in
730 add_computed_fw_error terms_info t
731 ({ exact; rel
; factor
; cst
}, f, strat)
733 (terms_info, Some
f, strat))
734 (* `t` is the result of an IEEE addition/sustraction/multiplication *)
735 | Some
(ieee_op
, [ t1
; t2
])
736 when is_uadd_ls ieee_op
|| is_usub_ls ieee_op
|| is_umul_ls ieee_op
->
737 (* Get error formulas on subterms `t1` and `t2` *)
738 let terms_info, fmla1
, strat_for_t1
= get_error_fmlas info t1
in
739 let terms_info, fmla2
, strat_for_t2
=
740 get_error_fmlas { info with terms_info } t2
742 let strat_for_t1 = get_strat fmla1
strat_for_t1 in
743 let strat_for_t2 = get_strat fmla2
strat_for_t2 in
744 let terms_info, f, strats
=
745 use_ieee_thms terms_info ieee_op
t t1 t2
strat_for_t1 strat_for_t2
747 (terms_info, Some
f, strats
)
748 (* `t` is the result of an other IEEE operation *)
749 | Some
(ieee_op
, [ t1
; t2
]) when is_uexact_div_ls ieee_op
->
750 let ts = get_ts t2
in
752 let to_real = to_real ts in
753 let terms_info, fmla1
, strat_for_t1 = get_error_fmlas info t1
in
754 let strat_for_t1 = get_strat fmla1
strat_for_t1 in
755 let t1_info = get_info terms_info t1
in
756 let e1 = Option.get
t1_info.error
in
759 exact = e1.exact //. to_real t2
;
761 factor
= e1.factor
//. abs (to_real t2
);
762 cst
= (e1.cst
//. abs (to_real t2
)) ++. eta;
766 (e1.rel
*. (e1.factor
/. abs (to_real t2
)))
767 +. ((e1.cst
/. abs (to_real t2
)) +. eta)
770 (e1.rel
**. (e1.factor
//. abs (to_real t2
)))
771 ++. ((e1.cst
//. abs (to_real t2
)) ++. eta)
773 let left = abs (to_real t -. (e1.exact //. to_real t2
)) in
774 let s = string_of_ufloat_type ts in
779 "udiv_exact_" ^
s ^
"_error_propagation";
781 sprintf
"%s" (term_to_str t1
);
795 ("assert", [ term_to_str (left <=. err) ], [ strat; default_strat () ])
798 if t_equal
err err_simp then
803 ("assert", [ term_to_str (left <=. err) ], [ s; default_strat () ])
806 let term_info = add_computed_fw_error terms_info t (fe, f, strat) in
807 (term_info, Some
f, strat)
808 | Some _
-> (info.terms_info, None
, default_strat ())
810 match t_info.error
with
811 (* `t` has a forward error, we look if it is the result of the
812 approximation of a known function, in which case we use the function's
813 propagation lemma to compute an error bound *)
815 match get_known_fn_and_args t e
.exact with
817 (* First we compute the potential forward errors of the function
822 let terms_info, f, s = get_error_fmlas info t in
828 ("assert", [ term_to_str f ], [ s; default_strat () ])
830 ({ info with terms_info }, s :: l
))
833 use_known_thm info t fn
args (List.rev strats
)
834 | None
-> (info.terms_info, None
, default_strat ()))
835 | None
-> (info.terms_info, None
, default_strat ())))
837 let parse_error is_match
exact t =
838 (* If it we don't find a "relative" error we have an absolute error of `t` *)
839 let e_default = { exact; rel
= zero; factor
= abs exact; cst
= t } in
842 ({ exact; rel
= one; factor
= t; cst
= zero }, true)
845 | Tapp
(ls
, [ t1
; t2
]) when is_add_ls ls
->
846 let e1, _
= parse t1
in
847 if is_zero e1.rel
then
848 let e2, _
= parse t2
in
849 if is_zero e2.rel
then
852 (* FIXME: we should combine the factor of t1 and factor of t2 *)
853 ({ e2 with cst
= e2.cst
++. t1
}, false)
855 ({ e1 with cst
= e1.cst
++. t2
}, false)
856 | Tapp
(ls
, [ t1
; t2
]) when is_sub_ls ls
->
857 (* FIXME : should be the same as for addition *)
858 let e1, _
= parse t1
in
859 if is_zero e1.rel
then
862 ({ e1 with cst
= e1.cst
--. t2
}, false)
863 | Tapp
(ls
, [ t1
; t2
]) when is_mul_ls ls
->
864 let e1, is_factor
= parse t1
in
865 if is_zero e1.rel
then
866 let e2, is_factor
= parse t2
in
867 if is_zero e2.rel
then
869 else if is_factor
then
870 ({ e2 with rel
= e2.rel
**. t1
}, true)
872 ({ e2 with cst
= e2.cst
**. t1
}, false)
873 else if is_factor
then
874 ({ e1 with rel
= e1.rel
**. t2
}, true)
876 ({ e1 with cst
= e1.cst
**. t2
}, false)
877 | _
-> (e_default, false)
881 (* Parse `|f i - exact_f i| <= C`.
882 * We try to decompose `C` to see if it has the form `A (f' i) + B` where
883 * |f i| <= f' i for i in a given range.
884 * Used for sum error propagation.
886 let parse_fn_error i
exact c
=
888 let extract_fn_and_args fn l
=
889 if ls_equal fn fs_func_app
then
890 match (List.hd l
).t_node
with
891 | Tapp
(fn
, []) -> (fn
, List.tl l
)
897 | Tapp
(ls
, [ t'
]) when is_abs_ls ls
&& is_match t'
-> true
899 let _, args = extract_fn_and_args fn' l
in
901 | [ i'
] when t_equal i i'
-> true
905 parse_error is_match exact c
907 (* Parse `|x_approx - x| <= C`.
908 * We try to decompose `C` to see if it has the form `Ax' + B` where |x| <= x' *)
909 let parse_error x c
=
910 let _is_sum, _f
, _i
, _j
=
912 | _ -> (false, None
, None
, None
)
918 | Tapp
(ls
, [ t'
]) when is_abs_ls ls
-> (
926 parse_error is_match x c
928 let is_var_equal t vs
=
930 | Tvar vs'
when vs_equal vs vs'
-> true
933 (* Looks for |to_real x - exact_x| <= C. *)
936 | Tapp
(ls
, [ t; c
]) when is_ineq_ls ls
-> (
938 | Tapp
(ls
, [ t ]) when is_abs_ls ls
-> (
940 | Tapp
(ls
, [ x
; exact_x
]) when is_sub_ls ls
-> (
942 | Tapp
(ls
, [ x
]) when is_to_real_ls ls
-> Some
(x
, exact_x
, c
)
948 (* Looks for `forall i. P -> |to_real (fn i) - exact_fn i| <= A(fn' i) + B` (or
949 the same formula without the hypothesis), with `i` an integer. We also match
950 on a potential hypothesis P because usually there is a hypothesis on the
951 bounds of i. A forward error on a function of type (int -> real) is can be
952 used for sum error propagation *)
953 let collect_in_quant info q
=
954 let vs, _, t = t_open_quant q
in
956 | [ i
] when ty_equal i
.vs_ty ty_int
-> (
959 | Tbinop
(Timplies
, _, t) -> t
962 match parse_ineq t with
963 | Some
(x
, exact_x
, c
) -> (
964 match (x
.t_node
, exact_x
.t_node
) with
965 | Tapp
(fn
, l
), Tapp
(exact_fn
, l'
) -> (
966 let extract_fn_and_args fn l
=
967 if ls_equal fn fs_func_app
then
968 match (List.hd l
).t_node
with
969 | Tapp
(fn
, []) -> (fn
, List.tl l
)
974 let fn, args = extract_fn_and_args fn l
in
975 let exact_fn, args'
= extract_fn_and_args exact_fn l'
in
976 match (args, args'
) with
977 | [ i'
], [ i''
] when is_var_equal i' i
&& is_var_equal i'' i
-> (
978 let e = parse_fn_error i'
(t_app_infer
exact_fn []) c
in
979 match e.factor
.t_node
with
980 | Tapp
(fn''
, args) ->
981 let fn'
= fst
(extract_fn_and_args fn''
args) in
983 Mls.add fn { e with factor
= t_app_infer
fn'
[] } info.fns_info
985 { info with fns_info }
992 let rec collect info f =
994 | Tbinop
(Tand
, f1
, f2
) ->
995 let info = collect info f1
in
997 | Tapp
(ls
, [ _; _ ]) when is_ineq_ls ls
-> (
998 (* term of the form x op y where op is an inequality "<" or "<=" over reals
999 FIXME: we should handle ">" and ">=" as well *)
1000 match parse_ineq f with
1001 | Some
(x
, exact_x
, c
) ->
1002 let error_fmla = parse_error exact_x c
in
1003 let terms_info = add_fw_error info.terms_info x
error_fmla in
1004 { info with terms_info }
1006 (* `r = uop args` or `uop args = r` *)
1007 | Tapp
(ls
, [ t1
; t2
]) when ls_equal ls ps_equ
-> (
1008 match t1
.t_node
with
1009 | Tapp
(ls
, args) when is_uop ls
->
1010 let terms_info = add_ieee_op info.terms_info ls t2
args in
1011 { info with terms_info }
1013 match t2
.t_node
with
1014 | Tapp
(ls
, args) when is_uop ls
->
1015 let terms_info = add_ieee_op info.terms_info ls t1
args in
1016 { info with terms_info }
1018 | Tquant
(Tforall
, tq
) ->
1019 (* Collect potential error on function (for sum propagation lemma) *)
1020 collect_in_quant info tq
1024 * We look for relevant axioms in the task, and we add corresponding
1025 * inequalities to `info`.
1026 * The formulas we look for have one of the following structures :
1027 * - `|to_real x - exact_x| <= A` : fw error on `x`
1028 * - `r = t1 uop t2` : `r` is the result of an IEEE operation on `t1` and `t2`
1029 * - `forall (i:int). P -> |to_real (f i) - exact_f i| <= A` : fw error on `f`
1030 * - `forall (i:int). |to_real (f i) - exact_f i| <= A` : fw error on `f`
1032 * We also look for definitions of the form `r = t1 uop t2`
1034 let collect_info info d
=
1036 | Dprop
(kind
, _pr
, f) when kind
= Paxiom
|| kind
= Plemma
-> collect info f
1040 (fun info (ls
, ls_def
) ->
1041 match ls
.ls_value
with
1042 | Some ty
when is_ty_float ty
->
1043 let _vsl, t = open_ls_defn ls_def
in
1048 { info with ls_defs }
1051 let init_symbols env printer
=
1052 let real = Env.read_theory env
[ "real" ] "Real" in
1053 let lt = ns_find_ls
real.th_export
[ Ident.op_infix
"<" ] in
1054 let le = ns_find_ls
real.th_export
[ Ident.op_infix
"<=" ] in
1055 let real_infix = Env.read_theory env
[ "real" ] "RealInfix" in
1056 let lt_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"<." ] in
1057 let le_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"<=." ] in
1058 let add = ns_find_ls
real.th_export
[ Ident.op_infix
"+" ] in
1059 let sub = ns_find_ls
real.th_export
[ Ident.op_infix
"-" ] in
1060 let mul = ns_find_ls
real.th_export
[ Ident.op_infix
"*" ] in
1061 let _div = ns_find_ls
real.th_export
[ Ident.op_infix
"/" ] in
1062 let minus = ns_find_ls
real.th_export
[ Ident.op_prefix
"-" ] in
1063 let add_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"+." ] in
1064 let sub_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"-." ] in
1065 let mul_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"*." ] in
1066 let div_infix = ns_find_ls
real_infix.th_export
[ Ident.op_infix
"/." ] in
1067 let minus_infix = ns_find_ls
real_infix.th_export
[ Ident.op_prefix
"-." ] in
1068 let real_abs = Env.read_theory env
[ "real" ] "Abs" in
1069 let abs = ns_find_ls
real_abs.th_export
[ "abs" ] in
1070 let exp_log_th = Env.read_theory env
[ "real" ] "ExpLog" in
1071 let exp = ns_find_ls
exp_log_th.th_export
[ "exp" ] in
1072 let log = ns_find_ls
exp_log_th.th_export
[ "log" ] in
1073 let log2 = ns_find_ls
exp_log_th.th_export
[ "log2" ] in
1074 let log10 = ns_find_ls
exp_log_th.th_export
[ "log10" ] in
1075 let trigo_th = Env.read_theory env
[ "real" ] "Trigonometry" in
1076 let sin = ns_find_ls
trigo_th.th_export
[ "sin" ] in
1077 let cos = ns_find_ls
trigo_th.th_export
[ "cos" ] in
1078 let usingle = Env.read_theory env
[ "ufloat" ] "USingle" in
1079 let udouble = Env.read_theory env
[ "ufloat" ] "UDouble" in
1080 let usingle_lemmas = Env.read_theory env
[ "ufloat" ] "USingleLemmas" in
1081 let udouble_lemmas = Env.read_theory env
[ "ufloat" ] "UDoubleLemmas" in
1082 let mk_ufloat_symbols th _th_lemmas ty
=
1084 try ns_find_ls th
.th_export
[ s ] with
1085 | Not_found
-> failwith
(Format.sprintf
"Symbol %s not found" s)
1088 ufloat_type
= ns_find_ts th
.th_export
[ ty
];
1089 to_real = f th
"to_real";
1094 uminus
= f th
"uminus";
1095 udiv_exact
= f th
"udiv_exact";
1096 uadd_infix
= f th
(Ident.op_infix
"++.");
1097 usub_infix
= f th
(Ident.op_infix
"--.");
1098 umul_infix
= f th
(Ident.op_infix
"**.");
1099 udiv_infix
= f th
(Ident.op_infix
"//.");
1100 uminus_prefix
= f th
(Ident.op_prefix
"--.");
1101 udiv_exact_infix
= f th
(Ident.op_infix
"///.");
1102 eps = fs_app
(f th
"eps") [] ty_real
;
1103 eta = fs_app
(f th
"eta") [] ty_real
;
1106 let usingle_symbols = mk_ufloat_symbols usingle usingle_lemmas "usingle" in
1107 let udouble_symbols = mk_ufloat_symbols udouble udouble_lemmas "udouble" in
1134 ident_printer
= printer
.Trans.printer
;
1135 tv_printer
= printer
.Trans.aprinter
;
1141 let rec compute_subterms (acc
: int Mterm.t) f : int Mterm.t =
1151 | None
-> t_fold
compute_subterms acc
f
1154 let n = Mterm.find
f acc
in
1155 Mterm.add f (n + 1) acc
1158 let acc = t_fold
compute_subterms acc f in
1161 let m = compute_subterms Mterm.empty
f in
1168 let id = Ident.id_fresh ?loc
:t.t_loc
"t" in
1169 let vs = create_vsymbol
id (Option.get
t.t_ty
) in
1173 let rec letify_rec f =
1175 let vs = Mterm.find
f m in
1178 | Not_found
-> t_map
letify_rec f
1180 let letified_f = letify_rec f in
1181 let l = Mterm.bindings
m in
1183 List.sort
(fun (t1
, _) (t2
, _) -> compare
(term_size t2
) (term_size t1
)) l
1187 let t = t_map
letify_rec t in
1188 t_let_close
vs t acc)
1191 (*** complete strategy *)
1193 let fw_propagation args env naming_table lang task
=
1194 (* let naming_table = Args_wrapper.build_naming_tables task in *)
1195 init_symbols env
naming_table;
1196 let printer = Args_wrapper.build_naming_tables task
in
1197 (* Update the printer at each call, but not the symbols *)
1202 ident_printer
= printer.Trans.printer;
1203 tv_printer
= printer.Trans.aprinter
;
1205 (* We start by collecting infos from the hypotheses of the task *)
1207 List.fold_left
collect_info
1208 { terms_info = Mterm.empty
; fns_info = Mls.empty
; ls_defs = Mls.empty
}
1213 (* If no argument is given, then we perform forward error propagation for
1214 every ufloat term of the goal. *)
1216 let goal = task_goal_fmla task
in
1219 Args_wrapper.parse_and_type_list ~lang ~as_fmla
:false floats naming_table
1222 (Args_wrapper.Arg_error
1223 "this strategy expects an optional comma-separated list of terms as \
1226 (* For each float `x`, we try to compute a formula of the form `|x - exact_x|
1227 <= A x' + B` where `exact_x` is the real value which is approximated by the
1228 float `x` and `|exact_x| <= x'`. For this, forward error propagation is
1229 performed using propagation lemmas of the ufloat stdlib with the data
1230 contained in `info`. For each new formula created, a proof tree is
1231 generated with the necessary steps to prove it. *)
1235 match get_error_fmlas info t with
1236 | _, None
, _ -> (f, l)
1237 | _, Some
f'
, s -> (t_and_simp
f f'
, s :: l))
1240 let f'
= letify f in
1241 if List.length strats
= 0 then
1244 else if List.length strats
= 1 then
1245 (* We only have an assertion on one float so no need to use split *)
1247 Sapply_trans
("assert", [ term_to_str f ], strats
@ [ default_strat () ])
1249 Sapply_trans
("assert", [ term_to_str f'
], [ f_strat; default_strat () ])
1251 (* We assert a conjunction of formulas, one for each float in the goal for
1252 which we can use propagation lemmas. We have one strat for each of this
1253 goal so we split our assertions and prove each subgoal with the
1254 corresponding strat *)
1255 let s = Sapply_trans
("split_vc", [], List.rev strats
) in
1257 Sapply_trans
("assert", [ term_to_str f ], [ s; default_strat () ])
1259 Sapply_trans
("assert", [ term_to_str f'
], [ f_strat; default_strat () ])
1262 register_strat_with_args
"forward_propagation" fw_propagation
1263 ~desc
:"Compute@ forward@ error@ of@ float@ computations."