1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
15 #include <linux/ipv6.h>
16 #include <linux/udp.h>
17 #include <net/ip_tunnels.h>
19 /* Must be called with bh disabled. */
20 static void update_rx_stats(struct wg_peer
*peer
, size_t len
)
22 struct pcpu_sw_netstats
*tstats
=
23 get_cpu_ptr(peer
->device
->dev
->tstats
);
25 u64_stats_update_begin(&tstats
->syncp
);
27 tstats
->rx_bytes
+= len
;
28 peer
->rx_bytes
+= len
;
29 u64_stats_update_end(&tstats
->syncp
);
33 #define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
35 static size_t validate_header_len(struct sk_buff
*skb
)
37 if (unlikely(skb
->len
< sizeof(struct message_header
)))
39 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_DATA
) &&
40 skb
->len
>= MESSAGE_MINIMUM_LENGTH
)
41 return sizeof(struct message_data
);
42 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
) &&
43 skb
->len
== sizeof(struct message_handshake_initiation
))
44 return sizeof(struct message_handshake_initiation
);
45 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
) &&
46 skb
->len
== sizeof(struct message_handshake_response
))
47 return sizeof(struct message_handshake_response
);
48 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
) &&
49 skb
->len
== sizeof(struct message_handshake_cookie
))
50 return sizeof(struct message_handshake_cookie
);
54 static int prepare_skb_header(struct sk_buff
*skb
, struct wg_device
*wg
)
56 size_t data_offset
, data_len
, header_len
;
59 if (unlikely(wg_skb_examine_untrusted_ip_hdr(skb
) != skb
->protocol
||
60 skb_transport_header(skb
) < skb
->head
||
61 (skb_transport_header(skb
) + sizeof(struct udphdr
)) >
62 skb_tail_pointer(skb
)))
63 return -EINVAL
; /* Bogus IP header */
65 data_offset
= (u8
*)udp
- skb
->data
;
66 if (unlikely(data_offset
> U16_MAX
||
67 data_offset
+ sizeof(struct udphdr
) > skb
->len
))
68 /* Packet has offset at impossible location or isn't big enough
72 data_len
= ntohs(udp
->len
);
73 if (unlikely(data_len
< sizeof(struct udphdr
) ||
74 data_len
> skb
->len
- data_offset
))
75 /* UDP packet is reporting too small of a size or lying about
79 data_len
-= sizeof(struct udphdr
);
80 data_offset
= (u8
*)udp
+ sizeof(struct udphdr
) - skb
->data
;
81 if (unlikely(!pskb_may_pull(skb
,
82 data_offset
+ sizeof(struct message_header
)) ||
83 pskb_trim(skb
, data_len
+ data_offset
) < 0))
85 skb_pull(skb
, data_offset
);
86 if (unlikely(skb
->len
!= data_len
))
87 /* Final len does not agree with calculated len */
89 header_len
= validate_header_len(skb
);
90 if (unlikely(!header_len
))
92 __skb_push(skb
, data_offset
);
93 if (unlikely(!pskb_may_pull(skb
, data_offset
+ header_len
)))
95 __skb_pull(skb
, data_offset
);
99 static void wg_receive_handshake_packet(struct wg_device
*wg
,
102 enum cookie_mac_state mac_state
;
103 struct wg_peer
*peer
= NULL
;
104 /* This is global, so that our load calculation applies to the whole
105 * system. We don't care about races with it at all.
107 static u64 last_under_load
;
108 bool packet_needs_cookie
;
111 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
)) {
112 net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
114 wg_cookie_message_consume(
115 (struct message_handshake_cookie
*)skb
->data
, wg
);
119 under_load
= skb_queue_len(&wg
->incoming_handshakes
) >=
120 MAX_QUEUED_INCOMING_HANDSHAKES
/ 8;
122 last_under_load
= ktime_get_coarse_boottime_ns();
123 else if (last_under_load
)
124 under_load
= !wg_birthdate_has_expired(last_under_load
, 1);
125 mac_state
= wg_cookie_validate_packet(&wg
->cookie_checker
, skb
,
127 if ((under_load
&& mac_state
== VALID_MAC_WITH_COOKIE
) ||
128 (!under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
)) {
129 packet_needs_cookie
= false;
130 } else if (under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
) {
131 packet_needs_cookie
= true;
133 net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
138 switch (SKB_TYPE_LE32(skb
)) {
139 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
): {
140 struct message_handshake_initiation
*message
=
141 (struct message_handshake_initiation
*)skb
->data
;
143 if (packet_needs_cookie
) {
144 wg_packet_send_handshake_cookie(wg
, skb
,
145 message
->sender_index
);
148 peer
= wg_noise_handshake_consume_initiation(message
, wg
);
149 if (unlikely(!peer
)) {
150 net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
154 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
155 net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
156 wg
->dev
->name
, peer
->internal_id
,
157 &peer
->endpoint
.addr
);
158 wg_packet_send_handshake_response(peer
);
161 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
): {
162 struct message_handshake_response
*message
=
163 (struct message_handshake_response
*)skb
->data
;
165 if (packet_needs_cookie
) {
166 wg_packet_send_handshake_cookie(wg
, skb
,
167 message
->sender_index
);
170 peer
= wg_noise_handshake_consume_response(message
, wg
);
171 if (unlikely(!peer
)) {
172 net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
176 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
177 net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
178 wg
->dev
->name
, peer
->internal_id
,
179 &peer
->endpoint
.addr
);
180 if (wg_noise_handshake_begin_session(&peer
->handshake
,
182 wg_timers_session_derived(peer
);
183 wg_timers_handshake_complete(peer
);
184 /* Calling this function will either send any existing
185 * packets in the queue and not send a keepalive, which
186 * is the best case, Or, if there's nothing in the
187 * queue, it will send a keepalive, in order to give
188 * immediate confirmation of the session.
190 wg_packet_send_keepalive(peer
);
196 if (unlikely(!peer
)) {
197 WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
202 update_rx_stats(peer
, skb
->len
);
205 wg_timers_any_authenticated_packet_received(peer
);
206 wg_timers_any_authenticated_packet_traversal(peer
);
210 void wg_packet_handshake_receive_worker(struct work_struct
*work
)
212 struct wg_device
*wg
= container_of(work
, struct multicore_worker
,
216 while ((skb
= skb_dequeue(&wg
->incoming_handshakes
)) != NULL
) {
217 wg_receive_handshake_packet(wg
, skb
);
223 static void keep_key_fresh(struct wg_peer
*peer
)
225 struct noise_keypair
*keypair
;
228 if (peer
->sent_lastminute_handshake
)
232 keypair
= rcu_dereference_bh(peer
->keypairs
.current_keypair
);
233 if (likely(keypair
&& READ_ONCE(keypair
->sending
.is_valid
)) &&
234 keypair
->i_am_the_initiator
&&
235 unlikely(wg_birthdate_has_expired(keypair
->sending
.birthdate
,
236 REJECT_AFTER_TIME
- KEEPALIVE_TIMEOUT
- REKEY_TIMEOUT
)))
238 rcu_read_unlock_bh();
241 peer
->sent_lastminute_handshake
= true;
242 wg_packet_send_queued_handshake_initiation(peer
, false);
246 static bool decrypt_packet(struct sk_buff
*skb
, struct noise_symmetric_key
*key
)
248 struct scatterlist sg
[MAX_SKB_FRAGS
+ 8];
249 struct sk_buff
*trailer
;
256 if (unlikely(!READ_ONCE(key
->is_valid
) ||
257 wg_birthdate_has_expired(key
->birthdate
, REJECT_AFTER_TIME
) ||
258 key
->counter
.receive
.counter
>= REJECT_AFTER_MESSAGES
)) {
259 WRITE_ONCE(key
->is_valid
, false);
263 PACKET_CB(skb
)->nonce
=
264 le64_to_cpu(((struct message_data
*)skb
->data
)->counter
);
266 /* We ensure that the network header is part of the packet before we
267 * call skb_cow_data, so that there's no chance that data is removed
268 * from the skb, so that later we can extract the original endpoint.
270 offset
= skb
->data
- skb_network_header(skb
);
271 skb_push(skb
, offset
);
272 num_frags
= skb_cow_data(skb
, 0, &trailer
);
273 offset
+= sizeof(struct message_data
);
274 skb_pull(skb
, offset
);
275 if (unlikely(num_frags
< 0 || num_frags
> ARRAY_SIZE(sg
)))
278 sg_init_table(sg
, num_frags
);
279 if (skb_to_sgvec(skb
, sg
, 0, skb
->len
) <= 0)
282 if (!chacha20poly1305_decrypt_sg_inplace(sg
, skb
->len
, NULL
, 0,
283 PACKET_CB(skb
)->nonce
,
287 /* Another ugly situation of pushing and pulling the header so as to
288 * keep endpoint information intact.
290 skb_push(skb
, offset
);
291 if (pskb_trim(skb
, skb
->len
- noise_encrypted_len(0)))
293 skb_pull(skb
, offset
);
298 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
299 static bool counter_validate(union noise_counter
*counter
, u64 their_counter
)
301 unsigned long index
, index_current
, top
, i
;
304 spin_lock_bh(&counter
->receive
.lock
);
306 if (unlikely(counter
->receive
.counter
>= REJECT_AFTER_MESSAGES
+ 1 ||
307 their_counter
>= REJECT_AFTER_MESSAGES
))
312 if (unlikely((COUNTER_WINDOW_SIZE
+ their_counter
) <
313 counter
->receive
.counter
))
316 index
= their_counter
>> ilog2(BITS_PER_LONG
);
318 if (likely(their_counter
> counter
->receive
.counter
)) {
319 index_current
= counter
->receive
.counter
>> ilog2(BITS_PER_LONG
);
320 top
= min_t(unsigned long, index
- index_current
,
321 COUNTER_BITS_TOTAL
/ BITS_PER_LONG
);
322 for (i
= 1; i
<= top
; ++i
)
323 counter
->receive
.backtrack
[(i
+ index_current
) &
324 ((COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1)] = 0;
325 counter
->receive
.counter
= their_counter
;
328 index
&= (COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1;
329 ret
= !test_and_set_bit(their_counter
& (BITS_PER_LONG
- 1),
330 &counter
->receive
.backtrack
[index
]);
333 spin_unlock_bh(&counter
->receive
.lock
);
337 #include "selftest/counter.c"
339 static void wg_packet_consume_data_done(struct wg_peer
*peer
,
341 struct endpoint
*endpoint
)
343 struct net_device
*dev
= peer
->device
->dev
;
344 unsigned int len
, len_before_trim
;
345 struct wg_peer
*routed_peer
;
347 wg_socket_set_peer_endpoint(peer
, endpoint
);
349 if (unlikely(wg_noise_received_with_keypair(&peer
->keypairs
,
350 PACKET_CB(skb
)->keypair
))) {
351 wg_timers_handshake_complete(peer
);
352 wg_packet_send_staged_packets(peer
);
355 keep_key_fresh(peer
);
357 wg_timers_any_authenticated_packet_received(peer
);
358 wg_timers_any_authenticated_packet_traversal(peer
);
360 /* A packet with length 0 is a keepalive packet */
361 if (unlikely(!skb
->len
)) {
362 update_rx_stats(peer
, message_data_len(0));
363 net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
364 dev
->name
, peer
->internal_id
,
365 &peer
->endpoint
.addr
);
366 goto packet_processed
;
369 wg_timers_data_received(peer
);
371 if (unlikely(skb_network_header(skb
) < skb
->head
))
372 goto dishonest_packet_size
;
373 if (unlikely(!(pskb_network_may_pull(skb
, sizeof(struct iphdr
)) &&
374 (ip_hdr(skb
)->version
== 4 ||
375 (ip_hdr(skb
)->version
== 6 &&
376 pskb_network_may_pull(skb
, sizeof(struct ipv6hdr
)))))))
377 goto dishonest_packet_type
;
380 /* We've already verified the Poly1305 auth tag, which means this packet
381 * was not modified in transit. We can therefore tell the networking
382 * stack that all checksums of every layer of encapsulation have already
383 * been checked "by the hardware" and therefore is unnecessary to check
386 skb
->ip_summed
= CHECKSUM_UNNECESSARY
;
387 skb
->csum_level
= ~0; /* All levels */
388 skb
->protocol
= wg_skb_examine_untrusted_ip_hdr(skb
);
389 if (skb
->protocol
== htons(ETH_P_IP
)) {
390 len
= ntohs(ip_hdr(skb
)->tot_len
);
391 if (unlikely(len
< sizeof(struct iphdr
)))
392 goto dishonest_packet_size
;
393 if (INET_ECN_is_ce(PACKET_CB(skb
)->ds
))
394 IP_ECN_set_ce(ip_hdr(skb
));
395 } else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
396 len
= ntohs(ipv6_hdr(skb
)->payload_len
) +
397 sizeof(struct ipv6hdr
);
398 if (INET_ECN_is_ce(PACKET_CB(skb
)->ds
))
399 IP6_ECN_set_ce(skb
, ipv6_hdr(skb
));
401 goto dishonest_packet_type
;
404 if (unlikely(len
> skb
->len
))
405 goto dishonest_packet_size
;
406 len_before_trim
= skb
->len
;
407 if (unlikely(pskb_trim(skb
, len
)))
408 goto packet_processed
;
410 routed_peer
= wg_allowedips_lookup_src(&peer
->device
->peer_allowedips
,
412 wg_peer_put(routed_peer
); /* We don't need the extra reference. */
414 if (unlikely(routed_peer
!= peer
))
415 goto dishonest_packet_peer
;
417 if (unlikely(napi_gro_receive(&peer
->napi
, skb
) == GRO_DROP
)) {
418 ++dev
->stats
.rx_dropped
;
419 net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n",
420 dev
->name
, peer
->internal_id
,
421 &peer
->endpoint
.addr
);
423 update_rx_stats(peer
, message_data_len(len_before_trim
));
427 dishonest_packet_peer
:
428 net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
429 dev
->name
, skb
, peer
->internal_id
,
430 &peer
->endpoint
.addr
);
431 ++dev
->stats
.rx_errors
;
432 ++dev
->stats
.rx_frame_errors
;
433 goto packet_processed
;
434 dishonest_packet_type
:
435 net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
436 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
437 ++dev
->stats
.rx_errors
;
438 ++dev
->stats
.rx_frame_errors
;
439 goto packet_processed
;
440 dishonest_packet_size
:
441 net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
442 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
443 ++dev
->stats
.rx_errors
;
444 ++dev
->stats
.rx_length_errors
;
445 goto packet_processed
;
450 int wg_packet_rx_poll(struct napi_struct
*napi
, int budget
)
452 struct wg_peer
*peer
= container_of(napi
, struct wg_peer
, napi
);
453 struct crypt_queue
*queue
= &peer
->rx_queue
;
454 struct noise_keypair
*keypair
;
455 struct endpoint endpoint
;
456 enum packet_state state
;
461 if (unlikely(budget
<= 0))
464 while ((skb
= __ptr_ring_peek(&queue
->ring
)) != NULL
&&
465 (state
= atomic_read_acquire(&PACKET_CB(skb
)->state
)) !=
466 PACKET_STATE_UNCRYPTED
) {
467 __ptr_ring_discard_one(&queue
->ring
);
468 peer
= PACKET_PEER(skb
);
469 keypair
= PACKET_CB(skb
)->keypair
;
472 if (unlikely(state
!= PACKET_STATE_CRYPTED
))
475 if (unlikely(!counter_validate(&keypair
->receiving
.counter
,
476 PACKET_CB(skb
)->nonce
))) {
477 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
478 peer
->device
->dev
->name
,
479 PACKET_CB(skb
)->nonce
,
480 keypair
->receiving
.counter
.receive
.counter
);
484 if (unlikely(wg_socket_endpoint_from_skb(&endpoint
, skb
)))
487 wg_reset_packet(skb
);
488 wg_packet_consume_data_done(peer
, skb
, &endpoint
);
492 wg_noise_keypair_put(keypair
, false);
497 if (++work_done
>= budget
)
501 if (work_done
< budget
)
502 napi_complete_done(napi
, work_done
);
507 void wg_packet_decrypt_worker(struct work_struct
*work
)
509 struct crypt_queue
*queue
= container_of(work
, struct multicore_worker
,
513 while ((skb
= ptr_ring_consume_bh(&queue
->ring
)) != NULL
) {
514 enum packet_state state
= likely(decrypt_packet(skb
,
515 &PACKET_CB(skb
)->keypair
->receiving
)) ?
516 PACKET_STATE_CRYPTED
: PACKET_STATE_DEAD
;
517 wg_queue_enqueue_per_peer_napi(skb
, state
);
521 static void wg_packet_consume_data(struct wg_device
*wg
, struct sk_buff
*skb
)
523 __le32 idx
= ((struct message_data
*)skb
->data
)->key_idx
;
524 struct wg_peer
*peer
= NULL
;
528 PACKET_CB(skb
)->keypair
=
529 (struct noise_keypair
*)wg_index_hashtable_lookup(
530 wg
->index_hashtable
, INDEX_HASHTABLE_KEYPAIR
, idx
,
532 if (unlikely(!wg_noise_keypair_get(PACKET_CB(skb
)->keypair
)))
535 if (unlikely(READ_ONCE(peer
->is_dead
)))
538 ret
= wg_queue_enqueue_per_device_and_peer(&wg
->decrypt_queue
,
539 &peer
->rx_queue
, skb
,
541 &wg
->decrypt_queue
.last_cpu
);
542 if (unlikely(ret
== -EPIPE
))
543 wg_queue_enqueue_per_peer_napi(skb
, PACKET_STATE_DEAD
);
544 if (likely(!ret
|| ret
== -EPIPE
)) {
545 rcu_read_unlock_bh();
549 wg_noise_keypair_put(PACKET_CB(skb
)->keypair
, false);
551 rcu_read_unlock_bh();
556 void wg_packet_receive(struct wg_device
*wg
, struct sk_buff
*skb
)
558 if (unlikely(prepare_skb_header(skb
, wg
) < 0))
560 switch (SKB_TYPE_LE32(skb
)) {
561 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
):
562 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
):
563 case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
): {
566 if (skb_queue_len(&wg
->incoming_handshakes
) >
567 MAX_QUEUED_INCOMING_HANDSHAKES
||
568 unlikely(!rng_is_initialized())) {
569 net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
573 skb_queue_tail(&wg
->incoming_handshakes
, skb
);
574 /* Queues up a call to packet_process_queued_handshake_
577 cpu
= wg_cpumask_next_online(&wg
->incoming_handshake_cpu
);
578 queue_work_on(cpu
, wg
->handshake_receive_wq
,
579 &per_cpu_ptr(wg
->incoming_handshakes_worker
, cpu
)->work
);
582 case cpu_to_le32(MESSAGE_DATA
):
583 PACKET_CB(skb
)->ds
= ip_tunnel_get_dsfield(ip_hdr(skb
), skb
);
584 wg_packet_consume_data(wg
, skb
);
587 net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n",