[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / DialectGen.cpp
blobf22434f755abe3cc82179290f1fd1303d2168ab6
1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
11 //===----------------------------------------------------------------------===//
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
31 using namespace mlir;
32 using namespace mlir::tblgen;
34 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
35 llvm::cl::opt<std::string>
36 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
37 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
39 /// Utility iterator used for filtering records for a specific dialect.
40 namespace {
41 using DialectFilterIterator =
42 llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
43 std::function<bool(const llvm::Record *)>>;
44 } // namespace
46 /// Given a set of records for a T, filter the ones that correspond to
47 /// the given dialect.
48 template <typename T>
49 static iterator_range<DialectFilterIterator>
50 filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
51 auto filterFn = [&](const llvm::Record *record) {
52 return T(record).getDialect() == dialect;
54 return {DialectFilterIterator(records.begin(), records.end(), filterFn),
55 DialectFilterIterator(records.end(), records.end(), filterFn)};
58 std::optional<Dialect>
59 tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
60 if (dialects.empty()) {
61 llvm::errs() << "no dialect was found\n";
62 return std::nullopt;
65 // Select the dialect to gen for.
66 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
67 return dialects.front();
69 if (selectedDialect.getNumOccurrences() == 0) {
70 llvm::errs() << "when more than 1 dialect is present, one must be selected "
71 "via '-dialect'\n";
72 return std::nullopt;
75 const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
76 return dialect.getName() == selectedDialect;
77 });
78 if (dialectIt == dialects.end()) {
79 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
80 return std::nullopt;
82 return *dialectIt;
85 //===----------------------------------------------------------------------===//
86 // GEN: Dialect declarations
87 //===----------------------------------------------------------------------===//
89 /// The code block for the start of a dialect class declaration.
90 ///
91 /// {0}: The name of the dialect class.
92 /// {1}: The dialect namespace.
93 /// {2}: The dialect parent class.
94 static const char *const dialectDeclBeginStr = R"(
95 class {0} : public ::mlir::{2} {
96 explicit {0}(::mlir::MLIRContext *context);
98 void initialize();
99 friend class ::mlir::MLIRContext;
100 public:
101 ~{0}() override;
102 static constexpr ::llvm::StringLiteral getDialectNamespace() {
103 return ::llvm::StringLiteral("{1}");
107 /// Registration for a single dependent dialect: to be inserted in the ctor
108 /// above for each dependent dialect.
109 const char *const dialectRegistrationTemplate = R"(
110 getContext()->loadDialect<{0}>();
113 /// The code block for the attribute parser/printer hooks.
114 static const char *const attrParserDecl = R"(
115 /// Parse an attribute registered to this dialect.
116 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
117 ::mlir::Type type) const override;
119 /// Print an attribute registered to this dialect.
120 void printAttribute(::mlir::Attribute attr,
121 ::mlir::DialectAsmPrinter &os) const override;
124 /// The code block for the type parser/printer hooks.
125 static const char *const typeParserDecl = R"(
126 /// Parse a type registered to this dialect.
127 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
129 /// Print a type registered to this dialect.
130 void printType(::mlir::Type type,
131 ::mlir::DialectAsmPrinter &os) const override;
134 /// The code block for the canonicalization pattern registration hook.
135 static const char *const canonicalizerDecl = R"(
136 /// Register canonicalization patterns.
137 void getCanonicalizationPatterns(
138 ::mlir::RewritePatternSet &results) const override;
141 /// The code block for the constant materializer hook.
142 static const char *const constantMaterializerDecl = R"(
143 /// Materialize a single constant operation from a given attribute value with
144 /// the desired resultant type.
145 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
146 ::mlir::Attribute value,
147 ::mlir::Type type,
148 ::mlir::Location loc) override;
151 /// The code block for the operation attribute verifier hook.
152 static const char *const opAttrVerifierDecl = R"(
153 /// Provides a hook for verifying dialect attributes attached to the given
154 /// op.
155 ::mlir::LogicalResult verifyOperationAttribute(
156 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
159 /// The code block for the region argument attribute verifier hook.
160 static const char *const regionArgAttrVerifierDecl = R"(
161 /// Provides a hook for verifying dialect attributes attached to the given
162 /// op's region argument.
163 ::mlir::LogicalResult verifyRegionArgAttribute(
164 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
165 ::mlir::NamedAttribute attribute) override;
168 /// The code block for the region result attribute verifier hook.
169 static const char *const regionResultAttrVerifierDecl = R"(
170 /// Provides a hook for verifying dialect attributes attached to the given
171 /// op's region result.
172 ::mlir::LogicalResult verifyRegionResultAttribute(
173 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
174 ::mlir::NamedAttribute attribute) override;
177 /// The code block for the op interface fallback hook.
178 static const char *const operationInterfaceFallbackDecl = R"(
179 /// Provides a hook for op interface.
180 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
181 mlir::OperationName opName) override;
184 /// Generate the declaration for the given dialect class.
185 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
186 // Emit all nested namespaces.
188 NamespaceEmitter nsEmitter(os, dialect);
190 // Emit the start of the decl.
191 std::string cppName = dialect.getCppClassName();
192 StringRef superClassName =
193 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
194 os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
195 superClassName);
197 // If the dialect requested the default attribute printer and parser, emit
198 // the declarations for the hooks.
199 if (dialect.useDefaultAttributePrinterParser())
200 os << attrParserDecl;
201 // If the dialect requested the default type printer and parser, emit the
202 // delcarations for the hooks.
203 if (dialect.useDefaultTypePrinterParser())
204 os << typeParserDecl;
206 // Add the decls for the various features of the dialect.
207 if (dialect.hasCanonicalizer())
208 os << canonicalizerDecl;
209 if (dialect.hasConstantMaterializer())
210 os << constantMaterializerDecl;
211 if (dialect.hasOperationAttrVerify())
212 os << opAttrVerifierDecl;
213 if (dialect.hasRegionArgAttrVerify())
214 os << regionArgAttrVerifierDecl;
215 if (dialect.hasRegionResultAttrVerify())
216 os << regionResultAttrVerifierDecl;
217 if (dialect.hasOperationInterfaceFallback())
218 os << operationInterfaceFallbackDecl;
219 if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
220 os << *extraDecl;
222 // End the dialect decl.
223 os << "};\n";
225 if (!dialect.getCppNamespace().empty())
226 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
227 << "::" << dialect.getCppClassName() << ")\n";
230 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
231 raw_ostream &os) {
232 emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
234 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
235 if (dialectDefs.empty())
236 return false;
238 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
239 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
240 if (!dialect)
241 return true;
242 emitDialectDecl(*dialect, os);
243 return false;
246 //===----------------------------------------------------------------------===//
247 // GEN: Dialect definitions
248 //===----------------------------------------------------------------------===//
250 /// The code block to generate a dialect constructor definition.
252 /// {0}: The name of the dialect class.
253 /// {1}: initialization code that is emitted in the ctor body before calling
254 /// initialize().
255 /// {2}: The dialect parent class.
256 static const char *const dialectConstructorStr = R"(
257 {0}::{0}(::mlir::MLIRContext *context)
258 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
260 initialize();
264 /// The code block to generate a default desturctor definition.
266 /// {0}: The name of the dialect class.
267 static const char *const dialectDestructorStr = R"(
268 {0}::~{0}() = default;
272 static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
273 std::string cppClassName = dialect.getCppClassName();
275 // Emit the TypeID explicit specializations to have a single symbol def.
276 if (!dialect.getCppNamespace().empty())
277 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
278 << "::" << cppClassName << ")\n";
280 // Emit all nested namespaces.
281 NamespaceEmitter nsEmitter(os, dialect);
283 /// Build the list of dependent dialects.
284 std::string dependentDialectRegistrations;
286 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
287 for (StringRef dependentDialect : dialect.getDependentDialects())
288 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
289 dependentDialect);
292 // Emit the constructor and destructor.
293 StringRef superClassName =
294 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
295 os << llvm::formatv(dialectConstructorStr, cppClassName,
296 dependentDialectRegistrations, superClassName);
297 if (!dialect.hasNonDefaultDestructor())
298 os << llvm::formatv(dialectDestructorStr, cppClassName);
301 static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
302 raw_ostream &os) {
303 emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
305 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
306 if (dialectDefs.empty())
307 return false;
309 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
310 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
311 if (!dialect)
312 return true;
313 emitDialectDef(*dialect, os);
314 return false;
317 //===----------------------------------------------------------------------===//
318 // GEN: Dialect registration hooks
319 //===----------------------------------------------------------------------===//
321 static mlir::GenRegistration
322 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
323 [](const llvm::RecordKeeper &records, raw_ostream &os) {
324 return emitDialectDecls(records, os);
327 static mlir::GenRegistration
328 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
329 [](const llvm::RecordKeeper &records, raw_ostream &os) {
330 return emitDialectDefs(records, os);