1 // SPDX-License-Identifier: GPL-2.0-only
10 #include <linux/err.h>
12 #include <linux/in6.h>
15 #include "network_helpers.h"
17 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
18 #define log_err(MSG, ...) ({ \
20 fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
21 __FILE__, __LINE__, clean_errno(), \
26 struct ipv4_packet pkt_v4
= {
27 .eth
.h_proto
= __bpf_constant_htons(ETH_P_IP
),
29 .iph
.protocol
= IPPROTO_TCP
,
30 .iph
.tot_len
= __bpf_constant_htons(MAGIC_BYTES
),
35 struct ipv6_packet pkt_v6
= {
36 .eth
.h_proto
= __bpf_constant_htons(ETH_P_IPV6
),
37 .iph
.nexthdr
= IPPROTO_TCP
,
38 .iph
.payload_len
= __bpf_constant_htons(MAGIC_BYTES
),
43 static int settimeo(int fd
, int timeout_ms
)
45 struct timeval timeout
= { .tv_sec
= 3 };
48 timeout
.tv_sec
= timeout_ms
/ 1000;
49 timeout
.tv_usec
= (timeout_ms
% 1000) * 1000;
52 if (setsockopt(fd
, SOL_SOCKET
, SO_RCVTIMEO
, &timeout
,
54 log_err("Failed to set SO_RCVTIMEO");
58 if (setsockopt(fd
, SOL_SOCKET
, SO_SNDTIMEO
, &timeout
,
60 log_err("Failed to set SO_SNDTIMEO");
67 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
69 int start_server(int family
, int type
, const char *addr_str
, __u16 port
,
72 struct sockaddr_storage addr
= {};
76 if (make_sockaddr(family
, addr_str
, port
, &addr
, &len
))
79 fd
= socket(family
, type
, 0);
81 log_err("Failed to create server socket");
85 if (settimeo(fd
, timeout_ms
))
88 if (bind(fd
, (const struct sockaddr
*)&addr
, len
) < 0) {
89 log_err("Failed to bind socket");
93 if (type
== SOCK_STREAM
) {
94 if (listen(fd
, 1) < 0) {
95 log_err("Failed to listed on socket");
103 save_errno_close(fd
);
107 int fastopen_connect(int server_fd
, const char *data
, unsigned int data_len
,
110 struct sockaddr_storage addr
;
111 socklen_t addrlen
= sizeof(addr
);
112 struct sockaddr_in
*addr_in
;
115 if (getsockname(server_fd
, (struct sockaddr
*)&addr
, &addrlen
)) {
116 log_err("Failed to get server addr");
120 addr_in
= (struct sockaddr_in
*)&addr
;
121 fd
= socket(addr_in
->sin_family
, SOCK_STREAM
, 0);
123 log_err("Failed to create client socket");
127 if (settimeo(fd
, timeout_ms
))
130 ret
= sendto(fd
, data
, data_len
, MSG_FASTOPEN
, (struct sockaddr
*)&addr
,
132 if (ret
!= data_len
) {
133 log_err("sendto(data, %u) != %d\n", data_len
, ret
);
140 save_errno_close(fd
);
144 static int connect_fd_to_addr(int fd
,
145 const struct sockaddr_storage
*addr
,
148 if (connect(fd
, (const struct sockaddr
*)addr
, addrlen
)) {
149 log_err("Failed to connect to server");
156 int connect_to_fd(int server_fd
, int timeout_ms
)
158 struct sockaddr_storage addr
;
159 struct sockaddr_in
*addr_in
;
160 socklen_t addrlen
, optlen
;
163 optlen
= sizeof(type
);
164 if (getsockopt(server_fd
, SOL_SOCKET
, SO_TYPE
, &type
, &optlen
)) {
165 log_err("getsockopt(SOL_TYPE)");
169 addrlen
= sizeof(addr
);
170 if (getsockname(server_fd
, (struct sockaddr
*)&addr
, &addrlen
)) {
171 log_err("Failed to get server addr");
175 addr_in
= (struct sockaddr_in
*)&addr
;
176 fd
= socket(addr_in
->sin_family
, type
, 0);
178 log_err("Failed to create client socket");
182 if (settimeo(fd
, timeout_ms
))
185 if (connect_fd_to_addr(fd
, &addr
, addrlen
))
191 save_errno_close(fd
);
195 int connect_fd_to_fd(int client_fd
, int server_fd
, int timeout_ms
)
197 struct sockaddr_storage addr
;
198 socklen_t len
= sizeof(addr
);
200 if (settimeo(client_fd
, timeout_ms
))
203 if (getsockname(server_fd
, (struct sockaddr
*)&addr
, &len
)) {
204 log_err("Failed to get server addr");
208 if (connect_fd_to_addr(client_fd
, &addr
, len
))
214 int make_sockaddr(int family
, const char *addr_str
, __u16 port
,
215 struct sockaddr_storage
*addr
, socklen_t
*len
)
217 if (family
== AF_INET
) {
218 struct sockaddr_in
*sin
= (void *)addr
;
220 sin
->sin_family
= AF_INET
;
221 sin
->sin_port
= htons(port
);
223 inet_pton(AF_INET
, addr_str
, &sin
->sin_addr
) != 1) {
224 log_err("inet_pton(AF_INET, %s)", addr_str
);
230 } else if (family
== AF_INET6
) {
231 struct sockaddr_in6
*sin6
= (void *)addr
;
233 sin6
->sin6_family
= AF_INET6
;
234 sin6
->sin6_port
= htons(port
);
236 inet_pton(AF_INET6
, addr_str
, &sin6
->sin6_addr
) != 1) {
237 log_err("inet_pton(AF_INET6, %s)", addr_str
);
241 *len
= sizeof(*sin6
);