1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/fake_authenticator.h"
7 #include "base/base64.h"
8 #include "base/callback_helpers.h"
9 #include "base/message_loop/message_loop.h"
10 #include "base/rand_util.h"
11 #include "base/strings/string_number_conversions.h"
12 #include "net/base/io_buffer.h"
13 #include "net/base/net_errors.h"
14 #include "remoting/base/constants.h"
15 #include "remoting/protocol/p2p_stream_socket.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17 #include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
22 FakeChannelAuthenticator::FakeChannelAuthenticator(bool accept
, bool async
)
23 : result_(accept
? net::OK
: net::ERR_FAILED
),
25 did_read_bytes_(false),
26 did_write_bytes_(false),
30 FakeChannelAuthenticator::~FakeChannelAuthenticator() {
33 void FakeChannelAuthenticator::SecureAndAuthenticate(
34 scoped_ptr
<P2PStreamSocket
> socket
,
35 const DoneCallback
& done_callback
) {
36 socket_
= socket
.Pass();
39 done_callback_
= done_callback
;
41 if (result_
!= net::OK
) {
42 // Don't write anything if we are going to reject auth to make test
43 // ordering deterministic.
44 did_write_bytes_
= true;
46 scoped_refptr
<net::IOBuffer
> write_buf
= new net::IOBuffer(1);
47 write_buf
->data()[0] = 0;
48 int result
= socket_
->Write(
50 base::Bind(&FakeChannelAuthenticator::OnAuthBytesWritten
,
51 weak_factory_
.GetWeakPtr()));
52 if (result
!= net::ERR_IO_PENDING
) {
53 // This will not call the callback because |did_read_bytes_| is
54 // still set to false.
55 OnAuthBytesWritten(result
);
59 scoped_refptr
<net::IOBuffer
> read_buf
= new net::IOBuffer(1);
61 socket_
->Read(read_buf
.get(), 1,
62 base::Bind(&FakeChannelAuthenticator::OnAuthBytesRead
,
63 weak_factory_
.GetWeakPtr()));
64 if (result
!= net::ERR_IO_PENDING
)
65 OnAuthBytesRead(result
);
71 void FakeChannelAuthenticator::OnAuthBytesWritten(int result
) {
73 EXPECT_FALSE(did_write_bytes_
);
74 did_write_bytes_
= true;
79 void FakeChannelAuthenticator::OnAuthBytesRead(int result
) {
81 EXPECT_FALSE(did_read_bytes_
);
82 did_read_bytes_
= true;
87 void FakeChannelAuthenticator::CallDoneCallback() {
88 if (result_
!= net::OK
)
90 base::ResetAndReturn(&done_callback_
).Run(result_
, socket_
.Pass());
93 FakeAuthenticator::FakeAuthenticator(Type type
,
98 round_trips_(round_trips
),
102 messages_till_started_(0) {
105 FakeAuthenticator::~FakeAuthenticator() {
108 void FakeAuthenticator::set_messages_till_started(int messages
) {
109 messages_till_started_
= messages
;
112 Authenticator::State
FakeAuthenticator::state() const {
113 EXPECT_LE(messages_
, round_trips_
* 2);
114 if (messages_
>= round_trips_
* 2) {
115 if (action_
== REJECT
) {
122 // Don't send the last message if this is a host that wants to
123 // reject a connection.
124 if (messages_
== round_trips_
* 2 - 1 &&
125 type_
== HOST
&& action_
== REJECT
) {
129 // We are not done yet. process next message.
130 if ((messages_
% 2 == 0 && type_
== CLIENT
) ||
131 (messages_
% 2 == 1 && type_
== HOST
)) {
132 return MESSAGE_READY
;
134 return WAITING_MESSAGE
;
138 bool FakeAuthenticator::started() const {
139 return messages_
> messages_till_started_
;
142 Authenticator::RejectionReason
FakeAuthenticator::rejection_reason() const {
143 EXPECT_EQ(REJECTED
, state());
144 return INVALID_CREDENTIALS
;
147 void FakeAuthenticator::ProcessMessage(const buzz::XmlElement
* message
,
148 const base::Closure
& resume_callback
) {
149 EXPECT_EQ(WAITING_MESSAGE
, state());
151 message
->TextNamed(buzz::QName(kChromotingXmlNamespace
, "id"));
152 EXPECT_EQ(id
, base::IntToString(messages_
));
154 // On the client receive the key in the last message.
155 if (type_
== CLIENT
&& messages_
== round_trips_
* 2 - 1) {
156 std::string key_base64
=
157 message
->TextNamed(buzz::QName(kChromotingXmlNamespace
, "key"));
158 EXPECT_TRUE(!key_base64
.empty());
159 EXPECT_TRUE(base::Base64Decode(key_base64
, &auth_key_
));
163 resume_callback
.Run();
166 scoped_ptr
<buzz::XmlElement
> FakeAuthenticator::GetNextMessage() {
167 EXPECT_EQ(MESSAGE_READY
, state());
169 scoped_ptr
<buzz::XmlElement
> result(new buzz::XmlElement(
170 buzz::QName(kChromotingXmlNamespace
, "authentication")));
171 buzz::XmlElement
* id
= new buzz::XmlElement(
172 buzz::QName(kChromotingXmlNamespace
, "id"));
173 id
->AddText(base::IntToString(messages_
));
174 result
->AddElement(id
);
176 // Add authentication key in the last message sent from host to client.
177 if (type_
== HOST
&& messages_
== round_trips_
* 2 - 1) {
178 auth_key_
= base::RandBytesAsString(16);
179 buzz::XmlElement
* key
= new buzz::XmlElement(
180 buzz::QName(kChromotingXmlNamespace
, "key"));
181 std::string key_base64
;
182 base::Base64Encode(auth_key_
, &key_base64
);
183 key
->AddText(key_base64
);
184 result
->AddElement(key
);
188 return result
.Pass();
191 const std::string
& FakeAuthenticator::GetAuthKey() const {
192 EXPECT_EQ(ACCEPTED
, state());
196 scoped_ptr
<ChannelAuthenticator
>
197 FakeAuthenticator::CreateChannelAuthenticator() const {
198 EXPECT_EQ(ACCEPTED
, state());
199 return make_scoped_ptr(
200 new FakeChannelAuthenticator(action_
!= REJECT_CHANNEL
, async_
));
203 FakeHostAuthenticatorFactory::FakeHostAuthenticatorFactory(
204 int round_trips
, int messages_till_started
,
205 FakeAuthenticator::Action action
, bool async
)
206 : round_trips_(round_trips
),
207 messages_till_started_(messages_till_started
),
208 action_(action
), async_(async
) {
211 FakeHostAuthenticatorFactory::~FakeHostAuthenticatorFactory() {
214 scoped_ptr
<Authenticator
> FakeHostAuthenticatorFactory::CreateAuthenticator(
215 const std::string
& local_jid
,
216 const std::string
& remote_jid
,
217 const buzz::XmlElement
* first_message
) {
218 FakeAuthenticator
* authenticator
= new FakeAuthenticator(
219 FakeAuthenticator::HOST
, round_trips_
, action_
, async_
);
220 authenticator
->set_messages_till_started(messages_till_started_
);
222 scoped_ptr
<Authenticator
> result(authenticator
);
223 return result
.Pass();
226 } // namespace protocol
227 } // namespace remoting