1 // SPDX-License-Identifier: GPL-2.0
3 #include <net/strparser.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
10 static void handle_nonesp(struct espintcp_ctx
*ctx
, struct sk_buff
*skb
,
13 if (atomic_read(&sk
->sk_rmem_alloc
) >= sk
->sk_rcvbuf
||
14 !sk_rmem_schedule(sk
, skb
, skb
->truesize
)) {
19 skb_set_owner_r(skb
, sk
);
21 memset(skb
->cb
, 0, sizeof(skb
->cb
));
22 skb_queue_tail(&ctx
->ike_queue
, skb
);
23 ctx
->saved_data_ready(sk
);
26 static void handle_esp(struct sk_buff
*skb
, struct sock
*sk
)
28 skb_reset_transport_header(skb
);
29 memset(skb
->cb
, 0, sizeof(skb
->cb
));
32 skb
->dev
= dev_get_by_index_rcu(sock_net(sk
), skb
->skb_iif
);
34 xfrm4_rcv_encap(skb
, IPPROTO_ESP
, 0, TCP_ENCAP_ESPINTCP
);
39 static void espintcp_rcv(struct strparser
*strp
, struct sk_buff
*skb
)
41 struct espintcp_ctx
*ctx
= container_of(strp
, struct espintcp_ctx
,
43 struct strp_msg
*rxm
= strp_msg(skb
);
47 err
= skb_copy_bits(skb
, rxm
->offset
+ 2, &nonesp_marker
,
48 sizeof(nonesp_marker
));
54 /* remove header, leave non-ESP marker/SPI */
55 if (!__pskb_pull(skb
, rxm
->offset
+ 2)) {
60 if (pskb_trim(skb
, rxm
->full_len
- 2) != 0) {
65 if (nonesp_marker
== 0)
66 handle_nonesp(ctx
, skb
, strp
->sk
);
68 handle_esp(skb
, strp
->sk
);
71 static int espintcp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
73 struct strp_msg
*rxm
= strp_msg(skb
);
78 if (skb
->len
< rxm
->offset
+ 2)
81 err
= skb_copy_bits(skb
, rxm
->offset
, &blen
, sizeof(blen
));
85 len
= be16_to_cpu(blen
);
92 static int espintcp_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
93 int nonblock
, int flags
, int *addr_len
)
95 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
101 flags
|= nonblock
? MSG_DONTWAIT
: 0;
103 skb
= __skb_recv_datagram(sk
, &ctx
->ike_queue
, flags
, NULL
, &off
, &err
);
108 if (copied
> skb
->len
)
110 else if (copied
< skb
->len
)
111 msg
->msg_flags
|= MSG_TRUNC
;
113 err
= skb_copy_datagram_msg(skb
, 0, msg
, copied
);
119 if (flags
& MSG_TRUNC
)
125 int espintcp_queue_out(struct sock
*sk
, struct sk_buff
*skb
)
127 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
129 if (skb_queue_len(&ctx
->out_queue
) >= netdev_max_backlog
)
132 __skb_queue_tail(&ctx
->out_queue
, skb
);
136 EXPORT_SYMBOL_GPL(espintcp_queue_out
);
138 /* espintcp length field is 2B and length includes the length field's size */
139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
141 static int espintcp_sendskb_locked(struct sock
*sk
, struct espintcp_msg
*emsg
,
147 ret
= skb_send_sock_locked(sk
, emsg
->skb
,
148 emsg
->offset
, emsg
->len
);
154 } while (emsg
->len
> 0);
156 kfree_skb(emsg
->skb
);
157 memset(emsg
, 0, sizeof(*emsg
));
162 static int espintcp_sendskmsg_locked(struct sock
*sk
,
163 struct espintcp_msg
*emsg
, int flags
)
165 struct sk_msg
*skmsg
= &emsg
->skmsg
;
166 struct scatterlist
*sg
;
170 flags
|= MSG_SENDPAGE_NOTLAST
;
171 sg
= &skmsg
->sg
.data
[skmsg
->sg
.start
];
173 size_t size
= sg
->length
- emsg
->offset
;
174 int offset
= sg
->offset
+ emsg
->offset
;
180 flags
&= ~MSG_SENDPAGE_NOTLAST
;
184 ret
= do_tcp_sendpages(sk
, p
, offset
, size
, flags
);
186 emsg
->offset
= offset
- sg
->offset
;
187 skmsg
->sg
.start
+= done
;
199 sk_mem_uncharge(sk
, sg
->length
);
203 memset(emsg
, 0, sizeof(*emsg
));
208 static int espintcp_push_msgs(struct sock
*sk
)
210 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
211 struct espintcp_msg
*emsg
= &ctx
->partial
;
222 err
= espintcp_sendskb_locked(sk
, emsg
, 0);
224 err
= espintcp_sendskmsg_locked(sk
, emsg
, 0);
225 if (err
== -EAGAIN
) {
230 memset(emsg
, 0, sizeof(*emsg
));
237 int espintcp_push_skb(struct sock
*sk
, struct sk_buff
*skb
)
239 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
240 struct espintcp_msg
*emsg
= &ctx
->partial
;
244 if (sk
->sk_state
!= TCP_ESTABLISHED
) {
249 offset
= skb_transport_offset(skb
);
250 len
= skb
->len
- offset
;
252 espintcp_push_msgs(sk
);
259 skb_set_owner_w(skb
, sk
);
261 emsg
->offset
= offset
;
265 espintcp_push_msgs(sk
);
269 EXPORT_SYMBOL_GPL(espintcp_push_skb
);
271 static int espintcp_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
273 long timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
274 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
275 struct espintcp_msg
*emsg
= &ctx
->partial
;
276 struct iov_iter pfx_iter
;
277 struct kvec pfx_iov
= {};
278 size_t msglen
= size
+ 2;
285 if (size
> MAX_ESPINTCP_MSG
)
288 if (msg
->msg_controllen
)
293 err
= espintcp_push_msgs(sk
);
299 sk_msg_init(&emsg
->skmsg
);
301 /* only -ENOMEM is possible since we don't coalesce */
302 err
= sk_msg_alloc(sk
, &emsg
->skmsg
, msglen
, 0);
306 err
= sk_stream_wait_memory(sk
, &timeo
);
311 *((__be16
*)buf
) = cpu_to_be16(msglen
);
312 pfx_iov
.iov_base
= buf
;
313 pfx_iov
.iov_len
= sizeof(buf
);
314 iov_iter_kvec(&pfx_iter
, WRITE
, &pfx_iov
, 1, pfx_iov
.iov_len
);
316 err
= sk_msg_memcopy_from_iter(sk
, &pfx_iter
, &emsg
->skmsg
,
321 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, &emsg
->skmsg
, size
);
325 end
= emsg
->skmsg
.sg
.end
;
327 sk_msg_iter_var_prev(end
);
328 sg_mark_end(sk_msg_elem(&emsg
->skmsg
, end
));
330 tcp_rate_check_app_limited(sk
);
332 err
= espintcp_push_msgs(sk
);
333 /* this message could be partially sent, keep it */
341 sk_msg_free(sk
, &emsg
->skmsg
);
342 memset(emsg
, 0, sizeof(*emsg
));
348 static struct proto espintcp_prot __ro_after_init
;
349 static struct proto_ops espintcp_ops __ro_after_init
;
351 static void espintcp_data_ready(struct sock
*sk
)
353 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
355 strp_data_ready(&ctx
->strp
);
358 static void espintcp_tx_work(struct work_struct
*work
)
360 struct espintcp_ctx
*ctx
= container_of(work
,
361 struct espintcp_ctx
, work
);
362 struct sock
*sk
= ctx
->strp
.sk
;
365 if (!ctx
->tx_running
)
366 espintcp_push_msgs(sk
);
370 static void espintcp_write_space(struct sock
*sk
)
372 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
374 schedule_work(&ctx
->work
);
375 ctx
->saved_write_space(sk
);
378 static void espintcp_destruct(struct sock
*sk
)
380 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
385 bool tcp_is_ulp_esp(struct sock
*sk
)
387 return sk
->sk_prot
== &espintcp_prot
;
389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp
);
391 static int espintcp_init_sk(struct sock
*sk
)
393 struct inet_connection_sock
*icsk
= inet_csk(sk
);
394 struct strp_callbacks cb
= {
395 .rcv_msg
= espintcp_rcv
,
396 .parse_msg
= espintcp_parse
,
398 struct espintcp_ctx
*ctx
;
401 /* sockmap is not compatible with espintcp */
402 if (sk
->sk_user_data
)
405 ctx
= kzalloc(sizeof(*ctx
), GFP_KERNEL
);
409 err
= strp_init(&ctx
->strp
, sk
, &cb
);
415 strp_check_rcv(&ctx
->strp
);
416 skb_queue_head_init(&ctx
->ike_queue
);
417 skb_queue_head_init(&ctx
->out_queue
);
418 sk
->sk_prot
= &espintcp_prot
;
419 sk
->sk_socket
->ops
= &espintcp_ops
;
420 ctx
->saved_data_ready
= sk
->sk_data_ready
;
421 ctx
->saved_write_space
= sk
->sk_write_space
;
422 sk
->sk_data_ready
= espintcp_data_ready
;
423 sk
->sk_write_space
= espintcp_write_space
;
424 sk
->sk_destruct
= espintcp_destruct
;
425 rcu_assign_pointer(icsk
->icsk_ulp_data
, ctx
);
426 INIT_WORK(&ctx
->work
, espintcp_tx_work
);
428 /* avoid using task_frag */
429 sk
->sk_allocation
= GFP_ATOMIC
;
438 static void espintcp_release(struct sock
*sk
)
440 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
441 struct sk_buff_head queue
;
444 __skb_queue_head_init(&queue
);
445 skb_queue_splice_init(&ctx
->out_queue
, &queue
);
447 while ((skb
= __skb_dequeue(&queue
)))
448 espintcp_push_skb(sk
, skb
);
453 static void espintcp_close(struct sock
*sk
, long timeout
)
455 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
456 struct espintcp_msg
*emsg
= &ctx
->partial
;
458 strp_stop(&ctx
->strp
);
460 sk
->sk_prot
= &tcp_prot
;
463 cancel_work_sync(&ctx
->work
);
464 strp_done(&ctx
->strp
);
466 skb_queue_purge(&ctx
->out_queue
);
467 skb_queue_purge(&ctx
->ike_queue
);
471 kfree_skb(emsg
->skb
);
473 sk_msg_free(sk
, &emsg
->skmsg
);
476 tcp_close(sk
, timeout
);
479 static __poll_t
espintcp_poll(struct file
*file
, struct socket
*sock
,
482 __poll_t mask
= datagram_poll(file
, sock
, wait
);
483 struct sock
*sk
= sock
->sk
;
484 struct espintcp_ctx
*ctx
= espintcp_getctx(sk
);
486 if (!skb_queue_empty(&ctx
->ike_queue
))
487 mask
|= EPOLLIN
| EPOLLRDNORM
;
492 static struct tcp_ulp_ops espintcp_ulp __read_mostly
= {
494 .owner
= THIS_MODULE
,
495 .init
= espintcp_init_sk
,
498 void __init
espintcp_init(void)
500 memcpy(&espintcp_prot
, &tcp_prot
, sizeof(tcp_prot
));
501 memcpy(&espintcp_ops
, &inet_stream_ops
, sizeof(inet_stream_ops
));
502 espintcp_prot
.sendmsg
= espintcp_sendmsg
;
503 espintcp_prot
.recvmsg
= espintcp_recvmsg
;
504 espintcp_prot
.close
= espintcp_close
;
505 espintcp_prot
.release_cb
= espintcp_release
;
506 espintcp_ops
.poll
= espintcp_poll
;
508 tcp_register_ulp(&espintcp_ulp
);