1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
6 #define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
10 #include "base/basictypes.h"
11 #include "base/memory/ref_counted.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/time/time.h"
14 #include "base/timer/timer.h"
15 #include "net/base/host_port_pair.h"
16 #include "net/dns/host_resolver.h"
17 #include "net/dns/single_request_host_resolver.h"
18 #include "net/socket/client_socket_pool.h"
19 #include "net/socket/client_socket_pool_base.h"
20 #include "net/socket/client_socket_pool_histograms.h"
24 class ClientSocketFactory
;
26 typedef base::Callback
<int(const AddressList
&, const BoundNetLog
& net_log
)>
27 OnHostResolutionCallback
;
29 class NET_EXPORT_PRIVATE TransportSocketParams
30 : public base::RefCounted
<TransportSocketParams
> {
32 // |host_resolution_callback| will be invoked after the the hostname is
33 // resolved. If |host_resolution_callback| does not return OK, then the
34 // connection will be aborted with that value.
35 TransportSocketParams(
36 const HostPortPair
& host_port_pair
,
37 RequestPriority priority
,
38 bool disable_resolver_cache
,
40 const OnHostResolutionCallback
& host_resolution_callback
);
42 const HostResolver::RequestInfo
& destination() const { return destination_
; }
43 bool ignore_limits() const { return ignore_limits_
; }
44 const OnHostResolutionCallback
& host_resolution_callback() const {
45 return host_resolution_callback_
;
49 friend class base::RefCounted
<TransportSocketParams
>;
50 ~TransportSocketParams();
52 void Initialize(RequestPriority priority
, bool disable_resolver_cache
);
54 HostResolver::RequestInfo destination_
;
56 const OnHostResolutionCallback host_resolution_callback_
;
58 DISALLOW_COPY_AND_ASSIGN(TransportSocketParams
);
61 // TransportConnectJob handles the host resolution necessary for socket creation
62 // and the transport (likely TCP) connect. TransportConnectJob also has fallback
63 // logic for IPv6 connect() timeouts (which may happen due to networks / routers
64 // with broken IPv6 support). Those timeouts take 20s, so rather than make the
65 // user wait 20s for the timeout to fire, we use a fallback timer
66 // (kIPv6FallbackTimerInMs) and start a connect() to a IPv4 address if the timer
67 // fires. Then we race the IPv4 connect() against the IPv6 connect() (which has
68 // a headstart) and return the one that completes first to the socket pool.
69 class NET_EXPORT_PRIVATE TransportConnectJob
: public ConnectJob
{
71 TransportConnectJob(const std::string
& group_name
,
72 const scoped_refptr
<TransportSocketParams
>& params
,
73 base::TimeDelta timeout_duration
,
74 ClientSocketFactory
* client_socket_factory
,
75 HostResolver
* host_resolver
,
78 virtual ~TransportConnectJob();
80 // ConnectJob methods.
81 virtual LoadState
GetLoadState() const OVERRIDE
;
83 // Rolls |addrlist| forward until the first IPv4 address, if any.
84 // WARNING: this method should only be used to implement the prefer-IPv4 hack.
85 static void MakeAddressListStartWithIPv4(AddressList
* addrlist
);
87 static const int kIPv6FallbackTimerInMs
;
92 STATE_RESOLVE_HOST_COMPLETE
,
93 STATE_TRANSPORT_CONNECT
,
94 STATE_TRANSPORT_CONNECT_COMPLETE
,
98 void OnIOComplete(int result
);
100 // Runs the state transition loop.
101 int DoLoop(int result
);
104 int DoResolveHostComplete(int result
);
105 int DoTransportConnect();
106 int DoTransportConnectComplete(int result
);
108 // Not part of the state machine.
109 void DoIPv6FallbackTransportConnect();
110 void DoIPv6FallbackTransportConnectComplete(int result
);
112 // Begins the host resolution and the TCP connect. Returns OK on success
113 // and ERR_IO_PENDING if it cannot immediately service the request.
114 // Otherwise, it returns a net error code.
115 virtual int ConnectInternal() OVERRIDE
;
117 scoped_refptr
<TransportSocketParams
> params_
;
118 ClientSocketFactory
* const client_socket_factory_
;
119 SingleRequestHostResolver resolver_
;
120 AddressList addresses_
;
123 scoped_ptr
<StreamSocket
> transport_socket_
;
125 scoped_ptr
<StreamSocket
> fallback_transport_socket_
;
126 scoped_ptr
<AddressList
> fallback_addresses_
;
127 base::TimeTicks fallback_connect_start_time_
;
128 base::OneShotTimer
<TransportConnectJob
> fallback_timer_
;
130 DISALLOW_COPY_AND_ASSIGN(TransportConnectJob
);
133 class NET_EXPORT_PRIVATE TransportClientSocketPool
: public ClientSocketPool
{
135 TransportClientSocketPool(
137 int max_sockets_per_group
,
138 ClientSocketPoolHistograms
* histograms
,
139 HostResolver
* host_resolver
,
140 ClientSocketFactory
* client_socket_factory
,
143 virtual ~TransportClientSocketPool();
145 // ClientSocketPool implementation.
146 virtual int RequestSocket(const std::string
& group_name
,
147 const void* resolve_info
,
148 RequestPriority priority
,
149 ClientSocketHandle
* handle
,
150 const CompletionCallback
& callback
,
151 const BoundNetLog
& net_log
) OVERRIDE
;
152 virtual void RequestSockets(const std::string
& group_name
,
155 const BoundNetLog
& net_log
) OVERRIDE
;
156 virtual void CancelRequest(const std::string
& group_name
,
157 ClientSocketHandle
* handle
) OVERRIDE
;
158 virtual void ReleaseSocket(const std::string
& group_name
,
159 StreamSocket
* socket
,
161 virtual void FlushWithError(int error
) OVERRIDE
;
162 virtual bool IsStalled() const OVERRIDE
;
163 virtual void CloseIdleSockets() OVERRIDE
;
164 virtual int IdleSocketCount() const OVERRIDE
;
165 virtual int IdleSocketCountInGroup(
166 const std::string
& group_name
) const OVERRIDE
;
167 virtual LoadState
GetLoadState(
168 const std::string
& group_name
,
169 const ClientSocketHandle
* handle
) const OVERRIDE
;
170 virtual void AddLayeredPool(LayeredPool
* layered_pool
) OVERRIDE
;
171 virtual void RemoveLayeredPool(LayeredPool
* layered_pool
) OVERRIDE
;
172 virtual base::DictionaryValue
* GetInfoAsValue(
173 const std::string
& name
,
174 const std::string
& type
,
175 bool include_nested_pools
) const OVERRIDE
;
176 virtual base::TimeDelta
ConnectionTimeout() const OVERRIDE
;
177 virtual ClientSocketPoolHistograms
* histograms() const OVERRIDE
;
180 typedef ClientSocketPoolBase
<TransportSocketParams
> PoolBase
;
182 class TransportConnectJobFactory
183 : public PoolBase::ConnectJobFactory
{
185 TransportConnectJobFactory(ClientSocketFactory
* client_socket_factory
,
186 HostResolver
* host_resolver
,
188 : client_socket_factory_(client_socket_factory
),
189 host_resolver_(host_resolver
),
192 virtual ~TransportConnectJobFactory() {}
194 // ClientSocketPoolBase::ConnectJobFactory methods.
196 virtual ConnectJob
* NewConnectJob(
197 const std::string
& group_name
,
198 const PoolBase::Request
& request
,
199 ConnectJob::Delegate
* delegate
) const OVERRIDE
;
201 virtual base::TimeDelta
ConnectionTimeout() const OVERRIDE
;
204 ClientSocketFactory
* const client_socket_factory_
;
205 HostResolver
* const host_resolver_
;
208 DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory
);
213 DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool
);
216 REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool
,
217 TransportSocketParams
);
221 #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_