1 // SPDX-License-Identifier: GPL-2.0-or-later
3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
6 #include <linux/module.h>
8 #include <linux/socket.h>
12 #include <xen/events.h>
13 #include <xen/grant_table.h>
15 #include <xen/xenbus.h>
16 #include <xen/interface/io/pvcalls.h>
18 #include "pvcalls-front.h"
20 #define PVCALLS_INVALID_ID UINT_MAX
21 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
22 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
23 #define PVCALLS_FRONT_MAX_SPIN 5000
25 static struct proto pvcalls_proto
= {
28 .obj_size
= sizeof(struct sock
),
31 struct pvcalls_bedata
{
32 struct xen_pvcalls_front_ring ring
;
36 struct list_head socket_mappings
;
37 spinlock_t socket_lock
;
39 wait_queue_head_t inflight_req
;
40 struct xen_pvcalls_response rsp
[PVCALLS_NR_RSP_PER_RING
];
42 /* Only one front/back connection supported. */
43 static struct xenbus_device
*pvcalls_front_dev
;
44 static atomic_t pvcalls_refcount
;
46 /* first increment refcount, then proceed */
47 #define pvcalls_enter() { \
48 atomic_inc(&pvcalls_refcount); \
51 /* first complete other operations, then decrement refcount */
52 #define pvcalls_exit() { \
53 atomic_dec(&pvcalls_refcount); \
58 struct list_head list
;
65 struct pvcalls_data_intf
*ring
;
66 struct pvcalls_data data
;
67 struct mutex in_mutex
;
68 struct mutex out_mutex
;
70 wait_queue_head_t inflight_conn_req
;
74 * Socket status, needs to be 64-bit aligned due to the
75 * test_and_* functions which have this requirement on arm64.
77 #define PVCALLS_STATUS_UNINITALIZED 0
78 #define PVCALLS_STATUS_BIND 1
79 #define PVCALLS_STATUS_LISTEN 2
80 uint8_t status
__attribute__((aligned(8)));
82 * Internal state-machine flags.
83 * Only one accept operation can be inflight for a socket.
84 * Only one poll operation can be inflight for a given socket.
85 * flags needs to be 64-bit aligned due to the test_and_*
86 * functions which have this requirement on arm64.
88 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
89 #define PVCALLS_FLAG_POLL_INFLIGHT 1
90 #define PVCALLS_FLAG_POLL_RET 2
91 uint8_t flags
__attribute__((aligned(8)));
92 uint32_t inflight_req_id
;
93 struct sock_mapping
*accept_map
;
94 wait_queue_head_t inflight_accept_req
;
99 static inline struct sock_mapping
*pvcalls_enter_sock(struct socket
*sock
)
101 struct sock_mapping
*map
;
103 if (!pvcalls_front_dev
||
104 dev_get_drvdata(&pvcalls_front_dev
->dev
) == NULL
)
105 return ERR_PTR(-ENOTCONN
);
107 map
= (struct sock_mapping
*)sock
->sk
->sk_send_head
;
109 return ERR_PTR(-ENOTSOCK
);
112 atomic_inc(&map
->refcount
);
116 static inline void pvcalls_exit_sock(struct socket
*sock
)
118 struct sock_mapping
*map
;
120 map
= (struct sock_mapping
*)sock
->sk
->sk_send_head
;
121 atomic_dec(&map
->refcount
);
125 static inline int get_request(struct pvcalls_bedata
*bedata
, int *req_id
)
127 *req_id
= bedata
->ring
.req_prod_pvt
& (RING_SIZE(&bedata
->ring
) - 1);
128 if (RING_FULL(&bedata
->ring
) ||
129 bedata
->rsp
[*req_id
].req_id
!= PVCALLS_INVALID_ID
)
134 static bool pvcalls_front_write_todo(struct sock_mapping
*map
)
136 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
137 RING_IDX cons
, prod
, size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
140 error
= intf
->out_error
;
141 if (error
== -ENOTCONN
)
146 cons
= intf
->out_cons
;
147 prod
= intf
->out_prod
;
148 return !!(size
- pvcalls_queued(prod
, cons
, size
));
151 static bool pvcalls_front_read_todo(struct sock_mapping
*map
)
153 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
157 cons
= intf
->in_cons
;
158 prod
= intf
->in_prod
;
159 error
= intf
->in_error
;
160 return (error
!= 0 ||
161 pvcalls_queued(prod
, cons
,
162 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
)) != 0);
165 static irqreturn_t
pvcalls_front_event_handler(int irq
, void *dev_id
)
167 struct xenbus_device
*dev
= dev_id
;
168 struct pvcalls_bedata
*bedata
;
169 struct xen_pvcalls_response
*rsp
;
171 int req_id
= 0, more
= 0, done
= 0;
177 bedata
= dev_get_drvdata(&dev
->dev
);
178 if (bedata
== NULL
) {
184 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata
->ring
)) {
185 rsp
= RING_GET_RESPONSE(&bedata
->ring
, bedata
->ring
.rsp_cons
);
187 req_id
= rsp
->req_id
;
188 if (rsp
->cmd
== PVCALLS_POLL
) {
189 struct sock_mapping
*map
= (struct sock_mapping
*)(uintptr_t)
192 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT
,
193 (void *)&map
->passive
.flags
);
195 * clear INFLIGHT, then set RET. It pairs with
196 * the checks at the beginning of
197 * pvcalls_front_poll_passive.
200 set_bit(PVCALLS_FLAG_POLL_RET
,
201 (void *)&map
->passive
.flags
);
203 dst
= (uint8_t *)&bedata
->rsp
[req_id
] +
205 src
= (uint8_t *)rsp
+ sizeof(rsp
->req_id
);
206 memcpy(dst
, src
, sizeof(*rsp
) - sizeof(rsp
->req_id
));
208 * First copy the rest of the data, then req_id. It is
209 * paired with the barrier when accessing bedata->rsp.
212 bedata
->rsp
[req_id
].req_id
= req_id
;
216 bedata
->ring
.rsp_cons
++;
219 RING_FINAL_CHECK_FOR_RESPONSES(&bedata
->ring
, more
);
223 wake_up(&bedata
->inflight_req
);
228 static void free_active_ring(struct sock_mapping
*map
);
230 static void pvcalls_front_destroy_active(struct pvcalls_bedata
*bedata
,
231 struct sock_mapping
*map
)
235 unbind_from_irqhandler(map
->active
.irq
, map
);
238 spin_lock(&bedata
->socket_lock
);
239 if (!list_empty(&map
->list
))
240 list_del_init(&map
->list
);
241 spin_unlock(&bedata
->socket_lock
);
244 for (i
= 0; i
< (1 << PVCALLS_RING_ORDER
); i
++)
245 gnttab_end_foreign_access(map
->active
.ring
->ref
[i
], NULL
);
246 gnttab_end_foreign_access(map
->active
.ref
, NULL
);
247 free_active_ring(map
);
250 static void pvcalls_front_free_map(struct pvcalls_bedata
*bedata
,
251 struct sock_mapping
*map
)
253 pvcalls_front_destroy_active(bedata
, map
);
258 static irqreturn_t
pvcalls_front_conn_handler(int irq
, void *sock_map
)
260 struct sock_mapping
*map
= sock_map
;
265 wake_up_interruptible(&map
->active
.inflight_conn_req
);
270 int pvcalls_front_socket(struct socket
*sock
)
272 struct pvcalls_bedata
*bedata
;
273 struct sock_mapping
*map
= NULL
;
274 struct xen_pvcalls_request
*req
;
275 int notify
, req_id
, ret
;
278 * PVCalls only supports domain AF_INET,
279 * type SOCK_STREAM and protocol 0 sockets for now.
281 * Check socket type here, AF_INET and protocol checks are done
284 if (sock
->type
!= SOCK_STREAM
)
288 if (!pvcalls_front_dev
) {
292 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
294 map
= kzalloc(sizeof(*map
), GFP_KERNEL
);
300 spin_lock(&bedata
->socket_lock
);
302 ret
= get_request(bedata
, &req_id
);
305 spin_unlock(&bedata
->socket_lock
);
311 * sock->sk->sk_send_head is not used for ip sockets: reuse the
312 * field to store a pointer to the struct sock_mapping
313 * corresponding to the socket. This way, we can easily get the
314 * struct sock_mapping from the struct socket.
316 sock
->sk
->sk_send_head
= (void *)map
;
317 list_add_tail(&map
->list
, &bedata
->socket_mappings
);
319 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
320 req
->req_id
= req_id
;
321 req
->cmd
= PVCALLS_SOCKET
;
322 req
->u
.socket
.id
= (uintptr_t) map
;
323 req
->u
.socket
.domain
= AF_INET
;
324 req
->u
.socket
.type
= SOCK_STREAM
;
325 req
->u
.socket
.protocol
= IPPROTO_IP
;
327 bedata
->ring
.req_prod_pvt
++;
328 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
329 spin_unlock(&bedata
->socket_lock
);
331 notify_remote_via_irq(bedata
->irq
);
333 wait_event(bedata
->inflight_req
,
334 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
336 /* read req_id, then the content */
338 ret
= bedata
->rsp
[req_id
].ret
;
339 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
345 static void free_active_ring(struct sock_mapping
*map
)
347 if (!map
->active
.ring
)
350 free_pages_exact(map
->active
.data
.in
,
351 PAGE_SIZE
<< map
->active
.ring
->ring_order
);
352 free_page((unsigned long)map
->active
.ring
);
355 static int alloc_active_ring(struct sock_mapping
*map
)
359 map
->active
.ring
= (struct pvcalls_data_intf
*)
360 get_zeroed_page(GFP_KERNEL
);
361 if (!map
->active
.ring
)
364 map
->active
.ring
->ring_order
= PVCALLS_RING_ORDER
;
365 bytes
= alloc_pages_exact(PAGE_SIZE
<< PVCALLS_RING_ORDER
,
366 GFP_KERNEL
| __GFP_ZERO
);
370 map
->active
.data
.in
= bytes
;
371 map
->active
.data
.out
= bytes
+
372 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
377 free_active_ring(map
);
381 static int create_active(struct sock_mapping
*map
, evtchn_port_t
*evtchn
)
384 int ret
, irq
= -1, i
;
387 init_waitqueue_head(&map
->active
.inflight_conn_req
);
389 bytes
= map
->active
.data
.in
;
390 for (i
= 0; i
< (1 << PVCALLS_RING_ORDER
); i
++)
391 map
->active
.ring
->ref
[i
] = gnttab_grant_foreign_access(
392 pvcalls_front_dev
->otherend_id
,
393 pfn_to_gfn(virt_to_pfn(bytes
) + i
), 0);
395 map
->active
.ref
= gnttab_grant_foreign_access(
396 pvcalls_front_dev
->otherend_id
,
397 pfn_to_gfn(virt_to_pfn((void *)map
->active
.ring
)), 0);
399 ret
= xenbus_alloc_evtchn(pvcalls_front_dev
, evtchn
);
402 irq
= bind_evtchn_to_irqhandler(*evtchn
, pvcalls_front_conn_handler
,
403 0, "pvcalls-frontend", map
);
409 map
->active
.irq
= irq
;
410 map
->active_socket
= true;
411 mutex_init(&map
->active
.in_mutex
);
412 mutex_init(&map
->active
.out_mutex
);
418 xenbus_free_evtchn(pvcalls_front_dev
, *evtchn
);
422 int pvcalls_front_connect(struct socket
*sock
, struct sockaddr
*addr
,
423 int addr_len
, int flags
)
425 struct pvcalls_bedata
*bedata
;
426 struct sock_mapping
*map
= NULL
;
427 struct xen_pvcalls_request
*req
;
428 int notify
, req_id
, ret
;
429 evtchn_port_t evtchn
;
431 if (addr
->sa_family
!= AF_INET
|| sock
->type
!= SOCK_STREAM
)
434 map
= pvcalls_enter_sock(sock
);
438 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
439 ret
= alloc_active_ring(map
);
441 pvcalls_exit_sock(sock
);
444 ret
= create_active(map
, &evtchn
);
446 free_active_ring(map
);
447 pvcalls_exit_sock(sock
);
451 spin_lock(&bedata
->socket_lock
);
452 ret
= get_request(bedata
, &req_id
);
454 spin_unlock(&bedata
->socket_lock
);
455 pvcalls_front_destroy_active(NULL
, map
);
456 pvcalls_exit_sock(sock
);
460 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
461 req
->req_id
= req_id
;
462 req
->cmd
= PVCALLS_CONNECT
;
463 req
->u
.connect
.id
= (uintptr_t)map
;
464 req
->u
.connect
.len
= addr_len
;
465 req
->u
.connect
.flags
= flags
;
466 req
->u
.connect
.ref
= map
->active
.ref
;
467 req
->u
.connect
.evtchn
= evtchn
;
468 memcpy(req
->u
.connect
.addr
, addr
, sizeof(*addr
));
472 bedata
->ring
.req_prod_pvt
++;
473 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
474 spin_unlock(&bedata
->socket_lock
);
477 notify_remote_via_irq(bedata
->irq
);
479 wait_event(bedata
->inflight_req
,
480 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
482 /* read req_id, then the content */
484 ret
= bedata
->rsp
[req_id
].ret
;
485 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
486 pvcalls_exit_sock(sock
);
490 static int __write_ring(struct pvcalls_data_intf
*intf
,
491 struct pvcalls_data
*data
,
492 struct iov_iter
*msg_iter
,
495 RING_IDX cons
, prod
, size
, masked_prod
, masked_cons
;
496 RING_IDX array_size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
499 error
= intf
->out_error
;
502 cons
= intf
->out_cons
;
503 prod
= intf
->out_prod
;
504 /* read indexes before continuing */
507 size
= pvcalls_queued(prod
, cons
, array_size
);
508 if (size
> array_size
)
510 if (size
== array_size
)
512 if (len
> array_size
- size
)
513 len
= array_size
- size
;
515 masked_prod
= pvcalls_mask(prod
, array_size
);
516 masked_cons
= pvcalls_mask(cons
, array_size
);
518 if (masked_prod
< masked_cons
) {
519 len
= copy_from_iter(data
->out
+ masked_prod
, len
, msg_iter
);
521 if (len
> array_size
- masked_prod
) {
522 int ret
= copy_from_iter(data
->out
+ masked_prod
,
523 array_size
- masked_prod
, msg_iter
);
524 if (ret
!= array_size
- masked_prod
) {
528 len
= ret
+ copy_from_iter(data
->out
, len
- ret
, msg_iter
);
530 len
= copy_from_iter(data
->out
+ masked_prod
, len
, msg_iter
);
534 /* write to ring before updating pointer */
536 intf
->out_prod
+= len
;
541 int pvcalls_front_sendmsg(struct socket
*sock
, struct msghdr
*msg
,
544 struct sock_mapping
*map
;
545 int sent
, tot_sent
= 0;
546 int count
= 0, flags
;
548 flags
= msg
->msg_flags
;
549 if (flags
& (MSG_CONFIRM
|MSG_DONTROUTE
|MSG_EOR
|MSG_OOB
))
552 map
= pvcalls_enter_sock(sock
);
556 mutex_lock(&map
->active
.out_mutex
);
557 if ((flags
& MSG_DONTWAIT
) && !pvcalls_front_write_todo(map
)) {
558 mutex_unlock(&map
->active
.out_mutex
);
559 pvcalls_exit_sock(sock
);
567 sent
= __write_ring(map
->active
.ring
,
568 &map
->active
.data
, &msg
->msg_iter
,
573 notify_remote_via_irq(map
->active
.irq
);
575 if (sent
>= 0 && len
> 0 && count
< PVCALLS_FRONT_MAX_SPIN
)
580 mutex_unlock(&map
->active
.out_mutex
);
581 pvcalls_exit_sock(sock
);
585 static int __read_ring(struct pvcalls_data_intf
*intf
,
586 struct pvcalls_data
*data
,
587 struct iov_iter
*msg_iter
,
588 size_t len
, int flags
)
590 RING_IDX cons
, prod
, size
, masked_prod
, masked_cons
;
591 RING_IDX array_size
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
594 cons
= intf
->in_cons
;
595 prod
= intf
->in_prod
;
596 error
= intf
->in_error
;
597 /* get pointers before reading from the ring */
600 size
= pvcalls_queued(prod
, cons
, array_size
);
601 masked_prod
= pvcalls_mask(prod
, array_size
);
602 masked_cons
= pvcalls_mask(cons
, array_size
);
605 return error
?: size
;
610 if (masked_prod
> masked_cons
) {
611 len
= copy_to_iter(data
->in
+ masked_cons
, len
, msg_iter
);
613 if (len
> (array_size
- masked_cons
)) {
614 int ret
= copy_to_iter(data
->in
+ masked_cons
,
615 array_size
- masked_cons
, msg_iter
);
616 if (ret
!= array_size
- masked_cons
) {
620 len
= ret
+ copy_to_iter(data
->in
, len
- ret
, msg_iter
);
622 len
= copy_to_iter(data
->in
+ masked_cons
, len
, msg_iter
);
626 /* read data from the ring before increasing the index */
628 if (!(flags
& MSG_PEEK
))
629 intf
->in_cons
+= len
;
634 int pvcalls_front_recvmsg(struct socket
*sock
, struct msghdr
*msg
, size_t len
,
638 struct sock_mapping
*map
;
640 if (flags
& (MSG_CMSG_CLOEXEC
|MSG_ERRQUEUE
|MSG_OOB
|MSG_TRUNC
))
643 map
= pvcalls_enter_sock(sock
);
647 mutex_lock(&map
->active
.in_mutex
);
648 if (len
> XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
))
649 len
= XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER
);
651 while (!(flags
& MSG_DONTWAIT
) && !pvcalls_front_read_todo(map
)) {
652 wait_event_interruptible(map
->active
.inflight_conn_req
,
653 pvcalls_front_read_todo(map
));
655 ret
= __read_ring(map
->active
.ring
, &map
->active
.data
,
656 &msg
->msg_iter
, len
, flags
);
659 notify_remote_via_irq(map
->active
.irq
);
661 ret
= (flags
& MSG_DONTWAIT
) ? -EAGAIN
: 0;
662 if (ret
== -ENOTCONN
)
665 mutex_unlock(&map
->active
.in_mutex
);
666 pvcalls_exit_sock(sock
);
670 int pvcalls_front_bind(struct socket
*sock
, struct sockaddr
*addr
, int addr_len
)
672 struct pvcalls_bedata
*bedata
;
673 struct sock_mapping
*map
= NULL
;
674 struct xen_pvcalls_request
*req
;
675 int notify
, req_id
, ret
;
677 if (addr
->sa_family
!= AF_INET
|| sock
->type
!= SOCK_STREAM
)
680 map
= pvcalls_enter_sock(sock
);
683 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
685 spin_lock(&bedata
->socket_lock
);
686 ret
= get_request(bedata
, &req_id
);
688 spin_unlock(&bedata
->socket_lock
);
689 pvcalls_exit_sock(sock
);
692 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
693 req
->req_id
= req_id
;
695 req
->cmd
= PVCALLS_BIND
;
696 req
->u
.bind
.id
= (uintptr_t)map
;
697 memcpy(req
->u
.bind
.addr
, addr
, sizeof(*addr
));
698 req
->u
.bind
.len
= addr_len
;
700 init_waitqueue_head(&map
->passive
.inflight_accept_req
);
702 map
->active_socket
= false;
704 bedata
->ring
.req_prod_pvt
++;
705 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
706 spin_unlock(&bedata
->socket_lock
);
708 notify_remote_via_irq(bedata
->irq
);
710 wait_event(bedata
->inflight_req
,
711 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
713 /* read req_id, then the content */
715 ret
= bedata
->rsp
[req_id
].ret
;
716 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
718 map
->passive
.status
= PVCALLS_STATUS_BIND
;
719 pvcalls_exit_sock(sock
);
723 int pvcalls_front_listen(struct socket
*sock
, int backlog
)
725 struct pvcalls_bedata
*bedata
;
726 struct sock_mapping
*map
;
727 struct xen_pvcalls_request
*req
;
728 int notify
, req_id
, ret
;
730 map
= pvcalls_enter_sock(sock
);
733 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
735 if (map
->passive
.status
!= PVCALLS_STATUS_BIND
) {
736 pvcalls_exit_sock(sock
);
740 spin_lock(&bedata
->socket_lock
);
741 ret
= get_request(bedata
, &req_id
);
743 spin_unlock(&bedata
->socket_lock
);
744 pvcalls_exit_sock(sock
);
747 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
748 req
->req_id
= req_id
;
749 req
->cmd
= PVCALLS_LISTEN
;
750 req
->u
.listen
.id
= (uintptr_t) map
;
751 req
->u
.listen
.backlog
= backlog
;
753 bedata
->ring
.req_prod_pvt
++;
754 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
755 spin_unlock(&bedata
->socket_lock
);
757 notify_remote_via_irq(bedata
->irq
);
759 wait_event(bedata
->inflight_req
,
760 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
762 /* read req_id, then the content */
764 ret
= bedata
->rsp
[req_id
].ret
;
765 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
767 map
->passive
.status
= PVCALLS_STATUS_LISTEN
;
768 pvcalls_exit_sock(sock
);
772 int pvcalls_front_accept(struct socket
*sock
, struct socket
*newsock
, int flags
)
774 struct pvcalls_bedata
*bedata
;
775 struct sock_mapping
*map
;
776 struct sock_mapping
*map2
= NULL
;
777 struct xen_pvcalls_request
*req
;
778 int notify
, req_id
, ret
, nonblock
;
779 evtchn_port_t evtchn
;
781 map
= pvcalls_enter_sock(sock
);
784 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
786 if (map
->passive
.status
!= PVCALLS_STATUS_LISTEN
) {
787 pvcalls_exit_sock(sock
);
791 nonblock
= flags
& SOCK_NONBLOCK
;
793 * Backend only supports 1 inflight accept request, will return
794 * errors for the others
796 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
797 (void *)&map
->passive
.flags
)) {
798 req_id
= READ_ONCE(map
->passive
.inflight_req_id
);
799 if (req_id
!= PVCALLS_INVALID_ID
&&
800 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
) {
801 map2
= map
->passive
.accept_map
;
805 pvcalls_exit_sock(sock
);
808 if (wait_event_interruptible(map
->passive
.inflight_accept_req
,
809 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
810 (void *)&map
->passive
.flags
))) {
811 pvcalls_exit_sock(sock
);
816 map2
= kzalloc(sizeof(*map2
), GFP_KERNEL
);
818 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
819 (void *)&map
->passive
.flags
);
820 pvcalls_exit_sock(sock
);
823 ret
= alloc_active_ring(map2
);
825 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
826 (void *)&map
->passive
.flags
);
828 pvcalls_exit_sock(sock
);
831 ret
= create_active(map2
, &evtchn
);
833 free_active_ring(map2
);
835 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
836 (void *)&map
->passive
.flags
);
837 pvcalls_exit_sock(sock
);
841 spin_lock(&bedata
->socket_lock
);
842 ret
= get_request(bedata
, &req_id
);
844 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
845 (void *)&map
->passive
.flags
);
846 spin_unlock(&bedata
->socket_lock
);
847 pvcalls_front_free_map(bedata
, map2
);
848 pvcalls_exit_sock(sock
);
852 list_add_tail(&map2
->list
, &bedata
->socket_mappings
);
854 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
855 req
->req_id
= req_id
;
856 req
->cmd
= PVCALLS_ACCEPT
;
857 req
->u
.accept
.id
= (uintptr_t) map
;
858 req
->u
.accept
.ref
= map2
->active
.ref
;
859 req
->u
.accept
.id_new
= (uintptr_t) map2
;
860 req
->u
.accept
.evtchn
= evtchn
;
861 map
->passive
.accept_map
= map2
;
863 bedata
->ring
.req_prod_pvt
++;
864 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
865 spin_unlock(&bedata
->socket_lock
);
867 notify_remote_via_irq(bedata
->irq
);
868 /* We could check if we have received a response before returning. */
870 WRITE_ONCE(map
->passive
.inflight_req_id
, req_id
);
871 pvcalls_exit_sock(sock
);
875 if (wait_event_interruptible(bedata
->inflight_req
,
876 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
)) {
877 pvcalls_exit_sock(sock
);
880 /* read req_id, then the content */
884 map2
->sock
= newsock
;
885 newsock
->sk
= sk_alloc(sock_net(sock
->sk
), PF_INET
, GFP_KERNEL
, &pvcalls_proto
, false);
887 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
888 map
->passive
.inflight_req_id
= PVCALLS_INVALID_ID
;
889 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
890 (void *)&map
->passive
.flags
);
891 pvcalls_front_free_map(bedata
, map2
);
892 pvcalls_exit_sock(sock
);
895 newsock
->sk
->sk_send_head
= (void *)map2
;
897 ret
= bedata
->rsp
[req_id
].ret
;
898 bedata
->rsp
[req_id
].req_id
= PVCALLS_INVALID_ID
;
899 map
->passive
.inflight_req_id
= PVCALLS_INVALID_ID
;
901 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
, (void *)&map
->passive
.flags
);
902 wake_up(&map
->passive
.inflight_accept_req
);
904 pvcalls_exit_sock(sock
);
908 static __poll_t
pvcalls_front_poll_passive(struct file
*file
,
909 struct pvcalls_bedata
*bedata
,
910 struct sock_mapping
*map
,
913 int notify
, req_id
, ret
;
914 struct xen_pvcalls_request
*req
;
916 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT
,
917 (void *)&map
->passive
.flags
)) {
918 uint32_t req_id
= READ_ONCE(map
->passive
.inflight_req_id
);
920 if (req_id
!= PVCALLS_INVALID_ID
&&
921 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
)
922 return EPOLLIN
| EPOLLRDNORM
;
924 poll_wait(file
, &map
->passive
.inflight_accept_req
, wait
);
928 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET
,
929 (void *)&map
->passive
.flags
))
930 return EPOLLIN
| EPOLLRDNORM
;
933 * First check RET, then INFLIGHT. No barriers necessary to
934 * ensure execution ordering because of the conditional
935 * instructions creating control dependencies.
938 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT
,
939 (void *)&map
->passive
.flags
)) {
940 poll_wait(file
, &bedata
->inflight_req
, wait
);
944 spin_lock(&bedata
->socket_lock
);
945 ret
= get_request(bedata
, &req_id
);
947 spin_unlock(&bedata
->socket_lock
);
950 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
951 req
->req_id
= req_id
;
952 req
->cmd
= PVCALLS_POLL
;
953 req
->u
.poll
.id
= (uintptr_t) map
;
955 bedata
->ring
.req_prod_pvt
++;
956 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
957 spin_unlock(&bedata
->socket_lock
);
959 notify_remote_via_irq(bedata
->irq
);
961 poll_wait(file
, &bedata
->inflight_req
, wait
);
965 static __poll_t
pvcalls_front_poll_active(struct file
*file
,
966 struct pvcalls_bedata
*bedata
,
967 struct sock_mapping
*map
,
971 int32_t in_error
, out_error
;
972 struct pvcalls_data_intf
*intf
= map
->active
.ring
;
974 out_error
= intf
->out_error
;
975 in_error
= intf
->in_error
;
977 poll_wait(file
, &map
->active
.inflight_conn_req
, wait
);
978 if (pvcalls_front_write_todo(map
))
979 mask
|= EPOLLOUT
| EPOLLWRNORM
;
980 if (pvcalls_front_read_todo(map
))
981 mask
|= EPOLLIN
| EPOLLRDNORM
;
982 if (in_error
!= 0 || out_error
!= 0)
988 __poll_t
pvcalls_front_poll(struct file
*file
, struct socket
*sock
,
991 struct pvcalls_bedata
*bedata
;
992 struct sock_mapping
*map
;
995 map
= pvcalls_enter_sock(sock
);
998 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1000 if (map
->active_socket
)
1001 ret
= pvcalls_front_poll_active(file
, bedata
, map
, wait
);
1003 ret
= pvcalls_front_poll_passive(file
, bedata
, map
, wait
);
1004 pvcalls_exit_sock(sock
);
1008 int pvcalls_front_release(struct socket
*sock
)
1010 struct pvcalls_bedata
*bedata
;
1011 struct sock_mapping
*map
;
1012 int req_id
, notify
, ret
;
1013 struct xen_pvcalls_request
*req
;
1015 if (sock
->sk
== NULL
)
1018 map
= pvcalls_enter_sock(sock
);
1020 if (PTR_ERR(map
) == -ENOTCONN
)
1025 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1027 spin_lock(&bedata
->socket_lock
);
1028 ret
= get_request(bedata
, &req_id
);
1030 spin_unlock(&bedata
->socket_lock
);
1031 pvcalls_exit_sock(sock
);
1034 sock
->sk
->sk_send_head
= NULL
;
1036 req
= RING_GET_REQUEST(&bedata
->ring
, req_id
);
1037 req
->req_id
= req_id
;
1038 req
->cmd
= PVCALLS_RELEASE
;
1039 req
->u
.release
.id
= (uintptr_t)map
;
1041 bedata
->ring
.req_prod_pvt
++;
1042 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata
->ring
, notify
);
1043 spin_unlock(&bedata
->socket_lock
);
1045 notify_remote_via_irq(bedata
->irq
);
1047 wait_event(bedata
->inflight_req
,
1048 READ_ONCE(bedata
->rsp
[req_id
].req_id
) == req_id
);
1050 if (map
->active_socket
) {
1052 * Set in_error and wake up inflight_conn_req to force
1053 * recvmsg waiters to exit.
1055 map
->active
.ring
->in_error
= -EBADF
;
1056 wake_up_interruptible(&map
->active
.inflight_conn_req
);
1059 * We need to make sure that sendmsg/recvmsg on this socket have
1060 * not started before we've cleared sk_send_head here. The
1061 * easiest way to guarantee this is to see that no pvcalls
1062 * (other than us) is in progress on this socket.
1064 while (atomic_read(&map
->refcount
) > 1)
1067 pvcalls_front_free_map(bedata
, map
);
1069 wake_up(&bedata
->inflight_req
);
1070 wake_up(&map
->passive
.inflight_accept_req
);
1072 while (atomic_read(&map
->refcount
) > 1)
1075 spin_lock(&bedata
->socket_lock
);
1076 list_del(&map
->list
);
1077 spin_unlock(&bedata
->socket_lock
);
1078 if (READ_ONCE(map
->passive
.inflight_req_id
) != PVCALLS_INVALID_ID
&&
1079 READ_ONCE(map
->passive
.inflight_req_id
) != 0) {
1080 pvcalls_front_free_map(bedata
,
1081 map
->passive
.accept_map
);
1085 WRITE_ONCE(bedata
->rsp
[req_id
].req_id
, PVCALLS_INVALID_ID
);
1091 static const struct xenbus_device_id pvcalls_front_ids
[] = {
1096 static void pvcalls_front_remove(struct xenbus_device
*dev
)
1098 struct pvcalls_bedata
*bedata
;
1099 struct sock_mapping
*map
= NULL
, *n
;
1101 bedata
= dev_get_drvdata(&pvcalls_front_dev
->dev
);
1102 dev_set_drvdata(&dev
->dev
, NULL
);
1103 pvcalls_front_dev
= NULL
;
1104 if (bedata
->irq
>= 0)
1105 unbind_from_irqhandler(bedata
->irq
, dev
);
1107 list_for_each_entry_safe(map
, n
, &bedata
->socket_mappings
, list
) {
1108 map
->sock
->sk
->sk_send_head
= NULL
;
1109 if (map
->active_socket
) {
1110 map
->active
.ring
->in_error
= -EBADF
;
1111 wake_up_interruptible(&map
->active
.inflight_conn_req
);
1116 while (atomic_read(&pvcalls_refcount
) > 0)
1118 list_for_each_entry_safe(map
, n
, &bedata
->socket_mappings
, list
) {
1119 if (map
->active_socket
) {
1120 /* No need to lock, refcount is 0 */
1121 pvcalls_front_free_map(bedata
, map
);
1123 list_del(&map
->list
);
1127 if (bedata
->ref
!= -1)
1128 gnttab_end_foreign_access(bedata
->ref
, NULL
);
1129 kfree(bedata
->ring
.sring
);
1131 xenbus_switch_state(dev
, XenbusStateClosed
);
1134 static int pvcalls_front_probe(struct xenbus_device
*dev
,
1135 const struct xenbus_device_id
*id
)
1137 int ret
= -ENOMEM
, i
;
1138 evtchn_port_t evtchn
;
1139 unsigned int max_page_order
, function_calls
, len
;
1141 grant_ref_t gref_head
= 0;
1142 struct xenbus_transaction xbt
;
1143 struct pvcalls_bedata
*bedata
= NULL
;
1144 struct xen_pvcalls_sring
*sring
;
1146 if (pvcalls_front_dev
!= NULL
) {
1147 dev_err(&dev
->dev
, "only one PV Calls connection supported\n");
1151 versions
= xenbus_read(XBT_NIL
, dev
->otherend
, "versions", &len
);
1152 if (IS_ERR(versions
))
1153 return PTR_ERR(versions
);
1156 if (strcmp(versions
, "1")) {
1161 max_page_order
= xenbus_read_unsigned(dev
->otherend
,
1162 "max-page-order", 0);
1163 if (max_page_order
< PVCALLS_RING_ORDER
)
1165 function_calls
= xenbus_read_unsigned(dev
->otherend
,
1166 "function-calls", 0);
1167 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1168 if (function_calls
!= 1)
1170 pr_info("%s max-page-order is %u\n", __func__
, max_page_order
);
1172 bedata
= kzalloc(sizeof(struct pvcalls_bedata
), GFP_KERNEL
);
1176 dev_set_drvdata(&dev
->dev
, bedata
);
1177 pvcalls_front_dev
= dev
;
1178 init_waitqueue_head(&bedata
->inflight_req
);
1179 INIT_LIST_HEAD(&bedata
->socket_mappings
);
1180 spin_lock_init(&bedata
->socket_lock
);
1184 for (i
= 0; i
< PVCALLS_NR_RSP_PER_RING
; i
++)
1185 bedata
->rsp
[i
].req_id
= PVCALLS_INVALID_ID
;
1187 sring
= (struct xen_pvcalls_sring
*) __get_free_page(GFP_KERNEL
|
1191 SHARED_RING_INIT(sring
);
1192 FRONT_RING_INIT(&bedata
->ring
, sring
, XEN_PAGE_SIZE
);
1194 ret
= xenbus_alloc_evtchn(dev
, &evtchn
);
1198 bedata
->irq
= bind_evtchn_to_irqhandler(evtchn
,
1199 pvcalls_front_event_handler
,
1200 0, "pvcalls-frontend", dev
);
1201 if (bedata
->irq
< 0) {
1206 ret
= gnttab_alloc_grant_references(1, &gref_head
);
1209 ret
= gnttab_claim_grant_reference(&gref_head
);
1213 gnttab_grant_foreign_access_ref(bedata
->ref
, dev
->otherend_id
,
1214 virt_to_gfn((void *)sring
), 0);
1217 ret
= xenbus_transaction_start(&xbt
);
1219 xenbus_dev_fatal(dev
, ret
, "starting transaction");
1222 ret
= xenbus_printf(xbt
, dev
->nodename
, "version", "%u", 1);
1225 ret
= xenbus_printf(xbt
, dev
->nodename
, "ring-ref", "%d", bedata
->ref
);
1228 ret
= xenbus_printf(xbt
, dev
->nodename
, "port", "%u",
1232 ret
= xenbus_transaction_end(xbt
, 0);
1236 xenbus_dev_fatal(dev
, ret
, "completing transaction");
1239 xenbus_switch_state(dev
, XenbusStateInitialised
);
1244 xenbus_transaction_end(xbt
, 1);
1245 xenbus_dev_fatal(dev
, ret
, "writing xenstore");
1247 pvcalls_front_remove(dev
);
1251 static void pvcalls_front_changed(struct xenbus_device
*dev
,
1252 enum xenbus_state backend_state
)
1254 switch (backend_state
) {
1255 case XenbusStateReconfiguring
:
1256 case XenbusStateReconfigured
:
1257 case XenbusStateInitialising
:
1258 case XenbusStateInitialised
:
1259 case XenbusStateUnknown
:
1262 case XenbusStateInitWait
:
1265 case XenbusStateConnected
:
1266 xenbus_switch_state(dev
, XenbusStateConnected
);
1269 case XenbusStateClosed
:
1270 if (dev
->state
== XenbusStateClosed
)
1272 /* Missed the backend's CLOSING state */
1274 case XenbusStateClosing
:
1275 xenbus_frontend_closed(dev
);
1280 static struct xenbus_driver pvcalls_front_driver
= {
1281 .ids
= pvcalls_front_ids
,
1282 .probe
= pvcalls_front_probe
,
1283 .remove
= pvcalls_front_remove
,
1284 .otherend_changed
= pvcalls_front_changed
,
1285 .not_essential
= true,
1288 static int __init
pvcalls_frontend_init(void)
1293 pr_info("Initialising Xen pvcalls frontend driver\n");
1295 return xenbus_register_frontend(&pvcalls_front_driver
);
1298 module_init(pvcalls_frontend_init
);
1300 MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1301 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1302 MODULE_LICENSE("GPL");