1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks_client_socket_pool.h"
8 #include "base/bind_helpers.h"
9 #include "base/time/time.h"
10 #include "base/values.h"
11 #include "net/base/net_errors.h"
12 #include "net/socket/client_socket_factory.h"
13 #include "net/socket/client_socket_handle.h"
14 #include "net/socket/client_socket_pool_base.h"
15 #include "net/socket/socks5_client_socket.h"
16 #include "net/socket/socks_client_socket.h"
17 #include "net/socket/transport_client_socket_pool.h"
21 SOCKSSocketParams::SOCKSSocketParams(
22 const scoped_refptr
<TransportSocketParams
>& proxy_server
,
24 const HostPortPair
& host_port_pair
)
25 : transport_params_(proxy_server
),
26 destination_(host_port_pair
),
28 if (transport_params_
.get())
29 ignore_limits_
= transport_params_
->ignore_limits();
31 ignore_limits_
= false;
34 SOCKSSocketParams::~SOCKSSocketParams() {}
36 // SOCKSConnectJobs will time out after this many seconds. Note this is on
37 // top of the timeout for the transport socket.
38 static const int kSOCKSConnectJobTimeoutInSeconds
= 30;
40 SOCKSConnectJob::SOCKSConnectJob(
41 const std::string
& group_name
,
42 RequestPriority priority
,
43 const scoped_refptr
<SOCKSSocketParams
>& socks_params
,
44 const base::TimeDelta
& timeout_duration
,
45 TransportClientSocketPool
* transport_pool
,
46 HostResolver
* host_resolver
,
49 : ConnectJob(group_name
, timeout_duration
, priority
, delegate
,
50 BoundNetLog::Make(net_log
, NetLog::SOURCE_CONNECT_JOB
)),
51 socks_params_(socks_params
),
52 transport_pool_(transport_pool
),
53 resolver_(host_resolver
),
54 callback_(base::Bind(&SOCKSConnectJob::OnIOComplete
,
55 base::Unretained(this))) {
58 SOCKSConnectJob::~SOCKSConnectJob() {
59 // We don't worry about cancelling the tcp socket since the destructor in
60 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
64 LoadState
SOCKSConnectJob::GetLoadState() const {
65 switch (next_state_
) {
66 case STATE_TRANSPORT_CONNECT
:
67 case STATE_TRANSPORT_CONNECT_COMPLETE
:
68 return transport_socket_handle_
->GetLoadState();
69 case STATE_SOCKS_CONNECT
:
70 case STATE_SOCKS_CONNECT_COMPLETE
:
71 return LOAD_STATE_CONNECTING
;
74 return LOAD_STATE_IDLE
;
78 void SOCKSConnectJob::OnIOComplete(int result
) {
79 int rv
= DoLoop(result
);
80 if (rv
!= ERR_IO_PENDING
)
81 NotifyDelegateOfCompletion(rv
); // Deletes |this|
84 int SOCKSConnectJob::DoLoop(int result
) {
85 DCHECK_NE(next_state_
, STATE_NONE
);
89 State state
= next_state_
;
90 next_state_
= STATE_NONE
;
92 case STATE_TRANSPORT_CONNECT
:
94 rv
= DoTransportConnect();
96 case STATE_TRANSPORT_CONNECT_COMPLETE
:
97 rv
= DoTransportConnectComplete(rv
);
99 case STATE_SOCKS_CONNECT
:
101 rv
= DoSOCKSConnect();
103 case STATE_SOCKS_CONNECT_COMPLETE
:
104 rv
= DoSOCKSConnectComplete(rv
);
107 NOTREACHED() << "bad state";
111 } while (rv
!= ERR_IO_PENDING
&& next_state_
!= STATE_NONE
);
116 int SOCKSConnectJob::DoTransportConnect() {
117 next_state_
= STATE_TRANSPORT_CONNECT_COMPLETE
;
118 transport_socket_handle_
.reset(new ClientSocketHandle());
119 return transport_socket_handle_
->Init(group_name(),
120 socks_params_
->transport_params(),
127 int SOCKSConnectJob::DoTransportConnectComplete(int result
) {
129 return ERR_PROXY_CONNECTION_FAILED
;
131 // Reset the timer to just the length of time allowed for SOCKS handshake
132 // so that a fast TCP connection plus a slow SOCKS failure doesn't take
133 // longer to timeout than it should.
134 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds
));
135 next_state_
= STATE_SOCKS_CONNECT
;
139 int SOCKSConnectJob::DoSOCKSConnect() {
140 next_state_
= STATE_SOCKS_CONNECT_COMPLETE
;
142 // Add a SOCKS connection on top of the tcp socket.
143 if (socks_params_
->is_socks_v5()) {
144 socket_
.reset(new SOCKS5ClientSocket(transport_socket_handle_
.Pass(),
145 socks_params_
->destination()));
147 socket_
.reset(new SOCKSClientSocket(transport_socket_handle_
.Pass(),
148 socks_params_
->destination(),
152 return socket_
->Connect(
153 base::Bind(&SOCKSConnectJob::OnIOComplete
, base::Unretained(this)));
156 int SOCKSConnectJob::DoSOCKSConnectComplete(int result
) {
158 socket_
->Disconnect();
162 SetSocket(socket_
.Pass());
166 int SOCKSConnectJob::ConnectInternal() {
167 next_state_
= STATE_TRANSPORT_CONNECT
;
171 scoped_ptr
<ConnectJob
>
172 SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
173 const std::string
& group_name
,
174 const PoolBase::Request
& request
,
175 ConnectJob::Delegate
* delegate
) const {
176 return scoped_ptr
<ConnectJob
>(new SOCKSConnectJob(group_name
,
187 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
188 return transport_pool_
->ConnectionTimeout() +
189 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds
);
192 SOCKSClientSocketPool::SOCKSClientSocketPool(
194 int max_sockets_per_group
,
195 HostResolver
* host_resolver
,
196 TransportClientSocketPool
* transport_pool
,
198 : transport_pool_(transport_pool
),
202 max_sockets_per_group
,
203 ClientSocketPool::unused_idle_socket_timeout(),
204 ClientSocketPool::used_idle_socket_timeout(),
205 new SOCKSConnectJobFactory(transport_pool
, host_resolver
, net_log
)) {
206 // We should always have a |transport_pool_| except in unit tests.
208 base_
.AddLowerLayeredPool(transport_pool_
);
211 SOCKSClientSocketPool::~SOCKSClientSocketPool() {
214 int SOCKSClientSocketPool::RequestSocket(
215 const std::string
& group_name
, const void* socket_params
,
216 RequestPriority priority
, ClientSocketHandle
* handle
,
217 const CompletionCallback
& callback
, const BoundNetLog
& net_log
) {
218 const scoped_refptr
<SOCKSSocketParams
>* casted_socket_params
=
219 static_cast<const scoped_refptr
<SOCKSSocketParams
>*>(socket_params
);
221 return base_
.RequestSocket(group_name
, *casted_socket_params
, priority
,
222 handle
, callback
, net_log
);
225 void SOCKSClientSocketPool::RequestSockets(
226 const std::string
& group_name
,
229 const BoundNetLog
& net_log
) {
230 const scoped_refptr
<SOCKSSocketParams
>* casted_params
=
231 static_cast<const scoped_refptr
<SOCKSSocketParams
>*>(params
);
233 base_
.RequestSockets(group_name
, *casted_params
, num_sockets
, net_log
);
236 void SOCKSClientSocketPool::CancelRequest(const std::string
& group_name
,
237 ClientSocketHandle
* handle
) {
238 base_
.CancelRequest(group_name
, handle
);
241 void SOCKSClientSocketPool::ReleaseSocket(const std::string
& group_name
,
242 scoped_ptr
<StreamSocket
> socket
,
244 base_
.ReleaseSocket(group_name
, socket
.Pass(), id
);
247 void SOCKSClientSocketPool::FlushWithError(int error
) {
248 base_
.FlushWithError(error
);
251 void SOCKSClientSocketPool::CloseIdleSockets() {
252 base_
.CloseIdleSockets();
255 int SOCKSClientSocketPool::IdleSocketCount() const {
256 return base_
.idle_socket_count();
259 int SOCKSClientSocketPool::IdleSocketCountInGroup(
260 const std::string
& group_name
) const {
261 return base_
.IdleSocketCountInGroup(group_name
);
264 LoadState
SOCKSClientSocketPool::GetLoadState(
265 const std::string
& group_name
, const ClientSocketHandle
* handle
) const {
266 return base_
.GetLoadState(group_name
, handle
);
269 base::DictionaryValue
* SOCKSClientSocketPool::GetInfoAsValue(
270 const std::string
& name
,
271 const std::string
& type
,
272 bool include_nested_pools
) const {
273 base::DictionaryValue
* dict
= base_
.GetInfoAsValue(name
, type
);
274 if (include_nested_pools
) {
275 base::ListValue
* list
= new base::ListValue();
276 list
->Append(transport_pool_
->GetInfoAsValue("transport_socket_pool",
277 "transport_socket_pool",
279 dict
->Set("nested_pools", list
);
284 base::TimeDelta
SOCKSClientSocketPool::ConnectionTimeout() const {
285 return base_
.ConnectionTimeout();
288 bool SOCKSClientSocketPool::IsStalled() const {
289 return base_
.IsStalled();
292 void SOCKSClientSocketPool::AddHigherLayeredPool(
293 HigherLayeredPool
* higher_pool
) {
294 base_
.AddHigherLayeredPool(higher_pool
);
297 void SOCKSClientSocketPool::RemoveHigherLayeredPool(
298 HigherLayeredPool
* higher_pool
) {
299 base_
.RemoveHigherLayeredPool(higher_pool
);
302 bool SOCKSClientSocketPool::CloseOneIdleConnection() {
303 if (base_
.CloseOneIdleSocket())
305 return base_
.CloseOneIdleConnectionInHigherLayeredPool();