1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/mdns_client_impl.h"
10 #include "base/message_loop/message_loop_proxy.h"
11 #include "base/stl_util.h"
12 #include "base/time/default_clock.h"
13 #include "base/time/time.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/net_errors.h"
16 #include "net/base/net_log.h"
17 #include "net/base/rand_callback.h"
18 #include "net/dns/dns_protocol.h"
19 #include "net/dns/record_rdata.h"
20 #include "net/udp/datagram_socket.h"
22 // TODO(gene): Remove this temporary method of disabling NSEC support once it
23 // becomes clear whether this feature should be
24 // supported. http://crbug.com/255232
31 const unsigned MDnsTransactionTimeoutSeconds
= 3;
32 // The fractions of the record's original TTL after which an active listener
33 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
34 // its cache. This happens both at 85% of the original TTL and again at 95% of
36 const double kListenerRefreshRatio1
= 0.85;
37 const double kListenerRefreshRatio2
= 0.95;
41 void MDnsSocketFactoryImpl::CreateSockets(
42 ScopedVector
<DatagramServerSocket
>* sockets
) {
43 InterfaceIndexFamilyList
interfaces(GetMDnsInterfacesToBind());
44 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
45 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
46 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
47 scoped_ptr
<DatagramServerSocket
> socket(
48 CreateAndBindMDnsSocket(interfaces
[i
].second
, interfaces
[i
].first
));
50 sockets
->push_back(socket
.release());
54 MDnsConnection::SocketHandler::SocketHandler(
55 scoped_ptr
<DatagramServerSocket
> socket
,
56 MDnsConnection
* connection
)
57 : socket_(socket
.Pass()),
58 connection_(connection
),
59 response_(dns_protocol::kMaxMulticastSize
),
60 send_in_progress_(false) {
63 MDnsConnection::SocketHandler::~SocketHandler() {
66 int MDnsConnection::SocketHandler::Start() {
68 int rv
= socket_
->GetLocalAddress(&end_point
);
71 DCHECK(end_point
.GetFamily() == ADDRESS_FAMILY_IPV4
||
72 end_point
.GetFamily() == ADDRESS_FAMILY_IPV6
);
73 multicast_addr_
= GetMDnsIPEndPoint(end_point
.GetFamily());
77 int MDnsConnection::SocketHandler::DoLoop(int rv
) {
80 connection_
->OnDatagramReceived(&response_
, recv_addr_
, rv
);
82 rv
= socket_
->RecvFrom(
83 response_
.io_buffer(),
84 response_
.io_buffer()->size(),
86 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived
,
87 base::Unretained(this)));
90 if (rv
!= ERR_IO_PENDING
)
96 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv
) {
101 connection_
->PostOnError(this, rv
);
104 void MDnsConnection::SocketHandler::Send(const scoped_refptr
<IOBuffer
>& buffer
,
106 if (send_in_progress_
) {
107 send_queue_
.push(std::make_pair(buffer
, size
));
110 int rv
= socket_
->SendTo(buffer
.get(),
113 base::Bind(&MDnsConnection::SocketHandler::SendDone
,
114 base::Unretained(this)));
115 if (rv
== ERR_IO_PENDING
) {
116 send_in_progress_
= true;
117 } else if (rv
< OK
) {
118 connection_
->PostOnError(this, rv
);
122 void MDnsConnection::SocketHandler::SendDone(int rv
) {
123 DCHECK(send_in_progress_
);
124 send_in_progress_
= false;
126 connection_
->PostOnError(this, rv
);
127 while (!send_in_progress_
&& !send_queue_
.empty()) {
128 std::pair
<scoped_refptr
<IOBuffer
>, unsigned> buffer
= send_queue_
.front();
130 Send(buffer
.first
, buffer
.second
);
134 MDnsConnection::MDnsConnection(MDnsConnection::Delegate
* delegate
)
135 : delegate_(delegate
), weak_ptr_factory_(this) {
138 MDnsConnection::~MDnsConnection() {
141 bool MDnsConnection::Init(MDnsSocketFactory
* socket_factory
) {
142 ScopedVector
<DatagramServerSocket
> sockets
;
143 socket_factory
->CreateSockets(&sockets
);
145 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
146 socket_handlers_
.push_back(
147 new MDnsConnection::SocketHandler(make_scoped_ptr(sockets
[i
]), this));
149 sockets
.weak_clear();
151 // All unbound sockets need to be bound before processing untrusted input.
152 // This is done for security reasons, so that an attacker can't get an unbound
154 for (size_t i
= 0; i
< socket_handlers_
.size();) {
155 int rv
= socket_handlers_
[i
]->Start();
157 socket_handlers_
.erase(socket_handlers_
.begin() + i
);
158 VLOG(1) << "Start failed, socket=" << i
<< ", error=" << rv
;
163 VLOG(1) << "Sockets ready:" << socket_handlers_
.size();
164 return !socket_handlers_
.empty();
167 void MDnsConnection::Send(const scoped_refptr
<IOBuffer
>& buffer
,
169 for (size_t i
= 0; i
< socket_handlers_
.size(); ++i
)
170 socket_handlers_
[i
]->Send(buffer
, size
);
173 void MDnsConnection::PostOnError(SocketHandler
* loop
, int rv
) {
174 VLOG(1) << "Socket error. id="
175 << std::find(socket_handlers_
.begin(), socket_handlers_
.end(), loop
) -
176 socket_handlers_
.begin() << ", error=" << rv
;
177 // Post to allow deletion of this object by delegate.
178 base::MessageLoop::current()->PostTask(
180 base::Bind(&MDnsConnection::OnError
, weak_ptr_factory_
.GetWeakPtr(), rv
));
183 void MDnsConnection::OnError(int rv
) {
184 // TODO(noamsml): Specific handling of intermittent errors that can be handled
185 // in the connection.
186 delegate_
->OnConnectionError(rv
);
189 void MDnsConnection::OnDatagramReceived(
190 DnsResponse
* response
,
191 const IPEndPoint
& recv_addr
,
193 // TODO(noamsml): More sophisticated error handling.
194 DCHECK_GT(bytes_read
, 0);
195 delegate_
->HandlePacket(response
, bytes_read
);
198 MDnsClientImpl::Core::Core() : connection_(new MDnsConnection(this)) {
201 MDnsClientImpl::Core::~Core() {
202 STLDeleteValues(&listeners_
);
205 bool MDnsClientImpl::Core::Init(MDnsSocketFactory
* socket_factory
) {
206 return connection_
->Init(socket_factory
);
209 bool MDnsClientImpl::Core::SendQuery(uint16 rrtype
, std::string name
) {
210 std::string name_dns
;
211 if (!DNSDomainFromDot(name
, &name_dns
))
214 DnsQuery
query(0, name_dns
, rrtype
);
215 query
.set_flags(0); // Remove the RD flag from the query. It is unneeded.
217 connection_
->Send(query
.io_buffer(), query
.io_buffer()->size());
221 void MDnsClientImpl::Core::HandlePacket(DnsResponse
* response
,
224 // Note: We store cache keys rather than record pointers to avoid
225 // erroneous behavior in case a packet contains multiple exclusive
226 // records with the same type and name.
227 std::map
<MDnsCache::Key
, MDnsCache::UpdateType
> update_keys
;
229 if (!response
->InitParseWithoutQuery(bytes_read
)) {
230 DVLOG(1) << "Could not understand an mDNS packet.";
231 return; // Message is unreadable.
234 // TODO(noamsml): duplicate query suppression.
235 if (!(response
->flags() & dns_protocol::kFlagResponse
))
236 return; // Message is a query. ignore it.
238 DnsRecordParser parser
= response
->Parser();
239 unsigned answer_count
= response
->answer_count() +
240 response
->additional_answer_count();
242 for (unsigned i
= 0; i
< answer_count
; i
++) {
243 offset
= parser
.GetOffset();
244 scoped_ptr
<const RecordParsed
> record
= RecordParsed::CreateFrom(
245 &parser
, base::Time::Now());
248 DVLOG(1) << "Could not understand an mDNS record.";
250 if (offset
== parser
.GetOffset()) {
251 DVLOG(1) << "Abandoned parsing the rest of the packet.";
252 return; // The parser did not advance, abort reading the packet.
254 continue; // We may be able to extract other records from the packet.
258 if ((record
->klass() & dns_protocol::kMDnsClassMask
) !=
259 dns_protocol::kClassIN
) {
260 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
261 continue; // Ignore all records not in the IN class.
264 MDnsCache::Key update_key
= MDnsCache::Key::CreateFor(record
.get());
265 MDnsCache::UpdateType update
= cache_
.UpdateDnsRecord(record
.Pass());
267 // Cleanup time may have changed.
268 ScheduleCleanup(cache_
.next_expiration());
270 update_keys
.insert(std::make_pair(update_key
, update
));
273 for (std::map
<MDnsCache::Key
, MDnsCache::UpdateType
>::iterator i
=
274 update_keys
.begin(); i
!= update_keys
.end(); i
++) {
275 const RecordParsed
* record
= cache_
.LookupKey(i
->first
);
279 if (record
->type() == dns_protocol::kTypeNSEC
) {
280 #if defined(ENABLE_NSEC)
281 NotifyNsecRecord(record
);
284 AlertListeners(i
->second
, ListenerKey(record
->name(), record
->type()),
290 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed
* record
) {
291 DCHECK_EQ(dns_protocol::kTypeNSEC
, record
->type());
292 const NsecRecordRdata
* rdata
= record
->rdata
<NsecRecordRdata
>();
295 // Remove all cached records matching the nonexistent RR types.
296 std::vector
<const RecordParsed
*> records_to_remove
;
298 cache_
.FindDnsRecords(0, record
->name(), &records_to_remove
,
301 for (std::vector
<const RecordParsed
*>::iterator i
= records_to_remove
.begin();
302 i
!= records_to_remove
.end(); i
++) {
303 if ((*i
)->type() == dns_protocol::kTypeNSEC
)
305 if (!rdata
->GetBit((*i
)->type())) {
306 scoped_ptr
<const RecordParsed
> record_removed
= cache_
.RemoveRecord((*i
));
307 DCHECK(record_removed
);
308 OnRecordRemoved(record_removed
.get());
312 // Alert all listeners waiting for the nonexistent RR types.
313 ListenerMap::iterator i
=
314 listeners_
.upper_bound(ListenerKey(record
->name(), 0));
315 for (; i
!= listeners_
.end() && i
->first
.first
== record
->name(); i
++) {
316 if (!rdata
->GetBit(i
->first
.second
)) {
317 FOR_EACH_OBSERVER(MDnsListenerImpl
, *i
->second
, AlertNsecRecord());
322 void MDnsClientImpl::Core::OnConnectionError(int error
) {
323 // TODO(noamsml): On connection error, recreate connection and flush cache.
326 void MDnsClientImpl::Core::AlertListeners(
327 MDnsCache::UpdateType update_type
,
328 const ListenerKey
& key
,
329 const RecordParsed
* record
) {
330 ListenerMap::iterator listener_map_iterator
= listeners_
.find(key
);
331 if (listener_map_iterator
== listeners_
.end()) return;
333 FOR_EACH_OBSERVER(MDnsListenerImpl
, *listener_map_iterator
->second
,
334 HandleRecordUpdate(update_type
, record
));
337 void MDnsClientImpl::Core::AddListener(
338 MDnsListenerImpl
* listener
) {
339 ListenerKey
key(listener
->GetName(), listener
->GetType());
340 std::pair
<ListenerMap::iterator
, bool> observer_insert_result
=
342 make_pair(key
, static_cast<ObserverList
<MDnsListenerImpl
>*>(NULL
)));
344 // If an equivalent key does not exist, actually create the observer list.
345 if (observer_insert_result
.second
)
346 observer_insert_result
.first
->second
= new ObserverList
<MDnsListenerImpl
>();
348 ObserverList
<MDnsListenerImpl
>* observer_list
=
349 observer_insert_result
.first
->second
;
351 observer_list
->AddObserver(listener
);
354 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl
* listener
) {
355 ListenerKey
key(listener
->GetName(), listener
->GetType());
356 ListenerMap::iterator observer_list_iterator
= listeners_
.find(key
);
358 DCHECK(observer_list_iterator
!= listeners_
.end());
359 DCHECK(observer_list_iterator
->second
->HasObserver(listener
));
361 observer_list_iterator
->second
->RemoveObserver(listener
);
363 // Remove the observer list from the map if it is empty
364 if (!observer_list_iterator
->second
->might_have_observers()) {
365 // Schedule the actual removal for later in case the listener removal
366 // happens while iterating over the observer list.
367 base::MessageLoop::current()->PostTask(
368 FROM_HERE
, base::Bind(
369 &MDnsClientImpl::Core::CleanupObserverList
, AsWeakPtr(), key
));
373 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey
& key
) {
374 ListenerMap::iterator found
= listeners_
.find(key
);
375 if (found
!= listeners_
.end() && !found
->second
->might_have_observers()) {
376 delete found
->second
;
377 listeners_
.erase(found
);
381 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup
) {
382 // Cleanup is already scheduled, no need to do anything.
383 if (cleanup
== scheduled_cleanup_
) return;
384 scheduled_cleanup_
= cleanup
;
386 // This cancels the previously scheduled cleanup.
387 cleanup_callback_
.Reset(base::Bind(
388 &MDnsClientImpl::Core::DoCleanup
, base::Unretained(this)));
390 // If |cleanup| is empty, then no cleanup necessary.
391 if (cleanup
!= base::Time()) {
392 base::MessageLoop::current()->PostDelayedTask(
394 cleanup_callback_
.callback(),
395 cleanup
- base::Time::Now());
399 void MDnsClientImpl::Core::DoCleanup() {
400 cache_
.CleanupRecords(base::Time::Now(), base::Bind(
401 &MDnsClientImpl::Core::OnRecordRemoved
, base::Unretained(this)));
403 ScheduleCleanup(cache_
.next_expiration());
406 void MDnsClientImpl::Core::OnRecordRemoved(
407 const RecordParsed
* record
) {
408 AlertListeners(MDnsCache::RecordRemoved
,
409 ListenerKey(record
->name(), record
->type()), record
);
412 void MDnsClientImpl::Core::QueryCache(
413 uint16 rrtype
, const std::string
& name
,
414 std::vector
<const RecordParsed
*>* records
) const {
415 cache_
.FindDnsRecords(rrtype
, name
, records
, base::Time::Now());
418 MDnsClientImpl::MDnsClientImpl() {
421 MDnsClientImpl::~MDnsClientImpl() {
424 bool MDnsClientImpl::StartListening(MDnsSocketFactory
* socket_factory
) {
425 DCHECK(!core_
.get());
426 core_
.reset(new Core());
427 if (!core_
->Init(socket_factory
)) {
434 void MDnsClientImpl::StopListening() {
438 bool MDnsClientImpl::IsListening() const {
439 return core_
.get() != NULL
;
442 scoped_ptr
<MDnsListener
> MDnsClientImpl::CreateListener(
444 const std::string
& name
,
445 MDnsListener::Delegate
* delegate
) {
446 return scoped_ptr
<net::MDnsListener
>(
447 new MDnsListenerImpl(rrtype
, name
, delegate
, this));
450 scoped_ptr
<MDnsTransaction
> MDnsClientImpl::CreateTransaction(
452 const std::string
& name
,
454 const MDnsTransaction::ResultCallback
& callback
) {
455 return scoped_ptr
<MDnsTransaction
>(
456 new MDnsTransactionImpl(rrtype
, name
, flags
, callback
, this));
459 MDnsListenerImpl::MDnsListenerImpl(
461 const std::string
& name
,
462 MDnsListener::Delegate
* delegate
,
463 MDnsClientImpl
* client
)
464 : rrtype_(rrtype
), name_(name
), client_(client
), delegate_(delegate
),
465 started_(false), active_refresh_(false) {
468 MDnsListenerImpl::~MDnsListenerImpl() {
470 DCHECK(client_
->core());
471 client_
->core()->RemoveListener(this);
475 bool MDnsListenerImpl::Start() {
480 DCHECK(client_
->core());
481 client_
->core()->AddListener(this);
486 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh
) {
487 active_refresh_
= active_refresh
;
490 if (!active_refresh_
) {
491 next_refresh_
.Cancel();
492 } else if (last_update_
!= base::Time()) {
493 ScheduleNextRefresh();
498 const std::string
& MDnsListenerImpl::GetName() const {
502 uint16
MDnsListenerImpl::GetType() const {
506 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type
,
507 const RecordParsed
* record
) {
510 if (update_type
!= MDnsCache::RecordRemoved
) {
511 ttl_
= record
->ttl();
512 last_update_
= record
->time_created();
514 ScheduleNextRefresh();
517 if (update_type
!= MDnsCache::NoChange
) {
518 MDnsListener::UpdateType update_external
;
520 switch (update_type
) {
521 case MDnsCache::RecordAdded
:
522 update_external
= MDnsListener::RECORD_ADDED
;
524 case MDnsCache::RecordChanged
:
525 update_external
= MDnsListener::RECORD_CHANGED
;
527 case MDnsCache::RecordRemoved
:
528 update_external
= MDnsListener::RECORD_REMOVED
;
530 case MDnsCache::NoChange
:
533 // Dummy assignment to suppress compiler warning.
534 update_external
= MDnsListener::RECORD_CHANGED
;
538 delegate_
->OnRecordUpdate(update_external
, record
);
542 void MDnsListenerImpl::AlertNsecRecord() {
544 delegate_
->OnNsecRecord(name_
, rrtype_
);
547 void MDnsListenerImpl::ScheduleNextRefresh() {
548 DCHECK(last_update_
!= base::Time());
550 if (!active_refresh_
)
553 // A zero TTL is a goodbye packet and should not be refreshed.
555 next_refresh_
.Cancel();
559 next_refresh_
.Reset(base::Bind(&MDnsListenerImpl::DoRefresh
,
562 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
563 // be canceled and rescheduled if the record's TTL is updated due to a
564 // response being received.
565 base::Time next_refresh1
= last_update_
+ base::TimeDelta::FromMilliseconds(
566 static_cast<int>(base::Time::kMillisecondsPerSecond
*
567 kListenerRefreshRatio1
* ttl_
));
569 base::Time next_refresh2
= last_update_
+ base::TimeDelta::FromMilliseconds(
570 static_cast<int>(base::Time::kMillisecondsPerSecond
*
571 kListenerRefreshRatio2
* ttl_
));
573 base::MessageLoop::current()->PostDelayedTask(
575 next_refresh_
.callback(),
576 next_refresh1
- base::Time::Now());
578 base::MessageLoop::current()->PostDelayedTask(
580 next_refresh_
.callback(),
581 next_refresh2
- base::Time::Now());
584 void MDnsListenerImpl::DoRefresh() {
585 client_
->core()->SendQuery(rrtype_
, name_
);
588 MDnsTransactionImpl::MDnsTransactionImpl(
590 const std::string
& name
,
592 const MDnsTransaction::ResultCallback
& callback
,
593 MDnsClientImpl
* client
)
594 : rrtype_(rrtype
), name_(name
), callback_(callback
), client_(client
),
595 started_(false), flags_(flags
) {
596 DCHECK((flags_
& MDnsTransaction::FLAG_MASK
) == flags_
);
597 DCHECK(flags_
& MDnsTransaction::QUERY_CACHE
||
598 flags_
& MDnsTransaction::QUERY_NETWORK
);
601 MDnsTransactionImpl::~MDnsTransactionImpl() {
605 bool MDnsTransactionImpl::Start() {
609 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
610 if (flags_
& MDnsTransaction::QUERY_CACHE
) {
611 ServeRecordsFromCache();
613 if (!weak_this
|| !is_active()) return true;
616 if (flags_
& MDnsTransaction::QUERY_NETWORK
) {
617 return QueryAndListen();
620 // If this is a cache only query, signal that the transaction is over
622 SignalTransactionOver();
626 const std::string
& MDnsTransactionImpl::GetName() const {
630 uint16
MDnsTransactionImpl::GetType() const {
634 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed
* record
) {
636 OnRecordUpdate(MDnsListener::RECORD_ADDED
, record
);
639 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result
,
640 const RecordParsed
* record
) {
642 if (!is_active()) return;
644 // Ensure callback is run after touching all class state, so that
645 // the callback can delete the transaction.
646 MDnsTransaction::ResultCallback callback
= callback_
;
648 // Reset the transaction if it expects a single result, or if the result
649 // is a final one (everything except for a record).
650 if (flags_
& MDnsTransaction::SINGLE_RESULT
||
651 result
!= MDnsTransaction::RESULT_RECORD
) {
655 callback
.Run(result
, record
);
658 void MDnsTransactionImpl::Reset() {
664 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update
,
665 const RecordParsed
* record
) {
667 if (update
== MDnsListener::RECORD_ADDED
||
668 update
== MDnsListener::RECORD_CHANGED
)
669 TriggerCallback(MDnsTransaction::RESULT_RECORD
, record
);
672 void MDnsTransactionImpl::SignalTransactionOver() {
674 if (flags_
& MDnsTransaction::SINGLE_RESULT
) {
675 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS
, NULL
);
677 TriggerCallback(MDnsTransaction::RESULT_DONE
, NULL
);
681 void MDnsTransactionImpl::ServeRecordsFromCache() {
682 std::vector
<const RecordParsed
*> records
;
683 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
685 if (client_
->core()) {
686 client_
->core()->QueryCache(rrtype_
, name_
, &records
);
687 for (std::vector
<const RecordParsed
*>::iterator i
= records
.begin();
688 i
!= records
.end() && weak_this
; ++i
) {
689 weak_this
->TriggerCallback(MDnsTransaction::RESULT_RECORD
, *i
);
692 #if defined(ENABLE_NSEC)
693 if (records
.empty()) {
695 client_
->core()->QueryCache(dns_protocol::kTypeNSEC
, name_
, &records
);
696 if (!records
.empty()) {
697 const NsecRecordRdata
* rdata
=
698 records
.front()->rdata
<NsecRecordRdata
>();
700 if (!rdata
->GetBit(rrtype_
))
701 weak_this
->TriggerCallback(MDnsTransaction::RESULT_NSEC
, NULL
);
708 bool MDnsTransactionImpl::QueryAndListen() {
709 listener_
= client_
->CreateListener(rrtype_
, name_
, this);
710 if (!listener_
->Start())
713 DCHECK(client_
->core());
714 if (!client_
->core()->SendQuery(rrtype_
, name_
))
717 timeout_
.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver
,
719 base::MessageLoop::current()->PostDelayedTask(
722 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds
));
727 void MDnsTransactionImpl::OnNsecRecord(const std::string
& name
, unsigned type
) {
728 TriggerCallback(RESULT_NSEC
, NULL
);
731 void MDnsTransactionImpl::OnCachePurged() {
732 // TODO(noamsml): Cache purge situations not yet implemented