Supervised user whitelists: Cleanup
[chromium-blink-merge.git] / net / socket / ssl_server_socket_unittest.cc
blobea072b05fec373f4ab0626396f9591f8a1b6594a
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 // Connects SSL socket to FakeDataChannel. This class is just a stub.
11 // 2. FakeDataChannel
12 // Implements the actual exchange of data between two FakeSockets.
14 // Implementations of these two classes are included in this file.
16 #include "net/socket/ssl_server_socket.h"
18 #include <stdlib.h>
20 #include <queue>
22 #include "base/compiler_specific.h"
23 #include "base/files/file_path.h"
24 #include "base/files/file_util.h"
25 #include "base/message_loop/message_loop.h"
26 #include "crypto/nss_util.h"
27 #include "crypto/rsa_private_key.h"
28 #include "net/base/address_list.h"
29 #include "net/base/completion_callback.h"
30 #include "net/base/host_port_pair.h"
31 #include "net/base/io_buffer.h"
32 #include "net/base/ip_endpoint.h"
33 #include "net/base/net_errors.h"
34 #include "net/base/test_data_directory.h"
35 #include "net/cert/cert_status_flags.h"
36 #include "net/cert/mock_cert_verifier.h"
37 #include "net/cert/x509_certificate.h"
38 #include "net/http/transport_security_state.h"
39 #include "net/log/net_log.h"
40 #include "net/socket/client_socket_factory.h"
41 #include "net/socket/socket_test_util.h"
42 #include "net/socket/ssl_client_socket.h"
43 #include "net/socket/stream_socket.h"
44 #include "net/ssl/ssl_config_service.h"
45 #include "net/ssl/ssl_info.h"
46 #include "net/test/cert_test_util.h"
47 #include "testing/gtest/include/gtest/gtest.h"
48 #include "testing/platform_test.h"
50 namespace net {
52 namespace {
54 class FakeDataChannel {
55 public:
56 FakeDataChannel()
57 : read_buf_len_(0),
58 closed_(false),
59 write_called_after_close_(false),
60 weak_factory_(this) {
63 int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
64 DCHECK(read_callback_.is_null());
65 DCHECK(!read_buf_.get());
66 if (closed_)
67 return 0;
68 if (data_.empty()) {
69 read_callback_ = callback;
70 read_buf_ = buf;
71 read_buf_len_ = buf_len;
72 return ERR_IO_PENDING;
74 return PropogateData(buf, buf_len);
77 int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
78 DCHECK(write_callback_.is_null());
79 if (closed_) {
80 if (write_called_after_close_)
81 return ERR_CONNECTION_RESET;
82 write_called_after_close_ = true;
83 write_callback_ = callback;
84 base::MessageLoop::current()->PostTask(
85 FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
86 weak_factory_.GetWeakPtr()));
87 return ERR_IO_PENDING;
89 // This function returns synchronously, so make a copy of the buffer.
90 data_.push(new DrainableIOBuffer(
91 new StringIOBuffer(std::string(buf->data(), buf_len)),
92 buf_len));
93 base::MessageLoop::current()->PostTask(
94 FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
95 weak_factory_.GetWeakPtr()));
96 return buf_len;
99 // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
100 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
101 // after the FakeDataChannel is closed, the first Write() call completes
102 // asynchronously, which is necessary to reproduce bug 127822.
103 void Close() {
104 closed_ = true;
107 private:
108 void DoReadCallback() {
109 if (read_callback_.is_null() || data_.empty())
110 return;
112 int copied = PropogateData(read_buf_, read_buf_len_);
113 CompletionCallback callback = read_callback_;
114 read_callback_.Reset();
115 read_buf_ = NULL;
116 read_buf_len_ = 0;
117 callback.Run(copied);
120 void DoWriteCallback() {
121 if (write_callback_.is_null())
122 return;
124 CompletionCallback callback = write_callback_;
125 write_callback_.Reset();
126 callback.Run(ERR_CONNECTION_RESET);
129 int PropogateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
130 scoped_refptr<DrainableIOBuffer> buf = data_.front();
131 int copied = std::min(buf->BytesRemaining(), read_buf_len);
132 memcpy(read_buf->data(), buf->data(), copied);
133 buf->DidConsume(copied);
135 if (!buf->BytesRemaining())
136 data_.pop();
137 return copied;
140 CompletionCallback read_callback_;
141 scoped_refptr<IOBuffer> read_buf_;
142 int read_buf_len_;
144 CompletionCallback write_callback_;
146 std::queue<scoped_refptr<DrainableIOBuffer> > data_;
148 // True if Close() has been called.
149 bool closed_;
151 // Controls the completion of Write() after the FakeDataChannel is closed.
152 // After the FakeDataChannel is closed, the first Write() call completes
153 // asynchronously.
154 bool write_called_after_close_;
156 base::WeakPtrFactory<FakeDataChannel> weak_factory_;
158 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
161 class FakeSocket : public StreamSocket {
162 public:
163 FakeSocket(FakeDataChannel* incoming_channel,
164 FakeDataChannel* outgoing_channel)
165 : incoming_(incoming_channel),
166 outgoing_(outgoing_channel) {
169 ~FakeSocket() override {}
171 int Read(IOBuffer* buf,
172 int buf_len,
173 const CompletionCallback& callback) override {
174 // Read random number of bytes.
175 buf_len = rand() % buf_len + 1;
176 return incoming_->Read(buf, buf_len, callback);
179 int Write(IOBuffer* buf,
180 int buf_len,
181 const CompletionCallback& callback) override {
182 // Write random number of bytes.
183 buf_len = rand() % buf_len + 1;
184 return outgoing_->Write(buf, buf_len, callback);
187 int SetReceiveBufferSize(int32 size) override { return OK; }
189 int SetSendBufferSize(int32 size) override { return OK; }
191 int Connect(const CompletionCallback& callback) override { return OK; }
193 void Disconnect() override {
194 incoming_->Close();
195 outgoing_->Close();
198 bool IsConnected() const override { return true; }
200 bool IsConnectedAndIdle() const override { return true; }
202 int GetPeerAddress(IPEndPoint* address) const override {
203 IPAddressNumber ip_address(kIPv4AddressSize);
204 *address = IPEndPoint(ip_address, 0 /*port*/);
205 return OK;
208 int GetLocalAddress(IPEndPoint* address) const override {
209 IPAddressNumber ip_address(4);
210 *address = IPEndPoint(ip_address, 0);
211 return OK;
214 const BoundNetLog& NetLog() const override { return net_log_; }
216 void SetSubresourceSpeculation() override {}
217 void SetOmniboxSpeculation() override {}
219 bool WasEverUsed() const override { return true; }
221 bool UsingTCPFastOpen() const override { return false; }
223 bool WasNpnNegotiated() const override { return false; }
225 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
227 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
229 private:
230 BoundNetLog net_log_;
231 FakeDataChannel* incoming_;
232 FakeDataChannel* outgoing_;
234 DISALLOW_COPY_AND_ASSIGN(FakeSocket);
237 } // namespace
239 // Verify the correctness of the test helper classes first.
240 TEST(FakeSocketTest, DataTransfer) {
241 // Establish channels between two sockets.
242 FakeDataChannel channel_1;
243 FakeDataChannel channel_2;
244 FakeSocket client(&channel_1, &channel_2);
245 FakeSocket server(&channel_2, &channel_1);
247 const char kTestData[] = "testing123";
248 const int kTestDataSize = strlen(kTestData);
249 const int kReadBufSize = 1024;
250 scoped_refptr<IOBuffer> write_buf = new StringIOBuffer(kTestData);
251 scoped_refptr<IOBuffer> read_buf = new IOBuffer(kReadBufSize);
253 // Write then read.
254 int written =
255 server.Write(write_buf.get(), kTestDataSize, CompletionCallback());
256 EXPECT_GT(written, 0);
257 EXPECT_LE(written, kTestDataSize);
259 int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback());
260 EXPECT_GT(read, 0);
261 EXPECT_LE(read, written);
262 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
264 // Read then write.
265 TestCompletionCallback callback;
266 EXPECT_EQ(ERR_IO_PENDING,
267 server.Read(read_buf.get(), kReadBufSize, callback.callback()));
269 written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
270 EXPECT_GT(written, 0);
271 EXPECT_LE(written, kTestDataSize);
273 read = callback.WaitForResult();
274 EXPECT_GT(read, 0);
275 EXPECT_LE(read, written);
276 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
279 class SSLServerSocketTest : public PlatformTest {
280 public:
281 SSLServerSocketTest()
282 : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
283 cert_verifier_(new MockCertVerifier()),
284 transport_security_state_(new TransportSecurityState) {
285 cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
288 protected:
289 void Initialize() {
290 scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
291 client_connection->SetSocket(
292 scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
293 scoped_ptr<StreamSocket> server_socket(
294 new FakeSocket(&channel_2_, &channel_1_));
296 base::FilePath certs_dir(GetTestCertsDirectory());
298 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
299 std::string cert_der;
300 ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
302 scoped_refptr<X509Certificate> cert =
303 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
305 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
306 std::string key_string;
307 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
308 std::vector<uint8> key_vector(
309 reinterpret_cast<const uint8*>(key_string.data()),
310 reinterpret_cast<const uint8*>(key_string.data() +
311 key_string.length()));
313 scoped_ptr<crypto::RSAPrivateKey> private_key(
314 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
316 SSLConfig ssl_config;
317 ssl_config.false_start_enabled = false;
318 ssl_config.channel_id_enabled = false;
320 // Certificate provided by the host doesn't need authority.
321 SSLConfig::CertAndStatus cert_and_status;
322 cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
323 cert_and_status.der_cert = cert_der;
324 ssl_config.allowed_bad_certs.push_back(cert_and_status);
326 HostPortPair host_and_pair("unittest", 0);
327 SSLClientSocketContext context;
328 context.cert_verifier = cert_verifier_.get();
329 context.transport_security_state = transport_security_state_.get();
330 client_socket_ =
331 socket_factory_->CreateSSLClientSocket(
332 client_connection.Pass(), host_and_pair, ssl_config, context);
333 server_socket_ = CreateSSLServerSocket(
334 server_socket.Pass(),
335 cert.get(), private_key.get(), SSLConfig());
338 FakeDataChannel channel_1_;
339 FakeDataChannel channel_2_;
340 scoped_ptr<SSLClientSocket> client_socket_;
341 scoped_ptr<SSLServerSocket> server_socket_;
342 ClientSocketFactory* socket_factory_;
343 scoped_ptr<MockCertVerifier> cert_verifier_;
344 scoped_ptr<TransportSecurityState> transport_security_state_;
347 // This test only executes creation of client and server sockets. This is to
348 // test that creation of sockets doesn't crash and have minimal code to run
349 // under valgrind in order to help debugging memory problems.
350 TEST_F(SSLServerSocketTest, Initialize) {
351 Initialize();
354 // This test executes Connect() on SSLClientSocket and Handshake() on
355 // SSLServerSocket to make sure handshaking between the two sockets is
356 // completed successfully.
357 TEST_F(SSLServerSocketTest, Handshake) {
358 Initialize();
360 TestCompletionCallback connect_callback;
361 TestCompletionCallback handshake_callback;
363 int server_ret = server_socket_->Handshake(handshake_callback.callback());
364 EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
366 int client_ret = client_socket_->Connect(connect_callback.callback());
367 EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
369 if (client_ret == ERR_IO_PENDING) {
370 EXPECT_EQ(OK, connect_callback.WaitForResult());
372 if (server_ret == ERR_IO_PENDING) {
373 EXPECT_EQ(OK, handshake_callback.WaitForResult());
376 // Make sure the cert status is expected.
377 SSLInfo ssl_info;
378 client_socket_->GetSSLInfo(&ssl_info);
379 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
382 TEST_F(SSLServerSocketTest, DataTransfer) {
383 Initialize();
385 TestCompletionCallback connect_callback;
386 TestCompletionCallback handshake_callback;
388 // Establish connection.
389 int client_ret = client_socket_->Connect(connect_callback.callback());
390 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
392 int server_ret = server_socket_->Handshake(handshake_callback.callback());
393 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
395 client_ret = connect_callback.GetResult(client_ret);
396 ASSERT_EQ(OK, client_ret);
397 server_ret = handshake_callback.GetResult(server_ret);
398 ASSERT_EQ(OK, server_ret);
400 const int kReadBufSize = 1024;
401 scoped_refptr<StringIOBuffer> write_buf =
402 new StringIOBuffer("testing123");
403 scoped_refptr<DrainableIOBuffer> read_buf =
404 new DrainableIOBuffer(new IOBuffer(kReadBufSize), kReadBufSize);
406 // Write then read.
407 TestCompletionCallback write_callback;
408 TestCompletionCallback read_callback;
409 server_ret = server_socket_->Write(
410 write_buf.get(), write_buf->size(), write_callback.callback());
411 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
412 client_ret = client_socket_->Read(
413 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
414 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
416 server_ret = write_callback.GetResult(server_ret);
417 EXPECT_GT(server_ret, 0);
418 client_ret = read_callback.GetResult(client_ret);
419 ASSERT_GT(client_ret, 0);
421 read_buf->DidConsume(client_ret);
422 while (read_buf->BytesConsumed() < write_buf->size()) {
423 client_ret = client_socket_->Read(
424 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
425 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
426 client_ret = read_callback.GetResult(client_ret);
427 ASSERT_GT(client_ret, 0);
428 read_buf->DidConsume(client_ret);
430 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
431 read_buf->SetOffset(0);
432 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
434 // Read then write.
435 write_buf = new StringIOBuffer("hello123");
436 server_ret = server_socket_->Read(
437 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
438 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
439 client_ret = client_socket_->Write(
440 write_buf.get(), write_buf->size(), write_callback.callback());
441 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
443 server_ret = read_callback.GetResult(server_ret);
444 ASSERT_GT(server_ret, 0);
445 client_ret = write_callback.GetResult(client_ret);
446 EXPECT_GT(client_ret, 0);
448 read_buf->DidConsume(server_ret);
449 while (read_buf->BytesConsumed() < write_buf->size()) {
450 server_ret = server_socket_->Read(
451 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
452 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
453 server_ret = read_callback.GetResult(server_ret);
454 ASSERT_GT(server_ret, 0);
455 read_buf->DidConsume(server_ret);
457 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
458 read_buf->SetOffset(0);
459 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
462 // A regression test for bug 127822 (http://crbug.com/127822).
463 // If the server closes the connection after the handshake is finished,
464 // the client's Write() call should not cause an infinite loop.
465 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
466 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
467 Initialize();
469 TestCompletionCallback connect_callback;
470 TestCompletionCallback handshake_callback;
472 // Establish connection.
473 int client_ret = client_socket_->Connect(connect_callback.callback());
474 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
476 int server_ret = server_socket_->Handshake(handshake_callback.callback());
477 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
479 client_ret = connect_callback.GetResult(client_ret);
480 ASSERT_EQ(OK, client_ret);
481 server_ret = handshake_callback.GetResult(server_ret);
482 ASSERT_EQ(OK, server_ret);
484 scoped_refptr<StringIOBuffer> write_buf = new StringIOBuffer("testing123");
486 // The server closes the connection. The server needs to write some
487 // data first so that the client's Read() calls from the transport
488 // socket won't return ERR_IO_PENDING. This ensures that the client
489 // will call Read() on the transport socket again.
490 TestCompletionCallback write_callback;
492 server_ret = server_socket_->Write(
493 write_buf.get(), write_buf->size(), write_callback.callback());
494 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
496 server_ret = write_callback.GetResult(server_ret);
497 EXPECT_GT(server_ret, 0);
499 server_socket_->Disconnect();
501 // The client writes some data. This should not cause an infinite loop.
502 client_ret = client_socket_->Write(
503 write_buf.get(), write_buf->size(), write_callback.callback());
504 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
506 client_ret = write_callback.GetResult(client_ret);
507 EXPECT_GT(client_ret, 0);
509 base::MessageLoop::current()->PostDelayedTask(
510 FROM_HERE, base::MessageLoop::QuitClosure(),
511 base::TimeDelta::FromMilliseconds(10));
512 base::MessageLoop::current()->Run();
515 // This test executes ExportKeyingMaterial() on the client and server sockets,
516 // after connecting them, and verifies that the results match.
517 // This test will fail if False Start is enabled (see crbug.com/90208).
518 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
519 Initialize();
521 TestCompletionCallback connect_callback;
522 TestCompletionCallback handshake_callback;
524 int client_ret = client_socket_->Connect(connect_callback.callback());
525 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
527 int server_ret = server_socket_->Handshake(handshake_callback.callback());
528 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
530 if (client_ret == ERR_IO_PENDING) {
531 ASSERT_EQ(OK, connect_callback.WaitForResult());
533 if (server_ret == ERR_IO_PENDING) {
534 ASSERT_EQ(OK, handshake_callback.WaitForResult());
537 const int kKeyingMaterialSize = 32;
538 const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
539 const char kKeyingContext[] = "";
540 unsigned char server_out[kKeyingMaterialSize];
541 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
542 false, kKeyingContext,
543 server_out, sizeof(server_out));
544 ASSERT_EQ(OK, rv);
546 unsigned char client_out[kKeyingMaterialSize];
547 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
548 false, kKeyingContext,
549 client_out, sizeof(client_out));
550 ASSERT_EQ(OK, rv);
551 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
553 const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
554 unsigned char client_bad[kKeyingMaterialSize];
555 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
556 false, kKeyingContext,
557 client_bad, sizeof(client_bad));
558 ASSERT_EQ(rv, OK);
559 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
562 } // namespace net