[ARM] Basic And/Or/Xor handling for MVE predicates
[llvm-complete.git] / utils / update_mca_test_checks.py
blob87ac19b39aea181f25a68977911bd6169192fcba
1 #!/usr/bin/env python
3 """A test case update script.
5 This script is a utility to update LLVM 'llvm-mca' based test cases with new
6 FileCheck patterns.
7 """
9 import argparse
10 from collections import defaultdict
11 import glob
12 import os
13 import sys
14 import warnings
16 from UpdateTestChecks import common
19 COMMENT_CHAR = '#'
20 ADVERT_PREFIX = '{} NOTE: Assertions have been autogenerated by '.format(
21 COMMENT_CHAR)
22 ADVERT = '{}utils/{}'.format(ADVERT_PREFIX, os.path.basename(__file__))
25 class Error(Exception):
26 """ Generic Error that can be raised without printing a traceback.
27 """
28 pass
31 def _warn(msg):
32 """ Log a user warning to stderr.
33 """
34 warnings.warn(msg, Warning, stacklevel=2)
37 def _configure_warnings(args):
38 warnings.resetwarnings()
39 if args.w:
40 warnings.simplefilter('ignore')
41 if args.Werror:
42 warnings.simplefilter('error')
45 def _showwarning(message, category, filename, lineno, file=None, line=None):
46 """ Version of warnings.showwarning that won't attempt to print out the
47 line at the location of the warning if the line text is not explicitly
48 specified.
49 """
50 if file is None:
51 file = sys.stderr
52 if line is None:
53 line = ''
54 file.write(warnings.formatwarning(message, category, filename, lineno, line))
57 def _parse_args():
58 parser = argparse.ArgumentParser(description=__doc__)
59 parser.add_argument('-v', '--verbose',
60 action='store_true',
61 help='show verbose output')
62 parser.add_argument('-w',
63 action='store_true',
64 help='suppress warnings')
65 parser.add_argument('-Werror',
66 action='store_true',
67 help='promote warnings to errors')
68 parser.add_argument('--llvm-mca-binary',
69 metavar='<path>',
70 default='llvm-mca',
71 help='the binary to use to generate the test case '
72 '(default: llvm-mca)')
73 parser.add_argument('tests',
74 metavar='<test-path>',
75 nargs='+')
76 args = parser.parse_args()
78 _configure_warnings(args)
80 if not args.llvm_mca_binary:
81 raise Error('--llvm-mca-binary value cannot be empty string')
83 if 'llvm-mca' not in os.path.basename(args.llvm_mca_binary):
84 _warn('unexpected binary name: {}'.format(args.llvm_mca_binary))
86 return args
89 def _find_run_lines(input_lines, args):
90 raw_lines = [m.group(1)
91 for m in [common.RUN_LINE_RE.match(l) for l in input_lines]
92 if m]
93 run_lines = [raw_lines[0]] if len(raw_lines) > 0 else []
94 for l in raw_lines[1:]:
95 if run_lines[-1].endswith(r'\\'):
96 run_lines[-1] = run_lines[-1].rstrip('\\') + ' ' + l
97 else:
98 run_lines.append(l)
100 if args.verbose:
101 sys.stderr.write('Found {} RUN line{}:\n'.format(
102 len(run_lines), '' if len(run_lines) == 1 else 's'))
103 for line in run_lines:
104 sys.stderr.write(' RUN: {}\n'.format(line))
106 return run_lines
109 def _get_run_infos(run_lines, args):
110 run_infos = []
111 for run_line in run_lines:
112 try:
113 (tool_cmd, filecheck_cmd) = tuple([cmd.strip()
114 for cmd in run_line.split('|', 1)])
115 except ValueError:
116 _warn('could not split tool and filecheck commands: {}'.format(run_line))
117 continue
119 tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
121 if not tool_cmd.startswith(tool_basename + ' '):
122 _warn('skipping non-{} RUN line: {}'.format(tool_basename, run_line))
123 continue
125 if not filecheck_cmd.startswith('FileCheck '):
126 _warn('skipping non-FileCheck RUN line: {}'.format(run_line))
127 continue
129 tool_cmd_args = tool_cmd[len(tool_basename):].strip()
130 tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip()
132 check_prefixes = [item
133 for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
134 for item in m.group(1).split(',')]
135 if not check_prefixes:
136 check_prefixes = ['CHECK']
138 run_infos.append((check_prefixes, tool_cmd_args))
140 return run_infos
143 def _break_down_block(block_info, common_prefix):
144 """ Given a block_info, see if we can analyze it further to let us break it
145 down by prefix per-line rather than per-block.
147 texts = block_info.keys()
148 prefixes = list(block_info.values())
149 # Split the lines from each of the incoming block_texts and zip them so that
150 # each element contains the corresponding lines from each text. E.g.
152 # block_text_1: A # line 1
153 # B # line 2
155 # block_text_2: A # line 1
156 # C # line 2
158 # would become:
160 # [(A, A), # line 1
161 # (B, C)] # line 2
163 line_tuples = list(zip(*list((text.splitlines() for text in texts))))
165 # To simplify output, we'll only proceed if the very first line of the block
166 # texts is common to each of them.
167 if len(set(line_tuples[0])) != 1:
168 return []
170 result = []
171 lresult = defaultdict(list)
172 for i, line in enumerate(line_tuples):
173 if len(set(line)) == 1:
174 # We're about to output a line with the common prefix. This is a sync
175 # point so flush any batched-up lines one prefix at a time to the output
176 # first.
177 for prefix in sorted(lresult):
178 result.extend(lresult[prefix])
179 lresult = defaultdict(list)
181 # The line is common to each block so output with the common prefix.
182 result.append((common_prefix, line[0]))
183 else:
184 # The line is not common to each block, or we don't have a common prefix.
185 # If there are no prefixes available, warn and bail out.
186 if not prefixes[0]:
187 _warn('multiple lines not disambiguated by prefixes:\n{}\n'
188 'Some blocks may be skipped entirely as a result.'.format(
189 '\n'.join(' - {}'.format(l) for l in line)))
190 return []
192 # Iterate through the line from each of the blocks and add the line with
193 # the corresponding prefix to the current batch of results so that we can
194 # later output them per-prefix.
195 for i, l in enumerate(line):
196 for prefix in prefixes[i]:
197 lresult[prefix].append((prefix, l))
199 # Flush any remaining batched-up lines one prefix at a time to the output.
200 for prefix in sorted(lresult):
201 result.extend(lresult[prefix])
202 return result
205 def _get_useful_prefix_info(run_infos):
206 """ Given the run_infos, calculate any prefixes that are common to every one,
207 and the length of the longest prefix string.
209 try:
210 all_sets = [set(s) for s in list(zip(*run_infos))[0]]
211 common_to_all = set.intersection(*all_sets)
212 longest_prefix_len = max(len(p) for p in set.union(*all_sets))
213 except IndexError:
214 common_to_all = []
215 longest_prefix_len = 0
216 else:
217 if len(common_to_all) > 1:
218 _warn('Multiple prefixes common to all RUN lines: {}'.format(
219 common_to_all))
220 if common_to_all:
221 common_to_all = sorted(common_to_all)[0]
222 return common_to_all, longest_prefix_len
225 def _align_matching_blocks(all_blocks, farthest_indexes):
226 """ Some sub-sequences of blocks may be common to multiple lists of blocks,
227 but at different indexes in each one.
229 For example, in the following case, A,B,E,F, and H are common to both
230 sets, but only A and B would be identified as such due to the indexes
231 matching:
233 index | 0 1 2 3 4 5 6
234 ------+--------------
235 setA | A B C D E F H
236 setB | A B E F G H
238 This function attempts to align the indexes of matching blocks by
239 inserting empty blocks into the block list. With this approach, A, B, E,
240 F, and H would now be able to be identified as matching blocks:
242 index | 0 1 2 3 4 5 6 7
243 ------+----------------
244 setA | A B C D E F H
245 setB | A B E F G H
248 # "Farthest block analysis": essentially, iterate over all blocks and find
249 # the highest index into a block list for the first instance of each block.
250 # This is relatively expensive, but we're dealing with small numbers of
251 # blocks so it doesn't make a perceivable difference to user time.
252 for blocks in all_blocks.values():
253 for block in blocks:
254 if not block:
255 continue
257 index = blocks.index(block)
259 if index > farthest_indexes[block]:
260 farthest_indexes[block] = index
262 # Use the results of the above analysis to identify any blocks that can be
263 # shunted along to match the farthest index value.
264 for blocks in all_blocks.values():
265 for index, block in enumerate(blocks):
266 if not block:
267 continue
269 changed = False
270 # If the block has not already been subject to alignment (i.e. if the
271 # previous block is not empty) then insert empty blocks until the index
272 # matches the farthest index identified for that block.
273 if (index > 0) and blocks[index - 1]:
274 while(index < farthest_indexes[block]):
275 blocks.insert(index, '')
276 index += 1
277 changed = True
279 if changed:
280 # Bail out. We'll need to re-do the farthest block analysis now that
281 # we've inserted some blocks.
282 return True
284 return False
287 def _get_block_infos(run_infos, test_path, args, common_prefix): # noqa
288 """ For each run line, run the tool with the specified args and collect the
289 output. We use the concept of 'blocks' for uniquing, where a block is
290 a series of lines of text with no more than one newline character between
291 each one. For example:
293 This
296 block
298 This is
299 another block
301 This is yet another block
303 We then build up a 'block_infos' structure containing a dict where the
304 text of each block is the key and a list of the sets of prefixes that may
305 generate that particular block. This then goes through a series of
306 transformations to minimise the amount of CHECK lines that need to be
307 written by taking advantage of common prefixes.
310 def _block_key(tool_args, prefixes):
311 """ Get a hashable key based on the current tool_args and prefixes.
313 return ' '.join([tool_args] + prefixes)
315 all_blocks = {}
316 max_block_len = 0
318 # A cache of the furthest-back position in any block list of the first
319 # instance of each block, indexed by the block itself.
320 farthest_indexes = defaultdict(int)
322 # Run the tool for each run line to generate all of the blocks.
323 for prefixes, tool_args in run_infos:
324 key = _block_key(tool_args, prefixes)
325 raw_tool_output = common.invoke_tool(args.llvm_mca_binary,
326 tool_args,
327 test_path)
329 # Replace any lines consisting of purely whitespace with empty lines.
330 raw_tool_output = '\n'.join(line if line.strip() else ''
331 for line in raw_tool_output.splitlines())
333 # Split blocks, stripping all trailing whitespace, but keeping preceding
334 # whitespace except for newlines so that columns will line up visually.
335 all_blocks[key] = [b.lstrip('\n').rstrip()
336 for b in raw_tool_output.split('\n\n')]
337 max_block_len = max(max_block_len, len(all_blocks[key]))
339 # Attempt to align matching blocks until no more changes can be made.
340 made_changes = True
341 while made_changes:
342 made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
344 # If necessary, pad the lists of blocks with empty blocks so that they are
345 # all the same length.
346 for key in all_blocks:
347 len_to_pad = max_block_len - len(all_blocks[key])
348 all_blocks[key] += [''] * len_to_pad
350 # Create the block_infos structure where it is a nested dict in the form of:
351 # block number -> block text -> list of prefix sets
352 block_infos = defaultdict(lambda: defaultdict(list))
353 for prefixes, tool_args in run_infos:
354 key = _block_key(tool_args, prefixes)
355 for block_num, block_text in enumerate(all_blocks[key]):
356 block_infos[block_num][block_text].append(set(prefixes))
358 # Now go through the block_infos structure and attempt to smartly prune the
359 # number of prefixes per block to the minimal set possible to output.
360 for block_num in range(len(block_infos)):
361 # When there are multiple block texts for a block num, remove any
362 # prefixes that are common to more than one of them.
363 # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
364 all_sets = [s for s in block_infos[block_num].values()]
365 pruned_sets = []
367 for i, setlist in enumerate(all_sets):
368 other_set_values = set([elem for j, setlist2 in enumerate(all_sets)
369 for set_ in setlist2 for elem in set_
370 if i != j])
371 pruned_sets.append([s - other_set_values for s in setlist])
373 for i, block_text in enumerate(block_infos[block_num]):
375 # When a block text matches multiple sets of prefixes, try removing any
376 # prefixes that aren't common to all of them.
377 # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
378 common_values = set.intersection(*pruned_sets[i])
379 if common_values:
380 pruned_sets[i] = [common_values]
382 # Everything should be uniqued as much as possible by now. Apply the
383 # newly pruned sets to the block_infos structure.
384 # If there are any blocks of text that still match multiple prefixes,
385 # output a warning.
386 current_set = set()
387 for s in pruned_sets[i]:
388 s = sorted(list(s))
389 if s:
390 current_set.add(s[0])
391 if len(s) > 1:
392 _warn('Multiple prefixes generating same output: {} '
393 '(discarding {})'.format(','.join(s), ','.join(s[1:])))
395 if block_text and not current_set:
396 raise Error(
397 'block not captured by existing prefixes:\n\n{}'.format(block_text))
398 block_infos[block_num][block_text] = sorted(list(current_set))
400 # If we have multiple block_texts, try to break them down further to avoid
401 # the case where we have very similar block_texts repeated after each
402 # other.
403 if common_prefix and len(block_infos[block_num]) > 1:
404 # We'll only attempt this if each of the block_texts have the same number
405 # of lines as each other.
406 same_num_Lines = (len(set(len(k.splitlines())
407 for k in block_infos[block_num].keys())) == 1)
408 if same_num_Lines:
409 breakdown = _break_down_block(block_infos[block_num], common_prefix)
410 if breakdown:
411 block_infos[block_num] = breakdown
413 return block_infos
416 def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
417 """ Write an individual block, with correct padding on the prefixes.
418 Returns a set of all of the prefixes that it has written.
420 end_prefix = ': '
421 previous_prefix = None
422 num_lines_of_prefix = 0
423 written_prefixes = set()
425 for prefix, line in block:
426 if prefix in not_prefix_set:
427 _warn('not writing for prefix {0} due to presence of "{0}-NOT:" '
428 'in input file.'.format(prefix))
429 continue
431 # If the previous line isn't already blank and we're writing more than one
432 # line for the current prefix output a blank line first, unless either the
433 # current of previous prefix is common to all.
434 num_lines_of_prefix += 1
435 if prefix != previous_prefix:
436 if output and output[-1]:
437 if num_lines_of_prefix > 1 or any(p == common_prefix
438 for p in (prefix, previous_prefix)):
439 output.append('')
440 num_lines_of_prefix = 0
441 previous_prefix = prefix
443 written_prefixes.add(prefix)
444 output.append(
445 '{} {}{}{} {}'.format(COMMENT_CHAR,
446 prefix,
447 end_prefix,
448 ' ' * (prefix_pad - len(prefix)),
449 line).rstrip())
450 end_prefix = '-NEXT:'
452 output.append('')
453 return written_prefixes
456 def _write_output(test_path, input_lines, prefix_list, block_infos, # noqa
457 args, common_prefix, prefix_pad):
458 prefix_set = set([prefix for prefixes, _ in prefix_list
459 for prefix in prefixes])
460 not_prefix_set = set()
462 output_lines = []
463 for input_line in input_lines:
464 if input_line.startswith(ADVERT_PREFIX):
465 continue
467 if input_line.startswith(COMMENT_CHAR):
468 m = common.CHECK_RE.match(input_line)
469 try:
470 prefix = m.group(1)
471 except AttributeError:
472 prefix = None
474 if '{}-NOT:'.format(prefix) in input_line:
475 not_prefix_set.add(prefix)
477 if prefix not in prefix_set or prefix in not_prefix_set:
478 output_lines.append(input_line)
479 continue
481 if common.should_add_line_to_output(input_line, prefix_set):
482 # This input line of the function body will go as-is into the output.
483 # Except make leading whitespace uniform: 2 spaces.
484 input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r' ', input_line)
486 # Skip empty lines if the previous output line is also empty.
487 if input_line or output_lines[-1]:
488 output_lines.append(input_line)
489 else:
490 continue
492 # Add a blank line before the new checks if required.
493 if len(output_lines) > 0 and output_lines[-1]:
494 output_lines.append('')
496 output_check_lines = []
497 used_prefixes = set()
498 for block_num in range(len(block_infos)):
499 if type(block_infos[block_num]) is list:
500 # The block is of the type output from _break_down_block().
501 used_prefixes |= _write_block(output_check_lines,
502 block_infos[block_num],
503 not_prefix_set,
504 common_prefix,
505 prefix_pad)
506 else:
507 # _break_down_block() was unable to do do anything so output the block
508 # as-is.
510 # Rather than writing out each block as soon we encounter it, save it
511 # indexed by prefix so that we can write all of the blocks out sorted by
512 # prefix at the end.
513 output_blocks = defaultdict(list)
515 for block_text in sorted(block_infos[block_num]):
517 if not block_text:
518 continue
520 lines = block_text.split('\n')
521 for prefix in block_infos[block_num][block_text]:
522 assert prefix not in output_blocks
523 used_prefixes |= _write_block(output_blocks[prefix],
524 [(prefix, line) for line in lines],
525 not_prefix_set,
526 common_prefix,
527 prefix_pad)
529 for prefix in sorted(output_blocks):
530 output_check_lines.extend(output_blocks[prefix])
532 unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
533 if unused_prefixes:
534 raise Error('unused prefixes: {}'.format(sorted(unused_prefixes)))
536 if output_check_lines:
537 output_lines.insert(0, ADVERT)
538 output_lines.extend(output_check_lines)
540 # The file should not end with two newlines. It creates unnecessary churn.
541 while len(output_lines) > 0 and output_lines[-1] == '':
542 output_lines.pop()
544 if input_lines == output_lines:
545 sys.stderr.write(' [unchanged]\n')
546 return
547 sys.stderr.write(' [{} lines total]\n'.format(len(output_lines)))
549 if args.verbose:
550 sys.stderr.write(
551 'Writing {} lines to {}...\n\n'.format(len(output_lines), test_path))
553 with open(test_path, 'wb') as f:
554 f.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
556 def main():
557 args = _parse_args()
558 test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
559 for test_path in test_paths:
560 sys.stderr.write('Test: {}\n'.format(test_path))
562 # Call this per test. By default each warning will only be written once
563 # per source location. Reset the warning filter so that now each warning
564 # will be written once per source location per test.
565 _configure_warnings(args)
567 if args.verbose:
568 sys.stderr.write(
569 'Scanning for RUN lines in test file: {}\n'.format(test_path))
571 if not os.path.isfile(test_path):
572 raise Error('could not find test file: {}'.format(test_path))
574 with open(test_path) as f:
575 input_lines = [l.rstrip() for l in f]
577 run_lines = _find_run_lines(input_lines, args)
578 run_infos = _get_run_infos(run_lines, args)
579 common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
580 block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
581 _write_output(test_path,
582 input_lines,
583 run_infos,
584 block_infos,
585 args,
586 common_prefix,
587 prefix_pad)
589 return 0
592 if __name__ == '__main__':
593 try:
594 warnings.showwarning = _showwarning
595 sys.exit(main())
596 except Error as e:
597 sys.stdout.write('error: {}\n'.format(e))
598 sys.exit(1)