1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
12 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
14 if (msg
->sg
.end
> msg
->sg
.start
&&
15 elem_first_coalesce
< msg
->sg
.end
)
18 if (msg
->sg
.end
< msg
->sg
.start
&&
19 (elem_first_coalesce
> msg
->sg
.start
||
20 elem_first_coalesce
< msg
->sg
.end
))
26 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
27 int elem_first_coalesce
)
29 struct page_frag
*pfrag
= sk_page_frag(sk
);
34 struct scatterlist
*sge
;
38 if (!sk_page_frag_refill(sk
, pfrag
))
41 orig_offset
= pfrag
->offset
;
42 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
43 if (!sk_wmem_schedule(sk
, use
))
47 sk_msg_iter_var_prev(i
);
48 sge
= &msg
->sg
.data
[i
];
50 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
51 sg_page(sge
) == pfrag
->page
&&
52 sge
->offset
+ sge
->length
== orig_offset
) {
55 if (sk_msg_full(msg
)) {
60 sge
= &msg
->sg
.data
[msg
->sg
.end
];
62 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
63 get_page(pfrag
->page
);
64 sk_msg_iter_next(msg
, end
);
67 sk_mem_charge(sk
, use
);
75 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
77 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
80 int i
= src
->sg
.start
;
81 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
82 struct scatterlist
*sgd
= NULL
;
86 if (sge
->length
> off
)
89 sk_msg_iter_var_next(i
);
90 if (i
== src
->sg
.end
&& off
)
92 sge
= sk_msg_elem(src
, i
);
96 sge_len
= sge
->length
- off
;
101 sgd
= sk_msg_elem(dst
, dst
->sg
.end
- 1);
104 (sg_page(sge
) == sg_page(sgd
)) &&
105 (sg_virt(sge
) + off
== sg_virt(sgd
) + sgd
->length
)) {
106 sgd
->length
+= sge_len
;
107 dst
->sg
.size
+= sge_len
;
108 } else if (!sk_msg_full(dst
)) {
109 sge_off
= sge
->offset
+ off
;
110 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
117 sk_mem_charge(sk
, sge_len
);
118 sk_msg_iter_var_next(i
);
119 if (i
== src
->sg
.end
&& len
)
121 sge
= sk_msg_elem(src
, i
);
126 EXPORT_SYMBOL_GPL(sk_msg_clone
);
128 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
130 int i
= msg
->sg
.start
;
133 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
135 if (bytes
< sge
->length
) {
136 sge
->length
-= bytes
;
137 sge
->offset
+= bytes
;
138 sk_mem_uncharge(sk
, bytes
);
142 sk_mem_uncharge(sk
, sge
->length
);
143 bytes
-= sge
->length
;
146 sk_msg_iter_var_next(i
);
147 } while (bytes
&& i
!= msg
->sg
.end
);
150 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
152 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
154 int i
= msg
->sg
.start
;
157 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
158 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
160 sk_mem_uncharge(sk
, uncharge
);
162 sk_msg_iter_var_next(i
);
163 } while (i
!= msg
->sg
.end
);
165 EXPORT_SYMBOL_GPL(sk_msg_return
);
167 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
170 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
171 u32 len
= sge
->length
;
173 /* When the skb owns the memory we free it from consume_skb path. */
176 sk_mem_uncharge(sk
, len
);
177 put_page(sg_page(sge
));
179 memset(sge
, 0, sizeof(*sge
));
183 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
186 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
189 while (msg
->sg
.size
) {
190 msg
->sg
.size
-= sge
->length
;
191 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
192 sk_msg_iter_var_next(i
);
193 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
194 sge
= sk_msg_elem(msg
, i
);
196 consume_skb(msg
->skb
);
201 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
203 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
205 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
207 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
209 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
211 EXPORT_SYMBOL_GPL(sk_msg_free
);
213 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
214 u32 bytes
, bool charge
)
216 struct scatterlist
*sge
;
217 u32 i
= msg
->sg
.start
;
220 sge
= sk_msg_elem(msg
, i
);
223 if (bytes
< sge
->length
) {
225 sk_mem_uncharge(sk
, bytes
);
226 sge
->length
-= bytes
;
227 sge
->offset
+= bytes
;
228 msg
->sg
.size
-= bytes
;
232 msg
->sg
.size
-= sge
->length
;
233 bytes
-= sge
->length
;
234 sk_msg_free_elem(sk
, msg
, i
, charge
);
235 sk_msg_iter_var_next(i
);
236 sk_msg_check_to_free(msg
, i
, bytes
);
241 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
243 __sk_msg_free_partial(sk
, msg
, bytes
, true);
245 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
247 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
250 __sk_msg_free_partial(sk
, msg
, bytes
, false);
253 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
255 int trim
= msg
->sg
.size
- len
;
263 sk_msg_iter_var_prev(i
);
265 while (msg
->sg
.data
[i
].length
&&
266 trim
>= msg
->sg
.data
[i
].length
) {
267 trim
-= msg
->sg
.data
[i
].length
;
268 sk_msg_free_elem(sk
, msg
, i
, true);
269 sk_msg_iter_var_prev(i
);
274 msg
->sg
.data
[i
].length
-= trim
;
275 sk_mem_uncharge(sk
, trim
);
276 /* Adjust copybreak if it falls into the trimmed part of last buf */
277 if (msg
->sg
.curr
== i
&& msg
->sg
.copybreak
> msg
->sg
.data
[i
].length
)
278 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
280 sk_msg_iter_var_next(i
);
283 /* If we trim data a full sg elem before curr pointer update
284 * copybreak and current so that any future copy operations
285 * start at new copy location.
286 * However trimed data that has not yet been used in a copy op
287 * does not require an update.
290 msg
->sg
.curr
= msg
->sg
.start
;
291 msg
->sg
.copybreak
= 0;
292 } else if (sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.curr
) >=
293 sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
)) {
294 sk_msg_iter_var_prev(i
);
296 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
299 EXPORT_SYMBOL_GPL(sk_msg_trim
);
301 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
302 struct sk_msg
*msg
, u32 bytes
)
304 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
305 const int to_max_pages
= MAX_MSG_FRAGS
;
306 struct page
*pages
[MAX_MSG_FRAGS
];
307 ssize_t orig
, copied
, use
, offset
;
312 maxpages
= to_max_pages
- num_elems
;
318 copied
= iov_iter_get_pages(from
, pages
, bytes
, maxpages
,
325 iov_iter_advance(from
, copied
);
327 msg
->sg
.size
+= copied
;
330 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
331 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
332 pages
[i
], use
, offset
);
333 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
334 sk_mem_charge(sk
, use
);
338 sk_msg_iter_next(msg
, end
);
342 /* When zerocopy is mixed with sk_msg_*copy* operations we
343 * may have a copybreak set in this case clear and prefer
344 * zerocopy remainder when possible.
346 msg
->sg
.copybreak
= 0;
347 msg
->sg
.curr
= msg
->sg
.end
;
350 /* Revert iov_iter updates, msg will need to use 'trim' later if it
351 * also needs to be cleared.
354 iov_iter_revert(from
, msg
->sg
.size
- orig
);
357 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
359 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
360 struct sk_msg
*msg
, u32 bytes
)
362 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
363 struct scatterlist
*sge
;
368 sge
= sk_msg_elem(msg
, i
);
369 /* This is possible if a trim operation shrunk the buffer */
370 if (msg
->sg
.copybreak
>= sge
->length
) {
371 msg
->sg
.copybreak
= 0;
372 sk_msg_iter_var_next(i
);
373 if (i
== msg
->sg
.end
)
375 sge
= sk_msg_elem(msg
, i
);
378 buf_size
= sge
->length
- msg
->sg
.copybreak
;
379 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
380 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
381 msg
->sg
.copybreak
+= copy
;
382 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
383 ret
= copy_from_iter_nocache(to
, copy
, from
);
385 ret
= copy_from_iter(to
, copy
, from
);
393 msg
->sg
.copybreak
= 0;
394 sk_msg_iter_var_next(i
);
395 } while (i
!= msg
->sg
.end
);
400 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
402 static struct sk_msg
*sk_psock_create_ingress_msg(struct sock
*sk
,
407 if (atomic_read(&sk
->sk_rmem_alloc
) > sk
->sk_rcvbuf
)
410 if (!sk_rmem_schedule(sk
, skb
, skb
->truesize
))
413 msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
421 static int sk_psock_skb_ingress_enqueue(struct sk_buff
*skb
,
422 struct sk_psock
*psock
,
428 /* skb linearize may fail with ENOMEM, but lets simply try again
429 * later if this happens. Under memory pressure we don't want to
430 * drop the skb. We need to linearize the skb so that the mapping
431 * in skb_to_sgvec can not error.
433 if (skb_linearize(skb
))
435 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, 0, skb
->len
);
436 if (unlikely(num_sge
< 0)) {
443 msg
->sg
.size
= copied
;
444 msg
->sg
.end
= num_sge
;
447 sk_psock_queue_msg(psock
, msg
);
448 sk_psock_data_ready(sk
, psock
);
452 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
);
454 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
)
456 struct sock
*sk
= psock
->sk
;
459 /* If we are receiving on the same sock skb->sk is already assigned,
460 * skip memory accounting and owner transition seeing it already set
463 if (unlikely(skb
->sk
== sk
))
464 return sk_psock_skb_ingress_self(psock
, skb
);
465 msg
= sk_psock_create_ingress_msg(sk
, skb
);
469 /* This will transition ownership of the data from the socket where
470 * the BPF program was run initiating the redirect to the socket
471 * we will eventually receive this data on. The data will be released
472 * from skb_consume found in __tcp_bpf_recvmsg() after its been copied
475 skb_set_owner_r(skb
, sk
);
476 return sk_psock_skb_ingress_enqueue(skb
, psock
, sk
, msg
);
479 /* Puts an skb on the ingress queue of the socket already assigned to the
480 * skb. In this case we do not need to check memory limits or skb_set_owner_r
481 * because the skb is already accounted for here.
483 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
)
485 struct sk_msg
*msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
486 struct sock
*sk
= psock
->sk
;
491 return sk_psock_skb_ingress_enqueue(skb
, psock
, sk
, msg
);
494 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
495 u32 off
, u32 len
, bool ingress
)
498 if (!sock_writeable(psock
->sk
))
500 return skb_send_sock_locked(psock
->sk
, skb
, off
, len
);
502 return sk_psock_skb_ingress(psock
, skb
);
505 static void sk_psock_backlog(struct work_struct
*work
)
507 struct sk_psock
*psock
= container_of(work
, struct sk_psock
, work
);
508 struct sk_psock_work_state
*state
= &psock
->work_state
;
514 /* Lock sock to avoid losing sk_socket during loop. */
515 lock_sock(psock
->sk
);
524 while ((skb
= skb_dequeue(&psock
->ingress_skb
))) {
528 ingress
= tcp_skb_bpf_ingress(skb
);
531 if (likely(psock
->sk
->sk_socket
))
532 ret
= sk_psock_handle_skb(psock
, skb
, off
,
535 if (ret
== -EAGAIN
) {
541 /* Hard errors break pipe and stop xmit. */
542 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
543 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
555 release_sock(psock
->sk
);
558 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
560 struct sk_psock
*psock
;
563 write_lock_bh(&sk
->sk_callback_lock
);
565 if (inet_csk_has_ulp(sk
)) {
566 psock
= ERR_PTR(-EINVAL
);
570 if (sk
->sk_user_data
) {
571 psock
= ERR_PTR(-EBUSY
);
575 psock
= kzalloc_node(sizeof(*psock
), GFP_ATOMIC
| __GFP_NOWARN
, node
);
577 psock
= ERR_PTR(-ENOMEM
);
581 prot
= READ_ONCE(sk
->sk_prot
);
583 psock
->eval
= __SK_NONE
;
584 psock
->sk_proto
= prot
;
585 psock
->saved_unhash
= prot
->unhash
;
586 psock
->saved_close
= prot
->close
;
587 psock
->saved_write_space
= sk
->sk_write_space
;
589 INIT_LIST_HEAD(&psock
->link
);
590 spin_lock_init(&psock
->link_lock
);
592 INIT_WORK(&psock
->work
, sk_psock_backlog
);
593 INIT_LIST_HEAD(&psock
->ingress_msg
);
594 skb_queue_head_init(&psock
->ingress_skb
);
596 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
597 refcount_set(&psock
->refcnt
, 1);
599 rcu_assign_sk_user_data_nocopy(sk
, psock
);
603 write_unlock_bh(&sk
->sk_callback_lock
);
606 EXPORT_SYMBOL_GPL(sk_psock_init
);
608 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
610 struct sk_psock_link
*link
;
612 spin_lock_bh(&psock
->link_lock
);
613 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
616 list_del(&link
->list
);
617 spin_unlock_bh(&psock
->link_lock
);
621 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
623 struct sk_msg
*msg
, *tmp
;
625 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
626 list_del(&msg
->list
);
627 sk_msg_free(psock
->sk
, msg
);
632 static void sk_psock_zap_ingress(struct sk_psock
*psock
)
634 __skb_queue_purge(&psock
->ingress_skb
);
635 __sk_psock_purge_ingress_msg(psock
);
638 static void sk_psock_link_destroy(struct sk_psock
*psock
)
640 struct sk_psock_link
*link
, *tmp
;
642 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
643 list_del(&link
->list
);
644 sk_psock_free_link(link
);
648 static void sk_psock_destroy_deferred(struct work_struct
*gc
)
650 struct sk_psock
*psock
= container_of(gc
, struct sk_psock
, gc
);
652 /* No sk_callback_lock since already detached. */
654 /* Parser has been stopped */
655 if (psock
->progs
.skb_parser
)
656 strp_done(&psock
->parser
.strp
);
658 cancel_work_sync(&psock
->work
);
660 psock_progs_drop(&psock
->progs
);
662 sk_psock_link_destroy(psock
);
663 sk_psock_cork_free(psock
);
664 sk_psock_zap_ingress(psock
);
667 sock_put(psock
->sk_redir
);
672 void sk_psock_destroy(struct rcu_head
*rcu
)
674 struct sk_psock
*psock
= container_of(rcu
, struct sk_psock
, rcu
);
676 INIT_WORK(&psock
->gc
, sk_psock_destroy_deferred
);
677 schedule_work(&psock
->gc
);
679 EXPORT_SYMBOL_GPL(sk_psock_destroy
);
681 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
683 sk_psock_cork_free(psock
);
684 sk_psock_zap_ingress(psock
);
686 write_lock_bh(&sk
->sk_callback_lock
);
687 sk_psock_restore_proto(sk
, psock
);
688 rcu_assign_sk_user_data(sk
, NULL
);
689 if (psock
->progs
.skb_parser
)
690 sk_psock_stop_strp(sk
, psock
);
691 else if (psock
->progs
.skb_verdict
)
692 sk_psock_stop_verdict(sk
, psock
);
693 write_unlock_bh(&sk
->sk_callback_lock
);
694 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
696 call_rcu(&psock
->rcu
, sk_psock_destroy
);
698 EXPORT_SYMBOL_GPL(sk_psock_drop
);
700 static int sk_psock_map_verd(int verdict
, bool redir
)
704 return redir
? __SK_REDIRECT
: __SK_PASS
;
713 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
716 struct bpf_prog
*prog
;
720 prog
= READ_ONCE(psock
->progs
.msg_parser
);
721 if (unlikely(!prog
)) {
726 sk_msg_compute_data_pointers(msg
);
728 ret
= bpf_prog_run_pin_on_cpu(prog
, msg
);
729 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
730 psock
->apply_bytes
= msg
->apply_bytes
;
731 if (ret
== __SK_REDIRECT
) {
733 sock_put(psock
->sk_redir
);
734 psock
->sk_redir
= msg
->sk_redir
;
735 if (!psock
->sk_redir
) {
739 sock_hold(psock
->sk_redir
);
745 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
747 static int sk_psock_bpf_run(struct sk_psock
*psock
, struct bpf_prog
*prog
,
750 bpf_compute_data_end_sk_skb(skb
);
751 return bpf_prog_run_pin_on_cpu(prog
, skb
);
754 static struct sk_psock
*sk_psock_from_strp(struct strparser
*strp
)
756 struct sk_psock_parser
*parser
;
758 parser
= container_of(strp
, struct sk_psock_parser
, strp
);
759 return container_of(parser
, struct sk_psock
, parser
);
762 static void sk_psock_skb_redirect(struct sk_buff
*skb
)
764 struct sk_psock
*psock_other
;
765 struct sock
*sk_other
;
767 sk_other
= tcp_skb_bpf_redirect_fetch(skb
);
768 /* This error is a buggy BPF program, it returned a redirect
769 * return code, but then didn't set a redirect interface.
771 if (unlikely(!sk_other
)) {
775 psock_other
= sk_psock(sk_other
);
776 /* This error indicates the socket is being torn down or had another
777 * error that caused the pipe to break. We can't send a packet on
778 * a socket that is in this state so we drop the skb.
780 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
) ||
781 !sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
)) {
786 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
787 schedule_work(&psock_other
->work
);
790 static void sk_psock_tls_verdict_apply(struct sk_buff
*skb
, struct sock
*sk
, int verdict
)
794 skb_set_owner_r(skb
, sk
);
795 sk_psock_skb_redirect(skb
);
804 int sk_psock_tls_strp_read(struct sk_psock
*psock
, struct sk_buff
*skb
)
806 struct bpf_prog
*prog
;
810 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
812 /* We skip full set_owner_r here because if we do a SK_PASS
813 * or SK_DROP we can skip skb memory accounting and use the
817 tcp_skb_bpf_redirect_clear(skb
);
818 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
819 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
822 sk_psock_tls_verdict_apply(skb
, psock
->sk
, ret
);
826 EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read
);
828 static void sk_psock_verdict_apply(struct sk_psock
*psock
,
829 struct sk_buff
*skb
, int verdict
)
831 struct tcp_skb_cb
*tcp
;
832 struct sock
*sk_other
;
837 sk_other
= psock
->sk
;
838 if (sock_flag(sk_other
, SOCK_DEAD
) ||
839 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
843 tcp
= TCP_SKB_CB(skb
);
844 tcp
->bpf
.flags
|= BPF_F_INGRESS
;
846 /* If the queue is empty then we can submit directly
847 * into the msg queue. If its not empty we have to
848 * queue work otherwise we may get OOO data. Otherwise,
849 * if sk_psock_skb_ingress errors will be handled by
850 * retrying later from workqueue.
852 if (skb_queue_empty(&psock
->ingress_skb
)) {
853 err
= sk_psock_skb_ingress_self(psock
, skb
);
856 skb_queue_tail(&psock
->ingress_skb
, skb
);
857 schedule_work(&psock
->work
);
861 sk_psock_skb_redirect(skb
);
870 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
872 struct sk_psock
*psock
;
873 struct bpf_prog
*prog
;
879 psock
= sk_psock(sk
);
880 if (unlikely(!psock
)) {
884 skb_set_owner_r(skb
, sk
);
885 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
887 tcp_skb_bpf_redirect_clear(skb
);
888 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
889 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
891 sk_psock_verdict_apply(psock
, skb
, ret
);
896 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
901 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
903 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
904 struct bpf_prog
*prog
;
908 prog
= READ_ONCE(psock
->progs
.skb_parser
);
911 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
918 /* Called with socket lock held. */
919 static void sk_psock_strp_data_ready(struct sock
*sk
)
921 struct sk_psock
*psock
;
924 psock
= sk_psock(sk
);
926 if (tls_sw_has_ctx_rx(sk
)) {
927 psock
->parser
.saved_data_ready(sk
);
929 write_lock_bh(&sk
->sk_callback_lock
);
930 strp_data_ready(&psock
->parser
.strp
);
931 write_unlock_bh(&sk
->sk_callback_lock
);
937 static int sk_psock_verdict_recv(read_descriptor_t
*desc
, struct sk_buff
*skb
,
938 unsigned int offset
, size_t orig_len
)
940 struct sock
*sk
= (struct sock
*)desc
->arg
.data
;
941 struct sk_psock
*psock
;
942 struct bpf_prog
*prog
;
946 /* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */
947 skb
= skb_clone(skb
, GFP_ATOMIC
);
949 desc
->error
= -ENOMEM
;
954 psock
= sk_psock(sk
);
955 if (unlikely(!psock
)) {
960 skb_set_owner_r(skb
, sk
);
961 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
963 tcp_skb_bpf_redirect_clear(skb
);
964 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
965 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
967 sk_psock_verdict_apply(psock
, skb
, ret
);
973 static void sk_psock_verdict_data_ready(struct sock
*sk
)
975 struct socket
*sock
= sk
->sk_socket
;
976 read_descriptor_t desc
;
978 if (unlikely(!sock
|| !sock
->ops
|| !sock
->ops
->read_sock
))
985 sock
->ops
->read_sock(sk
, &desc
, sk_psock_verdict_recv
);
988 static void sk_psock_write_space(struct sock
*sk
)
990 struct sk_psock
*psock
;
991 void (*write_space
)(struct sock
*sk
) = NULL
;
994 psock
= sk_psock(sk
);
996 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
997 schedule_work(&psock
->work
);
998 write_space
= psock
->saved_write_space
;
1005 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
1007 static const struct strp_callbacks cb
= {
1008 .rcv_msg
= sk_psock_strp_read
,
1009 .read_sock_done
= sk_psock_strp_read_done
,
1010 .parse_msg
= sk_psock_strp_parse
,
1013 psock
->parser
.enabled
= false;
1014 return strp_init(&psock
->parser
.strp
, sk
, &cb
);
1017 void sk_psock_start_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1019 struct sk_psock_parser
*parser
= &psock
->parser
;
1021 if (parser
->enabled
)
1024 parser
->saved_data_ready
= sk
->sk_data_ready
;
1025 sk
->sk_data_ready
= sk_psock_verdict_data_ready
;
1026 sk
->sk_write_space
= sk_psock_write_space
;
1027 parser
->enabled
= true;
1030 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
1032 struct sk_psock_parser
*parser
= &psock
->parser
;
1034 if (parser
->enabled
)
1037 parser
->saved_data_ready
= sk
->sk_data_ready
;
1038 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
1039 sk
->sk_write_space
= sk_psock_write_space
;
1040 parser
->enabled
= true;
1043 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
1045 struct sk_psock_parser
*parser
= &psock
->parser
;
1047 if (!parser
->enabled
)
1050 sk
->sk_data_ready
= parser
->saved_data_ready
;
1051 parser
->saved_data_ready
= NULL
;
1052 strp_stop(&parser
->strp
);
1053 parser
->enabled
= false;
1056 void sk_psock_stop_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1058 struct sk_psock_parser
*parser
= &psock
->parser
;
1060 if (!parser
->enabled
)
1063 sk
->sk_data_ready
= parser
->saved_data_ready
;
1064 parser
->saved_data_ready
= NULL
;
1065 parser
->enabled
= false;