Change next_proto member type.
[chromium-blink-merge.git] / tools / android / forwarder2 / socket.cc
blob05dfcbd16a94ecbb0a179b9fb0621799803c9554
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
30 } // namespace
32 namespace forwarder2 {
34 bool Socket::BindUnix(const std::string& path) {
35 errno = 0;
36 if (!InitUnixSocket(path) || !BindAndListen()) {
37 Close();
38 return false;
40 return true;
43 bool Socket::BindTcp(const std::string& host, int port) {
44 errno = 0;
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
46 Close();
47 return false;
49 return true;
52 bool Socket::ConnectUnix(const std::string& path) {
53 errno = 0;
54 if (!InitUnixSocket(path) || !Connect()) {
55 Close();
56 return false;
58 return true;
61 bool Socket::ConnectTcp(const std::string& host, int port) {
62 errno = 0;
63 if (!InitTcpSocket(host, port) || !Connect()) {
64 Close();
65 return false;
67 return true;
70 Socket::Socket()
71 : socket_(-1),
72 port_(0),
73 socket_error_(false),
74 family_(AF_INET),
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
80 Socket::~Socket() {
81 Close();
84 void Socket::Shutdown() {
85 if (!IsClosed()) {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
90 void Socket::Close() {
91 if (!IsClosed()) {
92 CloseFD(socket_);
93 socket_ = -1;
97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
99 if (socket_ < 0) {
100 PLOG(ERROR) << "socket";
101 return false;
103 tools::DisableNagle(socket_);
104 int reuse_addr = 1;
105 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
106 sizeof(reuse_addr));
107 if (!SetNonBlocking())
108 return false;
109 return true;
112 bool Socket::SetNonBlocking() {
113 const int flags = fcntl(socket_, F_GETFL);
114 if (flags < 0) {
115 PLOG(ERROR) << "fcntl";
116 return false;
118 if (flags & O_NONBLOCK)
119 return true;
120 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
121 PLOG(ERROR) << "fcntl";
122 return false;
124 return true;
127 bool Socket::InitUnixSocket(const std::string& path) {
128 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
129 // For abstract sockets we need one extra byte for the leading zero.
130 if (path.size() + 2 /* '\0' */ > kPathMax) {
131 LOG(ERROR) << "The provided path is too big to create a unix "
132 << "domain socket: " << path;
133 return false;
135 family_ = PF_UNIX;
136 addr_.addr_un.sun_family = family_;
137 // Copied from net/socket/unix_domain_socket_posix.cc
138 // Convert the path given into abstract socket name. It must start with
139 // the '\0' character, so we are adding it. |addr_len| must specify the
140 // length of the structure exactly, as potentially the socket name may
141 // have '\0' characters embedded (although we don't support this).
142 // Note that addr_.addr_un.sun_path is already zero initialized.
143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
145 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
146 return InitSocketInternal();
149 bool Socket::InitTcpSocket(const std::string& host, int port) {
150 port_ = port;
151 if (host.empty()) {
152 // Use localhost: INADDR_LOOPBACK
153 family_ = AF_INET;
154 addr_.addr4.sin_family = family_;
155 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
156 } else if (!Resolve(host)) {
157 return false;
159 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
160 if (family_ == AF_INET) {
161 addr_.addr4.sin_port = htons(port_);
162 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
163 addr_len_ = sizeof(addr_.addr4);
164 } else if (family_ == AF_INET6) {
165 addr_.addr6.sin6_port = htons(port_);
166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
167 addr_len_ = sizeof(addr_.addr6);
169 return InitSocketInternal();
172 bool Socket::BindAndListen() {
173 errno = 0;
174 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
175 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
176 PLOG(ERROR) << "bind/listen";
177 SetSocketError();
178 return false;
180 if (port_ == 0 && FamilyIsTCP(family_)) {
181 SockAddr addr;
182 memset(&addr, 0, sizeof(addr));
183 socklen_t addrlen = 0;
184 sockaddr* addr_ptr = NULL;
185 uint16* port_ptr = NULL;
186 if (family_ == AF_INET) {
187 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
188 port_ptr = &addr.addr4.sin_port;
189 addrlen = sizeof(addr.addr4);
190 } else if (family_ == AF_INET6) {
191 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
192 port_ptr = &addr.addr6.sin6_port;
193 addrlen = sizeof(addr.addr6);
195 errno = 0;
196 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
197 PLOG(ERROR) << "getsockname";
198 SetSocketError();
199 return false;
201 port_ = ntohs(*port_ptr);
203 return true;
206 bool Socket::Accept(Socket* new_socket) {
207 DCHECK(new_socket != NULL);
208 if (!WaitForEvent(READ, kNoTimeout)) {
209 SetSocketError();
210 return false;
212 errno = 0;
213 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
214 if (new_socket_fd < 0) {
215 SetSocketError();
216 return false;
218 tools::DisableNagle(new_socket_fd);
219 new_socket->socket_ = new_socket_fd;
220 if (!new_socket->SetNonBlocking())
221 return false;
222 return true;
225 bool Socket::Connect() {
226 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
227 errno = 0;
228 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
229 errno != EINPROGRESS) {
230 SetSocketError();
231 return false;
233 // Wait for connection to complete, or receive a notification.
234 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
235 SetSocketError();
236 return false;
238 int socket_errno;
239 socklen_t opt_len = sizeof(socket_errno);
240 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
241 PLOG(ERROR) << "getsockopt()";
242 SetSocketError();
243 return false;
245 if (socket_errno != 0) {
246 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
247 SetSocketError();
248 return false;
250 return true;
253 bool Socket::Resolve(const std::string& host) {
254 struct addrinfo hints;
255 struct addrinfo* res;
256 memset(&hints, 0, sizeof(hints));
257 hints.ai_family = AF_UNSPEC;
258 hints.ai_socktype = SOCK_STREAM;
259 hints.ai_flags |= AI_CANONNAME;
261 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
262 if (errcode != 0) {
263 errno = 0;
264 SetSocketError();
265 freeaddrinfo(res);
266 return false;
268 family_ = res->ai_family;
269 switch (res->ai_family) {
270 case AF_INET:
271 memcpy(&addr_.addr4,
272 reinterpret_cast<sockaddr_in*>(res->ai_addr),
273 sizeof(sockaddr_in));
274 break;
275 case AF_INET6:
276 memcpy(&addr_.addr6,
277 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
278 sizeof(sockaddr_in6));
279 break;
281 freeaddrinfo(res);
282 return true;
285 int Socket::GetPort() {
286 if (!FamilyIsTCP(family_)) {
287 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
288 return 0;
290 return port_;
293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
294 size_t bytes_read = 0;
295 int ret = 1;
296 while (bytes_read < num_bytes && ret > 0) {
297 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
298 if (ret >= 0)
299 bytes_read += ret;
301 return bytes_read;
304 void Socket::SetSocketError() {
305 socket_error_ = true;
306 DCHECK_NE(EAGAIN, errno);
307 DCHECK_NE(EWOULDBLOCK, errno);
308 Close();
311 int Socket::Read(void* buffer, size_t buffer_size) {
312 if (!WaitForEvent(READ, kNoTimeout)) {
313 SetSocketError();
314 return 0;
316 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
317 if (ret < 0) {
318 PLOG(ERROR) << "read";
319 SetSocketError();
321 return ret;
324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
325 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
327 if (ret < 0) {
328 PLOG(ERROR) << "read";
329 SetSocketError();
331 return ret;
334 int Socket::Write(const void* buffer, size_t count) {
335 if (!WaitForEvent(WRITE, kNoTimeout)) {
336 SetSocketError();
337 return 0;
339 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
340 if (ret < 0) {
341 PLOG(ERROR) << "send";
342 SetSocketError();
344 return ret;
347 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
348 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
350 if (ret < 0) {
351 PLOG(ERROR) << "send";
352 SetSocketError();
354 return ret;
357 int Socket::WriteString(const std::string& buffer) {
358 return WriteNumBytes(buffer.c_str(), buffer.size());
361 void Socket::AddEventFd(int event_fd) {
362 Event event;
363 event.fd = event_fd;
364 event.was_fired = false;
365 events_.push_back(event);
368 bool Socket::DidReceiveEventOnFd(int fd) const {
369 for (size_t i = 0; i < events_.size(); ++i)
370 if (events_[i].fd == fd)
371 return events_[i].was_fired;
372 return false;
375 bool Socket::DidReceiveEvent() const {
376 for (size_t i = 0; i < events_.size(); ++i)
377 if (events_[i].was_fired)
378 return true;
379 return false;
382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
383 size_t bytes_written = 0;
384 int ret = 1;
385 while (bytes_written < num_bytes && ret > 0) {
386 ret = Write(static_cast<const char*>(buffer) + bytes_written,
387 num_bytes - bytes_written);
388 if (ret >= 0)
389 bytes_written += ret;
391 return bytes_written;
394 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
395 if (socket_ == -1)
396 return true;
397 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
398 fd_set read_fds;
399 fd_set write_fds;
400 FD_ZERO(&read_fds);
401 FD_ZERO(&write_fds);
402 if (type == READ)
403 FD_SET(socket_, &read_fds);
404 else
405 FD_SET(socket_, &write_fds);
406 for (size_t i = 0; i < events_.size(); ++i)
407 FD_SET(events_[i].fd, &read_fds);
408 timeval tv = {};
409 timeval* tv_ptr = NULL;
410 if (timeout_secs > 0) {
411 tv.tv_sec = timeout_secs;
412 tv.tv_usec = 0;
413 tv_ptr = &tv;
415 int max_fd = socket_;
416 for (size_t i = 0; i < events_.size(); ++i)
417 if (events_[i].fd > max_fd)
418 max_fd = events_[i].fd;
419 if (HANDLE_EINTR(
420 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
421 PLOG(ERROR) << "select";
422 return false;
424 bool event_was_fired = false;
425 for (size_t i = 0; i < events_.size(); ++i) {
426 if (FD_ISSET(events_[i].fd, &read_fds)) {
427 events_[i].was_fired = true;
428 event_was_fired = true;
431 return !event_was_fired;
434 // static
435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
436 Socket socket;
437 if (!socket.ConnectUnix(path))
438 return -1;
439 ucred ucred;
440 socklen_t len = sizeof(ucred);
441 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
442 CHECK_NE(ENOPROTOOPT, errno);
443 return -1;
445 return ucred.pid;
448 } // namespace forwarder2