2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
34 #include "got_error.h"
42 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
45 struct gotd_listen_client
{
46 STAILQ_ENTRY(gotd_listen_client
) entry
;
51 STAILQ_HEAD(gotd_listen_clients
, gotd_listen_client
);
53 static struct gotd_listen_clients gotd_listen_clients
[GOTD_CLIENT_TABLE_SIZE
];
54 static SIPHASH_KEY clients_hash_key
;
55 static volatile int listen_client_cnt
;
58 struct gotd_uid_connection_counter
{
59 STAILQ_ENTRY(gotd_uid_connection_counter
) entry
;
63 STAILQ_HEAD(gotd_client_uids
, gotd_uid_connection_counter
);
64 static struct gotd_client_uids gotd_client_uids
[GOTD_CLIENT_TABLE_SIZE
];
65 static SIPHASH_KEY uid_hash_key
;
71 struct gotd_imsgev iev
;
72 struct gotd_imsgev pause
;
73 struct gotd_uid_connection_limit
*connection_limits
;
74 size_t nconnection_limits
;
79 static void listen_shutdown(void);
82 listen_sighdlr(int sig
, short event
, void *arg
)
85 * Normal signal handler rules don't apply because libevent
100 fatalx("unexpected signal");
105 client_hash(uint32_t client_id
)
107 return SipHash24(&clients_hash_key
, &client_id
, sizeof(client_id
));
111 add_client(struct gotd_listen_client
*client
)
113 uint64_t slot
= client_hash(client
->id
) % nitems(gotd_listen_clients
);
114 STAILQ_INSERT_HEAD(&gotd_listen_clients
[slot
], client
, entry
);
118 static struct gotd_listen_client
*
119 find_client(uint32_t client_id
)
122 struct gotd_listen_client
*c
;
124 slot
= client_hash(client_id
) % nitems(gotd_listen_clients
);
125 STAILQ_FOREACH(c
, &gotd_listen_clients
[slot
], entry
) {
126 if (c
->id
== client_id
)
141 duplicate
= (find_client(id
) != NULL
);
142 } while (duplicate
|| id
== 0);
150 return SipHash24(&uid_hash_key
, &euid
, sizeof(euid
));
154 add_uid_connection_counter(struct gotd_uid_connection_counter
*counter
)
156 uint64_t slot
= uid_hash(counter
->euid
) % nitems(gotd_client_uids
);
157 STAILQ_INSERT_HEAD(&gotd_client_uids
[slot
], counter
, entry
);
161 remove_uid_connection_counter(struct gotd_uid_connection_counter
*counter
)
163 uint64_t slot
= uid_hash(counter
->euid
) % nitems(gotd_client_uids
);
164 STAILQ_REMOVE(&gotd_client_uids
[slot
], counter
,
165 gotd_uid_connection_counter
, entry
);
168 static struct gotd_uid_connection_counter
*
169 find_uid_connection_counter(uid_t euid
)
172 struct gotd_uid_connection_counter
*c
;
174 slot
= uid_hash(euid
) % nitems(gotd_client_uids
);
175 STAILQ_FOREACH(c
, &gotd_client_uids
[slot
], entry
) {
183 static const struct got_error
*
184 disconnect(struct gotd_listen_client
*client
)
186 struct gotd_uid_connection_counter
*counter
;
190 log_debug("client on fd %d disconnecting", client
->fd
);
192 slot
= client_hash(client
->id
) % nitems(gotd_listen_clients
);
193 STAILQ_REMOVE(&gotd_listen_clients
[slot
], client
,
194 gotd_listen_client
, entry
);
196 counter
= find_uid_connection_counter(client
->euid
);
198 if (counter
->nconnections
> 0)
199 counter
->nconnections
--;
200 if (counter
->nconnections
== 0) {
201 remove_uid_connection_counter(counter
);
206 client_fd
= client
->fd
;
210 if (close(client_fd
) == -1)
211 return got_error_from_errno("close");
217 accept_reserve(int fd
, struct sockaddr
*addr
, socklen_t
*addrlen
,
218 int reserve
, volatile int *counter
)
222 if (getdtablecount() + reserve
+
223 ((*counter
+ 1) * GOTD_FD_NEEDED
) >= getdtablesize()) {
224 log_debug("inflight fds exceeded");
229 if ((ret
= accept4(fd
, addr
, addrlen
,
230 SOCK_NONBLOCK
| SOCK_CLOEXEC
)) > -1) {
238 gotd_accept_paused(int fd
, short event
, void *arg
)
240 event_add(&gotd_listen
.iev
.ev
, NULL
);
244 gotd_accept(int fd
, short event
, void *arg
)
246 struct gotd_imsgev
*iev
= arg
;
247 struct sockaddr_storage ss
;
248 struct timeval backoff
;
251 struct gotd_listen_client
*client
= NULL
;
252 struct gotd_uid_connection_counter
*counter
= NULL
;
253 struct gotd_imsg_connect iconn
;
260 if (event_add(&gotd_listen
.iev
.ev
, NULL
) == -1) {
261 log_warn("event_add");
264 if (event
& EV_TIMEOUT
)
269 /* Other backoff conditions apart from EMFILE/ENFILE? */
270 s
= accept_reserve(fd
, (struct sockaddr
*)&ss
, &len
, GOTD_FD_RESERVE
,
280 event_del(&gotd_listen
.iev
.ev
);
281 evtimer_add(&gotd_listen
.pause
.ev
, &backoff
);
289 if (listen_client_cnt
>= GOTD_MAXCLIENTS
)
292 if (getpeereid(s
, &euid
, &egid
) == -1) {
293 log_warn("getpeerid");
297 counter
= find_uid_connection_counter(euid
);
298 if (counter
== NULL
) {
299 counter
= calloc(1, sizeof(*counter
));
300 if (counter
== NULL
) {
301 log_warn("%s: calloc", __func__
);
304 counter
->euid
= euid
;
305 counter
->nconnections
= 1;
306 add_uid_connection_counter(counter
);
308 int max_connections
= GOTD_MAX_CONN_PER_UID
;
309 struct gotd_uid_connection_limit
*limit
;
311 limit
= gotd_find_uid_connection_limit(
312 gotd_listen
.connection_limits
,
313 gotd_listen
.nconnection_limits
, euid
);
315 max_connections
= limit
->max_connections
;
317 if (counter
->nconnections
>= max_connections
) {
318 log_warnx("maximum connections exceeded for uid %d",
322 counter
->nconnections
++;
325 client
= calloc(1, sizeof(*client
));
326 if (client
== NULL
) {
327 log_warn("%s: calloc", __func__
);
330 client
->id
= get_client_id();
335 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__
,
336 client
->fd
, euid
, egid
);
338 memset(&iconn
, 0, sizeof(iconn
));
339 iconn
.client_id
= client
->id
;
344 log_warn("%s: dup", __func__
);
347 if (gotd_imsg_compose_event(iev
, GOTD_IMSG_CONNECT
, PROC_LISTEN
, s
,
348 &iconn
, sizeof(iconn
)) == -1) {
349 log_warn("imsg compose CONNECT");
362 static const struct got_error
*
363 recv_disconnect(struct imsg
*imsg
)
365 struct gotd_imsg_disconnect idisconnect
;
367 struct gotd_listen_client
*client
= NULL
;
369 datalen
= imsg
->hdr
.len
- IMSG_HEADER_SIZE
;
370 if (datalen
!= sizeof(idisconnect
))
371 return got_error(GOT_ERR_PRIVSEP_LEN
);
372 memcpy(&idisconnect
, imsg
->data
, sizeof(idisconnect
));
374 log_debug("client disconnecting");
376 client
= find_client(idisconnect
.client_id
);
378 return got_error(GOT_ERR_CLIENT_ID
);
380 return disconnect(client
);
384 listen_dispatch(int fd
, short event
, void *arg
)
386 const struct got_error
*err
= NULL
;
387 struct gotd_imsgev
*iev
= arg
;
388 struct imsgbuf
*ibuf
= &iev
->ibuf
;
393 if (event
& EV_READ
) {
394 if ((n
= imsg_read(ibuf
)) == -1 && errno
!= EAGAIN
)
395 fatal("imsg_read error");
396 if (n
== 0) /* Connection closed. */
400 if (event
& EV_WRITE
) {
401 n
= msgbuf_write(&ibuf
->w
);
402 if (n
== -1 && errno
!= EAGAIN
)
403 fatal("msgbuf_write");
404 if (n
== 0) /* Connection closed. */
409 if ((n
= imsg_get(ibuf
, &imsg
)) == -1)
410 fatal("%s: imsg_get", __func__
);
411 if (n
== 0) /* No more messages. */
414 switch (imsg
.hdr
.type
) {
415 case GOTD_IMSG_DISCONNECT
:
416 err
= recv_disconnect(&imsg
);
418 log_warnx("disconnect: %s", err
->msg
);
421 log_debug("unexpected imsg %d", imsg
.hdr
.type
);
429 gotd_imsg_event_add(iev
);
431 /* This pipe is dead. Remove its event handler */
433 event_loopexit(NULL
);
438 listen_main(const char *title
, int gotd_socket
,
439 struct gotd_uid_connection_limit
*connection_limits
,
440 size_t nconnection_limits
)
442 struct gotd_imsgev iev
;
443 struct event evsigint
, evsigterm
, evsighup
, evsigusr1
;
445 arc4random_buf(&clients_hash_key
, sizeof(clients_hash_key
));
446 arc4random_buf(&uid_hash_key
, sizeof(uid_hash_key
));
448 gotd_listen
.title
= title
;
449 gotd_listen
.pid
= getpid();
450 gotd_listen
.fd
= gotd_socket
;
451 gotd_listen
.connection_limits
= connection_limits
;
452 gotd_listen
.nconnection_limits
= nconnection_limits
;
454 signal_set(&evsigint
, SIGINT
, listen_sighdlr
, NULL
);
455 signal_set(&evsigterm
, SIGTERM
, listen_sighdlr
, NULL
);
456 signal_set(&evsighup
, SIGHUP
, listen_sighdlr
, NULL
);
457 signal_set(&evsigusr1
, SIGUSR1
, listen_sighdlr
, NULL
);
458 signal(SIGPIPE
, SIG_IGN
);
460 signal_add(&evsigint
, NULL
);
461 signal_add(&evsigterm
, NULL
);
462 signal_add(&evsighup
, NULL
);
463 signal_add(&evsigusr1
, NULL
);
465 imsg_init(&iev
.ibuf
, GOTD_FILENO_MSG_PIPE
);
466 iev
.handler
= listen_dispatch
;
467 iev
.events
= EV_READ
;
468 iev
.handler_arg
= NULL
;
469 event_set(&iev
.ev
, iev
.ibuf
.fd
, EV_READ
, listen_dispatch
, &iev
);
470 if (event_add(&iev
.ev
, NULL
) == -1)
473 event_set(&gotd_listen
.iev
.ev
, gotd_listen
.fd
, EV_READ
| EV_PERSIST
,
475 if (event_add(&gotd_listen
.iev
.ev
, NULL
))
477 evtimer_set(&gotd_listen
.pause
.ev
, gotd_accept_paused
, NULL
);
485 listen_shutdown(void)
487 log_debug("shutting down");
489 free(gotd_listen
.connection_limits
);
490 if (gotd_listen
.fd
!= -1)
491 close(gotd_listen
.fd
);