2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
31 import os
# Used to advertise this file's name ("autogenerated_note").
35 ADVERT_BEGIN
= "// NOTE: Assertions have been autogenerated by "
37 // The script is designed to make adding checks to
38 // a test case fast, it is *not* designed to be authoritative
39 // about what constitutes a good test! The CHECK should be
40 // minimized and named to reflect the test intent.
44 # Regex command to match an SSA identifier.
45 SSA_RE_STR
= "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
46 SSA_RE
= re
.compile(SSA_RE_STR
)
48 # Regex matching the left-hand side of an assignment
49 SSA_RESULTS_STR
= r
'\s*(%' + SSA_RE_STR
+ r
')(\s*,\s*(%' + SSA_RE_STR
+ r
'))*\s*='
50 SSA_RESULTS_RE
= re
.compile(SSA_RESULTS_STR
)
52 # Regex matching attributes
53 ATTR_RE_STR
= r
'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
54 ATTR_RE
= re
.compile(ATTR_RE_STR
)
56 # Regex matching the left-hand side of an attribute definition
57 ATTR_DEF_RE_STR
= r
'\s*' + ATTR_RE_STR
+ r
'\s*='
58 ATTR_DEF_RE
= re
.compile(ATTR_DEF_RE_STR
)
61 # Class used to generate and manage string substitution blocks for SSA value
64 def __init__(self
, variable_names
):
68 # Number of variable names to still generate in parent scope
69 self
.generate_in_parent_scope_left
= 0
71 # Parse variable names
72 self
.variable_names
= [name
.upper() for name
in variable_names
.split(',')]
73 self
.used_variable_names
= set()
75 # Generate the following 'n' variable names in the parent scope.
76 def generate_in_parent_scope(self
, n
):
77 self
.generate_in_parent_scope_left
= n
79 # Generate a substitution name for the given ssa value name.
80 def generate_name(self
, source_variable_name
):
82 # Compute variable name
83 variable_name
= self
.variable_names
.pop(0) if len(self
.variable_names
) > 0 else ''
84 if variable_name
== '':
85 variable_name
= "VAL_" + str(self
.name_counter
)
86 self
.name_counter
+= 1
88 # Scope where variable name is saved
89 scope
= len(self
.scopes
) - 1
90 if self
.generate_in_parent_scope_left
> 0:
91 self
.generate_in_parent_scope_left
-= 1
92 scope
= len(self
.scopes
) - 2
96 if variable_name
in self
.used_variable_names
:
97 raise RuntimeError(variable_name
+ ': duplicate variable name')
98 self
.scopes
[scope
][source_variable_name
] = variable_name
99 self
.used_variable_names
.add(variable_name
)
103 # Push a new variable name scope.
104 def push_name_scope(self
):
105 self
.scopes
.append({})
107 # Pop the last variable name scope.
108 def pop_name_scope(self
):
111 # Return the level of nesting (number of pushed scopes).
112 def num_scopes(self
):
113 return len(self
.scopes
)
115 # Reset the counter and used variable names.
116 def clear_names(self
):
117 self
.name_counter
= 0
118 self
.used_variable_names
= set()
120 class AttributeNamer
:
122 def __init__(self
, attribute_names
):
123 self
.name_counter
= 0
124 self
.attribute_names
= [name
.upper() for name
in attribute_names
.split(',')]
126 self
.used_attribute_names
= set()
128 # Generate a substitution name for the given attribute name.
129 def generate_name(self
, source_attribute_name
):
131 # Compute FileCheck name
132 attribute_name
= self
.attribute_names
.pop(0) if len(self
.attribute_names
) > 0 else ''
133 if attribute_name
== '':
134 attribute_name
= "ATTR_" + str(self
.name_counter
)
135 self
.name_counter
+= 1
137 # Prepend global symbol
138 attribute_name
= '$' + attribute_name
141 if attribute_name
in self
.used_attribute_names
:
142 raise RuntimeError(attribute_name
+ ': duplicate attribute name')
143 self
.map[source_attribute_name
] = attribute_name
144 self
.used_attribute_names
.add(attribute_name
)
145 return attribute_name
147 # Get the saved substitution name for the given attribute name. If no name
148 # has been generated for the given attribute yet, the source attribute name
149 # itself is returned.
150 def get_name(self
, source_attribute_name
):
151 return self
.map[source_attribute_name
] if source_attribute_name
in self
.map else '?'
153 # Return the number of SSA results in a line of type
155 # The function returns 0 if there are no results.
156 def get_num_ssa_results(input_line
):
157 m
= SSA_RESULTS_RE
.match(input_line
)
158 return m
.group().count('%') if m
else 0
161 # Process a line of input that has been split at each SSA identifier '%'.
162 def process_line(line_chunks
, variable_namer
):
165 # Process the rest that contained an SSA value name.
166 for chunk
in line_chunks
:
167 m
= SSA_RE
.match(chunk
)
168 ssa_name
= m
.group(0) if m
is not None else ''
170 # Check if an existing variable exists for this name.
172 for scope
in variable_namer
.scopes
:
173 variable
= scope
.get(ssa_name
)
174 if variable
is not None:
177 # If one exists, then output the existing name.
178 if variable
is not None:
179 output_line
+= "%[[" + variable
+ "]]"
181 # Otherwise, generate a new variable.
182 variable
= variable_namer
.generate_name(ssa_name
)
183 output_line
+= "%[[" + variable
+ ":.*]]"
185 # Append the non named group.
186 output_line
+= chunk
[len(ssa_name
) :]
188 return output_line
.rstrip() + "\n"
191 # Process the source file lines. The source file doesn't have to be .mlir.
192 def process_source_lines(source_lines
, note
, args
):
193 source_split_re
= re
.compile(args
.source_delim_regex
)
195 source_segments
= [[]]
196 for line
in source_lines
:
197 # Remove previous note.
200 # Remove previous CHECK lines.
201 if line
.find(args
.check_prefix
) != -1:
203 # Segment the file based on --source_delim_regex.
204 if source_split_re
.search(line
):
205 source_segments
.append([])
207 source_segments
[-1].append(line
+ "\n")
208 return source_segments
210 def process_attribute_definition(line
, attribute_namer
, output
):
211 m
= ATTR_DEF_RE
.match(line
)
213 attribute_name
= attribute_namer
.generate_name(m
.group(1))
214 line
= '// CHECK: #[[' + attribute_name
+ ':.+]] =' + line
[len(m
.group(0)):] + '\n'
217 def process_attribute_references(line
, attribute_namer
):
220 components
= ATTR_RE
.split(line
)
221 for component
in components
:
222 m
= ATTR_RE
.match(component
)
224 output_line
+= '#[[' + attribute_namer
.get_name(m
.group(1)) + ']]'
225 output_line
+= component
[len(m
.group()):]
227 output_line
+= component
230 # Pre-process a line of input to remove any character sequences that will be
231 # problematic with FileCheck.
232 def preprocess_line(line
):
233 # Replace any double brackets, '[[' with escaped replacements. '[['
234 # corresponds to variable names in FileCheck.
235 output_line
= line
.replace("[[", "{{\\[\\[}}")
237 # Replace any single brackets that are followed by an SSA identifier, the
238 # identifier will be replace by a variable; Creating the same situation as
240 output_line
= output_line
.replace("[%", "{{\\[}}%")
246 parser
= argparse
.ArgumentParser(
247 description
=__doc__
, formatter_class
=argparse
.RawTextHelpFormatter
250 "--check-prefix", default
="CHECK", help="Prefix to use from check file."
253 "-o", "--output", nargs
="?", type=argparse
.FileType("w"), default
=None
256 "input", nargs
="?", type=argparse
.FileType("r"), default
=sys
.stdin
261 help="Print each CHECK chunk before each delimeter line in the source"
262 "file, respectively. The delimeter lines are identified by "
263 "--source_delim_regex.",
265 parser
.add_argument("--source_delim_regex", type=str, default
="func @")
267 "--starts_from_scope",
270 help="Omit the top specified level of content. For example, by default "
271 'it omits "module {"',
273 parser
.add_argument("-i", "--inplace", action
="store_true", default
=False)
278 help="Names to be used in FileCheck regular expression to represent SSA "
279 "variables in the order they are encountered. Separate names with commas, "
280 "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
285 help="Names to be used in FileCheck regular expression to represent "
286 "attributes in the order they are defined. Separate names with commas,"
287 "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
289 args
= parser
.parse_args()
291 # Open the given input file.
292 input_lines
= [l
.rstrip() for l
in args
.input]
295 # Generate a note used for the generated check file.
296 script_name
= os
.path
.basename(__file__
)
297 autogenerated_note
= ADVERT_BEGIN
+ "utils/" + script_name
+ "\n" + ADVERT_END
299 source_segments
= None
301 source_segments
= process_source_lines(
302 [l
.rstrip() for l
in open(args
.source
, "r")], autogenerated_note
, args
306 assert args
.output
is None
307 output
= open(args
.source
, "w")
308 elif args
.output
is None:
313 output_segments
= [[]]
316 variable_namer
= VariableNamer(args
.variable_names
)
317 attribute_namer
= AttributeNamer(args
.attribute_names
)
320 for input_line
in input_lines
:
324 # Check if this is an attribute definition and process it
325 process_attribute_definition(input_line
, attribute_namer
, output
)
327 # Lines with blocks begin with a ^. These lines have a trailing comment
328 # that needs to be stripped.
329 lstripped_input_line
= input_line
.lstrip()
330 is_block
= lstripped_input_line
[0] == "^"
332 input_line
= input_line
.rsplit("//", 1)[0].rstrip()
334 cur_level
= variable_namer
.num_scopes()
336 # If the line starts with a '}', pop the last name scope.
337 if lstripped_input_line
[0] == "}":
338 variable_namer
.pop_name_scope()
339 cur_level
= variable_namer
.num_scopes()
341 # If the line ends with a '{', push a new name scope.
342 if input_line
[-1] == "{":
343 variable_namer
.push_name_scope()
344 if cur_level
== args
.starts_from_scope
:
345 output_segments
.append([])
347 # Result SSA values must still be pushed to parent scope
348 num_ssa_results
= get_num_ssa_results(input_line
)
349 variable_namer
.generate_in_parent_scope(num_ssa_results
)
351 # Omit lines at the near top level e.g. "module {".
352 if cur_level
< args
.starts_from_scope
:
355 if len(output_segments
[-1]) == 0:
356 variable_namer
.clear_names()
358 # Preprocess the input to remove any sequences that may be problematic with
360 input_line
= preprocess_line(input_line
)
362 # Process uses of attributes in this line
363 input_line
= process_attribute_references(input_line
, attribute_namer
)
365 # Split the line at the each SSA value name.
366 ssa_split
= input_line
.split("%")
368 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
369 if len(output_segments
[-1]) != 0 or not ssa_split
[0]:
370 output_line
= "// " + args
.check_prefix
+ ": "
371 # Pad to align with the 'LABEL' statements.
372 output_line
+= " " * len("-LABEL")
374 # Output the first line chunk that does not contain an SSA name.
375 output_line
+= ssa_split
[0]
377 # Process the rest of the input line.
378 output_line
+= process_line(ssa_split
[1:], variable_namer
)
381 # Output the first line chunk that does not contain an SSA name for the
383 output_line
= "// " + args
.check_prefix
+ "-LABEL: " + ssa_split
[0] + "\n"
385 # Process the rest of the input line on separate check lines.
386 for argument
in ssa_split
[1:]:
387 output_line
+= "// " + args
.check_prefix
+ "-SAME: "
389 # Pad to align with the original position in the line.
390 output_line
+= " " * len(ssa_split
[0])
392 # Process the rest of the line.
393 output_line
+= process_line([argument
], variable_namer
)
395 # Append the output line.
396 output_segments
[-1].append(output_line
)
398 output
.write(autogenerated_note
+ "\n")
402 assert len(output_segments
) == len(source_segments
)
403 for check_segment
, source_segment
in zip(output_segments
, source_segments
):
404 for line
in check_segment
:
406 for line
in source_segment
:
409 for segment
in output_segments
:
411 for output_line
in segment
:
412 output
.write(output_line
)
417 if __name__
== "__main__":