1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
4 #include <sys/socket.h>
6 #include <netinet/in.h>
14 #include <bpf/libbpf.h>
16 #include "cgroup_helpers.h"
17 #include "bpf_rlimit.h"
19 enum bpf_addr_array_idx
{
22 __NR_BPF_ADDR_ARRAY_IDX
,
25 enum bpf_result_array_idx
{
29 __NR_BPF_RESULT_ARRAY_IDX
,
32 enum bpf_linum_array_idx
{
35 __NR_BPF_LINUM_ARRAY_IDX
,
38 struct bpf_spinlock_cnt
{
39 struct bpf_spin_lock lock
;
43 #define CHECK(condition, tag, format...) ({ \
44 int __ret = !!(condition); \
46 printf("%s(%d):FAIL:%s ", __func__, __LINE__, tag); \
53 #define TEST_CGROUP "/test-bpf-sock-fields"
54 #define DATA "Hello BPF!"
55 #define DATA_LEN sizeof(DATA)
57 static struct sockaddr_in6 srv_sa6
, cli_sa6
;
58 static int sk_pkt_out_cnt10_fd
;
59 static int sk_pkt_out_cnt_fd
;
60 static int linum_map_fd
;
61 static int addr_map_fd
;
65 static __u32 addr_srv_idx
= ADDR_SRV_IDX
;
66 static __u32 addr_cli_idx
= ADDR_CLI_IDX
;
68 static __u32 egress_srv_idx
= EGRESS_SRV_IDX
;
69 static __u32 egress_cli_idx
= EGRESS_CLI_IDX
;
70 static __u32 ingress_listen_idx
= INGRESS_LISTEN_IDX
;
72 static __u32 egress_linum_idx
= EGRESS_LINUM_IDX
;
73 static __u32 ingress_linum_idx
= INGRESS_LINUM_IDX
;
75 static void init_loopback6(struct sockaddr_in6
*sa6
)
77 memset(sa6
, 0, sizeof(*sa6
));
78 sa6
->sin6_family
= AF_INET6
;
79 sa6
->sin6_addr
= in6addr_loopback
;
82 static void print_sk(const struct bpf_sock
*sk
)
84 char src_ip4
[24], dst_ip4
[24];
85 char src_ip6
[64], dst_ip6
[64];
87 inet_ntop(AF_INET
, &sk
->src_ip4
, src_ip4
, sizeof(src_ip4
));
88 inet_ntop(AF_INET6
, &sk
->src_ip6
, src_ip6
, sizeof(src_ip6
));
89 inet_ntop(AF_INET
, &sk
->dst_ip4
, dst_ip4
, sizeof(dst_ip4
));
90 inet_ntop(AF_INET6
, &sk
->dst_ip6
, dst_ip6
, sizeof(dst_ip6
));
92 printf("state:%u bound_dev_if:%u family:%u type:%u protocol:%u mark:%u priority:%u "
93 "src_ip4:%x(%s) src_ip6:%x:%x:%x:%x(%s) src_port:%u "
94 "dst_ip4:%x(%s) dst_ip6:%x:%x:%x:%x(%s) dst_port:%u\n",
95 sk
->state
, sk
->bound_dev_if
, sk
->family
, sk
->type
, sk
->protocol
,
96 sk
->mark
, sk
->priority
,
98 sk
->src_ip6
[0], sk
->src_ip6
[1], sk
->src_ip6
[2], sk
->src_ip6
[3],
99 src_ip6
, sk
->src_port
,
100 sk
->dst_ip4
, dst_ip4
,
101 sk
->dst_ip6
[0], sk
->dst_ip6
[1], sk
->dst_ip6
[2], sk
->dst_ip6
[3],
102 dst_ip6
, ntohs(sk
->dst_port
));
105 static void print_tp(const struct bpf_tcp_sock
*tp
)
107 printf("snd_cwnd:%u srtt_us:%u rtt_min:%u snd_ssthresh:%u rcv_nxt:%u "
108 "snd_nxt:%u snd:una:%u mss_cache:%u ecn_flags:%u "
109 "rate_delivered:%u rate_interval_us:%u packets_out:%u "
110 "retrans_out:%u total_retrans:%u segs_in:%u data_segs_in:%u "
111 "segs_out:%u data_segs_out:%u lost_out:%u sacked_out:%u "
112 "bytes_received:%llu bytes_acked:%llu\n",
113 tp
->snd_cwnd
, tp
->srtt_us
, tp
->rtt_min
, tp
->snd_ssthresh
,
114 tp
->rcv_nxt
, tp
->snd_nxt
, tp
->snd_una
, tp
->mss_cache
,
115 tp
->ecn_flags
, tp
->rate_delivered
, tp
->rate_interval_us
,
116 tp
->packets_out
, tp
->retrans_out
, tp
->total_retrans
,
117 tp
->segs_in
, tp
->data_segs_in
, tp
->segs_out
,
118 tp
->data_segs_out
, tp
->lost_out
, tp
->sacked_out
,
119 tp
->bytes_received
, tp
->bytes_acked
);
122 static void check_result(void)
124 struct bpf_tcp_sock srv_tp
, cli_tp
, listen_tp
;
125 struct bpf_sock srv_sk
, cli_sk
, listen_sk
;
126 __u32 ingress_linum
, egress_linum
;
129 err
= bpf_map_lookup_elem(linum_map_fd
, &egress_linum_idx
,
131 CHECK(err
== -1, "bpf_map_lookup_elem(linum_map_fd)",
132 "err:%d errno:%d", err
, errno
);
134 err
= bpf_map_lookup_elem(linum_map_fd
, &ingress_linum_idx
,
136 CHECK(err
== -1, "bpf_map_lookup_elem(linum_map_fd)",
137 "err:%d errno:%d", err
, errno
);
139 err
= bpf_map_lookup_elem(sk_map_fd
, &egress_srv_idx
, &srv_sk
);
140 CHECK(err
== -1, "bpf_map_lookup_elem(sk_map_fd, &egress_srv_idx)",
141 "err:%d errno:%d", err
, errno
);
142 err
= bpf_map_lookup_elem(tp_map_fd
, &egress_srv_idx
, &srv_tp
);
143 CHECK(err
== -1, "bpf_map_lookup_elem(tp_map_fd, &egress_srv_idx)",
144 "err:%d errno:%d", err
, errno
);
146 err
= bpf_map_lookup_elem(sk_map_fd
, &egress_cli_idx
, &cli_sk
);
147 CHECK(err
== -1, "bpf_map_lookup_elem(sk_map_fd, &egress_cli_idx)",
148 "err:%d errno:%d", err
, errno
);
149 err
= bpf_map_lookup_elem(tp_map_fd
, &egress_cli_idx
, &cli_tp
);
150 CHECK(err
== -1, "bpf_map_lookup_elem(tp_map_fd, &egress_cli_idx)",
151 "err:%d errno:%d", err
, errno
);
153 err
= bpf_map_lookup_elem(sk_map_fd
, &ingress_listen_idx
, &listen_sk
);
154 CHECK(err
== -1, "bpf_map_lookup_elem(sk_map_fd, &ingress_listen_idx)",
155 "err:%d errno:%d", err
, errno
);
156 err
= bpf_map_lookup_elem(tp_map_fd
, &ingress_listen_idx
, &listen_tp
);
157 CHECK(err
== -1, "bpf_map_lookup_elem(tp_map_fd, &ingress_listen_idx)",
158 "err:%d errno:%d", err
, errno
);
160 printf("listen_sk: ");
161 print_sk(&listen_sk
);
172 printf("listen_tp: ");
173 print_tp(&listen_tp
);
184 CHECK(listen_sk
.state
!= 10 ||
185 listen_sk
.family
!= AF_INET6
||
186 listen_sk
.protocol
!= IPPROTO_TCP
||
187 memcmp(listen_sk
.src_ip6
, &in6addr_loopback
,
188 sizeof(listen_sk
.src_ip6
)) ||
189 listen_sk
.dst_ip6
[0] || listen_sk
.dst_ip6
[1] ||
190 listen_sk
.dst_ip6
[2] || listen_sk
.dst_ip6
[3] ||
191 listen_sk
.src_port
!= ntohs(srv_sa6
.sin6_port
) ||
193 "Unexpected listen_sk",
194 "Check listen_sk output. ingress_linum:%u",
197 CHECK(srv_sk
.state
== 10 ||
199 srv_sk
.family
!= AF_INET6
||
200 srv_sk
.protocol
!= IPPROTO_TCP
||
201 memcmp(srv_sk
.src_ip6
, &in6addr_loopback
,
202 sizeof(srv_sk
.src_ip6
)) ||
203 memcmp(srv_sk
.dst_ip6
, &in6addr_loopback
,
204 sizeof(srv_sk
.dst_ip6
)) ||
205 srv_sk
.src_port
!= ntohs(srv_sa6
.sin6_port
) ||
206 srv_sk
.dst_port
!= cli_sa6
.sin6_port
,
207 "Unexpected srv_sk", "Check srv_sk output. egress_linum:%u",
210 CHECK(cli_sk
.state
== 10 ||
212 cli_sk
.family
!= AF_INET6
||
213 cli_sk
.protocol
!= IPPROTO_TCP
||
214 memcmp(cli_sk
.src_ip6
, &in6addr_loopback
,
215 sizeof(cli_sk
.src_ip6
)) ||
216 memcmp(cli_sk
.dst_ip6
, &in6addr_loopback
,
217 sizeof(cli_sk
.dst_ip6
)) ||
218 cli_sk
.src_port
!= ntohs(cli_sa6
.sin6_port
) ||
219 cli_sk
.dst_port
!= srv_sa6
.sin6_port
,
220 "Unexpected cli_sk", "Check cli_sk output. egress_linum:%u",
223 CHECK(listen_tp
.data_segs_out
||
224 listen_tp
.data_segs_in
||
225 listen_tp
.total_retrans
||
226 listen_tp
.bytes_acked
,
227 "Unexpected listen_tp", "Check listen_tp output. ingress_linum:%u",
230 CHECK(srv_tp
.data_segs_out
!= 2 ||
231 srv_tp
.data_segs_in
||
232 srv_tp
.snd_cwnd
!= 10 ||
233 srv_tp
.total_retrans
||
234 srv_tp
.bytes_acked
!= 2 * DATA_LEN
,
235 "Unexpected srv_tp", "Check srv_tp output. egress_linum:%u",
238 CHECK(cli_tp
.data_segs_out
||
239 cli_tp
.data_segs_in
!= 2 ||
240 cli_tp
.snd_cwnd
!= 10 ||
241 cli_tp
.total_retrans
||
242 cli_tp
.bytes_received
!= 2 * DATA_LEN
,
243 "Unexpected cli_tp", "Check cli_tp output. egress_linum:%u",
247 static void check_sk_pkt_out_cnt(int accept_fd
, int cli_fd
)
249 struct bpf_spinlock_cnt pkt_out_cnt
= {}, pkt_out_cnt10
= {};
252 pkt_out_cnt
.cnt
= ~0;
253 pkt_out_cnt10
.cnt
= ~0;
254 err
= bpf_map_lookup_elem(sk_pkt_out_cnt_fd
, &accept_fd
, &pkt_out_cnt
);
256 err
= bpf_map_lookup_elem(sk_pkt_out_cnt10_fd
, &accept_fd
,
259 /* The bpf prog only counts for fullsock and
260 * passive conneciton did not become fullsock until 3WHS
262 * The bpf prog only counted two data packet out but we
263 * specially init accept_fd's pkt_out_cnt by 2 in
264 * init_sk_storage(). Hence, 4 here.
266 CHECK(err
|| pkt_out_cnt
.cnt
!= 4 || pkt_out_cnt10
.cnt
!= 40,
267 "bpf_map_lookup_elem(sk_pkt_out_cnt, &accept_fd)",
268 "err:%d errno:%d pkt_out_cnt:%u pkt_out_cnt10:%u",
269 err
, errno
, pkt_out_cnt
.cnt
, pkt_out_cnt10
.cnt
);
271 pkt_out_cnt
.cnt
= ~0;
272 pkt_out_cnt10
.cnt
= ~0;
273 err
= bpf_map_lookup_elem(sk_pkt_out_cnt_fd
, &cli_fd
, &pkt_out_cnt
);
275 err
= bpf_map_lookup_elem(sk_pkt_out_cnt10_fd
, &cli_fd
,
277 /* Active connection is fullsock from the beginning.
278 * 1 SYN and 1 ACK during 3WHS
279 * 2 Acks on data packet.
281 * The bpf_prog initialized it to 0xeB9F.
283 CHECK(err
|| pkt_out_cnt
.cnt
!= 0xeB9F + 4 ||
284 pkt_out_cnt10
.cnt
!= 0xeB9F + 40,
285 "bpf_map_lookup_elem(sk_pkt_out_cnt, &cli_fd)",
286 "err:%d errno:%d pkt_out_cnt:%u pkt_out_cnt10:%u",
287 err
, errno
, pkt_out_cnt
.cnt
, pkt_out_cnt10
.cnt
);
290 static void init_sk_storage(int sk_fd
, __u32 pkt_out_cnt
)
292 struct bpf_spinlock_cnt scnt
= {};
295 scnt
.cnt
= pkt_out_cnt
;
296 err
= bpf_map_update_elem(sk_pkt_out_cnt_fd
, &sk_fd
, &scnt
,
298 CHECK(err
, "bpf_map_update_elem(sk_pkt_out_cnt_fd)",
299 "err:%d errno:%d", err
, errno
);
302 err
= bpf_map_update_elem(sk_pkt_out_cnt10_fd
, &sk_fd
, &scnt
,
304 CHECK(err
, "bpf_map_update_elem(sk_pkt_out_cnt10_fd)",
305 "err:%d errno:%d", err
, errno
);
308 static void test(void)
310 int listen_fd
, cli_fd
, accept_fd
, epfd
, err
;
311 struct epoll_event ev
;
315 addrlen
= sizeof(struct sockaddr_in6
);
318 epfd
= epoll_create(1);
319 CHECK(epfd
== -1, "epoll_create()", "epfd:%d errno:%d", epfd
, errno
);
321 /* Prepare listen_fd */
322 listen_fd
= socket(AF_INET6
, SOCK_STREAM
| SOCK_NONBLOCK
, 0);
323 CHECK(listen_fd
== -1, "socket()", "listen_fd:%d errno:%d",
326 init_loopback6(&srv_sa6
);
327 err
= bind(listen_fd
, (struct sockaddr
*)&srv_sa6
, sizeof(srv_sa6
));
328 CHECK(err
, "bind(listen_fd)", "err:%d errno:%d", err
, errno
);
330 err
= getsockname(listen_fd
, (struct sockaddr
*)&srv_sa6
, &addrlen
);
331 CHECK(err
, "getsockname(listen_fd)", "err:%d errno:%d", err
, errno
);
333 err
= listen(listen_fd
, 1);
334 CHECK(err
, "listen(listen_fd)", "err:%d errno:%d", err
, errno
);
337 cli_fd
= socket(AF_INET6
, SOCK_STREAM
| SOCK_NONBLOCK
, 0);
338 CHECK(cli_fd
== -1, "socket()", "cli_fd:%d errno:%d", cli_fd
, errno
);
340 init_loopback6(&cli_sa6
);
341 err
= bind(cli_fd
, (struct sockaddr
*)&cli_sa6
, sizeof(cli_sa6
));
342 CHECK(err
, "bind(cli_fd)", "err:%d errno:%d", err
, errno
);
344 err
= getsockname(cli_fd
, (struct sockaddr
*)&cli_sa6
, &addrlen
);
345 CHECK(err
, "getsockname(cli_fd)", "err:%d errno:%d",
348 /* Update addr_map with srv_sa6 and cli_sa6 */
349 err
= bpf_map_update_elem(addr_map_fd
, &addr_srv_idx
, &srv_sa6
, 0);
350 CHECK(err
, "map_update", "err:%d errno:%d", err
, errno
);
352 err
= bpf_map_update_elem(addr_map_fd
, &addr_cli_idx
, &cli_sa6
, 0);
353 CHECK(err
, "map_update", "err:%d errno:%d", err
, errno
);
355 /* Connect from cli_sa6 to srv_sa6 */
356 err
= connect(cli_fd
, (struct sockaddr
*)&srv_sa6
, addrlen
);
357 printf("srv_sa6.sin6_port:%u cli_sa6.sin6_port:%u\n\n",
358 ntohs(srv_sa6
.sin6_port
), ntohs(cli_sa6
.sin6_port
));
359 CHECK(err
&& errno
!= EINPROGRESS
,
360 "connect(cli_fd)", "err:%d errno:%d", err
, errno
);
362 ev
.data
.fd
= listen_fd
;
363 err
= epoll_ctl(epfd
, EPOLL_CTL_ADD
, listen_fd
, &ev
);
364 CHECK(err
, "epoll_ctl(EPOLL_CTL_ADD, listen_fd)", "err:%d errno:%d",
367 /* Accept the connection */
368 /* Have some timeout in accept(listen_fd). Just in case. */
369 err
= epoll_wait(epfd
, &ev
, 1, 1000);
370 CHECK(err
!= 1 || ev
.data
.fd
!= listen_fd
,
371 "epoll_wait(listen_fd)",
372 "err:%d errno:%d ev.data.fd:%d listen_fd:%d",
373 err
, errno
, ev
.data
.fd
, listen_fd
);
375 accept_fd
= accept(listen_fd
, NULL
, NULL
);
376 CHECK(accept_fd
== -1, "accept(listen_fd)", "accept_fd:%d errno:%d",
381 err
= epoll_ctl(epfd
, EPOLL_CTL_ADD
, cli_fd
, &ev
);
382 CHECK(err
, "epoll_ctl(EPOLL_CTL_ADD, cli_fd)", "err:%d errno:%d",
385 init_sk_storage(accept_fd
, 2);
387 for (i
= 0; i
< 2; i
++) {
388 /* Send some data from accept_fd to cli_fd */
389 err
= send(accept_fd
, DATA
, DATA_LEN
, 0);
390 CHECK(err
!= DATA_LEN
, "send(accept_fd)", "err:%d errno:%d",
393 /* Have some timeout in recv(cli_fd). Just in case. */
394 err
= epoll_wait(epfd
, &ev
, 1, 1000);
395 CHECK(err
!= 1 || ev
.data
.fd
!= cli_fd
,
396 "epoll_wait(cli_fd)", "err:%d errno:%d ev.data.fd:%d cli_fd:%d",
397 err
, errno
, ev
.data
.fd
, cli_fd
);
399 err
= recv(cli_fd
, NULL
, 0, MSG_TRUNC
);
400 CHECK(err
, "recv(cli_fd)", "err:%d errno:%d", err
, errno
);
403 check_sk_pkt_out_cnt(accept_fd
, cli_fd
);
412 int main(int argc
, char **argv
)
414 struct bpf_prog_load_attr attr
= {
415 .file
= "test_sock_fields_kern.o",
416 .prog_type
= BPF_PROG_TYPE_CGROUP_SKB
,
417 .prog_flags
= BPF_F_TEST_RND_HI32
,
419 int cgroup_fd
, egress_fd
, ingress_fd
, err
;
420 struct bpf_program
*ingress_prog
;
421 struct bpf_object
*obj
;
424 err
= setup_cgroup_environment();
425 CHECK(err
, "setup_cgroup_environment()", "err:%d errno:%d",
428 atexit(cleanup_cgroup_environment
);
430 /* Create a cgroup, get fd, and join it */
431 cgroup_fd
= create_and_get_cgroup(TEST_CGROUP
);
432 CHECK(cgroup_fd
== -1, "create_and_get_cgroup()",
433 "cgroup_fd:%d errno:%d", cgroup_fd
, errno
);
435 err
= join_cgroup(TEST_CGROUP
);
436 CHECK(err
, "join_cgroup", "err:%d errno:%d", err
, errno
);
438 err
= bpf_prog_load_xattr(&attr
, &obj
, &egress_fd
);
439 CHECK(err
, "bpf_prog_load_xattr()", "err:%d", err
);
441 ingress_prog
= bpf_object__find_program_by_title(obj
,
442 "cgroup_skb/ingress");
444 "bpf_object__find_program_by_title(cgroup_skb/ingress)",
446 ingress_fd
= bpf_program__fd(ingress_prog
);
448 err
= bpf_prog_attach(egress_fd
, cgroup_fd
, BPF_CGROUP_INET_EGRESS
, 0);
449 CHECK(err
== -1, "bpf_prog_attach(CPF_CGROUP_INET_EGRESS)",
450 "err:%d errno%d", err
, errno
);
452 err
= bpf_prog_attach(ingress_fd
, cgroup_fd
,
453 BPF_CGROUP_INET_INGRESS
, 0);
454 CHECK(err
== -1, "bpf_prog_attach(CPF_CGROUP_INET_INGRESS)",
455 "err:%d errno%d", err
, errno
);
458 map
= bpf_object__find_map_by_name(obj
, "addr_map");
459 CHECK(!map
, "cannot find addr_map", "(null)");
460 addr_map_fd
= bpf_map__fd(map
);
462 map
= bpf_object__find_map_by_name(obj
, "sock_result_map");
463 CHECK(!map
, "cannot find sock_result_map", "(null)");
464 sk_map_fd
= bpf_map__fd(map
);
466 map
= bpf_object__find_map_by_name(obj
, "tcp_sock_result_map");
467 CHECK(!map
, "cannot find tcp_sock_result_map", "(null)");
468 tp_map_fd
= bpf_map__fd(map
);
470 map
= bpf_object__find_map_by_name(obj
, "linum_map");
471 CHECK(!map
, "cannot find linum_map", "(null)");
472 linum_map_fd
= bpf_map__fd(map
);
474 map
= bpf_object__find_map_by_name(obj
, "sk_pkt_out_cnt");
475 CHECK(!map
, "cannot find sk_pkt_out_cnt", "(null)");
476 sk_pkt_out_cnt_fd
= bpf_map__fd(map
);
478 map
= bpf_object__find_map_by_name(obj
, "sk_pkt_out_cnt10");
479 CHECK(!map
, "cannot find sk_pkt_out_cnt10", "(null)");
480 sk_pkt_out_cnt10_fd
= bpf_map__fd(map
);
484 bpf_object__close(obj
);
485 cleanup_cgroup_environment();