1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
13 if (msg
->sg
.end
> msg
->sg
.start
&&
14 elem_first_coalesce
< msg
->sg
.end
)
17 if (msg
->sg
.end
< msg
->sg
.start
&&
18 (elem_first_coalesce
> msg
->sg
.start
||
19 elem_first_coalesce
< msg
->sg
.end
))
25 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
26 int elem_first_coalesce
)
28 struct page_frag
*pfrag
= sk_page_frag(sk
);
33 struct scatterlist
*sge
;
37 if (!sk_page_frag_refill(sk
, pfrag
))
40 orig_offset
= pfrag
->offset
;
41 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
42 if (!sk_wmem_schedule(sk
, use
))
46 sk_msg_iter_var_prev(i
);
47 sge
= &msg
->sg
.data
[i
];
49 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
50 sg_page(sge
) == pfrag
->page
&&
51 sge
->offset
+ sge
->length
== orig_offset
) {
54 if (sk_msg_full(msg
)) {
59 sge
= &msg
->sg
.data
[msg
->sg
.end
];
61 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
62 get_page(pfrag
->page
);
63 sk_msg_iter_next(msg
, end
);
66 sk_mem_charge(sk
, use
);
74 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
76 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
79 int i
= src
->sg
.start
;
80 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
87 if (sge
->length
> off
)
90 sk_msg_iter_var_next(i
);
91 if (i
== src
->sg
.end
&& off
)
93 sge
= sk_msg_elem(src
, i
);
97 sge_len
= sge
->length
- off
;
98 sge_off
= sge
->offset
+ off
;
103 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
104 sk_mem_charge(sk
, sge_len
);
105 sk_msg_iter_var_next(i
);
106 if (i
== src
->sg
.end
&& len
)
108 sge
= sk_msg_elem(src
, i
);
113 EXPORT_SYMBOL_GPL(sk_msg_clone
);
115 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
117 int i
= msg
->sg
.start
;
120 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
122 if (bytes
< sge
->length
) {
123 sge
->length
-= bytes
;
124 sge
->offset
+= bytes
;
125 sk_mem_uncharge(sk
, bytes
);
129 sk_mem_uncharge(sk
, sge
->length
);
130 bytes
-= sge
->length
;
133 sk_msg_iter_var_next(i
);
134 } while (bytes
&& i
!= msg
->sg
.end
);
137 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
139 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
141 int i
= msg
->sg
.start
;
144 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
145 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
147 sk_mem_uncharge(sk
, uncharge
);
149 sk_msg_iter_var_next(i
);
150 } while (i
!= msg
->sg
.end
);
152 EXPORT_SYMBOL_GPL(sk_msg_return
);
154 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
157 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
158 u32 len
= sge
->length
;
161 sk_mem_uncharge(sk
, len
);
163 put_page(sg_page(sge
));
164 memset(sge
, 0, sizeof(*sge
));
168 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
171 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
174 while (msg
->sg
.size
) {
175 msg
->sg
.size
-= sge
->length
;
176 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
177 sk_msg_iter_var_next(i
);
178 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
179 sge
= sk_msg_elem(msg
, i
);
182 consume_skb(msg
->skb
);
187 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
189 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
191 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
193 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
195 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
197 EXPORT_SYMBOL_GPL(sk_msg_free
);
199 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
200 u32 bytes
, bool charge
)
202 struct scatterlist
*sge
;
203 u32 i
= msg
->sg
.start
;
206 sge
= sk_msg_elem(msg
, i
);
209 if (bytes
< sge
->length
) {
211 sk_mem_uncharge(sk
, bytes
);
212 sge
->length
-= bytes
;
213 sge
->offset
+= bytes
;
214 msg
->sg
.size
-= bytes
;
218 msg
->sg
.size
-= sge
->length
;
219 bytes
-= sge
->length
;
220 sk_msg_free_elem(sk
, msg
, i
, charge
);
221 sk_msg_iter_var_next(i
);
222 sk_msg_check_to_free(msg
, i
, bytes
);
227 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
229 __sk_msg_free_partial(sk
, msg
, bytes
, true);
231 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
233 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
236 __sk_msg_free_partial(sk
, msg
, bytes
, false);
239 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
241 int trim
= msg
->sg
.size
- len
;
249 sk_msg_iter_var_prev(i
);
251 while (msg
->sg
.data
[i
].length
&&
252 trim
>= msg
->sg
.data
[i
].length
) {
253 trim
-= msg
->sg
.data
[i
].length
;
254 sk_msg_free_elem(sk
, msg
, i
, true);
255 sk_msg_iter_var_prev(i
);
260 msg
->sg
.data
[i
].length
-= trim
;
261 sk_mem_uncharge(sk
, trim
);
263 /* If we trim data before curr pointer update copybreak and current
264 * so that any future copy operations start at new copy location.
265 * However trimed data that has not yet been used in a copy op
266 * does not require an update.
268 if (msg
->sg
.curr
>= i
) {
270 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
272 sk_msg_iter_var_next(i
);
275 EXPORT_SYMBOL_GPL(sk_msg_trim
);
277 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
278 struct sk_msg
*msg
, u32 bytes
)
280 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
281 const int to_max_pages
= MAX_MSG_FRAGS
;
282 struct page
*pages
[MAX_MSG_FRAGS
];
283 ssize_t orig
, copied
, use
, offset
;
288 maxpages
= to_max_pages
- num_elems
;
294 copied
= iov_iter_get_pages(from
, pages
, bytes
, maxpages
,
301 iov_iter_advance(from
, copied
);
303 msg
->sg
.size
+= copied
;
306 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
307 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
308 pages
[i
], use
, offset
);
309 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
310 sk_mem_charge(sk
, use
);
314 sk_msg_iter_next(msg
, end
);
318 /* When zerocopy is mixed with sk_msg_*copy* operations we
319 * may have a copybreak set in this case clear and prefer
320 * zerocopy remainder when possible.
322 msg
->sg
.copybreak
= 0;
323 msg
->sg
.curr
= msg
->sg
.end
;
326 /* Revert iov_iter updates, msg will need to use 'trim' later if it
327 * also needs to be cleared.
330 iov_iter_revert(from
, msg
->sg
.size
- orig
);
333 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
335 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
336 struct sk_msg
*msg
, u32 bytes
)
338 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
339 struct scatterlist
*sge
;
344 sge
= sk_msg_elem(msg
, i
);
345 /* This is possible if a trim operation shrunk the buffer */
346 if (msg
->sg
.copybreak
>= sge
->length
) {
347 msg
->sg
.copybreak
= 0;
348 sk_msg_iter_var_next(i
);
349 if (i
== msg
->sg
.end
)
351 sge
= sk_msg_elem(msg
, i
);
354 buf_size
= sge
->length
- msg
->sg
.copybreak
;
355 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
356 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
357 msg
->sg
.copybreak
+= copy
;
358 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
359 ret
= copy_from_iter_nocache(to
, copy
, from
);
361 ret
= copy_from_iter(to
, copy
, from
);
369 msg
->sg
.copybreak
= 0;
370 sk_msg_iter_var_next(i
);
371 } while (i
!= msg
->sg
.end
);
376 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
378 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
)
380 struct sock
*sk
= psock
->sk
;
381 int copied
= 0, num_sge
;
384 msg
= kzalloc(sizeof(*msg
), __GFP_NOWARN
| GFP_ATOMIC
);
387 if (!sk_rmem_schedule(sk
, skb
, skb
->len
)) {
393 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, 0, skb
->len
);
394 if (unlikely(num_sge
< 0)) {
399 sk_mem_charge(sk
, skb
->len
);
402 msg
->sg
.end
= num_sge
== MAX_MSG_FRAGS
? 0 : num_sge
;
405 sk_psock_queue_msg(psock
, msg
);
406 sk
->sk_data_ready(sk
);
410 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
411 u32 off
, u32 len
, bool ingress
)
414 return sk_psock_skb_ingress(psock
, skb
);
416 return skb_send_sock_locked(psock
->sk
, skb
, off
, len
);
419 static void sk_psock_backlog(struct work_struct
*work
)
421 struct sk_psock
*psock
= container_of(work
, struct sk_psock
, work
);
422 struct sk_psock_work_state
*state
= &psock
->work_state
;
428 /* Lock sock to avoid losing sk_socket during loop. */
429 lock_sock(psock
->sk
);
438 while ((skb
= skb_dequeue(&psock
->ingress_skb
))) {
442 ingress
= tcp_skb_bpf_ingress(skb
);
445 if (likely(psock
->sk
->sk_socket
))
446 ret
= sk_psock_handle_skb(psock
, skb
, off
,
449 if (ret
== -EAGAIN
) {
455 /* Hard errors break pipe and stop xmit. */
456 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
457 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
469 release_sock(psock
->sk
);
472 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
474 struct sk_psock
*psock
= kzalloc_node(sizeof(*psock
),
475 GFP_ATOMIC
| __GFP_NOWARN
,
481 psock
->eval
= __SK_NONE
;
483 INIT_LIST_HEAD(&psock
->link
);
484 spin_lock_init(&psock
->link_lock
);
486 INIT_WORK(&psock
->work
, sk_psock_backlog
);
487 INIT_LIST_HEAD(&psock
->ingress_msg
);
488 skb_queue_head_init(&psock
->ingress_skb
);
490 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
491 refcount_set(&psock
->refcnt
, 1);
493 rcu_assign_sk_user_data(sk
, psock
);
498 EXPORT_SYMBOL_GPL(sk_psock_init
);
500 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
502 struct sk_psock_link
*link
;
504 spin_lock_bh(&psock
->link_lock
);
505 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
508 list_del(&link
->list
);
509 spin_unlock_bh(&psock
->link_lock
);
513 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
515 struct sk_msg
*msg
, *tmp
;
517 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
518 list_del(&msg
->list
);
519 sk_msg_free(psock
->sk
, msg
);
524 static void sk_psock_zap_ingress(struct sk_psock
*psock
)
526 __skb_queue_purge(&psock
->ingress_skb
);
527 __sk_psock_purge_ingress_msg(psock
);
530 static void sk_psock_link_destroy(struct sk_psock
*psock
)
532 struct sk_psock_link
*link
, *tmp
;
534 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
535 list_del(&link
->list
);
536 sk_psock_free_link(link
);
540 static void sk_psock_destroy_deferred(struct work_struct
*gc
)
542 struct sk_psock
*psock
= container_of(gc
, struct sk_psock
, gc
);
544 /* No sk_callback_lock since already detached. */
545 if (psock
->parser
.enabled
)
546 strp_done(&psock
->parser
.strp
);
548 cancel_work_sync(&psock
->work
);
550 psock_progs_drop(&psock
->progs
);
552 sk_psock_link_destroy(psock
);
553 sk_psock_cork_free(psock
);
554 sk_psock_zap_ingress(psock
);
557 sock_put(psock
->sk_redir
);
562 void sk_psock_destroy(struct rcu_head
*rcu
)
564 struct sk_psock
*psock
= container_of(rcu
, struct sk_psock
, rcu
);
566 INIT_WORK(&psock
->gc
, sk_psock_destroy_deferred
);
567 schedule_work(&psock
->gc
);
569 EXPORT_SYMBOL_GPL(sk_psock_destroy
);
571 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
573 rcu_assign_sk_user_data(sk
, NULL
);
574 sk_psock_cork_free(psock
);
575 sk_psock_restore_proto(sk
, psock
);
577 write_lock_bh(&sk
->sk_callback_lock
);
578 if (psock
->progs
.skb_parser
)
579 sk_psock_stop_strp(sk
, psock
);
580 write_unlock_bh(&sk
->sk_callback_lock
);
581 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
583 call_rcu_sched(&psock
->rcu
, sk_psock_destroy
);
585 EXPORT_SYMBOL_GPL(sk_psock_drop
);
587 static int sk_psock_map_verd(int verdict
, bool redir
)
591 return redir
? __SK_REDIRECT
: __SK_PASS
;
600 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
603 struct bpf_prog
*prog
;
608 prog
= READ_ONCE(psock
->progs
.msg_parser
);
609 if (unlikely(!prog
)) {
614 sk_msg_compute_data_pointers(msg
);
616 ret
= BPF_PROG_RUN(prog
, msg
);
617 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
618 psock
->apply_bytes
= msg
->apply_bytes
;
619 if (ret
== __SK_REDIRECT
) {
621 sock_put(psock
->sk_redir
);
622 psock
->sk_redir
= msg
->sk_redir
;
623 if (!psock
->sk_redir
) {
627 sock_hold(psock
->sk_redir
);
634 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
636 static int sk_psock_bpf_run(struct sk_psock
*psock
, struct bpf_prog
*prog
,
642 bpf_compute_data_end_sk_skb(skb
);
644 ret
= BPF_PROG_RUN(prog
, skb
);
646 /* strparser clones the skb before handing it to a upper layer,
647 * meaning skb_orphan has been called. We NULL sk on the way out
648 * to ensure we don't trigger a BUG_ON() in skb/sk operations
649 * later and because we are not charging the memory of this skb
656 static struct sk_psock
*sk_psock_from_strp(struct strparser
*strp
)
658 struct sk_psock_parser
*parser
;
660 parser
= container_of(strp
, struct sk_psock_parser
, strp
);
661 return container_of(parser
, struct sk_psock
, parser
);
664 static void sk_psock_verdict_apply(struct sk_psock
*psock
,
665 struct sk_buff
*skb
, int verdict
)
667 struct sk_psock
*psock_other
;
668 struct sock
*sk_other
;
673 sk_other
= tcp_skb_bpf_redirect_fetch(skb
);
674 if (unlikely(!sk_other
))
676 psock_other
= sk_psock(sk_other
);
677 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
) ||
678 !sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
))
680 ingress
= tcp_skb_bpf_ingress(skb
);
681 if ((!ingress
&& sock_writeable(sk_other
)) ||
683 atomic_read(&sk_other
->sk_rmem_alloc
) <=
684 sk_other
->sk_rcvbuf
)) {
686 skb_set_owner_w(skb
, sk_other
);
687 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
688 schedule_work(&psock_other
->work
);
700 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
702 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
703 struct bpf_prog
*prog
;
707 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
710 tcp_skb_bpf_redirect_clear(skb
);
711 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
712 ret
= sk_psock_map_verd(ret
, tcp_skb_bpf_redirect_fetch(skb
));
715 sk_psock_verdict_apply(psock
, skb
, ret
);
718 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
723 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
725 struct sk_psock
*psock
= sk_psock_from_strp(strp
);
726 struct bpf_prog
*prog
;
730 prog
= READ_ONCE(psock
->progs
.skb_parser
);
732 ret
= sk_psock_bpf_run(psock
, prog
, skb
);
737 /* Called with socket lock held. */
738 static void sk_psock_data_ready(struct sock
*sk
)
740 struct sk_psock
*psock
;
743 psock
= sk_psock(sk
);
745 write_lock_bh(&sk
->sk_callback_lock
);
746 strp_data_ready(&psock
->parser
.strp
);
747 write_unlock_bh(&sk
->sk_callback_lock
);
752 static void sk_psock_write_space(struct sock
*sk
)
754 struct sk_psock
*psock
;
755 void (*write_space
)(struct sock
*sk
);
758 psock
= sk_psock(sk
);
759 if (likely(psock
&& sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)))
760 schedule_work(&psock
->work
);
761 write_space
= psock
->saved_write_space
;
766 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
768 static const struct strp_callbacks cb
= {
769 .rcv_msg
= sk_psock_strp_read
,
770 .read_sock_done
= sk_psock_strp_read_done
,
771 .parse_msg
= sk_psock_strp_parse
,
774 psock
->parser
.enabled
= false;
775 return strp_init(&psock
->parser
.strp
, sk
, &cb
);
778 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
780 struct sk_psock_parser
*parser
= &psock
->parser
;
785 parser
->saved_data_ready
= sk
->sk_data_ready
;
786 sk
->sk_data_ready
= sk_psock_data_ready
;
787 sk
->sk_write_space
= sk_psock_write_space
;
788 parser
->enabled
= true;
791 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
793 struct sk_psock_parser
*parser
= &psock
->parser
;
795 if (!parser
->enabled
)
798 sk
->sk_data_ready
= parser
->saved_data_ready
;
799 parser
->saved_data_ready
= NULL
;
800 strp_stop(&parser
->strp
);
801 parser
->enabled
= false;