1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
12 #include <linux/ctype.h>
13 #include <linux/net.h>
14 #include <linux/if_vlan.h>
15 #include <linux/if_ether.h>
16 #include <linux/inetdevice.h>
17 #include <net/udp_tunnel.h>
20 static int send4(struct wg_device
*wg
, struct sk_buff
*skb
,
21 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
24 .saddr
= endpoint
->src4
.s_addr
,
25 .daddr
= endpoint
->addr4
.sin_addr
.s_addr
,
26 .fl4_dport
= endpoint
->addr4
.sin_port
,
27 .flowi4_mark
= wg
->fwmark
,
28 .flowi4_proto
= IPPROTO_UDP
30 struct rtable
*rt
= NULL
;
34 skb_mark_not_on_list(skb
);
36 skb
->mark
= wg
->fwmark
;
39 sock
= rcu_dereference_bh(wg
->sock4
);
41 if (unlikely(!sock
)) {
46 fl
.fl4_sport
= inet_sk(sock
)->inet_sport
;
49 rt
= dst_cache_get_ip4(cache
, &fl
.saddr
);
52 security_sk_classify_flow(sock
, flowi4_to_flowi_common(&fl
));
53 if (unlikely(!inet_confirm_addr(sock_net(sock
), NULL
, 0,
54 fl
.saddr
, RT_SCOPE_HOST
))) {
55 endpoint
->src4
.s_addr
= 0;
56 endpoint
->src_if4
= 0;
59 dst_cache_reset(cache
);
61 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
62 if (unlikely(endpoint
->src_if4
&& ((IS_ERR(rt
) &&
63 PTR_ERR(rt
) == -EINVAL
) || (!IS_ERR(rt
) &&
64 rt
->dst
.dev
->ifindex
!= endpoint
->src_if4
)))) {
65 endpoint
->src4
.s_addr
= 0;
66 endpoint
->src_if4
= 0;
69 dst_cache_reset(cache
);
72 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
76 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
77 wg
->dev
->name
, &endpoint
->addr
, ret
);
81 dst_cache_set_ip4(cache
, &rt
->dst
, fl
.saddr
);
85 udp_tunnel_xmit_skb(rt
, sock
, skb
, fl
.saddr
, fl
.daddr
, ds
,
86 ip4_dst_hoplimit(&rt
->dst
), 0, fl
.fl4_sport
,
87 fl
.fl4_dport
, false, false);
97 static int send6(struct wg_device
*wg
, struct sk_buff
*skb
,
98 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
100 #if IS_ENABLED(CONFIG_IPV6)
102 .saddr
= endpoint
->src6
,
103 .daddr
= endpoint
->addr6
.sin6_addr
,
104 .fl6_dport
= endpoint
->addr6
.sin6_port
,
105 .flowi6_mark
= wg
->fwmark
,
106 .flowi6_oif
= endpoint
->addr6
.sin6_scope_id
,
107 .flowi6_proto
= IPPROTO_UDP
108 /* TODO: addr->sin6_flowinfo */
110 struct dst_entry
*dst
= NULL
;
114 skb_mark_not_on_list(skb
);
116 skb
->mark
= wg
->fwmark
;
119 sock
= rcu_dereference_bh(wg
->sock6
);
121 if (unlikely(!sock
)) {
126 fl
.fl6_sport
= inet_sk(sock
)->inet_sport
;
129 dst
= dst_cache_get_ip6(cache
, &fl
.saddr
);
132 security_sk_classify_flow(sock
, flowi6_to_flowi_common(&fl
));
133 if (unlikely(!ipv6_addr_any(&fl
.saddr
) &&
134 !ipv6_chk_addr(sock_net(sock
), &fl
.saddr
, NULL
, 0))) {
135 endpoint
->src6
= fl
.saddr
= in6addr_any
;
137 dst_cache_reset(cache
);
139 dst
= ipv6_stub
->ipv6_dst_lookup_flow(sock_net(sock
), sock
, &fl
,
143 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
144 wg
->dev
->name
, &endpoint
->addr
, ret
);
148 dst_cache_set_ip6(cache
, dst
, &fl
.saddr
);
152 udp_tunnel6_xmit_skb(dst
, sock
, skb
, skb
->dev
, &fl
.saddr
, &fl
.daddr
, ds
,
153 ip6_dst_hoplimit(dst
), 0, fl
.fl6_sport
,
154 fl
.fl6_dport
, false);
160 rcu_read_unlock_bh();
164 return -EAFNOSUPPORT
;
168 int wg_socket_send_skb_to_peer(struct wg_peer
*peer
, struct sk_buff
*skb
, u8 ds
)
170 size_t skb_len
= skb
->len
;
171 int ret
= -EAFNOSUPPORT
;
173 read_lock_bh(&peer
->endpoint_lock
);
174 if (peer
->endpoint
.addr
.sa_family
== AF_INET
)
175 ret
= send4(peer
->device
, skb
, &peer
->endpoint
, ds
,
176 &peer
->endpoint_cache
);
177 else if (peer
->endpoint
.addr
.sa_family
== AF_INET6
)
178 ret
= send6(peer
->device
, skb
, &peer
->endpoint
, ds
,
179 &peer
->endpoint_cache
);
183 peer
->tx_bytes
+= skb_len
;
184 read_unlock_bh(&peer
->endpoint_lock
);
189 int wg_socket_send_buffer_to_peer(struct wg_peer
*peer
, void *buffer
,
192 struct sk_buff
*skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
197 skb_reserve(skb
, SKB_HEADER_LEN
);
198 skb_set_inner_network_header(skb
, 0);
199 skb_put_data(skb
, buffer
, len
);
200 return wg_socket_send_skb_to_peer(peer
, skb
, ds
);
203 int wg_socket_send_buffer_as_reply_to_skb(struct wg_device
*wg
,
204 struct sk_buff
*in_skb
, void *buffer
,
209 struct endpoint endpoint
;
211 if (unlikely(!in_skb
))
213 ret
= wg_socket_endpoint_from_skb(&endpoint
, in_skb
);
214 if (unlikely(ret
< 0))
217 skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
220 skb_reserve(skb
, SKB_HEADER_LEN
);
221 skb_set_inner_network_header(skb
, 0);
222 skb_put_data(skb
, buffer
, len
);
224 if (endpoint
.addr
.sa_family
== AF_INET
)
225 ret
= send4(wg
, skb
, &endpoint
, 0, NULL
);
226 else if (endpoint
.addr
.sa_family
== AF_INET6
)
227 ret
= send6(wg
, skb
, &endpoint
, 0, NULL
);
228 /* No other possibilities if the endpoint is valid, which it is,
229 * as we checked above.
235 int wg_socket_endpoint_from_skb(struct endpoint
*endpoint
,
236 const struct sk_buff
*skb
)
238 memset(endpoint
, 0, sizeof(*endpoint
));
239 if (skb
->protocol
== htons(ETH_P_IP
)) {
240 endpoint
->addr4
.sin_family
= AF_INET
;
241 endpoint
->addr4
.sin_port
= udp_hdr(skb
)->source
;
242 endpoint
->addr4
.sin_addr
.s_addr
= ip_hdr(skb
)->saddr
;
243 endpoint
->src4
.s_addr
= ip_hdr(skb
)->daddr
;
244 endpoint
->src_if4
= skb
->skb_iif
;
245 } else if (IS_ENABLED(CONFIG_IPV6
) && skb
->protocol
== htons(ETH_P_IPV6
)) {
246 endpoint
->addr6
.sin6_family
= AF_INET6
;
247 endpoint
->addr6
.sin6_port
= udp_hdr(skb
)->source
;
248 endpoint
->addr6
.sin6_addr
= ipv6_hdr(skb
)->saddr
;
249 endpoint
->addr6
.sin6_scope_id
= ipv6_iface_scope_id(
250 &ipv6_hdr(skb
)->saddr
, skb
->skb_iif
);
251 endpoint
->src6
= ipv6_hdr(skb
)->daddr
;
258 static bool endpoint_eq(const struct endpoint
*a
, const struct endpoint
*b
)
260 return (a
->addr
.sa_family
== AF_INET
&& b
->addr
.sa_family
== AF_INET
&&
261 a
->addr4
.sin_port
== b
->addr4
.sin_port
&&
262 a
->addr4
.sin_addr
.s_addr
== b
->addr4
.sin_addr
.s_addr
&&
263 a
->src4
.s_addr
== b
->src4
.s_addr
&& a
->src_if4
== b
->src_if4
) ||
264 (a
->addr
.sa_family
== AF_INET6
&&
265 b
->addr
.sa_family
== AF_INET6
&&
266 a
->addr6
.sin6_port
== b
->addr6
.sin6_port
&&
267 ipv6_addr_equal(&a
->addr6
.sin6_addr
, &b
->addr6
.sin6_addr
) &&
268 a
->addr6
.sin6_scope_id
== b
->addr6
.sin6_scope_id
&&
269 ipv6_addr_equal(&a
->src6
, &b
->src6
)) ||
270 unlikely(!a
->addr
.sa_family
&& !b
->addr
.sa_family
);
273 void wg_socket_set_peer_endpoint(struct wg_peer
*peer
,
274 const struct endpoint
*endpoint
)
276 /* First we check unlocked, in order to optimize, since it's pretty rare
277 * that an endpoint will change. If we happen to be mid-write, and two
278 * CPUs wind up writing the same thing or something slightly different,
279 * it doesn't really matter much either.
281 if (endpoint_eq(endpoint
, &peer
->endpoint
))
283 write_lock_bh(&peer
->endpoint_lock
);
284 if (endpoint
->addr
.sa_family
== AF_INET
) {
285 peer
->endpoint
.addr4
= endpoint
->addr4
;
286 peer
->endpoint
.src4
= endpoint
->src4
;
287 peer
->endpoint
.src_if4
= endpoint
->src_if4
;
288 } else if (IS_ENABLED(CONFIG_IPV6
) && endpoint
->addr
.sa_family
== AF_INET6
) {
289 peer
->endpoint
.addr6
= endpoint
->addr6
;
290 peer
->endpoint
.src6
= endpoint
->src6
;
294 dst_cache_reset(&peer
->endpoint_cache
);
296 write_unlock_bh(&peer
->endpoint_lock
);
299 void wg_socket_set_peer_endpoint_from_skb(struct wg_peer
*peer
,
300 const struct sk_buff
*skb
)
302 struct endpoint endpoint
;
304 if (!wg_socket_endpoint_from_skb(&endpoint
, skb
))
305 wg_socket_set_peer_endpoint(peer
, &endpoint
);
308 void wg_socket_clear_peer_endpoint_src(struct wg_peer
*peer
)
310 write_lock_bh(&peer
->endpoint_lock
);
311 memset(&peer
->endpoint
.src6
, 0, sizeof(peer
->endpoint
.src6
));
312 dst_cache_reset_now(&peer
->endpoint_cache
);
313 write_unlock_bh(&peer
->endpoint_lock
);
316 static int wg_receive(struct sock
*sk
, struct sk_buff
*skb
)
318 struct wg_device
*wg
;
322 wg
= sk
->sk_user_data
;
325 skb_mark_not_on_list(skb
);
326 wg_packet_receive(wg
, skb
);
334 static void sock_free(struct sock
*sock
)
338 sk_clear_memalloc(sock
);
339 udp_tunnel_sock_release(sock
->sk_socket
);
342 static void set_sock_opts(struct socket
*sock
)
344 sock
->sk
->sk_allocation
= GFP_ATOMIC
;
345 sock
->sk
->sk_sndbuf
= INT_MAX
;
346 sk_set_memalloc(sock
->sk
);
349 int wg_socket_init(struct wg_device
*wg
, u16 port
)
353 struct udp_tunnel_sock_cfg cfg
= {
356 .encap_rcv
= wg_receive
358 struct socket
*new4
= NULL
, *new6
= NULL
;
359 struct udp_port_cfg port4
= {
361 .local_ip
.s_addr
= htonl(INADDR_ANY
),
362 .local_udp_port
= htons(port
),
363 .use_udp_checksums
= true
365 #if IS_ENABLED(CONFIG_IPV6)
367 struct udp_port_cfg port6
= {
369 .local_ip6
= IN6ADDR_ANY_INIT
,
370 .use_udp6_tx_checksums
= true,
371 .use_udp6_rx_checksums
= true,
377 net
= rcu_dereference(wg
->creating_net
);
378 net
= net
? maybe_get_net(net
) : NULL
;
383 #if IS_ENABLED(CONFIG_IPV6)
387 ret
= udp_sock_create(net
, &port4
, &new4
);
389 pr_err("%s: Could not create IPv4 socket\n", wg
->dev
->name
);
393 setup_udp_tunnel_sock(net
, new4
, &cfg
);
395 #if IS_ENABLED(CONFIG_IPV6)
396 if (ipv6_mod_enabled()) {
397 port6
.local_udp_port
= inet_sk(new4
->sk
)->inet_sport
;
398 ret
= udp_sock_create(net
, &port6
, &new6
);
400 udp_tunnel_sock_release(new4
);
401 if (ret
== -EADDRINUSE
&& !port
&& retries
++ < 100)
403 pr_err("%s: Could not create IPv6 socket\n",
408 setup_udp_tunnel_sock(net
, new6
, &cfg
);
412 wg_socket_reinit(wg
, new4
->sk
, new6
? new6
->sk
: NULL
);
419 void wg_socket_reinit(struct wg_device
*wg
, struct sock
*new4
,
422 struct sock
*old4
, *old6
;
424 mutex_lock(&wg
->socket_update_lock
);
425 old4
= rcu_dereference_protected(wg
->sock4
,
426 lockdep_is_held(&wg
->socket_update_lock
));
427 old6
= rcu_dereference_protected(wg
->sock6
,
428 lockdep_is_held(&wg
->socket_update_lock
));
429 rcu_assign_pointer(wg
->sock4
, new4
);
430 rcu_assign_pointer(wg
->sock6
, new6
);
432 wg
->incoming_port
= ntohs(inet_sk(new4
)->inet_sport
);
433 mutex_unlock(&wg
->socket_update_lock
);