1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
13 if (msg
->sg
.end
> msg
->sg
.start
&&
14 elem_first_coalesce
< msg
->sg
.end
)
17 if (msg
->sg
.end
< msg
->sg
.start
&&
18 (elem_first_coalesce
> msg
->sg
.start
||
19 elem_first_coalesce
< msg
->sg
.end
))
25 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
26 int elem_first_coalesce
)
28 struct page_frag
*pfrag
= sk_page_frag(sk
);
33 struct scatterlist
*sge
;
37 if (!sk_page_frag_refill(sk
, pfrag
))
40 orig_offset
= pfrag
->offset
;
41 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
42 if (!sk_wmem_schedule(sk
, use
))
46 sk_msg_iter_var_prev(i
);
47 sge
= &msg
->sg
.data
[i
];
49 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
50 sg_page(sge
) == pfrag
->page
&&
51 sge
->offset
+ sge
->length
== orig_offset
) {
54 if (sk_msg_full(msg
)) {
59 sge
= &msg
->sg
.data
[msg
->sg
.end
];
61 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
62 get_page(pfrag
->page
);
63 sk_msg_iter_next(msg
, end
);
66 sk_mem_charge(sk
, use
);
74 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
76 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
79 int i
= src
->sg
.start
;
80 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
81 struct scatterlist
*sgd
= NULL
;
85 if (sge
->length
> off
)
88 sk_msg_iter_var_next(i
);
89 if (i
== src
->sg
.end
&& off
)
91 sge
= sk_msg_elem(src
, i
);
95 sge_len
= sge
->length
- off
;
100 sgd
= sk_msg_elem(dst
, dst
->sg
.end
- 1);
103 (sg_page(sge
) == sg_page(sgd
)) &&
104 (sg_virt(sge
) + off
== sg_virt(sgd
) + sgd
->length
)) {
105 sgd
->length
+= sge_len
;
106 dst
->sg
.size
+= sge_len
;
107 } else if (!sk_msg_full(dst
)) {
108 sge_off
= sge
->offset
+ off
;
109 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
116 sk_mem_charge(sk
, sge_len
);
117 sk_msg_iter_var_next(i
);
118 if (i
== src
->sg
.end
&& len
)
120 sge
= sk_msg_elem(src
, i
);
125 EXPORT_SYMBOL_GPL(sk_msg_clone
);
127 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
129 int i
= msg
->sg
.start
;
132 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
134 if (bytes
< sge
->length
) {
135 sge
->length
-= bytes
;
136 sge
->offset
+= bytes
;
137 sk_mem_uncharge(sk
, bytes
);
141 sk_mem_uncharge(sk
, sge
->length
);
142 bytes
-= sge
->length
;
145 sk_msg_iter_var_next(i
);
146 } while (bytes
&& i
!= msg
->sg
.end
);
149 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
151 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
153 int i
= msg
->sg
.start
;
156 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
157 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
159 sk_mem_uncharge(sk
, uncharge
);
161 sk_msg_iter_var_next(i
);
162 } while (i
!= msg
->sg
.end
);
164 EXPORT_SYMBOL_GPL(sk_msg_return
);
166 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
169 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
170 u32 len
= sge
->length
;
173 sk_mem_uncharge(sk
, len
);
175 put_page(sg_page(sge
));
176 memset(sge
, 0, sizeof(*sge
));
180 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
183 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
186 while (msg
->sg
.size
) {
187 msg
->sg
.size
-= sge
->length
;
188 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
189 sk_msg_iter_var_next(i
);
190 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
191 sge
= sk_msg_elem(msg
, i
);
194 consume_skb(msg
->skb
);
199 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
201 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
203 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
205 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
207 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
209 EXPORT_SYMBOL_GPL(sk_msg_free
);
211 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
212 u32 bytes
, bool charge
)
214 struct scatterlist
*sge
;
215 u32 i
= msg
->sg
.start
;
218 sge
= sk_msg_elem(msg
, i
);
221 if (bytes
< sge
->length
) {
223 sk_mem_uncharge(sk
, bytes
);
224 sge
->length
-= bytes
;
225 sge
->offset
+= bytes
;
226 msg
->sg
.size
-= bytes
;
230 msg
->sg
.size
-= sge
->length
;
231 bytes
-= sge
->length
;
232 sk_msg_free_elem(sk
, msg
, i
, charge
);
233 sk_msg_iter_var_next(i
);
234 sk_msg_check_to_free(msg
, i
, bytes
);
239 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
241 __sk_msg_free_partial(sk
, msg
, bytes
, true);
243 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
245 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
248 __sk_msg_free_partial(sk
, msg
, bytes
, false);
251 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
253 int trim
= msg
->sg
.size
- len
;
261 sk_msg_iter_var_prev(i
);
263 while (msg
->sg
.data
[i
].length
&&
264 trim
>= msg
->sg
.data
[i
].length
) {
265 trim
-= msg
->sg
.data
[i
].length
;
266 sk_msg_free_elem(sk
, msg
, i
, true);
267 sk_msg_iter_var_prev(i
);
272 msg
->sg
.data
[i
].length
-= trim
;
273 sk_mem_uncharge(sk
, trim
);
275 /* If we trim data before curr pointer update copybreak and current
276 * so that any future copy operations start at new copy location.
277 * However trimed data that has not yet been used in a copy op
278 * does not require an update.
280 if (msg
->sg
.curr
>= i
) {
282 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
284 sk_msg_iter_var_next(i
);
287 EXPORT_SYMBOL_GPL(sk_msg_trim
);
289 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
290 struct sk_msg
*msg
, u32 bytes
)
292 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
293 const int to_max_pages
= MAX_MSG_FRAGS
;
294 struct page
*pages
[MAX_MSG_FRAGS
];
295 ssize_t orig
, copied
, use
, offset
;
300 maxpages
= to_max_pages
- num_elems
;
306 copied
= iov_iter_get_pages(from
, pages
, bytes
, maxpages
,
313 iov_iter_advance(from
, copied
);
315 msg
->sg
.size
+= copied
;
318 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
319 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
320 pages
[i
], use
, offset
);
321 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
322 sk_mem_charge(sk
, use
);
326 sk_msg_iter_next(msg
, end
);
330 /* When zerocopy is mixed with sk_msg_*copy* operations we
331 * may have a copybreak set in this case clear and prefer
332 * zerocopy remainder when possible.
334 msg
->sg
.copybreak
= 0;
335 msg
->sg
.curr
= msg
->sg
.end
;
338 /* Revert iov_iter updates, msg will need to use 'trim' later if it
339 * also needs to be cleared.
342 iov_iter_revert(from
, msg
->sg
.size
- orig
);
345 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
347 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
348 struct sk_msg
*msg
, u32 bytes
)
350 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
351 struct scatterlist
*sge
;
356 sge
= sk_msg_elem(msg
, i
);
357 /* This is possible if a trim operation shrunk the buffer */
358 if (msg
->sg
.copybreak
>= sge
->length
) {
359 msg
->sg
.copybreak
= 0;
360 sk_msg_iter_var_next(i
);
361 if (i
== msg
->sg
.end
)
363 sge
= sk_msg_elem(msg
, i
);
366 buf_size
= sge
->length
- msg
->sg
.copybreak
;
367 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
368 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
369 msg
->sg
.copybreak
+= copy
;
370 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
371 ret
= copy_from_iter_nocache(to
, copy
, from
);
373 ret
= copy_from_iter(to
, copy
, from
);
381 msg
->sg
.copybreak
= 0;
382 sk_msg_iter_var_next(i
);
383 } while (i
!= msg
->sg
.end
);
388 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
390 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
)
392 struct sock
*sk
= psock
->sk
;
393 int copied
= 0, num_sge
;
396 msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
399 if (!sk_rmem_schedule(sk
, skb
, skb
->len
)) {
405 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, 0, skb
->len
);
406 if (unlikely(num_sge
< 0)) {
411 sk_mem_charge(sk
, skb
->len
);
414 msg
->sg
.size
= copied
;
415 msg
->sg
.end
= num_sge
== MAX_MSG_FRAGS
? 0 : num_sge
;
418 sk_psock_queue_msg(psock
, msg
);
419 sk_psock_data_ready(sk
, psock
);
423 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
424 u32 off
, u32 len
, bool ingress
)
427 return sk_psock_skb_ingress(psock
, skb
);
429 return skb_send_sock_locked(psock
->sk
, skb
, off
, len
);
432 static void sk_psock_backlog(struct work_struct
*work
)
434 struct sk_psock
*psock
= container_of(work
, struct sk_psock
, work
);
435 struct sk_psock_work_state
*state
= &psock
->work_state
;
441 /* Lock sock to avoid losing sk_socket during loop. */
442 lock_sock(psock
->sk
);
451 while ((skb
= skb_dequeue(&psock
->ingress_skb
))) {
455 ingress
= tcp_skb_bpf_ingress(skb
);
458 if (likely(psock
->sk
->sk_socket
))
459 ret
= sk_psock_handle_skb(psock
, skb
, off
,
462 if (ret
== -EAGAIN
) {
468 /* Hard errors break pipe and stop xmit. */
469 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
470 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
482 release_sock(psock
->sk
);
485 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
487 struct sk_psock
*psock
= kzalloc_node(sizeof(*psock
),
488 GFP_ATOMIC
| __GFP_NOWARN
,
494 psock
->eval
= __SK_NONE
;
496 INIT_LIST_HEAD(&psock
->link
);
497 spin_lock_init(&psock
->link_lock
);
499 INIT_WORK(&psock
->work
, sk_psock_backlog
);
500 INIT_LIST_HEAD(&psock
->ingress_msg
);
501 skb_queue_head_init(&psock
->ingress_skb
);
503 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
504 refcount_set(&psock
->refcnt
, 1);
506 rcu_assign_sk_user_data(sk
, psock
);
511 EXPORT_SYMBOL_GPL(sk_psock_init
);
513 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
515 struct sk_psock_link
*link
;
517 spin_lock_bh(&psock
->link_lock
);
518 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
521 list_del(&link
->list
);
522 spin_unlock_bh(&psock
->link_lock
);
526 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
528 struct sk_msg
*msg
, *tmp
;
530 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
531 list_del(&msg
->list
);
532 sk_msg_free(psock
->sk
, msg
);
537 static void sk_psock_zap_ingress(struct sk_psock
*psock
)
539 __skb_queue_purge(&psock
->ingress_skb
);
540 __sk_psock_purge_ingress_msg(psock
);
543 static void sk_psock_link_destroy(struct sk_psock
*psock
)
545 struct sk_psock_link
*link
, *tmp
;
547 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
548 list_del(&link
->list
);
549 sk_psock_free_link(link
);
553 static void sk_psock_destroy_deferred(struct work_struct
*gc
)
555 struct sk_psock
*psock
= container_of(gc
, struct sk_psock
, gc
);
557 /* No sk_callback_lock since already detached. */
559 /* Parser has been stopped */
560 if (psock
->progs
.skb_parser
)
561 strp_done(&psock
->parser
.strp
);
563 cancel_work_sync(&psock
->work
);
565 psock_progs_drop(&psock
->progs
);
567 sk_psock_link_destroy(psock
);
568 sk_psock_cork_free(psock
);
569 sk_psock_zap_ingress(psock
);
572 sock_put(psock
->sk_redir
);
577 void sk_psock_destroy(struct rcu_head
*rcu
)
579 struct sk_psock
*psock
= container_of(rcu
, struct sk_psock
, rcu
);
581 INIT_WORK(&psock
->gc
, sk_psock_destroy_deferred
);
582 schedule_work(&psock
->gc
);
584 EXPORT_SYMBOL_GPL(sk_psock_destroy
);
586 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
588 rcu_assign_sk_user_data(sk
, NULL
);
589 sk_psock_cork_free(psock
);
590 sk_psock_zap_ingress(psock
);
591 sk_psock_restore_proto(sk
, psock
);
593 write_lock_bh(&sk
->sk_callback_lock
);
594 if (psock
->progs
.skb_parser
)
595 sk_psock_stop_strp(sk
, psock
);
596 write_unlock_bh(&sk
->sk_callback_lock
);
597 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
599 call_rcu(&psock
->rcu
, sk_psock_destroy
);
601 EXPORT_SYMBOL_GPL(sk_psock_drop
);
603 static int sk_psock_map_verd(int verdict
, bool redir
)
607 return redir
? __SK_REDIRECT
: __SK_PASS
;
616 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
619 struct bpf_prog
*prog
;
624 prog
= READ_ONCE(psock
->progs
.msg_parser
);
625 if (unlikely(!prog
)) {
630 sk_msg_compute_data_pointers(msg
);
632 ret
= BPF_PROG_RUN(prog
, msg
);
633 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
634 psock
->apply_bytes
= msg
->apply_bytes
;
635 if (ret
== __SK_REDIRECT
) {
637 sock_put(psock
->sk_redir
);
638 psock
->sk_redir
= msg
->sk_redir
;
639 if (!psock
->sk_redir
) {
643 sock_hold(psock
->sk_redir
);
650 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
652 static int sk_psock_bpf_run(struct sk_psock
*psock
, struct bpf_prog
*prog
,
658 bpf_compute_data_end_sk_skb(skb
);
660 ret
= BPF_PROG_RUN(prog
, skb
);
662 /* strparser clones the skb before handing it to a upper layer,
663 * meaning skb_orphan has been called. We NULL sk on the way out
664 * to ensure we don't trigger a BUG_ON() in skb/sk operations
665 * later and because we are not charging the memory of this skb
672 static struct sk_psock
*sk_psock_from_strp(struct strparser
*strp
)
674 struct sk_psock_parser
*parser
;
676 parser
= container_of(strp
, struct sk_psock_parser
, strp
);
677 return container_of(parser
, struct sk_psock
, parser
);
680 static void sk_psock_verdict_apply(struct sk_psock
*psock
,
681 struct sk_buff
*skb
, int verdict
)
683 struct sk_psock
*psock_other
;
684 struct sock
*sk_other
;
689 sk_other
= psock
->sk
;
690 if (sock_flag(sk_other
, SOCK_DEAD
) ||
691 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
694 if (atomic_read(&sk_other
->sk_rmem_alloc
) <=
695 sk_other
->sk_rcvbuf
) {
696 struct tcp_skb_cb
*tcp
= TCP_SKB_CB(skb
);
698 tcp
->bpf
.flags
|= BPF_F_INGRESS
;
699 skb_queue_tail(&psock
->ingress_skb
, skb
);
700 schedule_work(&psock
->work
);
705 sk_other
= tcp_skb_bpf_redirect_fetch(skb
);
706 if (unlikely(!sk_other
))
708 psock_other
= sk_psock(sk_other
);
709 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
) ||
710 !sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
))
712 ingress
= tcp_skb_bpf_ingress(skb
);
713 if ((!ingress
&& sock_writeable(sk_other
)) ||
715 atomic_read(&sk_other
->sk_rmem_alloc
) <=
716 sk_other
->sk_rcvbuf
)) {
718 skb_set_owner_w(skb
, sk_other
);
719 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
720 schedule_work(&psock_other
->work
);
732 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
734 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
735 struct bpf_prog
*prog
;
739 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
742 tcp_skb_bpf_redirect_clear(skb
);
743 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
744 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
747 sk_psock_verdict_apply(psock
, skb
, ret
);
750 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
755 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
757 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
758 struct bpf_prog
*prog
;
762 prog
= READ_ONCE(psock
->progs
.skb_parser
);
764 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
769 /* Called with socket lock held. */
770 static void sk_psock_strp_data_ready(struct sock
*sk
)
772 struct sk_psock
*psock
;
775 psock
= sk_psock(sk
);
777 write_lock_bh(&sk
->sk_callback_lock
);
778 strp_data_ready(&psock
->parser
.strp
);
779 write_unlock_bh(&sk
->sk_callback_lock
);
784 static void sk_psock_write_space(struct sock
*sk
)
786 struct sk_psock
*psock
;
787 void (*write_space
)(struct sock
*sk
);
790 psock
= sk_psock(sk
);
791 if (likely(psock
&& sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)))
792 schedule_work(&psock
->work
);
793 write_space
= psock
->saved_write_space
;
798 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
800 static const struct strp_callbacks cb
= {
801 .rcv_msg
= sk_psock_strp_read
,
802 .read_sock_done
= sk_psock_strp_read_done
,
803 .parse_msg
= sk_psock_strp_parse
,
806 psock
->parser
.enabled
= false;
807 return strp_init(&psock
->parser
.strp
, sk
, &cb
);
810 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
812 struct sk_psock_parser
*parser
= &psock
->parser
;
817 parser
->saved_data_ready
= sk
->sk_data_ready
;
818 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
819 sk
->sk_write_space
= sk_psock_write_space
;
820 parser
->enabled
= true;
823 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
825 struct sk_psock_parser
*parser
= &psock
->parser
;
827 if (!parser
->enabled
)
830 sk
->sk_data_ready
= parser
->saved_data_ready
;
831 parser
->saved_data_ready
= NULL
;
832 strp_stop(&parser
->strp
);
833 parser
->enabled
= false;