[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / AttrOrTypeDefGen.cpp
blob8cc8314418104c01943369a630bd9bab5b85005e
1 //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "mlir/TableGen/AttrOrTypeDef.h"
11 #include "mlir/TableGen/Class.h"
12 #include "mlir/TableGen/CodeGenHelpers.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Interfaces.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/TableGenBackend.h"
21 #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
23 using namespace mlir;
24 using namespace mlir::tblgen;
26 //===----------------------------------------------------------------------===//
27 // Utility Functions
28 //===----------------------------------------------------------------------===//
30 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
31 /// specified and can only find one dialect's defs, use that.
32 static void collectAllDefs(StringRef selectedDialect,
33 std::vector<llvm::Record *> records,
34 SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
35 // Nothing to do if no defs were found.
36 if (records.empty())
37 return;
39 auto defs = llvm::map_range(
40 records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
41 if (selectedDialect.empty()) {
42 // If a dialect was not specified, ensure that all found defs belong to the
43 // same dialect.
44 if (!llvm::all_equal(llvm::map_range(
45 defs, [](const auto &def) { return def.getDialect(); }))) {
46 llvm::PrintFatalError("defs belonging to more than one dialect. Must "
47 "select one via '--(attr|type)defs-dialect'");
49 resultDefs.assign(defs.begin(), defs.end());
50 } else {
51 // Otherwise, generate the defs that belong to the selected dialect.
52 auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
53 return def.getDialect().getName() == selectedDialect;
54 });
55 resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
59 //===----------------------------------------------------------------------===//
60 // DefGen
61 //===----------------------------------------------------------------------===//
63 namespace {
64 class DefGen {
65 public:
66 /// Create the attribute or type class.
67 DefGen(const AttrOrTypeDef &def);
69 void emitDecl(raw_ostream &os) const {
70 if (storageCls && def.genStorageClass()) {
71 NamespaceEmitter ns(os, def.getStorageNamespace());
72 os << "struct " << def.getStorageClassName() << ";\n";
74 defCls.writeDeclTo(os);
76 void emitDef(raw_ostream &os) const {
77 if (storageCls && def.genStorageClass()) {
78 NamespaceEmitter ns(os, def.getStorageNamespace());
79 storageCls->writeDeclTo(os); // everything is inline
81 defCls.writeDefTo(os);
84 private:
85 /// Add traits from the TableGen definition to the class.
86 void createParentWithTraits();
87 /// Emit top-level declarations: using declarations and any extra class
88 /// declarations.
89 void emitTopLevelDeclarations();
90 /// Emit the function that returns the type or attribute name.
91 void emitName();
92 /// Emit the dialect name as a static member variable.
93 void emitDialectName();
94 /// Emit attribute or type builders.
95 void emitBuilders();
96 /// Emit a verifier for the def.
97 void emitVerifier();
98 /// Emit parsers and printers.
99 void emitParserPrinter();
100 /// Emit parameter accessors, if required.
101 void emitAccessors();
102 /// Emit interface methods.
103 void emitInterfaceMethods();
105 //===--------------------------------------------------------------------===//
106 // Builder Emission
108 /// Emit the default builder `Attribute::get`
109 void emitDefaultBuilder();
110 /// Emit the checked builder `Attribute::getChecked`
111 void emitCheckedBuilder();
112 /// Emit a custom builder.
113 void emitCustomBuilder(const AttrOrTypeBuilder &builder);
114 /// Emit a checked custom builder.
115 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
117 //===--------------------------------------------------------------------===//
118 // Interface Method Emission
120 /// Emit methods for a trait.
121 void emitTraitMethods(const InterfaceTrait &trait);
122 /// Emit a trait method.
123 void emitTraitMethod(const InterfaceMethod &method);
125 //===--------------------------------------------------------------------===//
126 // Storage Class Emission
127 void emitStorageClass();
128 /// Generate the storage class constructor.
129 void emitStorageConstructor();
130 /// Emit the key type `KeyTy`.
131 void emitKeyType();
132 /// Emit the equality comparison operator.
133 void emitEquals();
134 /// Emit the key hash function.
135 void emitHashKey();
136 /// Emit the function to construct the storage class.
137 void emitConstruct();
139 //===--------------------------------------------------------------------===//
140 // Utility Function Declarations
142 /// Get the method parameters for a def builder, where the first several
143 /// parameters may be different.
144 SmallVector<MethodParameter>
145 getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
147 //===--------------------------------------------------------------------===//
148 // Class fields
150 /// The attribute or type definition.
151 const AttrOrTypeDef &def;
152 /// The list of attribute or type parameters.
153 ArrayRef<AttrOrTypeParameter> params;
154 /// The attribute or type class.
155 Class defCls;
156 /// An optional attribute or type storage class. The storage class will
157 /// exist if and only if the def has more than zero parameters.
158 std::optional<Class> storageCls;
160 /// The C++ base value of the def, either "Attribute" or "Type".
161 StringRef valueType;
162 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
163 StringRef defType;
165 } // namespace
167 DefGen::DefGen(const AttrOrTypeDef &def)
168 : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
169 valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
170 defType(isa<AttrDef>(def) ? "Attr" : "Type") {
171 // Check that all parameters have names.
172 for (const AttrOrTypeParameter &param : def.getParameters())
173 if (param.isAnonymous())
174 llvm::PrintFatalError("all parameters must have a name");
176 // If a storage class is needed, create one.
177 if (def.getNumParameters() > 0)
178 storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
180 // Create the parent class with any indicated traits.
181 createParentWithTraits();
182 // Emit top-level declarations.
183 emitTopLevelDeclarations();
184 // Emit builders for defs with parameters
185 if (storageCls)
186 emitBuilders();
187 // Emit the type name.
188 emitName();
189 // Emit the dialect name.
190 emitDialectName();
191 // Emit the verifier.
192 if (storageCls && def.genVerifyDecl())
193 emitVerifier();
194 // Emit the mnemonic, if there is one, and any associated parser and printer.
195 if (def.getMnemonic())
196 emitParserPrinter();
197 // Emit accessors
198 if (def.genAccessors())
199 emitAccessors();
200 // Emit trait interface methods
201 emitInterfaceMethods();
202 defCls.finalize();
203 // Emit a storage class if one is needed
204 if (storageCls && def.genStorageClass())
205 emitStorageClass();
208 void DefGen::createParentWithTraits() {
209 ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
210 defParent.addTemplateParam(def.getCppClassName());
211 defParent.addTemplateParam(def.getCppBaseClassName());
212 defParent.addTemplateParam(storageCls
213 ? strfmt("{0}::{1}", def.getStorageNamespace(),
214 def.getStorageClassName())
215 : strfmt("::mlir::{0}Storage", valueType));
216 for (auto &trait : def.getTraits()) {
217 defParent.addTemplateParam(
218 isa<NativeTrait>(&trait)
219 ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
220 : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
222 defCls.addParent(std::move(defParent));
225 /// Include declarations specified on NativeTrait
226 static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
227 SmallVector<StringRef> extraDeclarations;
228 // Include extra class declarations from NativeTrait
229 for (const auto &trait : def.getTraits()) {
230 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
231 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
232 if (value.empty())
233 continue;
234 extraDeclarations.push_back(value);
237 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
238 extraDeclarations.push_back(*extraDecl);
240 return llvm::join(extraDeclarations, "\n");
243 /// Extra class definitions have a `$cppClass` substitution that is to be
244 /// replaced by the C++ class name.
245 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
246 SmallVector<StringRef> extraDefinitions;
247 // Include extra class definitions from NativeTrait
248 for (const auto &trait : def.getTraits()) {
249 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
250 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
251 if (value.empty())
252 continue;
253 extraDefinitions.push_back(value);
256 if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
257 extraDefinitions.push_back(*extraDef);
259 FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
260 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
263 void DefGen::emitTopLevelDeclarations() {
264 // Inherit constructors from the attribute or type class.
265 defCls.declare<VisibilityDeclaration>(Visibility::Public);
266 defCls.declare<UsingDeclaration>("Base::Base");
268 // Emit the extra declarations first in case there's a definition in there.
269 std::string extraDecl = formatExtraDeclarations(def);
270 std::string extraDef = formatExtraDefinitions(def);
271 defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
272 std::move(extraDef));
275 void DefGen::emitName() {
276 StringRef name;
277 if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
278 name = attrDef->getAttrName();
279 } else {
280 auto *typeDef = cast<TypeDef>(&def);
281 name = typeDef->getTypeName();
283 std::string nameDecl =
284 strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name);
285 defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
288 void DefGen::emitDialectName() {
289 std::string decl =
290 strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
291 def.getDialect().getName());
292 defCls.declare<ExtraClassDeclaration>(std::move(decl));
295 void DefGen::emitBuilders() {
296 if (!def.skipDefaultBuilders()) {
297 emitDefaultBuilder();
298 if (def.genVerifyDecl())
299 emitCheckedBuilder();
301 for (auto &builder : def.getBuilders()) {
302 emitCustomBuilder(builder);
303 if (def.genVerifyDecl())
304 emitCheckedCustomBuilder(builder);
308 void DefGen::emitVerifier() {
309 defCls.declare<UsingDeclaration>("Base::getChecked");
310 defCls.declareStaticMethod(
311 "::llvm::LogicalResult", "verify",
312 getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
313 "emitError"}}));
316 void DefGen::emitParserPrinter() {
317 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
318 "::llvm::StringLiteral", "getMnemonic");
319 mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
321 // Declare the parser and printer, if needed.
322 bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
323 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
324 return;
326 // Declare the parser.
327 SmallVector<MethodParameter> parserParams;
328 parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
329 if (isa<AttrDef>(&def))
330 parserParams.emplace_back("::mlir::Type", "odsType");
331 auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
332 hasAssemblyFormat ? Method::Static
333 : Method::StaticDeclaration,
334 std::move(parserParams));
335 // Declare the printer.
336 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
337 Method *printer =
338 defCls.addMethod("void", "print", props,
339 MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
340 // Emit the bodies if we are using the declarative format.
341 if (hasAssemblyFormat)
342 return generateAttrOrTypeFormat(def, parser->body(), printer->body());
345 void DefGen::emitAccessors() {
346 for (auto &param : params) {
347 Method *m = defCls.addMethod(
348 param.getCppAccessorType(), param.getAccessorName(),
349 def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
350 // Generate accessor definitions only if we also generate the storage
351 // class. Otherwise, let the user define the exact accessor definition.
352 if (!def.genStorageClass())
353 continue;
354 m->body().indent() << "return getImpl()->" << param.getName() << ";";
358 void DefGen::emitInterfaceMethods() {
359 for (auto &traitDef : def.getTraits())
360 if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
361 if (trait->shouldDeclareMethods())
362 emitTraitMethods(*trait);
365 //===----------------------------------------------------------------------===//
366 // Builder Emission
368 SmallVector<MethodParameter>
369 DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
370 SmallVector<MethodParameter> builderParams;
371 builderParams.append(prefix.begin(), prefix.end());
372 for (auto &param : params)
373 builderParams.emplace_back(param.getCppType(), param.getName());
374 return builderParams;
377 void DefGen::emitDefaultBuilder() {
378 Method *m = defCls.addStaticMethod(
379 def.getCppClassName(), "get",
380 getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
381 MethodBody &body = m->body().indent();
382 auto scope = body.scope("return Base::get(context", ");");
383 for (const auto &param : params)
384 body << ", std::move(" << param.getName() << ")";
387 void DefGen::emitCheckedBuilder() {
388 Method *m = defCls.addStaticMethod(
389 def.getCppClassName(), "getChecked",
390 getBuilderParams(
391 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
392 {"::mlir::MLIRContext *", "context"}}));
393 MethodBody &body = m->body().indent();
394 auto scope = body.scope("return Base::getChecked(emitError, context", ");");
395 for (const auto &param : params)
396 body << ", " << param.getName();
399 static SmallVector<MethodParameter>
400 getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
401 const AttrOrTypeBuilder &builder) {
402 auto params = builder.getParameters();
403 SmallVector<MethodParameter> builderParams;
404 builderParams.append(prefix.begin(), prefix.end());
405 if (!builder.hasInferredContextParameter())
406 builderParams.emplace_back("::mlir::MLIRContext *", "context");
407 for (auto &param : params) {
408 builderParams.emplace_back(param.getCppType(), *param.getName(),
409 param.getDefaultValue());
411 return builderParams;
414 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
415 // Don't emit a body if there isn't one.
416 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
417 StringRef returnType = def.getCppClassName();
418 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
419 returnType = *builderReturnType;
420 Method *m = defCls.addMethod(returnType, "get", props,
421 getCustomBuilderParams({}, builder));
422 if (!builder.getBody())
423 return;
425 // Format the body and emit it.
426 FmtContext ctx;
427 ctx.addSubst("_get", "Base::get");
428 if (!builder.hasInferredContextParameter())
429 ctx.addSubst("_ctxt", "context");
430 std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
431 m->body().indent().getStream().printReindented(bodyStr);
434 /// Replace all instances of 'from' to 'to' in `str` and return the new string.
435 static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
436 size_t pos = 0;
437 while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
438 str.replace(pos, from.size(), to.data(), to.size());
439 return str;
442 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
443 // Don't emit a body if there isn't one.
444 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
445 StringRef returnType = def.getCppClassName();
446 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
447 returnType = *builderReturnType;
448 Method *m = defCls.addMethod(
449 returnType, "getChecked", props,
450 getCustomBuilderParams(
451 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
452 builder));
453 if (!builder.getBody())
454 return;
456 // Format the body and emit it. Replace $_get(...) with
457 // Base::getChecked(emitError, ...)
458 FmtContext ctx;
459 if (!builder.hasInferredContextParameter())
460 ctx.addSubst("_ctxt", "context");
461 std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
462 "Base::getChecked(emitError, ");
463 bodyStr = tgfmt(bodyStr, &ctx);
464 m->body().indent().getStream().printReindented(bodyStr);
467 //===----------------------------------------------------------------------===//
468 // Interface Method Emission
470 void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
471 // Get the set of methods that should always be declared.
472 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
473 StringSet<> alwaysDeclared;
474 alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
475 alwaysDeclaredMethods.end());
477 Interface iface = trait.getInterface(); // causes strange bugs if elided
478 for (auto &method : iface.getMethods()) {
479 // Don't declare if the method has a body. Or if the method has a default
480 // implementation and the def didn't request that it always be declared.
481 if (method.getBody() || (method.getDefaultImplementation() &&
482 !alwaysDeclared.count(method.getName())))
483 continue;
484 emitTraitMethod(method);
488 void DefGen::emitTraitMethod(const InterfaceMethod &method) {
489 // All interface methods are declaration-only.
490 auto props =
491 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
492 SmallVector<MethodParameter> params;
493 for (auto &param : method.getArguments())
494 params.emplace_back(param.type, param.name);
495 defCls.addMethod(method.getReturnType(), method.getName(), props,
496 std::move(params));
499 //===----------------------------------------------------------------------===//
500 // Storage Class Emission
502 void DefGen::emitStorageConstructor() {
503 Constructor *ctor =
504 storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
505 for (auto &param : params) {
506 std::string movedValue = ("std::move(" + param.getName() + ")").str();
507 ctor->addMemberInitializer(param.getName(), movedValue);
511 void DefGen::emitKeyType() {
512 std::string keyType("std::tuple<");
513 llvm::raw_string_ostream os(keyType);
514 llvm::interleaveComma(params, os,
515 [&](auto &param) { os << param.getCppType(); });
516 os << '>';
517 storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
519 // Add a method to construct the key type from the storage.
520 Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
521 m->body().indent() << "return KeyTy(";
522 llvm::interleaveComma(params, m->body().indent(),
523 [&](auto &param) { m->body() << param.getName(); });
524 m->body() << ");";
527 void DefGen::emitEquals() {
528 Method *eq = storageCls->addConstMethod<Method::Inline>(
529 "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
530 auto &body = eq->body().indent();
531 auto scope = body.scope("return (", ");");
532 const auto eachFn = [&](auto it) {
533 FmtContext ctx({{"_lhs", it.value().getName()},
534 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
535 body << tgfmt(it.value().getComparator(), &ctx);
537 llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
540 void DefGen::emitHashKey() {
541 Method *hash = storageCls->addStaticInlineMethod(
542 "::llvm::hash_code", "hashKey",
543 MethodParameter("const KeyTy &", "tblgenKey"));
544 auto &body = hash->body().indent();
545 auto scope = body.scope("return ::llvm::hash_combine(", ");");
546 llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
547 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
551 void DefGen::emitConstruct() {
552 Method *construct = storageCls->addMethod<Method::Inline>(
553 strfmt("{0} *", def.getStorageClassName()), "construct",
554 def.hasStorageCustomConstructor() ? Method::StaticDeclaration
555 : Method::Static,
556 MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
557 "allocator"),
558 MethodParameter("KeyTy &&", "tblgenKey"));
559 if (!def.hasStorageCustomConstructor()) {
560 auto &body = construct->body().indent();
561 for (const auto &it : llvm::enumerate(params)) {
562 body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
563 it.value().getName(), it.index());
565 // Use the parameters' custom allocator code, if provided.
566 FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
567 for (auto &param : params) {
568 if (std::optional<StringRef> allocCode = param.getAllocator()) {
569 ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
570 body << tgfmt(*allocCode, &ctx) << '\n';
573 auto scope =
574 body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
575 def.getStorageClassName()),
576 ");");
577 llvm::interleaveComma(params, body, [&](auto &param) {
578 body << "std::move(" << param.getName() << ")";
583 void DefGen::emitStorageClass() {
584 // Add the appropriate parent class.
585 storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
586 // Add the constructor.
587 emitStorageConstructor();
588 // Declare the key type.
589 emitKeyType();
590 // Add the comparison method.
591 emitEquals();
592 // Emit the key hash method.
593 emitHashKey();
594 // Emit the storage constructor. Just declare it if the user wants to define
595 // it themself.
596 emitConstruct();
597 // Emit the storage class members as public, at the very end of the struct.
598 storageCls->finalize();
599 for (auto &param : params)
600 storageCls->declare<Field>(param.getCppType(), param.getName());
603 //===----------------------------------------------------------------------===//
604 // DefGenerator
605 //===----------------------------------------------------------------------===//
607 namespace {
608 /// This struct is the base generator used when processing tablegen interfaces.
609 class DefGenerator {
610 public:
611 bool emitDecls(StringRef selectedDialect);
612 bool emitDefs(StringRef selectedDialect);
614 protected:
615 DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
616 StringRef defType, StringRef valueType, bool isAttrGenerator)
617 : defRecords(std::move(defs)), os(os), defType(defType),
618 valueType(valueType), isAttrGenerator(isAttrGenerator) {
619 // Sort by occurrence in file.
620 llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
621 return lhs->getID() < rhs->getID();
625 /// Emit the list of def type names.
626 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
627 /// Emit the code to dispatch between different defs during parsing/printing.
628 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
630 /// The set of def records to emit.
631 std::vector<llvm::Record *> defRecords;
632 /// The attribute or type class to emit.
633 /// The stream to emit to.
634 raw_ostream &os;
635 /// The prefix of the tablegen def name, e.g. Attr or Type.
636 StringRef defType;
637 /// The C++ base value type of the def, e.g. Attribute or Type.
638 StringRef valueType;
639 /// Flag indicating if this generator is for Attributes. False if the
640 /// generator is for types.
641 bool isAttrGenerator;
644 /// A specialized generator for AttrDefs.
645 struct AttrDefGenerator : public DefGenerator {
646 AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
647 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
648 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
650 /// A specialized generator for TypeDefs.
651 struct TypeDefGenerator : public DefGenerator {
652 TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
653 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
654 "Type", "Type", /*isAttrGenerator=*/false) {}
656 } // namespace
658 //===----------------------------------------------------------------------===//
659 // GEN: Declarations
660 //===----------------------------------------------------------------------===//
662 /// Print this above all the other declarations. Contains type declarations used
663 /// later on.
664 static const char *const typeDefDeclHeader = R"(
665 namespace mlir {
666 class AsmParser;
667 class AsmPrinter;
668 } // namespace mlir
671 bool DefGenerator::emitDecls(StringRef selectedDialect) {
672 emitSourceFileHeader((defType + "Def Declarations").str(), os);
673 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
675 // Output the common "header".
676 os << typeDefDeclHeader;
678 SmallVector<AttrOrTypeDef, 16> defs;
679 collectAllDefs(selectedDialect, defRecords, defs);
680 if (defs.empty())
681 return false;
683 NamespaceEmitter nsEmitter(os, defs.front().getDialect());
685 // Declare all the def classes first (in case they reference each other).
686 for (const AttrOrTypeDef &def : defs)
687 os << "class " << def.getCppClassName() << ";\n";
689 // Emit the declarations.
690 for (const AttrOrTypeDef &def : defs)
691 DefGen(def).emitDecl(os);
693 // Emit the TypeID explicit specializations to have a single definition for
694 // each of these.
695 for (const AttrOrTypeDef &def : defs)
696 if (!def.getDialect().getCppNamespace().empty())
697 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
698 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
699 << ")\n";
701 return false;
704 //===----------------------------------------------------------------------===//
705 // GEN: Def List
706 //===----------------------------------------------------------------------===//
708 void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
709 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
710 auto interleaveFn = [&](const AttrOrTypeDef &def) {
711 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
713 llvm::interleave(defs, os, interleaveFn, ",\n");
714 os << "\n";
717 //===----------------------------------------------------------------------===//
718 // GEN: Definitions
719 //===----------------------------------------------------------------------===//
721 /// The code block for default attribute parser/printer dispatch boilerplate.
722 /// {0}: the dialect fully qualified class name.
723 /// {1}: the optional code for the dynamic attribute parser dispatch.
724 /// {2}: the optional code for the dynamic attribute printer dispatch.
725 static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
726 /// Parse an attribute registered to this dialect.
727 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
728 ::mlir::Type type) const {{
729 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
730 ::llvm::StringRef attrTag;
732 ::mlir::Attribute attr;
733 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
734 if (parseResult.has_value())
735 return attr;
738 parser.emitError(typeLoc) << "unknown attribute `"
739 << attrTag << "` in dialect `" << getNamespace() << "`";
740 return {{};
742 /// Print an attribute registered to this dialect.
743 void {0}::printAttribute(::mlir::Attribute attr,
744 ::mlir::DialectAsmPrinter &printer) const {{
745 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
746 return;
751 /// The code block for dynamic attribute parser dispatch boilerplate.
752 static const char *const dialectDynamicAttrParserDispatch = R"(
754 ::mlir::Attribute genAttr;
755 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
756 if (parseResult.has_value()) {
757 if (::mlir::succeeded(parseResult.value()))
758 return genAttr;
759 return Attribute();
764 /// The code block for dynamic type printer dispatch boilerplate.
765 static const char *const dialectDynamicAttrPrinterDispatch = R"(
766 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
767 return;
770 /// The code block for default type parser/printer dispatch boilerplate.
771 /// {0}: the dialect fully qualified class name.
772 /// {1}: the optional code for the dynamic type parser dispatch.
773 /// {2}: the optional code for the dynamic type printer dispatch.
774 static const char *const dialectDefaultTypePrinterParserDispatch = R"(
775 /// Parse a type registered to this dialect.
776 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
777 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
778 ::llvm::StringRef mnemonic;
779 ::mlir::Type genType;
780 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
781 if (parseResult.has_value())
782 return genType;
784 parser.emitError(typeLoc) << "unknown type `"
785 << mnemonic << "` in dialect `" << getNamespace() << "`";
786 return {{};
788 /// Print a type registered to this dialect.
789 void {0}::printType(::mlir::Type type,
790 ::mlir::DialectAsmPrinter &printer) const {{
791 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
792 return;
797 /// The code block for dynamic type parser dispatch boilerplate.
798 static const char *const dialectDynamicTypeParserDispatch = R"(
800 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
801 if (parseResult.has_value()) {
802 if (::mlir::succeeded(parseResult.value()))
803 return genType;
804 return ::mlir::Type();
809 /// The code block for dynamic type printer dispatch boilerplate.
810 static const char *const dialectDynamicTypePrinterDispatch = R"(
811 if (::mlir::succeeded(printIfDynamicType(type, printer)))
812 return;
815 /// Emit the dialect printer/parser dispatcher. User's code should call these
816 /// functions from their dialect's print/parse methods.
817 void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
818 if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
819 return def.getMnemonic().has_value();
820 })) {
821 return;
823 // Declare the parser.
824 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
825 {"::llvm::StringRef *", "mnemonic"}};
826 if (isAttrGenerator)
827 params.emplace_back("::mlir::Type", "type");
828 params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
829 Method parse("::mlir::OptionalParseResult",
830 strfmt("generated{0}Parser", valueType), Method::StaticInline,
831 std::move(params));
832 // Declare the printer.
833 Method printer("::llvm::LogicalResult",
834 strfmt("generated{0}Printer", valueType), Method::StaticInline,
835 {{strfmt("::mlir::{0}", valueType), "def"},
836 {"::mlir::AsmPrinter &", "printer"}});
838 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
839 // calling the def's parse function.
840 parse.body() << " return "
841 "::mlir::AsmParser::KeywordSwitch<::mlir::"
842 "OptionalParseResult>(parser)\n";
843 const char *const getValueForMnemonic =
844 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
845 value = {0}::{1};
846 return ::mlir::success(!!value);
850 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
851 // printer.
852 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
853 << ", ::llvm::LogicalResult>(def)";
854 const char *const printValue = R"( .Case<{0}>([&](auto t) {{
855 printer << {0}::getMnemonic();{1}
856 return ::mlir::success();
859 for (auto &def : defs) {
860 if (!def.getMnemonic())
861 continue;
862 bool hasParserPrinterDecl =
863 def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
864 std::string defClass = strfmt(
865 "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
867 // If the def has no parameters or parser code, invoke a normal `get`.
868 std::string parseOrGet =
869 hasParserPrinterDecl
870 ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
871 : "get(parser.getContext())";
872 parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
874 // If the def has no parameters and no printer, just print the mnemonic.
875 StringRef printDef = "";
876 if (hasParserPrinterDecl)
877 printDef = "\nt.print(printer);";
878 printer.body() << llvm::formatv(printValue, defClass, printDef);
880 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
881 " *mnemonic = keyword;\n"
882 " return std::nullopt;\n"
883 " });";
884 printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
886 raw_indented_ostream indentedOs(os);
887 parse.writeDeclTo(indentedOs);
888 printer.writeDeclTo(indentedOs);
891 bool DefGenerator::emitDefs(StringRef selectedDialect) {
892 emitSourceFileHeader((defType + "Def Definitions").str(), os);
894 SmallVector<AttrOrTypeDef, 16> defs;
895 collectAllDefs(selectedDialect, defRecords, defs);
896 if (defs.empty())
897 return false;
898 emitTypeDefList(defs);
900 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
901 emitParsePrintDispatch(defs);
902 for (const AttrOrTypeDef &def : defs) {
904 NamespaceEmitter ns(os, def.getDialect());
905 DefGen gen(def);
906 gen.emitDef(os);
908 // Emit the TypeID explicit specializations to have a single symbol def.
909 if (!def.getDialect().getCppNamespace().empty())
910 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
911 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
912 << ")\n";
915 Dialect firstDialect = defs.front().getDialect();
917 // Emit the default parser/printer for Attributes if the dialect asked for it.
918 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
919 NamespaceEmitter nsEmitter(os, firstDialect);
920 if (firstDialect.isExtensible()) {
921 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
922 firstDialect.getCppClassName(),
923 dialectDynamicAttrParserDispatch,
924 dialectDynamicAttrPrinterDispatch);
925 } else {
926 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
927 firstDialect.getCppClassName(), "", "");
931 // Emit the default parser/printer for Types if the dialect asked for it.
932 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
933 NamespaceEmitter nsEmitter(os, firstDialect);
934 if (firstDialect.isExtensible()) {
935 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
936 firstDialect.getCppClassName(),
937 dialectDynamicTypeParserDispatch,
938 dialectDynamicTypePrinterDispatch);
939 } else {
940 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
941 firstDialect.getCppClassName(), "", "");
945 return false;
948 //===----------------------------------------------------------------------===//
949 // GEN: Registration hooks
950 //===----------------------------------------------------------------------===//
952 //===----------------------------------------------------------------------===//
953 // AttrDef
955 static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
956 static llvm::cl::opt<std::string>
957 attrDialect("attrdefs-dialect",
958 llvm::cl::desc("Generate attributes for this dialect"),
959 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
961 static mlir::GenRegistration
962 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
963 [](const llvm::RecordKeeper &records, raw_ostream &os) {
964 AttrDefGenerator generator(records, os);
965 return generator.emitDefs(attrDialect);
967 static mlir::GenRegistration
968 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
969 [](const llvm::RecordKeeper &records, raw_ostream &os) {
970 AttrDefGenerator generator(records, os);
971 return generator.emitDecls(attrDialect);
974 //===----------------------------------------------------------------------===//
975 // TypeDef
977 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
978 static llvm::cl::opt<std::string>
979 typeDialect("typedefs-dialect",
980 llvm::cl::desc("Generate types for this dialect"),
981 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
983 static mlir::GenRegistration
984 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
985 [](const llvm::RecordKeeper &records, raw_ostream &os) {
986 TypeDefGenerator generator(records, os);
987 return generator.emitDefs(typeDialect);
989 static mlir::GenRegistration
990 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
991 [](const llvm::RecordKeeper &records, raw_ostream &os) {
992 TypeDefGenerator generator(records, os);
993 return generator.emitDecls(typeDialect);