Update README.md
[KisSync.git] / src / io / ioserver.js
blobf23a932a897c206a2ee18982220b96f9d8ad95bf
1 import sio from 'socket.io';
2 import db from '../database';
3 import User from '../user';
4 import Server from '../server';
5 import Config from '../config';
6 const cookieParser = require("cookie-parser")(Config.get("http.cookie-secret"));
7 import typecheck from 'json-typecheck';
8 import { isTorExit } from '../tor';
9 import session from '../session';
10 import { verifyIPSessionCookie } from '../web/middleware/ipsessioncookie';
11 import Promise from 'bluebird';
12 const verifySession = Promise.promisify(session.verifySession);
13 const getAliases = Promise.promisify(db.getAliases);
14 import { CachingGlobalBanlist } from './globalban';
15 import proxyaddr from 'proxy-addr';
16 import { Counter, Gauge } from 'prom-client';
17 import { TokenBucket } from '../util/token-bucket';
18 import http from 'http';
20 const LOGGER = require('@calzoneman/jsli')('ioserver');
22 const rateLimitExceeded = new Counter({
23     name: 'cytube_socketio_rate_limited_total',
24     help: 'Number of socket.io connections rejected due to exceeding rate limit'
25 });
26 const connLimitExceeded = new Counter({
27     name: 'cytube_socketio_conn_limited_total',
28     help: 'Number of socket.io connections rejected due to exceeding conn limit'
29 });
30 const authFailureCount = new Counter({
31     name: 'cytube_socketio_auth_error_total',
32     help: 'Number of failed authentications from session middleware'
33 });
35 class IOServer {
36     constructor(options = {
37         proxyTrustFn: proxyaddr.compile('127.0.0.1')
38     }) {
39         ({
40             proxyTrustFn: this.proxyTrustFn
41         } = options);
43         this.ipThrottle = new Map();
44         this.ipCount = new Map();
45     }
47     // Map proxied sockets to the real IP address via X-Forwarded-For
48     // If the resulting address is a known Tor exit, flag it as such
49     ipProxyMiddleware(socket, next) {
50         if (!socket.context) socket.context = {};
52         try {
53             socket.handshake.connection = {
54                 remoteAddress: socket.handshake.address
55             };
57             socket.context.ipAddress = proxyaddr(
58                 socket.handshake,
59                 this.proxyTrustFn
60             );
62             if (!socket.context.ipAddress) {
63                 throw new Error(
64                     `Assertion failed: unexpected IP ${socket.context.ipAddress}`
65                 );
66             }
67         } catch (error) {
68             LOGGER.warn('Rejecting socket - proxyaddr failed: %s', error);
69             next(new Error('Could not determine IP address'));
70             return;
71         }
73         if (isTorExit(socket.context.ipAddress)) {
74             socket.context.torConnection = true;
75         }
77         next();
78     }
80     // Reject global banned IP addresses
81     ipBanMiddleware(socket, next) {
82         if (isIPGlobalBanned(socket.context.ipAddress)) {
83             LOGGER.info('Rejecting %s - banned',
84                     socket.context.ipAddress);
85             next(new Error('You are banned from the server'));
86             return;
87         }
89         next();
90     }
92     // Rate limit connection attempts by IP address
93     ipThrottleMiddleware(socket, next) {
94         if (!this.ipThrottle.has(socket.context.ipAddress)) {
95             this.ipThrottle.set(socket.context.ipAddress, new TokenBucket(5, 0.1));
96         }
98         const bucket = this.ipThrottle.get(socket.context.ipAddress);
99         if (bucket.throttle()) {
100             rateLimitExceeded.inc(1);
101             LOGGER.info('Rejecting %s - exceeded connection rate limit',
102                     socket.context.ipAddress);
103             next(new Error('Rate limit exceeded'));
104             return;
105         }
107         next();
108     }
110     checkIPLimit(socket) {
111         const ip = socket.context.ipAddress;
112         const count = this.ipCount.get(ip) || 0;
113         if (count >= Config.get('io.ip-connection-limit')) {
114             connLimitExceeded.inc(1);
115             LOGGER.info(
116                 'Rejecting %s - exceeded connection count limit',
117                 ip
118             );
119             socket.emit('kick', {
120                 reason: 'Too many connections from your IP address'
121             });
122             socket.disconnect(true);
123             return false;
124         }
126         this.ipCount.set(ip, count + 1);
127         socket.once('disconnect', () => {
128             const newCount = (this.ipCount.get(ip) || 1) - 1;
130             if (newCount === 0) {
131                 this.ipCount.delete(ip);
132             } else {
133                 this.ipCount.set(ip, newCount);
134             }
135         });
137         return true;
138     }
140     // Parse cookies
141     cookieParsingMiddleware(socket, next) {
142         const req = socket.handshake;
143         if (req.headers.cookie) {
144             cookieParser(req, null, () => next());
145         } else {
146             req.cookies = {};
147             req.signedCookies = {};
148             next();
149         }
150     }
152     // Determine session age from ip-session cookie
153     // (Used for restricting chat)
154     ipSessionCookieMiddleware(socket, next) {
155         const cookie = socket.handshake.signedCookies['ip-session'];
156         if (!cookie) {
157             socket.context.ipSessionFirstSeen = new Date();
158             next();
159             return;
160         }
162         const sessionMatch = verifyIPSessionCookie(socket.context.ipAddress, cookie);
163         if (sessionMatch) {
164             socket.context.ipSessionFirstSeen = sessionMatch.date;
165         } else {
166             socket.context.ipSessionFirstSeen = new Date();
167         }
168         next();
169     }
171     // Match login cookie against the DB, look up aliases
172     authUserMiddleware(socket, next) {
173         socket.context.aliases = [];
175         const promises = [];
176         const auth = socket.handshake.signedCookies.auth;
177         if (auth) {
178             promises.push(verifySession(auth).then(user => {
179                 socket.context.user = Object.assign({}, user);
180             }).catch(_error => {
181                 authFailureCount.inc(1);
182                 LOGGER.warn('Unable to verify session for %s - ignoring auth',
183                         socket.context.ipAddress);
184             }));
185         }
187         promises.push(getAliases(socket.context.ipAddress).then(aliases => {
188             socket.context.aliases = aliases;
189         }).catch(_error => {
190             LOGGER.warn('Unable to load aliases for %s',
191                     socket.context.ipAddress);
192         }));
194         Promise.all(promises).then(() => next());
195     }
197     handleConnection(socket) {
198         if (!this.checkIPLimit(socket)) {
199             //return;
200         }
202         patchTypecheckedFunctions(socket);
203         patchSocketMetrics(socket);
205         this.setRateLimiter(socket);
207         emitMetrics(socket);
209         LOGGER.info('Accepted socket from %s', socket.context.ipAddress);
210         socket.once('disconnect', (reason, reasonDetail) => {
211             LOGGER.info(
212                 '%s disconnected (%s%s)',
213                 socket.context.ipAddress,
214                 reason,
215                 reasonDetail ? ` - ${reasonDetail}` : ''
216             );
217         });
219         const user = new User(socket, socket.context.ipAddress, socket.context.user);
220         if (socket.context.user) {
221             db.recordVisit(socket.context.ipAddress, user.getName());
222         }
224         const announcement = Server.getServer().announcement;
225         if (announcement !== null) {
226             socket.emit('announcement', announcement);
227         }
228     }
230     setRateLimiter(socket) {
231         const refillRate = () => Config.get('io.throttle.in-rate-limit');
232         const capacity = () => Config.get('io.throttle.bucket-capacity');
234         socket._inRateLimit = new TokenBucket(capacity, refillRate);
236         socket.on('cytube:count-event', () => {
237             if (socket._inRateLimit.throttle()) {
238                 LOGGER.warn(
239                     'Kicking client %s: exceeded in-rate-limit of %d',
240                     socket.context.ipAddress,
241                     refillRate()
242                 );
244                 socket.emit('kick', { reason: 'Rate limit exceeded' });
245                 socket.disconnect();
246             }
247         });
248     }
250     initSocketIO() {
251         const io = this.io = sio.instance = sio();
252         io.use(this.ipProxyMiddleware.bind(this));
253         io.use(this.ipBanMiddleware.bind(this));
254         io.use(this.ipThrottleMiddleware.bind(this));
255         io.use(this.cookieParsingMiddleware.bind(this));
256         io.use(this.ipSessionCookieMiddleware.bind(this));
257         io.use(this.authUserMiddleware.bind(this));
258         io.on('connection', this.handleConnection.bind(this));
259     }
261     bindTo(servers) {
262         if (!this.io) {
263             throw new Error('Cannot bind: socket.io has not been initialized yet');
264         }
266         const engineOpts = {
267             /*
268              * Set ping timeout to 2 minutes to avoid spurious reconnects
269              * during transient network issues.  The default of 20 seconds
270              * is too aggressive.
271              *
272              * https://github.com/calzoneman/sync/issues/780
273              */
274             pingTimeout: 120000,
276             /*
277              * Per `ws` docs: "Note that Node.js has a variety of issues with
278              * high-performance compression, where increased concurrency,
279              * especially on Linux, can lead to catastrophic memory
280              * fragmentation and slow performance."
281              *
282              * CyTube's frames are ordinarily quite small, so there's not much
283              * point in compressing them.
284              */
285             perMessageDeflate: false,
286             httpCompression: false,
288             maxHttpBufferSize: 1 << 20,
290             /*
291              * Enable legacy support for socket.io v2 clients (e.g., bots)
292              */
293             allowEIO3: true,
295             cors: {
296                 origin: getCorsAllowCallback(),
297                 credentials: true // enable cookies for auth
298             }
299         };
301         servers.forEach(server => {
302             this.io.attach(server, engineOpts);
303         });
304     }
307 const incomingEventCount = new Counter({
308     name: 'cytube_socketio_incoming_events_total',
309     help: 'Number of received socket.io events from clients'
311 const outgoingPacketCount = new Counter({
312     name: 'cytube_socketio_outgoing_packets_total',
313     help: 'Number of outgoing socket.io packets to clients'
315 function patchSocketMetrics(sock) {
316     const emit = require('events').EventEmitter.prototype.emit;
318     sock.onAny(() => {
319         incomingEventCount.inc(1);
320         emit.call(sock, 'cytube:count-event');
321     });
323     let packet = sock.packet;
324     sock.packet = function patchedPacket() {
325         packet.apply(this, arguments);
326         outgoingPacketCount.inc(1);
327     }.bind(sock);
330 /* TODO: remove this crap */
331 /* Addendum 2021-08-14: socket.io v4 supports middleware, maybe move type validation to that */
332 function patchTypecheckedFunctions(sock) {
333     sock.typecheckedOn = function typecheckedOn(msg, template, cb) {
334         this.on(msg, (data, ack) => {
335             typecheck(data, template, (err, data) => {
336                 if (err) {
337                     this.emit("errorMsg", {
338                         msg: "Unexpected error for message " + msg + ": " + err.message
339                     });
340                 } else {
341                     cb(data, ack);
342                 }
343             });
344         });
345     }.bind(sock);
347     sock.typecheckedOnce = function typecheckedOnce(msg, template, cb) {
348         this.once(msg, data => {
349             typecheck(data, template, (err, data) => {
350                 if (err) {
351                     this.emit("errorMsg", {
352                         msg: "Unexpected error for message " + msg + ": " + err.message
353                     });
354                 } else {
355                     cb(data);
356                 }
357             });
358         });
359     }.bind(sock);
362 let globalIPBanlist = null;
363 function isIPGlobalBanned(ip) {
364     if (globalIPBanlist === null) {
365         globalIPBanlist = new CachingGlobalBanlist(db.getGlobalBanDB());
366         globalIPBanlist.refreshCache();
367         globalIPBanlist.startCacheTimer(60 * 1000);
368     }
370     return globalIPBanlist.isIPGlobalBanned(ip);
373 const promSocketCount = new Gauge({
374     name: 'cytube_sockets_num_connected',
375     help: 'Gauge of connected socket.io clients',
376     labelNames: ['transport']
378 const promSocketAccept = new Counter({
379     name: 'cytube_sockets_accepts_total',
380     help: 'Counter for number of connections accepted.  Excludes rejected connections.'
382 const promSocketDisconnect = new Counter({
383     name: 'cytube_sockets_disconnects_total',
384     help: 'Counter for number of connections disconnected.'
386 const promSocketReconnect = new Counter({
387     name: 'cytube_sockets_reconnects_total',
388     help: 'Counter for number of reconnects detected.'
390 function emitMetrics(sock) {
391     try {
392         let closed = false;
393         let transportName = sock.conn.transport.name;
394         promSocketCount.inc({ transport: transportName });
395         promSocketAccept.inc(1);
397         sock.conn.on('upgrade', () => {
398             try {
399                 let newTransport = sock.conn.transport.name;
400                 // Sanity check
401                 if (!closed && newTransport !== transportName) {
402                     promSocketCount.dec({ transport: transportName });
403                     transportName = newTransport;
404                     promSocketCount.inc({ transport: transportName });
405                 }
406             } catch (error) {
407                 LOGGER.error('Error emitting transport upgrade metrics for socket (ip=%s): %s',
408                         sock.context.ipAddress, error.stack);
409             }
410         });
412         sock.once('disconnect', () => {
413             try {
414                 closed = true;
415                 promSocketCount.dec({ transport: transportName });
416                 promSocketDisconnect.inc(1);
417             } catch (error) {
418                 LOGGER.error('Error emitting disconnect metrics for socket (ip=%s): %s',
419                         sock.context.ipAddress, error.stack);
420             }
421         });
423         sock.once('reportReconnect', () => {
424             try {
425                 promSocketReconnect.inc(1, new Date());
426             } catch (error) {
427                 LOGGER.error('Error emitting reconnect metrics for socket (ip=%s): %s',
428                         sock.context.ipAddress, error.stack);
429             }
430         });
431     } catch (error) {
432         LOGGER.error('Error emitting metrics for socket (ip=%s): %s',
433                 sock.context.ipAddress, error.stack);
434     }
437 let instance = null;
439 module.exports = {
440     init: function (srv, webConfig) {
441         if (instance !== null) {
442             throw new Error('ioserver.init: already initialized');
443         }
445         const ioServer = instance = new IOServer({
446             proxyTrustFn: proxyaddr.compile(webConfig.getTrustedProxies())
447         });
449         ioServer.initSocketIO();
451         const uniqueListenAddresses = new Set();
452         const servers = [];
454         Config.get("listen").forEach(function (bind) {
455             if (!bind.io) {
456                 return;
457             }
459             const id = bind.ip + ":" + bind.port;
460             if (uniqueListenAddresses.has(id)) {
461                 LOGGER.warn("Ignoring duplicate listen address %s", id);
462                 return;
463             }
465             if (srv.servers.hasOwnProperty(id)) {
466                 servers.push(srv.servers[id]);
467             } else {
468                 const server = http.createServer().listen(bind.port, bind.ip);
469                 servers.push(server);
470                 server.on("error", error => {
471                     if (error.code === "EADDRINUSE") {
472                         LOGGER.fatal(
473                             "Could not bind %s: address already in use.  Check " +
474                             "whether another application has already bound this " +
475                             "port, or whether another instance of this server " +
476                             "is running.",
477                             id
478                         );
479                         process.exit(1);
480                     }
481                 });
482             }
484             uniqueListenAddresses.add(id);
485         });
487         ioServer.bindTo(servers);
488     },
490     IOServer: IOServer
493 /* Clean out old rate limiters */
494 setInterval(function () {
495     if (instance == null) return;
497     let cleaned = 0;
498     const keys = instance.ipThrottle.keys();
499     for (const key of keys) {
500         if (instance.ipThrottle.get(key).lastRefill < Date.now() - 60000) {
501             const bucket = instance.ipThrottle.delete(key);
502             for (const k in bucket) delete bucket[k];
503             cleaned++;
504         }
505     }
507     if (cleaned > 0) {
508         LOGGER.info('Cleaned up %d stale IP throttle token buckets', cleaned);
509     }
510 }, 5 * 60 * 1000);
512 function getCorsAllowCallback() {
513     let origins = Array.prototype.slice.call(Config.get('io.cors.allowed-origins'));
515     origins = origins.concat([
516         Config.get('io.domain'),
517         Config.get('https.domain')
518     ]);
520     return function corsOriginAllowed(origin, callback) {
521         if (!origin) {
522             // Non-browser clients might not care about Origin, allow these.
523             callback(null, true);
524             return;
525         }
527         // Different ports are technically cross-origin; a distinction that does not matter to CyTube.
528         origin = origin.replace(/:\d+$/, '');
530         if (origins.includes(origin)) {
531             callback(null, true);
532         } else {
533             LOGGER.warn('Rejecting origin "%s"; allowed origins are %j', origin, origins);
534             callback(new Error('Invalid origin'));
535         }
536     };