[OpenMP][Docs] Update OpenMP supported features table (#126292)
[llvm-project.git] / mlir / tools / mlir-tblgen / PassGen.cpp
blob4b4ac41b9effb8eaf2948d06e8074979cec01b06
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 using llvm::formatv;
25 using llvm::RecordKeeper;
27 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
28 static llvm::cl::opt<std::string>
29 groupName("name", llvm::cl::desc("The name of this group of passes"),
30 llvm::cl::cat(passGenCat));
32 /// Extract the list of passes from the TableGen records.
33 static std::vector<Pass> getPasses(const RecordKeeper &records) {
34 std::vector<Pass> passes;
36 for (const auto *def : records.getAllDerivedDefinitions("PassBase"))
37 passes.emplace_back(def);
39 return passes;
42 const char *const passHeader = R"(
43 //===----------------------------------------------------------------------===//
44 // {0}
45 //===----------------------------------------------------------------------===//
46 )";
48 //===----------------------------------------------------------------------===//
49 // GEN: Pass registration generation
50 //===----------------------------------------------------------------------===//
52 /// The code snippet used to generate a pass registration.
53 ///
54 /// {0}: The def name of the pass record.
55 /// {1}: The pass constructor call.
56 const char *const passRegistrationCode = R"(
57 //===----------------------------------------------------------------------===//
58 // {0} Registration
59 //===----------------------------------------------------------------------===//
61 inline void register{0}() {{
62 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63 return {1};
64 });
67 // Old registration code, kept for temporary backwards compatibility.
68 inline void register{0}Pass() {{
69 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70 return {1};
71 });
73 )";
75 /// The code snippet used to generate a function to register all passes in a
76 /// group.
77 ///
78 /// {0}: The name of the pass group.
79 const char *const passGroupRegistrationCode = R"(
80 //===----------------------------------------------------------------------===//
81 // {0} Registration
82 //===----------------------------------------------------------------------===//
84 inline void register{0}Passes() {{
85 )";
87 /// Emits the definition of the struct to be used to control the pass options.
88 static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
89 StringRef passName = pass.getDef()->getName();
90 ArrayRef<PassOption> options = pass.getOptions();
92 // Emit the struct only if the pass has at least one option.
93 if (options.empty())
94 return;
96 os << formatv("struct {0}Options {{\n", passName);
98 for (const PassOption &opt : options) {
99 std::string type = opt.getType().str();
101 if (opt.isListOption())
102 type = "::llvm::SmallVector<" + type + ">";
104 os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
106 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
107 os << " = " << defaultVal;
109 os << ";\n";
112 os << "};\n";
115 static std::string getPassDeclVarName(const Pass &pass) {
116 return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
119 /// Emit the code to be included in the public header of the pass.
120 static void emitPassDecls(const Pass &pass, raw_ostream &os) {
121 StringRef passName = pass.getDef()->getName();
122 std::string enableVarName = getPassDeclVarName(pass);
124 os << "#ifdef " << enableVarName << "\n";
125 emitPassOptionsStruct(pass, os);
127 if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
128 // Default constructor declaration.
129 os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
131 // Declaration of the constructor with options.
132 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
133 os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
134 "{0}Options options);\n",
135 passName);
138 os << "#undef " << enableVarName << "\n";
139 os << "#endif // " << enableVarName << "\n";
142 /// Emit the code for registering each of the given passes with the global
143 /// PassRegistry.
144 static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145 os << "#ifdef GEN_PASS_REGISTRATION\n";
147 for (const Pass &pass : passes) {
148 std::string constructorCall;
149 if (StringRef constructor = pass.getConstructor(); !constructor.empty())
150 constructorCall = constructor.str();
151 else
152 constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
154 os << formatv(passRegistrationCode, pass.getDef()->getName(),
155 constructorCall);
158 os << formatv(passGroupRegistrationCode, groupName);
160 for (const Pass &pass : passes)
161 os << " register" << pass.getDef()->getName() << "();\n";
163 os << "}\n";
164 os << "#undef GEN_PASS_REGISTRATION\n";
165 os << "#endif // GEN_PASS_REGISTRATION\n";
168 //===----------------------------------------------------------------------===//
169 // GEN: Pass base class generation
170 //===----------------------------------------------------------------------===//
172 /// The code snippet used to generate the start of a pass base class.
174 /// {0}: The def name of the pass record.
175 /// {1}: The base class for the pass.
176 /// {2): The command line argument for the pass.
177 /// {3}: The summary for the pass.
178 /// {4}: The dependent dialects registration.
179 const char *const baseClassBegin = R"(
180 template <typename DerivedT>
181 class {0}Base : public {1} {
182 public:
183 using Base = {0}Base;
185 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
186 {0}Base(const {0}Base &other) : {1}(other) {{}
187 {0}Base& operator=(const {0}Base &) = delete;
188 {0}Base({0}Base &&) = delete;
189 {0}Base& operator=({0}Base &&) = delete;
190 ~{0}Base() = default;
192 /// Returns the command-line argument attached to this pass.
193 static constexpr ::llvm::StringLiteral getArgumentName() {
194 return ::llvm::StringLiteral("{2}");
196 ::llvm::StringRef getArgument() const override { return "{2}"; }
198 ::llvm::StringRef getDescription() const override { return "{3}"; }
200 /// Returns the derived pass name.
201 static constexpr ::llvm::StringLiteral getPassName() {
202 return ::llvm::StringLiteral("{0}");
204 ::llvm::StringRef getName() const override { return "{0}"; }
206 /// Support isa/dyn_cast functionality for the derived pass class.
207 static bool classof(const ::mlir::Pass *pass) {{
208 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
211 /// A clone method to create a copy of this pass.
212 std::unique_ptr<::mlir::Pass> clonePass() const override {{
213 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
216 /// Return the dialect that must be loaded in the context before this pass.
217 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
221 /// Explicitly declare the TypeID for this class. We declare an explicit private
222 /// instantiation because Pass classes should only be visible by the current
223 /// library.
224 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
228 /// Registration for a single dependent dialect, to be inserted for each
229 /// dependent dialect in the `getDependentDialects` above.
230 const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
232 const char *const friendDefaultConstructorDeclTemplate = R"(
233 namespace impl {{
234 std::unique_ptr<::mlir::Pass> create{0}();
235 } // namespace impl
238 const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
239 namespace impl {{
240 std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
241 } // namespace impl
244 const char *const friendDefaultConstructorDefTemplate = R"(
245 friend std::unique_ptr<::mlir::Pass> create{0}() {{
246 return std::make_unique<DerivedT>();
250 const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
251 friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
252 return std::make_unique<DerivedT>(std::move(options));
256 const char *const defaultConstructorDefTemplate = R"(
257 std::unique_ptr<::mlir::Pass> create{0}() {{
258 return impl::create{0}();
262 const char *const defaultConstructorWithOptionsDefTemplate = R"(
263 std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
264 return impl::create{0}(std::move(options));
268 /// Emit the declarations for each of the pass options.
269 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
270 for (const PassOption &opt : pass.getOptions()) {
271 os.indent(2) << "::mlir::Pass::"
272 << (opt.isListOption() ? "ListOption" : "Option");
274 os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
275 opt.getType(), opt.getCppVariableName(), opt.getArgument(),
276 opt.getDescription());
277 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
278 os << ", ::llvm::cl::init(" << defaultVal << ")";
279 if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
280 os << ", " << *additionalFlags;
281 os << "};\n";
285 /// Emit the declarations for each of the pass statistics.
286 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
287 for (const PassStatistic &stat : pass.getStatistics()) {
288 os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
289 stat.getCppVariableName(), stat.getName(),
290 stat.getDescription());
294 /// Emit the code to be used in the implementation of the pass.
295 static void emitPassDefs(const Pass &pass, raw_ostream &os) {
296 StringRef passName = pass.getDef()->getName();
297 std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
298 bool emitDefaultConstructors = pass.getConstructor().empty();
299 bool emitDefaultConstructorWithOptions = !pass.getOptions().empty();
301 os << "#ifdef " << enableVarName << "\n";
303 if (emitDefaultConstructors) {
304 os << formatv(friendDefaultConstructorDeclTemplate, passName);
306 if (emitDefaultConstructorWithOptions)
307 os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
310 std::string dependentDialectRegistrations;
312 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
313 llvm::interleave(
314 pass.getDependentDialects(), dialectsOs,
315 [&](StringRef dependentDialect) {
316 dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
318 "\n ");
321 os << "namespace impl {\n";
322 os << formatv(baseClassBegin, passName, pass.getBaseClass(),
323 pass.getArgument(), pass.getSummary(),
324 dependentDialectRegistrations);
326 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
327 os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
328 passName);
330 for (const PassOption &opt : pass.getOptions())
331 os.indent(4) << formatv("{0} = std::move(options.{0});\n",
332 opt.getCppVariableName());
334 os.indent(2) << "}\n";
337 // Protected content
338 os << "protected:\n";
339 emitPassOptionDecls(pass, os);
340 emitPassStatisticDecls(pass, os);
342 // Private content
343 os << "private:\n";
345 if (emitDefaultConstructors) {
346 os << formatv(friendDefaultConstructorDefTemplate, passName);
348 if (!pass.getOptions().empty())
349 os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
352 os << "};\n";
353 os << "} // namespace impl\n";
355 if (emitDefaultConstructors) {
356 os << formatv(defaultConstructorDefTemplate, passName);
358 if (emitDefaultConstructorWithOptions)
359 os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
362 os << "#undef " << enableVarName << "\n";
363 os << "#endif // " << enableVarName << "\n";
366 static void emitPass(const Pass &pass, raw_ostream &os) {
367 StringRef passName = pass.getDef()->getName();
368 os << formatv(passHeader, passName);
370 emitPassDecls(pass, os);
371 emitPassDefs(pass, os);
374 // TODO: Drop old pass declarations.
375 // The old pass base class is being kept until all the passes have switched to
376 // the new decls/defs design.
377 const char *const oldPassDeclBegin = R"(
378 template <typename DerivedT>
379 class {0}Base : public {1} {
380 public:
381 using Base = {0}Base;
383 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
384 {0}Base(const {0}Base &other) : {1}(other) {{}
385 {0}Base& operator=(const {0}Base &) = delete;
386 {0}Base({0}Base &&) = delete;
387 {0}Base& operator=({0}Base &&) = delete;
388 ~{0}Base() = default;
390 /// Returns the command-line argument attached to this pass.
391 static constexpr ::llvm::StringLiteral getArgumentName() {
392 return ::llvm::StringLiteral("{2}");
394 ::llvm::StringRef getArgument() const override { return "{2}"; }
396 ::llvm::StringRef getDescription() const override { return "{3}"; }
398 /// Returns the derived pass name.
399 static constexpr ::llvm::StringLiteral getPassName() {
400 return ::llvm::StringLiteral("{0}");
402 ::llvm::StringRef getName() const override { return "{0}"; }
404 /// Support isa/dyn_cast functionality for the derived pass class.
405 static bool classof(const ::mlir::Pass *pass) {{
406 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
409 /// A clone method to create a copy of this pass.
410 std::unique_ptr<::mlir::Pass> clonePass() const override {{
411 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
414 /// Register the dialects that must be loaded in the context before this pass.
415 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
419 /// Explicitly declare the TypeID for this class. We declare an explicit private
420 /// instantiation because Pass classes should only be visible by the current
421 /// library.
422 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
424 protected:
427 // TODO: Drop old pass declarations.
428 /// Emit a backward-compatible declaration of the pass base class.
429 static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
430 StringRef defName = pass.getDef()->getName();
431 std::string dependentDialectRegistrations;
433 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
434 llvm::interleave(
435 pass.getDependentDialects(), dialectsOs,
436 [&](StringRef dependentDialect) {
437 dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
439 "\n ");
441 os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
442 pass.getArgument(), pass.getSummary(),
443 dependentDialectRegistrations);
444 emitPassOptionDecls(pass, os);
445 emitPassStatisticDecls(pass, os);
446 os << "};\n";
449 static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
450 std::vector<Pass> passes = getPasses(records);
451 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
453 os << "\n";
454 os << "#ifdef GEN_PASS_DECL\n";
455 os << "// Generate declarations for all passes.\n";
456 for (const Pass &pass : passes)
457 os << "#define " << getPassDeclVarName(pass) << "\n";
458 os << "#undef GEN_PASS_DECL\n";
459 os << "#endif // GEN_PASS_DECL\n";
461 for (const Pass &pass : passes)
462 emitPass(pass, os);
464 emitRegistrations(passes, os);
466 // TODO: Drop old pass declarations.
467 // Emit the old code until all the passes have switched to the new design.
468 os << "// Deprecated. Please use the new per-pass macros.\n";
469 os << "#ifdef GEN_PASS_CLASSES\n";
470 for (const Pass &pass : passes)
471 emitOldPassDecl(pass, os);
472 os << "#undef GEN_PASS_CLASSES\n";
473 os << "#endif // GEN_PASS_CLASSES\n";
476 static mlir::GenRegistration
477 genPassDecls("gen-pass-decls", "Generate pass declarations",
478 [](const RecordKeeper &records, raw_ostream &os) {
479 emitPasses(records, os);
480 return false;