gn build: Merge r372267
[llvm-complete.git] / utils / update_mca_test_checks.py
blobbbeca1d557b7a730fb2639ef10fb7f01e35fc605
1 #!/usr/bin/env python
3 """A test case update script.
5 This script is a utility to update LLVM 'llvm-mca' based test cases with new
6 FileCheck patterns.
7 """
9 import argparse
10 from collections import defaultdict
11 import glob
12 import os
13 import sys
14 import warnings
16 from UpdateTestChecks import common
19 COMMENT_CHAR = '#'
20 ADVERT_PREFIX = '{} NOTE: Assertions have been autogenerated by '.format(
21 COMMENT_CHAR)
22 ADVERT = '{}utils/{}'.format(ADVERT_PREFIX, os.path.basename(__file__))
25 class Error(Exception):
26 """ Generic Error that can be raised without printing a traceback.
27 """
28 pass
31 def _warn(msg):
32 """ Log a user warning to stderr.
33 """
34 warnings.warn(msg, Warning, stacklevel=2)
37 def _configure_warnings(args):
38 warnings.resetwarnings()
39 if args.w:
40 warnings.simplefilter('ignore')
41 if args.Werror:
42 warnings.simplefilter('error')
45 def _showwarning(message, category, filename, lineno, file=None, line=None):
46 """ Version of warnings.showwarning that won't attempt to print out the
47 line at the location of the warning if the line text is not explicitly
48 specified.
49 """
50 if file is None:
51 file = sys.stderr
52 if line is None:
53 line = ''
54 file.write(warnings.formatwarning(message, category, filename, lineno, line))
57 def _parse_args():
58 parser = argparse.ArgumentParser(description=__doc__)
59 parser.add_argument('-v', '--verbose',
60 action='store_true',
61 help='show verbose output')
62 parser.add_argument('-w',
63 action='store_true',
64 help='suppress warnings')
65 parser.add_argument('-Werror',
66 action='store_true',
67 help='promote warnings to errors')
68 parser.add_argument('--llvm-mca-binary',
69 metavar='<path>',
70 default='llvm-mca',
71 help='the binary to use to generate the test case '
72 '(default: llvm-mca)')
73 parser.add_argument('tests',
74 metavar='<test-path>',
75 nargs='+')
76 args = parser.parse_args()
78 _configure_warnings(args)
80 if not args.llvm_mca_binary:
81 raise Error('--llvm-mca-binary value cannot be empty string')
83 if 'llvm-mca' not in os.path.basename(args.llvm_mca_binary):
84 _warn('unexpected binary name: {}'.format(args.llvm_mca_binary))
86 return args
89 def _find_run_lines(input_lines, args):
90 raw_lines = [m.group(1)
91 for m in [common.RUN_LINE_RE.match(l) for l in input_lines]
92 if m]
93 run_lines = [raw_lines[0]] if len(raw_lines) > 0 else []
94 for l in raw_lines[1:]:
95 if run_lines[-1].endswith(r'\\'):
96 run_lines[-1] = run_lines[-1].rstrip('\\') + ' ' + l
97 else:
98 run_lines.append(l)
100 if args.verbose:
101 sys.stderr.write('Found {} RUN line{}:\n'.format(
102 len(run_lines), '' if len(run_lines) == 1 else 's'))
103 for line in run_lines:
104 sys.stderr.write(' RUN: {}\n'.format(line))
106 return run_lines
109 def _get_run_infos(run_lines, args):
110 run_infos = []
111 for run_line in run_lines:
112 try:
113 (tool_cmd, filecheck_cmd) = tuple([cmd.strip()
114 for cmd in run_line.split('|', 1)])
115 except ValueError:
116 _warn('could not split tool and filecheck commands: {}'.format(run_line))
117 continue
119 common.verify_filecheck_prefixes(filecheck_cmd)
120 tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
122 if not tool_cmd.startswith(tool_basename + ' '):
123 _warn('skipping non-{} RUN line: {}'.format(tool_basename, run_line))
124 continue
126 if not filecheck_cmd.startswith('FileCheck '):
127 _warn('skipping non-FileCheck RUN line: {}'.format(run_line))
128 continue
130 tool_cmd_args = tool_cmd[len(tool_basename):].strip()
131 tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip()
133 check_prefixes = [item
134 for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
135 for item in m.group(1).split(',')]
136 if not check_prefixes:
137 check_prefixes = ['CHECK']
139 run_infos.append((check_prefixes, tool_cmd_args))
141 return run_infos
144 def _break_down_block(block_info, common_prefix):
145 """ Given a block_info, see if we can analyze it further to let us break it
146 down by prefix per-line rather than per-block.
148 texts = block_info.keys()
149 prefixes = list(block_info.values())
150 # Split the lines from each of the incoming block_texts and zip them so that
151 # each element contains the corresponding lines from each text. E.g.
153 # block_text_1: A # line 1
154 # B # line 2
156 # block_text_2: A # line 1
157 # C # line 2
159 # would become:
161 # [(A, A), # line 1
162 # (B, C)] # line 2
164 line_tuples = list(zip(*list((text.splitlines() for text in texts))))
166 # To simplify output, we'll only proceed if the very first line of the block
167 # texts is common to each of them.
168 if len(set(line_tuples[0])) != 1:
169 return []
171 result = []
172 lresult = defaultdict(list)
173 for i, line in enumerate(line_tuples):
174 if len(set(line)) == 1:
175 # We're about to output a line with the common prefix. This is a sync
176 # point so flush any batched-up lines one prefix at a time to the output
177 # first.
178 for prefix in sorted(lresult):
179 result.extend(lresult[prefix])
180 lresult = defaultdict(list)
182 # The line is common to each block so output with the common prefix.
183 result.append((common_prefix, line[0]))
184 else:
185 # The line is not common to each block, or we don't have a common prefix.
186 # If there are no prefixes available, warn and bail out.
187 if not prefixes[0]:
188 _warn('multiple lines not disambiguated by prefixes:\n{}\n'
189 'Some blocks may be skipped entirely as a result.'.format(
190 '\n'.join(' - {}'.format(l) for l in line)))
191 return []
193 # Iterate through the line from each of the blocks and add the line with
194 # the corresponding prefix to the current batch of results so that we can
195 # later output them per-prefix.
196 for i, l in enumerate(line):
197 for prefix in prefixes[i]:
198 lresult[prefix].append((prefix, l))
200 # Flush any remaining batched-up lines one prefix at a time to the output.
201 for prefix in sorted(lresult):
202 result.extend(lresult[prefix])
203 return result
206 def _get_useful_prefix_info(run_infos):
207 """ Given the run_infos, calculate any prefixes that are common to every one,
208 and the length of the longest prefix string.
210 try:
211 all_sets = [set(s) for s in list(zip(*run_infos))[0]]
212 common_to_all = set.intersection(*all_sets)
213 longest_prefix_len = max(len(p) for p in set.union(*all_sets))
214 except IndexError:
215 common_to_all = []
216 longest_prefix_len = 0
217 else:
218 if len(common_to_all) > 1:
219 _warn('Multiple prefixes common to all RUN lines: {}'.format(
220 common_to_all))
221 if common_to_all:
222 common_to_all = sorted(common_to_all)[0]
223 return common_to_all, longest_prefix_len
226 def _align_matching_blocks(all_blocks, farthest_indexes):
227 """ Some sub-sequences of blocks may be common to multiple lists of blocks,
228 but at different indexes in each one.
230 For example, in the following case, A,B,E,F, and H are common to both
231 sets, but only A and B would be identified as such due to the indexes
232 matching:
234 index | 0 1 2 3 4 5 6
235 ------+--------------
236 setA | A B C D E F H
237 setB | A B E F G H
239 This function attempts to align the indexes of matching blocks by
240 inserting empty blocks into the block list. With this approach, A, B, E,
241 F, and H would now be able to be identified as matching blocks:
243 index | 0 1 2 3 4 5 6 7
244 ------+----------------
245 setA | A B C D E F H
246 setB | A B E F G H
249 # "Farthest block analysis": essentially, iterate over all blocks and find
250 # the highest index into a block list for the first instance of each block.
251 # This is relatively expensive, but we're dealing with small numbers of
252 # blocks so it doesn't make a perceivable difference to user time.
253 for blocks in all_blocks.values():
254 for block in blocks:
255 if not block:
256 continue
258 index = blocks.index(block)
260 if index > farthest_indexes[block]:
261 farthest_indexes[block] = index
263 # Use the results of the above analysis to identify any blocks that can be
264 # shunted along to match the farthest index value.
265 for blocks in all_blocks.values():
266 for index, block in enumerate(blocks):
267 if not block:
268 continue
270 changed = False
271 # If the block has not already been subject to alignment (i.e. if the
272 # previous block is not empty) then insert empty blocks until the index
273 # matches the farthest index identified for that block.
274 if (index > 0) and blocks[index - 1]:
275 while(index < farthest_indexes[block]):
276 blocks.insert(index, '')
277 index += 1
278 changed = True
280 if changed:
281 # Bail out. We'll need to re-do the farthest block analysis now that
282 # we've inserted some blocks.
283 return True
285 return False
288 def _get_block_infos(run_infos, test_path, args, common_prefix): # noqa
289 """ For each run line, run the tool with the specified args and collect the
290 output. We use the concept of 'blocks' for uniquing, where a block is
291 a series of lines of text with no more than one newline character between
292 each one. For example:
294 This
297 block
299 This is
300 another block
302 This is yet another block
304 We then build up a 'block_infos' structure containing a dict where the
305 text of each block is the key and a list of the sets of prefixes that may
306 generate that particular block. This then goes through a series of
307 transformations to minimise the amount of CHECK lines that need to be
308 written by taking advantage of common prefixes.
311 def _block_key(tool_args, prefixes):
312 """ Get a hashable key based on the current tool_args and prefixes.
314 return ' '.join([tool_args] + prefixes)
316 all_blocks = {}
317 max_block_len = 0
319 # A cache of the furthest-back position in any block list of the first
320 # instance of each block, indexed by the block itself.
321 farthest_indexes = defaultdict(int)
323 # Run the tool for each run line to generate all of the blocks.
324 for prefixes, tool_args in run_infos:
325 key = _block_key(tool_args, prefixes)
326 raw_tool_output = common.invoke_tool(args.llvm_mca_binary,
327 tool_args,
328 test_path)
330 # Replace any lines consisting of purely whitespace with empty lines.
331 raw_tool_output = '\n'.join(line if line.strip() else ''
332 for line in raw_tool_output.splitlines())
334 # Split blocks, stripping all trailing whitespace, but keeping preceding
335 # whitespace except for newlines so that columns will line up visually.
336 all_blocks[key] = [b.lstrip('\n').rstrip()
337 for b in raw_tool_output.split('\n\n')]
338 max_block_len = max(max_block_len, len(all_blocks[key]))
340 # Attempt to align matching blocks until no more changes can be made.
341 made_changes = True
342 while made_changes:
343 made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
345 # If necessary, pad the lists of blocks with empty blocks so that they are
346 # all the same length.
347 for key in all_blocks:
348 len_to_pad = max_block_len - len(all_blocks[key])
349 all_blocks[key] += [''] * len_to_pad
351 # Create the block_infos structure where it is a nested dict in the form of:
352 # block number -> block text -> list of prefix sets
353 block_infos = defaultdict(lambda: defaultdict(list))
354 for prefixes, tool_args in run_infos:
355 key = _block_key(tool_args, prefixes)
356 for block_num, block_text in enumerate(all_blocks[key]):
357 block_infos[block_num][block_text].append(set(prefixes))
359 # Now go through the block_infos structure and attempt to smartly prune the
360 # number of prefixes per block to the minimal set possible to output.
361 for block_num in range(len(block_infos)):
362 # When there are multiple block texts for a block num, remove any
363 # prefixes that are common to more than one of them.
364 # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
365 all_sets = [s for s in block_infos[block_num].values()]
366 pruned_sets = []
368 for i, setlist in enumerate(all_sets):
369 other_set_values = set([elem for j, setlist2 in enumerate(all_sets)
370 for set_ in setlist2 for elem in set_
371 if i != j])
372 pruned_sets.append([s - other_set_values for s in setlist])
374 for i, block_text in enumerate(block_infos[block_num]):
376 # When a block text matches multiple sets of prefixes, try removing any
377 # prefixes that aren't common to all of them.
378 # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
379 common_values = set.intersection(*pruned_sets[i])
380 if common_values:
381 pruned_sets[i] = [common_values]
383 # Everything should be uniqued as much as possible by now. Apply the
384 # newly pruned sets to the block_infos structure.
385 # If there are any blocks of text that still match multiple prefixes,
386 # output a warning.
387 current_set = set()
388 for s in pruned_sets[i]:
389 s = sorted(list(s))
390 if s:
391 current_set.add(s[0])
392 if len(s) > 1:
393 _warn('Multiple prefixes generating same output: {} '
394 '(discarding {})'.format(','.join(s), ','.join(s[1:])))
396 if block_text and not current_set:
397 raise Error(
398 'block not captured by existing prefixes:\n\n{}'.format(block_text))
399 block_infos[block_num][block_text] = sorted(list(current_set))
401 # If we have multiple block_texts, try to break them down further to avoid
402 # the case where we have very similar block_texts repeated after each
403 # other.
404 if common_prefix and len(block_infos[block_num]) > 1:
405 # We'll only attempt this if each of the block_texts have the same number
406 # of lines as each other.
407 same_num_Lines = (len(set(len(k.splitlines())
408 for k in block_infos[block_num].keys())) == 1)
409 if same_num_Lines:
410 breakdown = _break_down_block(block_infos[block_num], common_prefix)
411 if breakdown:
412 block_infos[block_num] = breakdown
414 return block_infos
417 def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
418 """ Write an individual block, with correct padding on the prefixes.
419 Returns a set of all of the prefixes that it has written.
421 end_prefix = ': '
422 previous_prefix = None
423 num_lines_of_prefix = 0
424 written_prefixes = set()
426 for prefix, line in block:
427 if prefix in not_prefix_set:
428 _warn('not writing for prefix {0} due to presence of "{0}-NOT:" '
429 'in input file.'.format(prefix))
430 continue
432 # If the previous line isn't already blank and we're writing more than one
433 # line for the current prefix output a blank line first, unless either the
434 # current of previous prefix is common to all.
435 num_lines_of_prefix += 1
436 if prefix != previous_prefix:
437 if output and output[-1]:
438 if num_lines_of_prefix > 1 or any(p == common_prefix
439 for p in (prefix, previous_prefix)):
440 output.append('')
441 num_lines_of_prefix = 0
442 previous_prefix = prefix
444 written_prefixes.add(prefix)
445 output.append(
446 '{} {}{}{} {}'.format(COMMENT_CHAR,
447 prefix,
448 end_prefix,
449 ' ' * (prefix_pad - len(prefix)),
450 line).rstrip())
451 end_prefix = '-NEXT:'
453 output.append('')
454 return written_prefixes
457 def _write_output(test_path, input_lines, prefix_list, block_infos, # noqa
458 args, common_prefix, prefix_pad):
459 prefix_set = set([prefix for prefixes, _ in prefix_list
460 for prefix in prefixes])
461 not_prefix_set = set()
463 output_lines = []
464 for input_line in input_lines:
465 if input_line.startswith(ADVERT_PREFIX):
466 continue
468 if input_line.startswith(COMMENT_CHAR):
469 m = common.CHECK_RE.match(input_line)
470 try:
471 prefix = m.group(1)
472 except AttributeError:
473 prefix = None
475 if '{}-NOT:'.format(prefix) in input_line:
476 not_prefix_set.add(prefix)
478 if prefix not in prefix_set or prefix in not_prefix_set:
479 output_lines.append(input_line)
480 continue
482 if common.should_add_line_to_output(input_line, prefix_set):
483 # This input line of the function body will go as-is into the output.
484 # Except make leading whitespace uniform: 2 spaces.
485 input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r' ', input_line)
487 # Skip empty lines if the previous output line is also empty.
488 if input_line or output_lines[-1]:
489 output_lines.append(input_line)
490 else:
491 continue
493 # Add a blank line before the new checks if required.
494 if len(output_lines) > 0 and output_lines[-1]:
495 output_lines.append('')
497 output_check_lines = []
498 used_prefixes = set()
499 for block_num in range(len(block_infos)):
500 if type(block_infos[block_num]) is list:
501 # The block is of the type output from _break_down_block().
502 used_prefixes |= _write_block(output_check_lines,
503 block_infos[block_num],
504 not_prefix_set,
505 common_prefix,
506 prefix_pad)
507 else:
508 # _break_down_block() was unable to do do anything so output the block
509 # as-is.
511 # Rather than writing out each block as soon we encounter it, save it
512 # indexed by prefix so that we can write all of the blocks out sorted by
513 # prefix at the end.
514 output_blocks = defaultdict(list)
516 for block_text in sorted(block_infos[block_num]):
518 if not block_text:
519 continue
521 lines = block_text.split('\n')
522 for prefix in block_infos[block_num][block_text]:
523 assert prefix not in output_blocks
524 used_prefixes |= _write_block(output_blocks[prefix],
525 [(prefix, line) for line in lines],
526 not_prefix_set,
527 common_prefix,
528 prefix_pad)
530 for prefix in sorted(output_blocks):
531 output_check_lines.extend(output_blocks[prefix])
533 unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
534 if unused_prefixes:
535 raise Error('unused prefixes: {}'.format(sorted(unused_prefixes)))
537 if output_check_lines:
538 output_lines.insert(0, ADVERT)
539 output_lines.extend(output_check_lines)
541 # The file should not end with two newlines. It creates unnecessary churn.
542 while len(output_lines) > 0 and output_lines[-1] == '':
543 output_lines.pop()
545 if input_lines == output_lines:
546 sys.stderr.write(' [unchanged]\n')
547 return
548 sys.stderr.write(' [{} lines total]\n'.format(len(output_lines)))
550 if args.verbose:
551 sys.stderr.write(
552 'Writing {} lines to {}...\n\n'.format(len(output_lines), test_path))
554 with open(test_path, 'wb') as f:
555 f.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
557 def main():
558 args = _parse_args()
559 test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
560 for test_path in test_paths:
561 sys.stderr.write('Test: {}\n'.format(test_path))
563 # Call this per test. By default each warning will only be written once
564 # per source location. Reset the warning filter so that now each warning
565 # will be written once per source location per test.
566 _configure_warnings(args)
568 if args.verbose:
569 sys.stderr.write(
570 'Scanning for RUN lines in test file: {}\n'.format(test_path))
572 if not os.path.isfile(test_path):
573 raise Error('could not find test file: {}'.format(test_path))
575 with open(test_path) as f:
576 input_lines = [l.rstrip() for l in f]
578 run_lines = _find_run_lines(input_lines, args)
579 run_infos = _get_run_infos(run_lines, args)
580 common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
581 block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
582 _write_output(test_path,
583 input_lines,
584 run_infos,
585 block_infos,
586 args,
587 common_prefix,
588 prefix_pad)
590 return 0
593 if __name__ == '__main__':
594 try:
595 warnings.showwarning = _showwarning
596 sys.exit(main())
597 except Error as e:
598 sys.stdout.write('error: {}\n'.format(e))
599 sys.exit(1)