1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassRegistry.h"
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
25 using namespace detail
;
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic
<llvm::StringMap
<PassInfo
>> passRegistry
;
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic
<llvm::StringMap
<TypeID
>> passRegistryTypeIDs
;
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic
<llvm::StringMap
<PassPipelineInfo
>>
38 /// Utility to create a default registry function from a pass instance.
39 static PassRegistryFunction
40 buildDefaultRegistryFn(const PassAllocatorFunction
&allocator
) {
41 return [=](OpPassManager
&pm
, StringRef options
,
42 function_ref
<LogicalResult(const Twine
&)> errorHandler
) {
43 std::unique_ptr
<Pass
> pass
= allocator();
44 LogicalResult result
= pass
->initializeOptions(options
, errorHandler
);
46 std::optional
<StringRef
> pmOpName
= pm
.getOpName();
47 std::optional
<StringRef
> passOpName
= pass
->getOpName();
48 if ((pm
.getNesting() == OpPassManager::Nesting::Explicit
) && pmOpName
&&
49 passOpName
&& *pmOpName
!= *passOpName
) {
50 return errorHandler(llvm::Twine("Can't add pass '") + pass
->getName() +
51 "' restricted to '" + *pass
->getOpName() +
52 "' on a PassManager intended to run on '" +
53 pm
.getOpAnchorName() + "', did you intend to nest?");
55 pm
.addPass(std::move(pass
));
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg
, StringRef desc
, size_t indent
,
62 size_t descIndent
, bool isTopLevel
) {
63 size_t numSpaces
= descIndent
- indent
- 4;
64 llvm::outs().indent(indent
)
65 << "--" << llvm::left_justify(arg
, numSpaces
) << "- " << desc
<< '\n';
68 //===----------------------------------------------------------------------===//
70 //===----------------------------------------------------------------------===//
72 /// Prints the passes that were previously registered and stored in passRegistry
73 void mlir::printRegisteredPasses() {
75 for (auto &entry
: *passRegistry
)
76 maxWidth
= std::max(maxWidth
, entry
.second
.getOptionWidth() + 4);
78 // Functor used to print the ordered entries of a registration map.
79 auto printOrderedEntries
= [&](StringRef header
, auto &map
) {
80 llvm::SmallVector
<PassRegistryEntry
*, 32> orderedEntries
;
82 orderedEntries
.push_back(&kv
.second
);
84 orderedEntries
.begin(), orderedEntries
.end(),
85 [](PassRegistryEntry
*const *lhs
, PassRegistryEntry
*const *rhs
) {
86 return (*lhs
)->getPassArgument().compare((*rhs
)->getPassArgument());
89 llvm::outs().indent(0) << header
<< ":\n";
90 for (PassRegistryEntry
*entry
: orderedEntries
)
91 entry
->printHelpStr(/*indent=*/2, maxWidth
);
94 // Print the available passes.
95 printOrderedEntries("Passes", *passRegistry
);
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent
, size_t descIndent
) const {
102 printOptionHelp(getPassArgument(), getPassDescription(), indent
, descIndent
,
103 /*isTopLevel=*/true);
104 // If this entry has options, print the help for those as well.
105 optHandler([=](const PassOptions
&options
) {
106 options
.printHelp(indent
, descIndent
);
110 /// Return the maximum width required when printing the options of this
112 size_t PassRegistryEntry::getOptionWidth() const {
114 optHandler([&](const PassOptions
&options
) mutable {
115 maxLen
= options
.getOptionWidth() + 2;
120 //===----------------------------------------------------------------------===//
122 //===----------------------------------------------------------------------===//
124 void mlir::registerPassPipeline(
125 StringRef arg
, StringRef description
, const PassRegistryFunction
&function
,
126 std::function
<void(function_ref
<void(const PassOptions
&)>)> optHandler
) {
127 PassPipelineInfo
pipelineInfo(arg
, description
, function
,
128 std::move(optHandler
));
129 bool inserted
= passPipelineRegistry
->try_emplace(arg
, pipelineInfo
).second
;
132 report_fatal_error("Pass pipeline " + arg
+ " registered multiple times");
137 //===----------------------------------------------------------------------===//
139 //===----------------------------------------------------------------------===//
141 PassInfo::PassInfo(StringRef arg
, StringRef description
,
142 const PassAllocatorFunction
&allocator
)
144 arg
, description
, buildDefaultRegistryFn(allocator
),
145 // Use a temporary pass to provide an options instance.
146 [=](function_ref
<void(const PassOptions
&)> optHandler
) {
147 optHandler(allocator()->passOptions
);
150 void mlir::registerPass(const PassAllocatorFunction
&function
) {
151 std::unique_ptr
<Pass
> pass
= function();
152 StringRef arg
= pass
->getArgument();
154 llvm::report_fatal_error(llvm::Twine("Trying to register '") +
156 "' pass that does not override `getArgument()`");
157 StringRef description
= pass
->getDescription();
158 PassInfo
passInfo(arg
, description
, function
);
159 passRegistry
->try_emplace(arg
, passInfo
);
161 // Verify that the registered pass has the same ID as any registered to this
163 TypeID entryTypeID
= pass
->getTypeID();
164 auto it
= passRegistryTypeIDs
->try_emplace(arg
, entryTypeID
).first
;
165 if (it
->second
!= entryTypeID
)
166 llvm::report_fatal_error(
167 "pass allocator creates a different pass than previously "
168 "registered for pass " +
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo
*mlir::PassInfo::lookup(StringRef passArg
) {
174 auto it
= passRegistry
->find(passArg
);
175 return it
== passRegistry
->end() ? nullptr : &it
->second
;
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
180 const PassPipelineInfo
*mlir::PassPipelineInfo::lookup(StringRef pipelineArg
) {
181 auto it
= passPipelineRegistry
->find(pipelineArg
);
182 return it
== passPipelineRegistry
->end() ? nullptr : &it
->second
;
185 //===----------------------------------------------------------------------===//
187 //===----------------------------------------------------------------------===//
189 /// Extract an argument from 'options' and update it to point after the arg.
190 /// Returns the cleaned argument string.
191 static StringRef
extractArgAndUpdateOptions(StringRef
&options
,
193 StringRef str
= options
.take_front(argSize
).trim();
194 options
= options
.drop_front(argSize
).ltrim();
196 // Early exit if there's no escape sequence.
200 const auto escapePairs
= {std::make_pair('\'', '\''),
201 std::make_pair('"', '"'), std::make_pair('{', '}')};
202 for (const auto &escape
: escapePairs
) {
203 if (str
.front() == escape
.first
&& str
.back() == escape
.second
) {
204 // Drop the escape characters and trim.
205 str
= str
.drop_front().drop_back().trim();
206 // Don't process additional escape sequences.
214 LogicalResult
detail::pass_options::parseCommaSeparatedList(
215 llvm::cl::Option
&opt
, StringRef argName
, StringRef optionStr
,
216 function_ref
<LogicalResult(StringRef
)> elementParseFn
) {
217 // Functor used for finding a character in a string, and skipping over
218 // various "range" characters.
219 llvm::unique_function
<size_t(StringRef
, size_t, char)> findChar
=
220 [&](StringRef str
, size_t index
, char c
) -> size_t {
221 for (size_t i
= index
, e
= str
.size(); i
< e
; ++i
) {
224 // Check for various range characters.
226 i
= findChar(str
, i
+ 1, '}');
227 else if (str
[i
] == '(')
228 i
= findChar(str
, i
+ 1, ')');
229 else if (str
[i
] == '[')
230 i
= findChar(str
, i
+ 1, ']');
231 else if (str
[i
] == '\"')
232 i
= str
.find_first_of('\"', i
+ 1);
233 else if (str
[i
] == '\'')
234 i
= str
.find_first_of('\'', i
+ 1);
236 return StringRef::npos
;
239 size_t nextElePos
= findChar(optionStr
, 0, ',');
240 while (nextElePos
!= StringRef::npos
) {
241 // Process the portion before the comma.
243 elementParseFn(extractArgAndUpdateOptions(optionStr
, nextElePos
))))
246 // Drop the leading ','
247 optionStr
= optionStr
.drop_front();
248 nextElePos
= findChar(optionStr
, 0, ',');
250 return elementParseFn(
251 extractArgAndUpdateOptions(optionStr
, optionStr
.size()));
254 /// Out of line virtual function to provide home for the class.
255 void detail::PassOptions::OptionBase::anchor() {}
257 /// Copy the option values from 'other'.
258 void detail::PassOptions::copyOptionValuesFrom(const PassOptions
&other
) {
259 assert(options
.size() == other
.options
.size());
262 for (auto optionsIt
: llvm::zip(options
, other
.options
))
263 std::get
<0>(optionsIt
)->copyValueFrom(*std::get
<1>(optionsIt
));
266 /// Parse in the next argument from the given options string. Returns a tuple
267 /// containing [the key of the option, the value of the option, updated
268 /// `options` string pointing after the parsed option].
269 static std::tuple
<StringRef
, StringRef
, StringRef
>
270 parseNextArg(StringRef options
) {
271 // Try to process the given punctuation, properly escaping any contained
273 auto tryProcessPunct
= [&](size_t ¤tPos
, char punct
) {
274 if (options
[currentPos
] != punct
)
276 size_t nextIt
= options
.find_first_of(punct
, currentPos
+ 1);
277 if (nextIt
!= StringRef::npos
)
282 // Parse the argument name of the option.
284 for (size_t argEndIt
= 0, optionsE
= options
.size();; ++argEndIt
) {
285 // Check for the end of the full option.
286 if (argEndIt
== optionsE
|| options
[argEndIt
] == ' ') {
287 argName
= extractArgAndUpdateOptions(options
, argEndIt
);
288 return std::make_tuple(argName
, StringRef(), options
);
291 // Check for the end of the name and the start of the value.
292 if (options
[argEndIt
] == '=') {
293 argName
= extractArgAndUpdateOptions(options
, argEndIt
);
294 options
= options
.drop_front();
299 // Parse the value of the option.
300 for (size_t argEndIt
= 0, optionsE
= options
.size();; ++argEndIt
) {
301 // Handle the end of the options string.
302 if (argEndIt
== optionsE
|| options
[argEndIt
] == ' ') {
303 StringRef value
= extractArgAndUpdateOptions(options
, argEndIt
);
304 return std::make_tuple(argName
, value
, options
);
307 // Skip over escaped sequences.
308 char c
= options
[argEndIt
];
309 if (tryProcessPunct(argEndIt
, '\'') || tryProcessPunct(argEndIt
, '"'))
311 // '{...}' is used to specify options to passes, properly escape it so
312 // that we don't accidentally split any nested options.
314 size_t braceCount
= 1;
315 for (++argEndIt
; argEndIt
!= optionsE
; ++argEndIt
) {
316 // Allow nested punctuation.
317 if (tryProcessPunct(argEndIt
, '\'') || tryProcessPunct(argEndIt
, '"'))
319 if (options
[argEndIt
] == '{')
321 else if (options
[argEndIt
] == '}' && --braceCount
== 0)
324 // Account for the increment at the top of the loop.
328 llvm_unreachable("unexpected control flow in pass option parsing");
331 LogicalResult
detail::PassOptions::parseFromString(StringRef options
,
332 raw_ostream
&errorStream
) {
333 // NOTE: `options` is modified in place to always refer to the unprocessed
334 // part of the string.
335 while (!options
.empty()) {
336 StringRef key
, value
;
337 std::tie(key
, value
, options
) = parseNextArg(options
);
341 auto it
= OptionsMap
.find(key
);
342 if (it
== OptionsMap
.end()) {
343 errorStream
<< "<Pass-Options-Parser>: no such option " << key
<< "\n";
346 if (llvm::cl::ProvidePositionalOption(it
->second
, value
, 0))
353 /// Print the options held by this struct in a form that can be parsed via
354 /// 'parseFromString'.
355 void detail::PassOptions::print(raw_ostream
&os
) const {
356 // If there are no options, there is nothing left to do.
357 if (OptionsMap
.empty())
360 // Sort the options to make the ordering deterministic.
361 SmallVector
<OptionBase
*, 4> orderedOps(options
.begin(), options
.end());
362 auto compareOptionArgs
= [](OptionBase
*const *lhs
, OptionBase
*const *rhs
) {
363 return (*lhs
)->getArgStr().compare((*rhs
)->getArgStr());
365 llvm::array_pod_sort(orderedOps
.begin(), orderedOps
.end(), compareOptionArgs
);
367 // Interleave the options with ' '.
370 orderedOps
, os
, [&](OptionBase
*option
) { option
->print(os
); }, " ");
374 /// Print the help string for the options held by this struct. `descIndent` is
375 /// the indent within the stream that the descriptions should be aligned.
376 void detail::PassOptions::printHelp(size_t indent
, size_t descIndent
) const {
377 // Sort the options to make the ordering deterministic.
378 SmallVector
<OptionBase
*, 4> orderedOps(options
.begin(), options
.end());
379 auto compareOptionArgs
= [](OptionBase
*const *lhs
, OptionBase
*const *rhs
) {
380 return (*lhs
)->getArgStr().compare((*rhs
)->getArgStr());
382 llvm::array_pod_sort(orderedOps
.begin(), orderedOps
.end(), compareOptionArgs
);
383 for (OptionBase
*option
: orderedOps
) {
384 // TODO: printOptionInfo assumes a specific indent and will
385 // print options with values with incorrect indentation. We should add
386 // support to llvm::cl::Option for passing in a base indent to use when
388 llvm::outs().indent(indent
);
389 option
->getOption()->printOptionInfo(descIndent
- indent
);
393 /// Return the maximum width required when printing the help string.
394 size_t detail::PassOptions::getOptionWidth() const {
396 for (auto *option
: options
)
397 max
= std::max(max
, option
->getOption()->getOptionWidth());
401 //===----------------------------------------------------------------------===//
403 //===----------------------------------------------------------------------===//
405 //===----------------------------------------------------------------------===//
406 // OpPassManager: OptionValue
408 llvm::cl::OptionValue
<OpPassManager
>::OptionValue() = default;
409 llvm::cl::OptionValue
<OpPassManager
>::OptionValue(
410 const mlir::OpPassManager
&value
) {
413 llvm::cl::OptionValue
<OpPassManager
>::OptionValue(
414 const llvm::cl::OptionValue
<mlir::OpPassManager
> &rhs
) {
416 setValue(rhs
.getValue());
418 llvm::cl::OptionValue
<OpPassManager
> &
419 llvm::cl::OptionValue
<OpPassManager
>::operator=(
420 const mlir::OpPassManager
&rhs
) {
425 llvm::cl::OptionValue
<OpPassManager
>::~OptionValue
<OpPassManager
>() = default;
427 void llvm::cl::OptionValue
<OpPassManager
>::setValue(
428 const OpPassManager
&newValue
) {
432 value
= std::make_unique
<mlir::OpPassManager
>(newValue
);
434 void llvm::cl::OptionValue
<OpPassManager
>::setValue(StringRef pipelineStr
) {
435 FailureOr
<OpPassManager
> pipeline
= parsePassPipeline(pipelineStr
);
436 assert(succeeded(pipeline
) && "invalid pass pipeline");
440 bool llvm::cl::OptionValue
<OpPassManager
>::compare(
441 const mlir::OpPassManager
&rhs
) const {
442 std::string lhsStr
, rhsStr
;
444 raw_string_ostream
lhsStream(lhsStr
);
445 value
->printAsTextualPipeline(lhsStream
);
447 raw_string_ostream
rhsStream(rhsStr
);
448 rhs
.printAsTextualPipeline(rhsStream
);
451 // Use the textual format for pipeline comparisons.
452 return lhsStr
== rhsStr
;
455 void llvm::cl::OptionValue
<OpPassManager
>::anchor() {}
457 //===----------------------------------------------------------------------===//
458 // OpPassManager: Parser
462 template class basic_parser
<OpPassManager
>;
466 bool llvm::cl::parser
<OpPassManager
>::parse(Option
&, StringRef
, StringRef arg
,
467 ParsedPassManager
&value
) {
468 FailureOr
<OpPassManager
> pipeline
= parsePassPipeline(arg
);
469 if (failed(pipeline
))
471 value
.value
= std::make_unique
<OpPassManager
>(std::move(*pipeline
));
475 void llvm::cl::parser
<OpPassManager
>::print(raw_ostream
&os
,
476 const OpPassManager
&value
) {
477 value
.printAsTextualPipeline(os
);
480 void llvm::cl::parser
<OpPassManager
>::printOptionDiff(
481 const Option
&opt
, OpPassManager
&pm
, const OptVal
&defaultValue
,
482 size_t globalWidth
) const {
483 printOptionName(opt
, globalWidth
);
485 pm
.printAsTextualPipeline(outs());
487 if (defaultValue
.hasValue()) {
488 outs().indent(2) << " (default: ";
489 defaultValue
.getValue().printAsTextualPipeline(outs());
495 void llvm::cl::parser
<OpPassManager
>::anchor() {}
497 llvm::cl::parser
<OpPassManager
>::ParsedPassManager::ParsedPassManager() =
499 llvm::cl::parser
<OpPassManager
>::ParsedPassManager::ParsedPassManager(
500 ParsedPassManager
&&) = default;
501 llvm::cl::parser
<OpPassManager
>::ParsedPassManager::~ParsedPassManager() =
504 //===----------------------------------------------------------------------===//
505 // TextualPassPipeline Parser
506 //===----------------------------------------------------------------------===//
509 /// This class represents a textual description of a pass pipeline.
510 class TextualPipeline
{
512 /// Try to initialize this pipeline with the given pipeline text.
513 /// `errorStream` is the output stream to emit errors to.
514 LogicalResult
initialize(StringRef text
, raw_ostream
&errorStream
);
516 /// Add the internal pipeline elements to the provided pass manager.
518 addToPipeline(OpPassManager
&pm
,
519 function_ref
<LogicalResult(const Twine
&)> errorHandler
) const;
522 /// A functor used to emit errors found during pipeline handling. The first
523 /// parameter corresponds to the raw location within the pipeline string. This
524 /// should always return failure.
525 using ErrorHandlerT
= function_ref
<LogicalResult(const char *, Twine
)>;
527 /// A struct to capture parsed pass pipeline names.
529 /// A pipeline is defined as a series of names, each of which may in itself
530 /// recursively contain a nested pipeline. A name is either the name of a pass
531 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
532 /// the name is the name of a pass, the InnerPipeline is empty, since passes
533 /// cannot contain inner pipelines.
534 struct PipelineElement
{
535 PipelineElement(StringRef name
) : name(name
) {}
539 const PassRegistryEntry
*registryEntry
= nullptr;
540 std::vector
<PipelineElement
> innerPipeline
;
543 /// Parse the given pipeline text into the internal pipeline vector. This
544 /// function only parses the structure of the pipeline, and does not resolve
546 LogicalResult
parsePipelineText(StringRef text
, ErrorHandlerT errorHandler
);
548 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
549 /// the corresponding registry entry.
551 resolvePipelineElements(MutableArrayRef
<PipelineElement
> elements
,
552 ErrorHandlerT errorHandler
);
554 /// Resolve a single element of the pipeline.
555 LogicalResult
resolvePipelineElement(PipelineElement
&element
,
556 ErrorHandlerT errorHandler
);
558 /// Add the given pipeline elements to the provided pass manager.
560 addToPipeline(ArrayRef
<PipelineElement
> elements
, OpPassManager
&pm
,
561 function_ref
<LogicalResult(const Twine
&)> errorHandler
) const;
563 std::vector
<PipelineElement
> pipeline
;
568 /// Try to initialize this pipeline with the given pipeline text. An option is
569 /// given to enable accurate error reporting.
570 LogicalResult
TextualPipeline::initialize(StringRef text
,
571 raw_ostream
&errorStream
) {
575 // Build a source manager to use for error reporting.
576 llvm::SourceMgr pipelineMgr
;
577 pipelineMgr
.AddNewSourceBuffer(
578 llvm::MemoryBuffer::getMemBuffer(text
, "MLIR Textual PassPipeline Parser",
579 /*RequiresNullTerminator=*/false),
581 auto errorHandler
= [&](const char *rawLoc
, Twine msg
) {
582 pipelineMgr
.PrintMessage(errorStream
, SMLoc::getFromPointer(rawLoc
),
583 llvm::SourceMgr::DK_Error
, msg
);
587 // Parse the provided pipeline string.
588 if (failed(parsePipelineText(text
, errorHandler
)))
590 return resolvePipelineElements(pipeline
, errorHandler
);
593 /// Add the internal pipeline elements to the provided pass manager.
594 LogicalResult
TextualPipeline::addToPipeline(
596 function_ref
<LogicalResult(const Twine
&)> errorHandler
) const {
597 // Temporarily disable implicit nesting while we append to the pipeline. We
598 // want the created pipeline to exactly match the parsed text pipeline, so
599 // it's preferrable to just error out if implicit nesting would be required.
600 OpPassManager::Nesting nesting
= pm
.getNesting();
601 pm
.setNesting(OpPassManager::Nesting::Explicit
);
602 auto restore
= llvm::make_scope_exit([&]() { pm
.setNesting(nesting
); });
604 return addToPipeline(pipeline
, pm
, errorHandler
);
607 /// Parse the given pipeline text into the internal pipeline vector. This
608 /// function only parses the structure of the pipeline, and does not resolve
610 LogicalResult
TextualPipeline::parsePipelineText(StringRef text
,
611 ErrorHandlerT errorHandler
) {
612 SmallVector
<std::vector
<PipelineElement
> *, 4> pipelineStack
= {&pipeline
};
614 std::vector
<PipelineElement
> &pipeline
= *pipelineStack
.back();
615 size_t pos
= text
.find_first_of(",(){");
616 pipeline
.emplace_back(/*name=*/text
.substr(0, pos
).trim());
618 // If we have a single terminating name, we're done.
619 if (pos
== StringRef::npos
)
622 text
= text
.substr(pos
);
625 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
627 text
= text
.substr(1);
629 // Skip over everything until the closing '}' and store as options.
630 size_t close
= StringRef::npos
;
631 for (unsigned i
= 0, e
= text
.size(), braceCount
= 1; i
< e
; ++i
) {
632 if (text
[i
] == '{') {
636 if (text
[i
] == '}' && --braceCount
== 0) {
642 // Check to see if a closing options brace was found.
643 if (close
== StringRef::npos
) {
645 /*rawLoc=*/text
.data() - 1,
646 "missing closing '}' while processing pass options");
648 pipeline
.back().options
= text
.substr(0, close
);
649 text
= text
.substr(close
+ 1);
651 // Consume space characters that an user might add for readability.
654 // Skip checking for '(' because nested pipelines cannot have options.
655 } else if (sep
== '(') {
656 text
= text
.substr(1);
658 // Push the inner pipeline onto the stack to continue processing.
659 pipelineStack
.push_back(&pipeline
.back().innerPipeline
);
663 // When handling the close parenthesis, we greedily consume them to avoid
664 // empty strings in the pipeline.
665 while (text
.consume_front(")")) {
666 // If we try to pop the outer pipeline we have unbalanced parentheses.
667 if (pipelineStack
.size() == 1)
668 return errorHandler(/*rawLoc=*/text
.data() - 1,
669 "encountered extra closing ')' creating unbalanced "
670 "parentheses while parsing pipeline");
672 pipelineStack
.pop_back();
673 // Consume space characters that an user might add for readability.
677 // Check if we've finished parsing.
681 // Otherwise, the end of an inner pipeline always has to be followed by
682 // a comma, and then we can continue.
683 if (!text
.consume_front(","))
684 return errorHandler(text
.data(), "expected ',' after parsing pipeline");
687 // Check for unbalanced parentheses.
688 if (pipelineStack
.size() > 1)
691 "encountered unbalanced parentheses while parsing pipeline");
693 assert(pipelineStack
.back() == &pipeline
&&
694 "wrong pipeline at the bottom of the stack");
698 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
699 /// the corresponding registry entry.
700 LogicalResult
TextualPipeline::resolvePipelineElements(
701 MutableArrayRef
<PipelineElement
> elements
, ErrorHandlerT errorHandler
) {
702 for (auto &elt
: elements
)
703 if (failed(resolvePipelineElement(elt
, errorHandler
)))
708 /// Resolve a single element of the pipeline.
710 TextualPipeline::resolvePipelineElement(PipelineElement
&element
,
711 ErrorHandlerT errorHandler
) {
712 // If the inner pipeline of this element is not empty, this is an operation
714 if (!element
.innerPipeline
.empty())
715 return resolvePipelineElements(element
.innerPipeline
, errorHandler
);
717 // Otherwise, this must be a pass or pass pipeline.
718 // Check to see if a pipeline was registered with this name.
719 if ((element
.registryEntry
= PassPipelineInfo::lookup(element
.name
)))
722 // If not, then this must be a specific pass name.
723 if ((element
.registryEntry
= PassInfo::lookup(element
.name
)))
726 // Emit an error for the unknown pass.
727 auto *rawLoc
= element
.name
.data();
728 return errorHandler(rawLoc
, "'" + element
.name
+
729 "' does not refer to a "
730 "registered pass or pass pipeline");
733 /// Add the given pipeline elements to the provided pass manager.
734 LogicalResult
TextualPipeline::addToPipeline(
735 ArrayRef
<PipelineElement
> elements
, OpPassManager
&pm
,
736 function_ref
<LogicalResult(const Twine
&)> errorHandler
) const {
737 for (auto &elt
: elements
) {
738 if (elt
.registryEntry
) {
739 if (failed(elt
.registryEntry
->addToPipeline(pm
, elt
.options
,
741 return errorHandler("failed to add `" + elt
.name
+ "` with options `" +
744 } else if (failed(addToPipeline(elt
.innerPipeline
, pm
.nest(elt
.name
),
746 return errorHandler("failed to add `" + elt
.name
+ "` with options `" +
747 elt
.options
+ "` to inner pipeline");
753 LogicalResult
mlir::parsePassPipeline(StringRef pipeline
, OpPassManager
&pm
,
754 raw_ostream
&errorStream
) {
755 TextualPipeline pipelineParser
;
756 if (failed(pipelineParser
.initialize(pipeline
, errorStream
)))
758 auto errorHandler
= [&](Twine msg
) {
759 errorStream
<< msg
<< "\n";
762 if (failed(pipelineParser
.addToPipeline(pm
, errorHandler
)))
767 FailureOr
<OpPassManager
> mlir::parsePassPipeline(StringRef pipeline
,
768 raw_ostream
&errorStream
) {
769 pipeline
= pipeline
.trim();
770 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
771 size_t pipelineStart
= pipeline
.find_first_of('(');
772 if (pipelineStart
== 0 || pipelineStart
== StringRef::npos
||
773 !pipeline
.consume_back(")")) {
774 errorStream
<< "expected pass pipeline to be wrapped with the anchor "
775 "operation type, e.g. 'builtin.module(...)'";
779 StringRef opName
= pipeline
.take_front(pipelineStart
).rtrim();
780 OpPassManager
pm(opName
);
781 if (failed(parsePassPipeline(pipeline
.drop_front(1 + pipelineStart
), pm
,
787 //===----------------------------------------------------------------------===//
789 //===----------------------------------------------------------------------===//
792 /// This struct represents the possible data entries in a parsed pass pipeline
795 PassArgData() = default;
796 PassArgData(const PassRegistryEntry
*registryEntry
)
797 : registryEntry(registryEntry
) {}
799 /// This field is used when the parsed option corresponds to a registered pass
800 /// or pass pipeline.
801 const PassRegistryEntry
*registryEntry
{nullptr};
803 /// This field is set when instance specific pass options have been provided
804 /// on the command line.
811 /// Define a valid OptionValue for the command line pass argument.
813 struct OptionValue
<PassArgData
> final
814 : OptionValueBase
<PassArgData
, /*isClass=*/true> {
815 OptionValue(const PassArgData
&value
) { this->setValue(value
); }
816 OptionValue() = default;
817 void anchor() override
{}
819 bool hasValue() const { return true; }
820 const PassArgData
&getValue() const { return value
; }
821 void setValue(const PassArgData
&value
) { this->value
= value
; }
830 /// The name for the command line option used for parsing the textual pass
832 #define PASS_PIPELINE_ARG "pass-pipeline"
834 /// Adds command line option for each registered pass or pass pipeline, as well
835 /// as textual pass pipelines.
836 struct PassNameParser
: public llvm::cl::parser
<PassArgData
> {
837 PassNameParser(llvm::cl::Option
&opt
) : llvm::cl::parser
<PassArgData
>(opt
) {}
840 void printOptionInfo(const llvm::cl::Option
&opt
,
841 size_t globalWidth
) const override
;
842 size_t getOptionWidth(const llvm::cl::Option
&opt
) const override
;
843 bool parse(llvm::cl::Option
&opt
, StringRef argName
, StringRef arg
,
846 /// If true, this parser only parses entries that correspond to a concrete
847 /// pass registry entry, and does not include pipeline entries or the options
848 /// for pass entries.
849 bool passNamesOnly
= false;
853 void PassNameParser::initialize() {
854 llvm::cl::parser
<PassArgData
>::initialize();
856 /// Add the pass entries.
857 for (const auto &kv
: *passRegistry
) {
858 addLiteralOption(kv
.second
.getPassArgument(), &kv
.second
,
859 kv
.second
.getPassDescription());
861 /// Add the pass pipeline entries.
862 if (!passNamesOnly
) {
863 for (const auto &kv
: *passPipelineRegistry
) {
864 addLiteralOption(kv
.second
.getPassArgument(), &kv
.second
,
865 kv
.second
.getPassDescription());
870 void PassNameParser::printOptionInfo(const llvm::cl::Option
&opt
,
871 size_t globalWidth
) const {
872 // If this parser is just parsing pass names, print a simplified option
875 llvm::outs() << " --" << opt
.ArgStr
<< "=<pass-arg>";
876 opt
.printHelpStr(opt
.HelpStr
, globalWidth
, opt
.ArgStr
.size() + 18);
880 // Print the information for the top-level option.
881 if (opt
.hasArgStr()) {
882 llvm::outs() << " --" << opt
.ArgStr
;
883 opt
.printHelpStr(opt
.HelpStr
, globalWidth
, opt
.ArgStr
.size() + 7);
885 llvm::outs() << " " << opt
.HelpStr
<< '\n';
888 // Functor used to print the ordered entries of a registration map.
889 auto printOrderedEntries
= [&](StringRef header
, auto &map
) {
890 llvm::SmallVector
<PassRegistryEntry
*, 32> orderedEntries
;
892 orderedEntries
.push_back(&kv
.second
);
893 llvm::array_pod_sort(
894 orderedEntries
.begin(), orderedEntries
.end(),
895 [](PassRegistryEntry
*const *lhs
, PassRegistryEntry
*const *rhs
) {
896 return (*lhs
)->getPassArgument().compare((*rhs
)->getPassArgument());
899 llvm::outs().indent(4) << header
<< ":\n";
900 for (PassRegistryEntry
*entry
: orderedEntries
)
901 entry
->printHelpStr(/*indent=*/6, globalWidth
);
904 // Print the available passes.
905 printOrderedEntries("Passes", *passRegistry
);
907 // Print the available pass pipelines.
908 if (!passPipelineRegistry
->empty())
909 printOrderedEntries("Pass Pipelines", *passPipelineRegistry
);
912 size_t PassNameParser::getOptionWidth(const llvm::cl::Option
&opt
) const {
913 size_t maxWidth
= llvm::cl::parser
<PassArgData
>::getOptionWidth(opt
) + 2;
915 // Check for any wider pass or pipeline options.
916 for (auto &entry
: *passRegistry
)
917 maxWidth
= std::max(maxWidth
, entry
.second
.getOptionWidth() + 4);
918 for (auto &entry
: *passPipelineRegistry
)
919 maxWidth
= std::max(maxWidth
, entry
.second
.getOptionWidth() + 4);
923 bool PassNameParser::parse(llvm::cl::Option
&opt
, StringRef argName
,
924 StringRef arg
, PassArgData
&value
) {
925 if (llvm::cl::parser
<PassArgData
>::parse(opt
, argName
, arg
, value
))
931 //===----------------------------------------------------------------------===//
932 // PassPipelineCLParser
933 //===----------------------------------------------------------------------===//
937 struct PassPipelineCLParserImpl
{
938 PassPipelineCLParserImpl(StringRef arg
, StringRef description
,
940 : passList(arg
, llvm::cl::desc(description
)) {
941 passList
.getParser().passNamesOnly
= passNamesOnly
;
942 passList
.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional
);
945 /// Returns true if the given pass registry entry was registered at the
946 /// top-level of the parser, i.e. not within an explicit textual pipeline.
947 bool contains(const PassRegistryEntry
*entry
) const {
948 return llvm::any_of(passList
, [&](const PassArgData
&data
) {
949 return data
.registryEntry
== entry
;
953 /// The set of passes and pass pipelines to run.
954 llvm::cl::list
<PassArgData
, bool, PassNameParser
> passList
;
956 } // namespace detail
959 /// Construct a pass pipeline parser with the given command line description.
960 PassPipelineCLParser::PassPipelineCLParser(StringRef arg
, StringRef description
)
961 : impl(std::make_unique
<detail::PassPipelineCLParserImpl
>(
962 arg
, description
, /*passNamesOnly=*/false)),
965 llvm::cl::desc("Textual description of the pass pipeline to run")) {}
967 PassPipelineCLParser::PassPipelineCLParser(StringRef arg
, StringRef description
,
969 : PassPipelineCLParser(arg
, description
) {
970 passPipelineAlias
.emplace(alias
,
971 llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG
),
972 llvm::cl::aliasopt(passPipeline
));
975 PassPipelineCLParser::~PassPipelineCLParser() = default;
977 /// Returns true if this parser contains any valid options to add.
978 bool PassPipelineCLParser::hasAnyOccurrences() const {
979 return passPipeline
.getNumOccurrences() != 0 ||
980 impl
->passList
.getNumOccurrences() != 0;
983 /// Returns true if the given pass registry entry was registered at the
984 /// top-level of the parser, i.e. not within an explicit textual pipeline.
985 bool PassPipelineCLParser::contains(const PassRegistryEntry
*entry
) const {
986 return impl
->contains(entry
);
989 /// Adds the passes defined by this parser entry to the given pass manager.
990 LogicalResult
PassPipelineCLParser::addToPipeline(
992 function_ref
<LogicalResult(const Twine
&)> errorHandler
) const {
993 if (passPipeline
.getNumOccurrences()) {
994 if (impl
->passList
.getNumOccurrences())
996 "'-" PASS_PIPELINE_ARG
997 "' option can't be used with individual pass options");
999 llvm::raw_string_ostream
os(errMsg
);
1000 FailureOr
<OpPassManager
> parsed
= parsePassPipeline(passPipeline
, os
);
1002 return errorHandler(errMsg
);
1003 pm
= std::move(*parsed
);
1007 for (auto &passIt
: impl
->passList
) {
1008 if (failed(passIt
.registryEntry
->addToPipeline(pm
, passIt
.options
,
1015 //===----------------------------------------------------------------------===//
1018 /// Construct a pass pipeline parser with the given command line description.
1019 PassNameCLParser::PassNameCLParser(StringRef arg
, StringRef description
)
1020 : impl(std::make_unique
<detail::PassPipelineCLParserImpl
>(
1021 arg
, description
, /*passNamesOnly=*/true)) {
1022 impl
->passList
.setMiscFlag(llvm::cl::CommaSeparated
);
1024 PassNameCLParser::~PassNameCLParser() = default;
1026 /// Returns true if this parser contains any valid options to add.
1027 bool PassNameCLParser::hasAnyOccurrences() const {
1028 return impl
->passList
.getNumOccurrences() != 0;
1031 /// Returns true if the given pass registry entry was registered at the
1032 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1033 bool PassNameCLParser::contains(const PassRegistryEntry
*entry
) const {
1034 return impl
->contains(entry
);