1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/btf_ids.h>
6 #include <linux/filter.h>
7 #include <linux/errno.h>
8 #include <linux/file.h>
10 #include <linux/workqueue.h>
11 #include <linux/skmsg.h>
12 #include <linux/list.h>
13 #include <linux/jhash.h>
14 #include <linux/sock_diag.h>
20 struct sk_psock_progs progs
;
24 #define SOCK_CREATE_FLAG_MASK \
25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
27 static struct bpf_map
*sock_map_alloc(union bpf_attr
*attr
)
29 struct bpf_stab
*stab
;
31 if (!capable(CAP_NET_ADMIN
))
32 return ERR_PTR(-EPERM
);
33 if (attr
->max_entries
== 0 ||
34 attr
->key_size
!= 4 ||
35 (attr
->value_size
!= sizeof(u32
) &&
36 attr
->value_size
!= sizeof(u64
)) ||
37 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
38 return ERR_PTR(-EINVAL
);
40 stab
= kzalloc(sizeof(*stab
), GFP_USER
| __GFP_ACCOUNT
);
42 return ERR_PTR(-ENOMEM
);
44 bpf_map_init_from_attr(&stab
->map
, attr
);
45 raw_spin_lock_init(&stab
->lock
);
47 stab
->sks
= bpf_map_area_alloc(stab
->map
.max_entries
*
48 sizeof(struct sock
*),
52 return ERR_PTR(-ENOMEM
);
58 int sock_map_get_from_fd(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
60 u32 ufd
= attr
->target_fd
;
65 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
69 map
= __bpf_map_get(f
);
72 ret
= sock_map_prog_update(map
, prog
, NULL
, attr
->attach_type
);
77 int sock_map_prog_detach(const union bpf_attr
*attr
, enum bpf_prog_type ptype
)
79 u32 ufd
= attr
->target_fd
;
80 struct bpf_prog
*prog
;
85 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
89 map
= __bpf_map_get(f
);
93 prog
= bpf_prog_get(attr
->attach_bpf_fd
);
99 if (prog
->type
!= ptype
) {
104 ret
= sock_map_prog_update(map
, NULL
, prog
, attr
->attach_type
);
112 static void sock_map_sk_acquire(struct sock
*sk
)
113 __acquires(&sk
->sk_lock
.slock
)
120 static void sock_map_sk_release(struct sock
*sk
)
121 __releases(&sk
->sk_lock
.slock
)
128 static void sock_map_add_link(struct sk_psock
*psock
,
129 struct sk_psock_link
*link
,
130 struct bpf_map
*map
, void *link_raw
)
132 link
->link_raw
= link_raw
;
134 spin_lock_bh(&psock
->link_lock
);
135 list_add_tail(&link
->list
, &psock
->link
);
136 spin_unlock_bh(&psock
->link_lock
);
139 static void sock_map_del_link(struct sock
*sk
,
140 struct sk_psock
*psock
, void *link_raw
)
142 bool strp_stop
= false, verdict_stop
= false;
143 struct sk_psock_link
*link
, *tmp
;
145 spin_lock_bh(&psock
->link_lock
);
146 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
147 if (link
->link_raw
== link_raw
) {
148 struct bpf_map
*map
= link
->map
;
149 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
,
151 if (psock
->parser
.enabled
&& stab
->progs
.skb_parser
)
153 if (psock
->parser
.enabled
&& stab
->progs
.skb_verdict
)
155 list_del(&link
->list
);
156 sk_psock_free_link(link
);
159 spin_unlock_bh(&psock
->link_lock
);
160 if (strp_stop
|| verdict_stop
) {
161 write_lock_bh(&sk
->sk_callback_lock
);
163 sk_psock_stop_strp(sk
, psock
);
165 sk_psock_stop_verdict(sk
, psock
);
166 write_unlock_bh(&sk
->sk_callback_lock
);
170 static void sock_map_unref(struct sock
*sk
, void *link_raw
)
172 struct sk_psock
*psock
= sk_psock(sk
);
175 sock_map_del_link(sk
, psock
, link_raw
);
176 sk_psock_put(sk
, psock
);
180 static int sock_map_init_proto(struct sock
*sk
, struct sk_psock
*psock
)
184 switch (sk
->sk_type
) {
186 prot
= tcp_bpf_get_proto(sk
, psock
);
190 prot
= udp_bpf_get_proto(sk
, psock
);
198 return PTR_ERR(prot
);
200 sk_psock_update_proto(sk
, psock
, prot
);
204 static struct sk_psock
*sock_map_psock_get_checked(struct sock
*sk
)
206 struct sk_psock
*psock
;
209 psock
= sk_psock(sk
);
211 if (sk
->sk_prot
->close
!= sock_map_close
) {
212 psock
= ERR_PTR(-EBUSY
);
216 if (!refcount_inc_not_zero(&psock
->refcnt
))
217 psock
= ERR_PTR(-EBUSY
);
224 static int sock_map_link(struct bpf_map
*map
, struct sk_psock_progs
*progs
,
227 struct bpf_prog
*msg_parser
, *skb_parser
, *skb_verdict
;
228 struct sk_psock
*psock
;
231 skb_verdict
= READ_ONCE(progs
->skb_verdict
);
233 skb_verdict
= bpf_prog_inc_not_zero(skb_verdict
);
234 if (IS_ERR(skb_verdict
))
235 return PTR_ERR(skb_verdict
);
238 skb_parser
= READ_ONCE(progs
->skb_parser
);
240 skb_parser
= bpf_prog_inc_not_zero(skb_parser
);
241 if (IS_ERR(skb_parser
)) {
242 ret
= PTR_ERR(skb_parser
);
243 goto out_put_skb_verdict
;
247 msg_parser
= READ_ONCE(progs
->msg_parser
);
249 msg_parser
= bpf_prog_inc_not_zero(msg_parser
);
250 if (IS_ERR(msg_parser
)) {
251 ret
= PTR_ERR(msg_parser
);
252 goto out_put_skb_parser
;
256 psock
= sock_map_psock_get_checked(sk
);
258 ret
= PTR_ERR(psock
);
263 if ((msg_parser
&& READ_ONCE(psock
->progs
.msg_parser
)) ||
264 (skb_parser
&& READ_ONCE(psock
->progs
.skb_parser
)) ||
265 (skb_verdict
&& READ_ONCE(psock
->progs
.skb_verdict
))) {
266 sk_psock_put(sk
, psock
);
271 psock
= sk_psock_init(sk
, map
->numa_node
);
273 ret
= PTR_ERR(psock
);
279 psock_set_prog(&psock
->progs
.msg_parser
, msg_parser
);
281 ret
= sock_map_init_proto(sk
, psock
);
285 write_lock_bh(&sk
->sk_callback_lock
);
286 if (skb_parser
&& skb_verdict
&& !psock
->parser
.enabled
) {
287 ret
= sk_psock_init_strp(sk
, psock
);
289 goto out_unlock_drop
;
290 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
291 psock_set_prog(&psock
->progs
.skb_parser
, skb_parser
);
292 sk_psock_start_strp(sk
, psock
);
293 } else if (!skb_parser
&& skb_verdict
&& !psock
->parser
.enabled
) {
294 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
295 sk_psock_start_verdict(sk
,psock
);
297 write_unlock_bh(&sk
->sk_callback_lock
);
300 write_unlock_bh(&sk
->sk_callback_lock
);
302 sk_psock_put(sk
, psock
);
305 bpf_prog_put(msg_parser
);
308 bpf_prog_put(skb_parser
);
311 bpf_prog_put(skb_verdict
);
315 static int sock_map_link_no_progs(struct bpf_map
*map
, struct sock
*sk
)
317 struct sk_psock
*psock
;
320 psock
= sock_map_psock_get_checked(sk
);
322 return PTR_ERR(psock
);
325 psock
= sk_psock_init(sk
, map
->numa_node
);
327 return PTR_ERR(psock
);
330 ret
= sock_map_init_proto(sk
, psock
);
332 sk_psock_put(sk
, psock
);
336 static void sock_map_free(struct bpf_map
*map
)
338 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
341 /* After the sync no updates or deletes will be in-flight so it
342 * is safe to walk map and remove entries without risking a race
343 * in EEXIST update case.
346 for (i
= 0; i
< stab
->map
.max_entries
; i
++) {
347 struct sock
**psk
= &stab
->sks
[i
];
350 sk
= xchg(psk
, NULL
);
354 sock_map_unref(sk
, psk
);
360 /* wait for psock readers accessing its map link */
363 bpf_map_area_free(stab
->sks
);
367 static void sock_map_release_progs(struct bpf_map
*map
)
369 psock_progs_drop(&container_of(map
, struct bpf_stab
, map
)->progs
);
372 static struct sock
*__sock_map_lookup_elem(struct bpf_map
*map
, u32 key
)
374 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
376 WARN_ON_ONCE(!rcu_read_lock_held());
378 if (unlikely(key
>= map
->max_entries
))
380 return READ_ONCE(stab
->sks
[key
]);
383 static void *sock_map_lookup(struct bpf_map
*map
, void *key
)
387 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
390 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
395 static void *sock_map_lookup_sys(struct bpf_map
*map
, void *key
)
399 if (map
->value_size
!= sizeof(u64
))
400 return ERR_PTR(-ENOSPC
);
402 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
404 return ERR_PTR(-ENOENT
);
406 __sock_gen_cookie(sk
);
407 return &sk
->sk_cookie
;
410 static int __sock_map_delete(struct bpf_stab
*stab
, struct sock
*sk_test
,
416 raw_spin_lock_bh(&stab
->lock
);
418 if (!sk_test
|| sk_test
== sk
)
419 sk
= xchg(psk
, NULL
);
422 sock_map_unref(sk
, psk
);
426 raw_spin_unlock_bh(&stab
->lock
);
430 static void sock_map_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
433 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
435 __sock_map_delete(stab
, sk
, link_raw
);
438 static int sock_map_delete_elem(struct bpf_map
*map
, void *key
)
440 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
444 if (unlikely(i
>= map
->max_entries
))
448 return __sock_map_delete(stab
, NULL
, psk
);
451 static int sock_map_get_next_key(struct bpf_map
*map
, void *key
, void *next
)
453 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
454 u32 i
= key
? *(u32
*)key
: U32_MAX
;
455 u32
*key_next
= next
;
457 if (i
== stab
->map
.max_entries
- 1)
459 if (i
>= stab
->map
.max_entries
)
466 static bool sock_map_redirect_allowed(const struct sock
*sk
);
468 static int sock_map_update_common(struct bpf_map
*map
, u32 idx
,
469 struct sock
*sk
, u64 flags
)
471 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
472 struct sk_psock_link
*link
;
473 struct sk_psock
*psock
;
477 WARN_ON_ONCE(!rcu_read_lock_held());
478 if (unlikely(flags
> BPF_EXIST
))
480 if (unlikely(idx
>= map
->max_entries
))
483 link
= sk_psock_init_link();
487 /* Only sockets we can redirect into/from in BPF need to hold
488 * refs to parser/verdict progs and have their sk_data_ready
489 * and sk_write_space callbacks overridden.
491 if (sock_map_redirect_allowed(sk
))
492 ret
= sock_map_link(map
, &stab
->progs
, sk
);
494 ret
= sock_map_link_no_progs(map
, sk
);
498 psock
= sk_psock(sk
);
499 WARN_ON_ONCE(!psock
);
501 raw_spin_lock_bh(&stab
->lock
);
502 osk
= stab
->sks
[idx
];
503 if (osk
&& flags
== BPF_NOEXIST
) {
506 } else if (!osk
&& flags
== BPF_EXIST
) {
511 sock_map_add_link(psock
, link
, map
, &stab
->sks
[idx
]);
514 sock_map_unref(osk
, &stab
->sks
[idx
]);
515 raw_spin_unlock_bh(&stab
->lock
);
518 raw_spin_unlock_bh(&stab
->lock
);
520 sk_psock_put(sk
, psock
);
522 sk_psock_free_link(link
);
526 static bool sock_map_op_okay(const struct bpf_sock_ops_kern
*ops
)
528 return ops
->op
== BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
||
529 ops
->op
== BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
||
530 ops
->op
== BPF_SOCK_OPS_TCP_LISTEN_CB
;
533 static bool sk_is_tcp(const struct sock
*sk
)
535 return sk
->sk_type
== SOCK_STREAM
&&
536 sk
->sk_protocol
== IPPROTO_TCP
;
539 static bool sk_is_udp(const struct sock
*sk
)
541 return sk
->sk_type
== SOCK_DGRAM
&&
542 sk
->sk_protocol
== IPPROTO_UDP
;
545 static bool sock_map_redirect_allowed(const struct sock
*sk
)
547 return sk_is_tcp(sk
) && sk
->sk_state
!= TCP_LISTEN
;
550 static bool sock_map_sk_is_suitable(const struct sock
*sk
)
552 return sk_is_tcp(sk
) || sk_is_udp(sk
);
555 static bool sock_map_sk_state_allowed(const struct sock
*sk
)
558 return (1 << sk
->sk_state
) & (TCPF_ESTABLISHED
| TCPF_LISTEN
);
559 else if (sk_is_udp(sk
))
560 return sk_hashed(sk
);
565 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
566 struct sock
*sk
, u64 flags
);
568 int sock_map_update_elem_sys(struct bpf_map
*map
, void *key
, void *value
,
576 if (map
->value_size
== sizeof(u64
))
583 sock
= sockfd_lookup(ufd
, &ret
);
591 if (!sock_map_sk_is_suitable(sk
)) {
596 sock_map_sk_acquire(sk
);
597 if (!sock_map_sk_state_allowed(sk
))
599 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
600 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
602 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
603 sock_map_sk_release(sk
);
609 static int sock_map_update_elem(struct bpf_map
*map
, void *key
,
610 void *value
, u64 flags
)
612 struct sock
*sk
= (struct sock
*)value
;
615 if (unlikely(!sk
|| !sk_fullsock(sk
)))
618 if (!sock_map_sk_is_suitable(sk
))
623 if (!sock_map_sk_state_allowed(sk
))
625 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
626 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
628 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
634 BPF_CALL_4(bpf_sock_map_update
, struct bpf_sock_ops_kern
*, sops
,
635 struct bpf_map
*, map
, void *, key
, u64
, flags
)
637 WARN_ON_ONCE(!rcu_read_lock_held());
639 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
640 sock_map_op_okay(sops
)))
641 return sock_map_update_common(map
, *(u32
*)key
, sops
->sk
,
646 const struct bpf_func_proto bpf_sock_map_update_proto
= {
647 .func
= bpf_sock_map_update
,
650 .ret_type
= RET_INTEGER
,
651 .arg1_type
= ARG_PTR_TO_CTX
,
652 .arg2_type
= ARG_CONST_MAP_PTR
,
653 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
654 .arg4_type
= ARG_ANYTHING
,
657 BPF_CALL_4(bpf_sk_redirect_map
, struct sk_buff
*, skb
,
658 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
660 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
663 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
666 sk
= __sock_map_lookup_elem(map
, key
);
667 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
670 tcb
->bpf
.flags
= flags
;
671 tcb
->bpf
.sk_redir
= sk
;
675 const struct bpf_func_proto bpf_sk_redirect_map_proto
= {
676 .func
= bpf_sk_redirect_map
,
678 .ret_type
= RET_INTEGER
,
679 .arg1_type
= ARG_PTR_TO_CTX
,
680 .arg2_type
= ARG_CONST_MAP_PTR
,
681 .arg3_type
= ARG_ANYTHING
,
682 .arg4_type
= ARG_ANYTHING
,
685 BPF_CALL_4(bpf_msg_redirect_map
, struct sk_msg
*, msg
,
686 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
690 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
693 sk
= __sock_map_lookup_elem(map
, key
);
694 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
702 const struct bpf_func_proto bpf_msg_redirect_map_proto
= {
703 .func
= bpf_msg_redirect_map
,
705 .ret_type
= RET_INTEGER
,
706 .arg1_type
= ARG_PTR_TO_CTX
,
707 .arg2_type
= ARG_CONST_MAP_PTR
,
708 .arg3_type
= ARG_ANYTHING
,
709 .arg4_type
= ARG_ANYTHING
,
712 struct sock_map_seq_info
{
718 struct bpf_iter__sockmap
{
719 __bpf_md_ptr(struct bpf_iter_meta
*, meta
);
720 __bpf_md_ptr(struct bpf_map
*, map
);
721 __bpf_md_ptr(void *, key
);
722 __bpf_md_ptr(struct sock
*, sk
);
725 DEFINE_BPF_ITER_FUNC(sockmap
, struct bpf_iter_meta
*meta
,
726 struct bpf_map
*map
, void *key
,
729 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info
*info
)
731 if (unlikely(info
->index
>= info
->map
->max_entries
))
734 info
->sk
= __sock_map_lookup_elem(info
->map
, info
->index
);
736 /* can't return sk directly, since that might be NULL */
740 static void *sock_map_seq_start(struct seq_file
*seq
, loff_t
*pos
)
743 struct sock_map_seq_info
*info
= seq
->private;
748 /* pairs with sock_map_seq_stop */
750 return sock_map_seq_lookup_elem(info
);
753 static void *sock_map_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
756 struct sock_map_seq_info
*info
= seq
->private;
761 return sock_map_seq_lookup_elem(info
);
764 static int sock_map_seq_show(struct seq_file
*seq
, void *v
)
767 struct sock_map_seq_info
*info
= seq
->private;
768 struct bpf_iter__sockmap ctx
= {};
769 struct bpf_iter_meta meta
;
770 struct bpf_prog
*prog
;
773 prog
= bpf_iter_get_info(&meta
, !v
);
780 ctx
.key
= &info
->index
;
784 return bpf_iter_run_prog(prog
, &ctx
);
787 static void sock_map_seq_stop(struct seq_file
*seq
, void *v
)
791 (void)sock_map_seq_show(seq
, NULL
);
793 /* pairs with sock_map_seq_start */
797 static const struct seq_operations sock_map_seq_ops
= {
798 .start
= sock_map_seq_start
,
799 .next
= sock_map_seq_next
,
800 .stop
= sock_map_seq_stop
,
801 .show
= sock_map_seq_show
,
804 static int sock_map_init_seq_private(void *priv_data
,
805 struct bpf_iter_aux_info
*aux
)
807 struct sock_map_seq_info
*info
= priv_data
;
809 info
->map
= aux
->map
;
813 static const struct bpf_iter_seq_info sock_map_iter_seq_info
= {
814 .seq_ops
= &sock_map_seq_ops
,
815 .init_seq_private
= sock_map_init_seq_private
,
816 .seq_priv_size
= sizeof(struct sock_map_seq_info
),
819 static int sock_map_btf_id
;
820 const struct bpf_map_ops sock_map_ops
= {
821 .map_meta_equal
= bpf_map_meta_equal
,
822 .map_alloc
= sock_map_alloc
,
823 .map_free
= sock_map_free
,
824 .map_get_next_key
= sock_map_get_next_key
,
825 .map_lookup_elem_sys_only
= sock_map_lookup_sys
,
826 .map_update_elem
= sock_map_update_elem
,
827 .map_delete_elem
= sock_map_delete_elem
,
828 .map_lookup_elem
= sock_map_lookup
,
829 .map_release_uref
= sock_map_release_progs
,
830 .map_check_btf
= map_check_no_btf
,
831 .map_btf_name
= "bpf_stab",
832 .map_btf_id
= &sock_map_btf_id
,
833 .iter_seq_info
= &sock_map_iter_seq_info
,
836 struct bpf_shtab_elem
{
840 struct hlist_node node
;
844 struct bpf_shtab_bucket
{
845 struct hlist_head head
;
851 struct bpf_shtab_bucket
*buckets
;
854 struct sk_psock_progs progs
;
858 static inline u32
sock_hash_bucket_hash(const void *key
, u32 len
)
860 return jhash(key
, len
, 0);
863 static struct bpf_shtab_bucket
*sock_hash_select_bucket(struct bpf_shtab
*htab
,
866 return &htab
->buckets
[hash
& (htab
->buckets_num
- 1)];
869 static struct bpf_shtab_elem
*
870 sock_hash_lookup_elem_raw(struct hlist_head
*head
, u32 hash
, void *key
,
873 struct bpf_shtab_elem
*elem
;
875 hlist_for_each_entry_rcu(elem
, head
, node
) {
876 if (elem
->hash
== hash
&&
877 !memcmp(&elem
->key
, key
, key_size
))
884 static struct sock
*__sock_hash_lookup_elem(struct bpf_map
*map
, void *key
)
886 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
887 u32 key_size
= map
->key_size
, hash
;
888 struct bpf_shtab_bucket
*bucket
;
889 struct bpf_shtab_elem
*elem
;
891 WARN_ON_ONCE(!rcu_read_lock_held());
893 hash
= sock_hash_bucket_hash(key
, key_size
);
894 bucket
= sock_hash_select_bucket(htab
, hash
);
895 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
897 return elem
? elem
->sk
: NULL
;
900 static void sock_hash_free_elem(struct bpf_shtab
*htab
,
901 struct bpf_shtab_elem
*elem
)
903 atomic_dec(&htab
->count
);
904 kfree_rcu(elem
, rcu
);
907 static void sock_hash_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
910 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
911 struct bpf_shtab_elem
*elem_probe
, *elem
= link_raw
;
912 struct bpf_shtab_bucket
*bucket
;
914 WARN_ON_ONCE(!rcu_read_lock_held());
915 bucket
= sock_hash_select_bucket(htab
, elem
->hash
);
917 /* elem may be deleted in parallel from the map, but access here
918 * is okay since it's going away only after RCU grace period.
919 * However, we need to check whether it's still present.
921 raw_spin_lock_bh(&bucket
->lock
);
922 elem_probe
= sock_hash_lookup_elem_raw(&bucket
->head
, elem
->hash
,
923 elem
->key
, map
->key_size
);
924 if (elem_probe
&& elem_probe
== elem
) {
925 hlist_del_rcu(&elem
->node
);
926 sock_map_unref(elem
->sk
, elem
);
927 sock_hash_free_elem(htab
, elem
);
929 raw_spin_unlock_bh(&bucket
->lock
);
932 static int sock_hash_delete_elem(struct bpf_map
*map
, void *key
)
934 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
935 u32 hash
, key_size
= map
->key_size
;
936 struct bpf_shtab_bucket
*bucket
;
937 struct bpf_shtab_elem
*elem
;
940 hash
= sock_hash_bucket_hash(key
, key_size
);
941 bucket
= sock_hash_select_bucket(htab
, hash
);
943 raw_spin_lock_bh(&bucket
->lock
);
944 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
946 hlist_del_rcu(&elem
->node
);
947 sock_map_unref(elem
->sk
, elem
);
948 sock_hash_free_elem(htab
, elem
);
951 raw_spin_unlock_bh(&bucket
->lock
);
955 static struct bpf_shtab_elem
*sock_hash_alloc_elem(struct bpf_shtab
*htab
,
956 void *key
, u32 key_size
,
957 u32 hash
, struct sock
*sk
,
958 struct bpf_shtab_elem
*old
)
960 struct bpf_shtab_elem
*new;
962 if (atomic_inc_return(&htab
->count
) > htab
->map
.max_entries
) {
964 atomic_dec(&htab
->count
);
965 return ERR_PTR(-E2BIG
);
969 new = bpf_map_kmalloc_node(&htab
->map
, htab
->elem_size
,
970 GFP_ATOMIC
| __GFP_NOWARN
,
971 htab
->map
.numa_node
);
973 atomic_dec(&htab
->count
);
974 return ERR_PTR(-ENOMEM
);
976 memcpy(new->key
, key
, key_size
);
982 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
983 struct sock
*sk
, u64 flags
)
985 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
986 u32 key_size
= map
->key_size
, hash
;
987 struct bpf_shtab_elem
*elem
, *elem_new
;
988 struct bpf_shtab_bucket
*bucket
;
989 struct sk_psock_link
*link
;
990 struct sk_psock
*psock
;
993 WARN_ON_ONCE(!rcu_read_lock_held());
994 if (unlikely(flags
> BPF_EXIST
))
997 link
= sk_psock_init_link();
1001 /* Only sockets we can redirect into/from in BPF need to hold
1002 * refs to parser/verdict progs and have their sk_data_ready
1003 * and sk_write_space callbacks overridden.
1005 if (sock_map_redirect_allowed(sk
))
1006 ret
= sock_map_link(map
, &htab
->progs
, sk
);
1008 ret
= sock_map_link_no_progs(map
, sk
);
1012 psock
= sk_psock(sk
);
1013 WARN_ON_ONCE(!psock
);
1015 hash
= sock_hash_bucket_hash(key
, key_size
);
1016 bucket
= sock_hash_select_bucket(htab
, hash
);
1018 raw_spin_lock_bh(&bucket
->lock
);
1019 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
1020 if (elem
&& flags
== BPF_NOEXIST
) {
1023 } else if (!elem
&& flags
== BPF_EXIST
) {
1028 elem_new
= sock_hash_alloc_elem(htab
, key
, key_size
, hash
, sk
, elem
);
1029 if (IS_ERR(elem_new
)) {
1030 ret
= PTR_ERR(elem_new
);
1034 sock_map_add_link(psock
, link
, map
, elem_new
);
1035 /* Add new element to the head of the list, so that
1036 * concurrent search will find it before old elem.
1038 hlist_add_head_rcu(&elem_new
->node
, &bucket
->head
);
1040 hlist_del_rcu(&elem
->node
);
1041 sock_map_unref(elem
->sk
, elem
);
1042 sock_hash_free_elem(htab
, elem
);
1044 raw_spin_unlock_bh(&bucket
->lock
);
1047 raw_spin_unlock_bh(&bucket
->lock
);
1048 sk_psock_put(sk
, psock
);
1050 sk_psock_free_link(link
);
1054 static int sock_hash_get_next_key(struct bpf_map
*map
, void *key
,
1057 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1058 struct bpf_shtab_elem
*elem
, *elem_next
;
1059 u32 hash
, key_size
= map
->key_size
;
1060 struct hlist_head
*head
;
1064 goto find_first_elem
;
1065 hash
= sock_hash_bucket_hash(key
, key_size
);
1066 head
= &sock_hash_select_bucket(htab
, hash
)->head
;
1067 elem
= sock_hash_lookup_elem_raw(head
, hash
, key
, key_size
);
1069 goto find_first_elem
;
1071 elem_next
= hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem
->node
)),
1072 struct bpf_shtab_elem
, node
);
1074 memcpy(key_next
, elem_next
->key
, key_size
);
1078 i
= hash
& (htab
->buckets_num
- 1);
1081 for (; i
< htab
->buckets_num
; i
++) {
1082 head
= &sock_hash_select_bucket(htab
, i
)->head
;
1083 elem_next
= hlist_entry_safe(rcu_dereference(hlist_first_rcu(head
)),
1084 struct bpf_shtab_elem
, node
);
1086 memcpy(key_next
, elem_next
->key
, key_size
);
1094 static struct bpf_map
*sock_hash_alloc(union bpf_attr
*attr
)
1096 struct bpf_shtab
*htab
;
1099 if (!capable(CAP_NET_ADMIN
))
1100 return ERR_PTR(-EPERM
);
1101 if (attr
->max_entries
== 0 ||
1102 attr
->key_size
== 0 ||
1103 (attr
->value_size
!= sizeof(u32
) &&
1104 attr
->value_size
!= sizeof(u64
)) ||
1105 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
1106 return ERR_PTR(-EINVAL
);
1107 if (attr
->key_size
> MAX_BPF_STACK
)
1108 return ERR_PTR(-E2BIG
);
1110 htab
= kzalloc(sizeof(*htab
), GFP_USER
| __GFP_ACCOUNT
);
1112 return ERR_PTR(-ENOMEM
);
1114 bpf_map_init_from_attr(&htab
->map
, attr
);
1116 htab
->buckets_num
= roundup_pow_of_two(htab
->map
.max_entries
);
1117 htab
->elem_size
= sizeof(struct bpf_shtab_elem
) +
1118 round_up(htab
->map
.key_size
, 8);
1119 if (htab
->buckets_num
== 0 ||
1120 htab
->buckets_num
> U32_MAX
/ sizeof(struct bpf_shtab_bucket
)) {
1125 htab
->buckets
= bpf_map_area_alloc(htab
->buckets_num
*
1126 sizeof(struct bpf_shtab_bucket
),
1127 htab
->map
.numa_node
);
1128 if (!htab
->buckets
) {
1133 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1134 INIT_HLIST_HEAD(&htab
->buckets
[i
].head
);
1135 raw_spin_lock_init(&htab
->buckets
[i
].lock
);
1141 return ERR_PTR(err
);
1144 static void sock_hash_free(struct bpf_map
*map
)
1146 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1147 struct bpf_shtab_bucket
*bucket
;
1148 struct hlist_head unlink_list
;
1149 struct bpf_shtab_elem
*elem
;
1150 struct hlist_node
*node
;
1153 /* After the sync no updates or deletes will be in-flight so it
1154 * is safe to walk map and remove entries without risking a race
1155 * in EEXIST update case.
1158 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1159 bucket
= sock_hash_select_bucket(htab
, i
);
1161 /* We are racing with sock_hash_delete_from_link to
1162 * enter the spin-lock critical section. Every socket on
1163 * the list is still linked to sockhash. Since link
1164 * exists, psock exists and holds a ref to socket. That
1165 * lets us to grab a socket ref too.
1167 raw_spin_lock_bh(&bucket
->lock
);
1168 hlist_for_each_entry(elem
, &bucket
->head
, node
)
1169 sock_hold(elem
->sk
);
1170 hlist_move_list(&bucket
->head
, &unlink_list
);
1171 raw_spin_unlock_bh(&bucket
->lock
);
1173 /* Process removed entries out of atomic context to
1174 * block for socket lock before deleting the psock's
1177 hlist_for_each_entry_safe(elem
, node
, &unlink_list
, node
) {
1178 hlist_del(&elem
->node
);
1179 lock_sock(elem
->sk
);
1181 sock_map_unref(elem
->sk
, elem
);
1183 release_sock(elem
->sk
);
1185 sock_hash_free_elem(htab
, elem
);
1189 /* wait for psock readers accessing its map link */
1192 bpf_map_area_free(htab
->buckets
);
1196 static void *sock_hash_lookup_sys(struct bpf_map
*map
, void *key
)
1200 if (map
->value_size
!= sizeof(u64
))
1201 return ERR_PTR(-ENOSPC
);
1203 sk
= __sock_hash_lookup_elem(map
, key
);
1205 return ERR_PTR(-ENOENT
);
1207 __sock_gen_cookie(sk
);
1208 return &sk
->sk_cookie
;
1211 static void *sock_hash_lookup(struct bpf_map
*map
, void *key
)
1215 sk
= __sock_hash_lookup_elem(map
, key
);
1218 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
1223 static void sock_hash_release_progs(struct bpf_map
*map
)
1225 psock_progs_drop(&container_of(map
, struct bpf_shtab
, map
)->progs
);
1228 BPF_CALL_4(bpf_sock_hash_update
, struct bpf_sock_ops_kern
*, sops
,
1229 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1231 WARN_ON_ONCE(!rcu_read_lock_held());
1233 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
1234 sock_map_op_okay(sops
)))
1235 return sock_hash_update_common(map
, key
, sops
->sk
, flags
);
1239 const struct bpf_func_proto bpf_sock_hash_update_proto
= {
1240 .func
= bpf_sock_hash_update
,
1243 .ret_type
= RET_INTEGER
,
1244 .arg1_type
= ARG_PTR_TO_CTX
,
1245 .arg2_type
= ARG_CONST_MAP_PTR
,
1246 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1247 .arg4_type
= ARG_ANYTHING
,
1250 BPF_CALL_4(bpf_sk_redirect_hash
, struct sk_buff
*, skb
,
1251 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1253 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
1256 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1259 sk
= __sock_hash_lookup_elem(map
, key
);
1260 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1263 tcb
->bpf
.flags
= flags
;
1264 tcb
->bpf
.sk_redir
= sk
;
1268 const struct bpf_func_proto bpf_sk_redirect_hash_proto
= {
1269 .func
= bpf_sk_redirect_hash
,
1271 .ret_type
= RET_INTEGER
,
1272 .arg1_type
= ARG_PTR_TO_CTX
,
1273 .arg2_type
= ARG_CONST_MAP_PTR
,
1274 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1275 .arg4_type
= ARG_ANYTHING
,
1278 BPF_CALL_4(bpf_msg_redirect_hash
, struct sk_msg
*, msg
,
1279 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1283 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1286 sk
= __sock_hash_lookup_elem(map
, key
);
1287 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1295 const struct bpf_func_proto bpf_msg_redirect_hash_proto
= {
1296 .func
= bpf_msg_redirect_hash
,
1298 .ret_type
= RET_INTEGER
,
1299 .arg1_type
= ARG_PTR_TO_CTX
,
1300 .arg2_type
= ARG_CONST_MAP_PTR
,
1301 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1302 .arg4_type
= ARG_ANYTHING
,
1305 struct sock_hash_seq_info
{
1306 struct bpf_map
*map
;
1307 struct bpf_shtab
*htab
;
1311 static void *sock_hash_seq_find_next(struct sock_hash_seq_info
*info
,
1312 struct bpf_shtab_elem
*prev_elem
)
1314 const struct bpf_shtab
*htab
= info
->htab
;
1315 struct bpf_shtab_bucket
*bucket
;
1316 struct bpf_shtab_elem
*elem
;
1317 struct hlist_node
*node
;
1319 /* try to find next elem in the same bucket */
1321 node
= rcu_dereference(hlist_next_rcu(&prev_elem
->node
));
1322 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1326 /* no more elements, continue in the next bucket */
1330 for (; info
->bucket_id
< htab
->buckets_num
; info
->bucket_id
++) {
1331 bucket
= &htab
->buckets
[info
->bucket_id
];
1332 node
= rcu_dereference(hlist_first_rcu(&bucket
->head
));
1333 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1341 static void *sock_hash_seq_start(struct seq_file
*seq
, loff_t
*pos
)
1344 struct sock_hash_seq_info
*info
= seq
->private;
1349 /* pairs with sock_hash_seq_stop */
1351 return sock_hash_seq_find_next(info
, NULL
);
1354 static void *sock_hash_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
1357 struct sock_hash_seq_info
*info
= seq
->private;
1360 return sock_hash_seq_find_next(info
, v
);
1363 static int sock_hash_seq_show(struct seq_file
*seq
, void *v
)
1366 struct sock_hash_seq_info
*info
= seq
->private;
1367 struct bpf_iter__sockmap ctx
= {};
1368 struct bpf_shtab_elem
*elem
= v
;
1369 struct bpf_iter_meta meta
;
1370 struct bpf_prog
*prog
;
1373 prog
= bpf_iter_get_info(&meta
, !elem
);
1378 ctx
.map
= info
->map
;
1380 ctx
.key
= elem
->key
;
1384 return bpf_iter_run_prog(prog
, &ctx
);
1387 static void sock_hash_seq_stop(struct seq_file
*seq
, void *v
)
1391 (void)sock_hash_seq_show(seq
, NULL
);
1393 /* pairs with sock_hash_seq_start */
1397 static const struct seq_operations sock_hash_seq_ops
= {
1398 .start
= sock_hash_seq_start
,
1399 .next
= sock_hash_seq_next
,
1400 .stop
= sock_hash_seq_stop
,
1401 .show
= sock_hash_seq_show
,
1404 static int sock_hash_init_seq_private(void *priv_data
,
1405 struct bpf_iter_aux_info
*aux
)
1407 struct sock_hash_seq_info
*info
= priv_data
;
1409 info
->map
= aux
->map
;
1410 info
->htab
= container_of(aux
->map
, struct bpf_shtab
, map
);
1414 static const struct bpf_iter_seq_info sock_hash_iter_seq_info
= {
1415 .seq_ops
= &sock_hash_seq_ops
,
1416 .init_seq_private
= sock_hash_init_seq_private
,
1417 .seq_priv_size
= sizeof(struct sock_hash_seq_info
),
1420 static int sock_hash_map_btf_id
;
1421 const struct bpf_map_ops sock_hash_ops
= {
1422 .map_meta_equal
= bpf_map_meta_equal
,
1423 .map_alloc
= sock_hash_alloc
,
1424 .map_free
= sock_hash_free
,
1425 .map_get_next_key
= sock_hash_get_next_key
,
1426 .map_update_elem
= sock_map_update_elem
,
1427 .map_delete_elem
= sock_hash_delete_elem
,
1428 .map_lookup_elem
= sock_hash_lookup
,
1429 .map_lookup_elem_sys_only
= sock_hash_lookup_sys
,
1430 .map_release_uref
= sock_hash_release_progs
,
1431 .map_check_btf
= map_check_no_btf
,
1432 .map_btf_name
= "bpf_shtab",
1433 .map_btf_id
= &sock_hash_map_btf_id
,
1434 .iter_seq_info
= &sock_hash_iter_seq_info
,
1437 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
)
1439 switch (map
->map_type
) {
1440 case BPF_MAP_TYPE_SOCKMAP
:
1441 return &container_of(map
, struct bpf_stab
, map
)->progs
;
1442 case BPF_MAP_TYPE_SOCKHASH
:
1443 return &container_of(map
, struct bpf_shtab
, map
)->progs
;
1451 int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
1452 struct bpf_prog
*old
, u32 which
)
1454 struct sk_psock_progs
*progs
= sock_map_progs(map
);
1455 struct bpf_prog
**pprog
;
1461 case BPF_SK_MSG_VERDICT
:
1462 pprog
= &progs
->msg_parser
;
1464 case BPF_SK_SKB_STREAM_PARSER
:
1465 pprog
= &progs
->skb_parser
;
1467 case BPF_SK_SKB_STREAM_VERDICT
:
1468 pprog
= &progs
->skb_verdict
;
1475 return psock_replace_prog(pprog
, prog
, old
);
1477 psock_set_prog(pprog
, prog
);
1481 static void sock_map_unlink(struct sock
*sk
, struct sk_psock_link
*link
)
1483 switch (link
->map
->map_type
) {
1484 case BPF_MAP_TYPE_SOCKMAP
:
1485 return sock_map_delete_from_link(link
->map
, sk
,
1487 case BPF_MAP_TYPE_SOCKHASH
:
1488 return sock_hash_delete_from_link(link
->map
, sk
,
1495 static void sock_map_remove_links(struct sock
*sk
, struct sk_psock
*psock
)
1497 struct sk_psock_link
*link
;
1499 while ((link
= sk_psock_link_pop(psock
))) {
1500 sock_map_unlink(sk
, link
);
1501 sk_psock_free_link(link
);
1505 void sock_map_unhash(struct sock
*sk
)
1507 void (*saved_unhash
)(struct sock
*sk
);
1508 struct sk_psock
*psock
;
1511 psock
= sk_psock(sk
);
1512 if (unlikely(!psock
)) {
1514 if (sk
->sk_prot
->unhash
)
1515 sk
->sk_prot
->unhash(sk
);
1519 saved_unhash
= psock
->saved_unhash
;
1520 sock_map_remove_links(sk
, psock
);
1525 void sock_map_close(struct sock
*sk
, long timeout
)
1527 void (*saved_close
)(struct sock
*sk
, long timeout
);
1528 struct sk_psock
*psock
;
1532 psock
= sk_psock(sk
);
1533 if (unlikely(!psock
)) {
1536 return sk
->sk_prot
->close(sk
, timeout
);
1539 saved_close
= psock
->saved_close
;
1540 sock_map_remove_links(sk
, psock
);
1543 saved_close(sk
, timeout
);
1546 static int sock_map_iter_attach_target(struct bpf_prog
*prog
,
1547 union bpf_iter_link_info
*linfo
,
1548 struct bpf_iter_aux_info
*aux
)
1550 struct bpf_map
*map
;
1553 if (!linfo
->map
.map_fd
)
1556 map
= bpf_map_get_with_uref(linfo
->map
.map_fd
);
1558 return PTR_ERR(map
);
1560 if (map
->map_type
!= BPF_MAP_TYPE_SOCKMAP
&&
1561 map
->map_type
!= BPF_MAP_TYPE_SOCKHASH
)
1564 if (prog
->aux
->max_rdonly_access
> map
->key_size
) {
1573 bpf_map_put_with_uref(map
);
1577 static void sock_map_iter_detach_target(struct bpf_iter_aux_info
*aux
)
1579 bpf_map_put_with_uref(aux
->map
);
1582 static struct bpf_iter_reg sock_map_iter_reg
= {
1583 .target
= "sockmap",
1584 .attach_target
= sock_map_iter_attach_target
,
1585 .detach_target
= sock_map_iter_detach_target
,
1586 .show_fdinfo
= bpf_iter_map_show_fdinfo
,
1587 .fill_link_info
= bpf_iter_map_fill_link_info
,
1588 .ctx_arg_info_size
= 2,
1590 { offsetof(struct bpf_iter__sockmap
, key
),
1591 PTR_TO_RDONLY_BUF_OR_NULL
},
1592 { offsetof(struct bpf_iter__sockmap
, sk
),
1593 PTR_TO_BTF_ID_OR_NULL
},
1597 static int __init
bpf_sockmap_iter_init(void)
1599 sock_map_iter_reg
.ctx_arg_info
[1].btf_id
=
1600 btf_sock_ids
[BTF_SOCK_TYPE_SOCK
];
1601 return bpf_iter_reg_target(&sock_map_iter_reg
);
1603 late_initcall(bpf_sockmap_iter_init
);