fix sessions and CE oracles
[why3.git] / bench / test_mlw_printer
blobacc362d4223d196388bb4473156b48f55ea314ce
1 #!/usr/bin/env python3
2 import sys
3 import sexpdata
4 import math
5 import os
6 from subprocess import Popen, PIPE, DEVNULL
8 debug = os.environ.get("DEBUG") != None
10 loads_kwargs = dict(nil=None, true=None)  # Don't convert nil or #t in s-exp
12 class CommandFailed(Exception):
13     def __init__(self, msg):
14         self.msg = msg
16 class NotEqual(Exception):
17     def __init__(self, path, sexp0, sexp1):
18         self.path = path
19         self.sexp0 = sexp0
20         self.sexp1 = sexp1
22 # Read a whyml file as a s-exp
23 def read(why3, filename):
24     p = Popen([why3, "pp", "--output=sexp", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
25     s, _ = p.communicate()
26     if p.returncode == 0:
27         return sexpdata.loads(s, **loads_kwargs) if s else None
28     else:
29         raise CommandFailed("cannot print s-expr for original file (returncode={})"
30                             .format(p.returncode))
32 # Pretty-print a whyml file and read the result as a s-exp
33 def print_and_read(why3, filename):
34     p1 = Popen([why3, "pp", "--output=mlw", filename], stdout=PIPE, stderr=DEVNULL, encoding='utf8')
35     p2 = Popen([why3, "pp", "--output=sexp", "-"], stdin=p1.stdout, stdout=PIPE, stderr=DEVNULL, encoding='utf8')
36     s, _ = p2.communicate()
37     if p2.returncode == 0:
38         return sexpdata.loads(s, **loads_kwargs) if s else None
39     else:
40         raise CommandFailed("cannot print s-expr for pretty-printed output (returncode={})"
41                             .format(p2.returncode))
43 def is_location(sexp):
44     try:
45         return [type(x) for x in sexp] == [str, int, int, int]
46     except:
47         return False
49 IGNORE_ID_ATTRS = [
50     "W:unused_variable:N", "extraction:array_make", "extraction:array",
51     "induction", "mlw:reference_var", "infer", "useraxiom", "W:non_conservative_extension:N",
52     "model_trace:flag", "model_trace:first_val", "model_trace:sec_val", "model_trace:TEMP_NAME",
53     "model", "W:unmodified_variable:N"
56 def keep_id_attr(at):
57     try:
58         # (ATstr ((attr_string "...") (attr_tag N)))
59         variant, fields = at
60         field0, field1 = fields
61         field, value = field0
62         if variant.value() == 'ATstr' and field.value() == 'attr_string':
63             return value not in IGNORE_ID_ATTRS
64         else:
65             return True
66     except:
67         return True
69 def ignore_id_attrs(sexp):
70     try:
71         # (id_ats (at ...))
72         field, value = sexp
73         if field.value() == 'id_ats':
74             ats = [at for at in value if keep_id_attr(at)]
75             return [sexp[0], ats]
76         else:
77             return sexp
78     except:
79         return sexp
81 # Test for sexp (<field_name> _)
82 def is_field(sexp, field_name):
83     try:
84         return len(sexp) == 2 and sexp[0].value() == field_name
85     except:
86         return False
88 def assert_equal(path, sexp0, sexp1):
89     if sexp0 == sexp1:
90         return
91     if is_location(sexp0) and is_location(sexp1):
92         return # Don't bother with locations
93     if is_field(sexp0, "attr_tag") and is_field(sexp1, "attr_tag"):
94         return # Don't bother with tags
95     if type(sexp0) == float and math.isnan(sexp0) and type(sexp1) == float and math.isnan(sexp1):
96         return # nan != nan
97     if type(sexp0) == list and type(sexp1) == list:
98         while True: # Ignore additional parentheses
99             try:
100                 if sexp0[0].value() == "PTparen" and sexp1[0].value() != "PTparen":
101                     path = path+[1]
102                     sexp0 = sexp0[1]
103                 elif sexp1[0].value() == "PTparen" and sexp0[0].value() != "PTparen":
104                     sexp1 = sexp1[1]
105                 elif sexp0[0].value() == "Pparen" and sexp1[0].value() != "Pparen":
106                     path = path+[1]
107                     sexp0 = sexp0[1][0][1]
108                 elif sexp1[0].value() == "Pparen" and sexp0[0].value() != "Pparen":
109                     sexp1 = sexp1[1][0][1]
110                 elif sexp0[0].value() == "Ptuple" and len(sexp0[1]) == 1 and sexp1[0].value() != "Ptuple":
111                     path = path+[1]
112                     sexp0 = sexp0[1][0][1]
113                 elif sexp1[0].value() == "Ptuple" and len(sexp1[1]) == 1 and sexp0[0].value() != "Ptuple":
114                     sexp1 = sexp1[1][0][1]
115                 elif ((sexp0[0].value() == "Tinfix" and sexp1[0].value() == "Tinnfix") or
116                       (sexp0[0].value() == "Tinnfix" and sexp1[0].value() == "Tinfix") or
117                       (sexp0[0].value() == "Tbinop" and sexp1[0].value() == "Tbinnop") or
118                       (sexp0[0].value() == "Tbinnop" and sexp1[0].value() == "Tbinop") or
119                       (sexp0[0].value() == "Einfix" and sexp1[0].value() == "Einnfix") or
120                       (sexp0[0].value() == "Einnfix" and sexp1[0].value() == "Einfix")):
121                     sexp0 = sexp0[1:]
122                     sexp1 = sexp1[1:]
123                 else:
124                     break
125             except AttributeError:
126                 break
127         sexp0 = ignore_id_attrs(sexp0)
128         sexp1 = ignore_id_attrs(sexp1)
129         if len(sexp0) > len(sexp1):
130             raise NotEqual(path, sexp0, sexp1)
131         if len(sexp0) < len(sexp1):
132             raise NotEqual(path, sexp0, sexp1)
133         for i, (s0, s1) in enumerate(zip(sexp0, sexp1)):
134             assert_equal(path+[i], s0, s1)
135     else:
136         raise NotEqual(path, sexp0, sexp1)
138 def trace(path, sexp, sexp1):
139     if path == []:
140         return [sexpdata.Symbol("ERROR"),
141                 [sexpdata.Symbol("EXPECTED"), sexp],
142                 [sexpdata.Symbol("FOUND"), sexp1]]
143     elif type(sexp) == list:
144         return [trace(path[1:], sexp[i], sexp1)
145                 if i == path[0] else sexp[i]
146                 for i, x in enumerate(sexp)]
148 def test(why3, filename):
149     print("  {}: ".format(filename), end='', flush=True)
150     try:
151         sexp0 = read(why3, filename)
152         sexp1 = print_and_read(why3, filename)
153         assert_equal([], sexp0, sexp1)
154         print("ok")
155         return True
156     except NotEqual as e:
157         print("FAILED")
158         if debug:
159             sexpdata.dump(trace(e.path, sexp0, e.sexp1) or "NO TRACE", sys.stdout)
160         return False
161     except CommandFailed as e:
162         print("COMMAND FAILED:", e.msg)
163         return False
165 def main():
166     why3 = sys.argv[1]
167     files = sys.argv[2:]
168     res = all(test(why3, f) for f in files)
169     exit(0 if res else 1)
171 if __name__ == "__main__":
172     main()