Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / chrome / common / local_discovery / service_discovery_client_impl.cc
blob733e1666f16dad36de74113aa653cf3a2ddfdc2e
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <utility>
7 #include "base/location.h"
8 #include "base/logging.h"
9 #include "base/memory/singleton.h"
10 #include "base/single_thread_task_runner.h"
11 #include "base/stl_util.h"
12 #include "base/thread_task_runner_handle.h"
13 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
14 #include "net/dns/dns_protocol.h"
15 #include "net/dns/record_rdata.h"
17 namespace local_discovery {
19 namespace {
20 // TODO(noamsml): Make this configurable through the LocalDomainResolver
21 // interface.
22 const int kLocalDomainSecondAddressTimeoutMs = 100;
24 const int kInitialRequeryTimeSeconds = 1;
25 const int kMaxRequeryTimeSeconds = 2; // Time for last requery
28 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
29 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
32 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
35 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
36 const std::string& service_type,
37 const ServiceWatcher::UpdatedCallback& callback) {
38 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
39 service_type, callback, mdns_client_));
42 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
43 const std::string& service_name,
44 const ServiceResolver::ResolveCompleteCallback& callback) {
45 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
46 service_name, callback, mdns_client_));
49 scoped_ptr<LocalDomainResolver>
50 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
51 const std::string& domain,
52 net::AddressFamily address_family,
53 const LocalDomainResolver::IPAddressCallback& callback) {
54 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
55 domain, address_family, callback, mdns_client_));
58 ServiceWatcherImpl::ServiceWatcherImpl(
59 const std::string& service_type,
60 const ServiceWatcher::UpdatedCallback& callback,
61 net::MDnsClient* mdns_client)
62 : service_type_(service_type), callback_(callback), started_(false),
63 actively_refresh_services_(false), mdns_client_(mdns_client) {
66 void ServiceWatcherImpl::Start() {
67 DCHECK(!started_);
68 listener_ = mdns_client_->CreateListener(
69 net::dns_protocol::kTypePTR, service_type_, this);
70 started_ = listener_->Start();
71 if (started_)
72 ReadCachedServices();
75 ServiceWatcherImpl::~ServiceWatcherImpl() {
78 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
79 DCHECK(started_);
80 if (force_update)
81 services_.clear();
82 SendQuery(kInitialRequeryTimeSeconds, force_update);
85 void ServiceWatcherImpl::SetActivelyRefreshServices(
86 bool actively_refresh_services) {
87 DCHECK(started_);
88 actively_refresh_services_ = actively_refresh_services;
90 for (ServiceListenersMap::iterator i = services_.begin();
91 i != services_.end(); i++) {
92 i->second->SetActiveRefresh(actively_refresh_services);
96 void ServiceWatcherImpl::ReadCachedServices() {
97 DCHECK(started_);
98 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
99 &transaction_cache_);
102 bool ServiceWatcherImpl::CreateTransaction(
103 bool network, bool cache, bool force_refresh,
104 scoped_ptr<net::MDnsTransaction>* transaction) {
105 int transaction_flags = 0;
106 if (network)
107 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
109 if (cache)
110 transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
112 // TODO(noamsml): Add flag for force_refresh when supported.
114 if (transaction_flags) {
115 *transaction = mdns_client_->CreateTransaction(
116 net::dns_protocol::kTypePTR, service_type_, transaction_flags,
117 base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
118 base::Unretained(this), transaction));
119 return (*transaction)->Start();
122 return true;
125 std::string ServiceWatcherImpl::GetServiceType() const {
126 return listener_->GetName();
129 void ServiceWatcherImpl::OnRecordUpdate(
130 net::MDnsListener::UpdateType update,
131 const net::RecordParsed* record) {
132 DCHECK(started_);
133 if (record->type() == net::dns_protocol::kTypePTR) {
134 DCHECK(record->name() == GetServiceType());
135 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
137 switch (update) {
138 case net::MDnsListener::RECORD_ADDED:
139 AddService(rdata->ptrdomain());
140 break;
141 case net::MDnsListener::RECORD_CHANGED:
142 NOTREACHED();
143 break;
144 case net::MDnsListener::RECORD_REMOVED:
145 RemovePTR(rdata->ptrdomain());
146 break;
148 } else {
149 DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
150 record->type() == net::dns_protocol::kTypeTXT);
151 DCHECK(services_.find(record->name()) != services_.end());
153 if (record->type() == net::dns_protocol::kTypeSRV) {
154 if (update == net::MDnsListener::RECORD_REMOVED) {
155 RemoveSRV(record->name());
156 } else if (update == net::MDnsListener::RECORD_ADDED) {
157 AddSRV(record->name());
161 // If this is the first time we see an SRV record, do not send
162 // an UPDATE_CHANGED.
163 if (record->type() != net::dns_protocol::kTypeSRV ||
164 update != net::MDnsListener::RECORD_ADDED) {
165 DeferUpdate(UPDATE_CHANGED, record->name());
170 void ServiceWatcherImpl::OnCachePurged() {
171 // Not yet implemented.
174 void ServiceWatcherImpl::OnTransactionResponse(
175 scoped_ptr<net::MDnsTransaction>* transaction,
176 net::MDnsTransaction::Result result,
177 const net::RecordParsed* record) {
178 DCHECK(started_);
179 if (result == net::MDnsTransaction::RESULT_RECORD) {
180 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
181 DCHECK(rdata);
182 AddService(rdata->ptrdomain());
183 } else if (result == net::MDnsTransaction::RESULT_DONE) {
184 transaction->reset();
187 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
188 // record for PTR records on any name.
191 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
192 const std::string& service_name,
193 ServiceWatcherImpl* watcher,
194 net::MDnsClient* mdns_client)
195 : service_name_(service_name), mdns_client_(mdns_client),
196 update_pending_(false), has_ptr_(true), has_srv_(false) {
197 srv_listener_ = mdns_client->CreateListener(
198 net::dns_protocol::kTypeSRV, service_name, watcher);
199 txt_listener_ = mdns_client->CreateListener(
200 net::dns_protocol::kTypeTXT, service_name, watcher);
203 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
206 bool ServiceWatcherImpl::ServiceListeners::Start() {
207 if (!srv_listener_->Start())
208 return false;
209 return txt_listener_->Start();
212 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
213 bool active_refresh) {
214 srv_listener_->SetActiveRefresh(active_refresh);
216 if (active_refresh && !has_srv_) {
217 DCHECK(has_ptr_);
218 srv_transaction_ = mdns_client_->CreateTransaction(
219 net::dns_protocol::kTypeSRV, service_name_,
220 net::MDnsTransaction::SINGLE_RESULT |
221 net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
222 base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
223 base::Unretained(this)));
224 srv_transaction_->Start();
225 } else if (!active_refresh) {
226 srv_transaction_.reset();
230 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
231 net::MDnsTransaction::Result result,
232 const net::RecordParsed* record) {
233 set_has_srv(record != NULL);
236 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
237 has_srv_ = has_srv;
239 srv_transaction_.reset();
242 void ServiceWatcherImpl::AddService(const std::string& service) {
243 DCHECK(started_);
244 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
245 make_pair(service, linked_ptr<ServiceListeners>(NULL)));
247 if (found.second) { // Newly inserted.
248 found.first->second = linked_ptr<ServiceListeners>(
249 new ServiceListeners(service, this, mdns_client_));
250 bool success = found.first->second->Start();
251 found.first->second->SetActiveRefresh(actively_refresh_services_);
252 DeferUpdate(UPDATE_ADDED, service);
254 DCHECK(success);
257 found.first->second->set_has_ptr(true);
260 void ServiceWatcherImpl::AddSRV(const std::string& service) {
261 DCHECK(started_);
263 ServiceListenersMap::iterator found = services_.find(service);
264 if (found != services_.end()) {
265 found->second->set_has_srv(true);
269 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
270 const std::string& service_name) {
271 ServiceListenersMap::iterator found = services_.find(service_name);
273 if (found != services_.end() && !found->second->update_pending()) {
274 found->second->set_update_pending(true);
275 base::ThreadTaskRunnerHandle::Get()->PostTask(
276 FROM_HERE, base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate,
277 AsWeakPtr(), update_type, service_name));
281 void ServiceWatcherImpl::DeliverDeferredUpdate(
282 ServiceWatcher::UpdateType update_type, const std::string& service_name) {
283 ServiceListenersMap::iterator found = services_.find(service_name);
285 if (found != services_.end()) {
286 found->second->set_update_pending(false);
287 if (!callback_.is_null())
288 callback_.Run(update_type, service_name);
292 void ServiceWatcherImpl::RemovePTR(const std::string& service) {
293 DCHECK(started_);
295 ServiceListenersMap::iterator found = services_.find(service);
296 if (found != services_.end()) {
297 found->second->set_has_ptr(false);
299 if (!found->second->has_ptr_or_srv()) {
300 services_.erase(found);
301 if (!callback_.is_null())
302 callback_.Run(UPDATE_REMOVED, service);
307 void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
308 DCHECK(started_);
310 ServiceListenersMap::iterator found = services_.find(service);
311 if (found != services_.end()) {
312 found->second->set_has_srv(false);
314 if (!found->second->has_ptr_or_srv()) {
315 services_.erase(found);
316 if (!callback_.is_null())
317 callback_.Run(UPDATE_REMOVED, service);
322 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
323 unsigned rrtype) {
324 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
325 // on any name.
328 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
329 if (timeout_seconds <= kMaxRequeryTimeSeconds) {
330 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
331 FROM_HERE, base::Bind(&ServiceWatcherImpl::SendQuery, AsWeakPtr(),
332 timeout_seconds * 2 /*next_timeout_seconds*/,
333 false /*force_update*/),
334 base::TimeDelta::FromSeconds(timeout_seconds));
338 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
339 bool force_update) {
340 CreateTransaction(true /*network*/, false /*cache*/, force_update,
341 &transaction_network_);
342 ScheduleQuery(next_timeout_seconds);
345 ServiceResolverImpl::ServiceResolverImpl(
346 const std::string& service_name,
347 const ResolveCompleteCallback& callback,
348 net::MDnsClient* mdns_client)
349 : service_name_(service_name), callback_(callback),
350 metadata_resolved_(false), address_resolved_(false),
351 mdns_client_(mdns_client) {
354 void ServiceResolverImpl::StartResolving() {
355 address_resolved_ = false;
356 metadata_resolved_ = false;
357 service_staging_ = ServiceDescription();
358 service_staging_.service_name = service_name_;
360 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
361 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
365 ServiceResolverImpl::~ServiceResolverImpl() {
368 bool ServiceResolverImpl::CreateTxtTransaction() {
369 txt_transaction_ = mdns_client_->CreateTransaction(
370 net::dns_protocol::kTypeTXT, service_name_,
371 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
372 net::MDnsTransaction::QUERY_NETWORK,
373 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
374 AsWeakPtr()));
375 return txt_transaction_->Start();
378 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
379 void ServiceResolverImpl::CreateATransaction() {
380 a_transaction_ = mdns_client_->CreateTransaction(
381 net::dns_protocol::kTypeA,
382 service_staging_.address.host(),
383 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
384 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
385 AsWeakPtr()));
386 a_transaction_->Start();
389 bool ServiceResolverImpl::CreateSrvTransaction() {
390 srv_transaction_ = mdns_client_->CreateTransaction(
391 net::dns_protocol::kTypeSRV, service_name_,
392 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
393 net::MDnsTransaction::QUERY_NETWORK,
394 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
395 AsWeakPtr()));
396 return srv_transaction_->Start();
399 std::string ServiceResolverImpl::GetName() const {
400 return service_name_;
403 void ServiceResolverImpl::SrvRecordTransactionResponse(
404 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
405 srv_transaction_.reset();
406 if (status == net::MDnsTransaction::RESULT_RECORD) {
407 DCHECK(record);
408 service_staging_.address = RecordToAddress(record);
409 service_staging_.last_seen = record->time_created();
410 CreateATransaction();
411 } else {
412 ServiceNotFound(MDnsStatusToRequestStatus(status));
416 void ServiceResolverImpl::TxtRecordTransactionResponse(
417 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
418 txt_transaction_.reset();
419 if (status == net::MDnsTransaction::RESULT_RECORD) {
420 DCHECK(record);
421 service_staging_.metadata = RecordToMetadata(record);
422 } else {
423 service_staging_.metadata = std::vector<std::string>();
426 metadata_resolved_ = true;
427 AlertCallbackIfReady();
430 void ServiceResolverImpl::ARecordTransactionResponse(
431 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
432 a_transaction_.reset();
434 if (status == net::MDnsTransaction::RESULT_RECORD) {
435 DCHECK(record);
436 service_staging_.ip_address = RecordToIPAddress(record);
437 } else {
438 service_staging_.ip_address = net::IPAddressNumber();
441 address_resolved_ = true;
442 AlertCallbackIfReady();
445 void ServiceResolverImpl::AlertCallbackIfReady() {
446 if (metadata_resolved_ && address_resolved_) {
447 txt_transaction_.reset();
448 srv_transaction_.reset();
449 a_transaction_.reset();
450 if (!callback_.is_null())
451 callback_.Run(STATUS_SUCCESS, service_staging_);
455 void ServiceResolverImpl::ServiceNotFound(
456 ServiceResolver::RequestStatus status) {
457 txt_transaction_.reset();
458 srv_transaction_.reset();
459 a_transaction_.reset();
460 if (!callback_.is_null())
461 callback_.Run(status, ServiceDescription());
464 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
465 net::MDnsTransaction::Result status) const {
466 switch (status) {
467 case net::MDnsTransaction::RESULT_RECORD:
468 return ServiceResolver::STATUS_SUCCESS;
469 case net::MDnsTransaction::RESULT_NO_RESULTS:
470 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
471 case net::MDnsTransaction::RESULT_NSEC:
472 return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
473 case net::MDnsTransaction::RESULT_DONE: // Pass through.
474 default:
475 NOTREACHED();
476 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
480 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
481 const net::RecordParsed* record) const {
482 DCHECK(record->type() == net::dns_protocol::kTypeTXT);
483 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
484 DCHECK(txt_rdata);
485 return txt_rdata->texts();
488 net::HostPortPair ServiceResolverImpl::RecordToAddress(
489 const net::RecordParsed* record) const {
490 DCHECK(record->type() == net::dns_protocol::kTypeSRV);
491 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
492 DCHECK(srv_rdata);
493 return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
496 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
497 const net::RecordParsed* record) const {
498 DCHECK(record->type() == net::dns_protocol::kTypeA);
499 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
500 DCHECK(a_rdata);
501 return a_rdata->address();
504 LocalDomainResolverImpl::LocalDomainResolverImpl(
505 const std::string& domain,
506 net::AddressFamily address_family,
507 const IPAddressCallback& callback,
508 net::MDnsClient* mdns_client)
509 : domain_(domain), address_family_(address_family), callback_(callback),
510 transactions_finished_(0), mdns_client_(mdns_client) {
513 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
514 timeout_callback_.Cancel();
517 void LocalDomainResolverImpl::Start() {
518 if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
519 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
520 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
521 transaction_a_->Start();
524 if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
525 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
526 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
527 transaction_aaaa_->Start();
531 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
532 uint16 type) {
533 return mdns_client_->CreateTransaction(
534 type, domain_, net::MDnsTransaction::SINGLE_RESULT |
535 net::MDnsTransaction::QUERY_CACHE |
536 net::MDnsTransaction::QUERY_NETWORK,
537 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
538 base::Unretained(this)));
541 void LocalDomainResolverImpl::OnTransactionComplete(
542 net::MDnsTransaction::Result result, const net::RecordParsed* record) {
543 transactions_finished_++;
545 if (result == net::MDnsTransaction::RESULT_RECORD) {
546 if (record->type() == net::dns_protocol::kTypeA) {
547 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
548 address_ipv4_ = rdata->address();
549 } else {
550 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
551 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
552 address_ipv6_ = rdata->address();
556 if (transactions_finished_ == 1 &&
557 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
558 timeout_callback_.Reset(base::Bind(
559 &LocalDomainResolverImpl::SendResolvedAddresses,
560 base::Unretained(this)));
562 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
563 FROM_HERE, timeout_callback_.callback(),
564 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
565 } else if (transactions_finished_ == 2
566 || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
567 SendResolvedAddresses();
571 bool LocalDomainResolverImpl::IsSuccess() {
572 return !address_ipv4_.empty() || !address_ipv6_.empty();
575 void LocalDomainResolverImpl::SendResolvedAddresses() {
576 transaction_a_.reset();
577 transaction_aaaa_.reset();
578 timeout_callback_.Cancel();
579 callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
582 } // namespace local_discovery