When Retrier succeeds, record errors it encountered.
[chromium-blink-merge.git] / tools / android / forwarder2 / socket.cc
blob242af4048bdbe91fa57e632d0988f1ffaa98ccf3
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
30 } // namespace
32 namespace forwarder2 {
34 bool Socket::BindUnix(const std::string& path) {
35 errno = 0;
36 if (!InitUnixSocket(path) || !BindAndListen()) {
37 Close();
38 return false;
40 return true;
43 bool Socket::BindTcp(const std::string& host, int port) {
44 errno = 0;
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
46 Close();
47 return false;
49 return true;
52 bool Socket::ConnectUnix(const std::string& path) {
53 errno = 0;
54 if (!InitUnixSocket(path) || !Connect()) {
55 Close();
56 return false;
58 return true;
61 bool Socket::ConnectTcp(const std::string& host, int port) {
62 errno = 0;
63 if (!InitTcpSocket(host, port) || !Connect()) {
64 Close();
65 return false;
67 return true;
70 Socket::Socket()
71 : socket_(-1),
72 port_(0),
73 socket_error_(false),
74 family_(AF_INET),
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
80 Socket::~Socket() {
81 Close();
84 void Socket::Shutdown() {
85 if (!IsClosed()) {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
90 void Socket::Close() {
91 if (!IsClosed()) {
92 CloseFD(socket_);
93 socket_ = -1;
97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
99 if (socket_ < 0)
100 return false;
101 tools::DisableNagle(socket_);
102 int reuse_addr = 1;
103 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
104 &reuse_addr, sizeof(reuse_addr));
105 tools::DeferAccept(socket_);
106 return true;
109 bool Socket::InitUnixSocket(const std::string& path) {
110 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
111 // For abstract sockets we need one extra byte for the leading zero.
112 if (path.size() + 2 /* '\0' */ > kPathMax) {
113 LOG(ERROR) << "The provided path is too big to create a unix "
114 << "domain socket: " << path;
115 return false;
117 family_ = PF_UNIX;
118 addr_.addr_un.sun_family = family_;
119 // Copied from net/socket/unix_domain_socket_posix.cc
120 // Convert the path given into abstract socket name. It must start with
121 // the '\0' character, so we are adding it. |addr_len| must specify the
122 // length of the structure exactly, as potentially the socket name may
123 // have '\0' characters embedded (although we don't support this).
124 // Note that addr_.addr_un.sun_path is already zero initialized.
125 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
126 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
127 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
128 return InitSocketInternal();
131 bool Socket::InitTcpSocket(const std::string& host, int port) {
132 port_ = port;
133 if (host.empty()) {
134 // Use localhost: INADDR_LOOPBACK
135 family_ = AF_INET;
136 addr_.addr4.sin_family = family_;
137 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
138 } else if (!Resolve(host)) {
139 return false;
141 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
142 if (family_ == AF_INET) {
143 addr_.addr4.sin_port = htons(port_);
144 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
145 addr_len_ = sizeof(addr_.addr4);
146 } else if (family_ == AF_INET6) {
147 addr_.addr6.sin6_port = htons(port_);
148 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
149 addr_len_ = sizeof(addr_.addr6);
151 return InitSocketInternal();
154 bool Socket::BindAndListen() {
155 errno = 0;
156 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
157 HANDLE_EINTR(listen(socket_, 5)) < 0) {
158 SetSocketError();
159 return false;
161 if (port_ == 0 && FamilyIsTCP(family_)) {
162 SockAddr addr;
163 memset(&addr, 0, sizeof(addr));
164 socklen_t addrlen = 0;
165 sockaddr* addr_ptr = NULL;
166 uint16* port_ptr = NULL;
167 if (family_ == AF_INET) {
168 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
169 port_ptr = &addr.addr4.sin_port;
170 addrlen = sizeof(addr.addr4);
171 } else if (family_ == AF_INET6) {
172 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
173 port_ptr = &addr.addr6.sin6_port;
174 addrlen = sizeof(addr.addr6);
176 errno = 0;
177 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
178 LOG(ERROR) << "getsockname error: " << safe_strerror(errno);;
179 SetSocketError();
180 return false;
182 port_ = ntohs(*port_ptr);
184 return true;
187 bool Socket::Accept(Socket* new_socket) {
188 DCHECK(new_socket != NULL);
189 if (!WaitForEvent(READ, kNoTimeout)) {
190 SetSocketError();
191 return false;
193 errno = 0;
194 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
195 if (new_socket_fd < 0) {
196 SetSocketError();
197 return false;
200 tools::DisableNagle(new_socket_fd);
201 new_socket->socket_ = new_socket_fd;
202 return true;
205 bool Socket::Connect() {
206 // Set non-block because we use select for connect.
207 const int kFlags = fcntl(socket_, F_GETFL);
208 DCHECK(!(kFlags & O_NONBLOCK));
209 fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK);
210 errno = 0;
211 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
212 errno != EINPROGRESS) {
213 SetSocketError();
214 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
215 return false;
217 // Wait for connection to complete, or receive a notification.
218 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
219 SetSocketError();
220 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
221 return false;
223 int socket_errno;
224 socklen_t opt_len = sizeof(socket_errno);
225 if (!getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
226 LOG(ERROR) << "getsockopt(): " << safe_strerror(errno);
227 SetSocketError();
228 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
229 return false;
231 if (socket_errno != 0) {
232 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
233 SetSocketError();
234 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
235 return false;
237 fcntl(socket_, F_SETFL, kFlags);
238 return true;
241 bool Socket::Resolve(const std::string& host) {
242 struct addrinfo hints;
243 struct addrinfo* res;
244 memset(&hints, 0, sizeof(hints));
245 hints.ai_family = AF_UNSPEC;
246 hints.ai_socktype = SOCK_STREAM;
247 hints.ai_flags |= AI_CANONNAME;
249 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
250 if (errcode != 0) {
251 SetSocketError();
252 return false;
254 family_ = res->ai_family;
255 switch (res->ai_family) {
256 case AF_INET:
257 memcpy(&addr_.addr4,
258 reinterpret_cast<sockaddr_in*>(res->ai_addr),
259 sizeof(sockaddr_in));
260 break;
261 case AF_INET6:
262 memcpy(&addr_.addr6,
263 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
264 sizeof(sockaddr_in6));
265 break;
267 return true;
270 int Socket::GetPort() {
271 if (!FamilyIsTCP(family_)) {
272 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
273 return 0;
275 return port_;
278 bool Socket::IsFdInSet(const fd_set& fds) const {
279 if (IsClosed())
280 return false;
281 return FD_ISSET(socket_, &fds);
284 bool Socket::AddFdToSet(fd_set* fds) const {
285 if (IsClosed())
286 return false;
287 FD_SET(socket_, fds);
288 return true;
291 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
292 int bytes_read = 0;
293 int ret = 1;
294 while (bytes_read < num_bytes && ret > 0) {
295 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
296 if (ret >= 0)
297 bytes_read += ret;
299 return bytes_read;
302 void Socket::SetSocketError() {
303 socket_error_ = true;
304 // We never use non-blocking socket.
305 DCHECK(errno != EAGAIN && errno != EWOULDBLOCK);
306 Close();
309 int Socket::Read(void* buffer, size_t buffer_size) {
310 if (!WaitForEvent(READ, kNoTimeout)) {
311 SetSocketError();
312 return 0;
314 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
315 if (ret < 0)
316 SetSocketError();
317 return ret;
320 int Socket::Write(const void* buffer, size_t count) {
321 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
322 if (ret < 0)
323 SetSocketError();
324 return ret;
327 int Socket::WriteString(const std::string& buffer) {
328 return WriteNumBytes(buffer.c_str(), buffer.size());
331 void Socket::AddEventFd(int event_fd) {
332 Event event;
333 event.fd = event_fd;
334 event.was_fired = false;
335 events_.push_back(event);
338 bool Socket::DidReceiveEventOnFd(int fd) const {
339 for (size_t i = 0; i < events_.size(); ++i)
340 if (events_[i].fd == fd)
341 return events_[i].was_fired;
342 return false;
345 bool Socket::DidReceiveEvent() const {
346 for (size_t i = 0; i < events_.size(); ++i)
347 if (events_[i].was_fired)
348 return true;
349 return false;
352 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
353 int bytes_written = 0;
354 int ret = 1;
355 while (bytes_written < num_bytes && ret > 0) {
356 ret = Write(static_cast<const char*>(buffer) + bytes_written,
357 num_bytes - bytes_written);
358 if (ret >= 0)
359 bytes_written += ret;
361 return bytes_written;
364 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
365 if (events_.empty() || socket_ == -1)
366 return true;
367 fd_set read_fds;
368 fd_set write_fds;
369 FD_ZERO(&read_fds);
370 FD_ZERO(&write_fds);
371 if (type == READ)
372 FD_SET(socket_, &read_fds);
373 else
374 FD_SET(socket_, &write_fds);
375 for (size_t i = 0; i < events_.size(); ++i)
376 FD_SET(events_[i].fd, &read_fds);
377 timeval tv = {};
378 timeval* tv_ptr = NULL;
379 if (timeout_secs > 0) {
380 tv.tv_sec = timeout_secs;
381 tv.tv_usec = 0;
382 tv_ptr = &tv;
384 int max_fd = socket_;
385 for (size_t i = 0; i < events_.size(); ++i)
386 if (events_[i].fd > max_fd)
387 max_fd = events_[i].fd;
388 if (HANDLE_EINTR(
389 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
390 return false;
392 bool event_was_fired = false;
393 for (size_t i = 0; i < events_.size(); ++i) {
394 if (FD_ISSET(events_[i].fd, &read_fds)) {
395 events_[i].was_fired = true;
396 event_was_fired = true;
399 return !event_was_fired;
402 // static
403 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
404 return std::max(s1.socket_, s2.socket_);
407 // static
408 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
409 Socket socket;
410 if (!socket.ConnectUnix(path))
411 return -1;
412 ucred ucred;
413 socklen_t len = sizeof(ucred);
414 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
415 CHECK_NE(ENOPROTOOPT, errno);
416 return -1;
418 return ucred.pid;
421 } // namespace forwarder2