1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/transport_client_socket_pool_test_util.h"
9 #include "base/logging.h"
10 #include "base/memory/weak_ptr.h"
11 #include "base/run_loop.h"
12 #include "net/base/ip_endpoint.h"
13 #include "net/base/load_timing_info.h"
14 #include "net/base/load_timing_info_test_util.h"
15 #include "net/base/net_util.h"
16 #include "net/socket/client_socket_handle.h"
17 #include "net/socket/ssl_client_socket.h"
18 #include "net/udp/datagram_client_socket.h"
19 #include "testing/gtest/include/gtest/gtest.h"
25 IPAddressNumber
ParseIP(const std::string
& ip
) {
26 IPAddressNumber number
;
27 CHECK(ParseIPLiteralToNumber(ip
, &number
));
31 // A StreamSocket which connects synchronously and successfully.
32 class MockConnectClientSocket
: public StreamSocket
{
34 MockConnectClientSocket(const AddressList
& addrlist
, net::NetLog
* net_log
)
37 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)) {}
39 // StreamSocket implementation.
40 virtual int Connect(const CompletionCallback
& callback
) OVERRIDE
{
44 virtual void Disconnect() OVERRIDE
{ connected_
= false; }
45 virtual bool IsConnected() const OVERRIDE
{ return connected_
; }
46 virtual bool IsConnectedAndIdle() const OVERRIDE
{ return connected_
; }
48 virtual int GetPeerAddress(IPEndPoint
* address
) const OVERRIDE
{
49 *address
= addrlist_
.front();
52 virtual int GetLocalAddress(IPEndPoint
* address
) const OVERRIDE
{
54 return ERR_SOCKET_NOT_CONNECTED
;
55 if (addrlist_
.front().GetFamily() == ADDRESS_FAMILY_IPV4
)
56 SetIPv4Address(address
);
58 SetIPv6Address(address
);
61 virtual const BoundNetLog
& NetLog() const OVERRIDE
{ return net_log_
; }
63 virtual void SetSubresourceSpeculation() OVERRIDE
{}
64 virtual void SetOmniboxSpeculation() OVERRIDE
{}
65 virtual bool WasEverUsed() const OVERRIDE
{ return false; }
66 virtual bool UsingTCPFastOpen() const OVERRIDE
{ return false; }
67 virtual bool WasNpnNegotiated() const OVERRIDE
{ return false; }
68 virtual NextProto
GetNegotiatedProtocol() const OVERRIDE
{
71 virtual bool GetSSLInfo(SSLInfo
* ssl_info
) OVERRIDE
{ return false; }
73 // Socket implementation.
74 virtual int Read(IOBuffer
* buf
,
76 const CompletionCallback
& callback
) OVERRIDE
{
79 virtual int Write(IOBuffer
* buf
,
81 const CompletionCallback
& callback
) OVERRIDE
{
84 virtual int SetReceiveBufferSize(int32 size
) OVERRIDE
{ return OK
; }
85 virtual int SetSendBufferSize(int32 size
) OVERRIDE
{ return OK
; }
89 const AddressList addrlist_
;
92 DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket
);
95 class MockFailingClientSocket
: public StreamSocket
{
97 MockFailingClientSocket(const AddressList
& addrlist
, net::NetLog
* net_log
)
98 : addrlist_(addrlist
),
99 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)) {}
101 // StreamSocket implementation.
102 virtual int Connect(const CompletionCallback
& callback
) OVERRIDE
{
103 return ERR_CONNECTION_FAILED
;
106 virtual void Disconnect() OVERRIDE
{}
108 virtual bool IsConnected() const OVERRIDE
{ return false; }
109 virtual bool IsConnectedAndIdle() const OVERRIDE
{ return false; }
110 virtual int GetPeerAddress(IPEndPoint
* address
) const OVERRIDE
{
111 return ERR_UNEXPECTED
;
113 virtual int GetLocalAddress(IPEndPoint
* address
) const OVERRIDE
{
114 return ERR_UNEXPECTED
;
116 virtual const BoundNetLog
& NetLog() const OVERRIDE
{ return net_log_
; }
118 virtual void SetSubresourceSpeculation() OVERRIDE
{}
119 virtual void SetOmniboxSpeculation() OVERRIDE
{}
120 virtual bool WasEverUsed() const OVERRIDE
{ return false; }
121 virtual bool UsingTCPFastOpen() const OVERRIDE
{ return false; }
122 virtual bool WasNpnNegotiated() const OVERRIDE
{ return false; }
123 virtual NextProto
GetNegotiatedProtocol() const OVERRIDE
{
124 return kProtoUnknown
;
126 virtual bool GetSSLInfo(SSLInfo
* ssl_info
) OVERRIDE
{ return false; }
128 // Socket implementation.
129 virtual int Read(IOBuffer
* buf
,
131 const CompletionCallback
& callback
) OVERRIDE
{
135 virtual int Write(IOBuffer
* buf
,
137 const CompletionCallback
& callback
) OVERRIDE
{
140 virtual int SetReceiveBufferSize(int32 size
) OVERRIDE
{ return OK
; }
141 virtual int SetSendBufferSize(int32 size
) OVERRIDE
{ return OK
; }
144 const AddressList addrlist_
;
145 BoundNetLog net_log_
;
147 DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket
);
150 class MockTriggerableClientSocket
: public StreamSocket
{
152 // |should_connect| indicates whether the socket should successfully complete
154 MockTriggerableClientSocket(const AddressList
& addrlist
,
156 net::NetLog
* net_log
)
157 : should_connect_(should_connect
),
158 is_connected_(false),
160 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)),
161 weak_factory_(this) {}
163 // Call this method to get a closure which will trigger the connect callback
164 // when called. The closure can be called even after the socket is deleted; it
165 // will safely do nothing.
166 base::Closure
GetConnectCallback() {
167 return base::Bind(&MockTriggerableClientSocket::DoCallback
,
168 weak_factory_
.GetWeakPtr());
171 static scoped_ptr
<StreamSocket
> MakeMockPendingClientSocket(
172 const AddressList
& addrlist
,
174 net::NetLog
* net_log
) {
175 scoped_ptr
<MockTriggerableClientSocket
> socket(
176 new MockTriggerableClientSocket(addrlist
, should_connect
, net_log
));
177 base::MessageLoop::current()->PostTask(FROM_HERE
,
178 socket
->GetConnectCallback());
179 return socket
.PassAs
<StreamSocket
>();
182 static scoped_ptr
<StreamSocket
> MakeMockDelayedClientSocket(
183 const AddressList
& addrlist
,
185 const base::TimeDelta
& delay
,
186 net::NetLog
* net_log
) {
187 scoped_ptr
<MockTriggerableClientSocket
> socket(
188 new MockTriggerableClientSocket(addrlist
, should_connect
, net_log
));
189 base::MessageLoop::current()->PostDelayedTask(
190 FROM_HERE
, socket
->GetConnectCallback(), delay
);
191 return socket
.PassAs
<StreamSocket
>();
194 static scoped_ptr
<StreamSocket
> MakeMockStalledClientSocket(
195 const AddressList
& addrlist
,
196 net::NetLog
* net_log
) {
197 scoped_ptr
<MockTriggerableClientSocket
> socket(
198 new MockTriggerableClientSocket(addrlist
, true, net_log
));
199 return socket
.PassAs
<StreamSocket
>();
202 // StreamSocket implementation.
203 virtual int Connect(const CompletionCallback
& callback
) OVERRIDE
{
204 DCHECK(callback_
.is_null());
205 callback_
= callback
;
206 return ERR_IO_PENDING
;
209 virtual void Disconnect() OVERRIDE
{}
211 virtual bool IsConnected() const OVERRIDE
{ return is_connected_
; }
212 virtual bool IsConnectedAndIdle() const OVERRIDE
{ return is_connected_
; }
213 virtual int GetPeerAddress(IPEndPoint
* address
) const OVERRIDE
{
214 *address
= addrlist_
.front();
217 virtual int GetLocalAddress(IPEndPoint
* address
) const OVERRIDE
{
219 return ERR_SOCKET_NOT_CONNECTED
;
220 if (addrlist_
.front().GetFamily() == ADDRESS_FAMILY_IPV4
)
221 SetIPv4Address(address
);
223 SetIPv6Address(address
);
226 virtual const BoundNetLog
& NetLog() const OVERRIDE
{ return net_log_
; }
228 virtual void SetSubresourceSpeculation() OVERRIDE
{}
229 virtual void SetOmniboxSpeculation() OVERRIDE
{}
230 virtual bool WasEverUsed() const OVERRIDE
{ return false; }
231 virtual bool UsingTCPFastOpen() const OVERRIDE
{ return false; }
232 virtual bool WasNpnNegotiated() const OVERRIDE
{ return false; }
233 virtual NextProto
GetNegotiatedProtocol() const OVERRIDE
{
234 return kProtoUnknown
;
236 virtual bool GetSSLInfo(SSLInfo
* ssl_info
) OVERRIDE
{ return false; }
238 // Socket implementation.
239 virtual int Read(IOBuffer
* buf
,
241 const CompletionCallback
& callback
) OVERRIDE
{
245 virtual int Write(IOBuffer
* buf
,
247 const CompletionCallback
& callback
) OVERRIDE
{
250 virtual int SetReceiveBufferSize(int32 size
) OVERRIDE
{ return OK
; }
251 virtual int SetSendBufferSize(int32 size
) OVERRIDE
{ return OK
; }
255 is_connected_
= should_connect_
;
256 callback_
.Run(is_connected_
? OK
: ERR_CONNECTION_FAILED
);
259 bool should_connect_
;
261 const AddressList addrlist_
;
262 BoundNetLog net_log_
;
263 CompletionCallback callback_
;
265 base::WeakPtrFactory
<MockTriggerableClientSocket
> weak_factory_
;
267 DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket
);
272 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle
& handle
) {
273 LoadTimingInfo load_timing_info
;
274 // Only pass true in as |is_reused|, as in general, HttpStream types should
275 // have stricter concepts of reuse than socket pools.
276 EXPECT_TRUE(handle
.GetLoadTimingInfo(true, &load_timing_info
));
278 EXPECT_TRUE(load_timing_info
.socket_reused
);
279 EXPECT_NE(NetLog::Source::kInvalidId
, load_timing_info
.socket_log_id
);
281 ExpectConnectTimingHasNoTimes(load_timing_info
.connect_timing
);
282 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info
);
285 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle
& handle
) {
286 EXPECT_FALSE(handle
.is_reused());
288 LoadTimingInfo load_timing_info
;
289 EXPECT_TRUE(handle
.GetLoadTimingInfo(false, &load_timing_info
));
291 EXPECT_FALSE(load_timing_info
.socket_reused
);
292 EXPECT_NE(NetLog::Source::kInvalidId
, load_timing_info
.socket_log_id
);
294 ExpectConnectTimingHasTimes(load_timing_info
.connect_timing
,
295 CONNECT_TIMING_HAS_DNS_TIMES
);
296 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info
);
298 TestLoadTimingInfoConnectedReused(handle
);
301 void SetIPv4Address(IPEndPoint
* address
) {
302 *address
= IPEndPoint(ParseIP("1.1.1.1"), 80);
305 void SetIPv6Address(IPEndPoint
* address
) {
306 *address
= IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
309 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
312 allocation_count_(0),
313 client_socket_type_(MOCK_CLIENT_SOCKET
),
314 client_socket_types_(NULL
),
315 client_socket_index_(0),
316 client_socket_index_max_(0),
317 delay_(base::TimeDelta::FromMilliseconds(
318 ClientSocketPool::kMaxConnectRetryIntervalMs
)) {}
320 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {}
322 scoped_ptr
<DatagramClientSocket
>
323 MockTransportClientSocketFactory::CreateDatagramClientSocket(
324 DatagramSocket::BindType bind_type
,
325 const RandIntCallback
& rand_int_cb
,
327 const NetLog::Source
& source
) {
329 return scoped_ptr
<DatagramClientSocket
>();
332 scoped_ptr
<StreamSocket
>
333 MockTransportClientSocketFactory::CreateTransportClientSocket(
334 const AddressList
& addresses
,
335 NetLog
* /* net_log */,
336 const NetLog::Source
& /* source */) {
339 ClientSocketType type
= client_socket_type_
;
340 if (client_socket_types_
&& client_socket_index_
< client_socket_index_max_
) {
341 type
= client_socket_types_
[client_socket_index_
++];
345 case MOCK_CLIENT_SOCKET
:
346 return scoped_ptr
<StreamSocket
>(
347 new MockConnectClientSocket(addresses
, net_log_
));
348 case MOCK_FAILING_CLIENT_SOCKET
:
349 return scoped_ptr
<StreamSocket
>(
350 new MockFailingClientSocket(addresses
, net_log_
));
351 case MOCK_PENDING_CLIENT_SOCKET
:
352 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
353 addresses
, true, net_log_
);
354 case MOCK_PENDING_FAILING_CLIENT_SOCKET
:
355 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
356 addresses
, false, net_log_
);
357 case MOCK_DELAYED_CLIENT_SOCKET
:
358 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
359 addresses
, true, delay_
, net_log_
);
360 case MOCK_DELAYED_FAILING_CLIENT_SOCKET
:
361 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
362 addresses
, false, delay_
, net_log_
);
363 case MOCK_STALLED_CLIENT_SOCKET
:
364 return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses
,
366 case MOCK_TRIGGERABLE_CLIENT_SOCKET
: {
367 scoped_ptr
<MockTriggerableClientSocket
> rv(
368 new MockTriggerableClientSocket(addresses
, true, net_log_
));
369 triggerable_sockets_
.push(rv
->GetConnectCallback());
370 // run_loop_quit_closure_ behaves like a condition variable. It will
371 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
372 // don't need to worry about atomicity because this code is
374 if (!run_loop_quit_closure_
.is_null())
375 run_loop_quit_closure_
.Run();
376 return rv
.PassAs
<StreamSocket
>();
380 return scoped_ptr
<StreamSocket
>(
381 new MockConnectClientSocket(addresses
, net_log_
));
385 scoped_ptr
<SSLClientSocket
>
386 MockTransportClientSocketFactory::CreateSSLClientSocket(
387 scoped_ptr
<ClientSocketHandle
> transport_socket
,
388 const HostPortPair
& host_and_port
,
389 const SSLConfig
& ssl_config
,
390 const SSLClientSocketContext
& context
) {
392 return scoped_ptr
<SSLClientSocket
>();
395 void MockTransportClientSocketFactory::ClearSSLSessionCache() {
399 void MockTransportClientSocketFactory::set_client_socket_types(
400 ClientSocketType
* type_list
,
402 DCHECK_GT(num_types
, 0);
403 client_socket_types_
= type_list
;
404 client_socket_index_
= 0;
405 client_socket_index_max_
= num_types
;
409 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
410 while (triggerable_sockets_
.empty()) {
411 base::RunLoop run_loop
;
412 run_loop_quit_closure_
= run_loop
.QuitClosure();
414 run_loop_quit_closure_
.Reset();
416 base::Closure trigger
= triggerable_sockets_
.front();
417 triggerable_sockets_
.pop();