1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
7 #include "base/logging.h"
8 #include "base/memory/singleton.h"
9 #include "base/message_loop/message_loop_proxy.h"
10 #include "base/stl_util.h"
11 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/record_rdata.h"
15 namespace local_discovery
{
18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
20 const int kLocalDomainSecondAddressTimeoutMs
= 100;
22 const int kInitialRequeryTimeSeconds
= 1;
23 const int kMaxRequeryTimeSeconds
= 2; // Time for last requery
26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27 net::MDnsClient
* mdns_client
) : mdns_client_(mdns_client
) {
30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
33 scoped_ptr
<ServiceWatcher
> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34 const std::string
& service_type
,
35 const ServiceWatcher::UpdatedCallback
& callback
) {
36 return scoped_ptr
<ServiceWatcher
>(new ServiceWatcherImpl(
37 service_type
, callback
, mdns_client_
));
40 scoped_ptr
<ServiceResolver
> ServiceDiscoveryClientImpl::CreateServiceResolver(
41 const std::string
& service_name
,
42 const ServiceResolver::ResolveCompleteCallback
& callback
) {
43 return scoped_ptr
<ServiceResolver
>(new ServiceResolverImpl(
44 service_name
, callback
, mdns_client_
));
47 scoped_ptr
<LocalDomainResolver
>
48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49 const std::string
& domain
,
50 net::AddressFamily address_family
,
51 const LocalDomainResolver::IPAddressCallback
& callback
) {
52 return scoped_ptr
<LocalDomainResolver
>(new LocalDomainResolverImpl(
53 domain
, address_family
, callback
, mdns_client_
));
56 ServiceWatcherImpl::ServiceWatcherImpl(
57 const std::string
& service_type
,
58 const ServiceWatcher::UpdatedCallback
& callback
,
59 net::MDnsClient
* mdns_client
)
60 : service_type_(service_type
), callback_(callback
), started_(false),
61 actively_refresh_services_(false), mdns_client_(mdns_client
) {
64 void ServiceWatcherImpl::Start() {
66 listener_
= mdns_client_
->CreateListener(
67 net::dns_protocol::kTypePTR
, service_type_
, this);
68 started_
= listener_
->Start();
73 ServiceWatcherImpl::~ServiceWatcherImpl() {
76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update
) {
80 SendQuery(kInitialRequeryTimeSeconds
, force_update
);
83 void ServiceWatcherImpl::SetActivelyRefreshServices(
84 bool actively_refresh_services
) {
86 actively_refresh_services_
= actively_refresh_services
;
88 for (ServiceListenersMap::iterator i
= services_
.begin();
89 i
!= services_
.end(); i
++) {
90 i
->second
->SetActiveRefresh(actively_refresh_services
);
94 void ServiceWatcherImpl::ReadCachedServices() {
96 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
100 bool ServiceWatcherImpl::CreateTransaction(
101 bool network
, bool cache
, bool force_refresh
,
102 scoped_ptr
<net::MDnsTransaction
>* transaction
) {
103 int transaction_flags
= 0;
105 transaction_flags
|= net::MDnsTransaction::QUERY_NETWORK
;
108 transaction_flags
|= net::MDnsTransaction::QUERY_CACHE
;
110 // TODO(noamsml): Add flag for force_refresh when supported.
112 if (transaction_flags
) {
113 *transaction
= mdns_client_
->CreateTransaction(
114 net::dns_protocol::kTypePTR
, service_type_
, transaction_flags
,
115 base::Bind(&ServiceWatcherImpl::OnTransactionResponse
,
116 base::Unretained(this), transaction
));
117 return (*transaction
)->Start();
123 std::string
ServiceWatcherImpl::GetServiceType() const {
124 return listener_
->GetName();
127 void ServiceWatcherImpl::OnRecordUpdate(
128 net::MDnsListener::UpdateType update
,
129 const net::RecordParsed
* record
) {
131 if (record
->type() == net::dns_protocol::kTypePTR
) {
132 DCHECK(record
->name() == GetServiceType());
133 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
136 case net::MDnsListener::RECORD_ADDED
:
137 AddService(rdata
->ptrdomain());
139 case net::MDnsListener::RECORD_CHANGED
:
142 case net::MDnsListener::RECORD_REMOVED
:
143 RemovePTR(rdata
->ptrdomain());
147 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
||
148 record
->type() == net::dns_protocol::kTypeTXT
);
149 DCHECK(services_
.find(record
->name()) != services_
.end());
151 if (record
->type() == net::dns_protocol::kTypeSRV
) {
152 if (update
== net::MDnsListener::RECORD_REMOVED
) {
153 RemoveSRV(record
->name());
154 } else if (update
== net::MDnsListener::RECORD_ADDED
) {
155 AddSRV(record
->name());
159 // If this is the first time we see an SRV record, do not send
160 // an UPDATE_CHANGED.
161 if (record
->type() != net::dns_protocol::kTypeSRV
||
162 update
!= net::MDnsListener::RECORD_ADDED
) {
163 DeferUpdate(UPDATE_CHANGED
, record
->name());
168 void ServiceWatcherImpl::OnCachePurged() {
169 // Not yet implemented.
172 void ServiceWatcherImpl::OnTransactionResponse(
173 scoped_ptr
<net::MDnsTransaction
>* transaction
,
174 net::MDnsTransaction::Result result
,
175 const net::RecordParsed
* record
) {
177 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
178 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
180 AddService(rdata
->ptrdomain());
181 } else if (result
== net::MDnsTransaction::RESULT_DONE
) {
182 transaction
->reset();
185 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
186 // record for PTR records on any name.
189 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
190 const std::string
& service_name
,
191 ServiceWatcherImpl
* watcher
,
192 net::MDnsClient
* mdns_client
)
193 : service_name_(service_name
), mdns_client_(mdns_client
),
194 update_pending_(false), has_ptr_(true), has_srv_(false) {
195 srv_listener_
= mdns_client
->CreateListener(
196 net::dns_protocol::kTypeSRV
, service_name
, watcher
);
197 txt_listener_
= mdns_client
->CreateListener(
198 net::dns_protocol::kTypeTXT
, service_name
, watcher
);
201 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
204 bool ServiceWatcherImpl::ServiceListeners::Start() {
205 if (!srv_listener_
->Start())
207 return txt_listener_
->Start();
210 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
211 bool active_refresh
) {
212 srv_listener_
->SetActiveRefresh(active_refresh
);
214 if (active_refresh
&& !has_srv_
) {
216 srv_transaction_
= mdns_client_
->CreateTransaction(
217 net::dns_protocol::kTypeSRV
, service_name_
,
218 net::MDnsTransaction::SINGLE_RESULT
|
219 net::MDnsTransaction::QUERY_CACHE
| net::MDnsTransaction::QUERY_NETWORK
,
220 base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord
,
221 base::Unretained(this)));
222 srv_transaction_
->Start();
223 } else if (!active_refresh
) {
224 srv_transaction_
.reset();
228 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
229 net::MDnsTransaction::Result result
,
230 const net::RecordParsed
* record
) {
231 set_has_srv(record
!= NULL
);
234 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv
) {
237 srv_transaction_
.reset();
240 void ServiceWatcherImpl::AddService(const std::string
& service
) {
242 std::pair
<ServiceListenersMap::iterator
, bool> found
= services_
.insert(
243 make_pair(service
, linked_ptr
<ServiceListeners
>(NULL
)));
245 if (found
.second
) { // Newly inserted.
246 found
.first
->second
= linked_ptr
<ServiceListeners
>(
247 new ServiceListeners(service
, this, mdns_client_
));
248 bool success
= found
.first
->second
->Start();
249 found
.first
->second
->SetActiveRefresh(actively_refresh_services_
);
250 DeferUpdate(UPDATE_ADDED
, service
);
255 found
.first
->second
->set_has_ptr(true);
258 void ServiceWatcherImpl::AddSRV(const std::string
& service
) {
261 ServiceListenersMap::iterator found
= services_
.find(service
);
262 if (found
!= services_
.end()) {
263 found
->second
->set_has_srv(true);
267 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type
,
268 const std::string
& service_name
) {
269 ServiceListenersMap::iterator found
= services_
.find(service_name
);
271 if (found
!= services_
.end() && !found
->second
->update_pending()) {
272 found
->second
->set_update_pending(true);
273 base::MessageLoop::current()->PostTask(
275 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate
, AsWeakPtr(),
276 update_type
, service_name
));
280 void ServiceWatcherImpl::DeliverDeferredUpdate(
281 ServiceWatcher::UpdateType update_type
, const std::string
& service_name
) {
282 ServiceListenersMap::iterator found
= services_
.find(service_name
);
284 if (found
!= services_
.end()) {
285 found
->second
->set_update_pending(false);
286 if (!callback_
.is_null())
287 callback_
.Run(update_type
, service_name
);
291 void ServiceWatcherImpl::RemovePTR(const std::string
& service
) {
294 ServiceListenersMap::iterator found
= services_
.find(service
);
295 if (found
!= services_
.end()) {
296 found
->second
->set_has_ptr(false);
298 if (!found
->second
->has_ptr_or_srv()) {
299 services_
.erase(found
);
300 if (!callback_
.is_null())
301 callback_
.Run(UPDATE_REMOVED
, service
);
306 void ServiceWatcherImpl::RemoveSRV(const std::string
& service
) {
309 ServiceListenersMap::iterator found
= services_
.find(service
);
310 if (found
!= services_
.end()) {
311 found
->second
->set_has_srv(false);
313 if (!found
->second
->has_ptr_or_srv()) {
314 services_
.erase(found
);
315 if (!callback_
.is_null())
316 callback_
.Run(UPDATE_REMOVED
, service
);
321 void ServiceWatcherImpl::OnNsecRecord(const std::string
& name
,
323 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
327 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds
) {
328 if (timeout_seconds
<= kMaxRequeryTimeSeconds
) {
329 base::MessageLoop::current()->PostDelayedTask(
331 base::Bind(&ServiceWatcherImpl::SendQuery
,
333 timeout_seconds
* 2 /*next_timeout_seconds*/,
334 false /*force_update*/),
335 base::TimeDelta::FromSeconds(timeout_seconds
));
339 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds
,
341 CreateTransaction(true /*network*/, false /*cache*/, force_update
,
342 &transaction_network_
);
343 ScheduleQuery(next_timeout_seconds
);
346 ServiceResolverImpl::ServiceResolverImpl(
347 const std::string
& service_name
,
348 const ResolveCompleteCallback
& callback
,
349 net::MDnsClient
* mdns_client
)
350 : service_name_(service_name
), callback_(callback
),
351 metadata_resolved_(false), address_resolved_(false),
352 mdns_client_(mdns_client
) {
355 void ServiceResolverImpl::StartResolving() {
356 address_resolved_
= false;
357 metadata_resolved_
= false;
358 service_staging_
= ServiceDescription();
359 service_staging_
.service_name
= service_name_
;
361 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
362 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT
);
366 ServiceResolverImpl::~ServiceResolverImpl() {
369 bool ServiceResolverImpl::CreateTxtTransaction() {
370 txt_transaction_
= mdns_client_
->CreateTransaction(
371 net::dns_protocol::kTypeTXT
, service_name_
,
372 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
373 net::MDnsTransaction::QUERY_NETWORK
,
374 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse
,
376 return txt_transaction_
->Start();
379 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
380 void ServiceResolverImpl::CreateATransaction() {
381 a_transaction_
= mdns_client_
->CreateTransaction(
382 net::dns_protocol::kTypeA
,
383 service_staging_
.address
.host(),
384 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
,
385 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse
,
387 a_transaction_
->Start();
390 bool ServiceResolverImpl::CreateSrvTransaction() {
391 srv_transaction_
= mdns_client_
->CreateTransaction(
392 net::dns_protocol::kTypeSRV
, service_name_
,
393 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
394 net::MDnsTransaction::QUERY_NETWORK
,
395 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse
,
397 return srv_transaction_
->Start();
400 std::string
ServiceResolverImpl::GetName() const {
401 return service_name_
;
404 void ServiceResolverImpl::SrvRecordTransactionResponse(
405 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
406 srv_transaction_
.reset();
407 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
409 service_staging_
.address
= RecordToAddress(record
);
410 service_staging_
.last_seen
= record
->time_created();
411 CreateATransaction();
413 ServiceNotFound(MDnsStatusToRequestStatus(status
));
417 void ServiceResolverImpl::TxtRecordTransactionResponse(
418 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
419 txt_transaction_
.reset();
420 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
422 service_staging_
.metadata
= RecordToMetadata(record
);
424 service_staging_
.metadata
= std::vector
<std::string
>();
427 metadata_resolved_
= true;
428 AlertCallbackIfReady();
431 void ServiceResolverImpl::ARecordTransactionResponse(
432 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
433 a_transaction_
.reset();
435 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
437 service_staging_
.ip_address
= RecordToIPAddress(record
);
439 service_staging_
.ip_address
= net::IPAddressNumber();
442 address_resolved_
= true;
443 AlertCallbackIfReady();
446 void ServiceResolverImpl::AlertCallbackIfReady() {
447 if (metadata_resolved_
&& address_resolved_
) {
448 txt_transaction_
.reset();
449 srv_transaction_
.reset();
450 a_transaction_
.reset();
451 if (!callback_
.is_null())
452 callback_
.Run(STATUS_SUCCESS
, service_staging_
);
456 void ServiceResolverImpl::ServiceNotFound(
457 ServiceResolver::RequestStatus status
) {
458 txt_transaction_
.reset();
459 srv_transaction_
.reset();
460 a_transaction_
.reset();
461 if (!callback_
.is_null())
462 callback_
.Run(status
, ServiceDescription());
465 ServiceResolver::RequestStatus
ServiceResolverImpl::MDnsStatusToRequestStatus(
466 net::MDnsTransaction::Result status
) const {
468 case net::MDnsTransaction::RESULT_RECORD
:
469 return ServiceResolver::STATUS_SUCCESS
;
470 case net::MDnsTransaction::RESULT_NO_RESULTS
:
471 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
472 case net::MDnsTransaction::RESULT_NSEC
:
473 return ServiceResolver::STATUS_KNOWN_NONEXISTENT
;
474 case net::MDnsTransaction::RESULT_DONE
: // Pass through.
477 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
481 const std::vector
<std::string
>& ServiceResolverImpl::RecordToMetadata(
482 const net::RecordParsed
* record
) const {
483 DCHECK(record
->type() == net::dns_protocol::kTypeTXT
);
484 const net::TxtRecordRdata
* txt_rdata
= record
->rdata
<net::TxtRecordRdata
>();
486 return txt_rdata
->texts();
489 net::HostPortPair
ServiceResolverImpl::RecordToAddress(
490 const net::RecordParsed
* record
) const {
491 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
);
492 const net::SrvRecordRdata
* srv_rdata
= record
->rdata
<net::SrvRecordRdata
>();
494 return net::HostPortPair(srv_rdata
->target(), srv_rdata
->port());
497 const net::IPAddressNumber
& ServiceResolverImpl::RecordToIPAddress(
498 const net::RecordParsed
* record
) const {
499 DCHECK(record
->type() == net::dns_protocol::kTypeA
);
500 const net::ARecordRdata
* a_rdata
= record
->rdata
<net::ARecordRdata
>();
502 return a_rdata
->address();
505 LocalDomainResolverImpl::LocalDomainResolverImpl(
506 const std::string
& domain
,
507 net::AddressFamily address_family
,
508 const IPAddressCallback
& callback
,
509 net::MDnsClient
* mdns_client
)
510 : domain_(domain
), address_family_(address_family
), callback_(callback
),
511 transactions_finished_(0), mdns_client_(mdns_client
) {
514 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
515 timeout_callback_
.Cancel();
518 void LocalDomainResolverImpl::Start() {
519 if (address_family_
== net::ADDRESS_FAMILY_IPV4
||
520 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
521 transaction_a_
= CreateTransaction(net::dns_protocol::kTypeA
);
522 transaction_a_
->Start();
525 if (address_family_
== net::ADDRESS_FAMILY_IPV6
||
526 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
527 transaction_aaaa_
= CreateTransaction(net::dns_protocol::kTypeAAAA
);
528 transaction_aaaa_
->Start();
532 scoped_ptr
<net::MDnsTransaction
> LocalDomainResolverImpl::CreateTransaction(
534 return mdns_client_
->CreateTransaction(
535 type
, domain_
, net::MDnsTransaction::SINGLE_RESULT
|
536 net::MDnsTransaction::QUERY_CACHE
|
537 net::MDnsTransaction::QUERY_NETWORK
,
538 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete
,
539 base::Unretained(this)));
542 void LocalDomainResolverImpl::OnTransactionComplete(
543 net::MDnsTransaction::Result result
, const net::RecordParsed
* record
) {
544 transactions_finished_
++;
546 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
547 if (record
->type() == net::dns_protocol::kTypeA
) {
548 const net::ARecordRdata
* rdata
= record
->rdata
<net::ARecordRdata
>();
549 address_ipv4_
= rdata
->address();
551 DCHECK_EQ(net::dns_protocol::kTypeAAAA
, record
->type());
552 const net::AAAARecordRdata
* rdata
= record
->rdata
<net::AAAARecordRdata
>();
553 address_ipv6_
= rdata
->address();
557 if (transactions_finished_
== 1 &&
558 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
559 timeout_callback_
.Reset(base::Bind(
560 &LocalDomainResolverImpl::SendResolvedAddresses
,
561 base::Unretained(this)));
563 base::MessageLoop::current()->PostDelayedTask(
565 timeout_callback_
.callback(),
566 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs
));
567 } else if (transactions_finished_
== 2
568 || address_family_
!= net::ADDRESS_FAMILY_UNSPECIFIED
) {
569 SendResolvedAddresses();
573 bool LocalDomainResolverImpl::IsSuccess() {
574 return !address_ipv4_
.empty() || !address_ipv6_
.empty();
577 void LocalDomainResolverImpl::SendResolvedAddresses() {
578 transaction_a_
.reset();
579 transaction_aaaa_
.reset();
580 timeout_callback_
.Cancel();
581 callback_
.Run(IsSuccess(), address_ipv4_
, address_ipv6_
);
584 } // namespace local_discovery