1 // SPDX-License-Identifier: GPL-2.0-or-later
3 * Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org>
4 * Copyright (C) 2018 Samsung Electronics Co., Ltd.
7 #include <linux/freezer.h>
9 #include "smb_common.h"
12 #include "connection.h"
13 #include "transport_tcp.h"
15 #define IFACE_STATE_DOWN BIT(0)
16 #define IFACE_STATE_CONFIGURED BIT(1)
18 static atomic_t active_num_conn
;
21 struct task_struct
*ksmbd_kthread
;
22 struct socket
*ksmbd_socket
;
23 struct list_head entry
;
25 struct mutex sock_release_lock
;
29 static LIST_HEAD(iface_list
);
31 static int bind_additional_ifaces
;
33 struct tcp_transport
{
34 struct ksmbd_transport transport
;
40 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops
;
42 static void tcp_stop_kthread(struct task_struct
*kthread
);
43 static struct interface
*alloc_iface(char *ifname
);
45 #define KSMBD_TRANS(t) (&(t)->transport)
46 #define TCP_TRANS(t) ((struct tcp_transport *)container_of(t, \
47 struct tcp_transport, transport))
49 static inline void ksmbd_tcp_nodelay(struct socket
*sock
)
51 tcp_sock_set_nodelay(sock
->sk
);
54 static inline void ksmbd_tcp_reuseaddr(struct socket
*sock
)
56 sock_set_reuseaddr(sock
->sk
);
59 static inline void ksmbd_tcp_rcv_timeout(struct socket
*sock
, s64 secs
)
62 if (secs
&& secs
< MAX_SCHEDULE_TIMEOUT
/ HZ
- 1)
63 sock
->sk
->sk_rcvtimeo
= secs
* HZ
;
65 sock
->sk
->sk_rcvtimeo
= MAX_SCHEDULE_TIMEOUT
;
66 release_sock(sock
->sk
);
69 static inline void ksmbd_tcp_snd_timeout(struct socket
*sock
, s64 secs
)
71 sock_set_sndtimeo(sock
->sk
, secs
);
74 static struct tcp_transport
*alloc_transport(struct socket
*client_sk
)
76 struct tcp_transport
*t
;
77 struct ksmbd_conn
*conn
;
79 t
= kzalloc(sizeof(*t
), GFP_KERNEL
);
84 conn
= ksmbd_conn_alloc();
90 conn
->transport
= KSMBD_TRANS(t
);
91 KSMBD_TRANS(t
)->conn
= conn
;
92 KSMBD_TRANS(t
)->ops
= &ksmbd_tcp_transport_ops
;
96 static void free_transport(struct tcp_transport
*t
)
98 kernel_sock_shutdown(t
->sock
, SHUT_RDWR
);
99 sock_release(t
->sock
);
102 ksmbd_conn_free(KSMBD_TRANS(t
)->conn
);
108 * kvec_array_init() - initialize a IO vector segment
109 * @new: IO vector to be initialized
110 * @iov: base IO vector
111 * @nr_segs: number of segments in base iov
112 * @bytes: total iovec length so far for read
114 * Return: Number of IO segments
116 static unsigned int kvec_array_init(struct kvec
*new, struct kvec
*iov
,
117 unsigned int nr_segs
, size_t bytes
)
121 while (bytes
|| !iov
->iov_len
) {
122 int copy
= min(bytes
, iov
->iov_len
);
126 if (iov
->iov_len
== base
) {
133 memcpy(new, iov
, sizeof(*iov
) * nr_segs
);
134 new->iov_base
+= base
;
135 new->iov_len
-= base
;
140 * get_conn_iovec() - get connection iovec for reading from socket
141 * @t: TCP transport instance
142 * @nr_segs: number of segments in iov
144 * Return: return existing or newly allocate iovec
146 static struct kvec
*get_conn_iovec(struct tcp_transport
*t
, unsigned int nr_segs
)
148 struct kvec
*new_iov
;
150 if (t
->iov
&& nr_segs
<= t
->nr_iov
)
153 /* not big enough -- allocate a new one and release the old */
154 new_iov
= kmalloc_array(nr_segs
, sizeof(*new_iov
), GFP_KERNEL
);
163 static unsigned short ksmbd_tcp_get_port(const struct sockaddr
*sa
)
165 switch (sa
->sa_family
) {
167 return ntohs(((struct sockaddr_in
*)sa
)->sin_port
);
169 return ntohs(((struct sockaddr_in6
*)sa
)->sin6_port
);
175 * ksmbd_tcp_new_connection() - create a new tcp session on mount
176 * @client_sk: socket associated with new connection
178 * whenever a new connection is requested, create a conn thread
179 * (session thread) to handle new incoming smb requests from the connection
181 * Return: 0 on success, otherwise error
183 static int ksmbd_tcp_new_connection(struct socket
*client_sk
)
185 struct sockaddr
*csin
;
187 struct tcp_transport
*t
;
188 struct task_struct
*handler
;
190 t
= alloc_transport(client_sk
);
192 sock_release(client_sk
);
196 csin
= KSMBD_TCP_PEER_SOCKADDR(KSMBD_TRANS(t
)->conn
);
197 if (kernel_getpeername(client_sk
, csin
) < 0) {
198 pr_err("client ip resolution failed\n");
203 handler
= kthread_run(ksmbd_conn_handler_loop
,
204 KSMBD_TRANS(t
)->conn
,
206 ksmbd_tcp_get_port(csin
));
207 if (IS_ERR(handler
)) {
208 pr_err("cannot start conn thread\n");
209 rc
= PTR_ERR(handler
);
220 * ksmbd_kthread_fn() - listen to new SMB connections and callback server
221 * @p: arguments to forker thread
223 * Return: 0 on success, error number otherwise
225 static int ksmbd_kthread_fn(void *p
)
227 struct socket
*client_sk
= NULL
;
228 struct interface
*iface
= (struct interface
*)p
;
231 while (!kthread_should_stop()) {
232 mutex_lock(&iface
->sock_release_lock
);
233 if (!iface
->ksmbd_socket
) {
234 mutex_unlock(&iface
->sock_release_lock
);
237 ret
= kernel_accept(iface
->ksmbd_socket
, &client_sk
,
239 mutex_unlock(&iface
->sock_release_lock
);
242 /* check for new connections every 100 msecs */
243 schedule_timeout_interruptible(HZ
/ 10);
247 if (server_conf
.max_connections
&&
248 atomic_inc_return(&active_num_conn
) >= server_conf
.max_connections
) {
249 pr_info_ratelimited("Limit the maximum number of connections(%u)\n",
250 atomic_read(&active_num_conn
));
251 atomic_dec(&active_num_conn
);
252 sock_release(client_sk
);
256 ksmbd_debug(CONN
, "connect success: accepted new connection\n");
257 client_sk
->sk
->sk_rcvtimeo
= KSMBD_TCP_RECV_TIMEOUT
;
258 client_sk
->sk
->sk_sndtimeo
= KSMBD_TCP_SEND_TIMEOUT
;
260 ksmbd_tcp_new_connection(client_sk
);
263 ksmbd_debug(CONN
, "releasing socket\n");
268 * ksmbd_tcp_run_kthread() - start forker thread
269 * @iface: pointer to struct interface
271 * start forker thread(ksmbd/0) at module init time to listen
272 * on port 445 for new SMB connection requests. It creates per connection
273 * server threads(ksmbd/x)
275 * Return: 0 on success or error number
277 static int ksmbd_tcp_run_kthread(struct interface
*iface
)
280 struct task_struct
*kthread
;
282 kthread
= kthread_run(ksmbd_kthread_fn
, (void *)iface
, "ksmbd-%s",
284 if (IS_ERR(kthread
)) {
285 rc
= PTR_ERR(kthread
);
288 iface
->ksmbd_kthread
= kthread
;
294 * ksmbd_tcp_readv() - read data from socket in given iovec
295 * @t: TCP transport instance
296 * @iov_orig: base IO vector
297 * @nr_segs: number of segments in base iov
298 * @to_read: number of bytes to read from socket
299 * @max_retries: maximum retry count
301 * Return: on success return number of bytes read from socket,
302 * otherwise return error number
304 static int ksmbd_tcp_readv(struct tcp_transport
*t
, struct kvec
*iov_orig
,
305 unsigned int nr_segs
, unsigned int to_read
,
311 struct msghdr ksmbd_msg
;
313 struct ksmbd_conn
*conn
= KSMBD_TRANS(t
)->conn
;
315 iov
= get_conn_iovec(t
, nr_segs
);
319 ksmbd_msg
.msg_control
= NULL
;
320 ksmbd_msg
.msg_controllen
= 0;
322 for (total_read
= 0; to_read
; total_read
+= length
, to_read
-= length
) {
325 if (!ksmbd_conn_alive(conn
)) {
326 total_read
= -ESHUTDOWN
;
329 segs
= kvec_array_init(iov
, iov_orig
, nr_segs
, total_read
);
331 length
= kernel_recvmsg(t
->sock
, &ksmbd_msg
,
332 iov
, segs
, to_read
, 0);
334 if (length
== -EINTR
) {
335 total_read
= -ESHUTDOWN
;
337 } else if (ksmbd_conn_need_reconnect(conn
)) {
338 total_read
= -EAGAIN
;
340 } else if (length
== -ERESTARTSYS
|| length
== -EAGAIN
) {
342 * If max_retries is negative, Allow unlimited
343 * retries to keep connection with inactive sessions.
345 if (max_retries
== 0) {
348 } else if (max_retries
> 0) {
352 usleep_range(1000, 2000);
355 } else if (length
<= 0) {
364 * ksmbd_tcp_read() - read data from socket in given buffer
365 * @t: TCP transport instance
366 * @buf: buffer to store read data from socket
367 * @to_read: number of bytes to read from socket
368 * @max_retries: number of retries if reading from socket fails
370 * Return: on success return number of bytes read from socket,
371 * otherwise return error number
373 static int ksmbd_tcp_read(struct ksmbd_transport
*t
, char *buf
,
374 unsigned int to_read
, int max_retries
)
379 iov
.iov_len
= to_read
;
381 return ksmbd_tcp_readv(TCP_TRANS(t
), &iov
, 1, to_read
, max_retries
);
384 static int ksmbd_tcp_writev(struct ksmbd_transport
*t
, struct kvec
*iov
,
385 int nvecs
, int size
, bool need_invalidate
,
386 unsigned int remote_key
)
389 struct msghdr smb_msg
= {.msg_flags
= MSG_NOSIGNAL
};
391 return kernel_sendmsg(TCP_TRANS(t
)->sock
, &smb_msg
, iov
, nvecs
, size
);
394 static void ksmbd_tcp_disconnect(struct ksmbd_transport
*t
)
396 free_transport(TCP_TRANS(t
));
397 if (server_conf
.max_connections
)
398 atomic_dec(&active_num_conn
);
401 static void tcp_destroy_socket(struct socket
*ksmbd_socket
)
408 /* set zero to timeout */
409 ksmbd_tcp_rcv_timeout(ksmbd_socket
, 0);
410 ksmbd_tcp_snd_timeout(ksmbd_socket
, 0);
412 ret
= kernel_sock_shutdown(ksmbd_socket
, SHUT_RDWR
);
414 pr_err("Failed to shutdown socket: %d\n", ret
);
415 sock_release(ksmbd_socket
);
419 * create_socket - create socket for ksmbd/0
420 * @iface: interface to bind the created socket to
422 * Return: 0 on success, error number otherwise
424 static int create_socket(struct interface
*iface
)
427 struct sockaddr_in6 sin6
;
428 struct sockaddr_in sin
;
429 struct socket
*ksmbd_socket
;
432 ret
= sock_create(PF_INET6
, SOCK_STREAM
, IPPROTO_TCP
, &ksmbd_socket
);
434 if (ret
!= -EAFNOSUPPORT
)
435 pr_err("Can't create socket for ipv6, fallback to ipv4: %d\n", ret
);
436 ret
= sock_create(PF_INET
, SOCK_STREAM
, IPPROTO_TCP
,
439 pr_err("Can't create socket for ipv4: %d\n", ret
);
443 sin
.sin_family
= PF_INET
;
444 sin
.sin_addr
.s_addr
= htonl(INADDR_ANY
);
445 sin
.sin_port
= htons(server_conf
.tcp_port
);
448 sin6
.sin6_family
= PF_INET6
;
449 sin6
.sin6_addr
= in6addr_any
;
450 sin6
.sin6_port
= htons(server_conf
.tcp_port
);
452 lock_sock(ksmbd_socket
->sk
);
453 ksmbd_socket
->sk
->sk_ipv6only
= false;
454 release_sock(ksmbd_socket
->sk
);
457 ksmbd_tcp_nodelay(ksmbd_socket
);
458 ksmbd_tcp_reuseaddr(ksmbd_socket
);
460 ret
= sock_setsockopt(ksmbd_socket
,
463 KERNEL_SOCKPTR(iface
->name
),
464 strlen(iface
->name
));
465 if (ret
!= -ENODEV
&& ret
< 0) {
466 pr_err("Failed to set SO_BINDTODEVICE: %d\n", ret
);
471 ret
= kernel_bind(ksmbd_socket
, (struct sockaddr
*)&sin
,
474 ret
= kernel_bind(ksmbd_socket
, (struct sockaddr
*)&sin6
,
477 pr_err("Failed to bind socket: %d\n", ret
);
481 ksmbd_socket
->sk
->sk_rcvtimeo
= KSMBD_TCP_RECV_TIMEOUT
;
482 ksmbd_socket
->sk
->sk_sndtimeo
= KSMBD_TCP_SEND_TIMEOUT
;
484 ret
= kernel_listen(ksmbd_socket
, KSMBD_SOCKET_BACKLOG
);
486 pr_err("Port listen() error: %d\n", ret
);
490 iface
->ksmbd_socket
= ksmbd_socket
;
491 ret
= ksmbd_tcp_run_kthread(iface
);
493 pr_err("Can't start ksmbd main kthread: %d\n", ret
);
496 iface
->state
= IFACE_STATE_CONFIGURED
;
501 tcp_destroy_socket(ksmbd_socket
);
503 iface
->ksmbd_socket
= NULL
;
507 static int ksmbd_netdev_event(struct notifier_block
*nb
, unsigned long event
,
510 struct net_device
*netdev
= netdev_notifier_info_to_dev(ptr
);
511 struct interface
*iface
;
516 if (netif_is_bridge_port(netdev
))
519 list_for_each_entry(iface
, &iface_list
, entry
) {
520 if (!strcmp(iface
->name
, netdev
->name
)) {
522 if (iface
->state
!= IFACE_STATE_DOWN
)
524 ret
= create_socket(iface
);
530 if (!found
&& bind_additional_ifaces
) {
531 iface
= alloc_iface(kstrdup(netdev
->name
, GFP_KERNEL
));
534 ret
= create_socket(iface
);
540 list_for_each_entry(iface
, &iface_list
, entry
) {
541 if (!strcmp(iface
->name
, netdev
->name
) &&
542 iface
->state
== IFACE_STATE_CONFIGURED
) {
543 tcp_stop_kthread(iface
->ksmbd_kthread
);
544 iface
->ksmbd_kthread
= NULL
;
545 mutex_lock(&iface
->sock_release_lock
);
546 tcp_destroy_socket(iface
->ksmbd_socket
);
547 iface
->ksmbd_socket
= NULL
;
548 mutex_unlock(&iface
->sock_release_lock
);
550 iface
->state
= IFACE_STATE_DOWN
;
560 static struct notifier_block ksmbd_netdev_notifier
= {
561 .notifier_call
= ksmbd_netdev_event
,
564 int ksmbd_tcp_init(void)
566 register_netdevice_notifier(&ksmbd_netdev_notifier
);
571 static void tcp_stop_kthread(struct task_struct
*kthread
)
578 ret
= kthread_stop(kthread
);
580 pr_err("failed to stop forker thread\n");
583 void ksmbd_tcp_destroy(void)
585 struct interface
*iface
, *tmp
;
587 unregister_netdevice_notifier(&ksmbd_netdev_notifier
);
589 list_for_each_entry_safe(iface
, tmp
, &iface_list
, entry
) {
590 list_del(&iface
->entry
);
596 static struct interface
*alloc_iface(char *ifname
)
598 struct interface
*iface
;
603 iface
= kzalloc(sizeof(struct interface
), GFP_KERNEL
);
609 iface
->name
= ifname
;
610 iface
->state
= IFACE_STATE_DOWN
;
611 list_add(&iface
->entry
, &iface_list
);
612 mutex_init(&iface
->sock_release_lock
);
616 int ksmbd_tcp_set_interfaces(char *ifc_list
, int ifc_list_sz
)
621 struct net_device
*netdev
;
624 for_each_netdev(&init_net
, netdev
) {
625 if (netif_is_bridge_port(netdev
))
627 if (!alloc_iface(kstrdup(netdev
->name
, GFP_KERNEL
))) {
633 bind_additional_ifaces
= 1;
637 while (ifc_list_sz
> 0) {
638 if (!alloc_iface(kstrdup(ifc_list
, GFP_KERNEL
)))
641 sz
= strlen(ifc_list
);
646 ifc_list_sz
-= (sz
+ 1);
649 bind_additional_ifaces
= 0;
654 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops
= {
655 .read
= ksmbd_tcp_read
,
656 .writev
= ksmbd_tcp_writev
,
657 .disconnect
= ksmbd_tcp_disconnect
,