1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h"
9 #include "base/lazy_instance.h"
10 #include "base/location.h"
11 #include "base/single_thread_task_runner.h"
12 #include "chrome/common/local_discovery/local_discovery_messages.h"
13 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
14 #include "content/public/utility/utility_thread.h"
15 #include "net/socket/socket_descriptor.h"
16 #include "net/udp/datagram_server_socket.h"
18 namespace local_discovery
{
22 void ClosePlatformSocket(net::SocketDescriptor socket
);
24 // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
25 // keeps single socket that will be returned to the first call to
26 // |net::CreatePlatformSocket| during object lifetime.
27 class ScopedSocketFactory
: public net::PlatformSocketFactory
{
29 explicit ScopedSocketFactory(net::SocketDescriptor socket
) : socket_(socket
) {
30 net::PlatformSocketFactory::SetInstance(this);
33 ~ScopedSocketFactory() override
{
34 net::PlatformSocketFactory::SetInstance(NULL
);
35 ClosePlatformSocket(socket_
);
36 socket_
= net::kInvalidSocket
;
39 net::SocketDescriptor
CreateSocket(int family
,
41 int protocol
) override
{
42 DCHECK_EQ(type
, SOCK_DGRAM
);
43 DCHECK(family
== AF_INET
|| family
== AF_INET6
);
44 net::SocketDescriptor result
= net::kInvalidSocket
;
45 std::swap(result
, socket_
);
50 net::SocketDescriptor socket_
;
51 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory
);
55 SocketInfo(net::SocketDescriptor socket
,
56 net::AddressFamily address_family
,
57 uint32 interface_index
)
59 address_family(address_family
),
60 interface_index(interface_index
) {
62 net::SocketDescriptor socket
;
63 net::AddressFamily address_family
;
64 uint32 interface_index
;
67 // Returns list of sockets preallocated before.
68 class PreCreatedMDnsSocketFactory
: public net::MDnsSocketFactory
{
70 PreCreatedMDnsSocketFactory() {}
71 ~PreCreatedMDnsSocketFactory() override
{
72 // Not empty if process exits too fast, before starting mDns code. If
73 // happened, destructors may crash accessing destroyed global objects.
74 sockets_
.weak_clear();
77 // net::MDnsSocketFactory implementation:
79 ScopedVector
<net::DatagramServerSocket
>* sockets
) override
{
80 sockets
->swap(sockets_
);
84 void AddSocket(const SocketInfo
& socket_info
) {
85 // Takes ownership of socket_info.socket;
86 ScopedSocketFactory
platform_factory(socket_info
.socket
);
87 scoped_ptr
<net::DatagramServerSocket
> socket(
88 net::CreateAndBindMDnsSocket(socket_info
.address_family
,
89 socket_info
.interface_index
));
91 socket
->DetachFromThread();
92 sockets_
.push_back(socket
.release());
101 ScopedVector
<net::DatagramServerSocket
> sockets_
;
103 DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory
);
106 base::LazyInstance
<PreCreatedMDnsSocketFactory
>
107 g_local_discovery_socket_factory
= LAZY_INSTANCE_INITIALIZER
;
111 void ClosePlatformSocket(net::SocketDescriptor socket
) {
112 ::closesocket(socket
);
115 void StaticInitializeSocketFactory() {
116 net::InterfaceIndexFamilyList
interfaces(net::GetMDnsInterfacesToBind());
117 for (size_t i
= 0; i
< interfaces
.size(); ++i
) {
118 DCHECK(interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV4
||
119 interfaces
[i
].second
== net::ADDRESS_FAMILY_IPV6
);
120 net::SocketDescriptor descriptor
=
121 net::CreatePlatformSocket(
122 net::ConvertAddressFamily(interfaces
[i
].second
), SOCK_DGRAM
,
124 g_local_discovery_socket_factory
.Get().AddSocket(
125 SocketInfo(descriptor
, interfaces
[i
].second
, interfaces
[i
].first
));
131 void ClosePlatformSocket(net::SocketDescriptor socket
) {
135 void StaticInitializeSocketFactory() {
140 void SendHostMessageOnUtilityThread(IPC::Message
* msg
) {
141 content::UtilityThread::Get()->Send(msg
);
144 std::string
WatcherUpdateToString(ServiceWatcher::UpdateType update
) {
146 case ServiceWatcher::UPDATE_ADDED
:
147 return "UPDATE_ADDED";
148 case ServiceWatcher::UPDATE_CHANGED
:
149 return "UPDATE_CHANGED";
150 case ServiceWatcher::UPDATE_REMOVED
:
151 return "UPDATE_REMOVED";
152 case ServiceWatcher::UPDATE_INVALIDATED
:
153 return "UPDATE_INVALIDATED";
155 return "Unknown Update";
158 std::string
ResolverStatusToString(ServiceResolver::RequestStatus status
) {
160 case ServiceResolver::STATUS_SUCCESS
:
161 return "STATUS_SUCESS";
162 case ServiceResolver::STATUS_REQUEST_TIMEOUT
:
163 return "STATUS_REQUEST_TIMEOUT";
164 case ServiceResolver::STATUS_KNOWN_NONEXISTENT
:
165 return "STATUS_KNOWN_NONEXISTENT";
167 return "Unknown Status";
172 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
175 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
176 DCHECK(!discovery_thread_
);
179 void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
180 StaticInitializeSocketFactory();
183 void ServiceDiscoveryMessageHandler::InitializeMdns() {
184 if (service_discovery_client_
|| mdns_client_
)
187 mdns_client_
= net::MDnsClient::CreateDefault();
189 mdns_client_
->StartListening(g_local_discovery_socket_factory
.Pointer());
190 // Close unused sockets.
191 g_local_discovery_socket_factory
.Get().Reset();
193 VLOG(1) << "Failed to start MDnsClient";
194 Send(new LocalDiscoveryHostMsg_Error());
198 service_discovery_client_
.reset(
199 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_
.get()));
202 bool ServiceDiscoveryMessageHandler::InitializeThread() {
203 if (discovery_task_runner_
.get())
205 if (discovery_thread_
)
207 utility_task_runner_
= base::MessageLoop::current()->task_runner();
208 discovery_thread_
.reset(new base::Thread("ServiceDiscoveryThread"));
209 base::Thread::Options
thread_options(base::MessageLoop::TYPE_IO
, 0);
210 if (discovery_thread_
->StartWithOptions(thread_options
)) {
211 discovery_task_runner_
= discovery_thread_
->task_runner();
212 discovery_task_runner_
->PostTask(FROM_HERE
,
213 base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns
,
214 base::Unretained(this)));
216 return discovery_task_runner_
.get() != NULL
;
219 bool ServiceDiscoveryMessageHandler::OnMessageReceived(
220 const IPC::Message
& message
) {
222 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler
, message
)
223 #if defined(OS_POSIX)
224 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets
, OnSetSockets
)
226 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher
, OnStartWatcher
)
227 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices
, OnDiscoverServices
)
228 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices
,
229 OnSetActivelyRefreshServices
)
230 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher
, OnDestroyWatcher
)
231 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService
, OnResolveService
)
232 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver
, OnDestroyResolver
)
233 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain
,
234 OnResolveLocalDomain
)
235 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver
,
236 OnDestroyLocalDomainResolver
)
237 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery
,
238 ShutdownLocalDiscovery
)
239 IPC_MESSAGE_UNHANDLED(handled
= false)
240 IPC_END_MESSAGE_MAP()
244 void ServiceDiscoveryMessageHandler::PostTask(
245 const tracked_objects::Location
& from_here
,
246 const base::Closure
& task
) {
247 if (!InitializeThread())
249 discovery_task_runner_
->PostTask(from_here
, task
);
252 #if defined(OS_POSIX)
253 void ServiceDiscoveryMessageHandler::OnSetSockets(
254 const std::vector
<LocalDiscoveryMsg_SocketInfo
>& sockets
) {
255 for (size_t i
= 0; i
< sockets
.size(); ++i
) {
256 g_local_discovery_socket_factory
.Get().AddSocket(
257 SocketInfo(sockets
[i
].descriptor
.fd
, sockets
[i
].address_family
,
258 sockets
[i
].interface_index
));
263 void ServiceDiscoveryMessageHandler::OnStartWatcher(
265 const std::string
& service_type
) {
267 base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher
,
268 base::Unretained(this), id
, service_type
));
271 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id
,
274 base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices
,
275 base::Unretained(this), id
, force_update
));
278 void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
279 uint64 id
, bool actively_refresh_services
) {
282 &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices
,
283 base::Unretained(this), id
, actively_refresh_services
));
286 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id
) {
288 base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher
,
289 base::Unretained(this), id
));
292 void ServiceDiscoveryMessageHandler::OnResolveService(
294 const std::string
& service_name
) {
296 base::Bind(&ServiceDiscoveryMessageHandler::ResolveService
,
297 base::Unretained(this), id
, service_name
));
300 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id
) {
302 base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver
,
303 base::Unretained(this), id
));
306 void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
307 uint64 id
, const std::string
& domain
,
308 net::AddressFamily address_family
) {
310 base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain
,
311 base::Unretained(this), id
, domain
, address_family
));
314 void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id
) {
317 &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver
,
318 base::Unretained(this), id
));
321 void ServiceDiscoveryMessageHandler::StartWatcher(
323 const std::string
& service_type
) {
324 VLOG(1) << "StartWatcher, id=" << id
<< ", type=" << service_type
;
325 if (!service_discovery_client_
)
327 DCHECK(!ContainsKey(service_watchers_
, id
));
328 scoped_ptr
<ServiceWatcher
> watcher(
329 service_discovery_client_
->CreateServiceWatcher(
331 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated
,
332 base::Unretained(this), id
)));
334 service_watchers_
[id
].reset(watcher
.release());
337 void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id
,
339 VLOG(1) << "DiscoverServices, id=" << id
;
340 if (!service_discovery_client_
)
342 DCHECK(ContainsKey(service_watchers_
, id
));
343 service_watchers_
[id
]->DiscoverNewServices(force_update
);
346 void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
348 bool actively_refresh_services
) {
349 VLOG(1) << "ActivelyRefreshServices, id=" << id
;
350 if (!service_discovery_client_
)
352 DCHECK(ContainsKey(service_watchers_
, id
));
353 service_watchers_
[id
]->SetActivelyRefreshServices(actively_refresh_services
);
356 void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id
) {
357 VLOG(1) << "DestoryWatcher, id=" << id
;
358 if (!service_discovery_client_
)
360 service_watchers_
.erase(id
);
363 void ServiceDiscoveryMessageHandler::ResolveService(
365 const std::string
& service_name
) {
366 VLOG(1) << "ResolveService, id=" << id
<< ", name=" << service_name
;
367 if (!service_discovery_client_
)
369 DCHECK(!ContainsKey(service_resolvers_
, id
));
370 scoped_ptr
<ServiceResolver
> resolver(
371 service_discovery_client_
->CreateServiceResolver(
373 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved
,
374 base::Unretained(this), id
)));
375 resolver
->StartResolving();
376 service_resolvers_
[id
].reset(resolver
.release());
379 void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id
) {
380 VLOG(1) << "DestroyResolver, id=" << id
;
381 if (!service_discovery_client_
)
383 service_resolvers_
.erase(id
);
386 void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
388 const std::string
& domain
,
389 net::AddressFamily address_family
) {
390 VLOG(1) << "ResolveLocalDomain, id=" << id
<< ", domain=" << domain
;
391 if (!service_discovery_client_
)
393 DCHECK(!ContainsKey(local_domain_resolvers_
, id
));
394 scoped_ptr
<LocalDomainResolver
> resolver(
395 service_discovery_client_
->CreateLocalDomainResolver(
396 domain
, address_family
,
397 base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved
,
398 base::Unretained(this), id
)));
400 local_domain_resolvers_
[id
].reset(resolver
.release());
403 void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id
) {
404 VLOG(1) << "DestroyLocalDomainResolver, id=" << id
;
405 if (!service_discovery_client_
)
407 local_domain_resolvers_
.erase(id
);
410 void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
411 if (!discovery_task_runner_
.get())
414 discovery_task_runner_
->PostTask(
416 base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread
,
417 base::Unretained(this)));
419 // This will wait for message loop to drain, so ShutdownOnIOThread will
420 // definitely be called.
421 discovery_thread_
.reset();
424 void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
425 VLOG(1) << "ShutdownLocalDiscovery";
426 service_watchers_
.clear();
427 service_resolvers_
.clear();
428 local_domain_resolvers_
.clear();
429 service_discovery_client_
.reset();
430 mdns_client_
.reset();
433 void ServiceDiscoveryMessageHandler::OnServiceUpdated(
435 ServiceWatcher::UpdateType update
,
436 const std::string
& name
) {
437 VLOG(1) << "OnServiceUpdated, id=" << id
438 << ", status=" << WatcherUpdateToString(update
) << ", name=" << name
;
439 DCHECK(service_discovery_client_
);
441 Send(new LocalDiscoveryHostMsg_WatcherCallback(id
, update
, name
));
444 void ServiceDiscoveryMessageHandler::OnServiceResolved(
446 ServiceResolver::RequestStatus status
,
447 const ServiceDescription
& description
) {
448 VLOG(1) << "OnServiceResolved, id=" << id
449 << ", status=" << ResolverStatusToString(status
)
450 << ", name=" << description
.service_name
;
452 DCHECK(service_discovery_client_
);
453 Send(new LocalDiscoveryHostMsg_ResolverCallback(id
, status
, description
));
456 void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
459 const net::IPAddressNumber
& address_ipv4
,
460 const net::IPAddressNumber
& address_ipv6
) {
461 VLOG(1) << "OnLocalDomainResolved, id=" << id
462 << ", IPv4=" << (address_ipv4
.empty() ? "" :
463 net::IPAddressToString(address_ipv4
))
464 << ", IPv6=" << (address_ipv6
.empty() ? "" :
465 net::IPAddressToString(address_ipv6
));
467 DCHECK(service_discovery_client_
);
468 Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
469 id
, success
, address_ipv4
, address_ipv6
));
472 void ServiceDiscoveryMessageHandler::Send(IPC::Message
* msg
) {
473 utility_task_runner_
->PostTask(FROM_HERE
,
474 base::Bind(&SendHostMessageOnUtilityThread
,
478 } // namespace local_discovery