1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/authenticator_test_base.h"
7 #include "base/base64.h"
8 #include "base/files/file_path.h"
9 #include "base/files/file_util.h"
10 #include "base/test/test_timeouts.h"
11 #include "base/timer/timer.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/test_data_directory.h"
14 #include "remoting/base/rsa_key_pair.h"
15 #include "remoting/protocol/authenticator.h"
16 #include "remoting/protocol/channel_authenticator.h"
17 #include "remoting/protocol/fake_stream_socket.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
22 using testing::SaveArg
;
29 ACTION_P(QuitThreadOnCounter
, counter
) {
31 EXPECT_GE(*counter
, 0);
33 base::MessageLoop::current()->Quit();
38 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
40 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
42 AuthenticatorTestBase::AuthenticatorTestBase() {}
44 AuthenticatorTestBase::~AuthenticatorTestBase() {}
46 void AuthenticatorTestBase::SetUp() {
47 base::FilePath
certs_dir(net::GetTestCertsDirectory());
49 base::FilePath cert_path
= certs_dir
.AppendASCII("unittest.selfsigned.der");
50 ASSERT_TRUE(base::ReadFileToString(cert_path
, &host_cert_
));
52 base::FilePath key_path
= certs_dir
.AppendASCII("unittest.key.bin");
53 std::string key_string
;
54 ASSERT_TRUE(base::ReadFileToString(key_path
, &key_string
));
55 std::string key_base64
;
56 base::Base64Encode(key_string
, &key_base64
);
57 key_pair_
= RsaKeyPair::FromString(key_base64
);
58 ASSERT_TRUE(key_pair_
.get());
59 host_public_key_
= key_pair_
->GetPublicKey();
62 void AuthenticatorTestBase::RunAuthExchange() {
63 ContinueAuthExchangeWith(client_
.get(),
69 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
70 ContinueAuthExchangeWith(host_
.get(),
77 // This function sends a message from the sender and receiver and recursively
78 // calls itself to the send the next message from the receiver to the sender
79 // untils the authentication completes.
80 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator
* sender
,
81 Authenticator
* receiver
,
83 bool receiver_started
) {
84 scoped_ptr
<buzz::XmlElement
> message
;
85 ASSERT_NE(Authenticator::WAITING_MESSAGE
, sender
->state());
86 if (sender
->state() == Authenticator::ACCEPTED
||
87 sender
->state() == Authenticator::REJECTED
)
90 // Verify that once the started flag for either party is set to true,
91 // it should always stay true.
92 if (receiver_started
) {
93 ASSERT_TRUE(receiver
->started());
97 ASSERT_TRUE(sender
->started());
100 ASSERT_EQ(Authenticator::MESSAGE_READY
, sender
->state());
101 message
= sender
->GetNextMessage();
102 ASSERT_TRUE(message
.get());
103 ASSERT_NE(Authenticator::MESSAGE_READY
, sender
->state());
105 ASSERT_EQ(Authenticator::WAITING_MESSAGE
, receiver
->state());
106 receiver
->ProcessMessage(message
.get(), base::Bind(
107 &AuthenticatorTestBase::ContinueAuthExchangeWith
,
108 base::Unretained(receiver
), base::Unretained(sender
),
109 receiver
->started(), sender
->started()));
112 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail
) {
113 client_fake_socket_
.reset(new FakeStreamSocket());
114 host_fake_socket_
.reset(new FakeStreamSocket());
115 client_fake_socket_
->PairWith(host_fake_socket_
.get());
117 client_auth_
->SecureAndAuthenticate(
118 client_fake_socket_
.Pass(),
119 base::Bind(&AuthenticatorTestBase::OnClientConnected
,
120 base::Unretained(this)));
122 host_auth_
->SecureAndAuthenticate(
123 host_fake_socket_
.Pass(),
124 base::Bind(&AuthenticatorTestBase::OnHostConnected
,
125 base::Unretained(this)));
127 // Expect two callbacks to be called - the client callback and the host
129 int callback_counter
= 2;
131 EXPECT_CALL(client_callback_
, OnDone(net::OK
))
132 .WillOnce(QuitThreadOnCounter(&callback_counter
));
134 EXPECT_CALL(host_callback_
, OnDone(net::ERR_FAILED
))
135 .WillOnce(QuitThreadOnCounter(&callback_counter
));
137 EXPECT_CALL(host_callback_
, OnDone(net::OK
))
138 .WillOnce(QuitThreadOnCounter(&callback_counter
));
141 // Ensure that .Run() does not run unbounded if the callbacks are never
143 base::Timer
shutdown_timer(false, false);
144 shutdown_timer
.Start(FROM_HERE
,
145 TestTimeouts::action_timeout(),
146 base::MessageLoop::QuitClosure());
148 shutdown_timer
.Stop();
150 testing::Mock::VerifyAndClearExpectations(&client_callback_
);
151 testing::Mock::VerifyAndClearExpectations(&host_callback_
);
153 if (!expected_fail
) {
154 ASSERT_TRUE(client_socket_
.get() != nullptr);
155 ASSERT_TRUE(host_socket_
.get() != nullptr);
159 void AuthenticatorTestBase::OnHostConnected(
161 scoped_ptr
<net::StreamSocket
> socket
) {
162 host_callback_
.OnDone(error
);
163 host_socket_
= socket
.Pass();
166 void AuthenticatorTestBase::OnClientConnected(
168 scoped_ptr
<net::StreamSocket
> socket
) {
169 client_callback_
.OnDone(error
);
170 client_socket_
= socket
.Pass();
173 } // namespace protocol
174 } // namespace remoting