Minor Python style clean-up
[chromium-blink-merge.git] / tools / android / forwarder2 / socket.cc
blob23ff886d01cae973866a6bc01b8c264b110a7199
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/posix/safe_strerror.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
30 } // namespace
32 namespace forwarder2 {
34 bool Socket::BindUnix(const std::string& path) {
35 errno = 0;
36 if (!InitUnixSocket(path) || !BindAndListen()) {
37 Close();
38 return false;
40 return true;
43 bool Socket::BindTcp(const std::string& host, int port) {
44 errno = 0;
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
46 Close();
47 return false;
49 return true;
52 bool Socket::ConnectUnix(const std::string& path) {
53 errno = 0;
54 if (!InitUnixSocket(path) || !Connect()) {
55 Close();
56 return false;
58 return true;
61 bool Socket::ConnectTcp(const std::string& host, int port) {
62 errno = 0;
63 if (!InitTcpSocket(host, port) || !Connect()) {
64 Close();
65 return false;
67 return true;
70 Socket::Socket()
71 : socket_(-1),
72 port_(0),
73 socket_error_(false),
74 family_(AF_INET),
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
80 Socket::~Socket() {
81 Close();
84 void Socket::Shutdown() {
85 if (!IsClosed()) {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
90 void Socket::Close() {
91 if (!IsClosed()) {
92 CloseFD(socket_);
93 socket_ = -1;
97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
99 if (socket_ < 0) {
100 PLOG(ERROR) << "socket";
101 return false;
103 tools::DisableNagle(socket_);
104 int reuse_addr = 1;
105 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
106 sizeof(reuse_addr));
107 if (!SetNonBlocking())
108 return false;
109 return true;
112 bool Socket::SetNonBlocking() {
113 const int flags = fcntl(socket_, F_GETFL);
114 if (flags < 0) {
115 PLOG(ERROR) << "fcntl";
116 return false;
118 if (flags & O_NONBLOCK)
119 return true;
120 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
121 PLOG(ERROR) << "fcntl";
122 return false;
124 return true;
127 bool Socket::InitUnixSocket(const std::string& path) {
128 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
129 // For abstract sockets we need one extra byte for the leading zero.
130 if (path.size() + 2 /* '\0' */ > kPathMax) {
131 LOG(ERROR) << "The provided path is too big to create a unix "
132 << "domain socket: " << path;
133 return false;
135 family_ = PF_UNIX;
136 addr_.addr_un.sun_family = family_;
137 // Copied from net/socket/unix_domain_socket_posix.cc
138 // Convert the path given into abstract socket name. It must start with
139 // the '\0' character, so we are adding it. |addr_len| must specify the
140 // length of the structure exactly, as potentially the socket name may
141 // have '\0' characters embedded (although we don't support this).
142 // Note that addr_.addr_un.sun_path is already zero initialized.
143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
145 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
146 return InitSocketInternal();
149 bool Socket::InitTcpSocket(const std::string& host, int port) {
150 port_ = port;
151 if (host.empty()) {
152 // Use localhost: INADDR_LOOPBACK
153 family_ = AF_INET;
154 addr_.addr4.sin_family = family_;
155 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
156 } else if (!Resolve(host)) {
157 return false;
159 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
160 if (family_ == AF_INET) {
161 addr_.addr4.sin_port = htons(port_);
162 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
163 addr_len_ = sizeof(addr_.addr4);
164 } else if (family_ == AF_INET6) {
165 addr_.addr6.sin6_port = htons(port_);
166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
167 addr_len_ = sizeof(addr_.addr6);
169 return InitSocketInternal();
172 bool Socket::BindAndListen() {
173 errno = 0;
174 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
175 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
176 PLOG(ERROR) << "bind/listen";
177 SetSocketError();
178 return false;
180 if (port_ == 0 && FamilyIsTCP(family_)) {
181 SockAddr addr;
182 memset(&addr, 0, sizeof(addr));
183 socklen_t addrlen = 0;
184 sockaddr* addr_ptr = NULL;
185 uint16* port_ptr = NULL;
186 if (family_ == AF_INET) {
187 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
188 port_ptr = &addr.addr4.sin_port;
189 addrlen = sizeof(addr.addr4);
190 } else if (family_ == AF_INET6) {
191 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
192 port_ptr = &addr.addr6.sin6_port;
193 addrlen = sizeof(addr.addr6);
195 errno = 0;
196 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
197 PLOG(ERROR) << "getsockname";
198 SetSocketError();
199 return false;
201 port_ = ntohs(*port_ptr);
203 return true;
206 bool Socket::Accept(Socket* new_socket) {
207 DCHECK(new_socket != NULL);
208 if (!WaitForEvent(READ, kNoTimeout)) {
209 SetSocketError();
210 return false;
212 errno = 0;
213 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
214 if (new_socket_fd < 0) {
215 SetSocketError();
216 return false;
218 tools::DisableNagle(new_socket_fd);
219 new_socket->socket_ = new_socket_fd;
220 if (!new_socket->SetNonBlocking())
221 return false;
222 return true;
225 bool Socket::Connect() {
226 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
227 errno = 0;
228 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
229 errno != EINPROGRESS) {
230 SetSocketError();
231 return false;
233 // Wait for connection to complete, or receive a notification.
234 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
235 SetSocketError();
236 return false;
238 int socket_errno;
239 socklen_t opt_len = sizeof(socket_errno);
240 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
241 PLOG(ERROR) << "getsockopt()";
242 SetSocketError();
243 return false;
245 if (socket_errno != 0) {
246 LOG(ERROR) << "Could not connect to host: "
247 << base::safe_strerror(socket_errno);
248 SetSocketError();
249 return false;
251 return true;
254 bool Socket::Resolve(const std::string& host) {
255 struct addrinfo hints;
256 struct addrinfo* res;
257 memset(&hints, 0, sizeof(hints));
258 hints.ai_family = AF_UNSPEC;
259 hints.ai_socktype = SOCK_STREAM;
260 hints.ai_flags |= AI_CANONNAME;
262 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
263 if (errcode != 0) {
264 errno = 0;
265 SetSocketError();
266 freeaddrinfo(res);
267 return false;
269 family_ = res->ai_family;
270 switch (res->ai_family) {
271 case AF_INET:
272 memcpy(&addr_.addr4,
273 reinterpret_cast<sockaddr_in*>(res->ai_addr),
274 sizeof(sockaddr_in));
275 break;
276 case AF_INET6:
277 memcpy(&addr_.addr6,
278 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
279 sizeof(sockaddr_in6));
280 break;
282 freeaddrinfo(res);
283 return true;
286 int Socket::GetPort() {
287 if (!FamilyIsTCP(family_)) {
288 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
289 return 0;
291 return port_;
294 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
295 size_t bytes_read = 0;
296 int ret = 1;
297 while (bytes_read < num_bytes && ret > 0) {
298 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
299 if (ret >= 0)
300 bytes_read += ret;
302 return bytes_read;
305 void Socket::SetSocketError() {
306 socket_error_ = true;
307 DCHECK_NE(EAGAIN, errno);
308 DCHECK_NE(EWOULDBLOCK, errno);
309 Close();
312 int Socket::Read(void* buffer, size_t buffer_size) {
313 if (!WaitForEvent(READ, kNoTimeout)) {
314 SetSocketError();
315 return 0;
317 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
318 if (ret < 0) {
319 PLOG(ERROR) << "read";
320 SetSocketError();
322 return ret;
325 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
326 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
327 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
328 if (ret < 0) {
329 PLOG(ERROR) << "read";
330 SetSocketError();
332 return ret;
335 int Socket::Write(const void* buffer, size_t count) {
336 if (!WaitForEvent(WRITE, kNoTimeout)) {
337 SetSocketError();
338 return 0;
340 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
341 if (ret < 0) {
342 PLOG(ERROR) << "send";
343 SetSocketError();
345 return ret;
348 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
349 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
350 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
351 if (ret < 0) {
352 PLOG(ERROR) << "send";
353 SetSocketError();
355 return ret;
358 int Socket::WriteString(const std::string& buffer) {
359 return WriteNumBytes(buffer.c_str(), buffer.size());
362 void Socket::AddEventFd(int event_fd) {
363 Event event;
364 event.fd = event_fd;
365 event.was_fired = false;
366 events_.push_back(event);
369 bool Socket::DidReceiveEventOnFd(int fd) const {
370 for (size_t i = 0; i < events_.size(); ++i)
371 if (events_[i].fd == fd)
372 return events_[i].was_fired;
373 return false;
376 bool Socket::DidReceiveEvent() const {
377 for (size_t i = 0; i < events_.size(); ++i)
378 if (events_[i].was_fired)
379 return true;
380 return false;
383 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
384 size_t bytes_written = 0;
385 int ret = 1;
386 while (bytes_written < num_bytes && ret > 0) {
387 ret = Write(static_cast<const char*>(buffer) + bytes_written,
388 num_bytes - bytes_written);
389 if (ret >= 0)
390 bytes_written += ret;
392 return bytes_written;
395 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
396 if (socket_ == -1)
397 return true;
398 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
399 fd_set read_fds;
400 fd_set write_fds;
401 FD_ZERO(&read_fds);
402 FD_ZERO(&write_fds);
403 if (type == READ)
404 FD_SET(socket_, &read_fds);
405 else
406 FD_SET(socket_, &write_fds);
407 for (size_t i = 0; i < events_.size(); ++i)
408 FD_SET(events_[i].fd, &read_fds);
409 timeval tv = {};
410 timeval* tv_ptr = NULL;
411 if (timeout_secs > 0) {
412 tv.tv_sec = timeout_secs;
413 tv.tv_usec = 0;
414 tv_ptr = &tv;
416 int max_fd = socket_;
417 for (size_t i = 0; i < events_.size(); ++i)
418 if (events_[i].fd > max_fd)
419 max_fd = events_[i].fd;
420 if (HANDLE_EINTR(
421 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
422 PLOG(ERROR) << "select";
423 return false;
425 bool event_was_fired = false;
426 for (size_t i = 0; i < events_.size(); ++i) {
427 if (FD_ISSET(events_[i].fd, &read_fds)) {
428 events_[i].was_fired = true;
429 event_was_fired = true;
432 return !event_was_fired;
435 // static
436 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
437 Socket socket;
438 if (!socket.ConnectUnix(path))
439 return -1;
440 ucred ucred;
441 socklen_t len = sizeof(ucred);
442 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
443 CHECK_NE(ENOPROTOOPT, errno);
444 return -1;
446 return ucred.pid;
449 } // namespace forwarder2