treewide: remove redundant IS_ERR() before error code check
[linux/fpc-iii.git] / net / xfrm / espintcp.c
blobf15d6a564b0e63cc8d729d524043ab5531a84a62
1 // SPDX-License-Identifier: GPL-2.0
2 #include <net/tcp.h>
3 #include <net/strparser.h>
4 #include <net/xfrm.h>
5 #include <net/esp.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
10 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
11 struct sock *sk)
13 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
14 !sk_rmem_schedule(sk, skb, skb->truesize)) {
15 kfree_skb(skb);
16 return;
19 skb_set_owner_r(skb, sk);
21 memset(skb->cb, 0, sizeof(skb->cb));
22 skb_queue_tail(&ctx->ike_queue, skb);
23 ctx->saved_data_ready(sk);
26 static void handle_esp(struct sk_buff *skb, struct sock *sk)
28 skb_reset_transport_header(skb);
29 memset(skb->cb, 0, sizeof(skb->cb));
31 rcu_read_lock();
32 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
33 local_bh_disable();
34 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
35 local_bh_enable();
36 rcu_read_unlock();
39 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
41 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
42 strp);
43 struct strp_msg *rxm = strp_msg(skb);
44 u32 nonesp_marker;
45 int err;
47 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
48 sizeof(nonesp_marker));
49 if (err < 0) {
50 kfree_skb(skb);
51 return;
54 /* remove header, leave non-ESP marker/SPI */
55 if (!__pskb_pull(skb, rxm->offset + 2)) {
56 kfree_skb(skb);
57 return;
60 if (pskb_trim(skb, rxm->full_len - 2) != 0) {
61 kfree_skb(skb);
62 return;
65 if (nonesp_marker == 0)
66 handle_nonesp(ctx, skb, strp->sk);
67 else
68 handle_esp(skb, strp->sk);
71 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
73 struct strp_msg *rxm = strp_msg(skb);
74 __be16 blen;
75 u16 len;
76 int err;
78 if (skb->len < rxm->offset + 2)
79 return 0;
81 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
82 if (err < 0)
83 return err;
85 len = be16_to_cpu(blen);
86 if (len < 6)
87 return -EINVAL;
89 return len;
92 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
93 int nonblock, int flags, int *addr_len)
95 struct espintcp_ctx *ctx = espintcp_getctx(sk);
96 struct sk_buff *skb;
97 int err = 0;
98 int copied;
99 int off = 0;
101 flags |= nonblock ? MSG_DONTWAIT : 0;
103 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, NULL, &off, &err);
104 if (!skb)
105 return err;
107 copied = len;
108 if (copied > skb->len)
109 copied = skb->len;
110 else if (copied < skb->len)
111 msg->msg_flags |= MSG_TRUNC;
113 err = skb_copy_datagram_msg(skb, 0, msg, copied);
114 if (unlikely(err)) {
115 kfree_skb(skb);
116 return err;
119 if (flags & MSG_TRUNC)
120 copied = skb->len;
121 kfree_skb(skb);
122 return copied;
125 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
127 struct espintcp_ctx *ctx = espintcp_getctx(sk);
129 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
130 return -ENOBUFS;
132 __skb_queue_tail(&ctx->out_queue, skb);
134 return 0;
136 EXPORT_SYMBOL_GPL(espintcp_queue_out);
138 /* espintcp length field is 2B and length includes the length field's size */
139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
141 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
142 int flags)
144 do {
145 int ret;
147 ret = skb_send_sock_locked(sk, emsg->skb,
148 emsg->offset, emsg->len);
149 if (ret < 0)
150 return ret;
152 emsg->len -= ret;
153 emsg->offset += ret;
154 } while (emsg->len > 0);
156 kfree_skb(emsg->skb);
157 memset(emsg, 0, sizeof(*emsg));
159 return 0;
162 static int espintcp_sendskmsg_locked(struct sock *sk,
163 struct espintcp_msg *emsg, int flags)
165 struct sk_msg *skmsg = &emsg->skmsg;
166 struct scatterlist *sg;
167 int done = 0;
168 int ret;
170 flags |= MSG_SENDPAGE_NOTLAST;
171 sg = &skmsg->sg.data[skmsg->sg.start];
172 do {
173 size_t size = sg->length - emsg->offset;
174 int offset = sg->offset + emsg->offset;
175 struct page *p;
177 emsg->offset = 0;
179 if (sg_is_last(sg))
180 flags &= ~MSG_SENDPAGE_NOTLAST;
182 p = sg_page(sg);
183 retry:
184 ret = do_tcp_sendpages(sk, p, offset, size, flags);
185 if (ret < 0) {
186 emsg->offset = offset - sg->offset;
187 skmsg->sg.start += done;
188 return ret;
191 if (ret != size) {
192 offset += ret;
193 size -= ret;
194 goto retry;
197 done++;
198 put_page(p);
199 sk_mem_uncharge(sk, sg->length);
200 sg = sg_next(sg);
201 } while (sg);
203 memset(emsg, 0, sizeof(*emsg));
205 return 0;
208 static int espintcp_push_msgs(struct sock *sk)
210 struct espintcp_ctx *ctx = espintcp_getctx(sk);
211 struct espintcp_msg *emsg = &ctx->partial;
212 int err;
214 if (!emsg->len)
215 return 0;
217 if (ctx->tx_running)
218 return -EAGAIN;
219 ctx->tx_running = 1;
221 if (emsg->skb)
222 err = espintcp_sendskb_locked(sk, emsg, 0);
223 else
224 err = espintcp_sendskmsg_locked(sk, emsg, 0);
225 if (err == -EAGAIN) {
226 ctx->tx_running = 0;
227 return 0;
229 if (!err)
230 memset(emsg, 0, sizeof(*emsg));
232 ctx->tx_running = 0;
234 return err;
237 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
239 struct espintcp_ctx *ctx = espintcp_getctx(sk);
240 struct espintcp_msg *emsg = &ctx->partial;
241 unsigned int len;
242 int offset;
244 if (sk->sk_state != TCP_ESTABLISHED) {
245 kfree_skb(skb);
246 return -ECONNRESET;
249 offset = skb_transport_offset(skb);
250 len = skb->len - offset;
252 espintcp_push_msgs(sk);
254 if (emsg->len) {
255 kfree_skb(skb);
256 return -ENOBUFS;
259 skb_set_owner_w(skb, sk);
261 emsg->offset = offset;
262 emsg->len = len;
263 emsg->skb = skb;
265 espintcp_push_msgs(sk);
267 return 0;
269 EXPORT_SYMBOL_GPL(espintcp_push_skb);
271 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
273 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
274 struct espintcp_ctx *ctx = espintcp_getctx(sk);
275 struct espintcp_msg *emsg = &ctx->partial;
276 struct iov_iter pfx_iter;
277 struct kvec pfx_iov = {};
278 size_t msglen = size + 2;
279 char buf[2] = {0};
280 int err, end;
282 if (msg->msg_flags)
283 return -EOPNOTSUPP;
285 if (size > MAX_ESPINTCP_MSG)
286 return -EMSGSIZE;
288 if (msg->msg_controllen)
289 return -EOPNOTSUPP;
291 lock_sock(sk);
293 err = espintcp_push_msgs(sk);
294 if (err < 0) {
295 err = -ENOBUFS;
296 goto unlock;
299 sk_msg_init(&emsg->skmsg);
300 while (1) {
301 /* only -ENOMEM is possible since we don't coalesce */
302 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
303 if (!err)
304 break;
306 err = sk_stream_wait_memory(sk, &timeo);
307 if (err)
308 goto fail;
311 *((__be16 *)buf) = cpu_to_be16(msglen);
312 pfx_iov.iov_base = buf;
313 pfx_iov.iov_len = sizeof(buf);
314 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
316 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
317 pfx_iov.iov_len);
318 if (err < 0)
319 goto fail;
321 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
322 if (err < 0)
323 goto fail;
325 end = emsg->skmsg.sg.end;
326 emsg->len = size;
327 sk_msg_iter_var_prev(end);
328 sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
330 tcp_rate_check_app_limited(sk);
332 err = espintcp_push_msgs(sk);
333 /* this message could be partially sent, keep it */
334 if (err < 0)
335 goto unlock;
336 release_sock(sk);
338 return size;
340 fail:
341 sk_msg_free(sk, &emsg->skmsg);
342 memset(emsg, 0, sizeof(*emsg));
343 unlock:
344 release_sock(sk);
345 return err;
348 static struct proto espintcp_prot __ro_after_init;
349 static struct proto_ops espintcp_ops __ro_after_init;
351 static void espintcp_data_ready(struct sock *sk)
353 struct espintcp_ctx *ctx = espintcp_getctx(sk);
355 strp_data_ready(&ctx->strp);
358 static void espintcp_tx_work(struct work_struct *work)
360 struct espintcp_ctx *ctx = container_of(work,
361 struct espintcp_ctx, work);
362 struct sock *sk = ctx->strp.sk;
364 lock_sock(sk);
365 if (!ctx->tx_running)
366 espintcp_push_msgs(sk);
367 release_sock(sk);
370 static void espintcp_write_space(struct sock *sk)
372 struct espintcp_ctx *ctx = espintcp_getctx(sk);
374 schedule_work(&ctx->work);
375 ctx->saved_write_space(sk);
378 static void espintcp_destruct(struct sock *sk)
380 struct espintcp_ctx *ctx = espintcp_getctx(sk);
382 kfree(ctx);
385 bool tcp_is_ulp_esp(struct sock *sk)
387 return sk->sk_prot == &espintcp_prot;
389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
391 static int espintcp_init_sk(struct sock *sk)
393 struct inet_connection_sock *icsk = inet_csk(sk);
394 struct strp_callbacks cb = {
395 .rcv_msg = espintcp_rcv,
396 .parse_msg = espintcp_parse,
398 struct espintcp_ctx *ctx;
399 int err;
401 /* sockmap is not compatible with espintcp */
402 if (sk->sk_user_data)
403 return -EBUSY;
405 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
406 if (!ctx)
407 return -ENOMEM;
409 err = strp_init(&ctx->strp, sk, &cb);
410 if (err)
411 goto free;
413 __sk_dst_reset(sk);
415 strp_check_rcv(&ctx->strp);
416 skb_queue_head_init(&ctx->ike_queue);
417 skb_queue_head_init(&ctx->out_queue);
418 sk->sk_prot = &espintcp_prot;
419 sk->sk_socket->ops = &espintcp_ops;
420 ctx->saved_data_ready = sk->sk_data_ready;
421 ctx->saved_write_space = sk->sk_write_space;
422 sk->sk_data_ready = espintcp_data_ready;
423 sk->sk_write_space = espintcp_write_space;
424 sk->sk_destruct = espintcp_destruct;
425 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
426 INIT_WORK(&ctx->work, espintcp_tx_work);
428 /* avoid using task_frag */
429 sk->sk_allocation = GFP_ATOMIC;
431 return 0;
433 free:
434 kfree(ctx);
435 return err;
438 static void espintcp_release(struct sock *sk)
440 struct espintcp_ctx *ctx = espintcp_getctx(sk);
441 struct sk_buff_head queue;
442 struct sk_buff *skb;
444 __skb_queue_head_init(&queue);
445 skb_queue_splice_init(&ctx->out_queue, &queue);
447 while ((skb = __skb_dequeue(&queue)))
448 espintcp_push_skb(sk, skb);
450 tcp_release_cb(sk);
453 static void espintcp_close(struct sock *sk, long timeout)
455 struct espintcp_ctx *ctx = espintcp_getctx(sk);
456 struct espintcp_msg *emsg = &ctx->partial;
458 strp_stop(&ctx->strp);
460 sk->sk_prot = &tcp_prot;
461 barrier();
463 cancel_work_sync(&ctx->work);
464 strp_done(&ctx->strp);
466 skb_queue_purge(&ctx->out_queue);
467 skb_queue_purge(&ctx->ike_queue);
469 if (emsg->len) {
470 if (emsg->skb)
471 kfree_skb(emsg->skb);
472 else
473 sk_msg_free(sk, &emsg->skmsg);
476 tcp_close(sk, timeout);
479 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
480 poll_table *wait)
482 __poll_t mask = datagram_poll(file, sock, wait);
483 struct sock *sk = sock->sk;
484 struct espintcp_ctx *ctx = espintcp_getctx(sk);
486 if (!skb_queue_empty(&ctx->ike_queue))
487 mask |= EPOLLIN | EPOLLRDNORM;
489 return mask;
492 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
493 .name = "espintcp",
494 .owner = THIS_MODULE,
495 .init = espintcp_init_sk,
498 void __init espintcp_init(void)
500 memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
501 memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
502 espintcp_prot.sendmsg = espintcp_sendmsg;
503 espintcp_prot.recvmsg = espintcp_recvmsg;
504 espintcp_prot.close = espintcp_close;
505 espintcp_prot.release_cb = espintcp_release;
506 espintcp_ops.poll = espintcp_poll;
508 tcp_register_ulp(&espintcp_ulp);