1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/quic/test_tools/quic_test_utils.h"
8 #include "base/stl_util.h"
9 #include "base/strings/string_number_conversions.h"
10 #include "net/quic/crypto/crypto_framer.h"
11 #include "net/quic/crypto/crypto_handshake.h"
12 #include "net/quic/crypto/crypto_utils.h"
13 #include "net/quic/crypto/null_encrypter.h"
14 #include "net/quic/crypto/quic_decrypter.h"
15 #include "net/quic/crypto/quic_encrypter.h"
16 #include "net/quic/quic_data_writer.h"
17 #include "net/quic/quic_framer.h"
18 #include "net/quic/quic_packet_creator.h"
19 #include "net/quic/quic_utils.h"
20 #include "net/quic/test_tools/quic_connection_peer.h"
21 #include "net/spdy/spdy_frame_builder.h"
23 using base::StringPiece
;
27 using testing::AnyNumber
;
34 // No-op alarm implementation used by MockHelper.
35 class TestAlarm
: public QuicAlarm
{
37 explicit TestAlarm(QuicAlarm::Delegate
* delegate
)
38 : QuicAlarm(delegate
) {
41 void SetImpl() override
{}
42 void CancelImpl() override
{}
47 QuicAckFrame
MakeAckFrame(QuicPacketSequenceNumber largest_observed
) {
49 ack
.largest_observed
= largest_observed
;
54 QuicAckFrame
MakeAckFrameWithNackRanges(
55 size_t num_nack_ranges
, QuicPacketSequenceNumber least_unacked
) {
56 QuicAckFrame ack
= MakeAckFrame(2 * num_nack_ranges
+ least_unacked
);
57 // Add enough missing packets to get num_nack_ranges nack ranges.
58 for (QuicPacketSequenceNumber i
= 1; i
< 2 * num_nack_ranges
; i
+= 2) {
59 ack
.missing_packets
.insert(least_unacked
+ i
);
64 QuicPacket
* BuildUnsizedDataPacket(QuicFramer
* framer
,
65 const QuicPacketHeader
& header
,
66 const QuicFrames
& frames
) {
67 const size_t max_plaintext_size
= framer
->GetMaxPlaintextSize(kMaxPacketSize
);
68 size_t packet_size
= GetPacketHeaderSize(header
);
69 for (size_t i
= 0; i
< frames
.size(); ++i
) {
70 DCHECK_LE(packet_size
, max_plaintext_size
);
71 bool first_frame
= i
== 0;
72 bool last_frame
= i
== frames
.size() - 1;
73 const size_t frame_size
= framer
->GetSerializedFrameLength(
74 frames
[i
], max_plaintext_size
- packet_size
, first_frame
, last_frame
,
75 header
.is_in_fec_group
,
76 header
.public_header
.sequence_number_length
);
78 packet_size
+= frame_size
;
80 return BuildUnsizedDataPacket(framer
, header
, frames
, packet_size
);
83 QuicPacket
* BuildUnsizedDataPacket(QuicFramer
* framer
,
84 const QuicPacketHeader
& header
,
85 const QuicFrames
& frames
,
87 char* buffer
= new char[packet_size
];
88 scoped_ptr
<QuicPacket
> packet(
89 framer
->BuildDataPacket(header
, frames
, buffer
, packet_size
));
90 DCHECK(packet
.get() != nullptr);
91 // Now I have to re-construct the data packet with data ownership.
92 return new QuicPacket(buffer
, packet
->length(), true,
93 header
.public_header
.connection_id_length
,
94 header
.public_header
.version_flag
,
95 header
.public_header
.sequence_number_length
);
98 uint64
SimpleRandom::RandUint64() {
99 unsigned char hash
[base::kSHA1Length
];
100 base::SHA1HashBytes(reinterpret_cast<unsigned char*>(&seed_
), sizeof(seed_
),
102 memcpy(&seed_
, hash
, sizeof(seed_
));
106 MockFramerVisitor::MockFramerVisitor() {
107 // By default, we want to accept packets.
108 ON_CALL(*this, OnProtocolVersionMismatch(_
))
109 .WillByDefault(testing::Return(false));
111 // By default, we want to accept packets.
112 ON_CALL(*this, OnUnauthenticatedHeader(_
))
113 .WillByDefault(testing::Return(true));
115 ON_CALL(*this, OnUnauthenticatedPublicHeader(_
))
116 .WillByDefault(testing::Return(true));
118 ON_CALL(*this, OnPacketHeader(_
))
119 .WillByDefault(testing::Return(true));
121 ON_CALL(*this, OnStreamFrame(_
))
122 .WillByDefault(testing::Return(true));
124 ON_CALL(*this, OnAckFrame(_
))
125 .WillByDefault(testing::Return(true));
127 ON_CALL(*this, OnStopWaitingFrame(_
))
128 .WillByDefault(testing::Return(true));
130 ON_CALL(*this, OnPingFrame(_
))
131 .WillByDefault(testing::Return(true));
133 ON_CALL(*this, OnRstStreamFrame(_
))
134 .WillByDefault(testing::Return(true));
136 ON_CALL(*this, OnConnectionCloseFrame(_
))
137 .WillByDefault(testing::Return(true));
139 ON_CALL(*this, OnGoAwayFrame(_
))
140 .WillByDefault(testing::Return(true));
143 MockFramerVisitor::~MockFramerVisitor() {
146 bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version
) {
150 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
151 const QuicPacketPublicHeader
& header
) {
155 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
156 const QuicPacketHeader
& header
) {
160 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader
& header
) {
164 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame
& frame
) {
168 bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame
& frame
) {
172 bool NoOpFramerVisitor::OnStopWaitingFrame(
173 const QuicStopWaitingFrame
& frame
) {
177 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame
& frame
) {
181 bool NoOpFramerVisitor::OnRstStreamFrame(
182 const QuicRstStreamFrame
& frame
) {
186 bool NoOpFramerVisitor::OnConnectionCloseFrame(
187 const QuicConnectionCloseFrame
& frame
) {
191 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame
& frame
) {
195 bool NoOpFramerVisitor::OnWindowUpdateFrame(
196 const QuicWindowUpdateFrame
& frame
) {
200 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame
& frame
) {
204 MockConnectionVisitor::MockConnectionVisitor() {
207 MockConnectionVisitor::~MockConnectionVisitor() {
210 MockHelper::MockHelper() {
213 MockHelper::~MockHelper() {
216 const QuicClock
* MockHelper::GetClock() const {
220 QuicRandom
* MockHelper::GetRandomGenerator() {
221 return &random_generator_
;
224 QuicAlarm
* MockHelper::CreateAlarm(QuicAlarm::Delegate
* delegate
) {
225 return new TestAlarm(delegate
);
228 void MockHelper::AdvanceTime(QuicTime::Delta delta
) {
229 clock_
.AdvanceTime(delta
);
232 QuicPacketWriter
* NiceMockPacketWriterFactory::Create(
233 QuicConnection
* /*connection*/) const {
234 return new testing::NiceMock
<MockPacketWriter
>();
237 MockConnection::MockConnection(bool is_server
)
238 : QuicConnection(kTestConnectionId
,
239 IPEndPoint(TestPeerIPAddress(), kTestPort
),
240 new testing::NiceMock
<MockHelper
>(),
241 NiceMockPacketWriterFactory(),
242 /* owns_writer= */ true,
244 /* is_secure= */ false,
245 QuicSupportedVersions()),
249 MockConnection::MockConnection(bool is_server
, bool is_secure
)
250 : QuicConnection(kTestConnectionId
,
251 IPEndPoint(TestPeerIPAddress(), kTestPort
),
252 new testing::NiceMock
<MockHelper
>(),
253 NiceMockPacketWriterFactory(),
254 /* owns_writer= */ true,
257 QuicSupportedVersions()),
261 MockConnection::MockConnection(IPEndPoint address
,
263 : QuicConnection(kTestConnectionId
, address
,
264 new testing::NiceMock
<MockHelper
>(),
265 NiceMockPacketWriterFactory(),
266 /* owns_writer= */ true,
268 /* is_secure= */ false,
269 QuicSupportedVersions()),
273 MockConnection::MockConnection(QuicConnectionId connection_id
,
275 : QuicConnection(connection_id
,
276 IPEndPoint(TestPeerIPAddress(), kTestPort
),
277 new testing::NiceMock
<MockHelper
>(),
278 NiceMockPacketWriterFactory(),
279 /* owns_writer= */ true,
281 /* is_secure= */ false,
282 QuicSupportedVersions()),
286 MockConnection::MockConnection(bool is_server
,
287 const QuicVersionVector
& supported_versions
)
288 : QuicConnection(kTestConnectionId
,
289 IPEndPoint(TestPeerIPAddress(), kTestPort
),
290 new testing::NiceMock
<MockHelper
>(),
291 NiceMockPacketWriterFactory(),
292 /* owns_writer= */ true,
294 /* is_secure= */ false,
299 MockConnection::~MockConnection() {
302 void MockConnection::AdvanceTime(QuicTime::Delta delta
) {
303 static_cast<MockHelper
*>(helper())->AdvanceTime(delta
);
306 PacketSavingConnection::PacketSavingConnection(bool is_server
)
307 : MockConnection(is_server
) {
310 PacketSavingConnection::PacketSavingConnection(
312 const QuicVersionVector
& supported_versions
)
313 : MockConnection(is_server
, supported_versions
) {
316 PacketSavingConnection::~PacketSavingConnection() {
317 STLDeleteElements(&encrypted_packets_
);
320 void PacketSavingConnection::SendOrQueuePacket(QueuedPacket packet
) {
321 encrypted_packets_
.push_back(packet
.serialized_packet
.packet
);
322 // Transfer ownership of the packet to the SentPacketManager and the
323 // ack notifier to the AckNotifierManager.
324 sent_packet_manager_
.OnPacketSent(
325 &packet
.serialized_packet
, 0, QuicTime::Zero(), 1000,
326 NOT_RETRANSMISSION
, HAS_RETRANSMITTABLE_DATA
);
329 MockSession::MockSession(QuicConnection
* connection
)
330 : QuicSession(connection
, DefaultQuicConfig()) {
332 ON_CALL(*this, WritevData(_
, _
, _
, _
, _
, _
))
333 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
336 MockSession::~MockSession() {
339 TestSession::TestSession(QuicConnection
* connection
, const QuicConfig
& config
)
340 : QuicSession(connection
, config
),
341 crypto_stream_(nullptr) {
345 TestSession::~TestSession() {}
347 void TestSession::SetCryptoStream(QuicCryptoStream
* stream
) {
348 crypto_stream_
= stream
;
351 QuicCryptoStream
* TestSession::GetCryptoStream() {
352 return crypto_stream_
;
355 TestClientSession::TestClientSession(QuicConnection
* connection
,
356 const QuicConfig
& config
)
357 : QuicClientSessionBase(connection
, config
),
358 crypto_stream_(nullptr) {
359 EXPECT_CALL(*this, OnProofValid(_
)).Times(AnyNumber());
363 TestClientSession::~TestClientSession() {}
365 void TestClientSession::SetCryptoStream(QuicCryptoStream
* stream
) {
366 crypto_stream_
= stream
;
369 QuicCryptoStream
* TestClientSession::GetCryptoStream() {
370 return crypto_stream_
;
373 MockPacketWriter::MockPacketWriter() {
376 MockPacketWriter::~MockPacketWriter() {
379 MockSendAlgorithm::MockSendAlgorithm() {
382 MockSendAlgorithm::~MockSendAlgorithm() {
385 MockLossAlgorithm::MockLossAlgorithm() {
388 MockLossAlgorithm::~MockLossAlgorithm() {
391 MockAckNotifierDelegate::MockAckNotifierDelegate() {
394 MockAckNotifierDelegate::~MockAckNotifierDelegate() {
397 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {
400 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {
405 string
HexDumpWithMarks(const char* data
, int length
,
406 const bool* marks
, int mark_length
) {
407 static const char kHexChars
[] = "0123456789abcdef";
408 static const int kColumns
= 4;
410 const int kSizeLimit
= 1024;
411 if (length
> kSizeLimit
|| mark_length
> kSizeLimit
) {
412 LOG(ERROR
) << "Only dumping first " << kSizeLimit
<< " bytes.";
413 length
= min(length
, kSizeLimit
);
414 mark_length
= min(mark_length
, kSizeLimit
);
418 for (const char* row
= data
; length
> 0;
419 row
+= kColumns
, length
-= kColumns
) {
420 for (const char *p
= row
; p
< row
+ 4; ++p
) {
421 if (p
< row
+ length
) {
423 (marks
&& (p
- data
) < mark_length
&& marks
[p
- data
]);
424 hex
+= mark
? '*' : ' ';
425 hex
+= kHexChars
[(*p
& 0xf0) >> 4];
426 hex
+= kHexChars
[*p
& 0x0f];
427 hex
+= mark
? '*' : ' ';
434 for (const char *p
= row
; p
< row
+ 4 && p
< row
+ length
; ++p
)
435 hex
+= (*p
>= 0x20 && *p
<= 0x7f) ? (*p
) : '.';
444 IPAddressNumber
TestPeerIPAddress() { return Loopback4(); }
446 QuicVersion
QuicVersionMax() { return QuicSupportedVersions().front(); }
448 QuicVersion
QuicVersionMin() { return QuicSupportedVersions().back(); }
450 IPAddressNumber
Loopback4() {
451 IPAddressNumber addr
;
452 CHECK(ParseIPLiteralToNumber("127.0.0.1", &addr
));
456 IPAddressNumber
Loopback6() {
457 IPAddressNumber addr
;
458 CHECK(ParseIPLiteralToNumber("::1", &addr
));
462 void GenerateBody(string
* body
, int length
) {
464 body
->reserve(length
);
465 for (int i
= 0; i
< length
; ++i
) {
466 body
->append(1, static_cast<char>(32 + i
% (126 - 32)));
470 QuicEncryptedPacket
* ConstructEncryptedPacket(
471 QuicConnectionId connection_id
,
474 QuicPacketSequenceNumber sequence_number
,
475 const string
& data
) {
476 QuicPacketHeader header
;
477 header
.public_header
.connection_id
= connection_id
;
478 header
.public_header
.connection_id_length
= PACKET_8BYTE_CONNECTION_ID
;
479 header
.public_header
.version_flag
= version_flag
;
480 header
.public_header
.reset_flag
= reset_flag
;
481 header
.public_header
.sequence_number_length
= PACKET_6BYTE_SEQUENCE_NUMBER
;
482 header
.packet_sequence_number
= sequence_number
;
483 header
.entropy_flag
= false;
484 header
.entropy_hash
= 0;
485 header
.fec_flag
= false;
486 header
.is_in_fec_group
= NOT_IN_FEC_GROUP
;
487 header
.fec_group
= 0;
488 QuicStreamFrame
stream_frame(1, false, 0, MakeIOVector(data
));
489 QuicFrame
frame(&stream_frame
);
491 frames
.push_back(frame
);
492 QuicFramer
framer(QuicSupportedVersions(), QuicTime::Zero(), false);
493 scoped_ptr
<QuicPacket
> packet(
494 BuildUnsizedDataPacket(&framer
, header
, frames
));
495 EXPECT_TRUE(packet
!= nullptr);
496 QuicEncryptedPacket
* encrypted
= framer
.EncryptPacket(ENCRYPTION_NONE
,
499 EXPECT_TRUE(encrypted
!= nullptr);
503 void CompareCharArraysWithHexError(
504 const string
& description
,
506 const int actual_len
,
507 const char* expected
,
508 const int expected_len
) {
509 EXPECT_EQ(actual_len
, expected_len
);
510 const int min_len
= min(actual_len
, expected_len
);
511 const int max_len
= max(actual_len
, expected_len
);
512 scoped_ptr
<bool[]> marks(new bool[max_len
]);
513 bool identical
= (actual_len
== expected_len
);
514 for (int i
= 0; i
< min_len
; ++i
) {
515 if (actual
[i
] != expected
[i
]) {
522 for (int i
= min_len
; i
< max_len
; ++i
) {
525 if (identical
) return;
530 << HexDumpWithMarks(expected
, expected_len
, marks
.get(), max_len
)
532 << HexDumpWithMarks(actual
, actual_len
, marks
.get(), max_len
);
535 bool DecodeHexString(const base::StringPiece
& hex
, std::string
* bytes
) {
539 std::vector
<uint8
> v
;
540 if (!base::HexStringToBytes(hex
.as_string(), &v
))
543 bytes
->assign(reinterpret_cast<const char*>(&v
[0]), v
.size());
547 static QuicPacket
* ConstructPacketFromHandshakeMessage(
548 QuicConnectionId connection_id
,
549 const CryptoHandshakeMessage
& message
,
550 bool should_include_version
) {
551 CryptoFramer crypto_framer
;
552 scoped_ptr
<QuicData
> data(crypto_framer
.ConstructHandshakeMessage(message
));
553 QuicFramer
quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
555 QuicPacketHeader header
;
556 header
.public_header
.connection_id
= connection_id
;
557 header
.public_header
.reset_flag
= false;
558 header
.public_header
.version_flag
= should_include_version
;
559 header
.packet_sequence_number
= 1;
560 header
.entropy_flag
= false;
561 header
.entropy_hash
= 0;
562 header
.fec_flag
= false;
563 header
.fec_group
= 0;
565 QuicStreamFrame
stream_frame(kCryptoStreamId
, false, 0,
566 MakeIOVector(data
->AsStringPiece()));
568 QuicFrame
frame(&stream_frame
);
570 frames
.push_back(frame
);
571 return BuildUnsizedDataPacket(&quic_framer
, header
, frames
);
574 QuicPacket
* ConstructHandshakePacket(QuicConnectionId connection_id
,
576 CryptoHandshakeMessage message
;
577 message
.set_tag(tag
);
578 return ConstructPacketFromHandshakeMessage(connection_id
, message
, false);
581 size_t GetPacketLengthForOneStream(
583 bool include_version
,
584 QuicConnectionIdLength connection_id_length
,
585 QuicSequenceNumberLength sequence_number_length
,
586 InFecGroup is_in_fec_group
,
587 size_t* payload_length
) {
589 const size_t stream_length
=
590 NullEncrypter().GetCiphertextSize(*payload_length
) +
591 QuicPacketCreator::StreamFramePacketOverhead(
592 PACKET_8BYTE_CONNECTION_ID
, include_version
,
593 sequence_number_length
, 0u, is_in_fec_group
);
594 const size_t ack_length
= NullEncrypter().GetCiphertextSize(
595 QuicFramer::GetMinAckFrameSize(
596 sequence_number_length
, PACKET_1BYTE_SEQUENCE_NUMBER
)) +
597 GetPacketHeaderSize(connection_id_length
, include_version
,
598 sequence_number_length
, is_in_fec_group
);
599 if (stream_length
< ack_length
) {
600 *payload_length
= 1 + ack_length
- stream_length
;
603 return NullEncrypter().GetCiphertextSize(*payload_length
) +
604 QuicPacketCreator::StreamFramePacketOverhead(
605 connection_id_length
, include_version
,
606 sequence_number_length
, 0u, is_in_fec_group
);
609 TestEntropyCalculator::TestEntropyCalculator() {}
611 TestEntropyCalculator::~TestEntropyCalculator() {}
613 QuicPacketEntropyHash
TestEntropyCalculator::EntropyHash(
614 QuicPacketSequenceNumber sequence_number
) const {
618 MockEntropyCalculator::MockEntropyCalculator() {}
620 MockEntropyCalculator::~MockEntropyCalculator() {}
622 QuicConfig
DefaultQuicConfig() {
624 config
.SetInitialStreamFlowControlWindowToSend(
625 kInitialStreamFlowControlWindowForTest
);
626 config
.SetInitialSessionFlowControlWindowToSend(
627 kInitialSessionFlowControlWindowForTest
);
631 QuicVersionVector
SupportedVersions(QuicVersion version
) {
632 QuicVersionVector versions
;
633 versions
.push_back(version
);
637 TestWriterFactory::TestWriterFactory() : current_writer_(nullptr) {}
638 TestWriterFactory::~TestWriterFactory() {}
640 QuicPacketWriter
* TestWriterFactory::Create(QuicServerPacketWriter
* writer
,
641 QuicConnection
* connection
) {
642 return new PerConnectionPacketWriter(this, writer
, connection
);
645 void TestWriterFactory::OnPacketSent(WriteResult result
) {
646 if (current_writer_
!= nullptr && result
.status
== WRITE_STATUS_ERROR
) {
647 current_writer_
->connection()->OnWriteError(result
.error_code
);
648 current_writer_
= nullptr;
652 void TestWriterFactory::Unregister(PerConnectionPacketWriter
* writer
) {
653 if (current_writer_
== writer
) {
654 current_writer_
= nullptr;
658 TestWriterFactory::PerConnectionPacketWriter::PerConnectionPacketWriter(
659 TestWriterFactory
* factory
,
660 QuicServerPacketWriter
* writer
,
661 QuicConnection
* connection
)
662 : QuicPerConnectionPacketWriter(writer
, connection
),
666 TestWriterFactory::PerConnectionPacketWriter::~PerConnectionPacketWriter() {
667 factory_
->Unregister(this);
670 WriteResult
TestWriterFactory::PerConnectionPacketWriter::WritePacket(
673 const IPAddressNumber
& self_address
,
674 const IPEndPoint
& peer_address
) {
675 // A DCHECK(factory_current_writer_ == nullptr) would be wrong here -- this
676 // class may be used in a setting where connection()->OnPacketSent() is called
677 // in a different way, so TestWriterFactory::OnPacketSent might never be
679 factory_
->current_writer_
= this;
680 return QuicPerConnectionPacketWriter::WritePacket(buffer
,