[RISCV] Refactor predicates for rvv intrinsic patterns.
[llvm-project.git] / llvm / utils / update_mca_test_checks.py
blobdb4511dac7c028c0e94810726d068c7b98efa6ce
1 #!/usr/bin/env python3
3 """A test case update script.
5 This script is a utility to update LLVM 'llvm-mca' based test cases with new
6 FileCheck patterns.
7 """
9 import argparse
10 from collections import defaultdict
11 import glob
12 import os
13 import sys
14 import warnings
16 from UpdateTestChecks import common
19 COMMENT_CHAR = '#'
20 ADVERT_PREFIX = '{} NOTE: Assertions have been autogenerated by '.format(
21 COMMENT_CHAR)
22 ADVERT = '{}utils/{}'.format(ADVERT_PREFIX, os.path.basename(__file__))
25 class Error(Exception):
26 """ Generic Error that can be raised without printing a traceback.
27 """
28 pass
31 def _warn(msg):
32 """ Log a user warning to stderr.
33 """
34 warnings.warn(msg, Warning, stacklevel=2)
37 def _configure_warnings(args):
38 warnings.resetwarnings()
39 if args.w:
40 warnings.simplefilter('ignore')
41 if args.Werror:
42 warnings.simplefilter('error')
45 def _showwarning(message, category, filename, lineno, file=None, line=None):
46 """ Version of warnings.showwarning that won't attempt to print out the
47 line at the location of the warning if the line text is not explicitly
48 specified.
49 """
50 if file is None:
51 file = sys.stderr
52 if line is None:
53 line = ''
54 file.write(warnings.formatwarning(message, category, filename, lineno, line))
57 def _parse_args():
58 parser = argparse.ArgumentParser(description=__doc__)
59 parser.add_argument('-w',
60 action='store_true',
61 help='suppress warnings')
62 parser.add_argument('-Werror',
63 action='store_true',
64 help='promote warnings to errors')
65 parser.add_argument('--llvm-mca-binary',
66 metavar='<path>',
67 default='llvm-mca',
68 help='the binary to use to generate the test case '
69 '(default: llvm-mca)')
70 parser.add_argument('tests',
71 metavar='<test-path>',
72 nargs='+')
73 args = common.parse_commandline_args(parser)
75 _configure_warnings(args)
77 if not args.llvm_mca_binary:
78 raise Error('--llvm-mca-binary value cannot be empty string')
80 if 'llvm-mca' not in os.path.basename(args.llvm_mca_binary):
81 _warn('unexpected binary name: {}'.format(args.llvm_mca_binary))
83 return args
86 def _get_run_infos(run_lines, args):
87 run_infos = []
88 for run_line in run_lines:
89 try:
90 (tool_cmd, filecheck_cmd) = tuple([cmd.strip()
91 for cmd in run_line.split('|', 1)])
92 except ValueError:
93 _warn('could not split tool and filecheck commands: {}'.format(run_line))
94 continue
96 common.verify_filecheck_prefixes(filecheck_cmd)
97 tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
99 if not tool_cmd.startswith(tool_basename + ' '):
100 _warn('skipping non-{} RUN line: {}'.format(tool_basename, run_line))
101 continue
103 if not filecheck_cmd.startswith('FileCheck '):
104 _warn('skipping non-FileCheck RUN line: {}'.format(run_line))
105 continue
107 tool_cmd_args = tool_cmd[len(tool_basename):].strip()
108 tool_cmd_args = tool_cmd_args.replace('< %s', '').replace('%s', '').strip()
110 check_prefixes = common.get_check_prefixes(filecheck_cmd)
112 run_infos.append((check_prefixes, tool_cmd_args))
114 return run_infos
117 def _break_down_block(block_info, common_prefix):
118 """ Given a block_info, see if we can analyze it further to let us break it
119 down by prefix per-line rather than per-block.
121 texts = block_info.keys()
122 prefixes = list(block_info.values())
123 # Split the lines from each of the incoming block_texts and zip them so that
124 # each element contains the corresponding lines from each text. E.g.
126 # block_text_1: A # line 1
127 # B # line 2
129 # block_text_2: A # line 1
130 # C # line 2
132 # would become:
134 # [(A, A), # line 1
135 # (B, C)] # line 2
137 line_tuples = list(zip(*list((text.splitlines() for text in texts))))
139 # To simplify output, we'll only proceed if the very first line of the block
140 # texts is common to each of them.
141 if len(set(line_tuples[0])) != 1:
142 return []
144 result = []
145 lresult = defaultdict(list)
146 for i, line in enumerate(line_tuples):
147 if len(set(line)) == 1:
148 # We're about to output a line with the common prefix. This is a sync
149 # point so flush any batched-up lines one prefix at a time to the output
150 # first.
151 for prefix in sorted(lresult):
152 result.extend(lresult[prefix])
153 lresult = defaultdict(list)
155 # The line is common to each block so output with the common prefix.
156 result.append((common_prefix, line[0]))
157 else:
158 # The line is not common to each block, or we don't have a common prefix.
159 # If there are no prefixes available, warn and bail out.
160 if not prefixes[0]:
161 _warn('multiple lines not disambiguated by prefixes:\n{}\n'
162 'Some blocks may be skipped entirely as a result.'.format(
163 '\n'.join(' - {}'.format(l) for l in line)))
164 return []
166 # Iterate through the line from each of the blocks and add the line with
167 # the corresponding prefix to the current batch of results so that we can
168 # later output them per-prefix.
169 for i, l in enumerate(line):
170 for prefix in prefixes[i]:
171 lresult[prefix].append((prefix, l))
173 # Flush any remaining batched-up lines one prefix at a time to the output.
174 for prefix in sorted(lresult):
175 result.extend(lresult[prefix])
176 return result
179 def _get_useful_prefix_info(run_infos):
180 """ Given the run_infos, calculate any prefixes that are common to every one,
181 and the length of the longest prefix string.
183 try:
184 all_sets = [set(s) for s in list(zip(*run_infos))[0]]
185 common_to_all = set.intersection(*all_sets)
186 longest_prefix_len = max(len(p) for p in set.union(*all_sets))
187 except IndexError:
188 common_to_all = []
189 longest_prefix_len = 0
190 else:
191 if len(common_to_all) > 1:
192 _warn('Multiple prefixes common to all RUN lines: {}'.format(
193 common_to_all))
194 if common_to_all:
195 common_to_all = sorted(common_to_all)[0]
196 return common_to_all, longest_prefix_len
199 def _align_matching_blocks(all_blocks, farthest_indexes):
200 """ Some sub-sequences of blocks may be common to multiple lists of blocks,
201 but at different indexes in each one.
203 For example, in the following case, A,B,E,F, and H are common to both
204 sets, but only A and B would be identified as such due to the indexes
205 matching:
207 index | 0 1 2 3 4 5 6
208 ------+--------------
209 setA | A B C D E F H
210 setB | A B E F G H
212 This function attempts to align the indexes of matching blocks by
213 inserting empty blocks into the block list. With this approach, A, B, E,
214 F, and H would now be able to be identified as matching blocks:
216 index | 0 1 2 3 4 5 6 7
217 ------+----------------
218 setA | A B C D E F H
219 setB | A B E F G H
222 # "Farthest block analysis": essentially, iterate over all blocks and find
223 # the highest index into a block list for the first instance of each block.
224 # This is relatively expensive, but we're dealing with small numbers of
225 # blocks so it doesn't make a perceivable difference to user time.
226 for blocks in all_blocks.values():
227 for block in blocks:
228 if not block:
229 continue
231 index = blocks.index(block)
233 if index > farthest_indexes[block]:
234 farthest_indexes[block] = index
236 # Use the results of the above analysis to identify any blocks that can be
237 # shunted along to match the farthest index value.
238 for blocks in all_blocks.values():
239 for index, block in enumerate(blocks):
240 if not block:
241 continue
243 changed = False
244 # If the block has not already been subject to alignment (i.e. if the
245 # previous block is not empty) then insert empty blocks until the index
246 # matches the farthest index identified for that block.
247 if (index > 0) and blocks[index - 1]:
248 while(index < farthest_indexes[block]):
249 blocks.insert(index, '')
250 index += 1
251 changed = True
253 if changed:
254 # Bail out. We'll need to re-do the farthest block analysis now that
255 # we've inserted some blocks.
256 return True
258 return False
261 def _get_block_infos(run_infos, test_path, args, common_prefix): # noqa
262 """ For each run line, run the tool with the specified args and collect the
263 output. We use the concept of 'blocks' for uniquing, where a block is
264 a series of lines of text with no more than one newline character between
265 each one. For example:
267 This
270 block
272 This is
273 another block
275 This is yet another block
277 We then build up a 'block_infos' structure containing a dict where the
278 text of each block is the key and a list of the sets of prefixes that may
279 generate that particular block. This then goes through a series of
280 transformations to minimise the amount of CHECK lines that need to be
281 written by taking advantage of common prefixes.
284 def _block_key(tool_args, prefixes):
285 """ Get a hashable key based on the current tool_args and prefixes.
287 return ' '.join([tool_args] + prefixes)
289 all_blocks = {}
290 max_block_len = 0
292 # A cache of the furthest-back position in any block list of the first
293 # instance of each block, indexed by the block itself.
294 farthest_indexes = defaultdict(int)
296 # Run the tool for each run line to generate all of the blocks.
297 for prefixes, tool_args in run_infos:
298 key = _block_key(tool_args, prefixes)
299 raw_tool_output = common.invoke_tool(args.llvm_mca_binary,
300 tool_args,
301 test_path)
303 # Replace any lines consisting of purely whitespace with empty lines.
304 raw_tool_output = '\n'.join(line if line.strip() else ''
305 for line in raw_tool_output.splitlines())
307 # Split blocks, stripping all trailing whitespace, but keeping preceding
308 # whitespace except for newlines so that columns will line up visually.
309 all_blocks[key] = [b.lstrip('\n').rstrip()
310 for b in raw_tool_output.split('\n\n')]
311 max_block_len = max(max_block_len, len(all_blocks[key]))
313 # Attempt to align matching blocks until no more changes can be made.
314 made_changes = True
315 while made_changes:
316 made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
318 # If necessary, pad the lists of blocks with empty blocks so that they are
319 # all the same length.
320 for key in all_blocks:
321 len_to_pad = max_block_len - len(all_blocks[key])
322 all_blocks[key] += [''] * len_to_pad
324 # Create the block_infos structure where it is a nested dict in the form of:
325 # block number -> block text -> list of prefix sets
326 block_infos = defaultdict(lambda: defaultdict(list))
327 for prefixes, tool_args in run_infos:
328 key = _block_key(tool_args, prefixes)
329 for block_num, block_text in enumerate(all_blocks[key]):
330 block_infos[block_num][block_text].append(set(prefixes))
332 # Now go through the block_infos structure and attempt to smartly prune the
333 # number of prefixes per block to the minimal set possible to output.
334 for block_num in range(len(block_infos)):
335 # When there are multiple block texts for a block num, remove any
336 # prefixes that are common to more than one of them.
337 # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
338 all_sets = [s for s in block_infos[block_num].values()]
339 pruned_sets = []
341 for i, setlist in enumerate(all_sets):
342 other_set_values = set([elem for j, setlist2 in enumerate(all_sets)
343 for set_ in setlist2 for elem in set_
344 if i != j])
345 pruned_sets.append([s - other_set_values for s in setlist])
347 for i, block_text in enumerate(block_infos[block_num]):
349 # When a block text matches multiple sets of prefixes, try removing any
350 # prefixes that aren't common to all of them.
351 # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
352 common_values = set.intersection(*pruned_sets[i])
353 if common_values:
354 pruned_sets[i] = [common_values]
356 # Everything should be uniqued as much as possible by now. Apply the
357 # newly pruned sets to the block_infos structure.
358 # If there are any blocks of text that still match multiple prefixes,
359 # output a warning.
360 current_set = set()
361 for s in pruned_sets[i]:
362 s = sorted(list(s))
363 if s:
364 current_set.add(s[0])
365 if len(s) > 1:
366 _warn('Multiple prefixes generating same output: {} '
367 '(discarding {})'.format(','.join(s), ','.join(s[1:])))
369 if block_text and not current_set:
370 raise Error(
371 'block not captured by existing prefixes:\n\n{}'.format(block_text))
372 block_infos[block_num][block_text] = sorted(list(current_set))
374 # If we have multiple block_texts, try to break them down further to avoid
375 # the case where we have very similar block_texts repeated after each
376 # other.
377 if common_prefix and len(block_infos[block_num]) > 1:
378 # We'll only attempt this if each of the block_texts have the same number
379 # of lines as each other.
380 same_num_Lines = (len(set(len(k.splitlines())
381 for k in block_infos[block_num].keys())) == 1)
382 if same_num_Lines:
383 breakdown = _break_down_block(block_infos[block_num], common_prefix)
384 if breakdown:
385 block_infos[block_num] = breakdown
387 return block_infos
390 def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
391 """ Write an individual block, with correct padding on the prefixes.
392 Returns a set of all of the prefixes that it has written.
394 end_prefix = ': '
395 previous_prefix = None
396 num_lines_of_prefix = 0
397 written_prefixes = set()
399 for prefix, line in block:
400 if prefix in not_prefix_set:
401 _warn('not writing for prefix {0} due to presence of "{0}-NOT:" '
402 'in input file.'.format(prefix))
403 continue
405 # If the previous line isn't already blank and we're writing more than one
406 # line for the current prefix output a blank line first, unless either the
407 # current of previous prefix is common to all.
408 num_lines_of_prefix += 1
409 if prefix != previous_prefix:
410 if output and output[-1]:
411 if num_lines_of_prefix > 1 or any(p == common_prefix
412 for p in (prefix, previous_prefix)):
413 output.append('')
414 num_lines_of_prefix = 0
415 previous_prefix = prefix
417 written_prefixes.add(prefix)
418 output.append(
419 '{} {}{}{} {}'.format(COMMENT_CHAR,
420 prefix,
421 end_prefix,
422 ' ' * (prefix_pad - len(prefix)),
423 line).rstrip())
424 end_prefix = '-NEXT:'
426 output.append('')
427 return written_prefixes
430 def _write_output(test_path, input_lines, prefix_list, block_infos, # noqa
431 args, common_prefix, prefix_pad):
432 prefix_set = set([prefix for prefixes, _ in prefix_list
433 for prefix in prefixes])
434 not_prefix_set = set()
436 output_lines = []
437 for input_line in input_lines:
438 if input_line.startswith(ADVERT_PREFIX):
439 continue
441 if input_line.startswith(COMMENT_CHAR):
442 m = common.CHECK_RE.match(input_line)
443 try:
444 prefix = m.group(1)
445 except AttributeError:
446 prefix = None
448 if '{}-NOT:'.format(prefix) in input_line:
449 not_prefix_set.add(prefix)
451 if prefix not in prefix_set or prefix in not_prefix_set:
452 output_lines.append(input_line)
453 continue
455 if common.should_add_line_to_output(input_line, prefix_set):
456 # This input line of the function body will go as-is into the output.
457 # Except make leading whitespace uniform: 2 spaces.
458 input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r' ', input_line)
460 # Skip empty lines if the previous output line is also empty.
461 if input_line or output_lines[-1]:
462 output_lines.append(input_line)
463 else:
464 continue
466 # Add a blank line before the new checks if required.
467 if len(output_lines) > 0 and output_lines[-1]:
468 output_lines.append('')
470 output_check_lines = []
471 used_prefixes = set()
472 for block_num in range(len(block_infos)):
473 if type(block_infos[block_num]) is list:
474 # The block is of the type output from _break_down_block().
475 used_prefixes |= _write_block(output_check_lines,
476 block_infos[block_num],
477 not_prefix_set,
478 common_prefix,
479 prefix_pad)
480 else:
481 # _break_down_block() was unable to do do anything so output the block
482 # as-is.
484 # Rather than writing out each block as soon we encounter it, save it
485 # indexed by prefix so that we can write all of the blocks out sorted by
486 # prefix at the end.
487 output_blocks = defaultdict(list)
489 for block_text in sorted(block_infos[block_num]):
491 if not block_text:
492 continue
494 lines = block_text.split('\n')
495 for prefix in block_infos[block_num][block_text]:
496 assert prefix not in output_blocks
497 used_prefixes |= _write_block(output_blocks[prefix],
498 [(prefix, line) for line in lines],
499 not_prefix_set,
500 common_prefix,
501 prefix_pad)
503 for prefix in sorted(output_blocks):
504 output_check_lines.extend(output_blocks[prefix])
506 unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
507 if unused_prefixes:
508 raise Error('unused prefixes: {}'.format(sorted(unused_prefixes)))
510 if output_check_lines:
511 output_lines.insert(0, ADVERT)
512 output_lines.extend(output_check_lines)
514 # The file should not end with two newlines. It creates unnecessary churn.
515 while len(output_lines) > 0 and output_lines[-1] == '':
516 output_lines.pop()
518 if input_lines == output_lines:
519 sys.stderr.write(' [unchanged]\n')
520 return
521 sys.stderr.write(' [{} lines total]\n'.format(len(output_lines)))
523 common.debug('Writing', len(output_lines), 'lines to', test_path, '..\n\n')
525 with open(test_path, 'wb') as f:
526 f.writelines(['{}\n'.format(l).encode('utf-8') for l in output_lines])
528 def main():
529 args = _parse_args()
530 test_paths = [test for pattern in args.tests for test in glob.glob(pattern)]
531 for test_path in test_paths:
532 sys.stderr.write('Test: {}\n'.format(test_path))
534 # Call this per test. By default each warning will only be written once
535 # per source location. Reset the warning filter so that now each warning
536 # will be written once per source location per test.
537 _configure_warnings(args)
539 if not os.path.isfile(test_path):
540 raise Error('could not find test file: {}'.format(test_path))
542 with open(test_path) as f:
543 input_lines = [l.rstrip() for l in f]
545 run_lines = common.find_run_lines(test_path, input_lines)
546 run_infos = _get_run_infos(run_lines, args)
547 common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
548 block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
549 _write_output(test_path,
550 input_lines,
551 run_infos,
552 block_infos,
553 args,
554 common_prefix,
555 prefix_pad)
557 return 0
560 if __name__ == '__main__':
561 try:
562 warnings.showwarning = _showwarning
563 sys.exit(main())
564 except Error as e:
565 sys.stdout.write('error: {}\n'.format(e))
566 sys.exit(1)