1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/gnubby_auth_handler_posix.h"
10 #include "base/bind.h"
11 #include "base/file_util.h"
12 #include "base/json/json_reader.h"
13 #include "base/json/json_writer.h"
14 #include "base/lazy_instance.h"
15 #include "base/stl_util.h"
16 #include "base/values.h"
17 #include "net/socket/unix_domain_listen_socket_posix.h"
18 #include "remoting/base/logging.h"
19 #include "remoting/host/gnubby_socket.h"
20 #include "remoting/proto/control.pb.h"
21 #include "remoting/protocol/client_stub.h"
27 const char kConnectionId
[] = "connectionId";
28 const char kControlMessage
[] = "control";
29 const char kControlOption
[] = "option";
30 const char kDataMessage
[] = "data";
31 const char kDataPayload
[] = "data";
32 const char kErrorMessage
[] = "error";
33 const char kGnubbyAuthMessage
[] = "gnubby-auth";
34 const char kGnubbyAuthV1
[] = "auth-v1";
35 const char kMessageType
[] = "type";
37 // The name of the socket to listen for gnubby requests on.
38 base::LazyInstance
<base::FilePath
>::Leaky g_gnubby_socket_name
=
39 LAZY_INSTANCE_INITIALIZER
;
41 // STL predicate to match by a StreamListenSocket pointer.
44 explicit CompareSocket(net::StreamListenSocket
* socket
) : socket_(socket
) {}
46 bool operator()(const std::pair
<int, GnubbySocket
*> element
) const {
47 return element
.second
->IsSocket(socket_
);
51 net::StreamListenSocket
* socket_
;
54 // Socket authentication function that only allows connections from callers with
56 bool MatchUid(const net::UnixDomainServerSocket::Credentials
& credentials
) {
57 bool allowed
= credentials
.user_id
== getuid();
59 HOST_LOG
<< "Refused socket connection from uid " << credentials
.user_id
;
63 // Returns the command code (the first byte of the data) if it exists, or -1 if
65 unsigned int GetCommandCode(const std::string
& data
) {
66 return data
.empty() ? -1 : static_cast<unsigned int>(data
[0]);
69 // Creates a string of byte data from a ListValue of numbers. Returns true if
70 // all of the list elements are numbers.
71 bool ConvertListValueToString(base::ListValue
* bytes
, std::string
* out
) {
74 unsigned int byte_count
= bytes
->GetSize();
75 if (byte_count
!= 0) {
76 out
->reserve(byte_count
);
77 for (unsigned int i
= 0; i
< byte_count
; i
++) {
79 if (!bytes
->GetInteger(i
, &value
))
81 out
->push_back(static_cast<char>(value
));
89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
90 protocol::ClientStub
* client_stub
)
91 : client_stub_(client_stub
), last_connection_id_(0) {
95 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
96 STLDeleteValues(&active_sockets_
);
100 scoped_ptr
<GnubbyAuthHandler
> GnubbyAuthHandler::Create(
101 protocol::ClientStub
* client_stub
) {
102 return scoped_ptr
<GnubbyAuthHandler
>(new GnubbyAuthHandlerPosix(client_stub
));
106 void GnubbyAuthHandler::SetGnubbySocketName(
107 const base::FilePath
& gnubby_socket_name
) {
108 g_gnubby_socket_name
.Get() = gnubby_socket_name
;
111 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string
& message
) {
112 DCHECK(CalledOnValidThread());
114 scoped_ptr
<base::Value
> value(base::JSONReader::Read(message
));
115 base::DictionaryValue
* client_message
;
116 if (value
&& value
->GetAsDictionary(&client_message
)) {
118 if (!client_message
->GetString(kMessageType
, &type
)) {
119 LOG(ERROR
) << "Invalid gnubby-auth message";
123 if (type
== kControlMessage
) {
125 if (client_message
->GetString(kControlOption
, &option
) &&
126 option
== kGnubbyAuthV1
) {
127 CreateAuthorizationSocket();
129 LOG(ERROR
) << "Invalid gnubby-auth control option";
131 } else if (type
== kDataMessage
) {
132 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
133 if (iter
!= active_sockets_
.end()) {
134 base::ListValue
* bytes
;
135 std::string response
;
136 if (client_message
->GetList(kDataPayload
, &bytes
) &&
137 ConvertListValueToString(bytes
, &response
)) {
138 HOST_LOG
<< "Sending gnubby response: " << GetCommandCode(response
);
139 iter
->second
->SendResponse(response
);
141 LOG(ERROR
) << "Invalid gnubby data";
142 SendErrorAndCloseActiveSocket(iter
);
145 LOG(ERROR
) << "Unknown gnubby-auth data connection";
147 } else if (type
== kErrorMessage
) {
148 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
149 if (iter
!= active_sockets_
.end()) {
150 HOST_LOG
<< "Sending gnubby error";
151 SendErrorAndCloseActiveSocket(iter
);
153 LOG(ERROR
) << "Unknown gnubby-auth error connection";
156 LOG(ERROR
) << "Unknown gnubby-auth message type: " << type
;
161 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
163 const std::string
& data
) const {
164 DCHECK(CalledOnValidThread());
166 base::DictionaryValue request
;
167 request
.SetString(kMessageType
, kDataMessage
);
168 request
.SetInteger(kConnectionId
, connection_id
);
170 base::ListValue
* bytes
= new base::ListValue();
171 for (std::string::const_iterator i
= data
.begin(); i
!= data
.end(); ++i
) {
172 bytes
->AppendInteger(static_cast<unsigned char>(*i
));
174 request
.Set(kDataPayload
, bytes
);
176 std::string request_json
;
177 if (!base::JSONWriter::Write(&request
, &request_json
)) {
178 LOG(ERROR
) << "Failed to create request json";
182 protocol::ExtensionMessage message
;
183 message
.set_type(kGnubbyAuthMessage
);
184 message
.set_data(request_json
);
186 client_stub_
->DeliverHostMessage(message
);
189 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
190 net::StreamListenSocket
* socket
) const {
191 return std::find_if(active_sockets_
.begin(),
192 active_sockets_
.end(),
193 CompareSocket(socket
)) != active_sockets_
.end();
196 int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
197 net::StreamListenSocket
* socket
) const {
198 ActiveSockets::const_iterator iter
= std::find_if(
199 active_sockets_
.begin(), active_sockets_
.end(), CompareSocket(socket
));
203 GnubbySocket
* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
204 net::StreamListenSocket
* socket
) const {
205 ActiveSockets::const_iterator iter
= std::find_if(
206 active_sockets_
.begin(), active_sockets_
.end(), CompareSocket(socket
));
210 void GnubbyAuthHandlerPosix::DidAccept(
211 net::StreamListenSocket
* server
,
212 scoped_ptr
<net::StreamListenSocket
> socket
) {
213 DCHECK(CalledOnValidThread());
215 int connection_id
= ++last_connection_id_
;
216 active_sockets_
[connection_id
] =
217 new GnubbySocket(socket
.Pass(),
218 base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut
,
219 base::Unretained(this),
223 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket
* socket
,
226 DCHECK(CalledOnValidThread());
228 ActiveSockets::iterator iter
= std::find_if(
229 active_sockets_
.begin(), active_sockets_
.end(), CompareSocket(socket
));
230 if (iter
!= active_sockets_
.end()) {
231 GnubbySocket
* gnubby_socket
= iter
->second
;
232 gnubby_socket
->AddRequestData(data
, len
);
233 if (gnubby_socket
->IsRequestTooLarge()) {
234 SendErrorAndCloseActiveSocket(iter
);
235 } else if (gnubby_socket
->IsRequestComplete()) {
236 std::string request_data
;
237 gnubby_socket
->GetAndClearRequestData(&request_data
);
238 ProcessGnubbyRequest(iter
->first
, request_data
);
241 LOG(ERROR
) << "Received data for unknown connection";
245 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket
* socket
) {
246 DCHECK(CalledOnValidThread());
248 ActiveSockets::iterator iter
= std::find_if(
249 active_sockets_
.begin(), active_sockets_
.end(), CompareSocket(socket
));
250 if (iter
!= active_sockets_
.end()) {
252 active_sockets_
.erase(iter
);
256 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
257 DCHECK(CalledOnValidThread());
259 if (!g_gnubby_socket_name
.Get().empty()) {
260 // If the file already exists, a socket in use error is returned.
261 base::DeleteFile(g_gnubby_socket_name
.Get(), false);
263 HOST_LOG
<< "Listening for gnubby requests on "
264 << g_gnubby_socket_name
.Get().value();
266 auth_socket_
= net::deprecated::UnixDomainListenSocket::CreateAndListen(
267 g_gnubby_socket_name
.Get().value(), this, base::Bind(MatchUid
));
268 if (!auth_socket_
.get()) {
269 LOG(ERROR
) << "Failed to open socket for gnubby requests";
272 HOST_LOG
<< "No gnubby socket name specified";
276 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
278 const std::string
& request_data
) {
279 HOST_LOG
<< "Received gnubby request: " << GetCommandCode(request_data
);
280 DeliverHostDataMessage(connection_id
, request_data
);
283 GnubbyAuthHandlerPosix::ActiveSockets::iterator
284 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue
* message
) {
286 if (message
->GetInteger(kConnectionId
, &connection_id
)) {
287 return active_sockets_
.find(connection_id
);
289 return active_sockets_
.end();
292 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
293 const ActiveSockets::iterator
& iter
) {
294 iter
->second
->SendSshError();
297 active_sockets_
.erase(iter
);
300 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id
) {
301 HOST_LOG
<< "Gnubby request timed out";
302 ActiveSockets::iterator iter
= active_sockets_
.find(connection_id
);
303 if (iter
!= active_sockets_
.end())
304 SendErrorAndCloseActiveSocket(iter
);
307 } // namespace remoting