LAA: improve code in getStrideFromPointer (NFC) (#124780)
[llvm-project.git] / flang / test / Evaluate / test_folding.py
blobdcd1541997c2a22f0fe9ec362d52a7897c61d607
1 #!/usr/bin/env python3
3 """This script verifies expression folding.
4 It compiles a source file with '-fdebug-dump-symbols'
5 and looks for parameter declarations to check
6 they have been folded as expected.
7 To check folding of an expression EXPR,
8 the fortran program passed to this script
9 must contain the following:
11 logical, parameter :: test_x = <compare EXPR to expected value>
13 This script will test that all parameter
14 with a name starting with "test_"
15 have been folded to .true.
16 For instance, acos folding can be tested with:
18 real(4), parameter :: res_acos = acos(0.5_4)
19 real(4), parameter :: exp_acos = 1.047
20 logical, parameter :: test_acos = abs(res_acos - exp_acos).LE.(0.001_4)
22 There are two kinds of failure:
23 - test_x is folded to .false..
24 This means the expression was folded
25 but the value is not as expected.
26 - test_x is not folded (it is neither .true. nor .false.).
27 This means the compiler could not fold the expression.
29 Parameters:
30 sys.argv[1]: a source file with contains the input and expected output
31 sys.argv[2]: the Flang frontend driver
32 sys.argv[3:]: Optional arguments to the Flang frontend driver"""
34 import os
35 import sys
36 import tempfile
37 import re
38 import subprocess
40 from difflib import unified_diff
41 from pathlib import Path
44 def check_args(args):
45 """Verifies that the number is arguments passed is correct."""
46 if len(args) < 3:
47 print(f"Usage: {args[0]} <fortran-source> <flang-command>")
48 sys.exit(1)
51 def set_source(source):
52 """Sets the path to the source files."""
53 if not Path(source).is_file():
54 print(f"File not found: {src}")
55 sys.exit(1)
56 return Path(source)
59 def set_executable(exe):
60 """Sets the path to the Flang frontend driver."""
61 if not Path(exe).is_file():
62 print(f"Flang was not found: {exe}")
63 sys.exit(1)
64 return str(Path(exe))
67 check_args(sys.argv)
68 cwd = os.getcwd()
69 srcdir = set_source(sys.argv[1]).resolve()
70 with open(srcdir, "r", encoding="utf-8") as f:
71 src = f.readlines()
72 src1 = ""
73 src2 = ""
74 src3 = ""
75 src4 = ""
76 messages = ""
77 actual_warnings = ""
78 expected_warnings = ""
79 warning_diffs = ""
81 flang_fc1 = set_executable(sys.argv[2])
82 flang_fc1_args = sys.argv[3:]
83 flang_fc1_options = ""
84 LIBPGMATH = os.getenv("LIBPGMATH")
85 if LIBPGMATH:
86 flang_fc1_options = ["-fdebug-dump-symbols", "-DTEST_LIBPGMATH"]
87 print("Assuming libpgmath support")
88 else:
89 flang_fc1_options = ["-fdebug-dump-symbols"]
90 print("Not assuming libpgmath support")
92 cmd = [flang_fc1, *flang_fc1_args, *flang_fc1_options, str(srcdir)]
93 with tempfile.TemporaryDirectory() as tmpdir:
94 proc = subprocess.run(
95 cmd,
96 stdout=subprocess.PIPE,
97 stderr=subprocess.PIPE,
98 check=True,
99 universal_newlines=True,
100 cwd=tmpdir,
102 src1 = proc.stdout
103 messages = proc.stderr
105 for line in src1.split("\n"):
106 m = re.search(r"(\w*)(?=, PARAMETER).*init:(.*)", line)
107 if m:
108 src2 += f"{m.group(1)} {m.group(2)}\n"
110 for line in src2.split("\n"):
111 m = re.match(r"test_*", line)
112 if m:
113 src3 += f"{m.string}\n"
115 for passed_results, line in enumerate(src3.split("\n")):
116 m = re.search(r"\.false\._.$", line)
117 if m:
118 src4 += f"{line}\n"
120 for line in messages.split("\n"):
121 m = re.search(r"[^:]*:(\d*):\d*: (.*)", line)
122 if m:
123 actual_warnings += f"{m.group(1)}: {m.group(2)}\n"
125 passed_warnings = 0
126 warnings = []
127 for i, line in enumerate(src, 1):
128 m = re.search(r"(?:!WARN:)(.*)", line)
129 if m:
130 warnings.append(m.group(1))
131 continue
132 if warnings:
133 for x in warnings:
134 passed_warnings += 1
135 expected_warnings += f"{i}:{x}\n"
136 warnings = []
138 for line in unified_diff(
139 actual_warnings.split("\n"), expected_warnings.split("\n"), n=0
141 line = re.sub(r"(^\-)(\d+:)", r"\nactual at \g<2>", line)
142 line = re.sub(r"(^\+)(\d+:)", r"\nexpect at \g<2>", line)
143 warning_diffs += line
145 if src4 or warning_diffs:
146 print("Folding test failed:")
147 # Prints failed tests, including parameters with the same
148 # suffix so that more information can be obtained by declaring
149 # expected_x and result_x
150 if src4:
151 for line in src4.split("\n"):
152 m = re.match(r"test_(\w+)", line)
153 if m:
154 for line in src2.split("\n"):
155 if m.group(1) in line:
156 print(line)
157 if warning_diffs:
158 print(warning_diffs)
159 print()
160 print("FAIL")
161 sys.exit(1)
162 else:
163 print()
164 print(f"All {passed_results+passed_warnings} tests passed")
165 print("PASS")