1 // SPDX-License-Identifier: GPL-2.0-only
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
18 #include <sys/epoll.h>
25 /* Install signal handlers */
26 void init_signals(void)
28 struct sigaction act
= {
29 .sa_handler
= sigalrm
,
32 sigaction(SIGALRM
, &act
, NULL
);
33 signal(SIGPIPE
, SIG_IGN
);
36 static unsigned int parse_uint(const char *str
, const char *err_str
)
42 n
= strtoul(str
, &endptr
, 10);
43 if (errno
|| *endptr
!= '\0') {
44 fprintf(stderr
, "malformed %s \"%s\"\n", err_str
, str
);
50 /* Parse a CID in string representation */
51 unsigned int parse_cid(const char *str
)
53 return parse_uint(str
, "CID");
56 /* Parse a port in string representation */
57 unsigned int parse_port(const char *str
)
59 return parse_uint(str
, "port");
62 /* Wait for the remote to close the connection */
63 void vsock_wait_remote_close(int fd
)
65 struct epoll_event ev
;
68 epollfd
= epoll_create1(0);
70 perror("epoll_create1");
74 ev
.events
= EPOLLRDHUP
| EPOLLHUP
;
76 if (epoll_ctl(epollfd
, EPOLL_CTL_ADD
, fd
, &ev
) == -1) {
81 nfds
= epoll_wait(epollfd
, &ev
, 1, TIMEOUT
* 1000);
88 fprintf(stderr
, "epoll_wait timed out\n");
93 assert(ev
.events
& (EPOLLRDHUP
| EPOLLHUP
));
94 assert(ev
.data
.fd
== fd
);
99 /* Create socket <type>, bind to <cid, port> and return the file descriptor. */
100 int vsock_bind(unsigned int cid
, unsigned int port
, int type
)
102 struct sockaddr_vm sa
= {
103 .svm_family
= AF_VSOCK
,
109 fd
= socket(AF_VSOCK
, type
, 0);
115 if (bind(fd
, (struct sockaddr
*)&sa
, sizeof(sa
))) {
123 int vsock_connect_fd(int fd
, unsigned int cid
, unsigned int port
)
125 struct sockaddr_vm sa
= {
126 .svm_family
= AF_VSOCK
,
132 timeout_begin(TIMEOUT
);
134 ret
= connect(fd
, (struct sockaddr
*)&sa
, sizeof(sa
));
135 timeout_check("connect");
136 } while (ret
< 0 && errno
== EINTR
);
142 /* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */
143 int vsock_bind_connect(unsigned int cid
, unsigned int port
, unsigned int bind_port
, int type
)
147 client_fd
= vsock_bind(VMADDR_CID_ANY
, bind_port
, type
);
149 if (vsock_connect_fd(client_fd
, cid
, port
)) {
157 /* Connect to <cid, port> and return the file descriptor. */
158 int vsock_connect(unsigned int cid
, unsigned int port
, int type
)
162 control_expectln("LISTENING");
164 fd
= socket(AF_VSOCK
, type
, 0);
170 if (vsock_connect_fd(fd
, cid
, port
)) {
171 int old_errno
= errno
;
181 int vsock_stream_connect(unsigned int cid
, unsigned int port
)
183 return vsock_connect(cid
, port
, SOCK_STREAM
);
186 int vsock_seqpacket_connect(unsigned int cid
, unsigned int port
)
188 return vsock_connect(cid
, port
, SOCK_SEQPACKET
);
191 /* Listen on <cid, port> and return the file descriptor. */
192 static int vsock_listen(unsigned int cid
, unsigned int port
, int type
)
196 fd
= vsock_bind(cid
, port
, type
);
198 if (listen(fd
, 1) < 0) {
206 /* Listen on <cid, port> and return the first incoming connection. The remote
207 * address is stored to clientaddrp. clientaddrp may be NULL.
209 int vsock_accept(unsigned int cid
, unsigned int port
,
210 struct sockaddr_vm
*clientaddrp
, int type
)
214 struct sockaddr_vm svm
;
216 socklen_t clientaddr_len
= sizeof(clientaddr
.svm
);
217 int fd
, client_fd
, old_errno
;
219 fd
= vsock_listen(cid
, port
, type
);
221 control_writeln("LISTENING");
223 timeout_begin(TIMEOUT
);
225 client_fd
= accept(fd
, &clientaddr
.sa
, &clientaddr_len
);
226 timeout_check("accept");
227 } while (client_fd
< 0 && errno
== EINTR
);
237 if (clientaddr_len
!= sizeof(clientaddr
.svm
)) {
238 fprintf(stderr
, "unexpected addrlen from accept(2), %zu\n",
239 (size_t)clientaddr_len
);
242 if (clientaddr
.sa
.sa_family
!= AF_VSOCK
) {
243 fprintf(stderr
, "expected AF_VSOCK from accept(2), got %d\n",
244 clientaddr
.sa
.sa_family
);
249 *clientaddrp
= clientaddr
.svm
;
253 int vsock_stream_accept(unsigned int cid
, unsigned int port
,
254 struct sockaddr_vm
*clientaddrp
)
256 return vsock_accept(cid
, port
, clientaddrp
, SOCK_STREAM
);
259 int vsock_stream_listen(unsigned int cid
, unsigned int port
)
261 return vsock_listen(cid
, port
, SOCK_STREAM
);
264 int vsock_seqpacket_accept(unsigned int cid
, unsigned int port
,
265 struct sockaddr_vm
*clientaddrp
)
267 return vsock_accept(cid
, port
, clientaddrp
, SOCK_SEQPACKET
);
270 /* Transmit bytes from a buffer and check the return value.
273 * <0 Negative errno (for testing errors)
275 * >0 Success (bytes successfully written)
277 void send_buf(int fd
, const void *buf
, size_t len
, int flags
,
278 ssize_t expected_ret
)
280 ssize_t nwritten
= 0;
283 timeout_begin(TIMEOUT
);
285 ret
= send(fd
, buf
+ nwritten
, len
- nwritten
, flags
);
286 timeout_check("send");
288 if (ret
== 0 || (ret
< 0 && errno
!= EINTR
))
292 } while (nwritten
< len
);
295 if (expected_ret
< 0) {
297 fprintf(stderr
, "bogus send(2) return value %zd (expected %zd)\n",
301 if (errno
!= -expected_ret
) {
313 if (nwritten
!= expected_ret
) {
315 fprintf(stderr
, "unexpected EOF while sending bytes\n");
317 fprintf(stderr
, "bogus send(2) bytes written %zd (expected %zd)\n",
318 nwritten
, expected_ret
);
323 /* Receive bytes in a buffer and check the return value.
326 * <0 Negative errno (for testing errors)
328 * >0 Success (bytes successfully read)
330 void recv_buf(int fd
, void *buf
, size_t len
, int flags
, ssize_t expected_ret
)
335 timeout_begin(TIMEOUT
);
337 ret
= recv(fd
, buf
+ nread
, len
- nread
, flags
);
338 timeout_check("recv");
340 if (ret
== 0 || (ret
< 0 && errno
!= EINTR
))
344 } while (nread
< len
);
347 if (expected_ret
< 0) {
349 fprintf(stderr
, "bogus recv(2) return value %zd (expected %zd)\n",
353 if (errno
!= -expected_ret
) {
365 if (nread
!= expected_ret
) {
367 fprintf(stderr
, "unexpected EOF while receiving bytes\n");
369 fprintf(stderr
, "bogus recv(2) bytes read %zd (expected %zd)\n",
370 nread
, expected_ret
);
375 /* Transmit one byte and check the return value.
378 * <0 Negative errno (for testing errors)
382 void send_byte(int fd
, int expected_ret
, int flags
)
384 static const uint8_t byte
= 'A';
386 send_buf(fd
, &byte
, sizeof(byte
), flags
, expected_ret
);
389 /* Receive one byte and check the return value.
392 * <0 Negative errno (for testing errors)
396 void recv_byte(int fd
, int expected_ret
, int flags
)
400 recv_buf(fd
, &byte
, sizeof(byte
), flags
, expected_ret
);
403 fprintf(stderr
, "unexpected byte read 0x%02x\n", byte
);
408 /* Run test cases. The program terminates if a failure occurs. */
409 void run_tests(const struct test_case
*test_cases
,
410 const struct test_opts
*opts
)
414 for (i
= 0; test_cases
[i
].name
; i
++) {
415 void (*run
)(const struct test_opts
*opts
);
418 printf("%d - %s...", i
, test_cases
[i
].name
);
421 /* Full barrier before executing the next test. This
422 * ensures that client and server are executing the
423 * same test case. In particular, it means whoever is
424 * faster will not see the peer still executing the
425 * last test. This is important because port numbers
426 * can be used by multiple test cases.
428 if (test_cases
[i
].skip
)
429 control_writeln("SKIP");
431 control_writeln("NEXT");
433 line
= control_readln();
434 if (control_cmpln(line
, "SKIP", false) || test_cases
[i
].skip
) {
442 control_cmpln(line
, "NEXT", true);
445 if (opts
->mode
== TEST_MODE_CLIENT
)
446 run
= test_cases
[i
].run_client
;
448 run
= test_cases
[i
].run_server
;
457 void list_tests(const struct test_case
*test_cases
)
461 printf("ID\tTest name\n");
463 for (i
= 0; test_cases
[i
].name
; i
++)
464 printf("%d\t%s\n", i
, test_cases
[i
].name
);
469 static unsigned long parse_test_id(const char *test_id_str
, size_t test_cases_len
)
471 unsigned long test_id
;
475 test_id
= strtoul(test_id_str
, &endptr
, 10);
476 if (errno
|| *endptr
!= '\0') {
477 fprintf(stderr
, "malformed test ID \"%s\"\n", test_id_str
);
481 if (test_id
>= test_cases_len
) {
482 fprintf(stderr
, "test ID (%lu) larger than the max allowed (%lu)\n",
483 test_id
, test_cases_len
- 1);
490 void skip_test(struct test_case
*test_cases
, size_t test_cases_len
,
491 const char *test_id_str
)
493 unsigned long test_id
= parse_test_id(test_id_str
, test_cases_len
);
494 test_cases
[test_id
].skip
= true;
497 void pick_test(struct test_case
*test_cases
, size_t test_cases_len
,
498 const char *test_id_str
)
500 static bool skip_all
= true;
501 unsigned long test_id
;
506 for (i
= 0; i
< test_cases_len
; ++i
)
507 test_cases
[i
].skip
= true;
512 test_id
= parse_test_id(test_id_str
, test_cases_len
);
513 test_cases
[test_id
].skip
= false;
516 unsigned long hash_djb2(const void *data
, size_t len
)
518 unsigned long hash
= 5381;
522 hash
= ((hash
<< 5) + hash
) + ((unsigned char *)data
)[i
];
529 size_t iovec_bytes(const struct iovec
*iov
, size_t iovnum
)
534 for (bytes
= 0, i
= 0; i
< iovnum
; i
++)
535 bytes
+= iov
[i
].iov_len
;
540 unsigned long iovec_hash_djb2(const struct iovec
*iov
, size_t iovnum
)
548 iov_bytes
= iovec_bytes(iov
, iovnum
);
550 tmp
= malloc(iov_bytes
);
556 for (offs
= 0, i
= 0; i
< iovnum
; i
++) {
557 memcpy(tmp
+ offs
, iov
[i
].iov_base
, iov
[i
].iov_len
);
558 offs
+= iov
[i
].iov_len
;
561 hash
= hash_djb2(tmp
, iov_bytes
);
567 /* Allocates and returns new 'struct iovec *' according pattern
568 * in the 'test_iovec'. For each element in the 'test_iovec' it
569 * allocates new element in the resulting 'iovec'. 'iov_len'
570 * of the new element is copied from 'test_iovec'. 'iov_base' is
571 * allocated depending on the 'iov_base' of 'test_iovec':
573 * 'iov_base' == NULL -> valid buf: mmap('iov_len').
575 * 'iov_base' == MAP_FAILED -> invalid buf:
576 * mmap('iov_len'), then munmap('iov_len').
577 * 'iov_base' still contains result of
580 * 'iov_base' == number -> unaligned valid buf:
581 * mmap('iov_len') + number.
583 * 'iovnum' is number of elements in 'test_iovec'.
585 * Returns new 'iovec' or calls 'exit()' on error.
587 struct iovec
*alloc_test_iovec(const struct iovec
*test_iovec
, int iovnum
)
592 iovec
= malloc(sizeof(*iovec
) * iovnum
);
598 for (i
= 0; i
< iovnum
; i
++) {
599 iovec
[i
].iov_len
= test_iovec
[i
].iov_len
;
601 iovec
[i
].iov_base
= mmap(NULL
, iovec
[i
].iov_len
,
602 PROT_READ
| PROT_WRITE
,
603 MAP_PRIVATE
| MAP_ANONYMOUS
| MAP_POPULATE
,
605 if (iovec
[i
].iov_base
== MAP_FAILED
) {
610 if (test_iovec
[i
].iov_base
!= MAP_FAILED
)
611 iovec
[i
].iov_base
+= (uintptr_t)test_iovec
[i
].iov_base
;
614 /* Unmap "invalid" elements. */
615 for (i
= 0; i
< iovnum
; i
++) {
616 if (test_iovec
[i
].iov_base
== MAP_FAILED
) {
617 if (munmap(iovec
[i
].iov_base
, iovec
[i
].iov_len
)) {
624 for (i
= 0; i
< iovnum
; i
++) {
627 if (test_iovec
[i
].iov_base
== MAP_FAILED
)
630 for (j
= 0; j
< iovec
[i
].iov_len
; j
++)
631 ((uint8_t *)iovec
[i
].iov_base
)[j
] = rand() & 0xff;
637 /* Frees 'iovec *', previously allocated by 'alloc_test_iovec()'.
638 * On error calls 'exit()'.
640 void free_test_iovec(const struct iovec
*test_iovec
,
641 struct iovec
*iovec
, int iovnum
)
645 for (i
= 0; i
< iovnum
; i
++) {
646 if (test_iovec
[i
].iov_base
!= MAP_FAILED
) {
647 if (test_iovec
[i
].iov_base
)
648 iovec
[i
].iov_base
-= (uintptr_t)test_iovec
[i
].iov_base
;
650 if (munmap(iovec
[i
].iov_base
, iovec
[i
].iov_len
)) {
660 /* Set "unsigned long long" socket option and check that it's indeed set */
661 void setsockopt_ull_check(int fd
, int level
, int optname
,
662 unsigned long long val
, char const *errmsg
)
664 unsigned long long chkval
;
668 err
= setsockopt(fd
, level
, optname
, &val
, sizeof(val
));
670 fprintf(stderr
, "setsockopt err: %s (%d)\n",
671 strerror(errno
), errno
);
675 chkval
= ~val
; /* just make storage != val */
676 chklen
= sizeof(chkval
);
678 err
= getsockopt(fd
, level
, optname
, &chkval
, &chklen
);
680 fprintf(stderr
, "getsockopt err: %s (%d)\n",
681 strerror(errno
), errno
);
685 if (chklen
!= sizeof(chkval
)) {
686 fprintf(stderr
, "size mismatch: set %zu got %d\n", sizeof(val
),
692 fprintf(stderr
, "value mismatch: set %llu got %llu\n", val
,
698 fprintf(stderr
, "%s val %llu\n", errmsg
, val
);
703 /* Set "int" socket option and check that it's indeed set */
704 void setsockopt_int_check(int fd
, int level
, int optname
, int val
,
711 err
= setsockopt(fd
, level
, optname
, &val
, sizeof(val
));
713 fprintf(stderr
, "setsockopt err: %s (%d)\n",
714 strerror(errno
), errno
);
718 chkval
= ~val
; /* just make storage != val */
719 chklen
= sizeof(chkval
);
721 err
= getsockopt(fd
, level
, optname
, &chkval
, &chklen
);
723 fprintf(stderr
, "getsockopt err: %s (%d)\n",
724 strerror(errno
), errno
);
728 if (chklen
!= sizeof(chkval
)) {
729 fprintf(stderr
, "size mismatch: set %zu got %d\n", sizeof(val
),
735 fprintf(stderr
, "value mismatch: set %d got %d\n", val
, chkval
);
740 fprintf(stderr
, "%s val %d\n", errmsg
, val
);
744 static void mem_invert(unsigned char *mem
, size_t size
)
748 for (i
= 0; i
< size
; i
++)
752 /* Set "timeval" socket option and check that it's indeed set */
753 void setsockopt_timeval_check(int fd
, int level
, int optname
,
754 struct timeval val
, char const *errmsg
)
756 struct timeval chkval
;
760 err
= setsockopt(fd
, level
, optname
, &val
, sizeof(val
));
762 fprintf(stderr
, "setsockopt err: %s (%d)\n",
763 strerror(errno
), errno
);
767 /* just make storage != val */
769 mem_invert((unsigned char *)&chkval
, sizeof(chkval
));
770 chklen
= sizeof(chkval
);
772 err
= getsockopt(fd
, level
, optname
, &chkval
, &chklen
);
774 fprintf(stderr
, "getsockopt err: %s (%d)\n",
775 strerror(errno
), errno
);
779 if (chklen
!= sizeof(chkval
)) {
780 fprintf(stderr
, "size mismatch: set %zu got %d\n", sizeof(val
),
785 if (memcmp(&chkval
, &val
, sizeof(val
)) != 0) {
786 fprintf(stderr
, "value mismatch: set %ld:%ld got %ld:%ld\n",
787 val
.tv_sec
, val
.tv_usec
, chkval
.tv_sec
, chkval
.tv_usec
);
792 fprintf(stderr
, "%s val %ld:%ld\n", errmsg
, val
.tv_sec
, val
.tv_usec
);
796 void enable_so_zerocopy_check(int fd
)
798 setsockopt_int_check(fd
, SOL_SOCKET
, SO_ZEROCOPY
, 1,
799 "setsockopt SO_ZEROCOPY");