1 // SPDX-License-Identifier: GPL-2.0
3 #include <net/strparser.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9 #if IS_ENABLED(CONFIG_IPV6)
10 #include <net/ipv6_stubs.h>
13 static void handle_nonesp(struct espintcp_ctx
*ctx
, struct sk_buff
*skb
,
16 if (atomic_read(&sk
->sk_rmem_alloc
) >= sk
->sk_rcvbuf
||
17 !sk_rmem_schedule(sk
, skb
, skb
->truesize
)) {
18 XFRM_INC_STATS(sock_net(sk
), LINUX_MIB_XFRMINERROR
);
23 skb_set_owner_r(skb
, sk
);
25 memset(skb
->cb
, 0, sizeof(skb
->cb
));
26 skb_queue_tail(&ctx
->ike_queue
, skb
);
27 ctx
->saved_data_ready(sk
);
30 static void handle_esp(struct sk_buff
*skb
, struct sock
*sk
)
32 struct tcp_skb_cb
*tcp_cb
= (struct tcp_skb_cb
*)skb
->cb
;
34 skb_reset_transport_header(skb
);
36 /* restore IP CB, we need at least IP6CB->nhoff */
37 memmove(skb
->cb
, &tcp_cb
->header
, sizeof(tcp_cb
->header
));
40 skb
->dev
= dev_get_by_index_rcu(sock_net(sk
), skb
->skb_iif
);
42 #if IS_ENABLED(CONFIG_IPV6)
43 if (sk
->sk_family
== AF_INET6
)
44 ipv6_stub
->xfrm6_rcv_encap(skb
, IPPROTO_ESP
, 0, TCP_ENCAP_ESPINTCP
);
47 xfrm4_rcv_encap(skb
, IPPROTO_ESP
, 0, TCP_ENCAP_ESPINTCP
);
52 static void espintcp_rcv(struct strparser
*strp
, struct sk_buff
*skb
)
54 struct espintcp_ctx
*ctx
= container_of(strp
, struct espintcp_ctx
,
56 struct strp_msg
*rxm
= strp_msg(skb
);
57 int len
= rxm
->full_len
- 2;
61 /* keepalive packet? */
62 if (unlikely(len
== 1)) {
65 err
= skb_copy_bits(skb
, rxm
->offset
+ 2, &data
, 1);
67 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
78 /* drop other short messages */
79 if (unlikely(len
<= sizeof(nonesp_marker
))) {
80 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
85 err
= skb_copy_bits(skb
, rxm
->offset
+ 2, &nonesp_marker
,
86 sizeof(nonesp_marker
));
88 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINHDRERROR
);
93 /* remove header, leave non-ESP marker/SPI */
94 if (!__pskb_pull(skb
, rxm
->offset
+ 2)) {
95 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINERROR
);
100 if (pskb_trim(skb
, rxm
->full_len
- 2) != 0) {
101 XFRM_INC_STATS(sock_net(strp
->sk
), LINUX_MIB_XFRMINERROR
);
106 if (nonesp_marker
== 0)
107 handle_nonesp(ctx
, skb
, strp
->sk
);
109 handle_esp(skb
, strp
->sk
);
112 static int espintcp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
114 struct strp_msg
*rxm
= strp_msg(skb
);
119 if (skb
->len
< rxm
->offset
+ 2)
122 err
= skb_copy_bits(skb
, rxm
->offset
, &blen
, sizeof(blen
));
126 len
= be16_to_cpu(blen
);
133 static int espintcp_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
134 int nonblock
, int flags
, int *addr_len
)
136 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
142 flags
|= nonblock
? MSG_DONTWAIT
: 0;
144 skb
= __skb_recv_datagram(sk
, &ctx
->ike_queue
, flags
, &off
, &err
);
146 if (err
== -EAGAIN
&& sk
->sk_shutdown
& RCV_SHUTDOWN
)
152 if (copied
> skb
->len
)
154 else if (copied
< skb
->len
)
155 msg
->msg_flags
|= MSG_TRUNC
;
157 err
= skb_copy_datagram_msg(skb
, 0, msg
, copied
);
163 if (flags
& MSG_TRUNC
)
169 int espintcp_queue_out(struct sock
*sk
, struct sk_buff
*skb
)
171 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
173 if (skb_queue_len(&ctx
->out_queue
) >= netdev_max_backlog
)
176 __skb_queue_tail(&ctx
->out_queue
, skb
);
180 EXPORT_SYMBOL_GPL(espintcp_queue_out
);
182 /* espintcp length field is 2B and length includes the length field's size */
183 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
185 static int espintcp_sendskb_locked(struct sock
*sk
, struct espintcp_msg
*emsg
,
191 ret
= skb_send_sock_locked(sk
, emsg
->skb
,
192 emsg
->offset
, emsg
->len
);
198 } while (emsg
->len
> 0);
200 kfree_skb(emsg
->skb
);
201 memset(emsg
, 0, sizeof(*emsg
));
206 static int espintcp_sendskmsg_locked(struct sock
*sk
,
207 struct espintcp_msg
*emsg
, int flags
)
209 struct sk_msg
*skmsg
= &emsg
->skmsg
;
210 struct scatterlist
*sg
;
214 flags
|= MSG_SENDPAGE_NOTLAST
;
215 sg
= &skmsg
->sg
.data
[skmsg
->sg
.start
];
217 size_t size
= sg
->length
- emsg
->offset
;
218 int offset
= sg
->offset
+ emsg
->offset
;
224 flags
&= ~MSG_SENDPAGE_NOTLAST
;
228 ret
= do_tcp_sendpages(sk
, p
, offset
, size
, flags
);
230 emsg
->offset
= offset
- sg
->offset
;
231 skmsg
->sg
.start
+= done
;
243 sk_mem_uncharge(sk
, sg
->length
);
247 memset(emsg
, 0, sizeof(*emsg
));
252 static int espintcp_push_msgs(struct sock
*sk
, int flags
)
254 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
255 struct espintcp_msg
*emsg
= &ctx
->partial
;
266 err
= espintcp_sendskb_locked(sk
, emsg
, flags
);
268 err
= espintcp_sendskmsg_locked(sk
, emsg
, flags
);
269 if (err
== -EAGAIN
) {
271 return flags
& MSG_DONTWAIT
? -EAGAIN
: 0;
274 memset(emsg
, 0, sizeof(*emsg
));
281 int espintcp_push_skb(struct sock
*sk
, struct sk_buff
*skb
)
283 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
284 struct espintcp_msg
*emsg
= &ctx
->partial
;
288 if (sk
->sk_state
!= TCP_ESTABLISHED
) {
293 offset
= skb_transport_offset(skb
);
294 len
= skb
->len
- offset
;
296 espintcp_push_msgs(sk
, 0);
303 skb_set_owner_w(skb
, sk
);
305 emsg
->offset
= offset
;
309 espintcp_push_msgs(sk
, 0);
313 EXPORT_SYMBOL_GPL(espintcp_push_skb
);
315 static int espintcp_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
317 long timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
318 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
319 struct espintcp_msg
*emsg
= &ctx
->partial
;
320 struct iov_iter pfx_iter
;
321 struct kvec pfx_iov
= {};
322 size_t msglen
= size
+ 2;
326 if (msg
->msg_flags
& ~MSG_DONTWAIT
)
329 if (size
> MAX_ESPINTCP_MSG
)
332 if (msg
->msg_controllen
)
337 err
= espintcp_push_msgs(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
339 if (err
!= -EAGAIN
|| !(msg
->msg_flags
& MSG_DONTWAIT
))
344 sk_msg_init(&emsg
->skmsg
);
346 /* only -ENOMEM is possible since we don't coalesce */
347 err
= sk_msg_alloc(sk
, &emsg
->skmsg
, msglen
, 0);
351 err
= sk_stream_wait_memory(sk
, &timeo
);
356 *((__be16
*)buf
) = cpu_to_be16(msglen
);
357 pfx_iov
.iov_base
= buf
;
358 pfx_iov
.iov_len
= sizeof(buf
);
359 iov_iter_kvec(&pfx_iter
, WRITE
, &pfx_iov
, 1, pfx_iov
.iov_len
);
361 err
= sk_msg_memcopy_from_iter(sk
, &pfx_iter
, &emsg
->skmsg
,
366 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, &emsg
->skmsg
, size
);
370 end
= emsg
->skmsg
.sg
.end
;
372 sk_msg_iter_var_prev(end
);
373 sg_mark_end(sk_msg_elem(&emsg
->skmsg
, end
));
375 tcp_rate_check_app_limited(sk
);
377 err
= espintcp_push_msgs(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
378 /* this message could be partially sent, keep it */
385 sk_msg_free(sk
, &emsg
->skmsg
);
386 memset(emsg
, 0, sizeof(*emsg
));
392 static struct proto espintcp_prot __ro_after_init
;
393 static struct proto_ops espintcp_ops __ro_after_init
;
394 static struct proto espintcp6_prot
;
395 static struct proto_ops espintcp6_ops
;
396 static DEFINE_MUTEX(tcpv6_prot_mutex
);
398 static void espintcp_data_ready(struct sock
*sk
)
400 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
402 strp_data_ready(&ctx
->strp
);
405 static void espintcp_tx_work(struct work_struct
*work
)
407 struct espintcp_ctx
*ctx
= container_of(work
,
408 struct espintcp_ctx
, work
);
409 struct sock
*sk
= ctx
->strp
.sk
;
412 if (!ctx
->tx_running
)
413 espintcp_push_msgs(sk
, 0);
417 static void espintcp_write_space(struct sock
*sk
)
419 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
421 schedule_work(&ctx
->work
);
422 ctx
->saved_write_space(sk
);
425 static void espintcp_destruct(struct sock
*sk
)
427 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
429 ctx
->saved_destruct(sk
);
433 bool tcp_is_ulp_esp(struct sock
*sk
)
435 return sk
->sk_prot
== &espintcp_prot
|| sk
->sk_prot
== &espintcp6_prot
;
437 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp
);
439 static void build_protos(struct proto
*espintcp_prot
,
440 struct proto_ops
*espintcp_ops
,
441 const struct proto
*orig_prot
,
442 const struct proto_ops
*orig_ops
);
443 static int espintcp_init_sk(struct sock
*sk
)
445 struct inet_connection_sock
*icsk
= inet_csk(sk
);
446 struct strp_callbacks cb
= {
447 .rcv_msg
= espintcp_rcv
,
448 .parse_msg
= espintcp_parse
,
450 struct espintcp_ctx
*ctx
;
453 /* sockmap is not compatible with espintcp */
454 if (sk
->sk_user_data
)
457 ctx
= kzalloc(sizeof(*ctx
), GFP_KERNEL
);
461 err
= strp_init(&ctx
->strp
, sk
, &cb
);
467 strp_check_rcv(&ctx
->strp
);
468 skb_queue_head_init(&ctx
->ike_queue
);
469 skb_queue_head_init(&ctx
->out_queue
);
471 if (sk
->sk_family
== AF_INET
) {
472 sk
->sk_prot
= &espintcp_prot
;
473 sk
->sk_socket
->ops
= &espintcp_ops
;
475 mutex_lock(&tcpv6_prot_mutex
);
476 if (!espintcp6_prot
.recvmsg
)
477 build_protos(&espintcp6_prot
, &espintcp6_ops
, sk
->sk_prot
, sk
->sk_socket
->ops
);
478 mutex_unlock(&tcpv6_prot_mutex
);
480 sk
->sk_prot
= &espintcp6_prot
;
481 sk
->sk_socket
->ops
= &espintcp6_ops
;
483 ctx
->saved_data_ready
= sk
->sk_data_ready
;
484 ctx
->saved_write_space
= sk
->sk_write_space
;
485 ctx
->saved_destruct
= sk
->sk_destruct
;
486 sk
->sk_data_ready
= espintcp_data_ready
;
487 sk
->sk_write_space
= espintcp_write_space
;
488 sk
->sk_destruct
= espintcp_destruct
;
489 rcu_assign_pointer(icsk
->icsk_ulp_data
, ctx
);
490 INIT_WORK(&ctx
->work
, espintcp_tx_work
);
492 /* avoid using task_frag */
493 sk
->sk_allocation
= GFP_ATOMIC
;
502 static void espintcp_release(struct sock
*sk
)
504 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
505 struct sk_buff_head queue
;
508 __skb_queue_head_init(&queue
);
509 skb_queue_splice_init(&ctx
->out_queue
, &queue
);
511 while ((skb
= __skb_dequeue(&queue
)))
512 espintcp_push_skb(sk
, skb
);
517 static void espintcp_close(struct sock
*sk
, long timeout
)
519 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
520 struct espintcp_msg
*emsg
= &ctx
->partial
;
522 strp_stop(&ctx
->strp
);
524 sk
->sk_prot
= &tcp_prot
;
527 cancel_work_sync(&ctx
->work
);
528 strp_done(&ctx
->strp
);
530 skb_queue_purge(&ctx
->out_queue
);
531 skb_queue_purge(&ctx
->ike_queue
);
535 kfree_skb(emsg
->skb
);
537 sk_msg_free(sk
, &emsg
->skmsg
);
540 tcp_close(sk
, timeout
);
543 static __poll_t
espintcp_poll(struct file
*file
, struct socket
*sock
,
546 __poll_t mask
= datagram_poll(file
, sock
, wait
);
547 struct sock
*sk
= sock
->sk
;
548 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
550 if (!skb_queue_empty(&ctx
->ike_queue
))
551 mask
|= EPOLLIN
| EPOLLRDNORM
;
556 static void build_protos(struct proto
*espintcp_prot
,
557 struct proto_ops
*espintcp_ops
,
558 const struct proto
*orig_prot
,
559 const struct proto_ops
*orig_ops
)
561 memcpy(espintcp_prot
, orig_prot
, sizeof(struct proto
));
562 memcpy(espintcp_ops
, orig_ops
, sizeof(struct proto_ops
));
563 espintcp_prot
->sendmsg
= espintcp_sendmsg
;
564 espintcp_prot
->recvmsg
= espintcp_recvmsg
;
565 espintcp_prot
->close
= espintcp_close
;
566 espintcp_prot
->release_cb
= espintcp_release
;
567 espintcp_ops
->poll
= espintcp_poll
;
570 static struct tcp_ulp_ops espintcp_ulp __read_mostly
= {
572 .owner
= THIS_MODULE
,
573 .init
= espintcp_init_sk
,
576 void __init
espintcp_init(void)
578 build_protos(&espintcp_prot
, &espintcp_ops
, &tcp_prot
, &inet_stream_ops
);
580 tcp_register_ulp(&espintcp_ulp
);