treewide: remove redundant IS_ERR() before error code check
[linux/fpc-iii.git] / tools / testing / vsock / vsock_diag_test.c
blobcec6f5a738e1e4325c0f06809bf410ad7f3d4806
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * vsock_diag_test - vsock_diag.ko test suite
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
29 /* Per-socket status */
30 struct vsock_stat {
31 struct list_head list;
32 struct vsock_diag_msg msg;
35 static const char *sock_type_str(int type)
37 switch (type) {
38 case SOCK_DGRAM:
39 return "DGRAM";
40 case SOCK_STREAM:
41 return "STREAM";
42 default:
43 return "INVALID TYPE";
47 static const char *sock_state_str(int state)
49 switch (state) {
50 case TCP_CLOSE:
51 return "UNCONNECTED";
52 case TCP_SYN_SENT:
53 return "CONNECTING";
54 case TCP_ESTABLISHED:
55 return "CONNECTED";
56 case TCP_CLOSING:
57 return "DISCONNECTING";
58 case TCP_LISTEN:
59 return "LISTEN";
60 default:
61 return "INVALID STATE";
65 static const char *sock_shutdown_str(int shutdown)
67 switch (shutdown) {
68 case 1:
69 return "RCV_SHUTDOWN";
70 case 2:
71 return "SEND_SHUTDOWN";
72 case 3:
73 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
74 default:
75 return "0";
79 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
81 if (cid == VMADDR_CID_ANY)
82 fprintf(fp, "*:");
83 else
84 fprintf(fp, "%u:", cid);
86 if (port == VMADDR_PORT_ANY)
87 fprintf(fp, "*");
88 else
89 fprintf(fp, "%u", port);
92 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
94 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
95 fprintf(fp, " ");
96 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
97 fprintf(fp, " %s %s %s %u\n",
98 sock_type_str(st->msg.vdiag_type),
99 sock_state_str(st->msg.vdiag_state),
100 sock_shutdown_str(st->msg.vdiag_shutdown),
101 st->msg.vdiag_ino);
104 static void print_vsock_stats(FILE *fp, struct list_head *head)
106 struct vsock_stat *st;
108 list_for_each_entry(st, head, list)
109 print_vsock_stat(fp, st);
112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
114 struct vsock_stat *st;
115 struct stat stat;
117 if (fstat(fd, &stat) < 0) {
118 perror("fstat");
119 exit(EXIT_FAILURE);
122 list_for_each_entry(st, head, list)
123 if (st->msg.vdiag_ino == stat.st_ino)
124 return st;
126 fprintf(stderr, "cannot find fd %d\n", fd);
127 exit(EXIT_FAILURE);
130 static void check_no_sockets(struct list_head *head)
132 if (!list_empty(head)) {
133 fprintf(stderr, "expected no sockets\n");
134 print_vsock_stats(stderr, head);
135 exit(1);
139 static void check_num_sockets(struct list_head *head, int expected)
141 struct list_head *node;
142 int n = 0;
144 list_for_each(node, head)
145 n++;
147 if (n != expected) {
148 fprintf(stderr, "expected %d sockets, found %d\n",
149 expected, n);
150 print_vsock_stats(stderr, head);
151 exit(EXIT_FAILURE);
155 static void check_socket_state(struct vsock_stat *st, __u8 state)
157 if (st->msg.vdiag_state != state) {
158 fprintf(stderr, "expected socket state %#x, got %#x\n",
159 state, st->msg.vdiag_state);
160 exit(EXIT_FAILURE);
164 static void send_req(int fd)
166 struct sockaddr_nl nladdr = {
167 .nl_family = AF_NETLINK,
169 struct {
170 struct nlmsghdr nlh;
171 struct vsock_diag_req vreq;
172 } req = {
173 .nlh = {
174 .nlmsg_len = sizeof(req),
175 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
176 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
178 .vreq = {
179 .sdiag_family = AF_VSOCK,
180 .vdiag_states = ~(__u32)0,
183 struct iovec iov = {
184 .iov_base = &req,
185 .iov_len = sizeof(req),
187 struct msghdr msg = {
188 .msg_name = &nladdr,
189 .msg_namelen = sizeof(nladdr),
190 .msg_iov = &iov,
191 .msg_iovlen = 1,
194 for (;;) {
195 if (sendmsg(fd, &msg, 0) < 0) {
196 if (errno == EINTR)
197 continue;
199 perror("sendmsg");
200 exit(EXIT_FAILURE);
203 return;
207 static ssize_t recv_resp(int fd, void *buf, size_t len)
209 struct sockaddr_nl nladdr = {
210 .nl_family = AF_NETLINK,
212 struct iovec iov = {
213 .iov_base = buf,
214 .iov_len = len,
216 struct msghdr msg = {
217 .msg_name = &nladdr,
218 .msg_namelen = sizeof(nladdr),
219 .msg_iov = &iov,
220 .msg_iovlen = 1,
222 ssize_t ret;
224 do {
225 ret = recvmsg(fd, &msg, 0);
226 } while (ret < 0 && errno == EINTR);
228 if (ret < 0) {
229 perror("recvmsg");
230 exit(EXIT_FAILURE);
233 return ret;
236 static void add_vsock_stat(struct list_head *sockets,
237 const struct vsock_diag_msg *resp)
239 struct vsock_stat *st;
241 st = malloc(sizeof(*st));
242 if (!st) {
243 perror("malloc");
244 exit(EXIT_FAILURE);
247 st->msg = *resp;
248 list_add_tail(&st->list, sockets);
252 * Read vsock stats into a list.
254 static void read_vsock_stat(struct list_head *sockets)
256 long buf[8192 / sizeof(long)];
257 int fd;
259 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260 if (fd < 0) {
261 perror("socket");
262 exit(EXIT_FAILURE);
265 send_req(fd);
267 for (;;) {
268 const struct nlmsghdr *h;
269 ssize_t ret;
271 ret = recv_resp(fd, buf, sizeof(buf));
272 if (ret == 0)
273 goto done;
274 if (ret < sizeof(*h)) {
275 fprintf(stderr, "short read of %zd bytes\n", ret);
276 exit(EXIT_FAILURE);
279 h = (struct nlmsghdr *)buf;
281 while (NLMSG_OK(h, ret)) {
282 if (h->nlmsg_type == NLMSG_DONE)
283 goto done;
285 if (h->nlmsg_type == NLMSG_ERROR) {
286 const struct nlmsgerr *err = NLMSG_DATA(h);
288 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289 fprintf(stderr, "NLMSG_ERROR\n");
290 else {
291 errno = -err->error;
292 perror("NLMSG_ERROR");
295 exit(EXIT_FAILURE);
298 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299 fprintf(stderr, "unexpected nlmsg_type %#x\n",
300 h->nlmsg_type);
301 exit(EXIT_FAILURE);
303 if (h->nlmsg_len <
304 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305 fprintf(stderr, "short vsock_diag_msg\n");
306 exit(EXIT_FAILURE);
309 add_vsock_stat(sockets, NLMSG_DATA(h));
311 h = NLMSG_NEXT(h, ret);
315 done:
316 close(fd);
319 static void free_sock_stat(struct list_head *sockets)
321 struct vsock_stat *st;
322 struct vsock_stat *next;
324 list_for_each_entry_safe(st, next, sockets, list)
325 free(st);
328 static void test_no_sockets(const struct test_opts *opts)
330 LIST_HEAD(sockets);
332 read_vsock_stat(&sockets);
334 check_no_sockets(&sockets);
336 free_sock_stat(&sockets);
339 static void test_listen_socket_server(const struct test_opts *opts)
341 union {
342 struct sockaddr sa;
343 struct sockaddr_vm svm;
344 } addr = {
345 .svm = {
346 .svm_family = AF_VSOCK,
347 .svm_port = 1234,
348 .svm_cid = VMADDR_CID_ANY,
351 LIST_HEAD(sockets);
352 struct vsock_stat *st;
353 int fd;
355 fd = socket(AF_VSOCK, SOCK_STREAM, 0);
357 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 perror("bind");
359 exit(EXIT_FAILURE);
362 if (listen(fd, 1) < 0) {
363 perror("listen");
364 exit(EXIT_FAILURE);
367 read_vsock_stat(&sockets);
369 check_num_sockets(&sockets, 1);
370 st = find_vsock_stat(&sockets, fd);
371 check_socket_state(st, TCP_LISTEN);
373 close(fd);
374 free_sock_stat(&sockets);
377 static void test_connect_client(const struct test_opts *opts)
379 int fd;
380 LIST_HEAD(sockets);
381 struct vsock_stat *st;
383 fd = vsock_stream_connect(opts->peer_cid, 1234);
384 if (fd < 0) {
385 perror("connect");
386 exit(EXIT_FAILURE);
389 read_vsock_stat(&sockets);
391 check_num_sockets(&sockets, 1);
392 st = find_vsock_stat(&sockets, fd);
393 check_socket_state(st, TCP_ESTABLISHED);
395 control_expectln("DONE");
396 control_writeln("DONE");
398 close(fd);
399 free_sock_stat(&sockets);
402 static void test_connect_server(const struct test_opts *opts)
404 struct vsock_stat *st;
405 LIST_HEAD(sockets);
406 int client_fd;
408 client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
409 if (client_fd < 0) {
410 perror("accept");
411 exit(EXIT_FAILURE);
414 read_vsock_stat(&sockets);
416 check_num_sockets(&sockets, 1);
417 st = find_vsock_stat(&sockets, client_fd);
418 check_socket_state(st, TCP_ESTABLISHED);
420 control_writeln("DONE");
421 control_expectln("DONE");
423 close(client_fd);
424 free_sock_stat(&sockets);
427 static struct test_case test_cases[] = {
429 .name = "No sockets",
430 .run_server = test_no_sockets,
433 .name = "Listen socket",
434 .run_server = test_listen_socket_server,
437 .name = "Connect",
438 .run_client = test_connect_client,
439 .run_server = test_connect_server,
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
447 .name = "control-host",
448 .has_arg = required_argument,
449 .val = 'H',
452 .name = "control-port",
453 .has_arg = required_argument,
454 .val = 'P',
457 .name = "mode",
458 .has_arg = required_argument,
459 .val = 'm',
462 .name = "peer-cid",
463 .has_arg = required_argument,
464 .val = 'p',
467 .name = "list",
468 .has_arg = no_argument,
469 .val = 'l',
472 .name = "skip",
473 .has_arg = required_argument,
474 .val = 's',
477 .name = "help",
478 .has_arg = no_argument,
479 .val = '?',
484 static void usage(void)
486 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
487 "\n"
488 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
489 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
490 "\n"
491 "Run vsock_diag.ko tests. Must be launched in both\n"
492 "guest and host. One side must use --mode=client and\n"
493 "the other side must use --mode=server.\n"
494 "\n"
495 "A TCP control socket connection is used to coordinate tests\n"
496 "between the client and the server. The server requires a\n"
497 "listen address and the client requires an address to\n"
498 "connect to.\n"
499 "\n"
500 "The CID of the other side must be given with --peer-cid=<cid>.\n"
501 "\n"
502 "Options:\n"
503 " --help This help message\n"
504 " --control-host <host> Server IP address to connect to\n"
505 " --control-port <port> Server port to listen on/connect to\n"
506 " --mode client|server Server or client mode\n"
507 " --peer-cid <cid> CID of the other side\n"
508 " --list List of tests that will be executed\n"
509 " --skip <test_id> Test ID to skip;\n"
510 " use multiple --skip options to skip more tests\n"
512 exit(EXIT_FAILURE);
515 int main(int argc, char **argv)
517 const char *control_host = NULL;
518 const char *control_port = NULL;
519 struct test_opts opts = {
520 .mode = TEST_MODE_UNSET,
521 .peer_cid = VMADDR_CID_ANY,
524 init_signals();
526 for (;;) {
527 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
529 if (opt == -1)
530 break;
532 switch (opt) {
533 case 'H':
534 control_host = optarg;
535 break;
536 case 'm':
537 if (strcmp(optarg, "client") == 0)
538 opts.mode = TEST_MODE_CLIENT;
539 else if (strcmp(optarg, "server") == 0)
540 opts.mode = TEST_MODE_SERVER;
541 else {
542 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
543 return EXIT_FAILURE;
545 break;
546 case 'p':
547 opts.peer_cid = parse_cid(optarg);
548 break;
549 case 'P':
550 control_port = optarg;
551 break;
552 case 'l':
553 list_tests(test_cases);
554 break;
555 case 's':
556 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
557 optarg);
558 break;
559 case '?':
560 default:
561 usage();
565 if (!control_port)
566 usage();
567 if (opts.mode == TEST_MODE_UNSET)
568 usage();
569 if (opts.peer_cid == VMADDR_CID_ANY)
570 usage();
572 if (!control_host) {
573 if (opts.mode != TEST_MODE_SERVER)
574 usage();
575 control_host = "0.0.0.0";
578 control_init(control_host, control_port,
579 opts.mode == TEST_MODE_SERVER);
581 run_tests(test_cases, &opts);
583 control_cleanup();
584 return EXIT_SUCCESS;