[mlir] Use StringRef::{starts,ends}_with (NFC)
[llvm-project.git] / mlir / lib / Tools / mlir-lsp-server / LSPServer.cpp
blob0f23366f6fe80a8afdc4605d38377c2c44fabf82
1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "LSPServer.h"
10 #include "MLIRServer.h"
11 #include "Protocol.h"
12 #include "mlir/Tools/lsp-server-support/Logging.h"
13 #include "mlir/Tools/lsp-server-support/Transport.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 #include <optional>
18 #define DEBUG_TYPE "mlir-lsp-server"
20 using namespace mlir;
21 using namespace mlir::lsp;
23 //===----------------------------------------------------------------------===//
24 // LSPServer
25 //===----------------------------------------------------------------------===//
27 namespace {
28 struct LSPServer {
29 LSPServer(MLIRServer &server) : server(server) {}
31 //===--------------------------------------------------------------------===//
32 // Initialization
34 void onInitialize(const InitializeParams &params,
35 Callback<llvm::json::Value> reply);
36 void onInitialized(const InitializedParams &params);
37 void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
39 //===--------------------------------------------------------------------===//
40 // Document Change
42 void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
43 void onDocumentDidClose(const DidCloseTextDocumentParams &params);
44 void onDocumentDidChange(const DidChangeTextDocumentParams &params);
46 //===--------------------------------------------------------------------===//
47 // Definitions and References
49 void onGoToDefinition(const TextDocumentPositionParams &params,
50 Callback<std::vector<Location>> reply);
51 void onReference(const ReferenceParams &params,
52 Callback<std::vector<Location>> reply);
54 //===--------------------------------------------------------------------===//
55 // Hover
57 void onHover(const TextDocumentPositionParams &params,
58 Callback<std::optional<Hover>> reply);
60 //===--------------------------------------------------------------------===//
61 // Document Symbols
63 void onDocumentSymbol(const DocumentSymbolParams &params,
64 Callback<std::vector<DocumentSymbol>> reply);
66 //===--------------------------------------------------------------------===//
67 // Code Completion
69 void onCompletion(const CompletionParams &params,
70 Callback<CompletionList> reply);
72 //===--------------------------------------------------------------------===//
73 // Code Action
75 void onCodeAction(const CodeActionParams &params,
76 Callback<llvm::json::Value> reply);
78 //===--------------------------------------------------------------------===//
79 // Bytecode
81 void onConvertFromBytecode(const MLIRConvertBytecodeParams &params,
82 Callback<MLIRConvertBytecodeResult> reply);
83 void onConvertToBytecode(const MLIRConvertBytecodeParams &params,
84 Callback<MLIRConvertBytecodeResult> reply);
86 //===--------------------------------------------------------------------===//
87 // Fields
88 //===--------------------------------------------------------------------===//
90 MLIRServer &server;
92 /// An outgoing notification used to send diagnostics to the client when they
93 /// are ready to be processed.
94 OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
96 /// Used to indicate that the 'shutdown' request was received from the
97 /// Language Server client.
98 bool shutdownRequestReceived = false;
100 } // namespace
102 //===----------------------------------------------------------------------===//
103 // Initialization
105 void LSPServer::onInitialize(const InitializeParams &params,
106 Callback<llvm::json::Value> reply) {
107 // Send a response with the capabilities of this server.
108 llvm::json::Object serverCaps{
109 {"textDocumentSync",
110 llvm::json::Object{
111 {"openClose", true},
112 {"change", (int)TextDocumentSyncKind::Full},
113 {"save", true},
115 {"completionProvider",
116 llvm::json::Object{
117 {"allCommitCharacters",
119 "\t",
120 ";",
121 ",",
122 ".",
123 "=",
125 {"resolveProvider", false},
126 {"triggerCharacters",
127 {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
129 {"definitionProvider", true},
130 {"referencesProvider", true},
131 {"hoverProvider", true},
133 // For now we only support documenting symbols when the client supports
134 // hierarchical symbols.
135 {"documentSymbolProvider",
136 params.capabilities.hierarchicalDocumentSymbol},
139 // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
140 // CodeActionOptions is only valid if the client supports action literal
141 // via textDocument.codeAction.codeActionLiteralSupport.
142 serverCaps["codeActionProvider"] =
143 params.capabilities.codeActionStructure
144 ? llvm::json::Object{{"codeActionKinds",
145 {CodeAction::kQuickFix, CodeAction::kRefactor,
146 CodeAction::kInfo}}}
147 : llvm::json::Value(true);
149 llvm::json::Object result{
150 {{"serverInfo",
151 llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
152 {"capabilities", std::move(serverCaps)}}};
153 reply(std::move(result));
155 void LSPServer::onInitialized(const InitializedParams &) {}
156 void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
157 shutdownRequestReceived = true;
158 reply(nullptr);
161 //===----------------------------------------------------------------------===//
162 // Document Change
164 void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams &params) {
165 PublishDiagnosticsParams diagParams(params.textDocument.uri,
166 params.textDocument.version);
167 server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
168 params.textDocument.version,
169 diagParams.diagnostics);
171 // Publish any recorded diagnostics.
172 publishDiagnostics(diagParams);
174 void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams &params) {
175 std::optional<int64_t> version =
176 server.removeDocument(params.textDocument.uri);
177 if (!version)
178 return;
180 // Empty out the diagnostics shown for this document. This will clear out
181 // anything currently displayed by the client for this document (e.g. in the
182 // "Problems" pane of VSCode).
183 publishDiagnostics(
184 PublishDiagnosticsParams(params.textDocument.uri, *version));
186 void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams &params) {
187 // TODO: We currently only support full document updates, we should refactor
188 // to avoid this.
189 if (params.contentChanges.size() != 1)
190 return;
191 PublishDiagnosticsParams diagParams(params.textDocument.uri,
192 params.textDocument.version);
193 server.addOrUpdateDocument(
194 params.textDocument.uri, params.contentChanges.front().text,
195 params.textDocument.version, diagParams.diagnostics);
197 // Publish any recorded diagnostics.
198 publishDiagnostics(diagParams);
201 //===----------------------------------------------------------------------===//
202 // Definitions and References
204 void LSPServer::onGoToDefinition(const TextDocumentPositionParams &params,
205 Callback<std::vector<Location>> reply) {
206 std::vector<Location> locations;
207 server.getLocationsOf(params.textDocument.uri, params.position, locations);
208 reply(std::move(locations));
211 void LSPServer::onReference(const ReferenceParams &params,
212 Callback<std::vector<Location>> reply) {
213 std::vector<Location> locations;
214 server.findReferencesOf(params.textDocument.uri, params.position, locations);
215 reply(std::move(locations));
218 //===----------------------------------------------------------------------===//
219 // Hover
221 void LSPServer::onHover(const TextDocumentPositionParams &params,
222 Callback<std::optional<Hover>> reply) {
223 reply(server.findHover(params.textDocument.uri, params.position));
226 //===----------------------------------------------------------------------===//
227 // Document Symbols
229 void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
230 Callback<std::vector<DocumentSymbol>> reply) {
231 std::vector<DocumentSymbol> symbols;
232 server.findDocumentSymbols(params.textDocument.uri, symbols);
233 reply(std::move(symbols));
236 //===----------------------------------------------------------------------===//
237 // Code Completion
239 void LSPServer::onCompletion(const CompletionParams &params,
240 Callback<CompletionList> reply) {
241 reply(server.getCodeCompletion(params.textDocument.uri, params.position));
244 //===----------------------------------------------------------------------===//
245 // Code Action
247 void LSPServer::onCodeAction(const CodeActionParams &params,
248 Callback<llvm::json::Value> reply) {
249 URIForFile uri = params.textDocument.uri;
251 // Check whether a particular CodeActionKind is included in the response.
252 auto isKindAllowed = [only(params.context.only)](StringRef kind) {
253 if (only.empty())
254 return true;
255 return llvm::any_of(only, [&](StringRef base) {
256 return kind.consume_front(base) &&
257 (kind.empty() || kind.starts_with("."));
261 // We provide a code action for fixes on the specified diagnostics.
262 std::vector<CodeAction> actions;
263 if (isKindAllowed(CodeAction::kQuickFix))
264 server.getCodeActions(uri, params.range.start, params.context, actions);
265 reply(std::move(actions));
268 //===----------------------------------------------------------------------===//
269 // Bytecode
271 void LSPServer::onConvertFromBytecode(
272 const MLIRConvertBytecodeParams &params,
273 Callback<MLIRConvertBytecodeResult> reply) {
274 reply(server.convertFromBytecode(params.uri));
277 void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams &params,
278 Callback<MLIRConvertBytecodeResult> reply) {
279 reply(server.convertToBytecode(params.uri));
282 //===----------------------------------------------------------------------===//
283 // Entry point
284 //===----------------------------------------------------------------------===//
286 LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
287 JSONTransport &transport) {
288 LSPServer lspServer(server);
289 MessageHandler messageHandler(transport);
291 // Initialization
292 messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
293 messageHandler.notification("initialized", &lspServer,
294 &LSPServer::onInitialized);
295 messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);
297 // Document Changes
298 messageHandler.notification("textDocument/didOpen", &lspServer,
299 &LSPServer::onDocumentDidOpen);
300 messageHandler.notification("textDocument/didClose", &lspServer,
301 &LSPServer::onDocumentDidClose);
302 messageHandler.notification("textDocument/didChange", &lspServer,
303 &LSPServer::onDocumentDidChange);
305 // Definitions and References
306 messageHandler.method("textDocument/definition", &lspServer,
307 &LSPServer::onGoToDefinition);
308 messageHandler.method("textDocument/references", &lspServer,
309 &LSPServer::onReference);
311 // Hover
312 messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);
314 // Document Symbols
315 messageHandler.method("textDocument/documentSymbol", &lspServer,
316 &LSPServer::onDocumentSymbol);
318 // Code Completion
319 messageHandler.method("textDocument/completion", &lspServer,
320 &LSPServer::onCompletion);
322 // Code Action
323 messageHandler.method("textDocument/codeAction", &lspServer,
324 &LSPServer::onCodeAction);
326 // Bytecode
327 messageHandler.method("mlir/convertFromBytecode", &lspServer,
328 &LSPServer::onConvertFromBytecode);
329 messageHandler.method("mlir/convertToBytecode", &lspServer,
330 &LSPServer::onConvertToBytecode);
332 // Diagnostics
333 lspServer.publishDiagnostics =
334 messageHandler.outgoingNotification<PublishDiagnosticsParams>(
335 "textDocument/publishDiagnostics");
337 // Run the main loop of the transport.
338 LogicalResult result = success();
339 if (llvm::Error error = transport.run(messageHandler)) {
340 Logger::error("Transport error: {0}", error);
341 llvm::consumeError(std::move(error));
342 result = failure();
343 } else {
344 result = success(lspServer.shutdownRequestReceived);
346 return result;