1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
14 struct msghdr
*msg
, int len
, int flags
)
16 struct iov_iter
*iter
= &msg
->msg_iter
;
17 int peek
= flags
& MSG_PEEK
;
18 struct sk_msg
*msg_rx
;
21 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
24 while (copied
!= len
) {
25 struct scatterlist
*sge
;
27 if (unlikely(!msg_rx
))
35 sge
= sk_msg_elem(msg_rx
, i
);
38 if (copied
+ copy
> len
)
40 copy
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
42 return copied
? copied
: -EFAULT
;
49 sk_mem_uncharge(sk
, copy
);
50 msg_rx
->sg
.size
-= copy
;
53 sk_msg_iter_var_next(i
);
58 /* Lets not optimize peek case if copy_page_to_iter
59 * didn't copy the entire length lets just break.
61 if (copy
!= sge
->length
)
63 sk_msg_iter_var_next(i
);
68 } while (i
!= msg_rx
->sg
.end
);
71 if (msg_rx
== list_last_entry(&psock
->ingress_msg
,
74 msg_rx
= list_next_entry(msg_rx
, list
);
79 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
80 list_del(&msg_rx
->list
);
82 consume_skb(msg_rx
->skb
);
85 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
91 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
93 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
94 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
96 bool apply
= apply_bytes
;
97 struct scatterlist
*sge
;
102 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
107 tmp
->sg
.start
= msg
->sg
.start
;
110 sge
= sk_msg_elem(msg
, i
);
111 size
= (apply
&& apply_bytes
< sge
->length
) ?
112 apply_bytes
: sge
->length
;
113 if (!sk_wmem_schedule(sk
, size
)) {
119 sk_mem_charge(sk
, size
);
120 sk_msg_xfer(tmp
, msg
, i
, size
);
123 get_page(sk_msg_page(tmp
, i
));
124 sk_msg_iter_var_next(i
);
131 } while (i
!= msg
->sg
.end
);
135 sk_psock_queue_msg(psock
, tmp
);
136 sk_psock_data_ready(sk
, psock
);
138 sk_msg_free(sk
, tmp
);
146 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
147 int flags
, bool uncharge
)
149 bool apply
= apply_bytes
;
150 struct scatterlist
*sge
;
158 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
159 size
= (apply
&& apply_bytes
< sge
->length
) ?
160 apply_bytes
: sge
->length
;
164 tcp_rate_check_app_limited(sk
);
166 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
168 flags
|= MSG_SENDPAGE_NOPOLICY
;
169 ret
= kernel_sendpage_locked(sk
,
170 page
, off
, size
, flags
);
172 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
183 sk_mem_uncharge(sk
, ret
);
191 sk_msg_iter_next(msg
, start
);
192 sg_init_table(sge
, 1);
193 if (msg
->sg
.start
== msg
->sg
.end
)
196 if (apply
&& !apply_bytes
)
203 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
204 u32 apply_bytes
, int flags
, bool uncharge
)
209 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
214 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
215 u32 bytes
, int flags
)
217 bool ingress
= sk_msg_to_ingress(msg
);
218 struct sk_psock
*psock
= sk_psock_get(sk
);
221 if (unlikely(!psock
)) {
222 sk_msg_free(sk
, msg
);
225 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
226 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
227 sk_psock_put(sk
, psock
);
230 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
232 #ifdef CONFIG_BPF_STREAM_PARSER
233 static bool tcp_bpf_stream_read(const struct sock
*sk
)
235 struct sk_psock
*psock
;
239 psock
= sk_psock(sk
);
241 empty
= list_empty(&psock
->ingress_msg
);
246 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
247 int flags
, long timeo
, int *err
)
249 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
252 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
258 add_wait_queue(sk_sleep(sk
), &wait
);
259 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
260 ret
= sk_wait_event(sk
, &timeo
,
261 !list_empty(&psock
->ingress_msg
) ||
262 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
263 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
264 remove_wait_queue(sk_sleep(sk
), &wait
);
268 static int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
269 int nonblock
, int flags
, int *addr_len
)
271 struct sk_psock
*psock
;
274 if (unlikely(flags
& MSG_ERRQUEUE
))
275 return inet_recv_error(sk
, msg
, len
, addr_len
);
277 psock
= sk_psock_get(sk
);
278 if (unlikely(!psock
))
279 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
280 if (!skb_queue_empty(&sk
->sk_receive_queue
) &&
281 sk_psock_queue_empty(psock
)) {
282 sk_psock_put(sk
, psock
);
283 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
287 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
292 timeo
= sock_rcvtimeo(sk
, nonblock
);
293 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
295 if (!sk_psock_queue_empty(psock
))
296 goto msg_bytes_ready
;
298 sk_psock_put(sk
, psock
);
299 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
310 sk_psock_put(sk
, psock
);
314 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
315 struct sk_msg
*msg
, int *copied
, int flags
)
317 bool cork
= false, enospc
= sk_msg_full(msg
);
318 struct sock
*sk_redir
;
319 u32 tosend
, delta
= 0;
323 if (psock
->eval
== __SK_NONE
) {
324 /* Track delta in msg size to add/subtract it on SK_DROP from
325 * returned to user copied size. This ensures user doesn't
326 * get a positive return code with msg_cut_data and SK_DROP
329 delta
= msg
->sg
.size
;
330 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
331 delta
-= msg
->sg
.size
;
334 if (msg
->cork_bytes
&&
335 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
336 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
338 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
339 GFP_ATOMIC
| __GFP_NOWARN
);
343 memcpy(psock
->cork
, msg
, sizeof(*msg
));
347 tosend
= msg
->sg
.size
;
348 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
349 tosend
= psock
->apply_bytes
;
351 switch (psock
->eval
) {
353 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
355 *copied
-= sk_msg_free(sk
, msg
);
358 sk_msg_apply_bytes(psock
, tosend
);
361 sk_redir
= psock
->sk_redir
;
362 sk_msg_apply_bytes(psock
, tosend
);
367 sk_msg_return(sk
, msg
, tosend
);
369 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
371 if (unlikely(ret
< 0)) {
372 int free
= sk_msg_free_nocharge(sk
, msg
);
378 sk_msg_free(sk
, msg
);
386 sk_msg_free_partial(sk
, msg
, tosend
);
387 sk_msg_apply_bytes(psock
, tosend
);
388 *copied
-= (tosend
+ delta
);
393 if (!psock
->apply_bytes
) {
394 psock
->eval
= __SK_NONE
;
395 if (psock
->sk_redir
) {
396 sock_put(psock
->sk_redir
);
397 psock
->sk_redir
= NULL
;
401 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
402 msg
->sg
.data
[msg
->sg
.start
].length
)
408 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
410 struct sk_msg tmp
, *msg_tx
= NULL
;
411 int copied
= 0, err
= 0;
412 struct sk_psock
*psock
;
416 /* Don't let internal do_tcp_sendpages() flags through */
417 flags
= (msg
->msg_flags
& ~MSG_SENDPAGE_DECRYPTED
);
418 flags
|= MSG_NO_SHARED_FRAGS
;
420 psock
= sk_psock_get(sk
);
421 if (unlikely(!psock
))
422 return tcp_sendmsg(sk
, msg
, size
);
425 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
426 while (msg_data_left(msg
)) {
435 copy
= msg_data_left(msg
);
436 if (!sk_stream_memory_free(sk
))
437 goto wait_for_sndbuf
;
439 msg_tx
= psock
->cork
;
445 osize
= msg_tx
->sg
.size
;
446 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
449 goto wait_for_memory
;
451 copy
= msg_tx
->sg
.size
- osize
;
454 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
457 sk_msg_trim(sk
, msg_tx
, osize
);
462 if (psock
->cork_bytes
) {
463 if (size
> psock
->cork_bytes
)
464 psock
->cork_bytes
= 0;
466 psock
->cork_bytes
-= size
;
467 if (psock
->cork_bytes
&& !enospc
)
469 /* All cork bytes are accounted, rerun the prog. */
470 psock
->eval
= __SK_NONE
;
471 psock
->cork_bytes
= 0;
474 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
475 if (unlikely(err
< 0))
479 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
481 err
= sk_stream_wait_memory(sk
, &timeo
);
483 if (msg_tx
&& msg_tx
!= psock
->cork
)
484 sk_msg_free(sk
, msg_tx
);
490 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
492 sk_psock_put(sk
, psock
);
493 return copied
? copied
: err
;
496 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
497 size_t size
, int flags
)
499 struct sk_msg tmp
, *msg
= NULL
;
500 int err
= 0, copied
= 0;
501 struct sk_psock
*psock
;
504 psock
= sk_psock_get(sk
);
505 if (unlikely(!psock
))
506 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
516 /* Catch case where ring is full and sendpage is stalled. */
517 if (unlikely(sk_msg_full(msg
)))
520 sk_msg_page_add(msg
, page
, size
, offset
);
521 sk_mem_charge(sk
, size
);
523 if (sk_msg_full(msg
))
525 if (psock
->cork_bytes
) {
526 if (size
> psock
->cork_bytes
)
527 psock
->cork_bytes
= 0;
529 psock
->cork_bytes
-= size
;
530 if (psock
->cork_bytes
&& !enospc
)
532 /* All cork bytes are accounted, rerun the prog. */
533 psock
->eval
= __SK_NONE
;
534 psock
->cork_bytes
= 0;
537 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
540 sk_psock_put(sk
, psock
);
541 return copied
? copied
: err
;
556 static struct proto
*tcpv6_prot_saved __read_mostly
;
557 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
558 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
560 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
563 prot
[TCP_BPF_BASE
] = *base
;
564 prot
[TCP_BPF_BASE
].unhash
= sock_map_unhash
;
565 prot
[TCP_BPF_BASE
].close
= sock_map_close
;
566 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
567 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
569 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
570 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
571 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
574 static void tcp_bpf_check_v6_needs_rebuild(struct proto
*ops
)
576 if (unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
577 spin_lock_bh(&tcpv6_prot_lock
);
578 if (likely(ops
!= tcpv6_prot_saved
)) {
579 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
580 smp_store_release(&tcpv6_prot_saved
, ops
);
582 spin_unlock_bh(&tcpv6_prot_lock
);
586 static int __init
tcp_bpf_v4_build_proto(void)
588 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
591 core_initcall(tcp_bpf_v4_build_proto
);
593 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
595 /* In order to avoid retpoline, we make assumptions when we call
596 * into ops if e.g. a psock is not present. Make sure they are
597 * indeed valid assumptions.
599 return ops
->recvmsg
== tcp_recvmsg
&&
600 ops
->sendmsg
== tcp_sendmsg
&&
601 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
604 struct proto
*tcp_bpf_get_proto(struct sock
*sk
, struct sk_psock
*psock
)
606 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
607 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
609 if (sk
->sk_family
== AF_INET6
) {
610 if (tcp_bpf_assert_proto_ops(psock
->sk_proto
))
611 return ERR_PTR(-EINVAL
);
613 tcp_bpf_check_v6_needs_rebuild(psock
->sk_proto
);
616 return &tcp_bpf_prots
[family
][config
];
619 /* If a child got cloned from a listening socket that had tcp_bpf
620 * protocol callbacks installed, we need to restore the callbacks to
621 * the default ones because the child does not inherit the psock state
622 * that tcp_bpf callbacks expect.
624 void tcp_bpf_clone(const struct sock
*sk
, struct sock
*newsk
)
626 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
627 struct proto
*prot
= newsk
->sk_prot
;
629 if (prot
== &tcp_bpf_prots
[family
][TCP_BPF_BASE
])
630 newsk
->sk_prot
= sk
->sk_prot_creator
;
632 #endif /* CONFIG_BPF_STREAM_PARSER */