1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
13 if (msg
->sg
.end
> msg
->sg
.start
&&
14 elem_first_coalesce
< msg
->sg
.end
)
17 if (msg
->sg
.end
< msg
->sg
.start
&&
18 (elem_first_coalesce
> msg
->sg
.start
||
19 elem_first_coalesce
< msg
->sg
.end
))
25 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
26 int elem_first_coalesce
)
28 struct page_frag
*pfrag
= sk_page_frag(sk
);
33 struct scatterlist
*sge
;
37 if (!sk_page_frag_refill(sk
, pfrag
))
40 orig_offset
= pfrag
->offset
;
41 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
42 if (!sk_wmem_schedule(sk
, use
))
46 sk_msg_iter_var_prev(i
);
47 sge
= &msg
->sg
.data
[i
];
49 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
50 sg_page(sge
) == pfrag
->page
&&
51 sge
->offset
+ sge
->length
== orig_offset
) {
54 if (sk_msg_full(msg
)) {
59 sge
= &msg
->sg
.data
[msg
->sg
.end
];
61 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
62 get_page(pfrag
->page
);
63 sk_msg_iter_next(msg
, end
);
66 sk_mem_charge(sk
, use
);
74 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
76 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
79 int i
= src
->sg
.start
;
80 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
87 if (sge
->length
> off
)
90 sk_msg_iter_var_next(i
);
91 if (i
== src
->sg
.end
&& off
)
93 sge
= sk_msg_elem(src
, i
);
100 sge_len
= sge
->length
- off
;
101 sge_off
= sge
->offset
+ off
;
106 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
107 sk_mem_charge(sk
, sge_len
);
108 sk_msg_iter_var_next(i
);
109 if (i
== src
->sg
.end
&& len
)
111 sge
= sk_msg_elem(src
, i
);
116 EXPORT_SYMBOL_GPL(sk_msg_clone
);
118 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
120 int i
= msg
->sg
.start
;
123 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
125 if (bytes
< sge
->length
) {
126 sge
->length
-= bytes
;
127 sge
->offset
+= bytes
;
128 sk_mem_uncharge(sk
, bytes
);
132 sk_mem_uncharge(sk
, sge
->length
);
133 bytes
-= sge
->length
;
136 sk_msg_iter_var_next(i
);
137 } while (bytes
&& i
!= msg
->sg
.end
);
140 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
142 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
144 int i
= msg
->sg
.start
;
147 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
148 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
150 sk_mem_uncharge(sk
, uncharge
);
152 sk_msg_iter_var_next(i
);
153 } while (i
!= msg
->sg
.end
);
155 EXPORT_SYMBOL_GPL(sk_msg_return
);
157 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
160 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
161 u32 len
= sge
->length
;
164 sk_mem_uncharge(sk
, len
);
166 put_page(sg_page(sge
));
167 memset(sge
, 0, sizeof(*sge
));
171 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
174 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
177 while (msg
->sg
.size
) {
178 msg
->sg
.size
-= sge
->length
;
179 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
180 sk_msg_iter_var_next(i
);
181 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
182 sge
= sk_msg_elem(msg
, i
);
185 consume_skb(msg
->skb
);
190 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
192 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
194 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
196 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
198 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
200 EXPORT_SYMBOL_GPL(sk_msg_free
);
202 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
203 u32 bytes
, bool charge
)
205 struct scatterlist
*sge
;
206 u32 i
= msg
->sg
.start
;
209 sge
= sk_msg_elem(msg
, i
);
212 if (bytes
< sge
->length
) {
214 sk_mem_uncharge(sk
, bytes
);
215 sge
->length
-= bytes
;
216 sge
->offset
+= bytes
;
217 msg
->sg
.size
-= bytes
;
221 msg
->sg
.size
-= sge
->length
;
222 bytes
-= sge
->length
;
223 sk_msg_free_elem(sk
, msg
, i
, charge
);
224 sk_msg_iter_var_next(i
);
225 sk_msg_check_to_free(msg
, i
, bytes
);
230 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
232 __sk_msg_free_partial(sk
, msg
, bytes
, true);
234 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
236 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
239 __sk_msg_free_partial(sk
, msg
, bytes
, false);
242 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
244 int trim
= msg
->sg
.size
- len
;
252 sk_msg_iter_var_prev(i
);
254 while (msg
->sg
.data
[i
].length
&&
255 trim
>= msg
->sg
.data
[i
].length
) {
256 trim
-= msg
->sg
.data
[i
].length
;
257 sk_msg_free_elem(sk
, msg
, i
, true);
258 sk_msg_iter_var_prev(i
);
263 msg
->sg
.data
[i
].length
-= trim
;
264 sk_mem_uncharge(sk
, trim
);
266 /* If we trim data before curr pointer update copybreak and current
267 * so that any future copy operations start at new copy location.
268 * However trimed data that has not yet been used in a copy op
269 * does not require an update.
271 if (msg
->sg
.curr
>= i
) {
273 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
275 sk_msg_iter_var_next(i
);
278 EXPORT_SYMBOL_GPL(sk_msg_trim
);
280 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
281 struct sk_msg
*msg
, u32 bytes
)
283 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
284 const int to_max_pages
= MAX_MSG_FRAGS
;
285 struct page
*pages
[MAX_MSG_FRAGS
];
286 ssize_t orig
, copied
, use
, offset
;
291 maxpages
= to_max_pages
- num_elems
;
297 copied
= iov_iter_get_pages(from
, pages
, bytes
, maxpages
,
304 iov_iter_advance(from
, copied
);
306 msg
->sg
.size
+= copied
;
309 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
310 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
311 pages
[i
], use
, offset
);
312 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
313 sk_mem_charge(sk
, use
);
317 sk_msg_iter_next(msg
, end
);
321 /* When zerocopy is mixed with sk_msg_*copy* operations we
322 * may have a copybreak set in this case clear and prefer
323 * zerocopy remainder when possible.
325 msg
->sg
.copybreak
= 0;
326 msg
->sg
.curr
= msg
->sg
.end
;
329 /* Revert iov_iter updates, msg will need to use 'trim' later if it
330 * also needs to be cleared.
333 iov_iter_revert(from
, msg
->sg
.size
- orig
);
336 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
338 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
339 struct sk_msg
*msg
, u32 bytes
)
341 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
342 struct scatterlist
*sge
;
347 sge
= sk_msg_elem(msg
, i
);
348 /* This is possible if a trim operation shrunk the buffer */
349 if (msg
->sg
.copybreak
>= sge
->length
) {
350 msg
->sg
.copybreak
= 0;
351 sk_msg_iter_var_next(i
);
352 if (i
== msg
->sg
.end
)
354 sge
= sk_msg_elem(msg
, i
);
357 buf_size
= sge
->length
- msg
->sg
.copybreak
;
358 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
359 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
360 msg
->sg
.copybreak
+= copy
;
361 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
362 ret
= copy_from_iter_nocache(to
, copy
, from
);
364 ret
= copy_from_iter(to
, copy
, from
);
372 msg
->sg
.copybreak
= 0;
373 sk_msg_iter_var_next(i
);
374 } while (i
!= msg
->sg
.end
);
379 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
381 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
)
383 struct sock
*sk
= psock
->sk
;
384 int copied
= 0, num_sge
;
387 msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
390 if (!sk_rmem_schedule(sk
, skb
, skb
->len
)) {
396 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, 0, skb
->len
);
397 if (unlikely(num_sge
< 0)) {
402 sk_mem_charge(sk
, skb
->len
);
405 msg
->sg
.end
= num_sge
== MAX_MSG_FRAGS
? 0 : num_sge
;
408 sk_psock_queue_msg(psock
, msg
);
409 sk_psock_data_ready(sk
, psock
);
413 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
414 u32 off
, u32 len
, bool ingress
)
417 return sk_psock_skb_ingress(psock
, skb
);
419 return skb_send_sock_locked(psock
->sk
, skb
, off
, len
);
422 static void sk_psock_backlog(struct work_struct
*work
)
424 struct sk_psock
*psock
= container_of(work
, struct sk_psock
, work
);
425 struct sk_psock_work_state
*state
= &psock
->work_state
;
431 /* Lock sock to avoid losing sk_socket during loop. */
432 lock_sock(psock
->sk
);
441 while ((skb
= skb_dequeue(&psock
->ingress_skb
))) {
445 ingress
= tcp_skb_bpf_ingress(skb
);
448 if (likely(psock
->sk
->sk_socket
))
449 ret
= sk_psock_handle_skb(psock
, skb
, off
,
452 if (ret
== -EAGAIN
) {
458 /* Hard errors break pipe and stop xmit. */
459 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
460 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
472 release_sock(psock
->sk
);
475 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
477 struct sk_psock
*psock
= kzalloc_node(sizeof(*psock
),
478 GFP_ATOMIC
| __GFP_NOWARN
,
484 psock
->eval
= __SK_NONE
;
486 INIT_LIST_HEAD(&psock
->link
);
487 spin_lock_init(&psock
->link_lock
);
489 INIT_WORK(&psock
->work
, sk_psock_backlog
);
490 INIT_LIST_HEAD(&psock
->ingress_msg
);
491 skb_queue_head_init(&psock
->ingress_skb
);
493 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
494 refcount_set(&psock
->refcnt
, 1);
496 rcu_assign_sk_user_data(sk
, psock
);
501 EXPORT_SYMBOL_GPL(sk_psock_init
);
503 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
505 struct sk_psock_link
*link
;
507 spin_lock_bh(&psock
->link_lock
);
508 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
511 list_del(&link
->list
);
512 spin_unlock_bh(&psock
->link_lock
);
516 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
518 struct sk_msg
*msg
, *tmp
;
520 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
521 list_del(&msg
->list
);
522 sk_msg_free(psock
->sk
, msg
);
527 static void sk_psock_zap_ingress(struct sk_psock
*psock
)
529 __skb_queue_purge(&psock
->ingress_skb
);
530 __sk_psock_purge_ingress_msg(psock
);
533 static void sk_psock_link_destroy(struct sk_psock
*psock
)
535 struct sk_psock_link
*link
, *tmp
;
537 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
538 list_del(&link
->list
);
539 sk_psock_free_link(link
);
543 static void sk_psock_destroy_deferred(struct work_struct
*gc
)
545 struct sk_psock
*psock
= container_of(gc
, struct sk_psock
, gc
);
547 /* No sk_callback_lock since already detached. */
548 if (psock
->parser
.enabled
)
549 strp_done(&psock
->parser
.strp
);
551 cancel_work_sync(&psock
->work
);
553 psock_progs_drop(&psock
->progs
);
555 sk_psock_link_destroy(psock
);
556 sk_psock_cork_free(psock
);
557 sk_psock_zap_ingress(psock
);
560 sock_put(psock
->sk_redir
);
565 void sk_psock_destroy(struct rcu_head
*rcu
)
567 struct sk_psock
*psock
= container_of(rcu
, struct sk_psock
, rcu
);
569 INIT_WORK(&psock
->gc
, sk_psock_destroy_deferred
);
570 schedule_work(&psock
->gc
);
572 EXPORT_SYMBOL_GPL(sk_psock_destroy
);
574 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
576 rcu_assign_sk_user_data(sk
, NULL
);
577 sk_psock_cork_free(psock
);
578 sk_psock_zap_ingress(psock
);
579 sk_psock_restore_proto(sk
, psock
);
581 write_lock_bh(&sk
->sk_callback_lock
);
582 if (psock
->progs
.skb_parser
)
583 sk_psock_stop_strp(sk
, psock
);
584 write_unlock_bh(&sk
->sk_callback_lock
);
585 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
587 call_rcu(&psock
->rcu
, sk_psock_destroy
);
589 EXPORT_SYMBOL_GPL(sk_psock_drop
);
591 static int sk_psock_map_verd(int verdict
, bool redir
)
595 return redir
? __SK_REDIRECT
: __SK_PASS
;
604 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
607 struct bpf_prog
*prog
;
612 prog
= READ_ONCE(psock
->progs
.msg_parser
);
613 if (unlikely(!prog
)) {
618 sk_msg_compute_data_pointers(msg
);
620 ret
= BPF_PROG_RUN(prog
, msg
);
621 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
622 psock
->apply_bytes
= msg
->apply_bytes
;
623 if (ret
== __SK_REDIRECT
) {
625 sock_put(psock
->sk_redir
);
626 psock
->sk_redir
= msg
->sk_redir
;
627 if (!psock
->sk_redir
) {
631 sock_hold(psock
->sk_redir
);
638 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
640 static int sk_psock_bpf_run(struct sk_psock
*psock
, struct bpf_prog
*prog
,
646 bpf_compute_data_end_sk_skb(skb
);
648 ret
= BPF_PROG_RUN(prog
, skb
);
650 /* strparser clones the skb before handing it to a upper layer,
651 * meaning skb_orphan has been called. We NULL sk on the way out
652 * to ensure we don't trigger a BUG_ON() in skb/sk operations
653 * later and because we are not charging the memory of this skb
660 static struct sk_psock
*sk_psock_from_strp(struct strparser
*strp
)
662 struct sk_psock_parser
*parser
;
664 parser
= container_of(strp
, struct sk_psock_parser
, strp
);
665 return container_of(parser
, struct sk_psock
, parser
);
668 static void sk_psock_verdict_apply(struct sk_psock
*psock
,
669 struct sk_buff
*skb
, int verdict
)
671 struct sk_psock
*psock_other
;
672 struct sock
*sk_other
;
677 sk_other
= psock
->sk
;
678 if (sock_flag(sk_other
, SOCK_DEAD
) ||
679 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
682 if (atomic_read(&sk_other
->sk_rmem_alloc
) <=
683 sk_other
->sk_rcvbuf
) {
684 struct tcp_skb_cb
*tcp
= TCP_SKB_CB(skb
);
686 tcp
->bpf
.flags
|= BPF_F_INGRESS
;
687 skb_queue_tail(&psock
->ingress_skb
, skb
);
688 schedule_work(&psock
->work
);
693 sk_other
= tcp_skb_bpf_redirect_fetch(skb
);
694 if (unlikely(!sk_other
))
696 psock_other
= sk_psock(sk_other
);
697 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
) ||
698 !sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
))
700 ingress
= tcp_skb_bpf_ingress(skb
);
701 if ((!ingress
&& sock_writeable(sk_other
)) ||
703 atomic_read(&sk_other
->sk_rmem_alloc
) <=
704 sk_other
->sk_rcvbuf
)) {
706 skb_set_owner_w(skb
, sk_other
);
707 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
708 schedule_work(&psock_other
->work
);
720 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
722 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
723 struct bpf_prog
*prog
;
727 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
730 tcp_skb_bpf_redirect_clear(skb
);
731 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
732 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
735 sk_psock_verdict_apply(psock
, skb
, ret
);
738 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
743 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
745 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
746 struct bpf_prog
*prog
;
750 prog
= READ_ONCE(psock
->progs
.skb_parser
);
752 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
757 /* Called with socket lock held. */
758 static void sk_psock_strp_data_ready(struct sock
*sk
)
760 struct sk_psock
*psock
;
763 psock
= sk_psock(sk
);
765 write_lock_bh(&sk
->sk_callback_lock
);
766 strp_data_ready(&psock
->parser
.strp
);
767 write_unlock_bh(&sk
->sk_callback_lock
);
772 static void sk_psock_write_space(struct sock
*sk
)
774 struct sk_psock
*psock
;
775 void (*write_space
)(struct sock
*sk
);
778 psock
= sk_psock(sk
);
779 if (likely(psock
&& sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)))
780 schedule_work(&psock
->work
);
781 write_space
= psock
->saved_write_space
;
786 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
788 static const struct strp_callbacks cb
= {
789 .rcv_msg
= sk_psock_strp_read
,
790 .read_sock_done
= sk_psock_strp_read_done
,
791 .parse_msg
= sk_psock_strp_parse
,
794 psock
->parser
.enabled
= false;
795 return strp_init(&psock
->parser
.strp
, sk
, &cb
);
798 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
800 struct sk_psock_parser
*parser
= &psock
->parser
;
805 parser
->saved_data_ready
= sk
->sk_data_ready
;
806 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
807 sk
->sk_write_space
= sk_psock_write_space
;
808 parser
->enabled
= true;
811 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
813 struct sk_psock_parser
*parser
= &psock
->parser
;
815 if (!parser
->enabled
)
818 sk
->sk_data_ready
= parser
->saved_data_ready
;
819 parser
->saved_data_ready
= NULL
;
820 strp_stop(&parser
->strp
);
821 parser
->enabled
= false;