1 // SPDX-License-Identifier: GPL-2.0-only
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
17 #include <sys/epoll.h>
23 /* Install signal handlers */
24 void init_signals(void)
26 struct sigaction act
= {
27 .sa_handler
= sigalrm
,
30 sigaction(SIGALRM
, &act
, NULL
);
31 signal(SIGPIPE
, SIG_IGN
);
34 /* Parse a CID in string representation */
35 unsigned int parse_cid(const char *str
)
41 n
= strtoul(str
, &endptr
, 10);
42 if (errno
|| *endptr
!= '\0') {
43 fprintf(stderr
, "malformed CID \"%s\"\n", str
);
49 /* Wait for the remote to close the connection */
50 void vsock_wait_remote_close(int fd
)
52 struct epoll_event ev
;
55 epollfd
= epoll_create1(0);
57 perror("epoll_create1");
61 ev
.events
= EPOLLRDHUP
| EPOLLHUP
;
63 if (epoll_ctl(epollfd
, EPOLL_CTL_ADD
, fd
, &ev
) == -1) {
68 nfds
= epoll_wait(epollfd
, &ev
, 1, TIMEOUT
* 1000);
75 fprintf(stderr
, "epoll_wait timed out\n");
80 assert(ev
.events
& (EPOLLRDHUP
| EPOLLHUP
));
81 assert(ev
.data
.fd
== fd
);
86 /* Connect to <cid, port> and return the file descriptor. */
87 int vsock_stream_connect(unsigned int cid
, unsigned int port
)
91 struct sockaddr_vm svm
;
94 .svm_family
= AF_VSOCK
,
102 control_expectln("LISTENING");
104 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
106 timeout_begin(TIMEOUT
);
108 ret
= connect(fd
, &addr
.sa
, sizeof(addr
.svm
));
109 timeout_check("connect");
110 } while (ret
< 0 && errno
== EINTR
);
114 int old_errno
= errno
;
123 /* Listen on <cid, port> and return the first incoming connection. The remote
124 * address is stored to clientaddrp. clientaddrp may be NULL.
126 int vsock_stream_accept(unsigned int cid
, unsigned int port
,
127 struct sockaddr_vm
*clientaddrp
)
131 struct sockaddr_vm svm
;
134 .svm_family
= AF_VSOCK
,
141 struct sockaddr_vm svm
;
143 socklen_t clientaddr_len
= sizeof(clientaddr
.svm
);
148 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
150 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
155 if (listen(fd
, 1) < 0) {
160 control_writeln("LISTENING");
162 timeout_begin(TIMEOUT
);
164 client_fd
= accept(fd
, &clientaddr
.sa
, &clientaddr_len
);
165 timeout_check("accept");
166 } while (client_fd
< 0 && errno
== EINTR
);
176 if (clientaddr_len
!= sizeof(clientaddr
.svm
)) {
177 fprintf(stderr
, "unexpected addrlen from accept(2), %zu\n",
178 (size_t)clientaddr_len
);
181 if (clientaddr
.sa
.sa_family
!= AF_VSOCK
) {
182 fprintf(stderr
, "expected AF_VSOCK from accept(2), got %d\n",
183 clientaddr
.sa
.sa_family
);
188 *clientaddrp
= clientaddr
.svm
;
192 /* Transmit one byte and check the return value.
195 * <0 Negative errno (for testing errors)
199 void send_byte(int fd
, int expected_ret
, int flags
)
201 const uint8_t byte
= 'A';
204 timeout_begin(TIMEOUT
);
206 nwritten
= send(fd
, &byte
, sizeof(byte
), flags
);
207 timeout_check("write");
208 } while (nwritten
< 0 && errno
== EINTR
);
211 if (expected_ret
< 0) {
212 if (nwritten
!= -1) {
213 fprintf(stderr
, "bogus send(2) return value %zd\n",
217 if (errno
!= -expected_ret
) {
229 if (expected_ret
== 0)
232 fprintf(stderr
, "unexpected EOF while sending byte\n");
235 if (nwritten
!= sizeof(byte
)) {
236 fprintf(stderr
, "bogus send(2) return value %zd\n", nwritten
);
241 /* Receive one byte and check the return value.
244 * <0 Negative errno (for testing errors)
248 void recv_byte(int fd
, int expected_ret
, int flags
)
253 timeout_begin(TIMEOUT
);
255 nread
= recv(fd
, &byte
, sizeof(byte
), flags
);
256 timeout_check("read");
257 } while (nread
< 0 && errno
== EINTR
);
260 if (expected_ret
< 0) {
262 fprintf(stderr
, "bogus recv(2) return value %zd\n",
266 if (errno
!= -expected_ret
) {
278 if (expected_ret
== 0)
281 fprintf(stderr
, "unexpected EOF while receiving byte\n");
284 if (nread
!= sizeof(byte
)) {
285 fprintf(stderr
, "bogus recv(2) return value %zd\n", nread
);
289 fprintf(stderr
, "unexpected byte read %c\n", byte
);
294 /* Run test cases. The program terminates if a failure occurs. */
295 void run_tests(const struct test_case
*test_cases
,
296 const struct test_opts
*opts
)
300 for (i
= 0; test_cases
[i
].name
; i
++) {
301 void (*run
)(const struct test_opts
*opts
);
304 printf("%d - %s...", i
, test_cases
[i
].name
);
307 /* Full barrier before executing the next test. This
308 * ensures that client and server are executing the
309 * same test case. In particular, it means whoever is
310 * faster will not see the peer still executing the
311 * last test. This is important because port numbers
312 * can be used by multiple test cases.
314 if (test_cases
[i
].skip
)
315 control_writeln("SKIP");
317 control_writeln("NEXT");
319 line
= control_readln();
320 if (control_cmpln(line
, "SKIP", false) || test_cases
[i
].skip
) {
328 control_cmpln(line
, "NEXT", true);
331 if (opts
->mode
== TEST_MODE_CLIENT
)
332 run
= test_cases
[i
].run_client
;
334 run
= test_cases
[i
].run_server
;
343 void list_tests(const struct test_case
*test_cases
)
347 printf("ID\tTest name\n");
349 for (i
= 0; test_cases
[i
].name
; i
++)
350 printf("%d\t%s\n", i
, test_cases
[i
].name
);
355 void skip_test(struct test_case
*test_cases
, size_t test_cases_len
,
356 const char *test_id_str
)
358 unsigned long test_id
;
362 test_id
= strtoul(test_id_str
, &endptr
, 10);
363 if (errno
|| *endptr
!= '\0') {
364 fprintf(stderr
, "malformed test ID \"%s\"\n", test_id_str
);
368 if (test_id
>= test_cases_len
) {
369 fprintf(stderr
, "test ID (%lu) larger than the max allowed (%lu)\n",
370 test_id
, test_cases_len
- 1);
374 test_cases
[test_id
].skip
= true;