1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/mdns_client_impl.h"
8 #include "base/message_loop/message_loop_proxy.h"
9 #include "base/stl_util.h"
10 #include "base/time/default_clock.h"
11 #include "net/base/dns_util.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/net_log.h"
14 #include "net/base/rand_callback.h"
15 #include "net/dns/dns_protocol.h"
16 #include "net/dns/record_rdata.h"
17 #include "net/udp/datagram_socket.h"
19 // TODO(gene): Remove this temporary method of disabling NSEC support once it
20 // becomes clear whether this feature should be
21 // supported. http://crbug.com/255232
28 const unsigned MDnsTransactionTimeoutSeconds
= 3;
32 void MDnsSocketFactoryImpl::CreateSockets(
33 ScopedVector
<DatagramServerSocket
>* sockets
) {
34 InterfaceIndexFamilyList
interfaces(GetMDnsInterfacesToBind());
35 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
36 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
37 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
38 scoped_ptr
<DatagramServerSocket
> socket(
39 CreateAndBindMDnsSocket(interfaces
[i
].second
, interfaces
[i
].first
));
41 sockets
->push_back(socket
.release());
45 MDnsConnection::SocketHandler::SocketHandler(
46 scoped_ptr
<DatagramServerSocket
> socket
,
47 MDnsConnection
* connection
)
48 : socket_(socket
.Pass()),
49 connection_(connection
),
50 response_(dns_protocol::kMaxMulticastSize
) {
53 MDnsConnection::SocketHandler::~SocketHandler() {
56 int MDnsConnection::SocketHandler::Start() {
58 int rv
= socket_
->GetLocalAddress(&end_point
);
61 DCHECK(end_point
.GetFamily() == ADDRESS_FAMILY_IPV4
||
62 end_point
.GetFamily() == ADDRESS_FAMILY_IPV6
);
63 multicast_addr_
= GetMDnsIPEndPoint(end_point
.GetFamily());
67 int MDnsConnection::SocketHandler::DoLoop(int rv
) {
70 connection_
->OnDatagramReceived(&response_
, recv_addr_
, rv
);
72 rv
= socket_
->RecvFrom(
73 response_
.io_buffer(),
74 response_
.io_buffer()->size(),
76 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived
,
77 base::Unretained(this)));
80 if (rv
!= ERR_IO_PENDING
)
86 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv
) {
91 connection_
->OnError(this, rv
);
94 int MDnsConnection::SocketHandler::Send(IOBuffer
* buffer
, unsigned size
) {
95 return socket_
->SendTo(buffer
, size
, multicast_addr_
,
96 base::Bind(&MDnsConnection::SocketHandler::SendDone
,
97 base::Unretained(this) ));
100 void MDnsConnection::SocketHandler::SendDone(int rv
) {
101 // TODO(noamsml): Retry logic.
104 MDnsConnection::MDnsConnection(MDnsConnection::Delegate
* delegate
) :
105 delegate_(delegate
) {
108 MDnsConnection::~MDnsConnection() {
111 bool MDnsConnection::Init(MDnsSocketFactory
* socket_factory
) {
112 ScopedVector
<DatagramServerSocket
> sockets
;
113 socket_factory
->CreateSockets(&sockets
);
115 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
116 socket_handlers_
.push_back(
117 new MDnsConnection::SocketHandler(make_scoped_ptr(sockets
[i
]), this));
119 sockets
.weak_clear();
121 // All unbound sockets need to be bound before processing untrusted input.
122 // This is done for security reasons, so that an attacker can't get an unbound
124 for (size_t i
= 0; i
< socket_handlers_
.size();) {
125 int rv
= socket_handlers_
[i
]->Start();
127 socket_handlers_
.erase(socket_handlers_
.begin() + i
);
128 VLOG(1) << "Start failed, socket=" << i
<< ", error=" << rv
;
133 VLOG(1) << "Sockets ready:" << socket_handlers_
.size();
134 return !socket_handlers_
.empty();
137 bool MDnsConnection::Send(IOBuffer
* buffer
, unsigned size
) {
138 bool success
= false;
139 for (size_t i
= 0; i
< socket_handlers_
.size(); ++i
) {
140 int rv
= socket_handlers_
[i
]->Send(buffer
, size
);
141 if (rv
>= OK
|| rv
== ERR_IO_PENDING
) {
144 VLOG(1) << "Send failed, socket=" << i
<< ", error=" << rv
;
150 void MDnsConnection::OnError(SocketHandler
* loop
,
152 // TODO(noamsml): Specific handling of intermittent errors that can be handled
153 // in the connection.
154 delegate_
->OnConnectionError(error
);
157 void MDnsConnection::OnDatagramReceived(
158 DnsResponse
* response
,
159 const IPEndPoint
& recv_addr
,
161 // TODO(noamsml): More sophisticated error handling.
162 DCHECK_GT(bytes_read
, 0);
163 delegate_
->HandlePacket(response
, bytes_read
);
166 MDnsClientImpl::Core::Core(MDnsClientImpl
* client
)
167 : client_(client
), connection_(new MDnsConnection(this)) {
170 MDnsClientImpl::Core::~Core() {
171 STLDeleteValues(&listeners_
);
174 bool MDnsClientImpl::Core::Init(MDnsSocketFactory
* socket_factory
) {
175 return connection_
->Init(socket_factory
);
178 bool MDnsClientImpl::Core::SendQuery(uint16 rrtype
, std::string name
) {
179 std::string name_dns
;
180 if (!DNSDomainFromDot(name
, &name_dns
))
183 DnsQuery
query(0, name_dns
, rrtype
);
184 query
.set_flags(0); // Remove the RD flag from the query. It is unneeded.
186 return connection_
->Send(query
.io_buffer(), query
.io_buffer()->size());
189 void MDnsClientImpl::Core::HandlePacket(DnsResponse
* response
,
192 // Note: We store cache keys rather than record pointers to avoid
193 // erroneous behavior in case a packet contains multiple exclusive
194 // records with the same type and name.
195 std::map
<MDnsCache::Key
, MDnsListener::UpdateType
> update_keys
;
197 if (!response
->InitParseWithoutQuery(bytes_read
)) {
198 LOG(WARNING
) << "Could not understand an mDNS packet.";
199 return; // Message is unreadable.
202 // TODO(noamsml): duplicate query suppression.
203 if (!(response
->flags() & dns_protocol::kFlagResponse
))
204 return; // Message is a query. ignore it.
206 DnsRecordParser parser
= response
->Parser();
207 unsigned answer_count
= response
->answer_count() +
208 response
->additional_answer_count();
210 for (unsigned i
= 0; i
< answer_count
; i
++) {
211 offset
= parser
.GetOffset();
212 scoped_ptr
<const RecordParsed
> record
= RecordParsed::CreateFrom(
213 &parser
, base::Time::Now());
216 LOG(WARNING
) << "Could not understand an mDNS record.";
218 if (offset
== parser
.GetOffset()) {
219 LOG(WARNING
) << "Abandoned parsing the rest of the packet.";
220 return; // The parser did not advance, abort reading the packet.
222 continue; // We may be able to extract other records from the packet.
226 if ((record
->klass() & dns_protocol::kMDnsClassMask
) !=
227 dns_protocol::kClassIN
) {
228 LOG(WARNING
) << "Received an mDNS record with non-IN class. Ignoring.";
229 continue; // Ignore all records not in the IN class.
232 MDnsCache::Key update_key
= MDnsCache::Key::CreateFor(record
.get());
233 MDnsCache::UpdateType update
= cache_
.UpdateDnsRecord(record
.Pass());
235 // Cleanup time may have changed.
236 ScheduleCleanup(cache_
.next_expiration());
238 if (update
!= MDnsCache::NoChange
) {
239 MDnsListener::UpdateType update_external
;
242 case MDnsCache::RecordAdded
:
243 update_external
= MDnsListener::RECORD_ADDED
;
245 case MDnsCache::RecordChanged
:
246 update_external
= MDnsListener::RECORD_CHANGED
;
248 case MDnsCache::NoChange
:
251 // Dummy assignment to suppress compiler warning.
252 update_external
= MDnsListener::RECORD_CHANGED
;
256 update_keys
.insert(std::make_pair(update_key
, update_external
));
260 for (std::map
<MDnsCache::Key
, MDnsListener::UpdateType
>::iterator i
=
261 update_keys
.begin(); i
!= update_keys
.end(); i
++) {
262 const RecordParsed
* record
= cache_
.LookupKey(i
->first
);
266 if (record
->type() == dns_protocol::kTypeNSEC
) {
267 #if defined(ENABLE_NSEC)
268 NotifyNsecRecord(record
);
271 AlertListeners(i
->second
, ListenerKey(record
->name(), record
->type()),
277 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed
* record
) {
278 DCHECK_EQ(dns_protocol::kTypeNSEC
, record
->type());
279 const NsecRecordRdata
* rdata
= record
->rdata
<NsecRecordRdata
>();
282 // Remove all cached records matching the nonexistent RR types.
283 std::vector
<const RecordParsed
*> records_to_remove
;
285 cache_
.FindDnsRecords(0, record
->name(), &records_to_remove
,
288 for (std::vector
<const RecordParsed
*>::iterator i
= records_to_remove
.begin();
289 i
!= records_to_remove
.end(); i
++) {
290 if ((*i
)->type() == dns_protocol::kTypeNSEC
)
292 if (!rdata
->GetBit((*i
)->type())) {
293 scoped_ptr
<const RecordParsed
> record_removed
= cache_
.RemoveRecord((*i
));
294 DCHECK(record_removed
);
295 OnRecordRemoved(record_removed
.get());
299 // Alert all listeners waiting for the nonexistent RR types.
300 ListenerMap::iterator i
=
301 listeners_
.upper_bound(ListenerKey(record
->name(), 0));
302 for (; i
!= listeners_
.end() && i
->first
.first
== record
->name(); i
++) {
303 if (!rdata
->GetBit(i
->first
.second
)) {
304 FOR_EACH_OBSERVER(MDnsListenerImpl
, *i
->second
, AlertNsecRecord());
309 void MDnsClientImpl::Core::OnConnectionError(int error
) {
310 // TODO(noamsml): On connection error, recreate connection and flush cache.
313 void MDnsClientImpl::Core::AlertListeners(
314 MDnsListener::UpdateType update_type
,
315 const ListenerKey
& key
,
316 const RecordParsed
* record
) {
317 ListenerMap::iterator listener_map_iterator
= listeners_
.find(key
);
318 if (listener_map_iterator
== listeners_
.end()) return;
320 FOR_EACH_OBSERVER(MDnsListenerImpl
, *listener_map_iterator
->second
,
321 AlertDelegate(update_type
, record
));
324 void MDnsClientImpl::Core::AddListener(
325 MDnsListenerImpl
* listener
) {
326 ListenerKey
key(listener
->GetName(), listener
->GetType());
327 std::pair
<ListenerMap::iterator
, bool> observer_insert_result
=
329 make_pair(key
, static_cast<ObserverList
<MDnsListenerImpl
>*>(NULL
)));
331 // If an equivalent key does not exist, actually create the observer list.
332 if (observer_insert_result
.second
)
333 observer_insert_result
.first
->second
= new ObserverList
<MDnsListenerImpl
>();
335 ObserverList
<MDnsListenerImpl
>* observer_list
=
336 observer_insert_result
.first
->second
;
338 observer_list
->AddObserver(listener
);
341 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl
* listener
) {
342 ListenerKey
key(listener
->GetName(), listener
->GetType());
343 ListenerMap::iterator observer_list_iterator
= listeners_
.find(key
);
345 DCHECK(observer_list_iterator
!= listeners_
.end());
346 DCHECK(observer_list_iterator
->second
->HasObserver(listener
));
348 observer_list_iterator
->second
->RemoveObserver(listener
);
350 // Remove the observer list from the map if it is empty
351 if (!observer_list_iterator
->second
->might_have_observers()) {
352 // Schedule the actual removal for later in case the listener removal
353 // happens while iterating over the observer list.
354 base::MessageLoop::current()->PostTask(
355 FROM_HERE
, base::Bind(
356 &MDnsClientImpl::Core::CleanupObserverList
, AsWeakPtr(), key
));
360 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey
& key
) {
361 ListenerMap::iterator found
= listeners_
.find(key
);
362 if (found
!= listeners_
.end() && !found
->second
->might_have_observers()) {
363 delete found
->second
;
364 listeners_
.erase(found
);
368 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup
) {
369 // Cleanup is already scheduled, no need to do anything.
370 if (cleanup
== scheduled_cleanup_
) return;
371 scheduled_cleanup_
= cleanup
;
373 // This cancels the previously scheduled cleanup.
374 cleanup_callback_
.Reset(base::Bind(
375 &MDnsClientImpl::Core::DoCleanup
, base::Unretained(this)));
377 // If |cleanup| is empty, then no cleanup necessary.
378 if (cleanup
!= base::Time()) {
379 base::MessageLoop::current()->PostDelayedTask(
381 cleanup_callback_
.callback(),
382 cleanup
- base::Time::Now());
386 void MDnsClientImpl::Core::DoCleanup() {
387 cache_
.CleanupRecords(base::Time::Now(), base::Bind(
388 &MDnsClientImpl::Core::OnRecordRemoved
, base::Unretained(this)));
390 ScheduleCleanup(cache_
.next_expiration());
393 void MDnsClientImpl::Core::OnRecordRemoved(
394 const RecordParsed
* record
) {
395 AlertListeners(MDnsListener::RECORD_REMOVED
,
396 ListenerKey(record
->name(), record
->type()), record
);
399 void MDnsClientImpl::Core::QueryCache(
400 uint16 rrtype
, const std::string
& name
,
401 std::vector
<const RecordParsed
*>* records
) const {
402 cache_
.FindDnsRecords(rrtype
, name
, records
, base::Time::Now());
405 MDnsClientImpl::MDnsClientImpl() {
408 MDnsClientImpl::~MDnsClientImpl() {
411 bool MDnsClientImpl::StartListening(MDnsSocketFactory
* socket_factory
) {
412 DCHECK(!core_
.get());
413 core_
.reset(new Core(this));
414 if (!core_
->Init(socket_factory
)) {
421 void MDnsClientImpl::StopListening() {
425 bool MDnsClientImpl::IsListening() const {
426 return core_
.get() != NULL
;
429 scoped_ptr
<MDnsListener
> MDnsClientImpl::CreateListener(
431 const std::string
& name
,
432 MDnsListener::Delegate
* delegate
) {
433 return scoped_ptr
<net::MDnsListener
>(
434 new MDnsListenerImpl(rrtype
, name
, delegate
, this));
437 scoped_ptr
<MDnsTransaction
> MDnsClientImpl::CreateTransaction(
439 const std::string
& name
,
441 const MDnsTransaction::ResultCallback
& callback
) {
442 return scoped_ptr
<MDnsTransaction
>(
443 new MDnsTransactionImpl(rrtype
, name
, flags
, callback
, this));
446 MDnsListenerImpl::MDnsListenerImpl(
448 const std::string
& name
,
449 MDnsListener::Delegate
* delegate
,
450 MDnsClientImpl
* client
)
451 : rrtype_(rrtype
), name_(name
), client_(client
), delegate_(delegate
),
455 bool MDnsListenerImpl::Start() {
460 DCHECK(client_
->core());
461 client_
->core()->AddListener(this);
466 MDnsListenerImpl::~MDnsListenerImpl() {
468 DCHECK(client_
->core());
469 client_
->core()->RemoveListener(this);
473 const std::string
& MDnsListenerImpl::GetName() const {
477 uint16
MDnsListenerImpl::GetType() const {
481 void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type
,
482 const RecordParsed
* record
) {
484 delegate_
->OnRecordUpdate(update_type
, record
);
487 void MDnsListenerImpl::AlertNsecRecord() {
489 delegate_
->OnNsecRecord(name_
, rrtype_
);
492 MDnsTransactionImpl::MDnsTransactionImpl(
494 const std::string
& name
,
496 const MDnsTransaction::ResultCallback
& callback
,
497 MDnsClientImpl
* client
)
498 : rrtype_(rrtype
), name_(name
), callback_(callback
), client_(client
),
499 started_(false), flags_(flags
) {
500 DCHECK((flags_
& MDnsTransaction::FLAG_MASK
) == flags_
);
501 DCHECK(flags_
& MDnsTransaction::QUERY_CACHE
||
502 flags_
& MDnsTransaction::QUERY_NETWORK
);
505 MDnsTransactionImpl::~MDnsTransactionImpl() {
509 bool MDnsTransactionImpl::Start() {
513 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
514 if (flags_
& MDnsTransaction::QUERY_CACHE
) {
515 ServeRecordsFromCache();
517 if (!weak_this
|| !is_active()) return true;
520 if (flags_
& MDnsTransaction::QUERY_NETWORK
) {
521 return QueryAndListen();
524 // If this is a cache only query, signal that the transaction is over
526 SignalTransactionOver();
530 const std::string
& MDnsTransactionImpl::GetName() const {
534 uint16
MDnsTransactionImpl::GetType() const {
538 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed
* record
) {
540 OnRecordUpdate(MDnsListener::RECORD_ADDED
, record
);
543 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result
,
544 const RecordParsed
* record
) {
546 if (!is_active()) return;
548 // Ensure callback is run after touching all class state, so that
549 // the callback can delete the transaction.
550 MDnsTransaction::ResultCallback callback
= callback_
;
552 // Reset the transaction if it expects a single result, or if the result
553 // is a final one (everything except for a record).
554 if (flags_
& MDnsTransaction::SINGLE_RESULT
||
555 result
!= MDnsTransaction::RESULT_RECORD
) {
559 callback
.Run(result
, record
);
562 void MDnsTransactionImpl::Reset() {
568 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update
,
569 const RecordParsed
* record
) {
571 if (update
== MDnsListener::RECORD_ADDED
||
572 update
== MDnsListener::RECORD_CHANGED
)
573 TriggerCallback(MDnsTransaction::RESULT_RECORD
, record
);
576 void MDnsTransactionImpl::SignalTransactionOver() {
578 if (flags_
& MDnsTransaction::SINGLE_RESULT
) {
579 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS
, NULL
);
581 TriggerCallback(MDnsTransaction::RESULT_DONE
, NULL
);
585 void MDnsTransactionImpl::ServeRecordsFromCache() {
586 std::vector
<const RecordParsed
*> records
;
587 base::WeakPtr
<MDnsTransactionImpl
> weak_this
= AsWeakPtr();
589 if (client_
->core()) {
590 client_
->core()->QueryCache(rrtype_
, name_
, &records
);
591 for (std::vector
<const RecordParsed
*>::iterator i
= records
.begin();
592 i
!= records
.end() && weak_this
; ++i
) {
593 weak_this
->TriggerCallback(MDnsTransaction::RESULT_RECORD
, *i
);
596 #if defined(ENABLE_NSEC)
597 if (records
.empty()) {
599 client_
->core()->QueryCache(dns_protocol::kTypeNSEC
, name_
, &records
);
600 if (!records
.empty()) {
601 const NsecRecordRdata
* rdata
=
602 records
.front()->rdata
<NsecRecordRdata
>();
604 if (!rdata
->GetBit(rrtype_
))
605 weak_this
->TriggerCallback(MDnsTransaction::RESULT_NSEC
, NULL
);
612 bool MDnsTransactionImpl::QueryAndListen() {
613 listener_
= client_
->CreateListener(rrtype_
, name_
, this);
614 if (!listener_
->Start())
617 DCHECK(client_
->core());
618 if (!client_
->core()->SendQuery(rrtype_
, name_
))
621 timeout_
.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver
,
623 base::MessageLoop::current()->PostDelayedTask(
626 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds
));
631 void MDnsTransactionImpl::OnNsecRecord(const std::string
& name
, unsigned type
) {
632 TriggerCallback(RESULT_NSEC
, NULL
);
635 void MDnsTransactionImpl::OnCachePurged() {
636 // TODO(noamsml): Cache purge situations not yet implemented