util.x509: Return sets of services per identity
[prosody.git] / net / server_select.lua
blobe15f529896108ab3020473efe41bd1fcb9a315eb
1 --
2 -- server.lua by blastbeat of the luadch project
3 -- Re-used here under the MIT/X Consortium License
4 --
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
6 --
8 -- // wrapping luadch stuff // --
10 local use = function( what )
11 return _G[ what ]
12 end
14 local log, table_concat = require ("util.logger").init("socket"), table.concat;
15 local out_put = function (...) return log("debug", table_concat{...}); end
16 local out_error = function (...) return log("warn", table_concat{...}); end
18 ----------------------------------// DECLARATION //--
20 --// constants //--
22 local STAT_UNIT = 1 -- byte
24 --// lua functions //--
26 local type = use "type"
27 local pairs = use "pairs"
28 local ipairs = use "ipairs"
29 local tonumber = use "tonumber"
30 local tostring = use "tostring"
32 --// lua libs //--
34 local table = use "table"
35 local string = use "string"
36 local coroutine = use "coroutine"
38 --// lua lib methods //--
40 local math_min = math.min
41 local math_huge = math.huge
42 local table_concat = table.concat
43 local table_insert = table.insert
44 local string_sub = string.sub
45 local coroutine_wrap = coroutine.wrap
46 local coroutine_yield = coroutine.yield
48 --// extern libs //--
50 local has_luasec, luasec = pcall ( require , "ssl" )
51 local luasocket = use "socket" or require "socket"
52 local luasocket_gettime = luasocket.gettime
53 local inet = require "util.net";
54 local inet_pton = inet.pton;
56 --// extern lib methods //--
58 local ssl_wrap = ( has_luasec and luasec.wrap )
59 local socket_bind = luasocket.bind
60 local socket_select = luasocket.select
62 --// functions //--
64 local id
65 local loop
66 local stats
67 local idfalse
68 local closeall
69 local addsocket
70 local addserver
71 local listen
72 local addtimer
73 local getserver
74 local wrapserver
75 local getsettings
76 local closesocket
77 local removesocket
78 local removeserver
79 local wrapconnection
80 local changesettings
82 --// tables //--
84 local _server
85 local _readlist
86 local _timerlist
87 local _sendlist
88 local _socketlist
89 local _closelist
90 local _readtimes
91 local _writetimes
92 local _fullservers
94 --// simple data types //--
96 local _
97 local _readlistlen
98 local _sendlistlen
99 local _timerlistlen
101 local _sendtraffic
102 local _readtraffic
104 local _selecttimeout
105 local _tcpbacklog
106 local _accepretry
108 local _starttime
109 local _currenttime
111 local _maxsendlen
112 local _maxreadlen
114 local _checkinterval
115 local _sendtimeout
116 local _readtimeout
118 local _maxselectlen
119 local _maxfd
121 local _maxsslhandshake
123 ----------------------------------// DEFINITION //--
125 _server = { } -- key = port, value = table; list of listening servers
126 _readlist = { } -- array with sockets to read from
127 _sendlist = { } -- array with sockets to write to
128 _timerlist = { } -- array of timer functions
129 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
130 _readtimes = { } -- key = handler, value = timestamp of last data reading
131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
132 _closelist = { } -- handlers to close
133 _fullservers = { } -- servers in a paused state while there are too many clients
135 _readlistlen = 0 -- length of readlist
136 _sendlistlen = 0 -- length of sendlist
137 _timerlistlen = 0 -- length of timerlist
139 _sendtraffic = 0 -- some stats
140 _readtraffic = 0
142 _selecttimeout = 1 -- timeout of socket.select
143 _tcpbacklog = 128 -- some kind of hint to the OS
144 _accepretry = 10 -- seconds to wait until the next attempt of a full server to accept
146 _maxsendlen = 51000 * 1024 -- max len of send buffer
147 _maxreadlen = 25000 * 1024 -- max len of read buffer
149 _checkinterval = 30 -- interval in secs to check idle clients
150 _sendtimeout = 60000 -- allowed send idle time in secs
151 _readtimeout = 14 * 60 -- allowed read idle time in secs
153 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
154 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
155 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
157 _maxsslhandshake = 30 -- max handshake round-trips
159 ----------------------------------// PRIVATE //--
161 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
163 if socket:getfd() >= _maxfd then
164 out_error("server.lua: Disallowed FD number: "..socket:getfd())
165 socket:close()
166 return nil, "fd-too-large"
169 local connections = 0
171 local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
173 local accept = socket.accept
175 --// public methods of the object //--
177 local handler = { }
179 handler.shutdown = function( ) end
181 handler.ssl = function( )
182 return sslctx ~= nil
184 handler.sslctx = function( )
185 return sslctx
187 handler.hosts = {} -- sni
188 handler.remove = function( )
189 connections = connections - 1
190 if handler then
191 handler.resume( )
194 handler.close = function()
195 socket:close( )
196 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
197 _readlistlen = removesocket( _readlist, socket, _readlistlen )
198 _server[ip..":"..serverport] = nil;
199 _socketlist[ socket ] = nil
200 handler = nil
201 socket = nil
202 --mem_free( )
203 out_put "server.lua: closed server handler and removed sockets from list"
205 handler.pause = function( hard )
206 if not handler.paused then
207 _readlistlen = removesocket( _readlist, socket, _readlistlen )
208 if hard then
209 _socketlist[ socket ] = nil
210 socket:close( )
211 socket = nil;
213 handler.paused = true;
214 out_put("server.lua: server [", ip, "]:", serverport, " paused")
217 handler.resume = function( )
218 if handler.paused then
219 if not socket then
220 socket = socket_bind( ip, serverport, _tcpbacklog );
221 socket:settimeout( 0 )
223 _readlistlen = addsocket(_readlist, socket, _readlistlen)
224 _socketlist[ socket ] = handler
225 _fullservers[ handler ] = nil
226 handler.paused = false;
227 out_put("server.lua: server [", ip, "]:", serverport, " resumed")
230 handler.ip = function( )
231 return ip
233 handler.serverport = function( )
234 return serverport
236 handler.socket = function( )
237 return socket
239 handler.readbuffer = function( )
240 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
241 handler.pause( )
242 _fullservers[ handler ] = _currenttime
243 out_put( "server.lua: refused new client connection: server full" )
244 return false
246 local client, err = accept( socket ) -- try to accept
247 if client then
248 local ip, clientport = client:getpeername( )
249 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
250 if err then -- error while wrapping ssl socket
251 return false
253 connections = connections + 1
254 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
255 if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
256 return dispatch( handler );
258 return;
259 elseif err then -- maybe timeout or something else
260 out_put( "server.lua: error with new client connection: ", tostring(err) )
261 handler.pause( )
262 _fullservers[ handler ] = _currenttime
263 return false
266 return handler
269 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object
271 if socket:getfd() >= _maxfd then
272 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
273 socket:close( ) -- Should we send some kind of error here?
274 if server then
275 _fullservers[ server ] = _currenttime
276 server.pause( )
278 return nil, nil, "fd-too-large"
280 socket:settimeout( 0 )
282 --// local import of socket methods //--
284 local send
285 local receive
286 local shutdown
288 --// private closures of the object //--
290 local ssl
292 local dispatch = listeners.onincoming
293 local status = listeners.onstatus
294 local disconnect = listeners.ondisconnect
295 local drain = listeners.ondrain
296 local onreadtimeout = listeners.onreadtimeout;
297 local detach = listeners.ondetach
299 local bufferqueue = { } -- buffer array
300 local bufferqueuelen = 0 -- end of buffer array
302 local toclose
303 local needtls
305 local bufferlen = 0
307 local noread = false
308 local nosend = false
310 local sendtraffic, readtraffic = 0, 0
312 local maxsendlen = _maxsendlen
313 local maxreadlen = _maxreadlen
315 --// public methods of the object //--
317 local handler = bufferqueue -- saves a table ^_^
319 handler.extra = extra
320 if extra then
321 handler.servername = extra.servername
324 handler.dispatch = function( )
325 return dispatch
327 handler.disconnect = function( )
328 return disconnect
330 handler.onreadtimeout = onreadtimeout;
332 handler.setlistener = function( self, listeners, data )
333 if detach then
334 detach(self) -- Notify listener that it is no longer responsible for this connection
336 dispatch = listeners.onincoming
337 disconnect = listeners.ondisconnect
338 status = listeners.onstatus
339 drain = listeners.ondrain
340 handler.onreadtimeout = listeners.onreadtimeout
341 detach = listeners.ondetach
342 if listeners.onattach then
343 listeners.onattach(self, data)
346 handler.getstats = function( )
347 return readtraffic, sendtraffic
349 handler.ssl = function( )
350 return ssl
352 handler.sslctx = function ( )
353 return sslctx
355 handler.send = function( _, data, i, j )
356 return send( socket, data, i, j )
358 handler.receive = function( pattern, prefix )
359 return receive( socket, pattern, prefix )
361 handler.shutdown = function( pattern )
362 return shutdown( socket, pattern )
364 handler.setoption = function (self, option, value)
365 if socket.setoption then
366 return socket:setoption(option, value);
368 return false, "setoption not implemented";
370 handler.force_close = function ( self, err )
371 if bufferqueuelen ~= 0 then
372 out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
373 bufferqueuelen = 0;
375 return self:close(err);
377 handler.close = function( self, err )
378 if not handler then return true; end
379 _readlistlen = removesocket( _readlist, socket, _readlistlen )
380 _readtimes[ handler ] = nil
381 if bufferqueuelen ~= 0 then
382 handler.sendbuffer() -- Try now to send any outstanding data
383 if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
384 if handler then
385 handler.write = nil -- ... but no further writing allowed
387 toclose = true
388 return false
391 if socket then
392 _ = shutdown and shutdown( socket )
393 socket:close( )
394 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
395 _socketlist[ socket ] = nil
396 socket = nil
397 else
398 out_put "server.lua: socket already closed"
400 if handler then
401 _writetimes[ handler ] = nil
402 _closelist[ handler ] = nil
403 local _handler = handler;
404 handler = nil
405 if disconnect then
406 disconnect(_handler, err or false);
407 disconnect = nil
410 if server then
411 server.remove( )
413 out_put "server.lua: closed client handler and removed socket from list"
414 return true
416 handler.server = function ( )
417 return server
419 handler.ip = function( )
420 return ip
422 handler.serverport = function( )
423 return serverport
425 handler.clientport = function( )
426 return clientport
428 handler.port = handler.clientport -- COMPAT server_event
429 local write = function( self, data )
430 if not handler then return false end
431 bufferlen = bufferlen + #data
432 if bufferlen > maxsendlen then
433 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
434 return false
435 elseif not nosend and socket and not _sendlist[ socket ] then
436 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
438 bufferqueuelen = bufferqueuelen + 1
439 bufferqueue[ bufferqueuelen ] = data
440 if handler then
441 _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
443 return true
445 handler.write = write
446 handler.bufferqueue = function( self )
447 return bufferqueue
449 handler.socket = function( self )
450 return socket
452 handler.set_mode = function( self, new )
453 pattern = new or pattern
454 return pattern
456 handler.set_send = function ( self, newsend )
457 send = newsend or send
458 return send
460 handler.bufferlen = function( self, readlen, sendlen )
461 maxsendlen = sendlen or maxsendlen
462 maxreadlen = readlen or maxreadlen
463 return bufferlen, maxreadlen, maxsendlen
465 handler.lock_read = function (self, switch)
466 out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
467 if switch == true then
468 return self:pause()
469 elseif switch == false then
470 return self:resume()
472 return noread
474 handler.pause = function (self)
475 local tmp = _readlistlen
476 _readlistlen = removesocket( _readlist, socket, _readlistlen )
477 _readtimes[ handler ] = nil
478 if _readlistlen ~= tmp then
479 noread = true
481 return noread;
483 handler.resume = function (self)
484 if noread then
485 noread = false
486 _readlistlen = addsocket(_readlist, socket, _readlistlen)
487 _readtimes[ handler ] = _currenttime
489 return noread;
491 handler.lock = function( self, switch )
492 out_error( "server.lua, lock() is deprecated" )
493 handler.lock_read (self, switch)
494 if switch == true then
495 handler.pause_writes (self)
496 elseif switch == false then
497 handler.resume_writes (self)
499 return noread, nosend
501 handler.pause_writes = function (self)
502 local tmp = _sendlistlen
503 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
504 _writetimes[ handler ] = nil
505 nosend = true
507 handler.resume_writes = function (self)
508 nosend = false
509 if bufferlen > 0 then
510 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
514 local _readbuffer = function( ) -- this function reads data
515 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
516 if not err or (err == "wantread" or err == "timeout") then -- received something
517 local buffer = buffer or part or ""
518 local len = #buffer
519 if len > maxreadlen then
520 handler:close( "receive buffer exceeded" )
521 return false
523 local count = len * STAT_UNIT
524 readtraffic = readtraffic + count
525 _readtraffic = _readtraffic + count
526 _readtimes[ handler ] = _currenttime
527 --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
528 return dispatch( handler, buffer, err )
529 else -- connections was closed or fatal error
530 out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
531 _ = handler and handler:force_close( err )
532 return false
535 local _sendbuffer = function( ) -- this function sends data
536 local succ, err, byte, buffer, count;
537 if socket then
538 buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
539 succ, err, byte = send( socket, buffer, 1, bufferlen )
540 count = ( succ or byte or 0 ) * STAT_UNIT
541 sendtraffic = sendtraffic + count
542 _sendtraffic = _sendtraffic + count
543 for i = bufferqueuelen,1,-1 do
544 bufferqueue[ i ] = nil
546 --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
547 else
548 succ, err, count = false, "unexpected close", 0;
550 if succ then -- sending successful
551 bufferqueuelen = 0
552 bufferlen = 0
553 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
554 _writetimes[ handler ] = nil
555 if drain then
556 drain(handler)
558 _ = needtls and handler:starttls(nil)
559 _ = toclose and handler:force_close( )
560 return true
561 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
562 buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
563 bufferqueue[ 1 ] = buffer -- insert new buffer in queue
564 bufferqueuelen = 1
565 bufferlen = bufferlen - byte
566 _writetimes[ handler ] = _currenttime
567 return true
568 else -- connection was closed during sending or fatal error
569 out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
570 _ = handler and handler:force_close( err )
571 return false
575 -- Set the sslctx
576 local handshake;
577 function handler.set_sslctx(self, new_sslctx)
578 sslctx = new_sslctx;
579 local read, wrote
580 handshake = coroutine_wrap( function( client ) -- create handshake coroutine
581 local err
582 for _ = 1, _maxsslhandshake do
583 _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
584 _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
585 read, wrote = nil, nil
586 _, err = client:dohandshake( )
587 if not err then
588 out_put( "server.lua: ssl handshake done" )
589 handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
590 handler.sendbuffer = _sendbuffer
591 _ = status and status( handler, "ssl-handshake-complete" )
592 if self.autostart_ssl and listeners.onconnect then
593 listeners.onconnect(self);
594 if bufferqueuelen ~= 0 then
595 _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
598 _readlistlen = addsocket(_readlist, client, _readlistlen)
599 return true
600 else
601 if err == "wantwrite" then
602 _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
603 wrote = true
604 elseif err == "wantread" then
605 _readlistlen = addsocket(_readlist, client, _readlistlen)
606 read = true
607 else
608 break;
610 err = nil;
611 coroutine_yield( ) -- handshake not finished
614 err = "ssl handshake error: " .. ( err or "handshake too long" );
615 out_put( "server.lua: ", err );
616 _ = handler and handler:force_close(err)
617 return false, err -- handshake failed
621 if has_luasec then
622 handler.starttls = function( self, _sslctx)
623 if _sslctx then
624 handler:set_sslctx(_sslctx);
626 if bufferqueuelen > 0 then
627 out_put "server.lua: we need to do tls, but delaying until send buffer empty"
628 needtls = true
629 return
631 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
632 local oldsocket, err = socket
633 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
635 if not socket then
636 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
637 return nil, err -- fatal error
640 if socket.sni then
641 if self.servername then
642 socket:sni(self.servername);
643 elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
644 socket:sni(self.server().hosts, true);
648 socket:settimeout( 0 )
650 -- add the new socket to our system
651 send = socket.send
652 receive = socket.receive
653 shutdown = id
654 _socketlist[ socket ] = handler
655 _readlistlen = addsocket(_readlist, socket, _readlistlen)
657 -- remove traces of the old socket
658 _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
659 _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
660 _socketlist[ oldsocket ] = nil
662 handler.starttls = nil
663 needtls = nil
665 -- Secure now (if handshake fails connection will close)
666 ssl = true
668 handler.readbuffer = handshake
669 handler.sendbuffer = handshake
670 return handshake( socket ) -- do handshake
674 handler.readbuffer = _readbuffer
675 handler.sendbuffer = _sendbuffer
676 send = socket.send
677 receive = socket.receive
678 shutdown = ( ssl and id ) or socket.shutdown
680 _socketlist[ socket ] = handler
681 _readlistlen = addsocket(_readlist, socket, _readlistlen)
683 if sslctx and ssldirect and has_luasec then
684 out_put "server.lua: auto-starting ssl negotiation..."
685 handler.autostart_ssl = true;
686 local ok, err = handler:starttls(sslctx);
687 if ok == false then
688 return nil, nil, err
692 return handler, socket
695 id = function( )
698 idfalse = function( )
699 return false
702 addsocket = function( list, socket, len )
703 if not list[ socket ] then
704 len = len + 1
705 list[ len ] = socket
706 list[ socket ] = len
708 return len;
711 removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
712 local pos = list[ socket ]
713 if pos then
714 list[ socket ] = nil
715 local last = list[ len ]
716 list[ len ] = nil
717 if last ~= socket then
718 list[ last ] = pos
719 list[ pos ] = last
721 return len - 1
723 return len
726 closesocket = function( socket )
727 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
728 _readlistlen = removesocket( _readlist, socket, _readlistlen )
729 _socketlist[ socket ] = nil
730 socket:close( )
731 --mem_free( )
734 local function link(sender, receiver, buffersize)
735 local sender_locked;
736 local _sendbuffer = receiver.sendbuffer;
737 function receiver.sendbuffer()
738 _sendbuffer();
739 if sender_locked and receiver.bufferlen() < buffersize then
740 sender:lock_read(false); -- Unlock now
741 sender_locked = nil;
745 local _readbuffer = sender.readbuffer;
746 function sender.readbuffer()
747 _readbuffer();
748 if not sender_locked and receiver.bufferlen() >= buffersize then
749 sender_locked = true;
750 sender:lock_read(true);
753 sender:set_mode("*a");
756 ----------------------------------// PUBLIC //--
758 listen = function ( addr, port, listeners, config )
759 addr = addr or "*"
760 config = config or {}
761 local err
762 local sslctx = config.tls_ctx;
763 local ssldirect = config.tls_direct;
764 local pattern = config.read_size;
765 if type( listeners ) ~= "table" then
766 err = "invalid listener table"
767 elseif type ( addr ) ~= "string" then
768 err = "invalid address"
769 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
770 err = "invalid port"
771 elseif _server[ addr..":"..port ] then
772 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
773 elseif sslctx and not has_luasec then
774 err = "luasec not found"
776 if err then
777 out_error( "server.lua, [", addr, "]:", port, ": ", err )
778 return nil, err
780 local server, err = socket_bind( addr, port, _tcpbacklog )
781 if err then
782 out_error( "server.lua, [", addr, "]:", port, ": ", err )
783 return nil, err
785 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
786 if not handler then
787 server:close( )
788 return nil, err
790 server:settimeout( 0 )
791 _readlistlen = addsocket(_readlist, server, _readlistlen)
792 _server[ addr..":"..port ] = handler
793 _socketlist[ server ] = handler
794 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
795 return handler
798 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
799 return listen(addr, port, listeners, {
800 read_size = pattern;
801 tls_ctx = sslctx;
802 tls_direct = sslctx and true or false;
806 getserver = function ( addr, port )
807 return _server[ addr..":"..port ];
810 removeserver = function( addr, port )
811 local handler = _server[ addr..":"..port ]
812 if not handler then
813 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
815 handler:close( )
816 _server[ addr..":"..port ] = nil
817 return true
820 closeall = function( )
821 for _, handler in pairs( _socketlist ) do
822 handler:close( )
823 _socketlist[ _ ] = nil
825 _readlistlen = 0
826 _sendlistlen = 0
827 _timerlistlen = 0
828 _server = { }
829 _readlist = { }
830 _sendlist = { }
831 _timerlist = { }
832 _socketlist = { }
833 --mem_free( )
836 getsettings = function( )
837 return {
838 select_timeout = _selecttimeout;
839 tcp_backlog = _tcpbacklog;
840 max_send_buffer_size = _maxsendlen;
841 max_receive_buffer_size = _maxreadlen;
842 select_idle_check_interval = _checkinterval;
843 send_timeout = _sendtimeout;
844 read_timeout = _readtimeout;
845 max_connections = _maxselectlen;
846 max_ssl_handshake_roundtrips = _maxsslhandshake;
847 highest_allowed_fd = _maxfd;
848 accept_retry_interval = _accepretry;
852 changesettings = function( new )
853 if type( new ) ~= "table" then
854 return nil, "invalid settings table"
856 _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
857 _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
858 _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
859 _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
860 _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
861 _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
862 _readtimeout = tonumber( new.read_timeout ) or _readtimeout
863 _accepretry = tonumber( new.accept_retry_interval ) or _accepretry
864 _maxselectlen = new.max_connections or _maxselectlen
865 _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
866 _maxfd = new.highest_allowed_fd or _maxfd
867 return true
870 addtimer = function( listener )
871 if type( listener ) ~= "function" then
872 return nil, "invalid listener function"
874 _timerlistlen = _timerlistlen + 1
875 _timerlist[ _timerlistlen ] = listener
876 return true
879 local add_task do
880 local data = {};
881 local new_data = {};
883 function add_task(delay, callback)
884 local current_time = luasocket_gettime();
885 delay = delay + current_time;
886 if delay >= current_time then
887 table_insert(new_data, {delay, callback});
888 else
889 local r = callback(current_time);
890 if r and type(r) == "number" then
891 return add_task(r, callback);
896 addtimer(function(current_time)
897 if #new_data > 0 then
898 for _, d in pairs(new_data) do
899 table_insert(data, d);
901 new_data = {};
904 local next_time = math_huge;
905 for i, d in pairs(data) do
906 local t, callback = d[1], d[2];
907 if t <= current_time then
908 data[i] = nil;
909 local r = callback(current_time);
910 if type(r) == "number" then
911 add_task(r, callback);
912 next_time = math_min(next_time, r);
914 else
915 next_time = math_min(next_time, t - current_time);
918 return next_time;
919 end);
922 stats = function( )
923 return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
926 local quitting;
928 local function setquitting(quit)
929 quitting = quit;
932 loop = function(once) -- this is the main loop of the program
933 if quitting then return "quitting"; end
934 if once then quitting = "once"; end
935 _currenttime = luasocket_gettime( )
936 repeat
937 -- Fire timers
938 local next_timer_time = math_huge;
939 for i = 1, _timerlistlen do
940 local t = _timerlist[ i ]( _currenttime ) -- fire timers
941 if t then next_timer_time = math_min(next_timer_time, t); end
944 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
945 for _, socket in ipairs( read ) do -- receive data
946 local handler = _socketlist[ socket ]
947 if handler then
948 handler.readbuffer( )
949 else
950 closesocket( socket )
951 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
954 for _, socket in ipairs( write ) do -- send data waiting in writequeues
955 local handler = _socketlist[ socket ]
956 if handler then
957 handler.sendbuffer( )
958 else
959 closesocket( socket )
960 out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
963 for handler, err in pairs( _closelist ) do
964 handler.disconnect( )( handler, err )
965 handler:force_close() -- forced disconnect
966 _closelist[ handler ] = nil;
968 _currenttime = luasocket_gettime( )
970 -- Check for socket timeouts
971 if _currenttime - _starttime > _checkinterval then
972 _starttime = _currenttime
973 for handler, timestamp in pairs( _writetimes ) do
974 if _currenttime - timestamp > _sendtimeout then
975 handler.disconnect( )( handler, "send timeout" )
976 handler:force_close() -- forced disconnect
979 for handler, timestamp in pairs( _readtimes ) do
980 if _currenttime - timestamp > _readtimeout then
981 if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
982 handler.disconnect( )( handler, "read timeout" )
983 handler:close( ) -- forced disconnect?
984 else
985 _readtimes[ handler ] = _currenttime -- reset timer
991 for server, paused_time in pairs( _fullservers ) do
992 if _currenttime - paused_time > _accepretry then
993 _fullservers[ server ] = nil;
994 server.resume();
997 until quitting;
998 if quitting == "once" then quitting = nil; return; end
999 closeall();
1000 return "quitting"
1003 local function step()
1004 return loop(true);
1007 local function get_backend()
1008 return "select";
1011 --// EXPERIMENTAL //--
1013 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
1014 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra)
1015 if not handler then return nil, err end
1016 _socketlist[ socket ] = handler
1017 if not sslctx then
1018 _readlistlen = addsocket(_readlist, socket, _readlistlen)
1019 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
1020 if listeners.onconnect then
1021 -- When socket is writeable, call onconnect
1022 local _sendbuffer = handler.sendbuffer;
1023 handler.sendbuffer = function ()
1024 handler.sendbuffer = _sendbuffer;
1025 listeners.onconnect(handler);
1026 return _sendbuffer(); -- Send any queued outgoing data
1030 return handler, socket
1033 local addclient = function( address, port, listeners, pattern, sslctx, typ, extra )
1034 local err
1035 if type( listeners ) ~= "table" then
1036 err = "invalid listener table"
1037 elseif type ( address ) ~= "string" then
1038 err = "invalid address"
1039 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
1040 err = "invalid port"
1041 elseif sslctx and not has_luasec then
1042 err = "luasec not found"
1044 if not typ then
1045 local n = inet_pton(address);
1046 if not n then return nil, "invalid-ip"; end
1047 if #n == 16 then
1048 typ = "tcp6";
1049 elseif #n == 4 then
1050 typ = "tcp4";
1053 local create = luasocket[typ];
1054 if type( create ) ~= "function" then
1055 err = "invalid socket type"
1058 if err then
1059 out_error( "server.lua, addclient: ", err )
1060 return nil, err
1063 local client, err = create( )
1064 if err then
1065 return nil, err
1067 client:settimeout( 0 )
1068 local ok, err = client:setpeername( address, port )
1069 if ok or err == "timeout" or err == "Operation already in progress" then
1070 return wrapclient( client, address, port, listeners, pattern, sslctx, extra )
1071 else
1072 return nil, err
1076 local closewatcher = function (handler)
1077 local socket = handler.conn;
1078 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
1079 _readlistlen = removesocket( _readlist, socket, _readlistlen )
1080 _socketlist[ socket ] = nil
1081 end;
1083 local addremove = function (handler, read, send)
1084 local socket = handler.conn
1085 _socketlist[ socket ] = handler
1086 if read ~= nil then
1087 if read then
1088 _readlistlen = addsocket( _readlist, socket, _readlistlen )
1089 else
1090 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
1093 if send ~= nil then
1094 if send then
1095 _sendlistlen = addsocket( _sendlist, socket, _sendlistlen )
1096 else
1097 _readlistlen = removesocket( _readlist, socket, _readlistlen )
1102 local watchfd = function ( fd, onreadable, onwriteable )
1103 local socket = fd
1104 if type(fd) == "number" then
1105 socket = { getfd = function () return fd; end }
1107 local handler = {
1108 conn = socket;
1109 readbuffer = onreadable or id;
1110 sendbuffer = onwriteable or id;
1111 close = closewatcher;
1112 setflags = addremove;
1114 addremove( handler, onreadable, onwriteable )
1115 return handler
1118 ----------------------------------// BEGIN //--
1120 use "setmetatable" ( _socketlist, { __mode = "k" } )
1121 use "setmetatable" ( _readtimes, { __mode = "k" } )
1122 use "setmetatable" ( _writetimes, { __mode = "k" } )
1124 _starttime = luasocket_gettime( )
1126 local function setlogger(new_logger)
1127 local old_logger = log;
1128 if new_logger then
1129 log = new_logger;
1131 return old_logger;
1134 ----------------------------------// PUBLIC INTERFACE //--
1136 return {
1137 _addtimer = addtimer,
1138 add_task = add_task;
1140 addclient = addclient,
1141 wrapclient = wrapclient,
1142 watchfd = watchfd,
1144 loop = loop,
1145 link = link,
1146 step = step,
1147 stats = stats,
1148 closeall = closeall,
1149 addserver = addserver,
1150 listen = listen,
1151 getserver = getserver,
1152 setlogger = setlogger,
1153 getsettings = getsettings,
1154 setquitting = setquitting,
1155 removeserver = removeserver,
1156 get_backend = get_backend,
1157 changesettings = changesettings,