1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
9 #include <linux/util_macros.h>
11 #include <net/inet_common.h>
14 void tcp_eat_skb(struct sock
*sk
, struct sk_buff
*skb
)
19 if (!skb
|| !skb
->len
|| !sk_is_tcp(sk
))
22 if (skb_bpf_strparser(skb
))
26 copied
= tcp
->copied_seq
+ skb
->len
;
27 WRITE_ONCE(tcp
->copied_seq
, copied
);
28 tcp_rcv_space_adjust(sk
);
29 __tcp_cleanup_rbuf(sk
, skb
->len
);
32 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
33 struct sk_msg
*msg
, u32 apply_bytes
)
35 bool apply
= apply_bytes
;
36 struct scatterlist
*sge
;
41 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
46 tmp
->sg
.start
= msg
->sg
.start
;
49 sge
= sk_msg_elem(msg
, i
);
50 size
= (apply
&& apply_bytes
< sge
->length
) ?
51 apply_bytes
: sge
->length
;
52 if (!sk_wmem_schedule(sk
, size
)) {
58 sk_mem_charge(sk
, size
);
59 sk_msg_xfer(tmp
, msg
, i
, size
);
62 get_page(sk_msg_page(tmp
, i
));
63 sk_msg_iter_var_next(i
);
69 sk_msg_iter_var_prev(i
);
73 } while (i
!= msg
->sg
.end
);
77 sk_psock_queue_msg(psock
, tmp
);
78 sk_psock_data_ready(sk
, psock
);
88 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
89 int flags
, bool uncharge
)
91 struct msghdr msghdr
= {};
92 bool apply
= apply_bytes
;
93 struct scatterlist
*sge
;
102 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
103 size
= (apply
&& apply_bytes
< sge
->length
) ?
104 apply_bytes
: sge
->length
;
108 tcp_rate_check_app_limited(sk
);
110 msghdr
.msg_flags
= flags
| MSG_SPLICE_PAGES
;
111 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
113 msghdr
.msg_flags
|= MSG_SENDPAGE_NOPOLICY
;
115 if (size
< sge
->length
&& msg
->sg
.start
!= msg
->sg
.end
)
116 msghdr
.msg_flags
|= MSG_MORE
;
118 bvec_set_page(&bvec
, page
, size
, off
);
119 iov_iter_bvec(&msghdr
.msg_iter
, ITER_SOURCE
, &bvec
, 1, size
);
120 ret
= tcp_sendmsg_locked(sk
, &msghdr
, size
);
130 sk_mem_uncharge(sk
, ret
);
138 sk_msg_iter_next(msg
, start
);
139 sg_init_table(sge
, 1);
140 if (msg
->sg
.start
== msg
->sg
.end
)
143 if (apply
&& !apply_bytes
)
150 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
151 u32 apply_bytes
, int flags
, bool uncharge
)
156 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
161 int tcp_bpf_sendmsg_redir(struct sock
*sk
, bool ingress
,
162 struct sk_msg
*msg
, u32 bytes
, int flags
)
164 struct sk_psock
*psock
= sk_psock_get(sk
);
167 if (unlikely(!psock
))
170 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
) :
171 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
172 sk_psock_put(sk
, psock
);
175 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
177 #ifdef CONFIG_BPF_SYSCALL
178 static int tcp_msg_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
181 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
184 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
190 add_wait_queue(sk_sleep(sk
), &wait
);
191 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
192 ret
= sk_wait_event(sk
, &timeo
,
193 !list_empty(&psock
->ingress_msg
) ||
194 !skb_queue_empty_lockless(&sk
->sk_receive_queue
), &wait
);
195 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
196 remove_wait_queue(sk_sleep(sk
), &wait
);
200 static bool is_next_msg_fin(struct sk_psock
*psock
)
202 struct scatterlist
*sge
;
203 struct sk_msg
*msg_rx
;
206 msg_rx
= sk_psock_peek_msg(psock
);
207 i
= msg_rx
->sg
.start
;
208 sge
= sk_msg_elem(msg_rx
, i
);
210 struct sk_buff
*skb
= msg_rx
->skb
;
212 if (skb
&& TCP_SKB_CB(skb
)->tcp_flags
& TCPHDR_FIN
)
218 static int tcp_bpf_recvmsg_parser(struct sock
*sk
,
224 int peek
= flags
& MSG_PEEK
;
225 struct sk_psock
*psock
;
226 struct tcp_sock
*tcp
;
230 if (unlikely(flags
& MSG_ERRQUEUE
))
231 return inet_recv_error(sk
, msg
, len
, addr_len
);
236 psock
= sk_psock_get(sk
);
237 if (unlikely(!psock
))
238 return tcp_recvmsg(sk
, msg
, len
, flags
, addr_len
);
242 seq
= tcp
->copied_seq
;
243 /* We may have received data on the sk_receive_queue pre-accept and
244 * then we can not use read_skb in this context because we haven't
245 * assigned a sk_socket yet so have no link to the ops. The work-around
246 * is to check the sk_receive_queue and in these cases read skbs off
247 * queue again. The read_skb hook is not running at this point because
248 * of lock_sock so we avoid having multiple runners in read_skb.
250 if (unlikely(!skb_queue_empty(&sk
->sk_receive_queue
))) {
252 /* This handles the ENOMEM errors if we both receive data
253 * pre accept and are already under memory pressure. At least
254 * let user know to retry.
256 if (unlikely(!skb_queue_empty(&sk
->sk_receive_queue
))) {
263 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
264 /* The typical case for EFAULT is the socket was gracefully
265 * shutdown with a FIN pkt. So check here the other case is
266 * some error on copy_page_to_iter which would be unexpected.
267 * On fin return correct return code to zero.
269 if (copied
== -EFAULT
) {
270 bool is_fin
= is_next_msg_fin(psock
);
283 if (sock_flag(sk
, SOCK_DONE
))
287 copied
= sock_error(sk
);
291 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
294 if (sk
->sk_state
== TCP_CLOSE
) {
299 timeo
= sock_rcvtimeo(sk
, flags
& MSG_DONTWAIT
);
305 if (signal_pending(current
)) {
306 copied
= sock_intr_errno(timeo
);
310 data
= tcp_msg_wait_data(sk
, psock
, timeo
);
315 if (data
&& !sk_psock_queue_empty(psock
))
316 goto msg_bytes_ready
;
321 WRITE_ONCE(tcp
->copied_seq
, seq
);
322 tcp_rcv_space_adjust(sk
);
324 __tcp_cleanup_rbuf(sk
, copied
);
328 sk_psock_put(sk
, psock
);
332 static int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
333 int flags
, int *addr_len
)
335 struct sk_psock
*psock
;
338 if (unlikely(flags
& MSG_ERRQUEUE
))
339 return inet_recv_error(sk
, msg
, len
, addr_len
);
344 psock
= sk_psock_get(sk
);
345 if (unlikely(!psock
))
346 return tcp_recvmsg(sk
, msg
, len
, flags
, addr_len
);
347 if (!skb_queue_empty(&sk
->sk_receive_queue
) &&
348 sk_psock_queue_empty(psock
)) {
349 sk_psock_put(sk
, psock
);
350 return tcp_recvmsg(sk
, msg
, len
, flags
, addr_len
);
354 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
359 timeo
= sock_rcvtimeo(sk
, flags
& MSG_DONTWAIT
);
360 data
= tcp_msg_wait_data(sk
, psock
, timeo
);
366 if (!sk_psock_queue_empty(psock
))
367 goto msg_bytes_ready
;
369 sk_psock_put(sk
, psock
);
370 return tcp_recvmsg(sk
, msg
, len
, flags
, addr_len
);
378 sk_psock_put(sk
, psock
);
382 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
383 struct sk_msg
*msg
, int *copied
, int flags
)
385 bool cork
= false, enospc
= sk_msg_full(msg
), redir_ingress
;
386 struct sock
*sk_redir
;
387 u32 tosend
, origsize
, sent
, delta
= 0;
392 if (psock
->eval
== __SK_NONE
) {
393 /* Track delta in msg size to add/subtract it on SK_DROP from
394 * returned to user copied size. This ensures user doesn't
395 * get a positive return code with msg_cut_data and SK_DROP
398 delta
= msg
->sg
.size
;
399 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
400 delta
-= msg
->sg
.size
;
403 if (msg
->cork_bytes
&&
404 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
405 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
407 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
408 GFP_ATOMIC
| __GFP_NOWARN
);
412 memcpy(psock
->cork
, msg
, sizeof(*msg
));
416 tosend
= msg
->sg
.size
;
417 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
418 tosend
= psock
->apply_bytes
;
421 switch (psock
->eval
) {
423 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
425 *copied
-= sk_msg_free(sk
, msg
);
428 sk_msg_apply_bytes(psock
, tosend
);
431 redir_ingress
= psock
->redir_ingress
;
432 sk_redir
= psock
->sk_redir
;
433 sk_msg_apply_bytes(psock
, tosend
);
434 if (!psock
->apply_bytes
) {
435 /* Clean up before releasing the sock lock. */
437 psock
->eval
= __SK_NONE
;
438 psock
->sk_redir
= NULL
;
444 sk_msg_return(sk
, msg
, tosend
);
447 origsize
= msg
->sg
.size
;
448 ret
= tcp_bpf_sendmsg_redir(sk_redir
, redir_ingress
,
450 sent
= origsize
- msg
->sg
.size
;
452 if (eval
== __SK_REDIRECT
)
456 if (unlikely(ret
< 0)) {
457 int free
= sk_msg_free_nocharge(sk
, msg
);
463 sk_msg_free(sk
, msg
);
471 sk_msg_free_partial(sk
, msg
, tosend
);
472 sk_msg_apply_bytes(psock
, tosend
);
473 *copied
-= (tosend
+ delta
);
478 if (!psock
->apply_bytes
) {
479 psock
->eval
= __SK_NONE
;
480 if (psock
->sk_redir
) {
481 sock_put(psock
->sk_redir
);
482 psock
->sk_redir
= NULL
;
486 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
487 msg
->sg
.data
[msg
->sg
.start
].length
) {
488 if (eval
== __SK_REDIRECT
)
489 sk_mem_charge(sk
, tosend
- sent
);
496 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
498 struct sk_msg tmp
, *msg_tx
= NULL
;
499 int copied
= 0, err
= 0;
500 struct sk_psock
*psock
;
504 /* Don't let internal flags through */
505 flags
= (msg
->msg_flags
& ~MSG_SENDPAGE_DECRYPTED
);
506 flags
|= MSG_NO_SHARED_FRAGS
;
508 psock
= sk_psock_get(sk
);
509 if (unlikely(!psock
))
510 return tcp_sendmsg(sk
, msg
, size
);
513 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
514 while (msg_data_left(msg
)) {
523 copy
= msg_data_left(msg
);
524 if (!sk_stream_memory_free(sk
))
525 goto wait_for_sndbuf
;
527 msg_tx
= psock
->cork
;
533 osize
= msg_tx
->sg
.size
;
534 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
537 goto wait_for_memory
;
539 copy
= msg_tx
->sg
.size
- osize
;
542 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
545 sk_msg_trim(sk
, msg_tx
, osize
);
550 if (psock
->cork_bytes
) {
551 if (size
> psock
->cork_bytes
)
552 psock
->cork_bytes
= 0;
554 psock
->cork_bytes
-= size
;
555 if (psock
->cork_bytes
&& !enospc
)
557 /* All cork bytes are accounted, rerun the prog. */
558 psock
->eval
= __SK_NONE
;
559 psock
->cork_bytes
= 0;
562 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
563 if (unlikely(err
< 0))
567 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
569 err
= sk_stream_wait_memory(sk
, &timeo
);
571 if (msg_tx
&& msg_tx
!= psock
->cork
)
572 sk_msg_free(sk
, msg_tx
);
578 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
580 sk_psock_put(sk
, psock
);
581 return copied
> 0 ? copied
: err
;
598 static struct proto
*tcpv6_prot_saved __read_mostly
;
599 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
600 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
602 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
605 prot
[TCP_BPF_BASE
] = *base
;
606 prot
[TCP_BPF_BASE
].destroy
= sock_map_destroy
;
607 prot
[TCP_BPF_BASE
].close
= sock_map_close
;
608 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
609 prot
[TCP_BPF_BASE
].sock_is_readable
= sk_msg_is_readable
;
611 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
612 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
614 prot
[TCP_BPF_RX
] = prot
[TCP_BPF_BASE
];
615 prot
[TCP_BPF_RX
].recvmsg
= tcp_bpf_recvmsg_parser
;
617 prot
[TCP_BPF_TXRX
] = prot
[TCP_BPF_TX
];
618 prot
[TCP_BPF_TXRX
].recvmsg
= tcp_bpf_recvmsg_parser
;
621 static void tcp_bpf_check_v6_needs_rebuild(struct proto
*ops
)
623 if (unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
624 spin_lock_bh(&tcpv6_prot_lock
);
625 if (likely(ops
!= tcpv6_prot_saved
)) {
626 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
627 smp_store_release(&tcpv6_prot_saved
, ops
);
629 spin_unlock_bh(&tcpv6_prot_lock
);
633 static int __init
tcp_bpf_v4_build_proto(void)
635 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
638 late_initcall(tcp_bpf_v4_build_proto
);
640 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
642 /* In order to avoid retpoline, we make assumptions when we call
643 * into ops if e.g. a psock is not present. Make sure they are
644 * indeed valid assumptions.
646 return ops
->recvmsg
== tcp_recvmsg
&&
647 ops
->sendmsg
== tcp_sendmsg
? 0 : -ENOTSUPP
;
650 int tcp_bpf_update_proto(struct sock
*sk
, struct sk_psock
*psock
, bool restore
)
652 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
653 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
655 if (psock
->progs
.stream_verdict
|| psock
->progs
.skb_verdict
) {
656 config
= (config
== TCP_BPF_TX
) ? TCP_BPF_TXRX
: TCP_BPF_RX
;
660 if (inet_csk_has_ulp(sk
)) {
661 /* TLS does not have an unhash proto in SW cases,
662 * but we need to ensure we stop using the sock_map
663 * unhash routine because the associated psock is being
664 * removed. So use the original unhash handler.
666 WRITE_ONCE(sk
->sk_prot
->unhash
, psock
->saved_unhash
);
667 tcp_update_ulp(sk
, psock
->sk_proto
, psock
->saved_write_space
);
669 sk
->sk_write_space
= psock
->saved_write_space
;
670 /* Pairs with lockless read in sk_clone_lock() */
671 sock_replace_proto(sk
, psock
->sk_proto
);
676 if (sk
->sk_family
== AF_INET6
) {
677 if (tcp_bpf_assert_proto_ops(psock
->sk_proto
))
680 tcp_bpf_check_v6_needs_rebuild(psock
->sk_proto
);
683 /* Pairs with lockless read in sk_clone_lock() */
684 sock_replace_proto(sk
, &tcp_bpf_prots
[family
][config
]);
687 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto
);
689 /* If a child got cloned from a listening socket that had tcp_bpf
690 * protocol callbacks installed, we need to restore the callbacks to
691 * the default ones because the child does not inherit the psock state
692 * that tcp_bpf callbacks expect.
694 void tcp_bpf_clone(const struct sock
*sk
, struct sock
*newsk
)
696 struct proto
*prot
= newsk
->sk_prot
;
698 if (is_insidevar(prot
, tcp_bpf_prots
))
699 newsk
->sk_prot
= sk
->sk_prot_creator
;
701 #endif /* CONFIG_BPF_SYSCALL */