1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
25 static llvm::cl::OptionCategory
passGenCat("Options for -gen-pass-decls");
26 static llvm::cl::opt
<std::string
>
27 groupName("name", llvm::cl::desc("The name of this group of passes"),
28 llvm::cl::cat(passGenCat
));
30 /// Extract the list of passes from the TableGen records.
31 static std::vector
<Pass
> getPasses(const llvm::RecordKeeper
&recordKeeper
) {
32 std::vector
<Pass
> passes
;
34 for (const auto *def
: recordKeeper
.getAllDerivedDefinitions("PassBase"))
35 passes
.emplace_back(def
);
40 const char *const passHeader
= R
"(
41 //===----------------------------------------------------------------------===//
43 //===----------------------------------------------------------------------===//
46 //===----------------------------------------------------------------------===//
47 // GEN: Pass registration generation
48 //===----------------------------------------------------------------------===//
50 /// The code snippet used to generate a pass registration.
52 /// {0}: The def name of the pass record.
53 /// {1}: The pass constructor call.
54 const char *const passRegistrationCode
= R
"(
55 //===----------------------------------------------------------------------===//
57 //===----------------------------------------------------------------------===//
59 inline void register{0}() {{
60 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
65 // Old registration code, kept for temporary backwards compatibility.
66 inline void register{0}Pass() {{
67 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
73 /// The code snippet used to generate a function to register all passes in a
76 /// {0}: The name of the pass group.
77 const char *const passGroupRegistrationCode
= R
"(
78 //===----------------------------------------------------------------------===//
80 //===----------------------------------------------------------------------===//
82 inline void register{0}Passes() {{
85 /// Emits the definition of the struct to be used to control the pass options.
86 static void emitPassOptionsStruct(const Pass
&pass
, raw_ostream
&os
) {
87 StringRef passName
= pass
.getDef()->getName();
88 ArrayRef
<PassOption
> options
= pass
.getOptions();
90 // Emit the struct only if the pass has at least one option.
94 os
<< llvm::formatv("struct {0}Options {{\n", passName
);
96 for (const PassOption
&opt
: options
) {
97 std::string type
= opt
.getType().str();
99 if (opt
.isListOption())
100 type
= "::llvm::ArrayRef<" + type
+ ">";
102 os
.indent(2) << llvm::formatv("{0} {1}", type
, opt
.getCppVariableName());
104 if (std::optional
<StringRef
> defaultVal
= opt
.getDefaultValue())
105 os
<< " = " << defaultVal
;
113 static std::string
getPassDeclVarName(const Pass
&pass
) {
114 return "GEN_PASS_DECL_" + pass
.getDef()->getName().upper();
117 /// Emit the code to be included in the public header of the pass.
118 static void emitPassDecls(const Pass
&pass
, raw_ostream
&os
) {
119 StringRef passName
= pass
.getDef()->getName();
120 std::string enableVarName
= getPassDeclVarName(pass
);
122 os
<< "#ifdef " << enableVarName
<< "\n";
123 emitPassOptionsStruct(pass
, os
);
125 if (StringRef constructor
= pass
.getConstructor(); constructor
.empty()) {
126 // Default constructor declaration.
127 os
<< "std::unique_ptr<::mlir::Pass> create" << passName
<< "();\n";
129 // Declaration of the constructor with options.
130 if (ArrayRef
<PassOption
> options
= pass
.getOptions(); !options
.empty())
131 os
<< llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
132 "{0}Options &options);\n",
136 os
<< "#undef " << enableVarName
<< "\n";
137 os
<< "#endif // " << enableVarName
<< "\n";
140 /// Emit the code for registering each of the given passes with the global
142 static void emitRegistrations(llvm::ArrayRef
<Pass
> passes
, raw_ostream
&os
) {
143 os
<< "#ifdef GEN_PASS_REGISTRATION\n";
145 for (const Pass
&pass
: passes
) {
146 std::string constructorCall
;
147 if (StringRef constructor
= pass
.getConstructor(); !constructor
.empty())
148 constructorCall
= constructor
.str();
151 llvm::formatv("create{0}()", pass
.getDef()->getName()).str();
153 os
<< llvm::formatv(passRegistrationCode
, pass
.getDef()->getName(),
157 os
<< llvm::formatv(passGroupRegistrationCode
, groupName
);
159 for (const Pass
&pass
: passes
)
160 os
<< " register" << pass
.getDef()->getName() << "();\n";
163 os
<< "#undef GEN_PASS_REGISTRATION\n";
164 os
<< "#endif // GEN_PASS_REGISTRATION\n";
167 //===----------------------------------------------------------------------===//
168 // GEN: Pass base class generation
169 //===----------------------------------------------------------------------===//
171 /// The code snippet used to generate the start of a pass base class.
173 /// {0}: The def name of the pass record.
174 /// {1}: The base class for the pass.
175 /// {2): The command line argument for the pass.
176 /// {3}: The dependent dialects registration.
177 const char *const baseClassBegin
= R
"(
178 template <typename DerivedT>
179 class {0}Base : public {1} {
181 using Base = {0}Base;
183 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
184 {0}Base(const {0}Base &other) : {1}(other) {{}
186 /// Returns the command-line argument attached to this pass.
187 static constexpr ::llvm::StringLiteral getArgumentName() {
188 return ::llvm::StringLiteral("{2}");
190 ::llvm::StringRef getArgument() const override { return "{2}"; }
192 ::llvm::StringRef getDescription() const override { return "{3}"; }
194 /// Returns the derived pass name.
195 static constexpr ::llvm::StringLiteral getPassName() {
196 return ::llvm::StringLiteral("{0}");
198 ::llvm::StringRef getName() const override { return "{0}"; }
200 /// Support isa/dyn_cast functionality for the derived pass class.
201 static bool classof(const ::mlir::Pass *pass) {{
202 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
205 /// A clone method to create a copy of this pass.
206 std::unique_ptr<::mlir::Pass> clonePass() const override {{
207 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
210 /// Return the dialect that must be loaded in the context before this pass.
211 void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
215 /// Explicitly declare the TypeID for this class. We declare an explicit private
216 /// instantiation because Pass classes should only be visible by the current
218 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
222 /// Registration for a single dependent dialect, to be inserted for each
223 /// dependent dialect in the `getDependentDialects` above.
224 const char *const dialectRegistrationTemplate
= R
"(
225 registry.insert<{0}>();
228 const char *const friendDefaultConstructorDeclTemplate
= R
"(
230 std::unique_ptr<::mlir::Pass> create{0}();
234 const char *const friendDefaultConstructorWithOptionsDeclTemplate
= R
"(
236 std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options);
240 const char *const friendDefaultConstructorDefTemplate
= R
"(
241 friend std::unique_ptr<::mlir::Pass> create{0}() {{
242 return std::make_unique<DerivedT>();
246 const char *const friendDefaultConstructorWithOptionsDefTemplate
= R
"(
247 friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
248 return std::make_unique<DerivedT>(options);
252 const char *const defaultConstructorDefTemplate
= R
"(
253 std::unique_ptr<::mlir::Pass> create{0}() {{
254 return impl::create{0}();
258 const char *const defaultConstructorWithOptionsDefTemplate
= R
"(
259 std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
260 return impl::create{0}(options);
264 /// Emit the declarations for each of the pass options.
265 static void emitPassOptionDecls(const Pass
&pass
, raw_ostream
&os
) {
266 for (const PassOption
&opt
: pass
.getOptions()) {
267 os
.indent(2) << "::mlir::Pass::"
268 << (opt
.isListOption() ? "ListOption" : "Option");
270 os
<< llvm::formatv(R
"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
271 opt
.getType(), opt
.getCppVariableName(),
272 opt
.getArgument(), opt
.getDescription());
273 if (std::optional
<StringRef
> defaultVal
= opt
.getDefaultValue())
274 os
<< ", ::llvm::cl::init(" << defaultVal
<< ")";
275 if (std::optional
<StringRef
> additionalFlags
= opt
.getAdditionalFlags())
276 os
<< ", " << *additionalFlags
;
281 /// Emit the declarations for each of the pass statistics.
282 static void emitPassStatisticDecls(const Pass
&pass
, raw_ostream
&os
) {
283 for (const PassStatistic
&stat
: pass
.getStatistics()) {
285 " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
286 stat
.getCppVariableName(), stat
.getName(), stat
.getDescription());
290 /// Emit the code to be used in the implementation of the pass.
291 static void emitPassDefs(const Pass
&pass
, raw_ostream
&os
) {
292 StringRef passName
= pass
.getDef()->getName();
293 std::string enableVarName
= "GEN_PASS_DEF_" + passName
.upper();
294 bool emitDefaultConstructors
= pass
.getConstructor().empty();
295 bool emitDefaultConstructorWithOptions
= !pass
.getOptions().empty();
297 os
<< "#ifdef " << enableVarName
<< "\n";
299 if (emitDefaultConstructors
) {
300 os
<< llvm::formatv(friendDefaultConstructorDeclTemplate
, passName
);
302 if (emitDefaultConstructorWithOptions
)
303 os
<< llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate
,
307 std::string dependentDialectRegistrations
;
309 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
310 for (StringRef dependentDialect
: pass
.getDependentDialects())
311 dialectsOs
<< llvm::formatv(dialectRegistrationTemplate
,
315 os
<< "namespace impl {\n";
316 os
<< llvm::formatv(baseClassBegin
, passName
, pass
.getBaseClass(),
317 pass
.getArgument(), pass
.getSummary(),
318 dependentDialectRegistrations
);
320 if (ArrayRef
<PassOption
> options
= pass
.getOptions(); !options
.empty()) {
321 os
.indent(2) << llvm::formatv(
322 "{0}Base(const {0}Options &options) : {0}Base() {{\n", passName
);
324 for (const PassOption
&opt
: pass
.getOptions())
325 os
.indent(4) << llvm::formatv("{0} = options.{0};\n",
326 opt
.getCppVariableName());
328 os
.indent(2) << "}\n";
332 os
<< "protected:\n";
333 emitPassOptionDecls(pass
, os
);
334 emitPassStatisticDecls(pass
, os
);
339 if (emitDefaultConstructors
) {
340 os
<< llvm::formatv(friendDefaultConstructorDefTemplate
, passName
);
342 if (!pass
.getOptions().empty())
343 os
<< llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate
,
348 os
<< "} // namespace impl\n";
350 if (emitDefaultConstructors
) {
351 os
<< llvm::formatv(defaultConstructorDefTemplate
, passName
);
353 if (emitDefaultConstructorWithOptions
)
354 os
<< llvm::formatv(defaultConstructorWithOptionsDefTemplate
, passName
);
357 os
<< "#undef " << enableVarName
<< "\n";
358 os
<< "#endif // " << enableVarName
<< "\n";
361 static void emitPass(const Pass
&pass
, raw_ostream
&os
) {
362 StringRef passName
= pass
.getDef()->getName();
363 os
<< llvm::formatv(passHeader
, passName
);
365 emitPassDecls(pass
, os
);
366 emitPassDefs(pass
, os
);
369 // TODO: Drop old pass declarations.
370 // The old pass base class is being kept until all the passes have switched to
371 // the new decls/defs design.
372 const char *const oldPassDeclBegin
= R
"(
373 template <typename DerivedT>
374 class {0}Base : public {1} {
376 using Base = {0}Base;
378 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
379 {0}Base(const {0}Base &other) : {1}(other) {{}
381 /// Returns the command-line argument attached to this pass.
382 static constexpr ::llvm::StringLiteral getArgumentName() {
383 return ::llvm::StringLiteral("{2}");
385 ::llvm::StringRef getArgument() const override { return "{2}"; }
387 ::llvm::StringRef getDescription() const override { return "{3}"; }
389 /// Returns the derived pass name.
390 static constexpr ::llvm::StringLiteral getPassName() {
391 return ::llvm::StringLiteral("{0}");
393 ::llvm::StringRef getName() const override { return "{0}"; }
395 /// Support isa/dyn_cast functionality for the derived pass class.
396 static bool classof(const ::mlir::Pass *pass) {{
397 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
400 /// A clone method to create a copy of this pass.
401 std::unique_ptr<::mlir::Pass> clonePass() const override {{
402 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
405 /// Return the dialect that must be loaded in the context before this pass.
406 void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
410 /// Explicitly declare the TypeID for this class. We declare an explicit private
411 /// instantiation because Pass classes should only be visible by the current
413 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
418 // TODO: Drop old pass declarations.
419 /// Emit a backward-compatible declaration of the pass base class.
420 static void emitOldPassDecl(const Pass
&pass
, raw_ostream
&os
) {
421 StringRef defName
= pass
.getDef()->getName();
422 std::string dependentDialectRegistrations
;
424 llvm::raw_string_ostream
dialectsOs(dependentDialectRegistrations
);
425 for (StringRef dependentDialect
: pass
.getDependentDialects())
426 dialectsOs
<< llvm::formatv(dialectRegistrationTemplate
,
429 os
<< llvm::formatv(oldPassDeclBegin
, defName
, pass
.getBaseClass(),
430 pass
.getArgument(), pass
.getSummary(),
431 dependentDialectRegistrations
);
432 emitPassOptionDecls(pass
, os
);
433 emitPassStatisticDecls(pass
, os
);
437 static void emitPasses(const llvm::RecordKeeper
&recordKeeper
,
439 std::vector
<Pass
> passes
= getPasses(recordKeeper
);
440 os
<< "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
443 os
<< "#ifdef GEN_PASS_DECL\n";
444 os
<< "// Generate declarations for all passes.\n";
445 for (const Pass
&pass
: passes
)
446 os
<< "#define " << getPassDeclVarName(pass
) << "\n";
447 os
<< "#undef GEN_PASS_DECL\n";
448 os
<< "#endif // GEN_PASS_DECL\n";
450 for (const Pass
&pass
: passes
)
453 emitRegistrations(passes
, os
);
455 // TODO: Drop old pass declarations.
456 // Emit the old code until all the passes have switched to the new design.
457 os
<< "// Deprecated. Please use the new per-pass macros.\n";
458 os
<< "#ifdef GEN_PASS_CLASSES\n";
459 for (const Pass
&pass
: passes
)
460 emitOldPassDecl(pass
, os
);
461 os
<< "#undef GEN_PASS_CLASSES\n";
462 os
<< "#endif // GEN_PASS_CLASSES\n";
465 static mlir::GenRegistration
466 genPassDecls("gen-pass-decls", "Generate pass declarations",
467 [](const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
468 emitPasses(records
, os
);