2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
31 import os
# Used to advertise this file's name ("autogenerated_note").
35 ADVERT_BEGIN
= '// NOTE: Assertions have been autogenerated by '
37 // The script is designed to make adding checks to
38 // a test case fast, it is *not* designed to be authoritative
39 // about what constitutes a good test! The CHECK should be
40 // minimized and named to reflect the test intent.
44 # Regex command to match an SSA identifier.
45 SSA_RE_STR
= '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
46 SSA_RE
= re
.compile(SSA_RE_STR
)
49 # Class used to generate and manage string substitution blocks for SSA value
51 class SSAVariableNamer
:
57 # Generate a substitution name for the given ssa value name.
58 def generate_name(self
, ssa_name
):
59 variable
= 'VAL_' + str(self
.name_counter
)
60 self
.name_counter
+= 1
61 self
.scopes
[-1][ssa_name
] = variable
64 # Push a new variable name scope.
65 def push_name_scope(self
):
66 self
.scopes
.append({})
68 # Pop the last variable name scope.
69 def pop_name_scope(self
):
72 # Return the level of nesting (number of pushed scopes).
74 return len(self
.scopes
)
77 def clear_counter(self
):
81 # Process a line of input that has been split at each SSA identifier '%'.
82 def process_line(line_chunks
, variable_namer
):
85 # Process the rest that contained an SSA value name.
86 for chunk
in line_chunks
:
87 m
= SSA_RE
.match(chunk
)
90 # Check if an existing variable exists for this name.
92 for scope
in variable_namer
.scopes
:
93 variable
= scope
.get(ssa_name
)
94 if variable
is not None:
97 # If one exists, then output the existing name.
98 if variable
is not None:
99 output_line
+= '%[[' + variable
+ ']]'
101 # Otherwise, generate a new variable.
102 variable
= variable_namer
.generate_name(ssa_name
)
103 output_line
+= '%[[' + variable
+ ':.*]]'
105 # Append the non named group.
106 output_line
+= chunk
[len(ssa_name
):]
108 return output_line
.rstrip() + '\n'
111 # Process the source file lines. The source file doesn't have to be .mlir.
112 def process_source_lines(source_lines
, note
, args
):
113 source_split_re
= re
.compile(args
.source_delim_regex
)
115 source_segments
= [[]]
116 for line
in source_lines
:
117 # Remove previous note.
120 # Remove previous CHECK lines.
121 if line
.find(args
.check_prefix
) != -1:
123 # Segment the file based on --source_delim_regex.
124 if source_split_re
.search(line
):
125 source_segments
.append([])
127 source_segments
[-1].append(line
+ '\n')
128 return source_segments
131 # Pre-process a line of input to remove any character sequences that will be
132 # problematic with FileCheck.
133 def preprocess_line(line
):
134 # Replace any double brackets, '[[' with escaped replacements. '[['
135 # corresponds to variable names in FileCheck.
136 output_line
= line
.replace('[[', '{{\\[\\[}}')
138 # Replace any single brackets that are followed by an SSA identifier, the
139 # identifier will be replace by a variable; Creating the same situation as
141 output_line
= output_line
.replace('[%', '{{\\[}}%')
147 parser
= argparse
.ArgumentParser(
148 description
=__doc__
, formatter_class
=argparse
.RawTextHelpFormatter
)
150 '--check-prefix', default
='CHECK', help='Prefix to use from check file.')
155 type=argparse
.FileType('w'),
160 type=argparse
.FileType('r'),
163 '--source', type=str,
164 help='Print each CHECK chunk before each delimeter line in the source'
165 'file, respectively. The delimeter lines are identified by '
166 '--source_delim_regex.')
167 parser
.add_argument('--source_delim_regex', type=str, default
='func @')
169 '--starts_from_scope', type=int, default
=1,
170 help='Omit the top specified level of content. For example, by default '
171 'it omits "module {"')
172 parser
.add_argument('-i', '--inplace', action
='store_true', default
=False)
174 args
= parser
.parse_args()
176 # Open the given input file.
177 input_lines
= [l
.rstrip() for l
in args
.input]
180 # Generate a note used for the generated check file.
181 script_name
= os
.path
.basename(__file__
)
182 autogenerated_note
= (ADVERT_BEGIN
+ 'utils/' + script_name
+ "\n" + ADVERT_END
)
184 source_segments
= None
186 source_segments
= process_source_lines(
187 [l
.rstrip() for l
in open(args
.source
, 'r')],
193 assert args
.output
is None
194 output
= open(args
.source
, 'w')
195 elif args
.output
is None:
200 output_segments
= [[]]
201 # A map containing data used for naming SSA value names.
202 variable_namer
= SSAVariableNamer()
203 for input_line
in input_lines
:
206 lstripped_input_line
= input_line
.lstrip()
208 # Lines with blocks begin with a ^. These lines have a trailing comment
209 # that needs to be stripped.
210 is_block
= lstripped_input_line
[0] == '^'
212 input_line
= input_line
.rsplit('//', 1)[0].rstrip()
214 cur_level
= variable_namer
.num_scopes()
216 # If the line starts with a '}', pop the last name scope.
217 if lstripped_input_line
[0] == '}':
218 variable_namer
.pop_name_scope()
219 cur_level
= variable_namer
.num_scopes()
221 # If the line ends with a '{', push a new name scope.
222 if input_line
[-1] == '{':
223 variable_namer
.push_name_scope()
224 if cur_level
== args
.starts_from_scope
:
225 output_segments
.append([])
227 # Omit lines at the near top level e.g. "module {".
228 if cur_level
< args
.starts_from_scope
:
231 if len(output_segments
[-1]) == 0:
232 variable_namer
.clear_counter()
234 # Preprocess the input to remove any sequences that may be problematic with
236 input_line
= preprocess_line(input_line
)
238 # Split the line at the each SSA value name.
239 ssa_split
= input_line
.split('%')
241 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
242 if len(output_segments
[-1]) != 0 or not ssa_split
[0]:
243 output_line
= '// ' + args
.check_prefix
+ ': '
244 # Pad to align with the 'LABEL' statements.
245 output_line
+= (' ' * len('-LABEL'))
247 # Output the first line chunk that does not contain an SSA name.
248 output_line
+= ssa_split
[0]
250 # Process the rest of the input line.
251 output_line
+= process_line(ssa_split
[1:], variable_namer
)
254 # Output the first line chunk that does not contain an SSA name for the
256 output_line
= '// ' + args
.check_prefix
+ '-LABEL: ' + ssa_split
[0] + '\n'
258 # Process the rest of the input line on separate check lines.
259 for argument
in ssa_split
[1:]:
260 output_line
+= '// ' + args
.check_prefix
+ '-SAME: '
262 # Pad to align with the original position in the line.
263 output_line
+= ' ' * len(ssa_split
[0])
265 # Process the rest of the line.
266 output_line
+= process_line([argument
], variable_namer
)
268 # Append the output line.
269 output_segments
[-1].append(output_line
)
271 output
.write(autogenerated_note
+ '\n')
275 assert len(output_segments
) == len(source_segments
)
276 for check_segment
, source_segment
in zip(output_segments
, source_segments
):
277 for line
in check_segment
:
279 for line
in source_segment
:
282 for segment
in output_segments
:
284 for output_line
in segment
:
285 output
.write(output_line
)
290 if __name__
== '__main__':