[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / IR / Diagnostics.cpp
blob6b311a90e0de59c14ed746d00ca2e455bd2e05df
1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include <optional>
27 using namespace mlir;
28 using namespace mlir::detail;
30 //===----------------------------------------------------------------------===//
31 // DiagnosticArgument
32 //===----------------------------------------------------------------------===//
34 /// Construct from an Attribute.
35 DiagnosticArgument::DiagnosticArgument(Attribute attr)
36 : kind(DiagnosticArgumentKind::Attribute),
37 opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
39 /// Construct from a Type.
40 DiagnosticArgument::DiagnosticArgument(Type val)
41 : kind(DiagnosticArgumentKind::Type),
42 opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
44 /// Returns this argument as an Attribute.
45 Attribute DiagnosticArgument::getAsAttribute() const {
46 assert(getKind() == DiagnosticArgumentKind::Attribute);
47 return Attribute::getFromOpaquePointer(
48 reinterpret_cast<const void *>(opaqueVal));
51 /// Returns this argument as a Type.
52 Type DiagnosticArgument::getAsType() const {
53 assert(getKind() == DiagnosticArgumentKind::Type);
54 return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
57 /// Outputs this argument to a stream.
58 void DiagnosticArgument::print(raw_ostream &os) const {
59 switch (kind) {
60 case DiagnosticArgumentKind::Attribute:
61 os << getAsAttribute();
62 break;
63 case DiagnosticArgumentKind::Double:
64 os << getAsDouble();
65 break;
66 case DiagnosticArgumentKind::Integer:
67 os << getAsInteger();
68 break;
69 case DiagnosticArgumentKind::String:
70 os << getAsString();
71 break;
72 case DiagnosticArgumentKind::Type:
73 os << '\'' << getAsType() << '\'';
74 break;
75 case DiagnosticArgumentKind::Unsigned:
76 os << getAsUnsigned();
77 break;
81 //===----------------------------------------------------------------------===//
82 // Diagnostic
83 //===----------------------------------------------------------------------===//
85 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
86 /// stored in 'strings'.
87 static StringRef twineToStrRef(const Twine &val,
88 std::vector<std::unique_ptr<char[]>> &strings) {
89 // Allocate memory to hold this string.
90 SmallString<64> data;
91 auto strRef = val.toStringRef(data);
92 if (strRef.empty())
93 return strRef;
95 strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
96 memcpy(&strings.back()[0], strRef.data(), strRef.size());
97 // Return a reference to the new string.
98 return StringRef(&strings.back()[0], strRef.size());
101 /// Stream in a Twine argument.
102 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
103 Diagnostic &Diagnostic::operator<<(const Twine &val) {
104 arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
105 return *this;
107 Diagnostic &Diagnostic::operator<<(Twine &&val) {
108 arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
109 return *this;
112 Diagnostic &Diagnostic::operator<<(StringAttr val) {
113 arguments.push_back(DiagnosticArgument(val));
114 return *this;
117 /// Stream in an OperationName.
118 Diagnostic &Diagnostic::operator<<(OperationName val) {
119 // An OperationName is stored in the context, so we don't need to worry about
120 // the lifetime of its data.
121 arguments.push_back(DiagnosticArgument(val.getStringRef()));
122 return *this;
125 /// Adjusts operation printing flags used in diagnostics for the given severity
126 /// level.
127 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
128 DiagnosticSeverity severity) {
129 flags.useLocalScope();
130 flags.elideLargeElementsAttrs();
131 if (severity == DiagnosticSeverity::Error)
132 flags.printGenericOpForm();
133 return flags;
136 /// Stream in an Operation.
137 Diagnostic &Diagnostic::operator<<(Operation &op) {
138 return appendOp(op, OpPrintingFlags());
141 Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
142 std::string str;
143 llvm::raw_string_ostream os(str);
144 op.print(os, adjustPrintingFlags(flags, severity));
145 // Print on a new line for better readability if the op will be printed on
146 // multiple lines.
147 if (str.find('\n') != std::string::npos)
148 *this << '\n';
149 return *this << os.str();
152 /// Stream in a Value.
153 Diagnostic &Diagnostic::operator<<(Value val) {
154 std::string str;
155 llvm::raw_string_ostream os(str);
156 val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
157 return *this << os.str();
160 /// Outputs this diagnostic to a stream.
161 void Diagnostic::print(raw_ostream &os) const {
162 for (auto &arg : getArguments())
163 arg.print(os);
166 /// Convert the diagnostic to a string.
167 std::string Diagnostic::str() const {
168 std::string str;
169 llvm::raw_string_ostream os(str);
170 print(os);
171 return os.str();
174 /// Attaches a note to this diagnostic. A new location may be optionally
175 /// provided, if not, then the location defaults to the one specified for this
176 /// diagnostic. Notes may not be attached to other notes.
177 Diagnostic &Diagnostic::attachNote(std::optional<Location> noteLoc) {
178 // We don't allow attaching notes to notes.
179 assert(severity != DiagnosticSeverity::Note &&
180 "cannot attach a note to a note");
182 // If a location wasn't provided then reuse our location.
183 if (!noteLoc)
184 noteLoc = loc;
186 /// Append and return a new note.
187 notes.push_back(
188 std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
189 return *notes.back();
192 /// Allow a diagnostic to be converted to 'failure'.
193 Diagnostic::operator LogicalResult() const { return failure(); }
195 //===----------------------------------------------------------------------===//
196 // InFlightDiagnostic
197 //===----------------------------------------------------------------------===//
199 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
200 /// 'success' if this is an empty diagnostic.
201 InFlightDiagnostic::operator LogicalResult() const {
202 return failure(isActive());
205 /// Reports the diagnostic to the engine.
206 void InFlightDiagnostic::report() {
207 // If this diagnostic is still inflight and it hasn't been abandoned, then
208 // report it.
209 if (isInFlight()) {
210 owner->emit(std::move(*impl));
211 owner = nullptr;
213 impl.reset();
216 /// Abandons this diagnostic.
217 void InFlightDiagnostic::abandon() { owner = nullptr; }
219 //===----------------------------------------------------------------------===//
220 // DiagnosticEngineImpl
221 //===----------------------------------------------------------------------===//
223 namespace mlir {
224 namespace detail {
225 struct DiagnosticEngineImpl {
226 /// Emit a diagnostic using the registered issue handle if present, or with
227 /// the default behavior if not.
228 void emit(Diagnostic &&diag);
230 /// A mutex to ensure that diagnostics emission is thread-safe.
231 llvm::sys::SmartMutex<true> mutex;
233 /// These are the handlers used to report diagnostics.
234 llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
236 handlers;
238 /// This is a unique identifier counter for diagnostic handlers in the
239 /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
240 DiagnosticEngine::HandlerID uniqueHandlerId = 1;
242 } // namespace detail
243 } // namespace mlir
245 /// Emit a diagnostic using the registered issue handle if present, or with
246 /// the default behavior if not.
247 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
248 llvm::sys::SmartScopedLock<true> lock(mutex);
250 // Try to process the given diagnostic on one of the registered handlers.
251 // Handlers are walked in reverse order, so that the most recent handler is
252 // processed first.
253 for (auto &handlerIt : llvm::reverse(handlers))
254 if (succeeded(handlerIt.second(diag)))
255 return;
257 // Otherwise, if this is an error we emit it to stderr.
258 if (diag.getSeverity() != DiagnosticSeverity::Error)
259 return;
261 auto &os = llvm::errs();
262 if (!llvm::isa<UnknownLoc>(diag.getLocation()))
263 os << diag.getLocation() << ": ";
264 os << "error: ";
266 // The default behavior for errors is to emit them to stderr.
267 os << diag << '\n';
268 os.flush();
271 //===----------------------------------------------------------------------===//
272 // DiagnosticEngine
273 //===----------------------------------------------------------------------===//
275 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
276 DiagnosticEngine::~DiagnosticEngine() = default;
278 /// Register a new handler for diagnostics to the engine. This function returns
279 /// a unique identifier for the registered handler, which can be used to
280 /// unregister this handler at a later time.
281 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
282 llvm::sys::SmartScopedLock<true> lock(impl->mutex);
283 auto uniqueID = impl->uniqueHandlerId++;
284 impl->handlers.insert({uniqueID, std::move(handler)});
285 return uniqueID;
288 /// Erase the registered diagnostic handler with the given identifier.
289 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
290 llvm::sys::SmartScopedLock<true> lock(impl->mutex);
291 impl->handlers.erase(handlerID);
294 /// Emit a diagnostic using the registered issue handler if present, or with
295 /// the default behavior if not.
296 void DiagnosticEngine::emit(Diagnostic &&diag) {
297 assert(diag.getSeverity() != DiagnosticSeverity::Note &&
298 "notes should not be emitted directly");
299 impl->emit(std::move(diag));
302 /// Helper function used to emit a diagnostic with an optionally empty twine
303 /// message. If the message is empty, then it is not inserted into the
304 /// diagnostic.
305 static InFlightDiagnostic
306 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
307 MLIRContext *ctx = location->getContext();
308 auto &diagEngine = ctx->getDiagEngine();
309 auto diag = diagEngine.emit(location, severity);
310 if (!message.isTriviallyEmpty())
311 diag << message;
313 // Add the stack trace as a note if necessary.
314 if (ctx->shouldPrintStackTraceOnDiagnostic()) {
315 std::string bt;
317 llvm::raw_string_ostream stream(bt);
318 llvm::sys::PrintStackTrace(stream);
320 if (!bt.empty())
321 diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
324 return diag;
327 /// Emit an error message using this location.
328 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
329 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
330 return emitDiag(loc, DiagnosticSeverity::Error, message);
333 /// Emit a warning message using this location.
334 InFlightDiagnostic mlir::emitWarning(Location loc) {
335 return emitWarning(loc, {});
337 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
338 return emitDiag(loc, DiagnosticSeverity::Warning, message);
341 /// Emit a remark message using this location.
342 InFlightDiagnostic mlir::emitRemark(Location loc) {
343 return emitRemark(loc, {});
345 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
346 return emitDiag(loc, DiagnosticSeverity::Remark, message);
349 //===----------------------------------------------------------------------===//
350 // ScopedDiagnosticHandler
351 //===----------------------------------------------------------------------===//
353 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
354 if (handlerID)
355 ctx->getDiagEngine().eraseHandler(handlerID);
358 //===----------------------------------------------------------------------===//
359 // SourceMgrDiagnosticHandler
360 //===----------------------------------------------------------------------===//
361 namespace mlir {
362 namespace detail {
363 struct SourceMgrDiagnosticHandlerImpl {
364 /// Return the SrcManager buffer id for the specified file, or zero if none
365 /// can be found.
366 unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
367 StringRef filename) {
368 // Check for an existing mapping to the buffer id for this file.
369 auto bufferIt = filenameToBufId.find(filename);
370 if (bufferIt != filenameToBufId.end())
371 return bufferIt->second;
373 // Look for a buffer in the manager that has this filename.
374 for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
375 auto *buf = mgr.getMemoryBuffer(i);
376 if (buf->getBufferIdentifier() == filename)
377 return filenameToBufId[filename] = i;
380 // Otherwise, try to load the source file.
381 std::string ignored;
382 unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
383 filenameToBufId[filename] = id;
384 return id;
387 /// Mapping between file name and buffer ID's.
388 llvm::StringMap<unsigned> filenameToBufId;
390 } // namespace detail
391 } // namespace mlir
393 /// Return a processable CallSiteLoc from the given location.
394 static std::optional<CallSiteLoc> getCallSiteLoc(Location loc) {
395 if (dyn_cast<NameLoc>(loc))
396 return getCallSiteLoc(cast<NameLoc>(loc).getChildLoc());
397 if (auto callLoc = dyn_cast<CallSiteLoc>(loc))
398 return callLoc;
399 if (dyn_cast<FusedLoc>(loc)) {
400 for (auto subLoc : cast<FusedLoc>(loc).getLocations()) {
401 if (auto callLoc = getCallSiteLoc(subLoc)) {
402 return callLoc;
405 return std::nullopt;
407 return std::nullopt;
410 /// Given a diagnostic kind, returns the LLVM DiagKind.
411 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
412 switch (kind) {
413 case DiagnosticSeverity::Note:
414 return llvm::SourceMgr::DK_Note;
415 case DiagnosticSeverity::Warning:
416 return llvm::SourceMgr::DK_Warning;
417 case DiagnosticSeverity::Error:
418 return llvm::SourceMgr::DK_Error;
419 case DiagnosticSeverity::Remark:
420 return llvm::SourceMgr::DK_Remark;
422 llvm_unreachable("Unknown DiagnosticSeverity");
425 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
426 llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
427 ShouldShowLocFn &&shouldShowLocFn)
428 : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
429 shouldShowLocFn(std::move(shouldShowLocFn)),
430 impl(new SourceMgrDiagnosticHandlerImpl()) {
431 setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
434 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
435 llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
436 : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
437 std::move(shouldShowLocFn)) {}
439 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
441 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
442 DiagnosticSeverity kind,
443 bool displaySourceLine) {
444 // Extract a file location from this loc.
445 auto fileLoc = loc->findInstanceOf<FileLineColLoc>();
447 // If one doesn't exist, then print the raw message without a source location.
448 if (!fileLoc) {
449 std::string str;
450 llvm::raw_string_ostream strOS(str);
451 if (!llvm::isa<UnknownLoc>(loc))
452 strOS << loc << ": ";
453 strOS << message;
454 return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
457 // Otherwise if we are displaying the source line, try to convert the file
458 // location to an SMLoc.
459 if (displaySourceLine) {
460 auto smloc = convertLocToSMLoc(fileLoc);
461 if (smloc.isValid())
462 return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
465 // If the conversion was unsuccessful, create a diagnostic with the file
466 // information. We manually combine the line and column to avoid asserts in
467 // the constructor of SMDiagnostic that takes a location.
468 std::string locStr;
469 llvm::raw_string_ostream locOS(locStr);
470 locOS << fileLoc.getFilename().getValue() << ":" << fileLoc.getLine() << ":"
471 << fileLoc.getColumn();
472 llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
473 diag.print(nullptr, os);
476 /// Emit the given diagnostic with the held source manager.
477 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
478 SmallVector<std::pair<Location, StringRef>> locationStack;
479 auto addLocToStack = [&](Location loc, StringRef locContext) {
480 if (std::optional<Location> showableLoc = findLocToShow(loc))
481 locationStack.emplace_back(*showableLoc, locContext);
484 // Add locations to display for this diagnostic.
485 Location loc = diag.getLocation();
486 addLocToStack(loc, /*locContext=*/{});
488 // If the diagnostic location was a call site location, add the call stack as
489 // well.
490 if (auto callLoc = getCallSiteLoc(loc)) {
491 // Print the call stack while valid, or until the limit is reached.
492 loc = callLoc->getCaller();
493 for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
494 addLocToStack(loc, "called from");
495 if ((callLoc = getCallSiteLoc(loc)))
496 loc = callLoc->getCaller();
497 else
498 break;
502 // If the location stack is empty, use the initial location.
503 if (locationStack.empty()) {
504 emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
506 // Otherwise, use the location stack.
507 } else {
508 emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
509 for (auto &it : llvm::drop_begin(locationStack))
510 emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
513 // Emit each of the notes. Only display the source code if the location is
514 // different from the previous location.
515 for (auto &note : diag.getNotes()) {
516 emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
517 /*displaySourceLine=*/loc != note.getLocation());
518 loc = note.getLocation();
522 /// Get a memory buffer for the given file, or nullptr if one is not found.
523 const llvm::MemoryBuffer *
524 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
525 if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
526 return mgr.getMemoryBuffer(id);
527 return nullptr;
530 std::optional<Location>
531 SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
532 if (!shouldShowLocFn)
533 return loc;
534 if (!shouldShowLocFn(loc))
535 return std::nullopt;
537 // Recurse into the child locations of some of location types.
538 return TypeSwitch<LocationAttr, std::optional<Location>>(loc)
539 .Case([&](CallSiteLoc callLoc) -> std::optional<Location> {
540 // We recurse into the callee of a call site, as the caller will be
541 // emitted in a different note on the main diagnostic.
542 return findLocToShow(callLoc.getCallee());
544 .Case([&](FileLineColLoc) -> std::optional<Location> { return loc; })
545 .Case([&](FusedLoc fusedLoc) -> std::optional<Location> {
546 // Fused location is unique in that we try to find a sub-location to
547 // show, rather than the top-level location itself.
548 for (Location childLoc : fusedLoc.getLocations())
549 if (std::optional<Location> showableLoc = findLocToShow(childLoc))
550 return showableLoc;
551 return std::nullopt;
553 .Case([&](NameLoc nameLoc) -> std::optional<Location> {
554 return findLocToShow(nameLoc.getChildLoc());
556 .Case([&](OpaqueLoc opaqueLoc) -> std::optional<Location> {
557 // OpaqueLoc always falls back to a different source location.
558 return findLocToShow(opaqueLoc.getFallbackLocation());
560 .Case([](UnknownLoc) -> std::optional<Location> {
561 // Prefer not to show unknown locations.
562 return std::nullopt;
566 /// Get a memory buffer for the given file, or the main file of the source
567 /// manager if one doesn't exist. This always returns non-null.
568 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
569 // The column and line may be zero to represent unknown column and/or unknown
570 /// line/column information.
571 if (loc.getLine() == 0 || loc.getColumn() == 0)
572 return SMLoc();
574 unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
575 if (!bufferId)
576 return SMLoc();
577 return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
580 //===----------------------------------------------------------------------===//
581 // SourceMgrDiagnosticVerifierHandler
582 //===----------------------------------------------------------------------===//
584 namespace mlir {
585 namespace detail {
586 /// This class represents an expected output diagnostic.
587 struct ExpectedDiag {
588 ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
589 StringRef substring)
590 : kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
592 /// Emit an error at the location referenced by this diagnostic.
593 LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
594 const Twine &msg) {
595 SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
596 substring.size()));
597 mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
598 return failure();
601 /// Returns true if this diagnostic matches the given string.
602 bool match(StringRef str) const {
603 // If this isn't a regex diagnostic, we simply check if the string was
604 // contained.
605 if (substringRegex)
606 return substringRegex->match(str);
607 return str.contains(substring);
610 /// Compute the regex matcher for this diagnostic, using the provided stream
611 /// and manager to emit diagnostics as necessary.
612 LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
613 std::string regexStr;
614 llvm::raw_string_ostream regexOS(regexStr);
615 StringRef strToProcess = substring;
616 while (!strToProcess.empty()) {
617 // Find the next regex block.
618 size_t regexIt = strToProcess.find("{{");
619 if (regexIt == StringRef::npos) {
620 regexOS << llvm::Regex::escape(strToProcess);
621 break;
623 regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
624 strToProcess = strToProcess.drop_front(regexIt + 2);
626 // Find the end of the regex block.
627 size_t regexEndIt = strToProcess.find("}}");
628 if (regexEndIt == StringRef::npos)
629 return emitError(os, mgr, "found start of regex with no end '}}'");
630 StringRef regexStr = strToProcess.take_front(regexEndIt);
632 // Validate that the regex is actually valid.
633 std::string regexError;
634 if (!llvm::Regex(regexStr).isValid(regexError))
635 return emitError(os, mgr, "invalid regex: " + regexError);
637 regexOS << '(' << regexStr << ')';
638 strToProcess = strToProcess.drop_front(regexEndIt + 2);
640 substringRegex = llvm::Regex(regexOS.str());
641 return success();
644 /// The severity of the diagnosic expected.
645 DiagnosticSeverity kind;
646 /// The line number the expected diagnostic should be on.
647 unsigned lineNo;
648 /// The location of the expected diagnostic within the input file.
649 SMLoc fileLoc;
650 /// A flag indicating if the expected diagnostic has been matched yet.
651 bool matched = false;
652 /// The substring that is expected to be within the diagnostic.
653 StringRef substring;
654 /// An optional regex matcher, if the expected diagnostic sub-string was a
655 /// regex string.
656 std::optional<llvm::Regex> substringRegex;
659 struct SourceMgrDiagnosticVerifierHandlerImpl {
660 SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
662 /// Returns the expected diagnostics for the given source file.
663 std::optional<MutableArrayRef<ExpectedDiag>>
664 getExpectedDiags(StringRef bufName);
666 /// Computes the expected diagnostics for the given source buffer.
667 MutableArrayRef<ExpectedDiag>
668 computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
669 const llvm::MemoryBuffer *buf);
671 /// The current status of the verifier.
672 LogicalResult status;
674 /// A list of expected diagnostics for each buffer of the source manager.
675 llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
677 /// Regex to match the expected diagnostics format.
678 llvm::Regex expected =
679 llvm::Regex("expected-(error|note|remark|warning)(-re)? "
680 "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
682 } // namespace detail
683 } // namespace mlir
685 /// Given a diagnostic kind, return a human readable string for it.
686 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
687 switch (kind) {
688 case DiagnosticSeverity::Note:
689 return "note";
690 case DiagnosticSeverity::Warning:
691 return "warning";
692 case DiagnosticSeverity::Error:
693 return "error";
694 case DiagnosticSeverity::Remark:
695 return "remark";
697 llvm_unreachable("Unknown DiagnosticSeverity");
700 std::optional<MutableArrayRef<ExpectedDiag>>
701 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
702 auto expectedDiags = expectedDiagsPerFile.find(bufName);
703 if (expectedDiags != expectedDiagsPerFile.end())
704 return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
705 return std::nullopt;
708 MutableArrayRef<ExpectedDiag>
709 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
710 raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
711 // If the buffer is invalid, return an empty list.
712 if (!buf)
713 return std::nullopt;
714 auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
716 // The number of the last line that did not correlate to a designator.
717 unsigned lastNonDesignatorLine = 0;
719 // The indices of designators that apply to the next non designator line.
720 SmallVector<unsigned, 1> designatorsForNextLine;
722 // Scan the file for expected-* designators.
723 SmallVector<StringRef, 100> lines;
724 buf->getBuffer().split(lines, '\n');
725 for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
726 SmallVector<StringRef, 4> matches;
727 if (!expected.match(lines[lineNo].rtrim(), &matches)) {
728 // Check for designators that apply to this line.
729 if (!designatorsForNextLine.empty()) {
730 for (unsigned diagIndex : designatorsForNextLine)
731 expectedDiags[diagIndex].lineNo = lineNo + 1;
732 designatorsForNextLine.clear();
734 lastNonDesignatorLine = lineNo;
735 continue;
738 // Point to the start of expected-*.
739 SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
741 DiagnosticSeverity kind;
742 if (matches[1] == "error")
743 kind = DiagnosticSeverity::Error;
744 else if (matches[1] == "warning")
745 kind = DiagnosticSeverity::Warning;
746 else if (matches[1] == "remark")
747 kind = DiagnosticSeverity::Remark;
748 else {
749 assert(matches[1] == "note");
750 kind = DiagnosticSeverity::Note;
752 ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
754 // Check to see if this is a regex match, i.e. it includes the `-re`.
755 if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
756 status = failure();
757 continue;
760 StringRef offsetMatch = matches[3];
761 if (!offsetMatch.empty()) {
762 offsetMatch = offsetMatch.drop_front(1);
764 // Get the integer value without the @ and +/- prefix.
765 if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
766 int offset;
767 offsetMatch.drop_front().getAsInteger(0, offset);
769 if (offsetMatch.front() == '+')
770 record.lineNo += offset;
771 else
772 record.lineNo -= offset;
773 } else if (offsetMatch.consume_front("above")) {
774 // If the designator applies 'above' we add it to the last non
775 // designator line.
776 record.lineNo = lastNonDesignatorLine + 1;
777 } else {
778 // Otherwise, this is a 'below' designator and applies to the next
779 // non-designator line.
780 assert(offsetMatch.consume_front("below"));
781 designatorsForNextLine.push_back(expectedDiags.size());
783 // Set the line number to the last in the case that this designator ends
784 // up dangling.
785 record.lineNo = e;
788 expectedDiags.emplace_back(std::move(record));
790 return expectedDiags;
793 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
794 llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
795 : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
796 impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
797 // Compute the expected diagnostics for each of the current files in the
798 // source manager.
799 for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
800 (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
802 // Register a handler to verify the diagnostics.
803 setHandler([&](Diagnostic &diag) {
804 // Process the main diagnostics.
805 process(diag);
807 // Process each of the notes.
808 for (auto &note : diag.getNotes())
809 process(note);
813 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
814 llvm::SourceMgr &srcMgr, MLIRContext *ctx)
815 : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
817 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
818 // Ensure that all expected diagnostics were handled.
819 (void)verify();
822 /// Returns the status of the verifier and verifies that all expected
823 /// diagnostics were emitted. This return success if all diagnostics were
824 /// verified correctly, failure otherwise.
825 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
826 // Verify that all expected errors were seen.
827 for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
828 for (auto &err : expectedDiagsPair.second) {
829 if (err.matched)
830 continue;
831 impl->status =
832 err.emitError(os, mgr,
833 "expected " + getDiagKindStr(err.kind) + " \"" +
834 err.substring + "\" was not produced");
837 impl->expectedDiagsPerFile.clear();
838 return impl->status;
841 /// Process a single diagnostic.
842 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
843 auto kind = diag.getSeverity();
845 // Process a FileLineColLoc.
846 if (auto fileLoc = diag.getLocation()->findInstanceOf<FileLineColLoc>())
847 return process(fileLoc, diag.str(), kind);
849 emitDiagnostic(diag.getLocation(),
850 "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
851 DiagnosticSeverity::Error);
852 impl->status = failure();
855 /// Process a FileLineColLoc diagnostic.
856 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
857 StringRef msg,
858 DiagnosticSeverity kind) {
859 // Get the expected diagnostics for this file.
860 auto diags = impl->getExpectedDiags(loc.getFilename());
861 if (!diags) {
862 diags = impl->computeExpectedDiags(os, mgr,
863 getBufferForFile(loc.getFilename()));
866 // Search for a matching expected diagnostic.
867 // If we find something that is close then emit a more specific error.
868 ExpectedDiag *nearMiss = nullptr;
870 // If this was an expected error, remember that we saw it and return.
871 unsigned line = loc.getLine();
872 for (auto &e : *diags) {
873 if (line == e.lineNo && e.match(msg)) {
874 if (e.kind == kind) {
875 e.matched = true;
876 return;
879 // If this only differs based on the diagnostic kind, then consider it
880 // to be a near miss.
881 nearMiss = &e;
885 // Otherwise, emit an error for the near miss.
886 if (nearMiss)
887 mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
888 "'" + getDiagKindStr(kind) +
889 "' diagnostic emitted when expecting a '" +
890 getDiagKindStr(nearMiss->kind) + "'");
891 else
892 emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
893 DiagnosticSeverity::Error);
894 impl->status = failure();
897 //===----------------------------------------------------------------------===//
898 // ParallelDiagnosticHandler
899 //===----------------------------------------------------------------------===//
901 namespace mlir {
902 namespace detail {
903 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
904 struct ThreadDiagnostic {
905 ThreadDiagnostic(size_t id, Diagnostic diag)
906 : id(id), diag(std::move(diag)) {}
907 bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
909 /// The id for this diagnostic, this is used for ordering.
910 /// Note: This id corresponds to the ordered position of the current element
911 /// being processed by a given thread.
912 size_t id;
914 /// The diagnostic.
915 Diagnostic diag;
918 ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
919 handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
920 uint64_t tid = llvm::get_threadid();
921 llvm::sys::SmartScopedLock<true> lock(mutex);
923 // If this thread is not tracked, then return failure to let another
924 // handler process this diagnostic.
925 if (!threadToOrderID.count(tid))
926 return failure();
928 // Append a new diagnostic.
929 diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
930 return success();
934 ~ParallelDiagnosticHandlerImpl() override {
935 // Erase this handler from the context.
936 context->getDiagEngine().eraseHandler(handlerID);
938 // Early exit if there are no diagnostics, this is the common case.
939 if (diagnostics.empty())
940 return;
942 // Emit the diagnostics back to the context.
943 emitDiagnostics([&](Diagnostic &diag) {
944 return context->getDiagEngine().emit(std::move(diag));
948 /// Utility method to emit any held diagnostics.
949 void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
950 // Stable sort all of the diagnostics that were emitted. This creates a
951 // deterministic ordering for the diagnostics based upon which order id they
952 // were emitted for.
953 std::stable_sort(diagnostics.begin(), diagnostics.end());
955 // Emit each diagnostic to the context again.
956 for (ThreadDiagnostic &diag : diagnostics)
957 emitFn(diag.diag);
960 /// Set the order id for the current thread.
961 void setOrderIDForThread(size_t orderID) {
962 uint64_t tid = llvm::get_threadid();
963 llvm::sys::SmartScopedLock<true> lock(mutex);
964 threadToOrderID[tid] = orderID;
967 /// Remove the order id for the current thread.
968 void eraseOrderIDForThread() {
969 uint64_t tid = llvm::get_threadid();
970 llvm::sys::SmartScopedLock<true> lock(mutex);
971 threadToOrderID.erase(tid);
974 /// Dump the current diagnostics that were inflight.
975 void print(raw_ostream &os) const override {
976 // Early exit if there are no diagnostics, this is the common case.
977 if (diagnostics.empty())
978 return;
980 os << "In-Flight Diagnostics:\n";
981 emitDiagnostics([&](const Diagnostic &diag) {
982 os.indent(4);
984 // Print each diagnostic with the format:
985 // "<location>: <kind>: <msg>"
986 if (!llvm::isa<UnknownLoc>(diag.getLocation()))
987 os << diag.getLocation() << ": ";
988 switch (diag.getSeverity()) {
989 case DiagnosticSeverity::Error:
990 os << "error: ";
991 break;
992 case DiagnosticSeverity::Warning:
993 os << "warning: ";
994 break;
995 case DiagnosticSeverity::Note:
996 os << "note: ";
997 break;
998 case DiagnosticSeverity::Remark:
999 os << "remark: ";
1000 break;
1002 os << diag << '\n';
1006 /// A smart mutex to lock access to the internal state.
1007 llvm::sys::SmartMutex<true> mutex;
1009 /// A mapping between the thread id and the current order id.
1010 DenseMap<uint64_t, size_t> threadToOrderID;
1012 /// An unordered list of diagnostics that were emitted.
1013 mutable std::vector<ThreadDiagnostic> diagnostics;
1015 /// The unique id for the parallel handler.
1016 DiagnosticEngine::HandlerID handlerID = 0;
1018 /// The context to emit the diagnostics to.
1019 MLIRContext *context;
1021 } // namespace detail
1022 } // namespace mlir
1024 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
1025 : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
1026 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1028 /// Set the order id for the current thread.
1029 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
1030 impl->setOrderIDForThread(orderID);
1033 /// Remove the order id for the current thread. This removes the thread from
1034 /// diagnostics tracking.
1035 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1036 impl->eraseOrderIDForThread();