[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / EnumPythonBindingGen.cpp
blobf4ced0803772edbd56ae3c9da0cc5dcda6232a0e
1 //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10 // generate the corresponding Python binding classes.
12 //===----------------------------------------------------------------------===//
13 #include "OpGenHelpers.h"
15 #include "mlir/TableGen/AttrOrTypeDef.h"
16 #include "mlir/TableGen/Attribute.h"
17 #include "mlir/TableGen/Dialect.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Record.h"
22 using namespace mlir;
23 using namespace mlir::tblgen;
25 /// File header and includes.
26 constexpr const char *fileHeader = R"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
29 from enum import IntEnum, auto, IntFlag
30 from ._ods_common import _cext as _ods_cext
31 from ..ir import register_attribute_builder
32 _ods_ir = _ods_cext.ir
34 )Py";
36 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
37 static std::string makePythonEnumCaseName(StringRef name) {
38 if (isPythonReserved(name.str()))
39 return (name + "_").str();
40 return name.str();
43 /// Emits the Python class for the given enum.
44 static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
45 os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
46 enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
47 if (!enumAttr.getSummary().empty())
48 os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
49 os << "\n";
51 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
52 os << llvm::formatv(
53 " {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
54 enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
55 : "auto()");
58 os << "\n";
60 if (enumAttr.isBitEnum()) {
61 os << llvm::formatv(" def __iter__(self):\n"
62 " return iter([case for case in type(self) if "
63 "(self & case) is case])\n");
64 os << llvm::formatv(" def __len__(self):\n"
65 " return bin(self).count(\"1\")\n");
66 os << "\n";
69 os << llvm::formatv(" def __str__(self):\n");
70 if (enumAttr.isBitEnum())
71 os << llvm::formatv(" if len(self) > 1:\n"
72 " return \"{0}\".join(map(str, self))\n",
73 enumAttr.getDef().getValueAsString("separator"));
74 for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
75 os << llvm::formatv(" if self is {0}.{1}:\n",
76 enumAttr.getEnumClassName(),
77 makePythonEnumCaseName(enumCase.getSymbol()));
78 os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
80 os << llvm::formatv(
81 " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
82 enumAttr.getEnumClassName());
83 os << "\n";
86 /// Attempts to extract the bitwidth B from string "uintB_t" describing the
87 /// type. This bitwidth information is not readily available in ODS. Returns
88 /// `false` on success, `true` on failure.
89 static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
90 if (!uintType.consume_front("uint"))
91 return true;
92 if (!uintType.consume_back("_t"))
93 return true;
94 return uintType.getAsInteger(/*Radix=*/10, bitwidth);
97 /// Emits an attribute builder for the given enum attribute to support automatic
98 /// conversion between enum values and attributes in Python. Returns
99 /// `false` on success, `true` on failure.
100 static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
101 int64_t bitwidth;
102 if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
103 llvm::errs() << "failed to identify bitwidth of "
104 << enumAttr.getUnderlyingType();
105 return true;
108 os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
109 enumAttr.getAttrDefName());
110 os << llvm::formatv("def _{0}(x, context):\n",
111 enumAttr.getAttrDefName().lower());
112 os << llvm::formatv(
113 " return "
114 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
115 "context=context), int(x))\n\n",
116 bitwidth);
117 return false;
120 /// Emits an attribute builder for the given dialect enum attribute to support
121 /// automatic conversion between enum values and attributes in Python. Returns
122 /// `false` on success, `true` on failure.
123 static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
124 StringRef formatString,
125 raw_ostream &os) {
126 os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
127 os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
128 os << llvm::formatv(" return "
129 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
130 formatString);
131 return false;
134 /// Emits Python bindings for all enums in the record keeper. Returns
135 /// `false` on success, `true` on failure.
136 static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
137 raw_ostream &os) {
138 os << fileHeader;
139 for (auto &it :
140 recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
141 EnumAttr enumAttr(*it);
142 emitEnumClass(enumAttr, os);
143 emitAttributeBuilder(enumAttr, os);
145 for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
146 AttrOrTypeDef attr(&*it);
147 if (!attr.getMnemonic()) {
148 llvm::errs() << "enum case " << attr
149 << " needs mnemonic for python enum bindings generation";
150 return true;
152 StringRef mnemonic = attr.getMnemonic().value();
153 std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
154 StringRef dialect = attr.getDialect().getName();
155 if (assemblyFormat == "`<` $value `>`") {
156 emitDialectEnumAttributeBuilder(
157 attr.getName(),
158 llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
159 } else if (assemblyFormat == "$value") {
160 emitDialectEnumAttributeBuilder(
161 attr.getName(),
162 llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
163 } else {
164 llvm::errs()
165 << "unsupported assembly format for python enum bindings generation";
166 return true;
170 return false;
173 // Registers the enum utility generator to mlir-tblgen.
174 static mlir::GenRegistration
175 genPythonEnumBindings("gen-python-enum-bindings",
176 "Generate Python bindings for enum attributes",
177 &emitPythonEnums);