Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / utils / update_mca_test_checks.py
blob486cb66b827f3145ecfd270d47b5818346dc0bc7
1 #!/usr/bin/env python3
3 """A test case update script.
5 This script is a utility to update LLVM 'llvm-mca' based test cases with new
6 FileCheck patterns.
7 """
9 import argparse
10 from collections import defaultdict
11 import glob
12 import os
13 import sys
14 import warnings
16 from UpdateTestChecks import common
19 COMMENT_CHAR = "#"
20 ADVERT_PREFIX = "{} NOTE: Assertions have been autogenerated by ".format(COMMENT_CHAR)
21 ADVERT = "{}utils/{}".format(ADVERT_PREFIX, os.path.basename(__file__))
24 class Error(Exception):
25 """Generic Error that can be raised without printing a traceback."""
27 pass
30 def _warn(msg):
31 """Log a user warning to stderr."""
32 warnings.warn(msg, Warning, stacklevel=2)
35 def _configure_warnings(args):
36 warnings.resetwarnings()
37 if args.w:
38 warnings.simplefilter("ignore")
39 if args.Werror:
40 warnings.simplefilter("error")
43 def _showwarning(message, category, filename, lineno, file=None, line=None):
44 """Version of warnings.showwarning that won't attempt to print out the
45 line at the location of the warning if the line text is not explicitly
46 specified.
47 """
48 if file is None:
49 file = sys.stderr
50 if line is None:
51 line = ""
52 file.write(warnings.formatwarning(message, category, filename, lineno, line))
55 def _get_parser():
56 parser = argparse.ArgumentParser(description=__doc__)
57 parser.add_argument("-w", action="store_true", help="suppress warnings")
58 parser.add_argument(
59 "-Werror", action="store_true", help="promote warnings to errors"
61 parser.add_argument(
62 "--llvm-mca-binary",
63 metavar="<path>",
64 default="llvm-mca",
65 help="the binary to use to generate the test case " "(default: llvm-mca)",
67 parser.add_argument("tests", metavar="<test-path>", nargs="+")
68 return parser
70 def _get_run_infos(run_lines, args):
71 run_infos = []
72 for run_line in run_lines:
73 try:
74 (tool_cmd, filecheck_cmd) = tuple(
75 [cmd.strip() for cmd in run_line.split("|", 1)]
77 except ValueError:
78 _warn("could not split tool and filecheck commands: {}".format(run_line))
79 continue
81 common.verify_filecheck_prefixes(filecheck_cmd)
82 tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
84 if not tool_cmd.startswith(tool_basename + " "):
85 _warn("skipping non-{} RUN line: {}".format(tool_basename, run_line))
86 continue
88 if not filecheck_cmd.startswith("FileCheck "):
89 _warn("skipping non-FileCheck RUN line: {}".format(run_line))
90 continue
92 tool_cmd_args = tool_cmd[len(tool_basename) :].strip()
93 tool_cmd_args = tool_cmd_args.replace("< %s", "").replace("%s", "").strip()
95 check_prefixes = common.get_check_prefixes(filecheck_cmd)
97 run_infos.append((check_prefixes, tool_cmd_args))
99 return run_infos
102 def _break_down_block(block_info, common_prefix):
103 """Given a block_info, see if we can analyze it further to let us break it
104 down by prefix per-line rather than per-block.
106 texts = block_info.keys()
107 prefixes = list(block_info.values())
108 # Split the lines from each of the incoming block_texts and zip them so that
109 # each element contains the corresponding lines from each text. E.g.
111 # block_text_1: A # line 1
112 # B # line 2
114 # block_text_2: A # line 1
115 # C # line 2
117 # would become:
119 # [(A, A), # line 1
120 # (B, C)] # line 2
122 line_tuples = list(zip(*list((text.splitlines() for text in texts))))
124 # To simplify output, we'll only proceed if the very first line of the block
125 # texts is common to each of them.
126 if len(set(line_tuples[0])) != 1:
127 return []
129 result = []
130 lresult = defaultdict(list)
131 for i, line in enumerate(line_tuples):
132 if len(set(line)) == 1:
133 # We're about to output a line with the common prefix. This is a sync
134 # point so flush any batched-up lines one prefix at a time to the output
135 # first.
136 for prefix in sorted(lresult):
137 result.extend(lresult[prefix])
138 lresult = defaultdict(list)
140 # The line is common to each block so output with the common prefix.
141 result.append((common_prefix, line[0]))
142 else:
143 # The line is not common to each block, or we don't have a common prefix.
144 # If there are no prefixes available, warn and bail out.
145 if not prefixes[0]:
146 _warn(
147 "multiple lines not disambiguated by prefixes:\n{}\n"
148 "Some blocks may be skipped entirely as a result.".format(
149 "\n".join(" - {}".format(l) for l in line)
152 return []
154 # Iterate through the line from each of the blocks and add the line with
155 # the corresponding prefix to the current batch of results so that we can
156 # later output them per-prefix.
157 for i, l in enumerate(line):
158 for prefix in prefixes[i]:
159 lresult[prefix].append((prefix, l))
161 # Flush any remaining batched-up lines one prefix at a time to the output.
162 for prefix in sorted(lresult):
163 result.extend(lresult[prefix])
164 return result
167 def _get_useful_prefix_info(run_infos):
168 """Given the run_infos, calculate any prefixes that are common to every one,
169 and the length of the longest prefix string.
171 try:
172 all_sets = [set(s) for s in list(zip(*run_infos))[0]]
173 common_to_all = set.intersection(*all_sets)
174 longest_prefix_len = max(len(p) for p in set.union(*all_sets))
175 except IndexError:
176 common_to_all = []
177 longest_prefix_len = 0
178 else:
179 if len(common_to_all) > 1:
180 _warn("Multiple prefixes common to all RUN lines: {}".format(common_to_all))
181 if common_to_all:
182 common_to_all = sorted(common_to_all)[0]
183 return common_to_all, longest_prefix_len
186 def _align_matching_blocks(all_blocks, farthest_indexes):
187 """Some sub-sequences of blocks may be common to multiple lists of blocks,
188 but at different indexes in each one.
190 For example, in the following case, A,B,E,F, and H are common to both
191 sets, but only A and B would be identified as such due to the indexes
192 matching:
194 index | 0 1 2 3 4 5 6
195 ------+--------------
196 setA | A B C D E F H
197 setB | A B E F G H
199 This function attempts to align the indexes of matching blocks by
200 inserting empty blocks into the block list. With this approach, A, B, E,
201 F, and H would now be able to be identified as matching blocks:
203 index | 0 1 2 3 4 5 6 7
204 ------+----------------
205 setA | A B C D E F H
206 setB | A B E F G H
209 # "Farthest block analysis": essentially, iterate over all blocks and find
210 # the highest index into a block list for the first instance of each block.
211 # This is relatively expensive, but we're dealing with small numbers of
212 # blocks so it doesn't make a perceivable difference to user time.
213 for blocks in all_blocks.values():
214 for block in blocks:
215 if not block:
216 continue
218 index = blocks.index(block)
220 if index > farthest_indexes[block]:
221 farthest_indexes[block] = index
223 # Use the results of the above analysis to identify any blocks that can be
224 # shunted along to match the farthest index value.
225 for blocks in all_blocks.values():
226 for index, block in enumerate(blocks):
227 if not block:
228 continue
230 changed = False
231 # If the block has not already been subject to alignment (i.e. if the
232 # previous block is not empty) then insert empty blocks until the index
233 # matches the farthest index identified for that block.
234 if (index > 0) and blocks[index - 1]:
235 while index < farthest_indexes[block]:
236 blocks.insert(index, "")
237 index += 1
238 changed = True
240 if changed:
241 # Bail out. We'll need to re-do the farthest block analysis now that
242 # we've inserted some blocks.
243 return True
245 return False
248 def _get_block_infos(run_infos, test_path, args, common_prefix): # noqa
249 """For each run line, run the tool with the specified args and collect the
250 output. We use the concept of 'blocks' for uniquing, where a block is
251 a series of lines of text with no more than one newline character between
252 each one. For example:
254 This
257 block
259 This is
260 another block
262 This is yet another block
264 We then build up a 'block_infos' structure containing a dict where the
265 text of each block is the key and a list of the sets of prefixes that may
266 generate that particular block. This then goes through a series of
267 transformations to minimise the amount of CHECK lines that need to be
268 written by taking advantage of common prefixes.
271 def _block_key(tool_args, prefixes):
272 """Get a hashable key based on the current tool_args and prefixes."""
273 return " ".join([tool_args] + prefixes)
275 all_blocks = {}
276 max_block_len = 0
278 # A cache of the furthest-back position in any block list of the first
279 # instance of each block, indexed by the block itself.
280 farthest_indexes = defaultdict(int)
282 # Run the tool for each run line to generate all of the blocks.
283 for prefixes, tool_args in run_infos:
284 key = _block_key(tool_args, prefixes)
285 raw_tool_output = common.invoke_tool(args.llvm_mca_binary, tool_args, test_path)
287 # Replace any lines consisting of purely whitespace with empty lines.
288 raw_tool_output = "\n".join(
289 line if line.strip() else "" for line in raw_tool_output.splitlines()
292 # Split blocks, stripping all trailing whitespace, but keeping preceding
293 # whitespace except for newlines so that columns will line up visually.
294 all_blocks[key] = [
295 b.lstrip("\n").rstrip() for b in raw_tool_output.split("\n\n")
297 max_block_len = max(max_block_len, len(all_blocks[key]))
299 # Attempt to align matching blocks until no more changes can be made.
300 made_changes = True
301 while made_changes:
302 made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
304 # If necessary, pad the lists of blocks with empty blocks so that they are
305 # all the same length.
306 for key in all_blocks:
307 len_to_pad = max_block_len - len(all_blocks[key])
308 all_blocks[key] += [""] * len_to_pad
310 # Create the block_infos structure where it is a nested dict in the form of:
311 # block number -> block text -> list of prefix sets
312 block_infos = defaultdict(lambda: defaultdict(list))
313 for prefixes, tool_args in run_infos:
314 key = _block_key(tool_args, prefixes)
315 for block_num, block_text in enumerate(all_blocks[key]):
316 block_infos[block_num][block_text].append(set(prefixes))
318 # Now go through the block_infos structure and attempt to smartly prune the
319 # number of prefixes per block to the minimal set possible to output.
320 for block_num in range(len(block_infos)):
321 # When there are multiple block texts for a block num, remove any
322 # prefixes that are common to more than one of them.
323 # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
324 all_sets = [s for s in block_infos[block_num].values()]
325 pruned_sets = []
327 for i, setlist in enumerate(all_sets):
328 other_set_values = set(
330 elem
331 for j, setlist2 in enumerate(all_sets)
332 for set_ in setlist2
333 for elem in set_
334 if i != j
337 pruned_sets.append([s - other_set_values for s in setlist])
339 for i, block_text in enumerate(block_infos[block_num]):
341 # When a block text matches multiple sets of prefixes, try removing any
342 # prefixes that aren't common to all of them.
343 # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
344 common_values = set.intersection(*pruned_sets[i])
345 if common_values:
346 pruned_sets[i] = [common_values]
348 # Everything should be uniqued as much as possible by now. Apply the
349 # newly pruned sets to the block_infos structure.
350 # If there are any blocks of text that still match multiple prefixes,
351 # output a warning.
352 current_set = set()
353 for s in pruned_sets[i]:
354 s = sorted(list(s))
355 if s:
356 current_set.add(s[0])
357 if len(s) > 1:
358 _warn(
359 "Multiple prefixes generating same output: {} "
360 "(discarding {})".format(",".join(s), ",".join(s[1:]))
363 if block_text and not current_set:
364 raise Error(
365 "block not captured by existing prefixes:\n\n{}".format(block_text)
367 block_infos[block_num][block_text] = sorted(list(current_set))
369 # If we have multiple block_texts, try to break them down further to avoid
370 # the case where we have very similar block_texts repeated after each
371 # other.
372 if common_prefix and len(block_infos[block_num]) > 1:
373 # We'll only attempt this if each of the block_texts have the same number
374 # of lines as each other.
375 same_num_Lines = (
376 len(set(len(k.splitlines()) for k in block_infos[block_num].keys()))
377 == 1
379 if same_num_Lines:
380 breakdown = _break_down_block(block_infos[block_num], common_prefix)
381 if breakdown:
382 block_infos[block_num] = breakdown
384 return block_infos
387 def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
388 """Write an individual block, with correct padding on the prefixes.
389 Returns a set of all of the prefixes that it has written.
391 end_prefix = ": "
392 previous_prefix = None
393 num_lines_of_prefix = 0
394 written_prefixes = set()
396 for prefix, line in block:
397 if prefix in not_prefix_set:
398 _warn(
399 'not writing for prefix {0} due to presence of "{0}-NOT:" '
400 "in input file.".format(prefix)
402 continue
404 # If the previous line isn't already blank and we're writing more than one
405 # line for the current prefix output a blank line first, unless either the
406 # current of previous prefix is common to all.
407 num_lines_of_prefix += 1
408 if prefix != previous_prefix:
409 if output and output[-1]:
410 if num_lines_of_prefix > 1 or any(
411 p == common_prefix for p in (prefix, previous_prefix)
413 output.append("")
414 num_lines_of_prefix = 0
415 previous_prefix = prefix
417 written_prefixes.add(prefix)
418 output.append(
419 "{} {}{}{} {}".format(
420 COMMENT_CHAR, prefix, end_prefix, " " * (prefix_pad - len(prefix)), line
421 ).rstrip()
423 end_prefix = "-NEXT:"
425 output.append("")
426 return written_prefixes
429 def _write_output(
430 test_path,
431 input_lines,
432 prefix_list,
433 block_infos, # noqa
434 args,
435 common_prefix,
436 prefix_pad,
438 prefix_set = set([prefix for prefixes, _ in prefix_list for prefix in prefixes])
439 not_prefix_set = set()
441 output_lines = []
442 for input_line in input_lines:
443 if input_line.startswith(ADVERT_PREFIX):
444 continue
446 if input_line.startswith(COMMENT_CHAR):
447 m = common.CHECK_RE.match(input_line)
448 try:
449 prefix = m.group(1)
450 except AttributeError:
451 prefix = None
453 if "{}-NOT:".format(prefix) in input_line:
454 not_prefix_set.add(prefix)
456 if prefix not in prefix_set or prefix in not_prefix_set:
457 output_lines.append(input_line)
458 continue
460 if common.should_add_line_to_output(input_line, prefix_set):
461 # This input line of the function body will go as-is into the output.
462 # Except make leading whitespace uniform: 2 spaces.
463 input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r" ", input_line)
465 # Skip empty lines if the previous output line is also empty.
466 if input_line or output_lines[-1]:
467 output_lines.append(input_line)
468 else:
469 continue
471 # Add a blank line before the new checks if required.
472 if len(output_lines) > 0 and output_lines[-1]:
473 output_lines.append("")
475 output_check_lines = []
476 used_prefixes = set()
477 for block_num in range(len(block_infos)):
478 if type(block_infos[block_num]) is list:
479 # The block is of the type output from _break_down_block().
480 used_prefixes |= _write_block(
481 output_check_lines,
482 block_infos[block_num],
483 not_prefix_set,
484 common_prefix,
485 prefix_pad,
487 else:
488 # _break_down_block() was unable to do do anything so output the block
489 # as-is.
491 # Rather than writing out each block as soon we encounter it, save it
492 # indexed by prefix so that we can write all of the blocks out sorted by
493 # prefix at the end.
494 output_blocks = defaultdict(list)
496 for block_text in sorted(block_infos[block_num]):
498 if not block_text:
499 continue
501 lines = block_text.split("\n")
502 for prefix in block_infos[block_num][block_text]:
503 assert prefix not in output_blocks
504 used_prefixes |= _write_block(
505 output_blocks[prefix],
506 [(prefix, line) for line in lines],
507 not_prefix_set,
508 common_prefix,
509 prefix_pad,
512 for prefix in sorted(output_blocks):
513 output_check_lines.extend(output_blocks[prefix])
515 unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
516 if unused_prefixes:
517 raise Error("unused prefixes: {}".format(sorted(unused_prefixes)))
519 if output_check_lines:
520 output_lines.insert(0, ADVERT)
521 output_lines.extend(output_check_lines)
523 # The file should not end with two newlines. It creates unnecessary churn.
524 while len(output_lines) > 0 and output_lines[-1] == "":
525 output_lines.pop()
527 if input_lines == output_lines:
528 sys.stderr.write(" [unchanged]\n")
529 return
530 sys.stderr.write(" [{} lines total]\n".format(len(output_lines)))
532 common.debug("Writing", len(output_lines), "lines to", test_path, "..\n\n")
534 with open(test_path, "wb") as f:
535 f.writelines(["{}\n".format(l).encode("utf-8") for l in output_lines])
538 def update_test_file(args, test_path, autogenerated_note):
539 sys.stderr.write("Test: {}\n".format(test_path))
541 # Call this per test. By default each warning will only be written once
542 # per source location. Reset the warning filter so that now each warning
543 # will be written once per source location per test.
544 _configure_warnings(args)
546 with open(test_path) as f:
547 input_lines = [l.rstrip() for l in f]
549 run_lines = common.find_run_lines(test_path, input_lines)
550 run_infos = _get_run_infos(run_lines, args)
551 common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
552 block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
553 _write_output(
554 test_path,
555 input_lines,
556 run_infos,
557 block_infos,
558 args,
559 common_prefix,
560 prefix_pad,
563 def main():
564 script_name = "utils/" + os.path.basename(__file__)
565 parser = _get_parser()
566 args = common.parse_commandline_args(parser)
567 if not args.llvm_mca_binary:
568 raise Error("--llvm-mca-binary value cannot be empty string")
570 if "llvm-mca" not in os.path.basename(args.llvm_mca_binary):
571 _warn("unexpected binary name: {}".format(args.llvm_mca_binary))
573 for ti in common.itertests(args.tests, parser, script_name=script_name):
574 try:
575 update_test_file(ti.args, ti.path, ti.test_autogenerated_note)
576 except Exception:
577 common.warn("Error processing file", test_file=ti.path)
578 raise
579 return 0
581 if __name__ == "__main__":
582 try:
583 warnings.showwarning = _showwarning
584 sys.exit(main())
585 except Error as e:
586 sys.stdout.write("error: {}\n".format(e))
587 sys.exit(1)