1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
15 #include <linux/ipv6.h>
16 #include <linux/udp.h>
17 #include <net/ip_tunnels.h>
19 /* Must be called with bh disabled. */
20 static void update_rx_stats(struct wg_peer
*peer
, size_t len
)
22 struct pcpu_sw_netstats
*tstats
=
23 get_cpu_ptr(peer
->device
->dev
->tstats
);
25 u64_stats_update_begin(&tstats
->syncp
);
27 tstats
->rx_bytes
+= len
;
28 peer
->rx_bytes
+= len
;
29 u64_stats_update_end(&tstats
->syncp
);
33 #define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
35 static size_t validate_header_len(struct sk_buff
*skb
)
37 if (unlikely(skb
->len
< sizeof(struct message_header
)))
39 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_DATA
) &&
40 skb
->len
>= MESSAGE_MINIMUM_LENGTH
)
41 return sizeof(struct message_data
);
42 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
) &&
43 skb
->len
== sizeof(struct message_handshake_initiation
))
44 return sizeof(struct message_handshake_initiation
);
45 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
) &&
46 skb
->len
== sizeof(struct message_handshake_response
))
47 return sizeof(struct message_handshake_response
);
48 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
) &&
49 skb
->len
== sizeof(struct message_handshake_cookie
))
50 return sizeof(struct message_handshake_cookie
);
54 static int prepare_skb_header(struct sk_buff
*skb
, struct wg_device
*wg
)
56 size_t data_offset
, data_len
, header_len
;
59 if (unlikely(!wg_check_packet_protocol(skb
) ||
60 skb_transport_header(skb
) < skb
->head
||
61 (skb_transport_header(skb
) + sizeof(struct udphdr
)) >
62 skb_tail_pointer(skb
)))
63 return -EINVAL
; /* Bogus IP header */
65 data_offset
= (u8
*)udp
- skb
->data
;
66 if (unlikely(data_offset
> U16_MAX
||
67 data_offset
+ sizeof(struct udphdr
) > skb
->len
))
68 /* Packet has offset at impossible location or isn't big enough
72 data_len
= ntohs(udp
->len
);
73 if (unlikely(data_len
< sizeof(struct udphdr
) ||
74 data_len
> skb
->len
- data_offset
))
75 /* UDP packet is reporting too small of a size or lying about
79 data_len
-= sizeof(struct udphdr
);
80 data_offset
= (u8
*)udp
+ sizeof(struct udphdr
) - skb
->data
;
81 if (unlikely(!pskb_may_pull(skb
,
82 data_offset
+ sizeof(struct message_header
)) ||
83 pskb_trim(skb
, data_len
+ data_offset
) < 0))
85 skb_pull(skb
, data_offset
);
86 if (unlikely(skb
->len
!= data_len
))
87 /* Final len does not agree with calculated len */
89 header_len
= validate_header_len(skb
);
90 if (unlikely(!header_len
))
92 __skb_push(skb
, data_offset
);
93 if (unlikely(!pskb_may_pull(skb
, data_offset
+ header_len
)))
95 __skb_pull(skb
, data_offset
);
99 static void wg_receive_handshake_packet(struct wg_device
*wg
,
102 enum cookie_mac_state mac_state
;
103 struct wg_peer
*peer
= NULL
;
104 /* This is global, so that our load calculation applies to the whole
105 * system. We don't care about races with it at all.
107 static u64 last_under_load
;
108 bool packet_needs_cookie
;
111 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
)) {
112 net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
114 wg_cookie_message_consume(
115 (struct message_handshake_cookie
*)skb
->data
, wg
);
119 under_load
= skb_queue_len(&wg
->incoming_handshakes
) >=
120 MAX_QUEUED_INCOMING_HANDSHAKES
/ 8;
122 last_under_load
= ktime_get_coarse_boottime_ns();
123 } else if (last_under_load
) {
124 under_load
= !wg_birthdate_has_expired(last_under_load
, 1);
128 mac_state
= wg_cookie_validate_packet(&wg
->cookie_checker
, skb
,
130 if ((under_load
&& mac_state
== VALID_MAC_WITH_COOKIE
) ||
131 (!under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
)) {
132 packet_needs_cookie
= false;
133 } else if (under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
) {
134 packet_needs_cookie
= true;
136 net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
141 switch (SKB_TYPE_LE32(skb
)) {
142 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
): {
143 struct message_handshake_initiation
*message
=
144 (struct message_handshake_initiation
*)skb
->data
;
146 if (packet_needs_cookie
) {
147 wg_packet_send_handshake_cookie(wg
, skb
,
148 message
->sender_index
);
151 peer
= wg_noise_handshake_consume_initiation(message
, wg
);
152 if (unlikely(!peer
)) {
153 net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
157 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
158 net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
159 wg
->dev
->name
, peer
->internal_id
,
160 &peer
->endpoint
.addr
);
161 wg_packet_send_handshake_response(peer
);
164 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
): {
165 struct message_handshake_response
*message
=
166 (struct message_handshake_response
*)skb
->data
;
168 if (packet_needs_cookie
) {
169 wg_packet_send_handshake_cookie(wg
, skb
,
170 message
->sender_index
);
173 peer
= wg_noise_handshake_consume_response(message
, wg
);
174 if (unlikely(!peer
)) {
175 net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
179 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
180 net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
181 wg
->dev
->name
, peer
->internal_id
,
182 &peer
->endpoint
.addr
);
183 if (wg_noise_handshake_begin_session(&peer
->handshake
,
185 wg_timers_session_derived(peer
);
186 wg_timers_handshake_complete(peer
);
187 /* Calling this function will either send any existing
188 * packets in the queue and not send a keepalive, which
189 * is the best case, Or, if there's nothing in the
190 * queue, it will send a keepalive, in order to give
191 * immediate confirmation of the session.
193 wg_packet_send_keepalive(peer
);
199 if (unlikely(!peer
)) {
200 WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
205 update_rx_stats(peer
, skb
->len
);
208 wg_timers_any_authenticated_packet_received(peer
);
209 wg_timers_any_authenticated_packet_traversal(peer
);
213 void wg_packet_handshake_receive_worker(struct work_struct
*work
)
215 struct wg_device
*wg
= container_of(work
, struct multicore_worker
,
219 while ((skb
= skb_dequeue(&wg
->incoming_handshakes
)) != NULL
) {
220 wg_receive_handshake_packet(wg
, skb
);
226 static void keep_key_fresh(struct wg_peer
*peer
)
228 struct noise_keypair
*keypair
;
231 if (peer
->sent_lastminute_handshake
)
235 keypair
= rcu_dereference_bh(peer
->keypairs
.current_keypair
);
236 send
= keypair
&& READ_ONCE(keypair
->sending
.is_valid
) &&
237 keypair
->i_am_the_initiator
&&
238 wg_birthdate_has_expired(keypair
->sending
.birthdate
,
239 REJECT_AFTER_TIME
- KEEPALIVE_TIMEOUT
- REKEY_TIMEOUT
);
240 rcu_read_unlock_bh();
242 if (unlikely(send
)) {
243 peer
->sent_lastminute_handshake
= true;
244 wg_packet_send_queued_handshake_initiation(peer
, false);
248 static bool decrypt_packet(struct sk_buff
*skb
, struct noise_keypair
*keypair
)
250 struct scatterlist sg
[MAX_SKB_FRAGS
+ 8];
251 struct sk_buff
*trailer
;
255 if (unlikely(!keypair
))
258 if (unlikely(!READ_ONCE(keypair
->receiving
.is_valid
) ||
259 wg_birthdate_has_expired(keypair
->receiving
.birthdate
, REJECT_AFTER_TIME
) ||
260 keypair
->receiving_counter
.counter
>= REJECT_AFTER_MESSAGES
)) {
261 WRITE_ONCE(keypair
->receiving
.is_valid
, false);
265 PACKET_CB(skb
)->nonce
=
266 le64_to_cpu(((struct message_data
*)skb
->data
)->counter
);
268 /* We ensure that the network header is part of the packet before we
269 * call skb_cow_data, so that there's no chance that data is removed
270 * from the skb, so that later we can extract the original endpoint.
272 offset
= skb
->data
- skb_network_header(skb
);
273 skb_push(skb
, offset
);
274 num_frags
= skb_cow_data(skb
, 0, &trailer
);
275 offset
+= sizeof(struct message_data
);
276 skb_pull(skb
, offset
);
277 if (unlikely(num_frags
< 0 || num_frags
> ARRAY_SIZE(sg
)))
280 sg_init_table(sg
, num_frags
);
281 if (skb_to_sgvec(skb
, sg
, 0, skb
->len
) <= 0)
284 if (!chacha20poly1305_decrypt_sg_inplace(sg
, skb
->len
, NULL
, 0,
285 PACKET_CB(skb
)->nonce
,
286 keypair
->receiving
.key
))
289 /* Another ugly situation of pushing and pulling the header so as to
290 * keep endpoint information intact.
292 skb_push(skb
, offset
);
293 if (pskb_trim(skb
, skb
->len
- noise_encrypted_len(0)))
295 skb_pull(skb
, offset
);
300 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
301 static bool counter_validate(struct noise_replay_counter
*counter
, u64 their_counter
)
303 unsigned long index
, index_current
, top
, i
;
306 spin_lock_bh(&counter
->lock
);
308 if (unlikely(counter
->counter
>= REJECT_AFTER_MESSAGES
+ 1 ||
309 their_counter
>= REJECT_AFTER_MESSAGES
))
314 if (unlikely((COUNTER_WINDOW_SIZE
+ their_counter
) <
318 index
= their_counter
>> ilog2(BITS_PER_LONG
);
320 if (likely(their_counter
> counter
->counter
)) {
321 index_current
= counter
->counter
>> ilog2(BITS_PER_LONG
);
322 top
= min_t(unsigned long, index
- index_current
,
323 COUNTER_BITS_TOTAL
/ BITS_PER_LONG
);
324 for (i
= 1; i
<= top
; ++i
)
325 counter
->backtrack
[(i
+ index_current
) &
326 ((COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1)] = 0;
327 counter
->counter
= their_counter
;
330 index
&= (COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1;
331 ret
= !test_and_set_bit(their_counter
& (BITS_PER_LONG
- 1),
332 &counter
->backtrack
[index
]);
335 spin_unlock_bh(&counter
->lock
);
339 #include "selftest/counter.c"
341 static void wg_packet_consume_data_done(struct wg_peer
*peer
,
343 struct endpoint
*endpoint
)
345 struct net_device
*dev
= peer
->device
->dev
;
346 unsigned int len
, len_before_trim
;
347 struct wg_peer
*routed_peer
;
349 wg_socket_set_peer_endpoint(peer
, endpoint
);
351 if (unlikely(wg_noise_received_with_keypair(&peer
->keypairs
,
352 PACKET_CB(skb
)->keypair
))) {
353 wg_timers_handshake_complete(peer
);
354 wg_packet_send_staged_packets(peer
);
357 keep_key_fresh(peer
);
359 wg_timers_any_authenticated_packet_received(peer
);
360 wg_timers_any_authenticated_packet_traversal(peer
);
362 /* A packet with length 0 is a keepalive packet */
363 if (unlikely(!skb
->len
)) {
364 update_rx_stats(peer
, message_data_len(0));
365 net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
366 dev
->name
, peer
->internal_id
,
367 &peer
->endpoint
.addr
);
368 goto packet_processed
;
371 wg_timers_data_received(peer
);
373 if (unlikely(skb_network_header(skb
) < skb
->head
))
374 goto dishonest_packet_size
;
375 if (unlikely(!(pskb_network_may_pull(skb
, sizeof(struct iphdr
)) &&
376 (ip_hdr(skb
)->version
== 4 ||
377 (ip_hdr(skb
)->version
== 6 &&
378 pskb_network_may_pull(skb
, sizeof(struct ipv6hdr
)))))))
379 goto dishonest_packet_type
;
382 /* We've already verified the Poly1305 auth tag, which means this packet
383 * was not modified in transit. We can therefore tell the networking
384 * stack that all checksums of every layer of encapsulation have already
385 * been checked "by the hardware" and therefore is unnecessary to check
388 skb
->ip_summed
= CHECKSUM_UNNECESSARY
;
389 skb
->csum_level
= ~0; /* All levels */
390 skb
->protocol
= ip_tunnel_parse_protocol(skb
);
391 if (skb
->protocol
== htons(ETH_P_IP
)) {
392 len
= ntohs(ip_hdr(skb
)->tot_len
);
393 if (unlikely(len
< sizeof(struct iphdr
)))
394 goto dishonest_packet_size
;
395 INET_ECN_decapsulate(skb
, PACKET_CB(skb
)->ds
, ip_hdr(skb
)->tos
);
396 } else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
397 len
= ntohs(ipv6_hdr(skb
)->payload_len
) +
398 sizeof(struct ipv6hdr
);
399 INET_ECN_decapsulate(skb
, PACKET_CB(skb
)->ds
, ipv6_get_dsfield(ipv6_hdr(skb
)));
401 goto dishonest_packet_type
;
404 if (unlikely(len
> skb
->len
))
405 goto dishonest_packet_size
;
406 len_before_trim
= skb
->len
;
407 if (unlikely(pskb_trim(skb
, len
)))
408 goto packet_processed
;
410 routed_peer
= wg_allowedips_lookup_src(&peer
->device
->peer_allowedips
,
412 wg_peer_put(routed_peer
); /* We don't need the extra reference. */
414 if (unlikely(routed_peer
!= peer
))
415 goto dishonest_packet_peer
;
417 napi_gro_receive(&peer
->napi
, skb
);
418 update_rx_stats(peer
, message_data_len(len_before_trim
));
421 dishonest_packet_peer
:
422 net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
423 dev
->name
, skb
, peer
->internal_id
,
424 &peer
->endpoint
.addr
);
425 ++dev
->stats
.rx_errors
;
426 ++dev
->stats
.rx_frame_errors
;
427 goto packet_processed
;
428 dishonest_packet_type
:
429 net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
430 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
431 ++dev
->stats
.rx_errors
;
432 ++dev
->stats
.rx_frame_errors
;
433 goto packet_processed
;
434 dishonest_packet_size
:
435 net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
436 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
437 ++dev
->stats
.rx_errors
;
438 ++dev
->stats
.rx_length_errors
;
439 goto packet_processed
;
444 int wg_packet_rx_poll(struct napi_struct
*napi
, int budget
)
446 struct wg_peer
*peer
= container_of(napi
, struct wg_peer
, napi
);
447 struct crypt_queue
*queue
= &peer
->rx_queue
;
448 struct noise_keypair
*keypair
;
449 struct endpoint endpoint
;
450 enum packet_state state
;
455 if (unlikely(budget
<= 0))
458 while ((skb
= __ptr_ring_peek(&queue
->ring
)) != NULL
&&
459 (state
= atomic_read_acquire(&PACKET_CB(skb
)->state
)) !=
460 PACKET_STATE_UNCRYPTED
) {
461 __ptr_ring_discard_one(&queue
->ring
);
462 peer
= PACKET_PEER(skb
);
463 keypair
= PACKET_CB(skb
)->keypair
;
466 if (unlikely(state
!= PACKET_STATE_CRYPTED
))
469 if (unlikely(!counter_validate(&keypair
->receiving_counter
,
470 PACKET_CB(skb
)->nonce
))) {
471 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
472 peer
->device
->dev
->name
,
473 PACKET_CB(skb
)->nonce
,
474 keypair
->receiving_counter
.counter
);
478 if (unlikely(wg_socket_endpoint_from_skb(&endpoint
, skb
)))
481 wg_reset_packet(skb
, false);
482 wg_packet_consume_data_done(peer
, skb
, &endpoint
);
486 wg_noise_keypair_put(keypair
, false);
491 if (++work_done
>= budget
)
495 if (work_done
< budget
)
496 napi_complete_done(napi
, work_done
);
501 void wg_packet_decrypt_worker(struct work_struct
*work
)
503 struct crypt_queue
*queue
= container_of(work
, struct multicore_worker
,
507 while ((skb
= ptr_ring_consume_bh(&queue
->ring
)) != NULL
) {
508 enum packet_state state
=
509 likely(decrypt_packet(skb
, PACKET_CB(skb
)->keypair
)) ?
510 PACKET_STATE_CRYPTED
: PACKET_STATE_DEAD
;
511 wg_queue_enqueue_per_peer_napi(skb
, state
);
517 static void wg_packet_consume_data(struct wg_device
*wg
, struct sk_buff
*skb
)
519 __le32 idx
= ((struct message_data
*)skb
->data
)->key_idx
;
520 struct wg_peer
*peer
= NULL
;
524 PACKET_CB(skb
)->keypair
=
525 (struct noise_keypair
*)wg_index_hashtable_lookup(
526 wg
->index_hashtable
, INDEX_HASHTABLE_KEYPAIR
, idx
,
528 if (unlikely(!wg_noise_keypair_get(PACKET_CB(skb
)->keypair
)))
531 if (unlikely(READ_ONCE(peer
->is_dead
)))
534 ret
= wg_queue_enqueue_per_device_and_peer(&wg
->decrypt_queue
,
535 &peer
->rx_queue
, skb
,
537 &wg
->decrypt_queue
.last_cpu
);
538 if (unlikely(ret
== -EPIPE
))
539 wg_queue_enqueue_per_peer_napi(skb
, PACKET_STATE_DEAD
);
540 if (likely(!ret
|| ret
== -EPIPE
)) {
541 rcu_read_unlock_bh();
545 wg_noise_keypair_put(PACKET_CB(skb
)->keypair
, false);
547 rcu_read_unlock_bh();
552 void wg_packet_receive(struct wg_device
*wg
, struct sk_buff
*skb
)
554 if (unlikely(prepare_skb_header(skb
, wg
) < 0))
556 switch (SKB_TYPE_LE32(skb
)) {
557 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
):
558 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
):
559 case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
): {
562 if (skb_queue_len(&wg
->incoming_handshakes
) >
563 MAX_QUEUED_INCOMING_HANDSHAKES
||
564 unlikely(!rng_is_initialized())) {
565 net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
569 skb_queue_tail(&wg
->incoming_handshakes
, skb
);
570 /* Queues up a call to packet_process_queued_handshake_
573 cpu
= wg_cpumask_next_online(&wg
->incoming_handshake_cpu
);
574 queue_work_on(cpu
, wg
->handshake_receive_wq
,
575 &per_cpu_ptr(wg
->incoming_handshakes_worker
, cpu
)->work
);
578 case cpu_to_le32(MESSAGE_DATA
):
579 PACKET_CB(skb
)->ds
= ip_tunnel_get_dsfield(ip_hdr(skb
), skb
);
580 wg_packet_consume_data(wg
, skb
);
583 WARN(1, "Non-exhaustive parsing of packet header lead to unknown packet type!\n");