[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / EnumsGen.cpp
blobf1d7a233b66a9a6809090f381a84b2f277969610
1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
11 //===----------------------------------------------------------------------===//
13 #include "FormatGen.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "llvm/ADT/BitVector.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/Record.h"
24 #include "llvm/TableGen/TableGenBackend.h"
26 using llvm::formatv;
27 using llvm::isDigit;
28 using llvm::PrintFatalError;
29 using llvm::raw_ostream;
30 using llvm::Record;
31 using llvm::RecordKeeper;
32 using llvm::StringRef;
33 using mlir::tblgen::Attribute;
34 using mlir::tblgen::EnumAttr;
35 using mlir::tblgen::EnumAttrCase;
36 using mlir::tblgen::FmtContext;
37 using mlir::tblgen::tgfmt;
39 static std::string makeIdentifier(StringRef str) {
40 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
41 std::string newStr = std::string("_") + str.str();
42 return newStr;
44 return str.str();
47 static void emitEnumClass(const Record &enumDef, StringRef enumName,
48 StringRef underlyingType, StringRef description,
49 const std::vector<EnumAttrCase> &enumerants,
50 raw_ostream &os) {
51 os << "// " << description << "\n";
52 os << "enum class " << enumName;
54 if (!underlyingType.empty())
55 os << " : " << underlyingType;
56 os << " {\n";
58 for (const auto &enumerant : enumerants) {
59 auto symbol = makeIdentifier(enumerant.getSymbol());
60 auto value = enumerant.getValue();
61 if (value >= 0) {
62 os << formatv(" {0} = {1},\n", symbol, value);
63 } else {
64 os << formatv(" {0},\n", symbol);
67 os << "};\n\n";
70 static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
71 StringRef cppNamespace, raw_ostream &os) {
72 if (enumAttr.getUnderlyingType().empty() ||
73 enumAttr.getConstBuilderTemplate().empty())
74 return;
75 auto cases = enumAttr.getAllCases();
77 // Check which cases shouldn't be printed using a keyword.
78 llvm::BitVector nonKeywordCases(cases.size());
79 for (auto [index, caseVal] : llvm::enumerate(cases))
80 if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
81 nonKeywordCases.set(index);
83 // Generate the parser and the start of the printer for the enum.
84 const char *parsedAndPrinterStart = R"(
85 namespace mlir {
86 template <typename T, typename>
87 struct FieldParser;
89 template<>
90 struct FieldParser<{0}, {0}> {{
91 template <typename ParserT>
92 static FailureOr<{0}> parse(ParserT &parser) {{
93 // Parse the keyword/string containing the enum.
94 std::string enumKeyword;
95 auto loc = parser.getCurrentLocation();
96 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
97 return parser.emitError(loc, "expected keyword for {2}");
99 // Symbolize the keyword.
100 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
101 return *attr;
102 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
105 } // namespace mlir
107 namespace llvm {
108 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
109 auto valueStr = stringifyEnum(value);
111 os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
112 enumAttr.getSummary());
114 // If all cases require a string, always wrap.
115 if (nonKeywordCases.all()) {
116 os << " return p << '\"' << valueStr << '\"';\n"
117 "}\n"
118 "} // namespace llvm\n";
119 return;
122 // If there are any cases that can't be used with a keyword, switch on the
123 // case value to determine when to print in the string form.
124 if (nonKeywordCases.any()) {
125 os << " switch (value) {\n";
126 for (auto it : llvm::enumerate(cases)) {
127 if (nonKeywordCases.test(it.index()))
128 continue;
129 StringRef symbol = it.value().getSymbol();
130 os << llvm::formatv(" case {0}::{1}:\n", qualName,
131 makeIdentifier(symbol));
133 os << " break;\n"
134 " default:\n"
135 " return p << '\"' << valueStr << '\"';\n"
136 " }\n";
138 // If this is a bit enum, conservatively print the string form if the value
139 // is not a power of two (i.e. not a single bit case) and not a known case.
140 } else if (enumAttr.isBitEnum()) {
141 // Process the known multi-bit cases that use valid keywords.
142 llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
143 for (auto [index, caseVal] : llvm::enumerate(cases)) {
144 uint64_t value = caseVal.getValue();
145 if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
146 validMultiBitCases.push_back(&caseVal);
148 if (!validMultiBitCases.empty()) {
149 os << " switch (value) {\n";
150 for (EnumAttrCase *caseVal : validMultiBitCases) {
151 StringRef symbol = caseVal->getSymbol();
152 os << llvm::formatv(" case {0}::{1}:\n", qualName,
153 llvm::isDigit(symbol.front()) ? ("_" + symbol)
154 : symbol);
156 os << " return p << valueStr;\n"
157 " default:\n"
158 " break;\n"
159 " }\n";
162 // All other multi-bit cases should be printed as strings.
163 os << formatv(" auto underlyingValue = "
164 "static_cast<std::make_unsigned_t<{0}>>(value);\n",
165 qualName);
166 os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
167 " return p << '\"' << valueStr << '\"';\n";
169 os << " return p << valueStr;\n"
170 "}\n"
171 "} // namespace llvm\n";
174 static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
175 StringRef cppNamespace, raw_ostream &os) {
176 if (underlyingType.empty())
177 underlyingType =
178 std::string(formatv("std::underlying_type_t<{0}>", qualName));
180 const char *const mapInfo = R"(
181 namespace llvm {
182 template<> struct DenseMapInfo<{0}> {{
183 using StorageInfo = ::llvm::DenseMapInfo<{1}>;
185 static inline {0} getEmptyKey() {{
186 return static_cast<{0}>(StorageInfo::getEmptyKey());
189 static inline {0} getTombstoneKey() {{
190 return static_cast<{0}>(StorageInfo::getTombstoneKey());
193 static unsigned getHashValue(const {0} &val) {{
194 return StorageInfo::getHashValue(static_cast<{1}>(val));
197 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
198 return lhs == rhs;
201 })";
202 os << formatv(mapInfo, qualName, underlyingType);
203 os << "\n\n";
206 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
207 EnumAttr enumAttr(enumDef);
208 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
209 auto enumerants = enumAttr.getAllCases();
211 unsigned maxEnumVal = 0;
212 for (const auto &enumerant : enumerants) {
213 int64_t value = enumerant.getValue();
214 // Avoid generating the max value function if there is an enumerant without
215 // explicit value.
216 if (value < 0)
217 return;
219 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
222 // Emit the function to return the max enum value
223 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
224 os << formatv(" return {0};\n", maxEnumVal);
225 os << "}\n\n";
228 // Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
229 // otherwise.
230 static std::optional<EnumAttrCase>
231 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
232 for (auto attrCase : cases) {
233 if (attrCase.getValue() == 0)
234 return attrCase;
236 return std::nullopt;
239 // Emits the following inline function for bit enums:
241 // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
242 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
243 // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
244 // inline constexpr <enum-type> operator~(<enum-type> bits);
245 // inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
246 // inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
247 // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
248 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
249 // bool value=true);
250 static void emitOperators(const Record &enumDef, raw_ostream &os) {
251 EnumAttr enumAttr(enumDef);
252 StringRef enumName = enumAttr.getEnumClassName();
253 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
254 int64_t validBits = enumDef.getValueAsInt("validBits");
255 const char *const operators = R"(
256 inline constexpr {0} operator|({0} a, {0} b) {{
257 return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
259 inline constexpr {0} operator&({0} a, {0} b) {{
260 return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
262 inline constexpr {0} operator^({0} a, {0} b) {{
263 return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
265 inline constexpr {0} operator~({0} bits) {{
266 // Ensure only bits that can be present in the enum are set
267 return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
269 inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
270 return (bits & bit) == bit;
272 inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
273 return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
275 inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
276 return bits & ~bit;
278 inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
279 return value ? (bits | bit) : bitEnumClear(bits, bit);
282 os << formatv(operators, enumName, underlyingType, validBits);
285 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
286 EnumAttr enumAttr(enumDef);
287 StringRef enumName = enumAttr.getEnumClassName();
288 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
289 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
290 auto enumerants = enumAttr.getAllCases();
292 os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
293 symToStrFnRetType);
294 os << " switch (val) {\n";
295 for (const auto &enumerant : enumerants) {
296 auto symbol = enumerant.getSymbol();
297 auto str = enumerant.getStr();
298 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
299 makeIdentifier(symbol), str);
301 os << " }\n";
302 os << " return \"\";\n";
303 os << "}\n\n";
306 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
307 EnumAttr enumAttr(enumDef);
308 StringRef enumName = enumAttr.getEnumClassName();
309 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
310 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
311 StringRef separator = enumDef.getValueAsString("separator");
312 auto enumerants = enumAttr.getAllCases();
313 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
315 os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
316 symToStrFnRetType);
318 os << formatv(" auto val = static_cast<{0}>(symbol);\n",
319 enumAttr.getUnderlyingType());
320 // If we have unknown bit set, return an empty string to signal errors.
321 int64_t validBits = enumDef.getValueAsInt("validBits");
322 os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit "
323 "enum\");\n",
324 validBits);
325 if (allBitsUnsetCase) {
326 os << " // Special case for all bits unset.\n";
327 os << formatv(" if (val == 0) return \"{0}\";\n\n",
328 allBitsUnsetCase->getStr());
330 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
332 // Add case string if the value has all case bits, and remove them to avoid
333 // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
334 const char *const formatCompareRemove = R"(
335 if ({0}u == ({0}u & val)) {{
336 strs.push_back("{1}");
337 val &= ~static_cast<{2}>({0});
340 // Add case string if the value has all case bits. Used for individual bit
341 // cases, and for groups when printBitEnumPrimaryGroups is 0.
342 const char *const formatCompare = R"(
343 if ({0}u == ({0}u & val))
344 strs.push_back("{1}");
346 // Optionally elide bits that are members of groups that will also be printed
347 // for more concise output.
348 if (enumAttr.printBitEnumPrimaryGroups()) {
349 os << " // Print bit enum groups before individual bits\n";
350 // Emit comparisons for group bit cases in reverse tablegen declaration
351 // order, removing bits for groups with all bits present.
352 for (const auto &enumerant : llvm::reverse(enumerants)) {
353 if ((enumerant.getValue() != 0) &&
354 enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
355 os << formatv(formatCompareRemove, enumerant.getValue(),
356 enumerant.getStr(), enumAttr.getUnderlyingType());
359 // Emit comparisons for individual bit cases in tablegen declaration order.
360 for (const auto &enumerant : enumerants) {
361 if ((enumerant.getValue() != 0) &&
362 enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
363 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
365 } else {
366 // Emit comparisons for ALL nonzero cases (individual bits and groups) in
367 // tablegen declaration order.
368 for (const auto &enumerant : enumerants) {
369 if (enumerant.getValue() != 0)
370 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
373 os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);
375 os << "}\n\n";
378 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
379 EnumAttr enumAttr(enumDef);
380 StringRef enumName = enumAttr.getEnumClassName();
381 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
382 auto enumerants = enumAttr.getAllCases();
384 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
385 enumName, strToSymFnName);
386 os << formatv(" return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n",
387 enumName);
388 for (const auto &enumerant : enumerants) {
389 auto symbol = enumerant.getSymbol();
390 auto str = enumerant.getStr();
391 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str,
392 makeIdentifier(symbol));
394 os << " .Default(::std::nullopt);\n";
395 os << "}\n";
398 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
399 EnumAttr enumAttr(enumDef);
400 StringRef enumName = enumAttr.getEnumClassName();
401 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
402 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
403 StringRef separator = enumDef.getValueAsString("separator");
404 StringRef separatorTrimmed = separator.trim();
405 auto enumerants = enumAttr.getAllCases();
406 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
408 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
409 enumName, strToSymFnName);
411 if (allBitsUnsetCase) {
412 os << " // Special case for all bits unset.\n";
413 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
414 os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
415 allBitsUnsetCase->getStr(), makeIdentifier(caseSymbol));
418 // Split the string to get symbols for all the bits.
419 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
420 // Remove whitespace from the separator string when parsing.
421 os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
423 os << formatv(" {0} val = 0;\n", underlyingType);
424 os << " for (auto symbol : symbols) {\n";
426 // Convert each symbol to the bit ordinal and set the corresponding bit.
427 os << formatv(" auto bit = "
428 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n",
429 underlyingType);
430 for (const auto &enumerant : enumerants) {
431 // Skip the special enumerant for None.
432 if (auto val = enumerant.getValue())
433 os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
435 os.indent(6) << ".Default(::std::nullopt);\n";
437 os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n";
438 os << " }\n";
440 os << formatv(" return static_cast<{0}>(val);\n", enumName);
441 os << "}\n\n";
444 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
445 raw_ostream &os) {
446 EnumAttr enumAttr(enumDef);
447 StringRef enumName = enumAttr.getEnumClassName();
448 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
449 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
450 auto enumerants = enumAttr.getAllCases();
452 // Avoid generating the underlying value to symbol conversion function if
453 // there is an enumerant without explicit value.
454 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
455 return enumerant.getValue() < 0;
457 return;
459 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
460 underlyingToSymFnName,
461 underlyingType.empty() ? std::string("unsigned")
462 : underlyingType)
463 << " switch (value) {\n";
464 for (const auto &enumerant : enumerants) {
465 auto symbol = enumerant.getSymbol();
466 auto value = enumerant.getValue();
467 os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
468 makeIdentifier(symbol));
470 os << " default: return ::std::nullopt;\n"
471 << " }\n"
472 << "}\n\n";
475 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
476 EnumAttr enumAttr(enumDef);
477 StringRef enumName = enumAttr.getEnumClassName();
478 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
479 llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
480 Attribute baseAttr(baseAttrDef);
482 // Emit classof method
484 os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
485 attrClassName);
487 mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
488 if (baseAttrPred.isNull())
489 PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
491 std::string condition = baseAttrPred.getCondition();
492 FmtContext verifyCtx;
493 verifyCtx.withSelf("attr");
494 os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
496 os << "}\n";
498 // Emit get method
500 os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
501 attrClassName, enumName);
503 StringRef underlyingType = enumAttr.getUnderlyingType();
505 // Assuming that it is IntegerAttr constraint
506 int64_t bitwidth = 64;
507 if (baseAttrDef->getValue("valueType")) {
508 auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
509 if (valueTypeDef->getValue("bitwidth"))
510 bitwidth = valueTypeDef->getValueAsInt("bitwidth");
513 os << formatv(" ::mlir::IntegerType intType = "
514 "::mlir::IntegerType::get(context, {0});\n",
515 bitwidth);
516 os << formatv(" ::mlir::IntegerAttr baseAttr = "
517 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
518 underlyingType);
519 os << formatv(" return ::llvm::cast<{0}>(baseAttr);\n", attrClassName);
521 os << "}\n";
523 // Emit getValue method
525 os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
527 os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
528 enumName);
530 os << "}\n";
533 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
534 raw_ostream &os) {
535 EnumAttr enumAttr(enumDef);
536 StringRef enumName = enumAttr.getEnumClassName();
537 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
538 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
539 auto enumerants = enumAttr.getAllCases();
540 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
542 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
543 underlyingToSymFnName, underlyingType);
544 if (allBitsUnsetCase) {
545 os << " // Special case for all bits unset.\n";
546 os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
547 makeIdentifier(allBitsUnsetCase->getSymbol()));
549 int64_t validBits = enumDef.getValueAsInt("validBits");
550 os << formatv(" if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n",
551 underlyingType, validBits);
552 os << formatv(" return static_cast<{0}>(value);\n", enumName);
553 os << "}\n";
556 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
557 EnumAttr enumAttr(enumDef);
558 StringRef enumName = enumAttr.getEnumClassName();
559 StringRef cppNamespace = enumAttr.getCppNamespace();
560 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
561 StringRef description = enumAttr.getSummary();
562 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
563 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
564 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
565 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
566 auto enumerants = enumAttr.getAllCases();
568 llvm::SmallVector<StringRef, 2> namespaces;
569 llvm::SplitString(cppNamespace, namespaces, "::");
571 for (auto ns : namespaces)
572 os << "namespace " << ns << " {\n";
574 // Emit the enum class definition
575 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
577 // Emit conversion function declarations
578 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
579 return enumerant.getValue() >= 0;
580 })) {
581 os << formatv(
582 "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
583 underlyingType.empty() ? std::string("unsigned") : underlyingType);
585 os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
586 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
587 strToSymFnName);
589 if (enumAttr.isBitEnum()) {
590 emitOperators(enumDef, os);
591 } else {
592 emitMaxValueFn(enumDef, os);
595 // Generate a generic `stringifyEnum` function that forwards to the method
596 // specified by the user.
597 const char *const stringifyEnumStr = R"(
598 inline {0} stringifyEnum({1} enumValue) {{
599 return {2}(enumValue);
602 os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
604 // Generate a generic `symbolizeEnum` function that forwards to the method
605 // specified by the user.
606 const char *const symbolizeEnumStr = R"(
607 template <typename EnumType>
608 ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
610 template <>
611 inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
612 return {1}(str);
615 os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
617 const char *const attrClassDecl = R"(
618 class {1} : public ::mlir::{2} {
619 public:
620 using ValueType = {0};
621 using ::mlir::{2}::{2};
622 static bool classof(::mlir::Attribute attr);
623 static {1} get(::mlir::MLIRContext *context, {0} val);
624 {0} getValue() const;
627 if (enumAttr.genSpecializedAttr()) {
628 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
629 StringRef baseAttrClassName = "IntegerAttr";
630 os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
633 for (auto ns : llvm::reverse(namespaces))
634 os << "} // namespace " << ns << "\n";
636 // Generate a generic parser and printer for the enum.
637 std::string qualName =
638 std::string(formatv("{0}::{1}", cppNamespace, enumName));
639 emitParserPrinter(enumAttr, qualName, cppNamespace, os);
641 // Emit DenseMapInfo for this enum class
642 emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
645 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
646 llvm::emitSourceFileHeader("Enum Utility Declarations", os, recordKeeper);
648 auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
649 for (const auto *def : defs)
650 emitEnumDecl(*def, os);
652 return false;
655 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
656 EnumAttr enumAttr(enumDef);
657 StringRef cppNamespace = enumAttr.getCppNamespace();
659 llvm::SmallVector<StringRef, 2> namespaces;
660 llvm::SplitString(cppNamespace, namespaces, "::");
662 for (auto ns : namespaces)
663 os << "namespace " << ns << " {\n";
665 if (enumAttr.isBitEnum()) {
666 emitSymToStrFnForBitEnum(enumDef, os);
667 emitStrToSymFnForBitEnum(enumDef, os);
668 emitUnderlyingToSymFnForBitEnum(enumDef, os);
669 } else {
670 emitSymToStrFnForIntEnum(enumDef, os);
671 emitStrToSymFnForIntEnum(enumDef, os);
672 emitUnderlyingToSymFnForIntEnum(enumDef, os);
675 if (enumAttr.genSpecializedAttr())
676 emitSpecializedAttrDef(enumDef, os);
678 for (auto ns : llvm::reverse(namespaces))
679 os << "} // namespace " << ns << "\n";
680 os << "\n";
683 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
684 llvm::emitSourceFileHeader("Enum Utility Definitions", os, recordKeeper);
686 auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
687 for (const auto *def : defs)
688 emitEnumDef(*def, os);
690 return false;
693 // Registers the enum utility generator to mlir-tblgen.
694 static mlir::GenRegistration
695 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
696 [](const RecordKeeper &records, raw_ostream &os) {
697 return emitEnumDecls(records, os);
700 // Registers the enum utility generator to mlir-tblgen.
701 static mlir::GenRegistration
702 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
703 [](const RecordKeeper &records, raw_ostream &os) {
704 return emitEnumDefs(records, os);