1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks_client_socket_pool.h"
8 #include "base/bind_helpers.h"
9 #include "base/time/time.h"
10 #include "base/values.h"
11 #include "net/base/net_errors.h"
12 #include "net/socket/client_socket_factory.h"
13 #include "net/socket/client_socket_handle.h"
14 #include "net/socket/client_socket_pool_base.h"
15 #include "net/socket/socks5_client_socket.h"
16 #include "net/socket/socks_client_socket.h"
17 #include "net/socket/transport_client_socket_pool.h"
21 SOCKSSocketParams::SOCKSSocketParams(
22 const scoped_refptr
<TransportSocketParams
>& proxy_server
,
24 const HostPortPair
& host_port_pair
,
25 RequestPriority priority
)
26 : transport_params_(proxy_server
),
27 destination_(host_port_pair
),
29 if (transport_params_
.get())
30 ignore_limits_
= transport_params_
->ignore_limits();
32 ignore_limits_
= false;
33 destination_
.set_priority(priority
);
36 SOCKSSocketParams::~SOCKSSocketParams() {}
38 // SOCKSConnectJobs will time out after this many seconds. Note this is on
39 // top of the timeout for the transport socket.
40 static const int kSOCKSConnectJobTimeoutInSeconds
= 30;
42 SOCKSConnectJob::SOCKSConnectJob(
43 const std::string
& group_name
,
44 const scoped_refptr
<SOCKSSocketParams
>& socks_params
,
45 const base::TimeDelta
& timeout_duration
,
46 TransportClientSocketPool
* transport_pool
,
47 HostResolver
* host_resolver
,
50 : ConnectJob(group_name
, timeout_duration
, delegate
,
51 BoundNetLog::Make(net_log
, NetLog::SOURCE_CONNECT_JOB
)),
52 socks_params_(socks_params
),
53 transport_pool_(transport_pool
),
54 resolver_(host_resolver
),
55 callback_(base::Bind(&SOCKSConnectJob::OnIOComplete
,
56 base::Unretained(this))) {
59 SOCKSConnectJob::~SOCKSConnectJob() {
60 // We don't worry about cancelling the tcp socket since the destructor in
61 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
65 LoadState
SOCKSConnectJob::GetLoadState() const {
66 switch (next_state_
) {
67 case STATE_TRANSPORT_CONNECT
:
68 case STATE_TRANSPORT_CONNECT_COMPLETE
:
69 return transport_socket_handle_
->GetLoadState();
70 case STATE_SOCKS_CONNECT
:
71 case STATE_SOCKS_CONNECT_COMPLETE
:
72 return LOAD_STATE_CONNECTING
;
75 return LOAD_STATE_IDLE
;
79 void SOCKSConnectJob::OnIOComplete(int result
) {
80 int rv
= DoLoop(result
);
81 if (rv
!= ERR_IO_PENDING
)
82 NotifyDelegateOfCompletion(rv
); // Deletes |this|
85 int SOCKSConnectJob::DoLoop(int result
) {
86 DCHECK_NE(next_state_
, STATE_NONE
);
90 State state
= next_state_
;
91 next_state_
= STATE_NONE
;
93 case STATE_TRANSPORT_CONNECT
:
95 rv
= DoTransportConnect();
97 case STATE_TRANSPORT_CONNECT_COMPLETE
:
98 rv
= DoTransportConnectComplete(rv
);
100 case STATE_SOCKS_CONNECT
:
102 rv
= DoSOCKSConnect();
104 case STATE_SOCKS_CONNECT_COMPLETE
:
105 rv
= DoSOCKSConnectComplete(rv
);
108 NOTREACHED() << "bad state";
112 } while (rv
!= ERR_IO_PENDING
&& next_state_
!= STATE_NONE
);
117 int SOCKSConnectJob::DoTransportConnect() {
118 next_state_
= STATE_TRANSPORT_CONNECT_COMPLETE
;
119 transport_socket_handle_
.reset(new ClientSocketHandle());
120 return transport_socket_handle_
->Init(
121 group_name(), socks_params_
->transport_params(),
122 socks_params_
->destination().priority(), callback_
, transport_pool_
,
126 int SOCKSConnectJob::DoTransportConnectComplete(int result
) {
128 return ERR_PROXY_CONNECTION_FAILED
;
130 // Reset the timer to just the length of time allowed for SOCKS handshake
131 // so that a fast TCP connection plus a slow SOCKS failure doesn't take
132 // longer to timeout than it should.
133 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds
));
134 next_state_
= STATE_SOCKS_CONNECT
;
138 int SOCKSConnectJob::DoSOCKSConnect() {
139 next_state_
= STATE_SOCKS_CONNECT_COMPLETE
;
141 // Add a SOCKS connection on top of the tcp socket.
142 if (socks_params_
->is_socks_v5()) {
143 socket_
.reset(new SOCKS5ClientSocket(transport_socket_handle_
.release(),
144 socks_params_
->destination()));
146 socket_
.reset(new SOCKSClientSocket(transport_socket_handle_
.release(),
147 socks_params_
->destination(),
150 return socket_
->Connect(
151 base::Bind(&SOCKSConnectJob::OnIOComplete
, base::Unretained(this)));
154 int SOCKSConnectJob::DoSOCKSConnectComplete(int result
) {
156 socket_
->Disconnect();
160 set_socket(socket_
.release());
164 int SOCKSConnectJob::ConnectInternal() {
165 next_state_
= STATE_TRANSPORT_CONNECT
;
169 ConnectJob
* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
170 const std::string
& group_name
,
171 const PoolBase::Request
& request
,
172 ConnectJob::Delegate
* delegate
) const {
173 return new SOCKSConnectJob(group_name
,
183 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
184 return transport_pool_
->ConnectionTimeout() +
185 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds
);
188 SOCKSClientSocketPool::SOCKSClientSocketPool(
190 int max_sockets_per_group
,
191 ClientSocketPoolHistograms
* histograms
,
192 HostResolver
* host_resolver
,
193 TransportClientSocketPool
* transport_pool
,
195 : transport_pool_(transport_pool
),
196 base_(max_sockets
, max_sockets_per_group
, histograms
,
197 ClientSocketPool::unused_idle_socket_timeout(),
198 ClientSocketPool::used_idle_socket_timeout(),
199 new SOCKSConnectJobFactory(transport_pool
,
202 // We should always have a |transport_pool_| except in unit tests.
204 transport_pool_
->AddLayeredPool(this);
207 SOCKSClientSocketPool::~SOCKSClientSocketPool() {
208 // We should always have a |transport_pool_| except in unit tests.
210 transport_pool_
->RemoveLayeredPool(this);
213 int SOCKSClientSocketPool::RequestSocket(
214 const std::string
& group_name
, const void* socket_params
,
215 RequestPriority priority
, ClientSocketHandle
* handle
,
216 const CompletionCallback
& callback
, const BoundNetLog
& net_log
) {
217 const scoped_refptr
<SOCKSSocketParams
>* casted_socket_params
=
218 static_cast<const scoped_refptr
<SOCKSSocketParams
>*>(socket_params
);
220 return base_
.RequestSocket(group_name
, *casted_socket_params
, priority
,
221 handle
, callback
, net_log
);
224 void SOCKSClientSocketPool::RequestSockets(
225 const std::string
& group_name
,
228 const BoundNetLog
& net_log
) {
229 const scoped_refptr
<SOCKSSocketParams
>* casted_params
=
230 static_cast<const scoped_refptr
<SOCKSSocketParams
>*>(params
);
232 base_
.RequestSockets(group_name
, *casted_params
, num_sockets
, net_log
);
235 void SOCKSClientSocketPool::CancelRequest(const std::string
& group_name
,
236 ClientSocketHandle
* handle
) {
237 base_
.CancelRequest(group_name
, handle
);
240 void SOCKSClientSocketPool::ReleaseSocket(const std::string
& group_name
,
241 StreamSocket
* socket
, int id
) {
242 base_
.ReleaseSocket(group_name
, socket
, id
);
245 void SOCKSClientSocketPool::FlushWithError(int error
) {
246 base_
.FlushWithError(error
);
249 bool SOCKSClientSocketPool::IsStalled() const {
250 return base_
.IsStalled() || transport_pool_
->IsStalled();
253 void SOCKSClientSocketPool::CloseIdleSockets() {
254 base_
.CloseIdleSockets();
257 int SOCKSClientSocketPool::IdleSocketCount() const {
258 return base_
.idle_socket_count();
261 int SOCKSClientSocketPool::IdleSocketCountInGroup(
262 const std::string
& group_name
) const {
263 return base_
.IdleSocketCountInGroup(group_name
);
266 LoadState
SOCKSClientSocketPool::GetLoadState(
267 const std::string
& group_name
, const ClientSocketHandle
* handle
) const {
268 return base_
.GetLoadState(group_name
, handle
);
271 void SOCKSClientSocketPool::AddLayeredPool(LayeredPool
* layered_pool
) {
272 base_
.AddLayeredPool(layered_pool
);
275 void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool
* layered_pool
) {
276 base_
.RemoveLayeredPool(layered_pool
);
279 base::DictionaryValue
* SOCKSClientSocketPool::GetInfoAsValue(
280 const std::string
& name
,
281 const std::string
& type
,
282 bool include_nested_pools
) const {
283 base::DictionaryValue
* dict
= base_
.GetInfoAsValue(name
, type
);
284 if (include_nested_pools
) {
285 base::ListValue
* list
= new base::ListValue();
286 list
->Append(transport_pool_
->GetInfoAsValue("transport_socket_pool",
287 "transport_socket_pool",
289 dict
->Set("nested_pools", list
);
294 base::TimeDelta
SOCKSClientSocketPool::ConnectionTimeout() const {
295 return base_
.ConnectionTimeout();
298 ClientSocketPoolHistograms
* SOCKSClientSocketPool::histograms() const {
299 return base_
.histograms();
302 bool SOCKSClientSocketPool::CloseOneIdleConnection() {
303 if (base_
.CloseOneIdleSocket())
305 return base_
.CloseOneIdleConnectionInLayeredPool();