Merge tag 'block-5.11-2021-01-10' of git://git.kernel.dk/linux-block
[linux/fpc-iii.git] / drivers / net / wireguard / noise.c
blobc0cfd9b36c0b594ac5ea87dca04ca075fe3da1d2
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 */
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
20 /* This implements Noise_IKpsk2:
22 * <- s
23 * ******
24 * -> e, es, s, ss, {t}
25 * <- e, ee, se, psk, {}
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
34 void __init wg_noise_init(void)
36 struct blake2s_state blake;
38 blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 blake2s_init(&blake, NOISE_HASH_LEN);
41 blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 blake2s_final(&blake, handshake_init_hash);
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
49 down_write(&peer->handshake.lock);
50 if (!peer->handshake.static_identity->has_identity ||
51 !curve25519(peer->handshake.precomputed_static_static,
52 peer->handshake.static_identity->static_private,
53 peer->handshake.remote_static))
54 memset(peer->handshake.precomputed_static_static, 0,
55 NOISE_PUBLIC_KEY_LEN);
56 up_write(&peer->handshake.lock);
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60 struct noise_static_identity *static_identity,
61 const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62 const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63 struct wg_peer *peer)
65 memset(handshake, 0, sizeof(*handshake));
66 init_rwsem(&handshake->lock);
67 handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68 handshake->entry.peer = peer;
69 memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70 if (peer_preshared_key)
71 memcpy(handshake->preshared_key, peer_preshared_key,
72 NOISE_SYMMETRIC_KEY_LEN);
73 handshake->static_identity = static_identity;
74 handshake->state = HANDSHAKE_ZEROED;
75 wg_noise_precompute_static_static(peer);
78 static void handshake_zero(struct noise_handshake *handshake)
80 memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81 memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82 memset(&handshake->hash, 0, NOISE_HASH_LEN);
83 memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84 handshake->remote_index = 0;
85 handshake->state = HANDSHAKE_ZEROED;
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
90 down_write(&handshake->lock);
91 wg_index_hashtable_remove(
92 handshake->entry.peer->device->index_hashtable,
93 &handshake->entry);
94 handshake_zero(handshake);
95 up_write(&handshake->lock);
98 static struct noise_keypair *keypair_create(struct wg_peer *peer)
100 struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
102 if (unlikely(!keypair))
103 return NULL;
104 spin_lock_init(&keypair->receiving_counter.lock);
105 keypair->internal_id = atomic64_inc_return(&keypair_counter);
106 keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107 keypair->entry.peer = peer;
108 kref_init(&keypair->refcount);
109 return keypair;
112 static void keypair_free_rcu(struct rcu_head *rcu)
114 kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
117 static void keypair_free_kref(struct kref *kref)
119 struct noise_keypair *keypair =
120 container_of(kref, struct noise_keypair, refcount);
122 net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123 keypair->entry.peer->device->dev->name,
124 keypair->internal_id,
125 keypair->entry.peer->internal_id);
126 wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127 &keypair->entry);
128 call_rcu(&keypair->rcu, keypair_free_rcu);
131 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
133 if (unlikely(!keypair))
134 return;
135 if (unlikely(unreference_now))
136 wg_index_hashtable_remove(
137 keypair->entry.peer->device->index_hashtable,
138 &keypair->entry);
139 kref_put(&keypair->refcount, keypair_free_kref);
142 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
144 RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145 "Taking noise keypair reference without holding the RCU BH read lock");
146 if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147 return NULL;
148 return keypair;
151 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
153 struct noise_keypair *old;
155 spin_lock_bh(&keypairs->keypair_update_lock);
157 /* We zero the next_keypair before zeroing the others, so that
158 * wg_noise_received_with_keypair returns early before subsequent ones
159 * are zeroed.
161 old = rcu_dereference_protected(keypairs->next_keypair,
162 lockdep_is_held(&keypairs->keypair_update_lock));
163 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164 wg_noise_keypair_put(old, true);
166 old = rcu_dereference_protected(keypairs->previous_keypair,
167 lockdep_is_held(&keypairs->keypair_update_lock));
168 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169 wg_noise_keypair_put(old, true);
171 old = rcu_dereference_protected(keypairs->current_keypair,
172 lockdep_is_held(&keypairs->keypair_update_lock));
173 RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174 wg_noise_keypair_put(old, true);
176 spin_unlock_bh(&keypairs->keypair_update_lock);
179 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
181 struct noise_keypair *keypair;
183 wg_noise_handshake_clear(&peer->handshake);
184 wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
186 spin_lock_bh(&peer->keypairs.keypair_update_lock);
187 keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188 lockdep_is_held(&peer->keypairs.keypair_update_lock));
189 if (keypair)
190 keypair->sending.is_valid = false;
191 keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192 lockdep_is_held(&peer->keypairs.keypair_update_lock));
193 if (keypair)
194 keypair->sending.is_valid = false;
195 spin_unlock_bh(&peer->keypairs.keypair_update_lock);
198 static void add_new_keypair(struct noise_keypairs *keypairs,
199 struct noise_keypair *new_keypair)
201 struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
203 spin_lock_bh(&keypairs->keypair_update_lock);
204 previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205 lockdep_is_held(&keypairs->keypair_update_lock));
206 next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207 lockdep_is_held(&keypairs->keypair_update_lock));
208 current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209 lockdep_is_held(&keypairs->keypair_update_lock));
210 if (new_keypair->i_am_the_initiator) {
211 /* If we're the initiator, it means we've sent a handshake, and
212 * received a confirmation response, which means this new
213 * keypair can now be used.
215 if (next_keypair) {
216 /* If there already was a next keypair pending, we
217 * demote it to be the previous keypair, and free the
218 * existing current. Note that this means KCI can result
219 * in this transition. It would perhaps be more sound to
220 * always just get rid of the unused next keypair
221 * instead of putting it in the previous slot, but this
222 * might be a bit less robust. Something to think about
223 * for the future.
225 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226 rcu_assign_pointer(keypairs->previous_keypair,
227 next_keypair);
228 wg_noise_keypair_put(current_keypair, true);
229 } else /* If there wasn't an existing next keypair, we replace
230 * the previous with the current one.
232 rcu_assign_pointer(keypairs->previous_keypair,
233 current_keypair);
234 /* At this point we can get rid of the old previous keypair, and
235 * set up the new keypair.
237 wg_noise_keypair_put(previous_keypair, true);
238 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239 } else {
240 /* If we're the responder, it means we can't use the new keypair
241 * until we receive confirmation via the first data packet, so
242 * we get rid of the existing previous one, the possibly
243 * existing next one, and slide in the new next one.
245 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246 wg_noise_keypair_put(next_keypair, true);
247 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248 wg_noise_keypair_put(previous_keypair, true);
250 spin_unlock_bh(&keypairs->keypair_update_lock);
253 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254 struct noise_keypair *received_keypair)
256 struct noise_keypair *old_keypair;
257 bool key_is_new;
259 /* We first check without taking the spinlock. */
260 key_is_new = received_keypair ==
261 rcu_access_pointer(keypairs->next_keypair);
262 if (likely(!key_is_new))
263 return false;
265 spin_lock_bh(&keypairs->keypair_update_lock);
266 /* After locking, we double check that things didn't change from
267 * beneath us.
269 if (unlikely(received_keypair !=
270 rcu_dereference_protected(keypairs->next_keypair,
271 lockdep_is_held(&keypairs->keypair_update_lock)))) {
272 spin_unlock_bh(&keypairs->keypair_update_lock);
273 return false;
276 /* When we've finally received the confirmation, we slide the next
277 * into the current, the current into the previous, and get rid of
278 * the old previous.
280 old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281 lockdep_is_held(&keypairs->keypair_update_lock));
282 rcu_assign_pointer(keypairs->previous_keypair,
283 rcu_dereference_protected(keypairs->current_keypair,
284 lockdep_is_held(&keypairs->keypair_update_lock)));
285 wg_noise_keypair_put(old_keypair, true);
286 rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
289 spin_unlock_bh(&keypairs->keypair_update_lock);
290 return true;
293 /* Must hold static_identity->lock */
294 void wg_noise_set_static_identity_private_key(
295 struct noise_static_identity *static_identity,
296 const u8 private_key[NOISE_PUBLIC_KEY_LEN])
298 memcpy(static_identity->static_private, private_key,
299 NOISE_PUBLIC_KEY_LEN);
300 curve25519_clamp_secret(static_identity->static_private);
301 static_identity->has_identity = curve25519_generate_public(
302 static_identity->static_public, private_key);
305 /* This is Hugo Krawczyk's HKDF:
306 * - https://eprint.iacr.org/2010/264.pdf
307 * - https://tools.ietf.org/html/rfc5869
309 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
310 size_t first_len, size_t second_len, size_t third_len,
311 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
313 u8 output[BLAKE2S_HASH_SIZE + 1];
314 u8 secret[BLAKE2S_HASH_SIZE];
316 WARN_ON(IS_ENABLED(DEBUG) &&
317 (first_len > BLAKE2S_HASH_SIZE ||
318 second_len > BLAKE2S_HASH_SIZE ||
319 third_len > BLAKE2S_HASH_SIZE ||
320 ((second_len || second_dst || third_len || third_dst) &&
321 (!first_len || !first_dst)) ||
322 ((third_len || third_dst) && (!second_len || !second_dst))));
324 /* Extract entropy from data into secret */
325 blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
327 if (!first_dst || !first_len)
328 goto out;
330 /* Expand first key: key = secret, data = 0x1 */
331 output[0] = 1;
332 blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
333 memcpy(first_dst, output, first_len);
335 if (!second_dst || !second_len)
336 goto out;
338 /* Expand second key: key = secret, data = first-key || 0x2 */
339 output[BLAKE2S_HASH_SIZE] = 2;
340 blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
341 BLAKE2S_HASH_SIZE);
342 memcpy(second_dst, output, second_len);
344 if (!third_dst || !third_len)
345 goto out;
347 /* Expand third key: key = secret, data = second-key || 0x3 */
348 output[BLAKE2S_HASH_SIZE] = 3;
349 blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
350 BLAKE2S_HASH_SIZE);
351 memcpy(third_dst, output, third_len);
353 out:
354 /* Clear sensitive data from stack */
355 memzero_explicit(secret, BLAKE2S_HASH_SIZE);
356 memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
359 static void derive_keys(struct noise_symmetric_key *first_dst,
360 struct noise_symmetric_key *second_dst,
361 const u8 chaining_key[NOISE_HASH_LEN])
363 u64 birthdate = ktime_get_coarse_boottime_ns();
364 kdf(first_dst->key, second_dst->key, NULL, NULL,
365 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
366 chaining_key);
367 first_dst->birthdate = second_dst->birthdate = birthdate;
368 first_dst->is_valid = second_dst->is_valid = true;
371 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
372 u8 key[NOISE_SYMMETRIC_KEY_LEN],
373 const u8 private[NOISE_PUBLIC_KEY_LEN],
374 const u8 public[NOISE_PUBLIC_KEY_LEN])
376 u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
378 if (unlikely(!curve25519(dh_calculation, private, public)))
379 return false;
380 kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
381 NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
382 memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
383 return true;
386 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
387 u8 key[NOISE_SYMMETRIC_KEY_LEN],
388 const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
390 static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
391 if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
392 return false;
393 kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
394 NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
395 chaining_key);
396 return true;
399 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
401 struct blake2s_state blake;
403 blake2s_init(&blake, NOISE_HASH_LEN);
404 blake2s_update(&blake, hash, NOISE_HASH_LEN);
405 blake2s_update(&blake, src, src_len);
406 blake2s_final(&blake, hash);
409 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
410 u8 key[NOISE_SYMMETRIC_KEY_LEN],
411 const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
413 u8 temp_hash[NOISE_HASH_LEN];
415 kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
416 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
417 mix_hash(hash, temp_hash, NOISE_HASH_LEN);
418 memzero_explicit(temp_hash, NOISE_HASH_LEN);
421 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
422 u8 hash[NOISE_HASH_LEN],
423 const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
425 memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
426 memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
427 mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
430 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
431 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
432 u8 hash[NOISE_HASH_LEN])
434 chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
435 NOISE_HASH_LEN,
436 0 /* Always zero for Noise_IK */, key);
437 mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
440 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
441 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
442 u8 hash[NOISE_HASH_LEN])
444 if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
445 hash, NOISE_HASH_LEN,
446 0 /* Always zero for Noise_IK */, key))
447 return false;
448 mix_hash(hash, src_ciphertext, src_len);
449 return true;
452 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
453 const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
454 u8 chaining_key[NOISE_HASH_LEN],
455 u8 hash[NOISE_HASH_LEN])
457 if (ephemeral_dst != ephemeral_src)
458 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
459 mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
460 kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
461 NOISE_PUBLIC_KEY_LEN, chaining_key);
464 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
466 struct timespec64 now;
468 ktime_get_real_ts64(&now);
470 /* In order to prevent some sort of infoleak from precise timers, we
471 * round down the nanoseconds part to the closest rounded-down power of
472 * two to the maximum initiations per second allowed anyway by the
473 * implementation.
475 now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
476 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
478 /* https://cr.yp.to/libtai/tai64.html */
479 *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
480 *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
483 bool
484 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
485 struct noise_handshake *handshake)
487 u8 timestamp[NOISE_TIMESTAMP_LEN];
488 u8 key[NOISE_SYMMETRIC_KEY_LEN];
489 bool ret = false;
491 /* We need to wait for crng _before_ taking any locks, since
492 * curve25519_generate_secret uses get_random_bytes_wait.
494 wait_for_random_bytes();
496 down_read(&handshake->static_identity->lock);
497 down_write(&handshake->lock);
499 if (unlikely(!handshake->static_identity->has_identity))
500 goto out;
502 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
504 handshake_init(handshake->chaining_key, handshake->hash,
505 handshake->remote_static);
507 /* e */
508 curve25519_generate_secret(handshake->ephemeral_private);
509 if (!curve25519_generate_public(dst->unencrypted_ephemeral,
510 handshake->ephemeral_private))
511 goto out;
512 message_ephemeral(dst->unencrypted_ephemeral,
513 dst->unencrypted_ephemeral, handshake->chaining_key,
514 handshake->hash);
516 /* es */
517 if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
518 handshake->remote_static))
519 goto out;
521 /* s */
522 message_encrypt(dst->encrypted_static,
523 handshake->static_identity->static_public,
524 NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
526 /* ss */
527 if (!mix_precomputed_dh(handshake->chaining_key, key,
528 handshake->precomputed_static_static))
529 goto out;
531 /* {t} */
532 tai64n_now(timestamp);
533 message_encrypt(dst->encrypted_timestamp, timestamp,
534 NOISE_TIMESTAMP_LEN, key, handshake->hash);
536 dst->sender_index = wg_index_hashtable_insert(
537 handshake->entry.peer->device->index_hashtable,
538 &handshake->entry);
540 handshake->state = HANDSHAKE_CREATED_INITIATION;
541 ret = true;
543 out:
544 up_write(&handshake->lock);
545 up_read(&handshake->static_identity->lock);
546 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
547 return ret;
550 struct wg_peer *
551 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
552 struct wg_device *wg)
554 struct wg_peer *peer = NULL, *ret_peer = NULL;
555 struct noise_handshake *handshake;
556 bool replay_attack, flood_attack;
557 u8 key[NOISE_SYMMETRIC_KEY_LEN];
558 u8 chaining_key[NOISE_HASH_LEN];
559 u8 hash[NOISE_HASH_LEN];
560 u8 s[NOISE_PUBLIC_KEY_LEN];
561 u8 e[NOISE_PUBLIC_KEY_LEN];
562 u8 t[NOISE_TIMESTAMP_LEN];
563 u64 initiation_consumption;
565 down_read(&wg->static_identity.lock);
566 if (unlikely(!wg->static_identity.has_identity))
567 goto out;
569 handshake_init(chaining_key, hash, wg->static_identity.static_public);
571 /* e */
572 message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
574 /* es */
575 if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
576 goto out;
578 /* s */
579 if (!message_decrypt(s, src->encrypted_static,
580 sizeof(src->encrypted_static), key, hash))
581 goto out;
583 /* Lookup which peer we're actually talking to */
584 peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
585 if (!peer)
586 goto out;
587 handshake = &peer->handshake;
589 /* ss */
590 if (!mix_precomputed_dh(chaining_key, key,
591 handshake->precomputed_static_static))
592 goto out;
594 /* {t} */
595 if (!message_decrypt(t, src->encrypted_timestamp,
596 sizeof(src->encrypted_timestamp), key, hash))
597 goto out;
599 down_read(&handshake->lock);
600 replay_attack = memcmp(t, handshake->latest_timestamp,
601 NOISE_TIMESTAMP_LEN) <= 0;
602 flood_attack = (s64)handshake->last_initiation_consumption +
603 NSEC_PER_SEC / INITIATIONS_PER_SECOND >
604 (s64)ktime_get_coarse_boottime_ns();
605 up_read(&handshake->lock);
606 if (replay_attack || flood_attack)
607 goto out;
609 /* Success! Copy everything to peer */
610 down_write(&handshake->lock);
611 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
612 if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
613 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
614 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
615 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
616 handshake->remote_index = src->sender_index;
617 initiation_consumption = ktime_get_coarse_boottime_ns();
618 if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
619 handshake->last_initiation_consumption = initiation_consumption;
620 handshake->state = HANDSHAKE_CONSUMED_INITIATION;
621 up_write(&handshake->lock);
622 ret_peer = peer;
624 out:
625 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
626 memzero_explicit(hash, NOISE_HASH_LEN);
627 memzero_explicit(chaining_key, NOISE_HASH_LEN);
628 up_read(&wg->static_identity.lock);
629 if (!ret_peer)
630 wg_peer_put(peer);
631 return ret_peer;
634 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
635 struct noise_handshake *handshake)
637 u8 key[NOISE_SYMMETRIC_KEY_LEN];
638 bool ret = false;
640 /* We need to wait for crng _before_ taking any locks, since
641 * curve25519_generate_secret uses get_random_bytes_wait.
643 wait_for_random_bytes();
645 down_read(&handshake->static_identity->lock);
646 down_write(&handshake->lock);
648 if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
649 goto out;
651 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
652 dst->receiver_index = handshake->remote_index;
654 /* e */
655 curve25519_generate_secret(handshake->ephemeral_private);
656 if (!curve25519_generate_public(dst->unencrypted_ephemeral,
657 handshake->ephemeral_private))
658 goto out;
659 message_ephemeral(dst->unencrypted_ephemeral,
660 dst->unencrypted_ephemeral, handshake->chaining_key,
661 handshake->hash);
663 /* ee */
664 if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
665 handshake->remote_ephemeral))
666 goto out;
668 /* se */
669 if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
670 handshake->remote_static))
671 goto out;
673 /* psk */
674 mix_psk(handshake->chaining_key, handshake->hash, key,
675 handshake->preshared_key);
677 /* {} */
678 message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
680 dst->sender_index = wg_index_hashtable_insert(
681 handshake->entry.peer->device->index_hashtable,
682 &handshake->entry);
684 handshake->state = HANDSHAKE_CREATED_RESPONSE;
685 ret = true;
687 out:
688 up_write(&handshake->lock);
689 up_read(&handshake->static_identity->lock);
690 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
691 return ret;
694 struct wg_peer *
695 wg_noise_handshake_consume_response(struct message_handshake_response *src,
696 struct wg_device *wg)
698 enum noise_handshake_state state = HANDSHAKE_ZEROED;
699 struct wg_peer *peer = NULL, *ret_peer = NULL;
700 struct noise_handshake *handshake;
701 u8 key[NOISE_SYMMETRIC_KEY_LEN];
702 u8 hash[NOISE_HASH_LEN];
703 u8 chaining_key[NOISE_HASH_LEN];
704 u8 e[NOISE_PUBLIC_KEY_LEN];
705 u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
706 u8 static_private[NOISE_PUBLIC_KEY_LEN];
707 u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
709 down_read(&wg->static_identity.lock);
711 if (unlikely(!wg->static_identity.has_identity))
712 goto out;
714 handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
715 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
716 src->receiver_index, &peer);
717 if (unlikely(!handshake))
718 goto out;
720 down_read(&handshake->lock);
721 state = handshake->state;
722 memcpy(hash, handshake->hash, NOISE_HASH_LEN);
723 memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
724 memcpy(ephemeral_private, handshake->ephemeral_private,
725 NOISE_PUBLIC_KEY_LEN);
726 memcpy(preshared_key, handshake->preshared_key,
727 NOISE_SYMMETRIC_KEY_LEN);
728 up_read(&handshake->lock);
730 if (state != HANDSHAKE_CREATED_INITIATION)
731 goto fail;
733 /* e */
734 message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
736 /* ee */
737 if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
738 goto fail;
740 /* se */
741 if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
742 goto fail;
744 /* psk */
745 mix_psk(chaining_key, hash, key, preshared_key);
747 /* {} */
748 if (!message_decrypt(NULL, src->encrypted_nothing,
749 sizeof(src->encrypted_nothing), key, hash))
750 goto fail;
752 /* Success! Copy everything to peer */
753 down_write(&handshake->lock);
754 /* It's important to check that the state is still the same, while we
755 * have an exclusive lock.
757 if (handshake->state != state) {
758 up_write(&handshake->lock);
759 goto fail;
761 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
762 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
763 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
764 handshake->remote_index = src->sender_index;
765 handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
766 up_write(&handshake->lock);
767 ret_peer = peer;
768 goto out;
770 fail:
771 wg_peer_put(peer);
772 out:
773 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
774 memzero_explicit(hash, NOISE_HASH_LEN);
775 memzero_explicit(chaining_key, NOISE_HASH_LEN);
776 memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
777 memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
778 memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
779 up_read(&wg->static_identity.lock);
780 return ret_peer;
783 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
784 struct noise_keypairs *keypairs)
786 struct noise_keypair *new_keypair;
787 bool ret = false;
789 down_write(&handshake->lock);
790 if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
791 handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
792 goto out;
794 new_keypair = keypair_create(handshake->entry.peer);
795 if (!new_keypair)
796 goto out;
797 new_keypair->i_am_the_initiator = handshake->state ==
798 HANDSHAKE_CONSUMED_RESPONSE;
799 new_keypair->remote_index = handshake->remote_index;
801 if (new_keypair->i_am_the_initiator)
802 derive_keys(&new_keypair->sending, &new_keypair->receiving,
803 handshake->chaining_key);
804 else
805 derive_keys(&new_keypair->receiving, &new_keypair->sending,
806 handshake->chaining_key);
808 handshake_zero(handshake);
809 rcu_read_lock_bh();
810 if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
811 handshake)->is_dead))) {
812 add_new_keypair(keypairs, new_keypair);
813 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
814 handshake->entry.peer->device->dev->name,
815 new_keypair->internal_id,
816 handshake->entry.peer->internal_id);
817 ret = wg_index_hashtable_replace(
818 handshake->entry.peer->device->index_hashtable,
819 &handshake->entry, &new_keypair->entry);
820 } else {
821 kfree_sensitive(new_keypair);
823 rcu_read_unlock_bh();
825 out:
826 up_write(&handshake->lock);
827 return ret;