1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
10 #include <netinet/in.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
24 const int kNoTimeout
= -1;
25 const int kConnectTimeOut
= 10; // Seconds.
27 bool FamilyIsTCP(int family
) {
28 return family
== AF_INET
|| family
== AF_INET6
;
32 namespace forwarder2
{
34 bool Socket::BindUnix(const std::string
& path
) {
36 if (!InitUnixSocket(path
) || !BindAndListen()) {
43 bool Socket::BindTcp(const std::string
& host
, int port
) {
45 if (!InitTcpSocket(host
, port
) || !BindAndListen()) {
52 bool Socket::ConnectUnix(const std::string
& path
) {
54 if (!InitUnixSocket(path
) || !Connect()) {
61 bool Socket::ConnectTcp(const std::string
& host
, int port
) {
63 if (!InitTcpSocket(host
, port
) || !Connect()) {
75 addr_ptr_(reinterpret_cast<sockaddr
*>(&addr_
.addr4
)),
76 addr_len_(sizeof(sockaddr
)) {
77 memset(&addr_
, 0, sizeof(addr_
));
84 void Socket::Shutdown() {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_
, SHUT_RDWR
));
90 void Socket::Close() {
97 bool Socket::InitSocketInternal() {
98 socket_
= socket(family_
, SOCK_STREAM
, 0);
101 tools::DisableNagle(socket_
);
103 setsockopt(socket_
, SOL_SOCKET
, SO_REUSEADDR
,
104 &reuse_addr
, sizeof(reuse_addr
));
105 tools::DeferAccept(socket_
);
109 bool Socket::InitUnixSocket(const std::string
& path
) {
110 static const size_t kPathMax
= sizeof(addr_
.addr_un
.sun_path
);
111 // For abstract sockets we need one extra byte for the leading zero.
112 if (path
.size() + 2 /* '\0' */ > kPathMax
) {
113 LOG(ERROR
) << "The provided path is too big to create a unix "
114 << "domain socket: " << path
;
118 addr_
.addr_un
.sun_family
= family_
;
119 // Copied from net/socket/unix_domain_socket_posix.cc
120 // Convert the path given into abstract socket name. It must start with
121 // the '\0' character, so we are adding it. |addr_len| must specify the
122 // length of the structure exactly, as potentially the socket name may
123 // have '\0' characters embedded (although we don't support this).
124 // Note that addr_.addr_un.sun_path is already zero initialized.
125 memcpy(addr_
.addr_un
.sun_path
+ 1, path
.c_str(), path
.size());
126 addr_len_
= path
.size() + offsetof(struct sockaddr_un
, sun_path
) + 1;
127 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr_un
);
128 return InitSocketInternal();
131 bool Socket::InitTcpSocket(const std::string
& host
, int port
) {
134 // Use localhost: INADDR_LOOPBACK
136 addr_
.addr4
.sin_family
= family_
;
137 addr_
.addr4
.sin_addr
.s_addr
= htonl(INADDR_LOOPBACK
);
138 } else if (!Resolve(host
)) {
141 CHECK(FamilyIsTCP(family_
)) << "Invalid socket family.";
142 if (family_
== AF_INET
) {
143 addr_
.addr4
.sin_port
= htons(port_
);
144 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr4
);
145 addr_len_
= sizeof(addr_
.addr4
);
146 } else if (family_
== AF_INET6
) {
147 addr_
.addr6
.sin6_port
= htons(port_
);
148 addr_ptr_
= reinterpret_cast<sockaddr
*>(&addr_
.addr6
);
149 addr_len_
= sizeof(addr_
.addr6
);
151 return InitSocketInternal();
154 bool Socket::BindAndListen() {
156 if (HANDLE_EINTR(bind(socket_
, addr_ptr_
, addr_len_
)) < 0 ||
157 HANDLE_EINTR(listen(socket_
, 5)) < 0) {
161 if (port_
== 0 && FamilyIsTCP(family_
)) {
163 memset(&addr
, 0, sizeof(addr
));
164 socklen_t addrlen
= 0;
165 sockaddr
* addr_ptr
= NULL
;
166 uint16
* port_ptr
= NULL
;
167 if (family_
== AF_INET
) {
168 addr_ptr
= reinterpret_cast<sockaddr
*>(&addr
.addr4
);
169 port_ptr
= &addr
.addr4
.sin_port
;
170 addrlen
= sizeof(addr
.addr4
);
171 } else if (family_
== AF_INET6
) {
172 addr_ptr
= reinterpret_cast<sockaddr
*>(&addr
.addr6
);
173 port_ptr
= &addr
.addr6
.sin6_port
;
174 addrlen
= sizeof(addr
.addr6
);
177 if (getsockname(socket_
, addr_ptr
, &addrlen
) != 0) {
178 LOG(ERROR
) << "getsockname error: " << safe_strerror(errno
);;
182 port_
= ntohs(*port_ptr
);
187 bool Socket::Accept(Socket
* new_socket
) {
188 DCHECK(new_socket
!= NULL
);
189 if (!WaitForEvent(READ
, kNoTimeout
)) {
194 int new_socket_fd
= HANDLE_EINTR(accept(socket_
, NULL
, NULL
));
195 if (new_socket_fd
< 0) {
200 tools::DisableNagle(new_socket_fd
);
201 new_socket
->socket_
= new_socket_fd
;
205 bool Socket::Connect() {
206 // Set non-block because we use select for connect.
207 const int kFlags
= fcntl(socket_
, F_GETFL
);
208 DCHECK(!(kFlags
& O_NONBLOCK
));
209 fcntl(socket_
, F_SETFL
, kFlags
| O_NONBLOCK
);
211 if (HANDLE_EINTR(connect(socket_
, addr_ptr_
, addr_len_
)) < 0 &&
212 errno
!= EINPROGRESS
) {
214 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_
, F_SETFL
, kFlags
));
217 // Wait for connection to complete, or receive a notification.
218 if (!WaitForEvent(WRITE
, kConnectTimeOut
)) {
220 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_
, F_SETFL
, kFlags
));
224 socklen_t opt_len
= sizeof(socket_errno
);
225 if (!getsockopt(socket_
, SOL_SOCKET
, SO_ERROR
, &socket_errno
, &opt_len
) < 0) {
226 LOG(ERROR
) << "getsockopt(): " << safe_strerror(errno
);
228 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_
, F_SETFL
, kFlags
));
231 if (socket_errno
!= 0) {
232 LOG(ERROR
) << "Could not connect to host: " << safe_strerror(socket_errno
);
234 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_
, F_SETFL
, kFlags
));
237 fcntl(socket_
, F_SETFL
, kFlags
);
241 bool Socket::Resolve(const std::string
& host
) {
242 struct addrinfo hints
;
243 struct addrinfo
* res
;
244 memset(&hints
, 0, sizeof(hints
));
245 hints
.ai_family
= AF_UNSPEC
;
246 hints
.ai_socktype
= SOCK_STREAM
;
247 hints
.ai_flags
|= AI_CANONNAME
;
249 int errcode
= getaddrinfo(host
.c_str(), NULL
, &hints
, &res
);
254 family_
= res
->ai_family
;
255 switch (res
->ai_family
) {
258 reinterpret_cast<sockaddr_in
*>(res
->ai_addr
),
259 sizeof(sockaddr_in
));
263 reinterpret_cast<sockaddr_in6
*>(res
->ai_addr
),
264 sizeof(sockaddr_in6
));
270 int Socket::GetPort() {
271 if (!FamilyIsTCP(family_
)) {
272 LOG(ERROR
) << "Can't call GetPort() on an unix domain socket.";
278 bool Socket::IsFdInSet(const fd_set
& fds
) const {
281 return FD_ISSET(socket_
, &fds
);
284 bool Socket::AddFdToSet(fd_set
* fds
) const {
287 FD_SET(socket_
, fds
);
291 int Socket::ReadNumBytes(void* buffer
, size_t num_bytes
) {
294 while (bytes_read
< num_bytes
&& ret
> 0) {
295 ret
= Read(static_cast<char*>(buffer
) + bytes_read
, num_bytes
- bytes_read
);
302 void Socket::SetSocketError() {
303 socket_error_
= true;
304 // We never use non-blocking socket.
305 DCHECK(errno
!= EAGAIN
&& errno
!= EWOULDBLOCK
);
309 int Socket::Read(void* buffer
, size_t buffer_size
) {
310 if (!WaitForEvent(READ
, kNoTimeout
)) {
314 int ret
= HANDLE_EINTR(read(socket_
, buffer
, buffer_size
));
320 int Socket::Write(const void* buffer
, size_t count
) {
321 int ret
= HANDLE_EINTR(send(socket_
, buffer
, count
, MSG_NOSIGNAL
));
327 int Socket::WriteString(const std::string
& buffer
) {
328 return WriteNumBytes(buffer
.c_str(), buffer
.size());
331 void Socket::AddEventFd(int event_fd
) {
334 event
.was_fired
= false;
335 events_
.push_back(event
);
338 bool Socket::DidReceiveEventOnFd(int fd
) const {
339 for (size_t i
= 0; i
< events_
.size(); ++i
)
340 if (events_
[i
].fd
== fd
)
341 return events_
[i
].was_fired
;
345 bool Socket::DidReceiveEvent() const {
346 for (size_t i
= 0; i
< events_
.size(); ++i
)
347 if (events_
[i
].was_fired
)
352 int Socket::WriteNumBytes(const void* buffer
, size_t num_bytes
) {
353 int bytes_written
= 0;
355 while (bytes_written
< num_bytes
&& ret
> 0) {
356 ret
= Write(static_cast<const char*>(buffer
) + bytes_written
,
357 num_bytes
- bytes_written
);
359 bytes_written
+= ret
;
361 return bytes_written
;
364 bool Socket::WaitForEvent(EventType type
, int timeout_secs
) {
365 if (events_
.empty() || socket_
== -1)
372 FD_SET(socket_
, &read_fds
);
374 FD_SET(socket_
, &write_fds
);
375 for (size_t i
= 0; i
< events_
.size(); ++i
)
376 FD_SET(events_
[i
].fd
, &read_fds
);
378 timeval
* tv_ptr
= NULL
;
379 if (timeout_secs
> 0) {
380 tv
.tv_sec
= timeout_secs
;
384 int max_fd
= socket_
;
385 for (size_t i
= 0; i
< events_
.size(); ++i
)
386 if (events_
[i
].fd
> max_fd
)
387 max_fd
= events_
[i
].fd
;
389 select(max_fd
+ 1, &read_fds
, &write_fds
, NULL
, tv_ptr
)) <= 0) {
392 bool event_was_fired
= false;
393 for (size_t i
= 0; i
< events_
.size(); ++i
) {
394 if (FD_ISSET(events_
[i
].fd
, &read_fds
)) {
395 events_
[i
].was_fired
= true;
396 event_was_fired
= true;
399 return !event_was_fired
;
403 int Socket::GetHighestFileDescriptor(const Socket
& s1
, const Socket
& s2
) {
404 return std::max(s1
.socket_
, s2
.socket_
);
408 pid_t
Socket::GetUnixDomainSocketProcessOwner(const std::string
& path
) {
410 if (!socket
.ConnectUnix(path
))
413 socklen_t len
= sizeof(ucred
);
414 if (getsockopt(socket
.socket_
, SOL_SOCKET
, SO_PEERCRED
, &ucred
, &len
) == -1) {
415 CHECK_NE(ENOPROTOOPT
, errno
);
421 } // namespace forwarder2