1 //===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
12 #include "Transport.h"
13 #include "support/Logger.h"
14 #include "support/Threading.h"
15 #include "llvm/Support/Path.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include "gtest/gtest.h"
18 #include <condition_variable>
25 llvm::Expected
<llvm::json::Value
> clang::clangd::LSPClient::CallResult::take() {
26 std::unique_lock
<std::mutex
> Lock(Mu
);
27 if (!clangd::wait(Lock
, CV
, timeoutSeconds(10),
28 [this] { return Value
.has_value(); })) {
29 ADD_FAILURE() << "No result from call after 10 seconds!";
30 return llvm::json::Value(nullptr);
32 auto Res
= std::move(*Value
);
37 llvm::json::Value
LSPClient::CallResult::takeValue() {
38 auto ExpValue
= take();
40 ADD_FAILURE() << "takeValue(): " << llvm::toString(ExpValue
.takeError());
41 return llvm::json::Value(nullptr);
43 return std::move(*ExpValue
);
46 void LSPClient::CallResult::set(llvm::Expected
<llvm::json::Value
> V
) {
47 std::lock_guard
<std::mutex
> Lock(Mu
);
49 ADD_FAILURE() << "Multiple replies";
50 llvm::consumeError(V
.takeError());
57 LSPClient::CallResult::~CallResult() {
58 if (Value
&& !*Value
) {
59 ADD_FAILURE() << llvm::toString(Value
->takeError());
63 static void logBody(llvm::StringRef Method
, llvm::json::Value V
, bool Send
) {
64 // We invert <<< and >>> as the combined log is from the server's viewpoint.
65 vlog("{0} {1}: {2:2}", Send
? "<<<" : ">>>", Method
, V
);
68 class LSPClient::TransportImpl
: public Transport
{
70 std::pair
<llvm::json::Value
, CallResult
*> addCallSlot() {
71 std::lock_guard
<std::mutex
> Lock(Mu
);
72 unsigned ID
= CallResults
.size();
73 CallResults
.emplace_back();
74 return {ID
, &CallResults
.back()};
77 // A null action causes the transport to shut down.
78 void enqueue(std::function
<void(MessageHandler
&)> Action
) {
79 std::lock_guard
<std::mutex
> Lock(Mu
);
80 Actions
.push(std::move(Action
));
84 std::vector
<llvm::json::Value
> takeNotifications(llvm::StringRef Method
) {
85 std::vector
<llvm::json::Value
> Result
;
87 std::lock_guard
<std::mutex
> Lock(Mu
);
88 std::swap(Result
, Notifications
[Method
]);
94 void reply(llvm::json::Value ID
,
95 llvm::Expected
<llvm::json::Value
> V
) override
{
96 if (V
) // Nothing additional to log for error.
97 logBody("reply", *V
, /*Send=*/false);
98 std::lock_guard
<std::mutex
> Lock(Mu
);
99 if (auto I
= ID
.getAsInteger()) {
100 if (*I
>= 0 && *I
< static_cast<int64_t>(CallResults
.size())) {
101 CallResults
[*I
].set(std::move(V
));
105 ADD_FAILURE() << "Invalid reply to ID " << ID
;
106 llvm::consumeError(std::move(V
).takeError());
109 void notify(llvm::StringRef Method
, llvm::json::Value V
) override
{
110 logBody(Method
, V
, /*Send=*/false);
111 std::lock_guard
<std::mutex
> Lock(Mu
);
112 Notifications
[Method
].push_back(std::move(V
));
115 void call(llvm::StringRef Method
, llvm::json::Value Params
,
116 llvm::json::Value ID
) override
{
117 logBody(Method
, Params
, /*Send=*/false);
118 ADD_FAILURE() << "Unexpected server->client call " << Method
;
121 llvm::Error
loop(MessageHandler
&H
) override
{
122 std::unique_lock
<std::mutex
> Lock(Mu
);
124 CV
.wait(Lock
, [&] { return !Actions
.empty(); });
125 if (!Actions
.front()) // Stop!
126 return llvm::Error::success();
127 auto Action
= std::move(Actions
.front());
136 std::deque
<CallResult
> CallResults
;
137 std::queue
<std::function
<void(Transport::MessageHandler
&)>> Actions
;
138 std::condition_variable CV
;
139 llvm::StringMap
<std::vector
<llvm::json::Value
>> Notifications
;
142 LSPClient::LSPClient() : T(std::make_unique
<TransportImpl
>()) {}
143 LSPClient::~LSPClient() = default;
145 LSPClient::CallResult
&LSPClient::call(llvm::StringRef Method
,
146 llvm::json::Value Params
) {
147 auto Slot
= T
->addCallSlot();
148 T
->enqueue([ID(Slot
.first
), Method(Method
.str()),
149 Params(std::move(Params
))](Transport::MessageHandler
&H
) {
150 logBody(Method
, Params
, /*Send=*/true);
151 H
.onCall(Method
, std::move(Params
), ID
);
156 void LSPClient::notify(llvm::StringRef Method
, llvm::json::Value Params
) {
157 T
->enqueue([Method(Method
.str()),
158 Params(std::move(Params
))](Transport::MessageHandler
&H
) {
159 logBody(Method
, Params
, /*Send=*/true);
160 H
.onNotify(Method
, std::move(Params
));
164 std::vector
<llvm::json::Value
>
165 LSPClient::takeNotifications(llvm::StringRef Method
) {
166 return T
->takeNotifications(Method
);
169 void LSPClient::stop() { T
->enqueue(nullptr); }
171 Transport
&LSPClient::transport() { return *T
; }
173 using Obj
= llvm::json::Object
;
175 llvm::json::Value
LSPClient::uri(llvm::StringRef Path
) {
177 if (!llvm::sys::path::is_absolute(Path
))
178 Path
= Storage
= testPath(Path
);
179 return toJSON(URIForFile::canonicalize(Path
, Path
));
181 llvm::json::Value
LSPClient::documentID(llvm::StringRef Path
) {
182 return Obj
{{"uri", uri(Path
)}};
185 void LSPClient::didOpen(llvm::StringRef Path
, llvm::StringRef Content
) {
187 "textDocument/didOpen",
189 Obj
{{"uri", uri(Path
)}, {"text", Content
}, {"languageId", "cpp"}}}});
191 void LSPClient::didChange(llvm::StringRef Path
, llvm::StringRef Content
) {
192 notify("textDocument/didChange",
193 Obj
{{"textDocument", documentID(Path
)},
194 {"contentChanges", llvm::json::Array
{Obj
{{"text", Content
}}}}});
196 void LSPClient::didClose(llvm::StringRef Path
) {
197 notify("textDocument/didClose", Obj
{{"textDocument", documentID(Path
)}});
200 void LSPClient::sync() { call("sync", nullptr).takeValue(); }
202 std::optional
<std::vector
<llvm::json::Value
>>
203 LSPClient::diagnostics(llvm::StringRef Path
) {
205 auto Notifications
= takeNotifications("textDocument/publishDiagnostics");
206 for (const auto &Notification
: llvm::reverse(Notifications
)) {
207 if (const auto *PubDiagsParams
= Notification
.getAsObject()) {
208 auto U
= PubDiagsParams
->getString("uri");
209 auto *D
= PubDiagsParams
->getArray("diagnostics");
211 ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams
;
215 return std::vector
<llvm::json::Value
>(D
->begin(), D
->end());
221 } // namespace clangd