1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/gnubby_auth_handler_posix.h"
10 #include "base/bind.h"
11 #include "base/files/file_util.h"
12 #include "base/json/json_reader.h"
13 #include "base/json/json_writer.h"
14 #include "base/lazy_instance.h"
15 #include "base/logging.h"
16 #include "base/stl_util.h"
17 #include "base/threading/thread_restrictions.h"
18 #include "base/values.h"
19 #include "net/base/net_errors.h"
20 #include "net/socket/unix_domain_server_socket_posix.h"
21 #include "remoting/base/logging.h"
22 #include "remoting/host/gnubby_socket.h"
23 #include "remoting/proto/control.pb.h"
24 #include "remoting/protocol/client_stub.h"
30 const char kConnectionId
[] = "connectionId";
31 const char kControlMessage
[] = "control";
32 const char kControlOption
[] = "option";
33 const char kDataMessage
[] = "data";
34 const char kDataPayload
[] = "data";
35 const char kErrorMessage
[] = "error";
36 const char kGnubbyAuthMessage
[] = "gnubby-auth";
37 const char kGnubbyAuthV1
[] = "auth-v1";
38 const char kMessageType
[] = "type";
40 const int64 kDefaultRequestTimeoutSeconds
= 60;
42 // The name of the socket to listen for gnubby requests on.
43 base::LazyInstance
<base::FilePath
>::Leaky g_gnubby_socket_name
=
44 LAZY_INSTANCE_INITIALIZER
;
46 // Socket authentication function that only allows connections from callers with
48 bool MatchUid(const net::UnixDomainServerSocket::Credentials
& credentials
) {
49 bool allowed
= credentials
.user_id
== getuid();
51 HOST_LOG
<< "Refused socket connection from uid " << credentials
.user_id
;
55 // Returns the command code (the first byte of the data) if it exists, or -1 if
57 unsigned int GetCommandCode(const std::string
& data
) {
58 return data
.empty() ? -1 : static_cast<unsigned int>(data
[0]);
61 // Creates a string of byte data from a ListValue of numbers. Returns true if
62 // all of the list elements are numbers.
63 bool ConvertListValueToString(base::ListValue
* bytes
, std::string
* out
) {
66 unsigned int byte_count
= bytes
->GetSize();
67 if (byte_count
!= 0) {
68 out
->reserve(byte_count
);
69 for (unsigned int i
= 0; i
< byte_count
; i
++) {
71 if (!bytes
->GetInteger(i
, &value
))
73 out
->push_back(static_cast<char>(value
));
81 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
82 protocol::ClientStub
* client_stub
)
83 : client_stub_(client_stub
),
84 last_connection_id_(0),
86 base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds
)) {
90 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
91 STLDeleteValues(&active_sockets_
);
95 scoped_ptr
<GnubbyAuthHandler
> GnubbyAuthHandler::Create(
96 protocol::ClientStub
* client_stub
) {
97 return make_scoped_ptr(new GnubbyAuthHandlerPosix(client_stub
));
101 void GnubbyAuthHandler::SetGnubbySocketName(
102 const base::FilePath
& gnubby_socket_name
) {
103 g_gnubby_socket_name
.Get() = gnubby_socket_name
;
106 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string
& message
) {
107 DCHECK(CalledOnValidThread());
109 scoped_ptr
<base::Value
> value
= base::JSONReader::Read(message
);
110 base::DictionaryValue
* client_message
;
111 if (value
&& value
->GetAsDictionary(&client_message
)) {
113 if (!client_message
->GetString(kMessageType
, &type
)) {
114 LOG(ERROR
) << "Invalid gnubby-auth message";
118 if (type
== kControlMessage
) {
120 if (client_message
->GetString(kControlOption
, &option
) &&
121 option
== kGnubbyAuthV1
) {
122 CreateAuthorizationSocket();
124 LOG(ERROR
) << "Invalid gnubby-auth control option";
126 } else if (type
== kDataMessage
) {
127 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
128 if (iter
!= active_sockets_
.end()) {
129 base::ListValue
* bytes
;
130 std::string response
;
131 if (client_message
->GetList(kDataPayload
, &bytes
) &&
132 ConvertListValueToString(bytes
, &response
)) {
133 HOST_LOG
<< "Sending gnubby response: " << GetCommandCode(response
);
134 iter
->second
->SendResponse(response
);
136 LOG(ERROR
) << "Invalid gnubby data";
137 SendErrorAndCloseActiveSocket(iter
);
140 LOG(ERROR
) << "Unknown gnubby-auth data connection";
142 } else if (type
== kErrorMessage
) {
143 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
144 if (iter
!= active_sockets_
.end()) {
145 HOST_LOG
<< "Sending gnubby error";
146 SendErrorAndCloseActiveSocket(iter
);
148 LOG(ERROR
) << "Unknown gnubby-auth error connection";
151 LOG(ERROR
) << "Unknown gnubby-auth message type: " << type
;
156 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
158 const std::string
& data
) const {
159 DCHECK(CalledOnValidThread());
161 base::DictionaryValue request
;
162 request
.SetString(kMessageType
, kDataMessage
);
163 request
.SetInteger(kConnectionId
, connection_id
);
165 base::ListValue
* bytes
= new base::ListValue();
166 for (std::string::const_iterator i
= data
.begin(); i
!= data
.end(); ++i
) {
167 bytes
->AppendInteger(static_cast<unsigned char>(*i
));
169 request
.Set(kDataPayload
, bytes
);
171 std::string request_json
;
172 if (!base::JSONWriter::Write(request
, &request_json
)) {
173 LOG(ERROR
) << "Failed to create request json";
177 protocol::ExtensionMessage message
;
178 message
.set_type(kGnubbyAuthMessage
);
179 message
.set_data(request_json
);
181 client_stub_
->DeliverHostMessage(message
);
184 size_t GnubbyAuthHandlerPosix::GetActiveSocketsMapSizeForTest() const {
185 return active_sockets_
.size();
188 void GnubbyAuthHandlerPosix::SetRequestTimeoutForTest(
189 const base::TimeDelta
& timeout
) {
190 request_timeout_
= timeout
;
193 void GnubbyAuthHandlerPosix::DoAccept() {
194 int result
= auth_socket_
->Accept(
196 base::Bind(&GnubbyAuthHandlerPosix::OnAccepted
, base::Unretained(this)));
197 if (result
!= net::ERR_IO_PENDING
)
201 void GnubbyAuthHandlerPosix::OnAccepted(int result
) {
202 DCHECK(CalledOnValidThread());
203 DCHECK_NE(net::ERR_IO_PENDING
, result
);
206 LOG(ERROR
) << "Error in accepting a new connection";
210 int connection_id
= ++last_connection_id_
;
211 GnubbySocket
* socket
=
212 new GnubbySocket(accept_socket_
.Pass(), request_timeout_
,
213 base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut
,
214 base::Unretained(this), connection_id
));
215 active_sockets_
[connection_id
] = socket
;
216 socket
->StartReadingRequest(
217 base::Bind(&GnubbyAuthHandlerPosix::OnReadComplete
,
218 base::Unretained(this), connection_id
));
220 // Continue accepting new connections.
224 void GnubbyAuthHandlerPosix::OnReadComplete(int connection_id
) {
225 DCHECK(CalledOnValidThread());
227 ActiveSockets::iterator iter
= active_sockets_
.find(connection_id
);
228 DCHECK(iter
!= active_sockets_
.end());
229 std::string request_data
;
230 if (!iter
->second
->GetAndClearRequestData(&request_data
)) {
231 SendErrorAndCloseActiveSocket(iter
);
234 ProcessGnubbyRequest(connection_id
, request_data
);
235 iter
->second
->StartReadingRequest(
236 base::Bind(&GnubbyAuthHandlerPosix::OnReadComplete
,
237 base::Unretained(this), connection_id
));
240 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
241 DCHECK(CalledOnValidThread());
243 if (!g_gnubby_socket_name
.Get().empty()) {
245 // DeleteFile() is a blocking operation, but so is creation of the unix
246 // socket below. Consider moving this class to a different thread if this
247 // causes any problems. See crbug.com/509807 .
248 base::ThreadRestrictions::ScopedAllowIO allow_io
;
250 // If the file already exists, a socket in use error is returned.
251 base::DeleteFile(g_gnubby_socket_name
.Get(), false);
254 HOST_LOG
<< "Listening for gnubby requests on "
255 << g_gnubby_socket_name
.Get().value();
258 new net::UnixDomainServerSocket(base::Bind(MatchUid
), false));
259 int rv
= auth_socket_
->ListenWithAddressAndPort(
260 g_gnubby_socket_name
.Get().value(), 0, 1);
262 LOG(ERROR
) << "Failed to open socket for gnubby requests";
267 HOST_LOG
<< "No gnubby socket name specified";
271 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
273 const std::string
& request_data
) {
274 HOST_LOG
<< "Received gnubby request: " << GetCommandCode(request_data
);
275 DeliverHostDataMessage(connection_id
, request_data
);
278 GnubbyAuthHandlerPosix::ActiveSockets::iterator
279 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue
* message
) {
281 if (message
->GetInteger(kConnectionId
, &connection_id
)) {
282 return active_sockets_
.find(connection_id
);
284 return active_sockets_
.end();
287 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
288 const ActiveSockets::iterator
& iter
) {
289 iter
->second
->SendSshError();
291 active_sockets_
.erase(iter
);
294 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id
) {
295 HOST_LOG
<< "Gnubby request timed out";
296 ActiveSockets::iterator iter
= active_sockets_
.find(connection_id
);
297 if (iter
!= active_sockets_
.end())
298 SendErrorAndCloseActiveSocket(iter
);
301 } // namespace remoting