1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
7 #include "base/base64.h"
9 #include "base/file_util.h"
10 #include "base/files/file_path.h"
11 #include "base/message_loop/message_loop.h"
12 #include "base/path_service.h"
13 #include "base/test/test_timeouts.h"
14 #include "base/timer/timer.h"
15 #include "crypto/rsa_private_key.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/test_data_directory.h"
18 #include "remoting/base/rsa_key_pair.h"
19 #include "remoting/protocol/connection_tester.h"
20 #include "remoting/protocol/fake_session.h"
21 #include "testing/gmock/include/gmock/gmock.h"
22 #include "testing/gtest/include/gtest/gtest.h"
23 #include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
26 using testing::NotNull
;
27 using testing::SaveArg
;
34 const char kTestSharedSecret
[] = "1234-1234-5678";
35 const char kTestSharedSecretBad
[] = "0000-0000-0001";
37 class MockChannelDoneCallback
{
39 MOCK_METHOD2(OnDone
, void(net::Error error
, net::StreamSocket
* socket
));
42 ACTION_P(QuitThreadOnCounter
, counter
) {
44 EXPECT_GE(*counter
, 0);
46 base::MessageLoop::current()->Quit();
51 class SslHmacChannelAuthenticatorTest
: public testing::Test
{
53 SslHmacChannelAuthenticatorTest() {}
54 virtual ~SslHmacChannelAuthenticatorTest() {}
57 virtual void SetUp() OVERRIDE
{
58 base::FilePath
certs_dir(net::GetTestCertsDirectory());
60 base::FilePath cert_path
= certs_dir
.AppendASCII("unittest.selfsigned.der");
61 ASSERT_TRUE(base::ReadFileToString(cert_path
, &host_cert_
));
63 base::FilePath key_path
= certs_dir
.AppendASCII("unittest.key.bin");
64 std::string key_string
;
65 ASSERT_TRUE(base::ReadFileToString(key_path
, &key_string
));
66 std::string key_base64
;
67 base::Base64Encode(key_string
, &key_base64
);
68 key_pair_
= RsaKeyPair::FromString(key_base64
);
69 ASSERT_TRUE(key_pair_
.get());
72 void RunChannelAuth(bool expected_fail
) {
73 client_fake_socket_
.reset(new FakeSocket());
74 host_fake_socket_
.reset(new FakeSocket());
75 client_fake_socket_
->PairWith(host_fake_socket_
.get());
77 client_auth_
->SecureAndAuthenticate(
78 client_fake_socket_
.PassAs
<net::StreamSocket
>(),
79 base::Bind(&SslHmacChannelAuthenticatorTest::OnClientConnected
,
80 base::Unretained(this)));
82 host_auth_
->SecureAndAuthenticate(
83 host_fake_socket_
.PassAs
<net::StreamSocket
>(),
84 base::Bind(&SslHmacChannelAuthenticatorTest::OnHostConnected
,
85 base::Unretained(this)));
87 // Expect two callbacks to be called - the client callback and the host
89 int callback_counter
= 2;
92 EXPECT_CALL(client_callback_
, OnDone(net::ERR_FAILED
, NULL
))
93 .WillOnce(QuitThreadOnCounter(&callback_counter
));
94 EXPECT_CALL(host_callback_
, OnDone(net::ERR_FAILED
, NULL
))
95 .WillOnce(QuitThreadOnCounter(&callback_counter
));
97 EXPECT_CALL(client_callback_
, OnDone(net::OK
, NotNull()))
98 .WillOnce(QuitThreadOnCounter(&callback_counter
));
99 EXPECT_CALL(host_callback_
, OnDone(net::OK
, NotNull()))
100 .WillOnce(QuitThreadOnCounter(&callback_counter
));
103 // Ensure that .Run() does not run unbounded if the callbacks are never
105 base::Timer
shutdown_timer(false, false);
106 shutdown_timer
.Start(FROM_HERE
,
107 TestTimeouts::action_timeout(),
108 base::MessageLoop::QuitClosure());
112 void OnHostConnected(net::Error error
,
113 scoped_ptr
<net::StreamSocket
> socket
) {
114 host_callback_
.OnDone(error
, socket
.get());
115 host_socket_
= socket
.Pass();
118 void OnClientConnected(net::Error error
,
119 scoped_ptr
<net::StreamSocket
> socket
) {
120 client_callback_
.OnDone(error
, socket
.get());
121 client_socket_
= socket
.Pass();
124 base::MessageLoop message_loop_
;
126 scoped_refptr
<RsaKeyPair
> key_pair_
;
127 std::string host_cert_
;
128 scoped_ptr
<FakeSocket
> client_fake_socket_
;
129 scoped_ptr
<FakeSocket
> host_fake_socket_
;
130 scoped_ptr
<ChannelAuthenticator
> client_auth_
;
131 scoped_ptr
<ChannelAuthenticator
> host_auth_
;
132 MockChannelDoneCallback client_callback_
;
133 MockChannelDoneCallback host_callback_
;
134 scoped_ptr
<net::StreamSocket
> client_socket_
;
135 scoped_ptr
<net::StreamSocket
> host_socket_
;
137 DISALLOW_COPY_AND_ASSIGN(SslHmacChannelAuthenticatorTest
);
140 // Verify that a channel can be connected using a valid shared secret.
141 TEST_F(SslHmacChannelAuthenticatorTest
, SuccessfulAuth
) {
142 client_auth_
= SslHmacChannelAuthenticator::CreateForClient(
143 host_cert_
, kTestSharedSecret
);
144 host_auth_
= SslHmacChannelAuthenticator::CreateForHost(
145 host_cert_
, key_pair_
, kTestSharedSecret
);
147 RunChannelAuth(false);
149 ASSERT_TRUE(client_socket_
.get() != NULL
);
150 ASSERT_TRUE(host_socket_
.get() != NULL
);
152 StreamConnectionTester
tester(host_socket_
.get(), client_socket_
.get(),
157 tester
.CheckResults();
160 // Verify that channels cannot be using invalid shared secret.
161 TEST_F(SslHmacChannelAuthenticatorTest
, InvalidChannelSecret
) {
162 client_auth_
= SslHmacChannelAuthenticator::CreateForClient(
163 host_cert_
, kTestSharedSecretBad
);
164 host_auth_
= SslHmacChannelAuthenticator::CreateForHost(
165 host_cert_
, key_pair_
, kTestSharedSecret
);
167 RunChannelAuth(true);
169 ASSERT_TRUE(host_socket_
.get() == NULL
);
172 } // namespace protocol
173 } // namespace remoting