1 // SPDX-License-Identifier: GPL-2.0-only
3 * vsock_diag_test - vsock_diag.ko test suite
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
29 /* Per-socket status */
31 struct list_head list
;
32 struct vsock_diag_msg msg
;
35 static const char *sock_type_str(int type
)
43 return "INVALID TYPE";
47 static const char *sock_state_str(int state
)
57 return "DISCONNECTING";
61 return "INVALID STATE";
65 static const char *sock_shutdown_str(int shutdown
)
69 return "RCV_SHUTDOWN";
71 return "SEND_SHUTDOWN";
73 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
79 static void print_vsock_addr(FILE *fp
, unsigned int cid
, unsigned int port
)
81 if (cid
== VMADDR_CID_ANY
)
84 fprintf(fp
, "%u:", cid
);
86 if (port
== VMADDR_PORT_ANY
)
89 fprintf(fp
, "%u", port
);
92 static void print_vsock_stat(FILE *fp
, struct vsock_stat
*st
)
94 print_vsock_addr(fp
, st
->msg
.vdiag_src_cid
, st
->msg
.vdiag_src_port
);
96 print_vsock_addr(fp
, st
->msg
.vdiag_dst_cid
, st
->msg
.vdiag_dst_port
);
97 fprintf(fp
, " %s %s %s %u\n",
98 sock_type_str(st
->msg
.vdiag_type
),
99 sock_state_str(st
->msg
.vdiag_state
),
100 sock_shutdown_str(st
->msg
.vdiag_shutdown
),
104 static void print_vsock_stats(FILE *fp
, struct list_head
*head
)
106 struct vsock_stat
*st
;
108 list_for_each_entry(st
, head
, list
)
109 print_vsock_stat(fp
, st
);
112 static struct vsock_stat
*find_vsock_stat(struct list_head
*head
, int fd
)
114 struct vsock_stat
*st
;
117 if (fstat(fd
, &stat
) < 0) {
122 list_for_each_entry(st
, head
, list
)
123 if (st
->msg
.vdiag_ino
== stat
.st_ino
)
126 fprintf(stderr
, "cannot find fd %d\n", fd
);
130 static void check_no_sockets(struct list_head
*head
)
132 if (!list_empty(head
)) {
133 fprintf(stderr
, "expected no sockets\n");
134 print_vsock_stats(stderr
, head
);
139 static void check_num_sockets(struct list_head
*head
, int expected
)
141 struct list_head
*node
;
144 list_for_each(node
, head
)
148 fprintf(stderr
, "expected %d sockets, found %d\n",
150 print_vsock_stats(stderr
, head
);
155 static void check_socket_state(struct vsock_stat
*st
, __u8 state
)
157 if (st
->msg
.vdiag_state
!= state
) {
158 fprintf(stderr
, "expected socket state %#x, got %#x\n",
159 state
, st
->msg
.vdiag_state
);
164 static void send_req(int fd
)
166 struct sockaddr_nl nladdr
= {
167 .nl_family
= AF_NETLINK
,
171 struct vsock_diag_req vreq
;
174 .nlmsg_len
= sizeof(req
),
175 .nlmsg_type
= SOCK_DIAG_BY_FAMILY
,
176 .nlmsg_flags
= NLM_F_REQUEST
| NLM_F_DUMP
,
179 .sdiag_family
= AF_VSOCK
,
180 .vdiag_states
= ~(__u32
)0,
185 .iov_len
= sizeof(req
),
187 struct msghdr msg
= {
189 .msg_namelen
= sizeof(nladdr
),
195 if (sendmsg(fd
, &msg
, 0) < 0) {
207 static ssize_t
recv_resp(int fd
, void *buf
, size_t len
)
209 struct sockaddr_nl nladdr
= {
210 .nl_family
= AF_NETLINK
,
216 struct msghdr msg
= {
218 .msg_namelen
= sizeof(nladdr
),
225 ret
= recvmsg(fd
, &msg
, 0);
226 } while (ret
< 0 && errno
== EINTR
);
236 static void add_vsock_stat(struct list_head
*sockets
,
237 const struct vsock_diag_msg
*resp
)
239 struct vsock_stat
*st
;
241 st
= malloc(sizeof(*st
));
248 list_add_tail(&st
->list
, sockets
);
252 * Read vsock stats into a list.
254 static void read_vsock_stat(struct list_head
*sockets
)
256 long buf
[8192 / sizeof(long)];
259 fd
= socket(AF_NETLINK
, SOCK_RAW
, NETLINK_SOCK_DIAG
);
268 const struct nlmsghdr
*h
;
271 ret
= recv_resp(fd
, buf
, sizeof(buf
));
274 if (ret
< sizeof(*h
)) {
275 fprintf(stderr
, "short read of %zd bytes\n", ret
);
279 h
= (struct nlmsghdr
*)buf
;
281 while (NLMSG_OK(h
, ret
)) {
282 if (h
->nlmsg_type
== NLMSG_DONE
)
285 if (h
->nlmsg_type
== NLMSG_ERROR
) {
286 const struct nlmsgerr
*err
= NLMSG_DATA(h
);
288 if (h
->nlmsg_len
< NLMSG_LENGTH(sizeof(*err
)))
289 fprintf(stderr
, "NLMSG_ERROR\n");
292 perror("NLMSG_ERROR");
298 if (h
->nlmsg_type
!= SOCK_DIAG_BY_FAMILY
) {
299 fprintf(stderr
, "unexpected nlmsg_type %#x\n",
304 NLMSG_LENGTH(sizeof(struct vsock_diag_msg
))) {
305 fprintf(stderr
, "short vsock_diag_msg\n");
309 add_vsock_stat(sockets
, NLMSG_DATA(h
));
311 h
= NLMSG_NEXT(h
, ret
);
319 static void free_sock_stat(struct list_head
*sockets
)
321 struct vsock_stat
*st
;
322 struct vsock_stat
*next
;
324 list_for_each_entry_safe(st
, next
, sockets
, list
)
328 static void test_no_sockets(const struct test_opts
*opts
)
332 read_vsock_stat(&sockets
);
334 check_no_sockets(&sockets
);
336 free_sock_stat(&sockets
);
339 static void test_listen_socket_server(const struct test_opts
*opts
)
343 struct sockaddr_vm svm
;
346 .svm_family
= AF_VSOCK
,
348 .svm_cid
= VMADDR_CID_ANY
,
352 struct vsock_stat
*st
;
355 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
357 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
362 if (listen(fd
, 1) < 0) {
367 read_vsock_stat(&sockets
);
369 check_num_sockets(&sockets
, 1);
370 st
= find_vsock_stat(&sockets
, fd
);
371 check_socket_state(st
, TCP_LISTEN
);
374 free_sock_stat(&sockets
);
377 static void test_connect_client(const struct test_opts
*opts
)
381 struct vsock_stat
*st
;
383 fd
= vsock_stream_connect(opts
->peer_cid
, 1234);
389 read_vsock_stat(&sockets
);
391 check_num_sockets(&sockets
, 1);
392 st
= find_vsock_stat(&sockets
, fd
);
393 check_socket_state(st
, TCP_ESTABLISHED
);
395 control_expectln("DONE");
396 control_writeln("DONE");
399 free_sock_stat(&sockets
);
402 static void test_connect_server(const struct test_opts
*opts
)
404 struct vsock_stat
*st
;
408 client_fd
= vsock_stream_accept(VMADDR_CID_ANY
, 1234, NULL
);
414 read_vsock_stat(&sockets
);
416 check_num_sockets(&sockets
, 1);
417 st
= find_vsock_stat(&sockets
, client_fd
);
418 check_socket_state(st
, TCP_ESTABLISHED
);
420 control_writeln("DONE");
421 control_expectln("DONE");
424 free_sock_stat(&sockets
);
427 static struct test_case test_cases
[] = {
429 .name
= "No sockets",
430 .run_server
= test_no_sockets
,
433 .name
= "Listen socket",
434 .run_server
= test_listen_socket_server
,
438 .run_client
= test_connect_client
,
439 .run_server
= test_connect_server
,
444 static const char optstring
[] = "";
445 static const struct option longopts
[] = {
447 .name
= "control-host",
448 .has_arg
= required_argument
,
452 .name
= "control-port",
453 .has_arg
= required_argument
,
458 .has_arg
= required_argument
,
463 .has_arg
= required_argument
,
468 .has_arg
= no_argument
,
473 .has_arg
= required_argument
,
478 .has_arg
= no_argument
,
484 static void usage(void)
486 fprintf(stderr
, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
488 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
489 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
491 "Run vsock_diag.ko tests. Must be launched in both\n"
492 "guest and host. One side must use --mode=client and\n"
493 "the other side must use --mode=server.\n"
495 "A TCP control socket connection is used to coordinate tests\n"
496 "between the client and the server. The server requires a\n"
497 "listen address and the client requires an address to\n"
500 "The CID of the other side must be given with --peer-cid=<cid>.\n"
503 " --help This help message\n"
504 " --control-host <host> Server IP address to connect to\n"
505 " --control-port <port> Server port to listen on/connect to\n"
506 " --mode client|server Server or client mode\n"
507 " --peer-cid <cid> CID of the other side\n"
508 " --list List of tests that will be executed\n"
509 " --skip <test_id> Test ID to skip;\n"
510 " use multiple --skip options to skip more tests\n"
515 int main(int argc
, char **argv
)
517 const char *control_host
= NULL
;
518 const char *control_port
= NULL
;
519 struct test_opts opts
= {
520 .mode
= TEST_MODE_UNSET
,
521 .peer_cid
= VMADDR_CID_ANY
,
527 int opt
= getopt_long(argc
, argv
, optstring
, longopts
, NULL
);
534 control_host
= optarg
;
537 if (strcmp(optarg
, "client") == 0)
538 opts
.mode
= TEST_MODE_CLIENT
;
539 else if (strcmp(optarg
, "server") == 0)
540 opts
.mode
= TEST_MODE_SERVER
;
542 fprintf(stderr
, "--mode must be \"client\" or \"server\"\n");
547 opts
.peer_cid
= parse_cid(optarg
);
550 control_port
= optarg
;
553 list_tests(test_cases
);
556 skip_test(test_cases
, ARRAY_SIZE(test_cases
) - 1,
567 if (opts
.mode
== TEST_MODE_UNSET
)
569 if (opts
.peer_cid
== VMADDR_CID_ANY
)
573 if (opts
.mode
!= TEST_MODE_SERVER
)
575 control_host
= "0.0.0.0";
578 control_init(control_host
, control_port
,
579 opts
.mode
== TEST_MODE_SERVER
);
581 run_tests(test_cases
, &opts
);