1 // SPDX-License-Identifier: GPL-2.0
3 #include <net/strparser.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9 #include <trace/events/sock.h>
10 #if IS_ENABLED(CONFIG_IPV6)
11 #include <net/ipv6_stubs.h>
13 #include <net/hotdata.h>
15 static void handle_nonesp(struct espintcp_ctx
*ctx
, struct sk_buff
*skb
,
18 if (atomic_read(&sk
->sk_rmem_alloc
) >= sk
->sk_rcvbuf
||
19 !sk_rmem_schedule(sk
, skb
, skb
->truesize
)) {
20 XFRM_INC_STATS(sock_net(sk
), LINUX_MIB_XFRMINERROR
);
25 skb_set_owner_r(skb
, sk
);
27 memset(skb
->cb
, 0, sizeof(skb
->cb
));
28 skb_queue_tail(&ctx
->ike_queue
, skb
);
29 ctx
->saved_data_ready(sk
);
32 static void handle_esp(struct sk_buff
*skb
, struct sock
*sk
)
34 struct tcp_skb_cb
*tcp_cb
= (struct tcp_skb_cb
*)skb
->cb
;
36 skb_reset_transport_header(skb
);
38 /* restore IP CB, we need at least IP6CB->nhoff */
39 memmove(skb
->cb
, &tcp_cb
->header
, sizeof(tcp_cb
->header
));
42 skb
->dev
= dev_get_by_index_rcu(sock_net(sk
), skb
->skb_iif
);
44 #if IS_ENABLED(CONFIG_IPV6)
45 if (sk
->sk_family
== AF_INET6
)
46 ipv6_stub
->xfrm6_rcv_encap(skb
, IPPROTO_ESP
, 0, TCP_ENCAP_ESPINTCP
);
49 xfrm4_rcv_encap(skb
, IPPROTO_ESP
, 0, TCP_ENCAP_ESPINTCP
);
54 static void espintcp_rcv(struct strparser
*strp
, struct sk_buff
*skb
)
56 struct espintcp_ctx
*ctx
= container_of(strp
, struct espintcp_ctx
,
58 struct strp_msg
*rxm
= strp_msg(skb
);
59 int len
= rxm
->full_len
- 2;
63 /* keepalive packet? */
64 if (unlikely(len
== 1)) {
67 err
= skb_copy_bits(skb
, rxm
->offset
+ 2, &data
, 1);
69 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
80 /* drop other short messages */
81 if (unlikely(len
<= sizeof(nonesp_marker
))) {
82 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
87 err
= skb_copy_bits(skb
, rxm
->offset
+ 2, &nonesp_marker
,
88 sizeof(nonesp_marker
));
90 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
95 /* remove header, leave non-ESP marker/SPI */
96 if (!pskb_pull(skb
, rxm
->offset
+ 2)) {
97 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINERROR
);
102 if (pskb_trim(skb
, rxm
->full_len
- 2) != 0) {
103 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINERROR
);
108 if (nonesp_marker
== 0)
109 handle_nonesp(ctx
, skb
, strp
->sk
);
111 handle_esp(skb
, strp
->sk
);
114 static int espintcp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
116 struct strp_msg
*rxm
= strp_msg(skb
);
121 if (skb
->len
< rxm
->offset
+ 2)
124 err
= skb_copy_bits(skb
, rxm
->offset
, &blen
, sizeof(blen
));
128 len
= be16_to_cpu(blen
);
135 static int espintcp_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
136 int flags
, int *addr_len
)
138 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
144 skb
= __skb_recv_datagram(sk
, &ctx
->ike_queue
, flags
, &off
, &err
);
146 if (err
== -EAGAIN
&& sk
->sk_shutdown
& RCV_SHUTDOWN
)
152 if (copied
> skb
->len
)
154 else if (copied
< skb
->len
)
155 msg
->msg_flags
|= MSG_TRUNC
;
157 err
= skb_copy_datagram_msg(skb
, 0, msg
, copied
);
163 if (flags
& MSG_TRUNC
)
169 int espintcp_queue_out(struct sock
*sk
, struct sk_buff
*skb
)
171 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
173 if (skb_queue_len(&ctx
->out_queue
) >=
174 READ_ONCE(net_hotdata
.max_backlog
))
177 __skb_queue_tail(&ctx
->out_queue
, skb
);
181 EXPORT_SYMBOL_GPL(espintcp_queue_out
);
183 /* espintcp length field is 2B and length includes the length field's size */
184 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
186 static int espintcp_sendskb_locked(struct sock
*sk
, struct espintcp_msg
*emsg
,
192 ret
= skb_send_sock_locked(sk
, emsg
->skb
,
193 emsg
->offset
, emsg
->len
);
199 } while (emsg
->len
> 0);
201 kfree_skb(emsg
->skb
);
202 memset(emsg
, 0, sizeof(*emsg
));
207 static int espintcp_sendskmsg_locked(struct sock
*sk
,
208 struct espintcp_msg
*emsg
, int flags
)
210 struct msghdr msghdr
= {
211 .msg_flags
= flags
| MSG_SPLICE_PAGES
| MSG_MORE
,
213 struct sk_msg
*skmsg
= &emsg
->skmsg
;
214 bool more
= flags
& MSG_MORE
;
215 struct scatterlist
*sg
;
219 sg
= &skmsg
->sg
.data
[skmsg
->sg
.start
];
222 size_t size
= sg
->length
- emsg
->offset
;
223 int offset
= sg
->offset
+ emsg
->offset
;
228 if (sg_is_last(sg
) && !more
)
229 msghdr
.msg_flags
&= ~MSG_MORE
;
233 bvec_set_page(&bvec
, p
, size
, offset
);
234 iov_iter_bvec(&msghdr
.msg_iter
, ITER_SOURCE
, &bvec
, 1, size
);
235 ret
= tcp_sendmsg_locked(sk
, &msghdr
, size
);
237 emsg
->offset
= offset
- sg
->offset
;
238 skmsg
->sg
.start
+= done
;
250 sk_mem_uncharge(sk
, sg
->length
);
254 memset(emsg
, 0, sizeof(*emsg
));
259 static int espintcp_push_msgs(struct sock
*sk
, int flags
)
261 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
262 struct espintcp_msg
*emsg
= &ctx
->partial
;
273 err
= espintcp_sendskb_locked(sk
, emsg
, flags
);
275 err
= espintcp_sendskmsg_locked(sk
, emsg
, flags
);
276 if (err
== -EAGAIN
) {
278 return flags
& MSG_DONTWAIT
? -EAGAIN
: 0;
281 memset(emsg
, 0, sizeof(*emsg
));
288 int espintcp_push_skb(struct sock
*sk
, struct sk_buff
*skb
)
290 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
291 struct espintcp_msg
*emsg
= &ctx
->partial
;
295 if (sk
->sk_state
!= TCP_ESTABLISHED
) {
300 offset
= skb_transport_offset(skb
);
301 len
= skb
->len
- offset
;
303 espintcp_push_msgs(sk
, 0);
310 skb_set_owner_w(skb
, sk
);
312 emsg
->offset
= offset
;
316 espintcp_push_msgs(sk
, 0);
320 EXPORT_SYMBOL_GPL(espintcp_push_skb
);
322 static int espintcp_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
324 long timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
325 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
326 struct espintcp_msg
*emsg
= &ctx
->partial
;
327 struct iov_iter pfx_iter
;
328 struct kvec pfx_iov
= {};
329 size_t msglen
= size
+ 2;
333 if (msg
->msg_flags
& ~MSG_DONTWAIT
)
336 if (size
> MAX_ESPINTCP_MSG
)
339 if (msg
->msg_controllen
)
344 err
= espintcp_push_msgs(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
346 if (err
!= -EAGAIN
|| !(msg
->msg_flags
& MSG_DONTWAIT
))
351 sk_msg_init(&emsg
->skmsg
);
353 /* only -ENOMEM is possible since we don't coalesce */
354 err
= sk_msg_alloc(sk
, &emsg
->skmsg
, msglen
, 0);
358 err
= sk_stream_wait_memory(sk
, &timeo
);
363 *((__be16
*)buf
) = cpu_to_be16(msglen
);
364 pfx_iov
.iov_base
= buf
;
365 pfx_iov
.iov_len
= sizeof(buf
);
366 iov_iter_kvec(&pfx_iter
, ITER_SOURCE
, &pfx_iov
, 1, pfx_iov
.iov_len
);
368 err
= sk_msg_memcopy_from_iter(sk
, &pfx_iter
, &emsg
->skmsg
,
373 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, &emsg
->skmsg
, size
);
377 end
= emsg
->skmsg
.sg
.end
;
379 sk_msg_iter_var_prev(end
);
380 sg_mark_end(sk_msg_elem(&emsg
->skmsg
, end
));
382 tcp_rate_check_app_limited(sk
);
384 err
= espintcp_push_msgs(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
385 /* this message could be partially sent, keep it */
392 sk_msg_free(sk
, &emsg
->skmsg
);
393 memset(emsg
, 0, sizeof(*emsg
));
399 static struct proto espintcp_prot __ro_after_init
;
400 static struct proto_ops espintcp_ops __ro_after_init
;
401 static struct proto espintcp6_prot
;
402 static struct proto_ops espintcp6_ops
;
403 static DEFINE_MUTEX(tcpv6_prot_mutex
);
405 static void espintcp_data_ready(struct sock
*sk
)
407 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
409 trace_sk_data_ready(sk
);
411 strp_data_ready(&ctx
->strp
);
414 static void espintcp_tx_work(struct work_struct
*work
)
416 struct espintcp_ctx
*ctx
= container_of(work
,
417 struct espintcp_ctx
, work
);
418 struct sock
*sk
= ctx
->strp
.sk
;
421 if (!ctx
->tx_running
)
422 espintcp_push_msgs(sk
, 0);
426 static void espintcp_write_space(struct sock
*sk
)
428 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
430 schedule_work(&ctx
->work
);
431 ctx
->saved_write_space(sk
);
434 static void espintcp_destruct(struct sock
*sk
)
436 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
438 ctx
->saved_destruct(sk
);
442 bool tcp_is_ulp_esp(struct sock
*sk
)
444 return sk
->sk_prot
== &espintcp_prot
|| sk
->sk_prot
== &espintcp6_prot
;
446 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp
);
448 static void build_protos(struct proto
*espintcp_prot
,
449 struct proto_ops
*espintcp_ops
,
450 const struct proto
*orig_prot
,
451 const struct proto_ops
*orig_ops
);
452 static int espintcp_init_sk(struct sock
*sk
)
454 struct inet_connection_sock
*icsk
= inet_csk(sk
);
455 struct strp_callbacks cb
= {
456 .rcv_msg
= espintcp_rcv
,
457 .parse_msg
= espintcp_parse
,
459 struct espintcp_ctx
*ctx
;
462 /* sockmap is not compatible with espintcp */
463 if (sk
->sk_user_data
)
466 ctx
= kzalloc(sizeof(*ctx
), GFP_KERNEL
);
470 err
= strp_init(&ctx
->strp
, sk
, &cb
);
476 strp_check_rcv(&ctx
->strp
);
477 skb_queue_head_init(&ctx
->ike_queue
);
478 skb_queue_head_init(&ctx
->out_queue
);
480 if (sk
->sk_family
== AF_INET
) {
481 sk
->sk_prot
= &espintcp_prot
;
482 sk
->sk_socket
->ops
= &espintcp_ops
;
484 mutex_lock(&tcpv6_prot_mutex
);
485 if (!espintcp6_prot
.recvmsg
)
486 build_protos(&espintcp6_prot
, &espintcp6_ops
, sk
->sk_prot
, sk
->sk_socket
->ops
);
487 mutex_unlock(&tcpv6_prot_mutex
);
489 sk
->sk_prot
= &espintcp6_prot
;
490 sk
->sk_socket
->ops
= &espintcp6_ops
;
492 ctx
->saved_data_ready
= sk
->sk_data_ready
;
493 ctx
->saved_write_space
= sk
->sk_write_space
;
494 ctx
->saved_destruct
= sk
->sk_destruct
;
495 sk
->sk_data_ready
= espintcp_data_ready
;
496 sk
->sk_write_space
= espintcp_write_space
;
497 sk
->sk_destruct
= espintcp_destruct
;
498 rcu_assign_pointer(icsk
->icsk_ulp_data
, ctx
);
499 INIT_WORK(&ctx
->work
, espintcp_tx_work
);
501 /* avoid using task_frag */
502 sk
->sk_allocation
= GFP_ATOMIC
;
503 sk
->sk_use_task_frag
= false;
512 static void espintcp_release(struct sock
*sk
)
514 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
515 struct sk_buff_head queue
;
518 __skb_queue_head_init(&queue
);
519 skb_queue_splice_init(&ctx
->out_queue
, &queue
);
521 while ((skb
= __skb_dequeue(&queue
)))
522 espintcp_push_skb(sk
, skb
);
527 static void espintcp_close(struct sock
*sk
, long timeout
)
529 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
530 struct espintcp_msg
*emsg
= &ctx
->partial
;
532 strp_stop(&ctx
->strp
);
534 sk
->sk_prot
= &tcp_prot
;
537 cancel_work_sync(&ctx
->work
);
538 strp_done(&ctx
->strp
);
540 skb_queue_purge(&ctx
->out_queue
);
541 skb_queue_purge(&ctx
->ike_queue
);
545 kfree_skb(emsg
->skb
);
547 sk_msg_free(sk
, &emsg
->skmsg
);
550 tcp_close(sk
, timeout
);
553 static __poll_t
espintcp_poll(struct file
*file
, struct socket
*sock
,
556 __poll_t mask
= datagram_poll(file
, sock
, wait
);
557 struct sock
*sk
= sock
->sk
;
558 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
560 if (!skb_queue_empty(&ctx
->ike_queue
))
561 mask
|= EPOLLIN
| EPOLLRDNORM
;
566 static void build_protos(struct proto
*espintcp_prot
,
567 struct proto_ops
*espintcp_ops
,
568 const struct proto
*orig_prot
,
569 const struct proto_ops
*orig_ops
)
571 memcpy(espintcp_prot
, orig_prot
, sizeof(struct proto
));
572 memcpy(espintcp_ops
, orig_ops
, sizeof(struct proto_ops
));
573 espintcp_prot
->sendmsg
= espintcp_sendmsg
;
574 espintcp_prot
->recvmsg
= espintcp_recvmsg
;
575 espintcp_prot
->close
= espintcp_close
;
576 espintcp_prot
->release_cb
= espintcp_release
;
577 espintcp_ops
->poll
= espintcp_poll
;
580 static struct tcp_ulp_ops espintcp_ulp __read_mostly
= {
582 .owner
= THIS_MODULE
,
583 .init
= espintcp_init_sk
,
586 void __init
espintcp_init(void)
588 build_protos(&espintcp_prot
, &espintcp_ops
, &tcp_prot
, &inet_stream_ops
);
590 tcp_register_ulp(&espintcp_ulp
);