[CodeGen] Remove some implict conversions of MCRegister to unsigned by using(). NFC
[llvm-project.git] / mlir / tools / mlir-tblgen / EnumsGen.cpp
blobd11aa9b27c2d86381e3750d579c4027e0461f0a3
1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
11 //===----------------------------------------------------------------------===//
13 #include "FormatGen.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "llvm/ADT/BitVector.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/Record.h"
24 #include "llvm/TableGen/TableGenBackend.h"
26 using llvm::formatv;
27 using llvm::isDigit;
28 using llvm::PrintFatalError;
29 using llvm::Record;
30 using llvm::RecordKeeper;
31 using namespace mlir;
32 using mlir::tblgen::Attribute;
33 using mlir::tblgen::EnumAttr;
34 using mlir::tblgen::EnumAttrCase;
35 using mlir::tblgen::FmtContext;
36 using mlir::tblgen::tgfmt;
38 static std::string makeIdentifier(StringRef str) {
39 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
40 std::string newStr = std::string("_") + str.str();
41 return newStr;
43 return str.str();
46 static void emitEnumClass(const Record &enumDef, StringRef enumName,
47 StringRef underlyingType, StringRef description,
48 const std::vector<EnumAttrCase> &enumerants,
49 raw_ostream &os) {
50 os << "// " << description << "\n";
51 os << "enum class " << enumName;
53 if (!underlyingType.empty())
54 os << " : " << underlyingType;
55 os << " {\n";
57 for (const auto &enumerant : enumerants) {
58 auto symbol = makeIdentifier(enumerant.getSymbol());
59 auto value = enumerant.getValue();
60 if (value >= 0) {
61 os << formatv(" {0} = {1},\n", symbol, value);
62 } else {
63 os << formatv(" {0},\n", symbol);
66 os << "};\n\n";
69 static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
70 StringRef cppNamespace, raw_ostream &os) {
71 if (enumAttr.getUnderlyingType().empty() ||
72 enumAttr.getConstBuilderTemplate().empty())
73 return;
74 auto cases = enumAttr.getAllCases();
76 // Check which cases shouldn't be printed using a keyword.
77 llvm::BitVector nonKeywordCases(cases.size());
78 for (auto [index, caseVal] : llvm::enumerate(cases))
79 if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
80 nonKeywordCases.set(index);
82 // Generate the parser and the start of the printer for the enum.
83 const char *parsedAndPrinterStart = R"(
84 namespace mlir {
85 template <typename T, typename>
86 struct FieldParser;
88 template<>
89 struct FieldParser<{0}, {0}> {{
90 template <typename ParserT>
91 static FailureOr<{0}> parse(ParserT &parser) {{
92 // Parse the keyword/string containing the enum.
93 std::string enumKeyword;
94 auto loc = parser.getCurrentLocation();
95 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
96 return parser.emitError(loc, "expected keyword for {2}");
98 // Symbolize the keyword.
99 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
100 return *attr;
101 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
105 /// Support for std::optional, useful in attribute/type definition where the enum is
106 /// used as:
108 /// let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
109 template<>
110 struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
111 template <typename ParserT>
112 static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
113 // Parse the keyword/string containing the enum.
114 std::string enumKeyword;
115 auto loc = parser.getCurrentLocation();
116 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
117 return std::optional<{0}>{{};
119 // Symbolize the keyword.
120 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
121 return attr;
122 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
125 } // namespace mlir
127 namespace llvm {
128 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
129 auto valueStr = stringifyEnum(value);
131 os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
132 enumAttr.getSummary());
134 // If all cases require a string, always wrap.
135 if (nonKeywordCases.all()) {
136 os << " return p << '\"' << valueStr << '\"';\n"
137 "}\n"
138 "} // namespace llvm\n";
139 return;
142 // If there are any cases that can't be used with a keyword, switch on the
143 // case value to determine when to print in the string form.
144 if (nonKeywordCases.any()) {
145 os << " switch (value) {\n";
146 for (auto it : llvm::enumerate(cases)) {
147 if (nonKeywordCases.test(it.index()))
148 continue;
149 StringRef symbol = it.value().getSymbol();
150 os << llvm::formatv(" case {0}::{1}:\n", qualName,
151 makeIdentifier(symbol));
153 os << " break;\n"
154 " default:\n"
155 " return p << '\"' << valueStr << '\"';\n"
156 " }\n";
158 // If this is a bit enum, conservatively print the string form if the value
159 // is not a power of two (i.e. not a single bit case) and not a known case.
160 } else if (enumAttr.isBitEnum()) {
161 // Process the known multi-bit cases that use valid keywords.
162 SmallVector<EnumAttrCase *> validMultiBitCases;
163 for (auto [index, caseVal] : llvm::enumerate(cases)) {
164 uint64_t value = caseVal.getValue();
165 if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
166 validMultiBitCases.push_back(&caseVal);
168 if (!validMultiBitCases.empty()) {
169 os << " switch (value) {\n";
170 for (EnumAttrCase *caseVal : validMultiBitCases) {
171 StringRef symbol = caseVal->getSymbol();
172 os << llvm::formatv(" case {0}::{1}:\n", qualName,
173 llvm::isDigit(symbol.front()) ? ("_" + symbol)
174 : symbol);
176 os << " return p << valueStr;\n"
177 " default:\n"
178 " break;\n"
179 " }\n";
182 // All other multi-bit cases should be printed as strings.
183 os << formatv(" auto underlyingValue = "
184 "static_cast<std::make_unsigned_t<{0}>>(value);\n",
185 qualName);
186 os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
187 " return p << '\"' << valueStr << '\"';\n";
189 os << " return p << valueStr;\n"
190 "}\n"
191 "} // namespace llvm\n";
194 static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
195 StringRef cppNamespace, raw_ostream &os) {
196 if (underlyingType.empty())
197 underlyingType =
198 std::string(formatv("std::underlying_type_t<{0}>", qualName));
200 const char *const mapInfo = R"(
201 namespace llvm {
202 template<> struct DenseMapInfo<{0}> {{
203 using StorageInfo = ::llvm::DenseMapInfo<{1}>;
205 static inline {0} getEmptyKey() {{
206 return static_cast<{0}>(StorageInfo::getEmptyKey());
209 static inline {0} getTombstoneKey() {{
210 return static_cast<{0}>(StorageInfo::getTombstoneKey());
213 static unsigned getHashValue(const {0} &val) {{
214 return StorageInfo::getHashValue(static_cast<{1}>(val));
217 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
218 return lhs == rhs;
221 })";
222 os << formatv(mapInfo, qualName, underlyingType);
223 os << "\n\n";
226 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
227 EnumAttr enumAttr(enumDef);
228 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
229 auto enumerants = enumAttr.getAllCases();
231 unsigned maxEnumVal = 0;
232 for (const auto &enumerant : enumerants) {
233 int64_t value = enumerant.getValue();
234 // Avoid generating the max value function if there is an enumerant without
235 // explicit value.
236 if (value < 0)
237 return;
239 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
242 // Emit the function to return the max enum value
243 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
244 os << formatv(" return {0};\n", maxEnumVal);
245 os << "}\n\n";
248 // Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
249 // otherwise.
250 static std::optional<EnumAttrCase>
251 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
252 for (auto attrCase : cases) {
253 if (attrCase.getValue() == 0)
254 return attrCase;
256 return std::nullopt;
259 // Emits the following inline function for bit enums:
261 // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
262 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
263 // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
264 // inline constexpr <enum-type> operator~(<enum-type> bits);
265 // inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
266 // inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
267 // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
268 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
269 // bool value=true);
270 static void emitOperators(const Record &enumDef, raw_ostream &os) {
271 EnumAttr enumAttr(enumDef);
272 StringRef enumName = enumAttr.getEnumClassName();
273 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
274 int64_t validBits = enumDef.getValueAsInt("validBits");
275 const char *const operators = R"(
276 inline constexpr {0} operator|({0} a, {0} b) {{
277 return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
279 inline constexpr {0} operator&({0} a, {0} b) {{
280 return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
282 inline constexpr {0} operator^({0} a, {0} b) {{
283 return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
285 inline constexpr {0} operator~({0} bits) {{
286 // Ensure only bits that can be present in the enum are set
287 return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
289 inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
290 return (bits & bit) == bit;
292 inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
293 return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
295 inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
296 return bits & ~bit;
298 inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
299 return value ? (bits | bit) : bitEnumClear(bits, bit);
302 os << formatv(operators, enumName, underlyingType, validBits);
305 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
306 EnumAttr enumAttr(enumDef);
307 StringRef enumName = enumAttr.getEnumClassName();
308 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
309 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
310 auto enumerants = enumAttr.getAllCases();
312 os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
313 symToStrFnRetType);
314 os << " switch (val) {\n";
315 for (const auto &enumerant : enumerants) {
316 auto symbol = enumerant.getSymbol();
317 auto str = enumerant.getStr();
318 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
319 makeIdentifier(symbol), str);
321 os << " }\n";
322 os << " return \"\";\n";
323 os << "}\n\n";
326 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
327 EnumAttr enumAttr(enumDef);
328 StringRef enumName = enumAttr.getEnumClassName();
329 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
330 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
331 StringRef separator = enumDef.getValueAsString("separator");
332 auto enumerants = enumAttr.getAllCases();
333 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
335 os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
336 symToStrFnRetType);
338 os << formatv(" auto val = static_cast<{0}>(symbol);\n",
339 enumAttr.getUnderlyingType());
340 // If we have unknown bit set, return an empty string to signal errors.
341 int64_t validBits = enumDef.getValueAsInt("validBits");
342 os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit "
343 "enum\");\n",
344 validBits);
345 if (allBitsUnsetCase) {
346 os << " // Special case for all bits unset.\n";
347 os << formatv(" if (val == 0) return \"{0}\";\n\n",
348 allBitsUnsetCase->getStr());
350 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
352 // Add case string if the value has all case bits, and remove them to avoid
353 // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
354 const char *const formatCompareRemove = R"(
355 if ({0}u == ({0}u & val)) {{
356 strs.push_back("{1}");
357 val &= ~static_cast<{2}>({0});
360 // Add case string if the value has all case bits. Used for individual bit
361 // cases, and for groups when printBitEnumPrimaryGroups is 0.
362 const char *const formatCompare = R"(
363 if ({0}u == ({0}u & val))
364 strs.push_back("{1}");
366 // Optionally elide bits that are members of groups that will also be printed
367 // for more concise output.
368 if (enumAttr.printBitEnumPrimaryGroups()) {
369 os << " // Print bit enum groups before individual bits\n";
370 // Emit comparisons for group bit cases in reverse tablegen declaration
371 // order, removing bits for groups with all bits present.
372 for (const auto &enumerant : llvm::reverse(enumerants)) {
373 if ((enumerant.getValue() != 0) &&
374 enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
375 os << formatv(formatCompareRemove, enumerant.getValue(),
376 enumerant.getStr(), enumAttr.getUnderlyingType());
379 // Emit comparisons for individual bit cases in tablegen declaration order.
380 for (const auto &enumerant : enumerants) {
381 if ((enumerant.getValue() != 0) &&
382 enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
383 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
385 } else {
386 // Emit comparisons for ALL nonzero cases (individual bits and groups) in
387 // tablegen declaration order.
388 for (const auto &enumerant : enumerants) {
389 if (enumerant.getValue() != 0)
390 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
393 os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);
395 os << "}\n\n";
398 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
399 EnumAttr enumAttr(enumDef);
400 StringRef enumName = enumAttr.getEnumClassName();
401 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
402 auto enumerants = enumAttr.getAllCases();
404 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
405 enumName, strToSymFnName);
406 os << formatv(" return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n",
407 enumName);
408 for (const auto &enumerant : enumerants) {
409 auto symbol = enumerant.getSymbol();
410 auto str = enumerant.getStr();
411 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str,
412 makeIdentifier(symbol));
414 os << " .Default(::std::nullopt);\n";
415 os << "}\n";
418 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
419 EnumAttr enumAttr(enumDef);
420 StringRef enumName = enumAttr.getEnumClassName();
421 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
422 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
423 StringRef separator = enumDef.getValueAsString("separator");
424 StringRef separatorTrimmed = separator.trim();
425 auto enumerants = enumAttr.getAllCases();
426 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
428 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
429 enumName, strToSymFnName);
431 if (allBitsUnsetCase) {
432 os << " // Special case for all bits unset.\n";
433 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
434 os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
435 allBitsUnsetCase->getStr(), makeIdentifier(caseSymbol));
438 // Split the string to get symbols for all the bits.
439 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
440 // Remove whitespace from the separator string when parsing.
441 os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
443 os << formatv(" {0} val = 0;\n", underlyingType);
444 os << " for (auto symbol : symbols) {\n";
446 // Convert each symbol to the bit ordinal and set the corresponding bit.
447 os << formatv(" auto bit = "
448 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n",
449 underlyingType);
450 for (const auto &enumerant : enumerants) {
451 // Skip the special enumerant for None.
452 if (auto val = enumerant.getValue())
453 os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
455 os.indent(6) << ".Default(::std::nullopt);\n";
457 os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n";
458 os << " }\n";
460 os << formatv(" return static_cast<{0}>(val);\n", enumName);
461 os << "}\n\n";
464 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
465 raw_ostream &os) {
466 EnumAttr enumAttr(enumDef);
467 StringRef enumName = enumAttr.getEnumClassName();
468 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
469 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
470 auto enumerants = enumAttr.getAllCases();
472 // Avoid generating the underlying value to symbol conversion function if
473 // there is an enumerant without explicit value.
474 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
475 return enumerant.getValue() < 0;
477 return;
479 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
480 underlyingToSymFnName,
481 underlyingType.empty() ? std::string("unsigned")
482 : underlyingType)
483 << " switch (value) {\n";
484 for (const auto &enumerant : enumerants) {
485 auto symbol = enumerant.getSymbol();
486 auto value = enumerant.getValue();
487 os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
488 makeIdentifier(symbol));
490 os << " default: return ::std::nullopt;\n"
491 << " }\n"
492 << "}\n\n";
495 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
496 EnumAttr enumAttr(enumDef);
497 StringRef enumName = enumAttr.getEnumClassName();
498 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
499 const Record *baseAttrDef = enumAttr.getBaseAttrClass();
500 Attribute baseAttr(baseAttrDef);
502 // Emit classof method
504 os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
505 attrClassName);
507 mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
508 if (baseAttrPred.isNull())
509 PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
511 std::string condition = baseAttrPred.getCondition();
512 FmtContext verifyCtx;
513 verifyCtx.withSelf("attr");
514 os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
516 os << "}\n";
518 // Emit get method
520 os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
521 attrClassName, enumName);
523 StringRef underlyingType = enumAttr.getUnderlyingType();
525 // Assuming that it is IntegerAttr constraint
526 int64_t bitwidth = 64;
527 if (baseAttrDef->getValue("valueType")) {
528 auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
529 if (valueTypeDef->getValue("bitwidth"))
530 bitwidth = valueTypeDef->getValueAsInt("bitwidth");
533 os << formatv(" ::mlir::IntegerType intType = "
534 "::mlir::IntegerType::get(context, {0});\n",
535 bitwidth);
536 os << formatv(" ::mlir::IntegerAttr baseAttr = "
537 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
538 underlyingType);
539 os << formatv(" return ::llvm::cast<{0}>(baseAttr);\n", attrClassName);
541 os << "}\n";
543 // Emit getValue method
545 os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
547 os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
548 enumName);
550 os << "}\n";
553 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
554 raw_ostream &os) {
555 EnumAttr enumAttr(enumDef);
556 StringRef enumName = enumAttr.getEnumClassName();
557 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
558 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
559 auto enumerants = enumAttr.getAllCases();
560 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
562 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
563 underlyingToSymFnName, underlyingType);
564 if (allBitsUnsetCase) {
565 os << " // Special case for all bits unset.\n";
566 os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
567 makeIdentifier(allBitsUnsetCase->getSymbol()));
569 int64_t validBits = enumDef.getValueAsInt("validBits");
570 os << formatv(" if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n",
571 underlyingType, validBits);
572 os << formatv(" return static_cast<{0}>(value);\n", enumName);
573 os << "}\n";
576 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
577 EnumAttr enumAttr(enumDef);
578 StringRef enumName = enumAttr.getEnumClassName();
579 StringRef cppNamespace = enumAttr.getCppNamespace();
580 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
581 StringRef description = enumAttr.getSummary();
582 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
583 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
584 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
585 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
586 auto enumerants = enumAttr.getAllCases();
588 SmallVector<StringRef, 2> namespaces;
589 llvm::SplitString(cppNamespace, namespaces, "::");
591 for (auto ns : namespaces)
592 os << "namespace " << ns << " {\n";
594 // Emit the enum class definition
595 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
597 // Emit conversion function declarations
598 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
599 return enumerant.getValue() >= 0;
600 })) {
601 os << formatv(
602 "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
603 underlyingType.empty() ? std::string("unsigned") : underlyingType);
605 os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
606 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
607 strToSymFnName);
609 if (enumAttr.isBitEnum()) {
610 emitOperators(enumDef, os);
611 } else {
612 emitMaxValueFn(enumDef, os);
615 // Generate a generic `stringifyEnum` function that forwards to the method
616 // specified by the user.
617 const char *const stringifyEnumStr = R"(
618 inline {0} stringifyEnum({1} enumValue) {{
619 return {2}(enumValue);
622 os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
624 // Generate a generic `symbolizeEnum` function that forwards to the method
625 // specified by the user.
626 const char *const symbolizeEnumStr = R"(
627 template <typename EnumType>
628 ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
630 template <>
631 inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
632 return {1}(str);
635 os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
637 const char *const attrClassDecl = R"(
638 class {1} : public ::mlir::{2} {
639 public:
640 using ValueType = {0};
641 using ::mlir::{2}::{2};
642 static bool classof(::mlir::Attribute attr);
643 static {1} get(::mlir::MLIRContext *context, {0} val);
644 {0} getValue() const;
647 if (enumAttr.genSpecializedAttr()) {
648 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
649 StringRef baseAttrClassName = "IntegerAttr";
650 os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
653 for (auto ns : llvm::reverse(namespaces))
654 os << "} // namespace " << ns << "\n";
656 // Generate a generic parser and printer for the enum.
657 std::string qualName =
658 std::string(formatv("{0}::{1}", cppNamespace, enumName));
659 emitParserPrinter(enumAttr, qualName, cppNamespace, os);
661 // Emit DenseMapInfo for this enum class
662 emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
665 static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
666 llvm::emitSourceFileHeader("Enum Utility Declarations", os, records);
668 for (const Record *def :
669 records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
670 emitEnumDecl(*def, os);
672 return false;
675 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
676 EnumAttr enumAttr(enumDef);
677 StringRef cppNamespace = enumAttr.getCppNamespace();
679 SmallVector<StringRef, 2> namespaces;
680 llvm::SplitString(cppNamespace, namespaces, "::");
682 for (auto ns : namespaces)
683 os << "namespace " << ns << " {\n";
685 if (enumAttr.isBitEnum()) {
686 emitSymToStrFnForBitEnum(enumDef, os);
687 emitStrToSymFnForBitEnum(enumDef, os);
688 emitUnderlyingToSymFnForBitEnum(enumDef, os);
689 } else {
690 emitSymToStrFnForIntEnum(enumDef, os);
691 emitStrToSymFnForIntEnum(enumDef, os);
692 emitUnderlyingToSymFnForIntEnum(enumDef, os);
695 if (enumAttr.genSpecializedAttr())
696 emitSpecializedAttrDef(enumDef, os);
698 for (auto ns : llvm::reverse(namespaces))
699 os << "} // namespace " << ns << "\n";
700 os << "\n";
703 static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
704 llvm::emitSourceFileHeader("Enum Utility Definitions", os, records);
706 for (const Record *def :
707 records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
708 emitEnumDef(*def, os);
710 return false;
713 // Registers the enum utility generator to mlir-tblgen.
714 static mlir::GenRegistration
715 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
716 [](const RecordKeeper &records, raw_ostream &os) {
717 return emitEnumDecls(records, os);
720 // Registers the enum utility generator to mlir-tblgen.
721 static mlir::GenRegistration
722 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
723 [](const RecordKeeper &records, raw_ostream &os) {
724 return emitEnumDefs(records, os);