portable: release 0.101
[got-portable.git] / gotd / listen.c
blobe2d6d8e3db728cd0c119bc1fa84eecae6dfe9605
1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #include "got_compat.h"
19 #include <sys/types.h>
20 #include <sys/queue.h>
21 #include <sys/socket.h>
22 #include <sys/uio.h>
24 #include <errno.h>
25 #include <event.h>
26 #include <stdint.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <imsg.h>
31 #include <limits.h>
32 #include <signal.h>
33 #include <unistd.h>
35 #include "got_error.h"
36 #include "got_object.h"
37 #include "got_path.h"
39 #include "got_compat.h"
41 #include "gotd.h"
42 #include "log.h"
43 #include "listen.h"
45 #ifndef nitems
46 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
47 #endif
49 struct gotd_listen_client {
50 STAILQ_ENTRY(gotd_listen_client) entry;
51 uint32_t id;
52 int fd;
53 uid_t euid;
55 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
57 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
58 static SIPHASH_KEY clients_hash_key;
59 static volatile int listen_client_cnt;
60 static int inflight;
62 struct gotd_uid_connection_counter {
63 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
64 uid_t euid;
65 int nconnections;
67 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
68 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
69 static SIPHASH_KEY uid_hash_key;
71 static struct {
72 pid_t pid;
73 const char *title;
74 int fd;
75 struct gotd_imsgev iev;
76 struct gotd_imsgev pause;
77 struct gotd_uid_connection_limit *connection_limits;
78 size_t nconnection_limits;
79 } gotd_listen;
81 static int inflight;
83 static void listen_shutdown(void);
85 static void
86 listen_sighdlr(int sig, short event, void *arg)
89 * Normal signal handler rules don't apply because libevent
90 * decouples for us.
93 switch (sig) {
94 case SIGHUP:
95 break;
96 case SIGUSR1:
97 break;
98 case SIGTERM:
99 case SIGINT:
100 listen_shutdown();
101 /* NOTREACHED */
102 break;
103 default:
104 fatalx("unexpected signal");
108 static uint64_t
109 client_hash(uint32_t client_id)
111 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
114 static void
115 add_client(struct gotd_listen_client *client)
117 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
118 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
119 listen_client_cnt++;
122 static struct gotd_listen_client *
123 find_client(uint32_t client_id)
125 uint64_t slot;
126 struct gotd_listen_client *c;
128 slot = client_hash(client_id) % nitems(gotd_listen_clients);
129 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
130 if (c->id == client_id)
131 return c;
134 return NULL;
137 static uint32_t
138 get_client_id(void)
140 int duplicate = 0;
141 uint32_t id;
143 do {
144 id = arc4random();
145 duplicate = (find_client(id) != NULL);
146 } while (duplicate || id == 0);
148 return id;
151 static uint64_t
152 uid_hash(uid_t euid)
154 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
157 static void
158 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
160 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
161 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
164 static void
165 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
167 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
168 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
169 gotd_uid_connection_counter, entry);
172 static struct gotd_uid_connection_counter *
173 find_uid_connection_counter(uid_t euid)
175 uint64_t slot;
176 struct gotd_uid_connection_counter *c;
178 slot = uid_hash(euid) % nitems(gotd_client_uids);
179 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
180 if (c->euid == euid)
181 return c;
184 return NULL;
187 static const struct got_error *
188 disconnect(struct gotd_listen_client *client)
190 struct gotd_uid_connection_counter *counter;
191 uint64_t slot;
192 int client_fd;
194 log_debug("client on fd %d disconnecting", client->fd);
196 slot = client_hash(client->id) % nitems(gotd_listen_clients);
197 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
198 gotd_listen_client, entry);
200 counter = find_uid_connection_counter(client->euid);
201 if (counter) {
202 if (counter->nconnections > 0)
203 counter->nconnections--;
204 if (counter->nconnections == 0) {
205 remove_uid_connection_counter(counter);
206 free(counter);
210 client_fd = client->fd;
211 free(client);
212 inflight--;
213 listen_client_cnt--;
214 if (close(client_fd) == -1)
215 return got_error_from_errno("close");
217 return NULL;
220 static int
221 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
222 int reserve, volatile int *counter)
224 int ret;
225 int sock_flags = SOCK_NONBLOCK;
227 #ifdef SOCK_CLOEXEC
228 sock_flags |= SOCK_CLOEXEC;
229 #endif
231 if (getdtablecount() + reserve +
232 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
233 log_debug("inflight fds exceeded");
234 errno = EMFILE;
235 return -1;
237 #ifdef __APPLE__
238 /* TA: silence warning from GCC. */
239 (void)sock_flags;
240 ret = accept(fd, addr, addrlen);
241 #else
242 ret = accept4(fd, addr, addrlen, sock_flags);
243 #endif
245 if (ret > -1) {
246 (*counter)++;
249 return ret;
252 static void
253 gotd_accept_paused(int fd, short event, void *arg)
255 event_add(&gotd_listen.iev.ev, NULL);
258 static void
259 gotd_accept(int fd, short event, void *arg)
261 struct gotd_imsgev *iev = arg;
262 struct sockaddr_storage ss;
263 struct timeval backoff;
264 socklen_t len;
265 int s = -1;
266 struct gotd_listen_client *client = NULL;
267 struct gotd_uid_connection_counter *counter = NULL;
268 struct gotd_imsg_connect iconn;
269 uid_t euid;
270 gid_t egid;
272 backoff.tv_sec = 1;
273 backoff.tv_usec = 0;
275 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
276 log_warn("event_add");
277 return;
279 if (event & EV_TIMEOUT)
280 return;
282 len = sizeof(ss);
284 /* Other backoff conditions apart from EMFILE/ENFILE? */
285 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
286 &inflight);
287 if (s == -1) {
288 switch (errno) {
289 case EINTR:
290 case EWOULDBLOCK:
291 case ECONNABORTED:
292 return;
293 case EMFILE:
294 case ENFILE:
295 event_del(&gotd_listen.iev.ev);
296 evtimer_add(&gotd_listen.pause.ev, &backoff);
297 return;
298 default:
299 log_warn("accept");
300 return;
304 if (listen_client_cnt >= GOTD_MAXCLIENTS)
305 goto err;
307 if (getpeereid(s, &euid, &egid) == -1) {
308 log_warn("getpeerid");
309 goto err;
312 counter = find_uid_connection_counter(euid);
313 if (counter == NULL) {
314 counter = calloc(1, sizeof(*counter));
315 if (counter == NULL) {
316 log_warn("%s: calloc", __func__);
317 goto err;
319 counter->euid = euid;
320 counter->nconnections = 1;
321 add_uid_connection_counter(counter);
322 } else {
323 int max_connections = GOTD_MAX_CONN_PER_UID;
324 struct gotd_uid_connection_limit *limit;
326 limit = gotd_find_uid_connection_limit(
327 gotd_listen.connection_limits,
328 gotd_listen.nconnection_limits, euid);
329 if (limit)
330 max_connections = limit->max_connections;
332 if (counter->nconnections >= max_connections) {
333 log_warnx("maximum connections exceeded for uid %d",
334 euid);
335 goto err;
337 counter->nconnections++;
340 client = calloc(1, sizeof(*client));
341 if (client == NULL) {
342 log_warn("%s: calloc", __func__);
343 goto err;
345 client->id = get_client_id();
346 client->fd = s;
347 client->euid = euid;
348 s = -1;
349 add_client(client);
350 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
351 client->fd, euid, egid);
353 memset(&iconn, 0, sizeof(iconn));
354 iconn.client_id = client->id;
355 iconn.euid = euid;
356 iconn.egid = egid;
357 s = dup(client->fd);
358 if (s == -1) {
359 log_warn("%s: dup", __func__);
360 goto err;
362 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
363 &iconn, sizeof(iconn)) == -1) {
364 log_warn("imsg compose CONNECT");
365 goto err;
368 return;
369 err:
370 inflight--;
371 if (client)
372 disconnect(client);
373 if (s != -1)
374 close(s);
377 static const struct got_error *
378 recv_disconnect(struct imsg *imsg)
380 struct gotd_imsg_disconnect idisconnect;
381 size_t datalen;
382 struct gotd_listen_client *client = NULL;
384 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
385 if (datalen != sizeof(idisconnect))
386 return got_error(GOT_ERR_PRIVSEP_LEN);
387 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
389 log_debug("client disconnecting");
391 client = find_client(idisconnect.client_id);
392 if (client == NULL)
393 return got_error(GOT_ERR_CLIENT_ID);
395 return disconnect(client);
398 static void
399 listen_dispatch(int fd, short event, void *arg)
401 const struct got_error *err = NULL;
402 struct gotd_imsgev *iev = arg;
403 struct imsgbuf *ibuf = &iev->ibuf;
404 struct imsg imsg;
405 ssize_t n;
406 int shut = 0;
408 if (event & EV_READ) {
409 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
410 fatal("imsg_read error");
411 if (n == 0) /* Connection closed. */
412 shut = 1;
415 if (event & EV_WRITE) {
416 n = msgbuf_write(&ibuf->w);
417 if (n == -1 && errno != EAGAIN)
418 fatal("msgbuf_write");
419 if (n == 0) /* Connection closed. */
420 shut = 1;
423 for (;;) {
424 if ((n = imsg_get(ibuf, &imsg)) == -1)
425 fatal("%s: imsg_get", __func__);
426 if (n == 0) /* No more messages. */
427 break;
429 switch (imsg.hdr.type) {
430 case GOTD_IMSG_DISCONNECT:
431 err = recv_disconnect(&imsg);
432 if (err)
433 log_warnx("disconnect: %s", err->msg);
434 break;
435 default:
436 log_debug("unexpected imsg %d", imsg.hdr.type);
437 break;
440 imsg_free(&imsg);
443 if (!shut) {
444 gotd_imsg_event_add(iev);
445 } else {
446 /* This pipe is dead. Remove its event handler */
447 event_del(&iev->ev);
448 event_loopexit(NULL);
452 void
453 listen_main(const char *title, int gotd_socket,
454 struct gotd_uid_connection_limit *connection_limits,
455 size_t nconnection_limits)
457 struct gotd_imsgev iev;
458 struct event evsigint, evsigterm, evsighup, evsigusr1;
460 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
461 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
463 gotd_listen.title = title;
464 gotd_listen.pid = getpid();
465 gotd_listen.fd = gotd_socket;
466 gotd_listen.connection_limits = connection_limits;
467 gotd_listen.nconnection_limits = nconnection_limits;
469 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
470 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
471 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
472 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
473 signal(SIGPIPE, SIG_IGN);
475 signal_add(&evsigint, NULL);
476 signal_add(&evsigterm, NULL);
477 signal_add(&evsighup, NULL);
478 signal_add(&evsigusr1, NULL);
480 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
481 iev.handler = listen_dispatch;
482 iev.events = EV_READ;
483 iev.handler_arg = NULL;
484 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
485 if (event_add(&iev.ev, NULL) == -1)
486 fatalx("event add");
488 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
489 gotd_accept, &iev);
490 if (event_add(&gotd_listen.iev.ev, NULL))
491 fatalx("event add");
492 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
494 event_dispatch();
496 listen_shutdown();
499 static void
500 listen_shutdown(void)
502 log_debug("%s: shutting down", gotd_listen.title);
504 free(gotd_listen.connection_limits);
505 if (gotd_listen.fd != -1)
506 close(gotd_listen.fd);
508 exit(0);