1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/gnubby_auth_handler_posix.h"
10 #include "base/bind.h"
11 #include "base/files/file_util.h"
12 #include "base/json/json_reader.h"
13 #include "base/json/json_writer.h"
14 #include "base/lazy_instance.h"
15 #include "base/logging.h"
16 #include "base/stl_util.h"
17 #include "base/values.h"
18 #include "net/base/net_errors.h"
19 #include "net/socket/unix_domain_server_socket_posix.h"
20 #include "remoting/base/logging.h"
21 #include "remoting/host/gnubby_socket.h"
22 #include "remoting/proto/control.pb.h"
23 #include "remoting/protocol/client_stub.h"
29 const char kConnectionId
[] = "connectionId";
30 const char kControlMessage
[] = "control";
31 const char kControlOption
[] = "option";
32 const char kDataMessage
[] = "data";
33 const char kDataPayload
[] = "data";
34 const char kErrorMessage
[] = "error";
35 const char kGnubbyAuthMessage
[] = "gnubby-auth";
36 const char kGnubbyAuthV1
[] = "auth-v1";
37 const char kMessageType
[] = "type";
39 const int64 kDefaultRequestTimeoutSeconds
= 60;
41 // The name of the socket to listen for gnubby requests on.
42 base::LazyInstance
<base::FilePath
>::Leaky g_gnubby_socket_name
=
43 LAZY_INSTANCE_INITIALIZER
;
45 // Socket authentication function that only allows connections from callers with
47 bool MatchUid(const net::UnixDomainServerSocket::Credentials
& credentials
) {
48 bool allowed
= credentials
.user_id
== getuid();
50 HOST_LOG
<< "Refused socket connection from uid " << credentials
.user_id
;
54 // Returns the command code (the first byte of the data) if it exists, or -1 if
56 unsigned int GetCommandCode(const std::string
& data
) {
57 return data
.empty() ? -1 : static_cast<unsigned int>(data
[0]);
60 // Creates a string of byte data from a ListValue of numbers. Returns true if
61 // all of the list elements are numbers.
62 bool ConvertListValueToString(base::ListValue
* bytes
, std::string
* out
) {
65 unsigned int byte_count
= bytes
->GetSize();
66 if (byte_count
!= 0) {
67 out
->reserve(byte_count
);
68 for (unsigned int i
= 0; i
< byte_count
; i
++) {
70 if (!bytes
->GetInteger(i
, &value
))
72 out
->push_back(static_cast<char>(value
));
80 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
81 protocol::ClientStub
* client_stub
)
82 : client_stub_(client_stub
),
83 last_connection_id_(0),
85 base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds
)) {
89 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
90 STLDeleteValues(&active_sockets_
);
94 scoped_ptr
<GnubbyAuthHandler
> GnubbyAuthHandler::Create(
95 protocol::ClientStub
* client_stub
) {
96 return make_scoped_ptr(new GnubbyAuthHandlerPosix(client_stub
));
100 void GnubbyAuthHandler::SetGnubbySocketName(
101 const base::FilePath
& gnubby_socket_name
) {
102 g_gnubby_socket_name
.Get() = gnubby_socket_name
;
105 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string
& message
) {
106 DCHECK(CalledOnValidThread());
108 scoped_ptr
<base::Value
> value
= base::JSONReader::Read(message
);
109 base::DictionaryValue
* client_message
;
110 if (value
&& value
->GetAsDictionary(&client_message
)) {
112 if (!client_message
->GetString(kMessageType
, &type
)) {
113 LOG(ERROR
) << "Invalid gnubby-auth message";
117 if (type
== kControlMessage
) {
119 if (client_message
->GetString(kControlOption
, &option
) &&
120 option
== kGnubbyAuthV1
) {
121 CreateAuthorizationSocket();
123 LOG(ERROR
) << "Invalid gnubby-auth control option";
125 } else if (type
== kDataMessage
) {
126 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
127 if (iter
!= active_sockets_
.end()) {
128 base::ListValue
* bytes
;
129 std::string response
;
130 if (client_message
->GetList(kDataPayload
, &bytes
) &&
131 ConvertListValueToString(bytes
, &response
)) {
132 HOST_LOG
<< "Sending gnubby response: " << GetCommandCode(response
);
133 iter
->second
->SendResponse(response
);
135 LOG(ERROR
) << "Invalid gnubby data";
136 SendErrorAndCloseActiveSocket(iter
);
139 LOG(ERROR
) << "Unknown gnubby-auth data connection";
141 } else if (type
== kErrorMessage
) {
142 ActiveSockets::iterator iter
= GetSocketForMessage(client_message
);
143 if (iter
!= active_sockets_
.end()) {
144 HOST_LOG
<< "Sending gnubby error";
145 SendErrorAndCloseActiveSocket(iter
);
147 LOG(ERROR
) << "Unknown gnubby-auth error connection";
150 LOG(ERROR
) << "Unknown gnubby-auth message type: " << type
;
155 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
157 const std::string
& data
) const {
158 DCHECK(CalledOnValidThread());
160 base::DictionaryValue request
;
161 request
.SetString(kMessageType
, kDataMessage
);
162 request
.SetInteger(kConnectionId
, connection_id
);
164 base::ListValue
* bytes
= new base::ListValue();
165 for (std::string::const_iterator i
= data
.begin(); i
!= data
.end(); ++i
) {
166 bytes
->AppendInteger(static_cast<unsigned char>(*i
));
168 request
.Set(kDataPayload
, bytes
);
170 std::string request_json
;
171 if (!base::JSONWriter::Write(request
, &request_json
)) {
172 LOG(ERROR
) << "Failed to create request json";
176 protocol::ExtensionMessage message
;
177 message
.set_type(kGnubbyAuthMessage
);
178 message
.set_data(request_json
);
180 client_stub_
->DeliverHostMessage(message
);
183 size_t GnubbyAuthHandlerPosix::GetActiveSocketsMapSizeForTest() const {
184 return active_sockets_
.size();
187 void GnubbyAuthHandlerPosix::SetRequestTimeoutForTest(
188 const base::TimeDelta
& timeout
) {
189 request_timeout_
= timeout
;
192 void GnubbyAuthHandlerPosix::DoAccept() {
193 int result
= auth_socket_
->Accept(
195 base::Bind(&GnubbyAuthHandlerPosix::OnAccepted
, base::Unretained(this)));
196 if (result
!= net::ERR_IO_PENDING
)
200 void GnubbyAuthHandlerPosix::OnAccepted(int result
) {
201 DCHECK(CalledOnValidThread());
202 DCHECK_NE(net::ERR_IO_PENDING
, result
);
205 LOG(ERROR
) << "Error in accepting a new connection";
209 int connection_id
= ++last_connection_id_
;
210 GnubbySocket
* socket
=
211 new GnubbySocket(accept_socket_
.Pass(), request_timeout_
,
212 base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut
,
213 base::Unretained(this), connection_id
));
214 active_sockets_
[connection_id
] = socket
;
215 socket
->StartReadingRequest(
216 base::Bind(&GnubbyAuthHandlerPosix::OnReadComplete
,
217 base::Unretained(this), connection_id
));
219 // Continue accepting new connections.
223 void GnubbyAuthHandlerPosix::OnReadComplete(int connection_id
) {
224 DCHECK(CalledOnValidThread());
226 ActiveSockets::iterator iter
= active_sockets_
.find(connection_id
);
227 DCHECK(iter
!= active_sockets_
.end());
228 std::string request_data
;
229 if (!iter
->second
->GetAndClearRequestData(&request_data
)) {
230 SendErrorAndCloseActiveSocket(iter
);
233 ProcessGnubbyRequest(connection_id
, request_data
);
234 iter
->second
->StartReadingRequest(
235 base::Bind(&GnubbyAuthHandlerPosix::OnReadComplete
,
236 base::Unretained(this), connection_id
));
239 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
240 DCHECK(CalledOnValidThread());
242 if (!g_gnubby_socket_name
.Get().empty()) {
243 // If the file already exists, a socket in use error is returned.
244 base::DeleteFile(g_gnubby_socket_name
.Get(), false);
246 HOST_LOG
<< "Listening for gnubby requests on "
247 << g_gnubby_socket_name
.Get().value();
250 new net::UnixDomainServerSocket(base::Bind(MatchUid
), false));
251 int rv
= auth_socket_
->ListenWithAddressAndPort(
252 g_gnubby_socket_name
.Get().value(), 0, 1);
254 LOG(ERROR
) << "Failed to open socket for gnubby requests";
259 HOST_LOG
<< "No gnubby socket name specified";
263 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
265 const std::string
& request_data
) {
266 HOST_LOG
<< "Received gnubby request: " << GetCommandCode(request_data
);
267 DeliverHostDataMessage(connection_id
, request_data
);
270 GnubbyAuthHandlerPosix::ActiveSockets::iterator
271 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue
* message
) {
273 if (message
->GetInteger(kConnectionId
, &connection_id
)) {
274 return active_sockets_
.find(connection_id
);
276 return active_sockets_
.end();
279 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
280 const ActiveSockets::iterator
& iter
) {
281 iter
->second
->SendSshError();
283 active_sockets_
.erase(iter
);
286 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id
) {
287 HOST_LOG
<< "Gnubby request timed out";
288 ActiveSockets::iterator iter
= active_sockets_
.find(connection_id
);
289 if (iter
!= active_sockets_
.end())
290 SendErrorAndCloseActiveSocket(iter
);
293 } // namespace remoting