1 // SPDX-License-Identifier: GPL-2.0-only
3 * vsock_diag_test - vsock_diag.ko test suite
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
29 /* Per-socket status */
31 struct list_head list
;
32 struct vsock_diag_msg msg
;
35 static const char *sock_type_str(int type
)
45 return "INVALID TYPE";
49 static const char *sock_state_str(int state
)
59 return "DISCONNECTING";
63 return "INVALID STATE";
67 static const char *sock_shutdown_str(int shutdown
)
71 return "RCV_SHUTDOWN";
73 return "SEND_SHUTDOWN";
75 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
81 static void print_vsock_addr(FILE *fp
, unsigned int cid
, unsigned int port
)
83 if (cid
== VMADDR_CID_ANY
)
86 fprintf(fp
, "%u:", cid
);
88 if (port
== VMADDR_PORT_ANY
)
91 fprintf(fp
, "%u", port
);
94 static void print_vsock_stat(FILE *fp
, struct vsock_stat
*st
)
96 print_vsock_addr(fp
, st
->msg
.vdiag_src_cid
, st
->msg
.vdiag_src_port
);
98 print_vsock_addr(fp
, st
->msg
.vdiag_dst_cid
, st
->msg
.vdiag_dst_port
);
99 fprintf(fp
, " %s %s %s %u\n",
100 sock_type_str(st
->msg
.vdiag_type
),
101 sock_state_str(st
->msg
.vdiag_state
),
102 sock_shutdown_str(st
->msg
.vdiag_shutdown
),
106 static void print_vsock_stats(FILE *fp
, struct list_head
*head
)
108 struct vsock_stat
*st
;
110 list_for_each_entry(st
, head
, list
)
111 print_vsock_stat(fp
, st
);
114 static struct vsock_stat
*find_vsock_stat(struct list_head
*head
, int fd
)
116 struct vsock_stat
*st
;
119 if (fstat(fd
, &stat
) < 0) {
124 list_for_each_entry(st
, head
, list
)
125 if (st
->msg
.vdiag_ino
== stat
.st_ino
)
128 fprintf(stderr
, "cannot find fd %d\n", fd
);
132 static void check_no_sockets(struct list_head
*head
)
134 if (!list_empty(head
)) {
135 fprintf(stderr
, "expected no sockets\n");
136 print_vsock_stats(stderr
, head
);
141 static void check_num_sockets(struct list_head
*head
, int expected
)
143 struct list_head
*node
;
146 list_for_each(node
, head
)
150 fprintf(stderr
, "expected %d sockets, found %d\n",
152 print_vsock_stats(stderr
, head
);
157 static void check_socket_state(struct vsock_stat
*st
, __u8 state
)
159 if (st
->msg
.vdiag_state
!= state
) {
160 fprintf(stderr
, "expected socket state %#x, got %#x\n",
161 state
, st
->msg
.vdiag_state
);
166 static void send_req(int fd
)
168 struct sockaddr_nl nladdr
= {
169 .nl_family
= AF_NETLINK
,
173 struct vsock_diag_req vreq
;
176 .nlmsg_len
= sizeof(req
),
177 .nlmsg_type
= SOCK_DIAG_BY_FAMILY
,
178 .nlmsg_flags
= NLM_F_REQUEST
| NLM_F_DUMP
,
181 .sdiag_family
= AF_VSOCK
,
182 .vdiag_states
= ~(__u32
)0,
187 .iov_len
= sizeof(req
),
189 struct msghdr msg
= {
191 .msg_namelen
= sizeof(nladdr
),
197 if (sendmsg(fd
, &msg
, 0) < 0) {
209 static ssize_t
recv_resp(int fd
, void *buf
, size_t len
)
211 struct sockaddr_nl nladdr
= {
212 .nl_family
= AF_NETLINK
,
218 struct msghdr msg
= {
220 .msg_namelen
= sizeof(nladdr
),
227 ret
= recvmsg(fd
, &msg
, 0);
228 } while (ret
< 0 && errno
== EINTR
);
238 static void add_vsock_stat(struct list_head
*sockets
,
239 const struct vsock_diag_msg
*resp
)
241 struct vsock_stat
*st
;
243 st
= malloc(sizeof(*st
));
250 list_add_tail(&st
->list
, sockets
);
254 * Read vsock stats into a list.
256 static void read_vsock_stat(struct list_head
*sockets
)
258 long buf
[8192 / sizeof(long)];
261 fd
= socket(AF_NETLINK
, SOCK_RAW
, NETLINK_SOCK_DIAG
);
270 const struct nlmsghdr
*h
;
273 ret
= recv_resp(fd
, buf
, sizeof(buf
));
276 if (ret
< sizeof(*h
)) {
277 fprintf(stderr
, "short read of %zd bytes\n", ret
);
281 h
= (struct nlmsghdr
*)buf
;
283 while (NLMSG_OK(h
, ret
)) {
284 if (h
->nlmsg_type
== NLMSG_DONE
)
287 if (h
->nlmsg_type
== NLMSG_ERROR
) {
288 const struct nlmsgerr
*err
= NLMSG_DATA(h
);
290 if (h
->nlmsg_len
< NLMSG_LENGTH(sizeof(*err
)))
291 fprintf(stderr
, "NLMSG_ERROR\n");
294 perror("NLMSG_ERROR");
300 if (h
->nlmsg_type
!= SOCK_DIAG_BY_FAMILY
) {
301 fprintf(stderr
, "unexpected nlmsg_type %#x\n",
306 NLMSG_LENGTH(sizeof(struct vsock_diag_msg
))) {
307 fprintf(stderr
, "short vsock_diag_msg\n");
311 add_vsock_stat(sockets
, NLMSG_DATA(h
));
313 h
= NLMSG_NEXT(h
, ret
);
321 static void free_sock_stat(struct list_head
*sockets
)
323 struct vsock_stat
*st
;
324 struct vsock_stat
*next
;
326 list_for_each_entry_safe(st
, next
, sockets
, list
)
330 static void test_no_sockets(const struct test_opts
*opts
)
334 read_vsock_stat(&sockets
);
336 check_no_sockets(&sockets
);
339 static void test_listen_socket_server(const struct test_opts
*opts
)
343 struct sockaddr_vm svm
;
346 .svm_family
= AF_VSOCK
,
347 .svm_port
= opts
->peer_port
,
348 .svm_cid
= VMADDR_CID_ANY
,
352 struct vsock_stat
*st
;
355 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
357 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
362 if (listen(fd
, 1) < 0) {
367 read_vsock_stat(&sockets
);
369 check_num_sockets(&sockets
, 1);
370 st
= find_vsock_stat(&sockets
, fd
);
371 check_socket_state(st
, TCP_LISTEN
);
374 free_sock_stat(&sockets
);
377 static void test_connect_client(const struct test_opts
*opts
)
381 struct vsock_stat
*st
;
383 fd
= vsock_stream_connect(opts
->peer_cid
, opts
->peer_port
);
389 read_vsock_stat(&sockets
);
391 check_num_sockets(&sockets
, 1);
392 st
= find_vsock_stat(&sockets
, fd
);
393 check_socket_state(st
, TCP_ESTABLISHED
);
395 control_expectln("DONE");
396 control_writeln("DONE");
399 free_sock_stat(&sockets
);
402 static void test_connect_server(const struct test_opts
*opts
)
404 struct vsock_stat
*st
;
408 client_fd
= vsock_stream_accept(VMADDR_CID_ANY
, opts
->peer_port
, NULL
);
414 read_vsock_stat(&sockets
);
416 check_num_sockets(&sockets
, 1);
417 st
= find_vsock_stat(&sockets
, client_fd
);
418 check_socket_state(st
, TCP_ESTABLISHED
);
420 control_writeln("DONE");
421 control_expectln("DONE");
424 free_sock_stat(&sockets
);
427 static struct test_case test_cases
[] = {
429 .name
= "No sockets",
430 .run_server
= test_no_sockets
,
433 .name
= "Listen socket",
434 .run_server
= test_listen_socket_server
,
438 .run_client
= test_connect_client
,
439 .run_server
= test_connect_server
,
444 static const char optstring
[] = "";
445 static const struct option longopts
[] = {
447 .name
= "control-host",
448 .has_arg
= required_argument
,
452 .name
= "control-port",
453 .has_arg
= required_argument
,
458 .has_arg
= required_argument
,
463 .has_arg
= required_argument
,
468 .has_arg
= required_argument
,
473 .has_arg
= no_argument
,
478 .has_arg
= required_argument
,
483 .has_arg
= no_argument
,
489 static void usage(void)
491 fprintf(stderr
, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
493 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
496 "Run vsock_diag.ko tests. Must be launched in both\n"
497 "guest and host. One side must use --mode=client and\n"
498 "the other side must use --mode=server.\n"
500 "A TCP control socket connection is used to coordinate tests\n"
501 "between the client and the server. The server requires a\n"
502 "listen address and the client requires an address to\n"
505 "The CID of the other side must be given with --peer-cid=<cid>.\n"
508 " --help This help message\n"
509 " --control-host <host> Server IP address to connect to\n"
510 " --control-port <port> Server port to listen on/connect to\n"
511 " --mode client|server Server or client mode\n"
512 " --peer-cid <cid> CID of the other side\n"
513 " --peer-port <port> AF_VSOCK port used for the test [default: %d]\n"
514 " --list List of tests that will be executed\n"
515 " --skip <test_id> Test ID to skip;\n"
516 " use multiple --skip options to skip more tests\n",
522 int main(int argc
, char **argv
)
524 const char *control_host
= NULL
;
525 const char *control_port
= NULL
;
526 struct test_opts opts
= {
527 .mode
= TEST_MODE_UNSET
,
528 .peer_cid
= VMADDR_CID_ANY
,
529 .peer_port
= DEFAULT_PEER_PORT
,
535 int opt
= getopt_long(argc
, argv
, optstring
, longopts
, NULL
);
542 control_host
= optarg
;
545 if (strcmp(optarg
, "client") == 0)
546 opts
.mode
= TEST_MODE_CLIENT
;
547 else if (strcmp(optarg
, "server") == 0)
548 opts
.mode
= TEST_MODE_SERVER
;
550 fprintf(stderr
, "--mode must be \"client\" or \"server\"\n");
555 opts
.peer_cid
= parse_cid(optarg
);
558 opts
.peer_port
= parse_port(optarg
);
561 control_port
= optarg
;
564 list_tests(test_cases
);
567 skip_test(test_cases
, ARRAY_SIZE(test_cases
) - 1,
578 if (opts
.mode
== TEST_MODE_UNSET
)
580 if (opts
.peer_cid
== VMADDR_CID_ANY
)
584 if (opts
.mode
!= TEST_MODE_SERVER
)
586 control_host
= "0.0.0.0";
589 control_init(control_host
, control_port
,
590 opts
.mode
== TEST_MODE_SERVER
);
592 run_tests(test_cases
, &opts
);