[Frontend] Remove unused includes (NFC) (#116927)
[llvm-project.git] / mlir / utils / spirv / gen_spirv_dialect.py
blob99ed3489b4cbda691ddf23a9c448280444b0f875
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
4 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 # See https://llvm.org/LICENSE.txt for license information.
6 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 # Script for updating SPIR-V dialect by scraping information from SPIR-V
9 # HTML and JSON specs from the Internet.
11 # For example, to define the enum attribute for SPIR-V memory model:
13 # ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \
14 # --new-enum MemoryModel
16 # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
17 # SPIR-V enum classes.
19 import itertools
20 import math
21 import re
22 import requests
23 import textwrap
24 import yaml
26 SPIRV_HTML_SPEC_URL = (
27 "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
29 SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
31 SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
32 SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
34 AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
35 AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
36 AUTOGEN_OPCODE_SECTION_MARKER = (
37 "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
41 def get_spirv_doc_from_html_spec(url, settings):
42 """Extracts instruction documentation from SPIR-V HTML spec.
44 Returns:
45 - A dict mapping from instruction opcode to documentation.
46 """
47 if url is None:
48 url = SPIRV_HTML_SPEC_URL
50 response = requests.get(url)
51 spec = response.content
53 from bs4 import BeautifulSoup
55 spirv = BeautifulSoup(spec, "html.parser")
57 doc = {}
59 if settings.gen_cl_ops:
60 section_anchor = spirv.find("h2", {"id": "_binary_form"})
61 for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
62 for table in section.find_all("table"):
63 inst_html = table.tbody.tr.td
64 opname = inst_html.a["id"]
65 # Ignore the first line, which is just the opname.
66 doc[opname] = inst_html.text.split("\n", 1)[1].strip()
67 else:
68 section_anchor = spirv.find("h3", {"id": "_instructions_3"})
69 for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
70 for table in section.find_all("table"):
71 inst_html = table.tbody.tr.td.p
72 opname = inst_html.a["id"]
73 # Ignore the first line, which is just the opname.
74 doc[opname] = inst_html.text.split("\n", 1)[1].strip()
76 return doc
79 def get_spirv_grammar_from_json_spec(url):
80 """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
82 Returns:
83 - A list containing all operand kinds' grammar
84 - A list containing all instructions' grammar
85 """
86 response = requests.get(SPIRV_JSON_SPEC_URL)
87 spec = response.content
89 import json
91 spirv = json.loads(spec)
93 if url is None:
94 return spirv["operand_kinds"], spirv["instructions"]
96 response_ext = requests.get(url)
97 spec_ext = response_ext.content
98 spirv_ext = json.loads(spec_ext)
100 return spirv["operand_kinds"], spirv_ext["instructions"]
103 def split_list_into_sublists(items):
104 """Split the list of items into multiple sublists.
106 This is to make sure the string composed from each sublist won't exceed
107 80 characters.
109 Arguments:
110 - items: a list of strings
112 chuncks = []
113 chunk = []
114 chunk_len = 0
116 for item in items:
117 chunk_len += len(item) + 2
118 if chunk_len > 80:
119 chuncks.append(chunk)
120 chunk = []
121 chunk_len = len(item) + 2
122 chunk.append(item)
124 if len(chunk) != 0:
125 chuncks.append(chunk)
127 return chuncks
130 def toposort(dag, sort_fn):
131 """Topologically sorts the given dag.
133 Arguments:
134 - dag: a dict mapping from a node to its incoming nodes.
135 - sort_fn: a function for sorting nodes in the same batch.
137 Returns:
138 A list containing topologically sorted nodes.
141 # Returns the next batch of nodes without incoming edges
142 def get_next_batch(dag):
143 while True:
144 no_prev_nodes = set(node for node, prev in dag.items() if not prev)
145 if not no_prev_nodes:
146 break
147 yield sorted(no_prev_nodes, key=sort_fn)
148 dag = {
149 node: (prev - no_prev_nodes)
150 for node, prev in dag.items()
151 if node not in no_prev_nodes
153 assert not dag, "found cyclic dependency"
155 sorted_nodes = []
156 for batch in get_next_batch(dag):
157 sorted_nodes.extend(batch)
159 return sorted_nodes
162 def toposort_capabilities(all_cases):
163 """Returns topologically sorted capability (symbol, value) pairs.
165 Arguments:
166 - all_cases: all capability cases (containing symbol, value, and implied
167 capabilities).
169 Returns:
170 A list containing topologically sorted capability (symbol, value) pairs.
172 dag = {}
173 name_to_value = {}
174 for case in all_cases:
175 # Get the current capability.
176 cur = case["enumerant"]
177 name_to_value[cur] = case["value"]
179 # Get capabilities implied by the current capability.
180 prev = case.get("capabilities", [])
181 uniqued_prev = set(prev)
182 dag[cur] = uniqued_prev
184 sorted_caps = toposort(dag, lambda x: name_to_value[x])
185 # Attach the capability's value as the second component of the pair.
186 return [(c, name_to_value[c]) for c in sorted_caps]
189 def get_availability_spec(enum_case, for_op, for_cap):
190 """Returns the availability specification string for the given enum case.
192 Arguments:
193 - enum_case: the enum case to generate availability spec for. It may contain
194 'version', 'lastVersion', 'extensions', or 'capabilities'.
195 - for_op: bool value indicating whether this is the availability spec for an
196 op itself.
197 - for_cap: bool value indicating whether this is the availability spec for
198 capabilities themselves.
200 Returns:
201 - A `let availability = [...];` string if with availability spec or
202 empty string if without availability spec
204 assert not (for_op and for_cap), "cannot set both for_op and for_cap"
206 DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
207 DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
208 DEFAULT_CAP = "Capability<[]>"
209 DEFAULT_EXT = "Extension<[]>"
211 min_version = enum_case.get("version", "")
212 if min_version == "None":
213 min_version = ""
214 elif min_version:
215 min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
216 # TODO: delete this once ODS can support dialect-specific content
217 # and we can use omission to mean no requirements.
218 if for_op and not min_version:
219 min_version = DEFAULT_MIN_VERSION
221 max_version = enum_case.get("lastVersion", "")
222 if max_version:
223 max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
224 # TODO: delete this once ODS can support dialect-specific content
225 # and we can use omission to mean no requirements.
226 if for_op and not max_version:
227 max_version = DEFAULT_MAX_VERSION
229 exts = enum_case.get("extensions", [])
230 if exts:
231 exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
232 # We need to strip the minimal version requirement if this symbol is
233 # available via an extension, which means *any* SPIR-V version can support
234 # it as long as the extension is provided. The grammar's 'version' field
235 # under such case should be interpreted as this symbol is introduced as
236 # a core symbol since the given version, rather than a minimal version
237 # requirement.
238 min_version = DEFAULT_MIN_VERSION if for_op else ""
239 # TODO: delete this once ODS can support dialect-specific content
240 # and we can use omission to mean no requirements.
241 if for_op and not exts:
242 exts = DEFAULT_EXT
244 caps = enum_case.get("capabilities", [])
245 implies = ""
246 if caps:
247 canonicalized_caps = []
248 for c in caps:
249 canonicalized_caps.append(c)
250 prefixed_caps = [
251 "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
253 if for_cap:
254 # If this is generating the availability for capabilities, we need to
255 # put the capability "requirements" in implies field because now
256 # the "capabilities" field in the source grammar means so.
257 caps = ""
258 implies = "list<I32EnumAttrCase> implies = [{}];".format(
259 ", ".join(prefixed_caps)
261 else:
262 caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
263 implies = ""
264 # TODO: delete this once ODS can support dialect-specific content
265 # and we can use omission to mean no requirements.
266 if for_op and not caps:
267 caps = DEFAULT_CAP
269 avail = ""
270 # Compose availability spec if any of the requirements is not empty.
271 # For ops, because we have a default in SPIRV_Op class, omit if the spec
272 # is the same.
273 if (min_version or max_version or caps or exts) and not (
274 for_op
275 and min_version == DEFAULT_MIN_VERSION
276 and max_version == DEFAULT_MAX_VERSION
277 and caps == DEFAULT_CAP
278 and exts == DEFAULT_EXT
280 joined_spec = ",\n ".join(
281 [e for e in [min_version, max_version, exts, caps] if e]
283 avail = "{} availability = [\n {}\n ];".format(
284 "let" if for_op else "list<Availability>", joined_spec
287 return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)
290 def gen_operand_kind_enum_attr(operand_kind):
291 """Generates the TableGen EnumAttr definition for the given operand kind.
293 Returns:
294 - The operand kind's name
295 - A string containing the TableGen EnumAttr definition
297 if "enumerants" not in operand_kind:
298 return "", ""
300 # Returns a symbol for the given case in the given kind. This function
301 # handles Dim specially to avoid having numbers as the start of symbols,
302 # which does not play well with C++ and the MLIR parser.
303 def get_case_symbol(kind_name, case_name):
304 if kind_name == "Dim":
305 if case_name == "1D" or case_name == "2D" or case_name == "3D":
306 return "Dim{}".format(case_name)
307 return case_name
309 kind_name = operand_kind["kind"]
310 is_bit_enum = operand_kind["category"] == "BitEnum"
311 kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
313 name_to_case_dict = {}
314 for case in operand_kind["enumerants"]:
315 name_to_case_dict[case["enumerant"]] = case
317 if kind_name == "Capability":
318 # Special treatment for capability cases: we need to sort them topologically
319 # because a capability can refer to another via the 'implies' field.
320 kind_cases = toposort_capabilities(
321 operand_kind["enumerants"]
323 else:
324 kind_cases = [
325 (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
327 max_len = max([len(symbol) for (symbol, _) in kind_cases])
329 # Generate the definition for each enum case
330 case_category = "I32Bit" if is_bit_enum else "I32"
331 fmt_str = (
332 "def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
333 '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
335 case_defs = []
336 for case_pair in kind_cases:
337 name = case_pair[0]
338 if is_bit_enum:
339 value = int(case_pair[1], base=16)
340 else:
341 value = int(case_pair[1])
342 avail = get_availability_spec(
343 name_to_case_dict[name],
344 False,
345 kind_name == "Capability",
347 if is_bit_enum:
348 if value == 0:
349 suffix = "None"
350 value = ""
351 else:
352 suffix = "Bit"
353 value = ", {}".format(int(math.log2(value)))
354 else:
355 suffix = ""
356 value = ", {}".format(value)
358 case_def = fmt_str.format(
359 category=case_category,
360 suffix=suffix,
361 acronym=kind_acronym,
362 case_name=name,
363 symbol=get_case_symbol(kind_name, name),
364 case_value_part=value,
365 avail=" {{\n {}\n}}".format(avail) if avail else ";",
366 colon=":",
367 offset=(max_len + 1 - len(name)),
369 case_defs.append(case_def)
370 case_defs = "\n".join(case_defs)
372 # Generate the list of enum case names
373 fmt_str = "SPIRV_{acronym}_{symbol}"
374 case_names = [
375 fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
378 # Split them into sublists and concatenate into multiple lines
379 case_names = split_list_into_sublists(case_names)
380 case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
381 case_names = ",\n".join(case_names)
383 # Generate the enum attribute definition
384 kind_category = "Bit" if is_bit_enum else "I32"
385 enum_attr = """def SPIRV_{name}Attr :
386 SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
387 {cases}
388 ]>;""".format(
389 name=kind_name,
390 snake_name=snake_casify(kind_name),
391 category=kind_category,
392 cases=case_names,
394 return kind_name, case_defs + "\n\n" + enum_attr
397 def gen_opcode(instructions):
398 """Generates the TableGen definition to map opname to opcode
400 Returns:
401 - A string containing the TableGen SPIRV_OpCode definition
404 max_len = max([len(inst["opname"]) for inst in instructions])
405 def_fmt_str = (
406 "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
408 opcode_defs = [
409 def_fmt_str.format(
410 name=inst["opname"],
411 value=inst["opcode"],
412 colon=":",
413 offset=(max_len + 1 - len(inst["opname"])),
415 for inst in instructions
417 opcode_str = "\n".join(opcode_defs)
419 decl_fmt_str = "SPIRV_OC_{name}"
420 opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
421 opcode_list = split_list_into_sublists(opcode_list)
422 opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
423 opcode_list = ",\n".join(opcode_list)
424 enum_attr = (
425 "def SPIRV_OpcodeAttr :\n"
426 ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
427 '"opcode", [\n'
428 "{lst}\n"
429 " ]>;".format(name="Opcode", lst=opcode_list)
431 return opcode_str + "\n\n" + enum_attr
434 def map_cap_to_opnames(instructions):
435 """Maps capabilities to instructions enabled by those capabilities
437 Arguments:
438 - instructions: a list containing a subset of SPIR-V instructions' grammar
439 Returns:
440 - A map with keys representing capabilities and values of lists of
441 instructions enabled by the corresponding key
443 cap_to_inst = {}
445 for inst in instructions:
446 caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
447 for cap in caps:
448 if cap not in cap_to_inst:
449 cap_to_inst[cap] = []
450 cap_to_inst[cap].append(inst["opname"])
452 return cap_to_inst
455 def gen_instr_coverage_report(path, instructions):
456 """Dumps to standard output a YAML report of current instruction coverage
458 Arguments:
459 - path: the path to SPIRBase.td
460 - instructions: a list containing all SPIR-V instructions' grammar
462 with open(path, "r") as f:
463 content = f.read()
465 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
467 prefix = "def SPIRV_OC_"
468 existing_opcodes = [
469 k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1])
471 existing_instructions = list(
472 filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
475 instructions_opnames = [inst["opname"] for inst in instructions]
477 remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
478 remaining_instructions = list(
479 filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
482 rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
483 ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
485 rem_cap_to_cov = {}
487 # Calculate coverage for each capability
488 for cap in rem_cap_to_instr:
489 if cap not in ex_cap_to_instr:
490 rem_cap_to_cov[cap] = 0.0
491 else:
492 rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
493 len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
496 report = {}
498 # Merge the 3 maps into one report
499 for cap in rem_cap_to_instr:
500 report[cap] = {}
501 report[cap]["Supported Instructions"] = (
502 ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
504 report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
505 report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
507 print(yaml.dump(report))
510 def update_td_opcodes(path, instructions, filter_list):
511 """Updates SPIRBase.td with new generated opcode cases.
513 Arguments:
514 - path: the path to SPIRBase.td
515 - instructions: a list containing all SPIR-V instructions' grammar
516 - filter_list: a list containing new opnames to add
519 with open(path, "r") as f:
520 content = f.read()
522 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
523 assert len(content) == 3
525 # Extend opcode list with existing list
526 prefix = "def SPIRV_OC_"
527 existing_opcodes = [
528 k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1])
530 filter_list.extend(existing_opcodes)
531 filter_list = list(set(filter_list))
533 # Generate the opcode for all instructions in SPIR-V
534 filter_instrs = list(
535 filter(lambda inst: (inst["opname"] in filter_list), instructions)
537 # Sort instruction based on opcode
538 filter_instrs.sort(key=lambda inst: inst["opcode"])
539 opcode = gen_opcode(filter_instrs)
541 # Substitute the opcode
542 content = (
543 content[0]
544 + AUTOGEN_OPCODE_SECTION_MARKER
545 + "\n\n"
546 + opcode
547 + "\n\n// End "
548 + AUTOGEN_OPCODE_SECTION_MARKER
549 + content[2]
552 with open(path, "w") as f:
553 f.write(content)
556 def update_td_enum_attrs(path, operand_kinds, filter_list):
557 """Updates SPIRBase.td with new generated enum definitions.
559 Arguments:
560 - path: the path to SPIRBase.td
561 - operand_kinds: a list containing all operand kinds' grammar
562 - filter_list: a list containing new enums to add
564 with open(path, "r") as f:
565 content = f.read()
567 content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
568 assert len(content) == 3
570 # Extend filter list with existing enum definitions
571 prefix = "def SPIRV_"
572 suffix = "Attr"
573 existing_kinds = [
574 k[len(prefix) : -len(suffix)]
575 for k in re.findall(prefix + r"\w+" + suffix, content[1])
577 filter_list.extend(existing_kinds)
579 # Generate definitions for all enums in filter list
580 defs = [
581 gen_operand_kind_enum_attr(kind)
582 for kind in operand_kinds
583 if kind["kind"] in filter_list
585 # Sort alphabetically according to enum name
586 defs.sort(key=lambda enum: enum[0])
587 # Only keep the definitions from now on
588 # Put Capability's definition at the very beginning because capability cases
589 # will be referenced later
590 defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
591 enum[1] for enum in defs if enum[0] != "Capability"
594 # Substitute the old section
595 content = (
596 content[0]
597 + AUTOGEN_ENUM_SECTION_MARKER
598 + "\n\n"
599 + "\n\n".join(defs)
600 + "\n\n// End "
601 + AUTOGEN_ENUM_SECTION_MARKER
602 + content[2]
605 with open(path, "w") as f:
606 f.write(content)
609 def snake_casify(name):
610 """Turns the given name to follow snake_case convention."""
611 return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
614 def map_spec_operand_to_ods_argument(operand):
615 """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
617 Arguments:
618 - A dict containing the operand's kind, quantifier, and name
620 Returns:
621 - A string containing both the type and name for the argument
623 kind = operand["kind"]
624 quantifier = operand.get("quantifier", "")
626 # These instruction "operands" are for encoding the results; they should
627 # not be handled here.
628 assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
629 assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
631 if kind == "IdRef":
632 if quantifier == "":
633 arg_type = "SPIRV_Type"
634 elif quantifier == "?":
635 arg_type = "Optional<SPIRV_Type>"
636 else:
637 arg_type = "Variadic<SPIRV_Type>"
638 elif kind == "IdMemorySemantics" or kind == "IdScope":
639 # TODO: Need to further constrain 'IdMemorySemantics'
640 # and 'IdScope' given that they should be generated from OpConstant.
641 assert quantifier == "", (
642 "unexpected to have optional/variadic memory " "semantics or scope <id>"
644 arg_type = "SPIRV_" + kind[2:] + "Attr"
645 elif kind == "LiteralInteger":
646 if quantifier == "":
647 arg_type = "I32Attr"
648 elif quantifier == "?":
649 arg_type = "OptionalAttr<I32Attr>"
650 else:
651 arg_type = "OptionalAttr<I32ArrayAttr>"
652 elif (
653 kind == "LiteralString"
654 or kind == "LiteralContextDependentNumber"
655 or kind == "LiteralExtInstInteger"
656 or kind == "LiteralSpecConstantOpInteger"
657 or kind == "PairLiteralIntegerIdRef"
658 or kind == "PairIdRefLiteralInteger"
659 or kind == "PairIdRefIdRef"
661 assert False, '"{}" kind unimplemented'.format(kind)
662 else:
663 # The rest are all enum operands that we represent with op attributes.
664 assert quantifier != "*", "unexpected to have variadic enum attribute"
665 arg_type = "SPIRV_{}Attr".format(kind)
666 if quantifier == "?":
667 arg_type = "OptionalAttr<{}>".format(arg_type)
669 name = operand.get("name", "")
670 name = snake_casify(name) if name else kind.lower()
672 return "{}:${}".format(arg_type, name)
675 def get_description(text, appendix):
676 """Generates the description for the given SPIR-V instruction.
678 Arguments:
679 - text: Textual description of the operation as string.
680 - appendix: Additional contents to attach in description as string,
681 includking IR examples, and others.
683 Returns:
684 - A string that corresponds to the description of the Tablegen op.
686 fmt_str = "{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n "
687 return fmt_str.format(text=text, appendix=appendix)
690 def get_op_definition(
691 instruction, opname, doc, existing_info, settings
693 """Generates the TableGen op definition for the given SPIR-V instruction.
695 Arguments:
696 - instruction: the instruction's SPIR-V JSON grammar
697 - doc: the instruction's SPIR-V HTML doc
698 - existing_info: a dict containing potential manually specified sections for
699 this instruction
701 Returns:
702 - A string containing the TableGen op definition
704 if settings.gen_cl_ops:
705 fmt_str = (
706 "def SPIRV_{opname}Op : "
707 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
708 "{{\n let summary = {summary};\n\n let description = "
709 "[{{\n{description}}}];{availability}\n"
711 else:
712 fmt_str = (
713 "def SPIRV_{vendor_name}{opname_src}Op : "
714 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
715 "{{\n let summary = {summary};\n\n let description = "
716 "[{{\n{description}}}];{availability}\n"
719 vendor_name = ""
720 inst_category = existing_info.get("inst_category", "Op")
721 if inst_category == "Op":
722 fmt_str += (
723 "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n"
725 elif inst_category.endswith("VendorOp"):
726 vendor_name = inst_category.split("VendorOp")[0].upper()
727 assert len(vendor_name) != 0, "Invalid instruction category"
729 fmt_str += "{extras}" "}}\n"
731 opname_src = instruction["opname"]
732 if opname.startswith("Op"):
733 opname_src = opname_src[2:]
734 if len(vendor_name) > 0:
735 assert opname_src.endswith(
736 vendor_name
737 ), "op name does not match the instruction category"
738 opname_src = opname_src[: -len(vendor_name)]
740 category_args = existing_info.get("category_args", "")
742 if "\n" in doc:
743 summary, text = doc.split("\n", 1)
744 else:
745 summary = doc
746 text = ""
747 wrapper = textwrap.TextWrapper(
748 width=76, initial_indent=" ", subsequent_indent=" "
751 # Format summary. If the summary can fit in the same line, we print it out
752 # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
753 summary = summary.strip()
754 if len(summary) + len(' let summary = "";') <= 80:
755 summary = '"{}"'.format(summary)
756 else:
757 summary = "[{{\n{}\n }}]".format(wrapper.fill(summary))
759 # Wrap text
760 text = text.split("\n")
761 text = [wrapper.fill(line) for line in text if line]
762 text = "\n\n".join(text)
764 operands = instruction.get("operands", [])
766 # Op availability
767 avail = get_availability_spec(instruction, True, False)
768 if avail:
769 avail = "\n\n {0}".format(avail)
771 # Set op's result
772 results = ""
773 if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
774 results = "\n SPIRV_Type:$result\n "
775 operands = operands[1:]
776 if "results" in existing_info:
777 results = existing_info["results"]
779 # Ignore the operand standing for the result <id>
780 if len(operands) > 0 and operands[0]["kind"] == "IdResult":
781 operands = operands[1:]
783 # Set op' argument
784 arguments = existing_info.get("arguments", None)
785 if arguments is None:
786 arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
787 arguments = ",\n ".join(arguments)
788 if arguments:
789 # Prepend and append whitespace for formatting
790 arguments = "\n {}\n ".format(arguments)
792 description = existing_info.get("description", None)
793 if description is None:
794 assembly = (
795 "\n ```\n"
796 " [TODO]\n"
797 " ```\n\n"
798 " #### Example:\n\n"
799 " ```mlir\n"
800 " [TODO]\n"
801 " ```"
803 description = get_description(text, assembly)
805 return fmt_str.format(
806 opname=opname,
807 opname_src=opname_src,
808 opcode=instruction["opcode"],
809 category_args=category_args,
810 inst_category=inst_category,
811 vendor_name=vendor_name,
812 traits=existing_info.get("traits", ""),
813 summary=summary,
814 description=description,
815 availability=avail,
816 args=arguments,
817 results=results,
818 extras=existing_info.get("extras", ""),
822 def get_string_between(base, start, end):
823 """Extracts a substring with a specified start and end from a string.
825 Arguments:
826 - base: string to extract from.
827 - start: string to use as the start of the substring.
828 - end: string to use as the end of the substring.
830 Returns:
831 - The substring if found
832 - The part of the base after end of the substring. Is the base string itself
833 if the substring wasnt found.
835 split = base.split(start, 1)
836 if len(split) == 2:
837 rest = split[1].split(end, 1)
838 assert len(rest) == 2, (
839 'cannot find end "{end}" while extracting substring '
840 "starting with {start}".format(start=start, end=end)
842 return rest[0].rstrip(end), rest[1]
843 return "", split[0]
846 def get_string_between_nested(base, start, end):
847 """Extracts a substring with a nested start and end from a string.
849 Arguments:
850 - base: string to extract from.
851 - start: string to use as the start of the substring.
852 - end: string to use as the end of the substring.
854 Returns:
855 - The substring if found
856 - The part of the base after end of the substring. Is the base string itself
857 if the substring wasn't found.
859 split = base.split(start, 1)
860 if len(split) == 2:
861 # Handle nesting delimiters
862 rest = split[1]
863 unmatched_start = 1
864 index = 0
865 while unmatched_start > 0 and index < len(rest):
866 if rest[index:].startswith(end):
867 unmatched_start -= 1
868 if unmatched_start == 0:
869 break
870 index += len(end)
871 elif rest[index:].startswith(start):
872 unmatched_start += 1
873 index += len(start)
874 else:
875 index += 1
877 assert index < len(rest), (
878 'cannot find end "{end}" while extracting substring '
879 'starting with "{start}"'.format(start=start, end=end)
881 return rest[:index], rest[index + len(end) :]
882 return "", split[0]
885 def extract_td_op_info(op_def):
886 """Extracts potentially manually specified sections in op's definition.
888 Arguments: - A string containing the op's TableGen definition
890 Returns:
891 - A dict containing potential manually specified sections
893 # Get opname
894 prefix = "def SPIRV_"
895 suffix = "Op"
896 opname = [
897 o[len(prefix) : -len(suffix)]
898 for o in re.findall(prefix + r"\w+" + suffix, op_def)
900 assert len(opname) == 1, "more than one ops in the same section!"
901 opname = opname[0]
903 # Get instruction category
904 prefix = "SPIRV_"
905 inst_category = [
906 o[len(prefix) :]
907 for o in re.findall(prefix + r"\w+Op\b", op_def.split(":", 1)[1])
909 assert len(inst_category) <= 1, "more than one ops in the same section!"
910 inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
912 # Get category_args
913 op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
914 opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
915 category_args = rest.split("[", 1)[0]
916 category_args = category_args.rsplit(",", 1)[0]
918 # Get traits
919 traits, _ = get_string_between_nested(rest, "[", "]")
921 # Get description
922 description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
924 # Get arguments
925 args, rest = get_string_between(rest, " let arguments = (ins", ");\n")
927 # Get results
928 results, rest = get_string_between(rest, " let results = (outs", ");\n")
930 extras = rest.strip(" }\n")
931 if extras:
932 extras = "\n {}\n".format(extras)
934 return {
935 # Prefix with 'Op' to make it consistent with SPIR-V spec
936 "opname": "Op{}".format(opname),
937 "inst_category": inst_category,
938 "category_args": category_args,
939 "traits": traits,
940 "description": description,
941 "arguments": args,
942 "results": results,
943 "extras": extras,
947 def update_td_op_definitions(
948 path, instructions, docs, filter_list, inst_category, settings
950 """Updates SPIRVOps.td with newly generated op definition.
952 Arguments:
953 - path: path to SPIRVOps.td
954 - instructions: SPIR-V JSON grammar for all instructions
955 - docs: SPIR-V HTML doc for all instructions
956 - filter_list: a list containing new opnames to include
958 Returns:
959 - A string containing all the TableGen op definitions
961 with open(path, "r") as f:
962 content = f.read()
964 # Split the file into chunks, each containing one op.
965 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
966 header = ops[0]
967 footer = ops[-1]
968 ops = ops[1:-1]
970 # For each existing op, extract the manually-written sections out to retain
971 # them when re-generating the ops. Also append the existing ops to filter
972 # list.
973 name_op_map = {} # Map from opname to its existing ODS definition
974 op_info_dict = {}
975 for op in ops:
976 info_dict = extract_td_op_info(op)
977 opname = info_dict["opname"]
978 name_op_map[opname] = op
979 op_info_dict[opname] = info_dict
980 filter_list.append(opname)
981 filter_list = sorted(list(set(filter_list)))
983 op_defs = []
985 if settings.gen_cl_ops:
986 fix_opname = lambda src: src.replace("CL", "").lower()
987 else:
988 fix_opname = lambda src: src
990 for opname in filter_list:
991 # Find the grammar spec for this op
992 try:
993 fixed_opname = fix_opname(opname)
994 instruction = next(
995 inst for inst in instructions if inst["opname"] == fixed_opname
998 op_defs.append(
999 get_op_definition(
1000 instruction,
1001 opname,
1002 docs[fixed_opname],
1003 op_info_dict.get(opname, {"inst_category": inst_category}),
1004 settings,
1007 except StopIteration:
1008 # This is an op added by us; use the existing ODS definition.
1009 op_defs.append(name_op_map[opname])
1011 # Substitute the old op definitions
1012 op_defs = [header] + op_defs + [footer]
1013 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
1015 with open(path, "w") as f:
1016 f.write(content)
1019 if __name__ == "__main__":
1020 import argparse
1022 cli_parser = argparse.ArgumentParser(
1023 description="Update SPIR-V dialect definitions using SPIR-V spec"
1026 cli_parser.add_argument(
1027 "--base-td-path",
1028 dest="base_td_path",
1029 type=str,
1030 default=None,
1031 help="Path to SPIRVBase.td",
1033 cli_parser.add_argument(
1034 "--op-td-path",
1035 dest="op_td_path",
1036 type=str,
1037 default=None,
1038 help="Path to SPIRVOps.td",
1041 cli_parser.add_argument(
1042 "--new-enum",
1043 dest="new_enum",
1044 type=str,
1045 default=None,
1046 help="SPIR-V enum to be added to SPIRVBase.td",
1048 cli_parser.add_argument(
1049 "--new-opcodes",
1050 dest="new_opcodes",
1051 type=str,
1052 default=None,
1053 nargs="*",
1054 help="update SPIR-V opcodes in SPIRVBase.td",
1056 cli_parser.add_argument(
1057 "--new-inst",
1058 dest="new_inst",
1059 type=str,
1060 default=None,
1061 nargs="*",
1062 help="SPIR-V instruction to be added to ops file",
1064 cli_parser.add_argument(
1065 "--inst-category",
1066 dest="inst_category",
1067 type=str,
1068 default="Op",
1069 help="SPIR-V instruction category used for choosing "
1070 "the TableGen base class to define this op",
1072 cli_parser.add_argument(
1073 "--gen-cl-ops",
1074 dest="gen_cl_ops",
1075 help="Generate OpenCL Extended Instruction Set op",
1076 action="store_true",
1078 cli_parser.set_defaults(gen_cl_ops=False)
1079 cli_parser.add_argument(
1080 "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
1082 cli_parser.set_defaults(gen_inst_coverage=False)
1084 args = cli_parser.parse_args()
1086 if args.gen_cl_ops:
1087 ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
1088 ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
1089 else:
1090 ext_html_url = None
1091 ext_json_url = None
1093 operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
1095 # Define new enum attr
1096 if args.new_enum is not None:
1097 assert args.base_td_path is not None
1098 filter_list = [args.new_enum] if args.new_enum else []
1099 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
1101 # Define new opcode
1102 if args.new_opcodes is not None:
1103 assert args.base_td_path is not None
1104 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
1106 # Define new op
1107 if args.new_inst is not None:
1108 assert args.op_td_path is not None
1109 docs = get_spirv_doc_from_html_spec(ext_html_url, args)
1110 update_td_op_definitions(
1111 args.op_td_path,
1112 instructions,
1113 docs,
1114 args.new_inst,
1115 args.inst_category,
1116 args,
1118 print("Done. Note that this script just generates a template; ", end="")
1119 print("please read the spec and update traits, arguments, and ", end="")
1120 print("results accordingly.")
1122 if args.gen_inst_coverage:
1123 gen_instr_coverage_report(args.base_td_path, instructions)