[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / utils / spirv / gen_spirv_dialect.py
blob426bfca1b4f88f376cfceef703ba6b55d7778c48
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
4 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 # See https://llvm.org/LICENSE.txt for license information.
6 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 # Script for updating SPIR-V dialect by scraping information from SPIR-V
9 # HTML and JSON specs from the Internet.
11 # For example, to define the enum attribute for SPIR-V memory model:
13 # ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \
14 # --new-enum MemoryModel
16 # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
17 # SPIR-V enum classes.
19 import itertools
20 import math
21 import re
22 import requests
23 import textwrap
24 import yaml
26 SPIRV_HTML_SPEC_URL = (
27 "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
29 SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
31 SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
32 SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
34 AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
35 AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
36 AUTOGEN_OPCODE_SECTION_MARKER = (
37 "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
41 def get_spirv_doc_from_html_spec(url, settings):
42 """Extracts instruction documentation from SPIR-V HTML spec.
44 Returns:
45 - A dict mapping from instruction opcode to documentation.
46 """
47 if url is None:
48 url = SPIRV_HTML_SPEC_URL
50 response = requests.get(url)
51 spec = response.content
53 from bs4 import BeautifulSoup
55 spirv = BeautifulSoup(spec, "html.parser")
57 doc = {}
59 if settings.gen_cl_ops:
60 section_anchor = spirv.find("h2", {"id": "_binary_form"})
61 for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
62 for table in section.find_all("table"):
63 inst_html = table.tbody.tr.td
64 opname = inst_html.a["id"]
65 # Ignore the first line, which is just the opname.
66 doc[opname] = inst_html.text.split("\n", 1)[1].strip()
67 else:
68 section_anchor = spirv.find("h3", {"id": "_instructions_3"})
69 for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
70 for table in section.find_all("table"):
71 inst_html = table.tbody.tr.td.p
72 opname = inst_html.a["id"]
73 # Ignore the first line, which is just the opname.
74 doc[opname] = inst_html.text.split("\n", 1)[1].strip()
76 return doc
79 def get_spirv_grammar_from_json_spec(url):
80 """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
82 Returns:
83 - A list containing all operand kinds' grammar
84 - A list containing all instructions' grammar
85 """
86 response = requests.get(SPIRV_JSON_SPEC_URL)
87 spec = response.content
89 import json
91 spirv = json.loads(spec)
93 if url is None:
94 return spirv["operand_kinds"], spirv["instructions"]
96 response_ext = requests.get(url)
97 spec_ext = response_ext.content
98 spirv_ext = json.loads(spec_ext)
100 return spirv["operand_kinds"], spirv_ext["instructions"]
103 def split_list_into_sublists(items):
104 """Split the list of items into multiple sublists.
106 This is to make sure the string composed from each sublist won't exceed
107 80 characters.
109 Arguments:
110 - items: a list of strings
112 chuncks = []
113 chunk = []
114 chunk_len = 0
116 for item in items:
117 chunk_len += len(item) + 2
118 if chunk_len > 80:
119 chuncks.append(chunk)
120 chunk = []
121 chunk_len = len(item) + 2
122 chunk.append(item)
124 if len(chunk) != 0:
125 chuncks.append(chunk)
127 return chuncks
130 def uniquify_enum_cases(lst):
131 """Prunes duplicate enum cases from the list.
133 Arguments:
134 - lst: List whose elements are to be uniqued. Assumes each element is a
135 (symbol, value) pair and elements already sorted according to value.
137 Returns:
138 - A list with all duplicates removed. The elements are sorted according to
139 value and, for each value, uniqued according to symbol.
140 original list,
141 - A map from deduplicated cases to the uniqued case.
143 cases = lst
144 uniqued_cases = []
145 duplicated_cases = {}
147 # First sort according to the value
148 cases.sort(key=lambda x: x[1])
150 # Then group them according to the value
151 for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
152 # For each value, sort according to the enumerant symbol.
153 sorted_group = sorted(groups, key=lambda x: x[0])
154 # Keep the "smallest" case, which is typically the symbol without extension
155 # suffix. But we have special cases that we want to fix.
156 case = sorted_group[0]
157 for i in range(1, len(sorted_group)):
158 duplicated_cases[sorted_group[i][0]] = case[0]
159 if case[0] == "HlslSemanticGOOGLE":
160 assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
161 case = sorted_group[1]
162 duplicated_cases[sorted_group[0][0]] = case[0]
163 uniqued_cases.append(case)
165 return uniqued_cases, duplicated_cases
168 def toposort(dag, sort_fn):
169 """Topologically sorts the given dag.
171 Arguments:
172 - dag: a dict mapping from a node to its incoming nodes.
173 - sort_fn: a function for sorting nodes in the same batch.
175 Returns:
176 A list containing topologically sorted nodes.
179 # Returns the next batch of nodes without incoming edges
180 def get_next_batch(dag):
181 while True:
182 no_prev_nodes = set(node for node, prev in dag.items() if not prev)
183 if not no_prev_nodes:
184 break
185 yield sorted(no_prev_nodes, key=sort_fn)
186 dag = {
187 node: (prev - no_prev_nodes)
188 for node, prev in dag.items()
189 if node not in no_prev_nodes
191 assert not dag, "found cyclic dependency"
193 sorted_nodes = []
194 for batch in get_next_batch(dag):
195 sorted_nodes.extend(batch)
197 return sorted_nodes
200 def toposort_capabilities(all_cases, capability_mapping):
201 """Returns topologically sorted capability (symbol, value) pairs.
203 Arguments:
204 - all_cases: all capability cases (containing symbol, value, and implied
205 capabilities).
206 - capability_mapping: mapping from duplicated capability symbols to the
207 canonicalized symbol chosen for SPIRVBase.td.
209 Returns:
210 A list containing topologically sorted capability (symbol, value) pairs.
212 dag = {}
213 name_to_value = {}
214 for case in all_cases:
215 # Get the current capability.
216 cur = case["enumerant"]
217 name_to_value[cur] = case["value"]
218 # Ignore duplicated symbols.
219 if cur in capability_mapping:
220 continue
222 # Get capabilities implied by the current capability.
223 prev = case.get("capabilities", [])
224 uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
225 dag[cur] = uniqued_prev
227 sorted_caps = toposort(dag, lambda x: name_to_value[x])
228 # Attach the capability's value as the second component of the pair.
229 return [(c, name_to_value[c]) for c in sorted_caps]
232 def get_capability_mapping(operand_kinds):
233 """Returns the capability mapping from duplicated cases to canonicalized ones.
235 Arguments:
236 - operand_kinds: all operand kinds' grammar spec
238 Returns:
239 - A map mapping from duplicated capability symbols to the canonicalized
240 symbol chosen for SPIRVBase.td.
242 # Find the operand kind for capability
243 cap_kind = {}
244 for kind in operand_kinds:
245 if kind["kind"] == "Capability":
246 cap_kind = kind
248 kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
249 _, capability_mapping = uniquify_enum_cases(kind_cases)
251 return capability_mapping
254 def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
255 """Returns the availability specification string for the given enum case.
257 Arguments:
258 - enum_case: the enum case to generate availability spec for. It may contain
259 'version', 'lastVersion', 'extensions', or 'capabilities'.
260 - capability_mapping: mapping from duplicated capability symbols to the
261 canonicalized symbol chosen for SPIRVBase.td.
262 - for_op: bool value indicating whether this is the availability spec for an
263 op itself.
264 - for_cap: bool value indicating whether this is the availability spec for
265 capabilities themselves.
267 Returns:
268 - A `let availability = [...];` string if with availability spec or
269 empty string if without availability spec
271 assert not (for_op and for_cap), "cannot set both for_op and for_cap"
273 DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
274 DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
275 DEFAULT_CAP = "Capability<[]>"
276 DEFAULT_EXT = "Extension<[]>"
278 min_version = enum_case.get("version", "")
279 if min_version == "None":
280 min_version = ""
281 elif min_version:
282 min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
283 # TODO: delete this once ODS can support dialect-specific content
284 # and we can use omission to mean no requirements.
285 if for_op and not min_version:
286 min_version = DEFAULT_MIN_VERSION
288 max_version = enum_case.get("lastVersion", "")
289 if max_version:
290 max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
291 # TODO: delete this once ODS can support dialect-specific content
292 # and we can use omission to mean no requirements.
293 if for_op and not max_version:
294 max_version = DEFAULT_MAX_VERSION
296 exts = enum_case.get("extensions", [])
297 if exts:
298 exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
299 # We need to strip the minimal version requirement if this symbol is
300 # available via an extension, which means *any* SPIR-V version can support
301 # it as long as the extension is provided. The grammar's 'version' field
302 # under such case should be interpreted as this symbol is introduced as
303 # a core symbol since the given version, rather than a minimal version
304 # requirement.
305 min_version = DEFAULT_MIN_VERSION if for_op else ""
306 # TODO: delete this once ODS can support dialect-specific content
307 # and we can use omission to mean no requirements.
308 if for_op and not exts:
309 exts = DEFAULT_EXT
311 caps = enum_case.get("capabilities", [])
312 implies = ""
313 if caps:
314 canonicalized_caps = []
315 for c in caps:
316 if c in capability_mapping:
317 canonicalized_caps.append(capability_mapping[c])
318 else:
319 canonicalized_caps.append(c)
320 prefixed_caps = [
321 "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
323 if for_cap:
324 # If this is generating the availability for capabilities, we need to
325 # put the capability "requirements" in implies field because now
326 # the "capabilities" field in the source grammar means so.
327 caps = ""
328 implies = "list<I32EnumAttrCase> implies = [{}];".format(
329 ", ".join(prefixed_caps)
331 else:
332 caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
333 implies = ""
334 # TODO: delete this once ODS can support dialect-specific content
335 # and we can use omission to mean no requirements.
336 if for_op and not caps:
337 caps = DEFAULT_CAP
339 avail = ""
340 # Compose availability spec if any of the requirements is not empty.
341 # For ops, because we have a default in SPIRV_Op class, omit if the spec
342 # is the same.
343 if (min_version or max_version or caps or exts) and not (
344 for_op
345 and min_version == DEFAULT_MIN_VERSION
346 and max_version == DEFAULT_MAX_VERSION
347 and caps == DEFAULT_CAP
348 and exts == DEFAULT_EXT
350 joined_spec = ",\n ".join(
351 [e for e in [min_version, max_version, exts, caps] if e]
353 avail = "{} availability = [\n {}\n ];".format(
354 "let" if for_op else "list<Availability>", joined_spec
357 return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
360 def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
361 """Generates the TableGen EnumAttr definition for the given operand kind.
363 Returns:
364 - The operand kind's name
365 - A string containing the TableGen EnumAttr definition
367 if "enumerants" not in operand_kind:
368 return "", ""
370 # Returns a symbol for the given case in the given kind. This function
371 # handles Dim specially to avoid having numbers as the start of symbols,
372 # which does not play well with C++ and the MLIR parser.
373 def get_case_symbol(kind_name, case_name):
374 if kind_name == "Dim":
375 if case_name == "1D" or case_name == "2D" or case_name == "3D":
376 return "Dim{}".format(case_name)
377 return case_name
379 kind_name = operand_kind["kind"]
380 is_bit_enum = operand_kind["category"] == "BitEnum"
381 kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
383 name_to_case_dict = {}
384 for case in operand_kind["enumerants"]:
385 name_to_case_dict[case["enumerant"]] = case
387 if kind_name == "Capability":
388 # Special treatment for capability cases: we need to sort them topologically
389 # because a capability can refer to another via the 'implies' field.
390 kind_cases = toposort_capabilities(
391 operand_kind["enumerants"], capability_mapping
393 else:
394 kind_cases = [
395 (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
397 kind_cases, _ = uniquify_enum_cases(kind_cases)
398 max_len = max([len(symbol) for (symbol, _) in kind_cases])
400 # Generate the definition for each enum case
401 case_category = "I32Bit" if is_bit_enum else "I32"
402 fmt_str = (
403 "def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
404 '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
406 case_defs = []
407 for case_pair in kind_cases:
408 name = case_pair[0]
409 if is_bit_enum:
410 value = int(case_pair[1], base=16)
411 else:
412 value = int(case_pair[1])
413 avail = get_availability_spec(
414 name_to_case_dict[name],
415 capability_mapping,
416 False,
417 kind_name == "Capability",
419 if is_bit_enum:
420 if value == 0:
421 suffix = "None"
422 value = ""
423 else:
424 suffix = "Bit"
425 value = ", {}".format(int(math.log2(value)))
426 else:
427 suffix = ""
428 value = ", {}".format(value)
430 case_def = fmt_str.format(
431 category=case_category,
432 suffix=suffix,
433 acronym=kind_acronym,
434 case_name=name,
435 symbol=get_case_symbol(kind_name, name),
436 case_value_part=value,
437 avail=" {{\n {}\n}}".format(avail) if avail else ";",
438 colon=":",
439 offset=(max_len + 1 - len(name)),
441 case_defs.append(case_def)
442 case_defs = "\n".join(case_defs)
444 # Generate the list of enum case names
445 fmt_str = "SPIRV_{acronym}_{symbol}"
446 case_names = [
447 fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
450 # Split them into sublists and concatenate into multiple lines
451 case_names = split_list_into_sublists(case_names)
452 case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
453 case_names = ",\n".join(case_names)
455 # Generate the enum attribute definition
456 kind_category = "Bit" if is_bit_enum else "I32"
457 enum_attr = """def SPIRV_{name}Attr :
458 SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
459 {cases}
460 ]>;""".format(
461 name=kind_name,
462 snake_name=snake_casify(kind_name),
463 category=kind_category,
464 cases=case_names,
466 return kind_name, case_defs + "\n\n" + enum_attr
469 def gen_opcode(instructions):
470 """Generates the TableGen definition to map opname to opcode
472 Returns:
473 - A string containing the TableGen SPIRV_OpCode definition
476 max_len = max([len(inst["opname"]) for inst in instructions])
477 def_fmt_str = (
478 "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
480 opcode_defs = [
481 def_fmt_str.format(
482 name=inst["opname"],
483 value=inst["opcode"],
484 colon=":",
485 offset=(max_len + 1 - len(inst["opname"])),
487 for inst in instructions
489 opcode_str = "\n".join(opcode_defs)
491 decl_fmt_str = "SPIRV_OC_{name}"
492 opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
493 opcode_list = split_list_into_sublists(opcode_list)
494 opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
495 opcode_list = ",\n".join(opcode_list)
496 enum_attr = (
497 "def SPIRV_OpcodeAttr :\n"
498 ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
499 '"opcode", [\n'
500 "{lst}\n"
501 " ]>;".format(name="Opcode", lst=opcode_list)
503 return opcode_str + "\n\n" + enum_attr
506 def map_cap_to_opnames(instructions):
507 """Maps capabilities to instructions enabled by those capabilities
509 Arguments:
510 - instructions: a list containing a subset of SPIR-V instructions' grammar
511 Returns:
512 - A map with keys representing capabilities and values of lists of
513 instructions enabled by the corresponding key
515 cap_to_inst = {}
517 for inst in instructions:
518 caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
519 for cap in caps:
520 if cap not in cap_to_inst:
521 cap_to_inst[cap] = []
522 cap_to_inst[cap].append(inst["opname"])
524 return cap_to_inst
527 def gen_instr_coverage_report(path, instructions):
528 """Dumps to standard output a YAML report of current instruction coverage
530 Arguments:
531 - path: the path to SPIRBase.td
532 - instructions: a list containing all SPIR-V instructions' grammar
534 with open(path, "r") as f:
535 content = f.read()
537 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
539 existing_opcodes = [k[11:] for k in re.findall("def SPIRV_OC_\w+", content[1])]
540 existing_instructions = list(
541 filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
544 instructions_opnames = [inst["opname"] for inst in instructions]
546 remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
547 remaining_instructions = list(
548 filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
551 rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
552 ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
554 rem_cap_to_cov = {}
556 # Calculate coverage for each capability
557 for cap in rem_cap_to_instr:
558 if cap not in ex_cap_to_instr:
559 rem_cap_to_cov[cap] = 0.0
560 else:
561 rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
562 len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
565 report = {}
567 # Merge the 3 maps into one report
568 for cap in rem_cap_to_instr:
569 report[cap] = {}
570 report[cap]["Supported Instructions"] = (
571 ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
573 report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
574 report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
576 print(yaml.dump(report))
579 def update_td_opcodes(path, instructions, filter_list):
580 """Updates SPIRBase.td with new generated opcode cases.
582 Arguments:
583 - path: the path to SPIRBase.td
584 - instructions: a list containing all SPIR-V instructions' grammar
585 - filter_list: a list containing new opnames to add
588 with open(path, "r") as f:
589 content = f.read()
591 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
592 assert len(content) == 3
594 # Extend opcode list with existing list
595 prefix = "def SPIRV_OC_"
596 existing_opcodes = [
597 k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1])
599 filter_list.extend(existing_opcodes)
600 filter_list = list(set(filter_list))
602 # Generate the opcode for all instructions in SPIR-V
603 filter_instrs = list(
604 filter(lambda inst: (inst["opname"] in filter_list), instructions)
606 # Sort instruction based on opcode
607 filter_instrs.sort(key=lambda inst: inst["opcode"])
608 opcode = gen_opcode(filter_instrs)
610 # Substitute the opcode
611 content = (
612 content[0]
613 + AUTOGEN_OPCODE_SECTION_MARKER
614 + "\n\n"
615 + opcode
616 + "\n\n// End "
617 + AUTOGEN_OPCODE_SECTION_MARKER
618 + content[2]
621 with open(path, "w") as f:
622 f.write(content)
625 def update_td_enum_attrs(path, operand_kinds, filter_list):
626 """Updates SPIRBase.td with new generated enum definitions.
628 Arguments:
629 - path: the path to SPIRBase.td
630 - operand_kinds: a list containing all operand kinds' grammar
631 - filter_list: a list containing new enums to add
633 with open(path, "r") as f:
634 content = f.read()
636 content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
637 assert len(content) == 3
639 # Extend filter list with existing enum definitions
640 existing_kinds = [k[8:-4] for k in re.findall("def SPIRV_\w+Attr", content[1])]
641 filter_list.extend(existing_kinds)
643 capability_mapping = get_capability_mapping(operand_kinds)
645 # Generate definitions for all enums in filter list
646 defs = [
647 gen_operand_kind_enum_attr(kind, capability_mapping)
648 for kind in operand_kinds
649 if kind["kind"] in filter_list
651 # Sort alphabetically according to enum name
652 defs.sort(key=lambda enum: enum[0])
653 # Only keep the definitions from now on
654 # Put Capability's definition at the very beginning because capability cases
655 # will be referenced later
656 defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
657 enum[1] for enum in defs if enum[0] != "Capability"
660 # Substitute the old section
661 content = (
662 content[0]
663 + AUTOGEN_ENUM_SECTION_MARKER
664 + "\n\n"
665 + "\n\n".join(defs)
666 + "\n\n// End "
667 + AUTOGEN_ENUM_SECTION_MARKER
668 + content[2]
671 with open(path, "w") as f:
672 f.write(content)
675 def snake_casify(name):
676 """Turns the given name to follow snake_case convention."""
677 return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
680 def map_spec_operand_to_ods_argument(operand):
681 """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
683 Arguments:
684 - A dict containing the operand's kind, quantifier, and name
686 Returns:
687 - A string containing both the type and name for the argument
689 kind = operand["kind"]
690 quantifier = operand.get("quantifier", "")
692 # These instruction "operands" are for encoding the results; they should
693 # not be handled here.
694 assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
695 assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
697 if kind == "IdRef":
698 if quantifier == "":
699 arg_type = "SPIRV_Type"
700 elif quantifier == "?":
701 arg_type = "Optional<SPIRV_Type>"
702 else:
703 arg_type = "Variadic<SPIRV_Type>"
704 elif kind == "IdMemorySemantics" or kind == "IdScope":
705 # TODO: Need to further constrain 'IdMemorySemantics'
706 # and 'IdScope' given that they should be generated from OpConstant.
707 assert quantifier == "", (
708 "unexpected to have optional/variadic memory " "semantics or scope <id>"
710 arg_type = "SPIRV_" + kind[2:] + "Attr"
711 elif kind == "LiteralInteger":
712 if quantifier == "":
713 arg_type = "I32Attr"
714 elif quantifier == "?":
715 arg_type = "OptionalAttr<I32Attr>"
716 else:
717 arg_type = "OptionalAttr<I32ArrayAttr>"
718 elif (
719 kind == "LiteralString"
720 or kind == "LiteralContextDependentNumber"
721 or kind == "LiteralExtInstInteger"
722 or kind == "LiteralSpecConstantOpInteger"
723 or kind == "PairLiteralIntegerIdRef"
724 or kind == "PairIdRefLiteralInteger"
725 or kind == "PairIdRefIdRef"
727 assert False, '"{}" kind unimplemented'.format(kind)
728 else:
729 # The rest are all enum operands that we represent with op attributes.
730 assert quantifier != "*", "unexpected to have variadic enum attribute"
731 arg_type = "SPIRV_{}Attr".format(kind)
732 if quantifier == "?":
733 arg_type = "OptionalAttr<{}>".format(arg_type)
735 name = operand.get("name", "")
736 name = snake_casify(name) if name else kind.lower()
738 return "{}:${}".format(arg_type, name)
741 def get_description(text, appendix):
742 """Generates the description for the given SPIR-V instruction.
744 Arguments:
745 - text: Textual description of the operation as string.
746 - appendix: Additional contents to attach in description as string,
747 includking IR examples, and others.
749 Returns:
750 - A string that corresponds to the description of the Tablegen op.
752 fmt_str = "{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n "
753 return fmt_str.format(text=text, appendix=appendix)
756 def get_op_definition(
757 instruction, opname, doc, existing_info, capability_mapping, settings
759 """Generates the TableGen op definition for the given SPIR-V instruction.
761 Arguments:
762 - instruction: the instruction's SPIR-V JSON grammar
763 - doc: the instruction's SPIR-V HTML doc
764 - existing_info: a dict containing potential manually specified sections for
765 this instruction
766 - capability_mapping: mapping from duplicated capability symbols to the
767 canonicalized symbol chosen for SPIRVBase.td
769 Returns:
770 - A string containing the TableGen op definition
772 if settings.gen_cl_ops:
773 fmt_str = (
774 "def SPIRV_{opname}Op : "
775 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
776 "{{\n let summary = {summary};\n\n let description = "
777 "[{{\n{description}}}];{availability}\n"
779 else:
780 fmt_str = (
781 "def SPIRV_{vendor_name}{opname_src}Op : "
782 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
783 "{{\n let summary = {summary};\n\n let description = "
784 "[{{\n{description}}}];{availability}\n"
787 vendor_name = ""
788 inst_category = existing_info.get("inst_category", "Op")
789 if inst_category == "Op":
790 fmt_str += (
791 "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n"
793 elif inst_category.endswith("VendorOp"):
794 vendor_name = inst_category.split("VendorOp")[0].upper()
795 assert len(vendor_name) != 0, "Invalid instruction category"
797 fmt_str += "{extras}" "}}\n"
799 opname_src = instruction["opname"]
800 if opname.startswith("Op"):
801 opname_src = opname_src[2:]
802 if len(vendor_name) > 0:
803 assert opname_src.endswith(
804 vendor_name
805 ), "op name does not match the instruction category"
806 opname_src = opname_src[: -len(vendor_name)]
808 category_args = existing_info.get("category_args", "")
810 if "\n" in doc:
811 summary, text = doc.split("\n", 1)
812 else:
813 summary = doc
814 text = ""
815 wrapper = textwrap.TextWrapper(
816 width=76, initial_indent=" ", subsequent_indent=" "
819 # Format summary. If the summary can fit in the same line, we print it out
820 # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
821 summary = summary.strip()
822 if len(summary) + len(' let summary = "";') <= 80:
823 summary = '"{}"'.format(summary)
824 else:
825 summary = "[{{\n{}\n }}]".format(wrapper.fill(summary))
827 # Wrap text
828 text = text.split("\n")
829 text = [wrapper.fill(line) for line in text if line]
830 text = "\n\n".join(text)
832 operands = instruction.get("operands", [])
834 # Op availability
835 avail = get_availability_spec(instruction, capability_mapping, True, False)
836 if avail:
837 avail = "\n\n {0}".format(avail)
839 # Set op's result
840 results = ""
841 if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
842 results = "\n SPIRV_Type:$result\n "
843 operands = operands[1:]
844 if "results" in existing_info:
845 results = existing_info["results"]
847 # Ignore the operand standing for the result <id>
848 if len(operands) > 0 and operands[0]["kind"] == "IdResult":
849 operands = operands[1:]
851 # Set op' argument
852 arguments = existing_info.get("arguments", None)
853 if arguments is None:
854 arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
855 arguments = ",\n ".join(arguments)
856 if arguments:
857 # Prepend and append whitespace for formatting
858 arguments = "\n {}\n ".format(arguments)
860 description = existing_info.get("description", None)
861 if description is None:
862 assembly = (
863 "\n ```\n"
864 " [TODO]\n"
865 " ```\n\n"
866 " #### Example:\n\n"
867 " ```mlir\n"
868 " [TODO]\n"
869 " ```"
871 description = get_description(text, assembly)
873 return fmt_str.format(
874 opname=opname,
875 opname_src=opname_src,
876 opcode=instruction["opcode"],
877 category_args=category_args,
878 inst_category=inst_category,
879 vendor_name=vendor_name,
880 traits=existing_info.get("traits", ""),
881 summary=summary,
882 description=description,
883 availability=avail,
884 args=arguments,
885 results=results,
886 extras=existing_info.get("extras", ""),
890 def get_string_between(base, start, end):
891 """Extracts a substring with a specified start and end from a string.
893 Arguments:
894 - base: string to extract from.
895 - start: string to use as the start of the substring.
896 - end: string to use as the end of the substring.
898 Returns:
899 - The substring if found
900 - The part of the base after end of the substring. Is the base string itself
901 if the substring wasnt found.
903 split = base.split(start, 1)
904 if len(split) == 2:
905 rest = split[1].split(end, 1)
906 assert len(rest) == 2, (
907 'cannot find end "{end}" while extracting substring '
908 "starting with {start}".format(start=start, end=end)
910 return rest[0].rstrip(end), rest[1]
911 return "", split[0]
914 def get_string_between_nested(base, start, end):
915 """Extracts a substring with a nested start and end from a string.
917 Arguments:
918 - base: string to extract from.
919 - start: string to use as the start of the substring.
920 - end: string to use as the end of the substring.
922 Returns:
923 - The substring if found
924 - The part of the base after end of the substring. Is the base string itself
925 if the substring wasn't found.
927 split = base.split(start, 1)
928 if len(split) == 2:
929 # Handle nesting delimiters
930 rest = split[1]
931 unmatched_start = 1
932 index = 0
933 while unmatched_start > 0 and index < len(rest):
934 if rest[index:].startswith(end):
935 unmatched_start -= 1
936 if unmatched_start == 0:
937 break
938 index += len(end)
939 elif rest[index:].startswith(start):
940 unmatched_start += 1
941 index += len(start)
942 else:
943 index += 1
945 assert index < len(rest), (
946 'cannot find end "{end}" while extracting substring '
947 'starting with "{start}"'.format(start=start, end=end)
949 return rest[:index], rest[index + len(end) :]
950 return "", split[0]
953 def extract_td_op_info(op_def):
954 """Extracts potentially manually specified sections in op's definition.
956 Arguments: - A string containing the op's TableGen definition
958 Returns:
959 - A dict containing potential manually specified sections
961 # Get opname
962 opname = [o[8:-2] for o in re.findall("def SPIRV_\w+Op", op_def)]
963 assert len(opname) == 1, "more than one ops in the same section!"
964 opname = opname[0]
966 # Get instruction category
967 inst_category = [o[4:] for o in re.findall("SPIRV_\w+Op", op_def.split(":", 1)[1])]
968 assert len(inst_category) <= 1, "more than one ops in the same section!"
969 inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
971 # Get category_args
972 op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
973 opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
974 category_args = rest.split("[", 1)[0]
976 # Get traits
977 traits, _ = get_string_between_nested(rest, "[", "]")
979 # Get description
980 description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
982 # Get arguments
983 args, rest = get_string_between(rest, " let arguments = (ins", ");\n")
985 # Get results
986 results, rest = get_string_between(rest, " let results = (outs", ");\n")
988 extras = rest.strip(" }\n")
989 if extras:
990 extras = "\n {}\n".format(extras)
992 return {
993 # Prefix with 'Op' to make it consistent with SPIR-V spec
994 "opname": "Op{}".format(opname),
995 "inst_category": inst_category,
996 "category_args": category_args,
997 "traits": traits,
998 "description": description,
999 "arguments": args,
1000 "results": results,
1001 "extras": extras,
1005 def update_td_op_definitions(
1006 path, instructions, docs, filter_list, inst_category, capability_mapping, settings
1008 """Updates SPIRVOps.td with newly generated op definition.
1010 Arguments:
1011 - path: path to SPIRVOps.td
1012 - instructions: SPIR-V JSON grammar for all instructions
1013 - docs: SPIR-V HTML doc for all instructions
1014 - filter_list: a list containing new opnames to include
1015 - capability_mapping: mapping from duplicated capability symbols to the
1016 canonicalized symbol chosen for SPIRVBase.td.
1018 Returns:
1019 - A string containing all the TableGen op definitions
1021 with open(path, "r") as f:
1022 content = f.read()
1024 # Split the file into chunks, each containing one op.
1025 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
1026 header = ops[0]
1027 footer = ops[-1]
1028 ops = ops[1:-1]
1030 # For each existing op, extract the manually-written sections out to retain
1031 # them when re-generating the ops. Also append the existing ops to filter
1032 # list.
1033 name_op_map = {} # Map from opname to its existing ODS definition
1034 op_info_dict = {}
1035 for op in ops:
1036 info_dict = extract_td_op_info(op)
1037 opname = info_dict["opname"]
1038 name_op_map[opname] = op
1039 op_info_dict[opname] = info_dict
1040 filter_list.append(opname)
1041 filter_list = sorted(list(set(filter_list)))
1043 op_defs = []
1045 if settings.gen_cl_ops:
1046 fix_opname = lambda src: src.replace("CL", "").lower()
1047 else:
1048 fix_opname = lambda src: src
1050 for opname in filter_list:
1051 # Find the grammar spec for this op
1052 try:
1053 fixed_opname = fix_opname(opname)
1054 instruction = next(
1055 inst for inst in instructions if inst["opname"] == fixed_opname
1058 op_defs.append(
1059 get_op_definition(
1060 instruction,
1061 opname,
1062 docs[fixed_opname],
1063 op_info_dict.get(opname, {"inst_category": inst_category}),
1064 capability_mapping,
1065 settings,
1068 except StopIteration:
1069 # This is an op added by us; use the existing ODS definition.
1070 op_defs.append(name_op_map[opname])
1072 # Substitute the old op definitions
1073 op_defs = [header] + op_defs + [footer]
1074 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
1076 with open(path, "w") as f:
1077 f.write(content)
1080 if __name__ == "__main__":
1081 import argparse
1083 cli_parser = argparse.ArgumentParser(
1084 description="Update SPIR-V dialect definitions using SPIR-V spec"
1087 cli_parser.add_argument(
1088 "--base-td-path",
1089 dest="base_td_path",
1090 type=str,
1091 default=None,
1092 help="Path to SPIRVBase.td",
1094 cli_parser.add_argument(
1095 "--op-td-path",
1096 dest="op_td_path",
1097 type=str,
1098 default=None,
1099 help="Path to SPIRVOps.td",
1102 cli_parser.add_argument(
1103 "--new-enum",
1104 dest="new_enum",
1105 type=str,
1106 default=None,
1107 help="SPIR-V enum to be added to SPIRVBase.td",
1109 cli_parser.add_argument(
1110 "--new-opcodes",
1111 dest="new_opcodes",
1112 type=str,
1113 default=None,
1114 nargs="*",
1115 help="update SPIR-V opcodes in SPIRVBase.td",
1117 cli_parser.add_argument(
1118 "--new-inst",
1119 dest="new_inst",
1120 type=str,
1121 default=None,
1122 nargs="*",
1123 help="SPIR-V instruction to be added to ops file",
1125 cli_parser.add_argument(
1126 "--inst-category",
1127 dest="inst_category",
1128 type=str,
1129 default="Op",
1130 help="SPIR-V instruction category used for choosing "
1131 "the TableGen base class to define this op",
1133 cli_parser.add_argument(
1134 "--gen-cl-ops",
1135 dest="gen_cl_ops",
1136 help="Generate OpenCL Extended Instruction Set op",
1137 action="store_true",
1139 cli_parser.set_defaults(gen_cl_ops=False)
1140 cli_parser.add_argument(
1141 "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
1143 cli_parser.set_defaults(gen_inst_coverage=False)
1145 args = cli_parser.parse_args()
1147 if args.gen_cl_ops:
1148 ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
1149 ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
1150 else:
1151 ext_html_url = None
1152 ext_json_url = None
1154 operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
1156 # Define new enum attr
1157 if args.new_enum is not None:
1158 assert args.base_td_path is not None
1159 filter_list = [args.new_enum] if args.new_enum else []
1160 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
1162 # Define new opcode
1163 if args.new_opcodes is not None:
1164 assert args.base_td_path is not None
1165 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
1167 # Define new op
1168 if args.new_inst is not None:
1169 assert args.op_td_path is not None
1170 docs = get_spirv_doc_from_html_spec(ext_html_url, args)
1171 capability_mapping = get_capability_mapping(operand_kinds)
1172 update_td_op_definitions(
1173 args.op_td_path,
1174 instructions,
1175 docs,
1176 args.new_inst,
1177 args.inst_category,
1178 capability_mapping,
1179 args,
1181 print("Done. Note that this script just generates a template; ", end="")
1182 print("please read the spec and update traits, arguments, and ", end="")
1183 print("results accordingly.")
1185 if args.gen_inst_coverage:
1186 gen_instr_coverage_report(args.base_td_path, instructions)