1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <sys/socket.h>
8 #include "base/file_util.h"
9 #include "base/files/file_path.h"
10 #include "base/path_service.h"
11 #include "base/synchronization/waitable_event.h"
12 #include "base/threading/thread.h"
13 #include "base/threading/thread_restrictions.h"
14 #include "ipc/unix_domain_socket_util.h"
15 #include "testing/gtest/include/gtest/gtest.h"
19 class SocketAcceptor
: public MessageLoopForIO::Watcher
{
21 SocketAcceptor(int fd
, base::MessageLoopProxy
* target_thread
)
23 target_thread_(target_thread
),
24 started_watching_event_(false, false),
25 accepted_event_(false, false) {
26 target_thread
->PostTask(FROM_HERE
,
27 base::Bind(&SocketAcceptor::StartWatching
, base::Unretained(this), fd
));
30 virtual ~SocketAcceptor() {
34 int server_fd() const { return server_fd_
; }
36 void WaitUntilReady() {
37 started_watching_event_
.Wait();
40 void WaitForAccept() {
41 accepted_event_
.Wait();
46 target_thread_
->PostTask(FROM_HERE
,
47 base::Bind(&SocketAcceptor::StopWatching
, base::Unretained(this),
53 void StartWatching(int fd
) {
54 watcher_
.reset(new MessageLoopForIO::FileDescriptorWatcher
);
55 MessageLoopForIO::current()->WatchFileDescriptor(
58 MessageLoopForIO::WATCH_READ
,
61 started_watching_event_
.Signal();
63 void StopWatching(MessageLoopForIO::FileDescriptorWatcher
* watcher
) {
64 watcher
->StopWatchingFileDescriptor();
67 virtual void OnFileCanReadWithoutBlocking(int fd
) OVERRIDE
{
68 ASSERT_EQ(-1, server_fd_
);
69 IPC::ServerAcceptConnection(fd
, &server_fd_
);
70 watcher_
->StopWatchingFileDescriptor();
71 accepted_event_
.Signal();
73 virtual void OnFileCanWriteWithoutBlocking(int fd
) OVERRIDE
{}
76 base::MessageLoopProxy
* target_thread_
;
77 scoped_ptr
<MessageLoopForIO::FileDescriptorWatcher
> watcher_
;
78 base::WaitableEvent started_watching_event_
;
79 base::WaitableEvent accepted_event_
;
81 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor
);
84 const base::FilePath
GetChannelDir() {
85 #if defined(OS_ANDROID)
86 base::FilePath tmp_dir
;
87 PathService::Get(base::DIR_CACHE
, &tmp_dir
);
90 return base::FilePath("/var/tmp");
94 class TestUnixSocketConnection
{
96 TestUnixSocketConnection()
97 : worker_("WorkerThread"),
98 server_listen_fd_(-1),
101 socket_name_
= GetChannelDir().Append("TestSocket");
102 base::Thread::Options options
;
103 options
.message_loop_type
= MessageLoop::TYPE_IO
;
104 worker_
.StartWithOptions(options
);
107 bool CreateServerSocket() {
108 IPC::CreateServerUnixDomainSocket(socket_name_
, &server_listen_fd_
);
109 if (server_listen_fd_
< 0)
111 struct stat socket_stat
;
112 stat(socket_name_
.value().c_str(), &socket_stat
);
113 EXPECT_TRUE(S_ISSOCK(socket_stat
.st_mode
));
114 acceptor_
.reset(new SocketAcceptor(server_listen_fd_
,
115 worker_
.message_loop_proxy()));
116 acceptor_
->WaitUntilReady();
120 bool CreateClientSocket() {
121 DCHECK(server_listen_fd_
>= 0);
122 IPC::CreateClientUnixDomainSocket(socket_name_
, &client_fd_
);
125 acceptor_
->WaitForAccept();
126 server_fd_
= acceptor_
->server_fd();
127 return server_fd_
>= 0;
130 virtual ~TestUnixSocketConnection() {
135 if (server_listen_fd_
>= 0) {
136 close(server_listen_fd_
);
137 unlink(socket_name_
.value().c_str());
141 int client_fd() const { return client_fd_
; }
142 int server_fd() const { return server_fd_
; }
145 base::Thread worker_
;
146 base::FilePath socket_name_
;
147 int server_listen_fd_
;
150 scoped_ptr
<SocketAcceptor
> acceptor_
;
153 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
154 // IPC::CreateClientUnixDomainSocket can successfully connect to.
155 TEST(UnixDomainSocketUtil
, Connect
) {
156 TestUnixSocketConnection connection
;
157 ASSERT_TRUE(connection
.CreateServerSocket());
158 ASSERT_TRUE(connection
.CreateClientSocket());
161 // Ensure that messages can be sent across the resulting socket.
162 TEST(UnixDomainSocketUtil
, SendReceive
) {
163 TestUnixSocketConnection connection
;
164 ASSERT_TRUE(connection
.CreateServerSocket());
165 ASSERT_TRUE(connection
.CreateClientSocket());
167 const char buffer
[] = "Hello, server!";
168 size_t buf_len
= sizeof(buffer
);
170 HANDLE_EINTR(send(connection
.client_fd(), buffer
, buf_len
, 0));
171 ASSERT_EQ(buf_len
, sent_bytes
);
172 char recv_buf
[sizeof(buffer
)];
173 size_t received_bytes
=
174 HANDLE_EINTR(recv(connection
.server_fd(), recv_buf
, buf_len
, 0));
175 ASSERT_EQ(buf_len
, received_bytes
);
176 ASSERT_EQ(0, memcmp(recv_buf
, buffer
, buf_len
));