1 // SPDX-License-Identifier: GPL-2.0-only
3 * vsock_diag_test - vsock_diag.ko test suite
5 * Copyright (C) 2017 Red Hat, Inc.
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
18 #include <sys/socket.h>
20 #include <sys/types.h>
21 #include <linux/list.h>
22 #include <linux/net.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netinet/tcp.h>
27 #include "../../../include/uapi/linux/vm_sockets.h"
28 #include "../../../include/uapi/linux/vm_sockets_diag.h"
39 /* Per-socket status */
41 struct list_head list
;
42 struct vsock_diag_msg msg
;
45 static const char *sock_type_str(int type
)
53 return "INVALID TYPE";
57 static const char *sock_state_str(int state
)
67 return "DISCONNECTING";
71 return "INVALID STATE";
75 static const char *sock_shutdown_str(int shutdown
)
79 return "RCV_SHUTDOWN";
81 return "SEND_SHUTDOWN";
83 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
89 static void print_vsock_addr(FILE *fp
, unsigned int cid
, unsigned int port
)
91 if (cid
== VMADDR_CID_ANY
)
94 fprintf(fp
, "%u:", cid
);
96 if (port
== VMADDR_PORT_ANY
)
99 fprintf(fp
, "%u", port
);
102 static void print_vsock_stat(FILE *fp
, struct vsock_stat
*st
)
104 print_vsock_addr(fp
, st
->msg
.vdiag_src_cid
, st
->msg
.vdiag_src_port
);
106 print_vsock_addr(fp
, st
->msg
.vdiag_dst_cid
, st
->msg
.vdiag_dst_port
);
107 fprintf(fp
, " %s %s %s %u\n",
108 sock_type_str(st
->msg
.vdiag_type
),
109 sock_state_str(st
->msg
.vdiag_state
),
110 sock_shutdown_str(st
->msg
.vdiag_shutdown
),
114 static void print_vsock_stats(FILE *fp
, struct list_head
*head
)
116 struct vsock_stat
*st
;
118 list_for_each_entry(st
, head
, list
)
119 print_vsock_stat(fp
, st
);
122 static struct vsock_stat
*find_vsock_stat(struct list_head
*head
, int fd
)
124 struct vsock_stat
*st
;
127 if (fstat(fd
, &stat
) < 0) {
132 list_for_each_entry(st
, head
, list
)
133 if (st
->msg
.vdiag_ino
== stat
.st_ino
)
136 fprintf(stderr
, "cannot find fd %d\n", fd
);
140 static void check_no_sockets(struct list_head
*head
)
142 if (!list_empty(head
)) {
143 fprintf(stderr
, "expected no sockets\n");
144 print_vsock_stats(stderr
, head
);
149 static void check_num_sockets(struct list_head
*head
, int expected
)
151 struct list_head
*node
;
154 list_for_each(node
, head
)
158 fprintf(stderr
, "expected %d sockets, found %d\n",
160 print_vsock_stats(stderr
, head
);
165 static void check_socket_state(struct vsock_stat
*st
, __u8 state
)
167 if (st
->msg
.vdiag_state
!= state
) {
168 fprintf(stderr
, "expected socket state %#x, got %#x\n",
169 state
, st
->msg
.vdiag_state
);
174 static void send_req(int fd
)
176 struct sockaddr_nl nladdr
= {
177 .nl_family
= AF_NETLINK
,
181 struct vsock_diag_req vreq
;
184 .nlmsg_len
= sizeof(req
),
185 .nlmsg_type
= SOCK_DIAG_BY_FAMILY
,
186 .nlmsg_flags
= NLM_F_REQUEST
| NLM_F_DUMP
,
189 .sdiag_family
= AF_VSOCK
,
190 .vdiag_states
= ~(__u32
)0,
195 .iov_len
= sizeof(req
),
197 struct msghdr msg
= {
199 .msg_namelen
= sizeof(nladdr
),
205 if (sendmsg(fd
, &msg
, 0) < 0) {
217 static ssize_t
recv_resp(int fd
, void *buf
, size_t len
)
219 struct sockaddr_nl nladdr
= {
220 .nl_family
= AF_NETLINK
,
226 struct msghdr msg
= {
228 .msg_namelen
= sizeof(nladdr
),
235 ret
= recvmsg(fd
, &msg
, 0);
236 } while (ret
< 0 && errno
== EINTR
);
246 static void add_vsock_stat(struct list_head
*sockets
,
247 const struct vsock_diag_msg
*resp
)
249 struct vsock_stat
*st
;
251 st
= malloc(sizeof(*st
));
258 list_add_tail(&st
->list
, sockets
);
262 * Read vsock stats into a list.
264 static void read_vsock_stat(struct list_head
*sockets
)
266 long buf
[8192 / sizeof(long)];
269 fd
= socket(AF_NETLINK
, SOCK_RAW
, NETLINK_SOCK_DIAG
);
278 const struct nlmsghdr
*h
;
281 ret
= recv_resp(fd
, buf
, sizeof(buf
));
284 if (ret
< sizeof(*h
)) {
285 fprintf(stderr
, "short read of %zd bytes\n", ret
);
289 h
= (struct nlmsghdr
*)buf
;
291 while (NLMSG_OK(h
, ret
)) {
292 if (h
->nlmsg_type
== NLMSG_DONE
)
295 if (h
->nlmsg_type
== NLMSG_ERROR
) {
296 const struct nlmsgerr
*err
= NLMSG_DATA(h
);
298 if (h
->nlmsg_len
< NLMSG_LENGTH(sizeof(*err
)))
299 fprintf(stderr
, "NLMSG_ERROR\n");
302 perror("NLMSG_ERROR");
308 if (h
->nlmsg_type
!= SOCK_DIAG_BY_FAMILY
) {
309 fprintf(stderr
, "unexpected nlmsg_type %#x\n",
314 NLMSG_LENGTH(sizeof(struct vsock_diag_msg
))) {
315 fprintf(stderr
, "short vsock_diag_msg\n");
319 add_vsock_stat(sockets
, NLMSG_DATA(h
));
321 h
= NLMSG_NEXT(h
, ret
);
329 static void free_sock_stat(struct list_head
*sockets
)
331 struct vsock_stat
*st
;
332 struct vsock_stat
*next
;
334 list_for_each_entry_safe(st
, next
, sockets
, list
)
338 static void test_no_sockets(unsigned int peer_cid
)
342 read_vsock_stat(&sockets
);
344 check_no_sockets(&sockets
);
346 free_sock_stat(&sockets
);
349 static void test_listen_socket_server(unsigned int peer_cid
)
353 struct sockaddr_vm svm
;
356 .svm_family
= AF_VSOCK
,
358 .svm_cid
= VMADDR_CID_ANY
,
362 struct vsock_stat
*st
;
365 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
367 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
372 if (listen(fd
, 1) < 0) {
377 read_vsock_stat(&sockets
);
379 check_num_sockets(&sockets
, 1);
380 st
= find_vsock_stat(&sockets
, fd
);
381 check_socket_state(st
, TCP_LISTEN
);
384 free_sock_stat(&sockets
);
387 static void test_connect_client(unsigned int peer_cid
)
391 struct sockaddr_vm svm
;
394 .svm_family
= AF_VSOCK
,
402 struct vsock_stat
*st
;
404 control_expectln("LISTENING");
406 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
408 timeout_begin(TIMEOUT
);
410 ret
= connect(fd
, &addr
.sa
, sizeof(addr
.svm
));
411 timeout_check("connect");
412 } while (ret
< 0 && errno
== EINTR
);
420 read_vsock_stat(&sockets
);
422 check_num_sockets(&sockets
, 1);
423 st
= find_vsock_stat(&sockets
, fd
);
424 check_socket_state(st
, TCP_ESTABLISHED
);
426 control_expectln("DONE");
427 control_writeln("DONE");
430 free_sock_stat(&sockets
);
433 static void test_connect_server(unsigned int peer_cid
)
437 struct sockaddr_vm svm
;
440 .svm_family
= AF_VSOCK
,
442 .svm_cid
= VMADDR_CID_ANY
,
447 struct sockaddr_vm svm
;
449 socklen_t clientaddr_len
= sizeof(clientaddr
.svm
);
451 struct vsock_stat
*st
;
455 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
457 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
462 if (listen(fd
, 1) < 0) {
467 control_writeln("LISTENING");
469 timeout_begin(TIMEOUT
);
471 client_fd
= accept(fd
, &clientaddr
.sa
, &clientaddr_len
);
472 timeout_check("accept");
473 } while (client_fd
< 0 && errno
== EINTR
);
480 if (clientaddr
.sa
.sa_family
!= AF_VSOCK
) {
481 fprintf(stderr
, "expected AF_VSOCK from accept(2), got %d\n",
482 clientaddr
.sa
.sa_family
);
485 if (clientaddr
.svm
.svm_cid
!= peer_cid
) {
486 fprintf(stderr
, "expected peer CID %u from accept(2), got %u\n",
487 peer_cid
, clientaddr
.svm
.svm_cid
);
491 read_vsock_stat(&sockets
);
493 check_num_sockets(&sockets
, 2);
494 find_vsock_stat(&sockets
, fd
);
495 st
= find_vsock_stat(&sockets
, client_fd
);
496 check_socket_state(st
, TCP_ESTABLISHED
);
498 control_writeln("DONE");
499 control_expectln("DONE");
503 free_sock_stat(&sockets
);
508 void (*run_client
)(unsigned int peer_cid
);
509 void (*run_server
)(unsigned int peer_cid
);
512 .name
= "No sockets",
513 .run_server
= test_no_sockets
,
516 .name
= "Listen socket",
517 .run_server
= test_listen_socket_server
,
521 .run_client
= test_connect_client
,
522 .run_server
= test_connect_server
,
527 static void init_signals(void)
529 struct sigaction act
= {
530 .sa_handler
= sigalrm
,
533 sigaction(SIGALRM
, &act
, NULL
);
534 signal(SIGPIPE
, SIG_IGN
);
537 static unsigned int parse_cid(const char *str
)
543 n
= strtoul(str
, &endptr
, 10);
544 if (errno
|| *endptr
!= '\0') {
545 fprintf(stderr
, "malformed CID \"%s\"\n", str
);
551 static const char optstring
[] = "";
552 static const struct option longopts
[] = {
554 .name
= "control-host",
555 .has_arg
= required_argument
,
559 .name
= "control-port",
560 .has_arg
= required_argument
,
565 .has_arg
= required_argument
,
570 .has_arg
= required_argument
,
575 .has_arg
= no_argument
,
581 static void usage(void)
583 fprintf(stderr
, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
585 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
586 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
588 "Run vsock_diag.ko tests. Must be launched in both\n"
589 "guest and host. One side must use --mode=client and\n"
590 "the other side must use --mode=server.\n"
592 "A TCP control socket connection is used to coordinate tests\n"
593 "between the client and the server. The server requires a\n"
594 "listen address and the client requires an address to\n"
597 "The CID of the other side must be given with --peer-cid=<cid>.\n");
601 int main(int argc
, char **argv
)
603 const char *control_host
= NULL
;
604 const char *control_port
= NULL
;
605 int mode
= TEST_MODE_UNSET
;
606 unsigned int peer_cid
= VMADDR_CID_ANY
;
612 int opt
= getopt_long(argc
, argv
, optstring
, longopts
, NULL
);
619 control_host
= optarg
;
622 if (strcmp(optarg
, "client") == 0)
623 mode
= TEST_MODE_CLIENT
;
624 else if (strcmp(optarg
, "server") == 0)
625 mode
= TEST_MODE_SERVER
;
627 fprintf(stderr
, "--mode must be \"client\" or \"server\"\n");
632 peer_cid
= parse_cid(optarg
);
635 control_port
= optarg
;
645 if (mode
== TEST_MODE_UNSET
)
647 if (peer_cid
== VMADDR_CID_ANY
)
651 if (mode
!= TEST_MODE_SERVER
)
653 control_host
= "0.0.0.0";
656 control_init(control_host
, control_port
, mode
== TEST_MODE_SERVER
);
658 for (i
= 0; test_cases
[i
].name
; i
++) {
659 void (*run
)(unsigned int peer_cid
);
661 printf("%s...", test_cases
[i
].name
);
664 if (mode
== TEST_MODE_CLIENT
)
665 run
= test_cases
[i
].run_client
;
667 run
= test_cases
[i
].run_server
;