2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #include "got_compat.h"
19 #include <sys/types.h>
20 #include <sys/queue.h>
21 #include <sys/socket.h>
35 #include "got_error.h"
38 #include "got_compat.h"
45 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
48 struct gotd_listen_client
{
49 STAILQ_ENTRY(gotd_listen_client
) entry
;
54 STAILQ_HEAD(gotd_listen_clients
, gotd_listen_client
);
56 static struct gotd_listen_clients gotd_listen_clients
[GOTD_CLIENT_TABLE_SIZE
];
57 static SIPHASH_KEY clients_hash_key
;
58 static volatile int listen_client_cnt
;
61 struct gotd_uid_connection_counter
{
62 STAILQ_ENTRY(gotd_uid_connection_counter
) entry
;
66 STAILQ_HEAD(gotd_client_uids
, gotd_uid_connection_counter
);
67 static struct gotd_client_uids gotd_client_uids
[GOTD_CLIENT_TABLE_SIZE
];
68 static SIPHASH_KEY uid_hash_key
;
74 struct gotd_imsgev iev
;
75 struct gotd_imsgev pause
;
76 struct gotd_uid_connection_limit
*connection_limits
;
77 size_t nconnection_limits
;
82 static void listen_shutdown(void);
85 listen_sighdlr(int sig
, short event
, void *arg
)
88 * Normal signal handler rules don't apply because libevent
103 fatalx("unexpected signal");
108 client_hash(uint32_t client_id
)
110 return SipHash24(&clients_hash_key
, &client_id
, sizeof(client_id
));
114 add_client(struct gotd_listen_client
*client
)
116 uint64_t slot
= client_hash(client
->id
) % nitems(gotd_listen_clients
);
117 STAILQ_INSERT_HEAD(&gotd_listen_clients
[slot
], client
, entry
);
121 static struct gotd_listen_client
*
122 find_client(uint32_t client_id
)
125 struct gotd_listen_client
*c
;
127 slot
= client_hash(client_id
) % nitems(gotd_listen_clients
);
128 STAILQ_FOREACH(c
, &gotd_listen_clients
[slot
], entry
) {
129 if (c
->id
== client_id
)
144 duplicate
= (find_client(id
) != NULL
);
145 } while (duplicate
|| id
== 0);
153 return SipHash24(&uid_hash_key
, &euid
, sizeof(euid
));
157 add_uid_connection_counter(struct gotd_uid_connection_counter
*counter
)
159 uint64_t slot
= uid_hash(counter
->euid
) % nitems(gotd_client_uids
);
160 STAILQ_INSERT_HEAD(&gotd_client_uids
[slot
], counter
, entry
);
164 remove_uid_connection_counter(struct gotd_uid_connection_counter
*counter
)
166 uint64_t slot
= uid_hash(counter
->euid
) % nitems(gotd_client_uids
);
167 STAILQ_REMOVE(&gotd_client_uids
[slot
], counter
,
168 gotd_uid_connection_counter
, entry
);
171 static struct gotd_uid_connection_counter
*
172 find_uid_connection_counter(uid_t euid
)
175 struct gotd_uid_connection_counter
*c
;
177 slot
= uid_hash(euid
) % nitems(gotd_client_uids
);
178 STAILQ_FOREACH(c
, &gotd_client_uids
[slot
], entry
) {
186 static const struct got_error
*
187 disconnect(struct gotd_listen_client
*client
)
189 struct gotd_uid_connection_counter
*counter
;
193 log_debug("client on fd %d disconnecting", client
->fd
);
195 slot
= client_hash(client
->id
) % nitems(gotd_listen_clients
);
196 STAILQ_REMOVE(&gotd_listen_clients
[slot
], client
,
197 gotd_listen_client
, entry
);
199 counter
= find_uid_connection_counter(client
->euid
);
201 if (counter
->nconnections
> 0)
202 counter
->nconnections
--;
203 if (counter
->nconnections
== 0) {
204 remove_uid_connection_counter(counter
);
209 client_fd
= client
->fd
;
213 if (close(client_fd
) == -1)
214 return got_error_from_errno("close");
220 accept_reserve(int fd
, struct sockaddr
*addr
, socklen_t
*addrlen
,
221 int reserve
, volatile int *counter
)
224 int sock_flags
= SOCK_NONBLOCK
;
227 sock_flags
|= SOCK_CLOEXEC
;
230 if (getdtablecount() + reserve
+
231 ((*counter
+ 1) * GOTD_FD_NEEDED
) >= getdtablesize()) {
232 log_debug("inflight fds exceeded");
237 /* TA: silence warning from GCC. */
239 ret
= accept(fd
, addr
, addrlen
);
241 ret
= accept4(fd
, addr
, addrlen
, sock_flags
);
252 gotd_accept_paused(int fd
, short event
, void *arg
)
254 event_add(&gotd_listen
.iev
.ev
, NULL
);
258 gotd_accept(int fd
, short event
, void *arg
)
260 struct gotd_imsgev
*iev
= arg
;
261 struct sockaddr_storage ss
;
262 struct timeval backoff
;
265 struct gotd_listen_client
*client
= NULL
;
266 struct gotd_uid_connection_counter
*counter
= NULL
;
267 struct gotd_imsg_connect iconn
;
274 if (event_add(&gotd_listen
.iev
.ev
, NULL
) == -1) {
275 log_warn("event_add");
278 if (event
& EV_TIMEOUT
)
283 /* Other backoff conditions apart from EMFILE/ENFILE? */
284 s
= accept_reserve(fd
, (struct sockaddr
*)&ss
, &len
, GOTD_FD_RESERVE
,
294 event_del(&gotd_listen
.iev
.ev
);
295 evtimer_add(&gotd_listen
.pause
.ev
, &backoff
);
303 if (listen_client_cnt
>= GOTD_MAXCLIENTS
)
306 if (getpeereid(s
, &euid
, &egid
) == -1) {
307 log_warn("getpeerid");
311 counter
= find_uid_connection_counter(euid
);
312 if (counter
== NULL
) {
313 counter
= calloc(1, sizeof(*counter
));
314 if (counter
== NULL
) {
315 log_warn("%s: calloc", __func__
);
318 counter
->euid
= euid
;
319 counter
->nconnections
= 1;
320 add_uid_connection_counter(counter
);
322 int max_connections
= GOTD_MAX_CONN_PER_UID
;
323 struct gotd_uid_connection_limit
*limit
;
325 limit
= gotd_find_uid_connection_limit(
326 gotd_listen
.connection_limits
,
327 gotd_listen
.nconnection_limits
, euid
);
329 max_connections
= limit
->max_connections
;
331 if (counter
->nconnections
>= max_connections
) {
332 log_warnx("maximum connections exceeded for uid %d",
336 counter
->nconnections
++;
339 client
= calloc(1, sizeof(*client
));
340 if (client
== NULL
) {
341 log_warn("%s: calloc", __func__
);
344 client
->id
= get_client_id();
349 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__
,
350 client
->fd
, euid
, egid
);
352 memset(&iconn
, 0, sizeof(iconn
));
353 iconn
.client_id
= client
->id
;
358 log_warn("%s: dup", __func__
);
361 if (gotd_imsg_compose_event(iev
, GOTD_IMSG_CONNECT
, PROC_LISTEN
, s
,
362 &iconn
, sizeof(iconn
)) == -1) {
363 log_warn("imsg compose CONNECT");
376 static const struct got_error
*
377 recv_disconnect(struct imsg
*imsg
)
379 struct gotd_imsg_disconnect idisconnect
;
381 struct gotd_listen_client
*client
= NULL
;
383 datalen
= imsg
->hdr
.len
- IMSG_HEADER_SIZE
;
384 if (datalen
!= sizeof(idisconnect
))
385 return got_error(GOT_ERR_PRIVSEP_LEN
);
386 memcpy(&idisconnect
, imsg
->data
, sizeof(idisconnect
));
388 log_debug("client disconnecting");
390 client
= find_client(idisconnect
.client_id
);
392 return got_error(GOT_ERR_CLIENT_ID
);
394 return disconnect(client
);
398 listen_dispatch(int fd
, short event
, void *arg
)
400 const struct got_error
*err
= NULL
;
401 struct gotd_imsgev
*iev
= arg
;
402 struct imsgbuf
*ibuf
= &iev
->ibuf
;
407 if (event
& EV_READ
) {
408 if ((n
= imsg_read(ibuf
)) == -1 && errno
!= EAGAIN
)
409 fatal("imsg_read error");
410 if (n
== 0) /* Connection closed. */
414 if (event
& EV_WRITE
) {
415 n
= msgbuf_write(&ibuf
->w
);
416 if (n
== -1 && errno
!= EAGAIN
)
417 fatal("msgbuf_write");
418 if (n
== 0) /* Connection closed. */
423 if ((n
= imsg_get(ibuf
, &imsg
)) == -1)
424 fatal("%s: imsg_get", __func__
);
425 if (n
== 0) /* No more messages. */
428 switch (imsg
.hdr
.type
) {
429 case GOTD_IMSG_DISCONNECT
:
430 err
= recv_disconnect(&imsg
);
432 log_warnx("disconnect: %s", err
->msg
);
435 log_debug("unexpected imsg %d", imsg
.hdr
.type
);
443 gotd_imsg_event_add(iev
);
445 /* This pipe is dead. Remove its event handler */
447 event_loopexit(NULL
);
452 listen_main(const char *title
, int gotd_socket
,
453 struct gotd_uid_connection_limit
*connection_limits
,
454 size_t nconnection_limits
)
456 struct gotd_imsgev iev
;
457 struct event evsigint
, evsigterm
, evsighup
, evsigusr1
;
459 arc4random_buf(&clients_hash_key
, sizeof(clients_hash_key
));
460 arc4random_buf(&uid_hash_key
, sizeof(uid_hash_key
));
462 gotd_listen
.title
= title
;
463 gotd_listen
.pid
= getpid();
464 gotd_listen
.fd
= gotd_socket
;
465 gotd_listen
.connection_limits
= connection_limits
;
466 gotd_listen
.nconnection_limits
= nconnection_limits
;
468 signal_set(&evsigint
, SIGINT
, listen_sighdlr
, NULL
);
469 signal_set(&evsigterm
, SIGTERM
, listen_sighdlr
, NULL
);
470 signal_set(&evsighup
, SIGHUP
, listen_sighdlr
, NULL
);
471 signal_set(&evsigusr1
, SIGUSR1
, listen_sighdlr
, NULL
);
472 signal(SIGPIPE
, SIG_IGN
);
474 signal_add(&evsigint
, NULL
);
475 signal_add(&evsigterm
, NULL
);
476 signal_add(&evsighup
, NULL
);
477 signal_add(&evsigusr1
, NULL
);
479 imsg_init(&iev
.ibuf
, GOTD_FILENO_MSG_PIPE
);
480 iev
.handler
= listen_dispatch
;
481 iev
.events
= EV_READ
;
482 iev
.handler_arg
= NULL
;
483 event_set(&iev
.ev
, iev
.ibuf
.fd
, EV_READ
, listen_dispatch
, &iev
);
484 if (event_add(&iev
.ev
, NULL
) == -1)
487 event_set(&gotd_listen
.iev
.ev
, gotd_listen
.fd
, EV_READ
| EV_PERSIST
,
489 if (event_add(&gotd_listen
.iev
.ev
, NULL
))
491 evtimer_set(&gotd_listen
.pause
.ev
, gotd_accept_paused
, NULL
);
499 listen_shutdown(void)
501 log_debug("shutting down");
503 free(gotd_listen
.connection_limits
);
504 if (gotd_listen
.fd
!= -1)
505 close(gotd_listen
.fd
);