Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / remoting / host / it2me / it2me_native_messaging_host_unittest.cc
blobfa0b088aa88dd8bb4046e90ddd84c93b29b088d1
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/it2me/it2me_native_messaging_host.h"
7 #include "base/basictypes.h"
8 #include "base/compiler_specific.h"
9 #include "base/json/json_reader.h"
10 #include "base/json/json_writer.h"
11 #include "base/message_loop/message_loop.h"
12 #include "base/run_loop.h"
13 #include "base/stl_util.h"
14 #include "base/strings/stringize_macros.h"
15 #include "base/values.h"
16 #include "net/base/file_stream.h"
17 #include "net/base/net_util.h"
18 #include "remoting/base/auto_thread_task_runner.h"
19 #include "remoting/host/chromoting_host_context.h"
20 #include "remoting/host/native_messaging/native_messaging_pipe.h"
21 #include "remoting/host/native_messaging/pipe_messaging_channel.h"
22 #include "remoting/host/policy_watcher.h"
23 #include "remoting/host/setup/test_util.h"
24 #include "testing/gtest/include/gtest/gtest.h"
26 namespace remoting {
28 namespace {
30 const char kTestAccessCode[] = "888888";
31 const int kTestAccessCodeLifetimeInSeconds = 666;
32 const char kTestClientUsername[] = "some_user@gmail.com";
34 void VerifyId(scoped_ptr<base::DictionaryValue> response, int expected_value) {
35 ASSERT_TRUE(response);
37 int value;
38 EXPECT_TRUE(response->GetInteger("id", &value));
39 EXPECT_EQ(expected_value, value);
42 void VerifyStringProperty(scoped_ptr<base::DictionaryValue> response,
43 const std::string& name,
44 const std::string& expected_value) {
45 ASSERT_TRUE(response);
47 std::string value;
48 EXPECT_TRUE(response->GetString(name, &value));
49 EXPECT_EQ(expected_value, value);
52 // Verity the values of the "type" and "id" properties
53 void VerifyCommonProperties(scoped_ptr<base::DictionaryValue> response,
54 const std::string& type,
55 int id) {
56 ASSERT_TRUE(response);
58 std::string string_value;
59 EXPECT_TRUE(response->GetString("type", &string_value));
60 EXPECT_EQ(type, string_value);
62 int int_value;
63 EXPECT_TRUE(response->GetInteger("id", &int_value));
64 EXPECT_EQ(id, int_value);
67 } // namespace
69 class MockIt2MeHost : public It2MeHost {
70 public:
71 MockIt2MeHost(scoped_ptr<ChromotingHostContext> context,
72 scoped_ptr<PolicyWatcher> policy_watcher,
73 base::WeakPtr<It2MeHost::Observer> observer,
74 const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
75 const std::string& directory_bot_jid)
76 : It2MeHost(context.Pass(),
77 policy_watcher.Pass(),
78 nullptr,
79 observer,
80 xmpp_server_config,
81 directory_bot_jid) {}
83 // It2MeHost overrides
84 void Connect() override;
85 void Disconnect() override;
86 void RequestNatPolicy() override;
88 private:
89 ~MockIt2MeHost() override {}
91 void RunSetState(It2MeHostState state);
93 DISALLOW_COPY_AND_ASSIGN(MockIt2MeHost);
96 void MockIt2MeHost::Connect() {
97 if (!host_context()->ui_task_runner()->BelongsToCurrentThread()) {
98 DCHECK(task_runner()->BelongsToCurrentThread());
99 host_context()->ui_task_runner()->PostTask(
100 FROM_HERE, base::Bind(&MockIt2MeHost::Connect, this));
101 return;
104 RunSetState(kStarting);
105 RunSetState(kRequestedAccessCode);
107 std::string access_code(kTestAccessCode);
108 base::TimeDelta lifetime =
109 base::TimeDelta::FromSeconds(kTestAccessCodeLifetimeInSeconds);
110 task_runner()->PostTask(FROM_HERE,
111 base::Bind(&It2MeHost::Observer::OnStoreAccessCode,
112 observer(),
113 access_code,
114 lifetime));
116 RunSetState(kReceivedAccessCode);
118 std::string client_username(kTestClientUsername);
119 task_runner()->PostTask(
120 FROM_HERE,
121 base::Bind(&It2MeHost::Observer::OnClientAuthenticated,
122 observer(),
123 client_username));
125 RunSetState(kConnected);
128 void MockIt2MeHost::Disconnect() {
129 if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
130 DCHECK(task_runner()->BelongsToCurrentThread());
131 host_context()->network_task_runner()->PostTask(
132 FROM_HERE, base::Bind(&MockIt2MeHost::Disconnect, this));
133 return;
136 RunSetState(kDisconnecting);
137 RunSetState(kDisconnected);
140 void MockIt2MeHost::RequestNatPolicy() {}
142 void MockIt2MeHost::RunSetState(It2MeHostState state) {
143 if (!host_context()->network_task_runner()->BelongsToCurrentThread()) {
144 host_context()->network_task_runner()->PostTask(
145 FROM_HERE, base::Bind(&It2MeHost::SetStateForTesting, this, state, ""));
146 } else {
147 SetStateForTesting(state, "");
151 class MockIt2MeHostFactory : public It2MeHostFactory {
152 public:
153 MockIt2MeHostFactory() : It2MeHostFactory() {}
154 scoped_refptr<It2MeHost> CreateIt2MeHost(
155 scoped_ptr<ChromotingHostContext> context,
156 base::WeakPtr<It2MeHost::Observer> observer,
157 const XmppSignalStrategy::XmppServerConfig& xmpp_server_config,
158 const std::string& directory_bot_jid) override {
159 return new MockIt2MeHost(context.Pass(), nullptr, observer,
160 xmpp_server_config, directory_bot_jid);
163 private:
164 DISALLOW_COPY_AND_ASSIGN(MockIt2MeHostFactory);
165 }; // MockIt2MeHostFactory
167 class It2MeNativeMessagingHostTest : public testing::Test {
168 public:
169 It2MeNativeMessagingHostTest() {}
170 ~It2MeNativeMessagingHostTest() override {}
172 void SetUp() override;
173 void TearDown() override;
175 protected:
176 scoped_ptr<base::DictionaryValue> ReadMessageFromOutputPipe();
177 void WriteMessageToInputPipe(const base::Value& message);
179 void VerifyHelloResponse(int request_id);
180 void VerifyErrorResponse();
181 void VerifyConnectResponses(int request_id);
182 void VerifyDisconnectResponses(int request_id);
184 // The Host process should shut down when it receives a malformed request.
185 // This is tested by sending a known-good request, followed by |message|,
186 // followed by the known-good request again. The response file should only
187 // contain a single response from the first good request.
188 void TestBadRequest(const base::Value& message, bool expect_error_response);
189 void TestConnect();
191 private:
192 void StartHost();
193 void ExitTest();
195 // Each test creates two unidirectional pipes: "input" and "output".
196 // It2MeNativeMessagingHost reads from input_read_file and writes to
197 // output_write_file. The unittest supplies data to input_write_handle, and
198 // verifies output from output_read_handle.
200 // unittest -> [input] -> It2MeNativeMessagingHost -> [output] -> unittest
201 base::File input_write_file_;
202 base::File output_read_file_;
204 // Message loop of the test thread.
205 scoped_ptr<base::MessageLoop> test_message_loop_;
206 scoped_ptr<base::RunLoop> test_run_loop_;
208 scoped_ptr<base::Thread> host_thread_;
209 scoped_ptr<base::RunLoop> host_run_loop_;
211 // Task runner of the host thread.
212 scoped_refptr<AutoThreadTaskRunner> host_task_runner_;
213 scoped_ptr<remoting::NativeMessagingPipe> pipe_;
215 DISALLOW_COPY_AND_ASSIGN(It2MeNativeMessagingHostTest);
218 void It2MeNativeMessagingHostTest::SetUp() {
219 test_message_loop_.reset(new base::MessageLoop());
220 test_run_loop_.reset(new base::RunLoop());
222 // Run the host on a dedicated thread.
223 host_thread_.reset(new base::Thread("host_thread"));
224 host_thread_->Start();
226 host_task_runner_ = new AutoThreadTaskRunner(
227 host_thread_->task_runner(),
228 base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
229 base::Unretained(this)));
231 host_task_runner_->PostTask(
232 FROM_HERE,
233 base::Bind(&It2MeNativeMessagingHostTest::StartHost,
234 base::Unretained(this)));
236 // Wait until the host finishes starting.
237 test_run_loop_->Run();
240 void It2MeNativeMessagingHostTest::TearDown() {
241 // Release reference to AutoThreadTaskRunner, so the host thread can be shut
242 // down.
243 host_task_runner_ = nullptr;
245 // Closing the write-end of the input will send an EOF to the native
246 // messaging reader. This will trigger a host shutdown.
247 input_write_file_.Close();
249 // Start a new RunLoop and Wait until the host finishes shutting down.
250 test_run_loop_.reset(new base::RunLoop());
251 test_run_loop_->Run();
253 // Verify there are no more message in the output pipe.
254 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
255 EXPECT_FALSE(response);
257 // The It2MeNativeMessagingHost dtor closes the handles that are passed to it.
258 // So the only handle left to close is |output_read_file_|.
259 output_read_file_.Close();
262 scoped_ptr<base::DictionaryValue>
263 It2MeNativeMessagingHostTest::ReadMessageFromOutputPipe() {
264 uint32 length;
265 int read_result = output_read_file_.ReadAtCurrentPos(
266 reinterpret_cast<char*>(&length), sizeof(length));
267 if (read_result != sizeof(length)) {
268 // The output pipe has been closed, return an empty message.
269 return nullptr;
272 std::string message_json(length, '\0');
273 read_result = output_read_file_.ReadAtCurrentPos(
274 string_as_array(&message_json), length);
275 if (read_result != static_cast<int>(length)) {
276 LOG(ERROR) << "Message size (" << read_result
277 << ") doesn't match the header (" << length << ").";
278 return nullptr;
281 scoped_ptr<base::Value> message = base::JSONReader::Read(message_json);
282 if (!message || !message->IsType(base::Value::TYPE_DICTIONARY)) {
283 LOG(ERROR) << "Malformed message:" << message_json;
284 return nullptr;
287 return make_scoped_ptr(
288 static_cast<base::DictionaryValue*>(message.release()));
291 void It2MeNativeMessagingHostTest::WriteMessageToInputPipe(
292 const base::Value& message) {
293 std::string message_json;
294 base::JSONWriter::Write(message, &message_json);
296 uint32 length = message_json.length();
297 input_write_file_.WriteAtCurrentPos(reinterpret_cast<char*>(&length),
298 sizeof(length));
299 input_write_file_.WriteAtCurrentPos(message_json.data(), length);
302 void It2MeNativeMessagingHostTest::VerifyHelloResponse(int request_id) {
303 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
304 VerifyCommonProperties(response.Pass(), "helloResponse", request_id);
307 void It2MeNativeMessagingHostTest::VerifyErrorResponse() {
308 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
309 VerifyStringProperty(response.Pass(), "type", "error");
312 void It2MeNativeMessagingHostTest::VerifyConnectResponses(int request_id) {
313 bool connect_response_received = false;
314 bool starting_received = false;
315 bool requestedAccessCode_received = false;
316 bool receivedAccessCode_received = false;
317 bool connected_received = false;
319 // We expect a total of 5 messages: 1 connectResponse and 4 hostStateChanged.
320 for (int i = 0; i < 5; ++i) {
321 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
322 ASSERT_TRUE(response);
324 std::string type;
325 ASSERT_TRUE(response->GetString("type", &type));
327 if (type == "connectResponse") {
328 EXPECT_FALSE(connect_response_received);
329 connect_response_received = true;
330 VerifyId(response.Pass(), request_id);
331 } else if (type == "hostStateChanged") {
332 std::string state;
333 ASSERT_TRUE(response->GetString("state", &state));
335 std::string value;
336 if (state == It2MeNativeMessagingHost::HostStateToString(kStarting)) {
337 EXPECT_FALSE(starting_received);
338 starting_received = true;
339 } else if (state == It2MeNativeMessagingHost::HostStateToString(
340 kRequestedAccessCode)) {
341 EXPECT_FALSE(requestedAccessCode_received);
342 requestedAccessCode_received = true;
343 } else if (state == It2MeNativeMessagingHost::HostStateToString(
344 kReceivedAccessCode)) {
345 EXPECT_FALSE(receivedAccessCode_received);
346 receivedAccessCode_received = true;
348 EXPECT_TRUE(response->GetString("accessCode", &value));
349 EXPECT_EQ(kTestAccessCode, value);
351 int accessCodeLifetime;
352 EXPECT_TRUE(
353 response->GetInteger("accessCodeLifetime", &accessCodeLifetime));
354 EXPECT_EQ(kTestAccessCodeLifetimeInSeconds, accessCodeLifetime);
355 } else if (state ==
356 It2MeNativeMessagingHost::HostStateToString(kConnected)) {
357 EXPECT_FALSE(connected_received);
358 connected_received = true;
360 EXPECT_TRUE(response->GetString("client", &value));
361 EXPECT_EQ(kTestClientUsername, value);
362 } else {
363 ADD_FAILURE() << "Unexpected host state: " << state;
365 } else {
366 ADD_FAILURE() << "Unexpected message type: " << type;
371 void It2MeNativeMessagingHostTest::VerifyDisconnectResponses(int request_id) {
372 bool disconnect_response_received = false;
373 bool disconnecting_received = false;
374 bool disconnected_received = false;
376 // We expect a total of 3 messages: 1 connectResponse and 2 hostStateChanged.
377 for (int i = 0; i < 3; ++i) {
378 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
379 ASSERT_TRUE(response);
381 std::string type;
382 ASSERT_TRUE(response->GetString("type", &type));
384 if (type == "disconnectResponse") {
385 EXPECT_FALSE(disconnect_response_received);
386 disconnect_response_received = true;
387 VerifyId(response.Pass(), request_id);
388 } else if (type == "hostStateChanged") {
389 std::string state;
390 ASSERT_TRUE(response->GetString("state", &state));
391 if (state ==
392 It2MeNativeMessagingHost::HostStateToString(kDisconnecting)) {
393 EXPECT_FALSE(disconnecting_received);
394 disconnecting_received = true;
395 } else if (state ==
396 It2MeNativeMessagingHost::HostStateToString(kDisconnected)) {
397 EXPECT_FALSE(disconnected_received);
398 disconnected_received = true;
399 } else {
400 ADD_FAILURE() << "Unexpected host state: " << state;
402 } else {
403 ADD_FAILURE() << "Unexpected message type: " << type;
408 void It2MeNativeMessagingHostTest::TestBadRequest(const base::Value& message,
409 bool expect_error_response) {
410 base::DictionaryValue good_message;
411 good_message.SetString("type", "hello");
412 good_message.SetInteger("id", 1);
414 WriteMessageToInputPipe(good_message);
415 WriteMessageToInputPipe(message);
416 WriteMessageToInputPipe(good_message);
418 VerifyHelloResponse(1);
420 if (expect_error_response)
421 VerifyErrorResponse();
423 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
424 EXPECT_FALSE(response);
427 void It2MeNativeMessagingHostTest::StartHost() {
428 DCHECK(host_task_runner_->RunsTasksOnCurrentThread());
430 base::File input_read_file;
431 base::File output_write_file;
433 ASSERT_TRUE(MakePipe(&input_read_file, &input_write_file_));
434 ASSERT_TRUE(MakePipe(&output_read_file_, &output_write_file));
436 pipe_.reset(new NativeMessagingPipe());
438 scoped_ptr<extensions::NativeMessagingChannel> channel(
439 new PipeMessagingChannel(input_read_file.Pass(),
440 output_write_file.Pass()));
442 // Creating a native messaging host with a mock It2MeHostFactory.
443 scoped_ptr<extensions::NativeMessageHost> it2me_host(
444 new It2MeNativeMessagingHost(
445 ChromotingHostContext::Create(host_task_runner_),
446 make_scoped_ptr(new MockIt2MeHostFactory())));
447 it2me_host->Start(pipe_.get());
449 pipe_->Start(it2me_host.Pass(), channel.Pass());
451 // Notify the test that the host has finished starting up.
452 test_message_loop_->task_runner()->PostTask(
453 FROM_HERE, test_run_loop_->QuitClosure());
456 void It2MeNativeMessagingHostTest::ExitTest() {
457 if (!test_message_loop_->task_runner()->RunsTasksOnCurrentThread()) {
458 test_message_loop_->task_runner()->PostTask(
459 FROM_HERE,
460 base::Bind(&It2MeNativeMessagingHostTest::ExitTest,
461 base::Unretained(this)));
462 return;
464 test_run_loop_->Quit();
467 void It2MeNativeMessagingHostTest::TestConnect() {
468 base::DictionaryValue connect_message;
469 int next_id = 0;
471 // Send the "connect" request.
472 connect_message.SetInteger("id", ++next_id);
473 connect_message.SetString("type", "connect");
474 connect_message.SetString("xmppServerAddress", "talk.google.com:5222");
475 connect_message.SetBoolean("xmppServerUseTls", true);
476 connect_message.SetString("directoryBotJid", "remoting@bot.talk.google.com");
477 connect_message.SetString("userName", "chromo.pyauto@gmail.com");
478 connect_message.SetString("authServiceWithToken", "oauth2:sometoken");
479 WriteMessageToInputPipe(connect_message);
481 VerifyConnectResponses(next_id);
483 base::DictionaryValue disconnect_message;
484 disconnect_message.SetInteger("id", ++next_id);
485 disconnect_message.SetString("type", "disconnect");
486 WriteMessageToInputPipe(disconnect_message);
488 VerifyDisconnectResponses(next_id);
491 // Test hello request.
492 TEST_F(It2MeNativeMessagingHostTest, Hello) {
493 int next_id = 0;
494 base::DictionaryValue message;
495 message.SetInteger("id", ++next_id);
496 message.SetString("type", "hello");
497 WriteMessageToInputPipe(message);
499 VerifyHelloResponse(next_id);
502 // Verify that response ID matches request ID.
503 TEST_F(It2MeNativeMessagingHostTest, Id) {
504 base::DictionaryValue message;
505 message.SetString("type", "hello");
506 WriteMessageToInputPipe(message);
507 message.SetString("id", "42");
508 WriteMessageToInputPipe(message);
510 scoped_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe();
511 EXPECT_TRUE(response);
512 std::string value;
513 EXPECT_FALSE(response->GetString("id", &value));
515 response = ReadMessageFromOutputPipe();
516 EXPECT_TRUE(response);
517 EXPECT_TRUE(response->GetString("id", &value));
518 EXPECT_EQ("42", value);
521 TEST_F(It2MeNativeMessagingHostTest, Connect) {
522 // A new It2MeHost instance is created for every it2me session. The native
523 // messaging host, on the other hand, is long lived. This test verifies
524 // multiple It2Me host startup and shutdowns.
525 for (int i = 0; i < 3; ++i)
526 TestConnect();
529 // Verify non-Dictionary requests are rejected.
530 TEST_F(It2MeNativeMessagingHostTest, WrongFormat) {
531 base::ListValue message;
532 // No "error" response will be sent for non-Dictionary messages.
533 TestBadRequest(message, false);
536 // Verify requests with no type are rejected.
537 TEST_F(It2MeNativeMessagingHostTest, MissingType) {
538 base::DictionaryValue message;
539 TestBadRequest(message, true);
542 // Verify rejection if type is unrecognized.
543 TEST_F(It2MeNativeMessagingHostTest, InvalidType) {
544 base::DictionaryValue message;
545 message.SetString("type", "xxx");
546 TestBadRequest(message, true);
549 } // namespace remoting