Roll src/third_party/WebKit f36d5e0:68b67cd (svn 193299:193303)
[chromium-blink-merge.git] / remoting / host / gnubby_auth_handler_posix.cc
blob13c8f41c7e12f9344e47ac21df749af8c4865ade
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/gnubby_auth_handler_posix.h"
7 #include <unistd.h>
8 #include <utility>
10 #include "base/bind.h"
11 #include "base/files/file_util.h"
12 #include "base/json/json_reader.h"
13 #include "base/json/json_writer.h"
14 #include "base/lazy_instance.h"
15 #include "base/stl_util.h"
16 #include "base/values.h"
17 #include "net/socket/unix_domain_listen_socket_posix.h"
18 #include "remoting/base/logging.h"
19 #include "remoting/host/gnubby_socket.h"
20 #include "remoting/proto/control.pb.h"
21 #include "remoting/protocol/client_stub.h"
23 namespace remoting {
25 namespace {
27 const char kConnectionId[] = "connectionId";
28 const char kControlMessage[] = "control";
29 const char kControlOption[] = "option";
30 const char kDataMessage[] = "data";
31 const char kDataPayload[] = "data";
32 const char kErrorMessage[] = "error";
33 const char kGnubbyAuthMessage[] = "gnubby-auth";
34 const char kGnubbyAuthV1[] = "auth-v1";
35 const char kMessageType[] = "type";
37 // The name of the socket to listen for gnubby requests on.
38 base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name =
39 LAZY_INSTANCE_INITIALIZER;
41 // STL predicate to match by a StreamListenSocket pointer.
42 class CompareSocket {
43 public:
44 explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}
46 bool operator()(const std::pair<int, GnubbySocket*> element) const {
47 return element.second->IsSocket(socket_);
50 private:
51 net::StreamListenSocket* socket_;
54 // Socket authentication function that only allows connections from callers with
55 // the current uid.
56 bool MatchUid(const net::UnixDomainServerSocket::Credentials& credentials) {
57 bool allowed = credentials.user_id == getuid();
58 if (!allowed)
59 HOST_LOG << "Refused socket connection from uid " << credentials.user_id;
60 return allowed;
63 // Returns the command code (the first byte of the data) if it exists, or -1 if
64 // the data is empty.
65 unsigned int GetCommandCode(const std::string& data) {
66 return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
69 // Creates a string of byte data from a ListValue of numbers. Returns true if
70 // all of the list elements are numbers.
71 bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
72 out->clear();
74 unsigned int byte_count = bytes->GetSize();
75 if (byte_count != 0) {
76 out->reserve(byte_count);
77 for (unsigned int i = 0; i < byte_count; i++) {
78 int value;
79 if (!bytes->GetInteger(i, &value))
80 return false;
81 out->push_back(static_cast<char>(value));
84 return true;
87 } // namespace
89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
90 protocol::ClientStub* client_stub)
91 : client_stub_(client_stub), last_connection_id_(0) {
92 DCHECK(client_stub_);
95 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
96 STLDeleteValues(&active_sockets_);
99 // static
100 scoped_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create(
101 protocol::ClientStub* client_stub) {
102 return make_scoped_ptr(new GnubbyAuthHandlerPosix(client_stub));
105 // static
106 void GnubbyAuthHandler::SetGnubbySocketName(
107 const base::FilePath& gnubby_socket_name) {
108 g_gnubby_socket_name.Get() = gnubby_socket_name;
111 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
112 DCHECK(CalledOnValidThread());
114 scoped_ptr<base::Value> value(base::JSONReader::Read(message));
115 base::DictionaryValue* client_message;
116 if (value && value->GetAsDictionary(&client_message)) {
117 std::string type;
118 if (!client_message->GetString(kMessageType, &type)) {
119 LOG(ERROR) << "Invalid gnubby-auth message";
120 return;
123 if (type == kControlMessage) {
124 std::string option;
125 if (client_message->GetString(kControlOption, &option) &&
126 option == kGnubbyAuthV1) {
127 CreateAuthorizationSocket();
128 } else {
129 LOG(ERROR) << "Invalid gnubby-auth control option";
131 } else if (type == kDataMessage) {
132 ActiveSockets::iterator iter = GetSocketForMessage(client_message);
133 if (iter != active_sockets_.end()) {
134 base::ListValue* bytes;
135 std::string response;
136 if (client_message->GetList(kDataPayload, &bytes) &&
137 ConvertListValueToString(bytes, &response)) {
138 HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
139 iter->second->SendResponse(response);
140 } else {
141 LOG(ERROR) << "Invalid gnubby data";
142 SendErrorAndCloseActiveSocket(iter);
144 } else {
145 LOG(ERROR) << "Unknown gnubby-auth data connection";
147 } else if (type == kErrorMessage) {
148 ActiveSockets::iterator iter = GetSocketForMessage(client_message);
149 if (iter != active_sockets_.end()) {
150 HOST_LOG << "Sending gnubby error";
151 SendErrorAndCloseActiveSocket(iter);
152 } else {
153 LOG(ERROR) << "Unknown gnubby-auth error connection";
155 } else {
156 LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
161 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
162 int connection_id,
163 const std::string& data) const {
164 DCHECK(CalledOnValidThread());
166 base::DictionaryValue request;
167 request.SetString(kMessageType, kDataMessage);
168 request.SetInteger(kConnectionId, connection_id);
170 base::ListValue* bytes = new base::ListValue();
171 for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
172 bytes->AppendInteger(static_cast<unsigned char>(*i));
174 request.Set(kDataPayload, bytes);
176 std::string request_json;
177 if (!base::JSONWriter::Write(&request, &request_json)) {
178 LOG(ERROR) << "Failed to create request json";
179 return;
182 protocol::ExtensionMessage message;
183 message.set_type(kGnubbyAuthMessage);
184 message.set_data(request_json);
186 client_stub_->DeliverHostMessage(message);
189 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
190 net::StreamListenSocket* socket) const {
191 return std::find_if(active_sockets_.begin(),
192 active_sockets_.end(),
193 CompareSocket(socket)) != active_sockets_.end();
196 int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
197 net::StreamListenSocket* socket) const {
198 ActiveSockets::const_iterator iter = std::find_if(
199 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
200 return iter->first;
203 GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
204 net::StreamListenSocket* socket) const {
205 ActiveSockets::const_iterator iter = std::find_if(
206 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
207 return iter->second;
210 void GnubbyAuthHandlerPosix::DidAccept(
211 net::StreamListenSocket* server,
212 scoped_ptr<net::StreamListenSocket> socket) {
213 DCHECK(CalledOnValidThread());
215 int connection_id = ++last_connection_id_;
216 active_sockets_[connection_id] =
217 new GnubbySocket(socket.Pass(),
218 base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
219 base::Unretained(this),
220 connection_id));
223 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
224 const char* data,
225 int len) {
226 DCHECK(CalledOnValidThread());
228 ActiveSockets::iterator iter = std::find_if(
229 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
230 if (iter != active_sockets_.end()) {
231 GnubbySocket* gnubby_socket = iter->second;
232 gnubby_socket->AddRequestData(data, len);
233 if (gnubby_socket->IsRequestTooLarge()) {
234 SendErrorAndCloseActiveSocket(iter);
235 } else if (gnubby_socket->IsRequestComplete()) {
236 std::string request_data;
237 gnubby_socket->GetAndClearRequestData(&request_data);
238 ProcessGnubbyRequest(iter->first, request_data);
240 } else {
241 LOG(ERROR) << "Received data for unknown connection";
245 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
246 DCHECK(CalledOnValidThread());
248 ActiveSockets::iterator iter = std::find_if(
249 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
250 if (iter != active_sockets_.end()) {
251 delete iter->second;
252 active_sockets_.erase(iter);
256 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
257 DCHECK(CalledOnValidThread());
259 if (!g_gnubby_socket_name.Get().empty()) {
260 // If the file already exists, a socket in use error is returned.
261 base::DeleteFile(g_gnubby_socket_name.Get(), false);
263 HOST_LOG << "Listening for gnubby requests on "
264 << g_gnubby_socket_name.Get().value();
266 auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen(
267 g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
268 if (!auth_socket_.get()) {
269 LOG(ERROR) << "Failed to open socket for gnubby requests";
271 } else {
272 HOST_LOG << "No gnubby socket name specified";
276 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
277 int connection_id,
278 const std::string& request_data) {
279 HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
280 DeliverHostDataMessage(connection_id, request_data);
283 GnubbyAuthHandlerPosix::ActiveSockets::iterator
284 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
285 int connection_id;
286 if (message->GetInteger(kConnectionId, &connection_id)) {
287 return active_sockets_.find(connection_id);
289 return active_sockets_.end();
292 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
293 const ActiveSockets::iterator& iter) {
294 iter->second->SendSshError();
296 delete iter->second;
297 active_sockets_.erase(iter);
300 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
301 HOST_LOG << "Gnubby request timed out";
302 ActiveSockets::iterator iter = active_sockets_.find(connection_id);
303 if (iter != active_sockets_.end())
304 SendErrorAndCloseActiveSocket(iter);
307 } // namespace remoting