1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/devtools/device/port_forwarding_controller.h"
10 #include "base/compiler_specific.h"
11 #include "base/memory/singleton.h"
12 #include "base/message_loop/message_loop.h"
13 #include "base/prefs/pref_service.h"
14 #include "base/profiler/scoped_tracker.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h"
18 #include "base/threading/non_thread_safe.h"
19 #include "chrome/browser/devtools/devtools_protocol.h"
20 #include "chrome/browser/devtools/devtools_protocol_constants.h"
21 #include "chrome/browser/profiles/profile.h"
22 #include "chrome/common/pref_names.h"
23 #include "components/keyed_service/content/browser_context_dependency_manager.h"
24 #include "content/public/browser/browser_thread.h"
25 #include "net/base/address_list.h"
26 #include "net/base/io_buffer.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/net_util.h"
29 #include "net/dns/host_resolver.h"
30 #include "net/socket/tcp_client_socket.h"
32 using content::BrowserThread
;
36 const int kBufferSize
= 16 * 1024;
40 kStatusDisconnecting
= -2,
41 kStatusConnecting
= -1,
43 // Positive values are used to count open connections.
46 namespace tethering
= ::chrome::devtools::Tethering
;
48 static const char kDevToolsRemoteBrowserTarget
[] = "/devtools/browser";
50 class SocketTunnel
: public base::NonThreadSafe
{
52 typedef base::Callback
<void(int)> CounterCallback
;
54 static void StartTunnel(const std::string
& host
,
56 const CounterCallback
& callback
,
58 scoped_ptr
<net::StreamSocket
> socket
) {
61 SocketTunnel
* tunnel
= new SocketTunnel(callback
);
62 tunnel
->Start(socket
.Pass(), host
, port
);
66 explicit SocketTunnel(const CounterCallback
& callback
)
68 pending_destruction_(false),
70 about_to_destroy_(false) {
74 void Start(scoped_ptr
<net::StreamSocket
> socket
,
75 const std::string
& host
, int port
) {
76 remote_socket_
.swap(socket
);
78 host_resolver_
= net::HostResolver::CreateDefaultResolver(NULL
);
79 net::HostResolver::RequestInfo
request_info(net::HostPortPair(host
, port
));
80 int result
= host_resolver_
->Resolve(
82 net::DEFAULT_PRIORITY
,
84 base::Bind(&SocketTunnel::OnResolved
, base::Unretained(this)),
87 if (result
!= net::ERR_IO_PENDING
)
91 void OnResolved(int result
) {
92 // TODO(vadimt): Remove ScopedTracker below once crbug.com/436634 is fixed.
93 tracked_objects::ScopedTracker
tracking_profile(
94 FROM_HERE_WITH_EXPLICIT_FUNCTION("436634 SocketTunnel::OnResolved"));
101 host_socket_
.reset(new net::TCPClientSocket(address_list_
, NULL
,
102 net::NetLog::Source()));
103 result
= host_socket_
->Connect(base::Bind(&SocketTunnel::OnConnected
,
104 base::Unretained(this)));
105 if (result
!= net::ERR_IO_PENDING
)
110 about_to_destroy_
= true;
112 host_socket_
->Disconnect();
114 remote_socket_
->Disconnect();
118 void OnConnected(int result
) {
124 ++pending_writes_
; // avoid SelfDestruct in first Pump
125 Pump(host_socket_
.get(), remote_socket_
.get());
127 if (pending_destruction_
) {
130 Pump(remote_socket_
.get(), host_socket_
.get());
134 void Pump(net::StreamSocket
* from
, net::StreamSocket
* to
) {
135 scoped_refptr
<net::IOBuffer
> buffer
= new net::IOBuffer(kBufferSize
);
136 int result
= from
->Read(
140 &SocketTunnel::OnRead
, base::Unretained(this), from
, to
, buffer
));
141 if (result
!= net::ERR_IO_PENDING
)
142 OnRead(from
, to
, buffer
, result
);
145 void OnRead(net::StreamSocket
* from
,
146 net::StreamSocket
* to
,
147 scoped_refptr
<net::IOBuffer
> buffer
,
155 scoped_refptr
<net::DrainableIOBuffer
> drainable
=
156 new net::DrainableIOBuffer(buffer
.get(), total
);
159 result
= to
->Write(drainable
.get(),
161 base::Bind(&SocketTunnel::OnWritten
,
162 base::Unretained(this),
166 if (result
!= net::ERR_IO_PENDING
)
167 OnWritten(drainable
, from
, to
, result
);
170 void OnWritten(scoped_refptr
<net::DrainableIOBuffer
> drainable
,
171 net::StreamSocket
* from
,
172 net::StreamSocket
* to
,
180 drainable
->DidConsume(result
);
181 if (drainable
->BytesRemaining() > 0) {
183 result
= to
->Write(drainable
.get(),
184 drainable
->BytesRemaining(),
185 base::Bind(&SocketTunnel::OnWritten
,
186 base::Unretained(this),
190 if (result
!= net::ERR_IO_PENDING
)
191 OnWritten(drainable
, from
, to
, result
);
195 if (pending_destruction_
) {
202 void SelfDestruct() {
203 // In case one of the connections closes, we could get here
204 // from another one due to Disconnect firing back on all
206 if (about_to_destroy_
)
208 if (pending_writes_
> 0) {
209 pending_destruction_
= true;
215 scoped_ptr
<net::StreamSocket
> remote_socket_
;
216 scoped_ptr
<net::StreamSocket
> host_socket_
;
217 scoped_ptr
<net::HostResolver
> host_resolver_
;
218 net::AddressList address_list_
;
220 bool pending_destruction_
;
221 CounterCallback callback_
;
222 bool about_to_destroy_
;
227 class PortForwardingController::Connection
228 : public AndroidDeviceManager::AndroidWebSocket::Delegate
{
230 Connection(PortForwardingController
* controller
,
231 scoped_refptr
<DevToolsAndroidBridge::RemoteBrowser
> browser
,
232 const ForwardingMap
& forwarding_map
);
233 ~Connection() override
;
235 const PortStatusMap
& GetPortStatusMap();
237 void UpdateForwardingMap(const ForwardingMap
& new_forwarding_map
);
239 scoped_refptr
<DevToolsAndroidBridge::RemoteBrowser
> browser() {
244 friend struct content::BrowserThread::DeleteOnThread
<
245 content::BrowserThread::UI
>;
246 friend class base::DeleteHelper
<Connection
>;
248 typedef std::map
<int, std::string
> ForwardingMap
;
249 typedef base::Callback
<void(PortStatus
)> CommandCallback
;
250 typedef std::map
<int, CommandCallback
> CommandCallbackMap
;
252 void SerializeChanges(const std::string
& method
,
253 const ForwardingMap
& old_map
,
254 const ForwardingMap
& new_map
);
256 void SendCommand(const std::string
& method
, int port
);
257 bool ProcessResponse(const std::string
& json
);
259 void ProcessBindResponse(int port
, PortStatus status
);
260 void ProcessUnbindResponse(int port
, PortStatus status
);
262 static void UpdateSocketCountOnHandlerThread(
263 base::WeakPtr
<Connection
> weak_connection
, int port
, int increment
);
264 void UpdateSocketCount(int port
, int increment
);
266 // DevToolsAndroidBridge::AndroidWebSocket::Delegate implementation:
267 void OnSocketOpened() override
;
268 void OnFrameRead(const std::string
& message
) override
;
269 void OnSocketClosed() override
;
271 PortForwardingController
* controller_
;
272 scoped_refptr
<DevToolsAndroidBridge::RemoteBrowser
> browser_
;
273 scoped_ptr
<AndroidDeviceManager::AndroidWebSocket
> web_socket_
;
276 ForwardingMap forwarding_map_
;
277 CommandCallbackMap pending_responses_
;
278 PortStatusMap port_status_
;
279 base::WeakPtrFactory
<Connection
> weak_factory_
;
281 DISALLOW_COPY_AND_ASSIGN(Connection
);
284 PortForwardingController::Connection::Connection(
285 PortForwardingController
* controller
,
286 scoped_refptr
<DevToolsAndroidBridge::RemoteBrowser
> browser
,
287 const ForwardingMap
& forwarding_map
)
288 : controller_(controller
),
292 forwarding_map_(forwarding_map
),
293 weak_factory_(this) {
294 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
295 controller_
->registry_
[browser
->serial()] = this;
296 scoped_refptr
<AndroidDeviceManager::Device
> device(
297 controller_
->bridge_
->FindDevice(browser
->serial()));
298 DCHECK(device
.get());
300 device
->CreateWebSocket(browser
->socket(),
301 kDevToolsRemoteBrowserTarget
, this));
304 PortForwardingController::Connection::~Connection() {
305 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
306 DCHECK(controller_
->registry_
.find(browser_
->serial()) !=
307 controller_
->registry_
.end());
308 controller_
->registry_
.erase(browser_
->serial());
311 void PortForwardingController::Connection::UpdateForwardingMap(
312 const ForwardingMap
& new_forwarding_map
) {
313 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
315 SerializeChanges(tethering::unbind::kName
,
316 new_forwarding_map
, forwarding_map_
);
317 SerializeChanges(tethering::bind::kName
,
318 forwarding_map_
, new_forwarding_map
);
320 forwarding_map_
= new_forwarding_map
;
323 void PortForwardingController::Connection::SerializeChanges(
324 const std::string
& method
,
325 const ForwardingMap
& old_map
,
326 const ForwardingMap
& new_map
) {
327 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
328 for (ForwardingMap::const_iterator
new_it(new_map
.begin());
329 new_it
!= new_map
.end(); ++new_it
) {
330 int port
= new_it
->first
;
331 const std::string
& location
= new_it
->second
;
332 ForwardingMap::const_iterator old_it
= old_map
.find(port
);
333 if (old_it
!= old_map
.end() && old_it
->second
== location
)
334 continue; // The port points to the same location in both configs, skip.
336 SendCommand(method
, port
);
340 void PortForwardingController::Connection::SendCommand(
341 const std::string
& method
, int port
) {
342 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
343 scoped_ptr
<base::DictionaryValue
> params(new base::DictionaryValue
);
344 if (method
== tethering::bind::kName
) {
345 params
->SetInteger(tethering::bind::kParamPort
, port
);
347 DCHECK_EQ(tethering::unbind::kName
, method
);
348 params
->SetInteger(tethering::unbind::kParamPort
, port
);
350 int id
= ++command_id_
;
352 if (method
== tethering::bind::kName
) {
353 pending_responses_
[id
] =
354 base::Bind(&Connection::ProcessBindResponse
,
355 base::Unretained(this), port
);
356 #if defined(DEBUG_DEVTOOLS)
357 port_status_
[port
] = kStatusConnecting
;
358 #endif // defined(DEBUG_DEVTOOLS)
360 PortStatusMap::iterator it
= port_status_
.find(port
);
361 if (it
!= port_status_
.end() && it
->second
== kStatusError
) {
362 // The bind command failed on this port, do not attempt unbind.
363 port_status_
.erase(it
);
367 pending_responses_
[id
] =
368 base::Bind(&Connection::ProcessUnbindResponse
,
369 base::Unretained(this), port
);
370 #if defined(DEBUG_DEVTOOLS)
371 port_status_
[port
] = kStatusDisconnecting
;
372 #endif // defined(DEBUG_DEVTOOLS)
375 web_socket_
->SendFrame(
376 DevToolsProtocol::SerializeCommand(id
, method
, params
.Pass()));
379 bool PortForwardingController::Connection::ProcessResponse(
380 const std::string
& message
) {
383 if (!DevToolsProtocol::ParseResponse(message
, &id
, &error_code
))
386 CommandCallbackMap::iterator it
= pending_responses_
.find(id
);
387 if (it
== pending_responses_
.end())
390 it
->second
.Run(error_code
? kStatusError
: kStatusOK
);
391 pending_responses_
.erase(it
);
395 void PortForwardingController::Connection::ProcessBindResponse(
396 int port
, PortStatus status
) {
397 port_status_
[port
] = status
;
400 void PortForwardingController::Connection::ProcessUnbindResponse(
401 int port
, PortStatus status
) {
402 PortStatusMap::iterator it
= port_status_
.find(port
);
403 if (it
== port_status_
.end())
405 if (status
== kStatusError
)
408 port_status_
.erase(it
);
412 void PortForwardingController::Connection::UpdateSocketCountOnHandlerThread(
413 base::WeakPtr
<Connection
> weak_connection
, int port
, int increment
) {
414 BrowserThread::PostTask(BrowserThread::UI
, FROM_HERE
,
415 base::Bind(&Connection::UpdateSocketCount
,
416 weak_connection
, port
, increment
));
419 void PortForwardingController::Connection::UpdateSocketCount(
420 int port
, int increment
) {
421 #if defined(DEBUG_DEVTOOLS)
422 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
423 PortStatusMap::iterator it
= port_status_
.find(port
);
424 if (it
== port_status_
.end())
426 if (it
->second
< 0 || (it
->second
== 0 && increment
< 0))
428 it
->second
+= increment
;
429 #endif // defined(DEBUG_DEVTOOLS)
432 const PortForwardingController::PortStatusMap
&
433 PortForwardingController::Connection::GetPortStatusMap() {
434 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
438 void PortForwardingController::Connection::OnSocketOpened() {
439 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
441 SerializeChanges(tethering::bind::kName
, ForwardingMap(), forwarding_map_
);
444 void PortForwardingController::Connection::OnSocketClosed() {
448 void PortForwardingController::Connection::OnFrameRead(
449 const std::string
& message
) {
450 DCHECK_CURRENTLY_ON(BrowserThread::UI
);
451 if (ProcessResponse(message
))
455 scoped_ptr
<base::DictionaryValue
> params
;
456 if (!DevToolsProtocol::ParseNotification(message
, &method
, ¶ms
))
459 if (method
!= tethering::accepted::kName
|| !params
)
463 std::string connection_id
;
464 if (!params
->GetInteger(tethering::accepted::kParamPort
, &port
) ||
465 !params
->GetString(tethering::accepted::kParamConnectionId
,
469 std::map
<int, std::string
>::iterator it
= forwarding_map_
.find(port
);
470 if (it
== forwarding_map_
.end())
473 std::string location
= it
->second
;
474 std::vector
<std::string
> tokens
;
475 Tokenize(location
, ":", &tokens
);
476 int destination_port
= 0;
477 if (tokens
.size() != 2 || !base::StringToInt(tokens
[1], &destination_port
))
479 std::string destination_host
= tokens
[0];
481 SocketTunnel::CounterCallback callback
=
482 base::Bind(&Connection::UpdateSocketCountOnHandlerThread
,
483 weak_factory_
.GetWeakPtr(), port
);
485 scoped_refptr
<AndroidDeviceManager::Device
> device(
486 controller_
->bridge_
->FindDevice(browser_
->serial()));
487 DCHECK(device
.get());
489 connection_id
.c_str(),
490 base::Bind(&SocketTunnel::StartTunnel
,
496 PortForwardingController::PortForwardingController(
498 DevToolsAndroidBridge
* bridge
)
500 pref_service_(profile
->GetPrefs()) {
501 pref_change_registrar_
.Init(pref_service_
);
502 base::Closure callback
= base::Bind(
503 &PortForwardingController::OnPrefsChange
, base::Unretained(this));
504 pref_change_registrar_
.Add(prefs::kDevToolsPortForwardingEnabled
, callback
);
505 pref_change_registrar_
.Add(prefs::kDevToolsPortForwardingConfig
, callback
);
509 PortForwardingController::~PortForwardingController() {}
511 PortForwardingController::ForwardingStatus
512 PortForwardingController::DeviceListChanged(
513 const DevToolsAndroidBridge::RemoteDevices
& devices
) {
514 ForwardingStatus status
;
515 if (forwarding_map_
.empty())
518 for (const auto& device
: devices
) {
519 if (!device
->is_connected())
521 Registry::iterator rit
= registry_
.find(device
->serial());
522 if (rit
== registry_
.end()) {
523 if (device
->browsers().size() > 0)
524 new Connection(this, device
->browsers()[0], forwarding_map_
);
526 status
.push_back(std::make_pair(rit
->second
->browser(),
527 rit
->second
->GetPortStatusMap()));
533 void PortForwardingController::OnPrefsChange() {
534 forwarding_map_
.clear();
536 if (pref_service_
->GetBoolean(prefs::kDevToolsPortForwardingEnabled
)) {
537 const base::DictionaryValue
* dict
=
538 pref_service_
->GetDictionary(prefs::kDevToolsPortForwardingConfig
);
539 for (base::DictionaryValue::Iterator
it(*dict
);
540 !it
.IsAtEnd(); it
.Advance()) {
542 std::string location
;
543 if (base::StringToInt(it
.key(), &port_num
) &&
544 dict
->GetString(it
.key(), &location
))
545 forwarding_map_
[port_num
] = location
;
549 if (!forwarding_map_
.empty()) {
552 std::vector
<Connection
*> registry_copy
;
553 for (Registry::iterator it
= registry_
.begin();
554 it
!= registry_
.end(); ++it
) {
555 registry_copy
.push_back(it
->second
);
557 STLDeleteElements(®istry_copy
);
561 void PortForwardingController::UpdateConnections() {
562 for (Registry::iterator it
= registry_
.begin(); it
!= registry_
.end(); ++it
)
563 it
->second
->UpdateForwardingMap(forwarding_map_
);