[RISCV] Refactor predicates for rvv intrinsic patterns.
[llvm-project.git] / llvm / utils / update_mir_test_checks.py
blob6e3a5e9732761121328c9a2fcaceec3d24f6eecb
1 #!/usr/bin/env python3
3 """Updates FileCheck checks in MIR tests.
5 This script is a utility to update MIR based tests with new FileCheck
6 patterns.
8 The checks added by this script will cover the entire body of each
9 function it handles. Virtual registers used are given names via
10 FileCheck patterns, so if you do want to check a subset of the body it
11 should be straightforward to trim out the irrelevant parts. None of
12 the YAML metadata will be checked, other than function names, and fixedStack
13 if the --print-fixed-stack option is used.
15 If there are multiple llc commands in a test, the full set of checks
16 will be repeated for each different check pattern. Checks for patterns
17 that are common between different commands will be left as-is by
18 default, or removed if the --remove-common-prefixes flag is provided.
19 """
21 from __future__ import print_function
23 import argparse
24 import collections
25 import glob
26 import os
27 import re
28 import subprocess
29 import sys
31 from UpdateTestChecks import common
33 MIR_FUNC_NAME_RE = re.compile(r' *name: *(?P<func>[A-Za-z0-9_.-]+)')
34 MIR_BODY_BEGIN_RE = re.compile(r' *body: *\|')
35 MIR_BASIC_BLOCK_RE = re.compile(r' *bb\.[0-9]+.*:$')
36 VREG_RE = re.compile(r'(%[0-9]+)(?::[a-z0-9_]+)?(?:\([<>a-z0-9 ]+\))?')
37 MI_FLAGS_STR= (
38 r'(frame-setup |frame-destroy |nnan |ninf |nsz |arcp |contract |afn '
39 r'|reassoc |nuw |nsw |exact |nofpexcept |nomerge )*')
40 VREG_DEF_FLAGS_STR = r'(?:dead )*'
41 VREG_DEF_RE = re.compile(
42 r'^ *(?P<vregs>{2}{0}(?:, {2}{0})*) = '
43 r'{1}(?P<opcode>[A-Zt][A-Za-z0-9_]+)'.format(
44 VREG_RE.pattern, MI_FLAGS_STR, VREG_DEF_FLAGS_STR))
45 MIR_PREFIX_DATA_RE = re.compile(r'^ *(;|bb.[0-9].*: *$|[a-z]+:( |$)|$)')
47 IR_FUNC_NAME_RE = re.compile(
48 r'^\s*define\s+(?:internal\s+)?[^@]*@(?P<func>[A-Za-z0-9_.]+)\s*\(')
49 IR_PREFIX_DATA_RE = re.compile(r'^ *(;|$)')
51 MIR_FUNC_RE = re.compile(
52 r'^---$'
53 r'\n'
54 r'^ *name: *(?P<func>[A-Za-z0-9_.-]+)$'
55 r'.*?'
56 r'^ *fixedStack: *(\[\])? *\n'
57 r'(?P<fixedStack>.*?)\n?'
58 r'^ *stack:'
59 r'.*?'
60 r'^ *body: *\|\n'
61 r'(?P<body>.*?)\n'
62 r'^\.\.\.$',
63 flags=(re.M | re.S))
66 class LLC:
67 def __init__(self, bin):
68 self.bin = bin
70 def __call__(self, args, ir):
71 if ir.endswith('.mir'):
72 args = '{} -x mir'.format(args)
73 with open(ir) as ir_file:
74 stdout = subprocess.check_output('{} {}'.format(self.bin, args),
75 shell=True, stdin=ir_file)
76 if sys.version_info[0] > 2:
77 stdout = stdout.decode()
78 # Fix line endings to unix CR style.
79 stdout = stdout.replace('\r\n', '\n')
80 return stdout
83 class Run:
84 def __init__(self, prefixes, cmd_args, triple):
85 self.prefixes = prefixes
86 self.cmd_args = cmd_args
87 self.triple = triple
89 def __getitem__(self, index):
90 return [self.prefixes, self.cmd_args, self.triple][index]
93 def log(msg, verbose=True):
94 if verbose:
95 print(msg, file=sys.stderr)
98 def find_triple_in_ir(lines, verbose=False):
99 for l in lines:
100 m = common.TRIPLE_IR_RE.match(l)
101 if m:
102 return m.group(1)
103 return None
106 def build_run_list(test, run_lines, verbose=False):
107 run_list = []
108 all_prefixes = []
109 for l in run_lines:
110 if '|' not in l:
111 common.warn('Skipping unparsable RUN line: ' + l)
112 continue
114 commands = [cmd.strip() for cmd in l.split('|', 1)]
115 llc_cmd = commands[0]
116 filecheck_cmd = commands[1] if len(commands) > 1 else ''
117 common.verify_filecheck_prefixes(filecheck_cmd)
119 if not llc_cmd.startswith('llc '):
120 common.warn('Skipping non-llc RUN line: {}'.format(l), test_file=test)
121 continue
122 if not filecheck_cmd.startswith('FileCheck '):
123 common.warn('Skipping non-FileChecked RUN line: {}'.format(l),
124 test_file=test)
125 continue
127 triple = None
128 m = common.TRIPLE_ARG_RE.search(llc_cmd)
129 if m:
130 triple = m.group(1)
131 # If we find -march but not -mtriple, use that.
132 m = common.MARCH_ARG_RE.search(llc_cmd)
133 if m and not triple:
134 triple = '{}--'.format(m.group(1))
136 cmd_args = llc_cmd[len('llc'):].strip()
137 cmd_args = cmd_args.replace('< %s', '').replace('%s', '').strip()
138 check_prefixes = common.get_check_prefixes(filecheck_cmd)
139 all_prefixes += check_prefixes
141 run_list.append(Run(check_prefixes, cmd_args, triple))
143 # Sort prefixes that are shared between run lines before unshared prefixes.
144 # This causes us to prefer printing shared prefixes.
145 for run in run_list:
146 run.prefixes.sort(key=lambda prefix: -all_prefixes.count(prefix))
148 return run_list
151 def find_functions_with_one_bb(lines, verbose=False):
152 result = []
153 cur_func = None
154 bbs = 0
155 for line in lines:
156 m = MIR_FUNC_NAME_RE.match(line)
157 if m:
158 if bbs == 1:
159 result.append(cur_func)
160 cur_func = m.group('func')
161 bbs = 0
162 m = MIR_BASIC_BLOCK_RE.match(line)
163 if m:
164 bbs += 1
165 if bbs == 1:
166 result.append(cur_func)
167 return result
170 class FunctionInfo:
171 def __init__(self, body, fixedStack):
172 self.body = body
173 self.fixedStack = fixedStack
175 def __eq__(self, other):
176 if not isinstance(other, FunctionInfo):
177 return False
178 return self.body == other.body and self.fixedStack == other.fixedStack
181 def build_function_info_dictionary(test, raw_tool_output, triple, prefixes,
182 func_dict, verbose):
183 for m in MIR_FUNC_RE.finditer(raw_tool_output):
184 func = m.group('func')
185 fixedStack = m.group('fixedStack')
186 body = m.group('body')
187 if verbose:
188 log('Processing function: {}'.format(func))
189 for l in body.splitlines():
190 log(' {}'.format(l))
192 # Vreg mangling
193 mangled = []
194 vreg_map = {}
195 for func_line in body.splitlines(keepends=True):
196 m = VREG_DEF_RE.match(func_line)
197 if m:
198 for vreg in VREG_RE.finditer(m.group('vregs')):
199 name = mangle_vreg(m.group('opcode'), vreg_map.values())
200 vreg_map[vreg.group(1)] = name
201 func_line = func_line.replace(
202 vreg.group(1), '[[{}:%[0-9]+]]'.format(name), 1)
203 for number, name in vreg_map.items():
204 func_line = re.sub(r'{}\b'.format(number), '[[{}]]'.format(name),
205 func_line)
206 mangled.append(func_line)
207 body = ''.join(mangled)
209 for prefix in prefixes:
210 info = FunctionInfo(body, fixedStack)
211 if func in func_dict[prefix]:
212 if func_dict[prefix][func] != info:
213 func_dict[prefix][func] = None
214 else:
215 func_dict[prefix][func] = info
218 def add_checks_for_function(test, output_lines, run_list, func_dict, func_name,
219 single_bb, args):
220 printed_prefixes = set()
221 for run in run_list:
222 for prefix in run.prefixes:
223 if prefix in printed_prefixes:
224 break
225 if not func_dict[prefix][func_name]:
226 continue
227 # if printed_prefixes:
228 # # Add some space between different check prefixes.
229 # output_lines.append('')
230 printed_prefixes.add(prefix)
231 log('Adding {} lines for {}'.format(prefix, func_name), args.verbose)
232 add_check_lines(test, output_lines, prefix, func_name, single_bb,
233 func_dict[prefix][func_name], args)
234 break
235 else:
236 common.warn(
237 'Found conflicting asm for function: {}'.format(func_name),
238 test_file=test)
239 return output_lines
242 def add_check_lines(test, output_lines, prefix, func_name, single_bb,
243 func_info: FunctionInfo, args):
244 func_body = func_info.body.splitlines()
245 if single_bb:
246 # Don't bother checking the basic block label for a single BB
247 func_body.pop(0)
249 if not func_body:
250 common.warn('Function has no instructions to check: {}'.format(func_name),
251 test_file=test)
252 return
254 first_line = func_body[0]
255 indent = len(first_line) - len(first_line.lstrip(' '))
256 # A check comment, indented the appropriate amount
257 check = '{:>{}}; {}'.format('', indent, prefix)
259 output_lines.append('{}-LABEL: name: {}'.format(check, func_name))
261 if args.print_fixed_stack:
262 output_lines.append('{}: fixedStack:'.format(check))
263 for stack_line in func_info.fixedStack.splitlines():
264 filecheck_directive = check + '-NEXT'
265 output_lines.append('{}: {}'.format(filecheck_directive, stack_line))
267 first_check = True
268 for func_line in func_body:
269 if not func_line.strip():
270 # The mir printer prints leading whitespace so we can't use CHECK-EMPTY:
271 output_lines.append(check + '-NEXT: {{' + func_line + '$}}')
272 continue
273 filecheck_directive = check if first_check else check + '-NEXT'
274 first_check = False
275 check_line = '{}: {}'.format(filecheck_directive, func_line[indent:]).rstrip()
276 output_lines.append(check_line)
279 def mangle_vreg(opcode, current_names):
280 base = opcode
281 # Simplify some common prefixes and suffixes
282 if opcode.startswith('G_'):
283 base = base[len('G_'):]
284 if opcode.endswith('_PSEUDO'):
285 base = base[:len('_PSEUDO')]
286 # Shorten some common opcodes with long-ish names
287 base = dict(IMPLICIT_DEF='DEF',
288 GLOBAL_VALUE='GV',
289 CONSTANT='C',
290 FCONSTANT='C',
291 MERGE_VALUES='MV',
292 UNMERGE_VALUES='UV',
293 INTRINSIC='INT',
294 INTRINSIC_W_SIDE_EFFECTS='INT',
295 INSERT_VECTOR_ELT='IVEC',
296 EXTRACT_VECTOR_ELT='EVEC',
297 SHUFFLE_VECTOR='SHUF').get(base, base)
298 # Avoid ambiguity when opcodes end in numbers
299 if len(base.rstrip('0123456789')) < len(base):
300 base += '_'
302 i = 0
303 for name in current_names:
304 if name.rstrip('0123456789') == base:
305 i += 1
306 if i:
307 return '{}{}'.format(base, i)
308 return base
311 def should_add_line_to_output(input_line, prefix_set):
312 # Skip any check lines that we're handling.
313 m = common.CHECK_RE.match(input_line)
314 if m and m.group(1) in prefix_set:
315 return False
316 return True
319 def update_test_file(args, test, autogenerated_note):
320 with open(test) as fd:
321 input_lines = [l.rstrip() for l in fd]
323 triple_in_ir = find_triple_in_ir(input_lines, args.verbose)
324 run_lines = common.find_run_lines(test, input_lines)
325 run_list = build_run_list(test, run_lines, args.verbose)
327 simple_functions = find_functions_with_one_bb(input_lines, args.verbose)
329 func_dict = {}
330 for run in run_list:
331 for prefix in run.prefixes:
332 func_dict.update({prefix: dict()})
333 for prefixes, llc_args, triple_in_cmd in run_list:
334 log('Extracted LLC cmd: llc {}'.format(llc_args), args.verbose)
335 log('Extracted FileCheck prefixes: {}'.format(prefixes), args.verbose)
337 raw_tool_output = args.llc_binary(llc_args, test)
338 if not triple_in_cmd and not triple_in_ir:
339 common.warn('No triple found: skipping file', test_file=test)
340 return
342 build_function_info_dictionary(test, raw_tool_output,
343 triple_in_cmd or triple_in_ir,
344 prefixes, func_dict, args.verbose)
346 state = 'toplevel'
347 func_name = None
348 prefix_set = set([prefix for run in run_list for prefix in run.prefixes])
349 log('Rewriting FileCheck prefixes: {}'.format(prefix_set), args.verbose)
351 output_lines = []
352 output_lines.append(autogenerated_note)
354 for input_line in input_lines:
355 if input_line == autogenerated_note:
356 continue
358 if state == 'toplevel':
359 m = IR_FUNC_NAME_RE.match(input_line)
360 if m:
361 state = 'ir function prefix'
362 func_name = m.group('func')
363 if input_line.rstrip('| \r\n') == '---':
364 state = 'document'
365 output_lines.append(input_line)
366 elif state == 'document':
367 m = MIR_FUNC_NAME_RE.match(input_line)
368 if m:
369 state = 'mir function metadata'
370 func_name = m.group('func')
371 if input_line.strip() == '...':
372 state = 'toplevel'
373 func_name = None
374 if should_add_line_to_output(input_line, prefix_set):
375 output_lines.append(input_line)
376 elif state == 'mir function metadata':
377 if should_add_line_to_output(input_line, prefix_set):
378 output_lines.append(input_line)
379 m = MIR_BODY_BEGIN_RE.match(input_line)
380 if m:
381 if func_name in simple_functions:
382 # If there's only one block, put the checks inside it
383 state = 'mir function prefix'
384 continue
385 state = 'mir function body'
386 add_checks_for_function(test, output_lines, run_list,
387 func_dict, func_name, single_bb=False,
388 args=args)
389 elif state == 'mir function prefix':
390 m = MIR_PREFIX_DATA_RE.match(input_line)
391 if not m:
392 state = 'mir function body'
393 add_checks_for_function(test, output_lines, run_list,
394 func_dict, func_name, single_bb=True,
395 args=args)
397 if should_add_line_to_output(input_line, prefix_set):
398 output_lines.append(input_line)
399 elif state == 'mir function body':
400 if input_line.strip() == '...':
401 state = 'toplevel'
402 func_name = None
403 if should_add_line_to_output(input_line, prefix_set):
404 output_lines.append(input_line)
405 elif state == 'ir function prefix':
406 m = IR_PREFIX_DATA_RE.match(input_line)
407 if not m:
408 state = 'ir function body'
409 add_checks_for_function(test, output_lines, run_list,
410 func_dict, func_name, single_bb=False,
411 args=args)
413 if should_add_line_to_output(input_line, prefix_set):
414 output_lines.append(input_line)
415 elif state == 'ir function body':
416 if input_line.strip() == '}':
417 state = 'toplevel'
418 func_name = None
419 if should_add_line_to_output(input_line, prefix_set):
420 output_lines.append(input_line)
423 log('Writing {} lines to {}...'.format(len(output_lines), test), args.verbose)
425 with open(test, 'wb') as fd:
426 fd.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
429 def main():
430 parser = argparse.ArgumentParser(
431 description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
432 parser.add_argument('--llc-binary', default='llc', type=LLC,
433 help='The "llc" binary to generate the test case with')
434 parser.add_argument('--print-fixed-stack', action='store_true',
435 help='Add check lines for fixedStack')
436 parser.add_argument('tests', nargs='+')
437 args = common.parse_commandline_args(parser)
439 script_name = os.path.basename(__file__)
440 for ti in common.itertests(args.tests, parser,
441 script_name='utils/' + script_name):
442 try:
443 update_test_file(ti.args, ti.path, ti.test_autogenerated_note)
444 except Exception:
445 common.warn('Error processing file', test_file=ti.path)
446 raise
449 if __name__ == '__main__':
450 main()