1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 #include "allowedips.h"
9 static void swap_endian(u8
*dst
, const u8
*src
, u8 bits
)
12 *(u32
*)dst
= be32_to_cpu(*(const __be32
*)src
);
13 } else if (bits
== 128) {
14 ((u64
*)dst
)[0] = be64_to_cpu(((const __be64
*)src
)[0]);
15 ((u64
*)dst
)[1] = be64_to_cpu(((const __be64
*)src
)[1]);
19 static void copy_and_assign_cidr(struct allowedips_node
*node
, const u8
*src
,
23 node
->bit_at_a
= cidr
/ 8U;
24 #ifdef __LITTLE_ENDIAN
25 node
->bit_at_a
^= (bits
/ 8U - 1U) % 8U;
27 node
->bit_at_b
= 7U - (cidr
% 8U);
29 memcpy(node
->bits
, src
, bits
/ 8U);
31 #define CHOOSE_NODE(parent, key) \
32 parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
34 static void push_rcu(struct allowedips_node
**stack
,
35 struct allowedips_node __rcu
*p
, unsigned int *len
)
37 if (rcu_access_pointer(p
)) {
38 WARN_ON(IS_ENABLED(DEBUG
) && *len
>= 128);
39 stack
[(*len
)++] = rcu_dereference_raw(p
);
43 static void root_free_rcu(struct rcu_head
*rcu
)
45 struct allowedips_node
*node
, *stack
[128] = {
46 container_of(rcu
, struct allowedips_node
, rcu
) };
49 while (len
> 0 && (node
= stack
[--len
])) {
50 push_rcu(stack
, node
->bit
[0], &len
);
51 push_rcu(stack
, node
->bit
[1], &len
);
56 static void root_remove_peer_lists(struct allowedips_node
*root
)
58 struct allowedips_node
*node
, *stack
[128] = { root
};
61 while (len
> 0 && (node
= stack
[--len
])) {
62 push_rcu(stack
, node
->bit
[0], &len
);
63 push_rcu(stack
, node
->bit
[1], &len
);
64 if (rcu_access_pointer(node
->peer
))
65 list_del(&node
->peer_list
);
69 static void walk_remove_by_peer(struct allowedips_node __rcu
**top
,
70 struct wg_peer
*peer
, struct mutex
*lock
)
72 #define REF(p) rcu_access_pointer(p)
73 #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
75 WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
79 struct allowedips_node __rcu
**stack
[128], **nptr
;
80 struct allowedips_node
*node
, *prev
;
83 if (unlikely(!peer
|| !REF(*top
)))
86 for (prev
= NULL
, len
= 0, PUSH(top
); len
> 0; prev
= node
) {
87 nptr
= stack
[len
- 1];
93 if (!prev
|| REF(prev
->bit
[0]) == node
||
94 REF(prev
->bit
[1]) == node
) {
95 if (REF(node
->bit
[0]))
97 else if (REF(node
->bit
[1]))
99 } else if (REF(node
->bit
[0]) == prev
) {
100 if (REF(node
->bit
[1]))
103 if (rcu_dereference_protected(node
->peer
,
104 lockdep_is_held(lock
)) == peer
) {
105 RCU_INIT_POINTER(node
->peer
, NULL
);
106 list_del_init(&node
->peer_list
);
107 if (!node
->bit
[0] || !node
->bit
[1]) {
108 rcu_assign_pointer(*nptr
, DEREF(
109 &node
->bit
[!REF(node
->bit
[0])]));
110 kfree_rcu(node
, rcu
);
123 static unsigned int fls128(u64 a
, u64 b
)
125 return a
? fls64(a
) + 64U : fls64(b
);
128 static u8
common_bits(const struct allowedips_node
*node
, const u8
*key
,
132 return 32U - fls(*(const u32
*)node
->bits
^ *(const u32
*)key
);
133 else if (bits
== 128)
134 return 128U - fls128(
135 *(const u64
*)&node
->bits
[0] ^ *(const u64
*)&key
[0],
136 *(const u64
*)&node
->bits
[8] ^ *(const u64
*)&key
[8]);
140 static bool prefix_matches(const struct allowedips_node
*node
, const u8
*key
,
143 /* This could be much faster if it actually just compared the common
144 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and
145 * the rest, but it turns out that common_bits is already super fast on
146 * modern processors, even taking into account the unfortunate bswap.
147 * So, we just inline it like this instead.
149 return common_bits(node
, key
, bits
) >= node
->cidr
;
152 static struct allowedips_node
*find_node(struct allowedips_node
*trie
, u8 bits
,
155 struct allowedips_node
*node
= trie
, *found
= NULL
;
157 while (node
&& prefix_matches(node
, key
, bits
)) {
158 if (rcu_access_pointer(node
->peer
))
160 if (node
->cidr
== bits
)
162 node
= rcu_dereference_bh(CHOOSE_NODE(node
, key
));
167 /* Returns a strong reference to a peer */
168 static struct wg_peer
*lookup(struct allowedips_node __rcu
*root
, u8 bits
,
171 /* Aligned so it can be passed to fls/fls64 */
172 u8 ip
[16] __aligned(__alignof(u64
));
173 struct allowedips_node
*node
;
174 struct wg_peer
*peer
= NULL
;
176 swap_endian(ip
, be_ip
, bits
);
180 node
= find_node(rcu_dereference_bh(root
), bits
, ip
);
182 peer
= wg_peer_get_maybe_zero(rcu_dereference_bh(node
->peer
));
186 rcu_read_unlock_bh();
190 static bool node_placement(struct allowedips_node __rcu
*trie
, const u8
*key
,
191 u8 cidr
, u8 bits
, struct allowedips_node
**rnode
,
194 struct allowedips_node
*node
= rcu_dereference_protected(trie
,
195 lockdep_is_held(lock
));
196 struct allowedips_node
*parent
= NULL
;
199 while (node
&& node
->cidr
<= cidr
&& prefix_matches(node
, key
, bits
)) {
201 if (parent
->cidr
== cidr
) {
205 node
= rcu_dereference_protected(CHOOSE_NODE(parent
, key
),
206 lockdep_is_held(lock
));
212 static int add(struct allowedips_node __rcu
**trie
, u8 bits
, const u8
*key
,
213 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
215 struct allowedips_node
*node
, *parent
, *down
, *newnode
;
217 if (unlikely(cidr
> bits
|| !peer
))
220 if (!rcu_access_pointer(*trie
)) {
221 node
= kzalloc(sizeof(*node
), GFP_KERNEL
);
224 RCU_INIT_POINTER(node
->peer
, peer
);
225 list_add_tail(&node
->peer_list
, &peer
->allowedips_list
);
226 copy_and_assign_cidr(node
, key
, cidr
, bits
);
227 rcu_assign_pointer(*trie
, node
);
230 if (node_placement(*trie
, key
, cidr
, bits
, &node
, lock
)) {
231 rcu_assign_pointer(node
->peer
, peer
);
232 list_move_tail(&node
->peer_list
, &peer
->allowedips_list
);
236 newnode
= kzalloc(sizeof(*newnode
), GFP_KERNEL
);
237 if (unlikely(!newnode
))
239 RCU_INIT_POINTER(newnode
->peer
, peer
);
240 list_add_tail(&newnode
->peer_list
, &peer
->allowedips_list
);
241 copy_and_assign_cidr(newnode
, key
, cidr
, bits
);
244 down
= rcu_dereference_protected(*trie
, lockdep_is_held(lock
));
246 down
= rcu_dereference_protected(CHOOSE_NODE(node
, key
),
247 lockdep_is_held(lock
));
249 rcu_assign_pointer(CHOOSE_NODE(node
, key
), newnode
);
253 cidr
= min(cidr
, common_bits(down
, key
, bits
));
256 if (newnode
->cidr
== cidr
) {
257 rcu_assign_pointer(CHOOSE_NODE(newnode
, down
->bits
), down
);
259 rcu_assign_pointer(*trie
, newnode
);
261 rcu_assign_pointer(CHOOSE_NODE(parent
, newnode
->bits
),
264 node
= kzalloc(sizeof(*node
), GFP_KERNEL
);
265 if (unlikely(!node
)) {
269 INIT_LIST_HEAD(&node
->peer_list
);
270 copy_and_assign_cidr(node
, newnode
->bits
, cidr
, bits
);
272 rcu_assign_pointer(CHOOSE_NODE(node
, down
->bits
), down
);
273 rcu_assign_pointer(CHOOSE_NODE(node
, newnode
->bits
), newnode
);
275 rcu_assign_pointer(*trie
, node
);
277 rcu_assign_pointer(CHOOSE_NODE(parent
, node
->bits
),
283 void wg_allowedips_init(struct allowedips
*table
)
285 table
->root4
= table
->root6
= NULL
;
289 void wg_allowedips_free(struct allowedips
*table
, struct mutex
*lock
)
291 struct allowedips_node __rcu
*old4
= table
->root4
, *old6
= table
->root6
;
294 RCU_INIT_POINTER(table
->root4
, NULL
);
295 RCU_INIT_POINTER(table
->root6
, NULL
);
296 if (rcu_access_pointer(old4
)) {
297 struct allowedips_node
*node
= rcu_dereference_protected(old4
,
298 lockdep_is_held(lock
));
300 root_remove_peer_lists(node
);
301 call_rcu(&node
->rcu
, root_free_rcu
);
303 if (rcu_access_pointer(old6
)) {
304 struct allowedips_node
*node
= rcu_dereference_protected(old6
,
305 lockdep_is_held(lock
));
307 root_remove_peer_lists(node
);
308 call_rcu(&node
->rcu
, root_free_rcu
);
312 int wg_allowedips_insert_v4(struct allowedips
*table
, const struct in_addr
*ip
,
313 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
315 /* Aligned so it can be passed to fls */
316 u8 key
[4] __aligned(__alignof(u32
));
319 swap_endian(key
, (const u8
*)ip
, 32);
320 return add(&table
->root4
, 32, key
, cidr
, peer
, lock
);
323 int wg_allowedips_insert_v6(struct allowedips
*table
, const struct in6_addr
*ip
,
324 u8 cidr
, struct wg_peer
*peer
, struct mutex
*lock
)
326 /* Aligned so it can be passed to fls64 */
327 u8 key
[16] __aligned(__alignof(u64
));
330 swap_endian(key
, (const u8
*)ip
, 128);
331 return add(&table
->root6
, 128, key
, cidr
, peer
, lock
);
334 void wg_allowedips_remove_by_peer(struct allowedips
*table
,
335 struct wg_peer
*peer
, struct mutex
*lock
)
338 walk_remove_by_peer(&table
->root4
, peer
, lock
);
339 walk_remove_by_peer(&table
->root6
, peer
, lock
);
342 int wg_allowedips_read_node(struct allowedips_node
*node
, u8 ip
[16], u8
*cidr
)
344 const unsigned int cidr_bytes
= DIV_ROUND_UP(node
->cidr
, 8U);
345 swap_endian(ip
, node
->bits
, node
->bitlen
);
346 memset(ip
+ cidr_bytes
, 0, node
->bitlen
/ 8U - cidr_bytes
);
348 ip
[cidr_bytes
- 1U] &= ~0U << (-node
->cidr
% 8U);
351 return node
->bitlen
== 32 ? AF_INET
: AF_INET6
;
354 /* Returns a strong reference to a peer */
355 struct wg_peer
*wg_allowedips_lookup_dst(struct allowedips
*table
,
358 if (skb
->protocol
== htons(ETH_P_IP
))
359 return lookup(table
->root4
, 32, &ip_hdr(skb
)->daddr
);
360 else if (skb
->protocol
== htons(ETH_P_IPV6
))
361 return lookup(table
->root6
, 128, &ipv6_hdr(skb
)->daddr
);
365 /* Returns a strong reference to a peer */
366 struct wg_peer
*wg_allowedips_lookup_src(struct allowedips
*table
,
369 if (skb
->protocol
== htons(ETH_P_IP
))
370 return lookup(table
->root4
, 32, &ip_hdr(skb
)->saddr
);
371 else if (skb
->protocol
== htons(ETH_P_IPV6
))
372 return lookup(table
->root6
, 128, &ipv6_hdr(skb
)->saddr
);
376 #include "selftest/allowedips.c"