1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
12 static bool tcp_bpf_stream_read(const struct sock
*sk
)
14 struct sk_psock
*psock
;
20 empty
= list_empty(&psock
->ingress_msg
);
25 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
26 int flags
, long timeo
, int *err
)
28 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
31 add_wait_queue(sk_sleep(sk
), &wait
);
32 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
33 ret
= sk_wait_event(sk
, &timeo
,
34 !list_empty(&psock
->ingress_msg
) ||
35 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
37 remove_wait_queue(sk_sleep(sk
), &wait
);
41 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
42 struct msghdr
*msg
, int len
, int flags
)
44 struct iov_iter
*iter
= &msg
->msg_iter
;
45 int peek
= flags
& MSG_PEEK
;
46 int i
, ret
, copied
= 0;
47 struct sk_msg
*msg_rx
;
49 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
52 while (copied
!= len
) {
53 struct scatterlist
*sge
;
55 if (unlikely(!msg_rx
))
63 sge
= sk_msg_elem(msg_rx
, i
);
66 if (copied
+ copy
> len
)
68 ret
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
78 sk_mem_uncharge(sk
, copy
);
79 msg_rx
->sg
.size
-= copy
;
82 sk_msg_iter_var_next(i
);
87 sk_msg_iter_var_next(i
);
92 } while (i
!= msg_rx
->sg
.end
);
95 msg_rx
= list_next_entry(msg_rx
, list
);
100 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
101 list_del(&msg_rx
->list
);
103 consume_skb(msg_rx
->skb
);
106 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
107 struct sk_msg
, list
);
112 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
114 int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
115 int nonblock
, int flags
, int *addr_len
)
117 struct sk_psock
*psock
;
120 if (unlikely(flags
& MSG_ERRQUEUE
))
121 return inet_recv_error(sk
, msg
, len
, addr_len
);
122 if (!skb_queue_empty(&sk
->sk_receive_queue
))
123 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
125 psock
= sk_psock_get(sk
);
126 if (unlikely(!psock
))
127 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
130 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
135 timeo
= sock_rcvtimeo(sk
, nonblock
);
136 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
138 if (skb_queue_empty(&sk
->sk_receive_queue
))
139 goto msg_bytes_ready
;
141 sk_psock_put(sk
, psock
);
142 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
153 sk_psock_put(sk
, psock
);
157 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
158 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
160 bool apply
= apply_bytes
;
161 struct scatterlist
*sge
;
162 u32 size
, copied
= 0;
166 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
171 tmp
->sg
.start
= msg
->sg
.start
;
174 sge
= sk_msg_elem(msg
, i
);
175 size
= (apply
&& apply_bytes
< sge
->length
) ?
176 apply_bytes
: sge
->length
;
177 if (!sk_wmem_schedule(sk
, size
)) {
183 sk_mem_charge(sk
, size
);
184 sk_msg_xfer(tmp
, msg
, i
, size
);
187 get_page(sk_msg_page(tmp
, i
));
188 sk_msg_iter_var_next(i
);
195 } while (i
!= msg
->sg
.end
);
199 msg
->sg
.size
-= apply_bytes
;
200 sk_psock_queue_msg(psock
, tmp
);
201 sk
->sk_data_ready(sk
);
203 sk_msg_free(sk
, tmp
);
211 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
212 int flags
, bool uncharge
)
214 bool apply
= apply_bytes
;
215 struct scatterlist
*sge
;
221 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
222 size
= (apply
&& apply_bytes
< sge
->length
) ?
223 apply_bytes
: sge
->length
;
227 tcp_rate_check_app_limited(sk
);
229 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
238 sk_mem_uncharge(sk
, ret
);
246 sk_msg_iter_next(msg
, start
);
247 sg_init_table(sge
, 1);
248 if (msg
->sg
.start
== msg
->sg
.end
)
251 if (apply
&& !apply_bytes
)
258 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
259 u32 apply_bytes
, int flags
, bool uncharge
)
264 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
269 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
270 u32 bytes
, int flags
)
272 bool ingress
= sk_msg_to_ingress(msg
);
273 struct sk_psock
*psock
= sk_psock_get(sk
);
276 if (unlikely(!psock
)) {
277 sk_msg_free(sk
, msg
);
280 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
281 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
282 sk_psock_put(sk
, psock
);
285 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
287 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
288 struct sk_msg
*msg
, int *copied
, int flags
)
290 bool cork
= false, enospc
= msg
->sg
.start
== msg
->sg
.end
;
291 struct sock
*sk_redir
;
296 if (psock
->eval
== __SK_NONE
)
297 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
299 if (msg
->cork_bytes
&&
300 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
301 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
303 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
304 GFP_ATOMIC
| __GFP_NOWARN
);
308 memcpy(psock
->cork
, msg
, sizeof(*msg
));
312 tosend
= msg
->sg
.size
;
313 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
314 tosend
= psock
->apply_bytes
;
316 switch (psock
->eval
) {
318 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
320 *copied
-= sk_msg_free(sk
, msg
);
323 sk_msg_apply_bytes(psock
, tosend
);
326 sk_redir
= psock
->sk_redir
;
327 sk_msg_apply_bytes(psock
, tosend
);
332 sk_msg_return(sk
, msg
, tosend
);
334 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
336 if (unlikely(ret
< 0)) {
337 int free
= sk_msg_free_nocharge(sk
, msg
);
343 sk_msg_free(sk
, msg
);
351 sk_msg_free_partial(sk
, msg
, tosend
);
352 sk_msg_apply_bytes(psock
, tosend
);
358 if (!psock
->apply_bytes
) {
359 psock
->eval
= __SK_NONE
;
360 if (psock
->sk_redir
) {
361 sock_put(psock
->sk_redir
);
362 psock
->sk_redir
= NULL
;
366 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
367 msg
->sg
.data
[msg
->sg
.start
].length
)
373 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
375 struct sk_msg tmp
, *msg_tx
= NULL
;
376 int flags
= msg
->msg_flags
| MSG_NO_SHARED_FRAGS
;
377 int copied
= 0, err
= 0;
378 struct sk_psock
*psock
;
381 psock
= sk_psock_get(sk
);
382 if (unlikely(!psock
))
383 return tcp_sendmsg(sk
, msg
, size
);
386 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
387 while (msg_data_left(msg
)) {
396 copy
= msg_data_left(msg
);
397 if (!sk_stream_memory_free(sk
))
398 goto wait_for_sndbuf
;
400 msg_tx
= psock
->cork
;
406 osize
= msg_tx
->sg
.size
;
407 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
410 goto wait_for_memory
;
412 copy
= msg_tx
->sg
.size
- osize
;
415 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
418 sk_msg_trim(sk
, msg_tx
, osize
);
423 if (psock
->cork_bytes
) {
424 if (size
> psock
->cork_bytes
)
425 psock
->cork_bytes
= 0;
427 psock
->cork_bytes
-= size
;
428 if (psock
->cork_bytes
&& !enospc
)
430 /* All cork bytes are accounted, rerun the prog. */
431 psock
->eval
= __SK_NONE
;
432 psock
->cork_bytes
= 0;
435 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
436 if (unlikely(err
< 0))
440 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
442 err
= sk_stream_wait_memory(sk
, &timeo
);
444 if (msg_tx
&& msg_tx
!= psock
->cork
)
445 sk_msg_free(sk
, msg_tx
);
451 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
453 sk_psock_put(sk
, psock
);
454 return copied
? copied
: err
;
457 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
458 size_t size
, int flags
)
460 struct sk_msg tmp
, *msg
= NULL
;
461 int err
= 0, copied
= 0;
462 struct sk_psock
*psock
;
465 psock
= sk_psock_get(sk
);
466 if (unlikely(!psock
))
467 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
477 /* Catch case where ring is full and sendpage is stalled. */
478 if (unlikely(sk_msg_full(msg
)))
481 sk_msg_page_add(msg
, page
, size
, offset
);
482 sk_mem_charge(sk
, size
);
484 if (sk_msg_full(msg
))
486 if (psock
->cork_bytes
) {
487 if (size
> psock
->cork_bytes
)
488 psock
->cork_bytes
= 0;
490 psock
->cork_bytes
-= size
;
491 if (psock
->cork_bytes
&& !enospc
)
493 /* All cork bytes are accounted, rerun the prog. */
494 psock
->eval
= __SK_NONE
;
495 psock
->cork_bytes
= 0;
498 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
501 sk_psock_put(sk
, psock
);
502 return copied
? copied
: err
;
505 static void tcp_bpf_remove(struct sock
*sk
, struct sk_psock
*psock
)
507 struct sk_psock_link
*link
;
509 sk_psock_cork_free(psock
);
510 __sk_psock_purge_ingress_msg(psock
);
511 while ((link
= sk_psock_link_pop(psock
))) {
512 sk_psock_unlink(sk
, link
);
513 sk_psock_free_link(link
);
517 static void tcp_bpf_unhash(struct sock
*sk
)
519 void (*saved_unhash
)(struct sock
*sk
);
520 struct sk_psock
*psock
;
523 psock
= sk_psock(sk
);
524 if (unlikely(!psock
)) {
526 if (sk
->sk_prot
->unhash
)
527 sk
->sk_prot
->unhash(sk
);
531 saved_unhash
= psock
->saved_unhash
;
532 tcp_bpf_remove(sk
, psock
);
537 static void tcp_bpf_close(struct sock
*sk
, long timeout
)
539 void (*saved_close
)(struct sock
*sk
, long timeout
);
540 struct sk_psock
*psock
;
544 psock
= sk_psock(sk
);
545 if (unlikely(!psock
)) {
548 return sk
->sk_prot
->close(sk
, timeout
);
551 saved_close
= psock
->saved_close
;
552 tcp_bpf_remove(sk
, psock
);
555 saved_close(sk
, timeout
);
570 static struct proto
*tcpv6_prot_saved __read_mostly
;
571 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
572 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
574 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
577 prot
[TCP_BPF_BASE
] = *base
;
578 prot
[TCP_BPF_BASE
].unhash
= tcp_bpf_unhash
;
579 prot
[TCP_BPF_BASE
].close
= tcp_bpf_close
;
580 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
581 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
583 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
584 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
585 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
588 static void tcp_bpf_check_v6_needs_rebuild(struct sock
*sk
, struct proto
*ops
)
590 if (sk
->sk_family
== AF_INET6
&&
591 unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
592 spin_lock_bh(&tcpv6_prot_lock
);
593 if (likely(ops
!= tcpv6_prot_saved
)) {
594 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
595 smp_store_release(&tcpv6_prot_saved
, ops
);
597 spin_unlock_bh(&tcpv6_prot_lock
);
601 static int __init
tcp_bpf_v4_build_proto(void)
603 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
606 core_initcall(tcp_bpf_v4_build_proto
);
608 static void tcp_bpf_update_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
610 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
611 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
613 sk_psock_update_proto(sk
, psock
, &tcp_bpf_prots
[family
][config
]);
616 static void tcp_bpf_reinit_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
618 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
619 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
621 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
622 * or added requiring sk_prot hook updates. We keep original saved
623 * hooks in this case.
625 sk
->sk_prot
= &tcp_bpf_prots
[family
][config
];
628 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
630 /* In order to avoid retpoline, we make assumptions when we call
631 * into ops if e.g. a psock is not present. Make sure they are
632 * indeed valid assumptions.
634 return ops
->recvmsg
== tcp_recvmsg
&&
635 ops
->sendmsg
== tcp_sendmsg
&&
636 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
639 void tcp_bpf_reinit(struct sock
*sk
)
641 struct sk_psock
*psock
;
643 sock_owned_by_me(sk
);
646 psock
= sk_psock(sk
);
647 tcp_bpf_reinit_sk_prot(sk
, psock
);
651 int tcp_bpf_init(struct sock
*sk
)
653 struct proto
*ops
= READ_ONCE(sk
->sk_prot
);
654 struct sk_psock
*psock
;
656 sock_owned_by_me(sk
);
659 psock
= sk_psock(sk
);
660 if (unlikely(!psock
|| psock
->sk_proto
||
661 tcp_bpf_assert_proto_ops(ops
))) {
665 tcp_bpf_check_v6_needs_rebuild(sk
, ops
);
666 tcp_bpf_update_sk_prot(sk
, psock
);