1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com>
4 * Based off of net/unix/unix_bpf.c
8 #include <linux/module.h>
9 #include <linux/skmsg.h>
10 #include <linux/socket.h>
11 #include <linux/wait.h>
12 #include <net/af_vsock.h>
15 #define vsock_sk_has_data(__sk, __psock) \
16 ({ !skb_queue_empty(&(__sk)->sk_receive_queue) || \
17 !skb_queue_empty(&(__psock)->ingress_skb) || \
18 !list_empty(&(__psock)->ingress_msg); \
21 static struct proto
*vsock_prot_saved __read_mostly
;
22 static DEFINE_SPINLOCK(vsock_prot_lock
);
23 static struct proto vsock_bpf_prot
;
25 static bool vsock_has_data(struct sock
*sk
, struct sk_psock
*psock
)
27 struct vsock_sock
*vsk
= vsock_sk(sk
);
30 ret
= vsock_connectible_has_data(vsk
);
34 return vsock_sk_has_data(sk
, psock
);
37 static bool vsock_msg_wait_data(struct sock
*sk
, struct sk_psock
*psock
, long timeo
)
41 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
43 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
49 add_wait_queue(sk_sleep(sk
), &wait
);
50 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
51 ret
= vsock_has_data(sk
, psock
);
53 wait_woken(&wait
, TASK_INTERRUPTIBLE
, timeo
);
54 ret
= vsock_has_data(sk
, psock
);
56 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
57 remove_wait_queue(sk_sleep(sk
), &wait
);
61 static int __vsock_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
, int flags
)
63 struct socket
*sock
= sk
->sk_socket
;
66 if (sk
->sk_type
== SOCK_STREAM
|| sk
->sk_type
== SOCK_SEQPACKET
)
67 err
= __vsock_connectible_recvmsg(sock
, msg
, len
, flags
);
68 else if (sk
->sk_type
== SOCK_DGRAM
)
69 err
= __vsock_dgram_recvmsg(sock
, msg
, len
, flags
);
76 static int vsock_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
,
77 size_t len
, int flags
, int *addr_len
)
79 struct sk_psock
*psock
;
82 psock
= sk_psock_get(sk
);
84 return __vsock_recvmsg(sk
, msg
, len
, flags
);
87 if (vsock_has_data(sk
, psock
) && sk_psock_queue_empty(psock
)) {
89 sk_psock_put(sk
, psock
);
90 return __vsock_recvmsg(sk
, msg
, len
, flags
);
93 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
95 long timeo
= sock_rcvtimeo(sk
, flags
& MSG_DONTWAIT
);
97 if (!vsock_msg_wait_data(sk
, psock
, timeo
)) {
102 if (sk_psock_queue_empty(psock
)) {
104 sk_psock_put(sk
, psock
);
105 return __vsock_recvmsg(sk
, msg
, len
, flags
);
108 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
112 sk_psock_put(sk
, psock
);
117 static void vsock_bpf_rebuild_protos(struct proto
*prot
, const struct proto
*base
)
120 prot
->close
= sock_map_close
;
121 prot
->recvmsg
= vsock_bpf_recvmsg
;
122 prot
->sock_is_readable
= sk_msg_is_readable
;
125 static void vsock_bpf_check_needs_rebuild(struct proto
*ops
)
127 /* Paired with the smp_store_release() below. */
128 if (unlikely(ops
!= smp_load_acquire(&vsock_prot_saved
))) {
129 spin_lock_bh(&vsock_prot_lock
);
130 if (likely(ops
!= vsock_prot_saved
)) {
131 vsock_bpf_rebuild_protos(&vsock_bpf_prot
, ops
);
132 /* Make sure proto function pointers are updated before publishing the
133 * pointer to the struct.
135 smp_store_release(&vsock_prot_saved
, ops
);
137 spin_unlock_bh(&vsock_prot_lock
);
141 int vsock_bpf_update_proto(struct sock
*sk
, struct sk_psock
*psock
, bool restore
)
143 struct vsock_sock
*vsk
;
146 sk
->sk_write_space
= psock
->saved_write_space
;
147 sock_replace_proto(sk
, psock
->sk_proto
);
155 if (!vsk
->transport
->read_skb
)
158 vsock_bpf_check_needs_rebuild(psock
->sk_proto
);
159 sock_replace_proto(sk
, &vsock_bpf_prot
);
163 void __init
vsock_bpf_build_proto(void)
165 vsock_bpf_rebuild_protos(&vsock_bpf_prot
, &vsock_proto
);