1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "MLIRServer.h"
11 #include "mlir/AsmParser/AsmParser.h"
12 #include "mlir/AsmParser/AsmParserState.h"
13 #include "mlir/AsmParser/CodeComplete.h"
14 #include "mlir/Bytecode/BytecodeWriter.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "mlir/Support/ToolUtilities.h"
19 #include "mlir/Tools/lsp-server-support/Logging.h"
20 #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Base64.h"
23 #include "llvm/Support/SourceMgr.h"
28 /// Returns the range of a lexical token given a SMLoc corresponding to the
29 /// start of an token location. The range is computed heuristically, and
30 /// supports identifier-like tokens, strings, etc.
31 static SMRange
convertTokenLocToRange(SMLoc loc
) {
32 return lsp::convertTokenLocToRange(loc
, "$-.");
35 /// Returns a language server location from the given MLIR file location.
36 /// `uriScheme` is the scheme to use when building new uris.
37 static std::optional
<lsp::Location
> getLocationFromLoc(StringRef uriScheme
,
39 llvm::Expected
<lsp::URIForFile
> sourceURI
=
40 lsp::URIForFile::fromFile(loc
.getFilename(), uriScheme
);
42 lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
44 llvm::toString(sourceURI
.takeError()));
48 lsp::Position position
;
49 position
.line
= loc
.getLine() - 1;
50 position
.character
= loc
.getColumn() ? loc
.getColumn() - 1 : 0;
51 return lsp::Location
{*sourceURI
, lsp::Range(position
)};
54 /// Returns a language server location from the given MLIR location, or
55 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56 /// when building new uris. `uri` is an optional additional filter that, when
57 /// present, is used to filter sub locations that do not share the same uri.
58 static std::optional
<lsp::Location
>
59 getLocationFromLoc(llvm::SourceMgr
&sourceMgr
, Location loc
,
60 StringRef uriScheme
, const lsp::URIForFile
*uri
= nullptr) {
61 std::optional
<lsp::Location
> location
;
62 loc
->walk([&](Location nestedLoc
) {
63 FileLineColLoc fileLoc
= dyn_cast
<FileLineColLoc
>(nestedLoc
);
65 return WalkResult::advance();
67 std::optional
<lsp::Location
> sourceLoc
=
68 getLocationFromLoc(uriScheme
, fileLoc
);
69 if (sourceLoc
&& (!uri
|| sourceLoc
->uri
== *uri
)) {
70 location
= *sourceLoc
;
71 SMLoc loc
= sourceMgr
.FindLocForLineAndColumn(
72 sourceMgr
.getMainFileID(), fileLoc
.getLine(), fileLoc
.getColumn());
74 // Use range of potential identifier starting at location, else length 1
76 location
->range
.end
.character
+= 1;
77 if (std::optional
<SMRange
> range
= convertTokenLocToRange(loc
)) {
78 auto lineCol
= sourceMgr
.getLineAndColumn(range
->End
);
79 location
->range
.end
.character
=
80 std::max(fileLoc
.getColumn() + 1, lineCol
.second
- 1);
82 return WalkResult::interrupt();
84 return WalkResult::advance();
89 /// Collect all of the locations from the given MLIR location that are not
90 /// contained within the given URI.
91 static void collectLocationsFromLoc(Location loc
,
92 std::vector
<lsp::Location
> &locations
,
93 const lsp::URIForFile
&uri
) {
94 SetVector
<Location
> visitedLocs
;
95 loc
->walk([&](Location nestedLoc
) {
96 FileLineColLoc fileLoc
= dyn_cast
<FileLineColLoc
>(nestedLoc
);
97 if (!fileLoc
|| !visitedLocs
.insert(nestedLoc
))
98 return WalkResult::advance();
100 std::optional
<lsp::Location
> sourceLoc
=
101 getLocationFromLoc(uri
.scheme(), fileLoc
);
102 if (sourceLoc
&& sourceLoc
->uri
!= uri
)
103 locations
.push_back(*sourceLoc
);
104 return WalkResult::advance();
108 /// Returns true if the given range contains the given source location. Note
109 /// that this has slightly different behavior than SMRange because it is
110 /// inclusive of the end location.
111 static bool contains(SMRange range
, SMLoc loc
) {
112 return range
.Start
.getPointer() <= loc
.getPointer() &&
113 loc
.getPointer() <= range
.End
.getPointer();
116 /// Returns true if the given location is contained by the definition or one of
117 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118 /// the range within `def` that the provided `loc` overlapped with.
119 static bool isDefOrUse(const AsmParserState::SMDefinition
&def
, SMLoc loc
,
120 SMRange
*overlappedRange
= nullptr) {
121 // Check the main definition.
122 if (contains(def
.loc
, loc
)) {
124 *overlappedRange
= def
.loc
;
129 const auto *useIt
= llvm::find_if(
130 def
.uses
, [&](const SMRange
&range
) { return contains(range
, loc
); });
131 if (useIt
!= def
.uses
.end()) {
133 *overlappedRange
= *useIt
;
139 /// Given a location pointing to a result, return the result number it refers
140 /// to or std::nullopt if it refers to all of the results.
141 static std::optional
<unsigned> getResultNumberFromLoc(SMLoc loc
) {
142 // Skip all of the identifier characters.
143 auto isIdentifierChar
= [](char c
) {
144 return isalnum(c
) || c
== '%' || c
== '$' || c
== '.' || c
== '_' ||
147 const char *curPtr
= loc
.getPointer();
148 while (isIdentifierChar(*curPtr
))
151 // Check to see if this location indexes into the result group, via `#`. If it
152 // doesn't, we can't extract a sub result number.
156 // Compute the sub result number from the remaining portion of the string.
157 const char *numberStart
= ++curPtr
;
158 while (llvm::isDigit(*curPtr
))
160 StringRef
numberStr(numberStart
, curPtr
- numberStart
);
161 unsigned resultNumber
= 0;
162 return numberStr
.consumeInteger(10, resultNumber
) ? std::optional
<unsigned>()
166 /// Given a source location range, return the text covered by the given range.
167 /// If the range is invalid, returns std::nullopt.
168 static std::optional
<StringRef
> getTextFromRange(SMRange range
) {
169 if (!range
.isValid())
171 const char *startPtr
= range
.Start
.getPointer();
172 return StringRef(startPtr
, range
.End
.getPointer() - startPtr
);
175 /// Given a block, return its position in its parent region.
176 static unsigned getBlockNumber(Block
*block
) {
177 return std::distance(block
->getParent()->begin(), block
->getIterator());
180 /// Given a block and source location, print the source name of the block to the
181 /// given output stream.
182 static void printDefBlockName(raw_ostream
&os
, Block
*block
, SMRange loc
= {}) {
183 // Try to extract a name from the source location.
184 std::optional
<StringRef
> text
= getTextFromRange(loc
);
185 if (text
&& text
->starts_with("^")) {
190 // Otherwise, we don't have a name so print the block number.
191 os
<< "<Block #" << getBlockNumber(block
) << ">";
193 static void printDefBlockName(raw_ostream
&os
,
194 const AsmParserState::BlockDefinition
&def
) {
195 printDefBlockName(os
, def
.block
, def
.definition
.loc
);
198 /// Convert the given MLIR diagnostic to the LSP form.
199 static lsp::Diagnostic
getLspDiagnoticFromDiag(llvm::SourceMgr
&sourceMgr
,
201 const lsp::URIForFile
&uri
) {
202 lsp::Diagnostic lspDiag
;
203 lspDiag
.source
= "mlir";
205 // Note: Right now all of the diagnostics are treated as parser issues, but
206 // some are parser and some are verifier.
207 lspDiag
.category
= "Parse Error";
209 // Try to grab a file location for this diagnostic.
210 // TODO: For simplicity, we just grab the first one. It may be likely that we
211 // will need a more interesting heuristic here.'
212 StringRef uriScheme
= uri
.scheme();
213 std::optional
<lsp::Location
> lspLocation
=
214 getLocationFromLoc(sourceMgr
, diag
.getLocation(), uriScheme
, &uri
);
216 lspDiag
.range
= lspLocation
->range
;
218 // Convert the severity for the diagnostic.
219 switch (diag
.getSeverity()) {
220 case DiagnosticSeverity::Note
:
221 llvm_unreachable("expected notes to be handled separately");
222 case DiagnosticSeverity::Warning
:
223 lspDiag
.severity
= lsp::DiagnosticSeverity::Warning
;
225 case DiagnosticSeverity::Error
:
226 lspDiag
.severity
= lsp::DiagnosticSeverity::Error
;
228 case DiagnosticSeverity::Remark
:
229 lspDiag
.severity
= lsp::DiagnosticSeverity::Information
;
232 lspDiag
.message
= diag
.str();
234 // Attach any notes to the main diagnostic as related information.
235 std::vector
<lsp::DiagnosticRelatedInformation
> relatedDiags
;
236 for (Diagnostic
¬e
: diag
.getNotes()) {
237 lsp::Location noteLoc
;
238 if (std::optional
<lsp::Location
> loc
=
239 getLocationFromLoc(sourceMgr
, note
.getLocation(), uriScheme
))
243 relatedDiags
.emplace_back(noteLoc
, note
.str());
245 if (!relatedDiags
.empty())
246 lspDiag
.relatedInformation
= std::move(relatedDiags
);
251 //===----------------------------------------------------------------------===//
253 //===----------------------------------------------------------------------===//
256 /// This class represents all of the information pertaining to a specific MLIR
258 struct MLIRDocument
{
259 MLIRDocument(MLIRContext
&context
, const lsp::URIForFile
&uri
,
260 StringRef contents
, std::vector
<lsp::Diagnostic
> &diagnostics
);
261 MLIRDocument(const MLIRDocument
&) = delete;
262 MLIRDocument
&operator=(const MLIRDocument
&) = delete;
264 //===--------------------------------------------------------------------===//
265 // Definitions and References
266 //===--------------------------------------------------------------------===//
268 void getLocationsOf(const lsp::URIForFile
&uri
, const lsp::Position
&defPos
,
269 std::vector
<lsp::Location
> &locations
);
270 void findReferencesOf(const lsp::URIForFile
&uri
, const lsp::Position
&pos
,
271 std::vector
<lsp::Location
> &references
);
273 //===--------------------------------------------------------------------===//
275 //===--------------------------------------------------------------------===//
277 std::optional
<lsp::Hover
> findHover(const lsp::URIForFile
&uri
,
278 const lsp::Position
&hoverPos
);
279 std::optional
<lsp::Hover
>
280 buildHoverForOperation(SMRange hoverRange
,
281 const AsmParserState::OperationDefinition
&op
);
282 lsp::Hover
buildHoverForOperationResult(SMRange hoverRange
, Operation
*op
,
283 unsigned resultStart
,
284 unsigned resultEnd
, SMLoc posLoc
);
285 lsp::Hover
buildHoverForBlock(SMRange hoverRange
,
286 const AsmParserState::BlockDefinition
&block
);
288 buildHoverForBlockArgument(SMRange hoverRange
, BlockArgument arg
,
289 const AsmParserState::BlockDefinition
&block
);
291 lsp::Hover
buildHoverForAttributeAlias(
292 SMRange hoverRange
, const AsmParserState::AttributeAliasDefinition
&attr
);
294 buildHoverForTypeAlias(SMRange hoverRange
,
295 const AsmParserState::TypeAliasDefinition
&type
);
297 //===--------------------------------------------------------------------===//
299 //===--------------------------------------------------------------------===//
301 void findDocumentSymbols(std::vector
<lsp::DocumentSymbol
> &symbols
);
302 void findDocumentSymbols(Operation
*op
,
303 std::vector
<lsp::DocumentSymbol
> &symbols
);
305 //===--------------------------------------------------------------------===//
307 //===--------------------------------------------------------------------===//
309 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile
&uri
,
310 const lsp::Position
&completePos
,
311 const DialectRegistry
®istry
);
313 //===--------------------------------------------------------------------===//
315 //===--------------------------------------------------------------------===//
317 void getCodeActionForDiagnostic(const lsp::URIForFile
&uri
,
318 lsp::Position
&pos
, StringRef severity
,
320 std::vector
<lsp::TextEdit
> &edits
);
322 //===--------------------------------------------------------------------===//
324 //===--------------------------------------------------------------------===//
326 llvm::Expected
<lsp::MLIRConvertBytecodeResult
> convertToBytecode();
328 //===--------------------------------------------------------------------===//
330 //===--------------------------------------------------------------------===//
332 /// The high level parser state used to find definitions and references within
334 AsmParserState asmState
;
336 /// The container for the IR parsed from the input file.
339 /// A collection of external resources, which we want to propagate up to the
341 FallbackAsmResourceMap fallbackResourceMap
;
343 /// The source manager containing the contents of the input file.
344 llvm::SourceMgr sourceMgr
;
348 MLIRDocument::MLIRDocument(MLIRContext
&context
, const lsp::URIForFile
&uri
,
350 std::vector
<lsp::Diagnostic
> &diagnostics
) {
351 ScopedDiagnosticHandler
handler(&context
, [&](Diagnostic
&diag
) {
352 diagnostics
.push_back(getLspDiagnoticFromDiag(sourceMgr
, diag
, uri
));
355 // Try to parsed the given IR string.
356 auto memBuffer
= llvm::MemoryBuffer::getMemBufferCopy(contents
, uri
.file());
358 lsp::Logger::error("Failed to create memory buffer for file", uri
.file());
362 ParserConfig
config(&context
, /*verifyAfterParse=*/true,
363 &fallbackResourceMap
);
364 sourceMgr
.AddNewSourceBuffer(std::move(memBuffer
), SMLoc());
365 if (failed(parseAsmSourceFile(sourceMgr
, &parsedIR
, config
, &asmState
))) {
366 // If parsing failed, clear out any of the current state.
368 asmState
= AsmParserState();
369 fallbackResourceMap
= FallbackAsmResourceMap();
374 //===----------------------------------------------------------------------===//
375 // MLIRDocument: Definitions and References
376 //===----------------------------------------------------------------------===//
378 void MLIRDocument::getLocationsOf(const lsp::URIForFile
&uri
,
379 const lsp::Position
&defPos
,
380 std::vector
<lsp::Location
> &locations
) {
381 SMLoc posLoc
= defPos
.getAsSMLoc(sourceMgr
);
383 // Functor used to check if an SM definition contains the position.
384 auto containsPosition
= [&](const AsmParserState::SMDefinition
&def
) {
385 if (!isDefOrUse(def
, posLoc
))
387 locations
.emplace_back(uri
, sourceMgr
, def
.loc
);
391 // Check all definitions related to operations.
392 for (const AsmParserState::OperationDefinition
&op
: asmState
.getOpDefs()) {
393 if (contains(op
.loc
, posLoc
))
394 return collectLocationsFromLoc(op
.op
->getLoc(), locations
, uri
);
395 for (const auto &result
: op
.resultGroups
)
396 if (containsPosition(result
.definition
))
397 return collectLocationsFromLoc(op
.op
->getLoc(), locations
, uri
);
398 for (const auto &symUse
: op
.symbolUses
) {
399 if (contains(symUse
, posLoc
)) {
400 locations
.emplace_back(uri
, sourceMgr
, op
.loc
);
401 return collectLocationsFromLoc(op
.op
->getLoc(), locations
, uri
);
406 // Check all definitions related to blocks.
407 for (const AsmParserState::BlockDefinition
&block
: asmState
.getBlockDefs()) {
408 if (containsPosition(block
.definition
))
410 for (const AsmParserState::SMDefinition
&arg
: block
.arguments
)
411 if (containsPosition(arg
))
415 // Check all alias definitions.
416 for (const AsmParserState::AttributeAliasDefinition
&attr
:
417 asmState
.getAttributeAliasDefs()) {
418 if (containsPosition(attr
.definition
))
421 for (const AsmParserState::TypeAliasDefinition
&type
:
422 asmState
.getTypeAliasDefs()) {
423 if (containsPosition(type
.definition
))
428 void MLIRDocument::findReferencesOf(const lsp::URIForFile
&uri
,
429 const lsp::Position
&pos
,
430 std::vector
<lsp::Location
> &references
) {
431 // Functor used to append all of the definitions/uses of the given SM
432 // definition to the reference list.
433 auto appendSMDef
= [&](const AsmParserState::SMDefinition
&def
) {
434 references
.emplace_back(uri
, sourceMgr
, def
.loc
);
435 for (const SMRange
&use
: def
.uses
)
436 references
.emplace_back(uri
, sourceMgr
, use
);
439 SMLoc posLoc
= pos
.getAsSMLoc(sourceMgr
);
441 // Check all definitions related to operations.
442 for (const AsmParserState::OperationDefinition
&op
: asmState
.getOpDefs()) {
443 if (contains(op
.loc
, posLoc
)) {
444 for (const auto &result
: op
.resultGroups
)
445 appendSMDef(result
.definition
);
446 for (const auto &symUse
: op
.symbolUses
)
447 if (contains(symUse
, posLoc
))
448 references
.emplace_back(uri
, sourceMgr
, symUse
);
451 for (const auto &result
: op
.resultGroups
)
452 if (isDefOrUse(result
.definition
, posLoc
))
453 return appendSMDef(result
.definition
);
454 for (const auto &symUse
: op
.symbolUses
) {
455 if (!contains(symUse
, posLoc
))
457 for (const auto &symUse
: op
.symbolUses
)
458 references
.emplace_back(uri
, sourceMgr
, symUse
);
463 // Check all definitions related to blocks.
464 for (const AsmParserState::BlockDefinition
&block
: asmState
.getBlockDefs()) {
465 if (isDefOrUse(block
.definition
, posLoc
))
466 return appendSMDef(block
.definition
);
468 for (const AsmParserState::SMDefinition
&arg
: block
.arguments
)
469 if (isDefOrUse(arg
, posLoc
))
470 return appendSMDef(arg
);
473 // Check all alias definitions.
474 for (const AsmParserState::AttributeAliasDefinition
&attr
:
475 asmState
.getAttributeAliasDefs()) {
476 if (isDefOrUse(attr
.definition
, posLoc
))
477 return appendSMDef(attr
.definition
);
479 for (const AsmParserState::TypeAliasDefinition
&type
:
480 asmState
.getTypeAliasDefs()) {
481 if (isDefOrUse(type
.definition
, posLoc
))
482 return appendSMDef(type
.definition
);
486 //===----------------------------------------------------------------------===//
487 // MLIRDocument: Hover
488 //===----------------------------------------------------------------------===//
490 std::optional
<lsp::Hover
>
491 MLIRDocument::findHover(const lsp::URIForFile
&uri
,
492 const lsp::Position
&hoverPos
) {
493 SMLoc posLoc
= hoverPos
.getAsSMLoc(sourceMgr
);
496 // Check for Hovers on operations and results.
497 for (const AsmParserState::OperationDefinition
&op
: asmState
.getOpDefs()) {
498 // Check if the position points at this operation.
499 if (contains(op
.loc
, posLoc
))
500 return buildHoverForOperation(op
.loc
, op
);
502 // Check if the position points at the symbol name.
503 for (auto &use
: op
.symbolUses
)
504 if (contains(use
, posLoc
))
505 return buildHoverForOperation(use
, op
);
507 // Check if the position points at a result group.
508 for (unsigned i
= 0, e
= op
.resultGroups
.size(); i
< e
; ++i
) {
509 const auto &result
= op
.resultGroups
[i
];
510 if (!isDefOrUse(result
.definition
, posLoc
, &hoverRange
))
513 // Get the range of results covered by the over position.
514 unsigned resultStart
= result
.startIndex
;
515 unsigned resultEnd
= (i
== e
- 1) ? op
.op
->getNumResults()
516 : op
.resultGroups
[i
+ 1].startIndex
;
517 return buildHoverForOperationResult(hoverRange
, op
.op
, resultStart
,
522 // Check to see if the hover is over a block argument.
523 for (const AsmParserState::BlockDefinition
&block
: asmState
.getBlockDefs()) {
524 if (isDefOrUse(block
.definition
, posLoc
, &hoverRange
))
525 return buildHoverForBlock(hoverRange
, block
);
527 for (const auto &arg
: llvm::enumerate(block
.arguments
)) {
528 if (!isDefOrUse(arg
.value(), posLoc
, &hoverRange
))
531 return buildHoverForBlockArgument(
532 hoverRange
, block
.block
->getArgument(arg
.index()), block
);
536 // Check to see if the hover is over an alias.
537 for (const AsmParserState::AttributeAliasDefinition
&attr
:
538 asmState
.getAttributeAliasDefs()) {
539 if (isDefOrUse(attr
.definition
, posLoc
, &hoverRange
))
540 return buildHoverForAttributeAlias(hoverRange
, attr
);
542 for (const AsmParserState::TypeAliasDefinition
&type
:
543 asmState
.getTypeAliasDefs()) {
544 if (isDefOrUse(type
.definition
, posLoc
, &hoverRange
))
545 return buildHoverForTypeAlias(hoverRange
, type
);
551 std::optional
<lsp::Hover
> MLIRDocument::buildHoverForOperation(
552 SMRange hoverRange
, const AsmParserState::OperationDefinition
&op
) {
553 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
554 llvm::raw_string_ostream
os(hover
.contents
.value
);
556 // Add the operation name to the hover.
557 os
<< "\"" << op
.op
->getName() << "\"";
558 if (SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(op
.op
))
559 os
<< " : " << symbol
.getVisibility() << " @" << symbol
.getName() << "";
562 os
<< "Generic Form:\n\n```mlir\n";
564 op
.op
->print(os
, OpPrintingFlags()
565 .printGenericOpForm()
566 .elideLargeElementsAttrs()
573 lsp::Hover
MLIRDocument::buildHoverForOperationResult(SMRange hoverRange
,
575 unsigned resultStart
,
578 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
579 llvm::raw_string_ostream
os(hover
.contents
.value
);
581 // Add the parent operation name to the hover.
582 os
<< "Operation: \"" << op
->getName() << "\"\n\n";
584 // Check to see if the location points to a specific result within the
586 if (std::optional
<unsigned> resultNumber
= getResultNumberFromLoc(posLoc
)) {
587 if ((resultStart
+ *resultNumber
) < resultEnd
) {
588 resultStart
+= *resultNumber
;
589 resultEnd
= resultStart
+ 1;
593 // Add the range of results and their types to the hover info.
594 if ((resultStart
+ 1) == resultEnd
) {
595 os
<< "Result #" << resultStart
<< "\n\n"
596 << "Type: `" << op
->getResult(resultStart
).getType() << "`\n\n";
598 os
<< "Result #[" << resultStart
<< ", " << (resultEnd
- 1) << "]\n\n"
600 llvm::interleaveComma(
601 op
->getResults().slice(resultStart
, resultEnd
), os
,
602 [&](Value result
) { os
<< "`" << result
.getType() << "`"; });
609 MLIRDocument::buildHoverForBlock(SMRange hoverRange
,
610 const AsmParserState::BlockDefinition
&block
) {
611 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
612 llvm::raw_string_ostream
os(hover
.contents
.value
);
614 // Print the given block to the hover output stream.
615 auto printBlockToHover
= [&](Block
*newBlock
) {
616 if (const auto *def
= asmState
.getBlockDef(newBlock
))
617 printDefBlockName(os
, *def
);
619 printDefBlockName(os
, newBlock
);
622 // Display the parent operation, block number, predecessors, and successors.
623 os
<< "Operation: \"" << block
.block
->getParentOp()->getName() << "\"\n\n"
624 << "Block #" << getBlockNumber(block
.block
) << "\n\n";
625 if (!block
.block
->hasNoPredecessors()) {
626 os
<< "Predecessors: ";
627 llvm::interleaveComma(block
.block
->getPredecessors(), os
,
631 if (!block
.block
->hasNoSuccessors()) {
632 os
<< "Successors: ";
633 llvm::interleaveComma(block
.block
->getSuccessors(), os
, printBlockToHover
);
640 lsp::Hover
MLIRDocument::buildHoverForBlockArgument(
641 SMRange hoverRange
, BlockArgument arg
,
642 const AsmParserState::BlockDefinition
&block
) {
643 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
644 llvm::raw_string_ostream
os(hover
.contents
.value
);
646 // Display the parent operation, block, the argument number, and the type.
647 os
<< "Operation: \"" << block
.block
->getParentOp()->getName() << "\"\n\n"
649 printDefBlockName(os
, block
);
650 os
<< "\n\nArgument #" << arg
.getArgNumber() << "\n\n"
651 << "Type: `" << arg
.getType() << "`\n\n";
656 lsp::Hover
MLIRDocument::buildHoverForAttributeAlias(
657 SMRange hoverRange
, const AsmParserState::AttributeAliasDefinition
&attr
) {
658 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
659 llvm::raw_string_ostream
os(hover
.contents
.value
);
661 os
<< "Attribute Alias: \"" << attr
.name
<< "\n\n";
662 os
<< "Value: ```mlir\n" << attr
.value
<< "\n```\n\n";
667 lsp::Hover
MLIRDocument::buildHoverForTypeAlias(
668 SMRange hoverRange
, const AsmParserState::TypeAliasDefinition
&type
) {
669 lsp::Hover
hover(lsp::Range(sourceMgr
, hoverRange
));
670 llvm::raw_string_ostream
os(hover
.contents
.value
);
672 os
<< "Type Alias: \"" << type
.name
<< "\n\n";
673 os
<< "Value: ```mlir\n" << type
.value
<< "\n```\n\n";
678 //===----------------------------------------------------------------------===//
679 // MLIRDocument: Document Symbols
680 //===----------------------------------------------------------------------===//
682 void MLIRDocument::findDocumentSymbols(
683 std::vector
<lsp::DocumentSymbol
> &symbols
) {
684 for (Operation
&op
: parsedIR
)
685 findDocumentSymbols(&op
, symbols
);
688 void MLIRDocument::findDocumentSymbols(
689 Operation
*op
, std::vector
<lsp::DocumentSymbol
> &symbols
) {
690 std::vector
<lsp::DocumentSymbol
> *childSymbols
= &symbols
;
692 // Check for the source information of this operation.
693 if (const AsmParserState::OperationDefinition
*def
= asmState
.getOpDef(op
)) {
694 // If this operation defines a symbol, record it.
695 if (SymbolOpInterface symbol
= dyn_cast
<SymbolOpInterface
>(op
)) {
696 symbols
.emplace_back(symbol
.getName(),
697 isa
<FunctionOpInterface
>(op
)
698 ? lsp::SymbolKind::Function
699 : lsp::SymbolKind::Class
,
700 lsp::Range(sourceMgr
, def
->scopeLoc
),
701 lsp::Range(sourceMgr
, def
->loc
));
702 childSymbols
= &symbols
.back().children
;
704 } else if (op
->hasTrait
<OpTrait::SymbolTable
>()) {
705 // Otherwise, if this is a symbol table push an anonymous document symbol.
706 symbols
.emplace_back("<" + op
->getName().getStringRef() + ">",
707 lsp::SymbolKind::Namespace
,
708 lsp::Range(sourceMgr
, def
->scopeLoc
),
709 lsp::Range(sourceMgr
, def
->loc
));
710 childSymbols
= &symbols
.back().children
;
714 // Recurse into the regions of this operation.
715 if (!op
->getNumRegions())
717 for (Region
®ion
: op
->getRegions())
718 for (Operation
&childOp
: region
.getOps())
719 findDocumentSymbols(&childOp
, *childSymbols
);
722 //===----------------------------------------------------------------------===//
723 // MLIRDocument: Code Completion
724 //===----------------------------------------------------------------------===//
727 class LSPCodeCompleteContext
: public AsmParserCodeCompleteContext
{
729 LSPCodeCompleteContext(SMLoc completeLoc
, lsp::CompletionList
&completionList
,
731 : AsmParserCodeCompleteContext(completeLoc
),
732 completionList(completionList
), ctx(ctx
) {}
734 /// Signal code completion for a dialect name, with an optional prefix.
735 void completeDialectName(StringRef prefix
) final
{
736 for (StringRef dialect
: ctx
->getAvailableDialects()) {
737 lsp::CompletionItem
item(prefix
+ dialect
,
738 lsp::CompletionItemKind::Module
,
740 item
.detail
= "dialect";
741 completionList
.items
.emplace_back(item
);
744 using AsmParserCodeCompleteContext::completeDialectName
;
746 /// Signal code completion for an operation name within the given dialect.
747 void completeOperationName(StringRef dialectName
) final
{
748 Dialect
*dialect
= ctx
->getOrLoadDialect(dialectName
);
752 for (const auto &op
: ctx
->getRegisteredOperations()) {
753 if (&op
.getDialect() != dialect
)
756 lsp::CompletionItem
item(
757 op
.getStringRef().drop_front(dialectName
.size() + 1),
758 lsp::CompletionItemKind::Field
,
760 item
.detail
= "operation";
761 completionList
.items
.emplace_back(item
);
765 /// Append the given SSA value as a code completion result for SSA value
767 void appendSSAValueCompletion(StringRef name
, std::string typeData
) final
{
768 // Check if we need to insert the `%` or not.
769 bool stripPrefix
= getCodeCompleteLoc().getPointer()[-1] == '%';
771 lsp::CompletionItem
item(name
, lsp::CompletionItemKind::Variable
);
773 item
.insertText
= name
.drop_front(1).str();
774 item
.detail
= std::move(typeData
);
775 completionList
.items
.emplace_back(item
);
778 /// Append the given block as a code completion result for block name
780 void appendBlockCompletion(StringRef name
) final
{
781 // Check if we need to insert the `^` or not.
782 bool stripPrefix
= getCodeCompleteLoc().getPointer()[-1] == '^';
784 lsp::CompletionItem
item(name
, lsp::CompletionItemKind::Field
);
786 item
.insertText
= name
.drop_front(1).str();
787 completionList
.items
.emplace_back(item
);
790 /// Signal a completion for the given expected token.
791 void completeExpectedTokens(ArrayRef
<StringRef
> tokens
, bool optional
) final
{
792 for (StringRef token
: tokens
) {
793 lsp::CompletionItem
item(token
, lsp::CompletionItemKind::Keyword
,
795 item
.detail
= optional
? "optional" : "";
796 completionList
.items
.emplace_back(item
);
800 /// Signal a completion for an attribute.
801 void completeAttribute(const llvm::StringMap
<Attribute
> &aliases
) override
{
802 appendSimpleCompletions({"affine_set", "affine_map", "dense",
803 "dense_resource", "false", "loc", "sparse", "true",
805 lsp::CompletionItemKind::Field
,
808 completeDialectName("#");
809 completeAliases(aliases
, "#");
811 void completeDialectAttributeOrAlias(
812 const llvm::StringMap
<Attribute
> &aliases
) override
{
813 completeDialectName();
814 completeAliases(aliases
);
817 /// Signal a completion for a type.
818 void completeType(const llvm::StringMap
<Type
> &aliases
) override
{
819 // Handle the various builtin types.
820 appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
821 "bf16", "f16", "f32", "f64", "f80", "f128",
823 lsp::CompletionItemKind::Field
,
826 // Handle the builtin integer types.
827 for (StringRef type
: {"i", "si", "ui"}) {
828 lsp::CompletionItem
item(type
+ "<N>", lsp::CompletionItemKind::Field
,
830 item
.insertText
= type
.str();
831 completionList
.items
.emplace_back(item
);
834 // Insert completions for dialect types and aliases.
835 completeDialectName("!");
836 completeAliases(aliases
, "!");
839 completeDialectTypeOrAlias(const llvm::StringMap
<Type
> &aliases
) override
{
840 completeDialectName();
841 completeAliases(aliases
);
844 /// Add completion results for the given set of aliases.
845 template <typename T
>
846 void completeAliases(const llvm::StringMap
<T
> &aliases
,
847 StringRef prefix
= "") {
848 for (const auto &alias
: aliases
) {
849 lsp::CompletionItem
item(prefix
+ alias
.getKey(),
850 lsp::CompletionItemKind::Field
,
852 llvm::raw_string_ostream(item
.detail
) << "alias: " << alias
.getValue();
853 completionList
.items
.emplace_back(item
);
857 /// Add a set of simple completions that all have the same kind.
858 void appendSimpleCompletions(ArrayRef
<StringRef
> completions
,
859 lsp::CompletionItemKind kind
,
860 StringRef sortText
= "") {
861 for (StringRef completion
: completions
)
862 completionList
.items
.emplace_back(completion
, kind
, sortText
);
866 lsp::CompletionList
&completionList
;
872 MLIRDocument::getCodeCompletion(const lsp::URIForFile
&uri
,
873 const lsp::Position
&completePos
,
874 const DialectRegistry
®istry
) {
875 SMLoc posLoc
= completePos
.getAsSMLoc(sourceMgr
);
876 if (!posLoc
.isValid())
877 return lsp::CompletionList();
879 // To perform code completion, we run another parse of the module with the
880 // code completion context provided.
881 MLIRContext
tmpContext(registry
, MLIRContext::Threading::DISABLED
);
882 tmpContext
.allowUnregisteredDialects();
883 lsp::CompletionList completionList
;
884 LSPCodeCompleteContext
lspCompleteContext(posLoc
, completionList
,
888 AsmParserState tmpState
;
889 (void)parseAsmSourceFile(sourceMgr
, &tmpIR
, &tmpContext
, &tmpState
,
890 &lspCompleteContext
);
891 return completionList
;
894 //===----------------------------------------------------------------------===//
895 // MLIRDocument: Code Action
896 //===----------------------------------------------------------------------===//
898 void MLIRDocument::getCodeActionForDiagnostic(
899 const lsp::URIForFile
&uri
, lsp::Position
&pos
, StringRef severity
,
900 StringRef message
, std::vector
<lsp::TextEdit
> &edits
) {
901 // Ignore diagnostics that print the current operation. These are always
902 // enabled for the language server, but not generally during normal
903 // parsing/verification.
904 if (message
.starts_with("see current operation: "))
907 // Get the start of the line containing the diagnostic.
908 const auto &buffer
= sourceMgr
.getBufferInfo(sourceMgr
.getMainFileID());
909 const char *lineStart
= buffer
.getPointerForLineNumber(pos
.line
+ 1);
912 StringRef
line(lineStart
, pos
.character
);
914 // Add a text edit for adding an expected-* diagnostic check for this
917 edit
.range
= lsp::Range(lsp::Position(pos
.line
, 0));
919 // Use the indent of the current line for the expected-* diagnostic.
920 size_t indent
= line
.find_first_not_of(' ');
921 if (indent
== StringRef::npos
)
922 indent
= line
.size();
924 edit
.newText
.append(indent
, ' ');
925 llvm::raw_string_ostream(edit
.newText
)
926 << "// expected-" << severity
<< " @below {{" << message
<< "}}\n";
927 edits
.emplace_back(std::move(edit
));
930 //===----------------------------------------------------------------------===//
931 // MLIRDocument: Bytecode
932 //===----------------------------------------------------------------------===//
934 llvm::Expected
<lsp::MLIRConvertBytecodeResult
>
935 MLIRDocument::convertToBytecode() {
936 // TODO: We currently require a single top-level operation, but this could
937 // conceptually be relaxed.
938 if (!llvm::hasSingleElement(parsedIR
)) {
939 if (parsedIR
.empty()) {
940 return llvm::make_error
<lsp::LSPError
>(
941 "expected a single and valid top-level operation, please ensure "
942 "there are no errors",
943 lsp::ErrorCode::RequestFailed
);
945 return llvm::make_error
<lsp::LSPError
>(
946 "expected a single top-level operation", lsp::ErrorCode::RequestFailed
);
949 lsp::MLIRConvertBytecodeResult result
;
951 BytecodeWriterConfig
writerConfig(fallbackResourceMap
);
953 std::string rawBytecodeBuffer
;
954 llvm::raw_string_ostream
os(rawBytecodeBuffer
);
955 // No desired bytecode version set, so no need to check for error.
956 (void)writeBytecodeToFile(&parsedIR
.front(), os
, writerConfig
);
957 result
.output
= llvm::encodeBase64(rawBytecodeBuffer
);
962 //===----------------------------------------------------------------------===//
964 //===----------------------------------------------------------------------===//
967 /// This class represents a single chunk of an MLIR text file.
968 struct MLIRTextFileChunk
{
969 MLIRTextFileChunk(MLIRContext
&context
, uint64_t lineOffset
,
970 const lsp::URIForFile
&uri
, StringRef contents
,
971 std::vector
<lsp::Diagnostic
> &diagnostics
)
972 : lineOffset(lineOffset
), document(context
, uri
, contents
, diagnostics
) {}
974 /// Adjust the line number of the given range to anchor at the beginning of
975 /// the file, instead of the beginning of this chunk.
976 void adjustLocForChunkOffset(lsp::Range
&range
) {
977 adjustLocForChunkOffset(range
.start
);
978 adjustLocForChunkOffset(range
.end
);
980 /// Adjust the line number of the given position to anchor at the beginning of
981 /// the file, instead of the beginning of this chunk.
982 void adjustLocForChunkOffset(lsp::Position
&pos
) { pos
.line
+= lineOffset
; }
984 /// The line offset of this chunk from the beginning of the file.
986 /// The document referred to by this chunk.
987 MLIRDocument document
;
991 //===----------------------------------------------------------------------===//
993 //===----------------------------------------------------------------------===//
996 /// This class represents a text file containing one or more MLIR documents.
999 MLIRTextFile(const lsp::URIForFile
&uri
, StringRef fileContents
,
1000 int64_t version
, DialectRegistry
®istry
,
1001 std::vector
<lsp::Diagnostic
> &diagnostics
);
1003 /// Return the current version of this text file.
1004 int64_t getVersion() const { return version
; }
1006 //===--------------------------------------------------------------------===//
1008 //===--------------------------------------------------------------------===//
1010 void getLocationsOf(const lsp::URIForFile
&uri
, lsp::Position defPos
,
1011 std::vector
<lsp::Location
> &locations
);
1012 void findReferencesOf(const lsp::URIForFile
&uri
, lsp::Position pos
,
1013 std::vector
<lsp::Location
> &references
);
1014 std::optional
<lsp::Hover
> findHover(const lsp::URIForFile
&uri
,
1015 lsp::Position hoverPos
);
1016 void findDocumentSymbols(std::vector
<lsp::DocumentSymbol
> &symbols
);
1017 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile
&uri
,
1018 lsp::Position completePos
);
1019 void getCodeActions(const lsp::URIForFile
&uri
, const lsp::Range
&pos
,
1020 const lsp::CodeActionContext
&context
,
1021 std::vector
<lsp::CodeAction
> &actions
);
1022 llvm::Expected
<lsp::MLIRConvertBytecodeResult
> convertToBytecode();
1025 /// Find the MLIR document that contains the given position, and update the
1026 /// position to be anchored at the start of the found chunk instead of the
1027 /// beginning of the file.
1028 MLIRTextFileChunk
&getChunkFor(lsp::Position
&pos
);
1030 /// The context used to hold the state contained by the parsed document.
1031 MLIRContext context
;
1033 /// The full string contents of the file.
1034 std::string contents
;
1036 /// The version of this file.
1039 /// The number of lines in the file.
1040 int64_t totalNumLines
= 0;
1042 /// The chunks of this file. The order of these chunks is the order in which
1043 /// they appear in the text file.
1044 std::vector
<std::unique_ptr
<MLIRTextFileChunk
>> chunks
;
1048 MLIRTextFile::MLIRTextFile(const lsp::URIForFile
&uri
, StringRef fileContents
,
1049 int64_t version
, DialectRegistry
®istry
,
1050 std::vector
<lsp::Diagnostic
> &diagnostics
)
1051 : context(registry
, MLIRContext::Threading::DISABLED
),
1052 contents(fileContents
.str()), version(version
) {
1053 context
.allowUnregisteredDialects();
1055 // Split the file into separate MLIR documents.
1056 SmallVector
<StringRef
, 8> subContents
;
1057 StringRef(contents
).split(subContents
, kDefaultSplitMarker
);
1058 chunks
.emplace_back(std::make_unique
<MLIRTextFileChunk
>(
1059 context
, /*lineOffset=*/0, uri
, subContents
.front(), diagnostics
));
1061 uint64_t lineOffset
= subContents
.front().count('\n');
1062 for (StringRef docContents
: llvm::drop_begin(subContents
)) {
1063 unsigned currentNumDiags
= diagnostics
.size();
1064 auto chunk
= std::make_unique
<MLIRTextFileChunk
>(context
, lineOffset
, uri
,
1065 docContents
, diagnostics
);
1066 lineOffset
+= docContents
.count('\n');
1068 // Adjust locations used in diagnostics to account for the offset from the
1069 // beginning of the file.
1070 for (lsp::Diagnostic
&diag
:
1071 llvm::drop_begin(diagnostics
, currentNumDiags
)) {
1072 chunk
->adjustLocForChunkOffset(diag
.range
);
1074 if (!diag
.relatedInformation
)
1076 for (auto &it
: *diag
.relatedInformation
)
1077 if (it
.location
.uri
== uri
)
1078 chunk
->adjustLocForChunkOffset(it
.location
.range
);
1080 chunks
.emplace_back(std::move(chunk
));
1082 totalNumLines
= lineOffset
;
1085 void MLIRTextFile::getLocationsOf(const lsp::URIForFile
&uri
,
1086 lsp::Position defPos
,
1087 std::vector
<lsp::Location
> &locations
) {
1088 MLIRTextFileChunk
&chunk
= getChunkFor(defPos
);
1089 chunk
.document
.getLocationsOf(uri
, defPos
, locations
);
1091 // Adjust any locations within this file for the offset of this chunk.
1092 if (chunk
.lineOffset
== 0)
1094 for (lsp::Location
&loc
: locations
)
1096 chunk
.adjustLocForChunkOffset(loc
.range
);
1099 void MLIRTextFile::findReferencesOf(const lsp::URIForFile
&uri
,
1101 std::vector
<lsp::Location
> &references
) {
1102 MLIRTextFileChunk
&chunk
= getChunkFor(pos
);
1103 chunk
.document
.findReferencesOf(uri
, pos
, references
);
1105 // Adjust any locations within this file for the offset of this chunk.
1106 if (chunk
.lineOffset
== 0)
1108 for (lsp::Location
&loc
: references
)
1110 chunk
.adjustLocForChunkOffset(loc
.range
);
1113 std::optional
<lsp::Hover
> MLIRTextFile::findHover(const lsp::URIForFile
&uri
,
1114 lsp::Position hoverPos
) {
1115 MLIRTextFileChunk
&chunk
= getChunkFor(hoverPos
);
1116 std::optional
<lsp::Hover
> hoverInfo
= chunk
.document
.findHover(uri
, hoverPos
);
1118 // Adjust any locations within this file for the offset of this chunk.
1119 if (chunk
.lineOffset
!= 0 && hoverInfo
&& hoverInfo
->range
)
1120 chunk
.adjustLocForChunkOffset(*hoverInfo
->range
);
1124 void MLIRTextFile::findDocumentSymbols(
1125 std::vector
<lsp::DocumentSymbol
> &symbols
) {
1126 if (chunks
.size() == 1)
1127 return chunks
.front()->document
.findDocumentSymbols(symbols
);
1129 // If there are multiple chunks in this file, we create top-level symbols for
1131 for (unsigned i
= 0, e
= chunks
.size(); i
< e
; ++i
) {
1132 MLIRTextFileChunk
&chunk
= *chunks
[i
];
1133 lsp::Position
startPos(chunk
.lineOffset
);
1134 lsp::Position
endPos((i
== e
- 1) ? totalNumLines
- 1
1135 : chunks
[i
+ 1]->lineOffset
);
1136 lsp::DocumentSymbol
symbol("<file-split-" + Twine(i
) + ">",
1137 lsp::SymbolKind::Namespace
,
1138 /*range=*/lsp::Range(startPos
, endPos
),
1139 /*selectionRange=*/lsp::Range(startPos
));
1140 chunk
.document
.findDocumentSymbols(symbol
.children
);
1142 // Fixup the locations of document symbols within this chunk.
1144 SmallVector
<lsp::DocumentSymbol
*> symbolsToFix
;
1145 for (lsp::DocumentSymbol
&childSymbol
: symbol
.children
)
1146 symbolsToFix
.push_back(&childSymbol
);
1148 while (!symbolsToFix
.empty()) {
1149 lsp::DocumentSymbol
*symbol
= symbolsToFix
.pop_back_val();
1150 chunk
.adjustLocForChunkOffset(symbol
->range
);
1151 chunk
.adjustLocForChunkOffset(symbol
->selectionRange
);
1153 for (lsp::DocumentSymbol
&childSymbol
: symbol
->children
)
1154 symbolsToFix
.push_back(&childSymbol
);
1158 // Push the symbol for this chunk.
1159 symbols
.emplace_back(std::move(symbol
));
1163 lsp::CompletionList
MLIRTextFile::getCodeCompletion(const lsp::URIForFile
&uri
,
1164 lsp::Position completePos
) {
1165 MLIRTextFileChunk
&chunk
= getChunkFor(completePos
);
1166 lsp::CompletionList completionList
= chunk
.document
.getCodeCompletion(
1167 uri
, completePos
, context
.getDialectRegistry());
1169 // Adjust any completion locations.
1170 for (lsp::CompletionItem
&item
: completionList
.items
) {
1172 chunk
.adjustLocForChunkOffset(item
.textEdit
->range
);
1173 for (lsp::TextEdit
&edit
: item
.additionalTextEdits
)
1174 chunk
.adjustLocForChunkOffset(edit
.range
);
1176 return completionList
;
1179 void MLIRTextFile::getCodeActions(const lsp::URIForFile
&uri
,
1180 const lsp::Range
&pos
,
1181 const lsp::CodeActionContext
&context
,
1182 std::vector
<lsp::CodeAction
> &actions
) {
1183 // Create actions for any diagnostics in this file.
1184 for (auto &diag
: context
.diagnostics
) {
1185 if (diag
.source
!= "mlir")
1187 lsp::Position diagPos
= diag
.range
.start
;
1188 MLIRTextFileChunk
&chunk
= getChunkFor(diagPos
);
1190 // Add a new code action that inserts a "expected" diagnostic check.
1191 lsp::CodeAction action
;
1192 action
.title
= "Add expected-* diagnostic checks";
1193 action
.kind
= lsp::CodeAction::kQuickFix
.str();
1196 switch (diag
.severity
) {
1197 case lsp::DiagnosticSeverity::Error
:
1200 case lsp::DiagnosticSeverity::Warning
:
1201 severity
= "warning";
1207 // Get edits for the diagnostic.
1208 std::vector
<lsp::TextEdit
> edits
;
1209 chunk
.document
.getCodeActionForDiagnostic(uri
, diagPos
, severity
,
1210 diag
.message
, edits
);
1212 // Walk the related diagnostics, this is how we encode notes.
1213 if (diag
.relatedInformation
) {
1214 for (auto ¬eDiag
: *diag
.relatedInformation
) {
1215 if (noteDiag
.location
.uri
!= uri
)
1217 diagPos
= noteDiag
.location
.range
.start
;
1218 diagPos
.line
-= chunk
.lineOffset
;
1219 chunk
.document
.getCodeActionForDiagnostic(uri
, diagPos
, "note",
1220 noteDiag
.message
, edits
);
1223 // Fixup the locations for any edits.
1224 for (lsp::TextEdit
&edit
: edits
)
1225 chunk
.adjustLocForChunkOffset(edit
.range
);
1227 action
.edit
.emplace();
1228 action
.edit
->changes
[uri
.uri().str()] = std::move(edits
);
1229 action
.diagnostics
= {diag
};
1231 actions
.emplace_back(std::move(action
));
1235 llvm::Expected
<lsp::MLIRConvertBytecodeResult
>
1236 MLIRTextFile::convertToBytecode() {
1237 // Bail out if there is more than one chunk, bytecode wants a single module.
1238 if (chunks
.size() != 1) {
1239 return llvm::make_error
<lsp::LSPError
>(
1240 "unexpected split file, please remove all `// -----`",
1241 lsp::ErrorCode::RequestFailed
);
1243 return chunks
.front()->document
.convertToBytecode();
1246 MLIRTextFileChunk
&MLIRTextFile::getChunkFor(lsp::Position
&pos
) {
1247 if (chunks
.size() == 1)
1248 return *chunks
.front();
1250 // Search for the first chunk with a greater line offset, the previous chunk
1251 // is the one that contains `pos`.
1252 auto it
= llvm::upper_bound(
1253 chunks
, pos
, [](const lsp::Position
&pos
, const auto &chunk
) {
1254 return static_cast<uint64_t>(pos
.line
) < chunk
->lineOffset
;
1256 MLIRTextFileChunk
&chunk
= it
== chunks
.end() ? *chunks
.back() : **(--it
);
1257 pos
.line
-= chunk
.lineOffset
;
1261 //===----------------------------------------------------------------------===//
1263 //===----------------------------------------------------------------------===//
1265 struct lsp::MLIRServer::Impl
{
1266 Impl(DialectRegistry
®istry
) : registry(registry
) {}
1268 /// The registry containing dialects that can be recognized in parsed .mlir
1270 DialectRegistry
®istry
;
1272 /// The files held by the server, mapped by their URI file name.
1273 llvm::StringMap
<std::unique_ptr
<MLIRTextFile
>> files
;
1276 //===----------------------------------------------------------------------===//
1278 //===----------------------------------------------------------------------===//
1280 lsp::MLIRServer::MLIRServer(DialectRegistry
®istry
)
1281 : impl(std::make_unique
<Impl
>(registry
)) {}
1282 lsp::MLIRServer::~MLIRServer() = default;
1284 void lsp::MLIRServer::addOrUpdateDocument(
1285 const URIForFile
&uri
, StringRef contents
, int64_t version
,
1286 std::vector
<Diagnostic
> &diagnostics
) {
1287 impl
->files
[uri
.file()] = std::make_unique
<MLIRTextFile
>(
1288 uri
, contents
, version
, impl
->registry
, diagnostics
);
1291 std::optional
<int64_t> lsp::MLIRServer::removeDocument(const URIForFile
&uri
) {
1292 auto it
= impl
->files
.find(uri
.file());
1293 if (it
== impl
->files
.end())
1294 return std::nullopt
;
1296 int64_t version
= it
->second
->getVersion();
1297 impl
->files
.erase(it
);
1301 void lsp::MLIRServer::getLocationsOf(const URIForFile
&uri
,
1302 const Position
&defPos
,
1303 std::vector
<Location
> &locations
) {
1304 auto fileIt
= impl
->files
.find(uri
.file());
1305 if (fileIt
!= impl
->files
.end())
1306 fileIt
->second
->getLocationsOf(uri
, defPos
, locations
);
1309 void lsp::MLIRServer::findReferencesOf(const URIForFile
&uri
,
1310 const Position
&pos
,
1311 std::vector
<Location
> &references
) {
1312 auto fileIt
= impl
->files
.find(uri
.file());
1313 if (fileIt
!= impl
->files
.end())
1314 fileIt
->second
->findReferencesOf(uri
, pos
, references
);
1317 std::optional
<lsp::Hover
> lsp::MLIRServer::findHover(const URIForFile
&uri
,
1318 const Position
&hoverPos
) {
1319 auto fileIt
= impl
->files
.find(uri
.file());
1320 if (fileIt
!= impl
->files
.end())
1321 return fileIt
->second
->findHover(uri
, hoverPos
);
1322 return std::nullopt
;
1325 void lsp::MLIRServer::findDocumentSymbols(
1326 const URIForFile
&uri
, std::vector
<DocumentSymbol
> &symbols
) {
1327 auto fileIt
= impl
->files
.find(uri
.file());
1328 if (fileIt
!= impl
->files
.end())
1329 fileIt
->second
->findDocumentSymbols(symbols
);
1333 lsp::MLIRServer::getCodeCompletion(const URIForFile
&uri
,
1334 const Position
&completePos
) {
1335 auto fileIt
= impl
->files
.find(uri
.file());
1336 if (fileIt
!= impl
->files
.end())
1337 return fileIt
->second
->getCodeCompletion(uri
, completePos
);
1338 return CompletionList();
1341 void lsp::MLIRServer::getCodeActions(const URIForFile
&uri
, const Range
&pos
,
1342 const CodeActionContext
&context
,
1343 std::vector
<CodeAction
> &actions
) {
1344 auto fileIt
= impl
->files
.find(uri
.file());
1345 if (fileIt
!= impl
->files
.end())
1346 fileIt
->second
->getCodeActions(uri
, pos
, context
, actions
);
1349 llvm::Expected
<lsp::MLIRConvertBytecodeResult
>
1350 lsp::MLIRServer::convertFromBytecode(const URIForFile
&uri
) {
1351 MLIRContext
tempContext(impl
->registry
);
1352 tempContext
.allowUnregisteredDialects();
1354 // Collect any errors during parsing.
1355 std::string errorMsg
;
1356 ScopedDiagnosticHandler
diagHandler(
1358 [&](mlir::Diagnostic
&diag
) { errorMsg
+= diag
.str() + "\n"; });
1360 // Handling for external resources, which we want to propagate up to the user.
1361 FallbackAsmResourceMap fallbackResourceMap
;
1363 // Setup the parser config.
1364 ParserConfig
parserConfig(&tempContext
, /*verifyAfterParse=*/true,
1365 &fallbackResourceMap
);
1367 // Try to parse the given source file.
1369 if (failed(parseSourceFile(uri
.file(), &parsedBlock
, parserConfig
))) {
1370 return llvm::make_error
<lsp::LSPError
>(
1371 "failed to parse bytecode source file: " + errorMsg
,
1372 lsp::ErrorCode::RequestFailed
);
1375 // TODO: We currently expect a single top-level operation, but this could
1376 // conceptually be relaxed.
1377 if (!llvm::hasSingleElement(parsedBlock
)) {
1378 return llvm::make_error
<lsp::LSPError
>(
1379 "expected bytecode to contain a single top-level operation",
1380 lsp::ErrorCode::RequestFailed
);
1383 // Print the module to a buffer.
1384 lsp::MLIRConvertBytecodeResult result
;
1386 // Extract the top-level op so that aliases get printed.
1387 // FIXME: We should be able to enable aliases without having to do this!
1388 OwningOpRef
<Operation
*> topOp
= &parsedBlock
.front();
1391 AsmState
state(*topOp
, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392 /*locationMap=*/nullptr, &fallbackResourceMap
);
1394 llvm::raw_string_ostream
os(result
.output
);
1395 topOp
->print(os
, state
);
1397 return std::move(result
);
1400 llvm::Expected
<lsp::MLIRConvertBytecodeResult
>
1401 lsp::MLIRServer::convertToBytecode(const URIForFile
&uri
) {
1402 auto fileIt
= impl
->files
.find(uri
.file());
1403 if (fileIt
== impl
->files
.end()) {
1404 return llvm::make_error
<lsp::LSPError
>(
1405 "language server does not contain an entry for this source file",
1406 lsp::ErrorCode::RequestFailed
);
1408 return fileIt
->second
->convertToBytecode();