Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Pass / Pass.cpp
blob6fd51c1e3cb53843601ee209dea51d7a3c7251f8
1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/CrashRecoveryContext.h"
26 #include "llvm/Support/Mutex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/Support/Threading.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 #include <optional>
32 using namespace mlir;
33 using namespace mlir::detail;
35 //===----------------------------------------------------------------------===//
36 // PassExecutionAction
37 //===----------------------------------------------------------------------===//
39 PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
40 const Pass &pass)
41 : Base(irUnits), pass(pass) {}
43 void PassExecutionAction::print(raw_ostream &os) const {
44 os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
45 pass.getName(), getOp()->getName());
48 Operation *PassExecutionAction::getOp() const {
49 ArrayRef<IRUnit> irUnits = getContextIRUnits();
50 return irUnits.empty() ? nullptr
51 : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
54 //===----------------------------------------------------------------------===//
55 // Pass
56 //===----------------------------------------------------------------------===//
58 /// Out of line virtual method to ensure vtables and metadata are emitted to a
59 /// single .o file.
60 void Pass::anchor() {}
62 /// Attempt to initialize the options of this pass from the given string.
63 LogicalResult Pass::initializeOptions(
64 StringRef options,
65 function_ref<LogicalResult(const Twine &)> errorHandler) {
66 std::string errStr;
67 llvm::raw_string_ostream os(errStr);
68 if (failed(passOptions.parseFromString(options, os))) {
69 return errorHandler(errStr);
71 return success();
74 /// Copy the option values from 'other', which is another instance of this
75 /// pass.
76 void Pass::copyOptionValuesFrom(const Pass *other) {
77 passOptions.copyOptionValuesFrom(other->passOptions);
80 /// Prints out the pass in the textual representation of pipelines. If this is
81 /// an adaptor pass, print its pass managers.
82 void Pass::printAsTextualPipeline(raw_ostream &os) {
83 // Special case for adaptors to print its pass managers.
84 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
85 llvm::interleave(
86 adaptor->getPassManagers(),
87 [&](OpPassManager &pm) { pm.printAsTextualPipeline(os); },
88 [&] { os << ","; });
89 return;
91 // Otherwise, print the pass argument followed by its options. If the pass
92 // doesn't have an argument, print the name of the pass to give some indicator
93 // of what pass was run.
94 StringRef argument = getArgument();
95 if (!argument.empty())
96 os << argument;
97 else
98 os << "unknown<" << getName() << ">";
99 passOptions.print(os);
102 //===----------------------------------------------------------------------===//
103 // OpPassManagerImpl
104 //===----------------------------------------------------------------------===//
106 namespace mlir {
107 namespace detail {
108 struct OpPassManagerImpl {
109 OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
110 : name(opName.getStringRef().str()), opName(opName),
111 initializationGeneration(0), nesting(nesting) {}
112 OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
113 : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()),
114 initializationGeneration(0), nesting(nesting) {}
115 OpPassManagerImpl(OpPassManager::Nesting nesting)
116 : initializationGeneration(0), nesting(nesting) {}
117 OpPassManagerImpl(const OpPassManagerImpl &rhs)
118 : name(rhs.name), opName(rhs.opName),
119 initializationGeneration(rhs.initializationGeneration),
120 nesting(rhs.nesting) {
121 for (const std::unique_ptr<Pass> &pass : rhs.passes) {
122 std::unique_ptr<Pass> newPass = pass->clone();
123 newPass->threadingSibling = pass.get();
124 passes.push_back(std::move(newPass));
128 /// Merge the passes of this pass manager into the one provided.
129 void mergeInto(OpPassManagerImpl &rhs);
131 /// Nest a new operation pass manager for the given operation kind under this
132 /// pass manager.
133 OpPassManager &nest(OperationName nestedName) {
134 return nest(OpPassManager(nestedName, nesting));
136 OpPassManager &nest(StringRef nestedName) {
137 return nest(OpPassManager(nestedName, nesting));
139 OpPassManager &nestAny() { return nest(OpPassManager(nesting)); }
141 /// Nest the given pass manager under this pass manager.
142 OpPassManager &nest(OpPassManager &&nested);
144 /// Add the given pass to this pass manager. If this pass has a concrete
145 /// operation type, it must be the same type as this pass manager.
146 void addPass(std::unique_ptr<Pass> pass);
148 /// Clear the list of passes in this pass manager, other options are
149 /// preserved.
150 void clear();
152 /// Finalize the pass list in preparation for execution. This includes
153 /// coalescing adjacent pass managers when possible, verifying scheduled
154 /// passes, etc.
155 LogicalResult finalizePassList(MLIRContext *ctx);
157 /// Return the operation name of this pass manager.
158 std::optional<OperationName> getOpName(MLIRContext &context) {
159 if (!name.empty() && !opName)
160 opName = OperationName(name, &context);
161 return opName;
163 std::optional<StringRef> getOpName() const {
164 return name.empty() ? std::optional<StringRef>()
165 : std::optional<StringRef>(name);
168 /// Return the name used to anchor this pass manager. This is either the name
169 /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an
170 /// op-agnostic pass manager.
171 StringRef getOpAnchorName() const {
172 return getOpName().value_or(OpPassManager::getAnyOpAnchorName());
175 /// Indicate if the current pass manager can be scheduled on the given
176 /// operation type.
177 bool canScheduleOn(MLIRContext &context, OperationName opName);
179 /// The name of the operation that passes of this pass manager operate on.
180 std::string name;
182 /// The cached OperationName (internalized in the context) for the name of the
183 /// operation that passes of this pass manager operate on.
184 std::optional<OperationName> opName;
186 /// The set of passes to run as part of this pass manager.
187 std::vector<std::unique_ptr<Pass>> passes;
189 /// The current initialization generation of this pass manager. This is used
190 /// to indicate when a pass manager should be reinitialized.
191 unsigned initializationGeneration;
193 /// Control the implicit nesting of passes that mismatch the name set for this
194 /// OpPassManager.
195 OpPassManager::Nesting nesting;
197 } // namespace detail
198 } // namespace mlir
200 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
201 assert(name == rhs.name && "merging unrelated pass managers");
202 for (auto &pass : passes)
203 rhs.passes.push_back(std::move(pass));
204 passes.clear();
207 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
208 auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
209 addPass(std::unique_ptr<Pass>(adaptor));
210 return adaptor->getPassManagers().front();
213 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
214 // If this pass runs on a different operation than this pass manager, then
215 // implicitly nest a pass manager for this operation if enabled.
216 std::optional<StringRef> pmOpName = getOpName();
217 std::optional<StringRef> passOpName = pass->getOpName();
218 if (pmOpName && passOpName && *pmOpName != *passOpName) {
219 if (nesting == OpPassManager::Nesting::Implicit)
220 return nest(*passOpName).addPass(std::move(pass));
221 llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
222 "' restricted to '" + *passOpName +
223 "' on a PassManager intended to run on '" +
224 getOpAnchorName() + "', did you intend to nest?");
227 passes.emplace_back(std::move(pass));
230 void OpPassManagerImpl::clear() { passes.clear(); }
232 LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
233 auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) {
234 for (auto &pm : adaptor->getPassManagers())
235 if (failed(pm.getImpl().finalizePassList(ctx)))
236 return failure();
237 return success();
240 // Walk the pass list and merge adjacent adaptors.
241 OpToOpPassAdaptor *lastAdaptor = nullptr;
242 for (auto &pass : passes) {
243 // Check to see if this pass is an adaptor.
244 if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) {
245 // If it is the first adaptor in a possible chain, remember it and
246 // continue.
247 if (!lastAdaptor) {
248 lastAdaptor = currentAdaptor;
249 continue;
252 // Otherwise, try to merge into the existing adaptor and delete the
253 // current one. If merging fails, just remember this as the last adaptor.
254 if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor)))
255 pass.reset();
256 else
257 lastAdaptor = currentAdaptor;
258 } else if (lastAdaptor) {
259 // If this pass isn't an adaptor, finalize it and forget the last adaptor.
260 if (failed(finalizeAdaptor(lastAdaptor)))
261 return failure();
262 lastAdaptor = nullptr;
266 // If there was an adaptor at the end of the manager, finalize it as well.
267 if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor)))
268 return failure();
270 // Now that the adaptors have been merged, erase any empty slots corresponding
271 // to the merged adaptors that were nulled-out in the loop above.
272 llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
274 // If this is a op-agnostic pass manager, there is nothing left to do.
275 std::optional<OperationName> rawOpName = getOpName(*ctx);
276 if (!rawOpName)
277 return success();
279 // Otherwise, verify that all of the passes are valid for the current
280 // operation anchor.
281 std::optional<RegisteredOperationName> opName =
282 rawOpName->getRegisteredInfo();
283 for (std::unique_ptr<Pass> &pass : passes) {
284 if (opName && !pass->canScheduleOn(*opName)) {
285 return emitError(UnknownLoc::get(ctx))
286 << "unable to schedule pass '" << pass->getName()
287 << "' on a PassManager intended to run on '" << getOpAnchorName()
288 << "'!";
291 return success();
294 bool OpPassManagerImpl::canScheduleOn(MLIRContext &context,
295 OperationName opName) {
296 // If this pass manager is op-specific, we simply check if the provided
297 // operation name is the same as this one.
298 std::optional<OperationName> pmOpName = getOpName(context);
299 if (pmOpName)
300 return pmOpName == opName;
302 // Otherwise, this is an op-agnostic pass manager. Check that the operation
303 // can be scheduled on all passes within the manager.
304 std::optional<RegisteredOperationName> registeredInfo =
305 opName.getRegisteredInfo();
306 if (!registeredInfo ||
307 !registeredInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
308 return false;
309 return llvm::all_of(passes, [&](const std::unique_ptr<Pass> &pass) {
310 return pass->canScheduleOn(*registeredInfo);
314 //===----------------------------------------------------------------------===//
315 // OpPassManager
316 //===----------------------------------------------------------------------===//
318 OpPassManager::OpPassManager(Nesting nesting)
319 : impl(new OpPassManagerImpl(nesting)) {}
320 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
321 : impl(new OpPassManagerImpl(name, nesting)) {}
322 OpPassManager::OpPassManager(OperationName name, Nesting nesting)
323 : impl(new OpPassManagerImpl(name, nesting)) {}
324 OpPassManager::OpPassManager(OpPassManager &&rhs) { *this = std::move(rhs); }
325 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
326 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
327 impl = std::make_unique<OpPassManagerImpl>(*rhs.impl);
328 return *this;
330 OpPassManager &OpPassManager::operator=(OpPassManager &&rhs) {
331 impl = std::move(rhs.impl);
332 return *this;
335 OpPassManager::~OpPassManager() = default;
337 OpPassManager::pass_iterator OpPassManager::begin() {
338 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
340 OpPassManager::pass_iterator OpPassManager::end() {
341 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
344 OpPassManager::const_pass_iterator OpPassManager::begin() const {
345 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
347 OpPassManager::const_pass_iterator OpPassManager::end() const {
348 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
351 /// Nest a new operation pass manager for the given operation kind under this
352 /// pass manager.
353 OpPassManager &OpPassManager::nest(OperationName nestedName) {
354 return impl->nest(nestedName);
356 OpPassManager &OpPassManager::nest(StringRef nestedName) {
357 return impl->nest(nestedName);
359 OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); }
361 /// Add the given pass to this pass manager. If this pass has a concrete
362 /// operation type, it must be the same type as this pass manager.
363 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
364 impl->addPass(std::move(pass));
367 void OpPassManager::clear() { impl->clear(); }
369 /// Returns the number of passes held by this manager.
370 size_t OpPassManager::size() const { return impl->passes.size(); }
372 /// Returns the internal implementation instance.
373 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
375 /// Return the operation name that this pass manager operates on.
376 std::optional<StringRef> OpPassManager::getOpName() const {
377 return impl->getOpName();
380 /// Return the operation name that this pass manager operates on.
381 std::optional<OperationName>
382 OpPassManager::getOpName(MLIRContext &context) const {
383 return impl->getOpName(context);
386 StringRef OpPassManager::getOpAnchorName() const {
387 return impl->getOpAnchorName();
390 /// Prints out the passes of the pass manager as the textual representation
391 /// of pipelines.
392 void printAsTextualPipeline(
393 raw_ostream &os, StringRef anchorName,
394 const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
395 os << anchorName << "(";
396 llvm::interleave(
397 passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
398 [&]() { os << ","; });
399 os << ")";
401 void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
402 StringRef anchorName = getOpAnchorName();
403 ::printAsTextualPipeline(
404 os, anchorName,
405 {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
406 MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
409 void OpPassManager::dump() {
410 llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";
411 printAsTextualPipeline(llvm::errs());
412 llvm::errs() << "\n";
415 static void registerDialectsForPipeline(const OpPassManager &pm,
416 DialectRegistry &dialects) {
417 for (const Pass &pass : pm.getPasses())
418 pass.getDependentDialects(dialects);
421 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
422 registerDialectsForPipeline(*this, dialects);
425 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
427 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
429 LogicalResult OpPassManager::initialize(MLIRContext *context,
430 unsigned newInitGeneration) {
431 if (impl->initializationGeneration == newInitGeneration)
432 return success();
433 impl->initializationGeneration = newInitGeneration;
434 for (Pass &pass : getPasses()) {
435 // If this pass isn't an adaptor, directly initialize it.
436 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
437 if (!adaptor) {
438 if (failed(pass.initialize(context)))
439 return failure();
440 continue;
443 // Otherwise, initialize each of the adaptors pass managers.
444 for (OpPassManager &adaptorPM : adaptor->getPassManagers())
445 if (failed(adaptorPM.initialize(context, newInitGeneration)))
446 return failure();
448 return success();
451 llvm::hash_code OpPassManager::hash() {
452 llvm::hash_code hashCode{};
453 for (Pass &pass : getPasses()) {
454 // If this pass isn't an adaptor, directly hash it.
455 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
456 if (!adaptor) {
457 hashCode = llvm::hash_combine(hashCode, &pass);
458 continue;
460 // Otherwise, hash recursively each of the adaptors pass managers.
461 for (OpPassManager &adaptorPM : adaptor->getPassManagers())
462 llvm::hash_combine(hashCode, adaptorPM.hash());
464 return hashCode;
468 //===----------------------------------------------------------------------===//
469 // OpToOpPassAdaptor
470 //===----------------------------------------------------------------------===//
472 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
473 AnalysisManager am, bool verifyPasses,
474 unsigned parentInitGeneration) {
475 std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
476 if (!opInfo)
477 return op->emitOpError()
478 << "trying to schedule a pass on an unregistered operation";
479 if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
480 return op->emitOpError() << "trying to schedule a pass on an operation not "
481 "marked as 'IsolatedFromAbove'";
482 if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
483 return op->emitOpError()
484 << "trying to schedule a pass on an unsupported operation";
486 // Initialize the pass state with a callback for the pass to dynamically
487 // execute a pipeline on the currently visited operation.
488 PassInstrumentor *pi = am.getPassInstrumentor();
489 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
490 pass};
491 auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
492 Operation *root) -> LogicalResult {
493 if (!op->isAncestor(root))
494 return root->emitOpError()
495 << "Trying to schedule a dynamic pipeline on an "
496 "operation that isn't "
497 "nested under the current operation the pass is processing";
498 assert(
499 pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName()));
501 // Before running, finalize the passes held by the pipeline.
502 if (failed(pipeline.getImpl().finalizePassList(root->getContext())))
503 return failure();
505 // Initialize the user provided pipeline and execute the pipeline.
506 if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
507 return failure();
508 AnalysisManager nestedAm = root == op ? am : am.nest(root);
509 return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm,
510 verifyPasses, parentInitGeneration,
511 pi, &parentInfo);
513 pass->passState.emplace(op, am, dynamicPipelineCallback);
515 // Instrument before the pass has run.
516 if (pi)
517 pi->runBeforePass(pass, op);
519 bool passFailed = false;
520 op->getContext()->executeAction<PassExecutionAction>(
521 [&]() {
522 // Invoke the virtual runOnOperation method.
523 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
524 adaptor->runOnOperation(verifyPasses);
525 else
526 pass->runOnOperation();
527 passFailed = pass->passState->irAndPassFailed.getInt();
529 {op}, *pass);
531 // Invalidate any non preserved analyses.
532 am.invalidate(pass->passState->preservedAnalyses);
534 // When verifyPasses is specified, we run the verifier (unless the pass
535 // failed).
536 if (!passFailed && verifyPasses) {
537 bool runVerifierNow = true;
539 // If the pass is an adaptor pass, we don't run the verifier recursively
540 // because the nested operations should have already been verified after
541 // nested passes had run.
542 bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
544 // Reduce compile time by avoiding running the verifier if the pass didn't
545 // change the IR since the last time the verifier was run:
547 // 1) If the pass said that it preserved all analyses then it can't have
548 // permuted the IR.
550 // We run these checks in EXPENSIVE_CHECKS mode out of caution.
551 #ifndef EXPENSIVE_CHECKS
552 runVerifierNow = !pass->passState->preservedAnalyses.isAll();
553 #endif
554 if (runVerifierNow)
555 passFailed = failed(verify(op, runVerifierRecursively));
558 // Instrument after the pass has run.
559 if (pi) {
560 if (passFailed)
561 pi->runAfterPassFailed(pass, op);
562 else
563 pi->runAfterPass(pass, op);
566 // Return if the pass signaled a failure.
567 return failure(passFailed);
570 /// Run the given operation and analysis manager on a provided op pass manager.
571 LogicalResult OpToOpPassAdaptor::runPipeline(
572 OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
573 unsigned parentInitGeneration, PassInstrumentor *instrumentor,
574 const PassInstrumentation::PipelineParentInfo *parentInfo) {
575 assert((!instrumentor || parentInfo) &&
576 "expected parent info if instrumentor is provided");
577 auto scopeExit = llvm::make_scope_exit([&] {
578 // Clear out any computed operation analyses. These analyses won't be used
579 // any more in this pipeline, and this helps reduce the current working set
580 // of memory. If preserving these analyses becomes important in the future
581 // we can re-evaluate this.
582 am.clear();
585 // Run the pipeline over the provided operation.
586 if (instrumentor) {
587 instrumentor->runBeforePipeline(pm.getOpName(*op->getContext()),
588 *parentInfo);
591 for (Pass &pass : pm.getPasses())
592 if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
593 return failure();
595 if (instrumentor) {
596 instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()),
597 *parentInfo);
599 return success();
602 /// Find an operation pass manager with the given anchor name, or nullptr if one
603 /// does not exist.
604 static OpPassManager *
605 findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs, StringRef name) {
606 auto *it = llvm::find_if(
607 mgrs, [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; });
608 return it == mgrs.end() ? nullptr : &*it;
611 /// Find an operation pass manager that can operate on an operation of the given
612 /// type, or nullptr if one does not exist.
613 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
614 OperationName name,
615 MLIRContext &context) {
616 auto *it = llvm::find_if(mgrs, [&](OpPassManager &mgr) {
617 return mgr.getImpl().canScheduleOn(context, name);
619 return it == mgrs.end() ? nullptr : &*it;
622 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
623 mgrs.emplace_back(std::move(mgr));
626 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
627 for (auto &pm : mgrs)
628 pm.getDependentDialects(dialects);
631 LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
632 OpToOpPassAdaptor &rhs) {
633 // Functor used to check if a pass manager is generic, i.e. op-agnostic.
634 auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); };
636 // Functor used to detect if the given generic pass manager will have a
637 // potential schedule conflict with the given `otherPMs`.
638 auto hasScheduleConflictWith = [&](OpPassManager &genericPM,
639 MutableArrayRef<OpPassManager> otherPMs) {
640 return llvm::any_of(otherPMs, [&](OpPassManager &pm) {
641 // If this is a non-generic pass manager, a conflict will arise if a
642 // non-generic pass manager's operation name can be scheduled on the
643 // generic passmanager.
644 if (std::optional<OperationName> pmOpName = pm.getOpName(*ctx))
645 return genericPM.getImpl().canScheduleOn(*ctx, *pmOpName);
646 // Otherwise, this is a generic pass manager. We current can't determine
647 // when generic pass managers can be merged, so conservatively assume they
648 // conflict.
649 return true;
653 // Check that if either adaptor has a generic pass manager, that pm is
654 // compatible within any non-generic pass managers.
656 // Check the current adaptor.
657 auto *lhsGenericPMIt = llvm::find_if(mgrs, isGenericPM);
658 if (lhsGenericPMIt != mgrs.end() &&
659 hasScheduleConflictWith(*lhsGenericPMIt, rhs.mgrs))
660 return failure();
661 // Check the rhs adaptor.
662 auto *rhsGenericPMIt = llvm::find_if(rhs.mgrs, isGenericPM);
663 if (rhsGenericPMIt != rhs.mgrs.end() &&
664 hasScheduleConflictWith(*rhsGenericPMIt, mgrs))
665 return failure();
667 for (auto &pm : mgrs) {
668 // If an existing pass manager exists, then merge the given pass manager
669 // into it.
670 if (auto *existingPM =
671 findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) {
672 pm.getImpl().mergeInto(existingPM->getImpl());
673 } else {
674 // Otherwise, add the given pass manager to the list.
675 rhs.mgrs.emplace_back(std::move(pm));
678 mgrs.clear();
680 // After coalescing, sort the pass managers within rhs by name.
681 auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) {
682 // Order op-specific pass managers first and op-agnostic pass managers last.
683 if (std::optional<StringRef> lhsName = lhs->getOpName()) {
684 if (std::optional<StringRef> rhsName = rhs->getOpName())
685 return lhsName->compare(*rhsName);
686 return -1; // lhs(op-specific) < rhs(op-agnostic)
688 return 1; // lhs(op-agnostic) > rhs(op-specific)
690 llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), compareFn);
691 return success();
694 /// Returns the adaptor pass name.
695 std::string OpToOpPassAdaptor::getAdaptorName() {
696 std::string name = "Pipeline Collection : [";
697 llvm::raw_string_ostream os(name);
698 llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
699 os << '\'' << pm.getOpAnchorName() << '\'';
701 os << ']';
702 return name;
705 void OpToOpPassAdaptor::runOnOperation() {
706 llvm_unreachable(
707 "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
710 /// Run the held pipeline over all nested operations.
711 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
712 if (getContext().isMultithreadingEnabled())
713 runOnOperationAsyncImpl(verifyPasses);
714 else
715 runOnOperationImpl(verifyPasses);
718 /// Run this pass adaptor synchronously.
719 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
720 auto am = getAnalysisManager();
721 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
722 this};
723 auto *instrumentor = am.getPassInstrumentor();
724 for (auto &region : getOperation()->getRegions()) {
725 for (auto &block : region) {
726 for (auto &op : block) {
727 auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
728 if (!mgr)
729 continue;
731 // Run the held pipeline over the current operation.
732 unsigned initGeneration = mgr->impl->initializationGeneration;
733 if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
734 initGeneration, instrumentor, &parentInfo)))
735 signalPassFailure();
741 /// Utility functor that checks if the two ranges of pass managers have a size
742 /// mismatch.
743 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
744 ArrayRef<OpPassManager> rhs) {
745 return lhs.size() != rhs.size() ||
746 llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
747 [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
750 /// Run this pass adaptor synchronously.
751 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
752 AnalysisManager am = getAnalysisManager();
753 MLIRContext *context = &getContext();
755 // Create the async executors if they haven't been created, or if the main
756 // pipeline has changed.
757 if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
758 asyncExecutors.assign(context->getThreadPool().getMaxConcurrency(), mgrs);
760 // This struct represents the information for a single operation to be
761 // scheduled on a pass manager.
762 struct OpPMInfo {
763 OpPMInfo(unsigned passManagerIdx, Operation *op, AnalysisManager am)
764 : passManagerIdx(passManagerIdx), op(op), am(am) {}
766 /// The index of the pass manager to schedule the operation on.
767 unsigned passManagerIdx;
768 /// The operation to schedule.
769 Operation *op;
770 /// The analysis manager for the operation.
771 AnalysisManager am;
774 // Run a prepass over the operation to collect the nested operations to
775 // execute over. This ensures that an analysis manager exists for each
776 // operation, as well as providing a queue of operations to execute over.
777 std::vector<OpPMInfo> opInfos;
778 DenseMap<OperationName, std::optional<unsigned>> knownOpPMIdx;
779 for (auto &region : getOperation()->getRegions()) {
780 for (Operation &op : region.getOps()) {
781 // Get the pass manager index for this operation type.
782 auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt);
783 if (pmIdxIt.second) {
784 if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context))
785 pmIdxIt.first->second = std::distance(mgrs.begin(), mgr);
788 // If this operation can be scheduled, add it to the list.
789 if (pmIdxIt.first->second)
790 opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
794 // Get the current thread for this adaptor.
795 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
796 this};
797 auto *instrumentor = am.getPassInstrumentor();
799 // An atomic failure variable for the async executors.
800 std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
801 std::fill(activePMs.begin(), activePMs.end(), false);
802 std::atomic<bool> hasFailure = false;
803 parallelForEach(context, opInfos, [&](OpPMInfo &opInfo) {
804 // Find an executor for this operation.
805 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
806 bool expectedInactive = false;
807 return isActive.compare_exchange_strong(expectedInactive, true);
809 unsigned pmIndex = it - activePMs.begin();
811 // Get the pass manager for this operation and execute it.
812 OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx];
813 LogicalResult pipelineResult = runPipeline(
814 pm, opInfo.op, opInfo.am, verifyPasses,
815 pm.impl->initializationGeneration, instrumentor, &parentInfo);
816 if (failed(pipelineResult))
817 hasFailure.store(true);
819 // Reset the active bit for this pass manager.
820 activePMs[pmIndex].store(false);
823 // Signal a failure if any of the executors failed.
824 if (hasFailure)
825 signalPassFailure();
828 //===----------------------------------------------------------------------===//
829 // PassManager
830 //===----------------------------------------------------------------------===//
832 PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
833 Nesting nesting)
834 : OpPassManager(operationName, nesting), context(ctx), passTiming(false),
835 verifyPasses(true) {}
837 PassManager::PassManager(OperationName operationName, Nesting nesting)
838 : OpPassManager(operationName, nesting),
839 context(operationName.getContext()), passTiming(false),
840 verifyPasses(true) {}
842 PassManager::~PassManager() = default;
844 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
846 /// Run the passes within this manager on the provided operation.
847 LogicalResult PassManager::run(Operation *op) {
848 MLIRContext *context = getContext();
849 std::optional<OperationName> anchorOp = getOpName(*context);
850 if (anchorOp && anchorOp != op->getName())
851 return emitError(op->getLoc())
852 << "can't run '" << getOpAnchorName() << "' pass manager on '"
853 << op->getName() << "' op";
855 // Register all dialects for the current pipeline.
856 DialectRegistry dependentDialects;
857 getDependentDialects(dependentDialects);
858 context->appendDialectRegistry(dependentDialects);
859 for (StringRef name : dependentDialects.getDialectNames())
860 context->getOrLoadDialect(name);
862 // Before running, make sure to finalize the pipeline pass list.
863 if (failed(getImpl().finalizePassList(context)))
864 return failure();
866 // Notify the context that we start running a pipeline for bookkeeping.
867 context->enterMultiThreadedExecution();
869 // Initialize all of the passes within the pass manager with a new generation.
870 llvm::hash_code newInitKey = context->getRegistryHash();
871 llvm::hash_code pipelineKey = hash();
872 if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
873 if (failed(initialize(context, impl->initializationGeneration + 1)))
874 return failure();
875 initializationKey = newInitKey;
876 pipelineKey = pipelineInitializationKey;
879 // Construct a top level analysis manager for the pipeline.
880 ModuleAnalysisManager am(op, instrumentor.get());
882 // If reproducer generation is enabled, run the pass manager with crash
883 // handling enabled.
884 LogicalResult result =
885 crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);
887 // Notify the context that the run is done.
888 context->exitMultiThreadedExecution();
890 // Dump all of the pass statistics if necessary.
891 if (passStatisticsMode)
892 dumpStatistics();
893 return result;
896 /// Add the provided instrumentation to the pass manager.
897 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
898 if (!instrumentor)
899 instrumentor = std::make_unique<PassInstrumentor>();
901 instrumentor->addInstrumentation(std::move(pi));
904 LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
905 return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses,
906 impl->initializationGeneration);
909 //===----------------------------------------------------------------------===//
910 // AnalysisManager
911 //===----------------------------------------------------------------------===//
913 /// Get an analysis manager for the given operation, which must be a proper
914 /// descendant of the current operation represented by this analysis manager.
915 AnalysisManager AnalysisManager::nest(Operation *op) {
916 Operation *currentOp = impl->getOperation();
917 assert(currentOp->isProperAncestor(op) &&
918 "expected valid descendant operation");
920 // Check for the base case where the provided operation is immediately nested.
921 if (currentOp == op->getParentOp())
922 return nestImmediate(op);
924 // Otherwise, we need to collect all ancestors up to the current operation.
925 SmallVector<Operation *, 4> opAncestors;
926 do {
927 opAncestors.push_back(op);
928 op = op->getParentOp();
929 } while (op != currentOp);
931 AnalysisManager result = *this;
932 for (Operation *op : llvm::reverse(opAncestors))
933 result = result.nestImmediate(op);
934 return result;
937 /// Get an analysis manager for the given immediately nested child operation.
938 AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
939 assert(impl->getOperation() == op->getParentOp() &&
940 "expected immediate child operation");
942 auto [it, inserted] = impl->childAnalyses.try_emplace(op);
943 if (inserted)
944 it->second = std::make_unique<NestedAnalysisMap>(op, impl);
945 return {it->second.get()};
948 /// Invalidate any non preserved analyses.
949 void detail::NestedAnalysisMap::invalidate(
950 const detail::PreservedAnalyses &pa) {
951 // If all analyses were preserved, then there is nothing to do here.
952 if (pa.isAll())
953 return;
955 // Invalidate the analyses for the current operation directly.
956 analyses.invalidate(pa);
958 // If no analyses were preserved, then just simply clear out the child
959 // analysis results.
960 if (pa.isNone()) {
961 childAnalyses.clear();
962 return;
965 // Otherwise, invalidate each child analysis map.
966 SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
967 while (!mapsToInvalidate.empty()) {
968 auto *map = mapsToInvalidate.pop_back_val();
969 for (auto &analysisPair : map->childAnalyses) {
970 analysisPair.second->invalidate(pa);
971 if (!analysisPair.second->childAnalyses.empty())
972 mapsToInvalidate.push_back(analysisPair.second.get());
977 //===----------------------------------------------------------------------===//
978 // PassInstrumentation
979 //===----------------------------------------------------------------------===//
981 PassInstrumentation::~PassInstrumentation() = default;
983 void PassInstrumentation::runBeforePipeline(
984 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
986 void PassInstrumentation::runAfterPipeline(
987 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
989 //===----------------------------------------------------------------------===//
990 // PassInstrumentor
991 //===----------------------------------------------------------------------===//
993 namespace mlir {
994 namespace detail {
995 struct PassInstrumentorImpl {
996 /// Mutex to keep instrumentation access thread-safe.
997 llvm::sys::SmartMutex<true> mutex;
999 /// Set of registered instrumentations.
1000 std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
1002 } // namespace detail
1003 } // namespace mlir
1005 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
1006 PassInstrumentor::~PassInstrumentor() = default;
1008 /// See PassInstrumentation::runBeforePipeline for details.
1009 void PassInstrumentor::runBeforePipeline(
1010 std::optional<OperationName> name,
1011 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1012 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1013 for (auto &instr : impl->instrumentations)
1014 instr->runBeforePipeline(name, parentInfo);
1017 /// See PassInstrumentation::runAfterPipeline for details.
1018 void PassInstrumentor::runAfterPipeline(
1019 std::optional<OperationName> name,
1020 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1021 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1022 for (auto &instr : llvm::reverse(impl->instrumentations))
1023 instr->runAfterPipeline(name, parentInfo);
1026 /// See PassInstrumentation::runBeforePass for details.
1027 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
1028 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1029 for (auto &instr : impl->instrumentations)
1030 instr->runBeforePass(pass, op);
1033 /// See PassInstrumentation::runAfterPass for details.
1034 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
1035 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1036 for (auto &instr : llvm::reverse(impl->instrumentations))
1037 instr->runAfterPass(pass, op);
1040 /// See PassInstrumentation::runAfterPassFailed for details.
1041 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
1042 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1043 for (auto &instr : llvm::reverse(impl->instrumentations))
1044 instr->runAfterPassFailed(pass, op);
1047 /// See PassInstrumentation::runBeforeAnalysis for details.
1048 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
1049 Operation *op) {
1050 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1051 for (auto &instr : impl->instrumentations)
1052 instr->runBeforeAnalysis(name, id, op);
1055 /// See PassInstrumentation::runAfterAnalysis for details.
1056 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
1057 Operation *op) {
1058 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1059 for (auto &instr : llvm::reverse(impl->instrumentations))
1060 instr->runAfterAnalysis(name, id, op);
1063 /// Add the given instrumentation to the collection.
1064 void PassInstrumentor::addInstrumentation(
1065 std::unique_ptr<PassInstrumentation> pi) {
1066 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1067 impl->instrumentations.emplace_back(std::move(pi));