prune resources in MemoryCache
[chromium-blink-merge.git] / net / socket / socket_test_util.cc
blob308de2eaf6fe4de73868d8cec17f84748a570b89
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socket_test_util.h"
7 #include <algorithm>
8 #include <vector>
10 #include "base/basictypes.h"
11 #include "base/bind.h"
12 #include "base/bind_helpers.h"
13 #include "base/callback_helpers.h"
14 #include "base/compiler_specific.h"
15 #include "base/message_loop/message_loop.h"
16 #include "base/run_loop.h"
17 #include "base/time/time.h"
18 #include "net/base/address_family.h"
19 #include "net/base/address_list.h"
20 #include "net/base/auth.h"
21 #include "net/base/load_timing_info.h"
22 #include "net/http/http_network_session.h"
23 #include "net/http/http_request_headers.h"
24 #include "net/http/http_response_headers.h"
25 #include "net/socket/client_socket_pool_histograms.h"
26 #include "net/socket/socket.h"
27 #include "net/ssl/ssl_cert_request_info.h"
28 #include "net/ssl/ssl_connection_status_flags.h"
29 #include "net/ssl/ssl_info.h"
30 #include "testing/gtest/include/gtest/gtest.h"
32 // Socket events are easier to debug if you log individual reads and writes.
33 // Enable these if locally debugging, but they are too noisy for the waterfall.
34 #if 0
35 #define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() "
36 #else
37 #define NET_TRACE(level, s) EAT_STREAM_PARAMETERS
38 #endif
40 namespace net {
42 namespace {
44 inline char AsciifyHigh(char x) {
45 char nybble = static_cast<char>((x >> 4) & 0x0F);
46 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
49 inline char AsciifyLow(char x) {
50 char nybble = static_cast<char>((x >> 0) & 0x0F);
51 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
54 inline char Asciify(char x) {
55 if ((x < 0) || !isprint(x))
56 return '.';
57 return x;
60 void DumpData(const char* data, int data_len) {
61 if (logging::LOG_INFO < logging::GetMinLogLevel())
62 return;
63 DVLOG(1) << "Length: " << data_len;
64 const char* pfx = "Data: ";
65 if (!data || (data_len <= 0)) {
66 DVLOG(1) << pfx << "<None>";
67 } else {
68 int i;
69 for (i = 0; i <= (data_len - 4); i += 4) {
70 DVLOG(1) << pfx
71 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
72 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
73 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
74 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
75 << " '"
76 << Asciify(data[i + 0])
77 << Asciify(data[i + 1])
78 << Asciify(data[i + 2])
79 << Asciify(data[i + 3])
80 << "'";
81 pfx = " ";
83 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
84 switch (data_len - i) {
85 case 3:
86 DVLOG(1) << pfx
87 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
88 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
89 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
90 << " '"
91 << Asciify(data[i + 0])
92 << Asciify(data[i + 1])
93 << Asciify(data[i + 2])
94 << " '";
95 break;
96 case 2:
97 DVLOG(1) << pfx
98 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
99 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
100 << " '"
101 << Asciify(data[i + 0])
102 << Asciify(data[i + 1])
103 << " '";
104 break;
105 case 1:
106 DVLOG(1) << pfx
107 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
108 << " '"
109 << Asciify(data[i + 0])
110 << " '";
111 break;
116 template <MockReadWriteType type>
117 void DumpMockReadWrite(const MockReadWrite<type>& r) {
118 if (logging::LOG_INFO < logging::GetMinLogLevel())
119 return;
120 DVLOG(1) << "Async: " << (r.mode == ASYNC)
121 << "\nResult: " << r.result;
122 DumpData(r.data, r.data_len);
123 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
124 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
125 << "\nTime: " << r.time_stamp.ToInternalValue();
128 } // namespace
130 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
131 IPAddressNumber ip;
132 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
133 peer_addr = IPEndPoint(ip, 0);
136 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
137 IPAddressNumber ip;
138 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
139 peer_addr = IPEndPoint(ip, 0);
142 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) :
143 mode(io_mode),
144 result(r),
145 peer_addr(addr) {
148 MockConnect::~MockConnect() {}
150 StaticSocketDataProvider::StaticSocketDataProvider()
151 : reads_(NULL),
152 read_index_(0),
153 read_count_(0),
154 writes_(NULL),
155 write_index_(0),
156 write_count_(0) {
159 StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
160 size_t reads_count,
161 MockWrite* writes,
162 size_t writes_count)
163 : reads_(reads),
164 read_index_(0),
165 read_count_(reads_count),
166 writes_(writes),
167 write_index_(0),
168 write_count_(writes_count) {
171 StaticSocketDataProvider::~StaticSocketDataProvider() {}
173 const MockRead& StaticSocketDataProvider::PeekRead() const {
174 CHECK(!at_read_eof());
175 return reads_[read_index_];
178 const MockWrite& StaticSocketDataProvider::PeekWrite() const {
179 CHECK(!at_write_eof());
180 return writes_[write_index_];
183 const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
184 CHECK_LT(index, read_count_);
185 return reads_[index];
188 const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
189 CHECK_LT(index, write_count_);
190 return writes_[index];
193 MockRead StaticSocketDataProvider::GetNextRead() {
194 CHECK(!at_read_eof());
195 reads_[read_index_].time_stamp = base::Time::Now();
196 return reads_[read_index_++];
199 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
200 if (!writes_) {
201 // Not using mock writes; succeed synchronously.
202 return MockWriteResult(SYNCHRONOUS, data.length());
204 EXPECT_FALSE(at_write_eof());
205 if (at_write_eof()) {
206 // Show what the extra write actually consists of.
207 EXPECT_EQ("<unexpected write>", data);
208 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
211 // Check that what we are writing matches the expectation.
212 // Then give the mocked return value.
213 MockWrite* w = &writes_[write_index_++];
214 w->time_stamp = base::Time::Now();
215 int result = w->result;
216 if (w->data) {
217 // Note - we can simulate a partial write here. If the expected data
218 // is a match, but shorter than the write actually written, that is legal.
219 // Example:
220 // Application writes "foobarbaz" (9 bytes)
221 // Expected write was "foo" (3 bytes)
222 // This is a success, and we return 3 to the application.
223 std::string expected_data(w->data, w->data_len);
224 EXPECT_GE(data.length(), expected_data.length());
225 std::string actual_data(data.substr(0, w->data_len));
226 EXPECT_EQ(expected_data, actual_data);
227 if (expected_data != actual_data)
228 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
229 if (result == OK)
230 result = w->data_len;
232 return MockWriteResult(w->mode, result);
235 void StaticSocketDataProvider::Reset() {
236 read_index_ = 0;
237 write_index_ = 0;
240 DynamicSocketDataProvider::DynamicSocketDataProvider()
241 : short_read_limit_(0),
242 allow_unconsumed_reads_(false) {
245 DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
247 MockRead DynamicSocketDataProvider::GetNextRead() {
248 if (reads_.empty())
249 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
250 MockRead result = reads_.front();
251 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
252 reads_.pop_front();
253 } else {
254 result.data_len = short_read_limit_;
255 reads_.front().data += result.data_len;
256 reads_.front().data_len -= result.data_len;
258 return result;
261 void DynamicSocketDataProvider::Reset() {
262 reads_.clear();
265 void DynamicSocketDataProvider::SimulateRead(const char* data,
266 const size_t length) {
267 if (!allow_unconsumed_reads_) {
268 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
270 reads_.push_back(MockRead(ASYNC, data, length));
273 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
274 : connect(mode, result),
275 next_proto_status(SSLClientSocket::kNextProtoUnsupported),
276 was_npn_negotiated(false),
277 protocol_negotiated(kProtoUnknown),
278 client_cert_sent(false),
279 cert_request_info(NULL),
280 channel_id_sent(false),
281 connection_status(0),
282 should_pause_on_connect(false),
283 is_in_session_cache(false) {
284 SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2,
285 &connection_status);
286 // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305
287 SSLConnectionStatusSetCipherSuite(0xcc14, &connection_status);
290 SSLSocketDataProvider::~SSLSocketDataProvider() {
293 void SSLSocketDataProvider::SetNextProto(NextProto proto) {
294 was_npn_negotiated = true;
295 next_proto_status = SSLClientSocket::kNextProtoNegotiated;
296 protocol_negotiated = proto;
297 next_proto = SSLClientSocket::NextProtoToString(proto);
300 DelayedSocketData::DelayedSocketData(
301 int write_delay, MockRead* reads, size_t reads_count,
302 MockWrite* writes, size_t writes_count)
303 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
304 write_delay_(write_delay),
305 read_in_progress_(false),
306 weak_factory_(this) {
307 DCHECK_GE(write_delay_, 0);
310 DelayedSocketData::DelayedSocketData(
311 const MockConnect& connect, int write_delay, MockRead* reads,
312 size_t reads_count, MockWrite* writes, size_t writes_count)
313 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
314 write_delay_(write_delay),
315 read_in_progress_(false),
316 weak_factory_(this) {
317 DCHECK_GE(write_delay_, 0);
318 set_connect_data(connect);
321 DelayedSocketData::~DelayedSocketData() {
324 void DelayedSocketData::ForceNextRead() {
325 DCHECK(read_in_progress_);
326 write_delay_ = 0;
327 CompleteRead();
330 MockRead DelayedSocketData::GetNextRead() {
331 MockRead out = MockRead(ASYNC, ERR_IO_PENDING);
332 if (write_delay_ <= 0)
333 out = StaticSocketDataProvider::GetNextRead();
334 read_in_progress_ = (out.result == ERR_IO_PENDING);
335 return out;
338 MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
339 MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
340 // Now that our write has completed, we can allow reads to continue.
341 if (!--write_delay_ && read_in_progress_)
342 base::MessageLoop::current()->PostDelayedTask(
343 FROM_HERE,
344 base::Bind(&DelayedSocketData::CompleteRead,
345 weak_factory_.GetWeakPtr()),
346 base::TimeDelta::FromMilliseconds(100));
347 return rv;
350 void DelayedSocketData::Reset() {
351 set_socket(NULL);
352 read_in_progress_ = false;
353 weak_factory_.InvalidateWeakPtrs();
354 StaticSocketDataProvider::Reset();
357 void DelayedSocketData::CompleteRead() {
358 if (socket() && read_in_progress_)
359 socket()->OnReadComplete(GetNextRead());
362 OrderedSocketData::OrderedSocketData(
363 MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
364 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
365 sequence_number_(0), loop_stop_stage_(0),
366 blocked_(false), weak_factory_(this) {
369 OrderedSocketData::OrderedSocketData(
370 const MockConnect& connect,
371 MockRead* reads, size_t reads_count,
372 MockWrite* writes, size_t writes_count)
373 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
374 sequence_number_(0), loop_stop_stage_(0),
375 blocked_(false), weak_factory_(this) {
376 set_connect_data(connect);
379 void OrderedSocketData::EndLoop() {
380 // If we've already stopped the loop, don't do it again until we've advanced
381 // to the next sequence_number.
382 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()";
383 if (loop_stop_stage_ > 0) {
384 const MockRead& next_read = StaticSocketDataProvider::PeekRead();
385 if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
386 loop_stop_stage_) {
387 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
388 << ": Clearing stop index";
389 loop_stop_stage_ = 0;
390 } else {
391 return;
394 // Record the sequence_number at which we stopped the loop.
395 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
396 << ": Posting Quit at read " << read_index();
397 loop_stop_stage_ = sequence_number_;
400 MockRead OrderedSocketData::GetNextRead() {
401 weak_factory_.InvalidateWeakPtrs();
402 blocked_ = false;
403 const MockRead& next_read = StaticSocketDataProvider::PeekRead();
404 if (next_read.sequence_number & MockRead::STOPLOOP)
405 EndLoop();
406 if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
407 sequence_number_++) {
408 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
409 << ": Read " << read_index();
410 DumpMockReadWrite(next_read);
411 blocked_ = (next_read.result == ERR_IO_PENDING);
412 return StaticSocketDataProvider::GetNextRead();
414 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
415 << ": I/O Pending";
416 MockRead result = MockRead(ASYNC, ERR_IO_PENDING);
417 DumpMockReadWrite(result);
418 blocked_ = true;
419 return result;
422 MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
423 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
424 << ": Write " << write_index();
425 DumpMockReadWrite(PeekWrite());
426 ++sequence_number_;
427 if (blocked_) {
428 // TODO(willchan): This 100ms delay seems to work around some weirdness. We
429 // should probably fix the weirdness. One example is in SpdyStream,
430 // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the
431 // SYN_REPLY causes OnResponseReceived() to get called before
432 // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
433 base::MessageLoop::current()->PostDelayedTask(
434 FROM_HERE,
435 base::Bind(&OrderedSocketData::CompleteRead,
436 weak_factory_.GetWeakPtr()),
437 base::TimeDelta::FromMilliseconds(100));
439 return StaticSocketDataProvider::OnWrite(data);
442 void OrderedSocketData::Reset() {
443 NET_TRACE(INFO, " *** ") << "Stage "
444 << sequence_number_ << ": Reset()";
445 sequence_number_ = 0;
446 loop_stop_stage_ = 0;
447 set_socket(NULL);
448 weak_factory_.InvalidateWeakPtrs();
449 StaticSocketDataProvider::Reset();
452 void OrderedSocketData::CompleteRead() {
453 if (socket() && blocked_) {
454 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_;
455 socket()->OnReadComplete(GetNextRead());
459 OrderedSocketData::~OrderedSocketData() {}
461 DeterministicSocketData::DeterministicSocketData(MockRead* reads,
462 size_t reads_count, MockWrite* writes, size_t writes_count)
463 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
464 sequence_number_(0),
465 current_read_(),
466 current_write_(),
467 stopping_sequence_number_(0),
468 stopped_(false),
469 print_debug_(false),
470 is_running_(false) {
471 VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count);
474 DeterministicSocketData::~DeterministicSocketData() {}
476 void DeterministicSocketData::Run() {
477 DCHECK(!is_running_);
478 is_running_ = true;
480 SetStopped(false);
481 int counter = 0;
482 // Continue to consume data until all data has run out, or the stopped_ flag
483 // has been set. Consuming data requires two separate operations -- running
484 // the tasks in the message loop, and explicitly invoking the read/write
485 // callbacks (simulating network I/O). We check our conditions between each,
486 // since they can change in either.
487 while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
488 if (counter % 2 == 0)
489 base::RunLoop().RunUntilIdle();
490 if (counter % 2 == 1) {
491 InvokeCallbacks();
493 counter++;
495 // We're done consuming new data, but it is possible there are still some
496 // pending callbacks which we expect to complete before returning.
497 while (delegate_.get() &&
498 (delegate_->WritePending() || delegate_->ReadPending()) &&
499 !stopped()) {
500 InvokeCallbacks();
501 base::RunLoop().RunUntilIdle();
503 SetStopped(false);
504 is_running_ = false;
507 void DeterministicSocketData::RunFor(int steps) {
508 StopAfter(steps);
509 Run();
512 void DeterministicSocketData::SetStop(int seq) {
513 DCHECK_LT(sequence_number_, seq);
514 stopping_sequence_number_ = seq;
515 stopped_ = false;
518 void DeterministicSocketData::StopAfter(int seq) {
519 SetStop(sequence_number_ + seq);
522 MockRead DeterministicSocketData::GetNextRead() {
523 current_read_ = StaticSocketDataProvider::PeekRead();
525 // Synchronous read while stopped is an error
526 if (stopped() && current_read_.mode == SYNCHRONOUS) {
527 LOG(ERROR) << "Unable to perform synchronous IO while stopped";
528 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
531 // Async read which will be called back in a future step.
532 if (sequence_number_ < current_read_.sequence_number) {
533 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
534 << ": I/O Pending";
535 MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING);
536 if (current_read_.mode == SYNCHRONOUS) {
537 LOG(ERROR) << "Unable to perform synchronous read: "
538 << current_read_.sequence_number
539 << " at stage: " << sequence_number_;
540 result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
542 if (print_debug_)
543 DumpMockReadWrite(result);
544 return result;
547 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
548 << ": Read " << read_index();
549 if (print_debug_)
550 DumpMockReadWrite(current_read_);
552 // Increment the sequence number if IO is complete
553 if (current_read_.mode == SYNCHRONOUS)
554 NextStep();
556 DCHECK_NE(ERR_IO_PENDING, current_read_.result);
557 StaticSocketDataProvider::GetNextRead();
559 return current_read_;
562 MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
563 const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
564 current_write_ = next_write;
566 // Synchronous write while stopped is an error
567 if (stopped() && next_write.mode == SYNCHRONOUS) {
568 LOG(ERROR) << "Unable to perform synchronous IO while stopped";
569 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
572 // Async write which will be called back in a future step.
573 if (sequence_number_ < next_write.sequence_number) {
574 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
575 << ": I/O Pending";
576 if (next_write.mode == SYNCHRONOUS) {
577 LOG(ERROR) << "Unable to perform synchronous write: "
578 << next_write.sequence_number << " at stage: " << sequence_number_;
579 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
581 } else {
582 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
583 << ": Write " << write_index();
586 if (print_debug_)
587 DumpMockReadWrite(next_write);
589 // Move to the next step if I/O is synchronous, since the operation will
590 // complete when this method returns.
591 if (next_write.mode == SYNCHRONOUS)
592 NextStep();
594 // This is either a sync write for this step, or an async write.
595 return StaticSocketDataProvider::OnWrite(data);
598 void DeterministicSocketData::Reset() {
599 NET_TRACE(INFO, " *** ") << "Stage "
600 << sequence_number_ << ": Reset()";
601 sequence_number_ = 0;
602 StaticSocketDataProvider::Reset();
603 NOTREACHED();
606 void DeterministicSocketData::InvokeCallbacks() {
607 if (delegate_.get() && delegate_->WritePending() &&
608 (current_write().sequence_number == sequence_number())) {
609 NextStep();
610 delegate_->CompleteWrite();
611 return;
613 if (delegate_.get() && delegate_->ReadPending() &&
614 (current_read().sequence_number == sequence_number())) {
615 NextStep();
616 delegate_->CompleteRead();
617 return;
621 void DeterministicSocketData::NextStep() {
622 // Invariant: Can never move *past* the stopping step.
623 DCHECK_LT(sequence_number_, stopping_sequence_number_);
624 sequence_number_++;
625 if (sequence_number_ == stopping_sequence_number_)
626 SetStopped(true);
629 void DeterministicSocketData::VerifyCorrectSequenceNumbers(
630 MockRead* reads, size_t reads_count,
631 MockWrite* writes, size_t writes_count) {
632 size_t read = 0;
633 size_t write = 0;
634 int expected = 0;
635 while (read < reads_count || write < writes_count) {
636 // Check to see that we have a read or write at the expected
637 // state.
638 if (read < reads_count && reads[read].sequence_number == expected) {
639 ++read;
640 ++expected;
641 continue;
643 if (write < writes_count && writes[write].sequence_number == expected) {
644 ++write;
645 ++expected;
646 continue;
648 NOTREACHED() << "Missing sequence number: " << expected;
649 return;
651 DCHECK_EQ(read, reads_count);
652 DCHECK_EQ(write, writes_count);
655 MockClientSocketFactory::MockClientSocketFactory() {}
657 MockClientSocketFactory::~MockClientSocketFactory() {}
659 void MockClientSocketFactory::AddSocketDataProvider(
660 SocketDataProvider* data) {
661 mock_data_.Add(data);
664 void MockClientSocketFactory::AddSSLSocketDataProvider(
665 SSLSocketDataProvider* data) {
666 mock_ssl_data_.Add(data);
669 void MockClientSocketFactory::ResetNextMockIndexes() {
670 mock_data_.ResetNextIndex();
671 mock_ssl_data_.ResetNextIndex();
674 scoped_ptr<DatagramClientSocket>
675 MockClientSocketFactory::CreateDatagramClientSocket(
676 DatagramSocket::BindType bind_type,
677 const RandIntCallback& rand_int_cb,
678 net::NetLog* net_log,
679 const net::NetLog::Source& source) {
680 SocketDataProvider* data_provider = mock_data_.GetNext();
681 scoped_ptr<MockUDPClientSocket> socket(
682 new MockUDPClientSocket(data_provider, net_log));
683 data_provider->set_socket(socket.get());
684 if (bind_type == DatagramSocket::RANDOM_BIND)
685 socket->set_source_port(static_cast<uint16>(rand_int_cb.Run(1025, 65535)));
686 return socket.Pass();
689 scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
690 const AddressList& addresses,
691 net::NetLog* net_log,
692 const net::NetLog::Source& source) {
693 SocketDataProvider* data_provider = mock_data_.GetNext();
694 scoped_ptr<MockTCPClientSocket> socket(
695 new MockTCPClientSocket(addresses, net_log, data_provider));
696 data_provider->set_socket(socket.get());
697 return socket.Pass();
700 scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
701 scoped_ptr<ClientSocketHandle> transport_socket,
702 const HostPortPair& host_and_port,
703 const SSLConfig& ssl_config,
704 const SSLClientSocketContext& context) {
705 scoped_ptr<MockSSLClientSocket> socket(
706 new MockSSLClientSocket(transport_socket.Pass(),
707 host_and_port,
708 ssl_config,
709 mock_ssl_data_.GetNext()));
710 ssl_client_sockets_.push_back(socket.get());
711 return socket.Pass();
714 void MockClientSocketFactory::ClearSSLSessionCache() {
717 const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ";
719 MockClientSocket::MockClientSocket(const BoundNetLog& net_log)
720 : connected_(false),
721 net_log_(net_log),
722 weak_factory_(this) {
723 IPAddressNumber ip;
724 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
725 peer_addr_ = IPEndPoint(ip, 0);
728 int MockClientSocket::SetReceiveBufferSize(int32 size) {
729 return OK;
732 int MockClientSocket::SetSendBufferSize(int32 size) {
733 return OK;
736 void MockClientSocket::Disconnect() {
737 connected_ = false;
740 bool MockClientSocket::IsConnected() const {
741 return connected_;
744 bool MockClientSocket::IsConnectedAndIdle() const {
745 return connected_;
748 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
749 if (!IsConnected())
750 return ERR_SOCKET_NOT_CONNECTED;
751 *address = peer_addr_;
752 return OK;
755 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
756 IPAddressNumber ip;
757 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
758 CHECK(rv);
759 *address = IPEndPoint(ip, 123);
760 return OK;
763 const BoundNetLog& MockClientSocket::NetLog() const {
764 return net_log_;
767 std::string MockClientSocket::GetSessionCacheKey() const {
768 NOTIMPLEMENTED();
769 return std::string();
772 bool MockClientSocket::InSessionCache() const {
773 NOTIMPLEMENTED();
774 return false;
777 void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) {
778 NOTIMPLEMENTED();
781 void MockClientSocket::GetSSLCertRequestInfo(
782 SSLCertRequestInfo* cert_request_info) {
785 int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label,
786 bool has_context,
787 const base::StringPiece& context,
788 unsigned char* out,
789 unsigned int outlen) {
790 memset(out, 'A', outlen);
791 return OK;
794 int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) {
795 out->assign(MockClientSocket::kTlsUnique);
796 return OK;
799 ChannelIDService* MockClientSocket::GetChannelIDService() const {
800 NOTREACHED();
801 return NULL;
804 SSLClientSocket::NextProtoStatus
805 MockClientSocket::GetNextProto(std::string* proto) {
806 proto->clear();
807 return SSLClientSocket::kNextProtoUnsupported;
810 scoped_refptr<X509Certificate>
811 MockClientSocket::GetUnverifiedServerCertificateChain() const {
812 NOTREACHED();
813 return NULL;
816 MockClientSocket::~MockClientSocket() {}
818 void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback,
819 int result) {
820 base::MessageLoop::current()->PostTask(
821 FROM_HERE,
822 base::Bind(&MockClientSocket::RunCallback,
823 weak_factory_.GetWeakPtr(),
824 callback,
825 result));
828 void MockClientSocket::RunCallback(const net::CompletionCallback& callback,
829 int result) {
830 if (!callback.is_null())
831 callback.Run(result);
834 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
835 net::NetLog* net_log,
836 SocketDataProvider* data)
837 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
838 addresses_(addresses),
839 data_(data),
840 read_offset_(0),
841 read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
842 need_read_data_(true),
843 peer_closed_connection_(false),
844 pending_buf_(NULL),
845 pending_buf_len_(0),
846 was_used_to_convey_data_(false) {
847 DCHECK(data_);
848 peer_addr_ = data->connect_data().peer_addr;
849 data_->Reset();
852 MockTCPClientSocket::~MockTCPClientSocket() {}
854 int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len,
855 const CompletionCallback& callback) {
856 if (!connected_)
857 return ERR_UNEXPECTED;
859 // If the buffer is already in use, a read is already in progress!
860 DCHECK(pending_buf_.get() == NULL);
862 // Store our async IO data.
863 pending_buf_ = buf;
864 pending_buf_len_ = buf_len;
865 pending_callback_ = callback;
867 if (need_read_data_) {
868 read_data_ = data_->GetNextRead();
869 if (read_data_.result == ERR_CONNECTION_CLOSED) {
870 // This MockRead is just a marker to instruct us to set
871 // peer_closed_connection_.
872 peer_closed_connection_ = true;
874 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
875 // This MockRead is just a marker to instruct us to set
876 // peer_closed_connection_. Skip it and get the next one.
877 read_data_ = data_->GetNextRead();
878 peer_closed_connection_ = true;
880 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
881 // to complete the async IO manually later (via OnReadComplete).
882 if (read_data_.result == ERR_IO_PENDING) {
883 // We need to be using async IO in this case.
884 DCHECK(!callback.is_null());
885 return ERR_IO_PENDING;
887 need_read_data_ = false;
890 return CompleteRead();
893 int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len,
894 const CompletionCallback& callback) {
895 DCHECK(buf);
896 DCHECK_GT(buf_len, 0);
898 if (!connected_)
899 return ERR_UNEXPECTED;
901 std::string data(buf->data(), buf_len);
902 MockWriteResult write_result = data_->OnWrite(data);
904 was_used_to_convey_data_ = true;
906 if (write_result.mode == ASYNC) {
907 RunCallbackAsync(callback, write_result.result);
908 return ERR_IO_PENDING;
911 return write_result.result;
914 int MockTCPClientSocket::Connect(const CompletionCallback& callback) {
915 if (connected_)
916 return OK;
917 connected_ = true;
918 peer_closed_connection_ = false;
919 if (data_->connect_data().mode == ASYNC) {
920 if (data_->connect_data().result == ERR_IO_PENDING)
921 pending_callback_ = callback;
922 else
923 RunCallbackAsync(callback, data_->connect_data().result);
924 return ERR_IO_PENDING;
926 return data_->connect_data().result;
929 void MockTCPClientSocket::Disconnect() {
930 MockClientSocket::Disconnect();
931 pending_callback_.Reset();
934 bool MockTCPClientSocket::IsConnected() const {
935 return connected_ && !peer_closed_connection_;
938 bool MockTCPClientSocket::IsConnectedAndIdle() const {
939 return IsConnected();
942 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
943 if (addresses_.empty())
944 return MockClientSocket::GetPeerAddress(address);
946 *address = addresses_[0];
947 return OK;
950 bool MockTCPClientSocket::WasEverUsed() const {
951 return was_used_to_convey_data_;
954 bool MockTCPClientSocket::UsingTCPFastOpen() const {
955 return false;
958 bool MockTCPClientSocket::WasNpnNegotiated() const {
959 return false;
962 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
963 return false;
966 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
967 // There must be a read pending.
968 DCHECK(pending_buf_.get());
969 // You can't complete a read with another ERR_IO_PENDING status code.
970 DCHECK_NE(ERR_IO_PENDING, data.result);
971 // Since we've been waiting for data, need_read_data_ should be true.
972 DCHECK(need_read_data_);
974 read_data_ = data;
975 need_read_data_ = false;
977 // The caller is simulating that this IO completes right now. Don't
978 // let CompleteRead() schedule a callback.
979 read_data_.mode = SYNCHRONOUS;
981 CompletionCallback callback = pending_callback_;
982 int rv = CompleteRead();
983 RunCallback(callback, rv);
986 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
987 CompletionCallback callback = pending_callback_;
988 RunCallback(callback, data.result);
991 int MockTCPClientSocket::CompleteRead() {
992 DCHECK(pending_buf_.get());
993 DCHECK(pending_buf_len_ > 0);
995 was_used_to_convey_data_ = true;
997 // Save the pending async IO data and reset our |pending_| state.
998 scoped_refptr<IOBuffer> buf = pending_buf_;
999 int buf_len = pending_buf_len_;
1000 CompletionCallback callback = pending_callback_;
1001 pending_buf_ = NULL;
1002 pending_buf_len_ = 0;
1003 pending_callback_.Reset();
1005 int result = read_data_.result;
1006 DCHECK(result != ERR_IO_PENDING);
1008 if (read_data_.data) {
1009 if (read_data_.data_len - read_offset_ > 0) {
1010 result = std::min(buf_len, read_data_.data_len - read_offset_);
1011 memcpy(buf->data(), read_data_.data + read_offset_, result);
1012 read_offset_ += result;
1013 if (read_offset_ == read_data_.data_len) {
1014 need_read_data_ = true;
1015 read_offset_ = 0;
1017 } else {
1018 result = 0; // EOF
1022 if (read_data_.mode == ASYNC) {
1023 DCHECK(!callback.is_null());
1024 RunCallbackAsync(callback, result);
1025 return ERR_IO_PENDING;
1027 return result;
1030 DeterministicSocketHelper::DeterministicSocketHelper(
1031 net::NetLog* net_log,
1032 DeterministicSocketData* data)
1033 : write_pending_(false),
1034 write_result_(0),
1035 read_data_(),
1036 read_buf_(NULL),
1037 read_buf_len_(0),
1038 read_pending_(false),
1039 data_(data),
1040 was_used_to_convey_data_(false),
1041 peer_closed_connection_(false),
1042 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) {
1045 DeterministicSocketHelper::~DeterministicSocketHelper() {}
1047 void DeterministicSocketHelper::CompleteWrite() {
1048 was_used_to_convey_data_ = true;
1049 write_pending_ = false;
1050 write_callback_.Run(write_result_);
1053 int DeterministicSocketHelper::CompleteRead() {
1054 DCHECK_GT(read_buf_len_, 0);
1055 DCHECK_LE(read_data_.data_len, read_buf_len_);
1056 DCHECK(read_buf_);
1058 was_used_to_convey_data_ = true;
1060 if (read_data_.result == ERR_IO_PENDING)
1061 read_data_ = data_->GetNextRead();
1062 DCHECK_NE(ERR_IO_PENDING, read_data_.result);
1063 // If read_data_.mode is ASYNC, we do not need to wait, since this is already
1064 // the callback. Therefore we don't even bother to check it.
1065 int result = read_data_.result;
1067 if (read_data_.data_len > 0) {
1068 DCHECK(read_data_.data);
1069 result = std::min(read_buf_len_, read_data_.data_len);
1070 memcpy(read_buf_->data(), read_data_.data, result);
1073 if (read_pending_) {
1074 read_pending_ = false;
1075 read_callback_.Run(result);
1078 return result;
1081 int DeterministicSocketHelper::Write(
1082 IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
1083 DCHECK(buf);
1084 DCHECK_GT(buf_len, 0);
1086 std::string data(buf->data(), buf_len);
1087 MockWriteResult write_result = data_->OnWrite(data);
1089 if (write_result.mode == ASYNC) {
1090 write_callback_ = callback;
1091 write_result_ = write_result.result;
1092 DCHECK(!write_callback_.is_null());
1093 write_pending_ = true;
1094 return ERR_IO_PENDING;
1097 was_used_to_convey_data_ = true;
1098 write_pending_ = false;
1099 return write_result.result;
1102 int DeterministicSocketHelper::Read(
1103 IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
1105 read_data_ = data_->GetNextRead();
1106 // The buffer should always be big enough to contain all the MockRead data. To
1107 // use small buffers, split the data into multiple MockReads.
1108 DCHECK_LE(read_data_.data_len, buf_len);
1110 if (read_data_.result == ERR_CONNECTION_CLOSED) {
1111 // This MockRead is just a marker to instruct us to set
1112 // peer_closed_connection_.
1113 peer_closed_connection_ = true;
1115 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1116 // This MockRead is just a marker to instruct us to set
1117 // peer_closed_connection_. Skip it and get the next one.
1118 read_data_ = data_->GetNextRead();
1119 peer_closed_connection_ = true;
1122 read_buf_ = buf;
1123 read_buf_len_ = buf_len;
1124 read_callback_ = callback;
1126 if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) {
1127 read_pending_ = true;
1128 DCHECK(!read_callback_.is_null());
1129 return ERR_IO_PENDING;
1132 was_used_to_convey_data_ = true;
1133 return CompleteRead();
1136 DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket(
1137 net::NetLog* net_log,
1138 DeterministicSocketData* data)
1139 : connected_(false),
1140 helper_(net_log, data),
1141 source_port_(123) {
1144 DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {}
1146 bool DeterministicMockUDPClientSocket::WritePending() const {
1147 return helper_.write_pending();
1150 bool DeterministicMockUDPClientSocket::ReadPending() const {
1151 return helper_.read_pending();
1154 void DeterministicMockUDPClientSocket::CompleteWrite() {
1155 helper_.CompleteWrite();
1158 int DeterministicMockUDPClientSocket::CompleteRead() {
1159 return helper_.CompleteRead();
1162 int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) {
1163 if (connected_)
1164 return OK;
1165 connected_ = true;
1166 peer_address_ = address;
1167 return helper_.data()->connect_data().result;
1170 int DeterministicMockUDPClientSocket::Write(
1171 IOBuffer* buf,
1172 int buf_len,
1173 const CompletionCallback& callback) {
1174 if (!connected_)
1175 return ERR_UNEXPECTED;
1177 return helper_.Write(buf, buf_len, callback);
1180 int DeterministicMockUDPClientSocket::Read(
1181 IOBuffer* buf,
1182 int buf_len,
1183 const CompletionCallback& callback) {
1184 if (!connected_)
1185 return ERR_UNEXPECTED;
1187 return helper_.Read(buf, buf_len, callback);
1190 int DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) {
1191 return OK;
1194 int DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) {
1195 return OK;
1198 void DeterministicMockUDPClientSocket::Close() {
1199 connected_ = false;
1202 int DeterministicMockUDPClientSocket::GetPeerAddress(
1203 IPEndPoint* address) const {
1204 *address = peer_address_;
1205 return OK;
1208 int DeterministicMockUDPClientSocket::GetLocalAddress(
1209 IPEndPoint* address) const {
1210 IPAddressNumber ip;
1211 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
1212 CHECK(rv);
1213 *address = IPEndPoint(ip, source_port_);
1214 return OK;
1217 const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const {
1218 return helper_.net_log();
1221 void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {}
1223 void DeterministicMockUDPClientSocket::OnConnectComplete(
1224 const MockConnect& data) {
1225 NOTIMPLEMENTED();
1228 DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
1229 net::NetLog* net_log,
1230 DeterministicSocketData* data)
1231 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
1232 helper_(net_log, data) {
1233 peer_addr_ = data->connect_data().peer_addr;
1236 DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
1238 bool DeterministicMockTCPClientSocket::WritePending() const {
1239 return helper_.write_pending();
1242 bool DeterministicMockTCPClientSocket::ReadPending() const {
1243 return helper_.read_pending();
1246 void DeterministicMockTCPClientSocket::CompleteWrite() {
1247 helper_.CompleteWrite();
1250 int DeterministicMockTCPClientSocket::CompleteRead() {
1251 return helper_.CompleteRead();
1254 int DeterministicMockTCPClientSocket::Write(
1255 IOBuffer* buf,
1256 int buf_len,
1257 const CompletionCallback& callback) {
1258 if (!connected_)
1259 return ERR_UNEXPECTED;
1261 return helper_.Write(buf, buf_len, callback);
1264 int DeterministicMockTCPClientSocket::Read(
1265 IOBuffer* buf,
1266 int buf_len,
1267 const CompletionCallback& callback) {
1268 if (!connected_)
1269 return ERR_UNEXPECTED;
1271 return helper_.Read(buf, buf_len, callback);
1274 // TODO(erikchen): Support connect sequencing.
1275 int DeterministicMockTCPClientSocket::Connect(
1276 const CompletionCallback& callback) {
1277 if (connected_)
1278 return OK;
1279 connected_ = true;
1280 if (helper_.data()->connect_data().mode == ASYNC) {
1281 RunCallbackAsync(callback, helper_.data()->connect_data().result);
1282 return ERR_IO_PENDING;
1284 return helper_.data()->connect_data().result;
1287 void DeterministicMockTCPClientSocket::Disconnect() {
1288 MockClientSocket::Disconnect();
1291 bool DeterministicMockTCPClientSocket::IsConnected() const {
1292 return connected_ && !helper_.peer_closed_connection();
1295 bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
1296 return IsConnected();
1299 bool DeterministicMockTCPClientSocket::WasEverUsed() const {
1300 return helper_.was_used_to_convey_data();
1303 bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
1304 return false;
1307 bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const {
1308 return false;
1311 bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1312 return false;
1315 void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
1317 void DeterministicMockTCPClientSocket::OnConnectComplete(
1318 const MockConnect& data) {}
1320 MockSSLClientSocket::MockSSLClientSocket(
1321 scoped_ptr<ClientSocketHandle> transport_socket,
1322 const HostPortPair& host_port_pair,
1323 const SSLConfig& ssl_config,
1324 SSLSocketDataProvider* data)
1325 : MockClientSocket(
1326 // Have to use the right BoundNetLog for LoadTimingInfo regression
1327 // tests.
1328 transport_socket->socket()->NetLog()),
1329 transport_(transport_socket.Pass()),
1330 host_port_pair_(host_port_pair),
1331 data_(data),
1332 is_npn_state_set_(false),
1333 new_npn_value_(false),
1334 is_protocol_negotiated_set_(false),
1335 protocol_negotiated_(kProtoUnknown),
1336 next_connect_state_(STATE_NONE),
1337 reached_connect_(false),
1338 weak_factory_(this) {
1339 DCHECK(data_);
1340 peer_addr_ = data->connect.peer_addr;
1343 MockSSLClientSocket::~MockSSLClientSocket() {
1344 Disconnect();
1347 int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len,
1348 const CompletionCallback& callback) {
1349 return transport_->socket()->Read(buf, buf_len, callback);
1352 int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len,
1353 const CompletionCallback& callback) {
1354 return transport_->socket()->Write(buf, buf_len, callback);
1357 int MockSSLClientSocket::Connect(const CompletionCallback& callback) {
1358 next_connect_state_ = STATE_SSL_CONNECT;
1359 reached_connect_ = true;
1360 int rv = DoConnectLoop(OK);
1361 if (rv == ERR_IO_PENDING)
1362 connect_callback_ = callback;
1363 return rv;
1366 void MockSSLClientSocket::Disconnect() {
1367 weak_factory_.InvalidateWeakPtrs();
1368 MockClientSocket::Disconnect();
1369 if (transport_->socket() != NULL)
1370 transport_->socket()->Disconnect();
1373 bool MockSSLClientSocket::IsConnected() const {
1374 return transport_->socket()->IsConnected() && connected_;
1377 bool MockSSLClientSocket::WasEverUsed() const {
1378 return transport_->socket()->WasEverUsed();
1381 bool MockSSLClientSocket::UsingTCPFastOpen() const {
1382 return transport_->socket()->UsingTCPFastOpen();
1385 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1386 return transport_->socket()->GetPeerAddress(address);
1389 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1390 ssl_info->Reset();
1391 ssl_info->cert = data_->cert;
1392 ssl_info->client_cert_sent = data_->client_cert_sent;
1393 ssl_info->channel_id_sent = data_->channel_id_sent;
1394 ssl_info->connection_status = data_->connection_status;
1395 return true;
1398 std::string MockSSLClientSocket::GetSessionCacheKey() const {
1399 // For the purposes of these tests, |host_and_port| will serve as the
1400 // cache key.
1401 return host_port_pair_.ToString();
1404 bool MockSSLClientSocket::InSessionCache() const {
1405 return data_->is_in_session_cache;
1408 void MockSSLClientSocket::SetHandshakeCompletionCallback(
1409 const base::Closure& cb) {
1410 handshake_completion_callback_ = cb;
1413 void MockSSLClientSocket::GetSSLCertRequestInfo(
1414 SSLCertRequestInfo* cert_request_info) {
1415 DCHECK(cert_request_info);
1416 if (data_->cert_request_info) {
1417 cert_request_info->host_and_port =
1418 data_->cert_request_info->host_and_port;
1419 cert_request_info->client_certs = data_->cert_request_info->client_certs;
1420 } else {
1421 cert_request_info->Reset();
1425 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
1426 std::string* proto) {
1427 *proto = data_->next_proto;
1428 return data_->next_proto_status;
1431 bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
1432 is_npn_state_set_ = true;
1433 return new_npn_value_ = negotiated;
1436 bool MockSSLClientSocket::WasNpnNegotiated() const {
1437 if (is_npn_state_set_)
1438 return new_npn_value_;
1439 return data_->was_npn_negotiated;
1442 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1443 if (is_protocol_negotiated_set_)
1444 return protocol_negotiated_;
1445 return data_->protocol_negotiated;
1448 void MockSSLClientSocket::set_protocol_negotiated(
1449 NextProto protocol_negotiated) {
1450 is_protocol_negotiated_set_ = true;
1451 protocol_negotiated_ = protocol_negotiated;
1454 bool MockSSLClientSocket::WasChannelIDSent() const {
1455 return data_->channel_id_sent;
1458 void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
1459 data_->channel_id_sent = channel_id_sent;
1462 ChannelIDService* MockSSLClientSocket::GetChannelIDService() const {
1463 return data_->channel_id_service;
1466 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1467 NOTIMPLEMENTED();
1470 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1471 NOTIMPLEMENTED();
1474 void MockSSLClientSocket::RestartPausedConnect() {
1475 DCHECK(data_->should_pause_on_connect);
1476 DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE);
1477 OnIOComplete(data_->connect.result);
1480 void MockSSLClientSocket::OnIOComplete(int result) {
1481 int rv = DoConnectLoop(result);
1482 if (rv != ERR_IO_PENDING)
1483 base::ResetAndReturn(&connect_callback_).Run(rv);
1486 int MockSSLClientSocket::DoConnectLoop(int result) {
1487 DCHECK_NE(next_connect_state_, STATE_NONE);
1489 int rv = result;
1490 do {
1491 ConnectState state = next_connect_state_;
1492 next_connect_state_ = STATE_NONE;
1493 switch (state) {
1494 case STATE_SSL_CONNECT:
1495 rv = DoSSLConnect();
1496 break;
1497 case STATE_SSL_CONNECT_COMPLETE:
1498 rv = DoSSLConnectComplete(rv);
1499 break;
1500 default:
1501 NOTREACHED() << "bad state";
1502 rv = ERR_UNEXPECTED;
1503 break;
1505 } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE);
1507 return rv;
1510 int MockSSLClientSocket::DoSSLConnect() {
1511 next_connect_state_ = STATE_SSL_CONNECT_COMPLETE;
1513 if (data_->should_pause_on_connect)
1514 return ERR_IO_PENDING;
1516 if (data_->connect.mode == ASYNC) {
1517 base::MessageLoop::current()->PostTask(
1518 FROM_HERE,
1519 base::Bind(&MockSSLClientSocket::OnIOComplete,
1520 weak_factory_.GetWeakPtr(),
1521 data_->connect.result));
1522 return ERR_IO_PENDING;
1525 return data_->connect.result;
1528 int MockSSLClientSocket::DoSSLConnectComplete(int result) {
1529 if (result == OK)
1530 connected_ = true;
1532 if (!handshake_completion_callback_.is_null())
1533 base::ResetAndReturn(&handshake_completion_callback_).Run();
1534 return result;
1537 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1538 net::NetLog* net_log)
1539 : connected_(false),
1540 data_(data),
1541 read_offset_(0),
1542 read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1543 need_read_data_(true),
1544 source_port_(123),
1545 pending_buf_(NULL),
1546 pending_buf_len_(0),
1547 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
1548 weak_factory_(this) {
1549 DCHECK(data_);
1550 data_->Reset();
1551 peer_addr_ = data->connect_data().peer_addr;
1554 MockUDPClientSocket::~MockUDPClientSocket() {}
1556 int MockUDPClientSocket::Read(IOBuffer* buf,
1557 int buf_len,
1558 const CompletionCallback& callback) {
1559 if (!connected_)
1560 return ERR_UNEXPECTED;
1562 // If the buffer is already in use, a read is already in progress!
1563 DCHECK(pending_buf_.get() == NULL);
1565 // Store our async IO data.
1566 pending_buf_ = buf;
1567 pending_buf_len_ = buf_len;
1568 pending_callback_ = callback;
1570 if (need_read_data_) {
1571 read_data_ = data_->GetNextRead();
1572 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1573 // to complete the async IO manually later (via OnReadComplete).
1574 if (read_data_.result == ERR_IO_PENDING) {
1575 // We need to be using async IO in this case.
1576 DCHECK(!callback.is_null());
1577 return ERR_IO_PENDING;
1579 need_read_data_ = false;
1582 return CompleteRead();
1585 int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len,
1586 const CompletionCallback& callback) {
1587 DCHECK(buf);
1588 DCHECK_GT(buf_len, 0);
1590 if (!connected_)
1591 return ERR_UNEXPECTED;
1593 std::string data(buf->data(), buf_len);
1594 MockWriteResult write_result = data_->OnWrite(data);
1596 if (write_result.mode == ASYNC) {
1597 RunCallbackAsync(callback, write_result.result);
1598 return ERR_IO_PENDING;
1600 return write_result.result;
1603 int MockUDPClientSocket::SetReceiveBufferSize(int32 size) {
1604 return OK;
1607 int MockUDPClientSocket::SetSendBufferSize(int32 size) {
1608 return OK;
1611 void MockUDPClientSocket::Close() {
1612 connected_ = false;
1615 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1616 *address = peer_addr_;
1617 return OK;
1620 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1621 IPAddressNumber ip;
1622 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
1623 CHECK(rv);
1624 *address = IPEndPoint(ip, source_port_);
1625 return OK;
1628 const BoundNetLog& MockUDPClientSocket::NetLog() const {
1629 return net_log_;
1632 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1633 connected_ = true;
1634 peer_addr_ = address;
1635 return data_->connect_data().result;
1638 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1639 // There must be a read pending.
1640 DCHECK(pending_buf_.get());
1641 // You can't complete a read with another ERR_IO_PENDING status code.
1642 DCHECK_NE(ERR_IO_PENDING, data.result);
1643 // Since we've been waiting for data, need_read_data_ should be true.
1644 DCHECK(need_read_data_);
1646 read_data_ = data;
1647 need_read_data_ = false;
1649 // The caller is simulating that this IO completes right now. Don't
1650 // let CompleteRead() schedule a callback.
1651 read_data_.mode = SYNCHRONOUS;
1653 net::CompletionCallback callback = pending_callback_;
1654 int rv = CompleteRead();
1655 RunCallback(callback, rv);
1658 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1659 NOTIMPLEMENTED();
1662 int MockUDPClientSocket::CompleteRead() {
1663 DCHECK(pending_buf_.get());
1664 DCHECK(pending_buf_len_ > 0);
1666 // Save the pending async IO data and reset our |pending_| state.
1667 scoped_refptr<IOBuffer> buf = pending_buf_;
1668 int buf_len = pending_buf_len_;
1669 CompletionCallback callback = pending_callback_;
1670 pending_buf_ = NULL;
1671 pending_buf_len_ = 0;
1672 pending_callback_.Reset();
1674 int result = read_data_.result;
1675 DCHECK(result != ERR_IO_PENDING);
1677 if (read_data_.data) {
1678 if (read_data_.data_len - read_offset_ > 0) {
1679 result = std::min(buf_len, read_data_.data_len - read_offset_);
1680 memcpy(buf->data(), read_data_.data + read_offset_, result);
1681 read_offset_ += result;
1682 if (read_offset_ == read_data_.data_len) {
1683 need_read_data_ = true;
1684 read_offset_ = 0;
1686 } else {
1687 result = 0; // EOF
1691 if (read_data_.mode == ASYNC) {
1692 DCHECK(!callback.is_null());
1693 RunCallbackAsync(callback, result);
1694 return ERR_IO_PENDING;
1696 return result;
1699 void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback,
1700 int result) {
1701 base::MessageLoop::current()->PostTask(
1702 FROM_HERE,
1703 base::Bind(&MockUDPClientSocket::RunCallback,
1704 weak_factory_.GetWeakPtr(),
1705 callback,
1706 result));
1709 void MockUDPClientSocket::RunCallback(const CompletionCallback& callback,
1710 int result) {
1711 if (!callback.is_null())
1712 callback.Run(result);
1715 TestSocketRequest::TestSocketRequest(
1716 std::vector<TestSocketRequest*>* request_order, size_t* completion_count)
1717 : request_order_(request_order),
1718 completion_count_(completion_count),
1719 callback_(base::Bind(&TestSocketRequest::OnComplete,
1720 base::Unretained(this))) {
1721 DCHECK(request_order);
1722 DCHECK(completion_count);
1725 TestSocketRequest::~TestSocketRequest() {
1728 void TestSocketRequest::OnComplete(int result) {
1729 SetResult(result);
1730 (*completion_count_)++;
1731 request_order_->push_back(this);
1734 // static
1735 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1737 // static
1738 const int ClientSocketPoolTest::kRequestNotFound = -2;
1740 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
1741 ClientSocketPoolTest::~ClientSocketPoolTest() {}
1743 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1744 index--;
1745 if (index >= requests_.size())
1746 return kIndexOutOfBounds;
1748 for (size_t i = 0; i < request_order_.size(); i++)
1749 if (requests_[index] == request_order_[i])
1750 return i + 1;
1752 return kRequestNotFound;
1755 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1756 ScopedVector<TestSocketRequest>::iterator i;
1757 for (i = requests_.begin(); i != requests_.end(); ++i) {
1758 if ((*i)->handle()->is_initialized()) {
1759 if (keep_alive == NO_KEEP_ALIVE)
1760 (*i)->handle()->socket()->Disconnect();
1761 (*i)->handle()->Reset();
1762 base::RunLoop().RunUntilIdle();
1763 return true;
1766 return false;
1769 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1770 bool released_one;
1771 do {
1772 released_one = ReleaseOneConnection(keep_alive);
1773 } while (released_one);
1776 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1777 scoped_ptr<StreamSocket> socket,
1778 ClientSocketHandle* handle,
1779 const CompletionCallback& callback)
1780 : socket_(socket.Pass()),
1781 handle_(handle),
1782 user_callback_(callback) {
1785 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
1787 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1788 int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect,
1789 base::Unretained(this)));
1790 if (rv == OK) {
1791 user_callback_.Reset();
1792 OnConnect(OK);
1794 return rv;
1797 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1798 const ClientSocketHandle* handle) {
1799 if (handle != handle_)
1800 return false;
1801 socket_.reset();
1802 handle_ = NULL;
1803 user_callback_.Reset();
1804 return true;
1807 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1808 if (!socket_.get())
1809 return;
1810 if (rv == OK) {
1811 handle_->SetSocket(socket_.Pass());
1813 // Needed for socket pool tests that layer other sockets on top of mock
1814 // sockets.
1815 LoadTimingInfo::ConnectTiming connect_timing;
1816 base::TimeTicks now = base::TimeTicks::Now();
1817 connect_timing.dns_start = now;
1818 connect_timing.dns_end = now;
1819 connect_timing.connect_start = now;
1820 connect_timing.connect_end = now;
1821 handle_->set_connect_timing(connect_timing);
1822 } else {
1823 socket_.reset();
1826 handle_ = NULL;
1828 if (!user_callback_.is_null()) {
1829 CompletionCallback callback = user_callback_;
1830 user_callback_.Reset();
1831 callback.Run(rv);
1835 MockTransportClientSocketPool::MockTransportClientSocketPool(
1836 int max_sockets,
1837 int max_sockets_per_group,
1838 ClientSocketPoolHistograms* histograms,
1839 ClientSocketFactory* socket_factory)
1840 : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1841 NULL, NULL, NULL),
1842 client_socket_factory_(socket_factory),
1843 last_request_priority_(DEFAULT_PRIORITY),
1844 release_count_(0),
1845 cancel_count_(0) {
1848 MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
1850 int MockTransportClientSocketPool::RequestSocket(
1851 const std::string& group_name, const void* socket_params,
1852 RequestPriority priority, ClientSocketHandle* handle,
1853 const CompletionCallback& callback, const BoundNetLog& net_log) {
1854 last_request_priority_ = priority;
1855 scoped_ptr<StreamSocket> socket =
1856 client_socket_factory_->CreateTransportClientSocket(
1857 AddressList(), net_log.net_log(), net::NetLog::Source());
1858 MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback);
1859 job_list_.push_back(job);
1860 handle->set_pool_id(1);
1861 return job->Connect();
1864 void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
1865 ClientSocketHandle* handle) {
1866 std::vector<MockConnectJob*>::iterator i;
1867 for (i = job_list_.begin(); i != job_list_.end(); ++i) {
1868 if ((*i)->CancelHandle(handle)) {
1869 cancel_count_++;
1870 break;
1875 void MockTransportClientSocketPool::ReleaseSocket(
1876 const std::string& group_name,
1877 scoped_ptr<StreamSocket> socket,
1878 int id) {
1879 EXPECT_EQ(1, id);
1880 release_count_++;
1883 DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
1885 DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
1887 void DeterministicMockClientSocketFactory::AddSocketDataProvider(
1888 DeterministicSocketData* data) {
1889 mock_data_.Add(data);
1892 void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
1893 SSLSocketDataProvider* data) {
1894 mock_ssl_data_.Add(data);
1897 void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
1898 mock_data_.ResetNextIndex();
1899 mock_ssl_data_.ResetNextIndex();
1902 MockSSLClientSocket* DeterministicMockClientSocketFactory::
1903 GetMockSSLClientSocket(size_t index) const {
1904 DCHECK_LT(index, ssl_client_sockets_.size());
1905 return ssl_client_sockets_[index];
1908 scoped_ptr<DatagramClientSocket>
1909 DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
1910 DatagramSocket::BindType bind_type,
1911 const RandIntCallback& rand_int_cb,
1912 net::NetLog* net_log,
1913 const NetLog::Source& source) {
1914 DeterministicSocketData* data_provider = mock_data().GetNext();
1915 scoped_ptr<DeterministicMockUDPClientSocket> socket(
1916 new DeterministicMockUDPClientSocket(net_log, data_provider));
1917 data_provider->set_delegate(socket->AsWeakPtr());
1918 udp_client_sockets().push_back(socket.get());
1919 if (bind_type == DatagramSocket::RANDOM_BIND)
1920 socket->set_source_port(static_cast<uint16>(rand_int_cb.Run(1025, 65535)));
1921 return socket.Pass();
1924 scoped_ptr<StreamSocket>
1925 DeterministicMockClientSocketFactory::CreateTransportClientSocket(
1926 const AddressList& addresses,
1927 net::NetLog* net_log,
1928 const net::NetLog::Source& source) {
1929 DeterministicSocketData* data_provider = mock_data().GetNext();
1930 scoped_ptr<DeterministicMockTCPClientSocket> socket(
1931 new DeterministicMockTCPClientSocket(net_log, data_provider));
1932 data_provider->set_delegate(socket->AsWeakPtr());
1933 tcp_client_sockets().push_back(socket.get());
1934 return socket.Pass();
1937 scoped_ptr<SSLClientSocket>
1938 DeterministicMockClientSocketFactory::CreateSSLClientSocket(
1939 scoped_ptr<ClientSocketHandle> transport_socket,
1940 const HostPortPair& host_and_port,
1941 const SSLConfig& ssl_config,
1942 const SSLClientSocketContext& context) {
1943 scoped_ptr<MockSSLClientSocket> socket(
1944 new MockSSLClientSocket(transport_socket.Pass(),
1945 host_and_port, ssl_config,
1946 mock_ssl_data_.GetNext()));
1947 ssl_client_sockets_.push_back(socket.get());
1948 return socket.Pass();
1951 void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
1954 MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
1955 int max_sockets,
1956 int max_sockets_per_group,
1957 ClientSocketPoolHistograms* histograms,
1958 TransportClientSocketPool* transport_pool)
1959 : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1960 NULL, transport_pool, NULL),
1961 transport_pool_(transport_pool) {
1964 MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
1966 int MockSOCKSClientSocketPool::RequestSocket(
1967 const std::string& group_name, const void* socket_params,
1968 RequestPriority priority, ClientSocketHandle* handle,
1969 const CompletionCallback& callback, const BoundNetLog& net_log) {
1970 return transport_pool_->RequestSocket(
1971 group_name, socket_params, priority, handle, callback, net_log);
1974 void MockSOCKSClientSocketPool::CancelRequest(
1975 const std::string& group_name,
1976 ClientSocketHandle* handle) {
1977 return transport_pool_->CancelRequest(group_name, handle);
1980 void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
1981 scoped_ptr<StreamSocket> socket,
1982 int id) {
1983 return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id);
1986 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
1987 const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
1989 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
1990 const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
1992 const char kSOCKS5OkRequest[] =
1993 { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
1994 const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
1996 const char kSOCKS5OkResponse[] =
1997 { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
1998 const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
2000 } // namespace net