1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/transport_client_socket_pool_test_util.h"
9 #include "base/logging.h"
10 #include "base/memory/weak_ptr.h"
11 #include "base/run_loop.h"
12 #include "net/base/ip_endpoint.h"
13 #include "net/base/load_timing_info.h"
14 #include "net/base/load_timing_info_test_util.h"
15 #include "net/base/net_util.h"
16 #include "net/socket/client_socket_handle.h"
17 #include "net/socket/ssl_client_socket.h"
18 #include "net/udp/datagram_client_socket.h"
19 #include "testing/gtest/include/gtest/gtest.h"
25 IPAddressNumber
ParseIP(const std::string
& ip
) {
26 IPAddressNumber number
;
27 CHECK(ParseIPLiteralToNumber(ip
, &number
));
31 // A StreamSocket which connects synchronously and successfully.
32 class MockConnectClientSocket
: public StreamSocket
{
34 MockConnectClientSocket(const AddressList
& addrlist
, net::NetLog
* net_log
)
37 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)),
38 use_tcp_fastopen_(false) {}
40 // StreamSocket implementation.
41 int Connect(const CompletionCallback
& callback
) override
{
45 void Disconnect() override
{ connected_
= false; }
46 bool IsConnected() const override
{ return connected_
; }
47 bool IsConnectedAndIdle() const override
{ return connected_
; }
49 int GetPeerAddress(IPEndPoint
* address
) const override
{
50 *address
= addrlist_
.front();
53 int GetLocalAddress(IPEndPoint
* address
) const override
{
55 return ERR_SOCKET_NOT_CONNECTED
;
56 if (addrlist_
.front().GetFamily() == ADDRESS_FAMILY_IPV4
)
57 SetIPv4Address(address
);
59 SetIPv6Address(address
);
62 const BoundNetLog
& NetLog() const override
{ return net_log_
; }
64 void SetSubresourceSpeculation() override
{}
65 void SetOmniboxSpeculation() override
{}
66 bool WasEverUsed() const override
{ return false; }
67 void EnableTCPFastOpenIfSupported() override
{ use_tcp_fastopen_
= true; }
68 bool UsingTCPFastOpen() const override
{ return use_tcp_fastopen_
; }
69 bool WasNpnNegotiated() const override
{ return false; }
70 NextProto
GetNegotiatedProtocol() const override
{ return kProtoUnknown
; }
71 bool GetSSLInfo(SSLInfo
* ssl_info
) override
{ return false; }
73 // Socket implementation.
74 int Read(IOBuffer
* buf
,
76 const CompletionCallback
& callback
) override
{
79 int Write(IOBuffer
* buf
,
81 const CompletionCallback
& callback
) override
{
84 int SetReceiveBufferSize(int32 size
) override
{ return OK
; }
85 int SetSendBufferSize(int32 size
) override
{ return OK
; }
89 const AddressList addrlist_
;
91 bool use_tcp_fastopen_
;
93 DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket
);
96 class MockFailingClientSocket
: public StreamSocket
{
98 MockFailingClientSocket(const AddressList
& addrlist
, net::NetLog
* net_log
)
99 : addrlist_(addrlist
),
100 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)),
101 use_tcp_fastopen_(false) {}
103 // StreamSocket implementation.
104 int Connect(const CompletionCallback
& callback
) override
{
105 return ERR_CONNECTION_FAILED
;
108 void Disconnect() override
{}
110 bool IsConnected() const override
{ return false; }
111 bool IsConnectedAndIdle() const override
{ return false; }
112 int GetPeerAddress(IPEndPoint
* address
) const override
{
113 return ERR_UNEXPECTED
;
115 int GetLocalAddress(IPEndPoint
* address
) const override
{
116 return ERR_UNEXPECTED
;
118 const BoundNetLog
& NetLog() const override
{ return net_log_
; }
120 void SetSubresourceSpeculation() override
{}
121 void SetOmniboxSpeculation() override
{}
122 bool WasEverUsed() const override
{ return false; }
123 void EnableTCPFastOpenIfSupported() override
{ use_tcp_fastopen_
= true; }
124 bool UsingTCPFastOpen() const override
{ return use_tcp_fastopen_
; }
125 bool WasNpnNegotiated() const override
{ return false; }
126 NextProto
GetNegotiatedProtocol() const override
{ return kProtoUnknown
; }
127 bool GetSSLInfo(SSLInfo
* ssl_info
) override
{ return false; }
129 // Socket implementation.
130 int Read(IOBuffer
* buf
,
132 const CompletionCallback
& callback
) override
{
136 int Write(IOBuffer
* buf
,
138 const CompletionCallback
& callback
) override
{
141 int SetReceiveBufferSize(int32 size
) override
{ return OK
; }
142 int SetSendBufferSize(int32 size
) override
{ return OK
; }
145 const AddressList addrlist_
;
146 BoundNetLog net_log_
;
147 bool use_tcp_fastopen_
;
149 DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket
);
152 class MockTriggerableClientSocket
: public StreamSocket
{
154 // |should_connect| indicates whether the socket should successfully complete
156 MockTriggerableClientSocket(const AddressList
& addrlist
,
158 net::NetLog
* net_log
)
159 : should_connect_(should_connect
),
160 is_connected_(false),
162 net_log_(BoundNetLog::Make(net_log
, NetLog::SOURCE_SOCKET
)),
163 use_tcp_fastopen_(false),
164 weak_factory_(this) {}
166 // Call this method to get a closure which will trigger the connect callback
167 // when called. The closure can be called even after the socket is deleted; it
168 // will safely do nothing.
169 base::Closure
GetConnectCallback() {
170 return base::Bind(&MockTriggerableClientSocket::DoCallback
,
171 weak_factory_
.GetWeakPtr());
174 static scoped_ptr
<StreamSocket
> MakeMockPendingClientSocket(
175 const AddressList
& addrlist
,
177 net::NetLog
* net_log
) {
178 scoped_ptr
<MockTriggerableClientSocket
> socket(
179 new MockTriggerableClientSocket(addrlist
, should_connect
, net_log
));
180 base::MessageLoop::current()->PostTask(FROM_HERE
,
181 socket
->GetConnectCallback());
182 return socket
.Pass();
185 static scoped_ptr
<StreamSocket
> MakeMockDelayedClientSocket(
186 const AddressList
& addrlist
,
188 const base::TimeDelta
& delay
,
189 net::NetLog
* net_log
) {
190 scoped_ptr
<MockTriggerableClientSocket
> socket(
191 new MockTriggerableClientSocket(addrlist
, should_connect
, net_log
));
192 base::MessageLoop::current()->PostDelayedTask(
193 FROM_HERE
, socket
->GetConnectCallback(), delay
);
194 return socket
.Pass();
197 static scoped_ptr
<StreamSocket
> MakeMockStalledClientSocket(
198 const AddressList
& addrlist
,
199 net::NetLog
* net_log
) {
200 scoped_ptr
<MockTriggerableClientSocket
> socket(
201 new MockTriggerableClientSocket(addrlist
, true, net_log
));
202 return socket
.Pass();
205 // StreamSocket implementation.
206 int Connect(const CompletionCallback
& callback
) override
{
207 DCHECK(callback_
.is_null());
208 callback_
= callback
;
209 return ERR_IO_PENDING
;
212 void Disconnect() override
{}
214 bool IsConnected() const override
{ return is_connected_
; }
215 bool IsConnectedAndIdle() const override
{ return is_connected_
; }
216 int GetPeerAddress(IPEndPoint
* address
) const override
{
217 *address
= addrlist_
.front();
220 int GetLocalAddress(IPEndPoint
* address
) const override
{
222 return ERR_SOCKET_NOT_CONNECTED
;
223 if (addrlist_
.front().GetFamily() == ADDRESS_FAMILY_IPV4
)
224 SetIPv4Address(address
);
226 SetIPv6Address(address
);
229 const BoundNetLog
& NetLog() const override
{ return net_log_
; }
231 void SetSubresourceSpeculation() override
{}
232 void SetOmniboxSpeculation() override
{}
233 bool WasEverUsed() const override
{ return false; }
234 void EnableTCPFastOpenIfSupported() override
{ use_tcp_fastopen_
= true; }
235 bool UsingTCPFastOpen() const override
{ return use_tcp_fastopen_
; }
236 bool WasNpnNegotiated() const override
{ return false; }
237 NextProto
GetNegotiatedProtocol() const override
{ return kProtoUnknown
; }
238 bool GetSSLInfo(SSLInfo
* ssl_info
) override
{ return false; }
240 // Socket implementation.
241 int Read(IOBuffer
* buf
,
243 const CompletionCallback
& callback
) override
{
247 int Write(IOBuffer
* buf
,
249 const CompletionCallback
& callback
) override
{
252 int SetReceiveBufferSize(int32 size
) override
{ return OK
; }
253 int SetSendBufferSize(int32 size
) override
{ return OK
; }
257 is_connected_
= should_connect_
;
258 callback_
.Run(is_connected_
? OK
: ERR_CONNECTION_FAILED
);
261 bool should_connect_
;
263 const AddressList addrlist_
;
264 BoundNetLog net_log_
;
265 CompletionCallback callback_
;
266 bool use_tcp_fastopen_
;
268 base::WeakPtrFactory
<MockTriggerableClientSocket
> weak_factory_
;
270 DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket
);
275 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle
& handle
) {
276 LoadTimingInfo load_timing_info
;
277 // Only pass true in as |is_reused|, as in general, HttpStream types should
278 // have stricter concepts of reuse than socket pools.
279 EXPECT_TRUE(handle
.GetLoadTimingInfo(true, &load_timing_info
));
281 EXPECT_TRUE(load_timing_info
.socket_reused
);
282 EXPECT_NE(NetLog::Source::kInvalidId
, load_timing_info
.socket_log_id
);
284 ExpectConnectTimingHasNoTimes(load_timing_info
.connect_timing
);
285 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info
);
288 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle
& handle
) {
289 EXPECT_FALSE(handle
.is_reused());
291 LoadTimingInfo load_timing_info
;
292 EXPECT_TRUE(handle
.GetLoadTimingInfo(false, &load_timing_info
));
294 EXPECT_FALSE(load_timing_info
.socket_reused
);
295 EXPECT_NE(NetLog::Source::kInvalidId
, load_timing_info
.socket_log_id
);
297 ExpectConnectTimingHasTimes(load_timing_info
.connect_timing
,
298 CONNECT_TIMING_HAS_DNS_TIMES
);
299 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info
);
301 TestLoadTimingInfoConnectedReused(handle
);
304 void SetIPv4Address(IPEndPoint
* address
) {
305 *address
= IPEndPoint(ParseIP("1.1.1.1"), 80);
308 void SetIPv6Address(IPEndPoint
* address
) {
309 *address
= IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
312 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
315 allocation_count_(0),
316 client_socket_type_(MOCK_CLIENT_SOCKET
),
317 client_socket_types_(NULL
),
318 client_socket_index_(0),
319 client_socket_index_max_(0),
320 delay_(base::TimeDelta::FromMilliseconds(
321 ClientSocketPool::kMaxConnectRetryIntervalMs
)) {}
323 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {}
325 scoped_ptr
<DatagramClientSocket
>
326 MockTransportClientSocketFactory::CreateDatagramClientSocket(
327 DatagramSocket::BindType bind_type
,
328 const RandIntCallback
& rand_int_cb
,
330 const NetLog::Source
& source
) {
332 return scoped_ptr
<DatagramClientSocket
>();
335 scoped_ptr
<StreamSocket
>
336 MockTransportClientSocketFactory::CreateTransportClientSocket(
337 const AddressList
& addresses
,
338 NetLog
* /* net_log */,
339 const NetLog::Source
& /* source */) {
342 ClientSocketType type
= client_socket_type_
;
343 if (client_socket_types_
&& client_socket_index_
< client_socket_index_max_
) {
344 type
= client_socket_types_
[client_socket_index_
++];
348 case MOCK_CLIENT_SOCKET
:
349 return scoped_ptr
<StreamSocket
>(
350 new MockConnectClientSocket(addresses
, net_log_
));
351 case MOCK_FAILING_CLIENT_SOCKET
:
352 return scoped_ptr
<StreamSocket
>(
353 new MockFailingClientSocket(addresses
, net_log_
));
354 case MOCK_PENDING_CLIENT_SOCKET
:
355 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
356 addresses
, true, net_log_
);
357 case MOCK_PENDING_FAILING_CLIENT_SOCKET
:
358 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
359 addresses
, false, net_log_
);
360 case MOCK_DELAYED_CLIENT_SOCKET
:
361 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
362 addresses
, true, delay_
, net_log_
);
363 case MOCK_DELAYED_FAILING_CLIENT_SOCKET
:
364 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
365 addresses
, false, delay_
, net_log_
);
366 case MOCK_STALLED_CLIENT_SOCKET
:
367 return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses
,
369 case MOCK_TRIGGERABLE_CLIENT_SOCKET
: {
370 scoped_ptr
<MockTriggerableClientSocket
> rv(
371 new MockTriggerableClientSocket(addresses
, true, net_log_
));
372 triggerable_sockets_
.push(rv
->GetConnectCallback());
373 // run_loop_quit_closure_ behaves like a condition variable. It will
374 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
375 // don't need to worry about atomicity because this code is
377 if (!run_loop_quit_closure_
.is_null())
378 run_loop_quit_closure_
.Run();
383 return scoped_ptr
<StreamSocket
>(
384 new MockConnectClientSocket(addresses
, net_log_
));
388 scoped_ptr
<SSLClientSocket
>
389 MockTransportClientSocketFactory::CreateSSLClientSocket(
390 scoped_ptr
<ClientSocketHandle
> transport_socket
,
391 const HostPortPair
& host_and_port
,
392 const SSLConfig
& ssl_config
,
393 const SSLClientSocketContext
& context
) {
395 return scoped_ptr
<SSLClientSocket
>();
398 void MockTransportClientSocketFactory::ClearSSLSessionCache() {
402 void MockTransportClientSocketFactory::set_client_socket_types(
403 ClientSocketType
* type_list
,
405 DCHECK_GT(num_types
, 0);
406 client_socket_types_
= type_list
;
407 client_socket_index_
= 0;
408 client_socket_index_max_
= num_types
;
412 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
413 while (triggerable_sockets_
.empty()) {
414 base::RunLoop run_loop
;
415 run_loop_quit_closure_
= run_loop
.QuitClosure();
417 run_loop_quit_closure_
.Reset();
419 base::Closure trigger
= triggerable_sockets_
.front();
420 triggerable_sockets_
.pop();