2 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
15 #include <linux/module.h>
16 #include <linux/net.h>
17 #include <linux/socket.h>
21 #include <xen/events.h>
22 #include <xen/grant_table.h>
24 #include <xen/xenbus.h>
25 #include <xen/interface/io/pvcalls.h>
27 #include "pvcalls-front.h"
29 #define PVCALLS_INVALID_ID UINT_MAX
30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
32 #define PVCALLS_FRONT_MAX_SPIN 5000
34 static struct proto pvcalls_proto
= {
37 .obj_size
= sizeof(struct sock
),
40 struct pvcalls_bedata
{
41 struct xen_pvcalls_front_ring ring
;
45 struct list_head socket_mappings
;
46 spinlock_t socket_lock
;
48 wait_queue_head_t inflight_req
;
49 struct xen_pvcalls_response rsp
[PVCALLS_NR_RSP_PER_RING
];
51 /* Only one front/back connection supported. */
52 static struct xenbus_device
*pvcalls_front_dev
;
53 static atomic_t pvcalls_refcount
;
55 /* first increment refcount, then proceed */
56 #define pvcalls_enter() { \
57 atomic_inc(&pvcalls_refcount); \
60 /* first complete other operations, then decrement refcount */
61 #define pvcalls_exit() { \
62 atomic_dec(&pvcalls_refcount); \
67 struct list_head list
;
74 struct pvcalls_data_intf
*ring
;
75 struct pvcalls_data data
;
76 struct mutex in_mutex
;
77 struct mutex out_mutex
;
79 wait_queue_head_t inflight_conn_req
;
83 * Socket status, needs to be 64-bit aligned due to the
84 * test_and_* functions which have this requirement on arm64.
86 #define PVCALLS_STATUS_UNINITALIZED 0
87 #define PVCALLS_STATUS_BIND 1
88 #define PVCALLS_STATUS_LISTEN 2
89 uint8_t status
__attribute__((aligned(8)));
91 * Internal state-machine flags.
92 * Only one accept operation can be inflight for a socket.
93 * Only one poll operation can be inflight for a given socket.
94 * flags needs to be 64-bit aligned due to the test_and_*
95 * functions which have this requirement on arm64.
97 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
98 #define PVCALLS_FLAG_POLL_INFLIGHT 1
99 #define PVCALLS_FLAG_POLL_RET 2
100 uint8_t flags
__attribute__((aligned(8)));
101 uint32_t inflight_req_id
;
102 struct sock_mapping
*accept_map
;
103 wait_queue_head_t inflight_accept_req
;
108 static inline struct sock_mapping
*pvcalls_enter_sock(struct socket
*sock
)
110 struct sock_mapping
*map
;
112 if (!pvcalls_front_dev
||
113 dev_get_drvdata(&pvcalls_front_dev
->dev
) == NULL
)
114 return ERR_PTR(-ENOTCONN
);
116 map
= (struct sock_mapping
*)sock
->sk
->sk_send_head
;
118 return ERR_PTR(-ENOTSOCK
);
121 atomic_inc(&map
->refcount
);
125 static inline void pvcalls_exit_sock(struct socket
*sock
)
127 struct sock_mapping
*map
;
129 map
= (struct sock_mapping
*)sock
->sk
->sk_send_head
;
130 atomic_dec(&map
->refcount
);
134 static inline int get_request(struct pvcalls_bedata
*bedata
, int *req_id
)
136 *req_id
= bedata
->ring
.req_prod_pvt
& (RING_SIZE(&bedata
->ring
) - 1);
137 if (RING_FULL(&bedata
->ring
) ||
138 bedata
->rsp
[*req_id
].req_id
!= PVCALLS_INVALID_ID
)
143 static bool pvcalls_front_write_todo(struct sock_mapping
*map
)
145 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
146 RING_IDX cons
, prod
, size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
149 error
= intf
->out_error
;
150 if (error
== -ENOTCONN
)
155 cons
= intf
->out_cons
;
156 prod
= intf
->out_prod
;
157 return !!(size
- pvcalls_queued(prod
, cons
, size
));
160 static bool pvcalls_front_read_todo(struct sock_mapping
*map
)
162 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
166 cons
= intf
->in_cons
;
167 prod
= intf
->in_prod
;
168 error
= intf
->in_error
;
169 return (error
!= 0 ||
170 pvcalls_queued(prod
, cons
,
171 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
)) != 0);
174 static irqreturn_t
pvcalls_front_event_handler(int irq
, void *dev_id
)
176 struct xenbus_device
*dev
= dev_id
;
177 struct pvcalls_bedata
*bedata
;
178 struct xen_pvcalls_response
*rsp
;
180 int req_id
= 0, more
= 0, done
= 0;
186 bedata
= dev_get_drvdata(&dev
->dev
);
187 if (bedata
== NULL
) {
193 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata
->ring
)) {
194 rsp
= RING_GET_RESPONSE(&bedata
->ring
, bedata
->ring
.rsp_cons
);
196 req_id
= rsp
->req_id
;
197 if (rsp
->cmd
== PVCALLS_POLL
) {
198 struct sock_mapping
*map
= (struct sock_mapping
*)(uintptr_t)
201 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT
,
202 (void *)&map
->passive
.flags
);
204 * clear INFLIGHT, then set RET. It pairs with
205 * the checks at the beginning of
206 * pvcalls_front_poll_passive.
209 set_bit(PVCALLS_FLAG_POLL_RET
,
210 (void *)&map
->passive
.flags
);
212 dst
= (uint8_t *)&bedata
->rsp
[req_id
] +
214 src
= (uint8_t *)rsp
+ sizeof(rsp
->req_id
);
215 memcpy(dst
, src
, sizeof(*rsp
) - sizeof(rsp
->req_id
));
217 * First copy the rest of the data, then req_id. It is
218 * paired with the barrier when accessing bedata->rsp.
221 bedata
->rsp
[req_id
].req_id
= req_id
;
225 bedata
->ring
.rsp_cons
++;
228 RING_FINAL_CHECK_FOR_RESPONSES(&bedata
->ring
, more
);
232 wake_up(&bedata
->inflight_req
);
237 static void pvcalls_front_free_map(struct pvcalls_bedata
*bedata
,
238 struct sock_mapping
*map
)
242 unbind_from_irqhandler(map
->active
.irq
, map
);
244 spin_lock(&bedata
->socket_lock
);
245 if (!list_empty(&map
->list
))
246 list_del_init(&map
->list
);
247 spin_unlock(&bedata
->socket_lock
);
249 for (i
= 0; i
< (1 << PVCALLS_RING_ORDER
); i
++)
250 gnttab_end_foreign_access(map
->active
.ring
->ref
[i
], 0, 0);
251 gnttab_end_foreign_access(map
->active
.ref
, 0, 0);
252 free_page((unsigned long)map
->active
.ring
);
257 static irqreturn_t
pvcalls_front_conn_handler(int irq
, void *sock_map
)
259 struct sock_mapping
*map
= sock_map
;
264 wake_up_interruptible(&map
->active
.inflight_conn_req
);
269 int pvcalls_front_socket(struct socket
*sock
)
271 struct pvcalls_bedata
*bedata
;
272 struct sock_mapping
*map
= NULL
;
273 struct xen_pvcalls_request
*req
;
274 int notify
, req_id
, ret
;
277 * PVCalls only supports domain AF_INET,
278 * type SOCK_STREAM and protocol 0 sockets for now.
280 * Check socket type here, AF_INET and protocol checks are done
283 if (sock
->type
!= SOCK_STREAM
)
287 if (!pvcalls_front_dev
) {
291 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
293 map
= kzalloc(sizeof(*map
), GFP_KERNEL
);
299 spin_lock(&bedata
->socket_lock
);
301 ret
= get_request(bedata
, &req_id
);
304 spin_unlock(&bedata
->socket_lock
);
310 * sock->sk->sk_send_head is not used for ip sockets: reuse the
311 * field to store a pointer to the struct sock_mapping
312 * corresponding to the socket. This way, we can easily get the
313 * struct sock_mapping from the struct socket.
315 sock
->sk
->sk_send_head
= (void *)map
;
316 list_add_tail(&map
->list
, &bedata
->socket_mappings
);
318 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
319 req
->req_id
= req_id
;
320 req
->cmd
= PVCALLS_SOCKET
;
321 req
->u
.socket
.id
= (uintptr_t) map
;
322 req
->u
.socket
.domain
= AF_INET
;
323 req
->u
.socket
.type
= SOCK_STREAM
;
324 req
->u
.socket
.protocol
= IPPROTO_IP
;
326 bedata
->ring
.req_prod_pvt
++;
327 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
328 spin_unlock(&bedata
->socket_lock
);
330 notify_remote_via_irq(bedata
->irq
);
332 wait_event(bedata
->inflight_req
,
333 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
335 /* read req_id, then the content */
337 ret
= bedata
->rsp
[req_id
].ret
;
338 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
344 static void free_active_ring(struct sock_mapping
*map
)
346 if (!map
->active
.ring
)
349 free_pages((unsigned long)map
->active
.data
.in
,
350 map
->active
.ring
->ring_order
);
351 free_page((unsigned long)map
->active
.ring
);
354 static int alloc_active_ring(struct sock_mapping
*map
)
358 map
->active
.ring
= (struct pvcalls_data_intf
*)
359 get_zeroed_page(GFP_KERNEL
);
360 if (!map
->active
.ring
)
363 map
->active
.ring
->ring_order
= PVCALLS_RING_ORDER
;
364 bytes
= (void *)__get_free_pages(GFP_KERNEL
| __GFP_ZERO
,
369 map
->active
.data
.in
= bytes
;
370 map
->active
.data
.out
= bytes
+
371 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
376 free_active_ring(map
);
380 static int create_active(struct sock_mapping
*map
, int *evtchn
)
383 int ret
= -ENOMEM
, irq
= -1, i
;
386 init_waitqueue_head(&map
->active
.inflight_conn_req
);
388 bytes
= map
->active
.data
.in
;
389 for (i
= 0; i
< (1 << PVCALLS_RING_ORDER
); i
++)
390 map
->active
.ring
->ref
[i
] = gnttab_grant_foreign_access(
391 pvcalls_front_dev
->otherend_id
,
392 pfn_to_gfn(virt_to_pfn(bytes
) + i
), 0);
394 map
->active
.ref
= gnttab_grant_foreign_access(
395 pvcalls_front_dev
->otherend_id
,
396 pfn_to_gfn(virt_to_pfn((void *)map
->active
.ring
)), 0);
398 ret
= xenbus_alloc_evtchn(pvcalls_front_dev
, evtchn
);
401 irq
= bind_evtchn_to_irqhandler(*evtchn
, pvcalls_front_conn_handler
,
402 0, "pvcalls-frontend", map
);
408 map
->active
.irq
= irq
;
409 map
->active_socket
= true;
410 mutex_init(&map
->active
.in_mutex
);
411 mutex_init(&map
->active
.out_mutex
);
417 xenbus_free_evtchn(pvcalls_front_dev
, *evtchn
);
421 int pvcalls_front_connect(struct socket
*sock
, struct sockaddr
*addr
,
422 int addr_len
, int flags
)
424 struct pvcalls_bedata
*bedata
;
425 struct sock_mapping
*map
= NULL
;
426 struct xen_pvcalls_request
*req
;
427 int notify
, req_id
, ret
, evtchn
;
429 if (addr
->sa_family
!= AF_INET
|| sock
->type
!= SOCK_STREAM
)
432 map
= pvcalls_enter_sock(sock
);
436 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
437 ret
= alloc_active_ring(map
);
439 pvcalls_exit_sock(sock
);
443 spin_lock(&bedata
->socket_lock
);
444 ret
= get_request(bedata
, &req_id
);
446 spin_unlock(&bedata
->socket_lock
);
447 free_active_ring(map
);
448 pvcalls_exit_sock(sock
);
451 ret
= create_active(map
, &evtchn
);
453 spin_unlock(&bedata
->socket_lock
);
454 free_active_ring(map
);
455 pvcalls_exit_sock(sock
);
459 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
460 req
->req_id
= req_id
;
461 req
->cmd
= PVCALLS_CONNECT
;
462 req
->u
.connect
.id
= (uintptr_t)map
;
463 req
->u
.connect
.len
= addr_len
;
464 req
->u
.connect
.flags
= flags
;
465 req
->u
.connect
.ref
= map
->active
.ref
;
466 req
->u
.connect
.evtchn
= evtchn
;
467 memcpy(req
->u
.connect
.addr
, addr
, sizeof(*addr
));
471 bedata
->ring
.req_prod_pvt
++;
472 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
473 spin_unlock(&bedata
->socket_lock
);
476 notify_remote_via_irq(bedata
->irq
);
478 wait_event(bedata
->inflight_req
,
479 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
481 /* read req_id, then the content */
483 ret
= bedata
->rsp
[req_id
].ret
;
484 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
485 pvcalls_exit_sock(sock
);
489 static int __write_ring(struct pvcalls_data_intf
*intf
,
490 struct pvcalls_data
*data
,
491 struct iov_iter
*msg_iter
,
494 RING_IDX cons
, prod
, size
, masked_prod
, masked_cons
;
495 RING_IDX array_size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
498 error
= intf
->out_error
;
501 cons
= intf
->out_cons
;
502 prod
= intf
->out_prod
;
503 /* read indexes before continuing */
506 size
= pvcalls_queued(prod
, cons
, array_size
);
507 if (size
> array_size
)
509 if (size
== array_size
)
511 if (len
> array_size
- size
)
512 len
= array_size
- size
;
514 masked_prod
= pvcalls_mask(prod
, array_size
);
515 masked_cons
= pvcalls_mask(cons
, array_size
);
517 if (masked_prod
< masked_cons
) {
518 len
= copy_from_iter(data
->out
+ masked_prod
, len
, msg_iter
);
520 if (len
> array_size
- masked_prod
) {
521 int ret
= copy_from_iter(data
->out
+ masked_prod
,
522 array_size
- masked_prod
, msg_iter
);
523 if (ret
!= array_size
- masked_prod
) {
527 len
= ret
+ copy_from_iter(data
->out
, len
- ret
, msg_iter
);
529 len
= copy_from_iter(data
->out
+ masked_prod
, len
, msg_iter
);
533 /* write to ring before updating pointer */
535 intf
->out_prod
+= len
;
540 int pvcalls_front_sendmsg(struct socket
*sock
, struct msghdr
*msg
,
543 struct pvcalls_bedata
*bedata
;
544 struct sock_mapping
*map
;
545 int sent
, tot_sent
= 0;
546 int count
= 0, flags
;
548 flags
= msg
->msg_flags
;
549 if (flags
& (MSG_CONFIRM
|MSG_DONTROUTE
|MSG_EOR
|MSG_OOB
))
552 map
= pvcalls_enter_sock(sock
);
555 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
557 mutex_lock(&map
->active
.out_mutex
);
558 if ((flags
& MSG_DONTWAIT
) && !pvcalls_front_write_todo(map
)) {
559 mutex_unlock(&map
->active
.out_mutex
);
560 pvcalls_exit_sock(sock
);
568 sent
= __write_ring(map
->active
.ring
,
569 &map
->active
.data
, &msg
->msg_iter
,
574 notify_remote_via_irq(map
->active
.irq
);
576 if (sent
>= 0 && len
> 0 && count
< PVCALLS_FRONT_MAX_SPIN
)
581 mutex_unlock(&map
->active
.out_mutex
);
582 pvcalls_exit_sock(sock
);
586 static int __read_ring(struct pvcalls_data_intf
*intf
,
587 struct pvcalls_data
*data
,
588 struct iov_iter
*msg_iter
,
589 size_t len
, int flags
)
591 RING_IDX cons
, prod
, size
, masked_prod
, masked_cons
;
592 RING_IDX array_size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
595 cons
= intf
->in_cons
;
596 prod
= intf
->in_prod
;
597 error
= intf
->in_error
;
598 /* get pointers before reading from the ring */
601 size
= pvcalls_queued(prod
, cons
, array_size
);
602 masked_prod
= pvcalls_mask(prod
, array_size
);
603 masked_cons
= pvcalls_mask(cons
, array_size
);
606 return error
?: size
;
611 if (masked_prod
> masked_cons
) {
612 len
= copy_to_iter(data
->in
+ masked_cons
, len
, msg_iter
);
614 if (len
> (array_size
- masked_cons
)) {
615 int ret
= copy_to_iter(data
->in
+ masked_cons
,
616 array_size
- masked_cons
, msg_iter
);
617 if (ret
!= array_size
- masked_cons
) {
621 len
= ret
+ copy_to_iter(data
->in
, len
- ret
, msg_iter
);
623 len
= copy_to_iter(data
->in
+ masked_cons
, len
, msg_iter
);
627 /* read data from the ring before increasing the index */
629 if (!(flags
& MSG_PEEK
))
630 intf
->in_cons
+= len
;
635 int pvcalls_front_recvmsg(struct socket
*sock
, struct msghdr
*msg
, size_t len
,
638 struct pvcalls_bedata
*bedata
;
640 struct sock_mapping
*map
;
642 if (flags
& (MSG_CMSG_CLOEXEC
|MSG_ERRQUEUE
|MSG_OOB
|MSG_TRUNC
))
645 map
= pvcalls_enter_sock(sock
);
648 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
650 mutex_lock(&map
->active
.in_mutex
);
651 if (len
> XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
))
652 len
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
654 while (!(flags
& MSG_DONTWAIT
) && !pvcalls_front_read_todo(map
)) {
655 wait_event_interruptible(map
->active
.inflight_conn_req
,
656 pvcalls_front_read_todo(map
));
658 ret
= __read_ring(map
->active
.ring
, &map
->active
.data
,
659 &msg
->msg_iter
, len
, flags
);
662 notify_remote_via_irq(map
->active
.irq
);
664 ret
= (flags
& MSG_DONTWAIT
) ? -EAGAIN
: 0;
665 if (ret
== -ENOTCONN
)
668 mutex_unlock(&map
->active
.in_mutex
);
669 pvcalls_exit_sock(sock
);
673 int pvcalls_front_bind(struct socket
*sock
, struct sockaddr
*addr
, int addr_len
)
675 struct pvcalls_bedata
*bedata
;
676 struct sock_mapping
*map
= NULL
;
677 struct xen_pvcalls_request
*req
;
678 int notify
, req_id
, ret
;
680 if (addr
->sa_family
!= AF_INET
|| sock
->type
!= SOCK_STREAM
)
683 map
= pvcalls_enter_sock(sock
);
686 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
688 spin_lock(&bedata
->socket_lock
);
689 ret
= get_request(bedata
, &req_id
);
691 spin_unlock(&bedata
->socket_lock
);
692 pvcalls_exit_sock(sock
);
695 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
696 req
->req_id
= req_id
;
698 req
->cmd
= PVCALLS_BIND
;
699 req
->u
.bind
.id
= (uintptr_t)map
;
700 memcpy(req
->u
.bind
.addr
, addr
, sizeof(*addr
));
701 req
->u
.bind
.len
= addr_len
;
703 init_waitqueue_head(&map
->passive
.inflight_accept_req
);
705 map
->active_socket
= false;
707 bedata
->ring
.req_prod_pvt
++;
708 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
709 spin_unlock(&bedata
->socket_lock
);
711 notify_remote_via_irq(bedata
->irq
);
713 wait_event(bedata
->inflight_req
,
714 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
716 /* read req_id, then the content */
718 ret
= bedata
->rsp
[req_id
].ret
;
719 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
721 map
->passive
.status
= PVCALLS_STATUS_BIND
;
722 pvcalls_exit_sock(sock
);
726 int pvcalls_front_listen(struct socket
*sock
, int backlog
)
728 struct pvcalls_bedata
*bedata
;
729 struct sock_mapping
*map
;
730 struct xen_pvcalls_request
*req
;
731 int notify
, req_id
, ret
;
733 map
= pvcalls_enter_sock(sock
);
736 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
738 if (map
->passive
.status
!= PVCALLS_STATUS_BIND
) {
739 pvcalls_exit_sock(sock
);
743 spin_lock(&bedata
->socket_lock
);
744 ret
= get_request(bedata
, &req_id
);
746 spin_unlock(&bedata
->socket_lock
);
747 pvcalls_exit_sock(sock
);
750 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
751 req
->req_id
= req_id
;
752 req
->cmd
= PVCALLS_LISTEN
;
753 req
->u
.listen
.id
= (uintptr_t) map
;
754 req
->u
.listen
.backlog
= backlog
;
756 bedata
->ring
.req_prod_pvt
++;
757 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
758 spin_unlock(&bedata
->socket_lock
);
760 notify_remote_via_irq(bedata
->irq
);
762 wait_event(bedata
->inflight_req
,
763 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
765 /* read req_id, then the content */
767 ret
= bedata
->rsp
[req_id
].ret
;
768 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
770 map
->passive
.status
= PVCALLS_STATUS_LISTEN
;
771 pvcalls_exit_sock(sock
);
775 int pvcalls_front_accept(struct socket
*sock
, struct socket
*newsock
, int flags
)
777 struct pvcalls_bedata
*bedata
;
778 struct sock_mapping
*map
;
779 struct sock_mapping
*map2
= NULL
;
780 struct xen_pvcalls_request
*req
;
781 int notify
, req_id
, ret
, evtchn
, nonblock
;
783 map
= pvcalls_enter_sock(sock
);
786 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
788 if (map
->passive
.status
!= PVCALLS_STATUS_LISTEN
) {
789 pvcalls_exit_sock(sock
);
793 nonblock
= flags
& SOCK_NONBLOCK
;
795 * Backend only supports 1 inflight accept request, will return
796 * errors for the others
798 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
799 (void *)&map
->passive
.flags
)) {
800 req_id
= READ_ONCE(map
->passive
.inflight_req_id
);
801 if (req_id
!= PVCALLS_INVALID_ID
&&
802 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
) {
803 map2
= map
->passive
.accept_map
;
807 pvcalls_exit_sock(sock
);
810 if (wait_event_interruptible(map
->passive
.inflight_accept_req
,
811 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
812 (void *)&map
->passive
.flags
))) {
813 pvcalls_exit_sock(sock
);
818 map2
= kzalloc(sizeof(*map2
), GFP_KERNEL
);
820 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
821 (void *)&map
->passive
.flags
);
822 pvcalls_exit_sock(sock
);
825 ret
= alloc_active_ring(map2
);
827 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
828 (void *)&map
->passive
.flags
);
830 pvcalls_exit_sock(sock
);
833 spin_lock(&bedata
->socket_lock
);
834 ret
= get_request(bedata
, &req_id
);
836 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
837 (void *)&map
->passive
.flags
);
838 spin_unlock(&bedata
->socket_lock
);
839 free_active_ring(map2
);
841 pvcalls_exit_sock(sock
);
845 ret
= create_active(map2
, &evtchn
);
847 free_active_ring(map2
);
849 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
850 (void *)&map
->passive
.flags
);
851 spin_unlock(&bedata
->socket_lock
);
852 pvcalls_exit_sock(sock
);
855 list_add_tail(&map2
->list
, &bedata
->socket_mappings
);
857 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
858 req
->req_id
= req_id
;
859 req
->cmd
= PVCALLS_ACCEPT
;
860 req
->u
.accept
.id
= (uintptr_t) map
;
861 req
->u
.accept
.ref
= map2
->active
.ref
;
862 req
->u
.accept
.id_new
= (uintptr_t) map2
;
863 req
->u
.accept
.evtchn
= evtchn
;
864 map
->passive
.accept_map
= map2
;
866 bedata
->ring
.req_prod_pvt
++;
867 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
868 spin_unlock(&bedata
->socket_lock
);
870 notify_remote_via_irq(bedata
->irq
);
871 /* We could check if we have received a response before returning. */
873 WRITE_ONCE(map
->passive
.inflight_req_id
, req_id
);
874 pvcalls_exit_sock(sock
);
878 if (wait_event_interruptible(bedata
->inflight_req
,
879 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
)) {
880 pvcalls_exit_sock(sock
);
883 /* read req_id, then the content */
887 map2
->sock
= newsock
;
888 newsock
->sk
= sk_alloc(sock_net(sock
->sk
), PF_INET
, GFP_KERNEL
, &pvcalls_proto
, false);
890 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
891 map
->passive
.inflight_req_id
= PVCALLS_INVALID_ID
;
892 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
893 (void *)&map
->passive
.flags
);
894 pvcalls_front_free_map(bedata
, map2
);
895 pvcalls_exit_sock(sock
);
898 newsock
->sk
->sk_send_head
= (void *)map2
;
900 ret
= bedata
->rsp
[req_id
].ret
;
901 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
902 map
->passive
.inflight_req_id
= PVCALLS_INVALID_ID
;
904 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
, (void *)&map
->passive
.flags
);
905 wake_up(&map
->passive
.inflight_accept_req
);
907 pvcalls_exit_sock(sock
);
911 static __poll_t
pvcalls_front_poll_passive(struct file
*file
,
912 struct pvcalls_bedata
*bedata
,
913 struct sock_mapping
*map
,
916 int notify
, req_id
, ret
;
917 struct xen_pvcalls_request
*req
;
919 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
920 (void *)&map
->passive
.flags
)) {
921 uint32_t req_id
= READ_ONCE(map
->passive
.inflight_req_id
);
923 if (req_id
!= PVCALLS_INVALID_ID
&&
924 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
)
925 return EPOLLIN
| EPOLLRDNORM
;
927 poll_wait(file
, &map
->passive
.inflight_accept_req
, wait
);
931 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET
,
932 (void *)&map
->passive
.flags
))
933 return EPOLLIN
| EPOLLRDNORM
;
936 * First check RET, then INFLIGHT. No barriers necessary to
937 * ensure execution ordering because of the conditional
938 * instructions creating control dependencies.
941 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT
,
942 (void *)&map
->passive
.flags
)) {
943 poll_wait(file
, &bedata
->inflight_req
, wait
);
947 spin_lock(&bedata
->socket_lock
);
948 ret
= get_request(bedata
, &req_id
);
950 spin_unlock(&bedata
->socket_lock
);
953 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
954 req
->req_id
= req_id
;
955 req
->cmd
= PVCALLS_POLL
;
956 req
->u
.poll
.id
= (uintptr_t) map
;
958 bedata
->ring
.req_prod_pvt
++;
959 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
960 spin_unlock(&bedata
->socket_lock
);
962 notify_remote_via_irq(bedata
->irq
);
964 poll_wait(file
, &bedata
->inflight_req
, wait
);
968 static __poll_t
pvcalls_front_poll_active(struct file
*file
,
969 struct pvcalls_bedata
*bedata
,
970 struct sock_mapping
*map
,
974 int32_t in_error
, out_error
;
975 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
977 out_error
= intf
->out_error
;
978 in_error
= intf
->in_error
;
980 poll_wait(file
, &map
->active
.inflight_conn_req
, wait
);
981 if (pvcalls_front_write_todo(map
))
982 mask
|= EPOLLOUT
| EPOLLWRNORM
;
983 if (pvcalls_front_read_todo(map
))
984 mask
|= EPOLLIN
| EPOLLRDNORM
;
985 if (in_error
!= 0 || out_error
!= 0)
991 __poll_t
pvcalls_front_poll(struct file
*file
, struct socket
*sock
,
994 struct pvcalls_bedata
*bedata
;
995 struct sock_mapping
*map
;
998 map
= pvcalls_enter_sock(sock
);
1001 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1003 if (map
->active_socket
)
1004 ret
= pvcalls_front_poll_active(file
, bedata
, map
, wait
);
1006 ret
= pvcalls_front_poll_passive(file
, bedata
, map
, wait
);
1007 pvcalls_exit_sock(sock
);
1011 int pvcalls_front_release(struct socket
*sock
)
1013 struct pvcalls_bedata
*bedata
;
1014 struct sock_mapping
*map
;
1015 int req_id
, notify
, ret
;
1016 struct xen_pvcalls_request
*req
;
1018 if (sock
->sk
== NULL
)
1021 map
= pvcalls_enter_sock(sock
);
1023 if (PTR_ERR(map
) == -ENOTCONN
)
1028 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1030 spin_lock(&bedata
->socket_lock
);
1031 ret
= get_request(bedata
, &req_id
);
1033 spin_unlock(&bedata
->socket_lock
);
1034 pvcalls_exit_sock(sock
);
1037 sock
->sk
->sk_send_head
= NULL
;
1039 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
1040 req
->req_id
= req_id
;
1041 req
->cmd
= PVCALLS_RELEASE
;
1042 req
->u
.release
.id
= (uintptr_t)map
;
1044 bedata
->ring
.req_prod_pvt
++;
1045 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
1046 spin_unlock(&bedata
->socket_lock
);
1048 notify_remote_via_irq(bedata
->irq
);
1050 wait_event(bedata
->inflight_req
,
1051 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
1053 if (map
->active_socket
) {
1055 * Set in_error and wake up inflight_conn_req to force
1056 * recvmsg waiters to exit.
1058 map
->active
.ring
->in_error
= -EBADF
;
1059 wake_up_interruptible(&map
->active
.inflight_conn_req
);
1062 * We need to make sure that sendmsg/recvmsg on this socket have
1063 * not started before we've cleared sk_send_head here. The
1064 * easiest way to guarantee this is to see that no pvcalls
1065 * (other than us) is in progress on this socket.
1067 while (atomic_read(&map
->refcount
) > 1)
1070 pvcalls_front_free_map(bedata
, map
);
1072 wake_up(&bedata
->inflight_req
);
1073 wake_up(&map
->passive
.inflight_accept_req
);
1075 while (atomic_read(&map
->refcount
) > 1)
1078 spin_lock(&bedata
->socket_lock
);
1079 list_del(&map
->list
);
1080 spin_unlock(&bedata
->socket_lock
);
1081 if (READ_ONCE(map
->passive
.inflight_req_id
) != PVCALLS_INVALID_ID
&&
1082 READ_ONCE(map
->passive
.inflight_req_id
) != 0) {
1083 pvcalls_front_free_map(bedata
,
1084 map
->passive
.accept_map
);
1088 WRITE_ONCE(bedata
->rsp
[req_id
].req_id
, PVCALLS_INVALID_ID
);
1094 static const struct xenbus_device_id pvcalls_front_ids
[] = {
1099 static int pvcalls_front_remove(struct xenbus_device
*dev
)
1101 struct pvcalls_bedata
*bedata
;
1102 struct sock_mapping
*map
= NULL
, *n
;
1104 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1105 dev_set_drvdata(&dev
->dev
, NULL
);
1106 pvcalls_front_dev
= NULL
;
1107 if (bedata
->irq
>= 0)
1108 unbind_from_irqhandler(bedata
->irq
, dev
);
1110 list_for_each_entry_safe(map
, n
, &bedata
->socket_mappings
, list
) {
1111 map
->sock
->sk
->sk_send_head
= NULL
;
1112 if (map
->active_socket
) {
1113 map
->active
.ring
->in_error
= -EBADF
;
1114 wake_up_interruptible(&map
->active
.inflight_conn_req
);
1119 while (atomic_read(&pvcalls_refcount
) > 0)
1121 list_for_each_entry_safe(map
, n
, &bedata
->socket_mappings
, list
) {
1122 if (map
->active_socket
) {
1123 /* No need to lock, refcount is 0 */
1124 pvcalls_front_free_map(bedata
, map
);
1126 list_del(&map
->list
);
1130 if (bedata
->ref
!= -1)
1131 gnttab_end_foreign_access(bedata
->ref
, 0, 0);
1132 kfree(bedata
->ring
.sring
);
1134 xenbus_switch_state(dev
, XenbusStateClosed
);
1138 static int pvcalls_front_probe(struct xenbus_device
*dev
,
1139 const struct xenbus_device_id
*id
)
1141 int ret
= -ENOMEM
, evtchn
, i
;
1142 unsigned int max_page_order
, function_calls
, len
;
1144 grant_ref_t gref_head
= 0;
1145 struct xenbus_transaction xbt
;
1146 struct pvcalls_bedata
*bedata
= NULL
;
1147 struct xen_pvcalls_sring
*sring
;
1149 if (pvcalls_front_dev
!= NULL
) {
1150 dev_err(&dev
->dev
, "only one PV Calls connection supported\n");
1154 versions
= xenbus_read(XBT_NIL
, dev
->otherend
, "versions", &len
);
1155 if (IS_ERR(versions
))
1156 return PTR_ERR(versions
);
1159 if (strcmp(versions
, "1")) {
1164 max_page_order
= xenbus_read_unsigned(dev
->otherend
,
1165 "max-page-order", 0);
1166 if (max_page_order
< PVCALLS_RING_ORDER
)
1168 function_calls
= xenbus_read_unsigned(dev
->otherend
,
1169 "function-calls", 0);
1170 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1171 if (function_calls
!= 1)
1173 pr_info("%s max-page-order is %u\n", __func__
, max_page_order
);
1175 bedata
= kzalloc(sizeof(struct pvcalls_bedata
), GFP_KERNEL
);
1179 dev_set_drvdata(&dev
->dev
, bedata
);
1180 pvcalls_front_dev
= dev
;
1181 init_waitqueue_head(&bedata
->inflight_req
);
1182 INIT_LIST_HEAD(&bedata
->socket_mappings
);
1183 spin_lock_init(&bedata
->socket_lock
);
1187 for (i
= 0; i
< PVCALLS_NR_RSP_PER_RING
; i
++)
1188 bedata
->rsp
[i
].req_id
= PVCALLS_INVALID_ID
;
1190 sring
= (struct xen_pvcalls_sring
*) __get_free_page(GFP_KERNEL
|
1194 SHARED_RING_INIT(sring
);
1195 FRONT_RING_INIT(&bedata
->ring
, sring
, XEN_PAGE_SIZE
);
1197 ret
= xenbus_alloc_evtchn(dev
, &evtchn
);
1201 bedata
->irq
= bind_evtchn_to_irqhandler(evtchn
,
1202 pvcalls_front_event_handler
,
1203 0, "pvcalls-frontend", dev
);
1204 if (bedata
->irq
< 0) {
1209 ret
= gnttab_alloc_grant_references(1, &gref_head
);
1212 ret
= gnttab_claim_grant_reference(&gref_head
);
1216 gnttab_grant_foreign_access_ref(bedata
->ref
, dev
->otherend_id
,
1217 virt_to_gfn((void *)sring
), 0);
1220 ret
= xenbus_transaction_start(&xbt
);
1222 xenbus_dev_fatal(dev
, ret
, "starting transaction");
1225 ret
= xenbus_printf(xbt
, dev
->nodename
, "version", "%u", 1);
1228 ret
= xenbus_printf(xbt
, dev
->nodename
, "ring-ref", "%d", bedata
->ref
);
1231 ret
= xenbus_printf(xbt
, dev
->nodename
, "port", "%u",
1235 ret
= xenbus_transaction_end(xbt
, 0);
1239 xenbus_dev_fatal(dev
, ret
, "completing transaction");
1242 xenbus_switch_state(dev
, XenbusStateInitialised
);
1247 xenbus_transaction_end(xbt
, 1);
1248 xenbus_dev_fatal(dev
, ret
, "writing xenstore");
1250 pvcalls_front_remove(dev
);
1254 static void pvcalls_front_changed(struct xenbus_device
*dev
,
1255 enum xenbus_state backend_state
)
1257 switch (backend_state
) {
1258 case XenbusStateReconfiguring
:
1259 case XenbusStateReconfigured
:
1260 case XenbusStateInitialising
:
1261 case XenbusStateInitialised
:
1262 case XenbusStateUnknown
:
1265 case XenbusStateInitWait
:
1268 case XenbusStateConnected
:
1269 xenbus_switch_state(dev
, XenbusStateConnected
);
1272 case XenbusStateClosed
:
1273 if (dev
->state
== XenbusStateClosed
)
1275 /* Missed the backend's CLOSING state */
1277 case XenbusStateClosing
:
1278 xenbus_frontend_closed(dev
);
1283 static struct xenbus_driver pvcalls_front_driver
= {
1284 .ids
= pvcalls_front_ids
,
1285 .probe
= pvcalls_front_probe
,
1286 .remove
= pvcalls_front_remove
,
1287 .otherend_changed
= pvcalls_front_changed
,
1290 static int __init
pvcalls_frontend_init(void)
1295 pr_info("Initialising Xen pvcalls frontend driver\n");
1297 return xenbus_register_frontend(&pvcalls_front_driver
);
1300 module_init(pvcalls_frontend_init
);
1302 MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1303 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1304 MODULE_LICENSE("GPL");