[AMDGPU] Make v8i16/v8f16 legal
[llvm-project.git] / mlir / utils / generate-test-checks.py
blob474f812c9c0bc181f8575f87de0654af679415d0
1 #!/usr/bin/env python3
2 """A script to generate FileCheck statements for mlir unit tests.
4 This script is a utility to add FileCheck patterns to an mlir file.
6 NOTE: The input .mlir is expected to be the output from the parser, not a
7 stripped down variant.
9 Example usage:
10 $ generate-test-checks.py foo.mlir
11 $ mlir-opt foo.mlir -transformation | generate-test-checks.py
12 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14 $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
16 The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17 within the file. By default this script will also try to insert string
18 substitution blocks for all SSA value names. If --source file is specified, the
19 script will attempt to insert the generated CHECKs to the source file by looking
20 for line positions matched by --source_delim_regex.
22 The script is designed to make adding checks to a test case fast, it is *not*
23 designed to be authoritative about what constitutes a good test!
24 """
26 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27 # See https://llvm.org/LICENSE.txt for license information.
28 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
30 import argparse
31 import os # Used to advertise this file's name ("autogenerated_note").
32 import re
33 import sys
35 ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by '
36 ADVERT_END = """
37 // The script is designed to make adding checks to
38 // a test case fast, it is *not* designed to be authoritative
39 // about what constitutes a good test! The CHECK should be
40 // minimized and named to reflect the test intent.
41 """
44 # Regex command to match an SSA identifier.
45 SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
46 SSA_RE = re.compile(SSA_RE_STR)
49 # Class used to generate and manage string substitution blocks for SSA value
50 # names.
51 class SSAVariableNamer:
53 def __init__(self):
54 self.scopes = []
55 self.name_counter = 0
57 # Generate a substitution name for the given ssa value name.
58 def generate_name(self, ssa_name):
59 variable = 'VAL_' + str(self.name_counter)
60 self.name_counter += 1
61 self.scopes[-1][ssa_name] = variable
62 return variable
64 # Push a new variable name scope.
65 def push_name_scope(self):
66 self.scopes.append({})
68 # Pop the last variable name scope.
69 def pop_name_scope(self):
70 self.scopes.pop()
72 # Return the level of nesting (number of pushed scopes).
73 def num_scopes(self):
74 return len(self.scopes)
76 # Reset the counter.
77 def clear_counter(self):
78 self.name_counter = 0
81 # Process a line of input that has been split at each SSA identifier '%'.
82 def process_line(line_chunks, variable_namer):
83 output_line = ''
85 # Process the rest that contained an SSA value name.
86 for chunk in line_chunks:
87 m = SSA_RE.match(chunk)
88 ssa_name = m.group(0)
90 # Check if an existing variable exists for this name.
91 variable = None
92 for scope in variable_namer.scopes:
93 variable = scope.get(ssa_name)
94 if variable is not None:
95 break
97 # If one exists, then output the existing name.
98 if variable is not None:
99 output_line += '%[[' + variable + ']]'
100 else:
101 # Otherwise, generate a new variable.
102 variable = variable_namer.generate_name(ssa_name)
103 output_line += '%[[' + variable + ':.*]]'
105 # Append the non named group.
106 output_line += chunk[len(ssa_name):]
108 return output_line.rstrip() + '\n'
111 # Process the source file lines. The source file doesn't have to be .mlir.
112 def process_source_lines(source_lines, note, args):
113 source_split_re = re.compile(args.source_delim_regex)
115 source_segments = [[]]
116 for line in source_lines:
117 # Remove previous note.
118 if line == note:
119 continue
120 # Remove previous CHECK lines.
121 if line.find(args.check_prefix) != -1:
122 continue
123 # Segment the file based on --source_delim_regex.
124 if source_split_re.search(line):
125 source_segments.append([])
127 source_segments[-1].append(line + '\n')
128 return source_segments
131 # Pre-process a line of input to remove any character sequences that will be
132 # problematic with FileCheck.
133 def preprocess_line(line):
134 # Replace any double brackets, '[[' with escaped replacements. '[['
135 # corresponds to variable names in FileCheck.
136 output_line = line.replace('[[', '{{\\[\\[}}')
138 # Replace any single brackets that are followed by an SSA identifier, the
139 # identifier will be replace by a variable; Creating the same situation as
140 # above.
141 output_line = output_line.replace('[%', '{{\\[}}%')
143 return output_line
146 def main():
147 parser = argparse.ArgumentParser(
148 description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
149 parser.add_argument(
150 '--check-prefix', default='CHECK', help='Prefix to use from check file.')
151 parser.add_argument(
152 '-o',
153 '--output',
154 nargs='?',
155 type=argparse.FileType('w'),
156 default=None)
157 parser.add_argument(
158 'input',
159 nargs='?',
160 type=argparse.FileType('r'),
161 default=sys.stdin)
162 parser.add_argument(
163 '--source', type=str,
164 help='Print each CHECK chunk before each delimeter line in the source'
165 'file, respectively. The delimeter lines are identified by '
166 '--source_delim_regex.')
167 parser.add_argument('--source_delim_regex', type=str, default='func @')
168 parser.add_argument(
169 '--starts_from_scope', type=int, default=1,
170 help='Omit the top specified level of content. For example, by default '
171 'it omits "module {"')
172 parser.add_argument('-i', '--inplace', action='store_true', default=False)
174 args = parser.parse_args()
176 # Open the given input file.
177 input_lines = [l.rstrip() for l in args.input]
178 args.input.close()
180 # Generate a note used for the generated check file.
181 script_name = os.path.basename(__file__)
182 autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END)
184 source_segments = None
185 if args.source:
186 source_segments = process_source_lines(
187 [l.rstrip() for l in open(args.source, 'r')],
188 autogenerated_note,
189 args
192 if args.inplace:
193 assert args.output is None
194 output = open(args.source, 'w')
195 elif args.output is None:
196 output = sys.stdout
197 else:
198 output = args.output
200 output_segments = [[]]
201 # A map containing data used for naming SSA value names.
202 variable_namer = SSAVariableNamer()
203 for input_line in input_lines:
204 if not input_line:
205 continue
206 lstripped_input_line = input_line.lstrip()
208 # Lines with blocks begin with a ^. These lines have a trailing comment
209 # that needs to be stripped.
210 is_block = lstripped_input_line[0] == '^'
211 if is_block:
212 input_line = input_line.rsplit('//', 1)[0].rstrip()
214 cur_level = variable_namer.num_scopes()
216 # If the line starts with a '}', pop the last name scope.
217 if lstripped_input_line[0] == '}':
218 variable_namer.pop_name_scope()
219 cur_level = variable_namer.num_scopes()
221 # If the line ends with a '{', push a new name scope.
222 if input_line[-1] == '{':
223 variable_namer.push_name_scope()
224 if cur_level == args.starts_from_scope:
225 output_segments.append([])
227 # Omit lines at the near top level e.g. "module {".
228 if cur_level < args.starts_from_scope:
229 continue
231 if len(output_segments[-1]) == 0:
232 variable_namer.clear_counter()
234 # Preprocess the input to remove any sequences that may be problematic with
235 # FileCheck.
236 input_line = preprocess_line(input_line)
238 # Split the line at the each SSA value name.
239 ssa_split = input_line.split('%')
241 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
242 if len(output_segments[-1]) != 0 or not ssa_split[0]:
243 output_line = '// ' + args.check_prefix + ': '
244 # Pad to align with the 'LABEL' statements.
245 output_line += (' ' * len('-LABEL'))
247 # Output the first line chunk that does not contain an SSA name.
248 output_line += ssa_split[0]
250 # Process the rest of the input line.
251 output_line += process_line(ssa_split[1:], variable_namer)
253 else:
254 # Output the first line chunk that does not contain an SSA name for the
255 # label.
256 output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
258 # Process the rest of the input line on separate check lines.
259 for argument in ssa_split[1:]:
260 output_line += '// ' + args.check_prefix + '-SAME: '
262 # Pad to align with the original position in the line.
263 output_line += ' ' * len(ssa_split[0])
265 # Process the rest of the line.
266 output_line += process_line([argument], variable_namer)
268 # Append the output line.
269 output_segments[-1].append(output_line)
271 output.write(autogenerated_note + '\n')
273 # Write the output.
274 if source_segments:
275 assert len(output_segments) == len(source_segments)
276 for check_segment, source_segment in zip(output_segments, source_segments):
277 for line in check_segment:
278 output.write(line)
279 for line in source_segment:
280 output.write(line)
281 else:
282 for segment in output_segments:
283 output.write('\n')
284 for output_line in segment:
285 output.write(output_line)
286 output.write('\n')
287 output.close()
290 if __name__ == '__main__':
291 main()