1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/Mutex.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/Threading.h"
25 #include "llvm/Support/ToolOutputFile.h"
28 using namespace mlir::detail
;
30 //===----------------------------------------------------------------------===//
31 // RecoveryReproducerContext
32 //===----------------------------------------------------------------------===//
36 /// This class contains all of the context for generating a recovery reproducer.
37 /// Each recovery context is registered globally to allow for generating
38 /// reproducers when a signal is raised, such as a segfault.
39 struct RecoveryReproducerContext
{
40 RecoveryReproducerContext(std::string passPipelineStr
, Operation
*op
,
41 ReproducerStreamFactory
&streamFactory
,
43 ~RecoveryReproducerContext();
45 /// Generate a reproducer with the current context.
46 void generate(std::string
&description
);
48 /// Disable this reproducer context. This prevents the context from generating
49 /// a reproducer in the result of a crash.
52 /// Enable a previously disabled reproducer context.
56 /// This function is invoked in the event of a crash.
57 static void crashHandler(void *);
59 /// Register a signal handler to run in the event of a crash.
60 static void registerSignalHandler();
62 /// The textual description of the currently executing pipeline.
63 std::string pipelineElements
;
65 /// The MLIR operation representing the IR before the crash.
66 Operation
*preCrashOperation
;
68 /// The factory for the reproducer output stream to use when generating the
70 ReproducerStreamFactory
&streamFactory
;
72 /// Various pass manager and context flags.
76 /// The current set of active reproducer contexts. This is used in the event
77 /// of a crash. This is not thread_local as the pass manager may produce any
78 /// number of child threads. This uses a set to allow for multiple MLIR pass
79 /// managers to be running at the same time.
80 static llvm::ManagedStatic
<llvm::sys::SmartMutex
<true>> reproducerMutex
;
81 static llvm::ManagedStatic
<
82 llvm::SmallSetVector
<RecoveryReproducerContext
*, 1>>
88 llvm::ManagedStatic
<llvm::sys::SmartMutex
<true>>
89 RecoveryReproducerContext::reproducerMutex
;
90 llvm::ManagedStatic
<llvm::SmallSetVector
<RecoveryReproducerContext
*, 1>>
91 RecoveryReproducerContext::reproducerSet
;
93 RecoveryReproducerContext::RecoveryReproducerContext(
94 std::string passPipelineStr
, Operation
*op
,
95 ReproducerStreamFactory
&streamFactory
, bool verifyPasses
)
96 : pipelineElements(std::move(passPipelineStr
)),
97 preCrashOperation(op
->clone()), streamFactory(streamFactory
),
98 disableThreads(!op
->getContext()->isMultithreadingEnabled()),
99 verifyPasses(verifyPasses
) {
103 RecoveryReproducerContext::~RecoveryReproducerContext() {
104 // Erase the cloned preCrash IR that we cached.
105 preCrashOperation
->erase();
109 static void appendReproducer(std::string
&description
, Operation
*op
,
110 const ReproducerStreamFactory
&factory
,
111 const std::string
&pipelineElements
,
112 bool disableThreads
, bool verifyPasses
) {
113 llvm::raw_string_ostream
descOS(description
);
115 // Try to create a new output stream for this crash reproducer.
117 std::unique_ptr
<ReproducerStream
> stream
= factory(error
);
119 descOS
<< "failed to create output stream: " << error
;
122 descOS
<< "reproducer generated at `" << stream
->description() << "`";
124 std::string pipeline
=
125 (op
->getName().getStringRef() + "(" + pipelineElements
+ ")").str();
127 state
.attachResourcePrinter(
128 "mlir_reproducer", [&](Operation
*op
, AsmResourceBuilder
&builder
) {
129 builder
.buildString("pipeline", pipeline
);
130 builder
.buildBool("disable_threading", disableThreads
);
131 builder
.buildBool("verify_each", verifyPasses
);
134 // Output the .mlir module.
135 op
->print(stream
->os(), state
);
138 void RecoveryReproducerContext::generate(std::string
&description
) {
139 appendReproducer(description
, preCrashOperation
, streamFactory
,
140 pipelineElements
, disableThreads
, verifyPasses
);
143 void RecoveryReproducerContext::disable() {
144 llvm::sys::SmartScopedLock
<true> lock(*reproducerMutex
);
145 reproducerSet
->remove(this);
146 if (reproducerSet
->empty())
147 llvm::CrashRecoveryContext::Disable();
150 void RecoveryReproducerContext::enable() {
151 llvm::sys::SmartScopedLock
<true> lock(*reproducerMutex
);
152 if (reproducerSet
->empty())
153 llvm::CrashRecoveryContext::Enable();
154 registerSignalHandler();
155 reproducerSet
->insert(this);
158 void RecoveryReproducerContext::crashHandler(void *) {
159 // Walk the current stack of contexts and generate a reproducer for each one.
160 // We can't know for certain which one was the cause, so we need to generate
161 // a reproducer for all of them.
162 for (RecoveryReproducerContext
*context
: *reproducerSet
) {
163 std::string description
;
164 context
->generate(description
);
166 // Emit an error using information only available within the context.
167 emitError(context
->preCrashOperation
->getLoc())
168 << "A signal was caught while processing the MLIR module:"
169 << description
<< "; marking pass as failed";
173 void RecoveryReproducerContext::registerSignalHandler() {
174 // Ensure that the handler is only registered once.
175 static bool registered
=
176 (llvm::sys::AddSignalHandler(crashHandler
, nullptr), false);
180 //===----------------------------------------------------------------------===//
181 // PassCrashReproducerGenerator
182 //===----------------------------------------------------------------------===//
184 struct PassCrashReproducerGenerator::Impl
{
185 Impl(ReproducerStreamFactory
&streamFactory
, bool localReproducer
)
186 : streamFactory(streamFactory
), localReproducer(localReproducer
) {}
188 /// The factory to use when generating a crash reproducer.
189 ReproducerStreamFactory streamFactory
;
191 /// Flag indicating if reproducer generation should be localized to the
193 bool localReproducer
= false;
195 /// A record of all of the currently active reproducer contexts.
196 SmallVector
<std::unique_ptr
<RecoveryReproducerContext
>> activeContexts
;
198 /// The set of all currently running passes. Note: This is not populated when
199 /// `localReproducer` is true, as each pass will get its own recovery context.
200 SetVector
<std::pair
<Pass
*, Operation
*>> runningPasses
;
202 /// Various pass manager flags that get emitted when generating a reproducer.
203 bool pmFlagVerifyPasses
= false;
206 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
207 ReproducerStreamFactory
&streamFactory
, bool localReproducer
)
208 : impl(std::make_unique
<Impl
>(streamFactory
, localReproducer
)) {}
209 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
211 void PassCrashReproducerGenerator::initialize(
212 iterator_range
<PassManager::pass_iterator
> passes
, Operation
*op
,
213 bool pmFlagVerifyPasses
) {
214 assert((!impl
->localReproducer
||
215 !op
->getContext()->isMultithreadingEnabled()) &&
216 "expected multi-threading to be disabled when generating a local "
219 llvm::CrashRecoveryContext::Enable();
220 impl
->pmFlagVerifyPasses
= pmFlagVerifyPasses
;
222 // If we aren't generating a local reproducer, prepare a reproducer for the
223 // given top-level operation.
224 if (!impl
->localReproducer
)
225 prepareReproducerFor(passes
, op
);
229 formatPassOpReproducerMessage(Diagnostic
&os
,
230 std::pair
<Pass
*, Operation
*> passOpPair
) {
231 os
<< "`" << passOpPair
.first
->getName() << "` on "
232 << "'" << passOpPair
.second
->getName() << "' operation";
233 if (SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(passOpPair
.second
))
234 os
<< ": @" << symbol
.getName();
237 void PassCrashReproducerGenerator::finalize(Operation
*rootOp
,
238 LogicalResult executionResult
) {
239 // Don't generate a reproducer if we have no active contexts.
240 if (impl
->activeContexts
.empty())
243 // If the pass manager execution succeeded, we don't generate any reproducers.
244 if (succeeded(executionResult
))
245 return impl
->activeContexts
.clear();
247 InFlightDiagnostic diag
= emitError(rootOp
->getLoc())
248 << "Failures have been detected while "
249 "processing an MLIR pass pipeline";
251 // If we are generating a global reproducer, we include all of the running
252 // passes in the error message for the only active context.
253 if (!impl
->localReproducer
) {
254 assert(impl
->activeContexts
.size() == 1 && "expected one active context");
256 // Generate the reproducer.
257 std::string description
;
258 impl
->activeContexts
.front()->generate(description
);
260 // Emit an error to the user.
261 Diagnostic
¬e
= diag
.attachNote() << "Pipeline failed while executing [";
262 llvm::interleaveComma(impl
->runningPasses
, note
,
263 [&](const std::pair
<Pass
*, Operation
*> &value
) {
264 formatPassOpReproducerMessage(note
, value
);
266 note
<< "]: " << description
;
267 impl
->runningPasses
.clear();
268 impl
->activeContexts
.clear();
272 // If we were generating a local reproducer, we generate a reproducer for the
273 // most recently executing pass using the matching entry from `runningPasses`
274 // to generate a localized diagnostic message.
275 assert(impl
->activeContexts
.size() == impl
->runningPasses
.size() &&
276 "expected running passes to match active contexts");
278 // Generate the reproducer.
279 RecoveryReproducerContext
&reproducerContext
= *impl
->activeContexts
.back();
280 std::string description
;
281 reproducerContext
.generate(description
);
283 // Emit an error to the user.
284 Diagnostic
¬e
= diag
.attachNote() << "Pipeline failed while executing ";
285 formatPassOpReproducerMessage(note
, impl
->runningPasses
.back());
286 note
<< ": " << description
;
288 impl
->activeContexts
.clear();
289 impl
->runningPasses
.clear();
292 void PassCrashReproducerGenerator::prepareReproducerFor(Pass
*pass
,
294 // If not tracking local reproducers, we simply remember that this pass is
296 impl
->runningPasses
.insert(std::make_pair(pass
, op
));
297 if (!impl
->localReproducer
)
300 // Disable the current pass recovery context, if there is one. This may happen
301 // in the case of dynamic pass pipelines.
302 if (!impl
->activeContexts
.empty())
303 impl
->activeContexts
.back()->disable();
305 // Collect all of the parent scopes of this operation.
306 SmallVector
<OperationName
> scopes
;
307 while (Operation
*parentOp
= op
->getParentOp()) {
308 scopes
.push_back(op
->getName());
312 // Emit a pass pipeline string for the current pass running on the current
315 llvm::raw_string_ostream
passOS(passStr
);
316 for (OperationName scope
: llvm::reverse(scopes
))
317 passOS
<< scope
<< "(";
318 pass
->printAsTextualPipeline(passOS
);
319 for (unsigned i
= 0, e
= scopes
.size(); i
< e
; ++i
)
322 impl
->activeContexts
.push_back(std::make_unique
<RecoveryReproducerContext
>(
323 passOS
.str(), op
, impl
->streamFactory
, impl
->pmFlagVerifyPasses
));
325 void PassCrashReproducerGenerator::prepareReproducerFor(
326 iterator_range
<PassManager::pass_iterator
> passes
, Operation
*op
) {
328 llvm::raw_string_ostream
passOS(passStr
);
329 llvm::interleaveComma(
330 passes
, passOS
, [&](Pass
&pass
) { pass
.printAsTextualPipeline(passOS
); });
332 impl
->activeContexts
.push_back(std::make_unique
<RecoveryReproducerContext
>(
333 passOS
.str(), op
, impl
->streamFactory
, impl
->pmFlagVerifyPasses
));
336 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass
*pass
,
338 // We only pop the active context if we are tracking local reproducers.
339 impl
->runningPasses
.remove(std::make_pair(pass
, op
));
340 if (impl
->localReproducer
) {
341 impl
->activeContexts
.pop_back();
343 // Re-enable the previous pass recovery context, if there was one. This may
344 // happen in the case of dynamic pass pipelines.
345 if (!impl
->activeContexts
.empty())
346 impl
->activeContexts
.back()->enable();
350 //===----------------------------------------------------------------------===//
351 // CrashReproducerInstrumentation
352 //===----------------------------------------------------------------------===//
355 struct CrashReproducerInstrumentation
: public PassInstrumentation
{
356 CrashReproducerInstrumentation(PassCrashReproducerGenerator
&generator
)
357 : generator(generator
) {}
358 ~CrashReproducerInstrumentation() override
= default;
360 void runBeforePass(Pass
*pass
, Operation
*op
) override
{
361 if (!isa
<OpToOpPassAdaptor
>(pass
))
362 generator
.prepareReproducerFor(pass
, op
);
365 void runAfterPass(Pass
*pass
, Operation
*op
) override
{
366 if (!isa
<OpToOpPassAdaptor
>(pass
))
367 generator
.removeLastReproducerFor(pass
, op
);
370 void runAfterPassFailed(Pass
*pass
, Operation
*op
) override
{
371 // Only generate one reproducer per crash reproducer instrumentation.
375 alreadyFailed
= true;
376 generator
.finalize(op
, /*executionResult=*/failure());
380 /// The generator used to create crash reproducers.
381 PassCrashReproducerGenerator
&generator
;
382 bool alreadyFailed
= false;
386 //===----------------------------------------------------------------------===//
387 // FileReproducerStream
388 //===----------------------------------------------------------------------===//
391 /// This class represents a default instance of mlir::ReproducerStream
392 /// that is backed by a file.
393 struct FileReproducerStream
: public mlir::ReproducerStream
{
394 FileReproducerStream(std::unique_ptr
<llvm::ToolOutputFile
> outputFile
)
395 : outputFile(std::move(outputFile
)) {}
396 ~FileReproducerStream() override
{ outputFile
->keep(); }
398 /// Returns a description of the reproducer stream.
399 StringRef
description() override
{ return outputFile
->getFilename(); }
401 /// Returns the stream on which to output the reproducer.
402 raw_ostream
&os() override
{ return outputFile
->os(); }
405 /// ToolOutputFile corresponding to opened `filename`.
406 std::unique_ptr
<llvm::ToolOutputFile
> outputFile
= nullptr;
410 //===----------------------------------------------------------------------===//
412 //===----------------------------------------------------------------------===//
414 LogicalResult
PassManager::runWithCrashRecovery(Operation
*op
,
415 AnalysisManager am
) {
416 crashReproGenerator
->initialize(getPasses(), op
, verifyPasses
);
418 // Safely invoke the passes within a recovery context.
419 LogicalResult passManagerResult
= failure();
420 llvm::CrashRecoveryContext recoveryContext
;
421 recoveryContext
.RunSafelyOnThread(
422 [&] { passManagerResult
= runPasses(op
, am
); });
423 crashReproGenerator
->finalize(op
, passManagerResult
);
424 return passManagerResult
;
427 static ReproducerStreamFactory
428 makeReproducerStreamFactory(StringRef outputFile
) {
429 // Capture the filename by value in case outputFile is out of scope when
431 std::string filename
= outputFile
.str();
432 return [filename
](std::string
&error
) -> std::unique_ptr
<ReproducerStream
> {
433 std::unique_ptr
<llvm::ToolOutputFile
> outputFile
=
434 mlir::openOutputFile(filename
, &error
);
436 error
= "Failed to create reproducer stream: " + error
;
439 return std::make_unique
<FileReproducerStream
>(std::move(outputFile
));
443 void printAsTextualPipeline(
444 raw_ostream
&os
, StringRef anchorName
,
445 const llvm::iterator_range
<OpPassManager::pass_iterator
> &passes
);
447 std::string
mlir::makeReproducer(
448 StringRef anchorName
,
449 const llvm::iterator_range
<OpPassManager::pass_iterator
> &passes
,
450 Operation
*op
, StringRef outputFile
, bool disableThreads
,
453 std::string description
;
454 std::string pipelineStr
;
455 llvm::raw_string_ostream
passOS(pipelineStr
);
456 ::printAsTextualPipeline(passOS
, anchorName
, passes
);
457 appendReproducer(description
, op
, makeReproducerStreamFactory(outputFile
),
458 pipelineStr
, disableThreads
, verifyPasses
);
462 void PassManager::enableCrashReproducerGeneration(StringRef outputFile
,
463 bool genLocalReproducer
) {
464 enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile
),
468 void PassManager::enableCrashReproducerGeneration(
469 ReproducerStreamFactory factory
, bool genLocalReproducer
) {
470 assert(!crashReproGenerator
&&
471 "crash reproducer has already been initialized");
472 if (genLocalReproducer
&& getContext()->isMultithreadingEnabled())
473 llvm::report_fatal_error(
474 "Local crash reproduction can't be setup on a "
475 "pass-manager without disabling multi-threading first.");
477 crashReproGenerator
= std::make_unique
<PassCrashReproducerGenerator
>(
478 factory
, genLocalReproducer
);
480 std::make_unique
<CrashReproducerInstrumentation
>(*crashReproGenerator
));
483 //===----------------------------------------------------------------------===//
485 //===----------------------------------------------------------------------===//
487 void PassReproducerOptions::attachResourceParser(ParserConfig
&config
) {
488 auto parseFn
= [this](AsmParsedResourceEntry
&entry
) -> LogicalResult
{
489 if (entry
.getKey() == "pipeline") {
490 FailureOr
<std::string
> value
= entry
.parseAsString();
491 if (succeeded(value
))
492 this->pipeline
= std::move(*value
);
495 if (entry
.getKey() == "disable_threading") {
496 FailureOr
<bool> value
= entry
.parseAsBool();
497 if (succeeded(value
))
498 this->disableThreading
= *value
;
501 if (entry
.getKey() == "verify_each") {
502 FailureOr
<bool> value
= entry
.parseAsBool();
503 if (succeeded(value
))
504 this->verifyEach
= *value
;
507 return entry
.emitError() << "unknown 'mlir_reproducer' resource key '"
508 << entry
.getKey() << "'";
510 config
.attachResourceParser("mlir_reproducer", parseFn
);
513 LogicalResult
PassReproducerOptions::apply(PassManager
&pm
) const {
514 if (pipeline
.has_value()) {
515 FailureOr
<OpPassManager
> reproPm
= parsePassPipeline(*pipeline
);
518 static_cast<OpPassManager
&>(pm
) = std::move(*reproPm
);
521 if (disableThreading
.has_value())
522 pm
.getContext()->disableMultithreading(*disableThreading
);
524 if (verifyEach
.has_value())
525 pm
.enableVerifier(*verifyEach
);