[TySan] Don't report globals with incomplete types. (#121922)
[llvm-project.git] / mlir / lib / Tools / mlir-lsp-server / MLIRServer.cpp
blob4e19274c3da40749971d9451e7f443ec92fce1fe
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "MLIRServer.h"
10 #include "Protocol.h"
11 #include "mlir/AsmParser/AsmParser.h"
12 #include "mlir/AsmParser/AsmParserState.h"
13 #include "mlir/AsmParser/CodeComplete.h"
14 #include "mlir/Bytecode/BytecodeWriter.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "mlir/Support/ToolUtilities.h"
19 #include "mlir/Tools/lsp-server-support/Logging.h"
20 #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Base64.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include <optional>
26 using namespace mlir;
28 /// Returns the range of a lexical token given a SMLoc corresponding to the
29 /// start of an token location. The range is computed heuristically, and
30 /// supports identifier-like tokens, strings, etc.
31 static SMRange convertTokenLocToRange(SMLoc loc) {
32 return lsp::convertTokenLocToRange(loc, "$-.");
35 /// Returns a language server location from the given MLIR file location.
36 /// `uriScheme` is the scheme to use when building new uris.
37 static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38 FileLineColLoc loc) {
39 llvm::Expected<lsp::URIForFile> sourceURI =
40 lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41 if (!sourceURI) {
42 lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43 loc.getFilename(),
44 llvm::toString(sourceURI.takeError()));
45 return std::nullopt;
48 lsp::Position position;
49 position.line = loc.getLine() - 1;
50 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
54 /// Returns a language server location from the given MLIR location, or
55 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56 /// when building new uris. `uri` is an optional additional filter that, when
57 /// present, is used to filter sub locations that do not share the same uri.
58 static std::optional<lsp::Location>
59 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60 StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61 std::optional<lsp::Location> location;
62 loc->walk([&](Location nestedLoc) {
63 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64 if (!fileLoc)
65 return WalkResult::advance();
67 std::optional<lsp::Location> sourceLoc =
68 getLocationFromLoc(uriScheme, fileLoc);
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72 sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
74 // Use range of potential identifier starting at location, else length 1
75 // range.
76 location->range.end.character += 1;
77 if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78 auto lineCol = sourceMgr.getLineAndColumn(range->End);
79 location->range.end.character =
80 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
82 return WalkResult::interrupt();
84 return WalkResult::advance();
85 });
86 return location;
89 /// Collect all of the locations from the given MLIR location that are not
90 /// contained within the given URI.
91 static void collectLocationsFromLoc(Location loc,
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
94 SetVector<Location> visitedLocs;
95 loc->walk([&](Location nestedLoc) {
96 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97 if (!fileLoc || !visitedLocs.insert(nestedLoc))
98 return WalkResult::advance();
100 std::optional<lsp::Location> sourceLoc =
101 getLocationFromLoc(uri.scheme(), fileLoc);
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(*sourceLoc);
104 return WalkResult::advance();
108 /// Returns true if the given range contains the given source location. Note
109 /// that this has slightly different behavior than SMRange because it is
110 /// inclusive of the end location.
111 static bool contains(SMRange range, SMLoc loc) {
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
116 /// Returns true if the given location is contained by the definition or one of
117 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118 /// the range within `def` that the provided `loc` overlapped with.
119 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120 SMRange *overlappedRange = nullptr) {
121 // Check the main definition.
122 if (contains(def.loc, loc)) {
123 if (overlappedRange)
124 *overlappedRange = def.loc;
125 return true;
128 // Check the uses.
129 const auto *useIt = llvm::find_if(
130 def.uses, [&](const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.uses.end()) {
132 if (overlappedRange)
133 *overlappedRange = *useIt;
134 return true;
136 return false;
139 /// Given a location pointing to a result, return the result number it refers
140 /// to or std::nullopt if it refers to all of the results.
141 static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142 // Skip all of the identifier characters.
143 auto isIdentifierChar = [](char c) {
144 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145 c == '-';
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
149 ++curPtr;
151 // Check to see if this location indexes into the result group, via `#`. If it
152 // doesn't, we can't extract a sub result number.
153 if (*curPtr != '#')
154 return std::nullopt;
156 // Compute the sub result number from the remaining portion of the string.
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(*curPtr))
159 ++curPtr;
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163 : resultNumber;
166 /// Given a source location range, return the text covered by the given range.
167 /// If the range is invalid, returns std::nullopt.
168 static std::optional<StringRef> getTextFromRange(SMRange range) {
169 if (!range.isValid())
170 return std::nullopt;
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
175 /// Given a block, return its position in its parent region.
176 static unsigned getBlockNumber(Block *block) {
177 return std::distance(block->getParent()->begin(), block->getIterator());
180 /// Given a block and source location, print the source name of the block to the
181 /// given output stream.
182 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183 // Try to extract a name from the source location.
184 std::optional<StringRef> text = getTextFromRange(loc);
185 if (text && text->starts_with("^")) {
186 os << *text;
187 return;
190 // Otherwise, we don't have a name so print the block number.
191 os << "<Block #" << getBlockNumber(block) << ">";
193 static void printDefBlockName(raw_ostream &os,
194 const AsmParserState::BlockDefinition &def) {
195 printDefBlockName(os, def.block, def.definition.loc);
198 /// Convert the given MLIR diagnostic to the LSP form.
199 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200 Diagnostic &diag,
201 const lsp::URIForFile &uri) {
202 lsp::Diagnostic lspDiag;
203 lspDiag.source = "mlir";
205 // Note: Right now all of the diagnostics are treated as parser issues, but
206 // some are parser and some are verifier.
207 lspDiag.category = "Parse Error";
209 // Try to grab a file location for this diagnostic.
210 // TODO: For simplicity, we just grab the first one. It may be likely that we
211 // will need a more interesting heuristic here.'
212 StringRef uriScheme = uri.scheme();
213 std::optional<lsp::Location> lspLocation =
214 getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215 if (lspLocation)
216 lspDiag.range = lspLocation->range;
218 // Convert the severity for the diagnostic.
219 switch (diag.getSeverity()) {
220 case DiagnosticSeverity::Note:
221 llvm_unreachable("expected notes to be handled separately");
222 case DiagnosticSeverity::Warning:
223 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
224 break;
225 case DiagnosticSeverity::Error:
226 lspDiag.severity = lsp::DiagnosticSeverity::Error;
227 break;
228 case DiagnosticSeverity::Remark:
229 lspDiag.severity = lsp::DiagnosticSeverity::Information;
230 break;
232 lspDiag.message = diag.str();
234 // Attach any notes to the main diagnostic as related information.
235 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
236 for (Diagnostic &note : diag.getNotes()) {
237 lsp::Location noteLoc;
238 if (std::optional<lsp::Location> loc =
239 getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240 noteLoc = *loc;
241 else
242 noteLoc.uri = uri;
243 relatedDiags.emplace_back(noteLoc, note.str());
245 if (!relatedDiags.empty())
246 lspDiag.relatedInformation = std::move(relatedDiags);
248 return lspDiag;
251 //===----------------------------------------------------------------------===//
252 // MLIRDocument
253 //===----------------------------------------------------------------------===//
255 namespace {
256 /// This class represents all of the information pertaining to a specific MLIR
257 /// document.
258 struct MLIRDocument {
259 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261 MLIRDocument(const MLIRDocument &) = delete;
262 MLIRDocument &operator=(const MLIRDocument &) = delete;
264 //===--------------------------------------------------------------------===//
265 // Definitions and References
266 //===--------------------------------------------------------------------===//
268 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269 std::vector<lsp::Location> &locations);
270 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271 std::vector<lsp::Location> &references);
273 //===--------------------------------------------------------------------===//
274 // Hover
275 //===--------------------------------------------------------------------===//
277 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278 const lsp::Position &hoverPos);
279 std::optional<lsp::Hover>
280 buildHoverForOperation(SMRange hoverRange,
281 const AsmParserState::OperationDefinition &op);
282 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283 unsigned resultStart,
284 unsigned resultEnd, SMLoc posLoc);
285 lsp::Hover buildHoverForBlock(SMRange hoverRange,
286 const AsmParserState::BlockDefinition &block);
287 lsp::Hover
288 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289 const AsmParserState::BlockDefinition &block);
291 lsp::Hover buildHoverForAttributeAlias(
292 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293 lsp::Hover
294 buildHoverForTypeAlias(SMRange hoverRange,
295 const AsmParserState::TypeAliasDefinition &type);
297 //===--------------------------------------------------------------------===//
298 // Document Symbols
299 //===--------------------------------------------------------------------===//
301 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302 void findDocumentSymbols(Operation *op,
303 std::vector<lsp::DocumentSymbol> &symbols);
305 //===--------------------------------------------------------------------===//
306 // Code Completion
307 //===--------------------------------------------------------------------===//
309 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310 const lsp::Position &completePos,
311 const DialectRegistry &registry);
313 //===--------------------------------------------------------------------===//
314 // Code Action
315 //===--------------------------------------------------------------------===//
317 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318 lsp::Position &pos, StringRef severity,
319 StringRef message,
320 std::vector<lsp::TextEdit> &edits);
322 //===--------------------------------------------------------------------===//
323 // Bytecode
324 //===--------------------------------------------------------------------===//
326 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
328 //===--------------------------------------------------------------------===//
329 // Fields
330 //===--------------------------------------------------------------------===//
332 /// The high level parser state used to find definitions and references within
333 /// the source file.
334 AsmParserState asmState;
336 /// The container for the IR parsed from the input file.
337 Block parsedIR;
339 /// A collection of external resources, which we want to propagate up to the
340 /// user.
341 FallbackAsmResourceMap fallbackResourceMap;
343 /// The source manager containing the contents of the input file.
344 llvm::SourceMgr sourceMgr;
346 } // namespace
348 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349 StringRef contents,
350 std::vector<lsp::Diagnostic> &diagnostics) {
351 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352 diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
355 // Try to parsed the given IR string.
356 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357 if (!memBuffer) {
358 lsp::Logger::error("Failed to create memory buffer for file", uri.file());
359 return;
362 ParserConfig config(&context, /*verifyAfterParse=*/true,
363 &fallbackResourceMap);
364 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
365 if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
366 // If parsing failed, clear out any of the current state.
367 parsedIR.clear();
368 asmState = AsmParserState();
369 fallbackResourceMap = FallbackAsmResourceMap();
370 return;
374 //===----------------------------------------------------------------------===//
375 // MLIRDocument: Definitions and References
376 //===----------------------------------------------------------------------===//
378 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
379 const lsp::Position &defPos,
380 std::vector<lsp::Location> &locations) {
381 SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
383 // Functor used to check if an SM definition contains the position.
384 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
385 if (!isDefOrUse(def, posLoc))
386 return false;
387 locations.emplace_back(uri, sourceMgr, def.loc);
388 return true;
391 // Check all definitions related to operations.
392 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393 if (contains(op.loc, posLoc))
394 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
395 for (const auto &result : op.resultGroups)
396 if (containsPosition(result.definition))
397 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
398 for (const auto &symUse : op.symbolUses) {
399 if (contains(symUse, posLoc)) {
400 locations.emplace_back(uri, sourceMgr, op.loc);
401 return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
406 // Check all definitions related to blocks.
407 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
408 if (containsPosition(block.definition))
409 return;
410 for (const AsmParserState::SMDefinition &arg : block.arguments)
411 if (containsPosition(arg))
412 return;
415 // Check all alias definitions.
416 for (const AsmParserState::AttributeAliasDefinition &attr :
417 asmState.getAttributeAliasDefs()) {
418 if (containsPosition(attr.definition))
419 return;
421 for (const AsmParserState::TypeAliasDefinition &type :
422 asmState.getTypeAliasDefs()) {
423 if (containsPosition(type.definition))
424 return;
428 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
429 const lsp::Position &pos,
430 std::vector<lsp::Location> &references) {
431 // Functor used to append all of the definitions/uses of the given SM
432 // definition to the reference list.
433 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
434 references.emplace_back(uri, sourceMgr, def.loc);
435 for (const SMRange &use : def.uses)
436 references.emplace_back(uri, sourceMgr, use);
439 SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
441 // Check all definitions related to operations.
442 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
443 if (contains(op.loc, posLoc)) {
444 for (const auto &result : op.resultGroups)
445 appendSMDef(result.definition);
446 for (const auto &symUse : op.symbolUses)
447 if (contains(symUse, posLoc))
448 references.emplace_back(uri, sourceMgr, symUse);
449 return;
451 for (const auto &result : op.resultGroups)
452 if (isDefOrUse(result.definition, posLoc))
453 return appendSMDef(result.definition);
454 for (const auto &symUse : op.symbolUses) {
455 if (!contains(symUse, posLoc))
456 continue;
457 for (const auto &symUse : op.symbolUses)
458 references.emplace_back(uri, sourceMgr, symUse);
459 return;
463 // Check all definitions related to blocks.
464 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
465 if (isDefOrUse(block.definition, posLoc))
466 return appendSMDef(block.definition);
468 for (const AsmParserState::SMDefinition &arg : block.arguments)
469 if (isDefOrUse(arg, posLoc))
470 return appendSMDef(arg);
473 // Check all alias definitions.
474 for (const AsmParserState::AttributeAliasDefinition &attr :
475 asmState.getAttributeAliasDefs()) {
476 if (isDefOrUse(attr.definition, posLoc))
477 return appendSMDef(attr.definition);
479 for (const AsmParserState::TypeAliasDefinition &type :
480 asmState.getTypeAliasDefs()) {
481 if (isDefOrUse(type.definition, posLoc))
482 return appendSMDef(type.definition);
486 //===----------------------------------------------------------------------===//
487 // MLIRDocument: Hover
488 //===----------------------------------------------------------------------===//
490 std::optional<lsp::Hover>
491 MLIRDocument::findHover(const lsp::URIForFile &uri,
492 const lsp::Position &hoverPos) {
493 SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
494 SMRange hoverRange;
496 // Check for Hovers on operations and results.
497 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
498 // Check if the position points at this operation.
499 if (contains(op.loc, posLoc))
500 return buildHoverForOperation(op.loc, op);
502 // Check if the position points at the symbol name.
503 for (auto &use : op.symbolUses)
504 if (contains(use, posLoc))
505 return buildHoverForOperation(use, op);
507 // Check if the position points at a result group.
508 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
509 const auto &result = op.resultGroups[i];
510 if (!isDefOrUse(result.definition, posLoc, &hoverRange))
511 continue;
513 // Get the range of results covered by the over position.
514 unsigned resultStart = result.startIndex;
515 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
516 : op.resultGroups[i + 1].startIndex;
517 return buildHoverForOperationResult(hoverRange, op.op, resultStart,
518 resultEnd, posLoc);
522 // Check to see if the hover is over a block argument.
523 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
524 if (isDefOrUse(block.definition, posLoc, &hoverRange))
525 return buildHoverForBlock(hoverRange, block);
527 for (const auto &arg : llvm::enumerate(block.arguments)) {
528 if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
529 continue;
531 return buildHoverForBlockArgument(
532 hoverRange, block.block->getArgument(arg.index()), block);
536 // Check to see if the hover is over an alias.
537 for (const AsmParserState::AttributeAliasDefinition &attr :
538 asmState.getAttributeAliasDefs()) {
539 if (isDefOrUse(attr.definition, posLoc, &hoverRange))
540 return buildHoverForAttributeAlias(hoverRange, attr);
542 for (const AsmParserState::TypeAliasDefinition &type :
543 asmState.getTypeAliasDefs()) {
544 if (isDefOrUse(type.definition, posLoc, &hoverRange))
545 return buildHoverForTypeAlias(hoverRange, type);
548 return std::nullopt;
551 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
552 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
553 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
554 llvm::raw_string_ostream os(hover.contents.value);
556 // Add the operation name to the hover.
557 os << "\"" << op.op->getName() << "\"";
558 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
559 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
560 os << "\n\n";
562 os << "Generic Form:\n\n```mlir\n";
564 op.op->print(os, OpPrintingFlags()
565 .printGenericOpForm()
566 .elideLargeElementsAttrs()
567 .skipRegions());
568 os << "\n```\n";
570 return hover;
573 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
574 Operation *op,
575 unsigned resultStart,
576 unsigned resultEnd,
577 SMLoc posLoc) {
578 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579 llvm::raw_string_ostream os(hover.contents.value);
581 // Add the parent operation name to the hover.
582 os << "Operation: \"" << op->getName() << "\"\n\n";
584 // Check to see if the location points to a specific result within the
585 // group.
586 if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
587 if ((resultStart + *resultNumber) < resultEnd) {
588 resultStart += *resultNumber;
589 resultEnd = resultStart + 1;
593 // Add the range of results and their types to the hover info.
594 if ((resultStart + 1) == resultEnd) {
595 os << "Result #" << resultStart << "\n\n"
596 << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
597 } else {
598 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
599 << "Types: ";
600 llvm::interleaveComma(
601 op->getResults().slice(resultStart, resultEnd), os,
602 [&](Value result) { os << "`" << result.getType() << "`"; });
605 return hover;
608 lsp::Hover
609 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
610 const AsmParserState::BlockDefinition &block) {
611 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
612 llvm::raw_string_ostream os(hover.contents.value);
614 // Print the given block to the hover output stream.
615 auto printBlockToHover = [&](Block *newBlock) {
616 if (const auto *def = asmState.getBlockDef(newBlock))
617 printDefBlockName(os, *def);
618 else
619 printDefBlockName(os, newBlock);
622 // Display the parent operation, block number, predecessors, and successors.
623 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
624 << "Block #" << getBlockNumber(block.block) << "\n\n";
625 if (!block.block->hasNoPredecessors()) {
626 os << "Predecessors: ";
627 llvm::interleaveComma(block.block->getPredecessors(), os,
628 printBlockToHover);
629 os << "\n\n";
631 if (!block.block->hasNoSuccessors()) {
632 os << "Successors: ";
633 llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
634 os << "\n\n";
637 return hover;
640 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
641 SMRange hoverRange, BlockArgument arg,
642 const AsmParserState::BlockDefinition &block) {
643 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
644 llvm::raw_string_ostream os(hover.contents.value);
646 // Display the parent operation, block, the argument number, and the type.
647 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
648 << "Block: ";
649 printDefBlockName(os, block);
650 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
651 << "Type: `" << arg.getType() << "`\n\n";
653 return hover;
656 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
657 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
658 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
659 llvm::raw_string_ostream os(hover.contents.value);
661 os << "Attribute Alias: \"" << attr.name << "\n\n";
662 os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
664 return hover;
667 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
668 SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
669 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
670 llvm::raw_string_ostream os(hover.contents.value);
672 os << "Type Alias: \"" << type.name << "\n\n";
673 os << "Value: ```mlir\n" << type.value << "\n```\n\n";
675 return hover;
678 //===----------------------------------------------------------------------===//
679 // MLIRDocument: Document Symbols
680 //===----------------------------------------------------------------------===//
682 void MLIRDocument::findDocumentSymbols(
683 std::vector<lsp::DocumentSymbol> &symbols) {
684 for (Operation &op : parsedIR)
685 findDocumentSymbols(&op, symbols);
688 void MLIRDocument::findDocumentSymbols(
689 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
690 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
692 // Check for the source information of this operation.
693 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
694 // If this operation defines a symbol, record it.
695 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
696 symbols.emplace_back(symbol.getName(),
697 isa<FunctionOpInterface>(op)
698 ? lsp::SymbolKind::Function
699 : lsp::SymbolKind::Class,
700 lsp::Range(sourceMgr, def->scopeLoc),
701 lsp::Range(sourceMgr, def->loc));
702 childSymbols = &symbols.back().children;
704 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
705 // Otherwise, if this is a symbol table push an anonymous document symbol.
706 symbols.emplace_back("<" + op->getName().getStringRef() + ">",
707 lsp::SymbolKind::Namespace,
708 lsp::Range(sourceMgr, def->scopeLoc),
709 lsp::Range(sourceMgr, def->loc));
710 childSymbols = &symbols.back().children;
714 // Recurse into the regions of this operation.
715 if (!op->getNumRegions())
716 return;
717 for (Region &region : op->getRegions())
718 for (Operation &childOp : region.getOps())
719 findDocumentSymbols(&childOp, *childSymbols);
722 //===----------------------------------------------------------------------===//
723 // MLIRDocument: Code Completion
724 //===----------------------------------------------------------------------===//
726 namespace {
727 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
728 public:
729 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
730 MLIRContext *ctx)
731 : AsmParserCodeCompleteContext(completeLoc),
732 completionList(completionList), ctx(ctx) {}
734 /// Signal code completion for a dialect name, with an optional prefix.
735 void completeDialectName(StringRef prefix) final {
736 for (StringRef dialect : ctx->getAvailableDialects()) {
737 lsp::CompletionItem item(prefix + dialect,
738 lsp::CompletionItemKind::Module,
739 /*sortText=*/"3");
740 item.detail = "dialect";
741 completionList.items.emplace_back(item);
744 using AsmParserCodeCompleteContext::completeDialectName;
746 /// Signal code completion for an operation name within the given dialect.
747 void completeOperationName(StringRef dialectName) final {
748 Dialect *dialect = ctx->getOrLoadDialect(dialectName);
749 if (!dialect)
750 return;
752 for (const auto &op : ctx->getRegisteredOperations()) {
753 if (&op.getDialect() != dialect)
754 continue;
756 lsp::CompletionItem item(
757 op.getStringRef().drop_front(dialectName.size() + 1),
758 lsp::CompletionItemKind::Field,
759 /*sortText=*/"1");
760 item.detail = "operation";
761 completionList.items.emplace_back(item);
765 /// Append the given SSA value as a code completion result for SSA value
766 /// completions.
767 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
768 // Check if we need to insert the `%` or not.
769 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
771 lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
772 if (stripPrefix)
773 item.insertText = name.drop_front(1).str();
774 item.detail = std::move(typeData);
775 completionList.items.emplace_back(item);
778 /// Append the given block as a code completion result for block name
779 /// completions.
780 void appendBlockCompletion(StringRef name) final {
781 // Check if we need to insert the `^` or not.
782 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
784 lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
785 if (stripPrefix)
786 item.insertText = name.drop_front(1).str();
787 completionList.items.emplace_back(item);
790 /// Signal a completion for the given expected token.
791 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
792 for (StringRef token : tokens) {
793 lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
794 /*sortText=*/"0");
795 item.detail = optional ? "optional" : "";
796 completionList.items.emplace_back(item);
800 /// Signal a completion for an attribute.
801 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
802 appendSimpleCompletions({"affine_set", "affine_map", "dense",
803 "dense_resource", "false", "loc", "sparse", "true",
804 "unit"},
805 lsp::CompletionItemKind::Field,
806 /*sortText=*/"1");
808 completeDialectName("#");
809 completeAliases(aliases, "#");
811 void completeDialectAttributeOrAlias(
812 const llvm::StringMap<Attribute> &aliases) override {
813 completeDialectName();
814 completeAliases(aliases);
817 /// Signal a completion for a type.
818 void completeType(const llvm::StringMap<Type> &aliases) override {
819 // Handle the various builtin types.
820 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
821 "bf16", "f16", "f32", "f64", "f80", "f128",
822 "index", "none"},
823 lsp::CompletionItemKind::Field,
824 /*sortText=*/"1");
826 // Handle the builtin integer types.
827 for (StringRef type : {"i", "si", "ui"}) {
828 lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
829 /*sortText=*/"1");
830 item.insertText = type.str();
831 completionList.items.emplace_back(item);
834 // Insert completions for dialect types and aliases.
835 completeDialectName("!");
836 completeAliases(aliases, "!");
838 void
839 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
840 completeDialectName();
841 completeAliases(aliases);
844 /// Add completion results for the given set of aliases.
845 template <typename T>
846 void completeAliases(const llvm::StringMap<T> &aliases,
847 StringRef prefix = "") {
848 for (const auto &alias : aliases) {
849 lsp::CompletionItem item(prefix + alias.getKey(),
850 lsp::CompletionItemKind::Field,
851 /*sortText=*/"2");
852 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
853 completionList.items.emplace_back(item);
857 /// Add a set of simple completions that all have the same kind.
858 void appendSimpleCompletions(ArrayRef<StringRef> completions,
859 lsp::CompletionItemKind kind,
860 StringRef sortText = "") {
861 for (StringRef completion : completions)
862 completionList.items.emplace_back(completion, kind, sortText);
865 private:
866 lsp::CompletionList &completionList;
867 MLIRContext *ctx;
869 } // namespace
871 lsp::CompletionList
872 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
873 const lsp::Position &completePos,
874 const DialectRegistry &registry) {
875 SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
876 if (!posLoc.isValid())
877 return lsp::CompletionList();
879 // To perform code completion, we run another parse of the module with the
880 // code completion context provided.
881 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
882 tmpContext.allowUnregisteredDialects();
883 lsp::CompletionList completionList;
884 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
885 &tmpContext);
887 Block tmpIR;
888 AsmParserState tmpState;
889 (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
890 &lspCompleteContext);
891 return completionList;
894 //===----------------------------------------------------------------------===//
895 // MLIRDocument: Code Action
896 //===----------------------------------------------------------------------===//
898 void MLIRDocument::getCodeActionForDiagnostic(
899 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
900 StringRef message, std::vector<lsp::TextEdit> &edits) {
901 // Ignore diagnostics that print the current operation. These are always
902 // enabled for the language server, but not generally during normal
903 // parsing/verification.
904 if (message.starts_with("see current operation: "))
905 return;
907 // Get the start of the line containing the diagnostic.
908 const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
909 const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
910 if (!lineStart)
911 return;
912 StringRef line(lineStart, pos.character);
914 // Add a text edit for adding an expected-* diagnostic check for this
915 // diagnostic.
916 lsp::TextEdit edit;
917 edit.range = lsp::Range(lsp::Position(pos.line, 0));
919 // Use the indent of the current line for the expected-* diagnostic.
920 size_t indent = line.find_first_not_of(' ');
921 if (indent == StringRef::npos)
922 indent = line.size();
924 edit.newText.append(indent, ' ');
925 llvm::raw_string_ostream(edit.newText)
926 << "// expected-" << severity << " @below {{" << message << "}}\n";
927 edits.emplace_back(std::move(edit));
930 //===----------------------------------------------------------------------===//
931 // MLIRDocument: Bytecode
932 //===----------------------------------------------------------------------===//
934 llvm::Expected<lsp::MLIRConvertBytecodeResult>
935 MLIRDocument::convertToBytecode() {
936 // TODO: We currently require a single top-level operation, but this could
937 // conceptually be relaxed.
938 if (!llvm::hasSingleElement(parsedIR)) {
939 if (parsedIR.empty()) {
940 return llvm::make_error<lsp::LSPError>(
941 "expected a single and valid top-level operation, please ensure "
942 "there are no errors",
943 lsp::ErrorCode::RequestFailed);
945 return llvm::make_error<lsp::LSPError>(
946 "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
949 lsp::MLIRConvertBytecodeResult result;
951 BytecodeWriterConfig writerConfig(fallbackResourceMap);
953 std::string rawBytecodeBuffer;
954 llvm::raw_string_ostream os(rawBytecodeBuffer);
955 // No desired bytecode version set, so no need to check for error.
956 (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
957 result.output = llvm::encodeBase64(rawBytecodeBuffer);
959 return result;
962 //===----------------------------------------------------------------------===//
963 // MLIRTextFileChunk
964 //===----------------------------------------------------------------------===//
966 namespace {
967 /// This class represents a single chunk of an MLIR text file.
968 struct MLIRTextFileChunk {
969 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970 const lsp::URIForFile &uri, StringRef contents,
971 std::vector<lsp::Diagnostic> &diagnostics)
972 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
974 /// Adjust the line number of the given range to anchor at the beginning of
975 /// the file, instead of the beginning of this chunk.
976 void adjustLocForChunkOffset(lsp::Range &range) {
977 adjustLocForChunkOffset(range.start);
978 adjustLocForChunkOffset(range.end);
980 /// Adjust the line number of the given position to anchor at the beginning of
981 /// the file, instead of the beginning of this chunk.
982 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
984 /// The line offset of this chunk from the beginning of the file.
985 uint64_t lineOffset;
986 /// The document referred to by this chunk.
987 MLIRDocument document;
989 } // namespace
991 //===----------------------------------------------------------------------===//
992 // MLIRTextFile
993 //===----------------------------------------------------------------------===//
995 namespace {
996 /// This class represents a text file containing one or more MLIR documents.
997 class MLIRTextFile {
998 public:
999 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000 int64_t version, DialectRegistry &registry,
1001 std::vector<lsp::Diagnostic> &diagnostics);
1003 /// Return the current version of this text file.
1004 int64_t getVersion() const { return version; }
1006 //===--------------------------------------------------------------------===//
1007 // LSP Queries
1008 //===--------------------------------------------------------------------===//
1010 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1011 std::vector<lsp::Location> &locations);
1012 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1013 std::vector<lsp::Location> &references);
1014 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1015 lsp::Position hoverPos);
1016 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018 lsp::Position completePos);
1019 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020 const lsp::CodeActionContext &context,
1021 std::vector<lsp::CodeAction> &actions);
1022 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1024 private:
1025 /// Find the MLIR document that contains the given position, and update the
1026 /// position to be anchored at the start of the found chunk instead of the
1027 /// beginning of the file.
1028 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1030 /// The context used to hold the state contained by the parsed document.
1031 MLIRContext context;
1033 /// The full string contents of the file.
1034 std::string contents;
1036 /// The version of this file.
1037 int64_t version;
1039 /// The number of lines in the file.
1040 int64_t totalNumLines = 0;
1042 /// The chunks of this file. The order of these chunks is the order in which
1043 /// they appear in the text file.
1044 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1046 } // namespace
1048 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049 int64_t version, DialectRegistry &registry,
1050 std::vector<lsp::Diagnostic> &diagnostics)
1051 : context(registry, MLIRContext::Threading::DISABLED),
1052 contents(fileContents.str()), version(version) {
1053 context.allowUnregisteredDialects();
1055 // Split the file into separate MLIR documents.
1056 SmallVector<StringRef, 8> subContents;
1057 StringRef(contents).split(subContents, kDefaultSplitMarker);
1058 chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1059 context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1061 uint64_t lineOffset = subContents.front().count('\n');
1062 for (StringRef docContents : llvm::drop_begin(subContents)) {
1063 unsigned currentNumDiags = diagnostics.size();
1064 auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1065 docContents, diagnostics);
1066 lineOffset += docContents.count('\n');
1068 // Adjust locations used in diagnostics to account for the offset from the
1069 // beginning of the file.
1070 for (lsp::Diagnostic &diag :
1071 llvm::drop_begin(diagnostics, currentNumDiags)) {
1072 chunk->adjustLocForChunkOffset(diag.range);
1074 if (!diag.relatedInformation)
1075 continue;
1076 for (auto &it : *diag.relatedInformation)
1077 if (it.location.uri == uri)
1078 chunk->adjustLocForChunkOffset(it.location.range);
1080 chunks.emplace_back(std::move(chunk));
1082 totalNumLines = lineOffset;
1085 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1086 lsp::Position defPos,
1087 std::vector<lsp::Location> &locations) {
1088 MLIRTextFileChunk &chunk = getChunkFor(defPos);
1089 chunk.document.getLocationsOf(uri, defPos, locations);
1091 // Adjust any locations within this file for the offset of this chunk.
1092 if (chunk.lineOffset == 0)
1093 return;
1094 for (lsp::Location &loc : locations)
1095 if (loc.uri == uri)
1096 chunk.adjustLocForChunkOffset(loc.range);
1099 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1100 lsp::Position pos,
1101 std::vector<lsp::Location> &references) {
1102 MLIRTextFileChunk &chunk = getChunkFor(pos);
1103 chunk.document.findReferencesOf(uri, pos, references);
1105 // Adjust any locations within this file for the offset of this chunk.
1106 if (chunk.lineOffset == 0)
1107 return;
1108 for (lsp::Location &loc : references)
1109 if (loc.uri == uri)
1110 chunk.adjustLocForChunkOffset(loc.range);
1113 std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1114 lsp::Position hoverPos) {
1115 MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1116 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1118 // Adjust any locations within this file for the offset of this chunk.
1119 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120 chunk.adjustLocForChunkOffset(*hoverInfo->range);
1121 return hoverInfo;
1124 void MLIRTextFile::findDocumentSymbols(
1125 std::vector<lsp::DocumentSymbol> &symbols) {
1126 if (chunks.size() == 1)
1127 return chunks.front()->document.findDocumentSymbols(symbols);
1129 // If there are multiple chunks in this file, we create top-level symbols for
1130 // each chunk.
1131 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132 MLIRTextFileChunk &chunk = *chunks[i];
1133 lsp::Position startPos(chunk.lineOffset);
1134 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135 : chunks[i + 1]->lineOffset);
1136 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137 lsp::SymbolKind::Namespace,
1138 /*range=*/lsp::Range(startPos, endPos),
1139 /*selectionRange=*/lsp::Range(startPos));
1140 chunk.document.findDocumentSymbols(symbol.children);
1142 // Fixup the locations of document symbols within this chunk.
1143 if (i != 0) {
1144 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146 symbolsToFix.push_back(&childSymbol);
1148 while (!symbolsToFix.empty()) {
1149 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150 chunk.adjustLocForChunkOffset(symbol->range);
1151 chunk.adjustLocForChunkOffset(symbol->selectionRange);
1153 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154 symbolsToFix.push_back(&childSymbol);
1158 // Push the symbol for this chunk.
1159 symbols.emplace_back(std::move(symbol));
1163 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164 lsp::Position completePos) {
1165 MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167 uri, completePos, context.getDialectRegistry());
1169 // Adjust any completion locations.
1170 for (lsp::CompletionItem &item : completionList.items) {
1171 if (item.textEdit)
1172 chunk.adjustLocForChunkOffset(item.textEdit->range);
1173 for (lsp::TextEdit &edit : item.additionalTextEdits)
1174 chunk.adjustLocForChunkOffset(edit.range);
1176 return completionList;
1179 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180 const lsp::Range &pos,
1181 const lsp::CodeActionContext &context,
1182 std::vector<lsp::CodeAction> &actions) {
1183 // Create actions for any diagnostics in this file.
1184 for (auto &diag : context.diagnostics) {
1185 if (diag.source != "mlir")
1186 continue;
1187 lsp::Position diagPos = diag.range.start;
1188 MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1190 // Add a new code action that inserts a "expected" diagnostic check.
1191 lsp::CodeAction action;
1192 action.title = "Add expected-* diagnostic checks";
1193 action.kind = lsp::CodeAction::kQuickFix.str();
1195 StringRef severity;
1196 switch (diag.severity) {
1197 case lsp::DiagnosticSeverity::Error:
1198 severity = "error";
1199 break;
1200 case lsp::DiagnosticSeverity::Warning:
1201 severity = "warning";
1202 break;
1203 default:
1204 continue;
1207 // Get edits for the diagnostic.
1208 std::vector<lsp::TextEdit> edits;
1209 chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210 diag.message, edits);
1212 // Walk the related diagnostics, this is how we encode notes.
1213 if (diag.relatedInformation) {
1214 for (auto &noteDiag : *diag.relatedInformation) {
1215 if (noteDiag.location.uri != uri)
1216 continue;
1217 diagPos = noteDiag.location.range.start;
1218 diagPos.line -= chunk.lineOffset;
1219 chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1220 noteDiag.message, edits);
1223 // Fixup the locations for any edits.
1224 for (lsp::TextEdit &edit : edits)
1225 chunk.adjustLocForChunkOffset(edit.range);
1227 action.edit.emplace();
1228 action.edit->changes[uri.uri().str()] = std::move(edits);
1229 action.diagnostics = {diag};
1231 actions.emplace_back(std::move(action));
1235 llvm::Expected<lsp::MLIRConvertBytecodeResult>
1236 MLIRTextFile::convertToBytecode() {
1237 // Bail out if there is more than one chunk, bytecode wants a single module.
1238 if (chunks.size() != 1) {
1239 return llvm::make_error<lsp::LSPError>(
1240 "unexpected split file, please remove all `// -----`",
1241 lsp::ErrorCode::RequestFailed);
1243 return chunks.front()->document.convertToBytecode();
1246 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247 if (chunks.size() == 1)
1248 return *chunks.front();
1250 // Search for the first chunk with a greater line offset, the previous chunk
1251 // is the one that contains `pos`.
1252 auto it = llvm::upper_bound(
1253 chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1254 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1256 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257 pos.line -= chunk.lineOffset;
1258 return chunk;
1261 //===----------------------------------------------------------------------===//
1262 // MLIRServer::Impl
1263 //===----------------------------------------------------------------------===//
1265 struct lsp::MLIRServer::Impl {
1266 Impl(DialectRegistry &registry) : registry(registry) {}
1268 /// The registry containing dialects that can be recognized in parsed .mlir
1269 /// files.
1270 DialectRegistry &registry;
1272 /// The files held by the server, mapped by their URI file name.
1273 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1276 //===----------------------------------------------------------------------===//
1277 // MLIRServer
1278 //===----------------------------------------------------------------------===//
1280 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281 : impl(std::make_unique<Impl>(registry)) {}
1282 lsp::MLIRServer::~MLIRServer() = default;
1284 void lsp::MLIRServer::addOrUpdateDocument(
1285 const URIForFile &uri, StringRef contents, int64_t version,
1286 std::vector<Diagnostic> &diagnostics) {
1287 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288 uri, contents, version, impl->registry, diagnostics);
1291 std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1292 auto it = impl->files.find(uri.file());
1293 if (it == impl->files.end())
1294 return std::nullopt;
1296 int64_t version = it->second->getVersion();
1297 impl->files.erase(it);
1298 return version;
1301 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
1302 const Position &defPos,
1303 std::vector<Location> &locations) {
1304 auto fileIt = impl->files.find(uri.file());
1305 if (fileIt != impl->files.end())
1306 fileIt->second->getLocationsOf(uri, defPos, locations);
1309 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1310 const Position &pos,
1311 std::vector<Location> &references) {
1312 auto fileIt = impl->files.find(uri.file());
1313 if (fileIt != impl->files.end())
1314 fileIt->second->findReferencesOf(uri, pos, references);
1317 std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1318 const Position &hoverPos) {
1319 auto fileIt = impl->files.find(uri.file());
1320 if (fileIt != impl->files.end())
1321 return fileIt->second->findHover(uri, hoverPos);
1322 return std::nullopt;
1325 void lsp::MLIRServer::findDocumentSymbols(
1326 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327 auto fileIt = impl->files.find(uri.file());
1328 if (fileIt != impl->files.end())
1329 fileIt->second->findDocumentSymbols(symbols);
1332 lsp::CompletionList
1333 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1334 const Position &completePos) {
1335 auto fileIt = impl->files.find(uri.file());
1336 if (fileIt != impl->files.end())
1337 return fileIt->second->getCodeCompletion(uri, completePos);
1338 return CompletionList();
1341 void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1342 const CodeActionContext &context,
1343 std::vector<CodeAction> &actions) {
1344 auto fileIt = impl->files.find(uri.file());
1345 if (fileIt != impl->files.end())
1346 fileIt->second->getCodeActions(uri, pos, context, actions);
1349 llvm::Expected<lsp::MLIRConvertBytecodeResult>
1350 lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
1351 MLIRContext tempContext(impl->registry);
1352 tempContext.allowUnregisteredDialects();
1354 // Collect any errors during parsing.
1355 std::string errorMsg;
1356 ScopedDiagnosticHandler diagHandler(
1357 &tempContext,
1358 [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1360 // Handling for external resources, which we want to propagate up to the user.
1361 FallbackAsmResourceMap fallbackResourceMap;
1363 // Setup the parser config.
1364 ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1365 &fallbackResourceMap);
1367 // Try to parse the given source file.
1368 Block parsedBlock;
1369 if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370 return llvm::make_error<lsp::LSPError>(
1371 "failed to parse bytecode source file: " + errorMsg,
1372 lsp::ErrorCode::RequestFailed);
1375 // TODO: We currently expect a single top-level operation, but this could
1376 // conceptually be relaxed.
1377 if (!llvm::hasSingleElement(parsedBlock)) {
1378 return llvm::make_error<lsp::LSPError>(
1379 "expected bytecode to contain a single top-level operation",
1380 lsp::ErrorCode::RequestFailed);
1383 // Print the module to a buffer.
1384 lsp::MLIRConvertBytecodeResult result;
1386 // Extract the top-level op so that aliases get printed.
1387 // FIXME: We should be able to enable aliases without having to do this!
1388 OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389 topOp->remove();
1391 AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392 /*locationMap=*/nullptr, &fallbackResourceMap);
1394 llvm::raw_string_ostream os(result.output);
1395 topOp->print(os, state);
1397 return std::move(result);
1400 llvm::Expected<lsp::MLIRConvertBytecodeResult>
1401 lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
1402 auto fileIt = impl->files.find(uri.file());
1403 if (fileIt == impl->files.end()) {
1404 return llvm::make_error<lsp::LSPError>(
1405 "language server does not contain an entry for this source file",
1406 lsp::ErrorCode::RequestFailed);
1408 return fileIt->second->convertToBytecode();