1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/mdns_client_impl.h"
10 #include "base/message_loop/message_loop_proxy.h"
11 #include "base/stl_util.h"
12 #include "base/time/default_clock.h"
13 #include "base/time/time.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/net_errors.h"
16 #include "net/base/net_log.h"
17 #include "net/base/rand_callback.h"
18 #include "net/dns/dns_protocol.h"
19 #include "net/dns/record_rdata.h"
20 #include "net/udp/datagram_socket.h"
22 // TODO(gene): Remove this temporary method of disabling NSEC support once it
23 // becomes clear whether this feature should be
24 // supported. http://crbug.com/255232
31 const unsigned MDnsTransactionTimeoutSeconds
= 3;
32 // The fractions of the record's original TTL after which an active listener
33 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
34 // its cache. This happens both at 85% of the original TTL and again at 95% of
36 const double kListenerRefreshRatio1
= 0.85;
37 const double kListenerRefreshRatio2
= 0.95;
41 void MDnsSocketFactoryImpl::CreateSockets(
42 ScopedVector
<DatagramServerSocket
>* sockets
) {
43 InterfaceIndexFamilyList
interfaces(GetMDnsInterfacesToBind());
44 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
45 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
46 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
47 scoped_ptr
<DatagramServerSocket
> socket(
48 CreateAndBindMDnsSocket(interfaces
[i
].second
, interfaces
[i
].first
));
50 sockets
->push_back(socket
.release());
54 MDnsConnection::SocketHandler::SocketHandler(
55 scoped_ptr
<DatagramServerSocket
> socket
,
56 MDnsConnection
* connection
)
57 : socket_(socket
.Pass()),
58 connection_(connection
),
59 response_(dns_protocol::kMaxMulticastSize
),
60 send_in_progress_(false) {
63 MDnsConnection::SocketHandler::~SocketHandler() {
66 int MDnsConnection::SocketHandler::Start() {
68 int rv
= socket_
->GetLocalAddress(&end_point
);
71 DCHECK(end_point
.GetFamily() == ADDRESS_FAMILY_IPV4
||
72 end_point
.GetFamily() == ADDRESS_FAMILY_IPV6
);
73 multicast_addr_
= GetMDnsIPEndPoint(end_point
.GetFamily());
77 int MDnsConnection::SocketHandler::DoLoop(int rv
) {
80 connection_
->OnDatagramReceived(&response_
, recv_addr_
, rv
);
82 rv
= socket_
->RecvFrom(
83 response_
.io_buffer(),
84 response_
.io_buffer()->size(),
86 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived
,
87 base::Unretained(this)));
90 if (rv
!= ERR_IO_PENDING
)
96 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv
) {
101 connection_
->PostOnError(this, rv
);
104 void MDnsConnection::SocketHandler::Send(const scoped_refptr
<IOBuffer
>& buffer
,
106 if (send_in_progress_
) {
107 send_queue_
.push(std::make_pair(buffer
, size
));
110 int rv
= socket_
->SendTo(buffer
.get(),
113 base::Bind(&MDnsConnection::SocketHandler::SendDone
,
114 base::Unretained(this)));
115 if (rv
== ERR_IO_PENDING
) {
116 send_in_progress_
= true;
117 } else if (rv
< OK
) {
118 connection_
->PostOnError(this, rv
);
122 void MDnsConnection::SocketHandler::SendDone(int rv
) {
123 DCHECK(send_in_progress_
);
124 send_in_progress_
= false;
126 connection_
->PostOnError(this, rv
);
127 while (!send_in_progress_
&& !send_queue_
.empty()) {
128 std::pair
<scoped_refptr
<IOBuffer
>, unsigned> buffer
= send_queue_
.front();
130 Send(buffer
.first
, buffer
.second
);
134 MDnsConnection::MDnsConnection(MDnsConnection::Delegate
* delegate
)
135 : delegate_(delegate
), weak_ptr_factory_(this) {
138 MDnsConnection::~MDnsConnection() {
141 bool MDnsConnection::Init(MDnsSocketFactory
* socket_factory
) {
142 ScopedVector
<DatagramServerSocket
> sockets
;
143 socket_factory
->CreateSockets(&sockets
);
145 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
146 socket_handlers_
.push_back(
147 new MDnsConnection::SocketHandler(make_scoped_ptr(sockets
[i
]), this));
149 sockets
.weak_clear();
151 // All unbound sockets need to be bound before processing untrusted input.
152 // This is done for security reasons, so that an attacker can't get an unbound
154 for (size_t i
= 0; i
< socket_handlers_
.size();) {
155 int rv
= socket_handlers_
[i
]->Start();
157 socket_handlers_
.erase(socket_handlers_
.begin() + i
);
158 VLOG(1) << "Start failed, socket=" << i
<< ", error=" << rv
;
163 VLOG(1) << "Sockets ready:" << socket_handlers_
.size();
164 return !socket_handlers_
.empty();
167 void MDnsConnection::Send(const scoped_refptr
<IOBuffer
>& buffer
,
169 for (size_t i
= 0; i
< socket_handlers_
.size(); ++i
)
170 socket_handlers_
[i
]->Send(buffer
, size
);
173 void MDnsConnection::PostOnError(SocketHandler
* loop
, int rv
) {
174 VLOG(1) << "Socket error. id="
175 << std::find(socket_handlers_
.begin(), socket_handlers_
.end(), loop
) -
176 socket_handlers_
.begin() << ", error=" << rv
;
177 // Post to allow deletion of this object by delegate.
178 base::MessageLoop::current()->PostTask(
180 base::Bind(&MDnsConnection::OnError
, weak_ptr_factory_
.GetWeakPtr(), rv
));
183 void MDnsConnection::OnError(int rv
) {
184 // TODO(noamsml): Specific handling of intermittent errors that can be handled
185 // in the connection.
186 delegate_
->OnConnectionError(rv
);
189 void MDnsConnection::OnDatagramReceived(
190 DnsResponse
* response
,
191 const IPEndPoint
& recv_addr
,
193 // TODO(noamsml): More sophisticated error handling.
194 DCHECK_GT(bytes_read
, 0);
195 delegate_
->HandlePacket(response
, bytes_read
);
198 MDnsClientImpl::Core::Core(MDnsClientImpl
* client
)
199 : client_(client
), connection_(new MDnsConnection(this)) {
202 MDnsClientImpl::Core::~Core() {
203 STLDeleteValues(&listeners_
);
206 bool MDnsClientImpl::Core::Init(MDnsSocketFactory
* socket_factory
) {
207 return connection_
->Init(socket_factory
);
210 bool MDnsClientImpl::Core::SendQuery(uint16 rrtype
, std::string name
) {
211 std::string name_dns
;
212 if (!DNSDomainFromDot(name
, &name_dns
))
215 DnsQuery
query(0, name_dns
, rrtype
);
216 query
.set_flags(0); // Remove the RD flag from the query. It is unneeded.
218 connection_
->Send(query
.io_buffer(), query
.io_buffer()->size());
222 void MDnsClientImpl::Core::HandlePacket(DnsResponse
* response
,
225 // Note: We store cache keys rather than record pointers to avoid
226 // erroneous behavior in case a packet contains multiple exclusive
227 // records with the same type and name.
228 std::map
<MDnsCache::Key
, MDnsCache::UpdateType
> update_keys
;
230 if (!response
->InitParseWithoutQuery(bytes_read
)) {
231 DVLOG(1) << "Could not understand an mDNS packet.";
232 return; // Message is unreadable.
235 // TODO(noamsml): duplicate query suppression.
236 if (!(response
->flags() & dns_protocol::kFlagResponse
))
237 return; // Message is a query. ignore it.
239 DnsRecordParser parser
= response
->Parser();
240 unsigned answer_count
= response
->answer_count() +
241 response
->additional_answer_count();
243 for (unsigned i
= 0; i
< answer_count
; i
++) {
244 offset
= parser
.GetOffset();
245 scoped_ptr
<const RecordParsed
> record
= RecordParsed::CreateFrom(
246 &parser
, base::Time::Now());
249 DVLOG(1) << "Could not understand an mDNS record.";
251 if (offset
== parser
.GetOffset()) {
252 DVLOG(1) << "Abandoned parsing the rest of the packet.";
253 return; // The parser did not advance, abort reading the packet.
255 continue; // We may be able to extract other records from the packet.
259 if ((record
->klass() & dns_protocol::kMDnsClassMask
) !=
260 dns_protocol::kClassIN
) {
261 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
262 continue; // Ignore all records not in the IN class.
265 MDnsCache::Key update_key
= MDnsCache::Key::CreateFor(record
.get());
266 MDnsCache::UpdateType update
= cache_
.UpdateDnsRecord(record
.Pass());
268 // Cleanup time may have changed.
269 ScheduleCleanup(cache_
.next_expiration());
271 update_keys
.insert(std::make_pair(update_key
, update
));
274 for (std::map
<MDnsCache::Key
, MDnsCache::UpdateType
>::iterator i
=
275 update_keys
.begin(); i
!= update_keys
.end(); i
++) {
276 const RecordParsed
* record
= cache_
.LookupKey(i
->first
);
280 if (record
->type() == dns_protocol::kTypeNSEC
) {
281 #if defined(ENABLE_NSEC)
282 NotifyNsecRecord(record
);
285 AlertListeners(i
->second
, ListenerKey(record
->name(), record
->type()),
291 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed
* record
) {
292 DCHECK_EQ(dns_protocol::kTypeNSEC
, record
->type());
293 const NsecRecordRdata
* rdata
= record
->rdata
<NsecRecordRdata
>();
296 // Remove all cached records matching the nonexistent RR types.
297 std::vector
<const RecordParsed
*> records_to_remove
;
299 cache_
.FindDnsRecords(0, record
->name(), &records_to_remove
,
302 for (std::vector
<const RecordParsed
*>::iterator i
= records_to_remove
.begin();
303 i
!= records_to_remove
.end(); i
++) {
304 if ((*i
)->type() == dns_protocol::kTypeNSEC
)
306 if (!rdata
->GetBit((*i
)->type())) {
307 scoped_ptr
<const RecordParsed
> record_removed
= cache_
.RemoveRecord((*i
));
308 DCHECK(record_removed
);
309 OnRecordRemoved(record_removed
.get());
313 // Alert all listeners waiting for the nonexistent RR types.
314 ListenerMap::iterator i
=
315 listeners_
.upper_bound(ListenerKey(record
->name(), 0));
316 for (; i
!= listeners_
.end() && i
->first
.first
== record
->name(); i
++) {
317 if (!rdata
->GetBit(i
->first
.second
)) {
318 FOR_EACH_OBSERVER(MDnsListenerImpl
, *i
->second
, AlertNsecRecord());
323 void MDnsClientImpl::Core::OnConnectionError(int error
) {
324 // TODO(noamsml): On connection error, recreate connection and flush cache.
327 void MDnsClientImpl::Core::AlertListeners(
328 MDnsCache::UpdateType update_type
,
329 const ListenerKey
& key
,
330 const RecordParsed
* record
) {
331 ListenerMap::iterator listener_map_iterator
= listeners_
.find(key
);
332 if (listener_map_iterator
== listeners_
.end()) return;
334 FOR_EACH_OBSERVER(MDnsListenerImpl
, *listener_map_iterator
->second
,
335 HandleRecordUpdate(update_type
, record
));
338 void MDnsClientImpl::Core::AddListener(
339 MDnsListenerImpl
* listener
) {
340 ListenerKey
key(listener
->GetName(), listener
->GetType());
341 std::pair
<ListenerMap::iterator
, bool> observer_insert_result
=
343 make_pair(key
, static_cast<ObserverList
<MDnsListenerImpl
>*>(NULL
)));
345 // If an equivalent key does not exist, actually create the observer list.
346 if (observer_insert_result
.second
)
347 observer_insert_result
.first
->second
= new ObserverList
<MDnsListenerImpl
>();
349 ObserverList
<MDnsListenerImpl
>* observer_list
=
350 observer_insert_result
.first
->second
;
352 observer_list
->AddObserver(listener
);
355 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl
* listener
) {
356 ListenerKey
key(listener
->GetName(), listener
->GetType());
357 ListenerMap::iterator observer_list_iterator
= listeners_
.find(key
);
359 DCHECK(observer_list_iterator
!= listeners_
.end());
360 DCHECK(observer_list_iterator
->second
->HasObserver(listener
));
362 observer_list_iterator
->second
->RemoveObserver(listener
);
364 // Remove the observer list from the map if it is empty
365 if (!observer_list_iterator
->second
->might_have_observers()) {
366 // Schedule the actual removal for later in case the listener removal
367 // happens while iterating over the observer list.
368 base::MessageLoop::current()->PostTask(
369 FROM_HERE
, base::Bind(
370 &MDnsClientImpl::Core::CleanupObserverList
, AsWeakPtr(), key
));
374 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey
& key
) {
375 ListenerMap::iterator found
= listeners_
.find(key
);
376 if (found
!= listeners_
.end() && !found
->second
->might_have_observers()) {
377 delete found
->second
;
378 listeners_
.erase(found
);
382 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup
) {
383 // Cleanup is already scheduled, no need to do anything.
384 if (cleanup
== scheduled_cleanup_
) return;
385 scheduled_cleanup_
= cleanup
;
387 // This cancels the previously scheduled cleanup.
388 cleanup_callback_
.Reset(base::Bind(
389 &MDnsClientImpl::Core::DoCleanup
, base::Unretained(this)));
391 // If |cleanup| is empty, then no cleanup necessary.
392 if (cleanup
!= base::Time()) {
393 base::MessageLoop::current()->PostDelayedTask(
395 cleanup_callback_
.callback(),
396 cleanup
- base::Time::Now());
400 void MDnsClientImpl::Core::DoCleanup() {
401 cache_
.CleanupRecords(base::Time::Now(), base::Bind(
402 &MDnsClientImpl::Core::OnRecordRemoved
, base::Unretained(this)));
404 ScheduleCleanup(cache_
.next_expiration());
407 void MDnsClientImpl::Core::OnRecordRemoved(
408 const RecordParsed
* record
) {
409 AlertListeners(MDnsCache::RecordRemoved
,
410 ListenerKey(record
->name(), record
->type()), record
);
413 void MDnsClientImpl::Core::QueryCache(
414 uint16 rrtype
, const std::string
& name
,
415 std::vector
<const RecordParsed
*>* records
) const {
416 cache_
.FindDnsRecords(rrtype
, name
, records
, base::Time::Now());
419 MDnsClientImpl::MDnsClientImpl() {
422 MDnsClientImpl::~MDnsClientImpl() {
425 bool MDnsClientImpl::StartListening(MDnsSocketFactory
* socket_factory
) {
426 DCHECK(!core_
.get());
427 core_
.reset(new Core(this));
428 if (!core_
->Init(socket_factory
)) {
435 void MDnsClientImpl::StopListening() {
439 bool MDnsClientImpl::IsListening() const {
440 return core_
.get() != NULL
;
443 scoped_ptr
<MDnsListener
> MDnsClientImpl::CreateListener(
445 const std::string
& name
,
446 MDnsListener::Delegate
* delegate
) {
447 return scoped_ptr
<net::MDnsListener
>(
448 new MDnsListenerImpl(rrtype
, name
, delegate
, this));
451 scoped_ptr
<MDnsTransaction
> MDnsClientImpl::CreateTransaction(
453 const std::string
& name
,
455 const MDnsTransaction::ResultCallback
& callback
) {
456 return scoped_ptr
<MDnsTransaction
>(
457 new MDnsTransactionImpl(rrtype
, name
, flags
, callback
, this));
460 MDnsListenerImpl::MDnsListenerImpl(
462 const std::string
& name
,
463 MDnsListener::Delegate
* delegate
,
464 MDnsClientImpl
* client
)
465 : rrtype_(rrtype
), name_(name
), client_(client
), delegate_(delegate
),
466 started_(false), active_refresh_(false) {
469 MDnsListenerImpl::~MDnsListenerImpl() {
471 DCHECK(client_
->core());
472 client_
->core()->RemoveListener(this);
476 bool MDnsListenerImpl::Start() {
481 DCHECK(client_
->core());
482 client_
->core()->AddListener(this);
487 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh
) {
488 active_refresh_
= active_refresh
;
491 if (!active_refresh_
) {
492 next_refresh_
.Cancel();
493 } else if (last_update_
!= base::Time()) {
494 ScheduleNextRefresh();
499 const std::string
& MDnsListenerImpl::GetName() const {
503 uint16
MDnsListenerImpl::GetType() const {
507 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type
,
508 const RecordParsed
* record
) {
511 if (update_type
!= MDnsCache::RecordRemoved
) {
512 ttl_
= record
->ttl();
513 last_update_
= record
->time_created();
515 ScheduleNextRefresh();
518 if (update_type
!= MDnsCache::NoChange
) {
519 MDnsListener::UpdateType update_external
;
521 switch (update_type
) {
522 case MDnsCache::RecordAdded
:
523 update_external
= MDnsListener::RECORD_ADDED
;
525 case MDnsCache::RecordChanged
:
526 update_external
= MDnsListener::RECORD_CHANGED
;
528 case MDnsCache::RecordRemoved
:
529 update_external
= MDnsListener::RECORD_REMOVED
;
531 case MDnsCache::NoChange
:
534 // Dummy assignment to suppress compiler warning.
535 update_external
= MDnsListener::RECORD_CHANGED
;
539 delegate_
->OnRecordUpdate(update_external
, record
);
543 void MDnsListenerImpl::AlertNsecRecord() {
545 delegate_
->OnNsecRecord(name_
, rrtype_
);
548 void MDnsListenerImpl::ScheduleNextRefresh() {
549 DCHECK(last_update_
!= base::Time());
551 if (!active_refresh_
)
554 // A zero TTL is a goodbye packet and should not be refreshed.
556 next_refresh_
.Cancel();
560 next_refresh_
.Reset(base::Bind(&MDnsListenerImpl::DoRefresh
,
563 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
564 // be canceled and rescheduled if the record's TTL is updated due to a
565 // response being received.
566 base::Time next_refresh1
= last_update_
+ base::TimeDelta::FromMilliseconds(
567 static_cast<int>(base::Time::kMillisecondsPerSecond
*
568 kListenerRefreshRatio1
* ttl_
));
570 base::Time next_refresh2
= last_update_
+ base::TimeDelta::FromMilliseconds(
571 static_cast<int>(base::Time::kMillisecondsPerSecond
*
572 kListenerRefreshRatio2
* ttl_
));
574 base::MessageLoop::current()->PostDelayedTask(
576 next_refresh_
.callback(),
577 next_refresh1
- base::Time::Now());
579 base::MessageLoop::current()->PostDelayedTask(
581 next_refresh_
.callback(),
582 next_refresh2
- base::Time::Now());
585 void MDnsListenerImpl::DoRefresh() {
586 client_
->core()->SendQuery(rrtype_
, name_
);
589 MDnsTransactionImpl::MDnsTransactionImpl(
591 const std::string
& name
,
593 const MDnsTransaction::ResultCallback
& callback
,
594 MDnsClientImpl
* client
)
595 : rrtype_(rrtype
), name_(name
), callback_(callback
), client_(client
),
596 started_(false), flags_(flags
) {
597 DCHECK((flags_
& MDnsTransaction::FLAG_MASK
) == flags_
);
598 DCHECK(flags_
& MDnsTransaction::QUERY_CACHE
||
599 flags_
& MDnsTransaction::QUERY_NETWORK
);
602 MDnsTransactionImpl::~MDnsTransactionImpl() {
606 bool MDnsTransactionImpl::Start() {
610 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
611 if (flags_
& MDnsTransaction::QUERY_CACHE
) {
612 ServeRecordsFromCache();
614 if (!weak_this
|| !is_active()) return true;
617 if (flags_
& MDnsTransaction::QUERY_NETWORK
) {
618 return QueryAndListen();
621 // If this is a cache only query, signal that the transaction is over
623 SignalTransactionOver();
627 const std::string
& MDnsTransactionImpl::GetName() const {
631 uint16
MDnsTransactionImpl::GetType() const {
635 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed
* record
) {
637 OnRecordUpdate(MDnsListener::RECORD_ADDED
, record
);
640 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result
,
641 const RecordParsed
* record
) {
643 if (!is_active()) return;
645 // Ensure callback is run after touching all class state, so that
646 // the callback can delete the transaction.
647 MDnsTransaction::ResultCallback callback
= callback_
;
649 // Reset the transaction if it expects a single result, or if the result
650 // is a final one (everything except for a record).
651 if (flags_
& MDnsTransaction::SINGLE_RESULT
||
652 result
!= MDnsTransaction::RESULT_RECORD
) {
656 callback
.Run(result
, record
);
659 void MDnsTransactionImpl::Reset() {
665 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update
,
666 const RecordParsed
* record
) {
668 if (update
== MDnsListener::RECORD_ADDED
||
669 update
== MDnsListener::RECORD_CHANGED
)
670 TriggerCallback(MDnsTransaction::RESULT_RECORD
, record
);
673 void MDnsTransactionImpl::SignalTransactionOver() {
675 if (flags_
& MDnsTransaction::SINGLE_RESULT
) {
676 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS
, NULL
);
678 TriggerCallback(MDnsTransaction::RESULT_DONE
, NULL
);
682 void MDnsTransactionImpl::ServeRecordsFromCache() {
683 std::vector
<const RecordParsed
*> records
;
684 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
686 if (client_
->core()) {
687 client_
->core()->QueryCache(rrtype_
, name_
, &records
);
688 for (std::vector
<const RecordParsed
*>::iterator i
= records
.begin();
689 i
!= records
.end() && weak_this
; ++i
) {
690 weak_this
->TriggerCallback(MDnsTransaction::RESULT_RECORD
, *i
);
693 #if defined(ENABLE_NSEC)
694 if (records
.empty()) {
696 client_
->core()->QueryCache(dns_protocol::kTypeNSEC
, name_
, &records
);
697 if (!records
.empty()) {
698 const NsecRecordRdata
* rdata
=
699 records
.front()->rdata
<NsecRecordRdata
>();
701 if (!rdata
->GetBit(rrtype_
))
702 weak_this
->TriggerCallback(MDnsTransaction::RESULT_NSEC
, NULL
);
709 bool MDnsTransactionImpl::QueryAndListen() {
710 listener_
= client_
->CreateListener(rrtype_
, name_
, this);
711 if (!listener_
->Start())
714 DCHECK(client_
->core());
715 if (!client_
->core()->SendQuery(rrtype_
, name_
))
718 timeout_
.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver
,
720 base::MessageLoop::current()->PostDelayedTask(
723 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds
));
728 void MDnsTransactionImpl::OnNsecRecord(const std::string
& name
, unsigned type
) {
729 TriggerCallback(RESULT_NSEC
, NULL
);
732 void MDnsTransactionImpl::OnCachePurged() {
733 // TODO(noamsml): Cache purge situations not yet implemented