2 * vsock_diag_test - vsock_diag.ko test suite
4 * Copyright (C) 2017 Red Hat, Inc.
6 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 * This program is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU General Public License
10 * as published by the Free Software Foundation; version 2
22 #include <sys/socket.h>
24 #include <sys/types.h>
25 #include <linux/list.h>
26 #include <linux/net.h>
27 #include <linux/netlink.h>
28 #include <linux/sock_diag.h>
29 #include <netinet/tcp.h>
31 #include "../../../include/uapi/linux/vm_sockets.h"
32 #include "../../../include/uapi/linux/vm_sockets_diag.h"
43 /* Per-socket status */
45 struct list_head list
;
46 struct vsock_diag_msg msg
;
49 static const char *sock_type_str(int type
)
57 return "INVALID TYPE";
61 static const char *sock_state_str(int state
)
71 return "DISCONNECTING";
75 return "INVALID STATE";
79 static const char *sock_shutdown_str(int shutdown
)
83 return "RCV_SHUTDOWN";
85 return "SEND_SHUTDOWN";
87 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
93 static void print_vsock_addr(FILE *fp
, unsigned int cid
, unsigned int port
)
95 if (cid
== VMADDR_CID_ANY
)
98 fprintf(fp
, "%u:", cid
);
100 if (port
== VMADDR_PORT_ANY
)
103 fprintf(fp
, "%u", port
);
106 static void print_vsock_stat(FILE *fp
, struct vsock_stat
*st
)
108 print_vsock_addr(fp
, st
->msg
.vdiag_src_cid
, st
->msg
.vdiag_src_port
);
110 print_vsock_addr(fp
, st
->msg
.vdiag_dst_cid
, st
->msg
.vdiag_dst_port
);
111 fprintf(fp
, " %s %s %s %u\n",
112 sock_type_str(st
->msg
.vdiag_type
),
113 sock_state_str(st
->msg
.vdiag_state
),
114 sock_shutdown_str(st
->msg
.vdiag_shutdown
),
118 static void print_vsock_stats(FILE *fp
, struct list_head
*head
)
120 struct vsock_stat
*st
;
122 list_for_each_entry(st
, head
, list
)
123 print_vsock_stat(fp
, st
);
126 static struct vsock_stat
*find_vsock_stat(struct list_head
*head
, int fd
)
128 struct vsock_stat
*st
;
131 if (fstat(fd
, &stat
) < 0) {
136 list_for_each_entry(st
, head
, list
)
137 if (st
->msg
.vdiag_ino
== stat
.st_ino
)
140 fprintf(stderr
, "cannot find fd %d\n", fd
);
144 static void check_no_sockets(struct list_head
*head
)
146 if (!list_empty(head
)) {
147 fprintf(stderr
, "expected no sockets\n");
148 print_vsock_stats(stderr
, head
);
153 static void check_num_sockets(struct list_head
*head
, int expected
)
155 struct list_head
*node
;
158 list_for_each(node
, head
)
162 fprintf(stderr
, "expected %d sockets, found %d\n",
164 print_vsock_stats(stderr
, head
);
169 static void check_socket_state(struct vsock_stat
*st
, __u8 state
)
171 if (st
->msg
.vdiag_state
!= state
) {
172 fprintf(stderr
, "expected socket state %#x, got %#x\n",
173 state
, st
->msg
.vdiag_state
);
178 static void send_req(int fd
)
180 struct sockaddr_nl nladdr
= {
181 .nl_family
= AF_NETLINK
,
185 struct vsock_diag_req vreq
;
188 .nlmsg_len
= sizeof(req
),
189 .nlmsg_type
= SOCK_DIAG_BY_FAMILY
,
190 .nlmsg_flags
= NLM_F_REQUEST
| NLM_F_DUMP
,
193 .sdiag_family
= AF_VSOCK
,
194 .vdiag_states
= ~(__u32
)0,
199 .iov_len
= sizeof(req
),
201 struct msghdr msg
= {
203 .msg_namelen
= sizeof(nladdr
),
209 if (sendmsg(fd
, &msg
, 0) < 0) {
221 static ssize_t
recv_resp(int fd
, void *buf
, size_t len
)
223 struct sockaddr_nl nladdr
= {
224 .nl_family
= AF_NETLINK
,
230 struct msghdr msg
= {
232 .msg_namelen
= sizeof(nladdr
),
239 ret
= recvmsg(fd
, &msg
, 0);
240 } while (ret
< 0 && errno
== EINTR
);
250 static void add_vsock_stat(struct list_head
*sockets
,
251 const struct vsock_diag_msg
*resp
)
253 struct vsock_stat
*st
;
255 st
= malloc(sizeof(*st
));
262 list_add_tail(&st
->list
, sockets
);
266 * Read vsock stats into a list.
268 static void read_vsock_stat(struct list_head
*sockets
)
270 long buf
[8192 / sizeof(long)];
273 fd
= socket(AF_NETLINK
, SOCK_RAW
, NETLINK_SOCK_DIAG
);
282 const struct nlmsghdr
*h
;
285 ret
= recv_resp(fd
, buf
, sizeof(buf
));
288 if (ret
< sizeof(*h
)) {
289 fprintf(stderr
, "short read of %zd bytes\n", ret
);
293 h
= (struct nlmsghdr
*)buf
;
295 while (NLMSG_OK(h
, ret
)) {
296 if (h
->nlmsg_type
== NLMSG_DONE
)
299 if (h
->nlmsg_type
== NLMSG_ERROR
) {
300 const struct nlmsgerr
*err
= NLMSG_DATA(h
);
302 if (h
->nlmsg_len
< NLMSG_LENGTH(sizeof(*err
)))
303 fprintf(stderr
, "NLMSG_ERROR\n");
306 perror("NLMSG_ERROR");
312 if (h
->nlmsg_type
!= SOCK_DIAG_BY_FAMILY
) {
313 fprintf(stderr
, "unexpected nlmsg_type %#x\n",
318 NLMSG_LENGTH(sizeof(struct vsock_diag_msg
))) {
319 fprintf(stderr
, "short vsock_diag_msg\n");
323 add_vsock_stat(sockets
, NLMSG_DATA(h
));
325 h
= NLMSG_NEXT(h
, ret
);
333 static void free_sock_stat(struct list_head
*sockets
)
335 struct vsock_stat
*st
;
336 struct vsock_stat
*next
;
338 list_for_each_entry_safe(st
, next
, sockets
, list
)
342 static void test_no_sockets(unsigned int peer_cid
)
346 read_vsock_stat(&sockets
);
348 check_no_sockets(&sockets
);
350 free_sock_stat(&sockets
);
353 static void test_listen_socket_server(unsigned int peer_cid
)
357 struct sockaddr_vm svm
;
360 .svm_family
= AF_VSOCK
,
362 .svm_cid
= VMADDR_CID_ANY
,
366 struct vsock_stat
*st
;
369 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
371 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
376 if (listen(fd
, 1) < 0) {
381 read_vsock_stat(&sockets
);
383 check_num_sockets(&sockets
, 1);
384 st
= find_vsock_stat(&sockets
, fd
);
385 check_socket_state(st
, TCP_LISTEN
);
388 free_sock_stat(&sockets
);
391 static void test_connect_client(unsigned int peer_cid
)
395 struct sockaddr_vm svm
;
398 .svm_family
= AF_VSOCK
,
406 struct vsock_stat
*st
;
408 control_expectln("LISTENING");
410 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
412 timeout_begin(TIMEOUT
);
414 ret
= connect(fd
, &addr
.sa
, sizeof(addr
.svm
));
415 timeout_check("connect");
416 } while (ret
< 0 && errno
== EINTR
);
424 read_vsock_stat(&sockets
);
426 check_num_sockets(&sockets
, 1);
427 st
= find_vsock_stat(&sockets
, fd
);
428 check_socket_state(st
, TCP_ESTABLISHED
);
430 control_expectln("DONE");
431 control_writeln("DONE");
434 free_sock_stat(&sockets
);
437 static void test_connect_server(unsigned int peer_cid
)
441 struct sockaddr_vm svm
;
444 .svm_family
= AF_VSOCK
,
446 .svm_cid
= VMADDR_CID_ANY
,
451 struct sockaddr_vm svm
;
453 socklen_t clientaddr_len
= sizeof(clientaddr
.svm
);
455 struct vsock_stat
*st
;
459 fd
= socket(AF_VSOCK
, SOCK_STREAM
, 0);
461 if (bind(fd
, &addr
.sa
, sizeof(addr
.svm
)) < 0) {
466 if (listen(fd
, 1) < 0) {
471 control_writeln("LISTENING");
473 timeout_begin(TIMEOUT
);
475 client_fd
= accept(fd
, &clientaddr
.sa
, &clientaddr_len
);
476 timeout_check("accept");
477 } while (client_fd
< 0 && errno
== EINTR
);
484 if (clientaddr
.sa
.sa_family
!= AF_VSOCK
) {
485 fprintf(stderr
, "expected AF_VSOCK from accept(2), got %d\n",
486 clientaddr
.sa
.sa_family
);
489 if (clientaddr
.svm
.svm_cid
!= peer_cid
) {
490 fprintf(stderr
, "expected peer CID %u from accept(2), got %u\n",
491 peer_cid
, clientaddr
.svm
.svm_cid
);
495 read_vsock_stat(&sockets
);
497 check_num_sockets(&sockets
, 2);
498 find_vsock_stat(&sockets
, fd
);
499 st
= find_vsock_stat(&sockets
, client_fd
);
500 check_socket_state(st
, TCP_ESTABLISHED
);
502 control_writeln("DONE");
503 control_expectln("DONE");
507 free_sock_stat(&sockets
);
512 void (*run_client
)(unsigned int peer_cid
);
513 void (*run_server
)(unsigned int peer_cid
);
516 .name
= "No sockets",
517 .run_server
= test_no_sockets
,
520 .name
= "Listen socket",
521 .run_server
= test_listen_socket_server
,
525 .run_client
= test_connect_client
,
526 .run_server
= test_connect_server
,
531 static void init_signals(void)
533 struct sigaction act
= {
534 .sa_handler
= sigalrm
,
537 sigaction(SIGALRM
, &act
, NULL
);
538 signal(SIGPIPE
, SIG_IGN
);
541 static unsigned int parse_cid(const char *str
)
547 n
= strtoul(str
, &endptr
, 10);
548 if (errno
|| *endptr
!= '\0') {
549 fprintf(stderr
, "malformed CID \"%s\"\n", str
);
555 static const char optstring
[] = "";
556 static const struct option longopts
[] = {
558 .name
= "control-host",
559 .has_arg
= required_argument
,
563 .name
= "control-port",
564 .has_arg
= required_argument
,
569 .has_arg
= required_argument
,
574 .has_arg
= required_argument
,
579 .has_arg
= no_argument
,
585 static void usage(void)
587 fprintf(stderr
, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
589 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
590 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
592 "Run vsock_diag.ko tests. Must be launched in both\n"
593 "guest and host. One side must use --mode=client and\n"
594 "the other side must use --mode=server.\n"
596 "A TCP control socket connection is used to coordinate tests\n"
597 "between the client and the server. The server requires a\n"
598 "listen address and the client requires an address to\n"
601 "The CID of the other side must be given with --peer-cid=<cid>.\n");
605 int main(int argc
, char **argv
)
607 const char *control_host
= NULL
;
608 const char *control_port
= NULL
;
609 int mode
= TEST_MODE_UNSET
;
610 unsigned int peer_cid
= VMADDR_CID_ANY
;
616 int opt
= getopt_long(argc
, argv
, optstring
, longopts
, NULL
);
623 control_host
= optarg
;
626 if (strcmp(optarg
, "client") == 0)
627 mode
= TEST_MODE_CLIENT
;
628 else if (strcmp(optarg
, "server") == 0)
629 mode
= TEST_MODE_SERVER
;
631 fprintf(stderr
, "--mode must be \"client\" or \"server\"\n");
636 peer_cid
= parse_cid(optarg
);
639 control_port
= optarg
;
649 if (mode
== TEST_MODE_UNSET
)
651 if (peer_cid
== VMADDR_CID_ANY
)
655 if (mode
!= TEST_MODE_SERVER
)
657 control_host
= "0.0.0.0";
660 control_init(control_host
, control_port
, mode
== TEST_MODE_SERVER
);
662 for (i
= 0; test_cases
[i
].name
; i
++) {
663 void (*run
)(unsigned int peer_cid
);
665 printf("%s...", test_cases
[i
].name
);
668 if (mode
== TEST_MODE_CLIENT
)
669 run
= test_cases
[i
].run_client
;
671 run
= test_cases
[i
].run_server
;