1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 static bool tcp_bpf_stream_read(const struct sock
*sk
)
15 struct sk_psock
*psock
;
21 empty
= list_empty(&psock
->ingress_msg
);
26 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
27 int flags
, long timeo
, int *err
)
29 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
35 add_wait_queue(sk_sleep(sk
), &wait
);
36 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
37 ret
= sk_wait_event(sk
, &timeo
,
38 !list_empty(&psock
->ingress_msg
) ||
39 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
40 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
41 remove_wait_queue(sk_sleep(sk
), &wait
);
45 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
46 struct msghdr
*msg
, int len
, int flags
)
48 struct iov_iter
*iter
= &msg
->msg_iter
;
49 int peek
= flags
& MSG_PEEK
;
50 int i
, ret
, copied
= 0;
51 struct sk_msg
*msg_rx
;
53 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
56 while (copied
!= len
) {
57 struct scatterlist
*sge
;
59 if (unlikely(!msg_rx
))
67 sge
= sk_msg_elem(msg_rx
, i
);
70 if (copied
+ copy
> len
)
72 ret
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
82 sk_mem_uncharge(sk
, copy
);
83 msg_rx
->sg
.size
-= copy
;
86 sk_msg_iter_var_next(i
);
91 sk_msg_iter_var_next(i
);
96 } while (i
!= msg_rx
->sg
.end
);
99 msg_rx
= list_next_entry(msg_rx
, list
);
103 msg_rx
->sg
.start
= i
;
104 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
105 list_del(&msg_rx
->list
);
107 consume_skb(msg_rx
->skb
);
110 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
111 struct sk_msg
, list
);
116 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
118 int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
119 int nonblock
, int flags
, int *addr_len
)
121 struct sk_psock
*psock
;
124 if (unlikely(flags
& MSG_ERRQUEUE
))
125 return inet_recv_error(sk
, msg
, len
, addr_len
);
126 if (!skb_queue_empty(&sk
->sk_receive_queue
))
127 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
129 psock
= sk_psock_get(sk
);
130 if (unlikely(!psock
))
131 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
134 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
139 timeo
= sock_rcvtimeo(sk
, nonblock
);
140 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
142 if (skb_queue_empty(&sk
->sk_receive_queue
))
143 goto msg_bytes_ready
;
145 sk_psock_put(sk
, psock
);
146 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
157 sk_psock_put(sk
, psock
);
161 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
162 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
164 bool apply
= apply_bytes
;
165 struct scatterlist
*sge
;
166 u32 size
, copied
= 0;
170 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
175 tmp
->sg
.start
= msg
->sg
.start
;
178 sge
= sk_msg_elem(msg
, i
);
179 size
= (apply
&& apply_bytes
< sge
->length
) ?
180 apply_bytes
: sge
->length
;
181 if (!sk_wmem_schedule(sk
, size
)) {
187 sk_mem_charge(sk
, size
);
188 sk_msg_xfer(tmp
, msg
, i
, size
);
191 get_page(sk_msg_page(tmp
, i
));
192 sk_msg_iter_var_next(i
);
199 } while (i
!= msg
->sg
.end
);
203 msg
->sg
.size
-= apply_bytes
;
204 sk_psock_queue_msg(psock
, tmp
);
205 sk_psock_data_ready(sk
, psock
);
207 sk_msg_free(sk
, tmp
);
215 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
216 int flags
, bool uncharge
)
218 bool apply
= apply_bytes
;
219 struct scatterlist
*sge
;
227 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
228 size
= (apply
&& apply_bytes
< sge
->length
) ?
229 apply_bytes
: sge
->length
;
233 tcp_rate_check_app_limited(sk
);
235 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
237 flags
|= MSG_SENDPAGE_NOPOLICY
;
238 ret
= kernel_sendpage_locked(sk
,
239 page
, off
, size
, flags
);
241 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
252 sk_mem_uncharge(sk
, ret
);
260 sk_msg_iter_next(msg
, start
);
261 sg_init_table(sge
, 1);
262 if (msg
->sg
.start
== msg
->sg
.end
)
265 if (apply
&& !apply_bytes
)
272 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
273 u32 apply_bytes
, int flags
, bool uncharge
)
278 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
283 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
284 u32 bytes
, int flags
)
286 bool ingress
= sk_msg_to_ingress(msg
);
287 struct sk_psock
*psock
= sk_psock_get(sk
);
290 if (unlikely(!psock
)) {
291 sk_msg_free(sk
, msg
);
294 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
295 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
296 sk_psock_put(sk
, psock
);
299 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
301 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
302 struct sk_msg
*msg
, int *copied
, int flags
)
304 bool cork
= false, enospc
= msg
->sg
.start
== msg
->sg
.end
;
305 struct sock
*sk_redir
;
306 u32 tosend
, delta
= 0;
310 if (psock
->eval
== __SK_NONE
) {
311 /* Track delta in msg size to add/subtract it on SK_DROP from
312 * returned to user copied size. This ensures user doesn't
313 * get a positive return code with msg_cut_data and SK_DROP
316 delta
= msg
->sg
.size
;
317 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
318 if (msg
->sg
.size
< delta
)
319 delta
-= msg
->sg
.size
;
324 if (msg
->cork_bytes
&&
325 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
326 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
328 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
329 GFP_ATOMIC
| __GFP_NOWARN
);
333 memcpy(psock
->cork
, msg
, sizeof(*msg
));
337 tosend
= msg
->sg
.size
;
338 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
339 tosend
= psock
->apply_bytes
;
341 switch (psock
->eval
) {
343 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
345 *copied
-= sk_msg_free(sk
, msg
);
348 sk_msg_apply_bytes(psock
, tosend
);
351 sk_redir
= psock
->sk_redir
;
352 sk_msg_apply_bytes(psock
, tosend
);
357 sk_msg_return(sk
, msg
, tosend
);
359 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
361 if (unlikely(ret
< 0)) {
362 int free
= sk_msg_free_nocharge(sk
, msg
);
368 sk_msg_free(sk
, msg
);
376 sk_msg_free_partial(sk
, msg
, tosend
);
377 sk_msg_apply_bytes(psock
, tosend
);
378 *copied
-= (tosend
+ delta
);
383 if (!psock
->apply_bytes
) {
384 psock
->eval
= __SK_NONE
;
385 if (psock
->sk_redir
) {
386 sock_put(psock
->sk_redir
);
387 psock
->sk_redir
= NULL
;
391 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
392 msg
->sg
.data
[msg
->sg
.start
].length
)
398 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
400 struct sk_msg tmp
, *msg_tx
= NULL
;
401 int flags
= msg
->msg_flags
| MSG_NO_SHARED_FRAGS
;
402 int copied
= 0, err
= 0;
403 struct sk_psock
*psock
;
406 psock
= sk_psock_get(sk
);
407 if (unlikely(!psock
))
408 return tcp_sendmsg(sk
, msg
, size
);
411 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
412 while (msg_data_left(msg
)) {
421 copy
= msg_data_left(msg
);
422 if (!sk_stream_memory_free(sk
))
423 goto wait_for_sndbuf
;
425 msg_tx
= psock
->cork
;
431 osize
= msg_tx
->sg
.size
;
432 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
435 goto wait_for_memory
;
437 copy
= msg_tx
->sg
.size
- osize
;
440 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
443 sk_msg_trim(sk
, msg_tx
, osize
);
448 if (psock
->cork_bytes
) {
449 if (size
> psock
->cork_bytes
)
450 psock
->cork_bytes
= 0;
452 psock
->cork_bytes
-= size
;
453 if (psock
->cork_bytes
&& !enospc
)
455 /* All cork bytes are accounted, rerun the prog. */
456 psock
->eval
= __SK_NONE
;
457 psock
->cork_bytes
= 0;
460 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
461 if (unlikely(err
< 0))
465 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
467 err
= sk_stream_wait_memory(sk
, &timeo
);
469 if (msg_tx
&& msg_tx
!= psock
->cork
)
470 sk_msg_free(sk
, msg_tx
);
476 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
478 sk_psock_put(sk
, psock
);
479 return copied
? copied
: err
;
482 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
483 size_t size
, int flags
)
485 struct sk_msg tmp
, *msg
= NULL
;
486 int err
= 0, copied
= 0;
487 struct sk_psock
*psock
;
490 psock
= sk_psock_get(sk
);
491 if (unlikely(!psock
))
492 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
502 /* Catch case where ring is full and sendpage is stalled. */
503 if (unlikely(sk_msg_full(msg
)))
506 sk_msg_page_add(msg
, page
, size
, offset
);
507 sk_mem_charge(sk
, size
);
509 if (sk_msg_full(msg
))
511 if (psock
->cork_bytes
) {
512 if (size
> psock
->cork_bytes
)
513 psock
->cork_bytes
= 0;
515 psock
->cork_bytes
-= size
;
516 if (psock
->cork_bytes
&& !enospc
)
518 /* All cork bytes are accounted, rerun the prog. */
519 psock
->eval
= __SK_NONE
;
520 psock
->cork_bytes
= 0;
523 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
526 sk_psock_put(sk
, psock
);
527 return copied
? copied
: err
;
530 static void tcp_bpf_remove(struct sock
*sk
, struct sk_psock
*psock
)
532 struct sk_psock_link
*link
;
534 while ((link
= sk_psock_link_pop(psock
))) {
535 sk_psock_unlink(sk
, link
);
536 sk_psock_free_link(link
);
540 static void tcp_bpf_unhash(struct sock
*sk
)
542 void (*saved_unhash
)(struct sock
*sk
);
543 struct sk_psock
*psock
;
546 psock
= sk_psock(sk
);
547 if (unlikely(!psock
)) {
549 if (sk
->sk_prot
->unhash
)
550 sk
->sk_prot
->unhash(sk
);
554 saved_unhash
= psock
->saved_unhash
;
555 tcp_bpf_remove(sk
, psock
);
560 static void tcp_bpf_close(struct sock
*sk
, long timeout
)
562 void (*saved_close
)(struct sock
*sk
, long timeout
);
563 struct sk_psock
*psock
;
567 psock
= sk_psock(sk
);
568 if (unlikely(!psock
)) {
571 return sk
->sk_prot
->close(sk
, timeout
);
574 saved_close
= psock
->saved_close
;
575 tcp_bpf_remove(sk
, psock
);
578 saved_close(sk
, timeout
);
593 static struct proto
*tcpv6_prot_saved __read_mostly
;
594 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
595 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
597 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
600 prot
[TCP_BPF_BASE
] = *base
;
601 prot
[TCP_BPF_BASE
].unhash
= tcp_bpf_unhash
;
602 prot
[TCP_BPF_BASE
].close
= tcp_bpf_close
;
603 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
604 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
606 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
607 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
608 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
611 static void tcp_bpf_check_v6_needs_rebuild(struct sock
*sk
, struct proto
*ops
)
613 if (sk
->sk_family
== AF_INET6
&&
614 unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
615 spin_lock_bh(&tcpv6_prot_lock
);
616 if (likely(ops
!= tcpv6_prot_saved
)) {
617 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
618 smp_store_release(&tcpv6_prot_saved
, ops
);
620 spin_unlock_bh(&tcpv6_prot_lock
);
624 static int __init
tcp_bpf_v4_build_proto(void)
626 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
629 core_initcall(tcp_bpf_v4_build_proto
);
631 static void tcp_bpf_update_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
633 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
634 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
636 sk_psock_update_proto(sk
, psock
, &tcp_bpf_prots
[family
][config
]);
639 static void tcp_bpf_reinit_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
641 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
642 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
644 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
645 * or added requiring sk_prot hook updates. We keep original saved
646 * hooks in this case.
648 sk
->sk_prot
= &tcp_bpf_prots
[family
][config
];
651 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
653 /* In order to avoid retpoline, we make assumptions when we call
654 * into ops if e.g. a psock is not present. Make sure they are
655 * indeed valid assumptions.
657 return ops
->recvmsg
== tcp_recvmsg
&&
658 ops
->sendmsg
== tcp_sendmsg
&&
659 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
662 void tcp_bpf_reinit(struct sock
*sk
)
664 struct sk_psock
*psock
;
666 sock_owned_by_me(sk
);
669 psock
= sk_psock(sk
);
670 tcp_bpf_reinit_sk_prot(sk
, psock
);
674 int tcp_bpf_init(struct sock
*sk
)
676 struct proto
*ops
= READ_ONCE(sk
->sk_prot
);
677 struct sk_psock
*psock
;
679 sock_owned_by_me(sk
);
682 psock
= sk_psock(sk
);
683 if (unlikely(!psock
|| psock
->sk_proto
||
684 tcp_bpf_assert_proto_ops(ops
))) {
688 tcp_bpf_check_v6_needs_rebuild(sk
, ops
);
689 tcp_bpf_update_sk_prot(sk
, psock
);