[mlir] Attempt to resolve edge cases in PassPipeline textual format (#118877)
[llvm-project.git] / mlir / tools / mlir-tblgen / OpInterfacesGen.cpp
blob1f1b1d9a3403918e2b5618e91696ec5ea14c335b
1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpInterfacesGen generates definitions for operation interfaces.
11 //===----------------------------------------------------------------------===//
13 #include "DocGenUtilities.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 #include "llvm/TableGen/TableGenBackend.h"
25 using namespace mlir;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using mlir::tblgen::Interface;
29 using mlir::tblgen::InterfaceMethod;
30 using mlir::tblgen::OpInterface;
32 /// Emit a string corresponding to a C++ type, followed by a space if necessary.
33 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
34 type = type.trim();
35 os << type;
36 if (type.back() != '&' && type.back() != '*')
37 os << " ";
38 return os;
41 /// Emit the method name and argument list for the given method. If 'addThisArg'
42 /// is true, then an argument is added to the beginning of the argument list for
43 /// the concrete value.
44 static void emitMethodNameAndArgs(const InterfaceMethod &method,
45 raw_ostream &os, StringRef valueType,
46 bool addThisArg, bool addConst) {
47 os << method.getName() << '(';
48 if (addThisArg) {
49 if (addConst)
50 os << "const ";
51 os << "const Concept *impl, ";
52 emitCPPType(valueType, os)
53 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
55 llvm::interleaveComma(method.getArguments(), os,
56 [&](const InterfaceMethod::Argument &arg) {
57 os << arg.type << " " << arg.name;
58 });
59 os << ')';
60 if (addConst)
61 os << " const";
64 /// Get an array of all OpInterface definitions but exclude those subclassing
65 /// "DeclareOpInterfaceMethods".
66 static std::vector<const Record *>
67 getAllInterfaceDefinitions(const RecordKeeper &records, StringRef name) {
68 std::vector<const Record *> defs =
69 records.getAllDerivedDefinitions((name + "Interface").str());
71 std::string declareName = ("Declare" + name + "InterfaceMethods").str();
72 llvm::erase_if(defs, [&](const Record *def) {
73 // Ignore any "declare methods" interfaces.
74 if (def->isSubClassOf(declareName))
75 return true;
76 // Ignore interfaces defined outside of the top-level file.
77 return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
78 llvm::SrcMgr.getMainFileID();
79 });
80 return defs;
83 namespace {
84 /// This struct is the base generator used when processing tablegen interfaces.
85 class InterfaceGenerator {
86 public:
87 bool emitInterfaceDefs();
88 bool emitInterfaceDecls();
89 bool emitInterfaceDocs();
91 protected:
92 InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
93 : defs(std::move(defs)), os(os) {}
95 void emitConceptDecl(const Interface &interface);
96 void emitModelDecl(const Interface &interface);
97 void emitModelMethodsDef(const Interface &interface);
98 void emitTraitDecl(const Interface &interface, StringRef interfaceName,
99 StringRef interfaceTraitsName);
100 void emitInterfaceDecl(const Interface &interface);
102 /// The set of interface records to emit.
103 std::vector<const Record *> defs;
104 // The stream to emit to.
105 raw_ostream &os;
106 /// The C++ value type of the interface, e.g. Operation*.
107 StringRef valueType;
108 /// The C++ base interface type.
109 StringRef interfaceBaseType;
110 /// The name of the typename for the value template.
111 StringRef valueTemplate;
112 /// The name of the substituion variable for the value.
113 StringRef substVar;
114 /// The format context to use for methods.
115 tblgen::FmtContext nonStaticMethodFmt;
116 tblgen::FmtContext traitMethodFmt;
117 tblgen::FmtContext extraDeclsFmt;
120 /// A specialized generator for attribute interfaces.
121 struct AttrInterfaceGenerator : public InterfaceGenerator {
122 AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
123 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
124 valueType = "::mlir::Attribute";
125 interfaceBaseType = "AttributeInterface";
126 valueTemplate = "ConcreteAttr";
127 substVar = "_attr";
128 StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
129 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
130 traitMethodFmt.addSubst(substVar,
131 "(*static_cast<const ConcreteAttr *>(this))");
132 extraDeclsFmt.addSubst(substVar, "(*this)");
135 /// A specialized generator for operation interfaces.
136 struct OpInterfaceGenerator : public InterfaceGenerator {
137 OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
138 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
139 valueType = "::mlir::Operation *";
140 interfaceBaseType = "OpInterface";
141 valueTemplate = "ConcreteOp";
142 substVar = "_op";
143 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
144 nonStaticMethodFmt.addSubst("_this", "impl")
145 .addSubst(substVar, castCode)
146 .withSelf(castCode);
147 traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
148 extraDeclsFmt.addSubst(substVar, "(*this)");
151 /// A specialized generator for type interfaces.
152 struct TypeInterfaceGenerator : public InterfaceGenerator {
153 TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
154 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
155 valueType = "::mlir::Type";
156 interfaceBaseType = "TypeInterface";
157 valueTemplate = "ConcreteType";
158 substVar = "_type";
159 StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
160 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
161 traitMethodFmt.addSubst(substVar,
162 "(*static_cast<const ConcreteType *>(this))");
163 extraDeclsFmt.addSubst(substVar, "(*this)");
166 } // namespace
168 //===----------------------------------------------------------------------===//
169 // GEN: Interface definitions
170 //===----------------------------------------------------------------------===//
172 static void emitInterfaceMethodDoc(const InterfaceMethod &method,
173 raw_ostream &os, StringRef prefix = "") {
174 if (std::optional<StringRef> description = method.getDescription())
175 tblgen::emitDescriptionComment(*description, os, prefix);
177 static void emitInterfaceDefMethods(StringRef interfaceQualName,
178 const Interface &interface,
179 StringRef valueType, const Twine &implValue,
180 raw_ostream &os, bool isOpInterface) {
181 for (auto &method : interface.getMethods()) {
182 emitInterfaceMethodDoc(method, os);
183 emitCPPType(method.getReturnType(), os);
184 os << interfaceQualName << "::";
185 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
186 /*addConst=*/!isOpInterface);
188 // Forward to the method on the concrete operation type.
189 os << " {\n return " << implValue << "->" << method.getName() << '(';
190 if (!method.isStatic()) {
191 os << implValue << ", ";
192 os << (isOpInterface ? "getOperation()" : "*this");
193 os << (method.arg_empty() ? "" : ", ");
195 llvm::interleaveComma(
196 method.getArguments(), os,
197 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
198 os << ");\n }\n";
202 static void emitInterfaceDef(const Interface &interface, StringRef valueType,
203 raw_ostream &os) {
204 std::string interfaceQualNameStr = interface.getFullyQualifiedName();
205 StringRef interfaceQualName = interfaceQualNameStr;
206 interfaceQualName.consume_front("::");
208 // Insert the method definitions.
209 bool isOpInterface = isa<OpInterface>(interface);
210 emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
211 os, isOpInterface);
213 // Insert the method definitions for base classes.
214 for (auto &base : interface.getBaseInterfaces()) {
215 emitInterfaceDefMethods(interfaceQualName, base, valueType,
216 "getImpl()->impl" + base.getName(), os,
217 isOpInterface);
221 bool InterfaceGenerator::emitInterfaceDefs() {
222 llvm::emitSourceFileHeader("Interface Definitions", os);
224 for (const auto *def : defs)
225 emitInterfaceDef(Interface(def), valueType, os);
226 return false;
229 //===----------------------------------------------------------------------===//
230 // GEN: Interface declarations
231 //===----------------------------------------------------------------------===//
233 void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
234 os << " struct Concept {\n";
236 // Insert each of the pure virtual concept methods.
237 os << " /// The methods defined by the interface.\n";
238 for (auto &method : interface.getMethods()) {
239 os << " ";
240 emitCPPType(method.getReturnType(), os);
241 os << "(*" << method.getName() << ")(";
242 if (!method.isStatic()) {
243 os << "const Concept *impl, ";
244 emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
246 llvm::interleaveComma(
247 method.getArguments(), os,
248 [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
249 os << ");\n";
252 // Insert a field containing a concept for each of the base interfaces.
253 auto baseInterfaces = interface.getBaseInterfaces();
254 if (!baseInterfaces.empty()) {
255 os << " /// The base classes of this interface.\n";
256 for (const auto &base : interface.getBaseInterfaces()) {
257 os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
258 << base.getName() << " = nullptr;\n";
261 // Define an "initialize" method that allows for the initialization of the
262 // base class concepts.
263 os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
264 "&interfaceMap) {\n";
265 std::string interfaceQualName = interface.getFullyQualifiedName();
266 for (const auto &base : interface.getBaseInterfaces()) {
267 StringRef baseName = base.getName();
268 std::string baseQualName = base.getFullyQualifiedName();
269 os << " impl" << baseName << " = interfaceMap.lookup<"
270 << baseQualName << ">();\n"
271 << " assert(impl" << baseName << " && \"`" << interfaceQualName
272 << "` expected its base interface `" << baseQualName
273 << "` to be registered\");\n";
275 os << " }\n";
278 os << " };\n";
281 void InterfaceGenerator::emitModelDecl(const Interface &interface) {
282 // Emit the basic model and the fallback model.
283 for (const char *modelClass : {"Model", "FallbackModel"}) {
284 os << " template<typename " << valueTemplate << ">\n";
285 os << " class " << modelClass << " : public Concept {\n public:\n";
286 os << " using Interface = " << interface.getFullyQualifiedName()
287 << ";\n";
288 os << " " << modelClass << "() : Concept{";
289 llvm::interleaveComma(
290 interface.getMethods(), os,
291 [&](const InterfaceMethod &method) { os << method.getName(); });
292 os << "} {}\n\n";
294 // Insert each of the virtual method overrides.
295 for (auto &method : interface.getMethods()) {
296 emitCPPType(method.getReturnType(), os << " static inline ");
297 emitMethodNameAndArgs(method, os, valueType,
298 /*addThisArg=*/!method.isStatic(),
299 /*addConst=*/false);
300 os << ";\n";
302 os << " };\n";
305 // Emit the template for the external model.
306 os << " template<typename ConcreteModel, typename " << valueTemplate
307 << ">\n";
308 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
309 os << " public:\n";
310 os << " using ConcreteEntity = " << valueTemplate << ";\n";
312 // Emit declarations for methods that have default implementations. Other
313 // methods are expected to be implemented by the concrete derived model.
314 for (auto &method : interface.getMethods()) {
315 if (!method.getDefaultImplementation())
316 continue;
317 os << " ";
318 if (method.isStatic())
319 os << "static ";
320 emitCPPType(method.getReturnType(), os);
321 os << method.getName() << "(";
322 if (!method.isStatic()) {
323 emitCPPType(valueType, os);
324 os << "tablegen_opaque_val";
325 if (!method.arg_empty())
326 os << ", ";
328 llvm::interleaveComma(method.getArguments(), os,
329 [&](const InterfaceMethod::Argument &arg) {
330 emitCPPType(arg.type, os);
331 os << arg.name;
333 os << ")";
334 if (!method.isStatic())
335 os << " const";
336 os << ";\n";
338 os << " };\n";
341 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
342 llvm::SmallVector<StringRef, 2> namespaces;
343 llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
344 for (StringRef ns : namespaces)
345 os << "namespace " << ns << " {\n";
347 for (auto &method : interface.getMethods()) {
348 os << "template<typename " << valueTemplate << ">\n";
349 emitCPPType(method.getReturnType(), os);
350 os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
351 << valueTemplate << ">::";
352 emitMethodNameAndArgs(method, os, valueType,
353 /*addThisArg=*/!method.isStatic(),
354 /*addConst=*/false);
355 os << " {\n ";
357 // Check for a provided body to the function.
358 if (std::optional<StringRef> body = method.getBody()) {
359 if (method.isStatic())
360 os << body->trim();
361 else
362 os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
363 os << "\n}\n";
364 continue;
367 // Forward to the method on the concrete operation type.
368 if (method.isStatic())
369 os << "return " << valueTemplate << "::";
370 else
371 os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
373 // Add the arguments to the call.
374 os << method.getName() << '(';
375 llvm::interleaveComma(
376 method.getArguments(), os,
377 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
378 os << ");\n}\n";
381 for (auto &method : interface.getMethods()) {
382 os << "template<typename " << valueTemplate << ">\n";
383 emitCPPType(method.getReturnType(), os);
384 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
385 << valueTemplate << ">::";
386 emitMethodNameAndArgs(method, os, valueType,
387 /*addThisArg=*/!method.isStatic(),
388 /*addConst=*/false);
389 os << " {\n ";
391 // Forward to the method on the concrete Model implementation.
392 if (method.isStatic())
393 os << "return " << valueTemplate << "::";
394 else
395 os << "return static_cast<const " << valueTemplate << " *>(impl)->";
397 // Add the arguments to the call.
398 os << method.getName() << '(';
399 if (!method.isStatic())
400 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
401 llvm::interleaveComma(
402 method.getArguments(), os,
403 [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
404 os << ");\n}\n";
407 // Emit default implementations for the external model.
408 for (auto &method : interface.getMethods()) {
409 if (!method.getDefaultImplementation())
410 continue;
411 os << "template<typename ConcreteModel, typename " << valueTemplate
412 << ">\n";
413 emitCPPType(method.getReturnType(), os);
414 os << "detail::" << interface.getName()
415 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
416 << ">::";
418 os << method.getName() << "(";
419 if (!method.isStatic()) {
420 emitCPPType(valueType, os);
421 os << "tablegen_opaque_val";
422 if (!method.arg_empty())
423 os << ", ";
425 llvm::interleaveComma(method.getArguments(), os,
426 [&](const InterfaceMethod::Argument &arg) {
427 emitCPPType(arg.type, os);
428 os << arg.name;
430 os << ")";
431 if (!method.isStatic())
432 os << " const";
434 os << " {\n";
436 // Use the empty context for static methods.
437 tblgen::FmtContext ctx;
438 os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
439 method.isStatic() ? &ctx : &nonStaticMethodFmt);
440 os << "\n}\n";
443 for (StringRef ns : llvm::reverse(namespaces))
444 os << "} // namespace " << ns << "\n";
447 void InterfaceGenerator::emitTraitDecl(const Interface &interface,
448 StringRef interfaceName,
449 StringRef interfaceTraitsName) {
450 os << llvm::formatv(" template <typename {3}>\n"
451 " struct {0}Trait : public ::mlir::{2}<{0},"
452 " detail::{1}>::Trait<{3}> {{\n",
453 interfaceName, interfaceTraitsName, interfaceBaseType,
454 valueTemplate);
456 // Insert the default implementation for any methods.
457 bool isOpInterface = isa<OpInterface>(interface);
458 for (auto &method : interface.getMethods()) {
459 // Flag interface methods named verifyTrait.
460 if (method.getName() == "verifyTrait")
461 PrintFatalError(
462 formatv("'verifyTrait' method cannot be specified as interface "
463 "method for '{0}'; use the 'verify' field instead",
464 interfaceName));
465 auto defaultImpl = method.getDefaultImplementation();
466 if (!defaultImpl)
467 continue;
469 emitInterfaceMethodDoc(method, os, " ");
470 os << " " << (method.isStatic() ? "static " : "");
471 emitCPPType(method.getReturnType(), os);
472 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
473 /*addConst=*/!isOpInterface && !method.isStatic());
474 os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
475 << "\n }\n";
478 if (auto verify = interface.getVerify()) {
479 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
481 tblgen::FmtContext verifyCtx;
482 verifyCtx.addSubst("_op", "op");
483 os << llvm::formatv(
484 " static ::llvm::LogicalResult {0}(::mlir::Operation *op) ",
485 (interface.verifyWithRegions() ? "verifyRegionTrait"
486 : "verifyTrait"))
487 << "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx)
488 << "\n }\n";
490 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
491 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
492 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
493 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
495 os << " };\n";
498 static void emitInterfaceDeclMethods(const Interface &interface,
499 raw_ostream &os, StringRef valueType,
500 bool isOpInterface,
501 tblgen::FmtContext &extraDeclsFmt) {
502 for (auto &method : interface.getMethods()) {
503 emitInterfaceMethodDoc(method, os, " ");
504 emitCPPType(method.getReturnType(), os << " ");
505 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
506 /*addConst=*/!isOpInterface);
507 os << ";\n";
510 // Emit any extra declarations.
511 if (std::optional<StringRef> extraDecls =
512 interface.getExtraClassDeclaration())
513 os << extraDecls->rtrim() << "\n";
514 if (std::optional<StringRef> extraDecls =
515 interface.getExtraSharedClassDeclaration())
516 os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
519 void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
520 llvm::SmallVector<StringRef, 2> namespaces;
521 llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
522 for (StringRef ns : namespaces)
523 os << "namespace " << ns << " {\n";
525 StringRef interfaceName = interface.getName();
526 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
528 // Emit a forward declaration of the interface class so that it becomes usable
529 // in the signature of its methods.
530 os << "class " << interfaceName << ";\n";
532 // Emit the traits struct containing the concept and model declarations.
533 os << "namespace detail {\n"
534 << "struct " << interfaceTraitsName << " {\n";
535 emitConceptDecl(interface);
536 emitModelDecl(interface);
537 os << "};\n";
539 // Emit the derived trait for the interface.
540 os << "template <typename " << valueTemplate << ">\n";
541 os << "struct " << interface.getName() << "Trait;\n";
543 os << "\n} // namespace detail\n";
545 // Emit the main interface class declaration.
546 os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
547 "public:\n"
548 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
549 interfaceName, interfaceName, interfaceTraitsName,
550 interfaceBaseType);
552 // Emit a utility wrapper trait class.
553 os << llvm::formatv(" template <typename {1}>\n"
554 " struct Trait : public detail::{0}Trait<{1}> {{};\n",
555 interfaceName, valueTemplate);
557 // Insert the method declarations.
558 bool isOpInterface = isa<OpInterface>(interface);
559 emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
560 extraDeclsFmt);
562 // Insert the method declarations for base classes.
563 for (auto &base : interface.getBaseInterfaces()) {
564 std::string baseQualName = base.getFullyQualifiedName();
565 os << " //"
566 "===---------------------------------------------------------------"
567 "-===//\n"
568 << " // Inherited from " << baseQualName << "\n"
569 << " //"
570 "===---------------------------------------------------------------"
571 "-===//\n\n";
573 // Allow implicit conversion to the base interface.
574 os << " operator " << baseQualName << " () const {\n"
575 << " if (!*this) return nullptr;\n"
576 << " return " << baseQualName << "(*this, getImpl()->impl"
577 << base.getName() << ");\n"
578 << " }\n\n";
580 // Inherit the base interface's methods.
581 emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
584 // Emit classof code if necessary.
585 if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
586 auto extraClassOfFmt = tblgen::FmtContext();
587 extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
588 os << " static bool classof(" << valueType << " base) {\n"
589 << " auto* interface = getInterfaceFor(base);\n"
590 << " if (!interface)\n"
591 " return false;\n"
592 " " << interfaceName << " odsInterfaceInstance(base, interface);\n"
593 << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
594 << "\n }\n";
597 os << "};\n";
599 os << "namespace detail {\n";
600 emitTraitDecl(interface, interfaceName, interfaceTraitsName);
601 os << "}// namespace detail\n";
603 for (StringRef ns : llvm::reverse(namespaces))
604 os << "} // namespace " << ns << "\n";
607 bool InterfaceGenerator::emitInterfaceDecls() {
608 llvm::emitSourceFileHeader("Interface Declarations", os);
609 // Sort according to ID, so defs are emitted in the order in which they appear
610 // in the Tablegen file.
611 std::vector<const Record *> sortedDefs(defs);
612 llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
613 return lhs->getID() < rhs->getID();
615 for (const Record *def : sortedDefs)
616 emitInterfaceDecl(Interface(def));
617 for (const Record *def : sortedDefs)
618 emitModelMethodsDef(Interface(def));
619 return false;
622 //===----------------------------------------------------------------------===//
623 // GEN: Interface documentation
624 //===----------------------------------------------------------------------===//
626 static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
627 Interface interface(&interfaceDef);
629 // Emit the interface name followed by the description.
630 os << "## " << interface.getName() << " (`" << interfaceDef.getName()
631 << "`)\n\n";
632 if (auto description = interface.getDescription())
633 mlir::tblgen::emitDescription(*description, os);
635 // Emit the methods required by the interface.
636 os << "\n### Methods:\n";
637 for (const auto &method : interface.getMethods()) {
638 // Emit the method name.
639 os << "#### `" << method.getName() << "`\n\n```c++\n";
641 // Emit the method signature.
642 if (method.isStatic())
643 os << "static ";
644 emitCPPType(method.getReturnType(), os) << method.getName() << '(';
645 llvm::interleaveComma(method.getArguments(), os,
646 [&](const InterfaceMethod::Argument &arg) {
647 emitCPPType(arg.type, os) << arg.name;
649 os << ");\n```\n";
651 // Emit the description.
652 if (auto description = method.getDescription())
653 mlir::tblgen::emitDescription(*description, os);
655 // If the body is not provided, this method must be provided by the user.
656 if (!method.getBody())
657 os << "\nNOTE: This method *must* be implemented by the user.";
659 os << "\n\n";
663 bool InterfaceGenerator::emitInterfaceDocs() {
664 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
665 os << "# " << interfaceBaseType << " definitions\n";
667 for (const auto *def : defs)
668 emitInterfaceDoc(*def, os);
669 return false;
672 //===----------------------------------------------------------------------===//
673 // GEN: Interface registration hooks
674 //===----------------------------------------------------------------------===//
676 namespace {
677 template <typename GeneratorT>
678 struct InterfaceGenRegistration {
679 InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
680 : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
681 genDefArg(("gen-" + genArg + "-interface-defs").str()),
682 genDocArg(("gen-" + genArg + "-interface-docs").str()),
683 genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
684 genDefDesc(("Generate " + genDesc + " interface definitions").str()),
685 genDocDesc(("Generate " + genDesc + " interface documentation").str()),
686 genDecls(genDeclArg, genDeclDesc,
687 [](const RecordKeeper &records, raw_ostream &os) {
688 return GeneratorT(records, os).emitInterfaceDecls();
690 genDefs(genDefArg, genDefDesc,
691 [](const RecordKeeper &records, raw_ostream &os) {
692 return GeneratorT(records, os).emitInterfaceDefs();
694 genDocs(genDocArg, genDocDesc,
695 [](const RecordKeeper &records, raw_ostream &os) {
696 return GeneratorT(records, os).emitInterfaceDocs();
697 }) {}
699 std::string genDeclArg, genDefArg, genDocArg;
700 std::string genDeclDesc, genDefDesc, genDocDesc;
701 mlir::GenRegistration genDecls, genDefs, genDocs;
703 } // namespace
705 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
706 "attribute");
707 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
708 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");