1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
8 #include <sys/socket.h>
11 #include <sys/types.h>
19 #include "base/bind.h"
20 #include "base/callback.h"
21 #include "base/compiler_specific.h"
22 #include "base/file_util.h"
23 #include "base/files/file_path.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/scoped_ptr.h"
26 #include "base/message_loop/message_loop.h"
27 #include "base/posix/eintr_wrapper.h"
28 #include "base/synchronization/condition_variable.h"
29 #include "base/synchronization/lock.h"
30 #include "base/threading/platform_thread.h"
31 #include "base/threading/thread.h"
32 #include "net/socket/socket_descriptor.h"
33 #include "net/socket/unix_domain_socket_posix.h"
34 #include "testing/gtest/include/gtest/gtest.h"
42 const char kSocketFilename
[] = "unix_domain_socket_for_testing";
43 const char kInvalidSocketPath
[] = "/invalid/path";
44 const char kMsg
[] = "hello";
55 string
MakeSocketPath(const string
& socket_file_name
) {
56 base::FilePath temp_dir
;
57 base::GetTempDir(&temp_dir
);
58 return temp_dir
.Append(socket_file_name
).value();
61 string
MakeSocketPath() {
62 return MakeSocketPath(kSocketFilename
);
65 class EventManager
: public base::RefCounted
<EventManager
> {
67 EventManager() : condition_(&mutex_
) {}
69 bool HasPendingEvent() {
70 base::AutoLock
lock(mutex_
);
71 return !events_
.empty();
74 void Notify(EventType event
) {
75 base::AutoLock
lock(mutex_
);
77 condition_
.Broadcast();
80 EventType
WaitForEvent() {
81 base::AutoLock
lock(mutex_
);
82 while (events_
.empty())
84 EventType event
= events_
.front();
90 friend class base::RefCounted
<EventManager
>;
91 virtual ~EventManager() {}
93 queue
<EventType
> events_
;
95 base::ConditionVariable condition_
;
98 class TestListenSocketDelegate
: public StreamListenSocket::Delegate
{
100 explicit TestListenSocketDelegate(
101 const scoped_refptr
<EventManager
>& event_manager
)
102 : event_manager_(event_manager
) {}
104 virtual void DidAccept(StreamListenSocket
* server
,
105 scoped_ptr
<StreamListenSocket
> connection
) OVERRIDE
{
106 LOG(ERROR
) << __PRETTY_FUNCTION__
;
107 connection_
= connection
.Pass();
108 Notify(EVENT_ACCEPT
);
111 virtual void DidRead(StreamListenSocket
* connection
,
115 base::AutoLock
lock(mutex_
);
117 data_
.assign(data
, len
- 1);
122 virtual void DidClose(StreamListenSocket
* sock
) OVERRIDE
{
126 void OnListenCompleted() {
127 Notify(EVENT_LISTEN
);
130 string
ReceivedData() {
131 base::AutoLock
lock(mutex_
);
136 void Notify(EventType event
) {
137 event_manager_
->Notify(event
);
140 const scoped_refptr
<EventManager
> event_manager_
;
141 scoped_ptr
<StreamListenSocket
> connection_
;
146 bool UserCanConnectCallback(
147 bool allow_user
, const scoped_refptr
<EventManager
>& event_manager
,
149 event_manager
->Notify(
150 allow_user
? EVENT_AUTH_GRANTED
: EVENT_AUTH_DENIED
);
154 class UnixDomainSocketTestHelper
: public testing::Test
{
156 void CreateAndListen() {
157 socket_
= UnixDomainSocket::CreateAndListen(
158 file_path_
.value(), socket_delegate_
.get(), MakeAuthCallback());
159 socket_delegate_
->OnListenCompleted();
163 UnixDomainSocketTestHelper(const string
& path
, bool allow_user
)
165 allow_user_(allow_user
) {}
167 virtual void SetUp() OVERRIDE
{
168 event_manager_
= new EventManager();
169 socket_delegate_
.reset(new TestListenSocketDelegate(event_manager_
));
173 virtual void TearDown() OVERRIDE
{
176 socket_delegate_
.reset();
177 event_manager_
= NULL
;
180 UnixDomainSocket::AuthCallback
MakeAuthCallback() {
181 return base::Bind(&UserCanConnectCallback
, allow_user_
, event_manager_
);
184 void DeleteSocketFile() {
185 ASSERT_FALSE(file_path_
.empty());
186 base::DeleteFile(file_path_
, false /* not recursive */);
189 SocketDescriptor
CreateClientSocket() {
190 const SocketDescriptor sock
= CreatePlatformSocket(PF_UNIX
, SOCK_STREAM
, 0);
192 LOG(ERROR
) << "socket() error";
193 return kInvalidSocket
;
196 memset(&addr
, 0, sizeof(addr
));
197 addr
.sun_family
= AF_UNIX
;
199 strncpy(addr
.sun_path
, file_path_
.value().c_str(), sizeof(addr
.sun_path
));
200 addr_len
= sizeof(sockaddr_un
);
201 if (connect(sock
, reinterpret_cast<sockaddr
*>(&addr
), addr_len
) != 0) {
202 LOG(ERROR
) << "connect() error";
203 return kInvalidSocket
;
208 scoped_ptr
<base::Thread
> CreateAndRunServerThread() {
209 base::Thread::Options options
;
210 options
.message_loop_type
= base::MessageLoop::TYPE_IO
;
211 scoped_ptr
<base::Thread
> thread(new base::Thread("socketio_test"));
212 thread
->StartWithOptions(options
);
213 thread
->message_loop()->PostTask(
215 base::Bind(&UnixDomainSocketTestHelper::CreateAndListen
,
216 base::Unretained(this)));
217 return thread
.Pass();
220 const base::FilePath file_path_
;
221 const bool allow_user_
;
222 scoped_refptr
<EventManager
> event_manager_
;
223 scoped_ptr
<TestListenSocketDelegate
> socket_delegate_
;
224 scoped_ptr
<UnixDomainSocket
> socket_
;
227 class UnixDomainSocketTest
: public UnixDomainSocketTestHelper
{
229 UnixDomainSocketTest()
230 : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
233 class UnixDomainSocketTestWithInvalidPath
: public UnixDomainSocketTestHelper
{
235 UnixDomainSocketTestWithInvalidPath()
236 : UnixDomainSocketTestHelper(kInvalidSocketPath
, true) {}
239 class UnixDomainSocketTestWithForbiddenUser
240 : public UnixDomainSocketTestHelper
{
242 UnixDomainSocketTestWithForbiddenUser()
243 : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
246 TEST_F(UnixDomainSocketTest
, CreateAndListen
) {
248 EXPECT_FALSE(socket_
.get() == NULL
);
251 TEST_F(UnixDomainSocketTestWithInvalidPath
, CreateAndListenWithInvalidPath
) {
253 EXPECT_TRUE(socket_
.get() == NULL
);
256 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
257 // Test with an invalid path to make sure that the socket is not backed by a
259 TEST_F(UnixDomainSocketTestWithInvalidPath
,
260 CreateAndListenWithAbstractNamespace
) {
261 socket_
= UnixDomainSocket::CreateAndListenWithAbstractNamespace(
262 file_path_
.value(), "", socket_delegate_
.get(), MakeAuthCallback());
263 EXPECT_FALSE(socket_
.get() == NULL
);
266 TEST_F(UnixDomainSocketTest
, TestFallbackName
) {
267 scoped_ptr
<UnixDomainSocket
> existing_socket
=
268 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
269 file_path_
.value(), "", socket_delegate_
.get(), MakeAuthCallback());
270 EXPECT_FALSE(existing_socket
.get() == NULL
);
271 // First, try to bind socket with the same name with no fallback name.
273 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
274 file_path_
.value(), "", socket_delegate_
.get(), MakeAuthCallback());
275 EXPECT_TRUE(socket_
.get() == NULL
);
276 // Now with a fallback name.
277 const char kFallbackSocketName
[] = "unix_domain_socket_for_testing_2";
278 socket_
= UnixDomainSocket::CreateAndListenWithAbstractNamespace(
280 MakeSocketPath(kFallbackSocketName
),
281 socket_delegate_
.get(),
283 EXPECT_FALSE(socket_
.get() == NULL
);
287 TEST_F(UnixDomainSocketTest
, TestWithClient
) {
288 const scoped_ptr
<base::Thread
> server_thread
= CreateAndRunServerThread();
289 EventType event
= event_manager_
->WaitForEvent();
290 ASSERT_EQ(EVENT_LISTEN
, event
);
292 // Create the client socket.
293 const SocketDescriptor sock
= CreateClientSocket();
294 ASSERT_NE(kInvalidSocket
, sock
);
295 event
= event_manager_
->WaitForEvent();
296 ASSERT_EQ(EVENT_AUTH_GRANTED
, event
);
297 event
= event_manager_
->WaitForEvent();
298 ASSERT_EQ(EVENT_ACCEPT
, event
);
300 // Send a message from the client to the server.
301 ssize_t ret
= HANDLE_EINTR(send(sock
, kMsg
, sizeof(kMsg
), 0));
303 ASSERT_EQ(sizeof(kMsg
), static_cast<size_t>(ret
));
304 event
= event_manager_
->WaitForEvent();
305 ASSERT_EQ(EVENT_READ
, event
);
306 ASSERT_EQ(kMsg
, socket_delegate_
->ReceivedData());
308 // Close the client socket.
309 ret
= IGNORE_EINTR(close(sock
));
310 event
= event_manager_
->WaitForEvent();
311 ASSERT_EQ(EVENT_CLOSE
, event
);
314 TEST_F(UnixDomainSocketTestWithForbiddenUser
, TestWithForbiddenUser
) {
315 const scoped_ptr
<base::Thread
> server_thread
= CreateAndRunServerThread();
316 EventType event
= event_manager_
->WaitForEvent();
317 ASSERT_EQ(EVENT_LISTEN
, event
);
318 const SocketDescriptor sock
= CreateClientSocket();
319 ASSERT_NE(kInvalidSocket
, sock
);
321 event
= event_manager_
->WaitForEvent();
322 ASSERT_EQ(EVENT_AUTH_DENIED
, event
);
324 // Wait until the file descriptor is closed by the server.
325 struct pollfd poll_fd
;
327 poll_fd
.events
= POLLIN
;
328 poll(&poll_fd
, 1, -1 /* rely on GTest for timeout handling */);
331 ssize_t ret
= HANDLE_EINTR(send(sock
, kMsg
, sizeof(kMsg
), 0));
333 ASSERT_EQ(EPIPE
, errno
);
334 ASSERT_FALSE(event_manager_
->HasPendingEvent());