Revert of Add field trial testing config for Windows and a PRESUBMIT (patchset #4...
[chromium-blink-merge.git] / chrome / common / local_discovery / service_discovery_client_impl.cc
blobc86f72652261b749ef3373b2edcc5d26751490b2
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <utility>
7 #include "base/logging.h"
8 #include "base/memory/singleton.h"
9 #include "base/message_loop/message_loop_proxy.h"
10 #include "base/stl_util.h"
11 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/record_rdata.h"
15 namespace local_discovery {
17 namespace {
18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
19 // interface.
20 const int kLocalDomainSecondAddressTimeoutMs = 100;
22 const int kInitialRequeryTimeSeconds = 1;
23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery
26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34 const std::string& service_type,
35 const ServiceWatcher::UpdatedCallback& callback) {
36 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
37 service_type, callback, mdns_client_));
40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
41 const std::string& service_name,
42 const ServiceResolver::ResolveCompleteCallback& callback) {
43 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
44 service_name, callback, mdns_client_));
47 scoped_ptr<LocalDomainResolver>
48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49 const std::string& domain,
50 net::AddressFamily address_family,
51 const LocalDomainResolver::IPAddressCallback& callback) {
52 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
53 domain, address_family, callback, mdns_client_));
56 ServiceWatcherImpl::ServiceWatcherImpl(
57 const std::string& service_type,
58 const ServiceWatcher::UpdatedCallback& callback,
59 net::MDnsClient* mdns_client)
60 : service_type_(service_type), callback_(callback), started_(false),
61 actively_refresh_services_(false), mdns_client_(mdns_client) {
64 void ServiceWatcherImpl::Start() {
65 DCHECK(!started_);
66 listener_ = mdns_client_->CreateListener(
67 net::dns_protocol::kTypePTR, service_type_, this);
68 started_ = listener_->Start();
69 if (started_)
70 ReadCachedServices();
73 ServiceWatcherImpl::~ServiceWatcherImpl() {
76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
77 DCHECK(started_);
78 if (force_update)
79 services_.clear();
80 SendQuery(kInitialRequeryTimeSeconds, force_update);
83 void ServiceWatcherImpl::SetActivelyRefreshServices(
84 bool actively_refresh_services) {
85 DCHECK(started_);
86 actively_refresh_services_ = actively_refresh_services;
88 for (ServiceListenersMap::iterator i = services_.begin();
89 i != services_.end(); i++) {
90 i->second->SetActiveRefresh(actively_refresh_services);
94 void ServiceWatcherImpl::ReadCachedServices() {
95 DCHECK(started_);
96 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
97 &transaction_cache_);
100 bool ServiceWatcherImpl::CreateTransaction(
101 bool network, bool cache, bool force_refresh,
102 scoped_ptr<net::MDnsTransaction>* transaction) {
103 int transaction_flags = 0;
104 if (network)
105 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
107 if (cache)
108 transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
110 // TODO(noamsml): Add flag for force_refresh when supported.
112 if (transaction_flags) {
113 *transaction = mdns_client_->CreateTransaction(
114 net::dns_protocol::kTypePTR, service_type_, transaction_flags,
115 base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
116 base::Unretained(this), transaction));
117 return (*transaction)->Start();
120 return true;
123 std::string ServiceWatcherImpl::GetServiceType() const {
124 return listener_->GetName();
127 void ServiceWatcherImpl::OnRecordUpdate(
128 net::MDnsListener::UpdateType update,
129 const net::RecordParsed* record) {
130 DCHECK(started_);
131 if (record->type() == net::dns_protocol::kTypePTR) {
132 DCHECK(record->name() == GetServiceType());
133 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
135 switch (update) {
136 case net::MDnsListener::RECORD_ADDED:
137 AddService(rdata->ptrdomain());
138 break;
139 case net::MDnsListener::RECORD_CHANGED:
140 NOTREACHED();
141 break;
142 case net::MDnsListener::RECORD_REMOVED:
143 RemovePTR(rdata->ptrdomain());
144 break;
146 } else {
147 DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
148 record->type() == net::dns_protocol::kTypeTXT);
149 DCHECK(services_.find(record->name()) != services_.end());
151 if (record->type() == net::dns_protocol::kTypeSRV) {
152 if (update == net::MDnsListener::RECORD_REMOVED) {
153 RemoveSRV(record->name());
154 } else if (update == net::MDnsListener::RECORD_ADDED) {
155 AddSRV(record->name());
159 // If this is the first time we see an SRV record, do not send
160 // an UPDATE_CHANGED.
161 if (record->type() != net::dns_protocol::kTypeSRV ||
162 update != net::MDnsListener::RECORD_ADDED) {
163 DeferUpdate(UPDATE_CHANGED, record->name());
168 void ServiceWatcherImpl::OnCachePurged() {
169 // Not yet implemented.
172 void ServiceWatcherImpl::OnTransactionResponse(
173 scoped_ptr<net::MDnsTransaction>* transaction,
174 net::MDnsTransaction::Result result,
175 const net::RecordParsed* record) {
176 DCHECK(started_);
177 if (result == net::MDnsTransaction::RESULT_RECORD) {
178 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
179 DCHECK(rdata);
180 AddService(rdata->ptrdomain());
181 } else if (result == net::MDnsTransaction::RESULT_DONE) {
182 transaction->reset();
185 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
186 // record for PTR records on any name.
189 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
190 const std::string& service_name,
191 ServiceWatcherImpl* watcher,
192 net::MDnsClient* mdns_client)
193 : service_name_(service_name), mdns_client_(mdns_client),
194 update_pending_(false), has_ptr_(true), has_srv_(false) {
195 srv_listener_ = mdns_client->CreateListener(
196 net::dns_protocol::kTypeSRV, service_name, watcher);
197 txt_listener_ = mdns_client->CreateListener(
198 net::dns_protocol::kTypeTXT, service_name, watcher);
201 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
204 bool ServiceWatcherImpl::ServiceListeners::Start() {
205 if (!srv_listener_->Start())
206 return false;
207 return txt_listener_->Start();
210 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
211 bool active_refresh) {
212 srv_listener_->SetActiveRefresh(active_refresh);
214 if (active_refresh && !has_srv_) {
215 DCHECK(has_ptr_);
216 srv_transaction_ = mdns_client_->CreateTransaction(
217 net::dns_protocol::kTypeSRV, service_name_,
218 net::MDnsTransaction::SINGLE_RESULT |
219 net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
220 base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
221 base::Unretained(this)));
222 srv_transaction_->Start();
223 } else if (!active_refresh) {
224 srv_transaction_.reset();
228 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
229 net::MDnsTransaction::Result result,
230 const net::RecordParsed* record) {
231 set_has_srv(record != NULL);
234 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
235 has_srv_ = has_srv;
237 srv_transaction_.reset();
240 void ServiceWatcherImpl::AddService(const std::string& service) {
241 DCHECK(started_);
242 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
243 make_pair(service, linked_ptr<ServiceListeners>(NULL)));
245 if (found.second) { // Newly inserted.
246 found.first->second = linked_ptr<ServiceListeners>(
247 new ServiceListeners(service, this, mdns_client_));
248 bool success = found.first->second->Start();
249 found.first->second->SetActiveRefresh(actively_refresh_services_);
250 DeferUpdate(UPDATE_ADDED, service);
252 DCHECK(success);
255 found.first->second->set_has_ptr(true);
258 void ServiceWatcherImpl::AddSRV(const std::string& service) {
259 DCHECK(started_);
261 ServiceListenersMap::iterator found = services_.find(service);
262 if (found != services_.end()) {
263 found->second->set_has_srv(true);
267 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
268 const std::string& service_name) {
269 ServiceListenersMap::iterator found = services_.find(service_name);
271 if (found != services_.end() && !found->second->update_pending()) {
272 found->second->set_update_pending(true);
273 base::MessageLoop::current()->PostTask(
274 FROM_HERE,
275 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
276 update_type, service_name));
280 void ServiceWatcherImpl::DeliverDeferredUpdate(
281 ServiceWatcher::UpdateType update_type, const std::string& service_name) {
282 ServiceListenersMap::iterator found = services_.find(service_name);
284 if (found != services_.end()) {
285 found->second->set_update_pending(false);
286 if (!callback_.is_null())
287 callback_.Run(update_type, service_name);
291 void ServiceWatcherImpl::RemovePTR(const std::string& service) {
292 DCHECK(started_);
294 ServiceListenersMap::iterator found = services_.find(service);
295 if (found != services_.end()) {
296 found->second->set_has_ptr(false);
298 if (!found->second->has_ptr_or_srv()) {
299 services_.erase(found);
300 if (!callback_.is_null())
301 callback_.Run(UPDATE_REMOVED, service);
306 void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
307 DCHECK(started_);
309 ServiceListenersMap::iterator found = services_.find(service);
310 if (found != services_.end()) {
311 found->second->set_has_srv(false);
313 if (!found->second->has_ptr_or_srv()) {
314 services_.erase(found);
315 if (!callback_.is_null())
316 callback_.Run(UPDATE_REMOVED, service);
321 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
322 unsigned rrtype) {
323 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
324 // on any name.
327 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
328 if (timeout_seconds <= kMaxRequeryTimeSeconds) {
329 base::MessageLoop::current()->PostDelayedTask(
330 FROM_HERE,
331 base::Bind(&ServiceWatcherImpl::SendQuery,
332 AsWeakPtr(),
333 timeout_seconds * 2 /*next_timeout_seconds*/,
334 false /*force_update*/),
335 base::TimeDelta::FromSeconds(timeout_seconds));
339 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
340 bool force_update) {
341 CreateTransaction(true /*network*/, false /*cache*/, force_update,
342 &transaction_network_);
343 ScheduleQuery(next_timeout_seconds);
346 ServiceResolverImpl::ServiceResolverImpl(
347 const std::string& service_name,
348 const ResolveCompleteCallback& callback,
349 net::MDnsClient* mdns_client)
350 : service_name_(service_name), callback_(callback),
351 metadata_resolved_(false), address_resolved_(false),
352 mdns_client_(mdns_client) {
355 void ServiceResolverImpl::StartResolving() {
356 address_resolved_ = false;
357 metadata_resolved_ = false;
358 service_staging_ = ServiceDescription();
359 service_staging_.service_name = service_name_;
361 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
362 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
366 ServiceResolverImpl::~ServiceResolverImpl() {
369 bool ServiceResolverImpl::CreateTxtTransaction() {
370 txt_transaction_ = mdns_client_->CreateTransaction(
371 net::dns_protocol::kTypeTXT, service_name_,
372 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
373 net::MDnsTransaction::QUERY_NETWORK,
374 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
375 AsWeakPtr()));
376 return txt_transaction_->Start();
379 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
380 void ServiceResolverImpl::CreateATransaction() {
381 a_transaction_ = mdns_client_->CreateTransaction(
382 net::dns_protocol::kTypeA,
383 service_staging_.address.host(),
384 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
385 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
386 AsWeakPtr()));
387 a_transaction_->Start();
390 bool ServiceResolverImpl::CreateSrvTransaction() {
391 srv_transaction_ = mdns_client_->CreateTransaction(
392 net::dns_protocol::kTypeSRV, service_name_,
393 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
394 net::MDnsTransaction::QUERY_NETWORK,
395 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
396 AsWeakPtr()));
397 return srv_transaction_->Start();
400 std::string ServiceResolverImpl::GetName() const {
401 return service_name_;
404 void ServiceResolverImpl::SrvRecordTransactionResponse(
405 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
406 srv_transaction_.reset();
407 if (status == net::MDnsTransaction::RESULT_RECORD) {
408 DCHECK(record);
409 service_staging_.address = RecordToAddress(record);
410 service_staging_.last_seen = record->time_created();
411 CreateATransaction();
412 } else {
413 ServiceNotFound(MDnsStatusToRequestStatus(status));
417 void ServiceResolverImpl::TxtRecordTransactionResponse(
418 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
419 txt_transaction_.reset();
420 if (status == net::MDnsTransaction::RESULT_RECORD) {
421 DCHECK(record);
422 service_staging_.metadata = RecordToMetadata(record);
423 } else {
424 service_staging_.metadata = std::vector<std::string>();
427 metadata_resolved_ = true;
428 AlertCallbackIfReady();
431 void ServiceResolverImpl::ARecordTransactionResponse(
432 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
433 a_transaction_.reset();
435 if (status == net::MDnsTransaction::RESULT_RECORD) {
436 DCHECK(record);
437 service_staging_.ip_address = RecordToIPAddress(record);
438 } else {
439 service_staging_.ip_address = net::IPAddressNumber();
442 address_resolved_ = true;
443 AlertCallbackIfReady();
446 void ServiceResolverImpl::AlertCallbackIfReady() {
447 if (metadata_resolved_ && address_resolved_) {
448 txt_transaction_.reset();
449 srv_transaction_.reset();
450 a_transaction_.reset();
451 if (!callback_.is_null())
452 callback_.Run(STATUS_SUCCESS, service_staging_);
456 void ServiceResolverImpl::ServiceNotFound(
457 ServiceResolver::RequestStatus status) {
458 txt_transaction_.reset();
459 srv_transaction_.reset();
460 a_transaction_.reset();
461 if (!callback_.is_null())
462 callback_.Run(status, ServiceDescription());
465 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
466 net::MDnsTransaction::Result status) const {
467 switch (status) {
468 case net::MDnsTransaction::RESULT_RECORD:
469 return ServiceResolver::STATUS_SUCCESS;
470 case net::MDnsTransaction::RESULT_NO_RESULTS:
471 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
472 case net::MDnsTransaction::RESULT_NSEC:
473 return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
474 case net::MDnsTransaction::RESULT_DONE: // Pass through.
475 default:
476 NOTREACHED();
477 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
481 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
482 const net::RecordParsed* record) const {
483 DCHECK(record->type() == net::dns_protocol::kTypeTXT);
484 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
485 DCHECK(txt_rdata);
486 return txt_rdata->texts();
489 net::HostPortPair ServiceResolverImpl::RecordToAddress(
490 const net::RecordParsed* record) const {
491 DCHECK(record->type() == net::dns_protocol::kTypeSRV);
492 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
493 DCHECK(srv_rdata);
494 return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
497 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
498 const net::RecordParsed* record) const {
499 DCHECK(record->type() == net::dns_protocol::kTypeA);
500 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
501 DCHECK(a_rdata);
502 return a_rdata->address();
505 LocalDomainResolverImpl::LocalDomainResolverImpl(
506 const std::string& domain,
507 net::AddressFamily address_family,
508 const IPAddressCallback& callback,
509 net::MDnsClient* mdns_client)
510 : domain_(domain), address_family_(address_family), callback_(callback),
511 transactions_finished_(0), mdns_client_(mdns_client) {
514 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
515 timeout_callback_.Cancel();
518 void LocalDomainResolverImpl::Start() {
519 if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
520 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
521 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
522 transaction_a_->Start();
525 if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
526 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
527 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
528 transaction_aaaa_->Start();
532 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
533 uint16 type) {
534 return mdns_client_->CreateTransaction(
535 type, domain_, net::MDnsTransaction::SINGLE_RESULT |
536 net::MDnsTransaction::QUERY_CACHE |
537 net::MDnsTransaction::QUERY_NETWORK,
538 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
539 base::Unretained(this)));
542 void LocalDomainResolverImpl::OnTransactionComplete(
543 net::MDnsTransaction::Result result, const net::RecordParsed* record) {
544 transactions_finished_++;
546 if (result == net::MDnsTransaction::RESULT_RECORD) {
547 if (record->type() == net::dns_protocol::kTypeA) {
548 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
549 address_ipv4_ = rdata->address();
550 } else {
551 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
552 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
553 address_ipv6_ = rdata->address();
557 if (transactions_finished_ == 1 &&
558 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
559 timeout_callback_.Reset(base::Bind(
560 &LocalDomainResolverImpl::SendResolvedAddresses,
561 base::Unretained(this)));
563 base::MessageLoop::current()->PostDelayedTask(
564 FROM_HERE,
565 timeout_callback_.callback(),
566 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
567 } else if (transactions_finished_ == 2
568 || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
569 SendResolvedAddresses();
573 bool LocalDomainResolverImpl::IsSuccess() {
574 return !address_ipv4_.empty() || !address_ipv6_.empty();
577 void LocalDomainResolverImpl::SendResolvedAddresses() {
578 transaction_a_.reset();
579 transaction_aaaa_.reset();
580 timeout_callback_.Cancel();
581 callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
584 } // namespace local_discovery