1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
8 #include <sys/socket.h>
11 #include <sys/types.h>
19 #include "base/bind.h"
20 #include "base/callback.h"
21 #include "base/compiler_specific.h"
22 #include "base/file_util.h"
23 #include "base/files/file_path.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/scoped_ptr.h"
26 #include "base/message_loop.h"
27 #include "base/posix/eintr_wrapper.h"
28 #include "base/synchronization/condition_variable.h"
29 #include "base/synchronization/lock.h"
30 #include "base/threading/platform_thread.h"
31 #include "base/threading/thread.h"
32 #include "net/socket/unix_domain_socket_posix.h"
33 #include "testing/gtest/include/gtest/gtest.h"
41 const char kSocketFilename
[] = "unix_domain_socket_for_testing";
42 const char kInvalidSocketPath
[] = "/invalid/path";
43 const char kMsg
[] = "hello";
54 string
MakeSocketPath() {
55 base::FilePath temp_dir
;
56 file_util::GetTempDir(&temp_dir
);
57 return temp_dir
.Append(kSocketFilename
).value();
60 class EventManager
: public base::RefCounted
<EventManager
> {
62 EventManager() : condition_(&mutex_
) {}
64 bool HasPendingEvent() {
65 base::AutoLock
lock(mutex_
);
66 return !events_
.empty();
69 void Notify(EventType event
) {
70 base::AutoLock
lock(mutex_
);
72 condition_
.Broadcast();
75 EventType
WaitForEvent() {
76 base::AutoLock
lock(mutex_
);
77 while (events_
.empty())
79 EventType event
= events_
.front();
85 friend class base::RefCounted
<EventManager
>;
86 virtual ~EventManager() {}
88 queue
<EventType
> events_
;
90 base::ConditionVariable condition_
;
93 class TestListenSocketDelegate
: public StreamListenSocket::Delegate
{
95 explicit TestListenSocketDelegate(
96 const scoped_refptr
<EventManager
>& event_manager
)
97 : event_manager_(event_manager
) {}
99 virtual void DidAccept(StreamListenSocket
* server
,
100 StreamListenSocket
* connection
) OVERRIDE
{
101 LOG(ERROR
) << __PRETTY_FUNCTION__
;
102 connection_
= connection
;
103 Notify(EVENT_ACCEPT
);
106 virtual void DidRead(StreamListenSocket
* connection
,
110 base::AutoLock
lock(mutex_
);
112 data_
.assign(data
, len
- 1);
117 virtual void DidClose(StreamListenSocket
* sock
) OVERRIDE
{
121 void OnListenCompleted() {
122 Notify(EVENT_LISTEN
);
125 string
ReceivedData() {
126 base::AutoLock
lock(mutex_
);
131 void Notify(EventType event
) {
132 event_manager_
->Notify(event
);
135 const scoped_refptr
<EventManager
> event_manager_
;
136 scoped_refptr
<StreamListenSocket
> connection_
;
141 bool UserCanConnectCallback(
142 bool allow_user
, const scoped_refptr
<EventManager
>& event_manager
,
144 event_manager
->Notify(
145 allow_user
? EVENT_AUTH_GRANTED
: EVENT_AUTH_DENIED
);
149 class UnixDomainSocketTestHelper
: public testing::Test
{
151 void CreateAndListen() {
152 socket_
= UnixDomainSocket::CreateAndListen(
153 file_path_
.value(), socket_delegate_
.get(), MakeAuthCallback());
154 socket_delegate_
->OnListenCompleted();
158 UnixDomainSocketTestHelper(const string
& path
, bool allow_user
)
160 allow_user_(allow_user
) {}
162 virtual void SetUp() OVERRIDE
{
163 event_manager_
= new EventManager();
164 socket_delegate_
.reset(new TestListenSocketDelegate(event_manager_
));
168 virtual void TearDown() OVERRIDE
{
171 socket_delegate_
.reset();
172 event_manager_
= NULL
;
175 UnixDomainSocket::AuthCallback
MakeAuthCallback() {
176 return base::Bind(&UserCanConnectCallback
, allow_user_
, event_manager_
);
179 void DeleteSocketFile() {
180 ASSERT_FALSE(file_path_
.empty());
181 file_util::Delete(file_path_
, false /* not recursive */);
184 SocketDescriptor
CreateClientSocket() {
185 const SocketDescriptor sock
= socket(PF_UNIX
, SOCK_STREAM
, 0);
187 LOG(ERROR
) << "socket() error";
188 return StreamListenSocket::kInvalidSocket
;
191 memset(&addr
, 0, sizeof(addr
));
192 addr
.sun_family
= AF_UNIX
;
194 strncpy(addr
.sun_path
, file_path_
.value().c_str(), sizeof(addr
.sun_path
));
195 addr_len
= sizeof(sockaddr_un
);
196 if (connect(sock
, reinterpret_cast<sockaddr
*>(&addr
), addr_len
) != 0) {
197 LOG(ERROR
) << "connect() error";
198 return StreamListenSocket::kInvalidSocket
;
203 scoped_ptr
<base::Thread
> CreateAndRunServerThread() {
204 base::Thread::Options options
;
205 options
.message_loop_type
= MessageLoop::TYPE_IO
;
206 scoped_ptr
<base::Thread
> thread(new base::Thread("socketio_test"));
207 thread
->StartWithOptions(options
);
208 thread
->message_loop()->PostTask(
210 base::Bind(&UnixDomainSocketTestHelper::CreateAndListen
,
211 base::Unretained(this)));
212 return thread
.Pass();
215 const base::FilePath file_path_
;
216 const bool allow_user_
;
217 scoped_refptr
<EventManager
> event_manager_
;
218 scoped_ptr
<TestListenSocketDelegate
> socket_delegate_
;
219 scoped_refptr
<UnixDomainSocket
> socket_
;
222 class UnixDomainSocketTest
: public UnixDomainSocketTestHelper
{
224 UnixDomainSocketTest()
225 : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
228 class UnixDomainSocketTestWithInvalidPath
: public UnixDomainSocketTestHelper
{
230 UnixDomainSocketTestWithInvalidPath()
231 : UnixDomainSocketTestHelper(kInvalidSocketPath
, true) {}
234 class UnixDomainSocketTestWithForbiddenUser
235 : public UnixDomainSocketTestHelper
{
237 UnixDomainSocketTestWithForbiddenUser()
238 : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
241 TEST_F(UnixDomainSocketTest
, CreateAndListen
) {
243 EXPECT_FALSE(socket_
.get() == NULL
);
246 TEST_F(UnixDomainSocketTestWithInvalidPath
, CreateAndListenWithInvalidPath
) {
248 EXPECT_TRUE(socket_
.get() == NULL
);
251 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
252 // Test with an invalid path to make sure that the socket is not backed by a
254 TEST_F(UnixDomainSocketTestWithInvalidPath
,
255 CreateAndListenWithAbstractNamespace
) {
256 socket_
= UnixDomainSocket::CreateAndListenWithAbstractNamespace(
257 file_path_
.value(), socket_delegate_
.get(), MakeAuthCallback());
258 EXPECT_FALSE(socket_
.get() == NULL
);
262 TEST_F(UnixDomainSocketTest
, TestWithClient
) {
263 const scoped_ptr
<base::Thread
> server_thread
= CreateAndRunServerThread();
264 EventType event
= event_manager_
->WaitForEvent();
265 ASSERT_EQ(EVENT_LISTEN
, event
);
267 // Create the client socket.
268 const SocketDescriptor sock
= CreateClientSocket();
269 ASSERT_NE(StreamListenSocket::kInvalidSocket
, sock
);
270 event
= event_manager_
->WaitForEvent();
271 ASSERT_EQ(EVENT_AUTH_GRANTED
, event
);
272 event
= event_manager_
->WaitForEvent();
273 ASSERT_EQ(EVENT_ACCEPT
, event
);
275 // Send a message from the client to the server.
276 ssize_t ret
= HANDLE_EINTR(send(sock
, kMsg
, sizeof(kMsg
), 0));
278 ASSERT_EQ(sizeof(kMsg
), static_cast<size_t>(ret
));
279 event
= event_manager_
->WaitForEvent();
280 ASSERT_EQ(EVENT_READ
, event
);
281 ASSERT_EQ(kMsg
, socket_delegate_
->ReceivedData());
283 // Close the client socket.
284 ret
= HANDLE_EINTR(close(sock
));
285 event
= event_manager_
->WaitForEvent();
286 ASSERT_EQ(EVENT_CLOSE
, event
);
289 TEST_F(UnixDomainSocketTestWithForbiddenUser
, TestWithForbiddenUser
) {
290 const scoped_ptr
<base::Thread
> server_thread
= CreateAndRunServerThread();
291 EventType event
= event_manager_
->WaitForEvent();
292 ASSERT_EQ(EVENT_LISTEN
, event
);
293 const SocketDescriptor sock
= CreateClientSocket();
294 ASSERT_NE(StreamListenSocket::kInvalidSocket
, sock
);
296 event
= event_manager_
->WaitForEvent();
297 ASSERT_EQ(EVENT_AUTH_DENIED
, event
);
299 // Wait until the file descriptor is closed by the server.
300 struct pollfd poll_fd
;
302 poll_fd
.events
= POLLIN
;
303 poll(&poll_fd
, 1, -1 /* rely on GTest for timeout handling */);
306 ssize_t ret
= HANDLE_EINTR(send(sock
, kMsg
, sizeof(kMsg
), 0));
308 ASSERT_EQ(EPIPE
, errno
);
309 ASSERT_FALSE(event_manager_
->HasPendingEvent());