1 // SPDX-License-Identifier: GPL-2.0-only
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
18 #include <sys/epoll.h>
25 /* Install signal handlers */
26 void init_signals(void)
28 struct sigaction act
= {
29 .sa_handler
= sigalrm
,
32 sigaction(SIGALRM
, &act
, NULL
);
33 signal(SIGPIPE
, SIG_IGN
);
36 static unsigned int parse_uint(const char *str
, const char *err_str
)
42 n
= strtoul(str
, &endptr
, 10);
43 if (errno
|| *endptr
!= '\0') {
44 fprintf(stderr
, "malformed %s \"%s\"\n", err_str
, str
);
50 /* Parse a CID in string representation */
51 unsigned int parse_cid(const char *str
)
53 return parse_uint(str
, "CID");
56 /* Parse a port in string representation */
57 unsigned int parse_port(const char *str
)
59 return parse_uint(str
, "port");
62 /* Wait for the remote to close the connection */
63 void vsock_wait_remote_close(int fd
)
65 struct epoll_event ev
;
68 epollfd
= epoll_create1(0);
70 perror("epoll_create1");
74 ev
.events
= EPOLLRDHUP
| EPOLLHUP
;
76 if (epoll_ctl(epollfd
, EPOLL_CTL_ADD
, fd
, &ev
) == -1) {
81 nfds
= epoll_wait(epollfd
, &ev
, 1, TIMEOUT
* 1000);
88 fprintf(stderr
, "epoll_wait timed out\n");
93 assert(ev
.events
& (EPOLLRDHUP
| EPOLLHUP
));
94 assert(ev
.data
.fd
== fd
);
99 /* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */
100 int vsock_bind_connect(unsigned int cid
, unsigned int port
, unsigned int bind_port
, int type
)
102 struct sockaddr_vm sa_client
= {
103 .svm_family
= AF_VSOCK
,
104 .svm_cid
= VMADDR_CID_ANY
,
105 .svm_port
= bind_port
,
107 struct sockaddr_vm sa_server
= {
108 .svm_family
= AF_VSOCK
,
115 client_fd
= socket(AF_VSOCK
, type
, 0);
121 if (bind(client_fd
, (struct sockaddr
*)&sa_client
, sizeof(sa_client
))) {
126 timeout_begin(TIMEOUT
);
128 ret
= connect(client_fd
, (struct sockaddr
*)&sa_server
, sizeof(sa_server
));
129 timeout_check("connect");
130 } while (ret
< 0 && errno
== EINTR
);
141 /* Connect to <cid, port> and return the file descriptor. */
142 int vsock_connect(unsigned int cid
, unsigned int port
, int type
)
146 struct sockaddr_vm svm
;
149 .svm_family
= AF_VSOCK
,
157 control_expectln("LISTENING");
159 fd
= socket(AF_VSOCK
, type
, 0);
165 timeout_begin(TIMEOUT
);
167 ret
= connect(fd
, &addr
.sa
, sizeof(addr
.svm
));
168 timeout_check("connect");
169 } while (ret
< 0 && errno
== EINTR
);
173 int old_errno
= errno
;
182 int vsock_stream_connect(unsigned int cid
, unsigned int port
)
184 return vsock_connect(cid
, port
, SOCK_STREAM
);
187 int vsock_seqpacket_connect(unsigned int cid
, unsigned int port
)
189 return vsock_connect(cid
, port
, SOCK_SEQPACKET
);
192 /* Listen on <cid, port> and return the file descriptor. */
193 static int vsock_listen(unsigned int cid
, unsigned int port
, int type
)
197 struct sockaddr_vm svm
;
200 .svm_family
= AF_VSOCK
,
207 fd
= socket(AF_VSOCK
, type
, 0);
213 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
218 if (listen(fd
, 1) < 0) {
226 /* Listen on <cid, port> and return the first incoming connection. The remote
227 * address is stored to clientaddrp. clientaddrp may be NULL.
229 int vsock_accept(unsigned int cid
, unsigned int port
,
230 struct sockaddr_vm
*clientaddrp
, int type
)
234 struct sockaddr_vm svm
;
236 socklen_t clientaddr_len
= sizeof(clientaddr
.svm
);
237 int fd
, client_fd
, old_errno
;
239 fd
= vsock_listen(cid
, port
, type
);
241 control_writeln("LISTENING");
243 timeout_begin(TIMEOUT
);
245 client_fd
= accept(fd
, &clientaddr
.sa
, &clientaddr_len
);
246 timeout_check("accept");
247 } while (client_fd
< 0 && errno
== EINTR
);
257 if (clientaddr_len
!= sizeof(clientaddr
.svm
)) {
258 fprintf(stderr
, "unexpected addrlen from accept(2), %zu\n",
259 (size_t)clientaddr_len
);
262 if (clientaddr
.sa
.sa_family
!= AF_VSOCK
) {
263 fprintf(stderr
, "expected AF_VSOCK from accept(2), got %d\n",
264 clientaddr
.sa
.sa_family
);
269 *clientaddrp
= clientaddr
.svm
;
273 int vsock_stream_accept(unsigned int cid
, unsigned int port
,
274 struct sockaddr_vm
*clientaddrp
)
276 return vsock_accept(cid
, port
, clientaddrp
, SOCK_STREAM
);
279 int vsock_stream_listen(unsigned int cid
, unsigned int port
)
281 return vsock_listen(cid
, port
, SOCK_STREAM
);
284 int vsock_seqpacket_accept(unsigned int cid
, unsigned int port
,
285 struct sockaddr_vm
*clientaddrp
)
287 return vsock_accept(cid
, port
, clientaddrp
, SOCK_SEQPACKET
);
290 /* Transmit bytes from a buffer and check the return value.
293 * <0 Negative errno (for testing errors)
295 * >0 Success (bytes successfully written)
297 void send_buf(int fd
, const void *buf
, size_t len
, int flags
,
298 ssize_t expected_ret
)
300 ssize_t nwritten
= 0;
303 timeout_begin(TIMEOUT
);
305 ret
= send(fd
, buf
+ nwritten
, len
- nwritten
, flags
);
306 timeout_check("send");
308 if (ret
== 0 || (ret
< 0 && errno
!= EINTR
))
312 } while (nwritten
< len
);
315 if (expected_ret
< 0) {
317 fprintf(stderr
, "bogus send(2) return value %zd (expected %zd)\n",
321 if (errno
!= -expected_ret
) {
333 if (nwritten
!= expected_ret
) {
335 fprintf(stderr
, "unexpected EOF while sending bytes\n");
337 fprintf(stderr
, "bogus send(2) bytes written %zd (expected %zd)\n",
338 nwritten
, expected_ret
);
343 /* Receive bytes in a buffer and check the return value.
346 * <0 Negative errno (for testing errors)
348 * >0 Success (bytes successfully read)
350 void recv_buf(int fd
, void *buf
, size_t len
, int flags
, ssize_t expected_ret
)
355 timeout_begin(TIMEOUT
);
357 ret
= recv(fd
, buf
+ nread
, len
- nread
, flags
);
358 timeout_check("recv");
360 if (ret
== 0 || (ret
< 0 && errno
!= EINTR
))
364 } while (nread
< len
);
367 if (expected_ret
< 0) {
369 fprintf(stderr
, "bogus recv(2) return value %zd (expected %zd)\n",
373 if (errno
!= -expected_ret
) {
385 if (nread
!= expected_ret
) {
387 fprintf(stderr
, "unexpected EOF while receiving bytes\n");
389 fprintf(stderr
, "bogus recv(2) bytes read %zd (expected %zd)\n",
390 nread
, expected_ret
);
395 /* Transmit one byte and check the return value.
398 * <0 Negative errno (for testing errors)
402 void send_byte(int fd
, int expected_ret
, int flags
)
404 const uint8_t byte
= 'A';
406 send_buf(fd
, &byte
, sizeof(byte
), flags
, expected_ret
);
409 /* Receive one byte and check the return value.
412 * <0 Negative errno (for testing errors)
416 void recv_byte(int fd
, int expected_ret
, int flags
)
420 recv_buf(fd
, &byte
, sizeof(byte
), flags
, expected_ret
);
423 fprintf(stderr
, "unexpected byte read %c\n", byte
);
428 /* Run test cases. The program terminates if a failure occurs. */
429 void run_tests(const struct test_case
*test_cases
,
430 const struct test_opts
*opts
)
434 for (i
= 0; test_cases
[i
].name
; i
++) {
435 void (*run
)(const struct test_opts
*opts
);
438 printf("%d - %s...", i
, test_cases
[i
].name
);
441 /* Full barrier before executing the next test. This
442 * ensures that client and server are executing the
443 * same test case. In particular, it means whoever is
444 * faster will not see the peer still executing the
445 * last test. This is important because port numbers
446 * can be used by multiple test cases.
448 if (test_cases
[i
].skip
)
449 control_writeln("SKIP");
451 control_writeln("NEXT");
453 line
= control_readln();
454 if (control_cmpln(line
, "SKIP", false) || test_cases
[i
].skip
) {
462 control_cmpln(line
, "NEXT", true);
465 if (opts
->mode
== TEST_MODE_CLIENT
)
466 run
= test_cases
[i
].run_client
;
468 run
= test_cases
[i
].run_server
;
477 void list_tests(const struct test_case
*test_cases
)
481 printf("ID\tTest name\n");
483 for (i
= 0; test_cases
[i
].name
; i
++)
484 printf("%d\t%s\n", i
, test_cases
[i
].name
);
489 void skip_test(struct test_case
*test_cases
, size_t test_cases_len
,
490 const char *test_id_str
)
492 unsigned long test_id
;
496 test_id
= strtoul(test_id_str
, &endptr
, 10);
497 if (errno
|| *endptr
!= '\0') {
498 fprintf(stderr
, "malformed test ID \"%s\"\n", test_id_str
);
502 if (test_id
>= test_cases_len
) {
503 fprintf(stderr
, "test ID (%lu) larger than the max allowed (%lu)\n",
504 test_id
, test_cases_len
- 1);
508 test_cases
[test_id
].skip
= true;
511 unsigned long hash_djb2(const void *data
, size_t len
)
513 unsigned long hash
= 5381;
517 hash
= ((hash
<< 5) + hash
) + ((unsigned char *)data
)[i
];
524 size_t iovec_bytes(const struct iovec
*iov
, size_t iovnum
)
529 for (bytes
= 0, i
= 0; i
< iovnum
; i
++)
530 bytes
+= iov
[i
].iov_len
;
535 unsigned long iovec_hash_djb2(const struct iovec
*iov
, size_t iovnum
)
543 iov_bytes
= iovec_bytes(iov
, iovnum
);
545 tmp
= malloc(iov_bytes
);
551 for (offs
= 0, i
= 0; i
< iovnum
; i
++) {
552 memcpy(tmp
+ offs
, iov
[i
].iov_base
, iov
[i
].iov_len
);
553 offs
+= iov
[i
].iov_len
;
556 hash
= hash_djb2(tmp
, iov_bytes
);
562 /* Allocates and returns new 'struct iovec *' according pattern
563 * in the 'test_iovec'. For each element in the 'test_iovec' it
564 * allocates new element in the resulting 'iovec'. 'iov_len'
565 * of the new element is copied from 'test_iovec'. 'iov_base' is
566 * allocated depending on the 'iov_base' of 'test_iovec':
568 * 'iov_base' == NULL -> valid buf: mmap('iov_len').
570 * 'iov_base' == MAP_FAILED -> invalid buf:
571 * mmap('iov_len'), then munmap('iov_len').
572 * 'iov_base' still contains result of
575 * 'iov_base' == number -> unaligned valid buf:
576 * mmap('iov_len') + number.
578 * 'iovnum' is number of elements in 'test_iovec'.
580 * Returns new 'iovec' or calls 'exit()' on error.
582 struct iovec
*alloc_test_iovec(const struct iovec
*test_iovec
, int iovnum
)
587 iovec
= malloc(sizeof(*iovec
) * iovnum
);
593 for (i
= 0; i
< iovnum
; i
++) {
594 iovec
[i
].iov_len
= test_iovec
[i
].iov_len
;
596 iovec
[i
].iov_base
= mmap(NULL
, iovec
[i
].iov_len
,
597 PROT_READ
| PROT_WRITE
,
598 MAP_PRIVATE
| MAP_ANONYMOUS
| MAP_POPULATE
,
600 if (iovec
[i
].iov_base
== MAP_FAILED
) {
605 if (test_iovec
[i
].iov_base
!= MAP_FAILED
)
606 iovec
[i
].iov_base
+= (uintptr_t)test_iovec
[i
].iov_base
;
609 /* Unmap "invalid" elements. */
610 for (i
= 0; i
< iovnum
; i
++) {
611 if (test_iovec
[i
].iov_base
== MAP_FAILED
) {
612 if (munmap(iovec
[i
].iov_base
, iovec
[i
].iov_len
)) {
619 for (i
= 0; i
< iovnum
; i
++) {
622 if (test_iovec
[i
].iov_base
== MAP_FAILED
)
625 for (j
= 0; j
< iovec
[i
].iov_len
; j
++)
626 ((uint8_t *)iovec
[i
].iov_base
)[j
] = rand() & 0xff;
632 /* Frees 'iovec *', previously allocated by 'alloc_test_iovec()'.
633 * On error calls 'exit()'.
635 void free_test_iovec(const struct iovec
*test_iovec
,
636 struct iovec
*iovec
, int iovnum
)
640 for (i
= 0; i
< iovnum
; i
++) {
641 if (test_iovec
[i
].iov_base
!= MAP_FAILED
) {
642 if (test_iovec
[i
].iov_base
)
643 iovec
[i
].iov_base
-= (uintptr_t)test_iovec
[i
].iov_base
;
645 if (munmap(iovec
[i
].iov_base
, iovec
[i
].iov_len
)) {