1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
28 using namespace mlir::detail
;
30 //===----------------------------------------------------------------------===//
32 //===----------------------------------------------------------------------===//
34 /// Construct from an Attribute.
35 DiagnosticArgument::DiagnosticArgument(Attribute attr
)
36 : kind(DiagnosticArgumentKind::Attribute
),
37 opaqueVal(reinterpret_cast<intptr_t>(attr
.getAsOpaquePointer())) {}
39 /// Construct from a Type.
40 DiagnosticArgument::DiagnosticArgument(Type val
)
41 : kind(DiagnosticArgumentKind::Type
),
42 opaqueVal(reinterpret_cast<intptr_t>(val
.getAsOpaquePointer())) {}
44 /// Returns this argument as an Attribute.
45 Attribute
DiagnosticArgument::getAsAttribute() const {
46 assert(getKind() == DiagnosticArgumentKind::Attribute
);
47 return Attribute::getFromOpaquePointer(
48 reinterpret_cast<const void *>(opaqueVal
));
51 /// Returns this argument as a Type.
52 Type
DiagnosticArgument::getAsType() const {
53 assert(getKind() == DiagnosticArgumentKind::Type
);
54 return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal
));
57 /// Outputs this argument to a stream.
58 void DiagnosticArgument::print(raw_ostream
&os
) const {
60 case DiagnosticArgumentKind::Attribute
:
61 os
<< getAsAttribute();
63 case DiagnosticArgumentKind::Double
:
66 case DiagnosticArgumentKind::Integer
:
69 case DiagnosticArgumentKind::String
:
72 case DiagnosticArgumentKind::Type
:
73 os
<< '\'' << getAsType() << '\'';
75 case DiagnosticArgumentKind::Unsigned
:
76 os
<< getAsUnsigned();
81 //===----------------------------------------------------------------------===//
83 //===----------------------------------------------------------------------===//
85 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
86 /// stored in 'strings'.
87 static StringRef
twineToStrRef(const Twine
&val
,
88 std::vector
<std::unique_ptr
<char[]>> &strings
) {
89 // Allocate memory to hold this string.
91 auto strRef
= val
.toStringRef(data
);
95 strings
.push_back(std::unique_ptr
<char[]>(new char[strRef
.size()]));
96 memcpy(&strings
.back()[0], strRef
.data(), strRef
.size());
97 // Return a reference to the new string.
98 return StringRef(&strings
.back()[0], strRef
.size());
101 /// Stream in a Twine argument.
102 Diagnostic
&Diagnostic::operator<<(char val
) { return *this << Twine(val
); }
103 Diagnostic
&Diagnostic::operator<<(const Twine
&val
) {
104 arguments
.push_back(DiagnosticArgument(twineToStrRef(val
, strings
)));
107 Diagnostic
&Diagnostic::operator<<(Twine
&&val
) {
108 arguments
.push_back(DiagnosticArgument(twineToStrRef(val
, strings
)));
112 Diagnostic
&Diagnostic::operator<<(StringAttr val
) {
113 arguments
.push_back(DiagnosticArgument(val
));
117 /// Stream in an OperationName.
118 Diagnostic
&Diagnostic::operator<<(OperationName val
) {
119 // An OperationName is stored in the context, so we don't need to worry about
120 // the lifetime of its data.
121 arguments
.push_back(DiagnosticArgument(val
.getStringRef()));
125 /// Adjusts operation printing flags used in diagnostics for the given severity
127 static OpPrintingFlags
adjustPrintingFlags(OpPrintingFlags flags
,
128 DiagnosticSeverity severity
) {
129 flags
.useLocalScope();
130 flags
.elideLargeElementsAttrs();
131 if (severity
== DiagnosticSeverity::Error
)
132 flags
.printGenericOpForm();
136 /// Stream in an Operation.
137 Diagnostic
&Diagnostic::operator<<(Operation
&op
) {
138 return appendOp(op
, OpPrintingFlags());
141 Diagnostic
&Diagnostic::appendOp(Operation
&op
, const OpPrintingFlags
&flags
) {
143 llvm::raw_string_ostream
os(str
);
144 op
.print(os
, adjustPrintingFlags(flags
, severity
));
145 // Print on a new line for better readability if the op will be printed on
147 if (str
.find('\n') != std::string::npos
)
149 return *this << os
.str();
152 /// Stream in a Value.
153 Diagnostic
&Diagnostic::operator<<(Value val
) {
155 llvm::raw_string_ostream
os(str
);
156 val
.print(os
, adjustPrintingFlags(OpPrintingFlags(), severity
));
157 return *this << os
.str();
160 /// Outputs this diagnostic to a stream.
161 void Diagnostic::print(raw_ostream
&os
) const {
162 for (auto &arg
: getArguments())
166 /// Convert the diagnostic to a string.
167 std::string
Diagnostic::str() const {
169 llvm::raw_string_ostream
os(str
);
174 /// Attaches a note to this diagnostic. A new location may be optionally
175 /// provided, if not, then the location defaults to the one specified for this
176 /// diagnostic. Notes may not be attached to other notes.
177 Diagnostic
&Diagnostic::attachNote(std::optional
<Location
> noteLoc
) {
178 // We don't allow attaching notes to notes.
179 assert(severity
!= DiagnosticSeverity::Note
&&
180 "cannot attach a note to a note");
182 // If a location wasn't provided then reuse our location.
186 /// Append and return a new note.
188 std::make_unique
<Diagnostic
>(*noteLoc
, DiagnosticSeverity::Note
));
189 return *notes
.back();
192 /// Allow a diagnostic to be converted to 'failure'.
193 Diagnostic::operator LogicalResult() const { return failure(); }
195 //===----------------------------------------------------------------------===//
196 // InFlightDiagnostic
197 //===----------------------------------------------------------------------===//
199 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
200 /// 'success' if this is an empty diagnostic.
201 InFlightDiagnostic::operator LogicalResult() const {
202 return failure(isActive());
205 /// Reports the diagnostic to the engine.
206 void InFlightDiagnostic::report() {
207 // If this diagnostic is still inflight and it hasn't been abandoned, then
210 owner
->emit(std::move(*impl
));
216 /// Abandons this diagnostic.
217 void InFlightDiagnostic::abandon() { owner
= nullptr; }
219 //===----------------------------------------------------------------------===//
220 // DiagnosticEngineImpl
221 //===----------------------------------------------------------------------===//
225 struct DiagnosticEngineImpl
{
226 /// Emit a diagnostic using the registered issue handle if present, or with
227 /// the default behavior if not.
228 void emit(Diagnostic
&&diag
);
230 /// A mutex to ensure that diagnostics emission is thread-safe.
231 llvm::sys::SmartMutex
<true> mutex
;
233 /// These are the handlers used to report diagnostics.
234 llvm::SmallMapVector
<DiagnosticEngine::HandlerID
, DiagnosticEngine::HandlerTy
,
238 /// This is a unique identifier counter for diagnostic handlers in the
239 /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
240 DiagnosticEngine::HandlerID uniqueHandlerId
= 1;
242 } // namespace detail
245 /// Emit a diagnostic using the registered issue handle if present, or with
246 /// the default behavior if not.
247 void DiagnosticEngineImpl::emit(Diagnostic
&&diag
) {
248 llvm::sys::SmartScopedLock
<true> lock(mutex
);
250 // Try to process the given diagnostic on one of the registered handlers.
251 // Handlers are walked in reverse order, so that the most recent handler is
253 for (auto &handlerIt
: llvm::reverse(handlers
))
254 if (succeeded(handlerIt
.second(diag
)))
257 // Otherwise, if this is an error we emit it to stderr.
258 if (diag
.getSeverity() != DiagnosticSeverity::Error
)
261 auto &os
= llvm::errs();
262 if (!llvm::isa
<UnknownLoc
>(diag
.getLocation()))
263 os
<< diag
.getLocation() << ": ";
266 // The default behavior for errors is to emit them to stderr.
271 //===----------------------------------------------------------------------===//
273 //===----------------------------------------------------------------------===//
275 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
276 DiagnosticEngine::~DiagnosticEngine() = default;
278 /// Register a new handler for diagnostics to the engine. This function returns
279 /// a unique identifier for the registered handler, which can be used to
280 /// unregister this handler at a later time.
281 auto DiagnosticEngine::registerHandler(HandlerTy handler
) -> HandlerID
{
282 llvm::sys::SmartScopedLock
<true> lock(impl
->mutex
);
283 auto uniqueID
= impl
->uniqueHandlerId
++;
284 impl
->handlers
.insert({uniqueID
, std::move(handler
)});
288 /// Erase the registered diagnostic handler with the given identifier.
289 void DiagnosticEngine::eraseHandler(HandlerID handlerID
) {
290 llvm::sys::SmartScopedLock
<true> lock(impl
->mutex
);
291 impl
->handlers
.erase(handlerID
);
294 /// Emit a diagnostic using the registered issue handler if present, or with
295 /// the default behavior if not.
296 void DiagnosticEngine::emit(Diagnostic
&&diag
) {
297 assert(diag
.getSeverity() != DiagnosticSeverity::Note
&&
298 "notes should not be emitted directly");
299 impl
->emit(std::move(diag
));
302 /// Helper function used to emit a diagnostic with an optionally empty twine
303 /// message. If the message is empty, then it is not inserted into the
305 static InFlightDiagnostic
306 emitDiag(Location location
, DiagnosticSeverity severity
, const Twine
&message
) {
307 MLIRContext
*ctx
= location
->getContext();
308 auto &diagEngine
= ctx
->getDiagEngine();
309 auto diag
= diagEngine
.emit(location
, severity
);
310 if (!message
.isTriviallyEmpty())
313 // Add the stack trace as a note if necessary.
314 if (ctx
->shouldPrintStackTraceOnDiagnostic()) {
317 llvm::raw_string_ostream
stream(bt
);
318 llvm::sys::PrintStackTrace(stream
);
321 diag
.attachNote() << "diagnostic emitted with trace:\n" << bt
;
327 /// Emit an error message using this location.
328 InFlightDiagnostic
mlir::emitError(Location loc
) { return emitError(loc
, {}); }
329 InFlightDiagnostic
mlir::emitError(Location loc
, const Twine
&message
) {
330 return emitDiag(loc
, DiagnosticSeverity::Error
, message
);
333 /// Emit a warning message using this location.
334 InFlightDiagnostic
mlir::emitWarning(Location loc
) {
335 return emitWarning(loc
, {});
337 InFlightDiagnostic
mlir::emitWarning(Location loc
, const Twine
&message
) {
338 return emitDiag(loc
, DiagnosticSeverity::Warning
, message
);
341 /// Emit a remark message using this location.
342 InFlightDiagnostic
mlir::emitRemark(Location loc
) {
343 return emitRemark(loc
, {});
345 InFlightDiagnostic
mlir::emitRemark(Location loc
, const Twine
&message
) {
346 return emitDiag(loc
, DiagnosticSeverity::Remark
, message
);
349 //===----------------------------------------------------------------------===//
350 // ScopedDiagnosticHandler
351 //===----------------------------------------------------------------------===//
353 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
355 ctx
->getDiagEngine().eraseHandler(handlerID
);
358 //===----------------------------------------------------------------------===//
359 // SourceMgrDiagnosticHandler
360 //===----------------------------------------------------------------------===//
363 struct SourceMgrDiagnosticHandlerImpl
{
364 /// Return the SrcManager buffer id for the specified file, or zero if none
366 unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr
&mgr
,
367 StringRef filename
) {
368 // Check for an existing mapping to the buffer id for this file.
369 auto bufferIt
= filenameToBufId
.find(filename
);
370 if (bufferIt
!= filenameToBufId
.end())
371 return bufferIt
->second
;
373 // Look for a buffer in the manager that has this filename.
374 for (unsigned i
= 1, e
= mgr
.getNumBuffers() + 1; i
!= e
; ++i
) {
375 auto *buf
= mgr
.getMemoryBuffer(i
);
376 if (buf
->getBufferIdentifier() == filename
)
377 return filenameToBufId
[filename
] = i
;
380 // Otherwise, try to load the source file.
382 unsigned id
= mgr
.AddIncludeFile(std::string(filename
), SMLoc(), ignored
);
383 filenameToBufId
[filename
] = id
;
387 /// Mapping between file name and buffer ID's.
388 llvm::StringMap
<unsigned> filenameToBufId
;
390 } // namespace detail
393 /// Return a processable CallSiteLoc from the given location.
394 static std::optional
<CallSiteLoc
> getCallSiteLoc(Location loc
) {
395 if (dyn_cast
<NameLoc
>(loc
))
396 return getCallSiteLoc(cast
<NameLoc
>(loc
).getChildLoc());
397 if (auto callLoc
= dyn_cast
<CallSiteLoc
>(loc
))
399 if (dyn_cast
<FusedLoc
>(loc
)) {
400 for (auto subLoc
: cast
<FusedLoc
>(loc
).getLocations()) {
401 if (auto callLoc
= getCallSiteLoc(subLoc
)) {
410 /// Given a diagnostic kind, returns the LLVM DiagKind.
411 static llvm::SourceMgr::DiagKind
getDiagKind(DiagnosticSeverity kind
) {
413 case DiagnosticSeverity::Note
:
414 return llvm::SourceMgr::DK_Note
;
415 case DiagnosticSeverity::Warning
:
416 return llvm::SourceMgr::DK_Warning
;
417 case DiagnosticSeverity::Error
:
418 return llvm::SourceMgr::DK_Error
;
419 case DiagnosticSeverity::Remark
:
420 return llvm::SourceMgr::DK_Remark
;
422 llvm_unreachable("Unknown DiagnosticSeverity");
425 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
426 llvm::SourceMgr
&mgr
, MLIRContext
*ctx
, raw_ostream
&os
,
427 ShouldShowLocFn
&&shouldShowLocFn
)
428 : ScopedDiagnosticHandler(ctx
), mgr(mgr
), os(os
),
429 shouldShowLocFn(std::move(shouldShowLocFn
)),
430 impl(new SourceMgrDiagnosticHandlerImpl()) {
431 setHandler([this](Diagnostic
&diag
) { emitDiagnostic(diag
); });
434 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
435 llvm::SourceMgr
&mgr
, MLIRContext
*ctx
, ShouldShowLocFn
&&shouldShowLocFn
)
436 : SourceMgrDiagnosticHandler(mgr
, ctx
, llvm::errs(),
437 std::move(shouldShowLocFn
)) {}
439 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
441 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc
, Twine message
,
442 DiagnosticSeverity kind
,
443 bool displaySourceLine
) {
444 // Extract a file location from this loc.
445 auto fileLoc
= loc
->findInstanceOf
<FileLineColLoc
>();
447 // If one doesn't exist, then print the raw message without a source location.
450 llvm::raw_string_ostream
strOS(str
);
451 if (!llvm::isa
<UnknownLoc
>(loc
))
452 strOS
<< loc
<< ": ";
454 return mgr
.PrintMessage(os
, SMLoc(), getDiagKind(kind
), strOS
.str());
457 // Otherwise if we are displaying the source line, try to convert the file
458 // location to an SMLoc.
459 if (displaySourceLine
) {
460 auto smloc
= convertLocToSMLoc(fileLoc
);
462 return mgr
.PrintMessage(os
, smloc
, getDiagKind(kind
), message
);
465 // If the conversion was unsuccessful, create a diagnostic with the file
466 // information. We manually combine the line and column to avoid asserts in
467 // the constructor of SMDiagnostic that takes a location.
469 llvm::raw_string_ostream
locOS(locStr
);
470 locOS
<< fileLoc
.getFilename().getValue() << ":" << fileLoc
.getLine() << ":"
471 << fileLoc
.getColumn();
472 llvm::SMDiagnostic
diag(locOS
.str(), getDiagKind(kind
), message
.str());
473 diag
.print(nullptr, os
);
476 /// Emit the given diagnostic with the held source manager.
477 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic
&diag
) {
478 SmallVector
<std::pair
<Location
, StringRef
>> locationStack
;
479 auto addLocToStack
= [&](Location loc
, StringRef locContext
) {
480 if (std::optional
<Location
> showableLoc
= findLocToShow(loc
))
481 locationStack
.emplace_back(*showableLoc
, locContext
);
484 // Add locations to display for this diagnostic.
485 Location loc
= diag
.getLocation();
486 addLocToStack(loc
, /*locContext=*/{});
488 // If the diagnostic location was a call site location, add the call stack as
490 if (auto callLoc
= getCallSiteLoc(loc
)) {
491 // Print the call stack while valid, or until the limit is reached.
492 loc
= callLoc
->getCaller();
493 for (unsigned curDepth
= 0; curDepth
< callStackLimit
; ++curDepth
) {
494 addLocToStack(loc
, "called from");
495 if ((callLoc
= getCallSiteLoc(loc
)))
496 loc
= callLoc
->getCaller();
502 // If the location stack is empty, use the initial location.
503 if (locationStack
.empty()) {
504 emitDiagnostic(diag
.getLocation(), diag
.str(), diag
.getSeverity());
506 // Otherwise, use the location stack.
508 emitDiagnostic(locationStack
.front().first
, diag
.str(), diag
.getSeverity());
509 for (auto &it
: llvm::drop_begin(locationStack
))
510 emitDiagnostic(it
.first
, it
.second
, DiagnosticSeverity::Note
);
513 // Emit each of the notes. Only display the source code if the location is
514 // different from the previous location.
515 for (auto ¬e
: diag
.getNotes()) {
516 emitDiagnostic(note
.getLocation(), note
.str(), note
.getSeverity(),
517 /*displaySourceLine=*/loc
!= note
.getLocation());
518 loc
= note
.getLocation();
522 /// Get a memory buffer for the given file, or nullptr if one is not found.
523 const llvm::MemoryBuffer
*
524 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename
) {
525 if (unsigned id
= impl
->getSourceMgrBufferIDForFile(mgr
, filename
))
526 return mgr
.getMemoryBuffer(id
);
530 std::optional
<Location
>
531 SourceMgrDiagnosticHandler::findLocToShow(Location loc
) {
532 if (!shouldShowLocFn
)
534 if (!shouldShowLocFn(loc
))
537 // Recurse into the child locations of some of location types.
538 return TypeSwitch
<LocationAttr
, std::optional
<Location
>>(loc
)
539 .Case([&](CallSiteLoc callLoc
) -> std::optional
<Location
> {
540 // We recurse into the callee of a call site, as the caller will be
541 // emitted in a different note on the main diagnostic.
542 return findLocToShow(callLoc
.getCallee());
544 .Case([&](FileLineColLoc
) -> std::optional
<Location
> { return loc
; })
545 .Case([&](FusedLoc fusedLoc
) -> std::optional
<Location
> {
546 // Fused location is unique in that we try to find a sub-location to
547 // show, rather than the top-level location itself.
548 for (Location childLoc
: fusedLoc
.getLocations())
549 if (std::optional
<Location
> showableLoc
= findLocToShow(childLoc
))
553 .Case([&](NameLoc nameLoc
) -> std::optional
<Location
> {
554 return findLocToShow(nameLoc
.getChildLoc());
556 .Case([&](OpaqueLoc opaqueLoc
) -> std::optional
<Location
> {
557 // OpaqueLoc always falls back to a different source location.
558 return findLocToShow(opaqueLoc
.getFallbackLocation());
560 .Case([](UnknownLoc
) -> std::optional
<Location
> {
561 // Prefer not to show unknown locations.
566 /// Get a memory buffer for the given file, or the main file of the source
567 /// manager if one doesn't exist. This always returns non-null.
568 SMLoc
SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc
) {
569 // The column and line may be zero to represent unknown column and/or unknown
570 /// line/column information.
571 if (loc
.getLine() == 0 || loc
.getColumn() == 0)
574 unsigned bufferId
= impl
->getSourceMgrBufferIDForFile(mgr
, loc
.getFilename());
577 return mgr
.FindLocForLineAndColumn(bufferId
, loc
.getLine(), loc
.getColumn());
580 //===----------------------------------------------------------------------===//
581 // SourceMgrDiagnosticVerifierHandler
582 //===----------------------------------------------------------------------===//
586 /// This class represents an expected output diagnostic.
587 struct ExpectedDiag
{
588 ExpectedDiag(DiagnosticSeverity kind
, unsigned lineNo
, SMLoc fileLoc
,
590 : kind(kind
), lineNo(lineNo
), fileLoc(fileLoc
), substring(substring
) {}
592 /// Emit an error at the location referenced by this diagnostic.
593 LogicalResult
emitError(raw_ostream
&os
, llvm::SourceMgr
&mgr
,
595 SMRange
range(fileLoc
, SMLoc::getFromPointer(fileLoc
.getPointer() +
597 mgr
.PrintMessage(os
, fileLoc
, llvm::SourceMgr::DK_Error
, msg
, range
);
601 /// Returns true if this diagnostic matches the given string.
602 bool match(StringRef str
) const {
603 // If this isn't a regex diagnostic, we simply check if the string was
606 return substringRegex
->match(str
);
607 return str
.contains(substring
);
610 /// Compute the regex matcher for this diagnostic, using the provided stream
611 /// and manager to emit diagnostics as necessary.
612 LogicalResult
computeRegex(raw_ostream
&os
, llvm::SourceMgr
&mgr
) {
613 std::string regexStr
;
614 llvm::raw_string_ostream
regexOS(regexStr
);
615 StringRef strToProcess
= substring
;
616 while (!strToProcess
.empty()) {
617 // Find the next regex block.
618 size_t regexIt
= strToProcess
.find("{{");
619 if (regexIt
== StringRef::npos
) {
620 regexOS
<< llvm::Regex::escape(strToProcess
);
623 regexOS
<< llvm::Regex::escape(strToProcess
.take_front(regexIt
));
624 strToProcess
= strToProcess
.drop_front(regexIt
+ 2);
626 // Find the end of the regex block.
627 size_t regexEndIt
= strToProcess
.find("}}");
628 if (regexEndIt
== StringRef::npos
)
629 return emitError(os
, mgr
, "found start of regex with no end '}}'");
630 StringRef regexStr
= strToProcess
.take_front(regexEndIt
);
632 // Validate that the regex is actually valid.
633 std::string regexError
;
634 if (!llvm::Regex(regexStr
).isValid(regexError
))
635 return emitError(os
, mgr
, "invalid regex: " + regexError
);
637 regexOS
<< '(' << regexStr
<< ')';
638 strToProcess
= strToProcess
.drop_front(regexEndIt
+ 2);
640 substringRegex
= llvm::Regex(regexOS
.str());
644 /// The severity of the diagnosic expected.
645 DiagnosticSeverity kind
;
646 /// The line number the expected diagnostic should be on.
648 /// The location of the expected diagnostic within the input file.
650 /// A flag indicating if the expected diagnostic has been matched yet.
651 bool matched
= false;
652 /// The substring that is expected to be within the diagnostic.
654 /// An optional regex matcher, if the expected diagnostic sub-string was a
656 std::optional
<llvm::Regex
> substringRegex
;
659 struct SourceMgrDiagnosticVerifierHandlerImpl
{
660 SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
662 /// Returns the expected diagnostics for the given source file.
663 std::optional
<MutableArrayRef
<ExpectedDiag
>>
664 getExpectedDiags(StringRef bufName
);
666 /// Computes the expected diagnostics for the given source buffer.
667 MutableArrayRef
<ExpectedDiag
>
668 computeExpectedDiags(raw_ostream
&os
, llvm::SourceMgr
&mgr
,
669 const llvm::MemoryBuffer
*buf
);
671 /// The current status of the verifier.
672 LogicalResult status
;
674 /// A list of expected diagnostics for each buffer of the source manager.
675 llvm::StringMap
<SmallVector
<ExpectedDiag
, 2>> expectedDiagsPerFile
;
677 /// Regex to match the expected diagnostics format.
678 llvm::Regex expected
=
679 llvm::Regex("expected-(error|note|remark|warning)(-re)? "
680 "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
682 } // namespace detail
685 /// Given a diagnostic kind, return a human readable string for it.
686 static StringRef
getDiagKindStr(DiagnosticSeverity kind
) {
688 case DiagnosticSeverity::Note
:
690 case DiagnosticSeverity::Warning
:
692 case DiagnosticSeverity::Error
:
694 case DiagnosticSeverity::Remark
:
697 llvm_unreachable("Unknown DiagnosticSeverity");
700 std::optional
<MutableArrayRef
<ExpectedDiag
>>
701 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName
) {
702 auto expectedDiags
= expectedDiagsPerFile
.find(bufName
);
703 if (expectedDiags
!= expectedDiagsPerFile
.end())
704 return MutableArrayRef
<ExpectedDiag
>(expectedDiags
->second
);
708 MutableArrayRef
<ExpectedDiag
>
709 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
710 raw_ostream
&os
, llvm::SourceMgr
&mgr
, const llvm::MemoryBuffer
*buf
) {
711 // If the buffer is invalid, return an empty list.
714 auto &expectedDiags
= expectedDiagsPerFile
[buf
->getBufferIdentifier()];
716 // The number of the last line that did not correlate to a designator.
717 unsigned lastNonDesignatorLine
= 0;
719 // The indices of designators that apply to the next non designator line.
720 SmallVector
<unsigned, 1> designatorsForNextLine
;
722 // Scan the file for expected-* designators.
723 SmallVector
<StringRef
, 100> lines
;
724 buf
->getBuffer().split(lines
, '\n');
725 for (unsigned lineNo
= 0, e
= lines
.size(); lineNo
< e
; ++lineNo
) {
726 SmallVector
<StringRef
, 4> matches
;
727 if (!expected
.match(lines
[lineNo
].rtrim(), &matches
)) {
728 // Check for designators that apply to this line.
729 if (!designatorsForNextLine
.empty()) {
730 for (unsigned diagIndex
: designatorsForNextLine
)
731 expectedDiags
[diagIndex
].lineNo
= lineNo
+ 1;
732 designatorsForNextLine
.clear();
734 lastNonDesignatorLine
= lineNo
;
738 // Point to the start of expected-*.
739 SMLoc expectedStart
= SMLoc::getFromPointer(matches
[0].data());
741 DiagnosticSeverity kind
;
742 if (matches
[1] == "error")
743 kind
= DiagnosticSeverity::Error
;
744 else if (matches
[1] == "warning")
745 kind
= DiagnosticSeverity::Warning
;
746 else if (matches
[1] == "remark")
747 kind
= DiagnosticSeverity::Remark
;
749 assert(matches
[1] == "note");
750 kind
= DiagnosticSeverity::Note
;
752 ExpectedDiag
record(kind
, lineNo
+ 1, expectedStart
, matches
[5]);
754 // Check to see if this is a regex match, i.e. it includes the `-re`.
755 if (!matches
[2].empty() && failed(record
.computeRegex(os
, mgr
))) {
760 StringRef offsetMatch
= matches
[3];
761 if (!offsetMatch
.empty()) {
762 offsetMatch
= offsetMatch
.drop_front(1);
764 // Get the integer value without the @ and +/- prefix.
765 if (offsetMatch
[0] == '+' || offsetMatch
[0] == '-') {
767 offsetMatch
.drop_front().getAsInteger(0, offset
);
769 if (offsetMatch
.front() == '+')
770 record
.lineNo
+= offset
;
772 record
.lineNo
-= offset
;
773 } else if (offsetMatch
.consume_front("above")) {
774 // If the designator applies 'above' we add it to the last non
776 record
.lineNo
= lastNonDesignatorLine
+ 1;
778 // Otherwise, this is a 'below' designator and applies to the next
779 // non-designator line.
780 assert(offsetMatch
.consume_front("below"));
781 designatorsForNextLine
.push_back(expectedDiags
.size());
783 // Set the line number to the last in the case that this designator ends
788 expectedDiags
.emplace_back(std::move(record
));
790 return expectedDiags
;
793 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
794 llvm::SourceMgr
&srcMgr
, MLIRContext
*ctx
, raw_ostream
&out
)
795 : SourceMgrDiagnosticHandler(srcMgr
, ctx
, out
),
796 impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
797 // Compute the expected diagnostics for each of the current files in the
799 for (unsigned i
= 0, e
= mgr
.getNumBuffers(); i
!= e
; ++i
)
800 (void)impl
->computeExpectedDiags(out
, mgr
, mgr
.getMemoryBuffer(i
+ 1));
802 // Register a handler to verify the diagnostics.
803 setHandler([&](Diagnostic
&diag
) {
804 // Process the main diagnostics.
807 // Process each of the notes.
808 for (auto ¬e
: diag
.getNotes())
813 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
814 llvm::SourceMgr
&srcMgr
, MLIRContext
*ctx
)
815 : SourceMgrDiagnosticVerifierHandler(srcMgr
, ctx
, llvm::errs()) {}
817 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
818 // Ensure that all expected diagnostics were handled.
822 /// Returns the status of the verifier and verifies that all expected
823 /// diagnostics were emitted. This return success if all diagnostics were
824 /// verified correctly, failure otherwise.
825 LogicalResult
SourceMgrDiagnosticVerifierHandler::verify() {
826 // Verify that all expected errors were seen.
827 for (auto &expectedDiagsPair
: impl
->expectedDiagsPerFile
) {
828 for (auto &err
: expectedDiagsPair
.second
) {
832 err
.emitError(os
, mgr
,
833 "expected " + getDiagKindStr(err
.kind
) + " \"" +
834 err
.substring
+ "\" was not produced");
837 impl
->expectedDiagsPerFile
.clear();
841 /// Process a single diagnostic.
842 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic
&diag
) {
843 auto kind
= diag
.getSeverity();
845 // Process a FileLineColLoc.
846 if (auto fileLoc
= diag
.getLocation()->findInstanceOf
<FileLineColLoc
>())
847 return process(fileLoc
, diag
.str(), kind
);
849 emitDiagnostic(diag
.getLocation(),
850 "unexpected " + getDiagKindStr(kind
) + ": " + diag
.str(),
851 DiagnosticSeverity::Error
);
852 impl
->status
= failure();
855 /// Process a FileLineColLoc diagnostic.
856 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc
,
858 DiagnosticSeverity kind
) {
859 // Get the expected diagnostics for this file.
860 auto diags
= impl
->getExpectedDiags(loc
.getFilename());
862 diags
= impl
->computeExpectedDiags(os
, mgr
,
863 getBufferForFile(loc
.getFilename()));
866 // Search for a matching expected diagnostic.
867 // If we find something that is close then emit a more specific error.
868 ExpectedDiag
*nearMiss
= nullptr;
870 // If this was an expected error, remember that we saw it and return.
871 unsigned line
= loc
.getLine();
872 for (auto &e
: *diags
) {
873 if (line
== e
.lineNo
&& e
.match(msg
)) {
874 if (e
.kind
== kind
) {
879 // If this only differs based on the diagnostic kind, then consider it
880 // to be a near miss.
885 // Otherwise, emit an error for the near miss.
887 mgr
.PrintMessage(os
, nearMiss
->fileLoc
, llvm::SourceMgr::DK_Error
,
888 "'" + getDiagKindStr(kind
) +
889 "' diagnostic emitted when expecting a '" +
890 getDiagKindStr(nearMiss
->kind
) + "'");
892 emitDiagnostic(loc
, "unexpected " + getDiagKindStr(kind
) + ": " + msg
,
893 DiagnosticSeverity::Error
);
894 impl
->status
= failure();
897 //===----------------------------------------------------------------------===//
898 // ParallelDiagnosticHandler
899 //===----------------------------------------------------------------------===//
903 struct ParallelDiagnosticHandlerImpl
: public llvm::PrettyStackTraceEntry
{
904 struct ThreadDiagnostic
{
905 ThreadDiagnostic(size_t id
, Diagnostic diag
)
906 : id(id
), diag(std::move(diag
)) {}
907 bool operator<(const ThreadDiagnostic
&rhs
) const { return id
< rhs
.id
; }
909 /// The id for this diagnostic, this is used for ordering.
910 /// Note: This id corresponds to the ordered position of the current element
911 /// being processed by a given thread.
918 ParallelDiagnosticHandlerImpl(MLIRContext
*ctx
) : context(ctx
) {
919 handlerID
= ctx
->getDiagEngine().registerHandler([this](Diagnostic
&diag
) {
920 uint64_t tid
= llvm::get_threadid();
921 llvm::sys::SmartScopedLock
<true> lock(mutex
);
923 // If this thread is not tracked, then return failure to let another
924 // handler process this diagnostic.
925 if (!threadToOrderID
.count(tid
))
928 // Append a new diagnostic.
929 diagnostics
.emplace_back(threadToOrderID
[tid
], std::move(diag
));
934 ~ParallelDiagnosticHandlerImpl() override
{
935 // Erase this handler from the context.
936 context
->getDiagEngine().eraseHandler(handlerID
);
938 // Early exit if there are no diagnostics, this is the common case.
939 if (diagnostics
.empty())
942 // Emit the diagnostics back to the context.
943 emitDiagnostics([&](Diagnostic
&diag
) {
944 return context
->getDiagEngine().emit(std::move(diag
));
948 /// Utility method to emit any held diagnostics.
949 void emitDiagnostics(llvm::function_ref
<void(Diagnostic
&)> emitFn
) const {
950 // Stable sort all of the diagnostics that were emitted. This creates a
951 // deterministic ordering for the diagnostics based upon which order id they
953 std::stable_sort(diagnostics
.begin(), diagnostics
.end());
955 // Emit each diagnostic to the context again.
956 for (ThreadDiagnostic
&diag
: diagnostics
)
960 /// Set the order id for the current thread.
961 void setOrderIDForThread(size_t orderID
) {
962 uint64_t tid
= llvm::get_threadid();
963 llvm::sys::SmartScopedLock
<true> lock(mutex
);
964 threadToOrderID
[tid
] = orderID
;
967 /// Remove the order id for the current thread.
968 void eraseOrderIDForThread() {
969 uint64_t tid
= llvm::get_threadid();
970 llvm::sys::SmartScopedLock
<true> lock(mutex
);
971 threadToOrderID
.erase(tid
);
974 /// Dump the current diagnostics that were inflight.
975 void print(raw_ostream
&os
) const override
{
976 // Early exit if there are no diagnostics, this is the common case.
977 if (diagnostics
.empty())
980 os
<< "In-Flight Diagnostics:\n";
981 emitDiagnostics([&](const Diagnostic
&diag
) {
984 // Print each diagnostic with the format:
985 // "<location>: <kind>: <msg>"
986 if (!llvm::isa
<UnknownLoc
>(diag
.getLocation()))
987 os
<< diag
.getLocation() << ": ";
988 switch (diag
.getSeverity()) {
989 case DiagnosticSeverity::Error
:
992 case DiagnosticSeverity::Warning
:
995 case DiagnosticSeverity::Note
:
998 case DiagnosticSeverity::Remark
:
1006 /// A smart mutex to lock access to the internal state.
1007 llvm::sys::SmartMutex
<true> mutex
;
1009 /// A mapping between the thread id and the current order id.
1010 DenseMap
<uint64_t, size_t> threadToOrderID
;
1012 /// An unordered list of diagnostics that were emitted.
1013 mutable std::vector
<ThreadDiagnostic
> diagnostics
;
1015 /// The unique id for the parallel handler.
1016 DiagnosticEngine::HandlerID handlerID
= 0;
1018 /// The context to emit the diagnostics to.
1019 MLIRContext
*context
;
1021 } // namespace detail
1024 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext
*ctx
)
1025 : impl(new ParallelDiagnosticHandlerImpl(ctx
)) {}
1026 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1028 /// Set the order id for the current thread.
1029 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID
) {
1030 impl
->setOrderIDForThread(orderID
);
1033 /// Remove the order id for the current thread. This removes the thread from
1034 /// diagnostics tracking.
1035 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1036 impl
->eraseOrderIDForThread();