Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Pass / PassRegistry.cpp
blobfe842755958418e930abe7c8d18d689b02838a9c
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassRegistry.h"
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
21 #include <optional>
22 #include <utility>
24 using namespace mlir;
25 using namespace detail;
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36 passPipelineRegistry;
38 /// Utility to create a default registry function from a pass instance.
39 static PassRegistryFunction
40 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
41 return [=](OpPassManager &pm, StringRef options,
42 function_ref<LogicalResult(const Twine &)> errorHandler) {
43 std::unique_ptr<Pass> pass = allocator();
44 LogicalResult result = pass->initializeOptions(options, errorHandler);
46 std::optional<StringRef> pmOpName = pm.getOpName();
47 std::optional<StringRef> passOpName = pass->getOpName();
48 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49 passOpName && *pmOpName != *passOpName) {
50 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51 "' restricted to '" + *pass->getOpName() +
52 "' on a PassManager intended to run on '" +
53 pm.getOpAnchorName() + "', did you intend to nest?");
55 pm.addPass(std::move(pass));
56 return result;
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62 size_t descIndent, bool isTopLevel) {
63 size_t numSpaces = descIndent - indent - 4;
64 llvm::outs().indent(indent)
65 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
68 //===----------------------------------------------------------------------===//
69 // PassRegistry
70 //===----------------------------------------------------------------------===//
72 /// Prints the passes that were previously registered and stored in passRegistry
73 void mlir::printRegisteredPasses() {
74 size_t maxWidth = 0;
75 for (auto &entry : *passRegistry)
76 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
78 // Functor used to print the ordered entries of a registration map.
79 auto printOrderedEntries = [&](StringRef header, auto &map) {
80 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
81 for (auto &kv : map)
82 orderedEntries.push_back(&kv.second);
83 llvm::array_pod_sort(
84 orderedEntries.begin(), orderedEntries.end(),
85 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
87 });
89 llvm::outs().indent(0) << header << ":\n";
90 for (PassRegistryEntry *entry : orderedEntries)
91 entry->printHelpStr(/*indent=*/2, maxWidth);
94 // Print the available passes.
95 printOrderedEntries("Passes", *passRegistry);
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
103 /*isTopLevel=*/true);
104 // If this entry has options, print the help for those as well.
105 optHandler([=](const PassOptions &options) {
106 options.printHelp(indent, descIndent);
110 /// Return the maximum width required when printing the options of this
111 /// entry.
112 size_t PassRegistryEntry::getOptionWidth() const {
113 size_t maxLen = 0;
114 optHandler([&](const PassOptions &options) mutable {
115 maxLen = options.getOptionWidth() + 2;
117 return maxLen;
120 //===----------------------------------------------------------------------===//
121 // PassPipelineInfo
122 //===----------------------------------------------------------------------===//
124 void mlir::registerPassPipeline(
125 StringRef arg, StringRef description, const PassRegistryFunction &function,
126 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127 PassPipelineInfo pipelineInfo(arg, description, function,
128 std::move(optHandler));
129 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
130 #ifndef NDEBUG
131 if (!inserted)
132 report_fatal_error("Pass pipeline " + arg + " registered multiple times");
133 #endif
134 (void)inserted;
137 //===----------------------------------------------------------------------===//
138 // PassInfo
139 //===----------------------------------------------------------------------===//
141 PassInfo::PassInfo(StringRef arg, StringRef description,
142 const PassAllocatorFunction &allocator)
143 : PassRegistryEntry(
144 arg, description, buildDefaultRegistryFn(allocator),
145 // Use a temporary pass to provide an options instance.
146 [=](function_ref<void(const PassOptions &)> optHandler) {
147 optHandler(allocator()->passOptions);
148 }) {}
150 void mlir::registerPass(const PassAllocatorFunction &function) {
151 std::unique_ptr<Pass> pass = function();
152 StringRef arg = pass->getArgument();
153 if (arg.empty())
154 llvm::report_fatal_error(llvm::Twine("Trying to register '") +
155 pass->getName() +
156 "' pass that does not override `getArgument()`");
157 StringRef description = pass->getDescription();
158 PassInfo passInfo(arg, description, function);
159 passRegistry->try_emplace(arg, passInfo);
161 // Verify that the registered pass has the same ID as any registered to this
162 // arg before it.
163 TypeID entryTypeID = pass->getTypeID();
164 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
165 if (it->second != entryTypeID)
166 llvm::report_fatal_error(
167 "pass allocator creates a different pass than previously "
168 "registered for pass " +
169 arg);
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174 auto it = passRegistry->find(passArg);
175 return it == passRegistry->end() ? nullptr : &it->second;
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
179 /// null if unknown.
180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181 auto it = passPipelineRegistry->find(pipelineArg);
182 return it == passPipelineRegistry->end() ? nullptr : &it->second;
185 //===----------------------------------------------------------------------===//
186 // PassOptions
187 //===----------------------------------------------------------------------===//
189 /// Extract an argument from 'options' and update it to point after the arg.
190 /// Returns the cleaned argument string.
191 static StringRef extractArgAndUpdateOptions(StringRef &options,
192 size_t argSize) {
193 StringRef str = options.take_front(argSize).trim();
194 options = options.drop_front(argSize).ltrim();
196 // Early exit if there's no escape sequence.
197 if (str.size() <= 2)
198 return str;
200 const auto escapePairs = {std::make_pair('\'', '\''),
201 std::make_pair('"', '"'), std::make_pair('{', '}')};
202 for (const auto &escape : escapePairs) {
203 if (str.front() == escape.first && str.back() == escape.second) {
204 // Drop the escape characters and trim.
205 str = str.drop_front().drop_back().trim();
206 // Don't process additional escape sequences.
207 break;
211 return str;
214 LogicalResult detail::pass_options::parseCommaSeparatedList(
215 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
216 function_ref<LogicalResult(StringRef)> elementParseFn) {
217 // Functor used for finding a character in a string, and skipping over
218 // various "range" characters.
219 llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
220 [&](StringRef str, size_t index, char c) -> size_t {
221 for (size_t i = index, e = str.size(); i < e; ++i) {
222 if (str[i] == c)
223 return i;
224 // Check for various range characters.
225 if (str[i] == '{')
226 i = findChar(str, i + 1, '}');
227 else if (str[i] == '(')
228 i = findChar(str, i + 1, ')');
229 else if (str[i] == '[')
230 i = findChar(str, i + 1, ']');
231 else if (str[i] == '\"')
232 i = str.find_first_of('\"', i + 1);
233 else if (str[i] == '\'')
234 i = str.find_first_of('\'', i + 1);
236 return StringRef::npos;
239 size_t nextElePos = findChar(optionStr, 0, ',');
240 while (nextElePos != StringRef::npos) {
241 // Process the portion before the comma.
242 if (failed(
243 elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
244 return failure();
246 // Drop the leading ','
247 optionStr = optionStr.drop_front();
248 nextElePos = findChar(optionStr, 0, ',');
250 return elementParseFn(
251 extractArgAndUpdateOptions(optionStr, optionStr.size()));
254 /// Out of line virtual function to provide home for the class.
255 void detail::PassOptions::OptionBase::anchor() {}
257 /// Copy the option values from 'other'.
258 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
259 assert(options.size() == other.options.size());
260 if (options.empty())
261 return;
262 for (auto optionsIt : llvm::zip(options, other.options))
263 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
266 /// Parse in the next argument from the given options string. Returns a tuple
267 /// containing [the key of the option, the value of the option, updated
268 /// `options` string pointing after the parsed option].
269 static std::tuple<StringRef, StringRef, StringRef>
270 parseNextArg(StringRef options) {
271 // Try to process the given punctuation, properly escaping any contained
272 // characters.
273 auto tryProcessPunct = [&](size_t &currentPos, char punct) {
274 if (options[currentPos] != punct)
275 return false;
276 size_t nextIt = options.find_first_of(punct, currentPos + 1);
277 if (nextIt != StringRef::npos)
278 currentPos = nextIt;
279 return true;
282 // Parse the argument name of the option.
283 StringRef argName;
284 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
285 // Check for the end of the full option.
286 if (argEndIt == optionsE || options[argEndIt] == ' ') {
287 argName = extractArgAndUpdateOptions(options, argEndIt);
288 return std::make_tuple(argName, StringRef(), options);
291 // Check for the end of the name and the start of the value.
292 if (options[argEndIt] == '=') {
293 argName = extractArgAndUpdateOptions(options, argEndIt);
294 options = options.drop_front();
295 break;
299 // Parse the value of the option.
300 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
301 // Handle the end of the options string.
302 if (argEndIt == optionsE || options[argEndIt] == ' ') {
303 StringRef value = extractArgAndUpdateOptions(options, argEndIt);
304 return std::make_tuple(argName, value, options);
307 // Skip over escaped sequences.
308 char c = options[argEndIt];
309 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
310 continue;
311 // '{...}' is used to specify options to passes, properly escape it so
312 // that we don't accidentally split any nested options.
313 if (c == '{') {
314 size_t braceCount = 1;
315 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
316 // Allow nested punctuation.
317 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
318 continue;
319 if (options[argEndIt] == '{')
320 ++braceCount;
321 else if (options[argEndIt] == '}' && --braceCount == 0)
322 break;
324 // Account for the increment at the top of the loop.
325 --argEndIt;
328 llvm_unreachable("unexpected control flow in pass option parsing");
331 LogicalResult detail::PassOptions::parseFromString(StringRef options,
332 raw_ostream &errorStream) {
333 // NOTE: `options` is modified in place to always refer to the unprocessed
334 // part of the string.
335 while (!options.empty()) {
336 StringRef key, value;
337 std::tie(key, value, options) = parseNextArg(options);
338 if (key.empty())
339 continue;
341 auto it = OptionsMap.find(key);
342 if (it == OptionsMap.end()) {
343 errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
344 return failure();
346 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
347 return failure();
350 return success();
353 /// Print the options held by this struct in a form that can be parsed via
354 /// 'parseFromString'.
355 void detail::PassOptions::print(raw_ostream &os) const {
356 // If there are no options, there is nothing left to do.
357 if (OptionsMap.empty())
358 return;
360 // Sort the options to make the ordering deterministic.
361 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
362 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
363 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
365 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
367 // Interleave the options with ' '.
368 os << '{';
369 llvm::interleave(
370 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
371 os << '}';
374 /// Print the help string for the options held by this struct. `descIndent` is
375 /// the indent within the stream that the descriptions should be aligned.
376 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
377 // Sort the options to make the ordering deterministic.
378 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
379 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
380 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
382 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
383 for (OptionBase *option : orderedOps) {
384 // TODO: printOptionInfo assumes a specific indent and will
385 // print options with values with incorrect indentation. We should add
386 // support to llvm::cl::Option for passing in a base indent to use when
387 // printing.
388 llvm::outs().indent(indent);
389 option->getOption()->printOptionInfo(descIndent - indent);
393 /// Return the maximum width required when printing the help string.
394 size_t detail::PassOptions::getOptionWidth() const {
395 size_t max = 0;
396 for (auto *option : options)
397 max = std::max(max, option->getOption()->getOptionWidth());
398 return max;
401 //===----------------------------------------------------------------------===//
402 // MLIR Options
403 //===----------------------------------------------------------------------===//
405 //===----------------------------------------------------------------------===//
406 // OpPassManager: OptionValue
408 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
409 llvm::cl::OptionValue<OpPassManager>::OptionValue(
410 const mlir::OpPassManager &value) {
411 setValue(value);
413 llvm::cl::OptionValue<OpPassManager>::OptionValue(
414 const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
415 if (rhs.hasValue())
416 setValue(rhs.getValue());
418 llvm::cl::OptionValue<OpPassManager> &
419 llvm::cl::OptionValue<OpPassManager>::operator=(
420 const mlir::OpPassManager &rhs) {
421 setValue(rhs);
422 return *this;
425 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
427 void llvm::cl::OptionValue<OpPassManager>::setValue(
428 const OpPassManager &newValue) {
429 if (hasValue())
430 *value = newValue;
431 else
432 value = std::make_unique<mlir::OpPassManager>(newValue);
434 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
435 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
436 assert(succeeded(pipeline) && "invalid pass pipeline");
437 setValue(*pipeline);
440 bool llvm::cl::OptionValue<OpPassManager>::compare(
441 const mlir::OpPassManager &rhs) const {
442 std::string lhsStr, rhsStr;
444 raw_string_ostream lhsStream(lhsStr);
445 value->printAsTextualPipeline(lhsStream);
447 raw_string_ostream rhsStream(rhsStr);
448 rhs.printAsTextualPipeline(rhsStream);
451 // Use the textual format for pipeline comparisons.
452 return lhsStr == rhsStr;
455 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
457 //===----------------------------------------------------------------------===//
458 // OpPassManager: Parser
460 namespace llvm {
461 namespace cl {
462 template class basic_parser<OpPassManager>;
463 } // namespace cl
464 } // namespace llvm
466 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
467 ParsedPassManager &value) {
468 FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
469 if (failed(pipeline))
470 return true;
471 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
472 return false;
475 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
476 const OpPassManager &value) {
477 value.printAsTextualPipeline(os);
480 void llvm::cl::parser<OpPassManager>::printOptionDiff(
481 const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
482 size_t globalWidth) const {
483 printOptionName(opt, globalWidth);
484 outs() << "= ";
485 pm.printAsTextualPipeline(outs());
487 if (defaultValue.hasValue()) {
488 outs().indent(2) << " (default: ";
489 defaultValue.getValue().printAsTextualPipeline(outs());
490 outs() << ")";
492 outs() << "\n";
495 void llvm::cl::parser<OpPassManager>::anchor() {}
497 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
498 default;
499 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
500 ParsedPassManager &&) = default;
501 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
502 default;
504 //===----------------------------------------------------------------------===//
505 // TextualPassPipeline Parser
506 //===----------------------------------------------------------------------===//
508 namespace {
509 /// This class represents a textual description of a pass pipeline.
510 class TextualPipeline {
511 public:
512 /// Try to initialize this pipeline with the given pipeline text.
513 /// `errorStream` is the output stream to emit errors to.
514 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
516 /// Add the internal pipeline elements to the provided pass manager.
517 LogicalResult
518 addToPipeline(OpPassManager &pm,
519 function_ref<LogicalResult(const Twine &)> errorHandler) const;
521 private:
522 /// A functor used to emit errors found during pipeline handling. The first
523 /// parameter corresponds to the raw location within the pipeline string. This
524 /// should always return failure.
525 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
527 /// A struct to capture parsed pass pipeline names.
529 /// A pipeline is defined as a series of names, each of which may in itself
530 /// recursively contain a nested pipeline. A name is either the name of a pass
531 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
532 /// the name is the name of a pass, the InnerPipeline is empty, since passes
533 /// cannot contain inner pipelines.
534 struct PipelineElement {
535 PipelineElement(StringRef name) : name(name) {}
537 StringRef name;
538 StringRef options;
539 const PassRegistryEntry *registryEntry = nullptr;
540 std::vector<PipelineElement> innerPipeline;
543 /// Parse the given pipeline text into the internal pipeline vector. This
544 /// function only parses the structure of the pipeline, and does not resolve
545 /// its elements.
546 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
548 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
549 /// the corresponding registry entry.
550 LogicalResult
551 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
552 ErrorHandlerT errorHandler);
554 /// Resolve a single element of the pipeline.
555 LogicalResult resolvePipelineElement(PipelineElement &element,
556 ErrorHandlerT errorHandler);
558 /// Add the given pipeline elements to the provided pass manager.
559 LogicalResult
560 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
561 function_ref<LogicalResult(const Twine &)> errorHandler) const;
563 std::vector<PipelineElement> pipeline;
566 } // namespace
568 /// Try to initialize this pipeline with the given pipeline text. An option is
569 /// given to enable accurate error reporting.
570 LogicalResult TextualPipeline::initialize(StringRef text,
571 raw_ostream &errorStream) {
572 if (text.empty())
573 return success();
575 // Build a source manager to use for error reporting.
576 llvm::SourceMgr pipelineMgr;
577 pipelineMgr.AddNewSourceBuffer(
578 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
579 /*RequiresNullTerminator=*/false),
580 SMLoc());
581 auto errorHandler = [&](const char *rawLoc, Twine msg) {
582 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
583 llvm::SourceMgr::DK_Error, msg);
584 return failure();
587 // Parse the provided pipeline string.
588 if (failed(parsePipelineText(text, errorHandler)))
589 return failure();
590 return resolvePipelineElements(pipeline, errorHandler);
593 /// Add the internal pipeline elements to the provided pass manager.
594 LogicalResult TextualPipeline::addToPipeline(
595 OpPassManager &pm,
596 function_ref<LogicalResult(const Twine &)> errorHandler) const {
597 // Temporarily disable implicit nesting while we append to the pipeline. We
598 // want the created pipeline to exactly match the parsed text pipeline, so
599 // it's preferrable to just error out if implicit nesting would be required.
600 OpPassManager::Nesting nesting = pm.getNesting();
601 pm.setNesting(OpPassManager::Nesting::Explicit);
602 auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
604 return addToPipeline(pipeline, pm, errorHandler);
607 /// Parse the given pipeline text into the internal pipeline vector. This
608 /// function only parses the structure of the pipeline, and does not resolve
609 /// its elements.
610 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
611 ErrorHandlerT errorHandler) {
612 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
613 for (;;) {
614 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
615 size_t pos = text.find_first_of(",(){");
616 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
618 // If we have a single terminating name, we're done.
619 if (pos == StringRef::npos)
620 break;
622 text = text.substr(pos);
623 char sep = text[0];
625 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
626 if (sep == '{') {
627 text = text.substr(1);
629 // Skip over everything until the closing '}' and store as options.
630 size_t close = StringRef::npos;
631 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
632 if (text[i] == '{') {
633 ++braceCount;
634 continue;
636 if (text[i] == '}' && --braceCount == 0) {
637 close = i;
638 break;
642 // Check to see if a closing options brace was found.
643 if (close == StringRef::npos) {
644 return errorHandler(
645 /*rawLoc=*/text.data() - 1,
646 "missing closing '}' while processing pass options");
648 pipeline.back().options = text.substr(0, close);
649 text = text.substr(close + 1);
651 // Consume space characters that an user might add for readability.
652 text = text.ltrim();
654 // Skip checking for '(' because nested pipelines cannot have options.
655 } else if (sep == '(') {
656 text = text.substr(1);
658 // Push the inner pipeline onto the stack to continue processing.
659 pipelineStack.push_back(&pipeline.back().innerPipeline);
660 continue;
663 // When handling the close parenthesis, we greedily consume them to avoid
664 // empty strings in the pipeline.
665 while (text.consume_front(")")) {
666 // If we try to pop the outer pipeline we have unbalanced parentheses.
667 if (pipelineStack.size() == 1)
668 return errorHandler(/*rawLoc=*/text.data() - 1,
669 "encountered extra closing ')' creating unbalanced "
670 "parentheses while parsing pipeline");
672 pipelineStack.pop_back();
673 // Consume space characters that an user might add for readability.
674 text = text.ltrim();
677 // Check if we've finished parsing.
678 if (text.empty())
679 break;
681 // Otherwise, the end of an inner pipeline always has to be followed by
682 // a comma, and then we can continue.
683 if (!text.consume_front(","))
684 return errorHandler(text.data(), "expected ',' after parsing pipeline");
687 // Check for unbalanced parentheses.
688 if (pipelineStack.size() > 1)
689 return errorHandler(
690 text.data(),
691 "encountered unbalanced parentheses while parsing pipeline");
693 assert(pipelineStack.back() == &pipeline &&
694 "wrong pipeline at the bottom of the stack");
695 return success();
698 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
699 /// the corresponding registry entry.
700 LogicalResult TextualPipeline::resolvePipelineElements(
701 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
702 for (auto &elt : elements)
703 if (failed(resolvePipelineElement(elt, errorHandler)))
704 return failure();
705 return success();
708 /// Resolve a single element of the pipeline.
709 LogicalResult
710 TextualPipeline::resolvePipelineElement(PipelineElement &element,
711 ErrorHandlerT errorHandler) {
712 // If the inner pipeline of this element is not empty, this is an operation
713 // pipeline.
714 if (!element.innerPipeline.empty())
715 return resolvePipelineElements(element.innerPipeline, errorHandler);
717 // Otherwise, this must be a pass or pass pipeline.
718 // Check to see if a pipeline was registered with this name.
719 if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
720 return success();
722 // If not, then this must be a specific pass name.
723 if ((element.registryEntry = PassInfo::lookup(element.name)))
724 return success();
726 // Emit an error for the unknown pass.
727 auto *rawLoc = element.name.data();
728 return errorHandler(rawLoc, "'" + element.name +
729 "' does not refer to a "
730 "registered pass or pass pipeline");
733 /// Add the given pipeline elements to the provided pass manager.
734 LogicalResult TextualPipeline::addToPipeline(
735 ArrayRef<PipelineElement> elements, OpPassManager &pm,
736 function_ref<LogicalResult(const Twine &)> errorHandler) const {
737 for (auto &elt : elements) {
738 if (elt.registryEntry) {
739 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
740 errorHandler))) {
741 return errorHandler("failed to add `" + elt.name + "` with options `" +
742 elt.options + "`");
744 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
745 errorHandler))) {
746 return errorHandler("failed to add `" + elt.name + "` with options `" +
747 elt.options + "` to inner pipeline");
750 return success();
753 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
754 raw_ostream &errorStream) {
755 TextualPipeline pipelineParser;
756 if (failed(pipelineParser.initialize(pipeline, errorStream)))
757 return failure();
758 auto errorHandler = [&](Twine msg) {
759 errorStream << msg << "\n";
760 return failure();
762 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
763 return failure();
764 return success();
767 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
768 raw_ostream &errorStream) {
769 pipeline = pipeline.trim();
770 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
771 size_t pipelineStart = pipeline.find_first_of('(');
772 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
773 !pipeline.consume_back(")")) {
774 errorStream << "expected pass pipeline to be wrapped with the anchor "
775 "operation type, e.g. 'builtin.module(...)'";
776 return failure();
779 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
780 OpPassManager pm(opName);
781 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
782 errorStream)))
783 return failure();
784 return pm;
787 //===----------------------------------------------------------------------===//
788 // PassNameParser
789 //===----------------------------------------------------------------------===//
791 namespace {
792 /// This struct represents the possible data entries in a parsed pass pipeline
793 /// list.
794 struct PassArgData {
795 PassArgData() = default;
796 PassArgData(const PassRegistryEntry *registryEntry)
797 : registryEntry(registryEntry) {}
799 /// This field is used when the parsed option corresponds to a registered pass
800 /// or pass pipeline.
801 const PassRegistryEntry *registryEntry{nullptr};
803 /// This field is set when instance specific pass options have been provided
804 /// on the command line.
805 StringRef options;
807 } // namespace
809 namespace llvm {
810 namespace cl {
811 /// Define a valid OptionValue for the command line pass argument.
812 template <>
813 struct OptionValue<PassArgData> final
814 : OptionValueBase<PassArgData, /*isClass=*/true> {
815 OptionValue(const PassArgData &value) { this->setValue(value); }
816 OptionValue() = default;
817 void anchor() override {}
819 bool hasValue() const { return true; }
820 const PassArgData &getValue() const { return value; }
821 void setValue(const PassArgData &value) { this->value = value; }
823 PassArgData value;
825 } // namespace cl
826 } // namespace llvm
828 namespace {
830 /// The name for the command line option used for parsing the textual pass
831 /// pipeline.
832 #define PASS_PIPELINE_ARG "pass-pipeline"
834 /// Adds command line option for each registered pass or pass pipeline, as well
835 /// as textual pass pipelines.
836 struct PassNameParser : public llvm::cl::parser<PassArgData> {
837 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
839 void initialize();
840 void printOptionInfo(const llvm::cl::Option &opt,
841 size_t globalWidth) const override;
842 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
843 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
844 PassArgData &value);
846 /// If true, this parser only parses entries that correspond to a concrete
847 /// pass registry entry, and does not include pipeline entries or the options
848 /// for pass entries.
849 bool passNamesOnly = false;
851 } // namespace
853 void PassNameParser::initialize() {
854 llvm::cl::parser<PassArgData>::initialize();
856 /// Add the pass entries.
857 for (const auto &kv : *passRegistry) {
858 addLiteralOption(kv.second.getPassArgument(), &kv.second,
859 kv.second.getPassDescription());
861 /// Add the pass pipeline entries.
862 if (!passNamesOnly) {
863 for (const auto &kv : *passPipelineRegistry) {
864 addLiteralOption(kv.second.getPassArgument(), &kv.second,
865 kv.second.getPassDescription());
870 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
871 size_t globalWidth) const {
872 // If this parser is just parsing pass names, print a simplified option
873 // string.
874 if (passNamesOnly) {
875 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
876 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
877 return;
880 // Print the information for the top-level option.
881 if (opt.hasArgStr()) {
882 llvm::outs() << " --" << opt.ArgStr;
883 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
884 } else {
885 llvm::outs() << " " << opt.HelpStr << '\n';
888 // Functor used to print the ordered entries of a registration map.
889 auto printOrderedEntries = [&](StringRef header, auto &map) {
890 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
891 for (auto &kv : map)
892 orderedEntries.push_back(&kv.second);
893 llvm::array_pod_sort(
894 orderedEntries.begin(), orderedEntries.end(),
895 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
896 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
899 llvm::outs().indent(4) << header << ":\n";
900 for (PassRegistryEntry *entry : orderedEntries)
901 entry->printHelpStr(/*indent=*/6, globalWidth);
904 // Print the available passes.
905 printOrderedEntries("Passes", *passRegistry);
907 // Print the available pass pipelines.
908 if (!passPipelineRegistry->empty())
909 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
912 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
913 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
915 // Check for any wider pass or pipeline options.
916 for (auto &entry : *passRegistry)
917 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
918 for (auto &entry : *passPipelineRegistry)
919 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
920 return maxWidth;
923 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
924 StringRef arg, PassArgData &value) {
925 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
926 return true;
927 value.options = arg;
928 return false;
931 //===----------------------------------------------------------------------===//
932 // PassPipelineCLParser
933 //===----------------------------------------------------------------------===//
935 namespace mlir {
936 namespace detail {
937 struct PassPipelineCLParserImpl {
938 PassPipelineCLParserImpl(StringRef arg, StringRef description,
939 bool passNamesOnly)
940 : passList(arg, llvm::cl::desc(description)) {
941 passList.getParser().passNamesOnly = passNamesOnly;
942 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
945 /// Returns true if the given pass registry entry was registered at the
946 /// top-level of the parser, i.e. not within an explicit textual pipeline.
947 bool contains(const PassRegistryEntry *entry) const {
948 return llvm::any_of(passList, [&](const PassArgData &data) {
949 return data.registryEntry == entry;
953 /// The set of passes and pass pipelines to run.
954 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
956 } // namespace detail
957 } // namespace mlir
959 /// Construct a pass pipeline parser with the given command line description.
960 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
961 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
962 arg, description, /*passNamesOnly=*/false)),
963 passPipeline(
964 PASS_PIPELINE_ARG,
965 llvm::cl::desc("Textual description of the pass pipeline to run")) {}
967 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
968 StringRef alias)
969 : PassPipelineCLParser(arg, description) {
970 passPipelineAlias.emplace(alias,
971 llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
972 llvm::cl::aliasopt(passPipeline));
975 PassPipelineCLParser::~PassPipelineCLParser() = default;
977 /// Returns true if this parser contains any valid options to add.
978 bool PassPipelineCLParser::hasAnyOccurrences() const {
979 return passPipeline.getNumOccurrences() != 0 ||
980 impl->passList.getNumOccurrences() != 0;
983 /// Returns true if the given pass registry entry was registered at the
984 /// top-level of the parser, i.e. not within an explicit textual pipeline.
985 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
986 return impl->contains(entry);
989 /// Adds the passes defined by this parser entry to the given pass manager.
990 LogicalResult PassPipelineCLParser::addToPipeline(
991 OpPassManager &pm,
992 function_ref<LogicalResult(const Twine &)> errorHandler) const {
993 if (passPipeline.getNumOccurrences()) {
994 if (impl->passList.getNumOccurrences())
995 return errorHandler(
996 "'-" PASS_PIPELINE_ARG
997 "' option can't be used with individual pass options");
998 std::string errMsg;
999 llvm::raw_string_ostream os(errMsg);
1000 FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1001 if (failed(parsed))
1002 return errorHandler(errMsg);
1003 pm = std::move(*parsed);
1004 return success();
1007 for (auto &passIt : impl->passList) {
1008 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1009 errorHandler)))
1010 return failure();
1012 return success();
1015 //===----------------------------------------------------------------------===//
1016 // PassNameCLParser
1018 /// Construct a pass pipeline parser with the given command line description.
1019 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1020 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1021 arg, description, /*passNamesOnly=*/true)) {
1022 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1024 PassNameCLParser::~PassNameCLParser() = default;
1026 /// Returns true if this parser contains any valid options to add.
1027 bool PassNameCLParser::hasAnyOccurrences() const {
1028 return impl->passList.getNumOccurrences() != 0;
1031 /// Returns true if the given pass registry entry was registered at the
1032 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1033 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
1034 return impl->contains(entry);