1 //===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
12 #include "Transport.h"
13 #include "support/Logger.h"
14 #include "support/Threading.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringMap.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/Error.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/Path.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "gtest/gtest.h"
23 #include <condition_variable>
39 llvm::Expected
<llvm::json::Value
> clang::clangd::LSPClient::CallResult::take() {
40 std::unique_lock
<std::mutex
> Lock(Mu
);
41 static constexpr size_t TimeoutSecs
= 60;
42 if (!clangd::wait(Lock
, CV
, timeoutSeconds(TimeoutSecs
),
43 [this] { return Value
.has_value(); })) {
44 ADD_FAILURE() << "No result from call after " << TimeoutSecs
<< " seconds!";
45 return llvm::json::Value(nullptr);
47 auto Res
= std::move(*Value
);
52 llvm::json::Value
LSPClient::CallResult::takeValue() {
53 auto ExpValue
= take();
55 ADD_FAILURE() << "takeValue(): " << llvm::toString(ExpValue
.takeError());
56 return llvm::json::Value(nullptr);
58 return std::move(*ExpValue
);
61 void LSPClient::CallResult::set(llvm::Expected
<llvm::json::Value
> V
) {
62 std::lock_guard
<std::mutex
> Lock(Mu
);
64 ADD_FAILURE() << "Multiple replies";
65 llvm::consumeError(V
.takeError());
72 LSPClient::CallResult::~CallResult() {
73 if (Value
&& !*Value
) {
74 ADD_FAILURE() << llvm::toString(Value
->takeError());
78 static void logBody(llvm::StringRef Method
, llvm::json::Value V
, bool Send
) {
79 // We invert <<< and >>> as the combined log is from the server's viewpoint.
80 vlog("{0} {1}: {2:2}", Send
? "<<<" : ">>>", Method
, V
);
83 class LSPClient::TransportImpl
: public Transport
{
85 std::pair
<llvm::json::Value
, CallResult
*> addCallSlot() {
86 std::lock_guard
<std::mutex
> Lock(Mu
);
87 unsigned ID
= CallResults
.size();
88 CallResults
.emplace_back();
89 return {ID
, &CallResults
.back()};
92 // A null action causes the transport to shut down.
93 void enqueue(std::function
<void(MessageHandler
&)> Action
) {
94 std::lock_guard
<std::mutex
> Lock(Mu
);
95 Actions
.push(std::move(Action
));
99 std::vector
<llvm::json::Value
> takeNotifications(llvm::StringRef Method
) {
100 std::vector
<llvm::json::Value
> Result
;
102 std::lock_guard
<std::mutex
> Lock(Mu
);
103 std::swap(Result
, Notifications
[Method
]);
108 void expectCall(llvm::StringRef Method
) {
109 std::lock_guard
<std::mutex
> Lock(Mu
);
113 std::vector
<llvm::json::Value
> takeCallParams(llvm::StringRef Method
) {
114 std::vector
<llvm::json::Value
> Result
;
116 std::lock_guard
<std::mutex
> Lock(Mu
);
117 std::swap(Result
, Calls
[Method
]);
123 void reply(llvm::json::Value ID
,
124 llvm::Expected
<llvm::json::Value
> V
) override
{
125 if (V
) // Nothing additional to log for error.
126 logBody("reply", *V
, /*Send=*/false);
127 std::lock_guard
<std::mutex
> Lock(Mu
);
128 if (auto I
= ID
.getAsInteger()) {
129 if (*I
>= 0 && *I
< static_cast<int64_t>(CallResults
.size())) {
130 CallResults
[*I
].set(std::move(V
));
134 ADD_FAILURE() << "Invalid reply to ID " << ID
;
135 llvm::consumeError(std::move(V
).takeError());
138 void notify(llvm::StringRef Method
, llvm::json::Value V
) override
{
139 logBody(Method
, V
, /*Send=*/false);
140 std::lock_guard
<std::mutex
> Lock(Mu
);
141 Notifications
[Method
].push_back(std::move(V
));
144 void call(llvm::StringRef Method
, llvm::json::Value Params
,
145 llvm::json::Value ID
) override
{
146 logBody(Method
, Params
, /*Send=*/false);
147 std::lock_guard
<std::mutex
> Lock(Mu
);
148 if (Calls
.contains(Method
)) {
149 Calls
[Method
].push_back(std::move(Params
));
151 ADD_FAILURE() << "Unexpected server->client call " << Method
;
155 llvm::Error
loop(MessageHandler
&H
) override
{
156 std::unique_lock
<std::mutex
> Lock(Mu
);
158 CV
.wait(Lock
, [&] { return !Actions
.empty(); });
159 if (!Actions
.front()) // Stop!
160 return llvm::Error::success();
161 auto Action
= std::move(Actions
.front());
170 std::deque
<CallResult
> CallResults
;
171 std::queue
<std::function
<void(Transport::MessageHandler
&)>> Actions
;
172 std::condition_variable CV
;
173 llvm::StringMap
<std::vector
<llvm::json::Value
>> Notifications
;
174 llvm::StringMap
<std::vector
<llvm::json::Value
>> Calls
;
177 LSPClient::LSPClient() : T(std::make_unique
<TransportImpl
>()) {}
178 LSPClient::~LSPClient() = default;
180 LSPClient::CallResult
&LSPClient::call(llvm::StringRef Method
,
181 llvm::json::Value Params
) {
182 auto Slot
= T
->addCallSlot();
183 T
->enqueue([ID(Slot
.first
), Method(Method
.str()),
184 Params(std::move(Params
))](Transport::MessageHandler
&H
) {
185 logBody(Method
, Params
, /*Send=*/true);
186 H
.onCall(Method
, std::move(Params
), ID
);
191 void LSPClient::expectServerCall(llvm::StringRef Method
) {
192 T
->expectCall(Method
);
195 void LSPClient::notify(llvm::StringRef Method
, llvm::json::Value Params
) {
196 T
->enqueue([Method(Method
.str()),
197 Params(std::move(Params
))](Transport::MessageHandler
&H
) {
198 logBody(Method
, Params
, /*Send=*/true);
199 H
.onNotify(Method
, std::move(Params
));
203 std::vector
<llvm::json::Value
>
204 LSPClient::takeNotifications(llvm::StringRef Method
) {
205 return T
->takeNotifications(Method
);
208 std::vector
<llvm::json::Value
>
209 LSPClient::takeCallParams(llvm::StringRef Method
) {
210 return T
->takeCallParams(Method
);
213 void LSPClient::stop() { T
->enqueue(nullptr); }
215 Transport
&LSPClient::transport() { return *T
; }
217 using Obj
= llvm::json::Object
;
219 llvm::json::Value
LSPClient::uri(llvm::StringRef Path
) {
221 if (!llvm::sys::path::is_absolute(Path
))
222 Path
= Storage
= testPath(Path
);
223 return toJSON(URIForFile::canonicalize(Path
, Path
));
225 llvm::json::Value
LSPClient::documentID(llvm::StringRef Path
) {
226 return Obj
{{"uri", uri(Path
)}};
229 void LSPClient::didOpen(llvm::StringRef Path
, llvm::StringRef Content
) {
231 "textDocument/didOpen",
233 Obj
{{"uri", uri(Path
)}, {"text", Content
}, {"languageId", "cpp"}}}});
235 void LSPClient::didChange(llvm::StringRef Path
, llvm::StringRef Content
) {
236 notify("textDocument/didChange",
237 Obj
{{"textDocument", documentID(Path
)},
238 {"contentChanges", llvm::json::Array
{Obj
{{"text", Content
}}}}});
240 void LSPClient::didClose(llvm::StringRef Path
) {
241 notify("textDocument/didClose", Obj
{{"textDocument", documentID(Path
)}});
244 void LSPClient::sync() { call("sync", nullptr).takeValue(); }
246 std::optional
<std::vector
<llvm::json::Value
>>
247 LSPClient::diagnostics(llvm::StringRef Path
) {
249 auto Notifications
= takeNotifications("textDocument/publishDiagnostics");
250 for (const auto &Notification
: llvm::reverse(Notifications
)) {
251 if (const auto *PubDiagsParams
= Notification
.getAsObject()) {
252 auto U
= PubDiagsParams
->getString("uri");
253 auto *D
= PubDiagsParams
->getArray("diagnostics");
255 ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams
;
259 return std::vector
<llvm::json::Value
>(D
->begin(), D
->end());
265 } // namespace clangd