1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/ManagedStatic.h"
23 #include "llvm/Support/Mutex.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/Support/Threading.h"
26 #include "llvm/Support/ToolOutputFile.h"
29 using namespace mlir::detail
;
31 //===----------------------------------------------------------------------===//
32 // RecoveryReproducerContext
33 //===----------------------------------------------------------------------===//
37 /// This class contains all of the context for generating a recovery reproducer.
38 /// Each recovery context is registered globally to allow for generating
39 /// reproducers when a signal is raised, such as a segfault.
40 struct RecoveryReproducerContext
{
41 RecoveryReproducerContext(std::string passPipelineStr
, Operation
*op
,
42 ReproducerStreamFactory
&streamFactory
,
44 ~RecoveryReproducerContext();
46 /// Generate a reproducer with the current context.
47 void generate(std::string
&description
);
49 /// Disable this reproducer context. This prevents the context from generating
50 /// a reproducer in the result of a crash.
53 /// Enable a previously disabled reproducer context.
57 /// This function is invoked in the event of a crash.
58 static void crashHandler(void *);
60 /// Register a signal handler to run in the event of a crash.
61 static void registerSignalHandler();
63 /// The textual description of the currently executing pipeline.
64 std::string pipelineElements
;
66 /// The MLIR operation representing the IR before the crash.
67 Operation
*preCrashOperation
;
69 /// The factory for the reproducer output stream to use when generating the
71 ReproducerStreamFactory
&streamFactory
;
73 /// Various pass manager and context flags.
77 /// The current set of active reproducer contexts. This is used in the event
78 /// of a crash. This is not thread_local as the pass manager may produce any
79 /// number of child threads. This uses a set to allow for multiple MLIR pass
80 /// managers to be running at the same time.
81 static llvm::ManagedStatic
<llvm::sys::SmartMutex
<true>> reproducerMutex
;
82 static llvm::ManagedStatic
<
83 llvm::SmallSetVector
<RecoveryReproducerContext
*, 1>>
89 llvm::ManagedStatic
<llvm::sys::SmartMutex
<true>>
90 RecoveryReproducerContext::reproducerMutex
;
91 llvm::ManagedStatic
<llvm::SmallSetVector
<RecoveryReproducerContext
*, 1>>
92 RecoveryReproducerContext::reproducerSet
;
94 RecoveryReproducerContext::RecoveryReproducerContext(
95 std::string passPipelineStr
, Operation
*op
,
96 ReproducerStreamFactory
&streamFactory
, bool verifyPasses
)
97 : pipelineElements(std::move(passPipelineStr
)),
98 preCrashOperation(op
->clone()), streamFactory(streamFactory
),
99 disableThreads(!op
->getContext()->isMultithreadingEnabled()),
100 verifyPasses(verifyPasses
) {
104 RecoveryReproducerContext::~RecoveryReproducerContext() {
105 // Erase the cloned preCrash IR that we cached.
106 preCrashOperation
->erase();
110 static void appendReproducer(std::string
&description
, Operation
*op
,
111 const ReproducerStreamFactory
&factory
,
112 const std::string
&pipelineElements
,
113 bool disableThreads
, bool verifyPasses
) {
114 llvm::raw_string_ostream
descOS(description
);
116 // Try to create a new output stream for this crash reproducer.
118 std::unique_ptr
<ReproducerStream
> stream
= factory(error
);
120 descOS
<< "failed to create output stream: " << error
;
123 descOS
<< "reproducer generated at `" << stream
->description() << "`";
125 std::string pipeline
=
126 (op
->getName().getStringRef() + "(" + pipelineElements
+ ")").str();
128 state
.attachResourcePrinter(
129 "mlir_reproducer", [&](Operation
*op
, AsmResourceBuilder
&builder
) {
130 builder
.buildString("pipeline", pipeline
);
131 builder
.buildBool("disable_threading", disableThreads
);
132 builder
.buildBool("verify_each", verifyPasses
);
135 // Output the .mlir module.
136 op
->print(stream
->os(), state
);
139 void RecoveryReproducerContext::generate(std::string
&description
) {
140 appendReproducer(description
, preCrashOperation
, streamFactory
,
141 pipelineElements
, disableThreads
, verifyPasses
);
144 void RecoveryReproducerContext::disable() {
145 llvm::sys::SmartScopedLock
<true> lock(*reproducerMutex
);
146 reproducerSet
->remove(this);
147 if (reproducerSet
->empty())
148 llvm::CrashRecoveryContext::Disable();
151 void RecoveryReproducerContext::enable() {
152 llvm::sys::SmartScopedLock
<true> lock(*reproducerMutex
);
153 if (reproducerSet
->empty())
154 llvm::CrashRecoveryContext::Enable();
155 registerSignalHandler();
156 reproducerSet
->insert(this);
159 void RecoveryReproducerContext::crashHandler(void *) {
160 // Walk the current stack of contexts and generate a reproducer for each one.
161 // We can't know for certain which one was the cause, so we need to generate
162 // a reproducer for all of them.
163 for (RecoveryReproducerContext
*context
: *reproducerSet
) {
164 std::string description
;
165 context
->generate(description
);
167 // Emit an error using information only available within the context.
168 emitError(context
->preCrashOperation
->getLoc())
169 << "A signal was caught while processing the MLIR module:"
170 << description
<< "; marking pass as failed";
174 void RecoveryReproducerContext::registerSignalHandler() {
175 // Ensure that the handler is only registered once.
176 static bool registered
=
177 (llvm::sys::AddSignalHandler(crashHandler
, nullptr), false);
181 //===----------------------------------------------------------------------===//
182 // PassCrashReproducerGenerator
183 //===----------------------------------------------------------------------===//
185 struct PassCrashReproducerGenerator::Impl
{
186 Impl(ReproducerStreamFactory
&streamFactory
, bool localReproducer
)
187 : streamFactory(streamFactory
), localReproducer(localReproducer
) {}
189 /// The factory to use when generating a crash reproducer.
190 ReproducerStreamFactory streamFactory
;
192 /// Flag indicating if reproducer generation should be localized to the
194 bool localReproducer
= false;
196 /// A record of all of the currently active reproducer contexts.
197 SmallVector
<std::unique_ptr
<RecoveryReproducerContext
>> activeContexts
;
199 /// The set of all currently running passes. Note: This is not populated when
200 /// `localReproducer` is true, as each pass will get its own recovery context.
201 SetVector
<std::pair
<Pass
*, Operation
*>> runningPasses
;
203 /// Various pass manager flags that get emitted when generating a reproducer.
204 bool pmFlagVerifyPasses
= false;
207 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
208 ReproducerStreamFactory
&streamFactory
, bool localReproducer
)
209 : impl(std::make_unique
<Impl
>(streamFactory
, localReproducer
)) {}
210 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
212 void PassCrashReproducerGenerator::initialize(
213 iterator_range
<PassManager::pass_iterator
> passes
, Operation
*op
,
214 bool pmFlagVerifyPasses
) {
215 assert((!impl
->localReproducer
||
216 !op
->getContext()->isMultithreadingEnabled()) &&
217 "expected multi-threading to be disabled when generating a local "
220 llvm::CrashRecoveryContext::Enable();
221 impl
->pmFlagVerifyPasses
= pmFlagVerifyPasses
;
223 // If we aren't generating a local reproducer, prepare a reproducer for the
224 // given top-level operation.
225 if (!impl
->localReproducer
)
226 prepareReproducerFor(passes
, op
);
230 formatPassOpReproducerMessage(Diagnostic
&os
,
231 std::pair
<Pass
*, Operation
*> passOpPair
) {
232 os
<< "`" << passOpPair
.first
->getName() << "` on "
233 << "'" << passOpPair
.second
->getName() << "' operation";
234 if (SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(passOpPair
.second
))
235 os
<< ": @" << symbol
.getName();
238 void PassCrashReproducerGenerator::finalize(Operation
*rootOp
,
239 LogicalResult executionResult
) {
240 // Don't generate a reproducer if we have no active contexts.
241 if (impl
->activeContexts
.empty())
244 // If the pass manager execution succeeded, we don't generate any reproducers.
245 if (succeeded(executionResult
))
246 return impl
->activeContexts
.clear();
248 InFlightDiagnostic diag
= emitError(rootOp
->getLoc())
249 << "Failures have been detected while "
250 "processing an MLIR pass pipeline";
252 // If we are generating a global reproducer, we include all of the running
253 // passes in the error message for the only active context.
254 if (!impl
->localReproducer
) {
255 assert(impl
->activeContexts
.size() == 1 && "expected one active context");
257 // Generate the reproducer.
258 std::string description
;
259 impl
->activeContexts
.front()->generate(description
);
261 // Emit an error to the user.
262 Diagnostic
¬e
= diag
.attachNote() << "Pipeline failed while executing [";
263 llvm::interleaveComma(impl
->runningPasses
, note
,
264 [&](const std::pair
<Pass
*, Operation
*> &value
) {
265 formatPassOpReproducerMessage(note
, value
);
267 note
<< "]: " << description
;
268 impl
->runningPasses
.clear();
269 impl
->activeContexts
.clear();
273 // If we were generating a local reproducer, we generate a reproducer for the
274 // most recently executing pass using the matching entry from `runningPasses`
275 // to generate a localized diagnostic message.
276 assert(impl
->activeContexts
.size() == impl
->runningPasses
.size() &&
277 "expected running passes to match active contexts");
279 // Generate the reproducer.
280 RecoveryReproducerContext
&reproducerContext
= *impl
->activeContexts
.back();
281 std::string description
;
282 reproducerContext
.generate(description
);
284 // Emit an error to the user.
285 Diagnostic
¬e
= diag
.attachNote() << "Pipeline failed while executing ";
286 formatPassOpReproducerMessage(note
, impl
->runningPasses
.back());
287 note
<< ": " << description
;
289 impl
->activeContexts
.clear();
290 impl
->runningPasses
.clear();
293 void PassCrashReproducerGenerator::prepareReproducerFor(Pass
*pass
,
295 // If not tracking local reproducers, we simply remember that this pass is
297 impl
->runningPasses
.insert(std::make_pair(pass
, op
));
298 if (!impl
->localReproducer
)
301 // Disable the current pass recovery context, if there is one. This may happen
302 // in the case of dynamic pass pipelines.
303 if (!impl
->activeContexts
.empty())
304 impl
->activeContexts
.back()->disable();
306 // Collect all of the parent scopes of this operation.
307 SmallVector
<OperationName
> scopes
;
308 while (Operation
*parentOp
= op
->getParentOp()) {
309 scopes
.push_back(op
->getName());
313 // Emit a pass pipeline string for the current pass running on the current
316 llvm::raw_string_ostream
passOS(passStr
);
317 for (OperationName scope
: llvm::reverse(scopes
))
318 passOS
<< scope
<< "(";
319 pass
->printAsTextualPipeline(passOS
);
320 for (unsigned i
= 0, e
= scopes
.size(); i
< e
; ++i
)
323 impl
->activeContexts
.push_back(std::make_unique
<RecoveryReproducerContext
>(
324 passStr
, op
, impl
->streamFactory
, impl
->pmFlagVerifyPasses
));
326 void PassCrashReproducerGenerator::prepareReproducerFor(
327 iterator_range
<PassManager::pass_iterator
> passes
, Operation
*op
) {
329 llvm::raw_string_ostream
passOS(passStr
);
330 llvm::interleaveComma(
331 passes
, passOS
, [&](Pass
&pass
) { pass
.printAsTextualPipeline(passOS
); });
333 impl
->activeContexts
.push_back(std::make_unique
<RecoveryReproducerContext
>(
334 passStr
, op
, impl
->streamFactory
, impl
->pmFlagVerifyPasses
));
337 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass
*pass
,
339 // We only pop the active context if we are tracking local reproducers.
340 impl
->runningPasses
.remove(std::make_pair(pass
, op
));
341 if (impl
->localReproducer
) {
342 impl
->activeContexts
.pop_back();
344 // Re-enable the previous pass recovery context, if there was one. This may
345 // happen in the case of dynamic pass pipelines.
346 if (!impl
->activeContexts
.empty())
347 impl
->activeContexts
.back()->enable();
351 //===----------------------------------------------------------------------===//
352 // CrashReproducerInstrumentation
353 //===----------------------------------------------------------------------===//
356 struct CrashReproducerInstrumentation
: public PassInstrumentation
{
357 CrashReproducerInstrumentation(PassCrashReproducerGenerator
&generator
)
358 : generator(generator
) {}
359 ~CrashReproducerInstrumentation() override
= default;
361 void runBeforePass(Pass
*pass
, Operation
*op
) override
{
362 if (!isa
<OpToOpPassAdaptor
>(pass
))
363 generator
.prepareReproducerFor(pass
, op
);
366 void runAfterPass(Pass
*pass
, Operation
*op
) override
{
367 if (!isa
<OpToOpPassAdaptor
>(pass
))
368 generator
.removeLastReproducerFor(pass
, op
);
371 void runAfterPassFailed(Pass
*pass
, Operation
*op
) override
{
372 // Only generate one reproducer per crash reproducer instrumentation.
376 alreadyFailed
= true;
377 generator
.finalize(op
, /*executionResult=*/failure());
381 /// The generator used to create crash reproducers.
382 PassCrashReproducerGenerator
&generator
;
383 bool alreadyFailed
= false;
387 //===----------------------------------------------------------------------===//
388 // FileReproducerStream
389 //===----------------------------------------------------------------------===//
392 /// This class represents a default instance of mlir::ReproducerStream
393 /// that is backed by a file.
394 struct FileReproducerStream
: public mlir::ReproducerStream
{
395 FileReproducerStream(std::unique_ptr
<llvm::ToolOutputFile
> outputFile
)
396 : outputFile(std::move(outputFile
)) {}
397 ~FileReproducerStream() override
{ outputFile
->keep(); }
399 /// Returns a description of the reproducer stream.
400 StringRef
description() override
{ return outputFile
->getFilename(); }
402 /// Returns the stream on which to output the reproducer.
403 raw_ostream
&os() override
{ return outputFile
->os(); }
406 /// ToolOutputFile corresponding to opened `filename`.
407 std::unique_ptr
<llvm::ToolOutputFile
> outputFile
= nullptr;
411 //===----------------------------------------------------------------------===//
413 //===----------------------------------------------------------------------===//
415 LogicalResult
PassManager::runWithCrashRecovery(Operation
*op
,
416 AnalysisManager am
) {
417 crashReproGenerator
->initialize(getPasses(), op
, verifyPasses
);
419 // Safely invoke the passes within a recovery context.
420 LogicalResult passManagerResult
= failure();
421 llvm::CrashRecoveryContext recoveryContext
;
422 recoveryContext
.RunSafelyOnThread(
423 [&] { passManagerResult
= runPasses(op
, am
); });
424 crashReproGenerator
->finalize(op
, passManagerResult
);
425 return passManagerResult
;
428 static ReproducerStreamFactory
429 makeReproducerStreamFactory(StringRef outputFile
) {
430 // Capture the filename by value in case outputFile is out of scope when
432 std::string filename
= outputFile
.str();
433 return [filename
](std::string
&error
) -> std::unique_ptr
<ReproducerStream
> {
434 std::unique_ptr
<llvm::ToolOutputFile
> outputFile
=
435 mlir::openOutputFile(filename
, &error
);
437 error
= "Failed to create reproducer stream: " + error
;
440 return std::make_unique
<FileReproducerStream
>(std::move(outputFile
));
444 void printAsTextualPipeline(
445 raw_ostream
&os
, StringRef anchorName
,
446 const llvm::iterator_range
<OpPassManager::pass_iterator
> &passes
);
448 std::string
mlir::makeReproducer(
449 StringRef anchorName
,
450 const llvm::iterator_range
<OpPassManager::pass_iterator
> &passes
,
451 Operation
*op
, StringRef outputFile
, bool disableThreads
,
454 std::string description
;
455 std::string pipelineStr
;
456 llvm::raw_string_ostream
passOS(pipelineStr
);
457 ::printAsTextualPipeline(passOS
, anchorName
, passes
);
458 appendReproducer(description
, op
, makeReproducerStreamFactory(outputFile
),
459 pipelineStr
, disableThreads
, verifyPasses
);
463 void PassManager::enableCrashReproducerGeneration(StringRef outputFile
,
464 bool genLocalReproducer
) {
465 enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile
),
469 void PassManager::enableCrashReproducerGeneration(
470 ReproducerStreamFactory factory
, bool genLocalReproducer
) {
471 assert(!crashReproGenerator
&&
472 "crash reproducer has already been initialized");
473 if (genLocalReproducer
&& getContext()->isMultithreadingEnabled())
474 llvm::report_fatal_error(
475 "Local crash reproduction can't be setup on a "
476 "pass-manager without disabling multi-threading first.");
478 crashReproGenerator
= std::make_unique
<PassCrashReproducerGenerator
>(
479 factory
, genLocalReproducer
);
481 std::make_unique
<CrashReproducerInstrumentation
>(*crashReproGenerator
));
484 //===----------------------------------------------------------------------===//
486 //===----------------------------------------------------------------------===//
488 void PassReproducerOptions::attachResourceParser(ParserConfig
&config
) {
489 auto parseFn
= [this](AsmParsedResourceEntry
&entry
) -> LogicalResult
{
490 if (entry
.getKey() == "pipeline") {
491 FailureOr
<std::string
> value
= entry
.parseAsString();
492 if (succeeded(value
))
493 this->pipeline
= std::move(*value
);
496 if (entry
.getKey() == "disable_threading") {
497 FailureOr
<bool> value
= entry
.parseAsBool();
498 if (succeeded(value
))
499 this->disableThreading
= *value
;
502 if (entry
.getKey() == "verify_each") {
503 FailureOr
<bool> value
= entry
.parseAsBool();
504 if (succeeded(value
))
505 this->verifyEach
= *value
;
508 return entry
.emitError() << "unknown 'mlir_reproducer' resource key '"
509 << entry
.getKey() << "'";
511 config
.attachResourceParser("mlir_reproducer", parseFn
);
514 LogicalResult
PassReproducerOptions::apply(PassManager
&pm
) const {
515 if (pipeline
.has_value()) {
516 FailureOr
<OpPassManager
> reproPm
= parsePassPipeline(*pipeline
);
519 static_cast<OpPassManager
&>(pm
) = std::move(*reproPm
);
522 if (disableThreading
.has_value())
523 pm
.getContext()->disableMultithreading(*disableThreading
);
525 if (verifyEach
.has_value())
526 pm
.enableVerifier(*verifyEach
);