1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
13 if (msg
->sg
.end
> msg
->sg
.start
&&
14 elem_first_coalesce
< msg
->sg
.end
)
17 if (msg
->sg
.end
< msg
->sg
.start
&&
18 (elem_first_coalesce
> msg
->sg
.start
||
19 elem_first_coalesce
< msg
->sg
.end
))
25 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
26 int elem_first_coalesce
)
28 struct page_frag
*pfrag
= sk_page_frag(sk
);
33 struct scatterlist
*sge
;
37 if (!sk_page_frag_refill(sk
, pfrag
))
40 orig_offset
= pfrag
->offset
;
41 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
42 if (!sk_wmem_schedule(sk
, use
))
46 sk_msg_iter_var_prev(i
);
47 sge
= &msg
->sg
.data
[i
];
49 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
50 sg_page(sge
) == pfrag
->page
&&
51 sge
->offset
+ sge
->length
== orig_offset
) {
54 if (sk_msg_full(msg
)) {
59 sge
= &msg
->sg
.data
[msg
->sg
.end
];
61 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
62 get_page(pfrag
->page
);
63 sk_msg_iter_next(msg
, end
);
66 sk_mem_charge(sk
, use
);
74 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
76 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
79 int i
= src
->sg
.start
;
80 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
81 struct scatterlist
*sgd
= NULL
;
85 if (sge
->length
> off
)
88 sk_msg_iter_var_next(i
);
89 if (i
== src
->sg
.end
&& off
)
91 sge
= sk_msg_elem(src
, i
);
95 sge_len
= sge
->length
- off
;
100 sgd
= sk_msg_elem(dst
, dst
->sg
.end
- 1);
103 (sg_page(sge
) == sg_page(sgd
)) &&
104 (sg_virt(sge
) + off
== sg_virt(sgd
) + sgd
->length
)) {
105 sgd
->length
+= sge_len
;
106 dst
->sg
.size
+= sge_len
;
107 } else if (!sk_msg_full(dst
)) {
108 sge_off
= sge
->offset
+ off
;
109 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
116 sk_mem_charge(sk
, sge_len
);
117 sk_msg_iter_var_next(i
);
118 if (i
== src
->sg
.end
&& len
)
120 sge
= sk_msg_elem(src
, i
);
125 EXPORT_SYMBOL_GPL(sk_msg_clone
);
127 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
129 int i
= msg
->sg
.start
;
132 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
134 if (bytes
< sge
->length
) {
135 sge
->length
-= bytes
;
136 sge
->offset
+= bytes
;
137 sk_mem_uncharge(sk
, bytes
);
141 sk_mem_uncharge(sk
, sge
->length
);
142 bytes
-= sge
->length
;
145 sk_msg_iter_var_next(i
);
146 } while (bytes
&& i
!= msg
->sg
.end
);
149 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
151 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
153 int i
= msg
->sg
.start
;
156 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
157 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
159 sk_mem_uncharge(sk
, uncharge
);
161 sk_msg_iter_var_next(i
);
162 } while (i
!= msg
->sg
.end
);
164 EXPORT_SYMBOL_GPL(sk_msg_return
);
166 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
169 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
170 u32 len
= sge
->length
;
173 sk_mem_uncharge(sk
, len
);
175 put_page(sg_page(sge
));
176 memset(sge
, 0, sizeof(*sge
));
180 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
183 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
186 while (msg
->sg
.size
) {
187 msg
->sg
.size
-= sge
->length
;
188 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
189 sk_msg_iter_var_next(i
);
190 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
191 sge
= sk_msg_elem(msg
, i
);
193 consume_skb(msg
->skb
);
198 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
200 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
202 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
204 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
206 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
208 EXPORT_SYMBOL_GPL(sk_msg_free
);
210 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
211 u32 bytes
, bool charge
)
213 struct scatterlist
*sge
;
214 u32 i
= msg
->sg
.start
;
217 sge
= sk_msg_elem(msg
, i
);
220 if (bytes
< sge
->length
) {
222 sk_mem_uncharge(sk
, bytes
);
223 sge
->length
-= bytes
;
224 sge
->offset
+= bytes
;
225 msg
->sg
.size
-= bytes
;
229 msg
->sg
.size
-= sge
->length
;
230 bytes
-= sge
->length
;
231 sk_msg_free_elem(sk
, msg
, i
, charge
);
232 sk_msg_iter_var_next(i
);
233 sk_msg_check_to_free(msg
, i
, bytes
);
238 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
240 __sk_msg_free_partial(sk
, msg
, bytes
, true);
242 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
244 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
247 __sk_msg_free_partial(sk
, msg
, bytes
, false);
250 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
252 int trim
= msg
->sg
.size
- len
;
260 sk_msg_iter_var_prev(i
);
262 while (msg
->sg
.data
[i
].length
&&
263 trim
>= msg
->sg
.data
[i
].length
) {
264 trim
-= msg
->sg
.data
[i
].length
;
265 sk_msg_free_elem(sk
, msg
, i
, true);
266 sk_msg_iter_var_prev(i
);
271 msg
->sg
.data
[i
].length
-= trim
;
272 sk_mem_uncharge(sk
, trim
);
273 /* Adjust copybreak if it falls into the trimmed part of last buf */
274 if (msg
->sg
.curr
== i
&& msg
->sg
.copybreak
> msg
->sg
.data
[i
].length
)
275 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
277 sk_msg_iter_var_next(i
);
280 /* If we trim data a full sg elem before curr pointer update
281 * copybreak and current so that any future copy operations
282 * start at new copy location.
283 * However trimed data that has not yet been used in a copy op
284 * does not require an update.
287 msg
->sg
.curr
= msg
->sg
.start
;
288 msg
->sg
.copybreak
= 0;
289 } else if (sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.curr
) >=
290 sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
)) {
291 sk_msg_iter_var_prev(i
);
293 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
296 EXPORT_SYMBOL_GPL(sk_msg_trim
);
298 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
299 struct sk_msg
*msg
, u32 bytes
)
301 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
302 const int to_max_pages
= MAX_MSG_FRAGS
;
303 struct page
*pages
[MAX_MSG_FRAGS
];
304 ssize_t orig
, copied
, use
, offset
;
309 maxpages
= to_max_pages
- num_elems
;
315 copied
= iov_iter_get_pages(from
, pages
, bytes
, maxpages
,
322 iov_iter_advance(from
, copied
);
324 msg
->sg
.size
+= copied
;
327 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
328 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
329 pages
[i
], use
, offset
);
330 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
331 sk_mem_charge(sk
, use
);
335 sk_msg_iter_next(msg
, end
);
339 /* When zerocopy is mixed with sk_msg_*copy* operations we
340 * may have a copybreak set in this case clear and prefer
341 * zerocopy remainder when possible.
343 msg
->sg
.copybreak
= 0;
344 msg
->sg
.curr
= msg
->sg
.end
;
347 /* Revert iov_iter updates, msg will need to use 'trim' later if it
348 * also needs to be cleared.
351 iov_iter_revert(from
, msg
->sg
.size
- orig
);
354 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
356 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
357 struct sk_msg
*msg
, u32 bytes
)
359 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
360 struct scatterlist
*sge
;
365 sge
= sk_msg_elem(msg
, i
);
366 /* This is possible if a trim operation shrunk the buffer */
367 if (msg
->sg
.copybreak
>= sge
->length
) {
368 msg
->sg
.copybreak
= 0;
369 sk_msg_iter_var_next(i
);
370 if (i
== msg
->sg
.end
)
372 sge
= sk_msg_elem(msg
, i
);
375 buf_size
= sge
->length
- msg
->sg
.copybreak
;
376 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
377 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
378 msg
->sg
.copybreak
+= copy
;
379 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
380 ret
= copy_from_iter_nocache(to
, copy
, from
);
382 ret
= copy_from_iter(to
, copy
, from
);
390 msg
->sg
.copybreak
= 0;
391 sk_msg_iter_var_next(i
);
392 } while (i
!= msg
->sg
.end
);
397 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
399 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
)
401 struct sock
*sk
= psock
->sk
;
402 int copied
= 0, num_sge
;
405 msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
408 if (!sk_rmem_schedule(sk
, skb
, skb
->len
)) {
414 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, 0, skb
->len
);
415 if (unlikely(num_sge
< 0)) {
420 sk_mem_charge(sk
, skb
->len
);
423 msg
->sg
.size
= copied
;
424 msg
->sg
.end
= num_sge
;
427 sk_psock_queue_msg(psock
, msg
);
428 sk_psock_data_ready(sk
, psock
);
432 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
433 u32 off
, u32 len
, bool ingress
)
436 return sk_psock_skb_ingress(psock
, skb
);
438 return skb_send_sock_locked(psock
->sk
, skb
, off
, len
);
441 static void sk_psock_backlog(struct work_struct
*work
)
443 struct sk_psock
*psock
= container_of(work
, struct sk_psock
, work
);
444 struct sk_psock_work_state
*state
= &psock
->work_state
;
450 /* Lock sock to avoid losing sk_socket during loop. */
451 lock_sock(psock
->sk
);
460 while ((skb
= skb_dequeue(&psock
->ingress_skb
))) {
464 ingress
= tcp_skb_bpf_ingress(skb
);
467 if (likely(psock
->sk
->sk_socket
))
468 ret
= sk_psock_handle_skb(psock
, skb
, off
,
471 if (ret
== -EAGAIN
) {
477 /* Hard errors break pipe and stop xmit. */
478 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
479 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
491 release_sock(psock
->sk
);
494 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
496 struct sk_psock
*psock
= kzalloc_node(sizeof(*psock
),
497 GFP_ATOMIC
| __GFP_NOWARN
,
503 psock
->eval
= __SK_NONE
;
505 INIT_LIST_HEAD(&psock
->link
);
506 spin_lock_init(&psock
->link_lock
);
508 INIT_WORK(&psock
->work
, sk_psock_backlog
);
509 INIT_LIST_HEAD(&psock
->ingress_msg
);
510 skb_queue_head_init(&psock
->ingress_skb
);
512 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
513 refcount_set(&psock
->refcnt
, 1);
515 rcu_assign_sk_user_data_nocopy(sk
, psock
);
520 EXPORT_SYMBOL_GPL(sk_psock_init
);
522 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
524 struct sk_psock_link
*link
;
526 spin_lock_bh(&psock
->link_lock
);
527 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
530 list_del(&link
->list
);
531 spin_unlock_bh(&psock
->link_lock
);
535 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
537 struct sk_msg
*msg
, *tmp
;
539 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
540 list_del(&msg
->list
);
541 sk_msg_free(psock
->sk
, msg
);
546 static void sk_psock_zap_ingress(struct sk_psock
*psock
)
548 __skb_queue_purge(&psock
->ingress_skb
);
549 __sk_psock_purge_ingress_msg(psock
);
552 static void sk_psock_link_destroy(struct sk_psock
*psock
)
554 struct sk_psock_link
*link
, *tmp
;
556 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
557 list_del(&link
->list
);
558 sk_psock_free_link(link
);
562 static void sk_psock_destroy_deferred(struct work_struct
*gc
)
564 struct sk_psock
*psock
= container_of(gc
, struct sk_psock
, gc
);
566 /* No sk_callback_lock since already detached. */
568 /* Parser has been stopped */
569 if (psock
->progs
.skb_parser
)
570 strp_done(&psock
->parser
.strp
);
572 cancel_work_sync(&psock
->work
);
574 psock_progs_drop(&psock
->progs
);
576 sk_psock_link_destroy(psock
);
577 sk_psock_cork_free(psock
);
578 sk_psock_zap_ingress(psock
);
581 sock_put(psock
->sk_redir
);
586 void sk_psock_destroy(struct rcu_head
*rcu
)
588 struct sk_psock
*psock
= container_of(rcu
, struct sk_psock
, rcu
);
590 INIT_WORK(&psock
->gc
, sk_psock_destroy_deferred
);
591 schedule_work(&psock
->gc
);
593 EXPORT_SYMBOL_GPL(sk_psock_destroy
);
595 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
597 sk_psock_cork_free(psock
);
598 sk_psock_zap_ingress(psock
);
600 write_lock_bh(&sk
->sk_callback_lock
);
601 sk_psock_restore_proto(sk
, psock
);
602 rcu_assign_sk_user_data(sk
, NULL
);
603 if (psock
->progs
.skb_parser
)
604 sk_psock_stop_strp(sk
, psock
);
605 write_unlock_bh(&sk
->sk_callback_lock
);
606 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
608 call_rcu(&psock
->rcu
, sk_psock_destroy
);
610 EXPORT_SYMBOL_GPL(sk_psock_drop
);
612 static int sk_psock_map_verd(int verdict
, bool redir
)
616 return redir
? __SK_REDIRECT
: __SK_PASS
;
625 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
628 struct bpf_prog
*prog
;
632 prog
= READ_ONCE(psock
->progs
.msg_parser
);
633 if (unlikely(!prog
)) {
638 sk_msg_compute_data_pointers(msg
);
640 ret
= bpf_prog_run_pin_on_cpu(prog
, msg
);
641 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
642 psock
->apply_bytes
= msg
->apply_bytes
;
643 if (ret
== __SK_REDIRECT
) {
645 sock_put(psock
->sk_redir
);
646 psock
->sk_redir
= msg
->sk_redir
;
647 if (!psock
->sk_redir
) {
651 sock_hold(psock
->sk_redir
);
657 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
659 static int sk_psock_bpf_run(struct sk_psock
*psock
, struct bpf_prog
*prog
,
665 bpf_compute_data_end_sk_skb(skb
);
666 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
667 /* strparser clones the skb before handing it to a upper layer,
668 * meaning skb_orphan has been called. We NULL sk on the way out
669 * to ensure we don't trigger a BUG_ON() in skb/sk operations
670 * later and because we are not charging the memory of this skb
677 static struct sk_psock
*sk_psock_from_strp(struct strparser
*strp
)
679 struct sk_psock_parser
*parser
;
681 parser
= container_of(strp
, struct sk_psock_parser
, strp
);
682 return container_of(parser
, struct sk_psock
, parser
);
685 static void sk_psock_verdict_apply(struct sk_psock
*psock
,
686 struct sk_buff
*skb
, int verdict
)
688 struct sk_psock
*psock_other
;
689 struct sock
*sk_other
;
694 sk_other
= psock
->sk
;
695 if (sock_flag(sk_other
, SOCK_DEAD
) ||
696 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
699 if (atomic_read(&sk_other
->sk_rmem_alloc
) <=
700 sk_other
->sk_rcvbuf
) {
701 struct tcp_skb_cb
*tcp
= TCP_SKB_CB(skb
);
703 tcp
->bpf
.flags
|= BPF_F_INGRESS
;
704 skb_queue_tail(&psock
->ingress_skb
, skb
);
705 schedule_work(&psock
->work
);
710 sk_other
= tcp_skb_bpf_redirect_fetch(skb
);
711 if (unlikely(!sk_other
))
713 psock_other
= sk_psock(sk_other
);
714 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
) ||
715 !sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
))
717 ingress
= tcp_skb_bpf_ingress(skb
);
718 if ((!ingress
&& sock_writeable(sk_other
)) ||
720 atomic_read(&sk_other
->sk_rmem_alloc
) <=
721 sk_other
->sk_rcvbuf
)) {
723 skb_set_owner_w(skb
, sk_other
);
724 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
725 schedule_work(&psock_other
->work
);
737 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
739 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
740 struct bpf_prog
*prog
;
744 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
747 tcp_skb_bpf_redirect_clear(skb
);
748 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
749 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
752 sk_psock_verdict_apply(psock
, skb
, ret
);
755 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
760 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
762 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
763 struct bpf_prog
*prog
;
767 prog
= READ_ONCE(psock
->progs
.skb_parser
);
769 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
774 /* Called with socket lock held. */
775 static void sk_psock_strp_data_ready(struct sock
*sk
)
777 struct sk_psock
*psock
;
780 psock
= sk_psock(sk
);
782 write_lock_bh(&sk
->sk_callback_lock
);
783 strp_data_ready(&psock
->parser
.strp
);
784 write_unlock_bh(&sk
->sk_callback_lock
);
789 static void sk_psock_write_space(struct sock
*sk
)
791 struct sk_psock
*psock
;
792 void (*write_space
)(struct sock
*sk
) = NULL
;
795 psock
= sk_psock(sk
);
797 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
798 schedule_work(&psock
->work
);
799 write_space
= psock
->saved_write_space
;
806 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
808 static const struct strp_callbacks cb
= {
809 .rcv_msg
= sk_psock_strp_read
,
810 .read_sock_done
= sk_psock_strp_read_done
,
811 .parse_msg
= sk_psock_strp_parse
,
814 psock
->parser
.enabled
= false;
815 return strp_init(&psock
->parser
.strp
, sk
, &cb
);
818 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
820 struct sk_psock_parser
*parser
= &psock
->parser
;
825 parser
->saved_data_ready
= sk
->sk_data_ready
;
826 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
827 sk
->sk_write_space
= sk_psock_write_space
;
828 parser
->enabled
= true;
831 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
833 struct sk_psock_parser
*parser
= &psock
->parser
;
835 if (!parser
->enabled
)
838 sk
->sk_data_ready
= parser
->saved_data_ready
;
839 parser
->saved_data_ready
= NULL
;
840 strp_stop(&parser
->strp
);
841 parser
->enabled
= false;