1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // DialectGen uses the description of dialects to generate C++ definitions.
11 //===----------------------------------------------------------------------===//
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
32 using namespace mlir::tblgen
;
34 static llvm::cl::OptionCategory
dialectGenCat("Options for -gen-dialect-*");
35 llvm::cl::opt
<std::string
>
36 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
37 llvm::cl::cat(dialectGenCat
), llvm::cl::CommaSeparated
);
39 /// Utility iterator used for filtering records for a specific dialect.
41 using DialectFilterIterator
=
42 llvm::filter_iterator
<ArrayRef
<llvm::Record
*>::iterator
,
43 std::function
<bool(const llvm::Record
*)>>;
46 /// Given a set of records for a T, filter the ones that correspond to
47 /// the given dialect.
49 static iterator_range
<DialectFilterIterator
>
50 filterForDialect(ArrayRef
<llvm::Record
*> records
, Dialect
&dialect
) {
51 auto filterFn
= [&](const llvm::Record
*record
) {
52 return T(record
).getDialect() == dialect
;
54 return {DialectFilterIterator(records
.begin(), records
.end(), filterFn
),
55 DialectFilterIterator(records
.end(), records
.end(), filterFn
)};
58 std::optional
<Dialect
>
59 tblgen::findDialectToGenerate(ArrayRef
<Dialect
> dialects
) {
60 if (dialects
.empty()) {
61 llvm::errs() << "no dialect was found\n";
65 // Select the dialect to gen for.
66 if (dialects
.size() == 1 && selectedDialect
.getNumOccurrences() == 0)
67 return dialects
.front();
69 if (selectedDialect
.getNumOccurrences() == 0) {
70 llvm::errs() << "when more than 1 dialect is present, one must be selected "
75 const auto *dialectIt
= llvm::find_if(dialects
, [](const Dialect
&dialect
) {
76 return dialect
.getName() == selectedDialect
;
78 if (dialectIt
== dialects
.end()) {
79 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
85 //===----------------------------------------------------------------------===//
86 // GEN: Dialect declarations
87 //===----------------------------------------------------------------------===//
89 /// The code block for the start of a dialect class declaration.
91 /// {0}: The name of the dialect class.
92 /// {1}: The dialect namespace.
93 /// {2}: The dialect parent class.
94 static const char *const dialectDeclBeginStr
= R
"(
95 class {0} : public ::mlir::{2} {
96 explicit {0}(::mlir::MLIRContext *context);
99 friend class ::mlir::MLIRContext;
102 static constexpr ::llvm::StringLiteral getDialectNamespace() {
103 return ::llvm::StringLiteral("{1}");
107 /// Registration for a single dependent dialect: to be inserted in the ctor
108 /// above for each dependent dialect.
109 const char *const dialectRegistrationTemplate
= R
"(
110 getContext()->loadDialect<{0}>();
113 /// The code block for the attribute parser/printer hooks.
114 static const char *const attrParserDecl
= R
"(
115 /// Parse an attribute registered to this dialect.
116 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
117 ::mlir::Type type) const override;
119 /// Print an attribute registered to this dialect.
120 void printAttribute(::mlir::Attribute attr,
121 ::mlir::DialectAsmPrinter &os) const override;
124 /// The code block for the type parser/printer hooks.
125 static const char *const typeParserDecl
= R
"(
126 /// Parse a type registered to this dialect.
127 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
129 /// Print a type registered to this dialect.
130 void printType(::mlir::Type type,
131 ::mlir::DialectAsmPrinter &os) const override;
134 /// The code block for the canonicalization pattern registration hook.
135 static const char *const canonicalizerDecl
= R
"(
136 /// Register canonicalization patterns.
137 void getCanonicalizationPatterns(
138 ::mlir::RewritePatternSet &results) const override;
141 /// The code block for the constant materializer hook.
142 static const char *const constantMaterializerDecl
= R
"(
143 /// Materialize a single constant operation from a given attribute value with
144 /// the desired resultant type.
145 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
146 ::mlir::Attribute value,
148 ::mlir::Location loc) override;
151 /// The code block for the operation attribute verifier hook.
152 static const char *const opAttrVerifierDecl
= R
"(
153 /// Provides a hook for verifying dialect attributes attached to the given
155 ::mlir::LogicalResult verifyOperationAttribute(
156 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
159 /// The code block for the region argument attribute verifier hook.
160 static const char *const regionArgAttrVerifierDecl
= R
"(
161 /// Provides a hook for verifying dialect attributes attached to the given
162 /// op's region argument.
163 ::mlir::LogicalResult verifyRegionArgAttribute(
164 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
165 ::mlir::NamedAttribute attribute) override;
168 /// The code block for the region result attribute verifier hook.
169 static const char *const regionResultAttrVerifierDecl
= R
"(
170 /// Provides a hook for verifying dialect attributes attached to the given
171 /// op's region result.
172 ::mlir::LogicalResult verifyRegionResultAttribute(
173 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
174 ::mlir::NamedAttribute attribute) override;
177 /// The code block for the op interface fallback hook.
178 static const char *const operationInterfaceFallbackDecl
= R
"(
179 /// Provides a hook for op interface.
180 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
181 mlir::OperationName opName) override;
184 /// Generate the declaration for the given dialect class.
185 static void emitDialectDecl(Dialect
&dialect
, raw_ostream
&os
) {
186 // Emit all nested namespaces.
188 NamespaceEmitter
nsEmitter(os
, dialect
);
190 // Emit the start of the decl.
191 std::string cppName
= dialect
.getCppClassName();
192 StringRef superClassName
=
193 dialect
.isExtensible() ? "ExtensibleDialect" : "Dialect";
194 os
<< llvm::formatv(dialectDeclBeginStr
, cppName
, dialect
.getName(),
197 // If the dialect requested the default attribute printer and parser, emit
198 // the declarations for the hooks.
199 if (dialect
.useDefaultAttributePrinterParser())
200 os
<< attrParserDecl
;
201 // If the dialect requested the default type printer and parser, emit the
202 // delcarations for the hooks.
203 if (dialect
.useDefaultTypePrinterParser())
204 os
<< typeParserDecl
;
206 // Add the decls for the various features of the dialect.
207 if (dialect
.hasCanonicalizer())
208 os
<< canonicalizerDecl
;
209 if (dialect
.hasConstantMaterializer())
210 os
<< constantMaterializerDecl
;
211 if (dialect
.hasOperationAttrVerify())
212 os
<< opAttrVerifierDecl
;
213 if (dialect
.hasRegionArgAttrVerify())
214 os
<< regionArgAttrVerifierDecl
;
215 if (dialect
.hasRegionResultAttrVerify())
216 os
<< regionResultAttrVerifierDecl
;
217 if (dialect
.hasOperationInterfaceFallback())
218 os
<< operationInterfaceFallbackDecl
;
219 if (std::optional
<StringRef
> extraDecl
= dialect
.getExtraClassDeclaration())
222 // End the dialect decl.
225 if (!dialect
.getCppNamespace().empty())
226 os
<< "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect
.getCppNamespace()
227 << "::" << dialect
.getCppClassName() << ")\n";
230 static bool emitDialectDecls(const llvm::RecordKeeper
&recordKeeper
,
232 emitSourceFileHeader("Dialect Declarations", os
, recordKeeper
);
234 auto dialectDefs
= recordKeeper
.getAllDerivedDefinitions("Dialect");
235 if (dialectDefs
.empty())
238 SmallVector
<Dialect
> dialects(dialectDefs
.begin(), dialectDefs
.end());
239 std::optional
<Dialect
> dialect
= findDialectToGenerate(dialects
);
242 emitDialectDecl(*dialect
, os
);
246 //===----------------------------------------------------------------------===//
247 // GEN: Dialect definitions
248 //===----------------------------------------------------------------------===//
250 /// The code block to generate a dialect constructor definition.
252 /// {0}: The name of the dialect class.
253 /// {1}: initialization code that is emitted in the ctor body before calling
255 /// {2}: The dialect parent class.
256 static const char *const dialectConstructorStr
= R
"(
257 {0}::{0}(::mlir::MLIRContext *context)
258 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
264 /// The code block to generate a default desturctor definition.
266 /// {0}: The name of the dialect class.
267 static const char *const dialectDestructorStr
= R
"(
268 {0}::~{0}() = default;
272 static void emitDialectDef(Dialect
&dialect
, raw_ostream
&os
) {
273 std::string cppClassName
= dialect
.getCppClassName();
275 // Emit the TypeID explicit specializations to have a single symbol def.
276 if (!dialect
.getCppNamespace().empty())
277 os
<< "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect
.getCppNamespace()
278 << "::" << cppClassName
<< ")\n";
280 // Emit all nested namespaces.
281 NamespaceEmitter
nsEmitter(os
, dialect
);
283 /// Build the list of dependent dialects.
284 std::string dependentDialectRegistrations
;
286 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
287 for (StringRef dependentDialect
: dialect
.getDependentDialects())
288 dialectsOs
<< llvm::formatv(dialectRegistrationTemplate
,
292 // Emit the constructor and destructor.
293 StringRef superClassName
=
294 dialect
.isExtensible() ? "ExtensibleDialect" : "Dialect";
295 os
<< llvm::formatv(dialectConstructorStr
, cppClassName
,
296 dependentDialectRegistrations
, superClassName
);
297 if (!dialect
.hasNonDefaultDestructor())
298 os
<< llvm::formatv(dialectDestructorStr
, cppClassName
);
301 static bool emitDialectDefs(const llvm::RecordKeeper
&recordKeeper
,
303 emitSourceFileHeader("Dialect Definitions", os
, recordKeeper
);
305 auto dialectDefs
= recordKeeper
.getAllDerivedDefinitions("Dialect");
306 if (dialectDefs
.empty())
309 SmallVector
<Dialect
> dialects(dialectDefs
.begin(), dialectDefs
.end());
310 std::optional
<Dialect
> dialect
= findDialectToGenerate(dialects
);
313 emitDialectDef(*dialect
, os
);
317 //===----------------------------------------------------------------------===//
318 // GEN: Dialect registration hooks
319 //===----------------------------------------------------------------------===//
321 static mlir::GenRegistration
322 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
323 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
324 return emitDialectDecls(records
, os
);
327 static mlir::GenRegistration
328 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
329 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
330 return emitDialectDefs(records
, os
);