1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
17 struct sk_psock_progs progs
;
21 #define SOCK_CREATE_FLAG_MASK \
22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
24 static struct bpf_map
*sock_map_alloc(union bpf_attr
*attr
)
26 struct bpf_stab
*stab
;
30 if (!capable(CAP_NET_ADMIN
))
31 return ERR_PTR(-EPERM
);
32 if (attr
->max_entries
== 0 ||
33 attr
->key_size
!= 4 ||
34 attr
->value_size
!= 4 ||
35 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
36 return ERR_PTR(-EINVAL
);
38 stab
= kzalloc(sizeof(*stab
), GFP_USER
);
40 return ERR_PTR(-ENOMEM
);
42 bpf_map_init_from_attr(&stab
->map
, attr
);
43 raw_spin_lock_init(&stab
->lock
);
45 /* Make sure page count doesn't overflow. */
46 cost
= (u64
) stab
->map
.max_entries
* sizeof(struct sock
*);
47 err
= bpf_map_charge_init(&stab
->map
.memory
, cost
);
51 stab
->sks
= bpf_map_area_alloc(stab
->map
.max_entries
*
52 sizeof(struct sock
*),
57 bpf_map_charge_finish(&stab
->map
.memory
);
63 int sock_map_get_from_fd(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
65 u32 ufd
= attr
->target_fd
;
71 map
= __bpf_map_get(f
);
74 ret
= sock_map_prog_update(map
, prog
, NULL
, attr
->attach_type
);
79 int sock_map_prog_detach(const union bpf_attr
*attr
, enum bpf_prog_type ptype
)
81 u32 ufd
= attr
->target_fd
;
82 struct bpf_prog
*prog
;
87 if (attr
->attach_flags
)
91 map
= __bpf_map_get(f
);
95 prog
= bpf_prog_get(attr
->attach_bpf_fd
);
101 if (prog
->type
!= ptype
) {
106 ret
= sock_map_prog_update(map
, NULL
, prog
, attr
->attach_type
);
114 static void sock_map_sk_acquire(struct sock
*sk
)
115 __acquires(&sk
->sk_lock
.slock
)
122 static void sock_map_sk_release(struct sock
*sk
)
123 __releases(&sk
->sk_lock
.slock
)
130 static void sock_map_add_link(struct sk_psock
*psock
,
131 struct sk_psock_link
*link
,
132 struct bpf_map
*map
, void *link_raw
)
134 link
->link_raw
= link_raw
;
136 spin_lock_bh(&psock
->link_lock
);
137 list_add_tail(&link
->list
, &psock
->link
);
138 spin_unlock_bh(&psock
->link_lock
);
141 static void sock_map_del_link(struct sock
*sk
,
142 struct sk_psock
*psock
, void *link_raw
)
144 struct sk_psock_link
*link
, *tmp
;
145 bool strp_stop
= false;
147 spin_lock_bh(&psock
->link_lock
);
148 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
149 if (link
->link_raw
== link_raw
) {
150 struct bpf_map
*map
= link
->map
;
151 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
,
153 if (psock
->parser
.enabled
&& stab
->progs
.skb_parser
)
155 list_del(&link
->list
);
156 sk_psock_free_link(link
);
159 spin_unlock_bh(&psock
->link_lock
);
161 write_lock_bh(&sk
->sk_callback_lock
);
162 sk_psock_stop_strp(sk
, psock
);
163 write_unlock_bh(&sk
->sk_callback_lock
);
167 static void sock_map_unref(struct sock
*sk
, void *link_raw
)
169 struct sk_psock
*psock
= sk_psock(sk
);
172 sock_map_del_link(sk
, psock
, link_raw
);
173 sk_psock_put(sk
, psock
);
177 static int sock_map_link(struct bpf_map
*map
, struct sk_psock_progs
*progs
,
180 struct bpf_prog
*msg_parser
, *skb_parser
, *skb_verdict
;
181 bool skb_progs
, sk_psock_is_new
= false;
182 struct sk_psock
*psock
;
185 skb_verdict
= READ_ONCE(progs
->skb_verdict
);
186 skb_parser
= READ_ONCE(progs
->skb_parser
);
187 skb_progs
= skb_parser
&& skb_verdict
;
189 skb_verdict
= bpf_prog_inc_not_zero(skb_verdict
);
190 if (IS_ERR(skb_verdict
))
191 return PTR_ERR(skb_verdict
);
192 skb_parser
= bpf_prog_inc_not_zero(skb_parser
);
193 if (IS_ERR(skb_parser
)) {
194 bpf_prog_put(skb_verdict
);
195 return PTR_ERR(skb_parser
);
199 msg_parser
= READ_ONCE(progs
->msg_parser
);
201 msg_parser
= bpf_prog_inc_not_zero(msg_parser
);
202 if (IS_ERR(msg_parser
)) {
203 ret
= PTR_ERR(msg_parser
);
208 psock
= sk_psock_get_checked(sk
);
210 ret
= PTR_ERR(psock
);
215 if ((msg_parser
&& READ_ONCE(psock
->progs
.msg_parser
)) ||
216 (skb_progs
&& READ_ONCE(psock
->progs
.skb_parser
))) {
217 sk_psock_put(sk
, psock
);
222 psock
= sk_psock_init(sk
, map
->numa_node
);
227 sk_psock_is_new
= true;
231 psock_set_prog(&psock
->progs
.msg_parser
, msg_parser
);
232 if (sk_psock_is_new
) {
233 ret
= tcp_bpf_init(sk
);
240 write_lock_bh(&sk
->sk_callback_lock
);
241 if (skb_progs
&& !psock
->parser
.enabled
) {
242 ret
= sk_psock_init_strp(sk
, psock
);
244 write_unlock_bh(&sk
->sk_callback_lock
);
247 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
248 psock_set_prog(&psock
->progs
.skb_parser
, skb_parser
);
249 sk_psock_start_strp(sk
, psock
);
251 write_unlock_bh(&sk
->sk_callback_lock
);
254 sk_psock_put(sk
, psock
);
257 bpf_prog_put(msg_parser
);
260 bpf_prog_put(skb_verdict
);
261 bpf_prog_put(skb_parser
);
266 static void sock_map_free(struct bpf_map
*map
)
268 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
271 /* After the sync no updates or deletes will be in-flight so it
272 * is safe to walk map and remove entries without risking a race
273 * in EEXIST update case.
276 for (i
= 0; i
< stab
->map
.max_entries
; i
++) {
277 struct sock
**psk
= &stab
->sks
[i
];
280 sk
= xchg(psk
, NULL
);
284 sock_map_unref(sk
, psk
);
290 /* wait for psock readers accessing its map link */
293 bpf_map_area_free(stab
->sks
);
297 static void sock_map_release_progs(struct bpf_map
*map
)
299 psock_progs_drop(&container_of(map
, struct bpf_stab
, map
)->progs
);
302 static struct sock
*__sock_map_lookup_elem(struct bpf_map
*map
, u32 key
)
304 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
306 WARN_ON_ONCE(!rcu_read_lock_held());
308 if (unlikely(key
>= map
->max_entries
))
310 return READ_ONCE(stab
->sks
[key
]);
313 static void *sock_map_lookup(struct bpf_map
*map
, void *key
)
315 return ERR_PTR(-EOPNOTSUPP
);
318 static int __sock_map_delete(struct bpf_stab
*stab
, struct sock
*sk_test
,
324 raw_spin_lock_bh(&stab
->lock
);
326 if (!sk_test
|| sk_test
== sk
)
327 sk
= xchg(psk
, NULL
);
330 sock_map_unref(sk
, psk
);
334 raw_spin_unlock_bh(&stab
->lock
);
338 static void sock_map_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
341 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
343 __sock_map_delete(stab
, sk
, link_raw
);
346 static int sock_map_delete_elem(struct bpf_map
*map
, void *key
)
348 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
352 if (unlikely(i
>= map
->max_entries
))
356 return __sock_map_delete(stab
, NULL
, psk
);
359 static int sock_map_get_next_key(struct bpf_map
*map
, void *key
, void *next
)
361 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
362 u32 i
= key
? *(u32
*)key
: U32_MAX
;
363 u32
*key_next
= next
;
365 if (i
== stab
->map
.max_entries
- 1)
367 if (i
>= stab
->map
.max_entries
)
374 static int sock_map_update_common(struct bpf_map
*map
, u32 idx
,
375 struct sock
*sk
, u64 flags
)
377 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
378 struct inet_connection_sock
*icsk
= inet_csk(sk
);
379 struct sk_psock_link
*link
;
380 struct sk_psock
*psock
;
384 WARN_ON_ONCE(!rcu_read_lock_held());
385 if (unlikely(flags
> BPF_EXIST
))
387 if (unlikely(idx
>= map
->max_entries
))
389 if (unlikely(rcu_access_pointer(icsk
->icsk_ulp_data
)))
392 link
= sk_psock_init_link();
396 ret
= sock_map_link(map
, &stab
->progs
, sk
);
400 psock
= sk_psock(sk
);
401 WARN_ON_ONCE(!psock
);
403 raw_spin_lock_bh(&stab
->lock
);
404 osk
= stab
->sks
[idx
];
405 if (osk
&& flags
== BPF_NOEXIST
) {
408 } else if (!osk
&& flags
== BPF_EXIST
) {
413 sock_map_add_link(psock
, link
, map
, &stab
->sks
[idx
]);
416 sock_map_unref(osk
, &stab
->sks
[idx
]);
417 raw_spin_unlock_bh(&stab
->lock
);
420 raw_spin_unlock_bh(&stab
->lock
);
422 sk_psock_put(sk
, psock
);
424 sk_psock_free_link(link
);
428 static bool sock_map_op_okay(const struct bpf_sock_ops_kern
*ops
)
430 return ops
->op
== BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
||
431 ops
->op
== BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
;
434 static bool sock_map_sk_is_suitable(const struct sock
*sk
)
436 return sk
->sk_type
== SOCK_STREAM
&&
437 sk
->sk_protocol
== IPPROTO_TCP
;
440 static int sock_map_update_elem(struct bpf_map
*map
, void *key
,
441 void *value
, u64 flags
)
443 u32 ufd
= *(u32
*)value
;
444 u32 idx
= *(u32
*)key
;
449 sock
= sockfd_lookup(ufd
, &ret
);
457 if (!sock_map_sk_is_suitable(sk
)) {
462 sock_map_sk_acquire(sk
);
463 if (sk
->sk_state
!= TCP_ESTABLISHED
)
466 ret
= sock_map_update_common(map
, idx
, sk
, flags
);
467 sock_map_sk_release(sk
);
473 BPF_CALL_4(bpf_sock_map_update
, struct bpf_sock_ops_kern
*, sops
,
474 struct bpf_map
*, map
, void *, key
, u64
, flags
)
476 WARN_ON_ONCE(!rcu_read_lock_held());
478 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
479 sock_map_op_okay(sops
)))
480 return sock_map_update_common(map
, *(u32
*)key
, sops
->sk
,
485 const struct bpf_func_proto bpf_sock_map_update_proto
= {
486 .func
= bpf_sock_map_update
,
489 .ret_type
= RET_INTEGER
,
490 .arg1_type
= ARG_PTR_TO_CTX
,
491 .arg2_type
= ARG_CONST_MAP_PTR
,
492 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
493 .arg4_type
= ARG_ANYTHING
,
496 BPF_CALL_4(bpf_sk_redirect_map
, struct sk_buff
*, skb
,
497 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
499 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
501 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
503 tcb
->bpf
.flags
= flags
;
504 tcb
->bpf
.sk_redir
= __sock_map_lookup_elem(map
, key
);
505 if (!tcb
->bpf
.sk_redir
)
510 const struct bpf_func_proto bpf_sk_redirect_map_proto
= {
511 .func
= bpf_sk_redirect_map
,
513 .ret_type
= RET_INTEGER
,
514 .arg1_type
= ARG_PTR_TO_CTX
,
515 .arg2_type
= ARG_CONST_MAP_PTR
,
516 .arg3_type
= ARG_ANYTHING
,
517 .arg4_type
= ARG_ANYTHING
,
520 BPF_CALL_4(bpf_msg_redirect_map
, struct sk_msg
*, msg
,
521 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
523 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
526 msg
->sk_redir
= __sock_map_lookup_elem(map
, key
);
532 const struct bpf_func_proto bpf_msg_redirect_map_proto
= {
533 .func
= bpf_msg_redirect_map
,
535 .ret_type
= RET_INTEGER
,
536 .arg1_type
= ARG_PTR_TO_CTX
,
537 .arg2_type
= ARG_CONST_MAP_PTR
,
538 .arg3_type
= ARG_ANYTHING
,
539 .arg4_type
= ARG_ANYTHING
,
542 const struct bpf_map_ops sock_map_ops
= {
543 .map_alloc
= sock_map_alloc
,
544 .map_free
= sock_map_free
,
545 .map_get_next_key
= sock_map_get_next_key
,
546 .map_update_elem
= sock_map_update_elem
,
547 .map_delete_elem
= sock_map_delete_elem
,
548 .map_lookup_elem
= sock_map_lookup
,
549 .map_release_uref
= sock_map_release_progs
,
550 .map_check_btf
= map_check_no_btf
,
553 struct bpf_htab_elem
{
557 struct hlist_node node
;
561 struct bpf_htab_bucket
{
562 struct hlist_head head
;
568 struct bpf_htab_bucket
*buckets
;
571 struct sk_psock_progs progs
;
575 static inline u32
sock_hash_bucket_hash(const void *key
, u32 len
)
577 return jhash(key
, len
, 0);
580 static struct bpf_htab_bucket
*sock_hash_select_bucket(struct bpf_htab
*htab
,
583 return &htab
->buckets
[hash
& (htab
->buckets_num
- 1)];
586 static struct bpf_htab_elem
*
587 sock_hash_lookup_elem_raw(struct hlist_head
*head
, u32 hash
, void *key
,
590 struct bpf_htab_elem
*elem
;
592 hlist_for_each_entry_rcu(elem
, head
, node
) {
593 if (elem
->hash
== hash
&&
594 !memcmp(&elem
->key
, key
, key_size
))
601 static struct sock
*__sock_hash_lookup_elem(struct bpf_map
*map
, void *key
)
603 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
604 u32 key_size
= map
->key_size
, hash
;
605 struct bpf_htab_bucket
*bucket
;
606 struct bpf_htab_elem
*elem
;
608 WARN_ON_ONCE(!rcu_read_lock_held());
610 hash
= sock_hash_bucket_hash(key
, key_size
);
611 bucket
= sock_hash_select_bucket(htab
, hash
);
612 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
614 return elem
? elem
->sk
: NULL
;
617 static void sock_hash_free_elem(struct bpf_htab
*htab
,
618 struct bpf_htab_elem
*elem
)
620 atomic_dec(&htab
->count
);
621 kfree_rcu(elem
, rcu
);
624 static void sock_hash_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
627 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
628 struct bpf_htab_elem
*elem_probe
, *elem
= link_raw
;
629 struct bpf_htab_bucket
*bucket
;
631 WARN_ON_ONCE(!rcu_read_lock_held());
632 bucket
= sock_hash_select_bucket(htab
, elem
->hash
);
634 /* elem may be deleted in parallel from the map, but access here
635 * is okay since it's going away only after RCU grace period.
636 * However, we need to check whether it's still present.
638 raw_spin_lock_bh(&bucket
->lock
);
639 elem_probe
= sock_hash_lookup_elem_raw(&bucket
->head
, elem
->hash
,
640 elem
->key
, map
->key_size
);
641 if (elem_probe
&& elem_probe
== elem
) {
642 hlist_del_rcu(&elem
->node
);
643 sock_map_unref(elem
->sk
, elem
);
644 sock_hash_free_elem(htab
, elem
);
646 raw_spin_unlock_bh(&bucket
->lock
);
649 static int sock_hash_delete_elem(struct bpf_map
*map
, void *key
)
651 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
652 u32 hash
, key_size
= map
->key_size
;
653 struct bpf_htab_bucket
*bucket
;
654 struct bpf_htab_elem
*elem
;
657 hash
= sock_hash_bucket_hash(key
, key_size
);
658 bucket
= sock_hash_select_bucket(htab
, hash
);
660 raw_spin_lock_bh(&bucket
->lock
);
661 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
663 hlist_del_rcu(&elem
->node
);
664 sock_map_unref(elem
->sk
, elem
);
665 sock_hash_free_elem(htab
, elem
);
668 raw_spin_unlock_bh(&bucket
->lock
);
672 static struct bpf_htab_elem
*sock_hash_alloc_elem(struct bpf_htab
*htab
,
673 void *key
, u32 key_size
,
674 u32 hash
, struct sock
*sk
,
675 struct bpf_htab_elem
*old
)
677 struct bpf_htab_elem
*new;
679 if (atomic_inc_return(&htab
->count
) > htab
->map
.max_entries
) {
681 atomic_dec(&htab
->count
);
682 return ERR_PTR(-E2BIG
);
686 new = kmalloc_node(htab
->elem_size
, GFP_ATOMIC
| __GFP_NOWARN
,
687 htab
->map
.numa_node
);
689 atomic_dec(&htab
->count
);
690 return ERR_PTR(-ENOMEM
);
692 memcpy(new->key
, key
, key_size
);
698 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
699 struct sock
*sk
, u64 flags
)
701 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
702 struct inet_connection_sock
*icsk
= inet_csk(sk
);
703 u32 key_size
= map
->key_size
, hash
;
704 struct bpf_htab_elem
*elem
, *elem_new
;
705 struct bpf_htab_bucket
*bucket
;
706 struct sk_psock_link
*link
;
707 struct sk_psock
*psock
;
710 WARN_ON_ONCE(!rcu_read_lock_held());
711 if (unlikely(flags
> BPF_EXIST
))
713 if (unlikely(icsk
->icsk_ulp_data
))
716 link
= sk_psock_init_link();
720 ret
= sock_map_link(map
, &htab
->progs
, sk
);
724 psock
= sk_psock(sk
);
725 WARN_ON_ONCE(!psock
);
727 hash
= sock_hash_bucket_hash(key
, key_size
);
728 bucket
= sock_hash_select_bucket(htab
, hash
);
730 raw_spin_lock_bh(&bucket
->lock
);
731 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
732 if (elem
&& flags
== BPF_NOEXIST
) {
735 } else if (!elem
&& flags
== BPF_EXIST
) {
740 elem_new
= sock_hash_alloc_elem(htab
, key
, key_size
, hash
, sk
, elem
);
741 if (IS_ERR(elem_new
)) {
742 ret
= PTR_ERR(elem_new
);
746 sock_map_add_link(psock
, link
, map
, elem_new
);
747 /* Add new element to the head of the list, so that
748 * concurrent search will find it before old elem.
750 hlist_add_head_rcu(&elem_new
->node
, &bucket
->head
);
752 hlist_del_rcu(&elem
->node
);
753 sock_map_unref(elem
->sk
, elem
);
754 sock_hash_free_elem(htab
, elem
);
756 raw_spin_unlock_bh(&bucket
->lock
);
759 raw_spin_unlock_bh(&bucket
->lock
);
760 sk_psock_put(sk
, psock
);
762 sk_psock_free_link(link
);
766 static int sock_hash_update_elem(struct bpf_map
*map
, void *key
,
767 void *value
, u64 flags
)
769 u32 ufd
= *(u32
*)value
;
774 sock
= sockfd_lookup(ufd
, &ret
);
782 if (!sock_map_sk_is_suitable(sk
)) {
787 sock_map_sk_acquire(sk
);
788 if (sk
->sk_state
!= TCP_ESTABLISHED
)
791 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
792 sock_map_sk_release(sk
);
798 static int sock_hash_get_next_key(struct bpf_map
*map
, void *key
,
801 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
802 struct bpf_htab_elem
*elem
, *elem_next
;
803 u32 hash
, key_size
= map
->key_size
;
804 struct hlist_head
*head
;
808 goto find_first_elem
;
809 hash
= sock_hash_bucket_hash(key
, key_size
);
810 head
= &sock_hash_select_bucket(htab
, hash
)->head
;
811 elem
= sock_hash_lookup_elem_raw(head
, hash
, key
, key_size
);
813 goto find_first_elem
;
815 elem_next
= hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem
->node
)),
816 struct bpf_htab_elem
, node
);
818 memcpy(key_next
, elem_next
->key
, key_size
);
822 i
= hash
& (htab
->buckets_num
- 1);
825 for (; i
< htab
->buckets_num
; i
++) {
826 head
= &sock_hash_select_bucket(htab
, i
)->head
;
827 elem_next
= hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head
)),
828 struct bpf_htab_elem
, node
);
830 memcpy(key_next
, elem_next
->key
, key_size
);
838 static struct bpf_map
*sock_hash_alloc(union bpf_attr
*attr
)
840 struct bpf_htab
*htab
;
844 if (!capable(CAP_NET_ADMIN
))
845 return ERR_PTR(-EPERM
);
846 if (attr
->max_entries
== 0 ||
847 attr
->key_size
== 0 ||
848 attr
->value_size
!= 4 ||
849 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
850 return ERR_PTR(-EINVAL
);
851 if (attr
->key_size
> MAX_BPF_STACK
)
852 return ERR_PTR(-E2BIG
);
854 htab
= kzalloc(sizeof(*htab
), GFP_USER
);
856 return ERR_PTR(-ENOMEM
);
858 bpf_map_init_from_attr(&htab
->map
, attr
);
860 htab
->buckets_num
= roundup_pow_of_two(htab
->map
.max_entries
);
861 htab
->elem_size
= sizeof(struct bpf_htab_elem
) +
862 round_up(htab
->map
.key_size
, 8);
863 if (htab
->buckets_num
== 0 ||
864 htab
->buckets_num
> U32_MAX
/ sizeof(struct bpf_htab_bucket
)) {
869 cost
= (u64
) htab
->buckets_num
* sizeof(struct bpf_htab_bucket
) +
870 (u64
) htab
->elem_size
* htab
->map
.max_entries
;
871 if (cost
>= U32_MAX
- PAGE_SIZE
) {
875 err
= bpf_map_charge_init(&htab
->map
.memory
, cost
);
879 htab
->buckets
= bpf_map_area_alloc(htab
->buckets_num
*
880 sizeof(struct bpf_htab_bucket
),
881 htab
->map
.numa_node
);
882 if (!htab
->buckets
) {
883 bpf_map_charge_finish(&htab
->map
.memory
);
888 for (i
= 0; i
< htab
->buckets_num
; i
++) {
889 INIT_HLIST_HEAD(&htab
->buckets
[i
].head
);
890 raw_spin_lock_init(&htab
->buckets
[i
].lock
);
899 static void sock_hash_free(struct bpf_map
*map
)
901 struct bpf_htab
*htab
= container_of(map
, struct bpf_htab
, map
);
902 struct bpf_htab_bucket
*bucket
;
903 struct hlist_head unlink_list
;
904 struct bpf_htab_elem
*elem
;
905 struct hlist_node
*node
;
908 /* After the sync no updates or deletes will be in-flight so it
909 * is safe to walk map and remove entries without risking a race
910 * in EEXIST update case.
913 for (i
= 0; i
< htab
->buckets_num
; i
++) {
914 bucket
= sock_hash_select_bucket(htab
, i
);
916 /* We are racing with sock_hash_delete_from_link to
917 * enter the spin-lock critical section. Every socket on
918 * the list is still linked to sockhash. Since link
919 * exists, psock exists and holds a ref to socket. That
920 * lets us to grab a socket ref too.
922 raw_spin_lock_bh(&bucket
->lock
);
923 hlist_for_each_entry(elem
, &bucket
->head
, node
)
925 hlist_move_list(&bucket
->head
, &unlink_list
);
926 raw_spin_unlock_bh(&bucket
->lock
);
928 /* Process removed entries out of atomic context to
929 * block for socket lock before deleting the psock's
932 hlist_for_each_entry_safe(elem
, node
, &unlink_list
, node
) {
933 hlist_del(&elem
->node
);
936 sock_map_unref(elem
->sk
, elem
);
938 release_sock(elem
->sk
);
940 sock_hash_free_elem(htab
, elem
);
944 /* wait for psock readers accessing its map link */
947 /* wait for psock readers accessing its map link */
950 bpf_map_area_free(htab
->buckets
);
954 static void sock_hash_release_progs(struct bpf_map
*map
)
956 psock_progs_drop(&container_of(map
, struct bpf_htab
, map
)->progs
);
959 BPF_CALL_4(bpf_sock_hash_update
, struct bpf_sock_ops_kern
*, sops
,
960 struct bpf_map
*, map
, void *, key
, u64
, flags
)
962 WARN_ON_ONCE(!rcu_read_lock_held());
964 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
965 sock_map_op_okay(sops
)))
966 return sock_hash_update_common(map
, key
, sops
->sk
, flags
);
970 const struct bpf_func_proto bpf_sock_hash_update_proto
= {
971 .func
= bpf_sock_hash_update
,
974 .ret_type
= RET_INTEGER
,
975 .arg1_type
= ARG_PTR_TO_CTX
,
976 .arg2_type
= ARG_CONST_MAP_PTR
,
977 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
978 .arg4_type
= ARG_ANYTHING
,
981 BPF_CALL_4(bpf_sk_redirect_hash
, struct sk_buff
*, skb
,
982 struct bpf_map
*, map
, void *, key
, u64
, flags
)
984 struct tcp_skb_cb
*tcb
= TCP_SKB_CB(skb
);
986 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
988 tcb
->bpf
.flags
= flags
;
989 tcb
->bpf
.sk_redir
= __sock_hash_lookup_elem(map
, key
);
990 if (!tcb
->bpf
.sk_redir
)
995 const struct bpf_func_proto bpf_sk_redirect_hash_proto
= {
996 .func
= bpf_sk_redirect_hash
,
998 .ret_type
= RET_INTEGER
,
999 .arg1_type
= ARG_PTR_TO_CTX
,
1000 .arg2_type
= ARG_CONST_MAP_PTR
,
1001 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1002 .arg4_type
= ARG_ANYTHING
,
1005 BPF_CALL_4(bpf_msg_redirect_hash
, struct sk_msg
*, msg
,
1006 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1008 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1011 msg
->sk_redir
= __sock_hash_lookup_elem(map
, key
);
1017 const struct bpf_func_proto bpf_msg_redirect_hash_proto
= {
1018 .func
= bpf_msg_redirect_hash
,
1020 .ret_type
= RET_INTEGER
,
1021 .arg1_type
= ARG_PTR_TO_CTX
,
1022 .arg2_type
= ARG_CONST_MAP_PTR
,
1023 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1024 .arg4_type
= ARG_ANYTHING
,
1027 const struct bpf_map_ops sock_hash_ops
= {
1028 .map_alloc
= sock_hash_alloc
,
1029 .map_free
= sock_hash_free
,
1030 .map_get_next_key
= sock_hash_get_next_key
,
1031 .map_update_elem
= sock_hash_update_elem
,
1032 .map_delete_elem
= sock_hash_delete_elem
,
1033 .map_lookup_elem
= sock_map_lookup
,
1034 .map_release_uref
= sock_hash_release_progs
,
1035 .map_check_btf
= map_check_no_btf
,
1038 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
)
1040 switch (map
->map_type
) {
1041 case BPF_MAP_TYPE_SOCKMAP
:
1042 return &container_of(map
, struct bpf_stab
, map
)->progs
;
1043 case BPF_MAP_TYPE_SOCKHASH
:
1044 return &container_of(map
, struct bpf_htab
, map
)->progs
;
1052 int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
1053 struct bpf_prog
*old
, u32 which
)
1055 struct sk_psock_progs
*progs
= sock_map_progs(map
);
1056 struct bpf_prog
**pprog
;
1062 case BPF_SK_MSG_VERDICT
:
1063 pprog
= &progs
->msg_parser
;
1065 case BPF_SK_SKB_STREAM_PARSER
:
1066 pprog
= &progs
->skb_parser
;
1068 case BPF_SK_SKB_STREAM_VERDICT
:
1069 pprog
= &progs
->skb_verdict
;
1076 return psock_replace_prog(pprog
, prog
, old
);
1078 psock_set_prog(pprog
, prog
);
1082 void sk_psock_unlink(struct sock
*sk
, struct sk_psock_link
*link
)
1084 switch (link
->map
->map_type
) {
1085 case BPF_MAP_TYPE_SOCKMAP
:
1086 return sock_map_delete_from_link(link
->map
, sk
,
1088 case BPF_MAP_TYPE_SOCKHASH
:
1089 return sock_hash_delete_from_link(link
->map
, sk
,