Roll src/third_party/WebKit d9c6159:8139f33 (svn 201974:201975)
[chromium-blink-merge.git] / chrome / test / chromedriver / net / adb_client_socket.cc
blob14a10ba7f1d22fc7df4140cc7a335ecbc789dfb5
1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/adb_client_socket.h"
7 #include "base/bind.h"
8 #include "base/compiler_specific.h"
9 #include "base/strings/string_number_conversions.h"
10 #include "base/strings/string_split.h"
11 #include "base/strings/string_util.h"
12 #include "base/strings/stringprintf.h"
13 #include "net/base/address_list.h"
14 #include "net/base/completion_callback.h"
15 #include "net/base/net_errors.h"
16 #include "net/base/net_util.h"
17 #include "net/socket/tcp_client_socket.h"
19 namespace {
21 const int kBufferSize = 16 * 1024;
22 const char kOkayResponse[] = "OKAY";
23 const char kHostTransportCommand[] = "host:transport:%s";
24 const char kLocalAbstractCommand[] = "localabstract:%s";
25 const char kLocalhost[] = "127.0.0.1";
27 typedef base::Callback<void(int, const std::string&)> CommandCallback;
28 typedef base::Callback<void(int, net::StreamSocket*)> SocketCallback;
30 std::string EncodeMessage(const std::string& message) {
31 static const char kHexChars[] = "0123456789ABCDEF";
33 size_t length = message.length();
34 std::string result(4, '\0');
35 char b = reinterpret_cast<const char*>(&length)[1];
36 result[0] = kHexChars[(b >> 4) & 0xf];
37 result[1] = kHexChars[b & 0xf];
38 b = reinterpret_cast<const char*>(&length)[0];
39 result[2] = kHexChars[(b >> 4) & 0xf];
40 result[3] = kHexChars[b & 0xf];
41 return result + message;
44 class AdbTransportSocket : public AdbClientSocket {
45 public:
46 AdbTransportSocket(int port,
47 const std::string& serial,
48 const std::string& socket_name,
49 const SocketCallback& callback)
50 : AdbClientSocket(port),
51 serial_(serial),
52 socket_name_(socket_name),
53 callback_(callback) {
54 Connect(base::Bind(&AdbTransportSocket::OnConnected,
55 base::Unretained(this)));
58 private:
59 ~AdbTransportSocket() {}
61 void OnConnected(int result) {
62 if (!CheckNetResultOrDie(result))
63 return;
64 SendCommand(base::StringPrintf(kHostTransportCommand, serial_.c_str()),
65 true, true, base::Bind(&AdbTransportSocket::SendLocalAbstract,
66 base::Unretained(this)));
69 void SendLocalAbstract(int result, const std::string& response) {
70 if (!CheckNetResultOrDie(result))
71 return;
72 SendCommand(base::StringPrintf(kLocalAbstractCommand, socket_name_.c_str()),
73 true, true, base::Bind(&AdbTransportSocket::OnSocketAvailable,
74 base::Unretained(this)));
77 void OnSocketAvailable(int result, const std::string& response) {
78 if (!CheckNetResultOrDie(result))
79 return;
80 callback_.Run(net::OK, socket_.release());
81 delete this;
84 bool CheckNetResultOrDie(int result) {
85 if (result >= 0)
86 return true;
87 callback_.Run(result, NULL);
88 delete this;
89 return false;
92 std::string serial_;
93 std::string socket_name_;
94 SocketCallback callback_;
97 class HttpOverAdbSocket {
98 public:
99 HttpOverAdbSocket(int port,
100 const std::string& serial,
101 const std::string& socket_name,
102 const std::string& request,
103 const CommandCallback& callback)
104 : request_(request),
105 command_callback_(callback),
106 body_pos_(0) {
107 Connect(port, serial, socket_name);
110 HttpOverAdbSocket(int port,
111 const std::string& serial,
112 const std::string& socket_name,
113 const std::string& request,
114 const SocketCallback& callback)
115 : request_(request),
116 socket_callback_(callback),
117 body_pos_(0) {
118 Connect(port, serial, socket_name);
121 private:
122 ~HttpOverAdbSocket() {
125 void Connect(int port,
126 const std::string& serial,
127 const std::string& socket_name) {
128 AdbClientSocket::TransportQuery(
129 port, serial, socket_name,
130 base::Bind(&HttpOverAdbSocket::OnSocketAvailable,
131 base::Unretained(this)));
134 void OnSocketAvailable(int result,
135 net::StreamSocket* socket) {
136 if (!CheckNetResultOrDie(result))
137 return;
139 socket_.reset(socket);
141 scoped_refptr<net::StringIOBuffer> request_buffer =
142 new net::StringIOBuffer(request_);
144 result = socket_->Write(
145 request_buffer.get(),
146 request_buffer->size(),
147 base::Bind(&HttpOverAdbSocket::ReadResponse, base::Unretained(this)));
148 if (result != net::ERR_IO_PENDING)
149 ReadResponse(result);
152 void ReadResponse(int result) {
153 if (!CheckNetResultOrDie(result))
154 return;
156 scoped_refptr<net::IOBuffer> response_buffer =
157 new net::IOBuffer(kBufferSize);
159 result = socket_->Read(response_buffer.get(),
160 kBufferSize,
161 base::Bind(&HttpOverAdbSocket::OnResponseData,
162 base::Unretained(this),
163 response_buffer,
164 -1));
165 if (result != net::ERR_IO_PENDING)
166 OnResponseData(response_buffer, -1, result);
169 void OnResponseData(scoped_refptr<net::IOBuffer> response_buffer,
170 int bytes_total,
171 int result) {
172 if (!CheckNetResultOrDie(result))
173 return;
174 if (result == 0) {
175 CheckNetResultOrDie(net::ERR_CONNECTION_CLOSED);
176 return;
179 response_ += std::string(response_buffer->data(), result);
180 int expected_length = 0;
181 if (bytes_total < 0) {
182 size_t content_pos = response_.find("Content-Length:");
183 if (content_pos != std::string::npos) {
184 size_t endline_pos = response_.find("\n", content_pos);
185 if (endline_pos != std::string::npos) {
186 std::string len = response_.substr(content_pos + 15,
187 endline_pos - content_pos - 15);
188 base::TrimWhitespace(len, base::TRIM_ALL, &len);
189 if (!base::StringToInt(len, &expected_length)) {
190 CheckNetResultOrDie(net::ERR_FAILED);
191 return;
196 body_pos_ = response_.find("\r\n\r\n");
197 if (body_pos_ != std::string::npos) {
198 body_pos_ += 4;
199 bytes_total = body_pos_ + expected_length;
203 if (bytes_total == static_cast<int>(response_.length())) {
204 if (!command_callback_.is_null())
205 command_callback_.Run(body_pos_, response_);
206 else
207 socket_callback_.Run(net::OK, socket_.release());
208 delete this;
209 return;
212 result = socket_->Read(response_buffer.get(),
213 kBufferSize,
214 base::Bind(&HttpOverAdbSocket::OnResponseData,
215 base::Unretained(this),
216 response_buffer,
217 bytes_total));
218 if (result != net::ERR_IO_PENDING)
219 OnResponseData(response_buffer, bytes_total, result);
222 bool CheckNetResultOrDie(int result) {
223 if (result >= 0)
224 return true;
225 if (!command_callback_.is_null())
226 command_callback_.Run(result, std::string());
227 else
228 socket_callback_.Run(result, NULL);
229 delete this;
230 return false;
233 scoped_ptr<net::StreamSocket> socket_;
234 std::string request_;
235 std::string response_;
236 CommandCallback command_callback_;
237 SocketCallback socket_callback_;
238 size_t body_pos_;
241 class AdbQuerySocket : AdbClientSocket {
242 public:
243 AdbQuerySocket(int port,
244 const std::string& query,
245 const CommandCallback& callback)
246 : AdbClientSocket(port),
247 current_query_(0),
248 callback_(callback) {
249 queries_ = base::SplitString(
250 query, "|", base::KEEP_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
251 if (queries_.empty()) {
252 CheckNetResultOrDie(net::ERR_INVALID_ARGUMENT);
253 return;
255 Connect(base::Bind(&AdbQuerySocket::SendNextQuery, base::Unretained(this)));
258 private:
259 ~AdbQuerySocket() {
262 void SendNextQuery(int result) {
263 if (!CheckNetResultOrDie(result))
264 return;
265 std::string query = queries_[current_query_];
266 if (query.length() > 0xFFFF) {
267 CheckNetResultOrDie(net::ERR_MSG_TOO_BIG);
268 return;
270 bool is_void = current_query_ < queries_.size() - 1;
271 // The |shell| command is a special case because it is the only command that
272 // doesn't include a length at the beginning of the data stream.
273 bool has_length = query.find("shell:") != 0;
274 SendCommand(query, is_void, has_length,
275 base::Bind(&AdbQuerySocket::OnResponse, base::Unretained(this)));
278 void OnResponse(int result, const std::string& response) {
279 if (++current_query_ < queries_.size()) {
280 SendNextQuery(net::OK);
281 } else {
282 callback_.Run(result, response);
283 delete this;
287 bool CheckNetResultOrDie(int result) {
288 if (result >= 0)
289 return true;
290 callback_.Run(result, std::string());
291 delete this;
292 return false;
295 std::vector<std::string> queries_;
296 size_t current_query_;
297 CommandCallback callback_;
300 } // namespace
302 // static
303 void AdbClientSocket::AdbQuery(int port,
304 const std::string& query,
305 const CommandCallback& callback) {
306 new AdbQuerySocket(port, query, callback);
309 #if defined(DEBUG_DEVTOOLS)
310 static void UseTransportQueryForDesktop(const SocketCallback& callback,
311 net::StreamSocket* socket,
312 int result) {
313 callback.Run(result, socket);
315 #endif // defined(DEBUG_DEVTOOLS)
317 // static
318 void AdbClientSocket::TransportQuery(int port,
319 const std::string& serial,
320 const std::string& socket_name,
321 const SocketCallback& callback) {
322 #if defined(DEBUG_DEVTOOLS)
323 if (serial.empty()) {
324 // Use plain socket for remote debugging on Desktop (debugging purposes).
325 net::IPAddressNumber ip_number;
326 net::ParseIPLiteralToNumber(kLocalhost, &ip_number);
328 int tcp_port = 0;
329 if (!base::StringToInt(socket_name, &tcp_port))
330 tcp_port = 9222;
332 net::AddressList address_list =
333 net::AddressList::CreateFromIPAddress(ip_number, tcp_port);
334 net::TCPClientSocket* socket = new net::TCPClientSocket(
335 address_list, NULL, net::NetLog::Source());
336 socket->Connect(base::Bind(&UseTransportQueryForDesktop, callback, socket));
337 return;
339 #endif // defined(DEBUG_DEVTOOLS)
340 new AdbTransportSocket(port, serial, socket_name, callback);
343 // static
344 void AdbClientSocket::HttpQuery(int port,
345 const std::string& serial,
346 const std::string& socket_name,
347 const std::string& request_path,
348 const CommandCallback& callback) {
349 new HttpOverAdbSocket(port, serial, socket_name, request_path,
350 callback);
353 // static
354 void AdbClientSocket::HttpQuery(int port,
355 const std::string& serial,
356 const std::string& socket_name,
357 const std::string& request_path,
358 const SocketCallback& callback) {
359 new HttpOverAdbSocket(port, serial, socket_name, request_path,
360 callback);
363 AdbClientSocket::AdbClientSocket(int port)
364 : host_(kLocalhost), port_(port) {
367 AdbClientSocket::~AdbClientSocket() {
370 void AdbClientSocket::Connect(const net::CompletionCallback& callback) {
371 net::IPAddressNumber ip_number;
372 if (!net::ParseIPLiteralToNumber(host_, &ip_number)) {
373 callback.Run(net::ERR_FAILED);
374 return;
377 net::AddressList address_list =
378 net::AddressList::CreateFromIPAddress(ip_number, port_);
379 socket_.reset(new net::TCPClientSocket(address_list, NULL,
380 net::NetLog::Source()));
381 int result = socket_->Connect(callback);
382 if (result != net::ERR_IO_PENDING)
383 callback.Run(result);
386 void AdbClientSocket::SendCommand(const std::string& command,
387 bool is_void,
388 bool has_length,
389 const CommandCallback& callback) {
390 scoped_refptr<net::StringIOBuffer> request_buffer =
391 new net::StringIOBuffer(EncodeMessage(command));
392 int result = socket_->Write(request_buffer.get(),
393 request_buffer->size(),
394 base::Bind(&AdbClientSocket::ReadResponse,
395 base::Unretained(this),
396 callback,
397 is_void,
398 has_length));
399 if (result != net::ERR_IO_PENDING)
400 ReadResponse(callback, is_void, has_length, result);
403 void AdbClientSocket::ReadResponse(const CommandCallback& callback,
404 bool is_void,
405 bool has_length,
406 int result) {
407 if (result < 0) {
408 callback.Run(result, "IO error");
409 return;
411 scoped_refptr<net::IOBuffer> response_buffer =
412 new net::IOBuffer(kBufferSize);
413 result = socket_->Read(response_buffer.get(),
414 kBufferSize,
415 base::Bind(&AdbClientSocket::OnResponseStatus,
416 base::Unretained(this),
417 callback,
418 is_void,
419 has_length,
420 response_buffer));
421 if (result != net::ERR_IO_PENDING)
422 OnResponseStatus(callback, is_void, has_length, response_buffer, result);
425 void AdbClientSocket::OnResponseStatus(
426 const CommandCallback& callback,
427 bool is_void,
428 bool has_length,
429 scoped_refptr<net::IOBuffer> response_buffer,
430 int result) {
431 if (result <= 0) {
432 callback.Run(result == 0 ? net::ERR_CONNECTION_CLOSED : result,
433 "IO error");
434 return;
437 std::string data = std::string(response_buffer->data(), result);
438 if (result < 4) {
439 callback.Run(net::ERR_FAILED, "Response is too short: " + data);
440 return;
443 std::string status = data.substr(0, 4);
444 if (status != kOkayResponse) {
445 callback.Run(net::ERR_FAILED, data);
446 return;
449 data = data.substr(4);
451 if (!is_void) {
452 if (!has_length) {
453 // Payload doesn't include length, so skip straight to reading in data.
454 OnResponseData(callback, data, response_buffer, -1, 0);
455 } else if (data.length() >= 4) {
456 // We've already read the length out of the socket, so we don't need to
457 // read more yet.
458 OnResponseLength(callback, data, response_buffer, 0);
459 } else {
460 // Part or all of the length is still in the socket, so we need to read it
461 // out of the socket before parsing the length.
462 result = socket_->Read(response_buffer.get(),
463 kBufferSize,
464 base::Bind(&AdbClientSocket::OnResponseLength,
465 base::Unretained(this),
466 callback,
467 data,
468 response_buffer));
469 if (result != net::ERR_IO_PENDING)
470 OnResponseLength(callback, data, response_buffer, result);
472 } else {
473 callback.Run(net::OK, data);
477 void AdbClientSocket::OnResponseLength(
478 const CommandCallback& callback,
479 const std::string& response,
480 scoped_refptr<net::IOBuffer> response_buffer,
481 int result) {
482 if (result < 0) {
483 callback.Run(result, "IO error");
484 return;
487 std::string new_response =
488 response + std::string(response_buffer->data(), result);
489 if (new_response.length() < 4) {
490 result = socket_->Read(response_buffer.get(),
491 kBufferSize,
492 base::Bind(&AdbClientSocket::OnResponseLength,
493 base::Unretained(this),
494 callback,
495 new_response,
496 response_buffer));
497 if (result != net::ERR_IO_PENDING)
498 OnResponseLength(callback, new_response, response_buffer, result);
499 } else {
500 int payload_length = 0;
501 if (!base::HexStringToInt(new_response.substr(0, 4), &payload_length)) {
502 callback.Run(net::ERR_FAILED, new_response);
503 return;
506 new_response = new_response.substr(4);
507 int bytes_left = payload_length - new_response.length();
508 OnResponseData(callback, new_response, response_buffer, bytes_left, 0);
512 void AdbClientSocket::OnResponseData(
513 const CommandCallback& callback,
514 const std::string& response,
515 scoped_refptr<net::IOBuffer> response_buffer,
516 int bytes_left,
517 int result) {
518 if (result < 0) {
519 callback.Run(result, "IO error");
520 return;
523 bytes_left -= result;
524 std::string new_response =
525 response + std::string(response_buffer->data(), result);
526 if (bytes_left == 0) {
527 callback.Run(net::OK, new_response);
528 return;
531 // Read tail
532 result = socket_->Read(response_buffer.get(),
533 kBufferSize,
534 base::Bind(&AdbClientSocket::OnResponseData,
535 base::Unretained(this),
536 callback,
537 new_response,
538 response_buffer,
539 bytes_left));
540 if (result > 0)
541 OnResponseData(callback, new_response, response_buffer, bytes_left, result);
542 else if (result != net::ERR_IO_PENDING)
543 callback.Run(net::OK, new_response);