1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 static bool tcp_bpf_stream_read(const struct sock
*sk
)
15 struct sk_psock
*psock
;
21 empty
= list_empty(&psock
->ingress_msg
);
26 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
27 int flags
, long timeo
, int *err
)
29 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
35 add_wait_queue(sk_sleep(sk
), &wait
);
36 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
37 ret
= sk_wait_event(sk
, &timeo
,
38 !list_empty(&psock
->ingress_msg
) ||
39 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
40 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
41 remove_wait_queue(sk_sleep(sk
), &wait
);
45 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
46 struct msghdr
*msg
, int len
, int flags
)
48 struct iov_iter
*iter
= &msg
->msg_iter
;
49 int peek
= flags
& MSG_PEEK
;
50 int i
, ret
, copied
= 0;
51 struct sk_msg
*msg_rx
;
53 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
56 while (copied
!= len
) {
57 struct scatterlist
*sge
;
59 if (unlikely(!msg_rx
))
67 sge
= sk_msg_elem(msg_rx
, i
);
70 if (copied
+ copy
> len
)
72 ret
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
82 sk_mem_uncharge(sk
, copy
);
83 msg_rx
->sg
.size
-= copy
;
86 sk_msg_iter_var_next(i
);
91 sk_msg_iter_var_next(i
);
96 } while (i
!= msg_rx
->sg
.end
);
99 msg_rx
= list_next_entry(msg_rx
, list
);
103 msg_rx
->sg
.start
= i
;
104 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
105 list_del(&msg_rx
->list
);
107 consume_skb(msg_rx
->skb
);
110 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
111 struct sk_msg
, list
);
116 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
118 int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
119 int nonblock
, int flags
, int *addr_len
)
121 struct sk_psock
*psock
;
124 if (unlikely(flags
& MSG_ERRQUEUE
))
125 return inet_recv_error(sk
, msg
, len
, addr_len
);
126 if (!skb_queue_empty(&sk
->sk_receive_queue
))
127 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
129 psock
= sk_psock_get(sk
);
130 if (unlikely(!psock
))
131 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
134 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
139 timeo
= sock_rcvtimeo(sk
, nonblock
);
140 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
142 if (skb_queue_empty(&sk
->sk_receive_queue
))
143 goto msg_bytes_ready
;
145 sk_psock_put(sk
, psock
);
146 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
157 sk_psock_put(sk
, psock
);
161 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
162 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
164 bool apply
= apply_bytes
;
165 struct scatterlist
*sge
;
166 u32 size
, copied
= 0;
170 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
175 tmp
->sg
.start
= msg
->sg
.start
;
178 sge
= sk_msg_elem(msg
, i
);
179 size
= (apply
&& apply_bytes
< sge
->length
) ?
180 apply_bytes
: sge
->length
;
181 if (!sk_wmem_schedule(sk
, size
)) {
187 sk_mem_charge(sk
, size
);
188 sk_msg_xfer(tmp
, msg
, i
, size
);
191 get_page(sk_msg_page(tmp
, i
));
192 sk_msg_iter_var_next(i
);
199 } while (i
!= msg
->sg
.end
);
203 msg
->sg
.size
-= apply_bytes
;
204 sk_psock_queue_msg(psock
, tmp
);
205 sk_psock_data_ready(sk
, psock
);
207 sk_msg_free(sk
, tmp
);
215 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
216 int flags
, bool uncharge
)
218 bool apply
= apply_bytes
;
219 struct scatterlist
*sge
;
227 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
228 size
= (apply
&& apply_bytes
< sge
->length
) ?
229 apply_bytes
: sge
->length
;
233 tcp_rate_check_app_limited(sk
);
235 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
237 flags
|= MSG_SENDPAGE_NOPOLICY
;
238 ret
= kernel_sendpage_locked(sk
,
239 page
, off
, size
, flags
);
241 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
252 sk_mem_uncharge(sk
, ret
);
260 sk_msg_iter_next(msg
, start
);
261 sg_init_table(sge
, 1);
262 if (msg
->sg
.start
== msg
->sg
.end
)
265 if (apply
&& !apply_bytes
)
272 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
273 u32 apply_bytes
, int flags
, bool uncharge
)
278 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
283 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
284 u32 bytes
, int flags
)
286 bool ingress
= sk_msg_to_ingress(msg
);
287 struct sk_psock
*psock
= sk_psock_get(sk
);
290 if (unlikely(!psock
)) {
291 sk_msg_free(sk
, msg
);
294 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
295 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
296 sk_psock_put(sk
, psock
);
299 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
301 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
302 struct sk_msg
*msg
, int *copied
, int flags
)
304 bool cork
= false, enospc
= sk_msg_full(msg
);
305 struct sock
*sk_redir
;
306 u32 tosend
, delta
= 0;
310 if (psock
->eval
== __SK_NONE
) {
311 /* Track delta in msg size to add/subtract it on SK_DROP from
312 * returned to user copied size. This ensures user doesn't
313 * get a positive return code with msg_cut_data and SK_DROP
316 delta
= msg
->sg
.size
;
317 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
318 if (msg
->sg
.size
< delta
)
319 delta
-= msg
->sg
.size
;
324 if (msg
->cork_bytes
&&
325 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
326 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
328 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
329 GFP_ATOMIC
| __GFP_NOWARN
);
333 memcpy(psock
->cork
, msg
, sizeof(*msg
));
337 tosend
= msg
->sg
.size
;
338 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
339 tosend
= psock
->apply_bytes
;
341 switch (psock
->eval
) {
343 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
345 *copied
-= sk_msg_free(sk
, msg
);
348 sk_msg_apply_bytes(psock
, tosend
);
351 sk_redir
= psock
->sk_redir
;
352 sk_msg_apply_bytes(psock
, tosend
);
357 sk_msg_return(sk
, msg
, tosend
);
359 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
361 if (unlikely(ret
< 0)) {
362 int free
= sk_msg_free_nocharge(sk
, msg
);
368 sk_msg_free(sk
, msg
);
376 sk_msg_free_partial(sk
, msg
, tosend
);
377 sk_msg_apply_bytes(psock
, tosend
);
378 *copied
-= (tosend
+ delta
);
383 if (!psock
->apply_bytes
) {
384 psock
->eval
= __SK_NONE
;
385 if (psock
->sk_redir
) {
386 sock_put(psock
->sk_redir
);
387 psock
->sk_redir
= NULL
;
391 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
392 msg
->sg
.data
[msg
->sg
.start
].length
)
398 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
400 struct sk_msg tmp
, *msg_tx
= NULL
;
401 int copied
= 0, err
= 0;
402 struct sk_psock
*psock
;
406 /* Don't let internal do_tcp_sendpages() flags through */
407 flags
= (msg
->msg_flags
& ~MSG_SENDPAGE_DECRYPTED
);
408 flags
|= MSG_NO_SHARED_FRAGS
;
410 psock
= sk_psock_get(sk
);
411 if (unlikely(!psock
))
412 return tcp_sendmsg(sk
, msg
, size
);
415 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
416 while (msg_data_left(msg
)) {
425 copy
= msg_data_left(msg
);
426 if (!sk_stream_memory_free(sk
))
427 goto wait_for_sndbuf
;
429 msg_tx
= psock
->cork
;
435 osize
= msg_tx
->sg
.size
;
436 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
439 goto wait_for_memory
;
441 copy
= msg_tx
->sg
.size
- osize
;
444 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
447 sk_msg_trim(sk
, msg_tx
, osize
);
452 if (psock
->cork_bytes
) {
453 if (size
> psock
->cork_bytes
)
454 psock
->cork_bytes
= 0;
456 psock
->cork_bytes
-= size
;
457 if (psock
->cork_bytes
&& !enospc
)
459 /* All cork bytes are accounted, rerun the prog. */
460 psock
->eval
= __SK_NONE
;
461 psock
->cork_bytes
= 0;
464 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
465 if (unlikely(err
< 0))
469 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
471 err
= sk_stream_wait_memory(sk
, &timeo
);
473 if (msg_tx
&& msg_tx
!= psock
->cork
)
474 sk_msg_free(sk
, msg_tx
);
480 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
482 sk_psock_put(sk
, psock
);
483 return copied
? copied
: err
;
486 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
487 size_t size
, int flags
)
489 struct sk_msg tmp
, *msg
= NULL
;
490 int err
= 0, copied
= 0;
491 struct sk_psock
*psock
;
494 psock
= sk_psock_get(sk
);
495 if (unlikely(!psock
))
496 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
506 /* Catch case where ring is full and sendpage is stalled. */
507 if (unlikely(sk_msg_full(msg
)))
510 sk_msg_page_add(msg
, page
, size
, offset
);
511 sk_mem_charge(sk
, size
);
513 if (sk_msg_full(msg
))
515 if (psock
->cork_bytes
) {
516 if (size
> psock
->cork_bytes
)
517 psock
->cork_bytes
= 0;
519 psock
->cork_bytes
-= size
;
520 if (psock
->cork_bytes
&& !enospc
)
522 /* All cork bytes are accounted, rerun the prog. */
523 psock
->eval
= __SK_NONE
;
524 psock
->cork_bytes
= 0;
527 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
530 sk_psock_put(sk
, psock
);
531 return copied
? copied
: err
;
534 static void tcp_bpf_remove(struct sock
*sk
, struct sk_psock
*psock
)
536 struct sk_psock_link
*link
;
538 while ((link
= sk_psock_link_pop(psock
))) {
539 sk_psock_unlink(sk
, link
);
540 sk_psock_free_link(link
);
544 static void tcp_bpf_unhash(struct sock
*sk
)
546 void (*saved_unhash
)(struct sock
*sk
);
547 struct sk_psock
*psock
;
550 psock
= sk_psock(sk
);
551 if (unlikely(!psock
)) {
553 if (sk
->sk_prot
->unhash
)
554 sk
->sk_prot
->unhash(sk
);
558 saved_unhash
= psock
->saved_unhash
;
559 tcp_bpf_remove(sk
, psock
);
564 static void tcp_bpf_close(struct sock
*sk
, long timeout
)
566 void (*saved_close
)(struct sock
*sk
, long timeout
);
567 struct sk_psock
*psock
;
571 psock
= sk_psock(sk
);
572 if (unlikely(!psock
)) {
575 return sk
->sk_prot
->close(sk
, timeout
);
578 saved_close
= psock
->saved_close
;
579 tcp_bpf_remove(sk
, psock
);
582 saved_close(sk
, timeout
);
597 static struct proto
*tcpv6_prot_saved __read_mostly
;
598 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
599 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
601 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
604 prot
[TCP_BPF_BASE
] = *base
;
605 prot
[TCP_BPF_BASE
].unhash
= tcp_bpf_unhash
;
606 prot
[TCP_BPF_BASE
].close
= tcp_bpf_close
;
607 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
608 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
610 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
611 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
612 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
615 static void tcp_bpf_check_v6_needs_rebuild(struct sock
*sk
, struct proto
*ops
)
617 if (sk
->sk_family
== AF_INET6
&&
618 unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
619 spin_lock_bh(&tcpv6_prot_lock
);
620 if (likely(ops
!= tcpv6_prot_saved
)) {
621 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
622 smp_store_release(&tcpv6_prot_saved
, ops
);
624 spin_unlock_bh(&tcpv6_prot_lock
);
628 static int __init
tcp_bpf_v4_build_proto(void)
630 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
633 core_initcall(tcp_bpf_v4_build_proto
);
635 static void tcp_bpf_update_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
637 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
638 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
640 sk_psock_update_proto(sk
, psock
, &tcp_bpf_prots
[family
][config
]);
643 static void tcp_bpf_reinit_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
645 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
646 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
648 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
649 * or added requiring sk_prot hook updates. We keep original saved
650 * hooks in this case.
652 sk
->sk_prot
= &tcp_bpf_prots
[family
][config
];
655 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
657 /* In order to avoid retpoline, we make assumptions when we call
658 * into ops if e.g. a psock is not present. Make sure they are
659 * indeed valid assumptions.
661 return ops
->recvmsg
== tcp_recvmsg
&&
662 ops
->sendmsg
== tcp_sendmsg
&&
663 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
666 void tcp_bpf_reinit(struct sock
*sk
)
668 struct sk_psock
*psock
;
670 sock_owned_by_me(sk
);
673 psock
= sk_psock(sk
);
674 tcp_bpf_reinit_sk_prot(sk
, psock
);
678 int tcp_bpf_init(struct sock
*sk
)
680 struct proto
*ops
= READ_ONCE(sk
->sk_prot
);
681 struct sk_psock
*psock
;
683 sock_owned_by_me(sk
);
686 psock
= sk_psock(sk
);
687 if (unlikely(!psock
|| psock
->sk_proto
||
688 tcp_bpf_assert_proto_ops(ops
))) {
692 tcp_bpf_check_v6_needs_rebuild(sk
, ops
);
693 tcp_bpf_update_sk_prot(sk
, psock
);