1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
25 using llvm::RecordKeeper
;
27 static llvm::cl::OptionCategory
passGenCat("Options for -gen-pass-decls");
28 static llvm::cl::opt
<std::string
>
29 groupName("name", llvm::cl::desc("The name of this group of passes"),
30 llvm::cl::cat(passGenCat
));
32 /// Extract the list of passes from the TableGen records.
33 static std::vector
<Pass
> getPasses(const RecordKeeper
&records
) {
34 std::vector
<Pass
> passes
;
36 for (const auto *def
: records
.getAllDerivedDefinitions("PassBase"))
37 passes
.emplace_back(def
);
42 const char *const passHeader
= R
"(
43 //===----------------------------------------------------------------------===//
45 //===----------------------------------------------------------------------===//
48 //===----------------------------------------------------------------------===//
49 // GEN: Pass registration generation
50 //===----------------------------------------------------------------------===//
52 /// The code snippet used to generate a pass registration.
54 /// {0}: The def name of the pass record.
55 /// {1}: The pass constructor call.
56 const char *const passRegistrationCode
= R
"(
57 //===----------------------------------------------------------------------===//
59 //===----------------------------------------------------------------------===//
61 inline void register{0}() {{
62 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
67 // Old registration code, kept for temporary backwards compatibility.
68 inline void register{0}Pass() {{
69 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
75 /// The code snippet used to generate a function to register all passes in a
78 /// {0}: The name of the pass group.
79 const char *const passGroupRegistrationCode
= R
"(
80 //===----------------------------------------------------------------------===//
82 //===----------------------------------------------------------------------===//
84 inline void register{0}Passes() {{
87 /// Emits the definition of the struct to be used to control the pass options.
88 static void emitPassOptionsStruct(const Pass
&pass
, raw_ostream
&os
) {
89 StringRef passName
= pass
.getDef()->getName();
90 ArrayRef
<PassOption
> options
= pass
.getOptions();
92 // Emit the struct only if the pass has at least one option.
96 os
<< formatv("struct {0}Options {{\n", passName
);
98 for (const PassOption
&opt
: options
) {
99 std::string type
= opt
.getType().str();
101 if (opt
.isListOption())
102 type
= "::llvm::SmallVector<" + type
+ ">";
104 os
.indent(2) << formatv("{0} {1}", type
, opt
.getCppVariableName());
106 if (std::optional
<StringRef
> defaultVal
= opt
.getDefaultValue())
107 os
<< " = " << defaultVal
;
115 static std::string
getPassDeclVarName(const Pass
&pass
) {
116 return "GEN_PASS_DECL_" + pass
.getDef()->getName().upper();
119 /// Emit the code to be included in the public header of the pass.
120 static void emitPassDecls(const Pass
&pass
, raw_ostream
&os
) {
121 StringRef passName
= pass
.getDef()->getName();
122 std::string enableVarName
= getPassDeclVarName(pass
);
124 os
<< "#ifdef " << enableVarName
<< "\n";
125 emitPassOptionsStruct(pass
, os
);
127 if (StringRef constructor
= pass
.getConstructor(); constructor
.empty()) {
128 // Default constructor declaration.
129 os
<< "std::unique_ptr<::mlir::Pass> create" << passName
<< "();\n";
131 // Declaration of the constructor with options.
132 if (ArrayRef
<PassOption
> options
= pass
.getOptions(); !options
.empty())
133 os
<< formatv("std::unique_ptr<::mlir::Pass> create{0}("
134 "{0}Options options);\n",
138 os
<< "#undef " << enableVarName
<< "\n";
139 os
<< "#endif // " << enableVarName
<< "\n";
142 /// Emit the code for registering each of the given passes with the global
144 static void emitRegistrations(llvm::ArrayRef
<Pass
> passes
, raw_ostream
&os
) {
145 os
<< "#ifdef GEN_PASS_REGISTRATION\n";
147 for (const Pass
&pass
: passes
) {
148 std::string constructorCall
;
149 if (StringRef constructor
= pass
.getConstructor(); !constructor
.empty())
150 constructorCall
= constructor
.str();
152 constructorCall
= formatv("create{0}()", pass
.getDef()->getName()).str();
154 os
<< formatv(passRegistrationCode
, pass
.getDef()->getName(),
158 os
<< formatv(passGroupRegistrationCode
, groupName
);
160 for (const Pass
&pass
: passes
)
161 os
<< " register" << pass
.getDef()->getName() << "();\n";
164 os
<< "#undef GEN_PASS_REGISTRATION\n";
165 os
<< "#endif // GEN_PASS_REGISTRATION\n";
168 //===----------------------------------------------------------------------===//
169 // GEN: Pass base class generation
170 //===----------------------------------------------------------------------===//
172 /// The code snippet used to generate the start of a pass base class.
174 /// {0}: The def name of the pass record.
175 /// {1}: The base class for the pass.
176 /// {2): The command line argument for the pass.
177 /// {3}: The summary for the pass.
178 /// {4}: The dependent dialects registration.
179 const char *const baseClassBegin
= R
"(
180 template <typename DerivedT>
181 class {0}Base : public {1} {
183 using Base = {0}Base;
185 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
186 {0}Base(const {0}Base &other) : {1}(other) {{}
187 {0}Base& operator=(const {0}Base &) = delete;
188 {0}Base({0}Base &&) = delete;
189 {0}Base& operator=({0}Base &&) = delete;
190 ~{0}Base() = default;
192 /// Returns the command-line argument attached to this pass.
193 static constexpr ::llvm::StringLiteral getArgumentName() {
194 return ::llvm::StringLiteral("{2}");
196 ::llvm::StringRef getArgument() const override { return "{2}"; }
198 ::llvm::StringRef getDescription() const override { return "{3}"; }
200 /// Returns the derived pass name.
201 static constexpr ::llvm::StringLiteral getPassName() {
202 return ::llvm::StringLiteral("{0}");
204 ::llvm::StringRef getName() const override { return "{0}"; }
206 /// Support isa/dyn_cast functionality for the derived pass class.
207 static bool classof(const ::mlir::Pass *pass) {{
208 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
211 /// A clone method to create a copy of this pass.
212 std::unique_ptr<::mlir::Pass> clonePass() const override {{
213 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
216 /// Return the dialect that must be loaded in the context before this pass.
217 void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
221 /// Explicitly declare the TypeID for this class. We declare an explicit private
222 /// instantiation because Pass classes should only be visible by the current
224 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
228 /// Registration for a single dependent dialect, to be inserted for each
229 /// dependent dialect in the `getDependentDialects` above.
230 const char *const dialectRegistrationTemplate
= "registry.insert<{0}>();";
232 const char *const friendDefaultConstructorDeclTemplate
= R
"(
234 std::unique_ptr<::mlir::Pass> create{0}();
238 const char *const friendDefaultConstructorWithOptionsDeclTemplate
= R
"(
240 std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
244 const char *const friendDefaultConstructorDefTemplate
= R
"(
245 friend std::unique_ptr<::mlir::Pass> create{0}() {{
246 return std::make_unique<DerivedT>();
250 const char *const friendDefaultConstructorWithOptionsDefTemplate
= R
"(
251 friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
252 return std::make_unique<DerivedT>(std::move(options));
256 const char *const defaultConstructorDefTemplate
= R
"(
257 std::unique_ptr<::mlir::Pass> create{0}() {{
258 return impl::create{0}();
262 const char *const defaultConstructorWithOptionsDefTemplate
= R
"(
263 std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
264 return impl::create{0}(std::move(options));
268 /// Emit the declarations for each of the pass options.
269 static void emitPassOptionDecls(const Pass
&pass
, raw_ostream
&os
) {
270 for (const PassOption
&opt
: pass
.getOptions()) {
271 os
.indent(2) << "::mlir::Pass::"
272 << (opt
.isListOption() ? "ListOption" : "Option");
274 os
<< formatv(R
"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
275 opt
.getType(), opt
.getCppVariableName(), opt
.getArgument(),
276 opt
.getDescription());
277 if (std::optional
<StringRef
> defaultVal
= opt
.getDefaultValue())
278 os
<< ", ::llvm::cl::init(" << defaultVal
<< ")";
279 if (std::optional
<StringRef
> additionalFlags
= opt
.getAdditionalFlags())
280 os
<< ", " << *additionalFlags
;
285 /// Emit the declarations for each of the pass statistics.
286 static void emitPassStatisticDecls(const Pass
&pass
, raw_ostream
&os
) {
287 for (const PassStatistic
&stat
: pass
.getStatistics()) {
288 os
<< formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
289 stat
.getCppVariableName(), stat
.getName(),
290 stat
.getDescription());
294 /// Emit the code to be used in the implementation of the pass.
295 static void emitPassDefs(const Pass
&pass
, raw_ostream
&os
) {
296 StringRef passName
= pass
.getDef()->getName();
297 std::string enableVarName
= "GEN_PASS_DEF_" + passName
.upper();
298 bool emitDefaultConstructors
= pass
.getConstructor().empty();
299 bool emitDefaultConstructorWithOptions
= !pass
.getOptions().empty();
301 os
<< "#ifdef " << enableVarName
<< "\n";
303 if (emitDefaultConstructors
) {
304 os
<< formatv(friendDefaultConstructorDeclTemplate
, passName
);
306 if (emitDefaultConstructorWithOptions
)
307 os
<< formatv(friendDefaultConstructorWithOptionsDeclTemplate
, passName
);
310 std::string dependentDialectRegistrations
;
312 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
314 pass
.getDependentDialects(), dialectsOs
,
315 [&](StringRef dependentDialect
) {
316 dialectsOs
<< formatv(dialectRegistrationTemplate
, dependentDialect
);
321 os
<< "namespace impl {\n";
322 os
<< formatv(baseClassBegin
, passName
, pass
.getBaseClass(),
323 pass
.getArgument(), pass
.getSummary(),
324 dependentDialectRegistrations
);
326 if (ArrayRef
<PassOption
> options
= pass
.getOptions(); !options
.empty()) {
327 os
.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
330 for (const PassOption
&opt
: pass
.getOptions())
331 os
.indent(4) << formatv("{0} = std::move(options.{0});\n",
332 opt
.getCppVariableName());
334 os
.indent(2) << "}\n";
338 os
<< "protected:\n";
339 emitPassOptionDecls(pass
, os
);
340 emitPassStatisticDecls(pass
, os
);
345 if (emitDefaultConstructors
) {
346 os
<< formatv(friendDefaultConstructorDefTemplate
, passName
);
348 if (!pass
.getOptions().empty())
349 os
<< formatv(friendDefaultConstructorWithOptionsDefTemplate
, passName
);
353 os
<< "} // namespace impl\n";
355 if (emitDefaultConstructors
) {
356 os
<< formatv(defaultConstructorDefTemplate
, passName
);
358 if (emitDefaultConstructorWithOptions
)
359 os
<< formatv(defaultConstructorWithOptionsDefTemplate
, passName
);
362 os
<< "#undef " << enableVarName
<< "\n";
363 os
<< "#endif // " << enableVarName
<< "\n";
366 static void emitPass(const Pass
&pass
, raw_ostream
&os
) {
367 StringRef passName
= pass
.getDef()->getName();
368 os
<< formatv(passHeader
, passName
);
370 emitPassDecls(pass
, os
);
371 emitPassDefs(pass
, os
);
374 // TODO: Drop old pass declarations.
375 // The old pass base class is being kept until all the passes have switched to
376 // the new decls/defs design.
377 const char *const oldPassDeclBegin
= R
"(
378 template <typename DerivedT>
379 class {0}Base : public {1} {
381 using Base = {0}Base;
383 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
384 {0}Base(const {0}Base &other) : {1}(other) {{}
385 {0}Base& operator=(const {0}Base &) = delete;
386 {0}Base({0}Base &&) = delete;
387 {0}Base& operator=({0}Base &&) = delete;
388 ~{0}Base() = default;
390 /// Returns the command-line argument attached to this pass.
391 static constexpr ::llvm::StringLiteral getArgumentName() {
392 return ::llvm::StringLiteral("{2}");
394 ::llvm::StringRef getArgument() const override { return "{2}"; }
396 ::llvm::StringRef getDescription() const override { return "{3}"; }
398 /// Returns the derived pass name.
399 static constexpr ::llvm::StringLiteral getPassName() {
400 return ::llvm::StringLiteral("{0}");
402 ::llvm::StringRef getName() const override { return "{0}"; }
404 /// Support isa/dyn_cast functionality for the derived pass class.
405 static bool classof(const ::mlir::Pass *pass) {{
406 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
409 /// A clone method to create a copy of this pass.
410 std::unique_ptr<::mlir::Pass> clonePass() const override {{
411 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
414 /// Register the dialects that must be loaded in the context before this pass.
415 void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
419 /// Explicitly declare the TypeID for this class. We declare an explicit private
420 /// instantiation because Pass classes should only be visible by the current
422 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
427 // TODO: Drop old pass declarations.
428 /// Emit a backward-compatible declaration of the pass base class.
429 static void emitOldPassDecl(const Pass
&pass
, raw_ostream
&os
) {
430 StringRef defName
= pass
.getDef()->getName();
431 std::string dependentDialectRegistrations
;
433 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
435 pass
.getDependentDialects(), dialectsOs
,
436 [&](StringRef dependentDialect
) {
437 dialectsOs
<< formatv(dialectRegistrationTemplate
, dependentDialect
);
441 os
<< formatv(oldPassDeclBegin
, defName
, pass
.getBaseClass(),
442 pass
.getArgument(), pass
.getSummary(),
443 dependentDialectRegistrations
);
444 emitPassOptionDecls(pass
, os
);
445 emitPassStatisticDecls(pass
, os
);
449 static void emitPasses(const RecordKeeper
&records
, raw_ostream
&os
) {
450 std::vector
<Pass
> passes
= getPasses(records
);
451 os
<< "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
454 os
<< "#ifdef GEN_PASS_DECL\n";
455 os
<< "// Generate declarations for all passes.\n";
456 for (const Pass
&pass
: passes
)
457 os
<< "#define " << getPassDeclVarName(pass
) << "\n";
458 os
<< "#undef GEN_PASS_DECL\n";
459 os
<< "#endif // GEN_PASS_DECL\n";
461 for (const Pass
&pass
: passes
)
464 emitRegistrations(passes
, os
);
466 // TODO: Drop old pass declarations.
467 // Emit the old code until all the passes have switched to the new design.
468 os
<< "// Deprecated. Please use the new per-pass macros.\n";
469 os
<< "#ifdef GEN_PASS_CLASSES\n";
470 for (const Pass
&pass
: passes
)
471 emitOldPassDecl(pass
, os
);
472 os
<< "#undef GEN_PASS_CLASSES\n";
473 os
<< "#endif // GEN_PASS_CLASSES\n";
476 static mlir::GenRegistration
477 genPassDecls("gen-pass-decls", "Generate pass declarations",
478 [](const RecordKeeper
&records
, raw_ostream
&os
) {
479 emitPasses(records
, os
);