2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
31 import os
# Used to advertise this file's name ("autogenerated_note").
35 ADVERT
= '// NOTE: Assertions have been autogenerated by '
37 # Regex command to match an SSA identifier.
38 SSA_RE_STR
= '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
39 SSA_RE
= re
.compile(SSA_RE_STR
)
42 # Class used to generate and manage string substitution blocks for SSA value
44 class SSAVariableNamer
:
50 # Generate a substitution name for the given ssa value name.
51 def generate_name(self
, ssa_name
):
52 variable
= 'VAL_' + str(self
.name_counter
)
53 self
.name_counter
+= 1
54 self
.scopes
[-1][ssa_name
] = variable
57 # Push a new variable name scope.
58 def push_name_scope(self
):
59 self
.scopes
.append({})
61 # Pop the last variable name scope.
62 def pop_name_scope(self
):
65 # Return the level of nesting (number of pushed scopes).
67 return len(self
.scopes
)
70 def clear_counter(self
):
74 # Process a line of input that has been split at each SSA identifier '%'.
75 def process_line(line_chunks
, variable_namer
):
78 # Process the rest that contained an SSA value name.
79 for chunk
in line_chunks
:
80 m
= SSA_RE
.match(chunk
)
83 # Check if an existing variable exists for this name.
85 for scope
in variable_namer
.scopes
:
86 variable
= scope
.get(ssa_name
)
87 if variable
is not None:
90 # If one exists, then output the existing name.
91 if variable
is not None:
92 output_line
+= '%[[' + variable
+ ']]'
94 # Otherwise, generate a new variable.
95 variable
= variable_namer
.generate_name(ssa_name
)
96 output_line
+= '%[[' + variable
+ ':.*]]'
98 # Append the non named group.
99 output_line
+= chunk
[len(ssa_name
):]
101 return output_line
.rstrip() + '\n'
104 # Process the source file lines. The source file doesn't have to be .mlir.
105 def process_source_lines(source_lines
, note
, args
):
106 source_split_re
= re
.compile(args
.source_delim_regex
)
108 source_segments
= [[]]
109 for line
in source_lines
:
110 # Remove previous note.
113 # Remove previous CHECK lines.
114 if line
.find(args
.check_prefix
) != -1:
116 # Segment the file based on --source_delim_regex.
117 if source_split_re
.search(line
):
118 source_segments
.append([])
120 source_segments
[-1].append(line
+ '\n')
121 return source_segments
124 # Pre-process a line of input to remove any character sequences that will be
125 # problematic with FileCheck.
126 def preprocess_line(line
):
127 # Replace any double brackets, '[[' with escaped replacements. '[['
128 # corresponds to variable names in FileCheck.
129 output_line
= line
.replace('[[', '{{\\[\\[}}')
131 # Replace any single brackets that are followed by an SSA identifier, the
132 # identifier will be replace by a variable; Creating the same situation as
134 output_line
= output_line
.replace('[%', '{{\\[}}%')
140 parser
= argparse
.ArgumentParser(
141 description
=__doc__
, formatter_class
=argparse
.RawTextHelpFormatter
)
143 '--check-prefix', default
='CHECK', help='Prefix to use from check file.')
148 type=argparse
.FileType('w'),
153 type=argparse
.FileType('r'),
156 '--source', type=str,
157 help='Print each CHECK chunk before each delimeter line in the source'
158 'file, respectively. The delimeter lines are identified by '
159 '--source_delim_regex.')
160 parser
.add_argument('--source_delim_regex', type=str, default
='func @')
162 '--starts_from_scope', type=int, default
=1,
163 help='Omit the top specified level of content. For example, by default '
164 'it omits "module {"')
165 parser
.add_argument('-i', '--inplace', action
='store_true', default
=False)
167 args
= parser
.parse_args()
169 # Open the given input file.
170 input_lines
= [l
.rstrip() for l
in args
.input]
173 # Generate a note used for the generated check file.
174 script_name
= os
.path
.basename(__file__
)
175 autogenerated_note
= (ADVERT
+ 'utils/' + script_name
)
177 source_segments
= None
179 source_segments
= process_source_lines(
180 [l
.rstrip() for l
in open(args
.source
, 'r')],
186 assert args
.output
is None
187 output
= open(args
.source
, 'w')
188 elif args
.output
is None:
193 output_segments
= [[]]
194 # A map containing data used for naming SSA value names.
195 variable_namer
= SSAVariableNamer()
196 for input_line
in input_lines
:
199 lstripped_input_line
= input_line
.lstrip()
201 # Lines with blocks begin with a ^. These lines have a trailing comment
202 # that needs to be stripped.
203 is_block
= lstripped_input_line
[0] == '^'
205 input_line
= input_line
.rsplit('//', 1)[0].rstrip()
207 cur_level
= variable_namer
.num_scopes()
209 # If the line starts with a '}', pop the last name scope.
210 if lstripped_input_line
[0] == '}':
211 variable_namer
.pop_name_scope()
212 cur_level
= variable_namer
.num_scopes()
214 # If the line ends with a '{', push a new name scope.
215 if input_line
[-1] == '{':
216 variable_namer
.push_name_scope()
217 if cur_level
== args
.starts_from_scope
:
218 output_segments
.append([])
220 # Omit lines at the near top level e.g. "module {".
221 if cur_level
< args
.starts_from_scope
:
224 if len(output_segments
[-1]) == 0:
225 variable_namer
.clear_counter()
227 # Preprocess the input to remove any sequences that may be problematic with
229 input_line
= preprocess_line(input_line
)
231 # Split the line at the each SSA value name.
232 ssa_split
= input_line
.split('%')
234 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
235 if len(output_segments
[-1]) != 0 or not ssa_split
[0]:
236 output_line
= '// ' + args
.check_prefix
+ ': '
237 # Pad to align with the 'LABEL' statements.
238 output_line
+= (' ' * len('-LABEL'))
240 # Output the first line chunk that does not contain an SSA name.
241 output_line
+= ssa_split
[0]
243 # Process the rest of the input line.
244 output_line
+= process_line(ssa_split
[1:], variable_namer
)
247 # Output the first line chunk that does not contain an SSA name for the
249 output_line
= '// ' + args
.check_prefix
+ '-LABEL: ' + ssa_split
[0] + '\n'
251 # Process the rest of the input line on separate check lines.
252 for argument
in ssa_split
[1:]:
253 output_line
+= '// ' + args
.check_prefix
+ '-SAME: '
255 # Pad to align with the original position in the line.
256 output_line
+= ' ' * len(ssa_split
[0])
258 # Process the rest of the line.
259 output_line
+= process_line([argument
], variable_namer
)
261 # Append the output line.
262 output_segments
[-1].append(output_line
)
264 output
.write(autogenerated_note
+ '\n')
268 assert len(output_segments
) == len(source_segments
)
269 for check_segment
, source_segment
in zip(output_segments
, source_segments
):
270 for line
in check_segment
:
272 for line
in source_segment
:
275 for segment
in output_segments
:
277 for output_line
in segment
:
278 output
.write(output_line
)
283 if __name__
== '__main__':