1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/mdns_client_impl.h"
8 #include "base/message_loop/message_loop_proxy.h"
9 #include "base/stl_util.h"
10 #include "base/time/default_clock.h"
11 #include "net/base/dns_util.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/net_log.h"
14 #include "net/base/rand_callback.h"
15 #include "net/dns/dns_protocol.h"
16 #include "net/dns/record_rdata.h"
17 #include "net/udp/datagram_socket.h"
19 // TODO(gene): Remove this temporary method of disabling NSEC support once it
20 // becomes clear whether this feature should be
21 // supported. http://crbug.com/255232
28 const unsigned MDnsTransactionTimeoutSeconds
= 3;
29 // The fractions of the record's original TTL after which an active listener
30 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
31 // its cache. This happens both at 85% of the original TTL and again at 95% of
33 const double kListenerRefreshRatio1
= 0.85;
34 const double kListenerRefreshRatio2
= 0.95;
35 const unsigned kMillisecondsPerSecond
= 1000;
39 void MDnsSocketFactoryImpl::CreateSockets(
40 ScopedVector
<DatagramServerSocket
>* sockets
) {
41 InterfaceIndexFamilyList
interfaces(GetMDnsInterfacesToBind());
42 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
43 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
44 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
45 scoped_ptr
<DatagramServerSocket
> socket(
46 CreateAndBindMDnsSocket(interfaces
[i
].second
, interfaces
[i
].first
));
48 sockets
->push_back(socket
.release());
52 MDnsConnection::SocketHandler::SocketHandler(
53 scoped_ptr
<DatagramServerSocket
> socket
,
54 MDnsConnection
* connection
)
55 : socket_(socket
.Pass()),
56 connection_(connection
),
57 response_(dns_protocol::kMaxMulticastSize
) {
60 MDnsConnection::SocketHandler::~SocketHandler() {
63 int MDnsConnection::SocketHandler::Start() {
65 int rv
= socket_
->GetLocalAddress(&end_point
);
68 DCHECK(end_point
.GetFamily() == ADDRESS_FAMILY_IPV4
||
69 end_point
.GetFamily() == ADDRESS_FAMILY_IPV6
);
70 multicast_addr_
= GetMDnsIPEndPoint(end_point
.GetFamily());
74 int MDnsConnection::SocketHandler::DoLoop(int rv
) {
77 connection_
->OnDatagramReceived(&response_
, recv_addr_
, rv
);
79 rv
= socket_
->RecvFrom(
80 response_
.io_buffer(),
81 response_
.io_buffer()->size(),
83 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived
,
84 base::Unretained(this)));
87 if (rv
!= ERR_IO_PENDING
)
93 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv
) {
98 connection_
->OnError(this, rv
);
101 int MDnsConnection::SocketHandler::Send(IOBuffer
* buffer
, unsigned size
) {
102 return socket_
->SendTo(buffer
, size
, multicast_addr_
,
103 base::Bind(&MDnsConnection::SocketHandler::SendDone
,
104 base::Unretained(this) ));
107 void MDnsConnection::SocketHandler::SendDone(int rv
) {
108 // TODO(noamsml): Retry logic.
111 MDnsConnection::MDnsConnection(MDnsConnection::Delegate
* delegate
) :
112 delegate_(delegate
) {
115 MDnsConnection::~MDnsConnection() {
118 bool MDnsConnection::Init(MDnsSocketFactory
* socket_factory
) {
119 ScopedVector
<DatagramServerSocket
> sockets
;
120 socket_factory
->CreateSockets(&sockets
);
122 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
123 socket_handlers_
.push_back(
124 new MDnsConnection::SocketHandler(make_scoped_ptr(sockets
[i
]), this));
126 sockets
.weak_clear();
128 // All unbound sockets need to be bound before processing untrusted input.
129 // This is done for security reasons, so that an attacker can't get an unbound
131 for (size_t i
= 0; i
< socket_handlers_
.size();) {
132 int rv
= socket_handlers_
[i
]->Start();
134 socket_handlers_
.erase(socket_handlers_
.begin() + i
);
135 VLOG(1) << "Start failed, socket=" << i
<< ", error=" << rv
;
140 VLOG(1) << "Sockets ready:" << socket_handlers_
.size();
141 return !socket_handlers_
.empty();
144 bool MDnsConnection::Send(IOBuffer
* buffer
, unsigned size
) {
145 bool success
= false;
146 for (size_t i
= 0; i
< socket_handlers_
.size(); ++i
) {
147 int rv
= socket_handlers_
[i
]->Send(buffer
, size
);
148 if (rv
>= OK
|| rv
== ERR_IO_PENDING
) {
151 VLOG(1) << "Send failed, socket=" << i
<< ", error=" << rv
;
157 void MDnsConnection::OnError(SocketHandler
* loop
,
159 // TODO(noamsml): Specific handling of intermittent errors that can be handled
160 // in the connection.
161 delegate_
->OnConnectionError(error
);
164 void MDnsConnection::OnDatagramReceived(
165 DnsResponse
* response
,
166 const IPEndPoint
& recv_addr
,
168 // TODO(noamsml): More sophisticated error handling.
169 DCHECK_GT(bytes_read
, 0);
170 delegate_
->HandlePacket(response
, bytes_read
);
173 MDnsClientImpl::Core::Core(MDnsClientImpl
* client
)
174 : client_(client
), connection_(new MDnsConnection(this)) {
177 MDnsClientImpl::Core::~Core() {
178 STLDeleteValues(&listeners_
);
181 bool MDnsClientImpl::Core::Init(MDnsSocketFactory
* socket_factory
) {
182 return connection_
->Init(socket_factory
);
185 bool MDnsClientImpl::Core::SendQuery(uint16 rrtype
, std::string name
) {
186 std::string name_dns
;
187 if (!DNSDomainFromDot(name
, &name_dns
))
190 DnsQuery
query(0, name_dns
, rrtype
);
191 query
.set_flags(0); // Remove the RD flag from the query. It is unneeded.
193 return connection_
->Send(query
.io_buffer(), query
.io_buffer()->size());
196 void MDnsClientImpl::Core::HandlePacket(DnsResponse
* response
,
199 // Note: We store cache keys rather than record pointers to avoid
200 // erroneous behavior in case a packet contains multiple exclusive
201 // records with the same type and name.
202 std::map
<MDnsCache::Key
, MDnsCache::UpdateType
> update_keys
;
204 if (!response
->InitParseWithoutQuery(bytes_read
)) {
205 DVLOG(1) << "Could not understand an mDNS packet.";
206 return; // Message is unreadable.
209 // TODO(noamsml): duplicate query suppression.
210 if (!(response
->flags() & dns_protocol::kFlagResponse
))
211 return; // Message is a query. ignore it.
213 DnsRecordParser parser
= response
->Parser();
214 unsigned answer_count
= response
->answer_count() +
215 response
->additional_answer_count();
217 for (unsigned i
= 0; i
< answer_count
; i
++) {
218 offset
= parser
.GetOffset();
219 scoped_ptr
<const RecordParsed
> record
= RecordParsed::CreateFrom(
220 &parser
, base::Time::Now());
223 DVLOG(1) << "Could not understand an mDNS record.";
225 if (offset
== parser
.GetOffset()) {
226 DVLOG(1) << "Abandoned parsing the rest of the packet.";
227 return; // The parser did not advance, abort reading the packet.
229 continue; // We may be able to extract other records from the packet.
233 if ((record
->klass() & dns_protocol::kMDnsClassMask
) !=
234 dns_protocol::kClassIN
) {
235 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
236 continue; // Ignore all records not in the IN class.
239 MDnsCache::Key update_key
= MDnsCache::Key::CreateFor(record
.get());
240 MDnsCache::UpdateType update
= cache_
.UpdateDnsRecord(record
.Pass());
242 // Cleanup time may have changed.
243 ScheduleCleanup(cache_
.next_expiration());
245 update_keys
.insert(std::make_pair(update_key
, update
));
248 for (std::map
<MDnsCache::Key
, MDnsCache::UpdateType
>::iterator i
=
249 update_keys
.begin(); i
!= update_keys
.end(); i
++) {
250 const RecordParsed
* record
= cache_
.LookupKey(i
->first
);
254 if (record
->type() == dns_protocol::kTypeNSEC
) {
255 #if defined(ENABLE_NSEC)
256 NotifyNsecRecord(record
);
259 AlertListeners(i
->second
, ListenerKey(record
->name(), record
->type()),
265 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed
* record
) {
266 DCHECK_EQ(dns_protocol::kTypeNSEC
, record
->type());
267 const NsecRecordRdata
* rdata
= record
->rdata
<NsecRecordRdata
>();
270 // Remove all cached records matching the nonexistent RR types.
271 std::vector
<const RecordParsed
*> records_to_remove
;
273 cache_
.FindDnsRecords(0, record
->name(), &records_to_remove
,
276 for (std::vector
<const RecordParsed
*>::iterator i
= records_to_remove
.begin();
277 i
!= records_to_remove
.end(); i
++) {
278 if ((*i
)->type() == dns_protocol::kTypeNSEC
)
280 if (!rdata
->GetBit((*i
)->type())) {
281 scoped_ptr
<const RecordParsed
> record_removed
= cache_
.RemoveRecord((*i
));
282 DCHECK(record_removed
);
283 OnRecordRemoved(record_removed
.get());
287 // Alert all listeners waiting for the nonexistent RR types.
288 ListenerMap::iterator i
=
289 listeners_
.upper_bound(ListenerKey(record
->name(), 0));
290 for (; i
!= listeners_
.end() && i
->first
.first
== record
->name(); i
++) {
291 if (!rdata
->GetBit(i
->first
.second
)) {
292 FOR_EACH_OBSERVER(MDnsListenerImpl
, *i
->second
, AlertNsecRecord());
297 void MDnsClientImpl::Core::OnConnectionError(int error
) {
298 // TODO(noamsml): On connection error, recreate connection and flush cache.
301 void MDnsClientImpl::Core::AlertListeners(
302 MDnsCache::UpdateType update_type
,
303 const ListenerKey
& key
,
304 const RecordParsed
* record
) {
305 ListenerMap::iterator listener_map_iterator
= listeners_
.find(key
);
306 if (listener_map_iterator
== listeners_
.end()) return;
308 FOR_EACH_OBSERVER(MDnsListenerImpl
, *listener_map_iterator
->second
,
309 HandleRecordUpdate(update_type
, record
));
312 void MDnsClientImpl::Core::AddListener(
313 MDnsListenerImpl
* listener
) {
314 ListenerKey
key(listener
->GetName(), listener
->GetType());
315 std::pair
<ListenerMap::iterator
, bool> observer_insert_result
=
317 make_pair(key
, static_cast<ObserverList
<MDnsListenerImpl
>*>(NULL
)));
319 // If an equivalent key does not exist, actually create the observer list.
320 if (observer_insert_result
.second
)
321 observer_insert_result
.first
->second
= new ObserverList
<MDnsListenerImpl
>();
323 ObserverList
<MDnsListenerImpl
>* observer_list
=
324 observer_insert_result
.first
->second
;
326 observer_list
->AddObserver(listener
);
329 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl
* listener
) {
330 ListenerKey
key(listener
->GetName(), listener
->GetType());
331 ListenerMap::iterator observer_list_iterator
= listeners_
.find(key
);
333 DCHECK(observer_list_iterator
!= listeners_
.end());
334 DCHECK(observer_list_iterator
->second
->HasObserver(listener
));
336 observer_list_iterator
->second
->RemoveObserver(listener
);
338 // Remove the observer list from the map if it is empty
339 if (!observer_list_iterator
->second
->might_have_observers()) {
340 // Schedule the actual removal for later in case the listener removal
341 // happens while iterating over the observer list.
342 base::MessageLoop::current()->PostTask(
343 FROM_HERE
, base::Bind(
344 &MDnsClientImpl::Core::CleanupObserverList
, AsWeakPtr(), key
));
348 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey
& key
) {
349 ListenerMap::iterator found
= listeners_
.find(key
);
350 if (found
!= listeners_
.end() && !found
->second
->might_have_observers()) {
351 delete found
->second
;
352 listeners_
.erase(found
);
356 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup
) {
357 // Cleanup is already scheduled, no need to do anything.
358 if (cleanup
== scheduled_cleanup_
) return;
359 scheduled_cleanup_
= cleanup
;
361 // This cancels the previously scheduled cleanup.
362 cleanup_callback_
.Reset(base::Bind(
363 &MDnsClientImpl::Core::DoCleanup
, base::Unretained(this)));
365 // If |cleanup| is empty, then no cleanup necessary.
366 if (cleanup
!= base::Time()) {
367 base::MessageLoop::current()->PostDelayedTask(
369 cleanup_callback_
.callback(),
370 cleanup
- base::Time::Now());
374 void MDnsClientImpl::Core::DoCleanup() {
375 cache_
.CleanupRecords(base::Time::Now(), base::Bind(
376 &MDnsClientImpl::Core::OnRecordRemoved
, base::Unretained(this)));
378 ScheduleCleanup(cache_
.next_expiration());
381 void MDnsClientImpl::Core::OnRecordRemoved(
382 const RecordParsed
* record
) {
383 AlertListeners(MDnsCache::RecordRemoved
,
384 ListenerKey(record
->name(), record
->type()), record
);
387 void MDnsClientImpl::Core::QueryCache(
388 uint16 rrtype
, const std::string
& name
,
389 std::vector
<const RecordParsed
*>* records
) const {
390 cache_
.FindDnsRecords(rrtype
, name
, records
, base::Time::Now());
393 MDnsClientImpl::MDnsClientImpl() {
396 MDnsClientImpl::~MDnsClientImpl() {
399 bool MDnsClientImpl::StartListening(MDnsSocketFactory
* socket_factory
) {
400 DCHECK(!core_
.get());
401 core_
.reset(new Core(this));
402 if (!core_
->Init(socket_factory
)) {
409 void MDnsClientImpl::StopListening() {
413 bool MDnsClientImpl::IsListening() const {
414 return core_
.get() != NULL
;
417 scoped_ptr
<MDnsListener
> MDnsClientImpl::CreateListener(
419 const std::string
& name
,
420 MDnsListener::Delegate
* delegate
) {
421 return scoped_ptr
<net::MDnsListener
>(
422 new MDnsListenerImpl(rrtype
, name
, delegate
, this));
425 scoped_ptr
<MDnsTransaction
> MDnsClientImpl::CreateTransaction(
427 const std::string
& name
,
429 const MDnsTransaction::ResultCallback
& callback
) {
430 return scoped_ptr
<MDnsTransaction
>(
431 new MDnsTransactionImpl(rrtype
, name
, flags
, callback
, this));
434 MDnsListenerImpl::MDnsListenerImpl(
436 const std::string
& name
,
437 MDnsListener::Delegate
* delegate
,
438 MDnsClientImpl
* client
)
439 : rrtype_(rrtype
), name_(name
), client_(client
), delegate_(delegate
),
440 started_(false), active_refresh_(false) {
443 MDnsListenerImpl::~MDnsListenerImpl() {
445 DCHECK(client_
->core());
446 client_
->core()->RemoveListener(this);
450 bool MDnsListenerImpl::Start() {
455 DCHECK(client_
->core());
456 client_
->core()->AddListener(this);
461 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh
) {
462 active_refresh_
= active_refresh
;
465 if (!active_refresh_
) {
466 next_refresh_
.Cancel();
467 } else if (last_update_
!= base::Time()) {
468 ScheduleNextRefresh();
473 const std::string
& MDnsListenerImpl::GetName() const {
477 uint16
MDnsListenerImpl::GetType() const {
481 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type
,
482 const RecordParsed
* record
) {
485 if (update_type
!= MDnsCache::RecordRemoved
) {
486 ttl_
= record
->ttl();
487 last_update_
= record
->time_created();
489 ScheduleNextRefresh();
492 if (update_type
!= MDnsCache::NoChange
) {
493 MDnsListener::UpdateType update_external
;
495 switch (update_type
) {
496 case MDnsCache::RecordAdded
:
497 update_external
= MDnsListener::RECORD_ADDED
;
499 case MDnsCache::RecordChanged
:
500 update_external
= MDnsListener::RECORD_CHANGED
;
502 case MDnsCache::RecordRemoved
:
503 update_external
= MDnsListener::RECORD_REMOVED
;
505 case MDnsCache::NoChange
:
508 // Dummy assignment to suppress compiler warning.
509 update_external
= MDnsListener::RECORD_CHANGED
;
513 delegate_
->OnRecordUpdate(update_external
, record
);
517 void MDnsListenerImpl::AlertNsecRecord() {
519 delegate_
->OnNsecRecord(name_
, rrtype_
);
522 void MDnsListenerImpl::ScheduleNextRefresh() {
523 DCHECK(last_update_
!= base::Time());
525 if (!active_refresh_
)
528 // A zero TTL is a goodbye packet and should not be refreshed.
530 next_refresh_
.Cancel();
534 next_refresh_
.Reset(base::Bind(&MDnsListenerImpl::DoRefresh
,
537 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
538 // be canceled and rescheduled if the record's TTL is updated due to a
539 // response being received.
540 base::Time next_refresh1
= last_update_
+ base::TimeDelta::FromMilliseconds(
541 static_cast<int>(kMillisecondsPerSecond
*
542 kListenerRefreshRatio1
* ttl_
));
544 base::Time next_refresh2
= last_update_
+ base::TimeDelta::FromMilliseconds(
545 static_cast<int>(kMillisecondsPerSecond
*
546 kListenerRefreshRatio2
* ttl_
));
548 base::MessageLoop::current()->PostDelayedTask(
550 next_refresh_
.callback(),
551 next_refresh1
- base::Time::Now());
553 base::MessageLoop::current()->PostDelayedTask(
555 next_refresh_
.callback(),
556 next_refresh2
- base::Time::Now());
559 void MDnsListenerImpl::DoRefresh() {
560 client_
->core()->SendQuery(rrtype_
, name_
);
563 MDnsTransactionImpl::MDnsTransactionImpl(
565 const std::string
& name
,
567 const MDnsTransaction::ResultCallback
& callback
,
568 MDnsClientImpl
* client
)
569 : rrtype_(rrtype
), name_(name
), callback_(callback
), client_(client
),
570 started_(false), flags_(flags
) {
571 DCHECK((flags_
& MDnsTransaction::FLAG_MASK
) == flags_
);
572 DCHECK(flags_
& MDnsTransaction::QUERY_CACHE
||
573 flags_
& MDnsTransaction::QUERY_NETWORK
);
576 MDnsTransactionImpl::~MDnsTransactionImpl() {
580 bool MDnsTransactionImpl::Start() {
584 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
585 if (flags_
& MDnsTransaction::QUERY_CACHE
) {
586 ServeRecordsFromCache();
588 if (!weak_this
|| !is_active()) return true;
591 if (flags_
& MDnsTransaction::QUERY_NETWORK
) {
592 return QueryAndListen();
595 // If this is a cache only query, signal that the transaction is over
597 SignalTransactionOver();
601 const std::string
& MDnsTransactionImpl::GetName() const {
605 uint16
MDnsTransactionImpl::GetType() const {
609 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed
* record
) {
611 OnRecordUpdate(MDnsListener::RECORD_ADDED
, record
);
614 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result
,
615 const RecordParsed
* record
) {
617 if (!is_active()) return;
619 // Ensure callback is run after touching all class state, so that
620 // the callback can delete the transaction.
621 MDnsTransaction::ResultCallback callback
= callback_
;
623 // Reset the transaction if it expects a single result, or if the result
624 // is a final one (everything except for a record).
625 if (flags_
& MDnsTransaction::SINGLE_RESULT
||
626 result
!= MDnsTransaction::RESULT_RECORD
) {
630 callback
.Run(result
, record
);
633 void MDnsTransactionImpl::Reset() {
639 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update
,
640 const RecordParsed
* record
) {
642 if (update
== MDnsListener::RECORD_ADDED
||
643 update
== MDnsListener::RECORD_CHANGED
)
644 TriggerCallback(MDnsTransaction::RESULT_RECORD
, record
);
647 void MDnsTransactionImpl::SignalTransactionOver() {
649 if (flags_
& MDnsTransaction::SINGLE_RESULT
) {
650 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS
, NULL
);
652 TriggerCallback(MDnsTransaction::RESULT_DONE
, NULL
);
656 void MDnsTransactionImpl::ServeRecordsFromCache() {
657 std::vector
<const RecordParsed
*> records
;
658 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
660 if (client_
->core()) {
661 client_
->core()->QueryCache(rrtype_
, name_
, &records
);
662 for (std::vector
<const RecordParsed
*>::iterator i
= records
.begin();
663 i
!= records
.end() && weak_this
; ++i
) {
664 weak_this
->TriggerCallback(MDnsTransaction::RESULT_RECORD
, *i
);
667 #if defined(ENABLE_NSEC)
668 if (records
.empty()) {
670 client_
->core()->QueryCache(dns_protocol::kTypeNSEC
, name_
, &records
);
671 if (!records
.empty()) {
672 const NsecRecordRdata
* rdata
=
673 records
.front()->rdata
<NsecRecordRdata
>();
675 if (!rdata
->GetBit(rrtype_
))
676 weak_this
->TriggerCallback(MDnsTransaction::RESULT_NSEC
, NULL
);
683 bool MDnsTransactionImpl::QueryAndListen() {
684 listener_
= client_
->CreateListener(rrtype_
, name_
, this);
685 if (!listener_
->Start())
688 DCHECK(client_
->core());
689 if (!client_
->core()->SendQuery(rrtype_
, name_
))
692 timeout_
.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver
,
694 base::MessageLoop::current()->PostDelayedTask(
697 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds
));
702 void MDnsTransactionImpl::OnNsecRecord(const std::string
& name
, unsigned type
) {
703 TriggerCallback(RESULT_NSEC
, NULL
);
706 void MDnsTransactionImpl::OnCachePurged() {
707 // TODO(noamsml): Cache purge situations not yet implemented