1 //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10 // generate the corresponding Python binding classes.
12 //===----------------------------------------------------------------------===//
13 #include "OpGenHelpers.h"
15 #include "mlir/TableGen/AttrOrTypeDef.h"
16 #include "mlir/TableGen/Attribute.h"
17 #include "mlir/TableGen/Dialect.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
25 /// File header and includes.
26 constexpr const char *fileHeader
= R
"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
29 from enum import IntEnum, auto, IntFlag
30 from ._ods_common import _cext as _ods_cext
31 from ..ir import register_attribute_builder
32 _ods_ir = _ods_cext.ir
36 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
37 static std::string
makePythonEnumCaseName(StringRef name
) {
38 if (isPythonReserved(name
.str()))
39 return (name
+ "_").str();
43 /// Emits the Python class for the given enum.
44 static void emitEnumClass(EnumAttr enumAttr
, raw_ostream
&os
) {
45 os
<< llvm::formatv("class {0}({1}):\n", enumAttr
.getEnumClassName(),
46 enumAttr
.isBitEnum() ? "IntFlag" : "IntEnum");
47 if (!enumAttr
.getSummary().empty())
48 os
<< llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr
.getSummary());
51 for (const EnumAttrCase
&enumCase
: enumAttr
.getAllCases()) {
53 " {0} = {1}\n", makePythonEnumCaseName(enumCase
.getSymbol()),
54 enumCase
.getValue() >= 0 ? std::to_string(enumCase
.getValue())
60 if (enumAttr
.isBitEnum()) {
61 os
<< llvm::formatv(" def __iter__(self):\n"
62 " return iter([case for case in type(self) if "
63 "(self & case) is case])\n");
64 os
<< llvm::formatv(" def __len__(self):\n"
65 " return bin(self).count(\"1\")\n");
69 os
<< llvm::formatv(" def __str__(self):\n");
70 if (enumAttr
.isBitEnum())
71 os
<< llvm::formatv(" if len(self) > 1:\n"
72 " return \"{0}\".join(map(str, self))\n",
73 enumAttr
.getDef().getValueAsString("separator"));
74 for (const EnumAttrCase
&enumCase
: enumAttr
.getAllCases()) {
75 os
<< llvm::formatv(" if self is {0}.{1}:\n",
76 enumAttr
.getEnumClassName(),
77 makePythonEnumCaseName(enumCase
.getSymbol()));
78 os
<< llvm::formatv(" return \"{0}\"\n", enumCase
.getStr());
81 " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
82 enumAttr
.getEnumClassName());
86 /// Attempts to extract the bitwidth B from string "uintB_t" describing the
87 /// type. This bitwidth information is not readily available in ODS. Returns
88 /// `false` on success, `true` on failure.
89 static bool extractUIntBitwidth(StringRef uintType
, int64_t &bitwidth
) {
90 if (!uintType
.consume_front("uint"))
92 if (!uintType
.consume_back("_t"))
94 return uintType
.getAsInteger(/*Radix=*/10, bitwidth
);
97 /// Emits an attribute builder for the given enum attribute to support automatic
98 /// conversion between enum values and attributes in Python. Returns
99 /// `false` on success, `true` on failure.
100 static bool emitAttributeBuilder(const EnumAttr
&enumAttr
, raw_ostream
&os
) {
102 if (extractUIntBitwidth(enumAttr
.getUnderlyingType(), bitwidth
)) {
103 llvm::errs() << "failed to identify bitwidth of "
104 << enumAttr
.getUnderlyingType();
108 os
<< llvm::formatv("@register_attribute_builder(\"{0}\")\n",
109 enumAttr
.getAttrDefName());
110 os
<< llvm::formatv("def _{0}(x, context):\n",
111 enumAttr
.getAttrDefName().lower());
114 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
115 "context=context), int(x))\n\n",
120 /// Emits an attribute builder for the given dialect enum attribute to support
121 /// automatic conversion between enum values and attributes in Python. Returns
122 /// `false` on success, `true` on failure.
123 static bool emitDialectEnumAttributeBuilder(StringRef attrDefName
,
124 StringRef formatString
,
126 os
<< llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName
);
127 os
<< llvm::formatv("def _{0}(x, context):\n", attrDefName
.lower());
128 os
<< llvm::formatv(" return "
129 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
134 /// Emits Python bindings for all enums in the record keeper. Returns
135 /// `false` on success, `true` on failure.
136 static bool emitPythonEnums(const llvm::RecordKeeper
&recordKeeper
,
140 recordKeeper
.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
141 EnumAttr
enumAttr(*it
);
142 emitEnumClass(enumAttr
, os
);
143 emitAttributeBuilder(enumAttr
, os
);
145 for (auto &it
: recordKeeper
.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
146 AttrOrTypeDef
attr(&*it
);
147 if (!attr
.getMnemonic()) {
148 llvm::errs() << "enum case " << attr
149 << " needs mnemonic for python enum bindings generation";
152 StringRef mnemonic
= attr
.getMnemonic().value();
153 std::optional
<StringRef
> assemblyFormat
= attr
.getAssemblyFormat();
154 StringRef dialect
= attr
.getDialect().getName();
155 if (assemblyFormat
== "`<` $value `>`") {
156 emitDialectEnumAttributeBuilder(
158 llvm::formatv("#{0}.{1}<{{str(x)}>", dialect
, mnemonic
).str(), os
);
159 } else if (assemblyFormat
== "$value") {
160 emitDialectEnumAttributeBuilder(
162 llvm::formatv("#{0}<{1} {{str(x)}>", dialect
, mnemonic
).str(), os
);
165 << "unsupported assembly format for python enum bindings generation";
173 // Registers the enum utility generator to mlir-tblgen.
174 static mlir::GenRegistration
175 genPythonEnumBindings("gen-python-enum-bindings",
176 "Generate Python bindings for enum attributes",