1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "base/memory/weak_ptr.h"
6 #include "base/run_loop.h"
7 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
8 #include "net/base/net_errors.h"
9 #include "net/dns/dns_protocol.h"
10 #include "net/dns/mdns_client_impl.h"
11 #include "net/dns/mock_mdns_socket_factory.h"
12 #include "testing/gmock/include/gmock/gmock.h"
13 #include "testing/gtest/include/gtest/gtest.h"
16 using ::testing::Invoke
;
17 using ::testing::StrictMock
;
18 using ::testing::NiceMock
;
19 using ::testing::Mock
;
20 using ::testing::SaveArg
;
21 using ::testing::SetArgPointee
;
22 using ::testing::Return
;
23 using ::testing::Exactly
;
25 namespace local_discovery
{
29 const uint8 kSamplePacketPTR
[] = {
31 0x00, 0x00, // ID is zeroed out
32 0x81, 0x80, // Standard query response, RA, no error
33 0x00, 0x00, // No questions (for simplicity)
34 0x00, 0x01, // 1 RR (answers)
35 0x00, 0x00, // 0 authority RRs
36 0x00, 0x00, // 0 additional RRs
38 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
39 0x04, '_', 't', 'c', 'p',
40 0x05, 'l', 'o', 'c', 'a', 'l',
42 0x00, 0x0c, // TYPE is PTR.
43 0x00, 0x01, // CLASS is IN.
44 0x00, 0x00, // TTL (4 bytes) is 1 second.
46 0x00, 0x08, // RDLENGTH is 8 bytes.
47 0x05, 'h', 'e', 'l', 'l', 'o',
51 const uint8 kSamplePacketSRV
[] = {
53 0x00, 0x00, // ID is zeroed out
54 0x81, 0x80, // Standard query response, RA, no error
55 0x00, 0x00, // No questions (for simplicity)
56 0x00, 0x01, // 1 RR (answers)
57 0x00, 0x00, // 0 authority RRs
58 0x00, 0x00, // 0 additional RRs
60 0x05, 'h', 'e', 'l', 'l', 'o',
61 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
62 0x04, '_', 't', 'c', 'p',
63 0x05, 'l', 'o', 'c', 'a', 'l',
65 0x00, 0x21, // TYPE is SRV.
66 0x00, 0x01, // CLASS is IN.
67 0x00, 0x00, // TTL (4 bytes) is 1 second.
69 0x00, 0x15, // RDLENGTH is 21 bytes.
72 0x22, 0xb8, // port 8888
73 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
74 0x05, 'l', 'o', 'c', 'a', 'l',
78 const uint8 kSamplePacketTXT
[] = {
80 0x00, 0x00, // ID is zeroed out
81 0x81, 0x80, // Standard query response, RA, no error
82 0x00, 0x00, // No questions (for simplicity)
83 0x00, 0x01, // 1 RR (answers)
84 0x00, 0x00, // 0 authority RRs
85 0x00, 0x00, // 0 additional RRs
87 0x05, 'h', 'e', 'l', 'l', 'o',
88 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
89 0x04, '_', 't', 'c', 'p',
90 0x05, 'l', 'o', 'c', 'a', 'l',
92 0x00, 0x10, // TYPE is PTR.
93 0x00, 0x01, // CLASS is IN.
94 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
96 0x00, 0x06, // RDLENGTH is 21 bytes.
97 0x05, 'h', 'e', 'l', 'l', 'o'
100 const uint8 kSamplePacketSRVA
[] = {
102 0x00, 0x00, // ID is zeroed out
103 0x81, 0x80, // Standard query response, RA, no error
104 0x00, 0x00, // No questions (for simplicity)
105 0x00, 0x02, // 2 RR (answers)
106 0x00, 0x00, // 0 authority RRs
107 0x00, 0x00, // 0 additional RRs
109 0x05, 'h', 'e', 'l', 'l', 'o',
110 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
111 0x04, '_', 't', 'c', 'p',
112 0x05, 'l', 'o', 'c', 'a', 'l',
114 0x00, 0x21, // TYPE is SRV.
115 0x00, 0x01, // CLASS is IN.
116 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
118 0x00, 0x15, // RDLENGTH is 21 bytes.
121 0x22, 0xb8, // port 8888
122 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
123 0x05, 'l', 'o', 'c', 'a', 'l',
126 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
127 0x05, 'l', 'o', 'c', 'a', 'l',
129 0x00, 0x01, // TYPE is A.
130 0x00, 0x01, // CLASS is IN.
131 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
133 0x00, 0x04, // RDLENGTH is 4 bytes.
138 const uint8 kSamplePacketPTR2
[] = {
140 0x00, 0x00, // ID is zeroed out
141 0x81, 0x80, // Standard query response, RA, no error
142 0x00, 0x00, // No questions (for simplicity)
143 0x00, 0x02, // 2 RR (answers)
144 0x00, 0x00, // 0 authority RRs
145 0x00, 0x00, // 0 additional RRs
147 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
148 0x04, '_', 't', 'c', 'p',
149 0x05, 'l', 'o', 'c', 'a', 'l',
151 0x00, 0x0c, // TYPE is PTR.
152 0x00, 0x01, // CLASS is IN.
153 0x02, 0x00, // TTL (4 bytes) is 1 second.
155 0x00, 0x08, // RDLENGTH is 8 bytes.
156 0x05, 'g', 'd', 'b', 'y', 'e',
159 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
160 0x04, '_', 't', 'c', 'p',
161 0x05, 'l', 'o', 'c', 'a', 'l',
163 0x00, 0x0c, // TYPE is PTR.
164 0x00, 0x01, // CLASS is IN.
165 0x02, 0x00, // TTL (4 bytes) is 1 second.
167 0x00, 0x08, // RDLENGTH is 8 bytes.
168 0x05, 'h', 'e', 'l', 'l', 'o',
172 const uint8 kSamplePacketQuerySRV
[] = {
174 0x00, 0x00, // ID is zeroed out
175 0x00, 0x00, // No flags.
176 0x00, 0x01, // One question.
177 0x00, 0x00, // 0 RRs (answers)
178 0x00, 0x00, // 0 authority RRs
179 0x00, 0x00, // 0 additional RRs
182 0x05, 'h', 'e', 'l', 'l', 'o',
183 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
184 0x04, '_', 't', 'c', 'p',
185 0x05, 'l', 'o', 'c', 'a', 'l',
187 0x00, 0x21, // TYPE is SRV.
188 0x00, 0x01, // CLASS is IN.
192 class MockServiceWatcherClient
{
194 MOCK_METHOD2(OnServiceUpdated
,
195 void(ServiceWatcher::UpdateType
, const std::string
&));
197 ServiceWatcher::UpdatedCallback
GetCallback() {
198 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated
,
199 base::Unretained(this));
203 class ServiceDiscoveryTest
: public ::testing::Test
{
205 ServiceDiscoveryTest()
206 : service_discovery_client_(&mdns_client_
) {
207 mdns_client_
.StartListening(&socket_factory_
);
210 ~ServiceDiscoveryTest() override
{}
213 void RunFor(base::TimeDelta time_period
) {
214 base::CancelableCallback
<void()> callback(base::Bind(
215 &ServiceDiscoveryTest::Stop
, base::Unretained(this)));
216 base::MessageLoop::current()->PostDelayedTask(
217 FROM_HERE
, callback
.callback(), time_period
);
219 base::MessageLoop::current()->Run();
224 base::MessageLoop::current()->Quit();
227 net::MockMDnsSocketFactory socket_factory_
;
228 net::MDnsClientImpl mdns_client_
;
229 ServiceDiscoveryClientImpl service_discovery_client_
;
230 base::MessageLoop loop_
;
233 TEST_F(ServiceDiscoveryTest
, AddRemoveService
) {
234 StrictMock
<MockServiceWatcherClient
> delegate
;
236 scoped_ptr
<ServiceWatcher
> watcher(
237 service_discovery_client_
.CreateServiceWatcher(
238 "_privet._tcp.local", delegate
.GetCallback()));
242 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
243 "hello._privet._tcp.local"))
246 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
248 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED
,
249 "hello._privet._tcp.local"))
252 RunFor(base::TimeDelta::FromSeconds(2));
255 TEST_F(ServiceDiscoveryTest
, DiscoverNewServices
) {
256 StrictMock
<MockServiceWatcherClient
> delegate
;
258 scoped_ptr
<ServiceWatcher
> watcher(
259 service_discovery_client_
.CreateServiceWatcher(
260 "_privet._tcp.local", delegate
.GetCallback()));
264 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(2);
266 watcher
->DiscoverNewServices(false);
268 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(2);
270 RunFor(base::TimeDelta::FromSeconds(2));
273 TEST_F(ServiceDiscoveryTest
, ReadCachedServices
) {
274 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
276 StrictMock
<MockServiceWatcherClient
> delegate
;
278 scoped_ptr
<ServiceWatcher
> watcher(
279 service_discovery_client_
.CreateServiceWatcher(
280 "_privet._tcp.local", delegate
.GetCallback()));
284 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
285 "hello._privet._tcp.local"))
288 base::MessageLoop::current()->RunUntilIdle();
292 TEST_F(ServiceDiscoveryTest
, ReadCachedServicesMultiple
) {
293 socket_factory_
.SimulateReceive(kSamplePacketPTR2
, sizeof(kSamplePacketPTR2
));
295 StrictMock
<MockServiceWatcherClient
> delegate
;
296 scoped_ptr
<ServiceWatcher
> watcher
=
297 service_discovery_client_
.CreateServiceWatcher(
298 "_privet._tcp.local", delegate
.GetCallback());
302 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
303 "hello._privet._tcp.local"))
306 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
307 "gdbye._privet._tcp.local"))
310 base::MessageLoop::current()->RunUntilIdle();
314 TEST_F(ServiceDiscoveryTest
, OnServiceChanged
) {
315 StrictMock
<MockServiceWatcherClient
> delegate
;
316 scoped_ptr
<ServiceWatcher
> watcher(
317 service_discovery_client_
.CreateServiceWatcher(
318 "_privet._tcp.local", delegate
.GetCallback()));
322 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
323 "hello._privet._tcp.local"))
326 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
328 base::MessageLoop::current()->RunUntilIdle();
330 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED
,
331 "hello._privet._tcp.local"))
334 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
336 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
338 base::MessageLoop::current()->RunUntilIdle();
341 TEST_F(ServiceDiscoveryTest
, SinglePacket
) {
342 StrictMock
<MockServiceWatcherClient
> delegate
;
343 scoped_ptr
<ServiceWatcher
> watcher(
344 service_discovery_client_
.CreateServiceWatcher(
345 "_privet._tcp.local", delegate
.GetCallback()));
349 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
350 "hello._privet._tcp.local"))
353 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
355 // Reset the "already updated" flag.
356 base::MessageLoop::current()->RunUntilIdle();
358 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED
,
359 "hello._privet._tcp.local"))
362 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
364 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
366 base::MessageLoop::current()->RunUntilIdle();
369 TEST_F(ServiceDiscoveryTest
, ActivelyRefreshServices
) {
370 StrictMock
<MockServiceWatcherClient
> delegate
;
371 scoped_ptr
<ServiceWatcher
> watcher(
372 service_discovery_client_
.CreateServiceWatcher(
373 "_privet._tcp.local", delegate
.GetCallback()));
376 watcher
->SetActivelyRefreshServices(true);
378 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED
,
379 "hello._privet._tcp.local"))
382 std::string query_packet
= std::string((const char*)(kSamplePacketQuerySRV
),
383 sizeof(kSamplePacketQuerySRV
));
385 EXPECT_CALL(socket_factory_
, OnSendTo(query_packet
))
388 socket_factory_
.SimulateReceive(kSamplePacketPTR
, sizeof(kSamplePacketPTR
));
390 base::MessageLoop::current()->RunUntilIdle();
392 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
394 EXPECT_CALL(socket_factory_
, OnSendTo(query_packet
))
395 .Times(4); // IPv4 and IPv6 at 85% and 95%
397 EXPECT_CALL(delegate
, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED
,
398 "hello._privet._tcp.local"))
401 RunFor(base::TimeDelta::FromSeconds(2));
403 base::MessageLoop::current()->RunUntilIdle();
407 class ServiceResolverTest
: public ServiceDiscoveryTest
{
409 ServiceResolverTest() {
410 metadata_expected_
.push_back("hello");
411 address_expected_
= net::HostPortPair("myhello.local", 8888);
412 ip_address_expected_
.push_back(1);
413 ip_address_expected_
.push_back(2);
414 ip_address_expected_
.push_back(3);
415 ip_address_expected_
.push_back(4);
418 ~ServiceResolverTest() {
422 resolver_
= service_discovery_client_
.CreateServiceResolver(
423 "hello._privet._tcp.local",
424 base::Bind(&ServiceResolverTest::OnFinishedResolving
,
425 base::Unretained(this)));
428 void OnFinishedResolving(ServiceResolver::RequestStatus request_status
,
429 const ServiceDescription
& service_description
) {
430 OnFinishedResolvingInternal(request_status
,
431 service_description
.address
.ToString(),
432 service_description
.metadata
,
433 service_description
.ip_address
);
436 MOCK_METHOD4(OnFinishedResolvingInternal
,
437 void(ServiceResolver::RequestStatus
,
439 const std::vector
<std::string
>&,
440 const net::IPAddressNumber
&));
443 scoped_ptr
<ServiceResolver
> resolver_
;
444 net::IPAddressNumber ip_address_
;
445 net::HostPortPair address_expected_
;
446 std::vector
<std::string
> metadata_expected_
;
447 net::IPAddressNumber ip_address_expected_
;
450 TEST_F(ServiceResolverTest
, TxtAndSrvButNoA
) {
451 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
453 resolver_
->StartResolving();
455 socket_factory_
.SimulateReceive(kSamplePacketSRV
, sizeof(kSamplePacketSRV
));
457 base::MessageLoop::current()->RunUntilIdle();
460 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
461 address_expected_
.ToString(),
463 net::IPAddressNumber()));
465 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
468 TEST_F(ServiceResolverTest
, TxtSrvAndA
) {
469 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
471 resolver_
->StartResolving();
474 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
475 address_expected_
.ToString(),
477 ip_address_expected_
));
479 socket_factory_
.SimulateReceive(kSamplePacketTXT
, sizeof(kSamplePacketTXT
));
481 socket_factory_
.SimulateReceive(kSamplePacketSRVA
, sizeof(kSamplePacketSRVA
));
484 TEST_F(ServiceResolverTest
, JustSrv
) {
485 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
487 resolver_
->StartResolving();
490 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS
,
491 address_expected_
.ToString(),
492 std::vector
<std::string
>(),
493 ip_address_expected_
));
495 socket_factory_
.SimulateReceive(kSamplePacketSRVA
, sizeof(kSamplePacketSRVA
));
497 // TODO(noamsml): When NSEC record support is added, change this to use an
499 RunFor(base::TimeDelta::FromSeconds(4));
502 TEST_F(ServiceResolverTest
, WithNothing
) {
503 EXPECT_CALL(socket_factory_
, OnSendTo(_
)).Times(4);
505 resolver_
->StartResolving();
507 EXPECT_CALL(*this, OnFinishedResolvingInternal(
508 ServiceResolver::STATUS_REQUEST_TIMEOUT
, _
, _
, _
));
510 // TODO(noamsml): When NSEC record support is added, change this to use an
512 RunFor(base::TimeDelta::FromSeconds(4));
517 } // namespace local_discovery