1 // SPDX-License-Identifier: GPL-2.0
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
10 const unsigned int test_server_port
= 7010;
11 int __test_listen_socket(int backlog
, void *addr
, size_t addr_sz
)
13 int err
, sk
= socket(test_family
, SOCK_STREAM
, IPPROTO_TCP
);
17 test_error("socket()");
19 err
= setsockopt(sk
, SOL_SOCKET
, SO_BINDTODEVICE
, veth_name
,
20 strlen(veth_name
) + 1);
22 test_error("setsockopt(SO_BINDTODEVICE)");
24 if (bind(sk
, (struct sockaddr
*)addr
, addr_sz
) < 0)
27 flags
= fcntl(sk
, F_GETFL
);
28 if ((flags
< 0) || (fcntl(sk
, F_SETFL
, flags
| O_NONBLOCK
) < 0))
29 test_error("fcntl()");
31 if (listen(sk
, backlog
))
32 test_error("listen()");
37 int test_wait_fd(int sk
, time_t sec
, bool write
)
39 struct timeval tv
= { .tv_sec
= sec
};
40 struct timeval
*ptv
= NULL
;
43 socklen_t slen
= sizeof(ret
);
55 ret
= select(sk
+ 1, NULL
, &fds
, &efds
, ptv
);
57 ret
= select(sk
+ 1, &fds
, NULL
, &efds
, ptv
);
65 if (getsockopt(sk
, SOL_SOCKET
, SO_ERROR
, &ret
, &slen
))
72 int __test_connect_socket(int sk
, const char *device
,
73 void *addr
, size_t addr_sz
, time_t timeout
)
79 err
= setsockopt(sk
, SOL_SOCKET
, SO_BINDTODEVICE
, device
,
82 test_error("setsockopt(SO_BINDTODEVICE, %s)", device
);
86 err
= connect(sk
, addr
, addr_sz
);
94 flags
= fcntl(sk
, F_GETFL
);
95 if ((flags
< 0) || (fcntl(sk
, F_SETFL
, flags
| O_NONBLOCK
) < 0))
96 test_error("fcntl()");
98 if (connect(sk
, addr
, addr_sz
) < 0) {
99 if (errno
!= EINPROGRESS
) {
105 err
= test_wait_fd(sk
, timeout
, 1);
116 int __test_set_md5(int sk
, void *addr
, size_t addr_sz
, uint8_t prefix
,
117 int vrf
, const char *password
)
119 size_t pwd_len
= strlen(password
);
120 struct tcp_md5sig md5sig
= {};
122 md5sig
.tcpm_keylen
= pwd_len
;
123 memcpy(md5sig
.tcpm_key
, password
, pwd_len
);
124 md5sig
.tcpm_flags
= TCP_MD5SIG_FLAG_PREFIX
;
125 md5sig
.tcpm_prefixlen
= prefix
;
127 md5sig
.tcpm_flags
|= TCP_MD5SIG_FLAG_IFINDEX
;
128 md5sig
.tcpm_ifindex
= (uint8_t)vrf
;
130 memcpy(&md5sig
.tcpm_addr
, addr
, addr_sz
);
133 return setsockopt(sk
, IPPROTO_TCP
, TCP_MD5SIG_EXT
,
134 &md5sig
, sizeof(md5sig
));
138 int test_prepare_key_sockaddr(struct tcp_ao_add
*ao
, const char *alg
,
139 void *addr
, size_t addr_sz
, bool set_current
, bool set_rnext
,
140 uint8_t prefix
, uint8_t vrf
, uint8_t sndid
, uint8_t rcvid
,
141 uint8_t maclen
, uint8_t keyflags
,
142 uint8_t keylen
, const char *key
)
144 memset(ao
, 0, sizeof(struct tcp_ao_add
));
146 ao
->set_current
= !!set_current
;
147 ao
->set_rnext
= !!set_rnext
;
152 ao
->keyflags
= keyflags
;
156 memcpy(&ao
->addr
, addr
, addr_sz
);
158 if (strlen(alg
) > 64)
160 strncpy(ao
->alg_name
, alg
, 64);
163 (keylen
> TCP_AO_MAXKEYLEN
) ? TCP_AO_MAXKEYLEN
: keylen
);
167 static int test_get_ao_keys_nr(int sk
)
169 struct tcp_ao_getsockopt tmp
= {};
170 socklen_t tmp_sz
= sizeof(tmp
);
176 ret
= getsockopt(sk
, IPPROTO_TCP
, TCP_AO_GET_KEYS
, &tmp
, &tmp_sz
);
179 return (int)tmp
.nkeys
;
182 int test_get_one_ao(int sk
, struct tcp_ao_getsockopt
*out
,
183 void *addr
, size_t addr_sz
, uint8_t prefix
,
184 uint8_t sndid
, uint8_t rcvid
)
186 struct tcp_ao_getsockopt tmp
= {};
187 socklen_t tmp_sz
= sizeof(tmp
);
190 memcpy(&tmp
.addr
, addr
, addr_sz
);
196 ret
= getsockopt(sk
, IPPROTO_TCP
, TCP_AO_GET_KEYS
, &tmp
, &tmp_sz
);
205 int test_get_ao_info(int sk
, struct tcp_ao_info_opt
*out
)
207 socklen_t sz
= sizeof(*out
);
211 if (getsockopt(sk
, IPPROTO_TCP
, TCP_AO_INFO
, out
, &sz
))
213 if (sz
!= sizeof(*out
))
218 int test_set_ao_info(int sk
, struct tcp_ao_info_opt
*in
)
220 socklen_t sz
= sizeof(*in
);
224 if (setsockopt(sk
, IPPROTO_TCP
, TCP_AO_INFO
, in
, sz
))
229 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add
*a
,
230 const struct tcp_ao_getsockopt
*b
)
232 bool is_kdf_aes_128_cmac
= false;
233 bool is_cmac_aes
= false;
235 if (!strcmp("cmac(aes128)", a
->alg_name
)) {
236 is_kdf_aes_128_cmac
= (a
->keylen
!= 16);
240 #define __cmp_ao(member) \
242 if (b->member != a->member) { \
243 test_fail("getsockopt(): " __stringify(member) " %u != %u", \
244 b->member, a->member); \
255 } else if (b
->maclen
!= 12) {
256 test_fail("getsockopt(): expected default maclen 12, but it's %u",
260 if (!is_kdf_aes_128_cmac
) {
262 } else if (b
->keylen
!= 16) {
263 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
268 if (!is_kdf_aes_128_cmac
&& memcmp(b
->key
, a
->key
, a
->keylen
)) {
269 test_fail("getsockopt(): returned key is different `%s' != `%s'",
273 if (memcmp(&b
->addr
, &a
->addr
, sizeof(b
->addr
))) {
274 test_fail("getsockopt(): returned address is different");
277 if (!is_cmac_aes
&& strcmp(b
->alg_name
, a
->alg_name
)) {
278 test_fail("getsockopt(): returned algorithm %s is different than %s", b
->alg_name
, a
->alg_name
);
281 if (is_cmac_aes
&& strcmp(b
->alg_name
, "cmac(aes)")) {
282 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b
->alg_name
);
285 /* For a established key rotation test don't add a key with
286 * set_current = 1, as it's likely to change by peer's request;
287 * rather use setsockopt(TCP_AO_INFO)
289 if (a
->set_current
!= b
->is_current
) {
290 test_fail("getsockopt(): returned key is not Current_key");
293 if (a
->set_rnext
!= b
->is_rnext
) {
294 test_fail("getsockopt(): returned key is not RNext_key");
301 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt
*a
,
302 const struct tcp_ao_info_opt
*b
)
304 /* No check for ::current_key, as it may change by the peer */
305 if (a
->ao_required
!= b
->ao_required
) {
306 test_fail("getsockopt(): returned ao doesn't have ao_required");
309 if (a
->accept_icmps
!= b
->accept_icmps
) {
310 test_fail("getsockopt(): returned ao doesn't accept ICMPs");
313 if (a
->set_rnext
&& a
->rnext
!= b
->rnext
) {
314 test_fail("getsockopt(): RNext KeyID has changed");
317 #define __cmp_cnt(member) \
319 if (b->member != a->member) { \
320 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \
321 b->member, a->member); \
325 if (a
->set_counters
) {
328 __cmp_cnt(pkt_key_not_found
);
329 __cmp_cnt(pkt_ao_required
);
330 __cmp_cnt(pkt_dropped_icmp
);
336 int test_get_tcp_ao_counters(int sk
, struct tcp_ao_counters
*out
)
338 struct tcp_ao_getsockopt
*key_dump
;
339 socklen_t key_dump_sz
= sizeof(*key_dump
);
340 struct tcp_ao_info_opt info
= {};
341 bool c1
, c2
, c3
, c4
, c5
;
345 memset(out
, 0, sizeof(*out
));
349 out
->netns_ao_good
= netstat_get(ns
, "TCPAOGood", &c1
);
350 out
->netns_ao_bad
= netstat_get(ns
, "TCPAOBad", &c2
);
351 out
->netns_ao_key_not_found
= netstat_get(ns
, "TCPAOKeyNotFound", &c3
);
352 out
->netns_ao_required
= netstat_get(ns
, "TCPAORequired", &c4
);
353 out
->netns_ao_dropped_icmp
= netstat_get(ns
, "TCPAODroppedIcmps", &c5
);
355 if (c1
|| c2
|| c3
|| c4
|| c5
)
358 err
= test_get_ao_info(sk
, &info
);
363 out
->ao_info_pkt_good
= info
.pkt_good
;
364 out
->ao_info_pkt_bad
= info
.pkt_bad
;
365 out
->ao_info_pkt_key_not_found
= info
.pkt_key_not_found
;
366 out
->ao_info_pkt_ao_required
= info
.pkt_ao_required
;
367 out
->ao_info_pkt_dropped_icmp
= info
.pkt_dropped_icmp
;
370 nr_keys
= test_get_ao_keys_nr(sk
);
374 test_error("test_get_ao_keys_nr() == 0");
375 out
->nr_keys
= (size_t)nr_keys
;
376 key_dump
= calloc(nr_keys
, key_dump_sz
);
380 key_dump
[0].nkeys
= nr_keys
;
381 key_dump
[0].get_all
= 1;
382 err
= getsockopt(sk
, IPPROTO_TCP
, TCP_AO_GET_KEYS
,
383 key_dump
, &key_dump_sz
);
389 out
->key_cnts
= calloc(nr_keys
, sizeof(out
->key_cnts
[0]));
390 if (!out
->key_cnts
) {
396 out
->key_cnts
[nr_keys
].sndid
= key_dump
[nr_keys
].sndid
;
397 out
->key_cnts
[nr_keys
].rcvid
= key_dump
[nr_keys
].rcvid
;
398 out
->key_cnts
[nr_keys
].pkt_good
= key_dump
[nr_keys
].pkt_good
;
399 out
->key_cnts
[nr_keys
].pkt_bad
= key_dump
[nr_keys
].pkt_bad
;
406 int __test_tcp_ao_counters_cmp(const char *tst_name
,
407 struct tcp_ao_counters
*before
,
408 struct tcp_ao_counters
*after
,
411 #define __cmp_ao(cnt, expecting_inc) \
413 if (before->cnt > after->cnt) { \
414 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
415 tst_name ?: "", before->cnt, after->cnt); \
418 if ((before->cnt != after->cnt) != (expecting_inc)) { \
419 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
420 tst_name ?: "", (expecting_inc) ? "" : "not ", \
421 before->cnt, after->cnt); \
428 __cmp_ao(netns_ao_good
, !!(expected
& TEST_CNT_NS_GOOD
));
429 __cmp_ao(netns_ao_bad
, !!(expected
& TEST_CNT_NS_BAD
));
430 __cmp_ao(netns_ao_key_not_found
,
431 !!(expected
& TEST_CNT_NS_KEY_NOT_FOUND
));
432 __cmp_ao(netns_ao_required
, !!(expected
& TEST_CNT_NS_AO_REQUIRED
));
433 __cmp_ao(netns_ao_dropped_icmp
,
434 !!(expected
& TEST_CNT_NS_DROPPED_ICMP
));
436 __cmp_ao(ao_info_pkt_good
, !!(expected
& TEST_CNT_SOCK_GOOD
));
437 __cmp_ao(ao_info_pkt_bad
, !!(expected
& TEST_CNT_SOCK_BAD
));
438 __cmp_ao(ao_info_pkt_key_not_found
,
439 !!(expected
& TEST_CNT_SOCK_KEY_NOT_FOUND
));
440 __cmp_ao(ao_info_pkt_ao_required
, !!(expected
& TEST_CNT_SOCK_AO_REQUIRED
));
441 __cmp_ao(ao_info_pkt_dropped_icmp
,
442 !!(expected
& TEST_CNT_SOCK_DROPPED_ICMP
));
447 int test_tcp_ao_key_counters_cmp(const char *tst_name
,
448 struct tcp_ao_counters
*before
,
449 struct tcp_ao_counters
*after
,
451 int sndid
, int rcvid
)
454 #define __cmp_ao(i, cnt, expecting_inc) \
456 if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \
457 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
458 tst_name ?: "", before->key_cnts[i].cnt, \
459 after->key_cnts[i].cnt, \
460 before->key_cnts[i].sndid, \
461 before->key_cnts[i].rcvid); \
464 if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) { \
465 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
466 tst_name ?: "", (expecting_inc) ? "" : "not ",\
467 before->key_cnts[i].cnt, \
468 after->key_cnts[i].cnt, \
469 before->key_cnts[i].sndid, \
470 before->key_cnts[i].rcvid); \
475 if (before
->nr_keys
!= after
->nr_keys
) {
476 test_fail("%s: Keys changed on the socket %zu != %zu",
477 tst_name
, before
->nr_keys
, after
->nr_keys
);
484 if (sndid
>= 0 && before
->key_cnts
[i
].sndid
!= sndid
)
486 if (rcvid
>= 0 && before
->key_cnts
[i
].rcvid
!= rcvid
)
488 __cmp_ao(i
, pkt_good
, !!(expected
& TEST_CNT_KEY_GOOD
));
489 __cmp_ao(i
, pkt_bad
, !!(expected
& TEST_CNT_KEY_BAD
));
495 void test_tcp_ao_counters_free(struct tcp_ao_counters
*cnts
)
497 free(cnts
->key_cnts
);
500 #define TEST_BUF_SIZE 4096
501 ssize_t
test_server_run(int sk
, ssize_t quota
, time_t timeout_sec
)
506 char buf
[TEST_BUF_SIZE
];
510 ret
= test_wait_fd(sk
, timeout_sec
, 0);
514 bytes
= recv(sk
, buf
, sizeof(buf
), 0);
517 test_error("recv(): %zd", bytes
);
521 ret
= test_wait_fd(sk
, timeout_sec
, 1);
525 sent
= send(sk
, buf
, bytes
, 0);
529 test_error("send()");
531 } while (!quota
|| total
< quota
);
536 ssize_t
test_client_loop(int sk
, char *buf
, size_t buf_sz
,
537 const size_t msg_len
, time_t timeout_sec
)
543 if (setsockopt(sk
, IPPROTO_TCP
, TCP_NODELAY
, &nodelay
, sizeof(nodelay
)))
544 test_error("setsockopt(TCP_NODELAY)");
546 for (i
= 0; i
< buf_sz
; i
+= min(msg_len
, buf_sz
- i
)) {
547 size_t sent
, bytes
= min(msg_len
, buf_sz
- i
);
550 ret
= test_wait_fd(sk
, timeout_sec
, 1);
554 sent
= send(sk
, buf
+ i
, bytes
, 0);
558 test_error("send()");
564 ret
= test_wait_fd(sk
, timeout_sec
, 0);
568 got
= recv(sk
, msg
+ bytes
, sizeof(msg
) - bytes
, 0);
572 } while (bytes
< sent
);
574 test_error("recv(): %zd > %zd", bytes
, sent
);
575 if (memcmp(buf
+ i
, msg
, bytes
) != 0) {
576 test_fail("received message differs");
583 int test_client_verify(int sk
, const size_t msg_len
, const size_t nr
,
586 size_t buf_sz
= msg_len
* nr
;
587 char *buf
= alloca(buf_sz
);
590 randomize_buffer(buf
, buf_sz
);
591 ret
= test_client_loop(sk
, buf
, buf_sz
, msg_len
, timeout_sec
);
594 return ret
!= buf_sz
? -1 : 0;