1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h"
9 #include "base/lazy_instance.h"
10 #include "chrome/common/local_discovery/local_discovery_messages.h"
11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12 #include "content/public/utility/utility_thread.h"
13 #include "net/socket/socket_descriptor.h"
14 #include "net/udp/datagram_server_socket.h"
16 namespace local_discovery
{
20 void ClosePlatformSocket(net::SocketDescriptor socket
);
22 // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
23 // keeps single socket that will be returned to the first call to
24 // |net::CreatePlatformSocket| during object lifetime.
25 class ScopedSocketFactory
: public net::PlatformSocketFactory
{
27 explicit ScopedSocketFactory(net::SocketDescriptor socket
) : socket_(socket
) {
28 net::PlatformSocketFactory::SetInstance(this);
31 virtual ~ScopedSocketFactory() {
32 net::PlatformSocketFactory::SetInstance(NULL
);
33 ClosePlatformSocket(socket_
);
34 socket_
= net::kInvalidSocket
;
37 virtual net::SocketDescriptor
CreateSocket(int family
, int type
,
38 int protocol
) OVERRIDE
{
39 DCHECK_EQ(type
, SOCK_DGRAM
);
40 DCHECK(family
== AF_INET
|| family
== AF_INET6
);
41 net::SocketDescriptor result
= net::kInvalidSocket
;
42 std::swap(result
, socket_
);
47 net::SocketDescriptor socket_
;
48 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory
);
52 SocketInfo(net::SocketDescriptor socket
,
53 net::AddressFamily address_family
,
54 uint32 interface_index
)
56 address_family(address_family
),
57 interface_index(interface_index
) {
59 net::SocketDescriptor socket
;
60 net::AddressFamily address_family
;
61 uint32 interface_index
;
64 // Returns list of sockets preallocated before.
65 class PreCreatedMDnsSocketFactory
: public net::MDnsSocketFactory
{
67 PreCreatedMDnsSocketFactory() {}
68 virtual ~PreCreatedMDnsSocketFactory() {
72 // net::MDnsSocketFactory implementation:
73 virtual void CreateSockets(
74 ScopedVector
<net::DatagramServerSocket
>* sockets
) OVERRIDE
{
75 for (size_t i
= 0; i
< sockets_
.size(); ++i
) {
76 // Takes ownership of sockets_[i].socket;
77 ScopedSocketFactory
platform_factory(sockets_
[i
].socket
);
78 scoped_ptr
<net::DatagramServerSocket
> socket(
79 net::CreateAndBindMDnsSocket(sockets_
[i
].address_family
,
80 sockets_
[i
].interface_index
));
82 sockets
->push_back(socket
.release());
87 void AddSocket(const SocketInfo
& socket
) {
88 sockets_
.push_back(socket
);
92 for (size_t i
= 0; i
< sockets_
.size(); ++i
) {
93 if (sockets_
[i
].socket
!= net::kInvalidSocket
)
94 ClosePlatformSocket(sockets_
[i
].socket
);
100 std::vector
<SocketInfo
> sockets_
;
102 DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory
);
105 base::LazyInstance
<PreCreatedMDnsSocketFactory
>
106 g_local_discovery_socket_factory
= LAZY_INSTANCE_INITIALIZER
;
110 void ClosePlatformSocket(net::SocketDescriptor socket
) {
111 ::closesocket(socket
);
114 void StaticInitializeSocketFactory() {
115 net::InterfaceIndexFamilyList
interfaces(net::GetMDnsInterfacesToBind());
116 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
117 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
118 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
119 net::SocketDescriptor descriptor
=
120 net::CreatePlatformSocket(
121 net::ConvertAddressFamily(interfaces
[i
].second
), SOCK_DGRAM
,
123 g_local_discovery_socket_factory
.Get().AddSocket(
124 SocketInfo(descriptor
, interfaces
[i
].second
, interfaces
[i
].first
));
130 void ClosePlatformSocket(net::SocketDescriptor socket
) {
134 void StaticInitializeSocketFactory() {
139 void SendHostMessageOnUtilityThread(IPC::Message
* msg
) {
140 content::UtilityThread::Get()->Send(msg
);
143 std::string
WatcherUpdateToString(ServiceWatcher::UpdateType update
) {
145 case ServiceWatcher::UPDATE_ADDED
:
146 return "UPDATE_ADDED";
147 case ServiceWatcher::UPDATE_CHANGED
:
148 return "UPDATE_CHANGED";
149 case ServiceWatcher::UPDATE_REMOVED
:
150 return "UPDATE_REMOVED";
151 case ServiceWatcher::UPDATE_INVALIDATED
:
152 return "UPDATE_INVALIDATED";
154 return "Unknown Update";
157 std::string
ResolverStatusToString(ServiceResolver::RequestStatus status
) {
159 case ServiceResolver::STATUS_SUCCESS
:
160 return "STATUS_SUCESS";
161 case ServiceResolver::STATUS_REQUEST_TIMEOUT
:
162 return "STATUS_REQUEST_TIMEOUT";
163 case ServiceResolver::STATUS_KNOWN_NONEXISTENT
:
164 return "STATUS_KNOWN_NONEXISTENT";
166 return "Unknown Status";
171 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
174 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
175 DCHECK(!discovery_thread_
);
178 void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
179 StaticInitializeSocketFactory();
182 void ServiceDiscoveryMessageHandler::InitializeMdns() {
183 if (service_discovery_client_
|| mdns_client_
)
186 mdns_client_
= net::MDnsClient::CreateDefault();
188 mdns_client_
->StartListening(g_local_discovery_socket_factory
.Pointer());
189 // Close unused sockets.
190 g_local_discovery_socket_factory
.Get().Reset();
192 VLOG(1) << "Failed to start MDnsClient";
193 Send(new LocalDiscoveryHostMsg_Error());
197 service_discovery_client_
.reset(
198 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_
.get()));
201 bool ServiceDiscoveryMessageHandler::InitializeThread() {
202 if (discovery_task_runner_
)
204 if (discovery_thread_
)
206 utility_task_runner_
= base::MessageLoop::current()->message_loop_proxy();
207 discovery_thread_
.reset(new base::Thread("ServiceDiscoveryThread"));
208 base::Thread::Options
thread_options(base::MessageLoop::TYPE_IO
, 0);
209 if (discovery_thread_
->StartWithOptions(thread_options
)) {
210 discovery_task_runner_
= discovery_thread_
->message_loop_proxy();
211 discovery_task_runner_
->PostTask(FROM_HERE
,
212 base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns
,
213 base::Unretained(this)));
215 return discovery_task_runner_
!= NULL
;
218 bool ServiceDiscoveryMessageHandler::OnMessageReceived(
219 const IPC::Message
& message
) {
221 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler
, message
)
222 #if defined(OS_POSIX)
223 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets
, OnSetSockets
)
225 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher
, OnStartWatcher
)
226 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices
, OnDiscoverServices
)
227 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher
, OnDestroyWatcher
)
228 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService
, OnResolveService
)
229 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver
, OnDestroyResolver
)
230 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain
,
231 OnResolveLocalDomain
)
232 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver
,
233 OnDestroyLocalDomainResolver
)
234 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery
,
235 ShutdownLocalDiscovery
)
236 IPC_MESSAGE_UNHANDLED(handled
= false)
237 IPC_END_MESSAGE_MAP()
241 void ServiceDiscoveryMessageHandler::PostTask(
242 const tracked_objects::Location
& from_here
,
243 const base::Closure
& task
) {
244 if (!InitializeThread())
246 discovery_task_runner_
->PostTask(from_here
, task
);
249 #if defined(OS_POSIX)
250 void ServiceDiscoveryMessageHandler::OnSetSockets(
251 const std::vector
<LocalDiscoveryMsg_SocketInfo
>& sockets
) {
252 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
253 g_local_discovery_socket_factory
.Get().AddSocket(
254 SocketInfo(sockets
[i
].descriptor
.fd
, sockets
[i
].address_family
,
255 sockets
[i
].interface_index
));
260 void ServiceDiscoveryMessageHandler::OnStartWatcher(
262 const std::string
& service_type
) {
264 base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher
,
265 base::Unretained(this), id
, service_type
));
268 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id
,
271 base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices
,
272 base::Unretained(this), id
, force_update
));
275 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id
) {
277 base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher
,
278 base::Unretained(this), id
));
281 void ServiceDiscoveryMessageHandler::OnResolveService(
283 const std::string
& service_name
) {
285 base::Bind(&ServiceDiscoveryMessageHandler::ResolveService
,
286 base::Unretained(this), id
, service_name
));
289 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id
) {
291 base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver
,
292 base::Unretained(this), id
));
295 void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
296 uint64 id
, const std::string
& domain
,
297 net::AddressFamily address_family
) {
299 base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain
,
300 base::Unretained(this), id
, domain
, address_family
));
303 void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id
) {
306 &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver
,
307 base::Unretained(this), id
));
310 void ServiceDiscoveryMessageHandler::StartWatcher(
312 const std::string
& service_type
) {
313 VLOG(1) << "StartWatcher, id=" << id
<< ", type=" << service_type
;
314 if (!service_discovery_client_
)
316 DCHECK(!ContainsKey(service_watchers_
, id
));
317 scoped_ptr
<ServiceWatcher
> watcher(
318 service_discovery_client_
->CreateServiceWatcher(
320 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated
,
321 base::Unretained(this), id
)));
323 service_watchers_
[id
].reset(watcher
.release());
326 void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id
,
328 VLOG(1) << "DiscoverServices, id=" << id
;
329 if (!service_discovery_client_
)
331 DCHECK(ContainsKey(service_watchers_
, id
));
332 service_watchers_
[id
]->DiscoverNewServices(force_update
);
335 void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id
) {
336 VLOG(1) << "DestoryWatcher, id=" << id
;
337 if (!service_discovery_client_
)
339 service_watchers_
.erase(id
);
342 void ServiceDiscoveryMessageHandler::ResolveService(
344 const std::string
& service_name
) {
345 VLOG(1) << "ResolveService, id=" << id
<< ", name=" << service_name
;
346 if (!service_discovery_client_
)
348 DCHECK(!ContainsKey(service_resolvers_
, id
));
349 scoped_ptr
<ServiceResolver
> resolver(
350 service_discovery_client_
->CreateServiceResolver(
352 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved
,
353 base::Unretained(this), id
)));
354 resolver
->StartResolving();
355 service_resolvers_
[id
].reset(resolver
.release());
358 void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id
) {
359 VLOG(1) << "DestroyResolver, id=" << id
;
360 if (!service_discovery_client_
)
362 service_resolvers_
.erase(id
);
365 void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
367 const std::string
& domain
,
368 net::AddressFamily address_family
) {
369 VLOG(1) << "ResolveLocalDomain, id=" << id
<< ", domain=" << domain
;
370 if (!service_discovery_client_
)
372 DCHECK(!ContainsKey(local_domain_resolvers_
, id
));
373 scoped_ptr
<LocalDomainResolver
> resolver(
374 service_discovery_client_
->CreateLocalDomainResolver(
375 domain
, address_family
,
376 base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved
,
377 base::Unretained(this), id
)));
379 local_domain_resolvers_
[id
].reset(resolver
.release());
382 void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id
) {
383 VLOG(1) << "DestroyLocalDomainResolver, id=" << id
;
384 if (!service_discovery_client_
)
386 local_domain_resolvers_
.erase(id
);
389 void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
390 if (!discovery_task_runner_
)
393 discovery_task_runner_
->PostTask(
395 base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread
,
396 base::Unretained(this)));
398 // This will wait for message loop to drain, so ShutdownOnIOThread will
399 // definitely be called.
400 discovery_thread_
.reset();
403 void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
404 VLOG(1) << "ShutdownLocalDiscovery";
405 service_watchers_
.clear();
406 service_resolvers_
.clear();
407 local_domain_resolvers_
.clear();
408 service_discovery_client_
.reset();
409 mdns_client_
.reset();
412 void ServiceDiscoveryMessageHandler::OnServiceUpdated(
414 ServiceWatcher::UpdateType update
,
415 const std::string
& name
) {
416 VLOG(1) << "OnServiceUpdated, id=" << id
417 << ", status=" << WatcherUpdateToString(update
) << ", name=" << name
;
418 DCHECK(service_discovery_client_
);
420 Send(new LocalDiscoveryHostMsg_WatcherCallback(id
, update
, name
));
423 void ServiceDiscoveryMessageHandler::OnServiceResolved(
425 ServiceResolver::RequestStatus status
,
426 const ServiceDescription
& description
) {
427 VLOG(1) << "OnServiceResolved, id=" << id
428 << ", status=" << ResolverStatusToString(status
)
429 << ", name=" << description
.service_name
;
431 DCHECK(service_discovery_client_
);
432 Send(new LocalDiscoveryHostMsg_ResolverCallback(id
, status
, description
));
435 void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
438 const net::IPAddressNumber
& address_ipv4
,
439 const net::IPAddressNumber
& address_ipv6
) {
440 VLOG(1) << "OnLocalDomainResolved, id=" << id
441 << ", IPv4=" << (address_ipv4
.empty() ? "" :
442 net::IPAddressToString(address_ipv4
))
443 << ", IPv6=" << (address_ipv6
.empty() ? "" :
444 net::IPAddressToString(address_ipv6
));
446 DCHECK(service_discovery_client_
);
447 Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
448 id
, success
, address_ipv4
, address_ipv6
));
451 void ServiceDiscoveryMessageHandler::Send(IPC::Message
* msg
) {
452 utility_task_runner_
->PostTask(FROM_HERE
,
453 base::Bind(&SendHostMessageOnUtilityThread
,
457 } // namespace local_discovery