1 //===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains raw_ostream implementations for streams to communicate
12 //===----------------------------------------------------------------------===//
14 #include "llvm/Support/raw_socket_stream.h"
15 #include "llvm/Config/config.h"
16 #include "llvm/Support/Error.h"
17 #include "llvm/Support/FileSystem.h"
25 #include <sys/socket.h>
28 #include "llvm/Support/Windows/WindowsSupport.h"
29 // winsock2.h must be included before afunix.h. Briefly turn off clang-format to
38 #if defined(HAVE_UNISTD_H)
45 WSABalancer::WSABalancer() {
47 ::memset(&WsaData
, 0, sizeof(WsaData
));
48 if (WSAStartup(MAKEWORD(2, 2), &WsaData
) != 0) {
49 llvm::report_fatal_error("WSAStartup failed");
53 WSABalancer::~WSABalancer() { WSACleanup(); }
56 static std::error_code
getLastSocketErrorCode() {
58 return std::error_code(::WSAGetLastError(), std::system_category());
60 return errnoAsErrorCode();
64 static sockaddr_un
setSocketAddr(StringRef SocketPath
) {
65 struct sockaddr_un Addr
;
66 memset(&Addr
, 0, sizeof(Addr
));
67 Addr
.sun_family
= AF_UNIX
;
68 strncpy(Addr
.sun_path
, SocketPath
.str().c_str(), sizeof(Addr
.sun_path
) - 1);
72 static Expected
<int> getSocketFD(StringRef SocketPath
) {
74 SOCKET Socket
= socket(AF_UNIX
, SOCK_STREAM
, 0);
75 if (Socket
== INVALID_SOCKET
) {
77 int Socket
= socket(AF_UNIX
, SOCK_STREAM
, 0);
80 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
81 "Create socket failed");
84 struct sockaddr_un Addr
= setSocketAddr(SocketPath
);
85 if (::connect(Socket
, (struct sockaddr
*)&Addr
, sizeof(Addr
)) == -1)
86 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
87 "Connect socket failed");
90 return _open_osfhandle(Socket
, 0);
96 ListeningSocket::ListeningSocket(int SocketFD
, StringRef SocketPath
,
98 : FD(SocketFD
), SocketPath(SocketPath
), PipeFD
{PipeFD
[0], PipeFD
[1]} {}
100 ListeningSocket::ListeningSocket(ListeningSocket
&&LS
)
101 : FD(LS
.FD
.load()), SocketPath(LS
.SocketPath
),
102 PipeFD
{LS
.PipeFD
[0], LS
.PipeFD
[1]} {
105 LS
.SocketPath
.clear();
110 Expected
<ListeningSocket
> ListeningSocket::createUnix(StringRef SocketPath
,
113 // Handle instances where the target socket address already exists and
114 // differentiate between a preexisting file with and without a bound socket
116 // ::bind will return std::errc:address_in_use if a file at the socket address
117 // already exists (e.g., the file was not properly unlinked due to a crash)
118 // even if another socket has not yet binded to that address
119 if (llvm::sys::fs::exists(SocketPath
)) {
120 Expected
<int> MaybeFD
= getSocketFD(SocketPath
);
123 // Regardless of the error, notify the caller that a file already exists
124 // at the desired socket address and that there is no bound socket at that
125 // address. The file must be removed before ::bind can use the address
126 consumeError(MaybeFD
.takeError());
127 return llvm::make_error
<StringError
>(
128 std::make_error_code(std::errc::file_exists
),
129 "Socket address unavailable");
131 ::close(std::move(*MaybeFD
));
133 // Notify caller that the provided socket address already has a bound socket
134 return llvm::make_error
<StringError
>(
135 std::make_error_code(std::errc::address_in_use
),
136 "Socket address unavailable");
141 SOCKET Socket
= socket(AF_UNIX
, SOCK_STREAM
, 0);
142 if (Socket
== INVALID_SOCKET
)
144 int Socket
= socket(AF_UNIX
, SOCK_STREAM
, 0);
147 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
148 "socket create failed");
150 struct sockaddr_un Addr
= setSocketAddr(SocketPath
);
151 if (::bind(Socket
, (struct sockaddr
*)&Addr
, sizeof(Addr
)) == -1) {
152 // Grab error code from call to ::bind before calling ::close
153 std::error_code EC
= getLastSocketErrorCode();
155 return llvm::make_error
<StringError
>(EC
, "Bind error");
158 // Mark socket as passive so incoming connections can be accepted
159 if (::listen(Socket
, MaxBacklog
) == -1)
160 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
165 // Reserve 1 byte for the pipe and use default textmode
166 if (::_pipe(PipeFD
, 1, 0) == -1)
168 if (::pipe(PipeFD
) == -1)
170 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
174 return ListeningSocket
{_open_osfhandle(Socket
, 0), SocketPath
, PipeFD
};
176 return ListeningSocket
{Socket
, SocketPath
, PipeFD
};
180 // If a file descriptor being monitored by ::poll is closed by another thread,
181 // the result is unspecified. In the case ::poll does not unblock and return,
182 // when ActiveFD is closed, you can provide another file descriptor via CancelFD
183 // that when written to will cause poll to return. Typically CancelFD is the
184 // read end of a unidirectional pipe.
186 // Timeout should be -1 to block indefinitly
188 // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
189 static std::error_code
190 manageTimeout(const std::chrono::milliseconds
&Timeout
,
191 const std::function
<int()> &getActiveFD
,
192 const std::optional
<int> &CancelFD
= std::nullopt
) {
194 FD
[0].events
= POLLIN
;
196 SOCKET WinServerSock
= _get_osfhandle(getActiveFD());
197 FD
[0].fd
= WinServerSock
;
199 FD
[0].fd
= getActiveFD();
202 if (CancelFD
.has_value()) {
203 FD
[1].events
= POLLIN
;
204 FD
[1].fd
= CancelFD
.value();
208 // Keep track of how much time has passed in case ::poll or WSAPoll are
209 // interupted by a signal and need to be recalled
210 auto Start
= std::chrono::steady_clock::now();
211 auto RemainingTimeout
= Timeout
;
214 // If Timeout is -1 then poll should block and RemainingTimeout does not
215 // need to be recalculated
216 if (PollStatus
!= 0 && Timeout
!= std::chrono::milliseconds(-1)) {
217 auto TotalElapsedTime
=
218 std::chrono::duration_cast
<std::chrono::milliseconds
>(
219 std::chrono::steady_clock::now() - Start
);
221 if (TotalElapsedTime
>= Timeout
)
222 return std::make_error_code(std::errc::operation_would_block
);
224 RemainingTimeout
= Timeout
- TotalElapsedTime
;
227 PollStatus
= WSAPoll(FD
, FDCount
, RemainingTimeout
.count());
228 } while (PollStatus
== SOCKET_ERROR
&&
229 getLastSocketErrorCode() == std::errc::interrupted
);
231 PollStatus
= ::poll(FD
, FDCount
, RemainingTimeout
.count());
232 } while (PollStatus
== -1 &&
233 getLastSocketErrorCode() == std::errc::interrupted
);
236 // If ActiveFD equals -1 or CancelFD has data to be read then the operation
237 // has been canceled by another thread
238 if (getActiveFD() == -1 || (CancelFD
.has_value() && FD
[1].revents
& POLLIN
))
239 return std::make_error_code(std::errc::operation_canceled
);
241 if (PollStatus
== SOCKET_ERROR
)
243 if (PollStatus
== -1)
245 return getLastSocketErrorCode();
247 return std::make_error_code(std::errc::timed_out
);
248 if (FD
[0].revents
& POLLNVAL
)
249 return std::make_error_code(std::errc::bad_file_descriptor
);
250 return std::error_code();
253 Expected
<std::unique_ptr
<raw_socket_stream
>>
254 ListeningSocket::accept(const std::chrono::milliseconds
&Timeout
) {
255 auto getActiveFD
= [this]() -> int { return FD
; };
256 std::error_code TimeoutErr
= manageTimeout(Timeout
, getActiveFD
, PipeFD
[0]);
258 return llvm::make_error
<StringError
>(TimeoutErr
, "Timeout error");
262 SOCKET WinAcceptSock
= ::accept(_get_osfhandle(FD
), NULL
, NULL
);
263 AcceptFD
= _open_osfhandle(WinAcceptSock
, 0);
265 AcceptFD
= ::accept(FD
, NULL
, NULL
);
269 return llvm::make_error
<StringError
>(getLastSocketErrorCode(),
270 "Socket accept failed");
271 return std::make_unique
<raw_socket_stream
>(AcceptFD
);
274 void ListeningSocket::shutdown() {
275 int ObservedFD
= FD
.load();
277 if (ObservedFD
== -1)
280 // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then
281 // another thread is responsible for shutdown so return
282 if (!FD
.compare_exchange_strong(ObservedFD
, -1))
286 ::unlink(SocketPath
.c_str());
288 // Ensure ::poll returns if shutdown is called by a separate thread
290 ssize_t written
= ::write(PipeFD
[1], &Byte
, 1);
292 // Ignore any write() error
296 ListeningSocket::~ListeningSocket() {
299 // Close the pipe's FDs in the destructor instead of within
300 // ListeningSocket::shutdown to avoid unnecessary synchronization issues that
301 // would occur as PipeFD's values would have to be changed to -1
303 // The move constructor sets PipeFD to -1
310 //===----------------------------------------------------------------------===//
312 //===----------------------------------------------------------------------===//
314 raw_socket_stream::raw_socket_stream(int SocketFD
)
315 : raw_fd_stream(SocketFD
, true) {}
317 raw_socket_stream::~raw_socket_stream() {}
319 Expected
<std::unique_ptr
<raw_socket_stream
>>
320 raw_socket_stream::createConnectedUnix(StringRef SocketPath
) {
324 Expected
<int> FD
= getSocketFD(SocketPath
);
326 return FD
.takeError();
327 return std::make_unique
<raw_socket_stream
>(*FD
);
330 ssize_t
raw_socket_stream::read(char *Ptr
, size_t Size
,
331 const std::chrono::milliseconds
&Timeout
) {
332 auto getActiveFD
= [this]() -> int { return this->get_fd(); };
333 std::error_code Err
= manageTimeout(Timeout
, getActiveFD
);
334 // Mimic raw_fd_stream::read error handling behavior
336 raw_fd_stream::error_detected(Err
);
339 return raw_fd_stream::read(Ptr
, Size
);