[Frontend] Remove unused includes (NFC) (#116927)
[llvm-project.git] / mlir / utils / generate-test-checks.py
blob8faa425beace1d7a98c24827bdc152692d46d993
1 #!/usr/bin/env python3
2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
7 stripped down variant.
9 Example usage:
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
24 """
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
30 import argparse
31 import os # Used to advertise this file's name ("autogenerated_note").
32 import re
33 import sys
35 ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
36 ADVERT_END = """
37 // The script is designed to make adding checks to
38 // a test case fast, it is *not* designed to be authoritative
39 // about what constitutes a good test! The CHECK should be
40 // minimized and named to reflect the test intent.
41 """
44 # Regex command to match an SSA identifier.
45 SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
46 SSA_RE = re.compile(SSA_RE_STR)
48 # Regex matching the left-hand side of an assignment
49 SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
50 SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
52 # Regex matching attributes
53 ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
54 ATTR_RE = re.compile(ATTR_RE_STR)
56 # Regex matching the left-hand side of an attribute definition
57 ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
58 ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
61 # Class used to generate and manage string substitution blocks for SSA value
62 # names.
63 class VariableNamer:
64 def __init__(self, variable_names):
65 self.scopes = []
66 self.name_counter = 0
68 # Number of variable names to still generate in parent scope
69 self.generate_in_parent_scope_left = 0
71 # Parse variable names
72 self.variable_names = [name.upper() for name in variable_names.split(',')]
73 self.used_variable_names = set()
75 # Generate the following 'n' variable names in the parent scope.
76 def generate_in_parent_scope(self, n):
77 self.generate_in_parent_scope_left = n
79 # Generate a substitution name for the given ssa value name.
80 def generate_name(self, source_variable_name):
82 # Compute variable name
83 variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
84 if variable_name == '':
85 variable_name = "VAL_" + str(self.name_counter)
86 self.name_counter += 1
88 # Scope where variable name is saved
89 scope = len(self.scopes) - 1
90 if self.generate_in_parent_scope_left > 0:
91 self.generate_in_parent_scope_left -= 1
92 scope = len(self.scopes) - 2
93 assert(scope >= 0)
95 # Save variable
96 if variable_name in self.used_variable_names:
97 raise RuntimeError(variable_name + ': duplicate variable name')
98 self.scopes[scope][source_variable_name] = variable_name
99 self.used_variable_names.add(variable_name)
101 return variable_name
103 # Push a new variable name scope.
104 def push_name_scope(self):
105 self.scopes.append({})
107 # Pop the last variable name scope.
108 def pop_name_scope(self):
109 self.scopes.pop()
111 # Return the level of nesting (number of pushed scopes).
112 def num_scopes(self):
113 return len(self.scopes)
115 # Reset the counter and used variable names.
116 def clear_names(self):
117 self.name_counter = 0
118 self.used_variable_names = set()
120 class AttributeNamer:
122 def __init__(self, attribute_names):
123 self.name_counter = 0
124 self.attribute_names = [name.upper() for name in attribute_names.split(',')]
125 self.map = {}
126 self.used_attribute_names = set()
128 # Generate a substitution name for the given attribute name.
129 def generate_name(self, source_attribute_name):
131 # Compute FileCheck name
132 attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
133 if attribute_name == '':
134 attribute_name = "ATTR_" + str(self.name_counter)
135 self.name_counter += 1
137 # Prepend global symbol
138 attribute_name = '$' + attribute_name
140 # Save attribute
141 if attribute_name in self.used_attribute_names:
142 raise RuntimeError(attribute_name + ': duplicate attribute name')
143 self.map[source_attribute_name] = attribute_name
144 self.used_attribute_names.add(attribute_name)
145 return attribute_name
147 # Get the saved substitution name for the given attribute name. If no name
148 # has been generated for the given attribute yet, the source attribute name
149 # itself is returned.
150 def get_name(self, source_attribute_name):
151 return self.map[source_attribute_name] if source_attribute_name in self.map else '?'
153 # Return the number of SSA results in a line of type
154 # %0, %1, ... = ...
155 # The function returns 0 if there are no results.
156 def get_num_ssa_results(input_line):
157 m = SSA_RESULTS_RE.match(input_line)
158 return m.group().count('%') if m else 0
161 # Process a line of input that has been split at each SSA identifier '%'.
162 def process_line(line_chunks, variable_namer):
163 output_line = ""
165 # Process the rest that contained an SSA value name.
166 for chunk in line_chunks:
167 m = SSA_RE.match(chunk)
168 ssa_name = m.group(0) if m is not None else ''
170 # Check if an existing variable exists for this name.
171 variable = None
172 for scope in variable_namer.scopes:
173 variable = scope.get(ssa_name)
174 if variable is not None:
175 break
177 # If one exists, then output the existing name.
178 if variable is not None:
179 output_line += "%[[" + variable + "]]"
180 else:
181 # Otherwise, generate a new variable.
182 variable = variable_namer.generate_name(ssa_name)
183 output_line += "%[[" + variable + ":.*]]"
185 # Append the non named group.
186 output_line += chunk[len(ssa_name) :]
188 return output_line.rstrip() + "\n"
191 # Process the source file lines. The source file doesn't have to be .mlir.
192 def process_source_lines(source_lines, note, args):
193 source_split_re = re.compile(args.source_delim_regex)
195 source_segments = [[]]
196 for line in source_lines:
197 # Remove previous note.
198 if line == note:
199 continue
200 # Remove previous CHECK lines.
201 if line.find(args.check_prefix) != -1:
202 continue
203 # Segment the file based on --source_delim_regex.
204 if source_split_re.search(line):
205 source_segments.append([])
207 source_segments[-1].append(line + "\n")
208 return source_segments
210 def process_attribute_definition(line, attribute_namer, output):
211 m = ATTR_DEF_RE.match(line)
212 if m:
213 attribute_name = attribute_namer.generate_name(m.group(1))
214 line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
215 output.write(line)
217 def process_attribute_references(line, attribute_namer):
219 output_line = ''
220 components = ATTR_RE.split(line)
221 for component in components:
222 m = ATTR_RE.match(component)
223 if m:
224 output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]'
225 output_line += component[len(m.group()):]
226 else:
227 output_line += component
228 return output_line
230 # Pre-process a line of input to remove any character sequences that will be
231 # problematic with FileCheck.
232 def preprocess_line(line):
233 # Replace any double brackets, '[[' with escaped replacements. '[['
234 # corresponds to variable names in FileCheck.
235 output_line = line.replace("[[", "{{\\[\\[}}")
237 # Replace any single brackets that are followed by an SSA identifier, the
238 # identifier will be replace by a variable; Creating the same situation as
239 # above.
240 output_line = output_line.replace("[%", "{{\\[}}%")
242 return output_line
245 def main():
246 parser = argparse.ArgumentParser(
247 description=__doc__, formatter_class=argparse.RawTextHelpFormatter
249 parser.add_argument(
250 "--check-prefix", default="CHECK", help="Prefix to use from check file."
252 parser.add_argument(
253 "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
255 parser.add_argument(
256 "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
258 parser.add_argument(
259 "--source",
260 type=str,
261 help="Print each CHECK chunk before each delimeter line in the source"
262 "file, respectively. The delimeter lines are identified by "
263 "--source_delim_regex.",
265 parser.add_argument("--source_delim_regex", type=str, default="func @")
266 parser.add_argument(
267 "--starts_from_scope",
268 type=int,
269 default=1,
270 help="Omit the top specified level of content. For example, by default "
271 'it omits "module {"',
273 parser.add_argument("-i", "--inplace", action="store_true", default=False)
274 parser.add_argument(
275 "--variable_names",
276 type=str,
277 default='',
278 help="Names to be used in FileCheck regular expression to represent SSA "
279 "variables in the order they are encountered. Separate names with commas, "
280 "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
281 parser.add_argument(
282 "--attribute_names",
283 type=str,
284 default='',
285 help="Names to be used in FileCheck regular expression to represent "
286 "attributes in the order they are defined. Separate names with commas,"
287 "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
289 args = parser.parse_args()
291 # Open the given input file.
292 input_lines = [l.rstrip() for l in args.input]
293 args.input.close()
295 # Generate a note used for the generated check file.
296 script_name = os.path.basename(__file__)
297 autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
299 source_segments = None
300 if args.source:
301 source_segments = process_source_lines(
302 [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args
305 if args.inplace:
306 assert args.output is None
307 output = open(args.source, "w")
308 elif args.output is None:
309 output = sys.stdout
310 else:
311 output = args.output
313 output_segments = [[]]
315 # Namers
316 variable_namer = VariableNamer(args.variable_names)
317 attribute_namer = AttributeNamer(args.attribute_names)
319 # Process lines
320 for input_line in input_lines:
321 if not input_line:
322 continue
324 # Check if this is an attribute definition and process it
325 process_attribute_definition(input_line, attribute_namer, output)
327 # Lines with blocks begin with a ^. These lines have a trailing comment
328 # that needs to be stripped.
329 lstripped_input_line = input_line.lstrip()
330 is_block = lstripped_input_line[0] == "^"
331 if is_block:
332 input_line = input_line.rsplit("//", 1)[0].rstrip()
334 cur_level = variable_namer.num_scopes()
336 # If the line starts with a '}', pop the last name scope.
337 if lstripped_input_line[0] == "}":
338 variable_namer.pop_name_scope()
339 cur_level = variable_namer.num_scopes()
341 # If the line ends with a '{', push a new name scope.
342 if input_line[-1] == "{":
343 variable_namer.push_name_scope()
344 if cur_level == args.starts_from_scope:
345 output_segments.append([])
347 # Result SSA values must still be pushed to parent scope
348 num_ssa_results = get_num_ssa_results(input_line)
349 variable_namer.generate_in_parent_scope(num_ssa_results)
351 # Omit lines at the near top level e.g. "module {".
352 if cur_level < args.starts_from_scope:
353 continue
355 if len(output_segments[-1]) == 0:
356 variable_namer.clear_names()
358 # Preprocess the input to remove any sequences that may be problematic with
359 # FileCheck.
360 input_line = preprocess_line(input_line)
362 # Process uses of attributes in this line
363 input_line = process_attribute_references(input_line, attribute_namer)
365 # Split the line at the each SSA value name.
366 ssa_split = input_line.split("%")
368 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
369 if len(output_segments[-1]) != 0 or not ssa_split[0]:
370 output_line = "// " + args.check_prefix + ": "
371 # Pad to align with the 'LABEL' statements.
372 output_line += " " * len("-LABEL")
374 # Output the first line chunk that does not contain an SSA name.
375 output_line += ssa_split[0]
377 # Process the rest of the input line.
378 output_line += process_line(ssa_split[1:], variable_namer)
380 else:
381 # Output the first line chunk that does not contain an SSA name for the
382 # label.
383 output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"
385 # Process the rest of the input line on separate check lines.
386 for argument in ssa_split[1:]:
387 output_line += "// " + args.check_prefix + "-SAME: "
389 # Pad to align with the original position in the line.
390 output_line += " " * len(ssa_split[0])
392 # Process the rest of the line.
393 output_line += process_line([argument], variable_namer)
395 # Append the output line.
396 output_segments[-1].append(output_line)
398 output.write(autogenerated_note + "\n")
400 # Write the output.
401 if source_segments:
402 assert len(output_segments) == len(source_segments)
403 for check_segment, source_segment in zip(output_segments, source_segments):
404 for line in check_segment:
405 output.write(line)
406 for line in source_segment:
407 output.write(line)
408 else:
409 for segment in output_segments:
410 output.write("\n")
411 for output_line in segment:
412 output.write(output_line)
413 output.write("\n")
414 output.close()
417 if __name__ == "__main__":
418 main()