1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // DialectGen uses the description of dialects to generate C++ definitions.
11 //===----------------------------------------------------------------------===//
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
32 using namespace mlir::tblgen
;
34 using llvm::RecordKeeper
;
36 static llvm::cl::OptionCategory
dialectGenCat("Options for -gen-dialect-*");
37 llvm::cl::opt
<std::string
>
38 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
39 llvm::cl::cat(dialectGenCat
), llvm::cl::CommaSeparated
);
41 /// Utility iterator used for filtering records for a specific dialect.
43 using DialectFilterIterator
=
44 llvm::filter_iterator
<ArrayRef
<Record
*>::iterator
,
45 std::function
<bool(const Record
*)>>;
48 static void populateDiscardableAttributes(
49 Dialect
&dialect
, const llvm::DagInit
*discardableAttrDag
,
50 SmallVector
<std::pair
<std::string
, std::string
>> &discardableAttributes
) {
51 for (int i
: llvm::seq
<int>(0, discardableAttrDag
->getNumArgs())) {
52 const llvm::Init
*arg
= discardableAttrDag
->getArg(i
);
54 StringRef givenName
= discardableAttrDag
->getArgNameStr(i
);
55 if (givenName
.empty())
56 PrintFatalError(dialect
.getDef()->getLoc(),
57 "discardable attributes must be named");
58 discardableAttributes
.push_back(
59 {givenName
.str(), arg
->getAsUnquotedString()});
63 /// Given a set of records for a T, filter the ones that correspond to
64 /// the given dialect.
66 static iterator_range
<DialectFilterIterator
>
67 filterForDialect(ArrayRef
<Record
*> records
, Dialect
&dialect
) {
68 auto filterFn
= [&](const Record
*record
) {
69 return T(record
).getDialect() == dialect
;
71 return {DialectFilterIterator(records
.begin(), records
.end(), filterFn
),
72 DialectFilterIterator(records
.end(), records
.end(), filterFn
)};
75 std::optional
<Dialect
>
76 tblgen::findDialectToGenerate(ArrayRef
<Dialect
> dialects
) {
77 if (dialects
.empty()) {
78 llvm::errs() << "no dialect was found\n";
82 // Select the dialect to gen for.
83 if (dialects
.size() == 1 && selectedDialect
.getNumOccurrences() == 0)
84 return dialects
.front();
86 if (selectedDialect
.getNumOccurrences() == 0) {
87 llvm::errs() << "when more than 1 dialect is present, one must be selected "
92 const auto *dialectIt
= llvm::find_if(dialects
, [](const Dialect
&dialect
) {
93 return dialect
.getName() == selectedDialect
;
95 if (dialectIt
== dialects
.end()) {
96 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
102 //===----------------------------------------------------------------------===//
103 // GEN: Dialect declarations
104 //===----------------------------------------------------------------------===//
106 /// The code block for the start of a dialect class declaration.
108 /// {0}: The name of the dialect class.
109 /// {1}: The dialect namespace.
110 /// {2}: The dialect parent class.
111 static const char *const dialectDeclBeginStr
= R
"(
112 class {0} : public ::mlir::{2} {
113 explicit {0}(::mlir::MLIRContext *context);
116 friend class ::mlir::MLIRContext;
119 static constexpr ::llvm::StringLiteral getDialectNamespace() {
120 return ::llvm::StringLiteral("{1}");
124 /// Registration for a single dependent dialect: to be inserted in the ctor
125 /// above for each dependent dialect.
126 const char *const dialectRegistrationTemplate
=
127 "getContext()->loadDialect<{0}>();";
129 /// The code block for the attribute parser/printer hooks.
130 static const char *const attrParserDecl
= R
"(
131 /// Parse an attribute registered to this dialect.
132 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
133 ::mlir::Type type) const override;
135 /// Print an attribute registered to this dialect.
136 void printAttribute(::mlir::Attribute attr,
137 ::mlir::DialectAsmPrinter &os) const override;
140 /// The code block for the type parser/printer hooks.
141 static const char *const typeParserDecl
= R
"(
142 /// Parse a type registered to this dialect.
143 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
145 /// Print a type registered to this dialect.
146 void printType(::mlir::Type type,
147 ::mlir::DialectAsmPrinter &os) const override;
150 /// The code block for the canonicalization pattern registration hook.
151 static const char *const canonicalizerDecl
= R
"(
152 /// Register canonicalization patterns.
153 void getCanonicalizationPatterns(
154 ::mlir::RewritePatternSet &results) const override;
157 /// The code block for the constant materializer hook.
158 static const char *const constantMaterializerDecl
= R
"(
159 /// Materialize a single constant operation from a given attribute value with
160 /// the desired resultant type.
161 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
162 ::mlir::Attribute value,
164 ::mlir::Location loc) override;
167 /// The code block for the operation attribute verifier hook.
168 static const char *const opAttrVerifierDecl
= R
"(
169 /// Provides a hook for verifying dialect attributes attached to the given
171 ::llvm::LogicalResult verifyOperationAttribute(
172 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
175 /// The code block for the region argument attribute verifier hook.
176 static const char *const regionArgAttrVerifierDecl
= R
"(
177 /// Provides a hook for verifying dialect attributes attached to the given
178 /// op's region argument.
179 ::llvm::LogicalResult verifyRegionArgAttribute(
180 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
181 ::mlir::NamedAttribute attribute) override;
184 /// The code block for the region result attribute verifier hook.
185 static const char *const regionResultAttrVerifierDecl
= R
"(
186 /// Provides a hook for verifying dialect attributes attached to the given
187 /// op's region result.
188 ::llvm::LogicalResult verifyRegionResultAttribute(
189 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
190 ::mlir::NamedAttribute attribute) override;
193 /// The code block for the op interface fallback hook.
194 static const char *const operationInterfaceFallbackDecl
= R
"(
195 /// Provides a hook for op interface.
196 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
197 mlir::OperationName opName) override;
200 /// The code block for the discardable attribute helper.
201 static const char *const discardableAttrHelperDecl
= R
"(
202 /// Helper to manage the discardable attribute `{1}`.
203 class {0}AttrHelper {{
204 ::mlir::StringAttr name;
206 static constexpr ::llvm::StringLiteral getNameStr() {{
209 constexpr ::mlir::StringAttr getName() {{
213 {0}AttrHelper(::mlir::MLIRContext *ctx)
214 : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
216 {2} getAttr(::mlir::Operation *op) {{
217 return op->getAttrOfType<{2}>(name);
219 void setAttr(::mlir::Operation *op, {2} val) {{
220 op->setAttr(name, val);
222 bool isAttrPresent(::mlir::Operation *op) {{
223 return op->hasAttrOfType<{2}>(name);
225 void removeAttr(::mlir::Operation *op) {{
226 assert(op->hasAttrOfType<{2}>(name));
227 op->removeAttr(name);
230 {0}AttrHelper get{0}AttrHelper() {
234 {0}AttrHelper {3}AttrName;
238 /// Generate the declaration for the given dialect class.
239 static void emitDialectDecl(Dialect
&dialect
, raw_ostream
&os
) {
240 // Emit all nested namespaces.
242 NamespaceEmitter
nsEmitter(os
, dialect
);
244 // Emit the start of the decl.
245 std::string cppName
= dialect
.getCppClassName();
246 StringRef superClassName
=
247 dialect
.isExtensible() ? "ExtensibleDialect" : "Dialect";
248 os
<< llvm::formatv(dialectDeclBeginStr
, cppName
, dialect
.getName(),
251 // If the dialect requested the default attribute printer and parser, emit
252 // the declarations for the hooks.
253 if (dialect
.useDefaultAttributePrinterParser())
254 os
<< attrParserDecl
;
255 // If the dialect requested the default type printer and parser, emit the
256 // delcarations for the hooks.
257 if (dialect
.useDefaultTypePrinterParser())
258 os
<< typeParserDecl
;
260 // Add the decls for the various features of the dialect.
261 if (dialect
.hasCanonicalizer())
262 os
<< canonicalizerDecl
;
263 if (dialect
.hasConstantMaterializer())
264 os
<< constantMaterializerDecl
;
265 if (dialect
.hasOperationAttrVerify())
266 os
<< opAttrVerifierDecl
;
267 if (dialect
.hasRegionArgAttrVerify())
268 os
<< regionArgAttrVerifierDecl
;
269 if (dialect
.hasRegionResultAttrVerify())
270 os
<< regionResultAttrVerifierDecl
;
271 if (dialect
.hasOperationInterfaceFallback())
272 os
<< operationInterfaceFallbackDecl
;
274 const llvm::DagInit
*discardableAttrDag
=
275 dialect
.getDiscardableAttributes();
276 SmallVector
<std::pair
<std::string
, std::string
>> discardableAttributes
;
277 populateDiscardableAttributes(dialect
, discardableAttrDag
,
278 discardableAttributes
);
280 for (const auto &attrPair
: discardableAttributes
) {
281 std::string camelNameUpper
= llvm::convertToCamelFromSnakeCase(
282 attrPair
.first
, /*capitalizeFirst=*/true);
283 std::string camelName
= llvm::convertToCamelFromSnakeCase(
284 attrPair
.first
, /*capitalizeFirst=*/false);
285 os
<< llvm::formatv(discardableAttrHelperDecl
, camelNameUpper
,
286 attrPair
.first
, attrPair
.second
, camelName
,
290 if (std::optional
<StringRef
> extraDecl
= dialect
.getExtraClassDeclaration())
293 // End the dialect decl.
296 if (!dialect
.getCppNamespace().empty())
297 os
<< "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect
.getCppNamespace()
298 << "::" << dialect
.getCppClassName() << ")\n";
301 static bool emitDialectDecls(const RecordKeeper
&records
, raw_ostream
&os
) {
302 emitSourceFileHeader("Dialect Declarations", os
, records
);
304 auto dialectDefs
= records
.getAllDerivedDefinitions("Dialect");
305 if (dialectDefs
.empty())
308 SmallVector
<Dialect
> dialects(dialectDefs
.begin(), dialectDefs
.end());
309 std::optional
<Dialect
> dialect
= findDialectToGenerate(dialects
);
312 emitDialectDecl(*dialect
, os
);
316 //===----------------------------------------------------------------------===//
317 // GEN: Dialect definitions
318 //===----------------------------------------------------------------------===//
320 /// The code block to generate a dialect constructor definition.
322 /// {0}: The name of the dialect class.
323 /// {1}: Initialization code that is emitted in the ctor body before calling
324 /// initialize(), such as dependent dialect registration.
325 /// {2}: The dialect parent class.
326 /// {3}: Extra members to initialize
327 static const char *const dialectConstructorStr
= R
"(
328 {0}::{0}(::mlir::MLIRContext *context)
329 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
337 /// The code block to generate a default destructor definition.
339 /// {0}: The name of the dialect class.
340 static const char *const dialectDestructorStr
= R
"(
341 {0}::~{0}() = default;
345 static void emitDialectDef(Dialect
&dialect
, const RecordKeeper
&records
,
347 std::string cppClassName
= dialect
.getCppClassName();
349 // Emit the TypeID explicit specializations to have a single symbol def.
350 if (!dialect
.getCppNamespace().empty())
351 os
<< "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect
.getCppNamespace()
352 << "::" << cppClassName
<< ")\n";
354 // Emit all nested namespaces.
355 NamespaceEmitter
nsEmitter(os
, dialect
);
357 /// Build the list of dependent dialects.
358 std::string dependentDialectRegistrations
;
360 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
362 dialect
.getDependentDialects(), dialectsOs
,
363 [&](StringRef dependentDialect
) {
364 dialectsOs
<< llvm::formatv(dialectRegistrationTemplate
,
370 // Emit the constructor and destructor.
371 StringRef superClassName
=
372 dialect
.isExtensible() ? "ExtensibleDialect" : "Dialect";
374 const llvm::DagInit
*discardableAttrDag
= dialect
.getDiscardableAttributes();
375 SmallVector
<std::pair
<std::string
, std::string
>> discardableAttributes
;
376 populateDiscardableAttributes(dialect
, discardableAttrDag
,
377 discardableAttributes
);
378 std::string discardableAttributesInit
;
379 for (const auto &attrPair
: discardableAttributes
) {
380 std::string camelName
= llvm::convertToCamelFromSnakeCase(
381 attrPair
.first
, /*capitalizeFirst=*/false);
382 llvm::raw_string_ostream
os(discardableAttributesInit
);
383 os
<< ", " << camelName
<< "AttrName(context)";
386 os
<< llvm::formatv(dialectConstructorStr
, cppClassName
,
387 dependentDialectRegistrations
, superClassName
,
388 discardableAttributesInit
);
389 if (!dialect
.hasNonDefaultDestructor())
390 os
<< llvm::formatv(dialectDestructorStr
, cppClassName
);
393 static bool emitDialectDefs(const RecordKeeper
&records
, raw_ostream
&os
) {
394 emitSourceFileHeader("Dialect Definitions", os
, records
);
396 auto dialectDefs
= records
.getAllDerivedDefinitions("Dialect");
397 if (dialectDefs
.empty())
400 SmallVector
<Dialect
> dialects(dialectDefs
.begin(), dialectDefs
.end());
401 std::optional
<Dialect
> dialect
= findDialectToGenerate(dialects
);
404 emitDialectDef(*dialect
, records
, os
);
408 //===----------------------------------------------------------------------===//
409 // GEN: Dialect registration hooks
410 //===----------------------------------------------------------------------===//
412 static mlir::GenRegistration
413 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
414 [](const RecordKeeper
&records
, raw_ostream
&os
) {
415 return emitDialectDecls(records
, os
);
418 static mlir::GenRegistration
419 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
420 [](const RecordKeeper
&records
, raw_ostream
&os
) {
421 return emitDialectDefs(records
, os
);