1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
10 #include <netinet/in.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
24 const int kNoTimeout
= -1;
25 const int kConnectTimeOut
= 10; // Seconds.
27 bool FamilyIsTCP(int family
) {
28 return family
== AF_INET
|| family
== AF_INET6
;
32 namespace forwarder2
{
34 bool Socket::BindUnix(const std::string
& path
) {
36 if (!InitUnixSocket(path
) || !BindAndListen()) {
43 bool Socket::BindTcp(const std::string
& host
, int port
) {
45 if (!InitTcpSocket(host
, port
) || !BindAndListen()) {
52 bool Socket::ConnectUnix(const std::string
& path
) {
54 if (!InitUnixSocket(path
) || !Connect()) {
61 bool Socket::ConnectTcp(const std::string
& host
, int port
) {
63 if (!InitTcpSocket(host
, port
) || !Connect()) {
75 addr_ptr_(reinterpret_cast<sockaddr
*>(&addr_
.addr4
)),
76 addr_len_(sizeof(sockaddr
)) {
77 memset(&addr_
, 0, sizeof(addr_
));
84 void Socket::Shutdown() {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_
, SHUT_RDWR
));
90 void Socket::Close() {
97 bool Socket::InitSocketInternal() {
98 socket_
= socket(family_
, SOCK_STREAM
, 0);
100 PLOG(ERROR
) << "socket";
103 tools::DisableNagle(socket_
);
105 setsockopt(socket_
, SOL_SOCKET
, SO_REUSEADDR
, &reuse_addr
,
107 if (!SetNonBlocking())
112 bool Socket::SetNonBlocking() {
113 const int flags
= fcntl(socket_
, F_GETFL
);
115 PLOG(ERROR
) << "fcntl";
118 if (flags
& O_NONBLOCK
)
120 if (fcntl(socket_
, F_SETFL
, flags
| O_NONBLOCK
) < 0) {
121 PLOG(ERROR
) << "fcntl";
127 bool Socket::InitUnixSocket(const std::string
& path
) {
128 static const size_t kPathMax
= sizeof(addr_
.addr_un
.sun_path
);
129 // For abstract sockets we need one extra byte for the leading zero.
130 if (path
.size() + 2 /* '\0' */ > kPathMax
) {
131 LOG(ERROR
) << "The provided path is too big to create a unix "
132 << "domain socket: " << path
;
136 addr_
.addr_un
.sun_family
= family_
;
137 // Copied from net/socket/unix_domain_socket_posix.cc
138 // Convert the path given into abstract socket name. It must start with
139 // the '\0' character, so we are adding it. |addr_len| must specify the
140 // length of the structure exactly, as potentially the socket name may
141 // have '\0' characters embedded (although we don't support this).
142 // Note that addr_.addr_un.sun_path is already zero initialized.
143 memcpy(addr_
.addr_un
.sun_path
+ 1, path
.c_str(), path
.size());
144 addr_len_
= path
.size() + offsetof(struct sockaddr_un
, sun_path
) + 1;
145 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr_un
);
146 return InitSocketInternal();
149 bool Socket::InitTcpSocket(const std::string
& host
, int port
) {
152 // Use localhost: INADDR_LOOPBACK
154 addr_
.addr4
.sin_family
= family_
;
155 addr_
.addr4
.sin_addr
.s_addr
= htonl(INADDR_LOOPBACK
);
156 } else if (!Resolve(host
)) {
159 CHECK(FamilyIsTCP(family_
)) << "Invalid socket family.";
160 if (family_
== AF_INET
) {
161 addr_
.addr4
.sin_port
= htons(port_
);
162 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr4
);
163 addr_len_
= sizeof(addr_
.addr4
);
164 } else if (family_
== AF_INET6
) {
165 addr_
.addr6
.sin6_port
= htons(port_
);
166 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr6
);
167 addr_len_
= sizeof(addr_
.addr6
);
169 return InitSocketInternal();
172 bool Socket::BindAndListen() {
174 if (HANDLE_EINTR(bind(socket_
, addr_ptr_
, addr_len_
)) < 0 ||
175 HANDLE_EINTR(listen(socket_
, SOMAXCONN
)) < 0) {
176 PLOG(ERROR
) << "bind/listen";
180 if (port_
== 0 && FamilyIsTCP(family_
)) {
182 memset(&addr
, 0, sizeof(addr
));
183 socklen_t addrlen
= 0;
184 sockaddr
* addr_ptr
= NULL
;
185 uint16
* port_ptr
= NULL
;
186 if (family_
== AF_INET
) {
187 addr_ptr
= reinterpret_cast<sockaddr
*>(&addr
.addr4
);
188 port_ptr
= &addr
.addr4
.sin_port
;
189 addrlen
= sizeof(addr
.addr4
);
190 } else if (family_
== AF_INET6
) {
191 addr_ptr
= reinterpret_cast<sockaddr
*>(&addr
.addr6
);
192 port_ptr
= &addr
.addr6
.sin6_port
;
193 addrlen
= sizeof(addr
.addr6
);
196 if (getsockname(socket_
, addr_ptr
, &addrlen
) != 0) {
197 PLOG(ERROR
) << "getsockname";
201 port_
= ntohs(*port_ptr
);
206 bool Socket::Accept(Socket
* new_socket
) {
207 DCHECK(new_socket
!= NULL
);
208 if (!WaitForEvent(READ
, kNoTimeout
)) {
213 int new_socket_fd
= HANDLE_EINTR(accept(socket_
, NULL
, NULL
));
214 if (new_socket_fd
< 0) {
218 tools::DisableNagle(new_socket_fd
);
219 new_socket
->socket_
= new_socket_fd
;
220 if (!new_socket
->SetNonBlocking())
225 bool Socket::Connect() {
226 DCHECK(fcntl(socket_
, F_GETFL
) & O_NONBLOCK
);
228 if (HANDLE_EINTR(connect(socket_
, addr_ptr_
, addr_len_
)) < 0 &&
229 errno
!= EINPROGRESS
) {
233 // Wait for connection to complete, or receive a notification.
234 if (!WaitForEvent(WRITE
, kConnectTimeOut
)) {
239 socklen_t opt_len
= sizeof(socket_errno
);
240 if (getsockopt(socket_
, SOL_SOCKET
, SO_ERROR
, &socket_errno
, &opt_len
) < 0) {
241 PLOG(ERROR
) << "getsockopt()";
245 if (socket_errno
!= 0) {
246 LOG(ERROR
) << "Could not connect to host: " << safe_strerror(socket_errno
);
253 bool Socket::Resolve(const std::string
& host
) {
254 struct addrinfo hints
;
255 struct addrinfo
* res
;
256 memset(&hints
, 0, sizeof(hints
));
257 hints
.ai_family
= AF_UNSPEC
;
258 hints
.ai_socktype
= SOCK_STREAM
;
259 hints
.ai_flags
|= AI_CANONNAME
;
261 int errcode
= getaddrinfo(host
.c_str(), NULL
, &hints
, &res
);
268 family_
= res
->ai_family
;
269 switch (res
->ai_family
) {
272 reinterpret_cast<sockaddr_in
*>(res
->ai_addr
),
273 sizeof(sockaddr_in
));
277 reinterpret_cast<sockaddr_in6
*>(res
->ai_addr
),
278 sizeof(sockaddr_in6
));
285 int Socket::GetPort() {
286 if (!FamilyIsTCP(family_
)) {
287 LOG(ERROR
) << "Can't call GetPort() on an unix domain socket.";
293 int Socket::ReadNumBytes(void* buffer
, size_t num_bytes
) {
294 size_t bytes_read
= 0;
296 while (bytes_read
< num_bytes
&& ret
> 0) {
297 ret
= Read(static_cast<char*>(buffer
) + bytes_read
, num_bytes
- bytes_read
);
304 void Socket::SetSocketError() {
305 socket_error_
= true;
306 DCHECK_NE(EAGAIN
, errno
);
307 DCHECK_NE(EWOULDBLOCK
, errno
);
311 int Socket::Read(void* buffer
, size_t buffer_size
) {
312 if (!WaitForEvent(READ
, kNoTimeout
)) {
316 int ret
= HANDLE_EINTR(read(socket_
, buffer
, buffer_size
));
318 PLOG(ERROR
) << "read";
324 int Socket::NonBlockingRead(void* buffer
, size_t buffer_size
) {
325 DCHECK(fcntl(socket_
, F_GETFL
) & O_NONBLOCK
);
326 int ret
= HANDLE_EINTR(read(socket_
, buffer
, buffer_size
));
328 PLOG(ERROR
) << "read";
334 int Socket::Write(const void* buffer
, size_t count
) {
335 if (!WaitForEvent(WRITE
, kNoTimeout
)) {
339 int ret
= HANDLE_EINTR(send(socket_
, buffer
, count
, MSG_NOSIGNAL
));
341 PLOG(ERROR
) << "send";
347 int Socket::NonBlockingWrite(const void* buffer
, size_t count
) {
348 DCHECK(fcntl(socket_
, F_GETFL
) & O_NONBLOCK
);
349 int ret
= HANDLE_EINTR(send(socket_
, buffer
, count
, MSG_NOSIGNAL
));
351 PLOG(ERROR
) << "send";
357 int Socket::WriteString(const std::string
& buffer
) {
358 return WriteNumBytes(buffer
.c_str(), buffer
.size());
361 void Socket::AddEventFd(int event_fd
) {
364 event
.was_fired
= false;
365 events_
.push_back(event
);
368 bool Socket::DidReceiveEventOnFd(int fd
) const {
369 for (size_t i
= 0; i
< events_
.size(); ++i
)
370 if (events_
[i
].fd
== fd
)
371 return events_
[i
].was_fired
;
375 bool Socket::DidReceiveEvent() const {
376 for (size_t i
= 0; i
< events_
.size(); ++i
)
377 if (events_
[i
].was_fired
)
382 int Socket::WriteNumBytes(const void* buffer
, size_t num_bytes
) {
383 size_t bytes_written
= 0;
385 while (bytes_written
< num_bytes
&& ret
> 0) {
386 ret
= Write(static_cast<const char*>(buffer
) + bytes_written
,
387 num_bytes
- bytes_written
);
389 bytes_written
+= ret
;
391 return bytes_written
;
394 bool Socket::WaitForEvent(EventType type
, int timeout_secs
) {
397 DCHECK(fcntl(socket_
, F_GETFL
) & O_NONBLOCK
);
403 FD_SET(socket_
, &read_fds
);
405 FD_SET(socket_
, &write_fds
);
406 for (size_t i
= 0; i
< events_
.size(); ++i
)
407 FD_SET(events_
[i
].fd
, &read_fds
);
409 timeval
* tv_ptr
= NULL
;
410 if (timeout_secs
> 0) {
411 tv
.tv_sec
= timeout_secs
;
415 int max_fd
= socket_
;
416 for (size_t i
= 0; i
< events_
.size(); ++i
)
417 if (events_
[i
].fd
> max_fd
)
418 max_fd
= events_
[i
].fd
;
420 select(max_fd
+ 1, &read_fds
, &write_fds
, NULL
, tv_ptr
)) <= 0) {
421 PLOG(ERROR
) << "select";
424 bool event_was_fired
= false;
425 for (size_t i
= 0; i
< events_
.size(); ++i
) {
426 if (FD_ISSET(events_
[i
].fd
, &read_fds
)) {
427 events_
[i
].was_fired
= true;
428 event_was_fired
= true;
431 return !event_was_fired
;
435 pid_t
Socket::GetUnixDomainSocketProcessOwner(const std::string
& path
) {
437 if (!socket
.ConnectUnix(path
))
440 socklen_t len
= sizeof(ucred
);
441 if (getsockopt(socket
.socket_
, SOL_SOCKET
, SO_PEERCRED
, &ucred
, &len
) == -1) {
442 CHECK_NE(ENOPROTOOPT
, errno
);
448 } // namespace forwarder2