Re-land: C++ readability review
[chromium-blink-merge.git] / net / quic / test_tools / crypto_test_utils.cc
blobf402256e112710ce767e5cb7205275c80de8309b
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/quic/test_tools/crypto_test_utils.h"
7 #include "net/quic/crypto/channel_id.h"
8 #include "net/quic/crypto/common_cert_set.h"
9 #include "net/quic/crypto/crypto_handshake.h"
10 #include "net/quic/crypto/quic_crypto_server_config.h"
11 #include "net/quic/crypto/quic_decrypter.h"
12 #include "net/quic/crypto/quic_encrypter.h"
13 #include "net/quic/crypto/quic_random.h"
14 #include "net/quic/quic_clock.h"
15 #include "net/quic/quic_crypto_client_stream.h"
16 #include "net/quic/quic_crypto_server_stream.h"
17 #include "net/quic/quic_crypto_stream.h"
18 #include "net/quic/quic_server_id.h"
19 #include "net/quic/test_tools/quic_connection_peer.h"
20 #include "net/quic/test_tools/quic_framer_peer.h"
21 #include "net/quic/test_tools/quic_test_utils.h"
22 #include "net/quic/test_tools/simple_quic_framer.h"
24 using base::StringPiece;
25 using std::make_pair;
26 using std::pair;
27 using std::string;
28 using std::vector;
30 namespace net {
31 namespace test {
33 namespace {
35 const char kServerHostname[] = "test.example.com";
36 const uint16 kServerPort = 80;
38 // CryptoFramerVisitor is a framer visitor that records handshake messages.
39 class CryptoFramerVisitor : public CryptoFramerVisitorInterface {
40 public:
41 CryptoFramerVisitor()
42 : error_(false) {
45 void OnError(CryptoFramer* framer) override { error_ = true; }
47 void OnHandshakeMessage(const CryptoHandshakeMessage& message) override {
48 messages_.push_back(message);
51 bool error() const {
52 return error_;
55 const vector<CryptoHandshakeMessage>& messages() const {
56 return messages_;
59 private:
60 bool error_;
61 vector<CryptoHandshakeMessage> messages_;
64 // MovePackets parses crypto handshake messages from packet number
65 // |*inout_packet_index| through to the last packet (or until a packet fails to
66 // decrypt) and has |dest_stream| process them. |*inout_packet_index| is updated
67 // with an index one greater than the last packet processed.
68 void MovePackets(PacketSavingConnection* source_conn,
69 size_t *inout_packet_index,
70 QuicCryptoStream* dest_stream,
71 PacketSavingConnection* dest_conn) {
72 SimpleQuicFramer framer(source_conn->supported_versions());
73 CryptoFramer crypto_framer;
74 CryptoFramerVisitor crypto_visitor;
76 // In order to properly test the code we need to perform encryption and
77 // decryption so that the crypters latch when expected. The crypters are in
78 // |dest_conn|, but we don't want to try and use them there. Instead we swap
79 // them into |framer|, perform the decryption with them, and then swap them
80 // back.
81 QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
83 crypto_framer.set_visitor(&crypto_visitor);
85 size_t index = *inout_packet_index;
86 for (; index < source_conn->encrypted_packets_.size(); index++) {
87 if (!framer.ProcessPacket(*source_conn->encrypted_packets_[index])) {
88 // The framer will be unable to decrypt forward-secure packets sent after
89 // the handshake is complete. Don't treat them as handshake packets.
90 break;
93 for (vector<QuicStreamFrame>::const_iterator
94 i = framer.stream_frames().begin();
95 i != framer.stream_frames().end(); ++i) {
96 scoped_ptr<string> frame_data(i->GetDataAsString());
97 ASSERT_TRUE(crypto_framer.ProcessInput(*frame_data));
98 ASSERT_FALSE(crypto_visitor.error());
101 *inout_packet_index = index;
103 QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
105 ASSERT_EQ(0u, crypto_framer.InputBytesRemaining());
107 for (vector<CryptoHandshakeMessage>::const_iterator
108 i = crypto_visitor.messages().begin();
109 i != crypto_visitor.messages().end(); ++i) {
110 dest_stream->OnHandshakeMessage(*i);
114 // HexChar parses |c| as a hex character. If valid, it sets |*value| to the
115 // value of the hex character and returns true. Otherwise it returns false.
116 bool HexChar(char c, uint8* value) {
117 if (c >= '0' && c <= '9') {
118 *value = c - '0';
119 return true;
121 if (c >= 'a' && c <= 'f') {
122 *value = c - 'a' + 10;
123 return true;
125 if (c >= 'A' && c <= 'F') {
126 *value = c - 'A' + 10;
127 return true;
129 return false;
132 // A ChannelIDSource that works in asynchronous mode unless the |callback|
133 // argument to GetChannelIDKey is nullptr.
134 class AsyncTestChannelIDSource : public ChannelIDSource,
135 public CryptoTestUtils::CallbackSource {
136 public:
137 // Takes ownership of |sync_source|, a synchronous ChannelIDSource.
138 explicit AsyncTestChannelIDSource(ChannelIDSource* sync_source)
139 : sync_source_(sync_source) {}
140 ~AsyncTestChannelIDSource() override {}
142 // ChannelIDSource implementation.
143 QuicAsyncStatus GetChannelIDKey(const string& hostname,
144 scoped_ptr<ChannelIDKey>* channel_id_key,
145 ChannelIDSourceCallback* callback) override {
146 // Synchronous mode.
147 if (!callback) {
148 return sync_source_->GetChannelIDKey(hostname, channel_id_key, nullptr);
151 // Asynchronous mode.
152 QuicAsyncStatus status =
153 sync_source_->GetChannelIDKey(hostname, &channel_id_key_, nullptr);
154 if (status != QUIC_SUCCESS) {
155 return QUIC_FAILURE;
157 callback_.reset(callback);
158 return QUIC_PENDING;
161 // CallbackSource implementation.
162 void RunPendingCallbacks() override {
163 if (callback_.get()) {
164 callback_->Run(&channel_id_key_);
165 callback_.reset();
169 private:
170 scoped_ptr<ChannelIDSource> sync_source_;
171 scoped_ptr<ChannelIDSourceCallback> callback_;
172 scoped_ptr<ChannelIDKey> channel_id_key_;
175 } // anonymous namespace
177 CryptoTestUtils::FakeClientOptions::FakeClientOptions()
178 : dont_verify_certs(false),
179 channel_id_enabled(false),
180 channel_id_source_async(false) {
183 // static
184 int CryptoTestUtils::HandshakeWithFakeServer(
185 PacketSavingConnection* client_conn,
186 QuicCryptoClientStream* client) {
187 PacketSavingConnection* server_conn = new PacketSavingConnection(
188 Perspective::IS_SERVER, client_conn->supported_versions());
189 TestSession server_session(server_conn, DefaultQuicConfig());
190 server_session.InitializeSession();
191 QuicCryptoServerConfig crypto_config(QuicCryptoServerConfig::TESTING,
192 QuicRandom::GetInstance());
194 SetupCryptoServerConfigForTest(
195 server_session.connection()->clock(),
196 server_session.connection()->random_generator(),
197 server_session.config(), &crypto_config);
199 QuicCryptoServerStream server(&crypto_config, &server_session);
200 server_session.SetCryptoStream(&server);
202 // The client's handshake must have been started already.
203 CHECK_NE(0u, client_conn->encrypted_packets_.size());
205 CommunicateHandshakeMessages(client_conn, client, server_conn, &server);
207 CompareClientAndServerKeys(client, &server);
209 return client->num_sent_client_hellos();
212 // static
213 int CryptoTestUtils::HandshakeWithFakeClient(
214 PacketSavingConnection* server_conn,
215 QuicCryptoServerStream* server,
216 const FakeClientOptions& options) {
217 PacketSavingConnection* client_conn =
218 new PacketSavingConnection(Perspective::IS_CLIENT);
219 // Advance the time, because timers do not like uninitialized times.
220 client_conn->AdvanceTime(QuicTime::Delta::FromSeconds(1));
221 TestClientSession client_session(client_conn, DefaultQuicConfig());
222 QuicCryptoClientConfig crypto_config;
224 if (!options.dont_verify_certs) {
225 // TODO(wtc): replace this with ProofVerifierForTesting() when we have
226 // a working ProofSourceForTesting().
227 crypto_config.SetProofVerifier(FakeProofVerifierForTesting());
229 bool is_https = false;
230 AsyncTestChannelIDSource* async_channel_id_source = nullptr;
231 if (options.channel_id_enabled) {
232 is_https = true;
234 ChannelIDSource* source = ChannelIDSourceForTesting();
235 if (options.channel_id_source_async) {
236 async_channel_id_source = new AsyncTestChannelIDSource(source);
237 source = async_channel_id_source;
239 crypto_config.SetChannelIDSource(source);
241 QuicServerId server_id(kServerHostname, kServerPort, is_https,
242 PRIVACY_MODE_DISABLED);
243 QuicCryptoClientStream client(server_id, &client_session,
244 ProofVerifyContextForTesting(),
245 &crypto_config);
246 client_session.SetCryptoStream(&client);
248 client.CryptoConnect();
249 CHECK_EQ(1u, client_conn->encrypted_packets_.size());
251 CommunicateHandshakeMessagesAndRunCallbacks(
252 client_conn, &client, server_conn, server, async_channel_id_source);
254 CompareClientAndServerKeys(&client, server);
256 if (options.channel_id_enabled) {
257 scoped_ptr<ChannelIDKey> channel_id_key;
258 QuicAsyncStatus status = crypto_config.channel_id_source()->GetChannelIDKey(
259 kServerHostname, &channel_id_key, nullptr);
260 EXPECT_EQ(QUIC_SUCCESS, status);
261 EXPECT_EQ(channel_id_key->SerializeKey(),
262 server->crypto_negotiated_params().channel_id);
263 EXPECT_EQ(options.channel_id_source_async,
264 client.WasChannelIDSourceCallbackRun());
267 return client.num_sent_client_hellos();
270 // static
271 void CryptoTestUtils::SetupCryptoServerConfigForTest(
272 const QuicClock* clock,
273 QuicRandom* rand,
274 QuicConfig* config,
275 QuicCryptoServerConfig* crypto_config) {
276 QuicCryptoServerConfig::ConfigOptions options;
277 options.channel_id_enabled = true;
278 scoped_ptr<CryptoHandshakeMessage> scfg(
279 crypto_config->AddDefaultConfig(rand, clock, options));
282 // static
283 void CryptoTestUtils::CommunicateHandshakeMessages(
284 PacketSavingConnection* a_conn,
285 QuicCryptoStream* a,
286 PacketSavingConnection* b_conn,
287 QuicCryptoStream* b) {
288 CommunicateHandshakeMessagesAndRunCallbacks(a_conn, a, b_conn, b, nullptr);
291 // static
292 void CryptoTestUtils::CommunicateHandshakeMessagesAndRunCallbacks(
293 PacketSavingConnection* a_conn,
294 QuicCryptoStream* a,
295 PacketSavingConnection* b_conn,
296 QuicCryptoStream* b,
297 CallbackSource* callback_source) {
298 size_t a_i = 0, b_i = 0;
299 while (!a->handshake_confirmed()) {
300 ASSERT_GT(a_conn->encrypted_packets_.size(), a_i);
301 VLOG(1) << "Processing " << a_conn->encrypted_packets_.size() - a_i
302 << " packets a->b";
303 MovePackets(a_conn, &a_i, b, b_conn);
304 if (callback_source) {
305 callback_source->RunPendingCallbacks();
308 ASSERT_GT(b_conn->encrypted_packets_.size(), b_i);
309 VLOG(1) << "Processing " << b_conn->encrypted_packets_.size() - b_i
310 << " packets b->a";
311 MovePackets(b_conn, &b_i, a, a_conn);
312 if (callback_source) {
313 callback_source->RunPendingCallbacks();
318 // static
319 pair<size_t, size_t> CryptoTestUtils::AdvanceHandshake(
320 PacketSavingConnection* a_conn,
321 QuicCryptoStream* a,
322 size_t a_i,
323 PacketSavingConnection* b_conn,
324 QuicCryptoStream* b,
325 size_t b_i) {
326 VLOG(1) << "Processing " << a_conn->encrypted_packets_.size() - a_i
327 << " packets a->b";
328 MovePackets(a_conn, &a_i, b, b_conn);
330 VLOG(1) << "Processing " << b_conn->encrypted_packets_.size() - b_i
331 << " packets b->a";
332 if (b_conn->encrypted_packets_.size() - b_i == 2) {
333 VLOG(1) << "here";
335 MovePackets(b_conn, &b_i, a, a_conn);
337 return std::make_pair(a_i, b_i);
340 // static
341 string CryptoTestUtils::GetValueForTag(const CryptoHandshakeMessage& message,
342 QuicTag tag) {
343 QuicTagValueMap::const_iterator it = message.tag_value_map().find(tag);
344 if (it == message.tag_value_map().end()) {
345 return string();
347 return it->second;
350 class MockCommonCertSets : public CommonCertSets {
351 public:
352 MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index)
353 : cert_(cert.as_string()),
354 hash_(hash),
355 index_(index) {
358 StringPiece GetCommonHashes() const override {
359 CHECK(false) << "not implemented";
360 return StringPiece();
363 StringPiece GetCert(uint64 hash, uint32 index) const override {
364 if (hash == hash_ && index == index_) {
365 return cert_;
367 return StringPiece();
370 bool MatchCert(StringPiece cert,
371 StringPiece common_set_hashes,
372 uint64* out_hash,
373 uint32* out_index) const override {
374 if (cert != cert_) {
375 return false;
378 if (common_set_hashes.size() % sizeof(uint64) != 0) {
379 return false;
381 bool client_has_set = false;
382 for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) {
383 uint64 hash;
384 memcpy(&hash, common_set_hashes.data() + i, sizeof(hash));
385 if (hash == hash_) {
386 client_has_set = true;
387 break;
391 if (!client_has_set) {
392 return false;
395 *out_hash = hash_;
396 *out_index = index_;
397 return true;
400 private:
401 const string cert_;
402 const uint64 hash_;
403 const uint32 index_;
406 CommonCertSets* CryptoTestUtils::MockCommonCertSets(StringPiece cert,
407 uint64 hash,
408 uint32 index) {
409 return new class MockCommonCertSets(cert, hash, index);
412 void CryptoTestUtils::CompareClientAndServerKeys(
413 QuicCryptoClientStream* client,
414 QuicCryptoServerStream* server) {
415 QuicFramer* client_framer =
416 QuicConnectionPeer::GetFramer(client->session()->connection());
417 QuicFramer* server_framer =
418 QuicConnectionPeer::GetFramer(server->session()->connection());
419 const QuicEncrypter* client_encrypter(
420 QuicFramerPeer::GetEncrypter(client_framer, ENCRYPTION_INITIAL));
421 const QuicDecrypter* client_decrypter(
422 client->session()->connection()->decrypter());
423 const QuicEncrypter* client_forward_secure_encrypter(
424 QuicFramerPeer::GetEncrypter(client_framer, ENCRYPTION_FORWARD_SECURE));
425 const QuicDecrypter* client_forward_secure_decrypter(
426 client->session()->connection()->alternative_decrypter());
427 const QuicEncrypter* server_encrypter(
428 QuicFramerPeer::GetEncrypter(server_framer, ENCRYPTION_INITIAL));
429 const QuicDecrypter* server_decrypter(
430 server->session()->connection()->decrypter());
431 const QuicEncrypter* server_forward_secure_encrypter(
432 QuicFramerPeer::GetEncrypter(server_framer, ENCRYPTION_FORWARD_SECURE));
433 const QuicDecrypter* server_forward_secure_decrypter(
434 server->session()->connection()->alternative_decrypter());
436 StringPiece client_encrypter_key = client_encrypter->GetKey();
437 StringPiece client_encrypter_iv = client_encrypter->GetNoncePrefix();
438 StringPiece client_decrypter_key = client_decrypter->GetKey();
439 StringPiece client_decrypter_iv = client_decrypter->GetNoncePrefix();
440 StringPiece client_forward_secure_encrypter_key =
441 client_forward_secure_encrypter->GetKey();
442 StringPiece client_forward_secure_encrypter_iv =
443 client_forward_secure_encrypter->GetNoncePrefix();
444 StringPiece client_forward_secure_decrypter_key =
445 client_forward_secure_decrypter->GetKey();
446 StringPiece client_forward_secure_decrypter_iv =
447 client_forward_secure_decrypter->GetNoncePrefix();
448 StringPiece server_encrypter_key = server_encrypter->GetKey();
449 StringPiece server_encrypter_iv = server_encrypter->GetNoncePrefix();
450 StringPiece server_decrypter_key = server_decrypter->GetKey();
451 StringPiece server_decrypter_iv = server_decrypter->GetNoncePrefix();
452 StringPiece server_forward_secure_encrypter_key =
453 server_forward_secure_encrypter->GetKey();
454 StringPiece server_forward_secure_encrypter_iv =
455 server_forward_secure_encrypter->GetNoncePrefix();
456 StringPiece server_forward_secure_decrypter_key =
457 server_forward_secure_decrypter->GetKey();
458 StringPiece server_forward_secure_decrypter_iv =
459 server_forward_secure_decrypter->GetNoncePrefix();
461 StringPiece client_subkey_secret =
462 client->crypto_negotiated_params().subkey_secret;
463 StringPiece server_subkey_secret =
464 server->crypto_negotiated_params().subkey_secret;
467 const char kSampleLabel[] = "label";
468 const char kSampleContext[] = "context";
469 const size_t kSampleOutputLength = 32;
470 string client_key_extraction;
471 string server_key_extraction;
472 EXPECT_TRUE(client->ExportKeyingMaterial(kSampleLabel,
473 kSampleContext,
474 kSampleOutputLength,
475 &client_key_extraction));
476 EXPECT_TRUE(server->ExportKeyingMaterial(kSampleLabel,
477 kSampleContext,
478 kSampleOutputLength,
479 &server_key_extraction));
481 CompareCharArraysWithHexError("client write key",
482 client_encrypter_key.data(),
483 client_encrypter_key.length(),
484 server_decrypter_key.data(),
485 server_decrypter_key.length());
486 CompareCharArraysWithHexError("client write IV",
487 client_encrypter_iv.data(),
488 client_encrypter_iv.length(),
489 server_decrypter_iv.data(),
490 server_decrypter_iv.length());
491 CompareCharArraysWithHexError("server write key",
492 server_encrypter_key.data(),
493 server_encrypter_key.length(),
494 client_decrypter_key.data(),
495 client_decrypter_key.length());
496 CompareCharArraysWithHexError("server write IV",
497 server_encrypter_iv.data(),
498 server_encrypter_iv.length(),
499 client_decrypter_iv.data(),
500 client_decrypter_iv.length());
501 CompareCharArraysWithHexError("client forward secure write key",
502 client_forward_secure_encrypter_key.data(),
503 client_forward_secure_encrypter_key.length(),
504 server_forward_secure_decrypter_key.data(),
505 server_forward_secure_decrypter_key.length());
506 CompareCharArraysWithHexError("client forward secure write IV",
507 client_forward_secure_encrypter_iv.data(),
508 client_forward_secure_encrypter_iv.length(),
509 server_forward_secure_decrypter_iv.data(),
510 server_forward_secure_decrypter_iv.length());
511 CompareCharArraysWithHexError("server forward secure write key",
512 server_forward_secure_encrypter_key.data(),
513 server_forward_secure_encrypter_key.length(),
514 client_forward_secure_decrypter_key.data(),
515 client_forward_secure_decrypter_key.length());
516 CompareCharArraysWithHexError("server forward secure write IV",
517 server_forward_secure_encrypter_iv.data(),
518 server_forward_secure_encrypter_iv.length(),
519 client_forward_secure_decrypter_iv.data(),
520 client_forward_secure_decrypter_iv.length());
521 CompareCharArraysWithHexError("subkey secret",
522 client_subkey_secret.data(),
523 client_subkey_secret.length(),
524 server_subkey_secret.data(),
525 server_subkey_secret.length());
526 CompareCharArraysWithHexError("sample key extraction",
527 client_key_extraction.data(),
528 client_key_extraction.length(),
529 server_key_extraction.data(),
530 server_key_extraction.length());
533 // static
534 QuicTag CryptoTestUtils::ParseTag(const char* tagstr) {
535 const size_t len = strlen(tagstr);
536 CHECK_NE(0u, len);
538 QuicTag tag = 0;
540 if (tagstr[0] == '#') {
541 CHECK_EQ(static_cast<size_t>(1 + 2*4), len);
542 tagstr++;
544 for (size_t i = 0; i < 8; i++) {
545 tag <<= 4;
547 uint8 v = 0;
548 CHECK(HexChar(tagstr[i], &v));
549 tag |= v;
552 return tag;
555 CHECK_LE(len, 4u);
556 for (size_t i = 0; i < 4; i++) {
557 tag >>= 8;
558 if (i < len) {
559 tag |= static_cast<uint32>(tagstr[i]) << 24;
563 return tag;
566 // static
567 CryptoHandshakeMessage CryptoTestUtils::Message(const char* message_tag, ...) {
568 va_list ap;
569 va_start(ap, message_tag);
571 CryptoHandshakeMessage message = BuildMessage(message_tag, ap);
572 va_end(ap);
573 return message;
576 // static
577 CryptoHandshakeMessage CryptoTestUtils::BuildMessage(const char* message_tag,
578 va_list ap) {
579 CryptoHandshakeMessage msg;
580 msg.set_tag(ParseTag(message_tag));
582 for (;;) {
583 const char* tagstr = va_arg(ap, const char*);
584 if (tagstr == nullptr) {
585 break;
588 if (tagstr[0] == '$') {
589 // Special value.
590 const char* const special = tagstr + 1;
591 if (strcmp(special, "padding") == 0) {
592 const int min_bytes = va_arg(ap, int);
593 msg.set_minimum_size(min_bytes);
594 } else {
595 CHECK(false) << "Unknown special value: " << special;
598 continue;
601 const QuicTag tag = ParseTag(tagstr);
602 const char* valuestr = va_arg(ap, const char*);
604 size_t len = strlen(valuestr);
605 if (len > 0 && valuestr[0] == '#') {
606 valuestr++;
607 len--;
609 CHECK_EQ(0u, len % 2);
610 scoped_ptr<uint8[]> buf(new uint8[len/2]);
612 for (size_t i = 0; i < len/2; i++) {
613 uint8 v = 0;
614 CHECK(HexChar(valuestr[i*2], &v));
615 buf[i] = v << 4;
616 CHECK(HexChar(valuestr[i*2 + 1], &v));
617 buf[i] |= v;
620 msg.SetStringPiece(
621 tag, StringPiece(reinterpret_cast<char*>(buf.get()), len/2));
622 continue;
625 msg.SetStringPiece(tag, valuestr);
628 // The CryptoHandshakeMessage needs to be serialized and parsed to ensure
629 // that any padding is included.
630 scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg));
631 scoped_ptr<CryptoHandshakeMessage> parsed(
632 CryptoFramer::ParseMessage(bytes->AsStringPiece()));
633 CHECK(parsed.get());
635 return *parsed;
638 } // namespace test
639 } // namespace net