1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
7 #include "base/location.h"
8 #include "base/logging.h"
9 #include "base/memory/singleton.h"
10 #include "base/single_thread_task_runner.h"
11 #include "base/stl_util.h"
12 #include "base/thread_task_runner_handle.h"
13 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
14 #include "net/dns/dns_protocol.h"
15 #include "net/dns/record_rdata.h"
17 namespace local_discovery
{
20 // TODO(noamsml): Make this configurable through the LocalDomainResolver
22 const int kLocalDomainSecondAddressTimeoutMs
= 100;
24 const int kInitialRequeryTimeSeconds
= 1;
25 const int kMaxRequeryTimeSeconds
= 2; // Time for last requery
28 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
29 net::MDnsClient
* mdns_client
) : mdns_client_(mdns_client
) {
32 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
35 scoped_ptr
<ServiceWatcher
> ServiceDiscoveryClientImpl::CreateServiceWatcher(
36 const std::string
& service_type
,
37 const ServiceWatcher::UpdatedCallback
& callback
) {
38 return scoped_ptr
<ServiceWatcher
>(new ServiceWatcherImpl(
39 service_type
, callback
, mdns_client_
));
42 scoped_ptr
<ServiceResolver
> ServiceDiscoveryClientImpl::CreateServiceResolver(
43 const std::string
& service_name
,
44 const ServiceResolver::ResolveCompleteCallback
& callback
) {
45 return scoped_ptr
<ServiceResolver
>(new ServiceResolverImpl(
46 service_name
, callback
, mdns_client_
));
49 scoped_ptr
<LocalDomainResolver
>
50 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
51 const std::string
& domain
,
52 net::AddressFamily address_family
,
53 const LocalDomainResolver::IPAddressCallback
& callback
) {
54 return scoped_ptr
<LocalDomainResolver
>(new LocalDomainResolverImpl(
55 domain
, address_family
, callback
, mdns_client_
));
58 ServiceWatcherImpl::ServiceWatcherImpl(
59 const std::string
& service_type
,
60 const ServiceWatcher::UpdatedCallback
& callback
,
61 net::MDnsClient
* mdns_client
)
62 : service_type_(service_type
), callback_(callback
), started_(false),
63 actively_refresh_services_(false), mdns_client_(mdns_client
) {
66 void ServiceWatcherImpl::Start() {
68 listener_
= mdns_client_
->CreateListener(
69 net::dns_protocol::kTypePTR
, service_type_
, this);
70 started_
= listener_
->Start();
75 ServiceWatcherImpl::~ServiceWatcherImpl() {
78 void ServiceWatcherImpl::DiscoverNewServices(bool force_update
) {
82 SendQuery(kInitialRequeryTimeSeconds
, force_update
);
85 void ServiceWatcherImpl::SetActivelyRefreshServices(
86 bool actively_refresh_services
) {
88 actively_refresh_services_
= actively_refresh_services
;
90 for (ServiceListenersMap::iterator i
= services_
.begin();
91 i
!= services_
.end(); i
++) {
92 i
->second
->SetActiveRefresh(actively_refresh_services
);
96 void ServiceWatcherImpl::ReadCachedServices() {
98 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
102 bool ServiceWatcherImpl::CreateTransaction(
103 bool network
, bool cache
, bool force_refresh
,
104 scoped_ptr
<net::MDnsTransaction
>* transaction
) {
105 int transaction_flags
= 0;
107 transaction_flags
|= net::MDnsTransaction::QUERY_NETWORK
;
110 transaction_flags
|= net::MDnsTransaction::QUERY_CACHE
;
112 // TODO(noamsml): Add flag for force_refresh when supported.
114 if (transaction_flags
) {
115 *transaction
= mdns_client_
->CreateTransaction(
116 net::dns_protocol::kTypePTR
, service_type_
, transaction_flags
,
117 base::Bind(&ServiceWatcherImpl::OnTransactionResponse
,
118 base::Unretained(this), transaction
));
119 return (*transaction
)->Start();
125 std::string
ServiceWatcherImpl::GetServiceType() const {
126 return listener_
->GetName();
129 void ServiceWatcherImpl::OnRecordUpdate(
130 net::MDnsListener::UpdateType update
,
131 const net::RecordParsed
* record
) {
133 if (record
->type() == net::dns_protocol::kTypePTR
) {
134 DCHECK(record
->name() == GetServiceType());
135 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
138 case net::MDnsListener::RECORD_ADDED
:
139 AddService(rdata
->ptrdomain());
141 case net::MDnsListener::RECORD_CHANGED
:
144 case net::MDnsListener::RECORD_REMOVED
:
145 RemovePTR(rdata
->ptrdomain());
149 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
||
150 record
->type() == net::dns_protocol::kTypeTXT
);
151 DCHECK(services_
.find(record
->name()) != services_
.end());
153 if (record
->type() == net::dns_protocol::kTypeSRV
) {
154 if (update
== net::MDnsListener::RECORD_REMOVED
) {
155 RemoveSRV(record
->name());
156 } else if (update
== net::MDnsListener::RECORD_ADDED
) {
157 AddSRV(record
->name());
161 // If this is the first time we see an SRV record, do not send
162 // an UPDATE_CHANGED.
163 if (record
->type() != net::dns_protocol::kTypeSRV
||
164 update
!= net::MDnsListener::RECORD_ADDED
) {
165 DeferUpdate(UPDATE_CHANGED
, record
->name());
170 void ServiceWatcherImpl::OnCachePurged() {
171 // Not yet implemented.
174 void ServiceWatcherImpl::OnTransactionResponse(
175 scoped_ptr
<net::MDnsTransaction
>* transaction
,
176 net::MDnsTransaction::Result result
,
177 const net::RecordParsed
* record
) {
179 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
180 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
182 AddService(rdata
->ptrdomain());
183 } else if (result
== net::MDnsTransaction::RESULT_DONE
) {
184 transaction
->reset();
187 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
188 // record for PTR records on any name.
191 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
192 const std::string
& service_name
,
193 ServiceWatcherImpl
* watcher
,
194 net::MDnsClient
* mdns_client
)
195 : service_name_(service_name
), mdns_client_(mdns_client
),
196 update_pending_(false), has_ptr_(true), has_srv_(false) {
197 srv_listener_
= mdns_client
->CreateListener(
198 net::dns_protocol::kTypeSRV
, service_name
, watcher
);
199 txt_listener_
= mdns_client
->CreateListener(
200 net::dns_protocol::kTypeTXT
, service_name
, watcher
);
203 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
206 bool ServiceWatcherImpl::ServiceListeners::Start() {
207 if (!srv_listener_
->Start())
209 return txt_listener_
->Start();
212 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
213 bool active_refresh
) {
214 srv_listener_
->SetActiveRefresh(active_refresh
);
216 if (active_refresh
&& !has_srv_
) {
218 srv_transaction_
= mdns_client_
->CreateTransaction(
219 net::dns_protocol::kTypeSRV
, service_name_
,
220 net::MDnsTransaction::SINGLE_RESULT
|
221 net::MDnsTransaction::QUERY_CACHE
| net::MDnsTransaction::QUERY_NETWORK
,
222 base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord
,
223 base::Unretained(this)));
224 srv_transaction_
->Start();
225 } else if (!active_refresh
) {
226 srv_transaction_
.reset();
230 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
231 net::MDnsTransaction::Result result
,
232 const net::RecordParsed
* record
) {
233 set_has_srv(record
!= NULL
);
236 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv
) {
239 srv_transaction_
.reset();
242 void ServiceWatcherImpl::AddService(const std::string
& service
) {
244 std::pair
<ServiceListenersMap::iterator
, bool> found
= services_
.insert(
245 make_pair(service
, linked_ptr
<ServiceListeners
>(NULL
)));
247 if (found
.second
) { // Newly inserted.
248 found
.first
->second
= linked_ptr
<ServiceListeners
>(
249 new ServiceListeners(service
, this, mdns_client_
));
250 bool success
= found
.first
->second
->Start();
251 found
.first
->second
->SetActiveRefresh(actively_refresh_services_
);
252 DeferUpdate(UPDATE_ADDED
, service
);
257 found
.first
->second
->set_has_ptr(true);
260 void ServiceWatcherImpl::AddSRV(const std::string
& service
) {
263 ServiceListenersMap::iterator found
= services_
.find(service
);
264 if (found
!= services_
.end()) {
265 found
->second
->set_has_srv(true);
269 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type
,
270 const std::string
& service_name
) {
271 ServiceListenersMap::iterator found
= services_
.find(service_name
);
273 if (found
!= services_
.end() && !found
->second
->update_pending()) {
274 found
->second
->set_update_pending(true);
275 base::ThreadTaskRunnerHandle::Get()->PostTask(
276 FROM_HERE
, base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate
,
277 AsWeakPtr(), update_type
, service_name
));
281 void ServiceWatcherImpl::DeliverDeferredUpdate(
282 ServiceWatcher::UpdateType update_type
, const std::string
& service_name
) {
283 ServiceListenersMap::iterator found
= services_
.find(service_name
);
285 if (found
!= services_
.end()) {
286 found
->second
->set_update_pending(false);
287 if (!callback_
.is_null())
288 callback_
.Run(update_type
, service_name
);
292 void ServiceWatcherImpl::RemovePTR(const std::string
& service
) {
295 ServiceListenersMap::iterator found
= services_
.find(service
);
296 if (found
!= services_
.end()) {
297 found
->second
->set_has_ptr(false);
299 if (!found
->second
->has_ptr_or_srv()) {
300 services_
.erase(found
);
301 if (!callback_
.is_null())
302 callback_
.Run(UPDATE_REMOVED
, service
);
307 void ServiceWatcherImpl::RemoveSRV(const std::string
& service
) {
310 ServiceListenersMap::iterator found
= services_
.find(service
);
311 if (found
!= services_
.end()) {
312 found
->second
->set_has_srv(false);
314 if (!found
->second
->has_ptr_or_srv()) {
315 services_
.erase(found
);
316 if (!callback_
.is_null())
317 callback_
.Run(UPDATE_REMOVED
, service
);
322 void ServiceWatcherImpl::OnNsecRecord(const std::string
& name
,
324 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
328 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds
) {
329 if (timeout_seconds
<= kMaxRequeryTimeSeconds
) {
330 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
331 FROM_HERE
, base::Bind(&ServiceWatcherImpl::SendQuery
, AsWeakPtr(),
332 timeout_seconds
* 2 /*next_timeout_seconds*/,
333 false /*force_update*/),
334 base::TimeDelta::FromSeconds(timeout_seconds
));
338 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds
,
340 CreateTransaction(true /*network*/, false /*cache*/, force_update
,
341 &transaction_network_
);
342 ScheduleQuery(next_timeout_seconds
);
345 ServiceResolverImpl::ServiceResolverImpl(
346 const std::string
& service_name
,
347 const ResolveCompleteCallback
& callback
,
348 net::MDnsClient
* mdns_client
)
349 : service_name_(service_name
), callback_(callback
),
350 metadata_resolved_(false), address_resolved_(false),
351 mdns_client_(mdns_client
) {
354 void ServiceResolverImpl::StartResolving() {
355 address_resolved_
= false;
356 metadata_resolved_
= false;
357 service_staging_
= ServiceDescription();
358 service_staging_
.service_name
= service_name_
;
360 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
361 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT
);
365 ServiceResolverImpl::~ServiceResolverImpl() {
368 bool ServiceResolverImpl::CreateTxtTransaction() {
369 txt_transaction_
= mdns_client_
->CreateTransaction(
370 net::dns_protocol::kTypeTXT
, service_name_
,
371 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
372 net::MDnsTransaction::QUERY_NETWORK
,
373 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse
,
375 return txt_transaction_
->Start();
378 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
379 void ServiceResolverImpl::CreateATransaction() {
380 a_transaction_
= mdns_client_
->CreateTransaction(
381 net::dns_protocol::kTypeA
,
382 service_staging_
.address
.host(),
383 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
,
384 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse
,
386 a_transaction_
->Start();
389 bool ServiceResolverImpl::CreateSrvTransaction() {
390 srv_transaction_
= mdns_client_
->CreateTransaction(
391 net::dns_protocol::kTypeSRV
, service_name_
,
392 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
393 net::MDnsTransaction::QUERY_NETWORK
,
394 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse
,
396 return srv_transaction_
->Start();
399 std::string
ServiceResolverImpl::GetName() const {
400 return service_name_
;
403 void ServiceResolverImpl::SrvRecordTransactionResponse(
404 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
405 srv_transaction_
.reset();
406 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
408 service_staging_
.address
= RecordToAddress(record
);
409 service_staging_
.last_seen
= record
->time_created();
410 CreateATransaction();
412 ServiceNotFound(MDnsStatusToRequestStatus(status
));
416 void ServiceResolverImpl::TxtRecordTransactionResponse(
417 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
418 txt_transaction_
.reset();
419 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
421 service_staging_
.metadata
= RecordToMetadata(record
);
423 service_staging_
.metadata
= std::vector
<std::string
>();
426 metadata_resolved_
= true;
427 AlertCallbackIfReady();
430 void ServiceResolverImpl::ARecordTransactionResponse(
431 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
432 a_transaction_
.reset();
434 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
436 service_staging_
.ip_address
= RecordToIPAddress(record
);
438 service_staging_
.ip_address
= net::IPAddressNumber();
441 address_resolved_
= true;
442 AlertCallbackIfReady();
445 void ServiceResolverImpl::AlertCallbackIfReady() {
446 if (metadata_resolved_
&& address_resolved_
) {
447 txt_transaction_
.reset();
448 srv_transaction_
.reset();
449 a_transaction_
.reset();
450 if (!callback_
.is_null())
451 callback_
.Run(STATUS_SUCCESS
, service_staging_
);
455 void ServiceResolverImpl::ServiceNotFound(
456 ServiceResolver::RequestStatus status
) {
457 txt_transaction_
.reset();
458 srv_transaction_
.reset();
459 a_transaction_
.reset();
460 if (!callback_
.is_null())
461 callback_
.Run(status
, ServiceDescription());
464 ServiceResolver::RequestStatus
ServiceResolverImpl::MDnsStatusToRequestStatus(
465 net::MDnsTransaction::Result status
) const {
467 case net::MDnsTransaction::RESULT_RECORD
:
468 return ServiceResolver::STATUS_SUCCESS
;
469 case net::MDnsTransaction::RESULT_NO_RESULTS
:
470 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
471 case net::MDnsTransaction::RESULT_NSEC
:
472 return ServiceResolver::STATUS_KNOWN_NONEXISTENT
;
473 case net::MDnsTransaction::RESULT_DONE
: // Pass through.
476 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
480 const std::vector
<std::string
>& ServiceResolverImpl::RecordToMetadata(
481 const net::RecordParsed
* record
) const {
482 DCHECK(record
->type() == net::dns_protocol::kTypeTXT
);
483 const net::TxtRecordRdata
* txt_rdata
= record
->rdata
<net::TxtRecordRdata
>();
485 return txt_rdata
->texts();
488 net::HostPortPair
ServiceResolverImpl::RecordToAddress(
489 const net::RecordParsed
* record
) const {
490 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
);
491 const net::SrvRecordRdata
* srv_rdata
= record
->rdata
<net::SrvRecordRdata
>();
493 return net::HostPortPair(srv_rdata
->target(), srv_rdata
->port());
496 const net::IPAddressNumber
& ServiceResolverImpl::RecordToIPAddress(
497 const net::RecordParsed
* record
) const {
498 DCHECK(record
->type() == net::dns_protocol::kTypeA
);
499 const net::ARecordRdata
* a_rdata
= record
->rdata
<net::ARecordRdata
>();
501 return a_rdata
->address();
504 LocalDomainResolverImpl::LocalDomainResolverImpl(
505 const std::string
& domain
,
506 net::AddressFamily address_family
,
507 const IPAddressCallback
& callback
,
508 net::MDnsClient
* mdns_client
)
509 : domain_(domain
), address_family_(address_family
), callback_(callback
),
510 transactions_finished_(0), mdns_client_(mdns_client
) {
513 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
514 timeout_callback_
.Cancel();
517 void LocalDomainResolverImpl::Start() {
518 if (address_family_
== net::ADDRESS_FAMILY_IPV4
||
519 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
520 transaction_a_
= CreateTransaction(net::dns_protocol::kTypeA
);
521 transaction_a_
->Start();
524 if (address_family_
== net::ADDRESS_FAMILY_IPV6
||
525 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
526 transaction_aaaa_
= CreateTransaction(net::dns_protocol::kTypeAAAA
);
527 transaction_aaaa_
->Start();
531 scoped_ptr
<net::MDnsTransaction
> LocalDomainResolverImpl::CreateTransaction(
533 return mdns_client_
->CreateTransaction(
534 type
, domain_
, net::MDnsTransaction::SINGLE_RESULT
|
535 net::MDnsTransaction::QUERY_CACHE
|
536 net::MDnsTransaction::QUERY_NETWORK
,
537 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete
,
538 base::Unretained(this)));
541 void LocalDomainResolverImpl::OnTransactionComplete(
542 net::MDnsTransaction::Result result
, const net::RecordParsed
* record
) {
543 transactions_finished_
++;
545 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
546 if (record
->type() == net::dns_protocol::kTypeA
) {
547 const net::ARecordRdata
* rdata
= record
->rdata
<net::ARecordRdata
>();
548 address_ipv4_
= rdata
->address();
550 DCHECK_EQ(net::dns_protocol::kTypeAAAA
, record
->type());
551 const net::AAAARecordRdata
* rdata
= record
->rdata
<net::AAAARecordRdata
>();
552 address_ipv6_
= rdata
->address();
556 if (transactions_finished_
== 1 &&
557 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
558 timeout_callback_
.Reset(base::Bind(
559 &LocalDomainResolverImpl::SendResolvedAddresses
,
560 base::Unretained(this)));
562 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
563 FROM_HERE
, timeout_callback_
.callback(),
564 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs
));
565 } else if (transactions_finished_
== 2
566 || address_family_
!= net::ADDRESS_FAMILY_UNSPECIFIED
) {
567 SendResolvedAddresses();
571 bool LocalDomainResolverImpl::IsSuccess() {
572 return !address_ipv4_
.empty() || !address_ipv6_
.empty();
575 void LocalDomainResolverImpl::SendResolvedAddresses() {
576 transaction_a_
.reset();
577 transaction_aaaa_
.reset();
578 timeout_callback_
.Cancel();
579 callback_
.Run(IsSuccess(), address_ipv4_
, address_ipv6_
);
582 } // namespace local_discovery