1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 #include <trace/events/sock.h>
13 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
15 if (msg
->sg
.end
> msg
->sg
.start
&&
16 elem_first_coalesce
< msg
->sg
.end
)
19 if (msg
->sg
.end
< msg
->sg
.start
&&
20 (elem_first_coalesce
> msg
->sg
.start
||
21 elem_first_coalesce
< msg
->sg
.end
))
27 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
28 int elem_first_coalesce
)
30 struct page_frag
*pfrag
= sk_page_frag(sk
);
31 u32 osize
= msg
->sg
.size
;
36 struct scatterlist
*sge
;
40 if (!sk_page_frag_refill(sk
, pfrag
)) {
45 orig_offset
= pfrag
->offset
;
46 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
47 if (!sk_wmem_schedule(sk
, use
)) {
53 sk_msg_iter_var_prev(i
);
54 sge
= &msg
->sg
.data
[i
];
56 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
57 sg_page(sge
) == pfrag
->page
&&
58 sge
->offset
+ sge
->length
== orig_offset
) {
61 if (sk_msg_full(msg
)) {
66 sge
= &msg
->sg
.data
[msg
->sg
.end
];
68 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
69 get_page(pfrag
->page
);
70 sk_msg_iter_next(msg
, end
);
73 sk_mem_charge(sk
, use
);
82 sk_msg_trim(sk
, msg
, osize
);
85 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
87 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
90 int i
= src
->sg
.start
;
91 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
92 struct scatterlist
*sgd
= NULL
;
96 if (sge
->length
> off
)
99 sk_msg_iter_var_next(i
);
100 if (i
== src
->sg
.end
&& off
)
102 sge
= sk_msg_elem(src
, i
);
106 sge_len
= sge
->length
- off
;
111 sgd
= sk_msg_elem(dst
, dst
->sg
.end
- 1);
114 (sg_page(sge
) == sg_page(sgd
)) &&
115 (sg_virt(sge
) + off
== sg_virt(sgd
) + sgd
->length
)) {
116 sgd
->length
+= sge_len
;
117 dst
->sg
.size
+= sge_len
;
118 } else if (!sk_msg_full(dst
)) {
119 sge_off
= sge
->offset
+ off
;
120 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
127 sk_mem_charge(sk
, sge_len
);
128 sk_msg_iter_var_next(i
);
129 if (i
== src
->sg
.end
&& len
)
131 sge
= sk_msg_elem(src
, i
);
136 EXPORT_SYMBOL_GPL(sk_msg_clone
);
138 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
140 int i
= msg
->sg
.start
;
143 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
145 if (bytes
< sge
->length
) {
146 sge
->length
-= bytes
;
147 sge
->offset
+= bytes
;
148 sk_mem_uncharge(sk
, bytes
);
152 sk_mem_uncharge(sk
, sge
->length
);
153 bytes
-= sge
->length
;
156 sk_msg_iter_var_next(i
);
157 } while (bytes
&& i
!= msg
->sg
.end
);
160 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
162 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
164 int i
= msg
->sg
.start
;
167 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
168 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
170 sk_mem_uncharge(sk
, uncharge
);
172 sk_msg_iter_var_next(i
);
173 } while (i
!= msg
->sg
.end
);
175 EXPORT_SYMBOL_GPL(sk_msg_return
);
177 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
180 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
181 u32 len
= sge
->length
;
183 /* When the skb owns the memory we free it from consume_skb path. */
186 sk_mem_uncharge(sk
, len
);
187 put_page(sg_page(sge
));
189 memset(sge
, 0, sizeof(*sge
));
193 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
196 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
199 while (msg
->sg
.size
) {
200 msg
->sg
.size
-= sge
->length
;
201 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
202 sk_msg_iter_var_next(i
);
203 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
204 sge
= sk_msg_elem(msg
, i
);
206 consume_skb(msg
->skb
);
211 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
213 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
215 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
217 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
219 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
221 EXPORT_SYMBOL_GPL(sk_msg_free
);
223 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
224 u32 bytes
, bool charge
)
226 struct scatterlist
*sge
;
227 u32 i
= msg
->sg
.start
;
230 sge
= sk_msg_elem(msg
, i
);
233 if (bytes
< sge
->length
) {
235 sk_mem_uncharge(sk
, bytes
);
236 sge
->length
-= bytes
;
237 sge
->offset
+= bytes
;
238 msg
->sg
.size
-= bytes
;
242 msg
->sg
.size
-= sge
->length
;
243 bytes
-= sge
->length
;
244 sk_msg_free_elem(sk
, msg
, i
, charge
);
245 sk_msg_iter_var_next(i
);
246 sk_msg_check_to_free(msg
, i
, bytes
);
251 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
253 __sk_msg_free_partial(sk
, msg
, bytes
, true);
255 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
257 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
260 __sk_msg_free_partial(sk
, msg
, bytes
, false);
263 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
265 int trim
= msg
->sg
.size
- len
;
273 sk_msg_iter_var_prev(i
);
275 while (msg
->sg
.data
[i
].length
&&
276 trim
>= msg
->sg
.data
[i
].length
) {
277 trim
-= msg
->sg
.data
[i
].length
;
278 sk_msg_free_elem(sk
, msg
, i
, true);
279 sk_msg_iter_var_prev(i
);
284 msg
->sg
.data
[i
].length
-= trim
;
285 sk_mem_uncharge(sk
, trim
);
286 /* Adjust copybreak if it falls into the trimmed part of last buf */
287 if (msg
->sg
.curr
== i
&& msg
->sg
.copybreak
> msg
->sg
.data
[i
].length
)
288 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
290 sk_msg_iter_var_next(i
);
293 /* If we trim data a full sg elem before curr pointer update
294 * copybreak and current so that any future copy operations
295 * start at new copy location.
296 * However trimmed data that has not yet been used in a copy op
297 * does not require an update.
300 msg
->sg
.curr
= msg
->sg
.start
;
301 msg
->sg
.copybreak
= 0;
302 } else if (sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.curr
) >=
303 sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
)) {
304 sk_msg_iter_var_prev(i
);
306 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
309 EXPORT_SYMBOL_GPL(sk_msg_trim
);
311 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
312 struct sk_msg
*msg
, u32 bytes
)
314 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
315 const int to_max_pages
= MAX_MSG_FRAGS
;
316 struct page
*pages
[MAX_MSG_FRAGS
];
317 ssize_t orig
, copied
, use
, offset
;
322 maxpages
= to_max_pages
- num_elems
;
328 copied
= iov_iter_get_pages2(from
, pages
, bytes
, maxpages
,
336 msg
->sg
.size
+= copied
;
339 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
340 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
341 pages
[i
], use
, offset
);
342 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
343 sk_mem_charge(sk
, use
);
347 sk_msg_iter_next(msg
, end
);
351 /* When zerocopy is mixed with sk_msg_*copy* operations we
352 * may have a copybreak set in this case clear and prefer
353 * zerocopy remainder when possible.
355 msg
->sg
.copybreak
= 0;
356 msg
->sg
.curr
= msg
->sg
.end
;
359 /* Revert iov_iter updates, msg will need to use 'trim' later if it
360 * also needs to be cleared.
363 iov_iter_revert(from
, msg
->sg
.size
- orig
);
366 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
368 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
369 struct sk_msg
*msg
, u32 bytes
)
371 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
372 struct scatterlist
*sge
;
377 sge
= sk_msg_elem(msg
, i
);
378 /* This is possible if a trim operation shrunk the buffer */
379 if (msg
->sg
.copybreak
>= sge
->length
) {
380 msg
->sg
.copybreak
= 0;
381 sk_msg_iter_var_next(i
);
382 if (i
== msg
->sg
.end
)
384 sge
= sk_msg_elem(msg
, i
);
387 buf_size
= sge
->length
- msg
->sg
.copybreak
;
388 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
389 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
390 msg
->sg
.copybreak
+= copy
;
391 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
392 ret
= copy_from_iter_nocache(to
, copy
, from
);
394 ret
= copy_from_iter(to
, copy
, from
);
402 msg
->sg
.copybreak
= 0;
403 sk_msg_iter_var_next(i
);
404 } while (i
!= msg
->sg
.end
);
409 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
411 /* Receive sk_msg from psock->ingress_msg to @msg. */
412 int sk_msg_recvmsg(struct sock
*sk
, struct sk_psock
*psock
, struct msghdr
*msg
,
415 struct iov_iter
*iter
= &msg
->msg_iter
;
416 int peek
= flags
& MSG_PEEK
;
417 struct sk_msg
*msg_rx
;
420 msg_rx
= sk_psock_peek_msg(psock
);
421 while (copied
!= len
) {
422 struct scatterlist
*sge
;
424 if (unlikely(!msg_rx
))
427 i
= msg_rx
->sg
.start
;
432 sge
= sk_msg_elem(msg_rx
, i
);
435 if (copied
+ copy
> len
)
438 copy
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
440 copied
= copied
? copied
: -EFAULT
;
449 sk_mem_uncharge(sk
, copy
);
450 msg_rx
->sg
.size
-= copy
;
453 sk_msg_iter_var_next(i
);
458 /* Lets not optimize peek case if copy_page_to_iter
459 * didn't copy the entire length lets just break.
461 if (copy
!= sge
->length
)
463 sk_msg_iter_var_next(i
);
468 } while ((i
!= msg_rx
->sg
.end
) && !sg_is_last(sge
));
470 if (unlikely(peek
)) {
471 msg_rx
= sk_psock_next_msg(psock
, msg_rx
);
477 msg_rx
->sg
.start
= i
;
478 if (!sge
->length
&& (i
== msg_rx
->sg
.end
|| sg_is_last(sge
))) {
479 msg_rx
= sk_psock_dequeue_msg(psock
);
480 kfree_sk_msg(msg_rx
);
482 msg_rx
= sk_psock_peek_msg(psock
);
487 EXPORT_SYMBOL_GPL(sk_msg_recvmsg
);
489 bool sk_msg_is_readable(struct sock
*sk
)
491 struct sk_psock
*psock
;
495 psock
= sk_psock(sk
);
497 empty
= list_empty(&psock
->ingress_msg
);
501 EXPORT_SYMBOL_GPL(sk_msg_is_readable
);
503 static struct sk_msg
*alloc_sk_msg(gfp_t gfp
)
507 msg
= kzalloc(sizeof(*msg
), gfp
| __GFP_NOWARN
);
510 sg_init_marker(msg
->sg
.data
, NR_MSG_FRAG_IDS
);
514 static struct sk_msg
*sk_psock_create_ingress_msg(struct sock
*sk
,
517 if (atomic_read(&sk
->sk_rmem_alloc
) > sk
->sk_rcvbuf
)
520 if (!sk_rmem_schedule(sk
, skb
, skb
->truesize
))
523 return alloc_sk_msg(GFP_KERNEL
);
526 static int sk_psock_skb_ingress_enqueue(struct sk_buff
*skb
,
528 struct sk_psock
*psock
,
534 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, off
, len
);
536 /* skb linearize may fail with ENOMEM, but lets simply try again
537 * later if this happens. Under memory pressure we don't want to
538 * drop the skb. We need to linearize the skb so that the mapping
539 * in skb_to_sgvec can not error.
541 if (skb_linearize(skb
))
544 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, off
, len
);
545 if (unlikely(num_sge
< 0))
551 msg
->sg
.size
= copied
;
552 msg
->sg
.end
= num_sge
;
555 sk_psock_queue_msg(psock
, msg
);
556 sk_psock_data_ready(sk
, psock
);
560 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
,
563 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
,
566 struct sock
*sk
= psock
->sk
;
570 /* If we are receiving on the same sock skb->sk is already assigned,
571 * skip memory accounting and owner transition seeing it already set
574 if (unlikely(skb
->sk
== sk
))
575 return sk_psock_skb_ingress_self(psock
, skb
, off
, len
);
576 msg
= sk_psock_create_ingress_msg(sk
, skb
);
580 /* This will transition ownership of the data from the socket where
581 * the BPF program was run initiating the redirect to the socket
582 * we will eventually receive this data on. The data will be released
583 * from skb_consume found in __tcp_bpf_recvmsg() after its been copied
586 skb_set_owner_r(skb
, sk
);
587 err
= sk_psock_skb_ingress_enqueue(skb
, off
, len
, psock
, sk
, msg
);
593 /* Puts an skb on the ingress queue of the socket already assigned to the
594 * skb. In this case we do not need to check memory limits or skb_set_owner_r
595 * because the skb is already accounted for here.
597 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
,
600 struct sk_msg
*msg
= alloc_sk_msg(GFP_ATOMIC
);
601 struct sock
*sk
= psock
->sk
;
606 skb_set_owner_r(skb
, sk
);
607 err
= sk_psock_skb_ingress_enqueue(skb
, off
, len
, psock
, sk
, msg
);
613 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
614 u32 off
, u32 len
, bool ingress
)
619 if (!sock_writeable(psock
->sk
))
621 return skb_send_sock(psock
->sk
, skb
, off
, len
);
624 err
= sk_psock_skb_ingress(psock
, skb
, off
, len
);
630 static void sk_psock_skb_state(struct sk_psock
*psock
,
631 struct sk_psock_work_state
*state
,
634 spin_lock_bh(&psock
->ingress_lock
);
635 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
639 spin_unlock_bh(&psock
->ingress_lock
);
642 static void sk_psock_backlog(struct work_struct
*work
)
644 struct delayed_work
*dwork
= to_delayed_work(work
);
645 struct sk_psock
*psock
= container_of(dwork
, struct sk_psock
, work
);
646 struct sk_psock_work_state
*state
= &psock
->work_state
;
647 struct sk_buff
*skb
= NULL
;
648 u32 len
= 0, off
= 0;
652 mutex_lock(&psock
->work_mutex
);
653 if (unlikely(state
->len
)) {
658 while ((skb
= skb_peek(&psock
->ingress_skb
))) {
661 if (skb_bpf_strparser(skb
)) {
662 struct strp_msg
*stm
= strp_msg(skb
);
667 ingress
= skb_bpf_ingress(skb
);
668 skb_bpf_redirect_clear(skb
);
671 if (!sock_flag(psock
->sk
, SOCK_DEAD
))
672 ret
= sk_psock_handle_skb(psock
, skb
, off
,
675 if (ret
== -EAGAIN
) {
676 sk_psock_skb_state(psock
, state
, len
, off
);
678 /* Delay slightly to prioritize any
679 * other work that might be here.
681 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
682 schedule_delayed_work(&psock
->work
, 1);
685 /* Hard errors break pipe and stop xmit. */
686 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
687 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
694 skb
= skb_dequeue(&psock
->ingress_skb
);
698 mutex_unlock(&psock
->work_mutex
);
701 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
703 struct sk_psock
*psock
;
706 write_lock_bh(&sk
->sk_callback_lock
);
708 if (sk_is_inet(sk
) && inet_csk_has_ulp(sk
)) {
709 psock
= ERR_PTR(-EINVAL
);
713 if (sk
->sk_user_data
) {
714 psock
= ERR_PTR(-EBUSY
);
718 psock
= kzalloc_node(sizeof(*psock
), GFP_ATOMIC
| __GFP_NOWARN
, node
);
720 psock
= ERR_PTR(-ENOMEM
);
724 prot
= READ_ONCE(sk
->sk_prot
);
726 psock
->eval
= __SK_NONE
;
727 psock
->sk_proto
= prot
;
728 psock
->saved_unhash
= prot
->unhash
;
729 psock
->saved_destroy
= prot
->destroy
;
730 psock
->saved_close
= prot
->close
;
731 psock
->saved_write_space
= sk
->sk_write_space
;
733 INIT_LIST_HEAD(&psock
->link
);
734 spin_lock_init(&psock
->link_lock
);
736 INIT_DELAYED_WORK(&psock
->work
, sk_psock_backlog
);
737 mutex_init(&psock
->work_mutex
);
738 INIT_LIST_HEAD(&psock
->ingress_msg
);
739 spin_lock_init(&psock
->ingress_lock
);
740 skb_queue_head_init(&psock
->ingress_skb
);
742 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
743 refcount_set(&psock
->refcnt
, 1);
745 __rcu_assign_sk_user_data_with_flags(sk
, psock
,
746 SK_USER_DATA_NOCOPY
|
751 write_unlock_bh(&sk
->sk_callback_lock
);
754 EXPORT_SYMBOL_GPL(sk_psock_init
);
756 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
758 struct sk_psock_link
*link
;
760 spin_lock_bh(&psock
->link_lock
);
761 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
764 list_del(&link
->list
);
765 spin_unlock_bh(&psock
->link_lock
);
769 static void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
771 struct sk_msg
*msg
, *tmp
;
773 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
774 list_del(&msg
->list
);
775 sk_msg_free(psock
->sk
, msg
);
780 static void __sk_psock_zap_ingress(struct sk_psock
*psock
)
784 while ((skb
= skb_dequeue(&psock
->ingress_skb
)) != NULL
) {
785 skb_bpf_redirect_clear(skb
);
786 sock_drop(psock
->sk
, skb
);
788 __sk_psock_purge_ingress_msg(psock
);
791 static void sk_psock_link_destroy(struct sk_psock
*psock
)
793 struct sk_psock_link
*link
, *tmp
;
795 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
796 list_del(&link
->list
);
797 sk_psock_free_link(link
);
801 void sk_psock_stop(struct sk_psock
*psock
)
803 spin_lock_bh(&psock
->ingress_lock
);
804 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
805 sk_psock_cork_free(psock
);
806 spin_unlock_bh(&psock
->ingress_lock
);
809 static void sk_psock_done_strp(struct sk_psock
*psock
);
811 static void sk_psock_destroy(struct work_struct
*work
)
813 struct sk_psock
*psock
= container_of(to_rcu_work(work
),
814 struct sk_psock
, rwork
);
815 /* No sk_callback_lock since already detached. */
817 sk_psock_done_strp(psock
);
819 cancel_delayed_work_sync(&psock
->work
);
820 __sk_psock_zap_ingress(psock
);
821 mutex_destroy(&psock
->work_mutex
);
823 psock_progs_drop(&psock
->progs
);
825 sk_psock_link_destroy(psock
);
826 sk_psock_cork_free(psock
);
829 sock_put(psock
->sk_redir
);
831 sock_put(psock
->sk_pair
);
836 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
838 write_lock_bh(&sk
->sk_callback_lock
);
839 sk_psock_restore_proto(sk
, psock
);
840 rcu_assign_sk_user_data(sk
, NULL
);
841 if (psock
->progs
.stream_parser
)
842 sk_psock_stop_strp(sk
, psock
);
843 else if (psock
->progs
.stream_verdict
|| psock
->progs
.skb_verdict
)
844 sk_psock_stop_verdict(sk
, psock
);
845 write_unlock_bh(&sk
->sk_callback_lock
);
847 sk_psock_stop(psock
);
849 INIT_RCU_WORK(&psock
->rwork
, sk_psock_destroy
);
850 queue_rcu_work(system_wq
, &psock
->rwork
);
852 EXPORT_SYMBOL_GPL(sk_psock_drop
);
854 static int sk_psock_map_verd(int verdict
, bool redir
)
858 return redir
? __SK_REDIRECT
: __SK_PASS
;
867 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
870 struct bpf_prog
*prog
;
874 prog
= READ_ONCE(psock
->progs
.msg_parser
);
875 if (unlikely(!prog
)) {
880 sk_msg_compute_data_pointers(msg
);
882 ret
= bpf_prog_run_pin_on_cpu(prog
, msg
);
883 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
884 psock
->apply_bytes
= msg
->apply_bytes
;
885 if (ret
== __SK_REDIRECT
) {
886 if (psock
->sk_redir
) {
887 sock_put(psock
->sk_redir
);
888 psock
->sk_redir
= NULL
;
890 if (!msg
->sk_redir
) {
894 psock
->redir_ingress
= sk_msg_to_ingress(msg
);
895 psock
->sk_redir
= msg
->sk_redir
;
896 sock_hold(psock
->sk_redir
);
902 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
904 static int sk_psock_skb_redirect(struct sk_psock
*from
, struct sk_buff
*skb
)
906 struct sk_psock
*psock_other
;
907 struct sock
*sk_other
;
909 sk_other
= skb_bpf_redirect_fetch(skb
);
910 /* This error is a buggy BPF program, it returned a redirect
911 * return code, but then didn't set a redirect interface.
913 if (unlikely(!sk_other
)) {
914 skb_bpf_redirect_clear(skb
);
915 sock_drop(from
->sk
, skb
);
918 psock_other
= sk_psock(sk_other
);
919 /* This error indicates the socket is being torn down or had another
920 * error that caused the pipe to break. We can't send a packet on
921 * a socket that is in this state so we drop the skb.
923 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
)) {
924 skb_bpf_redirect_clear(skb
);
925 sock_drop(from
->sk
, skb
);
928 spin_lock_bh(&psock_other
->ingress_lock
);
929 if (!sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
)) {
930 spin_unlock_bh(&psock_other
->ingress_lock
);
931 skb_bpf_redirect_clear(skb
);
932 sock_drop(from
->sk
, skb
);
936 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
937 schedule_delayed_work(&psock_other
->work
, 0);
938 spin_unlock_bh(&psock_other
->ingress_lock
);
942 static void sk_psock_tls_verdict_apply(struct sk_buff
*skb
,
943 struct sk_psock
*from
, int verdict
)
947 sk_psock_skb_redirect(from
, skb
);
956 int sk_psock_tls_strp_read(struct sk_psock
*psock
, struct sk_buff
*skb
)
958 struct bpf_prog
*prog
;
962 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
966 skb_bpf_redirect_clear(skb
);
967 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
968 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
971 sk_psock_tls_verdict_apply(skb
, psock
, ret
);
975 EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read
);
977 static int sk_psock_verdict_apply(struct sk_psock
*psock
, struct sk_buff
*skb
,
980 struct sock
*sk_other
;
987 sk_other
= psock
->sk
;
988 if (sock_flag(sk_other
, SOCK_DEAD
) ||
989 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
992 skb_bpf_set_ingress(skb
);
994 /* If the queue is empty then we can submit directly
995 * into the msg queue. If its not empty we have to
996 * queue work otherwise we may get OOO data. Otherwise,
997 * if sk_psock_skb_ingress errors will be handled by
998 * retrying later from workqueue.
1000 if (skb_queue_empty(&psock
->ingress_skb
)) {
1003 if (skb_bpf_strparser(skb
)) {
1004 struct strp_msg
*stm
= strp_msg(skb
);
1007 len
= stm
->full_len
;
1009 err
= sk_psock_skb_ingress_self(psock
, skb
, off
, len
);
1012 spin_lock_bh(&psock
->ingress_lock
);
1013 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
1014 skb_queue_tail(&psock
->ingress_skb
, skb
);
1015 schedule_delayed_work(&psock
->work
, 0);
1018 spin_unlock_bh(&psock
->ingress_lock
);
1024 tcp_eat_skb(psock
->sk
, skb
);
1025 err
= sk_psock_skb_redirect(psock
, skb
);
1030 skb_bpf_redirect_clear(skb
);
1031 tcp_eat_skb(psock
->sk
, skb
);
1032 sock_drop(psock
->sk
, skb
);
1038 static void sk_psock_write_space(struct sock
*sk
)
1040 struct sk_psock
*psock
;
1041 void (*write_space
)(struct sock
*sk
) = NULL
;
1044 psock
= sk_psock(sk
);
1045 if (likely(psock
)) {
1046 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
1047 schedule_delayed_work(&psock
->work
, 0);
1048 write_space
= psock
->saved_write_space
;
1055 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1056 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
1058 struct sk_psock
*psock
;
1059 struct bpf_prog
*prog
;
1060 int ret
= __SK_DROP
;
1065 psock
= sk_psock(sk
);
1066 if (unlikely(!psock
)) {
1070 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
1074 skb_bpf_redirect_clear(skb
);
1075 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1076 skb_bpf_set_strparser(skb
);
1077 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
1080 sk_psock_verdict_apply(psock
, skb
, ret
);
1085 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
1090 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
1092 struct sk_psock
*psock
= container_of(strp
, struct sk_psock
, strp
);
1093 struct bpf_prog
*prog
;
1097 prog
= READ_ONCE(psock
->progs
.stream_parser
);
1099 skb
->sk
= psock
->sk
;
1100 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1107 /* Called with socket lock held. */
1108 static void sk_psock_strp_data_ready(struct sock
*sk
)
1110 struct sk_psock
*psock
;
1112 trace_sk_data_ready(sk
);
1115 psock
= sk_psock(sk
);
1116 if (likely(psock
)) {
1117 if (tls_sw_has_ctx_rx(sk
)) {
1118 psock
->saved_data_ready(sk
);
1120 write_lock_bh(&sk
->sk_callback_lock
);
1121 strp_data_ready(&psock
->strp
);
1122 write_unlock_bh(&sk
->sk_callback_lock
);
1128 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
1132 static const struct strp_callbacks cb
= {
1133 .rcv_msg
= sk_psock_strp_read
,
1134 .read_sock_done
= sk_psock_strp_read_done
,
1135 .parse_msg
= sk_psock_strp_parse
,
1138 ret
= strp_init(&psock
->strp
, sk
, &cb
);
1140 sk_psock_set_state(psock
, SK_PSOCK_RX_STRP_ENABLED
);
1145 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
1147 if (psock
->saved_data_ready
)
1150 psock
->saved_data_ready
= sk
->sk_data_ready
;
1151 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
1152 sk
->sk_write_space
= sk_psock_write_space
;
1155 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
1157 psock_set_prog(&psock
->progs
.stream_parser
, NULL
);
1159 if (!psock
->saved_data_ready
)
1162 sk
->sk_data_ready
= psock
->saved_data_ready
;
1163 psock
->saved_data_ready
= NULL
;
1164 strp_stop(&psock
->strp
);
1167 static void sk_psock_done_strp(struct sk_psock
*psock
)
1169 /* Parser has been stopped */
1170 if (sk_psock_test_state(psock
, SK_PSOCK_RX_STRP_ENABLED
))
1171 strp_done(&psock
->strp
);
1174 static void sk_psock_done_strp(struct sk_psock
*psock
)
1177 #endif /* CONFIG_BPF_STREAM_PARSER */
1179 static int sk_psock_verdict_recv(struct sock
*sk
, struct sk_buff
*skb
)
1181 struct sk_psock
*psock
;
1182 struct bpf_prog
*prog
;
1183 int ret
= __SK_DROP
;
1187 psock
= sk_psock(sk
);
1188 if (unlikely(!psock
)) {
1190 tcp_eat_skb(sk
, skb
);
1194 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
1196 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
1199 skb_bpf_redirect_clear(skb
);
1200 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1201 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
1203 ret
= sk_psock_verdict_apply(psock
, skb
, ret
);
1211 static void sk_psock_verdict_data_ready(struct sock
*sk
)
1213 struct socket
*sock
= sk
->sk_socket
;
1214 const struct proto_ops
*ops
;
1217 trace_sk_data_ready(sk
);
1219 if (unlikely(!sock
))
1221 ops
= READ_ONCE(sock
->ops
);
1222 if (!ops
|| !ops
->read_skb
)
1224 copied
= ops
->read_skb(sk
, sk_psock_verdict_recv
);
1226 struct sk_psock
*psock
;
1229 psock
= sk_psock(sk
);
1231 sk_psock_data_ready(sk
, psock
);
1236 void sk_psock_start_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1238 if (psock
->saved_data_ready
)
1241 psock
->saved_data_ready
= sk
->sk_data_ready
;
1242 sk
->sk_data_ready
= sk_psock_verdict_data_ready
;
1243 sk
->sk_write_space
= sk_psock_write_space
;
1246 void sk_psock_stop_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1248 psock_set_prog(&psock
->progs
.stream_verdict
, NULL
);
1249 psock_set_prog(&psock
->progs
.skb_verdict
, NULL
);
1251 if (!psock
->saved_data_ready
)
1254 sk
->sk_data_ready
= psock
->saved_data_ready
;
1255 psock
->saved_data_ready
= NULL
;