Revert 268405 "Make sure that ScratchBuffer::Allocate() always r..."
[chromium-blink-merge.git] / net / socket / unix_domain_socket_posix_unittest.cc
blobb1857e62e0e19df8f1d226def89fc954d4f8f951
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <errno.h>
6 #include <fcntl.h>
7 #include <poll.h>
8 #include <sys/socket.h>
9 #include <sys/stat.h>
10 #include <sys/time.h>
11 #include <sys/types.h>
12 #include <sys/un.h>
13 #include <unistd.h>
15 #include <cstring>
16 #include <queue>
17 #include <string>
19 #include "base/bind.h"
20 #include "base/callback.h"
21 #include "base/compiler_specific.h"
22 #include "base/file_util.h"
23 #include "base/files/file_path.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/scoped_ptr.h"
26 #include "base/message_loop/message_loop.h"
27 #include "base/posix/eintr_wrapper.h"
28 #include "base/synchronization/condition_variable.h"
29 #include "base/synchronization/lock.h"
30 #include "base/threading/platform_thread.h"
31 #include "base/threading/thread.h"
32 #include "net/socket/socket_descriptor.h"
33 #include "net/socket/unix_domain_socket_posix.h"
34 #include "testing/gtest/include/gtest/gtest.h"
36 using std::queue;
37 using std::string;
39 namespace net {
40 namespace {
42 const char kSocketFilename[] = "unix_domain_socket_for_testing";
43 const char kInvalidSocketPath[] = "/invalid/path";
44 const char kMsg[] = "hello";
46 enum EventType {
47 EVENT_ACCEPT,
48 EVENT_AUTH_DENIED,
49 EVENT_AUTH_GRANTED,
50 EVENT_CLOSE,
51 EVENT_LISTEN,
52 EVENT_READ,
55 string MakeSocketPath(const string& socket_file_name) {
56 base::FilePath temp_dir;
57 base::GetTempDir(&temp_dir);
58 return temp_dir.Append(socket_file_name).value();
61 string MakeSocketPath() {
62 return MakeSocketPath(kSocketFilename);
65 class EventManager : public base::RefCounted<EventManager> {
66 public:
67 EventManager() : condition_(&mutex_) {}
69 bool HasPendingEvent() {
70 base::AutoLock lock(mutex_);
71 return !events_.empty();
74 void Notify(EventType event) {
75 base::AutoLock lock(mutex_);
76 events_.push(event);
77 condition_.Broadcast();
80 EventType WaitForEvent() {
81 base::AutoLock lock(mutex_);
82 while (events_.empty())
83 condition_.Wait();
84 EventType event = events_.front();
85 events_.pop();
86 return event;
89 private:
90 friend class base::RefCounted<EventManager>;
91 virtual ~EventManager() {}
93 queue<EventType> events_;
94 base::Lock mutex_;
95 base::ConditionVariable condition_;
98 class TestListenSocketDelegate : public StreamListenSocket::Delegate {
99 public:
100 explicit TestListenSocketDelegate(
101 const scoped_refptr<EventManager>& event_manager)
102 : event_manager_(event_manager) {}
104 virtual void DidAccept(StreamListenSocket* server,
105 scoped_ptr<StreamListenSocket> connection) OVERRIDE {
106 LOG(ERROR) << __PRETTY_FUNCTION__;
107 connection_ = connection.Pass();
108 Notify(EVENT_ACCEPT);
111 virtual void DidRead(StreamListenSocket* connection,
112 const char* data,
113 int len) OVERRIDE {
115 base::AutoLock lock(mutex_);
116 DCHECK(len);
117 data_.assign(data, len - 1);
119 Notify(EVENT_READ);
122 virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
123 Notify(EVENT_CLOSE);
126 void OnListenCompleted() {
127 Notify(EVENT_LISTEN);
130 string ReceivedData() {
131 base::AutoLock lock(mutex_);
132 return data_;
135 private:
136 void Notify(EventType event) {
137 event_manager_->Notify(event);
140 const scoped_refptr<EventManager> event_manager_;
141 scoped_ptr<StreamListenSocket> connection_;
142 base::Lock mutex_;
143 string data_;
146 bool UserCanConnectCallback(
147 bool allow_user, const scoped_refptr<EventManager>& event_manager,
148 uid_t, gid_t) {
149 event_manager->Notify(
150 allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
151 return allow_user;
154 class UnixDomainSocketTestHelper : public testing::Test {
155 public:
156 void CreateAndListen() {
157 socket_ = UnixDomainSocket::CreateAndListen(
158 file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
159 socket_delegate_->OnListenCompleted();
162 protected:
163 UnixDomainSocketTestHelper(const string& path, bool allow_user)
164 : file_path_(path),
165 allow_user_(allow_user) {}
167 virtual void SetUp() OVERRIDE {
168 event_manager_ = new EventManager();
169 socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
170 DeleteSocketFile();
173 virtual void TearDown() OVERRIDE {
174 DeleteSocketFile();
175 socket_.reset();
176 socket_delegate_.reset();
177 event_manager_ = NULL;
180 UnixDomainSocket::AuthCallback MakeAuthCallback() {
181 return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
184 void DeleteSocketFile() {
185 ASSERT_FALSE(file_path_.empty());
186 base::DeleteFile(file_path_, false /* not recursive */);
189 SocketDescriptor CreateClientSocket() {
190 const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
191 if (sock < 0) {
192 LOG(ERROR) << "socket() error";
193 return kInvalidSocket;
195 sockaddr_un addr;
196 memset(&addr, 0, sizeof(addr));
197 addr.sun_family = AF_UNIX;
198 socklen_t addr_len;
199 strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
200 addr_len = sizeof(sockaddr_un);
201 if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
202 LOG(ERROR) << "connect() error";
203 return kInvalidSocket;
205 return sock;
208 scoped_ptr<base::Thread> CreateAndRunServerThread() {
209 base::Thread::Options options;
210 options.message_loop_type = base::MessageLoop::TYPE_IO;
211 scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
212 thread->StartWithOptions(options);
213 thread->message_loop()->PostTask(
214 FROM_HERE,
215 base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
216 base::Unretained(this)));
217 return thread.Pass();
220 const base::FilePath file_path_;
221 const bool allow_user_;
222 scoped_refptr<EventManager> event_manager_;
223 scoped_ptr<TestListenSocketDelegate> socket_delegate_;
224 scoped_ptr<UnixDomainSocket> socket_;
227 class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
228 protected:
229 UnixDomainSocketTest()
230 : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
233 class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
234 protected:
235 UnixDomainSocketTestWithInvalidPath()
236 : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
239 class UnixDomainSocketTestWithForbiddenUser
240 : public UnixDomainSocketTestHelper {
241 protected:
242 UnixDomainSocketTestWithForbiddenUser()
243 : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
246 TEST_F(UnixDomainSocketTest, CreateAndListen) {
247 CreateAndListen();
248 EXPECT_FALSE(socket_.get() == NULL);
251 TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
252 CreateAndListen();
253 EXPECT_TRUE(socket_.get() == NULL);
256 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
257 // Test with an invalid path to make sure that the socket is not backed by a
258 // file.
259 TEST_F(UnixDomainSocketTestWithInvalidPath,
260 CreateAndListenWithAbstractNamespace) {
261 socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
262 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
263 EXPECT_FALSE(socket_.get() == NULL);
266 TEST_F(UnixDomainSocketTest, TestFallbackName) {
267 scoped_ptr<UnixDomainSocket> existing_socket =
268 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
269 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
270 EXPECT_FALSE(existing_socket.get() == NULL);
271 // First, try to bind socket with the same name with no fallback name.
272 socket_ =
273 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
274 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
275 EXPECT_TRUE(socket_.get() == NULL);
276 // Now with a fallback name.
277 const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2";
278 socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
279 file_path_.value(),
280 MakeSocketPath(kFallbackSocketName),
281 socket_delegate_.get(),
282 MakeAuthCallback());
283 EXPECT_FALSE(socket_.get() == NULL);
285 #endif
287 TEST_F(UnixDomainSocketTest, TestWithClient) {
288 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
289 EventType event = event_manager_->WaitForEvent();
290 ASSERT_EQ(EVENT_LISTEN, event);
292 // Create the client socket.
293 const SocketDescriptor sock = CreateClientSocket();
294 ASSERT_NE(kInvalidSocket, sock);
295 event = event_manager_->WaitForEvent();
296 ASSERT_EQ(EVENT_AUTH_GRANTED, event);
297 event = event_manager_->WaitForEvent();
298 ASSERT_EQ(EVENT_ACCEPT, event);
300 // Send a message from the client to the server.
301 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
302 ASSERT_NE(-1, ret);
303 ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
304 event = event_manager_->WaitForEvent();
305 ASSERT_EQ(EVENT_READ, event);
306 ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
308 // Close the client socket.
309 ret = IGNORE_EINTR(close(sock));
310 event = event_manager_->WaitForEvent();
311 ASSERT_EQ(EVENT_CLOSE, event);
314 TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
315 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
316 EventType event = event_manager_->WaitForEvent();
317 ASSERT_EQ(EVENT_LISTEN, event);
318 const SocketDescriptor sock = CreateClientSocket();
319 ASSERT_NE(kInvalidSocket, sock);
321 event = event_manager_->WaitForEvent();
322 ASSERT_EQ(EVENT_AUTH_DENIED, event);
324 // Wait until the file descriptor is closed by the server.
325 struct pollfd poll_fd;
326 poll_fd.fd = sock;
327 poll_fd.events = POLLIN;
328 poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
330 // Send() must fail.
331 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
332 ASSERT_EQ(-1, ret);
333 ASSERT_EQ(EPIPE, errno);
334 ASSERT_FALSE(event_manager_->HasPendingEvent());
337 } // namespace
338 } // namespace net