1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 #include "ratelimiter.h"
7 #include <linux/siphash.h>
9 #include <linux/slab.h>
12 static struct kmem_cache
*entry_cache
;
13 static hsiphash_key_t key
;
14 static spinlock_t table_lock
= __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15 static DEFINE_MUTEX(init_lock
);
16 static u64 init_refcnt
; /* Protected by init_lock, hence not atomic. */
17 static atomic_t total_entries
= ATOMIC_INIT(0);
18 static unsigned int max_entries
, table_size
;
19 static void wg_ratelimiter_gc_entries(struct work_struct
*);
20 static DECLARE_DEFERRABLE_WORK(gc_work
, wg_ratelimiter_gc_entries
);
21 static struct hlist_head
*table_v4
;
22 #if IS_ENABLED(CONFIG_IPV6)
23 static struct hlist_head
*table_v6
;
26 struct ratelimiter_entry
{
27 u64 last_time_ns
, tokens
, ip
;
30 struct hlist_node hash
;
35 PACKETS_PER_SECOND
= 20,
36 PACKETS_BURSTABLE
= 5,
37 PACKET_COST
= NSEC_PER_SEC
/ PACKETS_PER_SECOND
,
38 TOKEN_MAX
= PACKET_COST
* PACKETS_BURSTABLE
41 static void entry_free(struct rcu_head
*rcu
)
43 kmem_cache_free(entry_cache
,
44 container_of(rcu
, struct ratelimiter_entry
, rcu
));
45 atomic_dec(&total_entries
);
48 static void entry_uninit(struct ratelimiter_entry
*entry
)
50 hlist_del_rcu(&entry
->hash
);
51 call_rcu(&entry
->rcu
, entry_free
);
54 /* Calling this function with a NULL work uninits all entries. */
55 static void wg_ratelimiter_gc_entries(struct work_struct
*work
)
57 const u64 now
= ktime_get_coarse_boottime_ns();
58 struct ratelimiter_entry
*entry
;
59 struct hlist_node
*temp
;
62 for (i
= 0; i
< table_size
; ++i
) {
63 spin_lock(&table_lock
);
64 hlist_for_each_entry_safe(entry
, temp
, &table_v4
[i
], hash
) {
65 if (unlikely(!work
) ||
66 now
- entry
->last_time_ns
> NSEC_PER_SEC
)
69 #if IS_ENABLED(CONFIG_IPV6)
70 hlist_for_each_entry_safe(entry
, temp
, &table_v6
[i
], hash
) {
71 if (unlikely(!work
) ||
72 now
- entry
->last_time_ns
> NSEC_PER_SEC
)
76 spin_unlock(&table_lock
);
81 queue_delayed_work(system_power_efficient_wq
, &gc_work
, HZ
);
84 bool wg_ratelimiter_allow(struct sk_buff
*skb
, struct net
*net
)
86 /* We only take the bottom half of the net pointer, so that we can hash
87 * 3 words in the end. This way, siphash's len param fits into the final
88 * u32, and we don't incur an extra round.
90 const u32 net_word
= (unsigned long)net
;
91 struct ratelimiter_entry
*entry
;
92 struct hlist_head
*bucket
;
95 if (skb
->protocol
== htons(ETH_P_IP
)) {
96 ip
= (u64 __force
)ip_hdr(skb
)->saddr
;
97 bucket
= &table_v4
[hsiphash_2u32(net_word
, ip
, &key
) &
100 #if IS_ENABLED(CONFIG_IPV6)
101 else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
102 /* Only use 64 bits, so as to ratelimit the whole /64. */
103 memcpy(&ip
, &ipv6_hdr(skb
)->saddr
, sizeof(ip
));
104 bucket
= &table_v6
[hsiphash_3u32(net_word
, ip
>> 32, ip
, &key
) &
111 hlist_for_each_entry_rcu(entry
, bucket
, hash
) {
112 if (entry
->net
== net
&& entry
->ip
== ip
) {
115 /* Quasi-inspired by nft_limit.c, but this is actually a
116 * slightly different algorithm. Namely, we incorporate
117 * the burst as part of the maximum tokens, rather than
118 * as part of the rate.
120 spin_lock(&entry
->lock
);
121 now
= ktime_get_coarse_boottime_ns();
122 tokens
= min_t(u64
, TOKEN_MAX
,
123 entry
->tokens
+ now
-
124 entry
->last_time_ns
);
125 entry
->last_time_ns
= now
;
126 ret
= tokens
>= PACKET_COST
;
127 entry
->tokens
= ret
? tokens
- PACKET_COST
: tokens
;
128 spin_unlock(&entry
->lock
);
135 if (atomic_inc_return(&total_entries
) > max_entries
)
138 entry
= kmem_cache_alloc(entry_cache
, GFP_KERNEL
);
139 if (unlikely(!entry
))
144 INIT_HLIST_NODE(&entry
->hash
);
145 spin_lock_init(&entry
->lock
);
146 entry
->last_time_ns
= ktime_get_coarse_boottime_ns();
147 entry
->tokens
= TOKEN_MAX
- PACKET_COST
;
148 spin_lock(&table_lock
);
149 hlist_add_head_rcu(&entry
->hash
, bucket
);
150 spin_unlock(&table_lock
);
154 atomic_dec(&total_entries
);
158 int wg_ratelimiter_init(void)
160 mutex_lock(&init_lock
);
161 if (++init_refcnt
!= 1)
164 entry_cache
= KMEM_CACHE(ratelimiter_entry
, 0);
168 /* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
169 * but what it shares in common is that it uses a massive hashtable. So,
170 * we borrow their wisdom about good table sizes on different systems
171 * dependent on RAM. This calculation here comes from there.
173 table_size
= (totalram_pages() > (1U << 30) / PAGE_SIZE
) ? 8192 :
174 max_t(unsigned long, 16, roundup_pow_of_two(
175 (totalram_pages() << PAGE_SHIFT
) /
176 (1U << 14) / sizeof(struct hlist_head
)));
177 max_entries
= table_size
* 8;
179 table_v4
= kvzalloc(table_size
* sizeof(*table_v4
), GFP_KERNEL
);
180 if (unlikely(!table_v4
))
183 #if IS_ENABLED(CONFIG_IPV6)
184 table_v6
= kvzalloc(table_size
* sizeof(*table_v6
), GFP_KERNEL
);
185 if (unlikely(!table_v6
)) {
191 queue_delayed_work(system_power_efficient_wq
, &gc_work
, HZ
);
192 get_random_bytes(&key
, sizeof(key
));
194 mutex_unlock(&init_lock
);
198 kmem_cache_destroy(entry_cache
);
201 mutex_unlock(&init_lock
);
205 void wg_ratelimiter_uninit(void)
207 mutex_lock(&init_lock
);
208 if (!init_refcnt
|| --init_refcnt
)
211 cancel_delayed_work_sync(&gc_work
);
212 wg_ratelimiter_gc_entries(NULL
);
215 #if IS_ENABLED(CONFIG_IPV6)
218 kmem_cache_destroy(entry_cache
);
220 mutex_unlock(&init_lock
);
223 #include "selftest/ratelimiter.c"