[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / AttrOrTypeDefGen.cpp
blob51f24de2442b95799a0cdce48d3803debe04b436
1 //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "mlir/TableGen/AttrOrTypeDef.h"
11 #include "mlir/TableGen/Class.h"
12 #include "mlir/TableGen/CodeGenHelpers.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Interfaces.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/TableGenBackend.h"
21 #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
23 using namespace mlir;
24 using namespace mlir::tblgen;
26 //===----------------------------------------------------------------------===//
27 // Utility Functions
28 //===----------------------------------------------------------------------===//
30 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
31 /// specified and can only find one dialect's defs, use that.
32 static void collectAllDefs(StringRef selectedDialect,
33 std::vector<llvm::Record *> records,
34 SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
35 // Nothing to do if no defs were found.
36 if (records.empty())
37 return;
39 auto defs = llvm::map_range(
40 records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
41 if (selectedDialect.empty()) {
42 // If a dialect was not specified, ensure that all found defs belong to the
43 // same dialect.
44 if (!llvm::all_equal(llvm::map_range(
45 defs, [](const auto &def) { return def.getDialect(); }))) {
46 llvm::PrintFatalError("defs belonging to more than one dialect. Must "
47 "select one via '--(attr|type)defs-dialect'");
49 resultDefs.assign(defs.begin(), defs.end());
50 } else {
51 // Otherwise, generate the defs that belong to the selected dialect.
52 auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
53 return def.getDialect().getName().equals(selectedDialect);
54 });
55 resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
59 //===----------------------------------------------------------------------===//
60 // DefGen
61 //===----------------------------------------------------------------------===//
63 namespace {
64 class DefGen {
65 public:
66 /// Create the attribute or type class.
67 DefGen(const AttrOrTypeDef &def);
69 void emitDecl(raw_ostream &os) const {
70 if (storageCls && def.genStorageClass()) {
71 NamespaceEmitter ns(os, def.getStorageNamespace());
72 os << "struct " << def.getStorageClassName() << ";\n";
74 defCls.writeDeclTo(os);
76 void emitDef(raw_ostream &os) const {
77 if (storageCls && def.genStorageClass()) {
78 NamespaceEmitter ns(os, def.getStorageNamespace());
79 storageCls->writeDeclTo(os); // everything is inline
81 defCls.writeDefTo(os);
84 private:
85 /// Add traits from the TableGen definition to the class.
86 void createParentWithTraits();
87 /// Emit top-level declarations: using declarations and any extra class
88 /// declarations.
89 void emitTopLevelDeclarations();
90 /// Emit attribute or type builders.
91 void emitBuilders();
92 /// Emit a verifier for the def.
93 void emitVerifier();
94 /// Emit parsers and printers.
95 void emitParserPrinter();
96 /// Emit parameter accessors, if required.
97 void emitAccessors();
98 /// Emit interface methods.
99 void emitInterfaceMethods();
101 //===--------------------------------------------------------------------===//
102 // Builder Emission
104 /// Emit the default builder `Attribute::get`
105 void emitDefaultBuilder();
106 /// Emit the checked builder `Attribute::getChecked`
107 void emitCheckedBuilder();
108 /// Emit a custom builder.
109 void emitCustomBuilder(const AttrOrTypeBuilder &builder);
110 /// Emit a checked custom builder.
111 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
113 //===--------------------------------------------------------------------===//
114 // Interface Method Emission
116 /// Emit methods for a trait.
117 void emitTraitMethods(const InterfaceTrait &trait);
118 /// Emit a trait method.
119 void emitTraitMethod(const InterfaceMethod &method);
121 //===--------------------------------------------------------------------===//
122 // Storage Class Emission
123 void emitStorageClass();
124 /// Generate the storage class constructor.
125 void emitStorageConstructor();
126 /// Emit the key type `KeyTy`.
127 void emitKeyType();
128 /// Emit the equality comparison operator.
129 void emitEquals();
130 /// Emit the key hash function.
131 void emitHashKey();
132 /// Emit the function to construct the storage class.
133 void emitConstruct();
135 //===--------------------------------------------------------------------===//
136 // Utility Function Declarations
138 /// Get the method parameters for a def builder, where the first several
139 /// parameters may be different.
140 SmallVector<MethodParameter>
141 getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
143 //===--------------------------------------------------------------------===//
144 // Class fields
146 /// The attribute or type definition.
147 const AttrOrTypeDef &def;
148 /// The list of attribute or type parameters.
149 ArrayRef<AttrOrTypeParameter> params;
150 /// The attribute or type class.
151 Class defCls;
152 /// An optional attribute or type storage class. The storage class will
153 /// exist if and only if the def has more than zero parameters.
154 std::optional<Class> storageCls;
156 /// The C++ base value of the def, either "Attribute" or "Type".
157 StringRef valueType;
158 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
159 StringRef defType;
161 } // namespace
163 DefGen::DefGen(const AttrOrTypeDef &def)
164 : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
165 valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
166 defType(isa<AttrDef>(def) ? "Attr" : "Type") {
167 // Check that all parameters have names.
168 for (const AttrOrTypeParameter &param : def.getParameters())
169 if (param.isAnonymous())
170 llvm::PrintFatalError("all parameters must have a name");
172 // If a storage class is needed, create one.
173 if (def.getNumParameters() > 0)
174 storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
176 // Create the parent class with any indicated traits.
177 createParentWithTraits();
178 // Emit top-level declarations.
179 emitTopLevelDeclarations();
180 // Emit builders for defs with parameters
181 if (storageCls)
182 emitBuilders();
183 // Emit the verifier.
184 if (storageCls && def.genVerifyDecl())
185 emitVerifier();
186 // Emit the mnemonic, if there is one, and any associated parser and printer.
187 if (def.getMnemonic())
188 emitParserPrinter();
189 // Emit accessors
190 if (def.genAccessors())
191 emitAccessors();
192 // Emit trait interface methods
193 emitInterfaceMethods();
194 defCls.finalize();
195 // Emit a storage class if one is needed
196 if (storageCls && def.genStorageClass())
197 emitStorageClass();
200 void DefGen::createParentWithTraits() {
201 ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
202 defParent.addTemplateParam(def.getCppClassName());
203 defParent.addTemplateParam(def.getCppBaseClassName());
204 defParent.addTemplateParam(storageCls
205 ? strfmt("{0}::{1}", def.getStorageNamespace(),
206 def.getStorageClassName())
207 : strfmt("::mlir::{0}Storage", valueType));
208 for (auto &trait : def.getTraits()) {
209 defParent.addTemplateParam(
210 isa<NativeTrait>(&trait)
211 ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
212 : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
214 defCls.addParent(std::move(defParent));
217 /// Include declarations specified on NativeTrait
218 static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
219 SmallVector<StringRef> extraDeclarations;
220 // Include extra class declarations from NativeTrait
221 for (const auto &trait : def.getTraits()) {
222 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
223 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
224 if (value.empty())
225 continue;
226 extraDeclarations.push_back(value);
229 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
230 extraDeclarations.push_back(*extraDecl);
232 return llvm::join(extraDeclarations, "\n");
235 /// Extra class definitions have a `$cppClass` substitution that is to be
236 /// replaced by the C++ class name.
237 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
238 SmallVector<StringRef> extraDefinitions;
239 // Include extra class definitions from NativeTrait
240 for (const auto &trait : def.getTraits()) {
241 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
242 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
243 if (value.empty())
244 continue;
245 extraDefinitions.push_back(value);
248 if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
249 extraDefinitions.push_back(*extraDef);
251 FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
252 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
255 void DefGen::emitTopLevelDeclarations() {
256 // Inherit constructors from the attribute or type class.
257 defCls.declare<VisibilityDeclaration>(Visibility::Public);
258 defCls.declare<UsingDeclaration>("Base::Base");
260 // Emit the extra declarations first in case there's a definition in there.
261 std::string extraDecl = formatExtraDeclarations(def);
262 std::string extraDef = formatExtraDefinitions(def);
263 defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
264 std::move(extraDef));
267 void DefGen::emitBuilders() {
268 if (!def.skipDefaultBuilders()) {
269 emitDefaultBuilder();
270 if (def.genVerifyDecl())
271 emitCheckedBuilder();
273 for (auto &builder : def.getBuilders()) {
274 emitCustomBuilder(builder);
275 if (def.genVerifyDecl())
276 emitCheckedCustomBuilder(builder);
280 void DefGen::emitVerifier() {
281 defCls.declare<UsingDeclaration>("Base::getChecked");
282 defCls.declareStaticMethod(
283 "::mlir::LogicalResult", "verify",
284 getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
285 "emitError"}}));
288 void DefGen::emitParserPrinter() {
289 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
290 "::llvm::StringLiteral", "getMnemonic");
291 mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
293 // Declare the parser and printer, if needed.
294 bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
295 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
296 return;
298 // Declare the parser.
299 SmallVector<MethodParameter> parserParams;
300 parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
301 if (isa<AttrDef>(&def))
302 parserParams.emplace_back("::mlir::Type", "odsType");
303 auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
304 hasAssemblyFormat ? Method::Static
305 : Method::StaticDeclaration,
306 std::move(parserParams));
307 // Declare the printer.
308 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
309 Method *printer =
310 defCls.addMethod("void", "print", props,
311 MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
312 // Emit the bodies if we are using the declarative format.
313 if (hasAssemblyFormat)
314 return generateAttrOrTypeFormat(def, parser->body(), printer->body());
317 void DefGen::emitAccessors() {
318 for (auto &param : params) {
319 Method *m = defCls.addMethod(
320 param.getCppAccessorType(), param.getAccessorName(),
321 def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
322 // Generate accessor definitions only if we also generate the storage
323 // class. Otherwise, let the user define the exact accessor definition.
324 if (!def.genStorageClass())
325 continue;
326 m->body().indent() << "return getImpl()->" << param.getName() << ";";
330 void DefGen::emitInterfaceMethods() {
331 for (auto &traitDef : def.getTraits())
332 if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
333 if (trait->shouldDeclareMethods())
334 emitTraitMethods(*trait);
337 //===----------------------------------------------------------------------===//
338 // Builder Emission
340 SmallVector<MethodParameter>
341 DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
342 SmallVector<MethodParameter> builderParams;
343 builderParams.append(prefix.begin(), prefix.end());
344 for (auto &param : params)
345 builderParams.emplace_back(param.getCppType(), param.getName());
346 return builderParams;
349 void DefGen::emitDefaultBuilder() {
350 Method *m = defCls.addStaticMethod(
351 def.getCppClassName(), "get",
352 getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
353 MethodBody &body = m->body().indent();
354 auto scope = body.scope("return Base::get(context", ");");
355 for (const auto &param : params)
356 body << ", std::move(" << param.getName() << ")";
359 void DefGen::emitCheckedBuilder() {
360 Method *m = defCls.addStaticMethod(
361 def.getCppClassName(), "getChecked",
362 getBuilderParams(
363 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
364 {"::mlir::MLIRContext *", "context"}}));
365 MethodBody &body = m->body().indent();
366 auto scope = body.scope("return Base::getChecked(emitError, context", ");");
367 for (const auto &param : params)
368 body << ", " << param.getName();
371 static SmallVector<MethodParameter>
372 getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
373 const AttrOrTypeBuilder &builder) {
374 auto params = builder.getParameters();
375 SmallVector<MethodParameter> builderParams;
376 builderParams.append(prefix.begin(), prefix.end());
377 if (!builder.hasInferredContextParameter())
378 builderParams.emplace_back("::mlir::MLIRContext *", "context");
379 for (auto &param : params) {
380 builderParams.emplace_back(param.getCppType(), *param.getName(),
381 param.getDefaultValue());
383 return builderParams;
386 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
387 // Don't emit a body if there isn't one.
388 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
389 StringRef returnType = def.getCppClassName();
390 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
391 returnType = *builderReturnType;
392 Method *m = defCls.addMethod(returnType, "get", props,
393 getCustomBuilderParams({}, builder));
394 if (!builder.getBody())
395 return;
397 // Format the body and emit it.
398 FmtContext ctx;
399 ctx.addSubst("_get", "Base::get");
400 if (!builder.hasInferredContextParameter())
401 ctx.addSubst("_ctxt", "context");
402 std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
403 m->body().indent().getStream().printReindented(bodyStr);
406 /// Replace all instances of 'from' to 'to' in `str` and return the new string.
407 static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
408 size_t pos = 0;
409 while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
410 str.replace(pos, from.size(), to.data(), to.size());
411 return str;
414 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
415 // Don't emit a body if there isn't one.
416 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
417 StringRef returnType = def.getCppClassName();
418 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
419 returnType = *builderReturnType;
420 Method *m = defCls.addMethod(
421 returnType, "getChecked", props,
422 getCustomBuilderParams(
423 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
424 builder));
425 if (!builder.getBody())
426 return;
428 // Format the body and emit it. Replace $_get(...) with
429 // Base::getChecked(emitError, ...)
430 FmtContext ctx;
431 if (!builder.hasInferredContextParameter())
432 ctx.addSubst("_ctxt", "context");
433 std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
434 "Base::getChecked(emitError, ");
435 bodyStr = tgfmt(bodyStr, &ctx);
436 m->body().indent().getStream().printReindented(bodyStr);
439 //===----------------------------------------------------------------------===//
440 // Interface Method Emission
442 void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
443 // Get the set of methods that should always be declared.
444 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
445 StringSet<> alwaysDeclared;
446 alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
447 alwaysDeclaredMethods.end());
449 Interface iface = trait.getInterface(); // causes strange bugs if elided
450 for (auto &method : iface.getMethods()) {
451 // Don't declare if the method has a body. Or if the method has a default
452 // implementation and the def didn't request that it always be declared.
453 if (method.getBody() || (method.getDefaultImplementation() &&
454 !alwaysDeclared.count(method.getName())))
455 continue;
456 emitTraitMethod(method);
460 void DefGen::emitTraitMethod(const InterfaceMethod &method) {
461 // All interface methods are declaration-only.
462 auto props =
463 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
464 SmallVector<MethodParameter> params;
465 for (auto &param : method.getArguments())
466 params.emplace_back(param.type, param.name);
467 defCls.addMethod(method.getReturnType(), method.getName(), props,
468 std::move(params));
471 //===----------------------------------------------------------------------===//
472 // Storage Class Emission
474 void DefGen::emitStorageConstructor() {
475 Constructor *ctor =
476 storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
477 for (auto &param : params) {
478 std::string movedValue = ("std::move(" + param.getName() + ")").str();
479 ctor->addMemberInitializer(param.getName(), movedValue);
483 void DefGen::emitKeyType() {
484 std::string keyType("std::tuple<");
485 llvm::raw_string_ostream os(keyType);
486 llvm::interleaveComma(params, os,
487 [&](auto &param) { os << param.getCppType(); });
488 os << '>';
489 storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
491 // Add a method to construct the key type from the storage.
492 Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
493 m->body().indent() << "return KeyTy(";
494 llvm::interleaveComma(params, m->body().indent(),
495 [&](auto &param) { m->body() << param.getName(); });
496 m->body() << ");";
499 void DefGen::emitEquals() {
500 Method *eq = storageCls->addConstMethod<Method::Inline>(
501 "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
502 auto &body = eq->body().indent();
503 auto scope = body.scope("return (", ");");
504 const auto eachFn = [&](auto it) {
505 FmtContext ctx({{"_lhs", it.value().getName()},
506 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
507 body << tgfmt(it.value().getComparator(), &ctx);
509 llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
512 void DefGen::emitHashKey() {
513 Method *hash = storageCls->addStaticInlineMethod(
514 "::llvm::hash_code", "hashKey",
515 MethodParameter("const KeyTy &", "tblgenKey"));
516 auto &body = hash->body().indent();
517 auto scope = body.scope("return ::llvm::hash_combine(", ");");
518 llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
519 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
523 void DefGen::emitConstruct() {
524 Method *construct = storageCls->addMethod<Method::Inline>(
525 strfmt("{0} *", def.getStorageClassName()), "construct",
526 def.hasStorageCustomConstructor() ? Method::StaticDeclaration
527 : Method::Static,
528 MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
529 "allocator"),
530 MethodParameter("KeyTy &&", "tblgenKey"));
531 if (!def.hasStorageCustomConstructor()) {
532 auto &body = construct->body().indent();
533 for (const auto &it : llvm::enumerate(params)) {
534 body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
535 it.value().getName(), it.index());
537 // Use the parameters' custom allocator code, if provided.
538 FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
539 for (auto &param : params) {
540 if (std::optional<StringRef> allocCode = param.getAllocator()) {
541 ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
542 body << tgfmt(*allocCode, &ctx) << '\n';
545 auto scope =
546 body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
547 def.getStorageClassName()),
548 ");");
549 llvm::interleaveComma(params, body, [&](auto &param) {
550 body << "std::move(" << param.getName() << ")";
555 void DefGen::emitStorageClass() {
556 // Add the appropriate parent class.
557 storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
558 // Add the constructor.
559 emitStorageConstructor();
560 // Declare the key type.
561 emitKeyType();
562 // Add the comparison method.
563 emitEquals();
564 // Emit the key hash method.
565 emitHashKey();
566 // Emit the storage constructor. Just declare it if the user wants to define
567 // it themself.
568 emitConstruct();
569 // Emit the storage class members as public, at the very end of the struct.
570 storageCls->finalize();
571 for (auto &param : params)
572 storageCls->declare<Field>(param.getCppType(), param.getName());
575 //===----------------------------------------------------------------------===//
576 // DefGenerator
577 //===----------------------------------------------------------------------===//
579 namespace {
580 /// This struct is the base generator used when processing tablegen interfaces.
581 class DefGenerator {
582 public:
583 bool emitDecls(StringRef selectedDialect);
584 bool emitDefs(StringRef selectedDialect);
586 protected:
587 DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
588 StringRef defType, StringRef valueType, bool isAttrGenerator)
589 : defRecords(std::move(defs)), os(os), defType(defType),
590 valueType(valueType), isAttrGenerator(isAttrGenerator) {}
592 /// Emit the list of def type names.
593 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
594 /// Emit the code to dispatch between different defs during parsing/printing.
595 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
597 /// The set of def records to emit.
598 std::vector<llvm::Record *> defRecords;
599 /// The attribute or type class to emit.
600 /// The stream to emit to.
601 raw_ostream &os;
602 /// The prefix of the tablegen def name, e.g. Attr or Type.
603 StringRef defType;
604 /// The C++ base value type of the def, e.g. Attribute or Type.
605 StringRef valueType;
606 /// Flag indicating if this generator is for Attributes. False if the
607 /// generator is for types.
608 bool isAttrGenerator;
611 /// A specialized generator for AttrDefs.
612 struct AttrDefGenerator : public DefGenerator {
613 AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
614 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
615 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
617 /// A specialized generator for TypeDefs.
618 struct TypeDefGenerator : public DefGenerator {
619 TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
620 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
621 "Type", "Type", /*isAttrGenerator=*/false) {}
623 } // namespace
625 //===----------------------------------------------------------------------===//
626 // GEN: Declarations
627 //===----------------------------------------------------------------------===//
629 /// Print this above all the other declarations. Contains type declarations used
630 /// later on.
631 static const char *const typeDefDeclHeader = R"(
632 namespace mlir {
633 class AsmParser;
634 class AsmPrinter;
635 } // namespace mlir
638 bool DefGenerator::emitDecls(StringRef selectedDialect) {
639 emitSourceFileHeader((defType + "Def Declarations").str(), os);
640 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
642 // Output the common "header".
643 os << typeDefDeclHeader;
645 SmallVector<AttrOrTypeDef, 16> defs;
646 collectAllDefs(selectedDialect, defRecords, defs);
647 if (defs.empty())
648 return false;
650 NamespaceEmitter nsEmitter(os, defs.front().getDialect());
652 // Declare all the def classes first (in case they reference each other).
653 for (const AttrOrTypeDef &def : defs)
654 os << "class " << def.getCppClassName() << ";\n";
656 // Emit the declarations.
657 for (const AttrOrTypeDef &def : defs)
658 DefGen(def).emitDecl(os);
660 // Emit the TypeID explicit specializations to have a single definition for
661 // each of these.
662 for (const AttrOrTypeDef &def : defs)
663 if (!def.getDialect().getCppNamespace().empty())
664 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
665 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
666 << ")\n";
668 return false;
671 //===----------------------------------------------------------------------===//
672 // GEN: Def List
673 //===----------------------------------------------------------------------===//
675 void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
676 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
677 auto interleaveFn = [&](const AttrOrTypeDef &def) {
678 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
680 llvm::interleave(defs, os, interleaveFn, ",\n");
681 os << "\n";
684 //===----------------------------------------------------------------------===//
685 // GEN: Definitions
686 //===----------------------------------------------------------------------===//
688 /// The code block for default attribute parser/printer dispatch boilerplate.
689 /// {0}: the dialect fully qualified class name.
690 /// {1}: the optional code for the dynamic attribute parser dispatch.
691 /// {2}: the optional code for the dynamic attribute printer dispatch.
692 static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
693 /// Parse an attribute registered to this dialect.
694 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
695 ::mlir::Type type) const {{
696 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
697 ::llvm::StringRef attrTag;
699 ::mlir::Attribute attr;
700 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
701 if (parseResult.has_value())
702 return attr;
705 parser.emitError(typeLoc) << "unknown attribute `"
706 << attrTag << "` in dialect `" << getNamespace() << "`";
707 return {{};
709 /// Print an attribute registered to this dialect.
710 void {0}::printAttribute(::mlir::Attribute attr,
711 ::mlir::DialectAsmPrinter &printer) const {{
712 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
713 return;
718 /// The code block for dynamic attribute parser dispatch boilerplate.
719 static const char *const dialectDynamicAttrParserDispatch = R"(
721 ::mlir::Attribute genAttr;
722 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
723 if (parseResult.has_value()) {
724 if (::mlir::succeeded(parseResult.value()))
725 return genAttr;
726 return Attribute();
731 /// The code block for dynamic type printer dispatch boilerplate.
732 static const char *const dialectDynamicAttrPrinterDispatch = R"(
733 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
734 return;
737 /// The code block for default type parser/printer dispatch boilerplate.
738 /// {0}: the dialect fully qualified class name.
739 /// {1}: the optional code for the dynamic type parser dispatch.
740 /// {2}: the optional code for the dynamic type printer dispatch.
741 static const char *const dialectDefaultTypePrinterParserDispatch = R"(
742 /// Parse a type registered to this dialect.
743 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
744 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
745 ::llvm::StringRef mnemonic;
746 ::mlir::Type genType;
747 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
748 if (parseResult.has_value())
749 return genType;
751 parser.emitError(typeLoc) << "unknown type `"
752 << mnemonic << "` in dialect `" << getNamespace() << "`";
753 return {{};
755 /// Print a type registered to this dialect.
756 void {0}::printType(::mlir::Type type,
757 ::mlir::DialectAsmPrinter &printer) const {{
758 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
759 return;
764 /// The code block for dynamic type parser dispatch boilerplate.
765 static const char *const dialectDynamicTypeParserDispatch = R"(
767 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
768 if (parseResult.has_value()) {
769 if (::mlir::succeeded(parseResult.value()))
770 return genType;
771 return ::mlir::Type();
776 /// The code block for dynamic type printer dispatch boilerplate.
777 static const char *const dialectDynamicTypePrinterDispatch = R"(
778 if (::mlir::succeeded(printIfDynamicType(type, printer)))
779 return;
782 /// Emit the dialect printer/parser dispatcher. User's code should call these
783 /// functions from their dialect's print/parse methods.
784 void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
785 if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
786 return def.getMnemonic().has_value();
787 })) {
788 return;
790 // Declare the parser.
791 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
792 {"::llvm::StringRef *", "mnemonic"}};
793 if (isAttrGenerator)
794 params.emplace_back("::mlir::Type", "type");
795 params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
796 Method parse("::mlir::OptionalParseResult",
797 strfmt("generated{0}Parser", valueType), Method::StaticInline,
798 std::move(params));
799 // Declare the printer.
800 Method printer("::mlir::LogicalResult",
801 strfmt("generated{0}Printer", valueType), Method::StaticInline,
802 {{strfmt("::mlir::{0}", valueType), "def"},
803 {"::mlir::AsmPrinter &", "printer"}});
805 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
806 // calling the def's parse function.
807 parse.body() << " return "
808 "::mlir::AsmParser::KeywordSwitch<::mlir::"
809 "OptionalParseResult>(parser)\n";
810 const char *const getValueForMnemonic =
811 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
812 value = {0}::{1};
813 return ::mlir::success(!!value);
817 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
818 // printer.
819 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
820 << ", ::mlir::LogicalResult>(def)";
821 const char *const printValue = R"( .Case<{0}>([&](auto t) {{
822 printer << {0}::getMnemonic();{1}
823 return ::mlir::success();
826 for (auto &def : defs) {
827 if (!def.getMnemonic())
828 continue;
829 bool hasParserPrinterDecl =
830 def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
831 std::string defClass = strfmt(
832 "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
834 // If the def has no parameters or parser code, invoke a normal `get`.
835 std::string parseOrGet =
836 hasParserPrinterDecl
837 ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
838 : "get(parser.getContext())";
839 parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
841 // If the def has no parameters and no printer, just print the mnemonic.
842 StringRef printDef = "";
843 if (hasParserPrinterDecl)
844 printDef = "\nt.print(printer);";
845 printer.body() << llvm::formatv(printValue, defClass, printDef);
847 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
848 " *mnemonic = keyword;\n"
849 " return std::nullopt;\n"
850 " });";
851 printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
853 raw_indented_ostream indentedOs(os);
854 parse.writeDeclTo(indentedOs);
855 printer.writeDeclTo(indentedOs);
858 bool DefGenerator::emitDefs(StringRef selectedDialect) {
859 emitSourceFileHeader((defType + "Def Definitions").str(), os);
861 SmallVector<AttrOrTypeDef, 16> defs;
862 collectAllDefs(selectedDialect, defRecords, defs);
863 if (defs.empty())
864 return false;
865 emitTypeDefList(defs);
867 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
868 emitParsePrintDispatch(defs);
869 for (const AttrOrTypeDef &def : defs) {
871 NamespaceEmitter ns(os, def.getDialect());
872 DefGen gen(def);
873 gen.emitDef(os);
875 // Emit the TypeID explicit specializations to have a single symbol def.
876 if (!def.getDialect().getCppNamespace().empty())
877 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
878 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
879 << ")\n";
882 Dialect firstDialect = defs.front().getDialect();
884 // Emit the default parser/printer for Attributes if the dialect asked for it.
885 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
886 NamespaceEmitter nsEmitter(os, firstDialect);
887 if (firstDialect.isExtensible()) {
888 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
889 firstDialect.getCppClassName(),
890 dialectDynamicAttrParserDispatch,
891 dialectDynamicAttrPrinterDispatch);
892 } else {
893 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
894 firstDialect.getCppClassName(), "", "");
898 // Emit the default parser/printer for Types if the dialect asked for it.
899 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
900 NamespaceEmitter nsEmitter(os, firstDialect);
901 if (firstDialect.isExtensible()) {
902 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
903 firstDialect.getCppClassName(),
904 dialectDynamicTypeParserDispatch,
905 dialectDynamicTypePrinterDispatch);
906 } else {
907 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
908 firstDialect.getCppClassName(), "", "");
912 return false;
915 //===----------------------------------------------------------------------===//
916 // GEN: Registration hooks
917 //===----------------------------------------------------------------------===//
919 //===----------------------------------------------------------------------===//
920 // AttrDef
922 static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
923 static llvm::cl::opt<std::string>
924 attrDialect("attrdefs-dialect",
925 llvm::cl::desc("Generate attributes for this dialect"),
926 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
928 static mlir::GenRegistration
929 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
930 [](const llvm::RecordKeeper &records, raw_ostream &os) {
931 AttrDefGenerator generator(records, os);
932 return generator.emitDefs(attrDialect);
934 static mlir::GenRegistration
935 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
936 [](const llvm::RecordKeeper &records, raw_ostream &os) {
937 AttrDefGenerator generator(records, os);
938 return generator.emitDecls(attrDialect);
941 //===----------------------------------------------------------------------===//
942 // TypeDef
944 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
945 static llvm::cl::opt<std::string>
946 typeDialect("typedefs-dialect",
947 llvm::cl::desc("Generate types for this dialect"),
948 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
950 static mlir::GenRegistration
951 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
952 [](const llvm::RecordKeeper &records, raw_ostream &os) {
953 TypeDefGenerator generator(records, os);
954 return generator.emitDefs(typeDialect);
956 static mlir::GenRegistration
957 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
958 [](const llvm::RecordKeeper &records, raw_ostream &os) {
959 TypeDefGenerator generator(records, os);
960 return generator.emitDecls(typeDialect);