1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "base/location.h"
6 #include "base/memory/weak_ptr.h"
7 #include "base/run_loop.h"
8 #include "base/single_thread_task_runner.h"
9 #include "base/thread_task_runner_handle.h"
10 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
11 #include "net/base/net_errors.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/mdns_client_impl.h"
14 #include "net/dns/mock_mdns_socket_factory.h"
15 #include "testing/gmock/include/gmock/gmock.h"
16 #include "testing/gtest/include/gtest/gtest.h"
19 using ::testing::Invoke
;
20 using ::testing::StrictMock
;
21 using ::testing::NiceMock
;
22 using ::testing::Mock
;
23 using ::testing::SaveArg
;
24 using ::testing::SetArgPointee
;
25 using ::testing::Return
;
26 using ::testing::Exactly
;
28 namespace local_discovery
{
32 const uint8 kSamplePacketPTR
[] = {
34 0x00, 0x00, // ID is zeroed out
35 0x81, 0x80, // Standard query response, RA, no error
36 0x00, 0x00, // No questions (for simplicity)
37 0x00, 0x01, // 1 RR (answers)
38 0x00, 0x00, // 0 authority RRs
39 0x00, 0x00, // 0 additional RRs
41 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
42 0x04, '_', 't', 'c', 'p',
43 0x05, 'l', 'o', 'c', 'a', 'l',
45 0x00, 0x0c, // TYPE is PTR.
46 0x00, 0x01, // CLASS is IN.
47 0x00, 0x00, // TTL (4 bytes) is 1 second.
49 0x00, 0x08, // RDLENGTH is 8 bytes.
50 0x05, 'h', 'e', 'l', 'l', 'o',
54 const uint8 kSamplePacketSRV
[] = {
56 0x00, 0x00, // ID is zeroed out
57 0x81, 0x80, // Standard query response, RA, no error
58 0x00, 0x00, // No questions (for simplicity)
59 0x00, 0x01, // 1 RR (answers)
60 0x00, 0x00, // 0 authority RRs
61 0x00, 0x00, // 0 additional RRs
63 0x05, 'h', 'e', 'l', 'l', 'o',
64 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
65 0x04, '_', 't', 'c', 'p',
66 0x05, 'l', 'o', 'c', 'a', 'l',
68 0x00, 0x21, // TYPE is SRV.
69 0x00, 0x01, // CLASS is IN.
70 0x00, 0x00, // TTL (4 bytes) is 1 second.
72 0x00, 0x15, // RDLENGTH is 21 bytes.
75 0x22, 0xb8, // port 8888
76 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
77 0x05, 'l', 'o', 'c', 'a', 'l',
81 const uint8 kSamplePacketTXT
[] = {
83 0x00, 0x00, // ID is zeroed out
84 0x81, 0x80, // Standard query response, RA, no error
85 0x00, 0x00, // No questions (for simplicity)
86 0x00, 0x01, // 1 RR (answers)
87 0x00, 0x00, // 0 authority RRs
88 0x00, 0x00, // 0 additional RRs
90 0x05, 'h', 'e', 'l', 'l', 'o',
91 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
92 0x04, '_', 't', 'c', 'p',
93 0x05, 'l', 'o', 'c', 'a', 'l',
95 0x00, 0x10, // TYPE is PTR.
96 0x00, 0x01, // CLASS is IN.
97 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
99 0x00, 0x06, // RDLENGTH is 21 bytes.
100 0x05, 'h', 'e', 'l', 'l', 'o'
103 const uint8 kSamplePacketSRVA
[] = {
105 0x00, 0x00, // ID is zeroed out
106 0x81, 0x80, // Standard query response, RA, no error
107 0x00, 0x00, // No questions (for simplicity)
108 0x00, 0x02, // 2 RR (answers)
109 0x00, 0x00, // 0 authority RRs
110 0x00, 0x00, // 0 additional RRs
112 0x05, 'h', 'e', 'l', 'l', 'o',
113 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
114 0x04, '_', 't', 'c', 'p',
115 0x05, 'l', 'o', 'c', 'a', 'l',
117 0x00, 0x21, // TYPE is SRV.
118 0x00, 0x01, // CLASS is IN.
119 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
121 0x00, 0x15, // RDLENGTH is 21 bytes.
124 0x22, 0xb8, // port 8888
125 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
126 0x05, 'l', 'o', 'c', 'a', 'l',
129 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
130 0x05, 'l', 'o', 'c', 'a', 'l',
132 0x00, 0x01, // TYPE is A.
133 0x00, 0x01, // CLASS is IN.
134 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
136 0x00, 0x04, // RDLENGTH is 4 bytes.
141 const uint8 kSamplePacketPTR2
[] = {
143 0x00, 0x00, // ID is zeroed out
144 0x81, 0x80, // Standard query response, RA, no error
145 0x00, 0x00, // No questions (for simplicity)
146 0x00, 0x02, // 2 RR (answers)
147 0x00, 0x00, // 0 authority RRs
148 0x00, 0x00, // 0 additional RRs
150 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
151 0x04, '_', 't', 'c', 'p',
152 0x05, 'l', 'o', 'c', 'a', 'l',
154 0x00, 0x0c, // TYPE is PTR.
155 0x00, 0x01, // CLASS is IN.
156 0x02, 0x00, // TTL (4 bytes) is 1 second.
158 0x00, 0x08, // RDLENGTH is 8 bytes.
159 0x05, 'g', 'd', 'b', 'y', 'e',
162 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
163 0x04, '_', 't', 'c', 'p',
164 0x05, 'l', 'o', 'c', 'a', 'l',
166 0x00, 0x0c, // TYPE is PTR.
167 0x00, 0x01, // CLASS is IN.
168 0x02, 0x00, // TTL (4 bytes) is 1 second.
170 0x00, 0x08, // RDLENGTH is 8 bytes.
171 0x05, 'h', 'e', 'l', 'l', 'o',
175 const uint8 kSamplePacketQuerySRV
[] = {
177 0x00, 0x00, // ID is zeroed out
178 0x00, 0x00, // No flags.
179 0x00, 0x01, // One question.
180 0x00, 0x00, // 0 RRs (answers)
181 0x00, 0x00, // 0 authority RRs
182 0x00, 0x00, // 0 additional RRs
185 0x05, 'h', 'e', 'l', 'l', 'o',
186 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
187 0x04, '_', 't', 'c', 'p',
188 0x05, 'l', 'o', 'c', 'a', 'l',
190 0x00, 0x21, // TYPE is SRV.
191 0x00, 0x01, // CLASS is IN.
195 class MockServiceWatcherClient
{
197 MOCK_METHOD2(OnServiceUpdated
,
198 void(ServiceWatcher::UpdateType
, const std::string
&));
200 ServiceWatcher::UpdatedCallback
GetCallback() {
201 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated
,
202 base::Unretained(this));
206 class ServiceDiscoveryTest
: public ::testing::Test
{
208 ServiceDiscoveryTest()
209 : service_discovery_client_(&mdns_client_
) {
210 mdns_client_
.StartListening(&socket_factory_
);
213 ~ServiceDiscoveryTest() override
{}
216 void RunFor(base::TimeDelta time_period
) {
217 base::CancelableCallback
<void()> callback(base::Bind(
218 &ServiceDiscoveryTest::Stop
, base::Unretained(this)));
219 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
220 FROM_HERE
, callback
.callback(), time_period
);
222 base::MessageLoop::current()->Run();
227 base::MessageLoop::current()->Quit();
230 net::MockMDnsSocketFactory socket_factory_
;
231 net::MDnsClientImpl mdns_client_
;
232 ServiceDiscoveryClientImpl service_discovery_client_
;
233 base::MessageLoop loop_
;
236 TEST_F(ServiceDiscoveryTest
, AddRemoveService
) {
237 StrictMock
<MockServiceWatcherClient
> delegate
;
239 scoped_ptr
<ServiceWatcher
> watcher(
240 service_discovery_client_
.CreateServiceWatcher(
241 "_privet._tcp.local", delegate
.GetCallback()));
245 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
246 "hello._privet._tcp.local"))
249 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
251 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED
,
252 "hello._privet._tcp.local"))
255 RunFor(base::TimeDelta::FromSeconds(2));
258 TEST_F(ServiceDiscoveryTest
, DiscoverNewServices
) {
259 StrictMock
<MockServiceWatcherClient
> delegate
;
261 scoped_ptr
<ServiceWatcher
> watcher(
262 service_discovery_client_
.CreateServiceWatcher(
263 "_privet._tcp.local", delegate
.GetCallback()));
267 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(2);
269 watcher
->DiscoverNewServices(false);
271 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(2);
273 RunFor(base::TimeDelta::FromSeconds(2));
276 TEST_F(ServiceDiscoveryTest
, ReadCachedServices
) {
277 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
279 StrictMock
<MockServiceWatcherClient
> delegate
;
281 scoped_ptr
<ServiceWatcher
> watcher(
282 service_discovery_client_
.CreateServiceWatcher(
283 "_privet._tcp.local", delegate
.GetCallback()));
287 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
288 "hello._privet._tcp.local"))
291 base::MessageLoop::current()->RunUntilIdle();
295 TEST_F(ServiceDiscoveryTest
, ReadCachedServicesMultiple
) {
296 socket_factory_
.SimulateReceive(kSamplePacketPTR2
, sizeof(kSamplePacketPTR2
));
298 StrictMock
<MockServiceWatcherClient
> delegate
;
299 scoped_ptr
<ServiceWatcher
> watcher
=
300 service_discovery_client_
.CreateServiceWatcher(
301 "_privet._tcp.local", delegate
.GetCallback());
305 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
306 "hello._privet._tcp.local"))
309 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
310 "gdbye._privet._tcp.local"))
313 base::MessageLoop::current()->RunUntilIdle();
317 TEST_F(ServiceDiscoveryTest
, OnServiceChanged
) {
318 StrictMock
<MockServiceWatcherClient
> delegate
;
319 scoped_ptr
<ServiceWatcher
> watcher(
320 service_discovery_client_
.CreateServiceWatcher(
321 "_privet._tcp.local", delegate
.GetCallback()));
325 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
326 "hello._privet._tcp.local"))
329 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
331 base::MessageLoop::current()->RunUntilIdle();
333 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED
,
334 "hello._privet._tcp.local"))
337 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
339 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
341 base::MessageLoop::current()->RunUntilIdle();
344 TEST_F(ServiceDiscoveryTest
, SinglePacket
) {
345 StrictMock
<MockServiceWatcherClient
> delegate
;
346 scoped_ptr
<ServiceWatcher
> watcher(
347 service_discovery_client_
.CreateServiceWatcher(
348 "_privet._tcp.local", delegate
.GetCallback()));
352 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
353 "hello._privet._tcp.local"))
356 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
358 // Reset the "already updated" flag.
359 base::MessageLoop::current()->RunUntilIdle();
361 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED
,
362 "hello._privet._tcp.local"))
365 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
367 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
369 base::MessageLoop::current()->RunUntilIdle();
372 TEST_F(ServiceDiscoveryTest
, ActivelyRefreshServices
) {
373 StrictMock
<MockServiceWatcherClient
> delegate
;
374 scoped_ptr
<ServiceWatcher
> watcher(
375 service_discovery_client_
.CreateServiceWatcher(
376 "_privet._tcp.local", delegate
.GetCallback()));
379 watcher
->SetActivelyRefreshServices(true);
381 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
382 "hello._privet._tcp.local"))
385 std::string query_packet
= std::string((const char*)(kSamplePacketQuerySRV
),
386 sizeof(kSamplePacketQuerySRV
));
388 EXPECT_CALL(socket_factory_
, OnSendTo(query_packet
))
391 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
393 base::MessageLoop::current()->RunUntilIdle();
395 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
397 EXPECT_CALL(socket_factory_
, OnSendTo(query_packet
))
398 .Times(4); // IPv4 and IPv6 at 85% and 95%
400 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED
,
401 "hello._privet._tcp.local"))
404 RunFor(base::TimeDelta::FromSeconds(2));
406 base::MessageLoop::current()->RunUntilIdle();
410 class ServiceResolverTest
: public ServiceDiscoveryTest
{
412 ServiceResolverTest() {
413 metadata_expected_
.push_back("hello");
414 address_expected_
= net::HostPortPair("myhello.local", 8888);
415 ip_address_expected_
.push_back(1);
416 ip_address_expected_
.push_back(2);
417 ip_address_expected_
.push_back(3);
418 ip_address_expected_
.push_back(4);
421 ~ServiceResolverTest() {
425 resolver_
= service_discovery_client_
.CreateServiceResolver(
426 "hello._privet._tcp.local",
427 base::Bind(&ServiceResolverTest::OnFinishedResolving
,
428 base::Unretained(this)));
431 void OnFinishedResolving(ServiceResolver::RequestStatus request_status
,
432 const ServiceDescription
& service_description
) {
433 OnFinishedResolvingInternal(request_status
,
434 service_description
.address
.ToString(),
435 service_description
.metadata
,
436 service_description
.ip_address
);
439 MOCK_METHOD4(OnFinishedResolvingInternal
,
440 void(ServiceResolver::RequestStatus
,
442 const std::vector
<std::string
>&,
443 const net::IPAddressNumber
&));
446 scoped_ptr
<ServiceResolver
> resolver_
;
447 net::IPAddressNumber ip_address_
;
448 net::HostPortPair address_expected_
;
449 std::vector
<std::string
> metadata_expected_
;
450 net::IPAddressNumber ip_address_expected_
;
453 TEST_F(ServiceResolverTest
, TxtAndSrvButNoA
) {
454 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
456 resolver_
->StartResolving();
458 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
460 base::MessageLoop::current()->RunUntilIdle();
463 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
464 address_expected_
.ToString(),
466 net::IPAddressNumber()));
468 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
471 TEST_F(ServiceResolverTest
, TxtSrvAndA
) {
472 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
474 resolver_
->StartResolving();
477 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
478 address_expected_
.ToString(),
480 ip_address_expected_
));
482 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
484 socket_factory_
.SimulateReceive(kSamplePacketSRVA
, sizeof(kSamplePacketSRVA
));
487 TEST_F(ServiceResolverTest
, JustSrv
) {
488 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
490 resolver_
->StartResolving();
493 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
494 address_expected_
.ToString(),
495 std::vector
<std::string
>(),
496 ip_address_expected_
));
498 socket_factory_
.SimulateReceive(kSamplePacketSRVA
, sizeof(kSamplePacketSRVA
));
500 // TODO(noamsml): When NSEC record support is added, change this to use an
502 RunFor(base::TimeDelta::FromSeconds(4));
505 TEST_F(ServiceResolverTest
, WithNothing
) {
506 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
508 resolver_
->StartResolving();
510 EXPECT_CALL(*this, OnFinishedResolvingInternal(
511 ServiceResolver::STATUS_REQUEST_TIMEOUT
, _
, _
, _
));
513 // TODO(noamsml): When NSEC record support is added, change this to use an
515 RunFor(base::TimeDelta::FromSeconds(4));
520 } // namespace local_discovery