2 * vhost transport for vsock
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
8 * This work is licensed under the terms of the GNU GPL, version 2.
10 #include <linux/miscdevice.h>
11 #include <linux/module.h>
12 #include <linux/mutex.h>
14 #include <linux/virtio_vsock.h>
15 #include <linux/vhost.h>
17 #include <net/af_vsock.h>
21 #define VHOST_VSOCK_DEFAULT_HOST_CID 2
23 static int vhost_transport_socket_init(struct vsock_sock
*vsk
,
24 struct vsock_sock
*psk
);
27 VHOST_VSOCK_FEATURES
= VHOST_FEATURES
,
30 /* Used to track all the vhost_vsock instances on the system. */
31 static LIST_HEAD(vhost_vsock_list
);
32 static DEFINE_MUTEX(vhost_vsock_mutex
);
34 struct vhost_vsock_virtqueue
{
35 struct vhost_virtqueue vq
;
41 /* Vhost vsock virtqueue*/
42 struct vhost_vsock_virtqueue vqs
[VSOCK_VQ_MAX
];
43 /* Link to global vhost_vsock_list*/
44 struct list_head list
;
45 /* Head for pkt from host to guest */
46 struct list_head send_pkt_list
;
47 /* Work item to send pkt */
48 struct vhost_work send_pkt_work
;
49 /* Wait queue for send pkt */
50 wait_queue_head_t queue_wait
;
51 /* Used for global tx buf limitation */
53 /* Guest contex id this vhost_vsock instance handles */
57 static u32
vhost_transport_get_local_cid(void)
59 u32 cid
= VHOST_VSOCK_DEFAULT_HOST_CID
;
63 static struct vhost_vsock
*vhost_vsock_get(u32 guest_cid
)
65 struct vhost_vsock
*vsock
;
67 mutex_lock(&vhost_vsock_mutex
);
68 list_for_each_entry(vsock
, &vhost_vsock_list
, list
) {
69 if (vsock
->guest_cid
== guest_cid
) {
70 mutex_unlock(&vhost_vsock_mutex
);
74 mutex_unlock(&vhost_vsock_mutex
);
80 vhost_transport_do_send_pkt(struct vhost_vsock
*vsock
,
81 struct vhost_virtqueue
*vq
)
85 mutex_lock(&vq
->mutex
);
86 vhost_disable_notify(&vsock
->dev
, vq
);
88 struct virtio_vsock_pkt
*pkt
;
89 struct iov_iter iov_iter
;
96 if (list_empty(&vsock
->send_pkt_list
)) {
97 vhost_enable_notify(&vsock
->dev
, vq
);
101 head
= vhost_get_vq_desc(vq
, vq
->iov
, ARRAY_SIZE(vq
->iov
),
102 &out
, &in
, NULL
, NULL
);
103 pr_debug("%s: head = %d\n", __func__
, head
);
107 if (head
== vq
->num
) {
108 if (unlikely(vhost_enable_notify(&vsock
->dev
, vq
))) {
109 vhost_disable_notify(&vsock
->dev
, vq
);
115 pkt
= list_first_entry(&vsock
->send_pkt_list
,
116 struct virtio_vsock_pkt
, list
);
117 list_del_init(&pkt
->list
);
120 virtio_transport_free_pkt(pkt
);
121 vq_err(vq
, "Expected 0 output buffers, got %u\n", out
);
125 len
= iov_length(&vq
->iov
[out
], in
);
126 iov_iter_init(&iov_iter
, READ
, &vq
->iov
[out
], in
, len
);
128 nbytes
= copy_to_iter(&pkt
->hdr
, sizeof(pkt
->hdr
), &iov_iter
);
129 if (nbytes
!= sizeof(pkt
->hdr
)) {
130 virtio_transport_free_pkt(pkt
);
131 vq_err(vq
, "Faulted on copying pkt hdr\n");
135 nbytes
= copy_to_iter(pkt
->buf
, pkt
->len
, &iov_iter
);
136 if (nbytes
!= pkt
->len
) {
137 virtio_transport_free_pkt(pkt
);
138 vq_err(vq
, "Faulted on copying pkt buf\n");
142 vhost_add_used(vq
, head
, pkt
->len
); /* TODO should this be sizeof(pkt->hdr) + pkt->len? */
145 virtio_transport_dec_tx_pkt(pkt
);
146 vsock
->total_tx_buf
-= pkt
->len
;
148 sk
= sk_vsock(pkt
->trans
->vsk
);
149 /* Release refcnt taken in vhost_transport_send_pkt */
152 virtio_transport_free_pkt(pkt
);
155 vhost_signal(&vsock
->dev
, vq
);
156 mutex_unlock(&vq
->mutex
);
159 wake_up(&vsock
->queue_wait
);
162 static void vhost_transport_send_pkt_work(struct vhost_work
*work
)
164 struct vhost_virtqueue
*vq
;
165 struct vhost_vsock
*vsock
;
167 vsock
= container_of(work
, struct vhost_vsock
, send_pkt_work
);
168 vq
= &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
170 vhost_transport_do_send_pkt(vsock
, vq
);
174 vhost_transport_send_pkt(struct vsock_sock
*vsk
,
175 struct virtio_vsock_pkt_info
*info
)
177 u32 src_cid
, src_port
, dst_cid
, dst_port
;
178 struct virtio_transport
*trans
;
179 struct virtio_vsock_pkt
*pkt
;
180 struct vhost_virtqueue
*vq
;
181 struct vhost_vsock
*vsock
;
182 u32 pkt_len
= info
->pkt_len
;
185 src_cid
= vhost_transport_get_local_cid();
186 src_port
= vsk
->local_addr
.svm_port
;
187 if (!info
->remote_cid
) {
188 dst_cid
= vsk
->remote_addr
.svm_cid
;
189 dst_port
= vsk
->remote_addr
.svm_port
;
191 dst_cid
= info
->remote_cid
;
192 dst_port
= info
->remote_port
;
195 /* Find the vhost_vsock according to guest context id */
196 vsock
= vhost_vsock_get(dst_cid
);
201 vq
= &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
203 /* we can send less than pkt_len bytes */
204 if (pkt_len
> VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE
)
205 pkt_len
= VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE
;
207 /* virtio_transport_get_credit might return less than pkt_len credit */
208 pkt_len
= virtio_transport_get_credit(trans
, pkt_len
);
210 /* Do not send zero length OP_RW pkt*/
211 if (pkt_len
== 0 && info
->op
== VIRTIO_VSOCK_OP_RW
)
214 /* Respect global tx buf limitation */
215 mutex_lock(&vq
->mutex
);
216 while (pkt_len
+ vsock
->total_tx_buf
> VIRTIO_VSOCK_MAX_TX_BUF_SIZE
) {
217 prepare_to_wait_exclusive(&vsock
->queue_wait
, &wait
,
218 TASK_UNINTERRUPTIBLE
);
219 mutex_unlock(&vq
->mutex
);
221 mutex_lock(&vq
->mutex
);
222 finish_wait(&vsock
->queue_wait
, &wait
);
224 vsock
->total_tx_buf
+= pkt_len
;
225 mutex_unlock(&vq
->mutex
);
227 pkt
= virtio_transport_alloc_pkt(vsk
, info
, pkt_len
,
231 mutex_lock(&vq
->mutex
);
232 vsock
->total_tx_buf
-= pkt_len
;
233 mutex_unlock(&vq
->mutex
);
234 virtio_transport_put_credit(trans
, pkt_len
);
238 pr_debug("%s:info->pkt_len= %d\n", __func__
, pkt_len
);
239 /* Released in vhost_transport_do_send_pkt */
240 sock_hold(&trans
->vsk
->sk
);
241 virtio_transport_inc_tx_pkt(pkt
);
243 /* Queue it up in vhost work */
244 mutex_lock(&vq
->mutex
);
245 list_add_tail(&pkt
->list
, &vsock
->send_pkt_list
);
246 vhost_work_queue(&vsock
->dev
, &vsock
->send_pkt_work
);
247 mutex_unlock(&vq
->mutex
);
252 static struct virtio_transport_pkt_ops vhost_ops
= {
253 .send_pkt
= vhost_transport_send_pkt
,
256 static struct virtio_vsock_pkt
*
257 vhost_vsock_alloc_pkt(struct vhost_virtqueue
*vq
,
258 unsigned int out
, unsigned int in
)
260 struct virtio_vsock_pkt
*pkt
;
261 struct iov_iter iov_iter
;
266 vq_err(vq
, "Expected 0 input buffers, got %u\n", in
);
270 pkt
= kzalloc(sizeof(*pkt
), GFP_KERNEL
);
274 len
= iov_length(vq
->iov
, out
);
275 iov_iter_init(&iov_iter
, WRITE
, vq
->iov
, out
, len
);
277 nbytes
= copy_from_iter(&pkt
->hdr
, sizeof(pkt
->hdr
), &iov_iter
);
278 if (nbytes
!= sizeof(pkt
->hdr
)) {
279 vq_err(vq
, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
280 sizeof(pkt
->hdr
), nbytes
);
285 if (le16_to_cpu(pkt
->hdr
.type
) == VIRTIO_VSOCK_TYPE_DGRAM
)
286 pkt
->len
= le32_to_cpu(pkt
->hdr
.len
) & 0XFFFF;
287 else if (le16_to_cpu(pkt
->hdr
.type
) == VIRTIO_VSOCK_TYPE_STREAM
)
288 pkt
->len
= le32_to_cpu(pkt
->hdr
.len
);
294 /* The pkt is too big */
295 if (pkt
->len
> VIRTIO_VSOCK_MAX_PKT_BUF_SIZE
) {
300 pkt
->buf
= kmalloc(pkt
->len
, GFP_KERNEL
);
306 nbytes
= copy_from_iter(pkt
->buf
, pkt
->len
, &iov_iter
);
307 if (nbytes
!= pkt
->len
) {
308 vq_err(vq
, "Expected %u byte payload, got %zu bytes\n",
310 virtio_transport_free_pkt(pkt
);
317 static void vhost_vsock_handle_ctl_kick(struct vhost_work
*work
)
319 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
321 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
324 pr_debug("%s vq=%p, vsock=%p\n", __func__
, vq
, vsock
);
327 static void vhost_vsock_handle_tx_kick(struct vhost_work
*work
)
329 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
331 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
333 struct virtio_vsock_pkt
*pkt
;
335 unsigned int out
, in
;
339 mutex_lock(&vq
->mutex
);
340 vhost_disable_notify(&vsock
->dev
, vq
);
342 head
= vhost_get_vq_desc(vq
, vq
->iov
, ARRAY_SIZE(vq
->iov
),
343 &out
, &in
, NULL
, NULL
);
347 if (head
== vq
->num
) {
348 if (unlikely(vhost_enable_notify(&vsock
->dev
, vq
))) {
349 vhost_disable_notify(&vsock
->dev
, vq
);
355 pkt
= vhost_vsock_alloc_pkt(vq
, out
, in
);
357 vq_err(vq
, "Faulted on pkt\n");
363 /* Only accept correctly addressed packets */
364 if (le32_to_cpu(pkt
->hdr
.src_cid
) == vsock
->guest_cid
&&
365 le32_to_cpu(pkt
->hdr
.dst_cid
) == vhost_transport_get_local_cid())
366 virtio_transport_recv_pkt(pkt
);
368 virtio_transport_free_pkt(pkt
);
370 vhost_add_used(vq
, head
, len
);
374 vhost_signal(&vsock
->dev
, vq
);
375 mutex_unlock(&vq
->mutex
);
378 static void vhost_vsock_handle_rx_kick(struct vhost_work
*work
)
380 struct vhost_virtqueue
*vq
= container_of(work
, struct vhost_virtqueue
,
382 struct vhost_vsock
*vsock
= container_of(vq
->dev
, struct vhost_vsock
,
385 vhost_transport_do_send_pkt(vsock
, vq
);
388 static int vhost_vsock_dev_open(struct inode
*inode
, struct file
*file
)
390 struct vhost_virtqueue
**vqs
;
391 struct vhost_vsock
*vsock
;
394 vsock
= kzalloc(sizeof(*vsock
), GFP_KERNEL
);
398 pr_debug("%s:vsock=%p\n", __func__
, vsock
);
400 vqs
= kmalloc(VSOCK_VQ_MAX
* sizeof(*vqs
), GFP_KERNEL
);
406 vqs
[VSOCK_VQ_CTRL
] = &vsock
->vqs
[VSOCK_VQ_CTRL
].vq
;
407 vqs
[VSOCK_VQ_TX
] = &vsock
->vqs
[VSOCK_VQ_TX
].vq
;
408 vqs
[VSOCK_VQ_RX
] = &vsock
->vqs
[VSOCK_VQ_RX
].vq
;
409 vsock
->vqs
[VSOCK_VQ_CTRL
].vq
.handle_kick
= vhost_vsock_handle_ctl_kick
;
410 vsock
->vqs
[VSOCK_VQ_TX
].vq
.handle_kick
= vhost_vsock_handle_tx_kick
;
411 vsock
->vqs
[VSOCK_VQ_RX
].vq
.handle_kick
= vhost_vsock_handle_rx_kick
;
413 vhost_dev_init(&vsock
->dev
, vqs
, VSOCK_VQ_MAX
);
415 file
->private_data
= vsock
;
416 init_waitqueue_head(&vsock
->queue_wait
);
417 INIT_LIST_HEAD(&vsock
->send_pkt_list
);
418 vhost_work_init(&vsock
->send_pkt_work
, vhost_transport_send_pkt_work
);
420 mutex_lock(&vhost_vsock_mutex
);
421 list_add_tail(&vsock
->list
, &vhost_vsock_list
);
422 mutex_unlock(&vhost_vsock_mutex
);
430 static void vhost_vsock_flush(struct vhost_vsock
*vsock
)
434 for (i
= 0; i
< VSOCK_VQ_MAX
; i
++)
435 vhost_poll_flush(&vsock
->vqs
[i
].vq
.poll
);
436 vhost_work_flush(&vsock
->dev
, &vsock
->send_pkt_work
);
439 static int vhost_vsock_dev_release(struct inode
*inode
, struct file
*file
)
441 struct vhost_vsock
*vsock
= file
->private_data
;
443 mutex_lock(&vhost_vsock_mutex
);
444 list_del(&vsock
->list
);
445 mutex_unlock(&vhost_vsock_mutex
);
447 vhost_dev_stop(&vsock
->dev
);
448 vhost_vsock_flush(vsock
);
449 vhost_dev_cleanup(&vsock
->dev
, false);
450 kfree(vsock
->dev
.vqs
);
455 static int vhost_vsock_set_cid(struct vhost_vsock
*vsock
, u32 guest_cid
)
457 struct vhost_vsock
*other
;
459 /* Refuse reserved CIDs */
460 if (guest_cid
<= VMADDR_CID_HOST
) {
464 /* Refuse if CID is already in use */
465 other
= vhost_vsock_get(guest_cid
);
466 if (other
&& other
!= vsock
) {
470 mutex_lock(&vhost_vsock_mutex
);
471 vsock
->guest_cid
= guest_cid
;
472 pr_debug("%s:guest_cid=%d\n", __func__
, guest_cid
);
473 mutex_unlock(&vhost_vsock_mutex
);
478 static int vhost_vsock_set_features(struct vhost_vsock
*vsock
, u64 features
)
480 struct vhost_virtqueue
*vq
;
483 if (features
& ~VHOST_VSOCK_FEATURES
)
486 mutex_lock(&vsock
->dev
.mutex
);
487 if ((features
& (1 << VHOST_F_LOG_ALL
)) &&
488 !vhost_log_access_ok(&vsock
->dev
)) {
489 mutex_unlock(&vsock
->dev
.mutex
);
493 for (i
= 0; i
< VSOCK_VQ_MAX
; i
++) {
494 vq
= &vsock
->vqs
[i
].vq
;
495 mutex_lock(&vq
->mutex
);
496 vq
->acked_features
= features
;
497 mutex_unlock(&vq
->mutex
);
499 mutex_unlock(&vsock
->dev
.mutex
);
503 static long vhost_vsock_dev_ioctl(struct file
*f
, unsigned int ioctl
,
506 struct vhost_vsock
*vsock
= f
->private_data
;
507 void __user
*argp
= (void __user
*)arg
;
508 u64 __user
*featurep
= argp
;
509 u32 __user
*cidp
= argp
;
515 case VHOST_VSOCK_SET_GUEST_CID
:
516 if (get_user(guest_cid
, cidp
))
518 return vhost_vsock_set_cid(vsock
, guest_cid
);
519 case VHOST_GET_FEATURES
:
520 features
= VHOST_VSOCK_FEATURES
;
521 if (copy_to_user(featurep
, &features
, sizeof(features
)))
524 case VHOST_SET_FEATURES
:
525 if (copy_from_user(&features
, featurep
, sizeof(features
)))
527 return vhost_vsock_set_features(vsock
, features
);
529 mutex_lock(&vsock
->dev
.mutex
);
530 r
= vhost_dev_ioctl(&vsock
->dev
, ioctl
, argp
);
531 if (r
== -ENOIOCTLCMD
)
532 r
= vhost_vring_ioctl(&vsock
->dev
, ioctl
, argp
);
534 vhost_vsock_flush(vsock
);
535 mutex_unlock(&vsock
->dev
.mutex
);
540 static const struct file_operations vhost_vsock_fops
= {
541 .owner
= THIS_MODULE
,
542 .open
= vhost_vsock_dev_open
,
543 .release
= vhost_vsock_dev_release
,
544 .llseek
= noop_llseek
,
545 .unlocked_ioctl
= vhost_vsock_dev_ioctl
,
548 static struct miscdevice vhost_vsock_misc
= {
549 .minor
= MISC_DYNAMIC_MINOR
,
550 .name
= "vhost-vsock",
551 .fops
= &vhost_vsock_fops
,
555 vhost_transport_socket_init(struct vsock_sock
*vsk
, struct vsock_sock
*psk
)
557 struct virtio_transport
*trans
;
560 ret
= virtio_transport_do_socket_init(vsk
, psk
);
565 trans
->ops
= &vhost_ops
;
570 static struct vsock_transport vhost_transport
= {
571 .get_local_cid
= vhost_transport_get_local_cid
,
573 .init
= vhost_transport_socket_init
,
574 .destruct
= virtio_transport_destruct
,
575 .release
= virtio_transport_release
,
576 .connect
= virtio_transport_connect
,
577 .shutdown
= virtio_transport_shutdown
,
579 .dgram_enqueue
= virtio_transport_dgram_enqueue
,
580 .dgram_dequeue
= virtio_transport_dgram_dequeue
,
581 .dgram_bind
= virtio_transport_dgram_bind
,
582 .dgram_allow
= virtio_transport_dgram_allow
,
584 .stream_enqueue
= virtio_transport_stream_enqueue
,
585 .stream_dequeue
= virtio_transport_stream_dequeue
,
586 .stream_has_data
= virtio_transport_stream_has_data
,
587 .stream_has_space
= virtio_transport_stream_has_space
,
588 .stream_rcvhiwat
= virtio_transport_stream_rcvhiwat
,
589 .stream_is_active
= virtio_transport_stream_is_active
,
590 .stream_allow
= virtio_transport_stream_allow
,
592 .notify_poll_in
= virtio_transport_notify_poll_in
,
593 .notify_poll_out
= virtio_transport_notify_poll_out
,
594 .notify_recv_init
= virtio_transport_notify_recv_init
,
595 .notify_recv_pre_block
= virtio_transport_notify_recv_pre_block
,
596 .notify_recv_pre_dequeue
= virtio_transport_notify_recv_pre_dequeue
,
597 .notify_recv_post_dequeue
= virtio_transport_notify_recv_post_dequeue
,
598 .notify_send_init
= virtio_transport_notify_send_init
,
599 .notify_send_pre_block
= virtio_transport_notify_send_pre_block
,
600 .notify_send_pre_enqueue
= virtio_transport_notify_send_pre_enqueue
,
601 .notify_send_post_enqueue
= virtio_transport_notify_send_post_enqueue
,
603 .set_buffer_size
= virtio_transport_set_buffer_size
,
604 .set_min_buffer_size
= virtio_transport_set_min_buffer_size
,
605 .set_max_buffer_size
= virtio_transport_set_max_buffer_size
,
606 .get_buffer_size
= virtio_transport_get_buffer_size
,
607 .get_min_buffer_size
= virtio_transport_get_min_buffer_size
,
608 .get_max_buffer_size
= virtio_transport_get_max_buffer_size
,
611 static int __init
vhost_vsock_init(void)
615 ret
= vsock_core_init(&vhost_transport
);
618 return misc_register(&vhost_vsock_misc
);
621 static void __exit
vhost_vsock_exit(void)
623 misc_deregister(&vhost_vsock_misc
);
627 module_init(vhost_vsock_init
);
628 module_exit(vhost_vsock_exit
);
629 MODULE_LICENSE("GPL v2");
630 MODULE_AUTHOR("Asias He");
631 MODULE_DESCRIPTION("vhost transport for vsock ");