1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
7 #include "base/logging.h"
8 #include "base/memory/singleton.h"
9 #include "base/message_loop/message_loop_proxy.h"
10 #include "base/stl_util.h"
11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/record_rdata.h"
15 namespace local_discovery
{
18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
20 const int kLocalDomainSecondAddressTimeoutMs
= 100;
22 const int kInitialRequeryTimeSeconds
= 1;
23 const int kMaxRequeryTimeSeconds
= 2; // Time for last requery
26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27 net::MDnsClient
* mdns_client
) : mdns_client_(mdns_client
) {
30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
33 scoped_ptr
<ServiceWatcher
> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34 const std::string
& service_type
,
35 const ServiceWatcher::UpdatedCallback
& callback
) {
36 return scoped_ptr
<ServiceWatcher
>(new ServiceWatcherImpl(
37 service_type
, callback
, mdns_client_
));
40 scoped_ptr
<ServiceResolver
> ServiceDiscoveryClientImpl::CreateServiceResolver(
41 const std::string
& service_name
,
42 const ServiceResolver::ResolveCompleteCallback
& callback
) {
43 return scoped_ptr
<ServiceResolver
>(new ServiceResolverImpl(
44 service_name
, callback
, mdns_client_
));
47 scoped_ptr
<LocalDomainResolver
>
48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49 const std::string
& domain
,
50 net::AddressFamily address_family
,
51 const LocalDomainResolver::IPAddressCallback
& callback
) {
52 return scoped_ptr
<LocalDomainResolver
>(new LocalDomainResolverImpl(
53 domain
, address_family
, callback
, mdns_client_
));
56 ServiceWatcherImpl::ServiceWatcherImpl(
57 const std::string
& service_type
,
58 const ServiceWatcher::UpdatedCallback
& callback
,
59 net::MDnsClient
* mdns_client
)
60 : service_type_(service_type
), callback_(callback
), started_(false),
61 mdns_client_(mdns_client
) {
64 void ServiceWatcherImpl::Start() {
66 listener_
= mdns_client_
->CreateListener(
67 net::dns_protocol::kTypePTR
, service_type_
, this);
68 started_
= listener_
->Start();
73 ServiceWatcherImpl::~ServiceWatcherImpl() {
76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update
) {
80 SendQuery(kInitialRequeryTimeSeconds
, force_update
);
83 void ServiceWatcherImpl::ReadCachedServices() {
85 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
89 bool ServiceWatcherImpl::CreateTransaction(
90 bool network
, bool cache
, bool force_refresh
,
91 scoped_ptr
<net::MDnsTransaction
>* transaction
) {
92 int transaction_flags
= 0;
94 transaction_flags
|= net::MDnsTransaction::QUERY_NETWORK
;
97 transaction_flags
|= net::MDnsTransaction::QUERY_CACHE
;
99 // TODO(noamsml): Add flag for force_refresh when supported.
101 if (transaction_flags
) {
102 *transaction
= mdns_client_
->CreateTransaction(
103 net::dns_protocol::kTypePTR
, service_type_
, transaction_flags
,
104 base::Bind(&ServiceWatcherImpl::OnTransactionResponse
,
105 base::Unretained(this), transaction
));
106 return (*transaction
)->Start();
112 std::string
ServiceWatcherImpl::GetServiceType() const {
113 return listener_
->GetName();
116 void ServiceWatcherImpl::OnRecordUpdate(
117 net::MDnsListener::UpdateType update
,
118 const net::RecordParsed
* record
) {
120 if (record
->type() == net::dns_protocol::kTypePTR
) {
121 DCHECK(record
->name() == GetServiceType());
122 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
125 case net::MDnsListener::RECORD_ADDED
:
126 AddService(rdata
->ptrdomain());
128 case net::MDnsListener::RECORD_CHANGED
:
131 case net::MDnsListener::RECORD_REMOVED
:
132 RemoveService(rdata
->ptrdomain());
136 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
||
137 record
->type() == net::dns_protocol::kTypeTXT
);
138 DCHECK(services_
.find(record
->name()) != services_
.end());
140 DeferUpdate(UPDATE_CHANGED
, record
->name());
144 void ServiceWatcherImpl::OnCachePurged() {
145 // Not yet implemented.
148 void ServiceWatcherImpl::OnTransactionResponse(
149 scoped_ptr
<net::MDnsTransaction
>* transaction
,
150 net::MDnsTransaction::Result result
,
151 const net::RecordParsed
* record
) {
153 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
154 const net::PtrRecordRdata
* rdata
= record
->rdata
<net::PtrRecordRdata
>();
156 AddService(rdata
->ptrdomain());
157 } else if (result
== net::MDnsTransaction::RESULT_DONE
) {
158 transaction
->reset();
161 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
162 // record for PTR records on any name.
165 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
166 const std::string
& service_name
,
167 ServiceWatcherImpl
* watcher
,
168 net::MDnsClient
* mdns_client
) : update_pending_(false) {
169 srv_listener_
= mdns_client
->CreateListener(
170 net::dns_protocol::kTypeSRV
, service_name
, watcher
);
171 txt_listener_
= mdns_client
->CreateListener(
172 net::dns_protocol::kTypeTXT
, service_name
, watcher
);
175 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
178 bool ServiceWatcherImpl::ServiceListeners::Start() {
179 if (!srv_listener_
->Start())
181 return txt_listener_
->Start();
184 void ServiceWatcherImpl::AddService(const std::string
& service
) {
186 std::pair
<ServiceListenersMap::iterator
, bool> found
= services_
.insert(
187 make_pair(service
, linked_ptr
<ServiceListeners
>(NULL
)));
188 if (found
.second
) { // Newly inserted.
189 found
.first
->second
= linked_ptr
<ServiceListeners
>(
190 new ServiceListeners(service
, this, mdns_client_
));
191 bool success
= found
.first
->second
->Start();
193 DeferUpdate(UPDATE_ADDED
, service
);
199 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type
,
200 const std::string
& service_name
) {
201 ServiceListenersMap::iterator found
= services_
.find(service_name
);
203 if (found
!= services_
.end() && !found
->second
->update_pending()) {
204 found
->second
->set_update_pending(true);
205 base::MessageLoop::current()->PostTask(
207 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate
, AsWeakPtr(),
208 update_type
, service_name
));
212 void ServiceWatcherImpl::DeliverDeferredUpdate(
213 ServiceWatcher::UpdateType update_type
, const std::string
& service_name
) {
214 ServiceListenersMap::iterator found
= services_
.find(service_name
);
216 if (found
!= services_
.end()) {
217 found
->second
->set_update_pending(false);
218 if (!callback_
.is_null())
219 callback_
.Run(update_type
, service_name
);
223 void ServiceWatcherImpl::RemoveService(const std::string
& service
) {
225 ServiceListenersMap::iterator found
= services_
.find(service
);
226 if (found
!= services_
.end()) {
227 services_
.erase(found
);
228 if (!callback_
.is_null())
229 callback_
.Run(UPDATE_REMOVED
, service
);
233 void ServiceWatcherImpl::OnNsecRecord(const std::string
& name
,
235 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
239 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds
) {
240 if (timeout_seconds
<= kMaxRequeryTimeSeconds
) {
241 base::MessageLoop::current()->PostDelayedTask(
243 base::Bind(&ServiceWatcherImpl::SendQuery
,
245 timeout_seconds
* 2 /*next_timeout_seconds*/,
246 false /*force_update*/),
247 base::TimeDelta::FromSeconds(timeout_seconds
));
251 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds
,
253 CreateTransaction(true /*network*/, false /*cache*/, force_update
,
254 &transaction_network_
);
255 ScheduleQuery(next_timeout_seconds
);
258 ServiceResolverImpl::ServiceResolverImpl(
259 const std::string
& service_name
,
260 const ResolveCompleteCallback
& callback
,
261 net::MDnsClient
* mdns_client
)
262 : service_name_(service_name
), callback_(callback
),
263 metadata_resolved_(false), address_resolved_(false),
264 mdns_client_(mdns_client
) {
267 void ServiceResolverImpl::StartResolving() {
268 address_resolved_
= false;
269 metadata_resolved_
= false;
270 service_staging_
= ServiceDescription();
271 service_staging_
.service_name
= service_name_
;
273 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
274 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT
);
278 ServiceResolverImpl::~ServiceResolverImpl() {
281 bool ServiceResolverImpl::CreateTxtTransaction() {
282 txt_transaction_
= mdns_client_
->CreateTransaction(
283 net::dns_protocol::kTypeTXT
, service_name_
,
284 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
285 net::MDnsTransaction::QUERY_NETWORK
,
286 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse
,
288 return txt_transaction_
->Start();
291 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
292 void ServiceResolverImpl::CreateATransaction() {
293 a_transaction_
= mdns_client_
->CreateTransaction(
294 net::dns_protocol::kTypeA
,
295 service_staging_
.address
.host(),
296 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
,
297 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse
,
299 a_transaction_
->Start();
302 bool ServiceResolverImpl::CreateSrvTransaction() {
303 srv_transaction_
= mdns_client_
->CreateTransaction(
304 net::dns_protocol::kTypeSRV
, service_name_
,
305 net::MDnsTransaction::SINGLE_RESULT
| net::MDnsTransaction::QUERY_CACHE
|
306 net::MDnsTransaction::QUERY_NETWORK
,
307 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse
,
309 return srv_transaction_
->Start();
312 std::string
ServiceResolverImpl::GetName() const {
313 return service_name_
;
316 void ServiceResolverImpl::SrvRecordTransactionResponse(
317 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
318 srv_transaction_
.reset();
319 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
321 service_staging_
.address
= RecordToAddress(record
);
322 service_staging_
.last_seen
= record
->time_created();
323 CreateATransaction();
325 ServiceNotFound(MDnsStatusToRequestStatus(status
));
329 void ServiceResolverImpl::TxtRecordTransactionResponse(
330 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
331 txt_transaction_
.reset();
332 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
334 service_staging_
.metadata
= RecordToMetadata(record
);
336 service_staging_
.metadata
= std::vector
<std::string
>();
339 metadata_resolved_
= true;
340 AlertCallbackIfReady();
343 void ServiceResolverImpl::ARecordTransactionResponse(
344 net::MDnsTransaction::Result status
, const net::RecordParsed
* record
) {
345 a_transaction_
.reset();
347 if (status
== net::MDnsTransaction::RESULT_RECORD
) {
349 service_staging_
.ip_address
= RecordToIPAddress(record
);
351 service_staging_
.ip_address
= net::IPAddressNumber();
354 address_resolved_
= true;
355 AlertCallbackIfReady();
358 void ServiceResolverImpl::AlertCallbackIfReady() {
359 if (metadata_resolved_
&& address_resolved_
) {
360 txt_transaction_
.reset();
361 srv_transaction_
.reset();
362 a_transaction_
.reset();
363 if (!callback_
.is_null())
364 callback_
.Run(STATUS_SUCCESS
, service_staging_
);
368 void ServiceResolverImpl::ServiceNotFound(
369 ServiceResolver::RequestStatus status
) {
370 txt_transaction_
.reset();
371 srv_transaction_
.reset();
372 a_transaction_
.reset();
373 if (!callback_
.is_null())
374 callback_
.Run(status
, ServiceDescription());
377 ServiceResolver::RequestStatus
ServiceResolverImpl::MDnsStatusToRequestStatus(
378 net::MDnsTransaction::Result status
) const {
380 case net::MDnsTransaction::RESULT_RECORD
:
381 return ServiceResolver::STATUS_SUCCESS
;
382 case net::MDnsTransaction::RESULT_NO_RESULTS
:
383 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
384 case net::MDnsTransaction::RESULT_NSEC
:
385 return ServiceResolver::STATUS_KNOWN_NONEXISTENT
;
386 case net::MDnsTransaction::RESULT_DONE
: // Pass through.
389 return ServiceResolver::STATUS_REQUEST_TIMEOUT
;
393 const std::vector
<std::string
>& ServiceResolverImpl::RecordToMetadata(
394 const net::RecordParsed
* record
) const {
395 DCHECK(record
->type() == net::dns_protocol::kTypeTXT
);
396 const net::TxtRecordRdata
* txt_rdata
= record
->rdata
<net::TxtRecordRdata
>();
398 return txt_rdata
->texts();
401 net::HostPortPair
ServiceResolverImpl::RecordToAddress(
402 const net::RecordParsed
* record
) const {
403 DCHECK(record
->type() == net::dns_protocol::kTypeSRV
);
404 const net::SrvRecordRdata
* srv_rdata
= record
->rdata
<net::SrvRecordRdata
>();
406 return net::HostPortPair(srv_rdata
->target(), srv_rdata
->port());
409 const net::IPAddressNumber
& ServiceResolverImpl::RecordToIPAddress(
410 const net::RecordParsed
* record
) const {
411 DCHECK(record
->type() == net::dns_protocol::kTypeA
);
412 const net::ARecordRdata
* a_rdata
= record
->rdata
<net::ARecordRdata
>();
414 return a_rdata
->address();
417 LocalDomainResolverImpl::LocalDomainResolverImpl(
418 const std::string
& domain
,
419 net::AddressFamily address_family
,
420 const IPAddressCallback
& callback
,
421 net::MDnsClient
* mdns_client
)
422 : domain_(domain
), address_family_(address_family
), callback_(callback
),
423 transactions_finished_(0), mdns_client_(mdns_client
) {
426 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
427 timeout_callback_
.Cancel();
430 void LocalDomainResolverImpl::Start() {
431 if (address_family_
== net::ADDRESS_FAMILY_IPV4
||
432 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
433 transaction_a_
= CreateTransaction(net::dns_protocol::kTypeA
);
434 transaction_a_
->Start();
437 if (address_family_
== net::ADDRESS_FAMILY_IPV6
||
438 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
439 transaction_aaaa_
= CreateTransaction(net::dns_protocol::kTypeAAAA
);
440 transaction_aaaa_
->Start();
444 scoped_ptr
<net::MDnsTransaction
> LocalDomainResolverImpl::CreateTransaction(
446 return mdns_client_
->CreateTransaction(
447 type
, domain_
, net::MDnsTransaction::SINGLE_RESULT
|
448 net::MDnsTransaction::QUERY_CACHE
|
449 net::MDnsTransaction::QUERY_NETWORK
,
450 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete
,
451 base::Unretained(this)));
454 void LocalDomainResolverImpl::OnTransactionComplete(
455 net::MDnsTransaction::Result result
, const net::RecordParsed
* record
) {
456 transactions_finished_
++;
458 if (result
== net::MDnsTransaction::RESULT_RECORD
) {
459 if (record
->type() == net::dns_protocol::kTypeA
) {
460 const net::ARecordRdata
* rdata
= record
->rdata
<net::ARecordRdata
>();
461 address_ipv4_
= rdata
->address();
463 DCHECK_EQ(net::dns_protocol::kTypeAAAA
, record
->type());
464 const net::AAAARecordRdata
* rdata
= record
->rdata
<net::AAAARecordRdata
>();
465 address_ipv6_
= rdata
->address();
469 if (transactions_finished_
== 1 &&
470 address_family_
== net::ADDRESS_FAMILY_UNSPECIFIED
) {
471 timeout_callback_
.Reset(base::Bind(
472 &LocalDomainResolverImpl::SendResolvedAddresses
,
473 base::Unretained(this)));
475 base::MessageLoop::current()->PostDelayedTask(
477 timeout_callback_
.callback(),
478 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs
));
479 } else if (transactions_finished_
== 2
480 || address_family_
!= net::ADDRESS_FAMILY_UNSPECIFIED
) {
481 SendResolvedAddresses();
485 bool LocalDomainResolverImpl::IsSuccess() {
486 return !address_ipv4_
.empty() || !address_ipv6_
.empty();
489 void LocalDomainResolverImpl::SendResolvedAddresses() {
490 transaction_a_
.reset();
491 transaction_aaaa_
.reset();
492 timeout_callback_
.Cancel();
493 callback_
.Run(IsSuccess(), address_ipv4_
, address_ipv6_
);
496 } // namespace local_discovery