1 //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "mlir/TableGen/AttrOrTypeDef.h"
11 #include "mlir/TableGen/Class.h"
12 #include "mlir/TableGen/CodeGenHelpers.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Interfaces.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/TableGenBackend.h"
21 #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
24 using namespace mlir::tblgen
;
26 //===----------------------------------------------------------------------===//
28 //===----------------------------------------------------------------------===//
30 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
31 /// specified and can only find one dialect's defs, use that.
32 static void collectAllDefs(StringRef selectedDialect
,
33 std::vector
<llvm::Record
*> records
,
34 SmallVectorImpl
<AttrOrTypeDef
> &resultDefs
) {
35 // Nothing to do if no defs were found.
39 auto defs
= llvm::map_range(
40 records
, [&](const llvm::Record
*rec
) { return AttrOrTypeDef(rec
); });
41 if (selectedDialect
.empty()) {
42 // If a dialect was not specified, ensure that all found defs belong to the
44 if (!llvm::all_equal(llvm::map_range(
45 defs
, [](const auto &def
) { return def
.getDialect(); }))) {
46 llvm::PrintFatalError("defs belonging to more than one dialect. Must "
47 "select one via '--(attr|type)defs-dialect'");
49 resultDefs
.assign(defs
.begin(), defs
.end());
51 // Otherwise, generate the defs that belong to the selected dialect.
52 auto dialectDefs
= llvm::make_filter_range(defs
, [&](const auto &def
) {
53 return def
.getDialect().getName().equals(selectedDialect
);
55 resultDefs
.assign(dialectDefs
.begin(), dialectDefs
.end());
59 //===----------------------------------------------------------------------===//
61 //===----------------------------------------------------------------------===//
66 /// Create the attribute or type class.
67 DefGen(const AttrOrTypeDef
&def
);
69 void emitDecl(raw_ostream
&os
) const {
71 NamespaceEmitter
ns(os
, def
.getStorageNamespace());
72 os
<< "struct " << def
.getStorageClassName() << ";\n";
74 defCls
.writeDeclTo(os
);
76 void emitDef(raw_ostream
&os
) const {
77 if (storageCls
&& def
.genStorageClass()) {
78 NamespaceEmitter
ns(os
, def
.getStorageNamespace());
79 storageCls
->writeDeclTo(os
); // everything is inline
81 defCls
.writeDefTo(os
);
85 /// Add traits from the TableGen definition to the class.
86 void createParentWithTraits();
87 /// Emit top-level declarations: using declarations and any extra class
89 void emitTopLevelDeclarations();
90 /// Emit attribute or type builders.
92 /// Emit a verifier for the def.
94 /// Emit parsers and printers.
95 void emitParserPrinter();
96 /// Emit parameter accessors, if required.
98 /// Emit interface methods.
99 void emitInterfaceMethods();
101 //===--------------------------------------------------------------------===//
104 /// Emit the default builder `Attribute::get`
105 void emitDefaultBuilder();
106 /// Emit the checked builder `Attribute::getChecked`
107 void emitCheckedBuilder();
108 /// Emit a custom builder.
109 void emitCustomBuilder(const AttrOrTypeBuilder
&builder
);
110 /// Emit a checked custom builder.
111 void emitCheckedCustomBuilder(const AttrOrTypeBuilder
&builder
);
113 //===--------------------------------------------------------------------===//
114 // Interface Method Emission
116 /// Emit methods for a trait.
117 void emitTraitMethods(const InterfaceTrait
&trait
);
118 /// Emit a trait method.
119 void emitTraitMethod(const InterfaceMethod
&method
);
121 //===--------------------------------------------------------------------===//
122 // Storage Class Emission
123 void emitStorageClass();
124 /// Generate the storage class constructor.
125 void emitStorageConstructor();
126 /// Emit the key type `KeyTy`.
128 /// Emit the equality comparison operator.
130 /// Emit the key hash function.
132 /// Emit the function to construct the storage class.
133 void emitConstruct();
135 //===--------------------------------------------------------------------===//
136 // Utility Function Declarations
138 /// Get the method parameters for a def builder, where the first several
139 /// parameters may be different.
140 SmallVector
<MethodParameter
>
141 getBuilderParams(std::initializer_list
<MethodParameter
> prefix
) const;
143 //===--------------------------------------------------------------------===//
146 /// The attribute or type definition.
147 const AttrOrTypeDef
&def
;
148 /// The list of attribute or type parameters.
149 ArrayRef
<AttrOrTypeParameter
> params
;
150 /// The attribute or type class.
152 /// An optional attribute or type storage class. The storage class will
153 /// exist if and only if the def has more than zero parameters.
154 std::optional
<Class
> storageCls
;
156 /// The C++ base value of the def, either "Attribute" or "Type".
158 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
163 DefGen::DefGen(const AttrOrTypeDef
&def
)
164 : def(def
), params(def
.getParameters()), defCls(def
.getCppClassName()),
165 valueType(isa
<AttrDef
>(def
) ? "Attribute" : "Type"),
166 defType(isa
<AttrDef
>(def
) ? "Attr" : "Type") {
167 // Check that all parameters have names.
168 for (const AttrOrTypeParameter
¶m
: def
.getParameters())
169 if (param
.isAnonymous())
170 llvm::PrintFatalError("all parameters must have a name");
172 // If a storage class is needed, create one.
173 if (def
.getNumParameters() > 0)
174 storageCls
.emplace(def
.getStorageClassName(), /*isStruct=*/true);
176 // Create the parent class with any indicated traits.
177 createParentWithTraits();
178 // Emit top-level declarations.
179 emitTopLevelDeclarations();
180 // Emit builders for defs with parameters
183 // Emit the verifier.
184 if (storageCls
&& def
.genVerifyDecl())
186 // Emit the mnemonic, if there is one, and any associated parser and printer.
187 if (def
.getMnemonic())
190 if (def
.genAccessors())
192 // Emit trait interface methods
193 emitInterfaceMethods();
195 // Emit a storage class if one is needed
196 if (storageCls
&& def
.genStorageClass())
200 void DefGen::createParentWithTraits() {
201 ParentClass
defParent(strfmt("::mlir::{0}::{1}Base", valueType
, defType
));
202 defParent
.addTemplateParam(def
.getCppClassName());
203 defParent
.addTemplateParam(def
.getCppBaseClassName());
204 defParent
.addTemplateParam(storageCls
205 ? strfmt("{0}::{1}", def
.getStorageNamespace(),
206 def
.getStorageClassName())
207 : strfmt("::mlir::{0}Storage", valueType
));
208 for (auto &trait
: def
.getTraits()) {
209 defParent
.addTemplateParam(
210 isa
<NativeTrait
>(&trait
)
211 ? cast
<NativeTrait
>(&trait
)->getFullyQualifiedTraitName()
212 : cast
<InterfaceTrait
>(&trait
)->getFullyQualifiedTraitName());
214 defCls
.addParent(std::move(defParent
));
217 /// Extra class definitions have a `$cppClass` substitution that is to be
218 /// replaced by the C++ class name.
219 static std::string
formatExtraDefinitions(const AttrOrTypeDef
&def
) {
220 if (std::optional
<StringRef
> extraDef
= def
.getExtraDefs()) {
221 FmtContext ctx
= FmtContext().addSubst("cppClass", def
.getCppClassName());
222 return tgfmt(*extraDef
, &ctx
).str();
227 void DefGen::emitTopLevelDeclarations() {
228 // Inherit constructors from the attribute or type class.
229 defCls
.declare
<VisibilityDeclaration
>(Visibility::Public
);
230 defCls
.declare
<UsingDeclaration
>("Base::Base");
232 // Emit the extra declarations first in case there's a definition in there.
233 std::optional
<StringRef
> extraDecl
= def
.getExtraDecls();
234 std::string extraDef
= formatExtraDefinitions(def
);
235 defCls
.declare
<ExtraClassDeclaration
>(extraDecl
? *extraDecl
: "",
236 std::move(extraDef
));
239 void DefGen::emitBuilders() {
240 if (!def
.skipDefaultBuilders()) {
241 emitDefaultBuilder();
242 if (def
.genVerifyDecl())
243 emitCheckedBuilder();
245 for (auto &builder
: def
.getBuilders()) {
246 emitCustomBuilder(builder
);
247 if (def
.genVerifyDecl())
248 emitCheckedCustomBuilder(builder
);
252 void DefGen::emitVerifier() {
253 defCls
.declare
<UsingDeclaration
>("Base::getChecked");
254 defCls
.declareStaticMethod(
255 "::mlir::LogicalResult", "verify",
256 getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
260 void DefGen::emitParserPrinter() {
261 auto *mnemonic
= defCls
.addStaticMethod
<Method::Constexpr
>(
262 "::llvm::StringLiteral", "getMnemonic");
263 mnemonic
->body().indent() << strfmt("return {\"{0}\"};", *def
.getMnemonic());
265 // Declare the parser and printer, if needed.
266 bool hasAssemblyFormat
= def
.getAssemblyFormat().has_value();
267 if (!def
.hasCustomAssemblyFormat() && !hasAssemblyFormat
)
270 // Declare the parser.
271 SmallVector
<MethodParameter
> parserParams
;
272 parserParams
.emplace_back("::mlir::AsmParser &", "odsParser");
273 if (isa
<AttrDef
>(&def
))
274 parserParams
.emplace_back("::mlir::Type", "odsType");
275 auto *parser
= defCls
.addMethod(strfmt("::mlir::{0}", valueType
), "parse",
276 hasAssemblyFormat
? Method::Static
277 : Method::StaticDeclaration
,
278 std::move(parserParams
));
279 // Declare the printer.
280 auto props
= hasAssemblyFormat
? Method::Const
: Method::ConstDeclaration
;
282 defCls
.addMethod("void", "print", props
,
283 MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
284 // Emit the bodies if we are using the declarative format.
285 if (hasAssemblyFormat
)
286 return generateAttrOrTypeFormat(def
, parser
->body(), printer
->body());
289 void DefGen::emitAccessors() {
290 for (auto ¶m
: params
) {
291 Method
*m
= defCls
.addMethod(
292 param
.getCppAccessorType(), param
.getAccessorName(),
293 def
.genStorageClass() ? Method::Const
: Method::ConstDeclaration
);
294 // Generate accessor definitions only if we also generate the storage
295 // class. Otherwise, let the user define the exact accessor definition.
296 if (!def
.genStorageClass())
298 m
->body().indent() << "return getImpl()->" << param
.getName() << ";";
302 void DefGen::emitInterfaceMethods() {
303 for (auto &traitDef
: def
.getTraits())
304 if (auto *trait
= dyn_cast
<InterfaceTrait
>(&traitDef
))
305 if (trait
->shouldDeclareMethods())
306 emitTraitMethods(*trait
);
309 //===----------------------------------------------------------------------===//
312 SmallVector
<MethodParameter
>
313 DefGen::getBuilderParams(std::initializer_list
<MethodParameter
> prefix
) const {
314 SmallVector
<MethodParameter
> builderParams
;
315 builderParams
.append(prefix
.begin(), prefix
.end());
316 for (auto ¶m
: params
)
317 builderParams
.emplace_back(param
.getCppType(), param
.getName());
318 return builderParams
;
321 void DefGen::emitDefaultBuilder() {
322 Method
*m
= defCls
.addStaticMethod(
323 def
.getCppClassName(), "get",
324 getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
325 MethodBody
&body
= m
->body().indent();
326 auto scope
= body
.scope("return Base::get(context", ");");
327 for (const auto ¶m
: params
)
328 body
<< ", " << param
.getName();
331 void DefGen::emitCheckedBuilder() {
332 Method
*m
= defCls
.addStaticMethod(
333 def
.getCppClassName(), "getChecked",
335 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
336 {"::mlir::MLIRContext *", "context"}}));
337 MethodBody
&body
= m
->body().indent();
338 auto scope
= body
.scope("return Base::getChecked(emitError, context", ");");
339 for (const auto ¶m
: params
)
340 body
<< ", " << param
.getName();
343 static SmallVector
<MethodParameter
>
344 getCustomBuilderParams(std::initializer_list
<MethodParameter
> prefix
,
345 const AttrOrTypeBuilder
&builder
) {
346 auto params
= builder
.getParameters();
347 SmallVector
<MethodParameter
> builderParams
;
348 builderParams
.append(prefix
.begin(), prefix
.end());
349 if (!builder
.hasInferredContextParameter())
350 builderParams
.emplace_back("::mlir::MLIRContext *", "context");
351 for (auto ¶m
: params
) {
352 builderParams
.emplace_back(param
.getCppType(), *param
.getName(),
353 param
.getDefaultValue());
355 return builderParams
;
358 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder
&builder
) {
359 // Don't emit a body if there isn't one.
360 auto props
= builder
.getBody() ? Method::Static
: Method::StaticDeclaration
;
361 StringRef returnType
= def
.getCppClassName();
362 if (std::optional
<StringRef
> builderReturnType
= builder
.getReturnType())
363 returnType
= *builderReturnType
;
364 Method
*m
= defCls
.addMethod(returnType
, "get", props
,
365 getCustomBuilderParams({}, builder
));
366 if (!builder
.getBody())
369 // Format the body and emit it.
371 ctx
.addSubst("_get", "Base::get");
372 if (!builder
.hasInferredContextParameter())
373 ctx
.addSubst("_ctxt", "context");
374 std::string bodyStr
= tgfmt(*builder
.getBody(), &ctx
);
375 m
->body().indent().getStream().printReindented(bodyStr
);
378 /// Replace all instances of 'from' to 'to' in `str` and return the new string.
379 static std::string
replaceInStr(std::string str
, StringRef from
, StringRef to
) {
381 while ((pos
= str
.find(from
.data(), pos
, from
.size())) != std::string::npos
)
382 str
.replace(pos
, from
.size(), to
.data(), to
.size());
386 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder
&builder
) {
387 // Don't emit a body if there isn't one.
388 auto props
= builder
.getBody() ? Method::Static
: Method::StaticDeclaration
;
389 StringRef returnType
= def
.getCppClassName();
390 if (std::optional
<StringRef
> builderReturnType
= builder
.getReturnType())
391 returnType
= *builderReturnType
;
392 Method
*m
= defCls
.addMethod(
393 returnType
, "getChecked", props
,
394 getCustomBuilderParams(
395 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
397 if (!builder
.getBody())
400 // Format the body and emit it. Replace $_get(...) with
401 // Base::getChecked(emitError, ...)
403 if (!builder
.hasInferredContextParameter())
404 ctx
.addSubst("_ctxt", "context");
405 std::string bodyStr
= replaceInStr(builder
.getBody()->str(), "$_get(",
406 "Base::getChecked(emitError, ");
407 bodyStr
= tgfmt(bodyStr
, &ctx
);
408 m
->body().indent().getStream().printReindented(bodyStr
);
411 //===----------------------------------------------------------------------===//
412 // Interface Method Emission
414 void DefGen::emitTraitMethods(const InterfaceTrait
&trait
) {
415 // Get the set of methods that should always be declared.
416 auto alwaysDeclaredMethods
= trait
.getAlwaysDeclaredMethods();
417 StringSet
<> alwaysDeclared
;
418 alwaysDeclared
.insert(alwaysDeclaredMethods
.begin(),
419 alwaysDeclaredMethods
.end());
421 Interface iface
= trait
.getInterface(); // causes strange bugs if elided
422 for (auto &method
: iface
.getMethods()) {
423 // Don't declare if the method has a body. Or if the method has a default
424 // implementation and the def didn't request that it always be declared.
425 if (method
.getBody() || (method
.getDefaultImplementation() &&
426 !alwaysDeclared
.count(method
.getName())))
428 emitTraitMethod(method
);
432 void DefGen::emitTraitMethod(const InterfaceMethod
&method
) {
433 // All interface methods are declaration-only.
435 method
.isStatic() ? Method::StaticDeclaration
: Method::ConstDeclaration
;
436 SmallVector
<MethodParameter
> params
;
437 for (auto ¶m
: method
.getArguments())
438 params
.emplace_back(param
.type
, param
.name
);
439 defCls
.addMethod(method
.getReturnType(), method
.getName(), props
,
443 //===----------------------------------------------------------------------===//
444 // Storage Class Emission
446 void DefGen::emitStorageConstructor() {
448 storageCls
->addConstructor
<Method::Inline
>(getBuilderParams({}));
449 for (auto ¶m
: params
)
450 ctor
->addMemberInitializer(param
.getName(), param
.getName());
453 void DefGen::emitKeyType() {
454 std::string
keyType("std::tuple<");
455 llvm::raw_string_ostream
os(keyType
);
456 llvm::interleaveComma(params
, os
,
457 [&](auto ¶m
) { os
<< param
.getCppType(); });
459 storageCls
->declare
<UsingDeclaration
>("KeyTy", std::move(os
.str()));
461 // Add a method to construct the key type from the storage.
462 Method
*m
= storageCls
->addConstMethod
<Method::Inline
>("KeyTy", "getAsKey");
463 m
->body().indent() << "return KeyTy(";
464 llvm::interleaveComma(params
, m
->body().indent(),
465 [&](auto ¶m
) { m
->body() << param
.getName(); });
469 void DefGen::emitEquals() {
470 Method
*eq
= storageCls
->addConstMethod
<Method::Inline
>(
471 "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
472 auto &body
= eq
->body().indent();
473 auto scope
= body
.scope("return (", ");");
474 const auto eachFn
= [&](auto it
) {
475 FmtContext
ctx({{"_lhs", it
.value().getName()},
476 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it
.index())}});
477 body
<< tgfmt(it
.value().getComparator(), &ctx
);
479 llvm::interleave(llvm::enumerate(params
), body
, eachFn
, ") && (");
482 void DefGen::emitHashKey() {
483 Method
*hash
= storageCls
->addStaticInlineMethod(
484 "::llvm::hash_code", "hashKey",
485 MethodParameter("const KeyTy &", "tblgenKey"));
486 auto &body
= hash
->body().indent();
487 auto scope
= body
.scope("return ::llvm::hash_combine(", ");");
488 llvm::interleaveComma(llvm::enumerate(params
), body
, [&](auto it
) {
489 body
<< llvm::formatv("std::get<{0}>(tblgenKey)", it
.index());
493 void DefGen::emitConstruct() {
494 Method
*construct
= storageCls
->addMethod
<Method::Inline
>(
495 strfmt("{0} *", def
.getStorageClassName()), "construct",
496 def
.hasStorageCustomConstructor() ? Method::StaticDeclaration
498 MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType
),
500 MethodParameter("const KeyTy &", "tblgenKey"));
501 if (!def
.hasStorageCustomConstructor()) {
502 auto &body
= construct
->body().indent();
503 for (const auto &it
: llvm::enumerate(params
)) {
504 body
<< formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
505 it
.value().getName(), it
.index());
507 // Use the parameters' custom allocator code, if provided.
508 FmtContext ctx
= FmtContext().addSubst("_allocator", "allocator");
509 for (auto ¶m
: params
) {
510 if (std::optional
<StringRef
> allocCode
= param
.getAllocator()) {
511 ctx
.withSelf(param
.getName()).addSubst("_dst", param
.getName());
512 body
<< tgfmt(*allocCode
, &ctx
) << '\n';
516 body
.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
517 def
.getStorageClassName()),
519 llvm::interleaveComma(params
, body
,
520 [&](auto ¶m
) { body
<< param
.getName(); });
524 void DefGen::emitStorageClass() {
525 // Add the appropriate parent class.
526 storageCls
->addParent(strfmt("::mlir::{0}Storage", valueType
));
527 // Add the constructor.
528 emitStorageConstructor();
529 // Declare the key type.
531 // Add the comparison method.
533 // Emit the key hash method.
535 // Emit the storage constructor. Just declare it if the user wants to define
538 // Emit the storage class members as public, at the very end of the struct.
539 storageCls
->finalize();
540 for (auto ¶m
: params
)
541 storageCls
->declare
<Field
>(param
.getCppType(), param
.getName());
544 //===----------------------------------------------------------------------===//
546 //===----------------------------------------------------------------------===//
549 /// This struct is the base generator used when processing tablegen interfaces.
552 bool emitDecls(StringRef selectedDialect
);
553 bool emitDefs(StringRef selectedDialect
);
556 DefGenerator(std::vector
<llvm::Record
*> &&defs
, raw_ostream
&os
,
557 StringRef defType
, StringRef valueType
, bool isAttrGenerator
)
558 : defRecords(std::move(defs
)), os(os
), defType(defType
),
559 valueType(valueType
), isAttrGenerator(isAttrGenerator
) {}
561 /// Emit the list of def type names.
562 void emitTypeDefList(ArrayRef
<AttrOrTypeDef
> defs
);
563 /// Emit the code to dispatch between different defs during parsing/printing.
564 void emitParsePrintDispatch(ArrayRef
<AttrOrTypeDef
> defs
);
566 /// The set of def records to emit.
567 std::vector
<llvm::Record
*> defRecords
;
568 /// The attribute or type class to emit.
569 /// The stream to emit to.
571 /// The prefix of the tablegen def name, e.g. Attr or Type.
573 /// The C++ base value type of the def, e.g. Attribute or Type.
575 /// Flag indicating if this generator is for Attributes. False if the
576 /// generator is for types.
577 bool isAttrGenerator
;
580 /// A specialized generator for AttrDefs.
581 struct AttrDefGenerator
: public DefGenerator
{
582 AttrDefGenerator(const llvm::RecordKeeper
&records
, raw_ostream
&os
)
583 : DefGenerator(records
.getAllDerivedDefinitionsIfDefined("AttrDef"), os
,
584 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
586 /// A specialized generator for TypeDefs.
587 struct TypeDefGenerator
: public DefGenerator
{
588 TypeDefGenerator(const llvm::RecordKeeper
&records
, raw_ostream
&os
)
589 : DefGenerator(records
.getAllDerivedDefinitionsIfDefined("TypeDef"), os
,
590 "Type", "Type", /*isAttrGenerator=*/false) {}
594 //===----------------------------------------------------------------------===//
596 //===----------------------------------------------------------------------===//
598 /// Print this above all the other declarations. Contains type declarations used
600 static const char *const typeDefDeclHeader
= R
"(
607 bool DefGenerator::emitDecls(StringRef selectedDialect
) {
608 emitSourceFileHeader((defType
+ "Def Declarations").str(), os
);
609 IfDefScope
scope("GET_" + defType
.upper() + "DEF_CLASSES", os
);
611 // Output the common "header".
612 os
<< typeDefDeclHeader
;
614 SmallVector
<AttrOrTypeDef
, 16> defs
;
615 collectAllDefs(selectedDialect
, defRecords
, defs
);
619 NamespaceEmitter
nsEmitter(os
, defs
.front().getDialect());
621 // Declare all the def classes first (in case they reference each other).
622 for (const AttrOrTypeDef
&def
: defs
)
623 os
<< "class " << def
.getCppClassName() << ";\n";
625 // Emit the declarations.
626 for (const AttrOrTypeDef
&def
: defs
)
627 DefGen(def
).emitDecl(os
);
629 // Emit the TypeID explicit specializations to have a single definition for
631 for (const AttrOrTypeDef
&def
: defs
)
632 if (!def
.getDialect().getCppNamespace().empty())
633 os
<< "MLIR_DECLARE_EXPLICIT_TYPE_ID("
634 << def
.getDialect().getCppNamespace() << "::" << def
.getCppClassName()
640 //===----------------------------------------------------------------------===//
642 //===----------------------------------------------------------------------===//
644 void DefGenerator::emitTypeDefList(ArrayRef
<AttrOrTypeDef
> defs
) {
645 IfDefScope
scope("GET_" + defType
.upper() + "DEF_LIST", os
);
646 auto interleaveFn
= [&](const AttrOrTypeDef
&def
) {
647 os
<< def
.getDialect().getCppNamespace() << "::" << def
.getCppClassName();
649 llvm::interleave(defs
, os
, interleaveFn
, ",\n");
653 //===----------------------------------------------------------------------===//
655 //===----------------------------------------------------------------------===//
657 /// The code block for default attribute parser/printer dispatch boilerplate.
658 /// {0}: the dialect fully qualified class name.
659 /// {1}: the optional code for the dynamic attribute parser dispatch.
660 /// {2}: the optional code for the dynamic attribute printer dispatch.
661 static const char *const dialectDefaultAttrPrinterParserDispatch
= R
"(
662 /// Parse an attribute registered to this dialect.
663 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
664 ::mlir::Type type) const {{
665 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
666 ::llvm::StringRef attrTag;
668 ::mlir::Attribute attr;
669 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
670 if (parseResult.has_value())
674 parser.emitError(typeLoc) << "unknown attribute `
"
675 << attrTag << "` in dialect `
" << getNamespace() << "`
";
678 /// Print an attribute registered to this dialect.
679 void {0}::printAttribute(::mlir::Attribute attr,
680 ::mlir::DialectAsmPrinter &printer) const {{
681 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
687 /// The code block for dynamic attribute parser dispatch boilerplate.
688 static const char *const dialectDynamicAttrParserDispatch
= R
"(
690 ::mlir::Attribute genAttr;
691 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
692 if (parseResult.has_value()) {
693 if (::mlir::succeeded(parseResult.value()))
700 /// The code block for dynamic type printer dispatch boilerplate.
701 static const char *const dialectDynamicAttrPrinterDispatch
= R
"(
702 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
706 /// The code block for default type parser/printer dispatch boilerplate.
707 /// {0}: the dialect fully qualified class name.
708 /// {1}: the optional code for the dynamic type parser dispatch.
709 /// {2}: the optional code for the dynamic type printer dispatch.
710 static const char *const dialectDefaultTypePrinterParserDispatch
= R
"(
711 /// Parse a type registered to this dialect.
712 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
713 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
714 ::llvm::StringRef mnemonic;
715 ::mlir::Type genType;
716 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
717 if (parseResult.has_value())
720 parser.emitError(typeLoc) << "unknown type `
"
721 << mnemonic << "` in dialect `
" << getNamespace() << "`
";
724 /// Print a type registered to this dialect.
725 void {0}::printType(::mlir::Type type,
726 ::mlir::DialectAsmPrinter &printer) const {{
727 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
733 /// The code block for dynamic type parser dispatch boilerplate.
734 static const char *const dialectDynamicTypeParserDispatch
= R
"(
736 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
737 if (parseResult.hasValue()) {
738 if (::mlir::succeeded(parseResult.getValue()))
745 /// The code block for dynamic type printer dispatch boilerplate.
746 static const char *const dialectDynamicTypePrinterDispatch
= R
"(
747 if (::mlir::succeeded(printIfDynamicType(type, printer)))
751 /// Emit the dialect printer/parser dispatcher. User's code should call these
752 /// functions from their dialect's print/parse methods.
753 void DefGenerator::emitParsePrintDispatch(ArrayRef
<AttrOrTypeDef
> defs
) {
754 if (llvm::none_of(defs
, [](const AttrOrTypeDef
&def
) {
755 return def
.getMnemonic().has_value();
759 // Declare the parser.
760 SmallVector
<MethodParameter
> params
= {{"::mlir::AsmParser &", "parser"},
761 {"::llvm::StringRef *", "mnemonic"}};
763 params
.emplace_back("::mlir::Type", "type");
764 params
.emplace_back(strfmt("::mlir::{0} &", valueType
), "value");
765 Method
parse("::mlir::OptionalParseResult",
766 strfmt("generated{0}Parser", valueType
), Method::StaticInline
,
768 // Declare the printer.
769 Method
printer("::mlir::LogicalResult",
770 strfmt("generated{0}Printer", valueType
), Method::StaticInline
,
771 {{strfmt("::mlir::{0}", valueType
), "def"},
772 {"::mlir::AsmPrinter &", "printer"}});
774 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
775 // calling the def's parse function.
776 parse
.body() << " return "
777 "::mlir::AsmParser::KeywordSwitch<::mlir::"
778 "OptionalParseResult>(parser)\n";
779 const char *const getValueForMnemonic
=
780 R
"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
782 return ::mlir::success(!!value);
786 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
788 printer
.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
789 << ", ::mlir::LogicalResult>(def)";
790 const char *const printValue
= R
"( .Case<{0}>([&](auto t) {{
791 printer << {0}::getMnemonic();{1}
792 return ::mlir::success();
795 for (auto &def
: defs
) {
796 if (!def
.getMnemonic())
798 bool hasParserPrinterDecl
=
799 def
.hasCustomAssemblyFormat() || def
.getAssemblyFormat();
800 std::string defClass
= strfmt(
801 "{0}::{1}", def
.getDialect().getCppNamespace(), def
.getCppClassName());
803 // If the def has no parameters or parser code, invoke a normal `get`.
804 std::string parseOrGet
=
806 ? strfmt("parse(parser{0})", isAttrGenerator
? ", type" : "")
807 : "get(parser.getContext())";
808 parse
.body() << llvm::formatv(getValueForMnemonic
, defClass
, parseOrGet
);
810 // If the def has no parameters and no printer, just print the mnemonic.
811 StringRef printDef
= "";
812 if (hasParserPrinterDecl
)
813 printDef
= "\nt.print(printer);";
814 printer
.body() << llvm::formatv(printValue
, defClass
, printDef
);
816 parse
.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
817 " *mnemonic = keyword;\n"
818 " return std::nullopt;\n"
820 printer
.body() << " .Default([](auto) { return ::mlir::failure(); });";
822 raw_indented_ostream
indentedOs(os
);
823 parse
.writeDeclTo(indentedOs
);
824 printer
.writeDeclTo(indentedOs
);
827 bool DefGenerator::emitDefs(StringRef selectedDialect
) {
828 emitSourceFileHeader((defType
+ "Def Definitions").str(), os
);
830 SmallVector
<AttrOrTypeDef
, 16> defs
;
831 collectAllDefs(selectedDialect
, defRecords
, defs
);
834 emitTypeDefList(defs
);
836 IfDefScope
scope("GET_" + defType
.upper() + "DEF_CLASSES", os
);
837 emitParsePrintDispatch(defs
);
838 for (const AttrOrTypeDef
&def
: defs
) {
840 NamespaceEmitter
ns(os
, def
.getDialect());
844 // Emit the TypeID explicit specializations to have a single symbol def.
845 if (!def
.getDialect().getCppNamespace().empty())
846 os
<< "MLIR_DEFINE_EXPLICIT_TYPE_ID("
847 << def
.getDialect().getCppNamespace() << "::" << def
.getCppClassName()
851 Dialect firstDialect
= defs
.front().getDialect();
853 // Emit the default parser/printer for Attributes if the dialect asked for it.
854 if (isAttrGenerator
&& firstDialect
.useDefaultAttributePrinterParser()) {
855 NamespaceEmitter
nsEmitter(os
, firstDialect
);
856 if (firstDialect
.isExtensible()) {
857 os
<< llvm::formatv(dialectDefaultAttrPrinterParserDispatch
,
858 firstDialect
.getCppClassName(),
859 dialectDynamicAttrParserDispatch
,
860 dialectDynamicAttrPrinterDispatch
);
862 os
<< llvm::formatv(dialectDefaultAttrPrinterParserDispatch
,
863 firstDialect
.getCppClassName(), "", "");
867 // Emit the default parser/printer for Types if the dialect asked for it.
868 if (!isAttrGenerator
&& firstDialect
.useDefaultTypePrinterParser()) {
869 NamespaceEmitter
nsEmitter(os
, firstDialect
);
870 if (firstDialect
.isExtensible()) {
871 os
<< llvm::formatv(dialectDefaultTypePrinterParserDispatch
,
872 firstDialect
.getCppClassName(),
873 dialectDynamicTypeParserDispatch
,
874 dialectDynamicTypePrinterDispatch
);
876 os
<< llvm::formatv(dialectDefaultTypePrinterParserDispatch
,
877 firstDialect
.getCppClassName(), "", "");
884 //===----------------------------------------------------------------------===//
885 // GEN: Registration hooks
886 //===----------------------------------------------------------------------===//
888 //===----------------------------------------------------------------------===//
891 static llvm::cl::OptionCategory
attrdefGenCat("Options for -gen-attrdef-*");
892 static llvm::cl::opt
<std::string
>
893 attrDialect("attrdefs-dialect",
894 llvm::cl::desc("Generate attributes for this dialect"),
895 llvm::cl::cat(attrdefGenCat
), llvm::cl::CommaSeparated
);
897 static mlir::GenRegistration
898 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
899 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
900 AttrDefGenerator
generator(records
, os
);
901 return generator
.emitDefs(attrDialect
);
903 static mlir::GenRegistration
904 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
905 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
906 AttrDefGenerator
generator(records
, os
);
907 return generator
.emitDecls(attrDialect
);
910 //===----------------------------------------------------------------------===//
913 static llvm::cl::OptionCategory
typedefGenCat("Options for -gen-typedef-*");
914 static llvm::cl::opt
<std::string
>
915 typeDialect("typedefs-dialect",
916 llvm::cl::desc("Generate types for this dialect"),
917 llvm::cl::cat(typedefGenCat
), llvm::cl::CommaSeparated
);
919 static mlir::GenRegistration
920 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
921 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
922 TypeDefGenerator
generator(records
, os
);
923 return generator
.emitDefs(typeDialect
);
925 static mlir::GenRegistration
926 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
927 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
928 TypeDefGenerator
generator(records
, os
);
929 return generator
.emitDecls(typeDialect
);