[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / PassGen.cpp
blobde159d144ffbb413d096c5ecacd839c0ea35d29e
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
22 using namespace mlir;
23 using namespace mlir::tblgen;
25 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
26 static llvm::cl::opt<std::string>
27 groupName("name", llvm::cl::desc("The name of this group of passes"),
28 llvm::cl::cat(passGenCat));
30 /// Extract the list of passes from the TableGen records.
31 static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
32 std::vector<Pass> passes;
34 for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
35 passes.emplace_back(def);
37 return passes;
40 const char *const passHeader = R"(
41 //===----------------------------------------------------------------------===//
42 // {0}
43 //===----------------------------------------------------------------------===//
44 )";
46 //===----------------------------------------------------------------------===//
47 // GEN: Pass registration generation
48 //===----------------------------------------------------------------------===//
50 /// The code snippet used to generate a pass registration.
51 ///
52 /// {0}: The def name of the pass record.
53 /// {1}: The pass constructor call.
54 const char *const passRegistrationCode = R"(
55 //===----------------------------------------------------------------------===//
56 // {0} Registration
57 //===----------------------------------------------------------------------===//
59 inline void register{0}() {{
60 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
61 return {1};
62 });
65 // Old registration code, kept for temporary backwards compatibility.
66 inline void register{0}Pass() {{
67 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
68 return {1};
69 });
71 )";
73 /// The code snippet used to generate a function to register all passes in a
74 /// group.
75 ///
76 /// {0}: The name of the pass group.
77 const char *const passGroupRegistrationCode = R"(
78 //===----------------------------------------------------------------------===//
79 // {0} Registration
80 //===----------------------------------------------------------------------===//
82 inline void register{0}Passes() {{
83 )";
85 /// Emits the definition of the struct to be used to control the pass options.
86 static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
87 StringRef passName = pass.getDef()->getName();
88 ArrayRef<PassOption> options = pass.getOptions();
90 // Emit the struct only if the pass has at least one option.
91 if (options.empty())
92 return;
94 os << llvm::formatv("struct {0}Options {{\n", passName);
96 for (const PassOption &opt : options) {
97 std::string type = opt.getType().str();
99 if (opt.isListOption())
100 type = "::llvm::ArrayRef<" + type + ">";
102 os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
104 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
105 os << " = " << defaultVal;
107 os << ";\n";
110 os << "};\n";
113 static std::string getPassDeclVarName(const Pass &pass) {
114 return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
117 /// Emit the code to be included in the public header of the pass.
118 static void emitPassDecls(const Pass &pass, raw_ostream &os) {
119 StringRef passName = pass.getDef()->getName();
120 std::string enableVarName = getPassDeclVarName(pass);
122 os << "#ifdef " << enableVarName << "\n";
123 emitPassOptionsStruct(pass, os);
125 if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
126 // Default constructor declaration.
127 os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
129 // Declaration of the constructor with options.
130 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
131 os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
132 "{0}Options &options);\n",
133 passName);
136 os << "#undef " << enableVarName << "\n";
137 os << "#endif // " << enableVarName << "\n";
140 /// Emit the code for registering each of the given passes with the global
141 /// PassRegistry.
142 static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
143 os << "#ifdef GEN_PASS_REGISTRATION\n";
145 for (const Pass &pass : passes) {
146 std::string constructorCall;
147 if (StringRef constructor = pass.getConstructor(); !constructor.empty())
148 constructorCall = constructor.str();
149 else
150 constructorCall =
151 llvm::formatv("create{0}()", pass.getDef()->getName()).str();
153 os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
154 constructorCall);
157 os << llvm::formatv(passGroupRegistrationCode, groupName);
159 for (const Pass &pass : passes)
160 os << " register" << pass.getDef()->getName() << "();\n";
162 os << "}\n";
163 os << "#undef GEN_PASS_REGISTRATION\n";
164 os << "#endif // GEN_PASS_REGISTRATION\n";
167 //===----------------------------------------------------------------------===//
168 // GEN: Pass base class generation
169 //===----------------------------------------------------------------------===//
171 /// The code snippet used to generate the start of a pass base class.
173 /// {0}: The def name of the pass record.
174 /// {1}: The base class for the pass.
175 /// {2): The command line argument for the pass.
176 /// {3}: The dependent dialects registration.
177 const char *const baseClassBegin = R"(
178 template <typename DerivedT>
179 class {0}Base : public {1} {
180 public:
181 using Base = {0}Base;
183 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
184 {0}Base(const {0}Base &other) : {1}(other) {{}
186 /// Returns the command-line argument attached to this pass.
187 static constexpr ::llvm::StringLiteral getArgumentName() {
188 return ::llvm::StringLiteral("{2}");
190 ::llvm::StringRef getArgument() const override { return "{2}"; }
192 ::llvm::StringRef getDescription() const override { return "{3}"; }
194 /// Returns the derived pass name.
195 static constexpr ::llvm::StringLiteral getPassName() {
196 return ::llvm::StringLiteral("{0}");
198 ::llvm::StringRef getName() const override { return "{0}"; }
200 /// Support isa/dyn_cast functionality for the derived pass class.
201 static bool classof(const ::mlir::Pass *pass) {{
202 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
205 /// A clone method to create a copy of this pass.
206 std::unique_ptr<::mlir::Pass> clonePass() const override {{
207 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
210 /// Return the dialect that must be loaded in the context before this pass.
211 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
215 /// Explicitly declare the TypeID for this class. We declare an explicit private
216 /// instantiation because Pass classes should only be visible by the current
217 /// library.
218 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
222 /// Registration for a single dependent dialect, to be inserted for each
223 /// dependent dialect in the `getDependentDialects` above.
224 const char *const dialectRegistrationTemplate = R"(
225 registry.insert<{0}>();
228 const char *const friendDefaultConstructorDeclTemplate = R"(
229 namespace impl {{
230 std::unique_ptr<::mlir::Pass> create{0}();
231 } // namespace impl
234 const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
235 namespace impl {{
236 std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options);
237 } // namespace impl
240 const char *const friendDefaultConstructorDefTemplate = R"(
241 friend std::unique_ptr<::mlir::Pass> create{0}() {{
242 return std::make_unique<DerivedT>();
246 const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
247 friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
248 return std::make_unique<DerivedT>(options);
252 const char *const defaultConstructorDefTemplate = R"(
253 std::unique_ptr<::mlir::Pass> create{0}() {{
254 return impl::create{0}();
258 const char *const defaultConstructorWithOptionsDefTemplate = R"(
259 std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
260 return impl::create{0}(options);
264 /// Emit the declarations for each of the pass options.
265 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
266 for (const PassOption &opt : pass.getOptions()) {
267 os.indent(2) << "::mlir::Pass::"
268 << (opt.isListOption() ? "ListOption" : "Option");
270 os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
271 opt.getType(), opt.getCppVariableName(),
272 opt.getArgument(), opt.getDescription());
273 if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
274 os << ", ::llvm::cl::init(" << defaultVal << ")";
275 if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
276 os << ", " << *additionalFlags;
277 os << "};\n";
281 /// Emit the declarations for each of the pass statistics.
282 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
283 for (const PassStatistic &stat : pass.getStatistics()) {
284 os << llvm::formatv(
285 " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
286 stat.getCppVariableName(), stat.getName(), stat.getDescription());
290 /// Emit the code to be used in the implementation of the pass.
291 static void emitPassDefs(const Pass &pass, raw_ostream &os) {
292 StringRef passName = pass.getDef()->getName();
293 std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
294 bool emitDefaultConstructors = pass.getConstructor().empty();
295 bool emitDefaultConstructorWithOptions = !pass.getOptions().empty();
297 os << "#ifdef " << enableVarName << "\n";
299 if (emitDefaultConstructors) {
300 os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName);
302 if (emitDefaultConstructorWithOptions)
303 os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate,
304 passName);
307 std::string dependentDialectRegistrations;
309 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
310 for (StringRef dependentDialect : pass.getDependentDialects())
311 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
312 dependentDialect);
315 os << "namespace impl {\n";
316 os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(),
317 pass.getArgument(), pass.getSummary(),
318 dependentDialectRegistrations);
320 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
321 os.indent(2) << llvm::formatv(
322 "{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
324 for (const PassOption &opt : pass.getOptions())
325 os.indent(4) << llvm::formatv("{0} = options.{0};\n",
326 opt.getCppVariableName());
328 os.indent(2) << "}\n";
331 // Protected content
332 os << "protected:\n";
333 emitPassOptionDecls(pass, os);
334 emitPassStatisticDecls(pass, os);
336 // Private content
337 os << "private:\n";
339 if (emitDefaultConstructors) {
340 os << llvm::formatv(friendDefaultConstructorDefTemplate, passName);
342 if (!pass.getOptions().empty())
343 os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate,
344 passName);
347 os << "};\n";
348 os << "} // namespace impl\n";
350 if (emitDefaultConstructors) {
351 os << llvm::formatv(defaultConstructorDefTemplate, passName);
353 if (emitDefaultConstructorWithOptions)
354 os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName);
357 os << "#undef " << enableVarName << "\n";
358 os << "#endif // " << enableVarName << "\n";
361 static void emitPass(const Pass &pass, raw_ostream &os) {
362 StringRef passName = pass.getDef()->getName();
363 os << llvm::formatv(passHeader, passName);
365 emitPassDecls(pass, os);
366 emitPassDefs(pass, os);
369 // TODO: Drop old pass declarations.
370 // The old pass base class is being kept until all the passes have switched to
371 // the new decls/defs design.
372 const char *const oldPassDeclBegin = R"(
373 template <typename DerivedT>
374 class {0}Base : public {1} {
375 public:
376 using Base = {0}Base;
378 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
379 {0}Base(const {0}Base &other) : {1}(other) {{}
381 /// Returns the command-line argument attached to this pass.
382 static constexpr ::llvm::StringLiteral getArgumentName() {
383 return ::llvm::StringLiteral("{2}");
385 ::llvm::StringRef getArgument() const override { return "{2}"; }
387 ::llvm::StringRef getDescription() const override { return "{3}"; }
389 /// Returns the derived pass name.
390 static constexpr ::llvm::StringLiteral getPassName() {
391 return ::llvm::StringLiteral("{0}");
393 ::llvm::StringRef getName() const override { return "{0}"; }
395 /// Support isa/dyn_cast functionality for the derived pass class.
396 static bool classof(const ::mlir::Pass *pass) {{
397 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
400 /// A clone method to create a copy of this pass.
401 std::unique_ptr<::mlir::Pass> clonePass() const override {{
402 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
405 /// Return the dialect that must be loaded in the context before this pass.
406 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
410 /// Explicitly declare the TypeID for this class. We declare an explicit private
411 /// instantiation because Pass classes should only be visible by the current
412 /// library.
413 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
415 protected:
418 // TODO: Drop old pass declarations.
419 /// Emit a backward-compatible declaration of the pass base class.
420 static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
421 StringRef defName = pass.getDef()->getName();
422 std::string dependentDialectRegistrations;
424 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
425 for (StringRef dependentDialect : pass.getDependentDialects())
426 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
427 dependentDialect);
429 os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
430 pass.getArgument(), pass.getSummary(),
431 dependentDialectRegistrations);
432 emitPassOptionDecls(pass, os);
433 emitPassStatisticDecls(pass, os);
434 os << "};\n";
437 static void emitPasses(const llvm::RecordKeeper &recordKeeper,
438 raw_ostream &os) {
439 std::vector<Pass> passes = getPasses(recordKeeper);
440 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
442 os << "\n";
443 os << "#ifdef GEN_PASS_DECL\n";
444 os << "// Generate declarations for all passes.\n";
445 for (const Pass &pass : passes)
446 os << "#define " << getPassDeclVarName(pass) << "\n";
447 os << "#undef GEN_PASS_DECL\n";
448 os << "#endif // GEN_PASS_DECL\n";
450 for (const Pass &pass : passes)
451 emitPass(pass, os);
453 emitRegistrations(passes, os);
455 // TODO: Drop old pass declarations.
456 // Emit the old code until all the passes have switched to the new design.
457 os << "// Deprecated. Please use the new per-pass macros.\n";
458 os << "#ifdef GEN_PASS_CLASSES\n";
459 for (const Pass &pass : passes)
460 emitOldPassDecl(pass, os);
461 os << "#undef GEN_PASS_CLASSES\n";
462 os << "#endif // GEN_PASS_CLASSES\n";
465 static mlir::GenRegistration
466 genPassDecls("gen-pass-decls", "Generate pass declarations",
467 [](const llvm::RecordKeeper &records, raw_ostream &os) {
468 emitPasses(records, os);
469 return false;