treewide: remove redundant IS_ERR() before error code check
[linux/fpc-iii.git] / drivers / net / wireguard / noise.c
blobd71c8db68a8ceb233a78b873dbaf3371b9d0d295
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 */
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
20 /* This implements Noise_IKpsk2:
22 * <- s
23 * ******
24 * -> e, es, s, ss, {t}
25 * <- e, ee, se, psk, {}
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
34 void __init wg_noise_init(void)
36 struct blake2s_state blake;
38 blake2s(handshake_init_chaining_key, handshake_name, NULL,
39 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40 blake2s_init(&blake, NOISE_HASH_LEN);
41 blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42 blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43 blake2s_final(&blake, handshake_init_hash);
46 /* Must hold peer->handshake.static_identity->lock */
47 bool wg_noise_precompute_static_static(struct wg_peer *peer)
49 bool ret = true;
51 down_write(&peer->handshake.lock);
52 if (peer->handshake.static_identity->has_identity)
53 ret = curve25519(
54 peer->handshake.precomputed_static_static,
55 peer->handshake.static_identity->static_private,
56 peer->handshake.remote_static);
57 else
58 memset(peer->handshake.precomputed_static_static, 0,
59 NOISE_PUBLIC_KEY_LEN);
60 up_write(&peer->handshake.lock);
61 return ret;
64 bool wg_noise_handshake_init(struct noise_handshake *handshake,
65 struct noise_static_identity *static_identity,
66 const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
67 const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
68 struct wg_peer *peer)
70 memset(handshake, 0, sizeof(*handshake));
71 init_rwsem(&handshake->lock);
72 handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
73 handshake->entry.peer = peer;
74 memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
75 if (peer_preshared_key)
76 memcpy(handshake->preshared_key, peer_preshared_key,
77 NOISE_SYMMETRIC_KEY_LEN);
78 handshake->static_identity = static_identity;
79 handshake->state = HANDSHAKE_ZEROED;
80 return wg_noise_precompute_static_static(peer);
83 static void handshake_zero(struct noise_handshake *handshake)
85 memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
86 memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
87 memset(&handshake->hash, 0, NOISE_HASH_LEN);
88 memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
89 handshake->remote_index = 0;
90 handshake->state = HANDSHAKE_ZEROED;
93 void wg_noise_handshake_clear(struct noise_handshake *handshake)
95 wg_index_hashtable_remove(
96 handshake->entry.peer->device->index_hashtable,
97 &handshake->entry);
98 down_write(&handshake->lock);
99 handshake_zero(handshake);
100 up_write(&handshake->lock);
101 wg_index_hashtable_remove(
102 handshake->entry.peer->device->index_hashtable,
103 &handshake->entry);
106 static struct noise_keypair *keypair_create(struct wg_peer *peer)
108 struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
110 if (unlikely(!keypair))
111 return NULL;
112 keypair->internal_id = atomic64_inc_return(&keypair_counter);
113 keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
114 keypair->entry.peer = peer;
115 kref_init(&keypair->refcount);
116 return keypair;
119 static void keypair_free_rcu(struct rcu_head *rcu)
121 kzfree(container_of(rcu, struct noise_keypair, rcu));
124 static void keypair_free_kref(struct kref *kref)
126 struct noise_keypair *keypair =
127 container_of(kref, struct noise_keypair, refcount);
129 net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
130 keypair->entry.peer->device->dev->name,
131 keypair->internal_id,
132 keypair->entry.peer->internal_id);
133 wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
134 &keypair->entry);
135 call_rcu(&keypair->rcu, keypair_free_rcu);
138 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
140 if (unlikely(!keypair))
141 return;
142 if (unlikely(unreference_now))
143 wg_index_hashtable_remove(
144 keypair->entry.peer->device->index_hashtable,
145 &keypair->entry);
146 kref_put(&keypair->refcount, keypair_free_kref);
149 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
151 RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
152 "Taking noise keypair reference without holding the RCU BH read lock");
153 if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
154 return NULL;
155 return keypair;
158 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
160 struct noise_keypair *old;
162 spin_lock_bh(&keypairs->keypair_update_lock);
164 /* We zero the next_keypair before zeroing the others, so that
165 * wg_noise_received_with_keypair returns early before subsequent ones
166 * are zeroed.
168 old = rcu_dereference_protected(keypairs->next_keypair,
169 lockdep_is_held(&keypairs->keypair_update_lock));
170 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
171 wg_noise_keypair_put(old, true);
173 old = rcu_dereference_protected(keypairs->previous_keypair,
174 lockdep_is_held(&keypairs->keypair_update_lock));
175 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
176 wg_noise_keypair_put(old, true);
178 old = rcu_dereference_protected(keypairs->current_keypair,
179 lockdep_is_held(&keypairs->keypair_update_lock));
180 RCU_INIT_POINTER(keypairs->current_keypair, NULL);
181 wg_noise_keypair_put(old, true);
183 spin_unlock_bh(&keypairs->keypair_update_lock);
186 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
188 struct noise_keypair *keypair;
190 wg_noise_handshake_clear(&peer->handshake);
191 wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
193 spin_lock_bh(&peer->keypairs.keypair_update_lock);
194 keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
195 lockdep_is_held(&peer->keypairs.keypair_update_lock));
196 if (keypair)
197 keypair->sending.is_valid = false;
198 keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
199 lockdep_is_held(&peer->keypairs.keypair_update_lock));
200 if (keypair)
201 keypair->sending.is_valid = false;
202 spin_unlock_bh(&peer->keypairs.keypair_update_lock);
205 static void add_new_keypair(struct noise_keypairs *keypairs,
206 struct noise_keypair *new_keypair)
208 struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
210 spin_lock_bh(&keypairs->keypair_update_lock);
211 previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
212 lockdep_is_held(&keypairs->keypair_update_lock));
213 next_keypair = rcu_dereference_protected(keypairs->next_keypair,
214 lockdep_is_held(&keypairs->keypair_update_lock));
215 current_keypair = rcu_dereference_protected(keypairs->current_keypair,
216 lockdep_is_held(&keypairs->keypair_update_lock));
217 if (new_keypair->i_am_the_initiator) {
218 /* If we're the initiator, it means we've sent a handshake, and
219 * received a confirmation response, which means this new
220 * keypair can now be used.
222 if (next_keypair) {
223 /* If there already was a next keypair pending, we
224 * demote it to be the previous keypair, and free the
225 * existing current. Note that this means KCI can result
226 * in this transition. It would perhaps be more sound to
227 * always just get rid of the unused next keypair
228 * instead of putting it in the previous slot, but this
229 * might be a bit less robust. Something to think about
230 * for the future.
232 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
233 rcu_assign_pointer(keypairs->previous_keypair,
234 next_keypair);
235 wg_noise_keypair_put(current_keypair, true);
236 } else /* If there wasn't an existing next keypair, we replace
237 * the previous with the current one.
239 rcu_assign_pointer(keypairs->previous_keypair,
240 current_keypair);
241 /* At this point we can get rid of the old previous keypair, and
242 * set up the new keypair.
244 wg_noise_keypair_put(previous_keypair, true);
245 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
246 } else {
247 /* If we're the responder, it means we can't use the new keypair
248 * until we receive confirmation via the first data packet, so
249 * we get rid of the existing previous one, the possibly
250 * existing next one, and slide in the new next one.
252 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
253 wg_noise_keypair_put(next_keypair, true);
254 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
255 wg_noise_keypair_put(previous_keypair, true);
257 spin_unlock_bh(&keypairs->keypair_update_lock);
260 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
261 struct noise_keypair *received_keypair)
263 struct noise_keypair *old_keypair;
264 bool key_is_new;
266 /* We first check without taking the spinlock. */
267 key_is_new = received_keypair ==
268 rcu_access_pointer(keypairs->next_keypair);
269 if (likely(!key_is_new))
270 return false;
272 spin_lock_bh(&keypairs->keypair_update_lock);
273 /* After locking, we double check that things didn't change from
274 * beneath us.
276 if (unlikely(received_keypair !=
277 rcu_dereference_protected(keypairs->next_keypair,
278 lockdep_is_held(&keypairs->keypair_update_lock)))) {
279 spin_unlock_bh(&keypairs->keypair_update_lock);
280 return false;
283 /* When we've finally received the confirmation, we slide the next
284 * into the current, the current into the previous, and get rid of
285 * the old previous.
287 old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
288 lockdep_is_held(&keypairs->keypair_update_lock));
289 rcu_assign_pointer(keypairs->previous_keypair,
290 rcu_dereference_protected(keypairs->current_keypair,
291 lockdep_is_held(&keypairs->keypair_update_lock)));
292 wg_noise_keypair_put(old_keypair, true);
293 rcu_assign_pointer(keypairs->current_keypair, received_keypair);
294 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
296 spin_unlock_bh(&keypairs->keypair_update_lock);
297 return true;
300 /* Must hold static_identity->lock */
301 void wg_noise_set_static_identity_private_key(
302 struct noise_static_identity *static_identity,
303 const u8 private_key[NOISE_PUBLIC_KEY_LEN])
305 memcpy(static_identity->static_private, private_key,
306 NOISE_PUBLIC_KEY_LEN);
307 curve25519_clamp_secret(static_identity->static_private);
308 static_identity->has_identity = curve25519_generate_public(
309 static_identity->static_public, private_key);
312 /* This is Hugo Krawczyk's HKDF:
313 * - https://eprint.iacr.org/2010/264.pdf
314 * - https://tools.ietf.org/html/rfc5869
316 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
317 size_t first_len, size_t second_len, size_t third_len,
318 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
320 u8 output[BLAKE2S_HASH_SIZE + 1];
321 u8 secret[BLAKE2S_HASH_SIZE];
323 WARN_ON(IS_ENABLED(DEBUG) &&
324 (first_len > BLAKE2S_HASH_SIZE ||
325 second_len > BLAKE2S_HASH_SIZE ||
326 third_len > BLAKE2S_HASH_SIZE ||
327 ((second_len || second_dst || third_len || third_dst) &&
328 (!first_len || !first_dst)) ||
329 ((third_len || third_dst) && (!second_len || !second_dst))));
331 /* Extract entropy from data into secret */
332 blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
334 if (!first_dst || !first_len)
335 goto out;
337 /* Expand first key: key = secret, data = 0x1 */
338 output[0] = 1;
339 blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
340 memcpy(first_dst, output, first_len);
342 if (!second_dst || !second_len)
343 goto out;
345 /* Expand second key: key = secret, data = first-key || 0x2 */
346 output[BLAKE2S_HASH_SIZE] = 2;
347 blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
348 BLAKE2S_HASH_SIZE);
349 memcpy(second_dst, output, second_len);
351 if (!third_dst || !third_len)
352 goto out;
354 /* Expand third key: key = secret, data = second-key || 0x3 */
355 output[BLAKE2S_HASH_SIZE] = 3;
356 blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
357 BLAKE2S_HASH_SIZE);
358 memcpy(third_dst, output, third_len);
360 out:
361 /* Clear sensitive data from stack */
362 memzero_explicit(secret, BLAKE2S_HASH_SIZE);
363 memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
366 static void symmetric_key_init(struct noise_symmetric_key *key)
368 spin_lock_init(&key->counter.receive.lock);
369 atomic64_set(&key->counter.counter, 0);
370 memset(key->counter.receive.backtrack, 0,
371 sizeof(key->counter.receive.backtrack));
372 key->birthdate = ktime_get_coarse_boottime_ns();
373 key->is_valid = true;
376 static void derive_keys(struct noise_symmetric_key *first_dst,
377 struct noise_symmetric_key *second_dst,
378 const u8 chaining_key[NOISE_HASH_LEN])
380 kdf(first_dst->key, second_dst->key, NULL, NULL,
381 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
382 chaining_key);
383 symmetric_key_init(first_dst);
384 symmetric_key_init(second_dst);
387 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
388 u8 key[NOISE_SYMMETRIC_KEY_LEN],
389 const u8 private[NOISE_PUBLIC_KEY_LEN],
390 const u8 public[NOISE_PUBLIC_KEY_LEN])
392 u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
394 if (unlikely(!curve25519(dh_calculation, private, public)))
395 return false;
396 kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
397 NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
398 memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
399 return true;
402 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
404 struct blake2s_state blake;
406 blake2s_init(&blake, NOISE_HASH_LEN);
407 blake2s_update(&blake, hash, NOISE_HASH_LEN);
408 blake2s_update(&blake, src, src_len);
409 blake2s_final(&blake, hash);
412 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
413 u8 key[NOISE_SYMMETRIC_KEY_LEN],
414 const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
416 u8 temp_hash[NOISE_HASH_LEN];
418 kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
419 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
420 mix_hash(hash, temp_hash, NOISE_HASH_LEN);
421 memzero_explicit(temp_hash, NOISE_HASH_LEN);
424 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
425 u8 hash[NOISE_HASH_LEN],
426 const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
428 memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
429 memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
430 mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
433 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
434 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
435 u8 hash[NOISE_HASH_LEN])
437 chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
438 NOISE_HASH_LEN,
439 0 /* Always zero for Noise_IK */, key);
440 mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
443 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
444 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
445 u8 hash[NOISE_HASH_LEN])
447 if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
448 hash, NOISE_HASH_LEN,
449 0 /* Always zero for Noise_IK */, key))
450 return false;
451 mix_hash(hash, src_ciphertext, src_len);
452 return true;
455 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
456 const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
457 u8 chaining_key[NOISE_HASH_LEN],
458 u8 hash[NOISE_HASH_LEN])
460 if (ephemeral_dst != ephemeral_src)
461 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
462 mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
463 kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
464 NOISE_PUBLIC_KEY_LEN, chaining_key);
467 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
469 struct timespec64 now;
471 ktime_get_real_ts64(&now);
473 /* In order to prevent some sort of infoleak from precise timers, we
474 * round down the nanoseconds part to the closest rounded-down power of
475 * two to the maximum initiations per second allowed anyway by the
476 * implementation.
478 now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
479 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
481 /* https://cr.yp.to/libtai/tai64.html */
482 *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
483 *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
486 bool
487 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
488 struct noise_handshake *handshake)
490 u8 timestamp[NOISE_TIMESTAMP_LEN];
491 u8 key[NOISE_SYMMETRIC_KEY_LEN];
492 bool ret = false;
494 /* We need to wait for crng _before_ taking any locks, since
495 * curve25519_generate_secret uses get_random_bytes_wait.
497 wait_for_random_bytes();
499 down_read(&handshake->static_identity->lock);
500 down_write(&handshake->lock);
502 if (unlikely(!handshake->static_identity->has_identity))
503 goto out;
505 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
507 handshake_init(handshake->chaining_key, handshake->hash,
508 handshake->remote_static);
510 /* e */
511 curve25519_generate_secret(handshake->ephemeral_private);
512 if (!curve25519_generate_public(dst->unencrypted_ephemeral,
513 handshake->ephemeral_private))
514 goto out;
515 message_ephemeral(dst->unencrypted_ephemeral,
516 dst->unencrypted_ephemeral, handshake->chaining_key,
517 handshake->hash);
519 /* es */
520 if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
521 handshake->remote_static))
522 goto out;
524 /* s */
525 message_encrypt(dst->encrypted_static,
526 handshake->static_identity->static_public,
527 NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
529 /* ss */
530 kdf(handshake->chaining_key, key, NULL,
531 handshake->precomputed_static_static, NOISE_HASH_LEN,
532 NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
533 handshake->chaining_key);
535 /* {t} */
536 tai64n_now(timestamp);
537 message_encrypt(dst->encrypted_timestamp, timestamp,
538 NOISE_TIMESTAMP_LEN, key, handshake->hash);
540 dst->sender_index = wg_index_hashtable_insert(
541 handshake->entry.peer->device->index_hashtable,
542 &handshake->entry);
544 handshake->state = HANDSHAKE_CREATED_INITIATION;
545 ret = true;
547 out:
548 up_write(&handshake->lock);
549 up_read(&handshake->static_identity->lock);
550 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
551 return ret;
554 struct wg_peer *
555 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
556 struct wg_device *wg)
558 struct wg_peer *peer = NULL, *ret_peer = NULL;
559 struct noise_handshake *handshake;
560 bool replay_attack, flood_attack;
561 u8 key[NOISE_SYMMETRIC_KEY_LEN];
562 u8 chaining_key[NOISE_HASH_LEN];
563 u8 hash[NOISE_HASH_LEN];
564 u8 s[NOISE_PUBLIC_KEY_LEN];
565 u8 e[NOISE_PUBLIC_KEY_LEN];
566 u8 t[NOISE_TIMESTAMP_LEN];
567 u64 initiation_consumption;
569 down_read(&wg->static_identity.lock);
570 if (unlikely(!wg->static_identity.has_identity))
571 goto out;
573 handshake_init(chaining_key, hash, wg->static_identity.static_public);
575 /* e */
576 message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
578 /* es */
579 if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
580 goto out;
582 /* s */
583 if (!message_decrypt(s, src->encrypted_static,
584 sizeof(src->encrypted_static), key, hash))
585 goto out;
587 /* Lookup which peer we're actually talking to */
588 peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
589 if (!peer)
590 goto out;
591 handshake = &peer->handshake;
593 /* ss */
594 kdf(chaining_key, key, NULL, handshake->precomputed_static_static,
595 NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
596 chaining_key);
598 /* {t} */
599 if (!message_decrypt(t, src->encrypted_timestamp,
600 sizeof(src->encrypted_timestamp), key, hash))
601 goto out;
603 down_read(&handshake->lock);
604 replay_attack = memcmp(t, handshake->latest_timestamp,
605 NOISE_TIMESTAMP_LEN) <= 0;
606 flood_attack = (s64)handshake->last_initiation_consumption +
607 NSEC_PER_SEC / INITIATIONS_PER_SECOND >
608 (s64)ktime_get_coarse_boottime_ns();
609 up_read(&handshake->lock);
610 if (replay_attack || flood_attack)
611 goto out;
613 /* Success! Copy everything to peer */
614 down_write(&handshake->lock);
615 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
616 if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
617 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
618 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
619 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
620 handshake->remote_index = src->sender_index;
621 if ((s64)(handshake->last_initiation_consumption -
622 (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
623 handshake->last_initiation_consumption = initiation_consumption;
624 handshake->state = HANDSHAKE_CONSUMED_INITIATION;
625 up_write(&handshake->lock);
626 ret_peer = peer;
628 out:
629 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
630 memzero_explicit(hash, NOISE_HASH_LEN);
631 memzero_explicit(chaining_key, NOISE_HASH_LEN);
632 up_read(&wg->static_identity.lock);
633 if (!ret_peer)
634 wg_peer_put(peer);
635 return ret_peer;
638 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
639 struct noise_handshake *handshake)
641 u8 key[NOISE_SYMMETRIC_KEY_LEN];
642 bool ret = false;
644 /* We need to wait for crng _before_ taking any locks, since
645 * curve25519_generate_secret uses get_random_bytes_wait.
647 wait_for_random_bytes();
649 down_read(&handshake->static_identity->lock);
650 down_write(&handshake->lock);
652 if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
653 goto out;
655 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
656 dst->receiver_index = handshake->remote_index;
658 /* e */
659 curve25519_generate_secret(handshake->ephemeral_private);
660 if (!curve25519_generate_public(dst->unencrypted_ephemeral,
661 handshake->ephemeral_private))
662 goto out;
663 message_ephemeral(dst->unencrypted_ephemeral,
664 dst->unencrypted_ephemeral, handshake->chaining_key,
665 handshake->hash);
667 /* ee */
668 if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
669 handshake->remote_ephemeral))
670 goto out;
672 /* se */
673 if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
674 handshake->remote_static))
675 goto out;
677 /* psk */
678 mix_psk(handshake->chaining_key, handshake->hash, key,
679 handshake->preshared_key);
681 /* {} */
682 message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
684 dst->sender_index = wg_index_hashtable_insert(
685 handshake->entry.peer->device->index_hashtable,
686 &handshake->entry);
688 handshake->state = HANDSHAKE_CREATED_RESPONSE;
689 ret = true;
691 out:
692 up_write(&handshake->lock);
693 up_read(&handshake->static_identity->lock);
694 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
695 return ret;
698 struct wg_peer *
699 wg_noise_handshake_consume_response(struct message_handshake_response *src,
700 struct wg_device *wg)
702 enum noise_handshake_state state = HANDSHAKE_ZEROED;
703 struct wg_peer *peer = NULL, *ret_peer = NULL;
704 struct noise_handshake *handshake;
705 u8 key[NOISE_SYMMETRIC_KEY_LEN];
706 u8 hash[NOISE_HASH_LEN];
707 u8 chaining_key[NOISE_HASH_LEN];
708 u8 e[NOISE_PUBLIC_KEY_LEN];
709 u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
710 u8 static_private[NOISE_PUBLIC_KEY_LEN];
712 down_read(&wg->static_identity.lock);
714 if (unlikely(!wg->static_identity.has_identity))
715 goto out;
717 handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
718 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
719 src->receiver_index, &peer);
720 if (unlikely(!handshake))
721 goto out;
723 down_read(&handshake->lock);
724 state = handshake->state;
725 memcpy(hash, handshake->hash, NOISE_HASH_LEN);
726 memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
727 memcpy(ephemeral_private, handshake->ephemeral_private,
728 NOISE_PUBLIC_KEY_LEN);
729 up_read(&handshake->lock);
731 if (state != HANDSHAKE_CREATED_INITIATION)
732 goto fail;
734 /* e */
735 message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
737 /* ee */
738 if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
739 goto fail;
741 /* se */
742 if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
743 goto fail;
745 /* psk */
746 mix_psk(chaining_key, hash, key, handshake->preshared_key);
748 /* {} */
749 if (!message_decrypt(NULL, src->encrypted_nothing,
750 sizeof(src->encrypted_nothing), key, hash))
751 goto fail;
753 /* Success! Copy everything to peer */
754 down_write(&handshake->lock);
755 /* It's important to check that the state is still the same, while we
756 * have an exclusive lock.
758 if (handshake->state != state) {
759 up_write(&handshake->lock);
760 goto fail;
762 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
763 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
764 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
765 handshake->remote_index = src->sender_index;
766 handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
767 up_write(&handshake->lock);
768 ret_peer = peer;
769 goto out;
771 fail:
772 wg_peer_put(peer);
773 out:
774 memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
775 memzero_explicit(hash, NOISE_HASH_LEN);
776 memzero_explicit(chaining_key, NOISE_HASH_LEN);
777 memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
778 memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
779 up_read(&wg->static_identity.lock);
780 return ret_peer;
783 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
784 struct noise_keypairs *keypairs)
786 struct noise_keypair *new_keypair;
787 bool ret = false;
789 down_write(&handshake->lock);
790 if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
791 handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
792 goto out;
794 new_keypair = keypair_create(handshake->entry.peer);
795 if (!new_keypair)
796 goto out;
797 new_keypair->i_am_the_initiator = handshake->state ==
798 HANDSHAKE_CONSUMED_RESPONSE;
799 new_keypair->remote_index = handshake->remote_index;
801 if (new_keypair->i_am_the_initiator)
802 derive_keys(&new_keypair->sending, &new_keypair->receiving,
803 handshake->chaining_key);
804 else
805 derive_keys(&new_keypair->receiving, &new_keypair->sending,
806 handshake->chaining_key);
808 handshake_zero(handshake);
809 rcu_read_lock_bh();
810 if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
811 handshake)->is_dead))) {
812 add_new_keypair(keypairs, new_keypair);
813 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
814 handshake->entry.peer->device->dev->name,
815 new_keypair->internal_id,
816 handshake->entry.peer->internal_id);
817 ret = wg_index_hashtable_replace(
818 handshake->entry.peer->device->index_hashtable,
819 &handshake->entry, &new_keypair->entry);
820 } else {
821 kzfree(new_keypair);
823 rcu_read_unlock_bh();
825 out:
826 up_write(&handshake->lock);
827 return ret;