1 // SPDX-License-Identifier: GPL-2.0
2 #include <test_progs.h>
3 #include "cgroup_helpers.h"
5 #define SOL_CUSTOM 0xdeadbeef
6 #define CUSTOM_INHERIT1 0
7 #define CUSTOM_INHERIT2 1
8 #define CUSTOM_LISTENER 2
10 static int connect_to_server(int server_fd
)
12 struct sockaddr_storage addr
;
13 socklen_t len
= sizeof(addr
);
16 fd
= socket(AF_INET
, SOCK_STREAM
, 0);
18 log_err("Failed to create client socket");
22 if (getsockname(server_fd
, (struct sockaddr
*)&addr
, &len
)) {
23 log_err("Failed to get server addr");
27 if (connect(fd
, (const struct sockaddr
*)&addr
, len
) < 0) {
28 log_err("Fail to connect to server");
39 static int verify_sockopt(int fd
, int optname
, const char *msg
, char expected
)
45 err
= getsockopt(fd
, SOL_CUSTOM
, optname
, &buf
, &optlen
);
47 log_err("%s: failed to call getsockopt", msg
);
51 printf("%s %d: got=0x%x ? expected=0x%x\n", msg
, optname
, buf
, expected
);
53 if (buf
!= expected
) {
54 log_err("%s: unexpected getsockopt value %d != %d", msg
,
62 static pthread_mutex_t server_started_mtx
= PTHREAD_MUTEX_INITIALIZER
;
63 static pthread_cond_t server_started
= PTHREAD_COND_INITIALIZER
;
65 static void *server_thread(void *arg
)
67 struct sockaddr_storage addr
;
68 socklen_t len
= sizeof(addr
);
75 pthread_mutex_lock(&server_started_mtx
);
76 pthread_cond_signal(&server_started
);
77 pthread_mutex_unlock(&server_started_mtx
);
79 if (CHECK_FAIL(err
< 0)) {
80 perror("Failed to listed on socket");
84 err
+= verify_sockopt(fd
, CUSTOM_INHERIT1
, "listen", 1);
85 err
+= verify_sockopt(fd
, CUSTOM_INHERIT2
, "listen", 1);
86 err
+= verify_sockopt(fd
, CUSTOM_LISTENER
, "listen", 1);
88 client_fd
= accept(fd
, (struct sockaddr
*)&addr
, &len
);
89 if (CHECK_FAIL(client_fd
< 0)) {
90 perror("Failed to accept client");
94 err
+= verify_sockopt(client_fd
, CUSTOM_INHERIT1
, "accept", 1);
95 err
+= verify_sockopt(client_fd
, CUSTOM_INHERIT2
, "accept", 1);
96 err
+= verify_sockopt(client_fd
, CUSTOM_LISTENER
, "accept", 0);
100 return (void *)(long)err
;
103 static int start_server(void)
105 struct sockaddr_in addr
= {
106 .sin_family
= AF_INET
,
107 .sin_addr
.s_addr
= htonl(INADDR_LOOPBACK
),
114 fd
= socket(AF_INET
, SOCK_STREAM
, 0);
116 log_err("Failed to create server socket");
120 for (i
= CUSTOM_INHERIT1
; i
<= CUSTOM_LISTENER
; i
++) {
122 err
= setsockopt(fd
, SOL_CUSTOM
, i
, &buf
, 1);
124 log_err("Failed to call setsockopt(%d)", i
);
130 if (bind(fd
, (const struct sockaddr
*)&addr
, sizeof(addr
)) < 0) {
131 log_err("Failed to bind socket");
139 static int prog_attach(struct bpf_object
*obj
, int cgroup_fd
, const char *title
)
141 enum bpf_attach_type attach_type
;
142 enum bpf_prog_type prog_type
;
143 struct bpf_program
*prog
;
146 err
= libbpf_prog_type_by_name(title
, &prog_type
, &attach_type
);
148 log_err("Failed to deduct types for %s BPF program", title
);
152 prog
= bpf_object__find_program_by_title(obj
, title
);
154 log_err("Failed to find %s BPF program", title
);
158 err
= bpf_prog_attach(bpf_program__fd(prog
), cgroup_fd
,
161 log_err("Failed to attach %s BPF program", title
);
168 static void run_test(int cgroup_fd
)
170 struct bpf_prog_load_attr attr
= {
171 .file
= "./sockopt_inherit.o",
173 int server_fd
= -1, client_fd
;
174 struct bpf_object
*obj
;
180 err
= bpf_prog_load_xattr(&attr
, &obj
, &ignored
);
184 err
= prog_attach(obj
, cgroup_fd
, "cgroup/getsockopt");
186 goto close_bpf_object
;
188 err
= prog_attach(obj
, cgroup_fd
, "cgroup/setsockopt");
190 goto close_bpf_object
;
192 server_fd
= start_server();
193 if (CHECK_FAIL(server_fd
< 0))
194 goto close_bpf_object
;
196 if (CHECK_FAIL(pthread_create(&tid
, NULL
, server_thread
,
197 (void *)&server_fd
)))
198 goto close_server_fd
;
200 pthread_mutex_lock(&server_started_mtx
);
201 pthread_cond_wait(&server_started
, &server_started_mtx
);
202 pthread_mutex_unlock(&server_started_mtx
);
204 client_fd
= connect_to_server(server_fd
);
205 if (CHECK_FAIL(client_fd
< 0))
206 goto close_server_fd
;
208 CHECK_FAIL(verify_sockopt(client_fd
, CUSTOM_INHERIT1
, "connect", 0));
209 CHECK_FAIL(verify_sockopt(client_fd
, CUSTOM_INHERIT2
, "connect", 0));
210 CHECK_FAIL(verify_sockopt(client_fd
, CUSTOM_LISTENER
, "connect", 0));
212 pthread_join(tid
, &server_err
);
214 err
= (int)(long)server_err
;
222 bpf_object__close(obj
);
225 void test_sockopt_inherit(void)
229 cgroup_fd
= test__join_cgroup("/sockopt_inherit");
230 if (CHECK_FAIL(cgroup_fd
< 0))