[CFGDiff] Refactor Succ/Pred maps.
[llvm-project.git] / mlir / utils / generate-test-checks.py
blobf6197554eaa793dc4e133a9b277651b8beba08f6
1 #!/usr/bin/env python3
2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
7 stripped down variant.
9 Example usage:
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
24 """
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
30 import argparse
31 import os # Used to advertise this file's name ("autogenerated_note").
32 import re
33 import sys
35 ADVERT = '// NOTE: Assertions have been autogenerated by '
37 # Regex command to match an SSA identifier.
38 SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
39 SSA_RE = re.compile(SSA_RE_STR)
42 # Class used to generate and manage string substitution blocks for SSA value
43 # names.
44 class SSAVariableNamer:
46 def __init__(self):
47 self.scopes = []
48 self.name_counter = 0
50 # Generate a substitution name for the given ssa value name.
51 def generate_name(self, ssa_name):
52 variable = 'VAL_' + str(self.name_counter)
53 self.name_counter += 1
54 self.scopes[-1][ssa_name] = variable
55 return variable
57 # Push a new variable name scope.
58 def push_name_scope(self):
59 self.scopes.append({})
61 # Pop the last variable name scope.
62 def pop_name_scope(self):
63 self.scopes.pop()
65 # Return the level of nesting (number of pushed scopes).
66 def num_scopes(self):
67 return len(self.scopes)
69 # Reset the counter.
70 def clear_counter(self):
71 self.name_counter = 0
74 # Process a line of input that has been split at each SSA identifier '%'.
75 def process_line(line_chunks, variable_namer):
76 output_line = ''
78 # Process the rest that contained an SSA value name.
79 for chunk in line_chunks:
80 m = SSA_RE.match(chunk)
81 ssa_name = m.group(0)
83 # Check if an existing variable exists for this name.
84 variable = None
85 for scope in variable_namer.scopes:
86 variable = scope.get(ssa_name)
87 if variable is not None:
88 break
90 # If one exists, then output the existing name.
91 if variable is not None:
92 output_line += '%[[' + variable + ']]'
93 else:
94 # Otherwise, generate a new variable.
95 variable = variable_namer.generate_name(ssa_name)
96 output_line += '%[[' + variable + ':.*]]'
98 # Append the non named group.
99 output_line += chunk[len(ssa_name):]
101 return output_line.rstrip() + '\n'
104 # Process the source file lines. The source file doesn't have to be .mlir.
105 def process_source_lines(source_lines, note, args):
106 source_split_re = re.compile(args.source_delim_regex)
108 source_segments = [[]]
109 for line in source_lines:
110 # Remove previous note.
111 if line == note:
112 continue
113 # Remove previous CHECK lines.
114 if line.find(args.check_prefix) != -1:
115 continue
116 # Segment the file based on --source_delim_regex.
117 if source_split_re.search(line):
118 source_segments.append([])
120 source_segments[-1].append(line + '\n')
121 return source_segments
124 # Pre-process a line of input to remove any character sequences that will be
125 # problematic with FileCheck.
126 def preprocess_line(line):
127 # Replace any double brackets, '[[' with escaped replacements. '[['
128 # corresponds to variable names in FileCheck.
129 output_line = line.replace('[[', '{{\\[\\[}}')
131 # Replace any single brackets that are followed by an SSA identifier, the
132 # identifier will be replace by a variable; Creating the same situation as
133 # above.
134 output_line = output_line.replace('[%', '{{\\[}}%')
136 return output_line
139 def main():
140 parser = argparse.ArgumentParser(
141 description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
142 parser.add_argument(
143 '--check-prefix', default='CHECK', help='Prefix to use from check file.')
144 parser.add_argument(
145 '-o',
146 '--output',
147 nargs='?',
148 type=argparse.FileType('w'),
149 default=None)
150 parser.add_argument(
151 'input',
152 nargs='?',
153 type=argparse.FileType('r'),
154 default=sys.stdin)
155 parser.add_argument(
156 '--source', type=str,
157 help='Print each CHECK chunk before each delimeter line in the source'
158 'file, respectively. The delimeter lines are identified by '
159 '--source_delim_regex.')
160 parser.add_argument('--source_delim_regex', type=str, default='func @')
161 parser.add_argument(
162 '--starts_from_scope', type=int, default=1,
163 help='Omit the top specified level of content. For example, by default '
164 'it omits "module {"')
165 parser.add_argument('-i', '--inplace', action='store_true', default=False)
167 args = parser.parse_args()
169 # Open the given input file.
170 input_lines = [l.rstrip() for l in args.input]
171 args.input.close()
173 # Generate a note used for the generated check file.
174 script_name = os.path.basename(__file__)
175 autogenerated_note = (ADVERT + 'utils/' + script_name)
177 source_segments = None
178 if args.source:
179 source_segments = process_source_lines(
180 [l.rstrip() for l in open(args.source, 'r')],
181 autogenerated_note,
182 args
185 if args.inplace:
186 assert args.output is None
187 output = open(args.source, 'w')
188 elif args.output is None:
189 output = sys.stdout
190 else:
191 output = args.output
193 output_segments = [[]]
194 # A map containing data used for naming SSA value names.
195 variable_namer = SSAVariableNamer()
196 for input_line in input_lines:
197 if not input_line:
198 continue
199 lstripped_input_line = input_line.lstrip()
201 # Lines with blocks begin with a ^. These lines have a trailing comment
202 # that needs to be stripped.
203 is_block = lstripped_input_line[0] == '^'
204 if is_block:
205 input_line = input_line.rsplit('//', 1)[0].rstrip()
207 cur_level = variable_namer.num_scopes()
209 # If the line starts with a '}', pop the last name scope.
210 if lstripped_input_line[0] == '}':
211 variable_namer.pop_name_scope()
212 cur_level = variable_namer.num_scopes()
214 # If the line ends with a '{', push a new name scope.
215 if input_line[-1] == '{':
216 variable_namer.push_name_scope()
217 if cur_level == args.starts_from_scope:
218 output_segments.append([])
220 # Omit lines at the near top level e.g. "module {".
221 if cur_level < args.starts_from_scope:
222 continue
224 if len(output_segments[-1]) == 0:
225 variable_namer.clear_counter()
227 # Preprocess the input to remove any sequences that may be problematic with
228 # FileCheck.
229 input_line = preprocess_line(input_line)
231 # Split the line at the each SSA value name.
232 ssa_split = input_line.split('%')
234 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
235 if len(output_segments[-1]) != 0 or not ssa_split[0]:
236 output_line = '// ' + args.check_prefix + ': '
237 # Pad to align with the 'LABEL' statements.
238 output_line += (' ' * len('-LABEL'))
240 # Output the first line chunk that does not contain an SSA name.
241 output_line += ssa_split[0]
243 # Process the rest of the input line.
244 output_line += process_line(ssa_split[1:], variable_namer)
246 else:
247 # Output the first line chunk that does not contain an SSA name for the
248 # label.
249 output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
251 # Process the rest of the input line on separate check lines.
252 for argument in ssa_split[1:]:
253 output_line += '// ' + args.check_prefix + '-SAME: '
255 # Pad to align with the original position in the line.
256 output_line += ' ' * len(ssa_split[0])
258 # Process the rest of the line.
259 output_line += process_line([argument], variable_namer)
261 # Append the output line.
262 output_segments[-1].append(output_line)
264 output.write(autogenerated_note + '\n')
266 # Write the output.
267 if source_segments:
268 assert len(output_segments) == len(source_segments)
269 for check_segment, source_segment in zip(output_segments, source_segments):
270 for line in check_segment:
271 output.write(line)
272 for line in source_segment:
273 output.write(line)
274 else:
275 for segment in output_segments:
276 output.write('\n')
277 for output_line in segment:
278 output.write(output_line)
279 output.write('\n')
280 output.close()
283 if __name__ == '__main__':
284 main()