2 -- Copyright (C) 2016-2018 Kim Alvefur
4 -- This project is MIT/X11 licensed. Please see the
5 -- COPYING file in the source package for more information.
9 local t_insert
= table.insert
;
10 local t_concat
= table.concat
;
11 local setmetatable
= setmetatable
;
16 local logger
= require
"util.logger";
17 local log = logger
.init("server_epoll");
18 local socket
= require
"socket";
19 local luasec
= require
"ssl";
20 local gettime
= require
"util.time".now
;
21 local indexedbheap
= require
"util.indexedbheap";
22 local createtable
= require
"util.table".create
;
23 local inet
= require
"util.net";
24 local inet_pton
= inet
.pton
;
25 local _SOCKETINVALID
= socket
._SOCKETINVALID
or -1;
26 local new_id
= require
"util.id".medium
;
28 local poller
= require
"util.poll"
29 local EEXIST
= poller
.EEXIST
;
30 local ENOENT
= poller
.ENOENT
;
32 local poll
= assert(poller
.new());
37 local default_config
= { __index
= {
38 -- If a connection is silent for this long, close it unless onreadtimeout says not to
39 read_timeout
= 14 * 60;
41 -- How long to wait for a socket to become writable after queuing data to send
44 -- How long to wait for a socket to become writable after creation
47 -- Some number possibly influencing how many pending connections can be accepted
50 -- If accepting a new incoming connection fails, wait this long before trying again
51 accept_retry_interval
= 10;
53 -- If there is still more data to read from LuaSocktes buffer, wait this long and read again
54 read_retry_delay
= 1e-06;
56 -- Size of chunks to read from sockets
59 -- Timeout used during between steps in TLS handshakes
60 ssl_handshake_timeout
= 60;
62 -- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
67 -- Whether to kill connections in case of callback errors.
70 -- Attempt writes instantly
71 opportunistic_writes
= false;
73 local cfg
= default_config
.__index
;
75 local fds
= createtable(10, 0); -- FD -> conn
77 -- Timer and scheduling --
79 local timers
= indexedbheap
.create();
81 local function noop() end
82 local function closetimer(t
)
88 local function reschedule(t
, time
)
90 timers
:reprioritize(t
.id
, time
);
94 local function at(time
, f
)
95 local timer
= { time
, f
, close
= closetimer
, reschedule
= reschedule
, id
= nil };
96 timer
.id
= timers
:insert(timer
, time
);
100 -- Add relative timer
101 local function addtimer(timeout
, f
)
102 return at(gettime() + timeout
, f
);
105 -- Run callbacks of expired timers
106 -- Return time until next timeout
107 local function runtimers(next_delay
, min_wait
)
108 -- Any timers at all?
109 local now
= gettime();
110 local peek
= timers
:peek();
114 next_delay
= peek
- now
;
118 local _
, timer
, id
= timers
:pop();
119 local ok
, ret
= pcall(timer
[2], now
);
120 if ok
and type(ret
) == "number" then
121 local next_time
= now
+ret
;
122 timer
[1] = next_time
;
123 timers
:insert(timer
, next_time
);
126 peek
= timers
:peek();
132 if next_delay
< min_wait
then
138 -- Socket handler interface
140 local interface
= {};
141 local interface_mt
= { __index
= interface
};
143 function interface_mt
:__tostring()
144 if self
.sockname
and self
.peername
then
145 return ("FD %d (%s, %d, %s, %d)"):format(self
:getfd(), self
.peername
, self
.peerport
, self
.sockname
, self
.sockport
);
146 elseif self
.sockname
or self
.peername
then
147 return ("FD %d (%s, %d)"):format(self
:getfd(), self
.sockname
or self
.peername
, self
.sockport
or self
.peerport
);
149 return ("FD %d"):format(self
:getfd());
153 function interface
:debug(msg
, ...) --luacheck: ignore 212/self
154 self
.log("debug", msg
, ...);
157 function interface
:error(msg
, ...) --luacheck: ignore 212/self
158 self
.log("error", msg
, ...);
161 -- Replace the listener and tell the old one
162 function interface
:setlistener(listeners
, data
)
164 self
.listeners
= listeners
;
165 self
:on("attach", data
);
168 -- Call a listener callback
169 function interface
:on(what
, ...)
170 if not self
.listeners
then
171 self
:debug("Interface is missing listener callbacks");
174 local listener
= self
.listeners
["on"..what
];
176 -- self:debug("Missing listener 'on%s'", what); -- uncomment for development and debugging
179 local ok
, err
= pcall(listener
, self
, ...);
181 if cfg
.fatal_errors
then
182 self
:debug("Closing due to error calling on%s: %s", what
, err
);
185 self
:debug("Error calling on%s: %s", what
, err
);
192 -- Return the file descriptor number
193 function interface
:getfd()
195 return self
.conn
:getfd();
197 return _SOCKETINVALID
;
200 function interface
:server()
201 return self
._server
or self
;
205 function interface
:ip()
206 return self
.peername
or self
.sockname
;
209 -- Get a port number, doesn't matter which
210 function interface
:port()
211 return self
.sockport
or self
.peerport
;
214 -- Get local port number
215 function interface
:clientport()
216 return self
.sockport
;
220 function interface
:serverport()
221 if self
.sockport
then
222 return self
.sockport
;
223 elseif self
._server
then
228 -- Return underlying socket
229 function interface
:socket()
233 function interface
:set_mode(new_mode
)
234 self
.read_size
= new_mode
;
237 function interface
:setoption(k
, v
)
238 -- LuaSec doesn't expose setoption :(
239 if self
.conn
.setoption
then
240 self
.conn
:setoption(k
, v
);
244 -- Timeout for detecting dead or idle sockets
245 function interface
:setreadtimeout(t
)
247 if self
._readtimeout
then
248 self
._readtimeout
:close();
249 self
._readtimeout
= nil;
253 t
= t
or cfg
.read_timeout
;
254 if self
._readtimeout
then
255 self
._readtimeout
:reschedule(gettime() + t
);
257 self
._readtimeout
= addtimer(t
, function ()
258 if self
:on("readtimeout") then
259 return cfg
.read_timeout
;
261 self
:on("disconnect", "read timeout");
268 -- Timeout for detecting dead sockets
269 function interface
:setwritetimeout(t
)
271 if self
._writetimeout
then
272 self
._writetimeout
:close();
273 self
._writetimeout
= nil;
277 t
= t
or cfg
.send_timeout
;
278 if self
._writetimeout
then
279 self
._writetimeout
:reschedule(gettime() + t
);
281 self
._writetimeout
= addtimer(t
, function ()
282 self
:on("disconnect", "write timeout");
288 function interface
:add(r
, w
)
289 local fd
= self
:getfd();
291 return nil, "invalid fd";
293 if r
== nil then r
= self
._wantread
; end
294 if w
== nil then w
= self
._wantwrite
; end
295 local ok
, err
, errno
= poll
:add(fd
, r
, w
);
297 if errno
== EEXIST
then
298 self
:debug("FD already registered in poller! (EEXIST)");
299 return self
:set(r
, w
); -- So try to change its flags
301 self
:debug("Could not register in poller: %s(%d)", err
, errno
);
304 self
._wantread
, self
._wantwrite
= r
, w
;
306 self
:debug("Registered in poller");
310 function interface
:set(r
, w
)
311 local fd
= self
:getfd();
313 return nil, "invalid fd";
315 if r
== nil then r
= self
._wantread
; end
316 if w
== nil then w
= self
._wantwrite
; end
317 local ok
, err
, errno
= poll
:set(fd
, r
, w
);
319 self
:debug("Could not update poller state: %s(%d)", err
, errno
);
322 self
._wantread
, self
._wantwrite
= r
, w
;
326 function interface
:del()
327 local fd
= self
:getfd();
329 return nil, "invalid fd";
331 if fds
[fd
] ~= self
then
332 return nil, "unregistered fd";
334 local ok
, err
, errno
= poll
:del(fd
);
335 if not ok
and errno
~= ENOENT
then
336 self
:debug("Could not unregister: %s(%d)", err
, errno
);
339 self
._wantread
, self
._wantwrite
= nil, nil;
341 self
:debug("Unregistered from poller");
345 function interface
:setflags(r
, w
)
346 if not(self
._wantread
or self
._wantwrite
) then
348 return true; -- no change
350 return self
:add(r
, w
);
355 return self
:set(r
, w
);
358 -- Called when socket is readable
359 function interface
:onreadable()
360 local data
, err
, partial
= self
.conn
:receive(self
.read_size
or cfg
.read_size
);
363 self
:on("incoming", data
);
365 if err
== "wantread" then
368 elseif err
== "wantwrite" then
372 if partial
and partial
~= "" then
374 self
:on("incoming", partial
, err
);
376 if err
~= "timeout" then
377 self
:on("disconnect", err
);
382 if not self
.conn
then return; end
383 if self
._limit
and (data
or partial
) then
384 local cost
= self
._limit
* #(data
or partial
);
385 if cost
> cfg
.min_wait
then
386 self
:setreadtimeout(false);
391 if self
._wantread
and self
.conn
:dirty() then
392 self
:setreadtimeout(false);
393 self
:pausefor(cfg
.read_retry_delay
);
395 self
:setreadtimeout();
399 -- Called when socket is writable
400 function interface
:onwritable()
402 if not self
.conn
then return; end -- could have been closed in onconnect
403 local buffer
= self
.writebuffer
;
404 local data
= t_concat(buffer
);
405 local ok
, err
, partial
= self
.conn
:send(data
);
407 self
:set(nil, false);
408 for i
= #buffer
, 1, -1 do
411 self
:setwritetimeout(false);
412 self
:ondrain(); -- Be aware of writes in ondrain
415 buffer
[1] = data
:sub(partial
+1);
416 for i
= #buffer
, 2, -1 do
420 self
:setwritetimeout();
422 if err
== "wantwrite" or err
== "timeout" then
424 elseif err
== "wantread" then
426 elseif err
~= "timeout" then
427 self
:on("disconnect", err
);
432 -- The write buffer has been successfully emptied
433 function interface
:ondrain()
434 return self
:on("drain");
437 -- Add data to write buffer and set flag for wanting to write
438 function interface
:write(data
)
439 local buffer
= self
.writebuffer
;
441 t_insert(buffer
, data
);
443 self
.writebuffer
= { data
};
445 if not self
._write_lock
then
446 if cfg
.opportunistic_writes
then
450 self
:setwritetimeout();
455 interface
.send
= interface
.write;
457 -- Close, possibly after writing is done
458 function interface
:close()
459 if self
.writebuffer
and self
.writebuffer
[1] then
460 self
:set(false, true); -- Flush final buffer contents
461 self
.write, self
.send
= noop
, noop
; -- No more writing
462 self
:debug("Close after writing");
463 self
.ondrain
= interface
.close
;
465 self
:debug("Closing now");
466 self
.write, self
.send
= noop
, noop
;
468 self
:on("disconnect");
473 function interface
:destroy()
475 self
:setwritetimeout(false);
476 self
:setreadtimeout(false);
477 self
.onreadable
= noop
;
478 self
.onwritable
= noop
;
486 function interface
:ssl()
490 function interface
:starttls(tls_ctx
)
491 if tls_ctx
then self
.tls_ctx
= tls_ctx
; end
492 self
.starttls
= false;
493 if self
.writebuffer
and self
.writebuffer
[1] then
494 self
:debug("Start TLS after write");
495 self
.ondrain
= interface
.starttls
;
496 self
:set(nil, true); -- make sure wantwrite is set
498 if self
.ondrain
== interface
.starttls
then
501 self
.onwritable
= interface
.tlshandskake
;
502 self
.onreadable
= interface
.tlshandskake
;
503 self
:set(true, true);
504 self
:debug("Prepared to start TLS");
508 function interface
:tlshandskake()
509 self
:setwritetimeout(false);
510 self
:setreadtimeout(false);
511 if not self
._tls
then
513 self
:debug("Starting TLS now");
515 local ok
, conn
, err
= pcall(luasec
.wrap
, self
.conn
, self
.tls_ctx
);
517 conn
, err
= ok
, conn
;
518 self
:debug("Failed to initialize TLS: %s", err
);
521 self
:on("disconnect", err
);
528 if self
.servername
then
529 conn
:sni(self
.servername
);
530 elseif self
._server
and type(self
._server
.hosts
) == "table" and next(self
._server
.hosts
) ~= nil then
531 conn
:sni(self
._server
.hosts
, true);
536 self
.onwritable
= interface
.tlshandskake
;
537 self
.onreadable
= interface
.tlshandskake
;
540 local ok
, err
= self
.conn
:dohandshake();
542 self
:debug("TLS handshake complete");
543 self
.onwritable
= nil;
544 self
.onreadable
= nil;
545 self
:on("status", "ssl-handshake-complete");
546 self
:setwritetimeout();
547 self
:set(true, true);
548 elseif err
== "wantread" then
549 self
:debug("TLS handshake to wait until readable");
550 self
:set(true, false);
551 self
:setreadtimeout(cfg
.ssl_handshake_timeout
);
552 elseif err
== "wantwrite" then
553 self
:debug("TLS handshake to wait until writable");
554 self
:set(false, true);
555 self
:setwritetimeout(cfg
.ssl_handshake_timeout
);
557 self
:debug("TLS handshake error: %s", err
);
558 self
:on("disconnect", err
);
563 local function wrapsocket(client
, server
, read_size
, listeners
, tls_ctx
, extra
) -- luasocket object -> interface object
564 client
:settimeout(0);
565 local conn
= setmetatable({
569 listeners
= listeners
;
570 read_size
= read_size
or (server
and server
.read_size
);
572 tls_ctx
= tls_ctx
or (server
and server
.tls_ctx
);
573 tls_direct
= server
and server
.tls_direct
;
574 log = logger
.init(("conn%s"):format(new_id()));
579 if extra
.servername
then
580 conn
.servername
= extra
.servername
;
588 function interface
:updatenames()
589 local conn
= self
.conn
;
590 local ok
, peername
, peerport
= pcall(conn
.getpeername
, conn
);
592 self
.peername
, self
.peerport
= peername
, peerport
;
594 local ok
, sockname
, sockport
= pcall(conn
.getsockname
, conn
);
596 self
.sockname
, self
.sockport
= sockname
, sockport
;
600 -- A server interface has new incoming connections waiting
601 -- This replaces the onreadable callback
602 function interface
:onacceptable()
603 local conn
, err
= self
.conn
:accept();
605 self
:debug("Error accepting new client: %s, server will be paused for %ds", err
, cfg
.accept_retry_interval
);
606 self
:pausefor(cfg
.accept_retry_interval
);
609 local client
= wrapsocket(conn
, self
, nil, self
.listeners
);
610 client
:debug("New connection %s on server %s", client
, self
);
612 if self
.tls_direct
then
613 client
:starttls(self
.tls_ctx
);
620 function interface
:init()
621 self
:setwritetimeout(cfg
.connect_timeout
);
622 return self
:add(true, true);
625 function interface
:pause()
626 return self
:set(false);
629 function interface
:resume()
630 return self
:set(true);
633 -- Pause connection for some time
634 function interface
:pausefor(t
)
635 self
:debug("Pause for %fs", t
);
636 if self
._pausefor
then
637 self
._pausefor
:close();
639 if t
== false then return; end
641 self
._pausefor
= addtimer(t
, function ()
642 self
._pausefor
= nil;
644 if self
.conn
and self
.conn
:dirty() then
650 function interface
:setlimit(Bps
)
658 function interface
:pause_writes()
659 self
._write_lock
= true;
660 self
:setwritetimeout(false);
661 self
:set(nil, false);
664 function interface
:resume_writes()
665 self
._write_lock
= nil;
666 if self
.writebuffer
[1] then
667 self
:setwritetimeout();
673 function interface
:onconnect()
675 self
.onconnect
= noop
;
679 local function listen(addr
, port
, listeners
, config
)
680 local conn
, err
= socket
.bind(addr
, port
, cfg
.tcp_backlog
);
681 if not conn
then return conn
, err
; end
683 local server
= setmetatable({
686 listeners
= listeners
;
687 read_size
= config
and config
.read_size
;
688 onreadable
= interface
.onacceptable
;
689 tls_ctx
= config
and config
.tls_ctx
;
690 tls_direct
= config
and config
.tls_direct
;
691 hosts
= config
and config
.sni_hosts
;
694 log = logger
.init(("serv%s"):format(new_id()));
696 server
:debug("Server %s created", server
);
697 server
:add(true, false);
702 local function addserver(addr
, port
, listeners
, read_size
, tls_ctx
)
703 return listen(addr
, port
, listeners
, {
704 read_size
= read_size
;
706 tls_direct
= tls_ctx
and true or false;
711 local function wrapclient(conn
, addr
, port
, listeners
, read_size
, tls_ctx
, extra
)
712 local client
= wrapsocket(conn
, nil, read_size
, listeners
, tls_ctx
, extra
);
713 if not client
.peername
then
714 client
.peername
, client
.peerport
= addr
, port
;
716 local ok
, err
= client
:init();
717 if not ok
then return ok
, err
; end
719 client
:starttls(tls_ctx
);
724 -- New outgoing TCP connection
725 local function addclient(addr
, port
, listeners
, read_size
, tls_ctx
, typ
, extra
)
728 local n
= inet_pton(addr
);
729 if not n
then return nil, "invalid-ip"; end
737 create
= socket
[typ
];
739 if type(create
) ~= "function" then
740 return nil, "invalid socket type";
742 local conn
, err
= create();
743 if not conn
then return conn
, err
; end
744 local ok
, err
= conn
:settimeout(0);
745 if not ok
then return ok
, err
; end
746 local ok
, err
= conn
:setpeername(addr
, port
);
747 if not ok
and err
~= "timeout" then return ok
, err
; end
748 local client
= wrapsocket(conn
, nil, read_size
, listeners
, tls_ctx
, extra
)
749 local ok
, err
= client
:init();
750 if not ok
then return ok
, err
; end
752 client
:starttls(tls_ctx
);
754 client
:debug("Client %s created", client
);
758 local function watchfd(fd
, onreadable
, onwritable
)
759 local conn
= setmetatable({
761 onreadable
= onreadable
;
762 onwritable
= onwritable
;
763 close
= function (self
)
767 if type(fd
) == "number" then
768 conn
.getfd
= function ()
771 -- Otherwise it'll need to be something LuaSocket-compatible
773 conn
.log = logger
.init(("fdwatch%s"):format(new_id()));
774 conn
:add(onreadable
, onwritable
);
778 -- Dump all data from one connection into another
779 local function link(from
, to
)
780 from
.listeners
= setmetatable({
781 onincoming
= function (_
, data
)
785 }, {__index
=from
.listeners
});
786 to
.listeners
= setmetatable({
787 ondrain
= function ()
790 }, {__index
=to
.listeners
});
796 -- net.adns calls this but then replaces :send so this can be a noop
797 function interface
:set_send(new_send
) -- luacheck: ignore 212
800 -- Close all connections and servers
801 local function closeall()
802 for fd
, conn
in pairs(fds
) do -- luacheck: ignore 213/fd
807 local quitting
= nil;
809 -- Signal main loop about shutdown via above upvalue
810 local function setquitting(quit
)
812 quitting
= "quitting";
820 local function loop(once
)
822 local t
= runtimers(cfg
.max_wait
, cfg
.min_wait
);
823 local fd
, r
, w
= poll
:wait(t
);
825 local conn
= fds
[fd
];
834 log("debug", "Removing unknown fd %d", fd
);
837 elseif r
~= "timeout" and r
~= "signal" then
838 log("debug", "epoll_wait error: %s[%d]", r
, w
);
840 until once
or (quitting
and next(fds
) == nil);
845 get_backend
= function () return "epoll"; end;
846 addserver
= addserver
;
847 addclient
= addclient
;
853 setquitting
= setquitting
;
854 wrapclient
= wrapclient
;
857 set_config
= function (newconfig
)
858 cfg
= setmetatable(newconfig
, default_config
);
861 -- libevent emulation
862 event
= { EV_READ
= "r", EV_WRITE
= "w", EV_READWRITE
= "rw", EV_LEAVE
= -1 };
863 addevent
= function (fd
, mode
, callback
)
864 log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
865 local function onevent(self
)
866 local ret
= self
:callback();
868 self
:set(false, false);
870 self
:set(mode
== "r" or mode
== "rw", mode
== "w" or mode
== "rw");
874 local conn
= setmetatable({
875 getfd
= function () return fd
; end;
877 onreadable
= onevent
;
878 onwritable
= onevent
;
879 close
= function (self
)
884 conn
.log = logger
.init(("fdwatch%d"):format(conn
:getfd()));
885 local ok
, err
= conn
:add(mode
== "r" or mode
== "rw", mode
== "w" or mode
== "rw");
886 if not ok
then return ok
, err
; end