Prepare required data folder for integration tests
[prosody.git] / net / server_epoll.lua
blobf48086e32b131739e7a80634fd7871e1c04679ac
1 -- Prosody IM
2 -- Copyright (C) 2016-2018 Kim Alvefur
3 --
4 -- This project is MIT/X11 licensed. Please see the
5 -- COPYING file in the source package for more information.
6 --
9 local t_insert = table.insert;
10 local t_concat = table.concat;
11 local setmetatable = setmetatable;
12 local pcall = pcall;
13 local type = type;
14 local next = next;
15 local pairs = pairs;
16 local logger = require "util.logger";
17 local log = logger.init("server_epoll");
18 local socket = require "socket";
19 local luasec = require "ssl";
20 local gettime = require "util.time".now;
21 local indexedbheap = require "util.indexedbheap";
22 local createtable = require "util.table".create;
23 local inet = require "util.net";
24 local inet_pton = inet.pton;
25 local _SOCKETINVALID = socket._SOCKETINVALID or -1;
26 local new_id = require "util.id".medium;
28 local poller = require "util.poll"
29 local EEXIST = poller.EEXIST;
30 local ENOENT = poller.ENOENT;
32 local poll = assert(poller.new());
34 local _ENV = nil;
35 -- luacheck: std none
37 local default_config = { __index = {
38 -- If a connection is silent for this long, close it unless onreadtimeout says not to
39 read_timeout = 14 * 60;
41 -- How long to wait for a socket to become writable after queuing data to send
42 send_timeout = 180;
44 -- How long to wait for a socket to become writable after creation
45 connect_timeout = 20;
47 -- Some number possibly influencing how many pending connections can be accepted
48 tcp_backlog = 128;
50 -- If accepting a new incoming connection fails, wait this long before trying again
51 accept_retry_interval = 10;
53 -- If there is still more data to read from LuaSocktes buffer, wait this long and read again
54 read_retry_delay = 1e-06;
56 -- Size of chunks to read from sockets
57 read_size = 8192;
59 -- Timeout used during between steps in TLS handshakes
60 ssl_handshake_timeout = 60;
62 -- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
63 max_wait = 86400;
64 min_wait = 1e-06;
66 -- EXPERIMENTAL
67 -- Whether to kill connections in case of callback errors.
68 fatal_errors = false;
70 -- Attempt writes instantly
71 opportunistic_writes = false;
72 }};
73 local cfg = default_config.__index;
75 local fds = createtable(10, 0); -- FD -> conn
77 -- Timer and scheduling --
79 local timers = indexedbheap.create();
81 local function noop() end
82 local function closetimer(t)
83 t[1] = 0;
84 t[2] = noop;
85 timers:remove(t.id);
86 end
88 local function reschedule(t, time)
89 t[1] = time;
90 timers:reprioritize(t.id, time);
91 end
93 -- Add absolute timer
94 local function at(time, f)
95 local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil };
96 timer.id = timers:insert(timer, time);
97 return timer;
98 end
100 -- Add relative timer
101 local function addtimer(timeout, f)
102 return at(gettime() + timeout, f);
105 -- Run callbacks of expired timers
106 -- Return time until next timeout
107 local function runtimers(next_delay, min_wait)
108 -- Any timers at all?
109 local now = gettime();
110 local peek = timers:peek();
111 while peek do
113 if peek > now then
114 next_delay = peek - now;
115 break;
118 local _, timer, id = timers:pop();
119 local ok, ret = pcall(timer[2], now);
120 if ok and type(ret) == "number" then
121 local next_time = now+ret;
122 timer[1] = next_time;
123 timers:insert(timer, next_time);
126 peek = timers:peek();
128 if peek == nil then
129 return next_delay;
132 if next_delay < min_wait then
133 return min_wait;
135 return next_delay;
138 -- Socket handler interface
140 local interface = {};
141 local interface_mt = { __index = interface };
143 function interface_mt:__tostring()
144 if self.sockname and self.peername then
145 return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport);
146 elseif self.sockname or self.peername then
147 return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport);
149 return ("FD %d"):format(self:getfd());
152 interface.log = log;
153 function interface:debug(msg, ...) --luacheck: ignore 212/self
154 self.log("debug", msg, ...);
157 function interface:error(msg, ...) --luacheck: ignore 212/self
158 self.log("error", msg, ...);
161 -- Replace the listener and tell the old one
162 function interface:setlistener(listeners, data)
163 self:on("detach");
164 self.listeners = listeners;
165 self:on("attach", data);
168 -- Call a listener callback
169 function interface:on(what, ...)
170 if not self.listeners then
171 self:debug("Interface is missing listener callbacks");
172 return;
174 local listener = self.listeners["on"..what];
175 if not listener then
176 -- self:debug("Missing listener 'on%s'", what); -- uncomment for development and debugging
177 return;
179 local ok, err = pcall(listener, self, ...);
180 if not ok then
181 if cfg.fatal_errors then
182 self:debug("Closing due to error calling on%s: %s", what, err);
183 self:destroy();
184 else
185 self:debug("Error calling on%s: %s", what, err);
187 return nil, err;
189 return err;
192 -- Return the file descriptor number
193 function interface:getfd()
194 if self.conn then
195 return self.conn:getfd();
197 return _SOCKETINVALID;
200 function interface:server()
201 return self._server or self;
204 -- Get IP address
205 function interface:ip()
206 return self.peername or self.sockname;
209 -- Get a port number, doesn't matter which
210 function interface:port()
211 return self.sockport or self.peerport;
214 -- Get local port number
215 function interface:clientport()
216 return self.sockport;
219 -- Get remote port
220 function interface:serverport()
221 if self.sockport then
222 return self.sockport;
223 elseif self._server then
224 self._server:port();
228 -- Return underlying socket
229 function interface:socket()
230 return self.conn;
233 function interface:set_mode(new_mode)
234 self.read_size = new_mode;
237 function interface:setoption(k, v)
238 -- LuaSec doesn't expose setoption :(
239 if self.conn.setoption then
240 self.conn:setoption(k, v);
244 -- Timeout for detecting dead or idle sockets
245 function interface:setreadtimeout(t)
246 if t == false then
247 if self._readtimeout then
248 self._readtimeout:close();
249 self._readtimeout = nil;
251 return
253 t = t or cfg.read_timeout;
254 if self._readtimeout then
255 self._readtimeout:reschedule(gettime() + t);
256 else
257 self._readtimeout = addtimer(t, function ()
258 if self:on("readtimeout") then
259 return cfg.read_timeout;
260 else
261 self:on("disconnect", "read timeout");
262 self:destroy();
264 end);
268 -- Timeout for detecting dead sockets
269 function interface:setwritetimeout(t)
270 if t == false then
271 if self._writetimeout then
272 self._writetimeout:close();
273 self._writetimeout = nil;
275 return
277 t = t or cfg.send_timeout;
278 if self._writetimeout then
279 self._writetimeout:reschedule(gettime() + t);
280 else
281 self._writetimeout = addtimer(t, function ()
282 self:on("disconnect", "write timeout");
283 self:destroy();
284 end);
288 function interface:add(r, w)
289 local fd = self:getfd();
290 if fd < 0 then
291 return nil, "invalid fd";
293 if r == nil then r = self._wantread; end
294 if w == nil then w = self._wantwrite; end
295 local ok, err, errno = poll:add(fd, r, w);
296 if not ok then
297 if errno == EEXIST then
298 self:debug("FD already registered in poller! (EEXIST)");
299 return self:set(r, w); -- So try to change its flags
301 self:debug("Could not register in poller: %s(%d)", err, errno);
302 return ok, err;
304 self._wantread, self._wantwrite = r, w;
305 fds[fd] = self;
306 self:debug("Registered in poller");
307 return true;
310 function interface:set(r, w)
311 local fd = self:getfd();
312 if fd < 0 then
313 return nil, "invalid fd";
315 if r == nil then r = self._wantread; end
316 if w == nil then w = self._wantwrite; end
317 local ok, err, errno = poll:set(fd, r, w);
318 if not ok then
319 self:debug("Could not update poller state: %s(%d)", err, errno);
320 return ok, err;
322 self._wantread, self._wantwrite = r, w;
323 return true;
326 function interface:del()
327 local fd = self:getfd();
328 if fd < 0 then
329 return nil, "invalid fd";
331 if fds[fd] ~= self then
332 return nil, "unregistered fd";
334 local ok, err, errno = poll:del(fd);
335 if not ok and errno ~= ENOENT then
336 self:debug("Could not unregister: %s(%d)", err, errno);
337 return ok, err;
339 self._wantread, self._wantwrite = nil, nil;
340 fds[fd] = nil;
341 self:debug("Unregistered from poller");
342 return true;
345 function interface:setflags(r, w)
346 if not(self._wantread or self._wantwrite) then
347 if not(r or w) then
348 return true; -- no change
350 return self:add(r, w);
352 if not(r or w) then
353 return self:del();
355 return self:set(r, w);
358 -- Called when socket is readable
359 function interface:onreadable()
360 local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
361 if data then
362 self:onconnect();
363 self:on("incoming", data);
364 else
365 if err == "wantread" then
366 self:set(true, nil);
367 err = "timeout";
368 elseif err == "wantwrite" then
369 self:set(nil, true);
370 err = "timeout";
372 if partial and partial ~= "" then
373 self:onconnect();
374 self:on("incoming", partial, err);
376 if err ~= "timeout" then
377 self:on("disconnect", err);
378 self:destroy()
379 return;
382 if not self.conn then return; end
383 if self._limit and (data or partial) then
384 local cost = self._limit * #(data or partial);
385 if cost > cfg.min_wait then
386 self:setreadtimeout(false);
387 self:pausefor(cost);
388 return;
391 if self._wantread and self.conn:dirty() then
392 self:setreadtimeout(false);
393 self:pausefor(cfg.read_retry_delay);
394 else
395 self:setreadtimeout();
399 -- Called when socket is writable
400 function interface:onwritable()
401 self:onconnect();
402 if not self.conn then return; end -- could have been closed in onconnect
403 local buffer = self.writebuffer;
404 local data = t_concat(buffer);
405 local ok, err, partial = self.conn:send(data);
406 if ok then
407 self:set(nil, false);
408 for i = #buffer, 1, -1 do
409 buffer[i] = nil;
411 self:setwritetimeout(false);
412 self:ondrain(); -- Be aware of writes in ondrain
413 return;
414 elseif partial then
415 buffer[1] = data:sub(partial+1);
416 for i = #buffer, 2, -1 do
417 buffer[i] = nil;
419 self:set(nil, true);
420 self:setwritetimeout();
422 if err == "wantwrite" or err == "timeout" then
423 self:set(nil, true);
424 elseif err == "wantread" then
425 self:set(true, nil);
426 elseif err ~= "timeout" then
427 self:on("disconnect", err);
428 self:destroy();
432 -- The write buffer has been successfully emptied
433 function interface:ondrain()
434 return self:on("drain");
437 -- Add data to write buffer and set flag for wanting to write
438 function interface:write(data)
439 local buffer = self.writebuffer;
440 if buffer then
441 t_insert(buffer, data);
442 else
443 self.writebuffer = { data };
445 if not self._write_lock then
446 if cfg.opportunistic_writes then
447 self:onwritable();
448 return #data;
450 self:setwritetimeout();
451 self:set(nil, true);
453 return #data;
455 interface.send = interface.write;
457 -- Close, possibly after writing is done
458 function interface:close()
459 if self.writebuffer and self.writebuffer[1] then
460 self:set(false, true); -- Flush final buffer contents
461 self.write, self.send = noop, noop; -- No more writing
462 self:debug("Close after writing");
463 self.ondrain = interface.close;
464 else
465 self:debug("Closing now");
466 self.write, self.send = noop, noop;
467 self.close = noop;
468 self:on("disconnect");
469 self:destroy();
473 function interface:destroy()
474 self:del();
475 self:setwritetimeout(false);
476 self:setreadtimeout(false);
477 self.onreadable = noop;
478 self.onwritable = noop;
479 self.destroy = noop;
480 self.close = noop;
481 self.on = noop;
482 self.conn:close();
483 self.conn = nil;
486 function interface:ssl()
487 return self._tls;
490 function interface:starttls(tls_ctx)
491 if tls_ctx then self.tls_ctx = tls_ctx; end
492 self.starttls = false;
493 if self.writebuffer and self.writebuffer[1] then
494 self:debug("Start TLS after write");
495 self.ondrain = interface.starttls;
496 self:set(nil, true); -- make sure wantwrite is set
497 else
498 if self.ondrain == interface.starttls then
499 self.ondrain = nil;
501 self.onwritable = interface.tlshandskake;
502 self.onreadable = interface.tlshandskake;
503 self:set(true, true);
504 self:debug("Prepared to start TLS");
508 function interface:tlshandskake()
509 self:setwritetimeout(false);
510 self:setreadtimeout(false);
511 if not self._tls then
512 self._tls = true;
513 self:debug("Starting TLS now");
514 self:del();
515 local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
516 if not ok then
517 conn, err = ok, conn;
518 self:debug("Failed to initialize TLS: %s", err);
520 if not conn then
521 self:on("disconnect", err);
522 self:destroy();
523 return conn, err;
525 conn:settimeout(0);
526 self.conn = conn;
527 if conn.sni then
528 if self.servername then
529 conn:sni(self.servername);
530 elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
531 conn:sni(self._server.hosts, true);
534 self:on("starttls");
535 self.ondrain = nil;
536 self.onwritable = interface.tlshandskake;
537 self.onreadable = interface.tlshandskake;
538 return self:init();
540 local ok, err = self.conn:dohandshake();
541 if ok then
542 self:debug("TLS handshake complete");
543 self.onwritable = nil;
544 self.onreadable = nil;
545 self:on("status", "ssl-handshake-complete");
546 self:setwritetimeout();
547 self:set(true, true);
548 elseif err == "wantread" then
549 self:debug("TLS handshake to wait until readable");
550 self:set(true, false);
551 self:setreadtimeout(cfg.ssl_handshake_timeout);
552 elseif err == "wantwrite" then
553 self:debug("TLS handshake to wait until writable");
554 self:set(false, true);
555 self:setwritetimeout(cfg.ssl_handshake_timeout);
556 else
557 self:debug("TLS handshake error: %s", err);
558 self:on("disconnect", err);
559 self:destroy();
563 local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
564 client:settimeout(0);
565 local conn = setmetatable({
566 conn = client;
567 _server = server;
568 created = gettime();
569 listeners = listeners;
570 read_size = read_size or (server and server.read_size);
571 writebuffer = {};
572 tls_ctx = tls_ctx or (server and server.tls_ctx);
573 tls_direct = server and server.tls_direct;
574 log = logger.init(("conn%s"):format(new_id()));
575 extra = extra;
576 }, interface_mt);
578 if extra then
579 if extra.servername then
580 conn.servername = extra.servername;
584 conn:updatenames();
585 return conn;
588 function interface:updatenames()
589 local conn = self.conn;
590 local ok, peername, peerport = pcall(conn.getpeername, conn);
591 if ok then
592 self.peername, self.peerport = peername, peerport;
594 local ok, sockname, sockport = pcall(conn.getsockname, conn);
595 if ok then
596 self.sockname, self.sockport = sockname, sockport;
600 -- A server interface has new incoming connections waiting
601 -- This replaces the onreadable callback
602 function interface:onacceptable()
603 local conn, err = self.conn:accept();
604 if not conn then
605 self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
606 self:pausefor(cfg.accept_retry_interval);
607 return;
609 local client = wrapsocket(conn, self, nil, self.listeners);
610 client:debug("New connection %s on server %s", client, self);
611 client:init();
612 if self.tls_direct then
613 client:starttls(self.tls_ctx);
614 else
615 client:onconnect();
619 -- Initialization
620 function interface:init()
621 self:setwritetimeout(cfg.connect_timeout);
622 return self:add(true, true);
625 function interface:pause()
626 return self:set(false);
629 function interface:resume()
630 return self:set(true);
633 -- Pause connection for some time
634 function interface:pausefor(t)
635 self:debug("Pause for %fs", t);
636 if self._pausefor then
637 self._pausefor:close();
639 if t == false then return; end
640 self:set(false);
641 self._pausefor = addtimer(t, function ()
642 self._pausefor = nil;
643 self:set(true);
644 if self.conn and self.conn:dirty() then
645 self:onreadable();
647 end);
650 function interface:setlimit(Bps)
651 if Bps > 0 then
652 self._limit = 1/Bps;
653 else
654 self._limit = nil;
658 function interface:pause_writes()
659 self._write_lock = true;
660 self:setwritetimeout(false);
661 self:set(nil, false);
664 function interface:resume_writes()
665 self._write_lock = nil;
666 if self.writebuffer[1] then
667 self:setwritetimeout();
668 self:set(nil, true);
672 -- Connected!
673 function interface:onconnect()
674 self:updatenames();
675 self.onconnect = noop;
676 self:on("connect");
679 local function listen(addr, port, listeners, config)
680 local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
681 if not conn then return conn, err; end
682 conn:settimeout(0);
683 local server = setmetatable({
684 conn = conn;
685 created = gettime();
686 listeners = listeners;
687 read_size = config and config.read_size;
688 onreadable = interface.onacceptable;
689 tls_ctx = config and config.tls_ctx;
690 tls_direct = config and config.tls_direct;
691 hosts = config and config.sni_hosts;
692 sockname = addr;
693 sockport = port;
694 log = logger.init(("serv%s"):format(new_id()));
695 }, interface_mt);
696 server:debug("Server %s created", server);
697 server:add(true, false);
698 return server;
701 -- COMPAT
702 local function addserver(addr, port, listeners, read_size, tls_ctx)
703 return listen(addr, port, listeners, {
704 read_size = read_size;
705 tls_ctx = tls_ctx;
706 tls_direct = tls_ctx and true or false;
710 -- COMPAT
711 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
712 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
713 if not client.peername then
714 client.peername, client.peerport = addr, port;
716 local ok, err = client:init();
717 if not ok then return ok, err; end
718 if tls_ctx then
719 client:starttls(tls_ctx);
721 return client;
724 -- New outgoing TCP connection
725 local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
726 local create;
727 if not typ then
728 local n = inet_pton(addr);
729 if not n then return nil, "invalid-ip"; end
730 if #n == 16 then
731 typ = "tcp6";
732 else
733 typ = "tcp4";
736 if typ then
737 create = socket[typ];
739 if type(create) ~= "function" then
740 return nil, "invalid socket type";
742 local conn, err = create();
743 if not conn then return conn, err; end
744 local ok, err = conn:settimeout(0);
745 if not ok then return ok, err; end
746 local ok, err = conn:setpeername(addr, port);
747 if not ok and err ~= "timeout" then return ok, err; end
748 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
749 local ok, err = client:init();
750 if not ok then return ok, err; end
751 if tls_ctx then
752 client:starttls(tls_ctx);
754 client:debug("Client %s created", client);
755 return client, conn;
758 local function watchfd(fd, onreadable, onwritable)
759 local conn = setmetatable({
760 conn = fd;
761 onreadable = onreadable;
762 onwritable = onwritable;
763 close = function (self)
764 self:del();
766 }, interface_mt);
767 if type(fd) == "number" then
768 conn.getfd = function ()
769 return fd;
770 end;
771 -- Otherwise it'll need to be something LuaSocket-compatible
773 conn.log = logger.init(("fdwatch%s"):format(new_id()));
774 conn:add(onreadable, onwritable);
775 return conn;
776 end;
778 -- Dump all data from one connection into another
779 local function link(from, to)
780 from.listeners = setmetatable({
781 onincoming = function (_, data)
782 from:pause();
783 to:write(data);
784 end,
785 }, {__index=from.listeners});
786 to.listeners = setmetatable({
787 ondrain = function ()
788 from:resume();
789 end,
790 }, {__index=to.listeners});
791 from:set(true, nil);
792 to:set(nil, true);
795 -- COMPAT
796 -- net.adns calls this but then replaces :send so this can be a noop
797 function interface:set_send(new_send) -- luacheck: ignore 212
800 -- Close all connections and servers
801 local function closeall()
802 for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
803 conn:close();
807 local quitting = nil;
809 -- Signal main loop about shutdown via above upvalue
810 local function setquitting(quit)
811 if quit then
812 quitting = "quitting";
813 closeall();
814 else
815 quitting = nil;
819 -- Main loop
820 local function loop(once)
821 repeat
822 local t = runtimers(cfg.max_wait, cfg.min_wait);
823 local fd, r, w = poll:wait(t);
824 if fd then
825 local conn = fds[fd];
826 if conn then
827 if r then
828 conn:onreadable();
830 if w then
831 conn:onwritable();
833 else
834 log("debug", "Removing unknown fd %d", fd);
835 poll:del(fd);
837 elseif r ~= "timeout" and r ~= "signal" then
838 log("debug", "epoll_wait error: %s[%d]", r, w);
840 until once or (quitting and next(fds) == nil);
841 return quitting;
844 return {
845 get_backend = function () return "epoll"; end;
846 addserver = addserver;
847 addclient = addclient;
848 add_task = addtimer;
849 listen = listen;
850 at = at;
851 loop = loop;
852 closeall = closeall;
853 setquitting = setquitting;
854 wrapclient = wrapclient;
855 watchfd = watchfd;
856 link = link;
857 set_config = function (newconfig)
858 cfg = setmetatable(newconfig, default_config);
859 end;
861 -- libevent emulation
862 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
863 addevent = function (fd, mode, callback)
864 log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
865 local function onevent(self)
866 local ret = self:callback();
867 if ret == -1 then
868 self:set(false, false);
869 elseif ret then
870 self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
874 local conn = setmetatable({
875 getfd = function () return fd; end;
876 callback = callback;
877 onreadable = onevent;
878 onwritable = onevent;
879 close = function (self)
880 self:del();
881 fds[fd] = nil;
882 end;
883 }, interface_mt);
884 conn.log = logger.init(("fdwatch%d"):format(conn:getfd()));
885 local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
886 if not ok then return ok, err; end
887 return conn;
888 end;