[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / tools / mlir-tblgen / DialectGen.cpp
blob414cad5e1dcc2e798472004ccec67c37852cbba7
1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
11 //===----------------------------------------------------------------------===//
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
31 using namespace mlir;
32 using namespace mlir::tblgen;
33 using llvm::Record;
34 using llvm::RecordKeeper;
36 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
37 llvm::cl::opt<std::string>
38 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
39 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
41 /// Utility iterator used for filtering records for a specific dialect.
42 namespace {
43 using DialectFilterIterator =
44 llvm::filter_iterator<ArrayRef<Record *>::iterator,
45 std::function<bool(const Record *)>>;
46 } // namespace
48 static void populateDiscardableAttributes(
49 Dialect &dialect, const llvm::DagInit *discardableAttrDag,
50 SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
51 for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
52 const llvm::Init *arg = discardableAttrDag->getArg(i);
54 StringRef givenName = discardableAttrDag->getArgNameStr(i);
55 if (givenName.empty())
56 PrintFatalError(dialect.getDef()->getLoc(),
57 "discardable attributes must be named");
58 discardableAttributes.push_back(
59 {givenName.str(), arg->getAsUnquotedString()});
63 /// Given a set of records for a T, filter the ones that correspond to
64 /// the given dialect.
65 template <typename T>
66 static iterator_range<DialectFilterIterator>
67 filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
68 auto filterFn = [&](const Record *record) {
69 return T(record).getDialect() == dialect;
71 return {DialectFilterIterator(records.begin(), records.end(), filterFn),
72 DialectFilterIterator(records.end(), records.end(), filterFn)};
75 std::optional<Dialect>
76 tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
77 if (dialects.empty()) {
78 llvm::errs() << "no dialect was found\n";
79 return std::nullopt;
82 // Select the dialect to gen for.
83 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
84 return dialects.front();
86 if (selectedDialect.getNumOccurrences() == 0) {
87 llvm::errs() << "when more than 1 dialect is present, one must be selected "
88 "via '-dialect'\n";
89 return std::nullopt;
92 const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
93 return dialect.getName() == selectedDialect;
94 });
95 if (dialectIt == dialects.end()) {
96 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
97 return std::nullopt;
99 return *dialectIt;
102 //===----------------------------------------------------------------------===//
103 // GEN: Dialect declarations
104 //===----------------------------------------------------------------------===//
106 /// The code block for the start of a dialect class declaration.
108 /// {0}: The name of the dialect class.
109 /// {1}: The dialect namespace.
110 /// {2}: The dialect parent class.
111 static const char *const dialectDeclBeginStr = R"(
112 class {0} : public ::mlir::{2} {
113 explicit {0}(::mlir::MLIRContext *context);
115 void initialize();
116 friend class ::mlir::MLIRContext;
117 public:
118 ~{0}() override;
119 static constexpr ::llvm::StringLiteral getDialectNamespace() {
120 return ::llvm::StringLiteral("{1}");
124 /// Registration for a single dependent dialect: to be inserted in the ctor
125 /// above for each dependent dialect.
126 const char *const dialectRegistrationTemplate =
127 "getContext()->loadDialect<{0}>();";
129 /// The code block for the attribute parser/printer hooks.
130 static const char *const attrParserDecl = R"(
131 /// Parse an attribute registered to this dialect.
132 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
133 ::mlir::Type type) const override;
135 /// Print an attribute registered to this dialect.
136 void printAttribute(::mlir::Attribute attr,
137 ::mlir::DialectAsmPrinter &os) const override;
140 /// The code block for the type parser/printer hooks.
141 static const char *const typeParserDecl = R"(
142 /// Parse a type registered to this dialect.
143 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
145 /// Print a type registered to this dialect.
146 void printType(::mlir::Type type,
147 ::mlir::DialectAsmPrinter &os) const override;
150 /// The code block for the canonicalization pattern registration hook.
151 static const char *const canonicalizerDecl = R"(
152 /// Register canonicalization patterns.
153 void getCanonicalizationPatterns(
154 ::mlir::RewritePatternSet &results) const override;
157 /// The code block for the constant materializer hook.
158 static const char *const constantMaterializerDecl = R"(
159 /// Materialize a single constant operation from a given attribute value with
160 /// the desired resultant type.
161 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
162 ::mlir::Attribute value,
163 ::mlir::Type type,
164 ::mlir::Location loc) override;
167 /// The code block for the operation attribute verifier hook.
168 static const char *const opAttrVerifierDecl = R"(
169 /// Provides a hook for verifying dialect attributes attached to the given
170 /// op.
171 ::llvm::LogicalResult verifyOperationAttribute(
172 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
175 /// The code block for the region argument attribute verifier hook.
176 static const char *const regionArgAttrVerifierDecl = R"(
177 /// Provides a hook for verifying dialect attributes attached to the given
178 /// op's region argument.
179 ::llvm::LogicalResult verifyRegionArgAttribute(
180 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
181 ::mlir::NamedAttribute attribute) override;
184 /// The code block for the region result attribute verifier hook.
185 static const char *const regionResultAttrVerifierDecl = R"(
186 /// Provides a hook for verifying dialect attributes attached to the given
187 /// op's region result.
188 ::llvm::LogicalResult verifyRegionResultAttribute(
189 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
190 ::mlir::NamedAttribute attribute) override;
193 /// The code block for the op interface fallback hook.
194 static const char *const operationInterfaceFallbackDecl = R"(
195 /// Provides a hook for op interface.
196 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
197 mlir::OperationName opName) override;
200 /// The code block for the discardable attribute helper.
201 static const char *const discardableAttrHelperDecl = R"(
202 /// Helper to manage the discardable attribute `{1}`.
203 class {0}AttrHelper {{
204 ::mlir::StringAttr name;
205 public:
206 static constexpr ::llvm::StringLiteral getNameStr() {{
207 return "{4}.{1}";
209 constexpr ::mlir::StringAttr getName() {{
210 return name;
213 {0}AttrHelper(::mlir::MLIRContext *ctx)
214 : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
216 {2} getAttr(::mlir::Operation *op) {{
217 return op->getAttrOfType<{2}>(name);
219 void setAttr(::mlir::Operation *op, {2} val) {{
220 op->setAttr(name, val);
222 bool isAttrPresent(::mlir::Operation *op) {{
223 return op->hasAttrOfType<{2}>(name);
225 void removeAttr(::mlir::Operation *op) {{
226 assert(op->hasAttrOfType<{2}>(name));
227 op->removeAttr(name);
230 {0}AttrHelper get{0}AttrHelper() {
231 return {3}AttrName;
233 private:
234 {0}AttrHelper {3}AttrName;
235 public:
238 /// Generate the declaration for the given dialect class.
239 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
240 // Emit all nested namespaces.
242 NamespaceEmitter nsEmitter(os, dialect);
244 // Emit the start of the decl.
245 std::string cppName = dialect.getCppClassName();
246 StringRef superClassName =
247 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
248 os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
249 superClassName);
251 // If the dialect requested the default attribute printer and parser, emit
252 // the declarations for the hooks.
253 if (dialect.useDefaultAttributePrinterParser())
254 os << attrParserDecl;
255 // If the dialect requested the default type printer and parser, emit the
256 // delcarations for the hooks.
257 if (dialect.useDefaultTypePrinterParser())
258 os << typeParserDecl;
260 // Add the decls for the various features of the dialect.
261 if (dialect.hasCanonicalizer())
262 os << canonicalizerDecl;
263 if (dialect.hasConstantMaterializer())
264 os << constantMaterializerDecl;
265 if (dialect.hasOperationAttrVerify())
266 os << opAttrVerifierDecl;
267 if (dialect.hasRegionArgAttrVerify())
268 os << regionArgAttrVerifierDecl;
269 if (dialect.hasRegionResultAttrVerify())
270 os << regionResultAttrVerifierDecl;
271 if (dialect.hasOperationInterfaceFallback())
272 os << operationInterfaceFallbackDecl;
274 const llvm::DagInit *discardableAttrDag =
275 dialect.getDiscardableAttributes();
276 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
277 populateDiscardableAttributes(dialect, discardableAttrDag,
278 discardableAttributes);
280 for (const auto &attrPair : discardableAttributes) {
281 std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
282 attrPair.first, /*capitalizeFirst=*/true);
283 std::string camelName = llvm::convertToCamelFromSnakeCase(
284 attrPair.first, /*capitalizeFirst=*/false);
285 os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
286 attrPair.first, attrPair.second, camelName,
287 dialect.getName());
290 if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
291 os << *extraDecl;
293 // End the dialect decl.
294 os << "};\n";
296 if (!dialect.getCppNamespace().empty())
297 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
298 << "::" << dialect.getCppClassName() << ")\n";
301 static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) {
302 emitSourceFileHeader("Dialect Declarations", os, records);
304 auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
305 if (dialectDefs.empty())
306 return false;
308 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
309 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
310 if (!dialect)
311 return true;
312 emitDialectDecl(*dialect, os);
313 return false;
316 //===----------------------------------------------------------------------===//
317 // GEN: Dialect definitions
318 //===----------------------------------------------------------------------===//
320 /// The code block to generate a dialect constructor definition.
322 /// {0}: The name of the dialect class.
323 /// {1}: Initialization code that is emitted in the ctor body before calling
324 /// initialize(), such as dependent dialect registration.
325 /// {2}: The dialect parent class.
326 /// {3}: Extra members to initialize
327 static const char *const dialectConstructorStr = R"(
328 {0}::{0}(::mlir::MLIRContext *context)
329 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
333 initialize();
337 /// The code block to generate a default destructor definition.
339 /// {0}: The name of the dialect class.
340 static const char *const dialectDestructorStr = R"(
341 {0}::~{0}() = default;
345 static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
346 raw_ostream &os) {
347 std::string cppClassName = dialect.getCppClassName();
349 // Emit the TypeID explicit specializations to have a single symbol def.
350 if (!dialect.getCppNamespace().empty())
351 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
352 << "::" << cppClassName << ")\n";
354 // Emit all nested namespaces.
355 NamespaceEmitter nsEmitter(os, dialect);
357 /// Build the list of dependent dialects.
358 std::string dependentDialectRegistrations;
360 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
361 llvm::interleave(
362 dialect.getDependentDialects(), dialectsOs,
363 [&](StringRef dependentDialect) {
364 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
365 dependentDialect);
367 "\n ");
370 // Emit the constructor and destructor.
371 StringRef superClassName =
372 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
374 const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
375 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
376 populateDiscardableAttributes(dialect, discardableAttrDag,
377 discardableAttributes);
378 std::string discardableAttributesInit;
379 for (const auto &attrPair : discardableAttributes) {
380 std::string camelName = llvm::convertToCamelFromSnakeCase(
381 attrPair.first, /*capitalizeFirst=*/false);
382 llvm::raw_string_ostream os(discardableAttributesInit);
383 os << ", " << camelName << "AttrName(context)";
386 os << llvm::formatv(dialectConstructorStr, cppClassName,
387 dependentDialectRegistrations, superClassName,
388 discardableAttributesInit);
389 if (!dialect.hasNonDefaultDestructor())
390 os << llvm::formatv(dialectDestructorStr, cppClassName);
393 static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
394 emitSourceFileHeader("Dialect Definitions", os, records);
396 auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
397 if (dialectDefs.empty())
398 return false;
400 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
401 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
402 if (!dialect)
403 return true;
404 emitDialectDef(*dialect, records, os);
405 return false;
408 //===----------------------------------------------------------------------===//
409 // GEN: Dialect registration hooks
410 //===----------------------------------------------------------------------===//
412 static mlir::GenRegistration
413 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
414 [](const RecordKeeper &records, raw_ostream &os) {
415 return emitDialectDecls(records, os);
418 static mlir::GenRegistration
419 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
420 [](const RecordKeeper &records, raw_ostream &os) {
421 return emitDialectDefs(records, os);