Disable view source for Developer Tools.
[chromium-blink-merge.git] / chrome / utility / local_discovery / service_discovery_client_impl.cc
blobebb541d932faf6f63cadfd5784c266d0aff35071
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <utility>
7 #include "base/logging.h"
8 #include "base/memory/singleton.h"
9 #include "base/message_loop/message_loop_proxy.h"
10 #include "base/stl_util.h"
11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/record_rdata.h"
15 namespace local_discovery {
17 namespace {
18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
19 // interface.
20 const int kLocalDomainSecondAddressTimeoutMs = 100;
22 const int kInitialRequeryTimeSeconds = 1;
23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery
26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34 const std::string& service_type,
35 const ServiceWatcher::UpdatedCallback& callback) {
36 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
37 service_type, callback, mdns_client_));
40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
41 const std::string& service_name,
42 const ServiceResolver::ResolveCompleteCallback& callback) {
43 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
44 service_name, callback, mdns_client_));
47 scoped_ptr<LocalDomainResolver>
48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49 const std::string& domain,
50 net::AddressFamily address_family,
51 const LocalDomainResolver::IPAddressCallback& callback) {
52 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
53 domain, address_family, callback, mdns_client_));
56 ServiceWatcherImpl::ServiceWatcherImpl(
57 const std::string& service_type,
58 const ServiceWatcher::UpdatedCallback& callback,
59 net::MDnsClient* mdns_client)
60 : service_type_(service_type), callback_(callback), started_(false),
61 mdns_client_(mdns_client) {
64 void ServiceWatcherImpl::Start() {
65 DCHECK(!started_);
66 listener_ = mdns_client_->CreateListener(
67 net::dns_protocol::kTypePTR, service_type_, this);
68 started_ = listener_->Start();
69 if (started_)
70 ReadCachedServices();
73 ServiceWatcherImpl::~ServiceWatcherImpl() {
76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
77 DCHECK(started_);
78 if (force_update)
79 services_.clear();
80 SendQuery(kInitialRequeryTimeSeconds, force_update);
83 void ServiceWatcherImpl::ReadCachedServices() {
84 DCHECK(started_);
85 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
86 &transaction_cache_);
89 bool ServiceWatcherImpl::CreateTransaction(
90 bool network, bool cache, bool force_refresh,
91 scoped_ptr<net::MDnsTransaction>* transaction) {
92 int transaction_flags = 0;
93 if (network)
94 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
96 if (cache)
97 transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
99 // TODO(noamsml): Add flag for force_refresh when supported.
101 if (transaction_flags) {
102 *transaction = mdns_client_->CreateTransaction(
103 net::dns_protocol::kTypePTR, service_type_, transaction_flags,
104 base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
105 base::Unretained(this), transaction));
106 return (*transaction)->Start();
109 return true;
112 std::string ServiceWatcherImpl::GetServiceType() const {
113 return listener_->GetName();
116 void ServiceWatcherImpl::OnRecordUpdate(
117 net::MDnsListener::UpdateType update,
118 const net::RecordParsed* record) {
119 DCHECK(started_);
120 if (record->type() == net::dns_protocol::kTypePTR) {
121 DCHECK(record->name() == GetServiceType());
122 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
124 switch (update) {
125 case net::MDnsListener::RECORD_ADDED:
126 AddService(rdata->ptrdomain());
127 break;
128 case net::MDnsListener::RECORD_CHANGED:
129 NOTREACHED();
130 break;
131 case net::MDnsListener::RECORD_REMOVED:
132 RemoveService(rdata->ptrdomain());
133 break;
135 } else {
136 DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
137 record->type() == net::dns_protocol::kTypeTXT);
138 DCHECK(services_.find(record->name()) != services_.end());
140 DeferUpdate(UPDATE_CHANGED, record->name());
144 void ServiceWatcherImpl::OnCachePurged() {
145 // Not yet implemented.
148 void ServiceWatcherImpl::OnTransactionResponse(
149 scoped_ptr<net::MDnsTransaction>* transaction,
150 net::MDnsTransaction::Result result,
151 const net::RecordParsed* record) {
152 DCHECK(started_);
153 if (result == net::MDnsTransaction::RESULT_RECORD) {
154 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
155 DCHECK(rdata);
156 AddService(rdata->ptrdomain());
157 } else if (result == net::MDnsTransaction::RESULT_DONE) {
158 transaction->reset();
161 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
162 // record for PTR records on any name.
165 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
166 const std::string& service_name,
167 ServiceWatcherImpl* watcher,
168 net::MDnsClient* mdns_client) : update_pending_(false) {
169 srv_listener_ = mdns_client->CreateListener(
170 net::dns_protocol::kTypeSRV, service_name, watcher);
171 txt_listener_ = mdns_client->CreateListener(
172 net::dns_protocol::kTypeTXT, service_name, watcher);
175 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
178 bool ServiceWatcherImpl::ServiceListeners::Start() {
179 if (!srv_listener_->Start())
180 return false;
181 return txt_listener_->Start();
184 void ServiceWatcherImpl::AddService(const std::string& service) {
185 DCHECK(started_);
186 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
187 make_pair(service, linked_ptr<ServiceListeners>(NULL)));
188 if (found.second) { // Newly inserted.
189 found.first->second = linked_ptr<ServiceListeners>(
190 new ServiceListeners(service, this, mdns_client_));
191 bool success = found.first->second->Start();
193 DeferUpdate(UPDATE_ADDED, service);
195 DCHECK(success);
199 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
200 const std::string& service_name) {
201 ServiceListenersMap::iterator found = services_.find(service_name);
203 if (found != services_.end() && !found->second->update_pending()) {
204 found->second->set_update_pending(true);
205 base::MessageLoop::current()->PostTask(
206 FROM_HERE,
207 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
208 update_type, service_name));
212 void ServiceWatcherImpl::DeliverDeferredUpdate(
213 ServiceWatcher::UpdateType update_type, const std::string& service_name) {
214 ServiceListenersMap::iterator found = services_.find(service_name);
216 if (found != services_.end()) {
217 found->second->set_update_pending(false);
218 if (!callback_.is_null())
219 callback_.Run(update_type, service_name);
223 void ServiceWatcherImpl::RemoveService(const std::string& service) {
224 DCHECK(started_);
225 ServiceListenersMap::iterator found = services_.find(service);
226 if (found != services_.end()) {
227 services_.erase(found);
228 if (!callback_.is_null())
229 callback_.Run(UPDATE_REMOVED, service);
233 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
234 unsigned rrtype) {
235 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
236 // on any name.
239 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
240 if (timeout_seconds <= kMaxRequeryTimeSeconds) {
241 base::MessageLoop::current()->PostDelayedTask(
242 FROM_HERE,
243 base::Bind(&ServiceWatcherImpl::SendQuery,
244 AsWeakPtr(),
245 timeout_seconds * 2 /*next_timeout_seconds*/,
246 false /*force_update*/),
247 base::TimeDelta::FromSeconds(timeout_seconds));
251 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
252 bool force_update) {
253 CreateTransaction(true /*network*/, false /*cache*/, force_update,
254 &transaction_network_);
255 ScheduleQuery(next_timeout_seconds);
258 ServiceResolverImpl::ServiceResolverImpl(
259 const std::string& service_name,
260 const ResolveCompleteCallback& callback,
261 net::MDnsClient* mdns_client)
262 : service_name_(service_name), callback_(callback),
263 metadata_resolved_(false), address_resolved_(false),
264 mdns_client_(mdns_client) {
267 void ServiceResolverImpl::StartResolving() {
268 address_resolved_ = false;
269 metadata_resolved_ = false;
270 service_staging_ = ServiceDescription();
271 service_staging_.service_name = service_name_;
273 if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
274 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
278 ServiceResolverImpl::~ServiceResolverImpl() {
281 bool ServiceResolverImpl::CreateTxtTransaction() {
282 txt_transaction_ = mdns_client_->CreateTransaction(
283 net::dns_protocol::kTypeTXT, service_name_,
284 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
285 net::MDnsTransaction::QUERY_NETWORK,
286 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
287 AsWeakPtr()));
288 return txt_transaction_->Start();
291 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in
292 void ServiceResolverImpl::CreateATransaction() {
293 a_transaction_ = mdns_client_->CreateTransaction(
294 net::dns_protocol::kTypeA,
295 service_staging_.address.host(),
296 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
297 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
298 AsWeakPtr()));
299 a_transaction_->Start();
302 bool ServiceResolverImpl::CreateSrvTransaction() {
303 srv_transaction_ = mdns_client_->CreateTransaction(
304 net::dns_protocol::kTypeSRV, service_name_,
305 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
306 net::MDnsTransaction::QUERY_NETWORK,
307 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
308 AsWeakPtr()));
309 return srv_transaction_->Start();
312 std::string ServiceResolverImpl::GetName() const {
313 return service_name_;
316 void ServiceResolverImpl::SrvRecordTransactionResponse(
317 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
318 srv_transaction_.reset();
319 if (status == net::MDnsTransaction::RESULT_RECORD) {
320 DCHECK(record);
321 service_staging_.address = RecordToAddress(record);
322 service_staging_.last_seen = record->time_created();
323 CreateATransaction();
324 } else {
325 ServiceNotFound(MDnsStatusToRequestStatus(status));
329 void ServiceResolverImpl::TxtRecordTransactionResponse(
330 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
331 txt_transaction_.reset();
332 if (status == net::MDnsTransaction::RESULT_RECORD) {
333 DCHECK(record);
334 service_staging_.metadata = RecordToMetadata(record);
335 } else {
336 service_staging_.metadata = std::vector<std::string>();
339 metadata_resolved_ = true;
340 AlertCallbackIfReady();
343 void ServiceResolverImpl::ARecordTransactionResponse(
344 net::MDnsTransaction::Result status, const net::RecordParsed* record) {
345 a_transaction_.reset();
347 if (status == net::MDnsTransaction::RESULT_RECORD) {
348 DCHECK(record);
349 service_staging_.ip_address = RecordToIPAddress(record);
350 } else {
351 service_staging_.ip_address = net::IPAddressNumber();
354 address_resolved_ = true;
355 AlertCallbackIfReady();
358 void ServiceResolverImpl::AlertCallbackIfReady() {
359 if (metadata_resolved_ && address_resolved_) {
360 txt_transaction_.reset();
361 srv_transaction_.reset();
362 a_transaction_.reset();
363 if (!callback_.is_null())
364 callback_.Run(STATUS_SUCCESS, service_staging_);
368 void ServiceResolverImpl::ServiceNotFound(
369 ServiceResolver::RequestStatus status) {
370 txt_transaction_.reset();
371 srv_transaction_.reset();
372 a_transaction_.reset();
373 if (!callback_.is_null())
374 callback_.Run(status, ServiceDescription());
377 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
378 net::MDnsTransaction::Result status) const {
379 switch (status) {
380 case net::MDnsTransaction::RESULT_RECORD:
381 return ServiceResolver::STATUS_SUCCESS;
382 case net::MDnsTransaction::RESULT_NO_RESULTS:
383 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
384 case net::MDnsTransaction::RESULT_NSEC:
385 return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
386 case net::MDnsTransaction::RESULT_DONE: // Pass through.
387 default:
388 NOTREACHED();
389 return ServiceResolver::STATUS_REQUEST_TIMEOUT;
393 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
394 const net::RecordParsed* record) const {
395 DCHECK(record->type() == net::dns_protocol::kTypeTXT);
396 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
397 DCHECK(txt_rdata);
398 return txt_rdata->texts();
401 net::HostPortPair ServiceResolverImpl::RecordToAddress(
402 const net::RecordParsed* record) const {
403 DCHECK(record->type() == net::dns_protocol::kTypeSRV);
404 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
405 DCHECK(srv_rdata);
406 return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
409 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
410 const net::RecordParsed* record) const {
411 DCHECK(record->type() == net::dns_protocol::kTypeA);
412 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
413 DCHECK(a_rdata);
414 return a_rdata->address();
417 LocalDomainResolverImpl::LocalDomainResolverImpl(
418 const std::string& domain,
419 net::AddressFamily address_family,
420 const IPAddressCallback& callback,
421 net::MDnsClient* mdns_client)
422 : domain_(domain), address_family_(address_family), callback_(callback),
423 transactions_finished_(0), mdns_client_(mdns_client) {
426 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
427 timeout_callback_.Cancel();
430 void LocalDomainResolverImpl::Start() {
431 if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
432 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
433 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
434 transaction_a_->Start();
437 if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
438 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
439 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
440 transaction_aaaa_->Start();
444 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
445 uint16 type) {
446 return mdns_client_->CreateTransaction(
447 type, domain_, net::MDnsTransaction::SINGLE_RESULT |
448 net::MDnsTransaction::QUERY_CACHE |
449 net::MDnsTransaction::QUERY_NETWORK,
450 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
451 base::Unretained(this)));
454 void LocalDomainResolverImpl::OnTransactionComplete(
455 net::MDnsTransaction::Result result, const net::RecordParsed* record) {
456 transactions_finished_++;
458 if (result == net::MDnsTransaction::RESULT_RECORD) {
459 if (record->type() == net::dns_protocol::kTypeA) {
460 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
461 address_ipv4_ = rdata->address();
462 } else {
463 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
464 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
465 address_ipv6_ = rdata->address();
469 if (transactions_finished_ == 1 &&
470 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
471 timeout_callback_.Reset(base::Bind(
472 &LocalDomainResolverImpl::SendResolvedAddresses,
473 base::Unretained(this)));
475 base::MessageLoop::current()->PostDelayedTask(
476 FROM_HERE,
477 timeout_callback_.callback(),
478 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
479 } else if (transactions_finished_ == 2
480 || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
481 SendResolvedAddresses();
485 bool LocalDomainResolverImpl::IsSuccess() {
486 return !address_ipv4_.empty() || !address_ipv6_.empty();
489 void LocalDomainResolverImpl::SendResolvedAddresses() {
490 transaction_a_.reset();
491 transaction_aaaa_.reset();
492 timeout_callback_.Cancel();
493 callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
496 } // namespace local_discovery