1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 #include "allowedips.h"
9 enum { MAX_ALLOWEDIPS_DEPTH
= 129 };
11 static struct kmem_cache
*node_cache
;
13 static void swap_endian(u8
*dst
, const u8
*src
, u8 bits
)
16 *(u32
*)dst
= be32_to_cpu(*(const __be32
*)src
);
17 } else if (bits
== 128) {
18 ((u64
*)dst
)[0] = get_unaligned_be64(src
);
19 ((u64
*)dst
)[1] = get_unaligned_be64(src
+ 8);
23 static void copy_and_assign_cidr(struct allowedips_node
*node
, const u8
*src
,
27 node
->bit_at_a
= cidr
/ 8U;
28 #ifdef __LITTLE_ENDIAN
29 node
->bit_at_a
^= (bits
/ 8U - 1U) % 8U;
31 node
->bit_at_b
= 7U - (cidr
% 8U);
33 memcpy(node
->bits
, src
, bits
/ 8U);
36 static inline u8
choose(struct allowedips_node
*node
, const u8
*key
)
38 return (key
[node
->bit_at_a
] >> node
->bit_at_b
) & 1;
41 static void push_rcu(struct allowedips_node
**stack
,
42 struct allowedips_node __rcu
*p
, unsigned int *len
)
44 if (rcu_access_pointer(p
)) {
45 if (WARN_ON(IS_ENABLED(DEBUG
) && *len
>= MAX_ALLOWEDIPS_DEPTH
))
47 stack
[(*len
)++] = rcu_dereference_raw(p
);
51 static void node_free_rcu(struct rcu_head
*rcu
)
53 kmem_cache_free(node_cache
, container_of(rcu
, struct allowedips_node
, rcu
));
56 static void root_free_rcu(struct rcu_head
*rcu
)
58 struct allowedips_node
*node
, *stack
[MAX_ALLOWEDIPS_DEPTH
] = {
59 container_of(rcu
, struct allowedips_node
, rcu
) };
62 while (len
> 0 && (node
= stack
[--len
])) {
63 push_rcu(stack
, node
->bit
[0], &len
);
64 push_rcu(stack
, node
->bit
[1], &len
);
65 kmem_cache_free(node_cache
, node
);
69 static void root_remove_peer_lists(struct allowedips_node
*root
)
71 struct allowedips_node
*node
, *stack
[MAX_ALLOWEDIPS_DEPTH
] = { root
};
74 while (len
> 0 && (node
= stack
[--len
])) {
75 push_rcu(stack
, node
->bit
[0], &len
);
76 push_rcu(stack
, node
->bit
[1], &len
);
77 if (rcu_access_pointer(node
->peer
))
78 list_del(&node
->peer_list
);
82 static unsigned int fls128(u64 a
, u64 b
)
84 return a
? fls64(a
) + 64U : fls64(b
);
87 static u8
common_bits(const struct allowedips_node
*node
, const u8
*key
,
91 return 32U - fls(*(const u32
*)node
->bits
^ *(const u32
*)key
);
94 *(const u64
*)&node
->bits
[0] ^ *(const u64
*)&key
[0],
95 *(const u64
*)&node
->bits
[8] ^ *(const u64
*)&key
[8]);
99 static bool prefix_matches(const struct allowedips_node
*node
, const u8
*key
,
102 /* This could be much faster if it actually just compared the common
103 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and
104 * the rest, but it turns out that common_bits is already super fast on
105 * modern processors, even taking into account the unfortunate bswap.
106 * So, we just inline it like this instead.
108 return common_bits(node
, key
, bits
) >= node
->cidr
;
111 static struct allowedips_node
*find_node(struct allowedips_node
*trie
, u8 bits
,
114 struct allowedips_node
*node
= trie
, *found
= NULL
;
116 while (node
&& prefix_matches(node
, key
, bits
)) {
117 if (rcu_access_pointer(node
->peer
))
119 if (node
->cidr
== bits
)
121 node
= rcu_dereference_bh(node
->bit
[choose(node
, key
)]);
126 /* Returns a strong reference to a peer */
127 static struct wg_peer
*lookup(struct allowedips_node __rcu
*root
, u8 bits
,
130 /* Aligned so it can be passed to fls/fls64 */
131 u8 ip
[16] __aligned(__alignof(u64
));
132 struct allowedips_node
*node
;
133 struct wg_peer
*peer
= NULL
;
135 swap_endian(ip
, be_ip
, bits
);
139 node
= find_node(rcu_dereference_bh(root
), bits
, ip
);
141 peer
= wg_peer_get_maybe_zero(rcu_dereference_bh(node
->peer
));
145 rcu_read_unlock_bh();
149 static bool node_placement(struct allowedips_node __rcu
*trie
, const u8
*key
,
150 u8 cidr
, u8 bits
, struct allowedips_node
**rnode
,
153 struct allowedips_node
*node
= rcu_dereference_protected(trie
, lockdep_is_held(lock
));
154 struct allowedips_node
*parent
= NULL
;
157 while (node
&& node
->cidr
<= cidr
&& prefix_matches(node
, key
, bits
)) {
159 if (parent
->cidr
== cidr
) {
163 node
= rcu_dereference_protected(parent
->bit
[choose(parent
, key
)], lockdep_is_held(lock
));
169 static inline void connect_node(struct allowedips_node __rcu
**parent
, u8 bit
, struct allowedips_node
*node
)
171 node
->parent_bit_packed
= (unsigned long)parent
| bit
;
172 rcu_assign_pointer(*parent
, node
);
175 static inline void choose_and_connect_node(struct allowedips_node
*parent
, struct allowedips_node
*node
)
177 u8 bit
= choose(parent
, node
->bits
);
178 connect_node(&parent
->bit
[bit
], bit
, node
);
181 static int add(struct allowedips_node __rcu
**trie
, u8 bits
, const u8
*key
,
182 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
184 struct allowedips_node
*node
, *parent
, *down
, *newnode
;
186 if (unlikely(cidr
> bits
|| !peer
))
189 if (!rcu_access_pointer(*trie
)) {
190 node
= kmem_cache_zalloc(node_cache
, GFP_KERNEL
);
193 RCU_INIT_POINTER(node
->peer
, peer
);
194 list_add_tail(&node
->peer_list
, &peer
->allowedips_list
);
195 copy_and_assign_cidr(node
, key
, cidr
, bits
);
196 connect_node(trie
, 2, node
);
199 if (node_placement(*trie
, key
, cidr
, bits
, &node
, lock
)) {
200 rcu_assign_pointer(node
->peer
, peer
);
201 list_move_tail(&node
->peer_list
, &peer
->allowedips_list
);
205 newnode
= kmem_cache_zalloc(node_cache
, GFP_KERNEL
);
206 if (unlikely(!newnode
))
208 RCU_INIT_POINTER(newnode
->peer
, peer
);
209 list_add_tail(&newnode
->peer_list
, &peer
->allowedips_list
);
210 copy_and_assign_cidr(newnode
, key
, cidr
, bits
);
213 down
= rcu_dereference_protected(*trie
, lockdep_is_held(lock
));
215 const u8 bit
= choose(node
, key
);
216 down
= rcu_dereference_protected(node
->bit
[bit
], lockdep_is_held(lock
));
218 connect_node(&node
->bit
[bit
], bit
, newnode
);
222 cidr
= min(cidr
, common_bits(down
, key
, bits
));
225 if (newnode
->cidr
== cidr
) {
226 choose_and_connect_node(newnode
, down
);
228 connect_node(trie
, 2, newnode
);
230 choose_and_connect_node(parent
, newnode
);
234 node
= kmem_cache_zalloc(node_cache
, GFP_KERNEL
);
235 if (unlikely(!node
)) {
236 list_del(&newnode
->peer_list
);
237 kmem_cache_free(node_cache
, newnode
);
240 INIT_LIST_HEAD(&node
->peer_list
);
241 copy_and_assign_cidr(node
, newnode
->bits
, cidr
, bits
);
243 choose_and_connect_node(node
, down
);
244 choose_and_connect_node(node
, newnode
);
246 connect_node(trie
, 2, node
);
248 choose_and_connect_node(parent
, node
);
252 void wg_allowedips_init(struct allowedips
*table
)
254 table
->root4
= table
->root6
= NULL
;
258 void wg_allowedips_free(struct allowedips
*table
, struct mutex
*lock
)
260 struct allowedips_node __rcu
*old4
= table
->root4
, *old6
= table
->root6
;
263 RCU_INIT_POINTER(table
->root4
, NULL
);
264 RCU_INIT_POINTER(table
->root6
, NULL
);
265 if (rcu_access_pointer(old4
)) {
266 struct allowedips_node
*node
= rcu_dereference_protected(old4
,
267 lockdep_is_held(lock
));
269 root_remove_peer_lists(node
);
270 call_rcu(&node
->rcu
, root_free_rcu
);
272 if (rcu_access_pointer(old6
)) {
273 struct allowedips_node
*node
= rcu_dereference_protected(old6
,
274 lockdep_is_held(lock
));
276 root_remove_peer_lists(node
);
277 call_rcu(&node
->rcu
, root_free_rcu
);
281 int wg_allowedips_insert_v4(struct allowedips
*table
, const struct in_addr
*ip
,
282 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
284 /* Aligned so it can be passed to fls */
285 u8 key
[4] __aligned(__alignof(u32
));
288 swap_endian(key
, (const u8
*)ip
, 32);
289 return add(&table
->root4
, 32, key
, cidr
, peer
, lock
);
292 int wg_allowedips_insert_v6(struct allowedips
*table
, const struct in6_addr
*ip
,
293 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
295 /* Aligned so it can be passed to fls64 */
296 u8 key
[16] __aligned(__alignof(u64
));
299 swap_endian(key
, (const u8
*)ip
, 128);
300 return add(&table
->root6
, 128, key
, cidr
, peer
, lock
);
303 void wg_allowedips_remove_by_peer(struct allowedips
*table
,
304 struct wg_peer
*peer
, struct mutex
*lock
)
306 struct allowedips_node
*node
, *child
, **parent_bit
, *parent
, *tmp
;
309 if (list_empty(&peer
->allowedips_list
))
312 list_for_each_entry_safe(node
, tmp
, &peer
->allowedips_list
, peer_list
) {
313 list_del_init(&node
->peer_list
);
314 RCU_INIT_POINTER(node
->peer
, NULL
);
315 if (node
->bit
[0] && node
->bit
[1])
317 child
= rcu_dereference_protected(node
->bit
[!rcu_access_pointer(node
->bit
[0])],
318 lockdep_is_held(lock
));
320 child
->parent_bit_packed
= node
->parent_bit_packed
;
321 parent_bit
= (struct allowedips_node
**)(node
->parent_bit_packed
& ~3UL);
323 parent
= (void *)parent_bit
-
324 offsetof(struct allowedips_node
, bit
[node
->parent_bit_packed
& 1]);
325 free_parent
= !rcu_access_pointer(node
->bit
[0]) &&
326 !rcu_access_pointer(node
->bit
[1]) &&
327 (node
->parent_bit_packed
& 3) <= 1 &&
328 !rcu_access_pointer(parent
->peer
);
330 child
= rcu_dereference_protected(
331 parent
->bit
[!(node
->parent_bit_packed
& 1)],
332 lockdep_is_held(lock
));
333 call_rcu(&node
->rcu
, node_free_rcu
);
337 child
->parent_bit_packed
= parent
->parent_bit_packed
;
338 *(struct allowedips_node
**)(parent
->parent_bit_packed
& ~3UL) = child
;
339 call_rcu(&parent
->rcu
, node_free_rcu
);
343 int wg_allowedips_read_node(struct allowedips_node
*node
, u8 ip
[16], u8
*cidr
)
345 const unsigned int cidr_bytes
= DIV_ROUND_UP(node
->cidr
, 8U);
346 swap_endian(ip
, node
->bits
, node
->bitlen
);
347 memset(ip
+ cidr_bytes
, 0, node
->bitlen
/ 8U - cidr_bytes
);
349 ip
[cidr_bytes
- 1U] &= ~0U << (-node
->cidr
% 8U);
352 return node
->bitlen
== 32 ? AF_INET
: AF_INET6
;
355 /* Returns a strong reference to a peer */
356 struct wg_peer
*wg_allowedips_lookup_dst(struct allowedips
*table
,
359 if (skb
->protocol
== htons(ETH_P_IP
))
360 return lookup(table
->root4
, 32, &ip_hdr(skb
)->daddr
);
361 else if (skb
->protocol
== htons(ETH_P_IPV6
))
362 return lookup(table
->root6
, 128, &ipv6_hdr(skb
)->daddr
);
366 /* Returns a strong reference to a peer */
367 struct wg_peer
*wg_allowedips_lookup_src(struct allowedips
*table
,
370 if (skb
->protocol
== htons(ETH_P_IP
))
371 return lookup(table
->root4
, 32, &ip_hdr(skb
)->saddr
);
372 else if (skb
->protocol
== htons(ETH_P_IPV6
))
373 return lookup(table
->root6
, 128, &ipv6_hdr(skb
)->saddr
);
377 int __init
wg_allowedips_slab_init(void)
379 node_cache
= KMEM_CACHE(allowedips_node
, 0);
380 return node_cache
? 0 : -ENOMEM
;
383 void wg_allowedips_slab_uninit(void)
386 kmem_cache_destroy(node_cache
);
389 #include "selftest/allowedips.c"