1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/btf_ids.h>
6 #include <linux/filter.h>
7 #include <linux/errno.h>
8 #include <linux/file.h>
10 #include <linux/workqueue.h>
11 #include <linux/skmsg.h>
12 #include <linux/list.h>
13 #include <linux/jhash.h>
14 #include <linux/sock_diag.h>
20 struct sk_psock_progs progs
;
24 #define SOCK_CREATE_FLAG_MASK \
25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
27 /* This mutex is used to
28 * - protect race between prog/link attach/detach and link prog update, and
29 * - protect race between releasing and accessing map in bpf_link.
30 * A single global mutex lock is used since it is expected contention is low.
32 static DEFINE_MUTEX(sockmap_mutex
);
34 static int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
35 struct bpf_prog
*old
, struct bpf_link
*link
,
37 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
);
39 static struct bpf_map
*sock_map_alloc(union bpf_attr
*attr
)
41 struct bpf_stab
*stab
;
43 if (attr
->max_entries
== 0 ||
44 attr
->key_size
!= 4 ||
45 (attr
->value_size
!= sizeof(u32
) &&
46 attr
->value_size
!= sizeof(u64
)) ||
47 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
48 return ERR_PTR(-EINVAL
);
50 stab
= bpf_map_area_alloc(sizeof(*stab
), NUMA_NO_NODE
);
52 return ERR_PTR(-ENOMEM
);
54 bpf_map_init_from_attr(&stab
->map
, attr
);
55 spin_lock_init(&stab
->lock
);
57 stab
->sks
= bpf_map_area_alloc((u64
) stab
->map
.max_entries
*
58 sizeof(struct sock
*),
61 bpf_map_area_free(stab
);
62 return ERR_PTR(-ENOMEM
);
68 int sock_map_get_from_fd(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
73 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
76 CLASS(fd
, f
)(attr
->target_fd
);
77 map
= __bpf_map_get(f
);
80 mutex_lock(&sockmap_mutex
);
81 ret
= sock_map_prog_update(map
, prog
, NULL
, NULL
, attr
->attach_type
);
82 mutex_unlock(&sockmap_mutex
);
86 int sock_map_prog_detach(const union bpf_attr
*attr
, enum bpf_prog_type ptype
)
88 struct bpf_prog
*prog
;
92 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
95 CLASS(fd
, f
)(attr
->target_fd
);
96 map
= __bpf_map_get(f
);
100 prog
= bpf_prog_get(attr
->attach_bpf_fd
);
102 return PTR_ERR(prog
);
104 if (prog
->type
!= ptype
) {
109 mutex_lock(&sockmap_mutex
);
110 ret
= sock_map_prog_update(map
, NULL
, prog
, NULL
, attr
->attach_type
);
111 mutex_unlock(&sockmap_mutex
);
117 static void sock_map_sk_acquire(struct sock
*sk
)
118 __acquires(&sk
->sk_lock
.slock
)
124 static void sock_map_sk_release(struct sock
*sk
)
125 __releases(&sk
->sk_lock
.slock
)
131 static void sock_map_add_link(struct sk_psock
*psock
,
132 struct sk_psock_link
*link
,
133 struct bpf_map
*map
, void *link_raw
)
135 link
->link_raw
= link_raw
;
137 spin_lock_bh(&psock
->link_lock
);
138 list_add_tail(&link
->list
, &psock
->link
);
139 spin_unlock_bh(&psock
->link_lock
);
142 static void sock_map_del_link(struct sock
*sk
,
143 struct sk_psock
*psock
, void *link_raw
)
145 bool strp_stop
= false, verdict_stop
= false;
146 struct sk_psock_link
*link
, *tmp
;
148 spin_lock_bh(&psock
->link_lock
);
149 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
150 if (link
->link_raw
== link_raw
) {
151 struct bpf_map
*map
= link
->map
;
152 struct sk_psock_progs
*progs
= sock_map_progs(map
);
154 if (psock
->saved_data_ready
&& progs
->stream_parser
)
156 if (psock
->saved_data_ready
&& progs
->stream_verdict
)
158 if (psock
->saved_data_ready
&& progs
->skb_verdict
)
160 list_del(&link
->list
);
161 sk_psock_free_link(link
);
164 spin_unlock_bh(&psock
->link_lock
);
165 if (strp_stop
|| verdict_stop
) {
166 write_lock_bh(&sk
->sk_callback_lock
);
168 sk_psock_stop_strp(sk
, psock
);
170 sk_psock_stop_verdict(sk
, psock
);
172 if (psock
->psock_update_sk_prot
)
173 psock
->psock_update_sk_prot(sk
, psock
, false);
174 write_unlock_bh(&sk
->sk_callback_lock
);
178 static void sock_map_unref(struct sock
*sk
, void *link_raw
)
180 struct sk_psock
*psock
= sk_psock(sk
);
183 sock_map_del_link(sk
, psock
, link_raw
);
184 sk_psock_put(sk
, psock
);
188 static int sock_map_init_proto(struct sock
*sk
, struct sk_psock
*psock
)
190 if (!sk
->sk_prot
->psock_update_sk_prot
)
192 psock
->psock_update_sk_prot
= sk
->sk_prot
->psock_update_sk_prot
;
193 return sk
->sk_prot
->psock_update_sk_prot(sk
, psock
, false);
196 static struct sk_psock
*sock_map_psock_get_checked(struct sock
*sk
)
198 struct sk_psock
*psock
;
201 psock
= sk_psock(sk
);
203 if (sk
->sk_prot
->close
!= sock_map_close
) {
204 psock
= ERR_PTR(-EBUSY
);
208 if (!refcount_inc_not_zero(&psock
->refcnt
))
209 psock
= ERR_PTR(-EBUSY
);
216 static int sock_map_link(struct bpf_map
*map
, struct sock
*sk
)
218 struct sk_psock_progs
*progs
= sock_map_progs(map
);
219 struct bpf_prog
*stream_verdict
= NULL
;
220 struct bpf_prog
*stream_parser
= NULL
;
221 struct bpf_prog
*skb_verdict
= NULL
;
222 struct bpf_prog
*msg_parser
= NULL
;
223 struct sk_psock
*psock
;
226 stream_verdict
= READ_ONCE(progs
->stream_verdict
);
227 if (stream_verdict
) {
228 stream_verdict
= bpf_prog_inc_not_zero(stream_verdict
);
229 if (IS_ERR(stream_verdict
))
230 return PTR_ERR(stream_verdict
);
233 stream_parser
= READ_ONCE(progs
->stream_parser
);
235 stream_parser
= bpf_prog_inc_not_zero(stream_parser
);
236 if (IS_ERR(stream_parser
)) {
237 ret
= PTR_ERR(stream_parser
);
238 goto out_put_stream_verdict
;
242 msg_parser
= READ_ONCE(progs
->msg_parser
);
244 msg_parser
= bpf_prog_inc_not_zero(msg_parser
);
245 if (IS_ERR(msg_parser
)) {
246 ret
= PTR_ERR(msg_parser
);
247 goto out_put_stream_parser
;
251 skb_verdict
= READ_ONCE(progs
->skb_verdict
);
253 skb_verdict
= bpf_prog_inc_not_zero(skb_verdict
);
254 if (IS_ERR(skb_verdict
)) {
255 ret
= PTR_ERR(skb_verdict
);
256 goto out_put_msg_parser
;
260 psock
= sock_map_psock_get_checked(sk
);
262 ret
= PTR_ERR(psock
);
267 if ((msg_parser
&& READ_ONCE(psock
->progs
.msg_parser
)) ||
268 (stream_parser
&& READ_ONCE(psock
->progs
.stream_parser
)) ||
269 (skb_verdict
&& READ_ONCE(psock
->progs
.skb_verdict
)) ||
270 (skb_verdict
&& READ_ONCE(psock
->progs
.stream_verdict
)) ||
271 (stream_verdict
&& READ_ONCE(psock
->progs
.skb_verdict
)) ||
272 (stream_verdict
&& READ_ONCE(psock
->progs
.stream_verdict
))) {
273 sk_psock_put(sk
, psock
);
278 psock
= sk_psock_init(sk
, map
->numa_node
);
280 ret
= PTR_ERR(psock
);
286 psock_set_prog(&psock
->progs
.msg_parser
, msg_parser
);
288 psock_set_prog(&psock
->progs
.stream_parser
, stream_parser
);
290 psock_set_prog(&psock
->progs
.stream_verdict
, stream_verdict
);
292 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
294 /* msg_* and stream_* programs references tracked in psock after this
295 * point. Reference dec and cleanup will occur through psock destructor
297 ret
= sock_map_init_proto(sk
, psock
);
299 sk_psock_put(sk
, psock
);
303 write_lock_bh(&sk
->sk_callback_lock
);
304 if (stream_parser
&& stream_verdict
&& !psock
->saved_data_ready
) {
305 ret
= sk_psock_init_strp(sk
, psock
);
307 write_unlock_bh(&sk
->sk_callback_lock
);
308 sk_psock_put(sk
, psock
);
311 sk_psock_start_strp(sk
, psock
);
312 } else if (!stream_parser
&& stream_verdict
&& !psock
->saved_data_ready
) {
313 sk_psock_start_verdict(sk
,psock
);
314 } else if (!stream_verdict
&& skb_verdict
&& !psock
->saved_data_ready
) {
315 sk_psock_start_verdict(sk
, psock
);
317 write_unlock_bh(&sk
->sk_callback_lock
);
321 bpf_prog_put(skb_verdict
);
324 bpf_prog_put(msg_parser
);
325 out_put_stream_parser
:
327 bpf_prog_put(stream_parser
);
328 out_put_stream_verdict
:
330 bpf_prog_put(stream_verdict
);
335 static void sock_map_free(struct bpf_map
*map
)
337 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
340 /* After the sync no updates or deletes will be in-flight so it
341 * is safe to walk map and remove entries without risking a race
342 * in EEXIST update case.
345 for (i
= 0; i
< stab
->map
.max_entries
; i
++) {
346 struct sock
**psk
= &stab
->sks
[i
];
349 sk
= xchg(psk
, NULL
);
354 sock_map_unref(sk
, psk
);
361 /* wait for psock readers accessing its map link */
364 bpf_map_area_free(stab
->sks
);
365 bpf_map_area_free(stab
);
368 static void sock_map_release_progs(struct bpf_map
*map
)
370 psock_progs_drop(&container_of(map
, struct bpf_stab
, map
)->progs
);
373 static struct sock
*__sock_map_lookup_elem(struct bpf_map
*map
, u32 key
)
375 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
377 WARN_ON_ONCE(!rcu_read_lock_held());
379 if (unlikely(key
>= map
->max_entries
))
381 return READ_ONCE(stab
->sks
[key
]);
384 static void *sock_map_lookup(struct bpf_map
*map
, void *key
)
388 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
391 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
396 static void *sock_map_lookup_sys(struct bpf_map
*map
, void *key
)
400 if (map
->value_size
!= sizeof(u64
))
401 return ERR_PTR(-ENOSPC
);
403 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
405 return ERR_PTR(-ENOENT
);
407 __sock_gen_cookie(sk
);
408 return &sk
->sk_cookie
;
411 static int __sock_map_delete(struct bpf_stab
*stab
, struct sock
*sk_test
,
417 spin_lock_bh(&stab
->lock
);
419 if (!sk_test
|| sk_test
== sk
)
420 sk
= xchg(psk
, NULL
);
423 sock_map_unref(sk
, psk
);
427 spin_unlock_bh(&stab
->lock
);
431 static void sock_map_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
434 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
436 __sock_map_delete(stab
, sk
, link_raw
);
439 static long sock_map_delete_elem(struct bpf_map
*map
, void *key
)
441 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
445 if (unlikely(i
>= map
->max_entries
))
449 return __sock_map_delete(stab
, NULL
, psk
);
452 static int sock_map_get_next_key(struct bpf_map
*map
, void *key
, void *next
)
454 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
455 u32 i
= key
? *(u32
*)key
: U32_MAX
;
456 u32
*key_next
= next
;
458 if (i
== stab
->map
.max_entries
- 1)
460 if (i
>= stab
->map
.max_entries
)
467 static int sock_map_update_common(struct bpf_map
*map
, u32 idx
,
468 struct sock
*sk
, u64 flags
)
470 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
471 struct sk_psock_link
*link
;
472 struct sk_psock
*psock
;
476 WARN_ON_ONCE(!rcu_read_lock_held());
477 if (unlikely(flags
> BPF_EXIST
))
479 if (unlikely(idx
>= map
->max_entries
))
482 link
= sk_psock_init_link();
486 ret
= sock_map_link(map
, sk
);
490 psock
= sk_psock(sk
);
491 WARN_ON_ONCE(!psock
);
493 spin_lock_bh(&stab
->lock
);
494 osk
= stab
->sks
[idx
];
495 if (osk
&& flags
== BPF_NOEXIST
) {
498 } else if (!osk
&& flags
== BPF_EXIST
) {
503 sock_map_add_link(psock
, link
, map
, &stab
->sks
[idx
]);
506 sock_map_unref(osk
, &stab
->sks
[idx
]);
507 spin_unlock_bh(&stab
->lock
);
510 spin_unlock_bh(&stab
->lock
);
512 sk_psock_put(sk
, psock
);
514 sk_psock_free_link(link
);
518 static bool sock_map_op_okay(const struct bpf_sock_ops_kern
*ops
)
520 return ops
->op
== BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
||
521 ops
->op
== BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
||
522 ops
->op
== BPF_SOCK_OPS_TCP_LISTEN_CB
;
525 static bool sock_map_redirect_allowed(const struct sock
*sk
)
528 return sk
->sk_state
!= TCP_LISTEN
;
530 return sk
->sk_state
== TCP_ESTABLISHED
;
533 static bool sock_map_sk_is_suitable(const struct sock
*sk
)
535 return !!sk
->sk_prot
->psock_update_sk_prot
;
538 static bool sock_map_sk_state_allowed(const struct sock
*sk
)
541 return (1 << sk
->sk_state
) & (TCPF_ESTABLISHED
| TCPF_LISTEN
);
542 if (sk_is_stream_unix(sk
))
543 return (1 << sk
->sk_state
) & TCPF_ESTABLISHED
;
547 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
548 struct sock
*sk
, u64 flags
);
550 int sock_map_update_elem_sys(struct bpf_map
*map
, void *key
, void *value
,
558 if (map
->value_size
== sizeof(u64
))
565 sock
= sockfd_lookup(ufd
, &ret
);
573 if (!sock_map_sk_is_suitable(sk
)) {
578 sock_map_sk_acquire(sk
);
579 if (!sock_map_sk_state_allowed(sk
))
581 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
582 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
584 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
585 sock_map_sk_release(sk
);
591 static long sock_map_update_elem(struct bpf_map
*map
, void *key
,
592 void *value
, u64 flags
)
594 struct sock
*sk
= (struct sock
*)value
;
597 if (unlikely(!sk
|| !sk_fullsock(sk
)))
600 if (!sock_map_sk_is_suitable(sk
))
605 if (!sock_map_sk_state_allowed(sk
))
607 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
608 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
610 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
616 BPF_CALL_4(bpf_sock_map_update
, struct bpf_sock_ops_kern
*, sops
,
617 struct bpf_map
*, map
, void *, key
, u64
, flags
)
619 WARN_ON_ONCE(!rcu_read_lock_held());
621 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
622 sock_map_op_okay(sops
)))
623 return sock_map_update_common(map
, *(u32
*)key
, sops
->sk
,
628 const struct bpf_func_proto bpf_sock_map_update_proto
= {
629 .func
= bpf_sock_map_update
,
632 .ret_type
= RET_INTEGER
,
633 .arg1_type
= ARG_PTR_TO_CTX
,
634 .arg2_type
= ARG_CONST_MAP_PTR
,
635 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
636 .arg4_type
= ARG_ANYTHING
,
639 BPF_CALL_4(bpf_sk_redirect_map
, struct sk_buff
*, skb
,
640 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
644 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
647 sk
= __sock_map_lookup_elem(map
, key
);
648 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
650 if ((flags
& BPF_F_INGRESS
) && sk_is_vsock(sk
))
653 skb_bpf_set_redir(skb
, sk
, flags
& BPF_F_INGRESS
);
657 const struct bpf_func_proto bpf_sk_redirect_map_proto
= {
658 .func
= bpf_sk_redirect_map
,
660 .ret_type
= RET_INTEGER
,
661 .arg1_type
= ARG_PTR_TO_CTX
,
662 .arg2_type
= ARG_CONST_MAP_PTR
,
663 .arg3_type
= ARG_ANYTHING
,
664 .arg4_type
= ARG_ANYTHING
,
667 BPF_CALL_4(bpf_msg_redirect_map
, struct sk_msg
*, msg
,
668 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
672 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
675 sk
= __sock_map_lookup_elem(map
, key
);
676 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
678 if (!(flags
& BPF_F_INGRESS
) && !sk_is_tcp(sk
))
688 const struct bpf_func_proto bpf_msg_redirect_map_proto
= {
689 .func
= bpf_msg_redirect_map
,
691 .ret_type
= RET_INTEGER
,
692 .arg1_type
= ARG_PTR_TO_CTX
,
693 .arg2_type
= ARG_CONST_MAP_PTR
,
694 .arg3_type
= ARG_ANYTHING
,
695 .arg4_type
= ARG_ANYTHING
,
698 struct sock_map_seq_info
{
704 struct bpf_iter__sockmap
{
705 __bpf_md_ptr(struct bpf_iter_meta
*, meta
);
706 __bpf_md_ptr(struct bpf_map
*, map
);
707 __bpf_md_ptr(void *, key
);
708 __bpf_md_ptr(struct sock
*, sk
);
711 DEFINE_BPF_ITER_FUNC(sockmap
, struct bpf_iter_meta
*meta
,
712 struct bpf_map
*map
, void *key
,
715 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info
*info
)
717 if (unlikely(info
->index
>= info
->map
->max_entries
))
720 info
->sk
= __sock_map_lookup_elem(info
->map
, info
->index
);
722 /* can't return sk directly, since that might be NULL */
726 static void *sock_map_seq_start(struct seq_file
*seq
, loff_t
*pos
)
729 struct sock_map_seq_info
*info
= seq
->private;
734 /* pairs with sock_map_seq_stop */
736 return sock_map_seq_lookup_elem(info
);
739 static void *sock_map_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
742 struct sock_map_seq_info
*info
= seq
->private;
747 return sock_map_seq_lookup_elem(info
);
750 static int sock_map_seq_show(struct seq_file
*seq
, void *v
)
753 struct sock_map_seq_info
*info
= seq
->private;
754 struct bpf_iter__sockmap ctx
= {};
755 struct bpf_iter_meta meta
;
756 struct bpf_prog
*prog
;
759 prog
= bpf_iter_get_info(&meta
, !v
);
766 ctx
.key
= &info
->index
;
770 return bpf_iter_run_prog(prog
, &ctx
);
773 static void sock_map_seq_stop(struct seq_file
*seq
, void *v
)
777 (void)sock_map_seq_show(seq
, NULL
);
779 /* pairs with sock_map_seq_start */
783 static const struct seq_operations sock_map_seq_ops
= {
784 .start
= sock_map_seq_start
,
785 .next
= sock_map_seq_next
,
786 .stop
= sock_map_seq_stop
,
787 .show
= sock_map_seq_show
,
790 static int sock_map_init_seq_private(void *priv_data
,
791 struct bpf_iter_aux_info
*aux
)
793 struct sock_map_seq_info
*info
= priv_data
;
795 bpf_map_inc_with_uref(aux
->map
);
796 info
->map
= aux
->map
;
800 static void sock_map_fini_seq_private(void *priv_data
)
802 struct sock_map_seq_info
*info
= priv_data
;
804 bpf_map_put_with_uref(info
->map
);
807 static u64
sock_map_mem_usage(const struct bpf_map
*map
)
809 u64 usage
= sizeof(struct bpf_stab
);
811 usage
+= (u64
)map
->max_entries
* sizeof(struct sock
*);
815 static const struct bpf_iter_seq_info sock_map_iter_seq_info
= {
816 .seq_ops
= &sock_map_seq_ops
,
817 .init_seq_private
= sock_map_init_seq_private
,
818 .fini_seq_private
= sock_map_fini_seq_private
,
819 .seq_priv_size
= sizeof(struct sock_map_seq_info
),
822 BTF_ID_LIST_SINGLE(sock_map_btf_ids
, struct, bpf_stab
)
823 const struct bpf_map_ops sock_map_ops
= {
824 .map_meta_equal
= bpf_map_meta_equal
,
825 .map_alloc
= sock_map_alloc
,
826 .map_free
= sock_map_free
,
827 .map_get_next_key
= sock_map_get_next_key
,
828 .map_lookup_elem_sys_only
= sock_map_lookup_sys
,
829 .map_update_elem
= sock_map_update_elem
,
830 .map_delete_elem
= sock_map_delete_elem
,
831 .map_lookup_elem
= sock_map_lookup
,
832 .map_release_uref
= sock_map_release_progs
,
833 .map_check_btf
= map_check_no_btf
,
834 .map_mem_usage
= sock_map_mem_usage
,
835 .map_btf_id
= &sock_map_btf_ids
[0],
836 .iter_seq_info
= &sock_map_iter_seq_info
,
839 struct bpf_shtab_elem
{
843 struct hlist_node node
;
847 struct bpf_shtab_bucket
{
848 struct hlist_head head
;
854 struct bpf_shtab_bucket
*buckets
;
857 struct sk_psock_progs progs
;
861 static inline u32
sock_hash_bucket_hash(const void *key
, u32 len
)
863 return jhash(key
, len
, 0);
866 static struct bpf_shtab_bucket
*sock_hash_select_bucket(struct bpf_shtab
*htab
,
869 return &htab
->buckets
[hash
& (htab
->buckets_num
- 1)];
872 static struct bpf_shtab_elem
*
873 sock_hash_lookup_elem_raw(struct hlist_head
*head
, u32 hash
, void *key
,
876 struct bpf_shtab_elem
*elem
;
878 hlist_for_each_entry_rcu(elem
, head
, node
) {
879 if (elem
->hash
== hash
&&
880 !memcmp(&elem
->key
, key
, key_size
))
887 static struct sock
*__sock_hash_lookup_elem(struct bpf_map
*map
, void *key
)
889 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
890 u32 key_size
= map
->key_size
, hash
;
891 struct bpf_shtab_bucket
*bucket
;
892 struct bpf_shtab_elem
*elem
;
894 WARN_ON_ONCE(!rcu_read_lock_held());
896 hash
= sock_hash_bucket_hash(key
, key_size
);
897 bucket
= sock_hash_select_bucket(htab
, hash
);
898 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
900 return elem
? elem
->sk
: NULL
;
903 static void sock_hash_free_elem(struct bpf_shtab
*htab
,
904 struct bpf_shtab_elem
*elem
)
906 atomic_dec(&htab
->count
);
907 kfree_rcu(elem
, rcu
);
910 static void sock_hash_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
913 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
914 struct bpf_shtab_elem
*elem_probe
, *elem
= link_raw
;
915 struct bpf_shtab_bucket
*bucket
;
917 WARN_ON_ONCE(!rcu_read_lock_held());
918 bucket
= sock_hash_select_bucket(htab
, elem
->hash
);
920 /* elem may be deleted in parallel from the map, but access here
921 * is okay since it's going away only after RCU grace period.
922 * However, we need to check whether it's still present.
924 spin_lock_bh(&bucket
->lock
);
925 elem_probe
= sock_hash_lookup_elem_raw(&bucket
->head
, elem
->hash
,
926 elem
->key
, map
->key_size
);
927 if (elem_probe
&& elem_probe
== elem
) {
928 hlist_del_rcu(&elem
->node
);
929 sock_map_unref(elem
->sk
, elem
);
930 sock_hash_free_elem(htab
, elem
);
932 spin_unlock_bh(&bucket
->lock
);
935 static long sock_hash_delete_elem(struct bpf_map
*map
, void *key
)
937 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
938 u32 hash
, key_size
= map
->key_size
;
939 struct bpf_shtab_bucket
*bucket
;
940 struct bpf_shtab_elem
*elem
;
943 hash
= sock_hash_bucket_hash(key
, key_size
);
944 bucket
= sock_hash_select_bucket(htab
, hash
);
946 spin_lock_bh(&bucket
->lock
);
947 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
949 hlist_del_rcu(&elem
->node
);
950 sock_map_unref(elem
->sk
, elem
);
951 sock_hash_free_elem(htab
, elem
);
954 spin_unlock_bh(&bucket
->lock
);
958 static struct bpf_shtab_elem
*sock_hash_alloc_elem(struct bpf_shtab
*htab
,
959 void *key
, u32 key_size
,
960 u32 hash
, struct sock
*sk
,
961 struct bpf_shtab_elem
*old
)
963 struct bpf_shtab_elem
*new;
965 if (atomic_inc_return(&htab
->count
) > htab
->map
.max_entries
) {
967 atomic_dec(&htab
->count
);
968 return ERR_PTR(-E2BIG
);
972 new = bpf_map_kmalloc_node(&htab
->map
, htab
->elem_size
,
973 GFP_ATOMIC
| __GFP_NOWARN
,
974 htab
->map
.numa_node
);
976 atomic_dec(&htab
->count
);
977 return ERR_PTR(-ENOMEM
);
979 memcpy(new->key
, key
, key_size
);
985 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
986 struct sock
*sk
, u64 flags
)
988 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
989 u32 key_size
= map
->key_size
, hash
;
990 struct bpf_shtab_elem
*elem
, *elem_new
;
991 struct bpf_shtab_bucket
*bucket
;
992 struct sk_psock_link
*link
;
993 struct sk_psock
*psock
;
996 WARN_ON_ONCE(!rcu_read_lock_held());
997 if (unlikely(flags
> BPF_EXIST
))
1000 link
= sk_psock_init_link();
1004 ret
= sock_map_link(map
, sk
);
1008 psock
= sk_psock(sk
);
1009 WARN_ON_ONCE(!psock
);
1011 hash
= sock_hash_bucket_hash(key
, key_size
);
1012 bucket
= sock_hash_select_bucket(htab
, hash
);
1014 spin_lock_bh(&bucket
->lock
);
1015 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
1016 if (elem
&& flags
== BPF_NOEXIST
) {
1019 } else if (!elem
&& flags
== BPF_EXIST
) {
1024 elem_new
= sock_hash_alloc_elem(htab
, key
, key_size
, hash
, sk
, elem
);
1025 if (IS_ERR(elem_new
)) {
1026 ret
= PTR_ERR(elem_new
);
1030 sock_map_add_link(psock
, link
, map
, elem_new
);
1031 /* Add new element to the head of the list, so that
1032 * concurrent search will find it before old elem.
1034 hlist_add_head_rcu(&elem_new
->node
, &bucket
->head
);
1036 hlist_del_rcu(&elem
->node
);
1037 sock_map_unref(elem
->sk
, elem
);
1038 sock_hash_free_elem(htab
, elem
);
1040 spin_unlock_bh(&bucket
->lock
);
1043 spin_unlock_bh(&bucket
->lock
);
1044 sk_psock_put(sk
, psock
);
1046 sk_psock_free_link(link
);
1050 static int sock_hash_get_next_key(struct bpf_map
*map
, void *key
,
1053 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1054 struct bpf_shtab_elem
*elem
, *elem_next
;
1055 u32 hash
, key_size
= map
->key_size
;
1056 struct hlist_head
*head
;
1060 goto find_first_elem
;
1061 hash
= sock_hash_bucket_hash(key
, key_size
);
1062 head
= &sock_hash_select_bucket(htab
, hash
)->head
;
1063 elem
= sock_hash_lookup_elem_raw(head
, hash
, key
, key_size
);
1065 goto find_first_elem
;
1067 elem_next
= hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem
->node
)),
1068 struct bpf_shtab_elem
, node
);
1070 memcpy(key_next
, elem_next
->key
, key_size
);
1074 i
= hash
& (htab
->buckets_num
- 1);
1077 for (; i
< htab
->buckets_num
; i
++) {
1078 head
= &sock_hash_select_bucket(htab
, i
)->head
;
1079 elem_next
= hlist_entry_safe(rcu_dereference(hlist_first_rcu(head
)),
1080 struct bpf_shtab_elem
, node
);
1082 memcpy(key_next
, elem_next
->key
, key_size
);
1090 static struct bpf_map
*sock_hash_alloc(union bpf_attr
*attr
)
1092 struct bpf_shtab
*htab
;
1095 if (attr
->max_entries
== 0 ||
1096 attr
->key_size
== 0 ||
1097 (attr
->value_size
!= sizeof(u32
) &&
1098 attr
->value_size
!= sizeof(u64
)) ||
1099 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
1100 return ERR_PTR(-EINVAL
);
1101 if (attr
->key_size
> MAX_BPF_STACK
)
1102 return ERR_PTR(-E2BIG
);
1104 htab
= bpf_map_area_alloc(sizeof(*htab
), NUMA_NO_NODE
);
1106 return ERR_PTR(-ENOMEM
);
1108 bpf_map_init_from_attr(&htab
->map
, attr
);
1110 htab
->buckets_num
= roundup_pow_of_two(htab
->map
.max_entries
);
1111 htab
->elem_size
= sizeof(struct bpf_shtab_elem
) +
1112 round_up(htab
->map
.key_size
, 8);
1113 if (htab
->buckets_num
== 0 ||
1114 htab
->buckets_num
> U32_MAX
/ sizeof(struct bpf_shtab_bucket
)) {
1119 htab
->buckets
= bpf_map_area_alloc(htab
->buckets_num
*
1120 sizeof(struct bpf_shtab_bucket
),
1121 htab
->map
.numa_node
);
1122 if (!htab
->buckets
) {
1127 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1128 INIT_HLIST_HEAD(&htab
->buckets
[i
].head
);
1129 spin_lock_init(&htab
->buckets
[i
].lock
);
1134 bpf_map_area_free(htab
);
1135 return ERR_PTR(err
);
1138 static void sock_hash_free(struct bpf_map
*map
)
1140 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1141 struct bpf_shtab_bucket
*bucket
;
1142 struct hlist_head unlink_list
;
1143 struct bpf_shtab_elem
*elem
;
1144 struct hlist_node
*node
;
1147 /* After the sync no updates or deletes will be in-flight so it
1148 * is safe to walk map and remove entries without risking a race
1149 * in EEXIST update case.
1152 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1153 bucket
= sock_hash_select_bucket(htab
, i
);
1155 /* We are racing with sock_hash_delete_from_link to
1156 * enter the spin-lock critical section. Every socket on
1157 * the list is still linked to sockhash. Since link
1158 * exists, psock exists and holds a ref to socket. That
1159 * lets us to grab a socket ref too.
1161 spin_lock_bh(&bucket
->lock
);
1162 hlist_for_each_entry(elem
, &bucket
->head
, node
)
1163 sock_hold(elem
->sk
);
1164 hlist_move_list(&bucket
->head
, &unlink_list
);
1165 spin_unlock_bh(&bucket
->lock
);
1167 /* Process removed entries out of atomic context to
1168 * block for socket lock before deleting the psock's
1171 hlist_for_each_entry_safe(elem
, node
, &unlink_list
, node
) {
1172 hlist_del(&elem
->node
);
1173 lock_sock(elem
->sk
);
1175 sock_map_unref(elem
->sk
, elem
);
1177 release_sock(elem
->sk
);
1179 sock_hash_free_elem(htab
, elem
);
1184 /* wait for psock readers accessing its map link */
1187 bpf_map_area_free(htab
->buckets
);
1188 bpf_map_area_free(htab
);
1191 static void *sock_hash_lookup_sys(struct bpf_map
*map
, void *key
)
1195 if (map
->value_size
!= sizeof(u64
))
1196 return ERR_PTR(-ENOSPC
);
1198 sk
= __sock_hash_lookup_elem(map
, key
);
1200 return ERR_PTR(-ENOENT
);
1202 __sock_gen_cookie(sk
);
1203 return &sk
->sk_cookie
;
1206 static void *sock_hash_lookup(struct bpf_map
*map
, void *key
)
1210 sk
= __sock_hash_lookup_elem(map
, key
);
1213 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
1218 static void sock_hash_release_progs(struct bpf_map
*map
)
1220 psock_progs_drop(&container_of(map
, struct bpf_shtab
, map
)->progs
);
1223 BPF_CALL_4(bpf_sock_hash_update
, struct bpf_sock_ops_kern
*, sops
,
1224 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1226 WARN_ON_ONCE(!rcu_read_lock_held());
1228 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
1229 sock_map_op_okay(sops
)))
1230 return sock_hash_update_common(map
, key
, sops
->sk
, flags
);
1234 const struct bpf_func_proto bpf_sock_hash_update_proto
= {
1235 .func
= bpf_sock_hash_update
,
1238 .ret_type
= RET_INTEGER
,
1239 .arg1_type
= ARG_PTR_TO_CTX
,
1240 .arg2_type
= ARG_CONST_MAP_PTR
,
1241 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1242 .arg4_type
= ARG_ANYTHING
,
1245 BPF_CALL_4(bpf_sk_redirect_hash
, struct sk_buff
*, skb
,
1246 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1250 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1253 sk
= __sock_hash_lookup_elem(map
, key
);
1254 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1256 if ((flags
& BPF_F_INGRESS
) && sk_is_vsock(sk
))
1259 skb_bpf_set_redir(skb
, sk
, flags
& BPF_F_INGRESS
);
1263 const struct bpf_func_proto bpf_sk_redirect_hash_proto
= {
1264 .func
= bpf_sk_redirect_hash
,
1266 .ret_type
= RET_INTEGER
,
1267 .arg1_type
= ARG_PTR_TO_CTX
,
1268 .arg2_type
= ARG_CONST_MAP_PTR
,
1269 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1270 .arg4_type
= ARG_ANYTHING
,
1273 BPF_CALL_4(bpf_msg_redirect_hash
, struct sk_msg
*, msg
,
1274 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1278 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1281 sk
= __sock_hash_lookup_elem(map
, key
);
1282 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1284 if (!(flags
& BPF_F_INGRESS
) && !sk_is_tcp(sk
))
1286 if (sk_is_vsock(sk
))
1294 const struct bpf_func_proto bpf_msg_redirect_hash_proto
= {
1295 .func
= bpf_msg_redirect_hash
,
1297 .ret_type
= RET_INTEGER
,
1298 .arg1_type
= ARG_PTR_TO_CTX
,
1299 .arg2_type
= ARG_CONST_MAP_PTR
,
1300 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1301 .arg4_type
= ARG_ANYTHING
,
1304 struct sock_hash_seq_info
{
1305 struct bpf_map
*map
;
1306 struct bpf_shtab
*htab
;
1310 static void *sock_hash_seq_find_next(struct sock_hash_seq_info
*info
,
1311 struct bpf_shtab_elem
*prev_elem
)
1313 const struct bpf_shtab
*htab
= info
->htab
;
1314 struct bpf_shtab_bucket
*bucket
;
1315 struct bpf_shtab_elem
*elem
;
1316 struct hlist_node
*node
;
1318 /* try to find next elem in the same bucket */
1320 node
= rcu_dereference(hlist_next_rcu(&prev_elem
->node
));
1321 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1325 /* no more elements, continue in the next bucket */
1329 for (; info
->bucket_id
< htab
->buckets_num
; info
->bucket_id
++) {
1330 bucket
= &htab
->buckets
[info
->bucket_id
];
1331 node
= rcu_dereference(hlist_first_rcu(&bucket
->head
));
1332 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1340 static void *sock_hash_seq_start(struct seq_file
*seq
, loff_t
*pos
)
1343 struct sock_hash_seq_info
*info
= seq
->private;
1348 /* pairs with sock_hash_seq_stop */
1350 return sock_hash_seq_find_next(info
, NULL
);
1353 static void *sock_hash_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
1356 struct sock_hash_seq_info
*info
= seq
->private;
1359 return sock_hash_seq_find_next(info
, v
);
1362 static int sock_hash_seq_show(struct seq_file
*seq
, void *v
)
1365 struct sock_hash_seq_info
*info
= seq
->private;
1366 struct bpf_iter__sockmap ctx
= {};
1367 struct bpf_shtab_elem
*elem
= v
;
1368 struct bpf_iter_meta meta
;
1369 struct bpf_prog
*prog
;
1372 prog
= bpf_iter_get_info(&meta
, !elem
);
1377 ctx
.map
= info
->map
;
1379 ctx
.key
= elem
->key
;
1383 return bpf_iter_run_prog(prog
, &ctx
);
1386 static void sock_hash_seq_stop(struct seq_file
*seq
, void *v
)
1390 (void)sock_hash_seq_show(seq
, NULL
);
1392 /* pairs with sock_hash_seq_start */
1396 static const struct seq_operations sock_hash_seq_ops
= {
1397 .start
= sock_hash_seq_start
,
1398 .next
= sock_hash_seq_next
,
1399 .stop
= sock_hash_seq_stop
,
1400 .show
= sock_hash_seq_show
,
1403 static int sock_hash_init_seq_private(void *priv_data
,
1404 struct bpf_iter_aux_info
*aux
)
1406 struct sock_hash_seq_info
*info
= priv_data
;
1408 bpf_map_inc_with_uref(aux
->map
);
1409 info
->map
= aux
->map
;
1410 info
->htab
= container_of(aux
->map
, struct bpf_shtab
, map
);
1414 static void sock_hash_fini_seq_private(void *priv_data
)
1416 struct sock_hash_seq_info
*info
= priv_data
;
1418 bpf_map_put_with_uref(info
->map
);
1421 static u64
sock_hash_mem_usage(const struct bpf_map
*map
)
1423 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1424 u64 usage
= sizeof(*htab
);
1426 usage
+= htab
->buckets_num
* sizeof(struct bpf_shtab_bucket
);
1427 usage
+= atomic_read(&htab
->count
) * (u64
)htab
->elem_size
;
1431 static const struct bpf_iter_seq_info sock_hash_iter_seq_info
= {
1432 .seq_ops
= &sock_hash_seq_ops
,
1433 .init_seq_private
= sock_hash_init_seq_private
,
1434 .fini_seq_private
= sock_hash_fini_seq_private
,
1435 .seq_priv_size
= sizeof(struct sock_hash_seq_info
),
1438 BTF_ID_LIST_SINGLE(sock_hash_map_btf_ids
, struct, bpf_shtab
)
1439 const struct bpf_map_ops sock_hash_ops
= {
1440 .map_meta_equal
= bpf_map_meta_equal
,
1441 .map_alloc
= sock_hash_alloc
,
1442 .map_free
= sock_hash_free
,
1443 .map_get_next_key
= sock_hash_get_next_key
,
1444 .map_update_elem
= sock_map_update_elem
,
1445 .map_delete_elem
= sock_hash_delete_elem
,
1446 .map_lookup_elem
= sock_hash_lookup
,
1447 .map_lookup_elem_sys_only
= sock_hash_lookup_sys
,
1448 .map_release_uref
= sock_hash_release_progs
,
1449 .map_check_btf
= map_check_no_btf
,
1450 .map_mem_usage
= sock_hash_mem_usage
,
1451 .map_btf_id
= &sock_hash_map_btf_ids
[0],
1452 .iter_seq_info
= &sock_hash_iter_seq_info
,
1455 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
)
1457 switch (map
->map_type
) {
1458 case BPF_MAP_TYPE_SOCKMAP
:
1459 return &container_of(map
, struct bpf_stab
, map
)->progs
;
1460 case BPF_MAP_TYPE_SOCKHASH
:
1461 return &container_of(map
, struct bpf_shtab
, map
)->progs
;
1469 static int sock_map_prog_link_lookup(struct bpf_map
*map
, struct bpf_prog
***pprog
,
1470 struct bpf_link
***plink
, u32 which
)
1472 struct sk_psock_progs
*progs
= sock_map_progs(map
);
1473 struct bpf_prog
**cur_pprog
;
1474 struct bpf_link
**cur_plink
;
1480 case BPF_SK_MSG_VERDICT
:
1481 cur_pprog
= &progs
->msg_parser
;
1482 cur_plink
= &progs
->msg_parser_link
;
1484 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1485 case BPF_SK_SKB_STREAM_PARSER
:
1486 cur_pprog
= &progs
->stream_parser
;
1487 cur_plink
= &progs
->stream_parser_link
;
1490 case BPF_SK_SKB_STREAM_VERDICT
:
1491 if (progs
->skb_verdict
)
1493 cur_pprog
= &progs
->stream_verdict
;
1494 cur_plink
= &progs
->stream_verdict_link
;
1496 case BPF_SK_SKB_VERDICT
:
1497 if (progs
->stream_verdict
)
1499 cur_pprog
= &progs
->skb_verdict
;
1500 cur_plink
= &progs
->skb_verdict_link
;
1512 /* Handle the following four cases:
1513 * prog_attach: prog != NULL, old == NULL, link == NULL
1514 * prog_detach: prog == NULL, old != NULL, link == NULL
1515 * link_attach: prog != NULL, old == NULL, link != NULL
1516 * link_detach: prog == NULL, old != NULL, link != NULL
1518 static int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
1519 struct bpf_prog
*old
, struct bpf_link
*link
,
1522 struct bpf_prog
**pprog
;
1523 struct bpf_link
**plink
;
1526 ret
= sock_map_prog_link_lookup(map
, &pprog
, &plink
, which
);
1530 /* for prog_attach/prog_detach/link_attach, return error if a bpf_link
1531 * exists for that prog.
1533 if ((!link
|| prog
) && *plink
)
1537 ret
= psock_replace_prog(pprog
, prog
, old
);
1541 psock_set_prog(pprog
, prog
);
1549 int sock_map_bpf_prog_query(const union bpf_attr
*attr
,
1550 union bpf_attr __user
*uattr
)
1552 __u32 __user
*prog_ids
= u64_to_user_ptr(attr
->query
.prog_ids
);
1553 u32 prog_cnt
= 0, flags
= 0;
1554 struct bpf_prog
**pprog
;
1555 struct bpf_prog
*prog
;
1556 struct bpf_map
*map
;
1560 if (attr
->query
.query_flags
)
1563 CLASS(fd
, f
)(attr
->target_fd
);
1564 map
= __bpf_map_get(f
);
1566 return PTR_ERR(map
);
1570 ret
= sock_map_prog_link_lookup(map
, &pprog
, NULL
, attr
->query
.attach_type
);
1575 prog_cnt
= !prog
? 0 : 1;
1577 if (!attr
->query
.prog_cnt
|| !prog_ids
|| !prog_cnt
)
1580 /* we do not hold the refcnt, the bpf prog may be released
1581 * asynchronously and the id would be set to 0.
1583 id
= data_race(prog
->aux
->id
);
1590 if (copy_to_user(&uattr
->query
.attach_flags
, &flags
, sizeof(flags
)) ||
1591 (id
!= 0 && copy_to_user(prog_ids
, &id
, sizeof(u32
))) ||
1592 copy_to_user(&uattr
->query
.prog_cnt
, &prog_cnt
, sizeof(prog_cnt
)))
1598 static void sock_map_unlink(struct sock
*sk
, struct sk_psock_link
*link
)
1600 switch (link
->map
->map_type
) {
1601 case BPF_MAP_TYPE_SOCKMAP
:
1602 return sock_map_delete_from_link(link
->map
, sk
,
1604 case BPF_MAP_TYPE_SOCKHASH
:
1605 return sock_hash_delete_from_link(link
->map
, sk
,
1612 static void sock_map_remove_links(struct sock
*sk
, struct sk_psock
*psock
)
1614 struct sk_psock_link
*link
;
1616 while ((link
= sk_psock_link_pop(psock
))) {
1617 sock_map_unlink(sk
, link
);
1618 sk_psock_free_link(link
);
1622 void sock_map_unhash(struct sock
*sk
)
1624 void (*saved_unhash
)(struct sock
*sk
);
1625 struct sk_psock
*psock
;
1628 psock
= sk_psock(sk
);
1629 if (unlikely(!psock
)) {
1631 saved_unhash
= READ_ONCE(sk
->sk_prot
)->unhash
;
1633 saved_unhash
= psock
->saved_unhash
;
1634 sock_map_remove_links(sk
, psock
);
1637 if (WARN_ON_ONCE(saved_unhash
== sock_map_unhash
))
1642 EXPORT_SYMBOL_GPL(sock_map_unhash
);
1644 void sock_map_destroy(struct sock
*sk
)
1646 void (*saved_destroy
)(struct sock
*sk
);
1647 struct sk_psock
*psock
;
1650 psock
= sk_psock_get(sk
);
1651 if (unlikely(!psock
)) {
1653 saved_destroy
= READ_ONCE(sk
->sk_prot
)->destroy
;
1655 saved_destroy
= psock
->saved_destroy
;
1656 sock_map_remove_links(sk
, psock
);
1658 sk_psock_stop(psock
);
1659 sk_psock_put(sk
, psock
);
1661 if (WARN_ON_ONCE(saved_destroy
== sock_map_destroy
))
1666 EXPORT_SYMBOL_GPL(sock_map_destroy
);
1668 void sock_map_close(struct sock
*sk
, long timeout
)
1670 void (*saved_close
)(struct sock
*sk
, long timeout
);
1671 struct sk_psock
*psock
;
1675 psock
= sk_psock(sk
);
1676 if (likely(psock
)) {
1677 saved_close
= psock
->saved_close
;
1678 sock_map_remove_links(sk
, psock
);
1679 psock
= sk_psock_get(sk
);
1680 if (unlikely(!psock
))
1683 sk_psock_stop(psock
);
1685 cancel_delayed_work_sync(&psock
->work
);
1686 sk_psock_put(sk
, psock
);
1688 saved_close
= READ_ONCE(sk
->sk_prot
)->close
;
1694 /* Make sure we do not recurse. This is a bug.
1695 * Leak the socket instead of crashing on a stack overflow.
1697 if (WARN_ON_ONCE(saved_close
== sock_map_close
))
1699 saved_close(sk
, timeout
);
1701 EXPORT_SYMBOL_GPL(sock_map_close
);
1703 struct sockmap_link
{
1704 struct bpf_link link
;
1705 struct bpf_map
*map
;
1706 enum bpf_attach_type attach_type
;
1709 static void sock_map_link_release(struct bpf_link
*link
)
1711 struct sockmap_link
*sockmap_link
= container_of(link
, struct sockmap_link
, link
);
1713 mutex_lock(&sockmap_mutex
);
1714 if (!sockmap_link
->map
)
1717 WARN_ON_ONCE(sock_map_prog_update(sockmap_link
->map
, NULL
, link
->prog
, link
,
1718 sockmap_link
->attach_type
));
1720 bpf_map_put_with_uref(sockmap_link
->map
);
1721 sockmap_link
->map
= NULL
;
1723 mutex_unlock(&sockmap_mutex
);
1726 static int sock_map_link_detach(struct bpf_link
*link
)
1728 sock_map_link_release(link
);
1732 static void sock_map_link_dealloc(struct bpf_link
*link
)
1737 /* Handle the following two cases:
1738 * case 1: link != NULL, prog != NULL, old != NULL
1739 * case 2: link != NULL, prog != NULL, old == NULL
1741 static int sock_map_link_update_prog(struct bpf_link
*link
,
1742 struct bpf_prog
*prog
,
1743 struct bpf_prog
*old
)
1745 const struct sockmap_link
*sockmap_link
= container_of(link
, struct sockmap_link
, link
);
1746 struct bpf_prog
**pprog
, *old_link_prog
;
1747 struct bpf_link
**plink
;
1750 mutex_lock(&sockmap_mutex
);
1752 /* If old prog is not NULL, ensure old prog is the same as link->prog. */
1753 if (old
&& link
->prog
!= old
) {
1757 /* Ensure link->prog has the same type/attach_type as the new prog. */
1758 if (link
->prog
->type
!= prog
->type
||
1759 link
->prog
->expected_attach_type
!= prog
->expected_attach_type
) {
1764 ret
= sock_map_prog_link_lookup(sockmap_link
->map
, &pprog
, &plink
,
1765 sockmap_link
->attach_type
);
1769 /* return error if the stored bpf_link does not match the incoming bpf_link. */
1770 if (link
!= *plink
) {
1776 ret
= psock_replace_prog(pprog
, prog
, old
);
1780 psock_set_prog(pprog
, prog
);
1784 old_link_prog
= xchg(&link
->prog
, prog
);
1785 bpf_prog_put(old_link_prog
);
1788 mutex_unlock(&sockmap_mutex
);
1792 static u32
sock_map_link_get_map_id(const struct sockmap_link
*sockmap_link
)
1796 mutex_lock(&sockmap_mutex
);
1797 if (sockmap_link
->map
)
1798 map_id
= sockmap_link
->map
->id
;
1799 mutex_unlock(&sockmap_mutex
);
1803 static int sock_map_link_fill_info(const struct bpf_link
*link
,
1804 struct bpf_link_info
*info
)
1806 const struct sockmap_link
*sockmap_link
= container_of(link
, struct sockmap_link
, link
);
1807 u32 map_id
= sock_map_link_get_map_id(sockmap_link
);
1809 info
->sockmap
.map_id
= map_id
;
1810 info
->sockmap
.attach_type
= sockmap_link
->attach_type
;
1814 static void sock_map_link_show_fdinfo(const struct bpf_link
*link
,
1815 struct seq_file
*seq
)
1817 const struct sockmap_link
*sockmap_link
= container_of(link
, struct sockmap_link
, link
);
1818 u32 map_id
= sock_map_link_get_map_id(sockmap_link
);
1820 seq_printf(seq
, "map_id:\t%u\n", map_id
);
1821 seq_printf(seq
, "attach_type:\t%u\n", sockmap_link
->attach_type
);
1824 static const struct bpf_link_ops sock_map_link_ops
= {
1825 .release
= sock_map_link_release
,
1826 .dealloc
= sock_map_link_dealloc
,
1827 .detach
= sock_map_link_detach
,
1828 .update_prog
= sock_map_link_update_prog
,
1829 .fill_link_info
= sock_map_link_fill_info
,
1830 .show_fdinfo
= sock_map_link_show_fdinfo
,
1833 int sock_map_link_create(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
1835 struct bpf_link_primer link_primer
;
1836 struct sockmap_link
*sockmap_link
;
1837 enum bpf_attach_type attach_type
;
1838 struct bpf_map
*map
;
1841 if (attr
->link_create
.flags
)
1844 map
= bpf_map_get_with_uref(attr
->link_create
.target_fd
);
1846 return PTR_ERR(map
);
1847 if (map
->map_type
!= BPF_MAP_TYPE_SOCKMAP
&& map
->map_type
!= BPF_MAP_TYPE_SOCKHASH
) {
1852 sockmap_link
= kzalloc(sizeof(*sockmap_link
), GFP_USER
);
1853 if (!sockmap_link
) {
1858 attach_type
= attr
->link_create
.attach_type
;
1859 bpf_link_init(&sockmap_link
->link
, BPF_LINK_TYPE_SOCKMAP
, &sock_map_link_ops
, prog
);
1860 sockmap_link
->map
= map
;
1861 sockmap_link
->attach_type
= attach_type
;
1863 ret
= bpf_link_prime(&sockmap_link
->link
, &link_primer
);
1865 kfree(sockmap_link
);
1869 mutex_lock(&sockmap_mutex
);
1870 ret
= sock_map_prog_update(map
, prog
, NULL
, &sockmap_link
->link
, attach_type
);
1871 mutex_unlock(&sockmap_mutex
);
1873 bpf_link_cleanup(&link_primer
);
1877 /* Increase refcnt for the prog since when old prog is replaced with
1878 * psock_replace_prog() and psock_set_prog() its refcnt will be decreased.
1880 * Actually, we do not need to increase refcnt for the prog since bpf_link
1881 * will hold a reference. But in order to have less complexity w.r.t.
1882 * replacing/setting prog, let us increase the refcnt to make things simpler.
1886 return bpf_link_settle(&link_primer
);
1889 bpf_map_put_with_uref(map
);
1893 static int sock_map_iter_attach_target(struct bpf_prog
*prog
,
1894 union bpf_iter_link_info
*linfo
,
1895 struct bpf_iter_aux_info
*aux
)
1897 struct bpf_map
*map
;
1900 if (!linfo
->map
.map_fd
)
1903 map
= bpf_map_get_with_uref(linfo
->map
.map_fd
);
1905 return PTR_ERR(map
);
1907 if (map
->map_type
!= BPF_MAP_TYPE_SOCKMAP
&&
1908 map
->map_type
!= BPF_MAP_TYPE_SOCKHASH
)
1911 if (prog
->aux
->max_rdonly_access
> map
->key_size
) {
1920 bpf_map_put_with_uref(map
);
1924 static void sock_map_iter_detach_target(struct bpf_iter_aux_info
*aux
)
1926 bpf_map_put_with_uref(aux
->map
);
1929 static struct bpf_iter_reg sock_map_iter_reg
= {
1930 .target
= "sockmap",
1931 .attach_target
= sock_map_iter_attach_target
,
1932 .detach_target
= sock_map_iter_detach_target
,
1933 .show_fdinfo
= bpf_iter_map_show_fdinfo
,
1934 .fill_link_info
= bpf_iter_map_fill_link_info
,
1935 .ctx_arg_info_size
= 2,
1937 { offsetof(struct bpf_iter__sockmap
, key
),
1938 PTR_TO_BUF
| PTR_MAYBE_NULL
| MEM_RDONLY
},
1939 { offsetof(struct bpf_iter__sockmap
, sk
),
1940 PTR_TO_BTF_ID_OR_NULL
},
1944 static int __init
bpf_sockmap_iter_init(void)
1946 sock_map_iter_reg
.ctx_arg_info
[1].btf_id
=
1947 btf_sock_ids
[BTF_SOCK_TYPE_SOCK
];
1948 return bpf_iter_reg_target(&sock_map_iter_reg
);
1950 late_initcall(bpf_sockmap_iter_init
);