1 //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10 // generate the corresponding Python binding classes.
12 //===----------------------------------------------------------------------===//
13 #include "OpGenHelpers.h"
15 #include "mlir/TableGen/AttrOrTypeDef.h"
16 #include "mlir/TableGen/Attribute.h"
17 #include "mlir/TableGen/Dialect.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
26 using llvm::RecordKeeper
;
28 /// File header and includes.
29 constexpr const char *fileHeader
= R
"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
32 from enum import IntEnum, auto, IntFlag
33 from ._ods_common import _cext as _ods_cext
34 from ..ir import register_attribute_builder
35 _ods_ir = _ods_cext.ir
39 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
40 static std::string
makePythonEnumCaseName(StringRef name
) {
41 if (isPythonReserved(name
.str()))
42 return (name
+ "_").str();
46 /// Emits the Python class for the given enum.
47 static void emitEnumClass(EnumAttr enumAttr
, raw_ostream
&os
) {
48 os
<< formatv("class {0}({1}):\n", enumAttr
.getEnumClassName(),
49 enumAttr
.isBitEnum() ? "IntFlag" : "IntEnum");
50 if (!enumAttr
.getSummary().empty())
51 os
<< formatv(" \"\"\"{0}\"\"\"\n", enumAttr
.getSummary());
54 for (const EnumAttrCase
&enumCase
: enumAttr
.getAllCases()) {
55 os
<< formatv(" {0} = {1}\n",
56 makePythonEnumCaseName(enumCase
.getSymbol()),
57 enumCase
.getValue() >= 0 ? std::to_string(enumCase
.getValue())
63 if (enumAttr
.isBitEnum()) {
64 os
<< formatv(" def __iter__(self):\n"
65 " return iter([case for case in type(self) if "
66 "(self & case) is case])\n");
67 os
<< formatv(" def __len__(self):\n"
68 " return bin(self).count(\"1\")\n");
72 os
<< formatv(" def __str__(self):\n");
73 if (enumAttr
.isBitEnum())
74 os
<< formatv(" if len(self) > 1:\n"
75 " return \"{0}\".join(map(str, self))\n",
76 enumAttr
.getDef().getValueAsString("separator"));
77 for (const EnumAttrCase
&enumCase
: enumAttr
.getAllCases()) {
78 os
<< formatv(" if self is {0}.{1}:\n", enumAttr
.getEnumClassName(),
79 makePythonEnumCaseName(enumCase
.getSymbol()));
80 os
<< formatv(" return \"{0}\"\n", enumCase
.getStr());
82 os
<< formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
83 enumAttr
.getEnumClassName());
87 /// Attempts to extract the bitwidth B from string "uintB_t" describing the
88 /// type. This bitwidth information is not readily available in ODS. Returns
89 /// `false` on success, `true` on failure.
90 static bool extractUIntBitwidth(StringRef uintType
, int64_t &bitwidth
) {
91 if (!uintType
.consume_front("uint"))
93 if (!uintType
.consume_back("_t"))
95 return uintType
.getAsInteger(/*Radix=*/10, bitwidth
);
98 /// Emits an attribute builder for the given enum attribute to support automatic
99 /// conversion between enum values and attributes in Python. Returns
100 /// `false` on success, `true` on failure.
101 static bool emitAttributeBuilder(const EnumAttr
&enumAttr
, raw_ostream
&os
) {
103 if (extractUIntBitwidth(enumAttr
.getUnderlyingType(), bitwidth
)) {
104 llvm::errs() << "failed to identify bitwidth of "
105 << enumAttr
.getUnderlyingType();
109 os
<< formatv("@register_attribute_builder(\"{0}\")\n",
110 enumAttr
.getAttrDefName());
111 os
<< formatv("def _{0}(x, context):\n", enumAttr
.getAttrDefName().lower());
112 os
<< formatv(" return "
113 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
114 "context=context), int(x))\n\n",
119 /// Emits an attribute builder for the given dialect enum attribute to support
120 /// automatic conversion between enum values and attributes in Python. Returns
121 /// `false` on success, `true` on failure.
122 static bool emitDialectEnumAttributeBuilder(StringRef attrDefName
,
123 StringRef formatString
,
125 os
<< formatv("@register_attribute_builder(\"{0}\")\n", attrDefName
);
126 os
<< formatv("def _{0}(x, context):\n", attrDefName
.lower());
127 os
<< formatv(" return "
128 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
133 /// Emits Python bindings for all enums in the record keeper. Returns
134 /// `false` on success, `true` on failure.
135 static bool emitPythonEnums(const RecordKeeper
&records
, raw_ostream
&os
) {
137 for (const Record
*it
:
138 records
.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
139 EnumAttr
enumAttr(*it
);
140 emitEnumClass(enumAttr
, os
);
141 emitAttributeBuilder(enumAttr
, os
);
143 for (const Record
*it
:
144 records
.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
145 AttrOrTypeDef
attr(&*it
);
146 if (!attr
.getMnemonic()) {
147 llvm::errs() << "enum case " << attr
148 << " needs mnemonic for python enum bindings generation";
151 StringRef mnemonic
= attr
.getMnemonic().value();
152 std::optional
<StringRef
> assemblyFormat
= attr
.getAssemblyFormat();
153 StringRef dialect
= attr
.getDialect().getName();
154 if (assemblyFormat
== "`<` $value `>`") {
155 emitDialectEnumAttributeBuilder(
157 formatv("#{0}.{1}<{{str(x)}>", dialect
, mnemonic
).str(), os
);
158 } else if (assemblyFormat
== "$value") {
159 emitDialectEnumAttributeBuilder(
161 formatv("#{0}<{1} {{str(x)}>", dialect
, mnemonic
).str(), os
);
164 << "unsupported assembly format for python enum bindings generation";
172 // Registers the enum utility generator to mlir-tblgen.
173 static mlir::GenRegistration
174 genPythonEnumBindings("gen-python-enum-bindings",
175 "Generate Python bindings for enum attributes",