use poll() instead of select()
[rofl0r-microsocks.git] / sockssrv.c
blob98e46221198cdd9b1a14d0277b6981cb876ae34b
1 /*
2 MicroSocks - multithreaded, small, efficient SOCKS5 server.
4 Copyright (C) 2017 rofl0r.
6 This is the successor of "rocksocks5", and it was written with
7 different goals in mind:
9 - prefer usage of standard libc functions over homegrown ones
10 - no artificial limits
11 - do not aim for minimal binary size, but for minimal source code size,
12 and maximal readability, reusability, and extensibility.
14 as a result of that, ipv4, dns, and ipv6 is supported out of the box
15 and can use the same code, while rocksocks5 has several compile time
16 defines to bring down the size of the resulting binary to extreme values
17 like 10 KB static linked when only ipv4 support is enabled.
19 still, if optimized for size, *this* program when static linked against musl
20 libc is not even 50 KB. that's easily usable even on the cheapest routers.
24 #define _GNU_SOURCE
25 #include <unistd.h>
26 #define _POSIX_C_SOURCE 200809L
27 #include <stdlib.h>
28 #include <string.h>
29 #include <stdio.h>
30 #include <pthread.h>
31 #include <signal.h>
32 #include <poll.h>
33 #include <arpa/inet.h>
34 #include <errno.h>
35 #include <limits.h>
36 #include "server.h"
37 #include "sblist.h"
39 #ifndef MAX
40 #define MAX(x, y) ((x) > (y) ? (x) : (y))
41 #endif
43 #ifdef PTHREAD_STACK_MIN
44 #define THREAD_STACK_SIZE MAX(8*1024, PTHREAD_STACK_MIN)
45 #else
46 #define THREAD_STACK_SIZE 64*1024
47 #endif
49 #if defined(__APPLE__)
50 #undef THREAD_STACK_SIZE
51 #define THREAD_STACK_SIZE 64*1024
52 #elif defined(__GLIBC__) || defined(__FreeBSD__)
53 #undef THREAD_STACK_SIZE
54 #define THREAD_STACK_SIZE 32*1024
55 #endif
57 static const char* auth_user;
58 static const char* auth_pass;
59 static sblist* auth_ips;
60 static pthread_rwlock_t auth_ips_lock = PTHREAD_RWLOCK_INITIALIZER;
61 static const struct server* server;
62 static union sockaddr_union bind_addr = {.v4.sin_family = AF_UNSPEC};
64 enum socksstate {
65 SS_1_CONNECTED,
66 SS_2_NEED_AUTH, /* skipped if NO_AUTH method supported */
67 SS_3_AUTHED,
70 enum authmethod {
71 AM_NO_AUTH = 0,
72 AM_GSSAPI = 1,
73 AM_USERNAME = 2,
74 AM_INVALID = 0xFF
77 enum errorcode {
78 EC_SUCCESS = 0,
79 EC_GENERAL_FAILURE = 1,
80 EC_NOT_ALLOWED = 2,
81 EC_NET_UNREACHABLE = 3,
82 EC_HOST_UNREACHABLE = 4,
83 EC_CONN_REFUSED = 5,
84 EC_TTL_EXPIRED = 6,
85 EC_COMMAND_NOT_SUPPORTED = 7,
86 EC_ADDRESSTYPE_NOT_SUPPORTED = 8,
89 struct thread {
90 pthread_t pt;
91 struct client client;
92 enum socksstate state;
93 volatile int done;
96 #ifndef CONFIG_LOG
97 #define CONFIG_LOG 1
98 #endif
99 #if CONFIG_LOG
100 /* we log to stderr because it's not using line buffering, i.e. malloc which would need
101 locking when called from different threads. for the same reason we use dprintf,
102 which writes directly to an fd. */
103 #define dolog(...) dprintf(2, __VA_ARGS__)
104 #else
105 static void dolog(const char* fmt, ...) { }
106 #endif
108 static int connect_socks_target(unsigned char *buf, size_t n, struct client *client) {
109 if(n < 5) return -EC_GENERAL_FAILURE;
110 if(buf[0] != 5) return -EC_GENERAL_FAILURE;
111 if(buf[1] != 1) return -EC_COMMAND_NOT_SUPPORTED; /* we support only CONNECT method */
112 if(buf[2] != 0) return -EC_GENERAL_FAILURE; /* malformed packet */
114 int af = AF_INET;
115 size_t minlen = 4 + 4 + 2, l;
116 char namebuf[256];
117 struct addrinfo* remote;
119 switch(buf[3]) {
120 case 4: /* ipv6 */
121 af = AF_INET6;
122 minlen = 4 + 2 + 16;
123 /* fall through */
124 case 1: /* ipv4 */
125 if(n < minlen) return -EC_GENERAL_FAILURE;
126 if(namebuf != inet_ntop(af, buf+4, namebuf, sizeof namebuf))
127 return -EC_GENERAL_FAILURE; /* malformed or too long addr */
128 break;
129 case 3: /* dns name */
130 l = buf[4];
131 minlen = 4 + 2 + l + 1;
132 if(n < 4 + 2 + l + 1) return -EC_GENERAL_FAILURE;
133 memcpy(namebuf, buf+4+1, l);
134 namebuf[l] = 0;
135 break;
136 default:
137 return -EC_ADDRESSTYPE_NOT_SUPPORTED;
139 unsigned short port;
140 port = (buf[minlen-2] << 8) | buf[minlen-1];
141 /* there's no suitable errorcode in rfc1928 for dns lookup failure */
142 if(resolve(namebuf, port, &remote)) return -EC_GENERAL_FAILURE;
143 int fd = socket(remote->ai_addr->sa_family, SOCK_STREAM, 0);
144 if(fd == -1) {
145 eval_errno:
146 if(fd != -1) close(fd);
147 freeaddrinfo(remote);
148 switch(errno) {
149 case ETIMEDOUT:
150 return -EC_TTL_EXPIRED;
151 case EPROTOTYPE:
152 case EPROTONOSUPPORT:
153 case EAFNOSUPPORT:
154 return -EC_ADDRESSTYPE_NOT_SUPPORTED;
155 case ECONNREFUSED:
156 return -EC_CONN_REFUSED;
157 case ENETDOWN:
158 case ENETUNREACH:
159 return -EC_NET_UNREACHABLE;
160 case EHOSTUNREACH:
161 return -EC_HOST_UNREACHABLE;
162 case EBADF:
163 default:
164 perror("socket/connect");
165 return -EC_GENERAL_FAILURE;
168 if(SOCKADDR_UNION_AF(&bind_addr) != AF_UNSPEC && bindtoip(fd, &bind_addr) == -1)
169 goto eval_errno;
170 if(connect(fd, remote->ai_addr, remote->ai_addrlen) == -1)
171 goto eval_errno;
173 freeaddrinfo(remote);
174 if(CONFIG_LOG) {
175 char clientname[256];
176 af = SOCKADDR_UNION_AF(&client->addr);
177 void *ipdata = SOCKADDR_UNION_ADDRESS(&client->addr);
178 inet_ntop(af, ipdata, clientname, sizeof clientname);
179 dolog("client[%d] %s: connected to %s:%d\n", client->fd, clientname, namebuf, port);
181 return fd;
184 static int is_authed(union sockaddr_union *client, union sockaddr_union *authedip) {
185 int af = SOCKADDR_UNION_AF(authedip);
186 if(af == SOCKADDR_UNION_AF(client)) {
187 size_t cmpbytes = af == AF_INET ? 4 : 16;
188 void *cmp1 = SOCKADDR_UNION_ADDRESS(client);
189 void *cmp2 = SOCKADDR_UNION_ADDRESS(authedip);
190 if(!memcmp(cmp1, cmp2, cmpbytes)) return 1;
192 return 0;
195 static int is_in_authed_list(union sockaddr_union *caddr) {
196 size_t i;
197 for(i=0;i<sblist_getsize(auth_ips);i++)
198 if(is_authed(caddr, sblist_get(auth_ips, i)))
199 return 1;
200 return 0;
203 static void add_auth_ip(union sockaddr_union *caddr) {
204 sblist_add(auth_ips, caddr);
207 static enum authmethod check_auth_method(unsigned char *buf, size_t n, struct client*client) {
208 if(buf[0] != 5) return AM_INVALID;
209 size_t idx = 1;
210 if(idx >= n ) return AM_INVALID;
211 int n_methods = buf[idx];
212 idx++;
213 while(idx < n && n_methods > 0) {
214 if(buf[idx] == AM_NO_AUTH) {
215 if(!auth_user) return AM_NO_AUTH;
216 else if(auth_ips) {
217 int authed = 0;
218 if(pthread_rwlock_rdlock(&auth_ips_lock) == 0) {
219 authed = is_in_authed_list(&client->addr);
220 pthread_rwlock_unlock(&auth_ips_lock);
222 if(authed) return AM_NO_AUTH;
224 } else if(buf[idx] == AM_USERNAME) {
225 if(auth_user) return AM_USERNAME;
227 idx++;
228 n_methods--;
230 return AM_INVALID;
233 static void send_auth_response(int fd, int version, enum authmethod meth) {
234 unsigned char buf[2];
235 buf[0] = version;
236 buf[1] = meth;
237 write(fd, buf, 2);
240 static void send_error(int fd, enum errorcode ec) {
241 /* position 4 contains ATYP, the address type, which is the same as used in the connect
242 request. we're lazy and return always IPV4 address type in errors. */
243 char buf[10] = { 5, ec, 0, 1 /*AT_IPV4*/, 0,0,0,0, 0,0 };
244 write(fd, buf, 10);
247 static void copyloop(int fd1, int fd2) {
248 struct pollfd fds[2] = {
249 [0] = {.fd = fd1, .events = POLLIN},
250 [1] = {.fd = fd2, .events = POLLIN},
253 while(1) {
254 /* inactive connections are reaped after 15 min to free resources.
255 usually programs send keep-alive packets so this should only happen
256 when a connection is really unused. */
257 switch(poll(fds, 2, 60*15*1000)) {
258 case 0:
259 send_error(fd1, EC_TTL_EXPIRED);
260 return;
261 case -1:
262 if(errno == EINTR || errno == EAGAIN) continue;
263 else perror("poll");
264 return;
266 int infd = (fds[0].revents & POLLIN) ? fd1 : fd2;
267 int outfd = infd == fd2 ? fd1 : fd2;
268 char buf[1024];
269 ssize_t sent = 0, n = read(infd, buf, sizeof buf);
270 if(n <= 0) return;
271 while(sent < n) {
272 ssize_t m = write(outfd, buf+sent, n-sent);
273 if(m < 0) return;
274 sent += m;
279 static enum errorcode check_credentials(unsigned char* buf, size_t n) {
280 if(n < 5) return EC_GENERAL_FAILURE;
281 if(buf[0] != 1) return EC_GENERAL_FAILURE;
282 unsigned ulen, plen;
283 ulen=buf[1];
284 if(n < 2 + ulen + 2) return EC_GENERAL_FAILURE;
285 plen=buf[2+ulen];
286 if(n < 2 + ulen + 1 + plen) return EC_GENERAL_FAILURE;
287 char user[256], pass[256];
288 memcpy(user, buf+2, ulen);
289 memcpy(pass, buf+2+ulen+1, plen);
290 user[ulen] = 0;
291 pass[plen] = 0;
292 if(!strcmp(user, auth_user) && !strcmp(pass, auth_pass)) return EC_SUCCESS;
293 return EC_NOT_ALLOWED;
296 static void* clientthread(void *data) {
297 struct thread *t = data;
298 t->state = SS_1_CONNECTED;
299 unsigned char buf[1024];
300 ssize_t n;
301 int ret;
302 int remotefd = -1;
303 enum authmethod am;
304 while((n = recv(t->client.fd, buf, sizeof buf, 0)) > 0) {
305 switch(t->state) {
306 case SS_1_CONNECTED:
307 am = check_auth_method(buf, n, &t->client);
308 if(am == AM_NO_AUTH) t->state = SS_3_AUTHED;
309 else if (am == AM_USERNAME) t->state = SS_2_NEED_AUTH;
310 send_auth_response(t->client.fd, 5, am);
311 if(am == AM_INVALID) goto breakloop;
312 break;
313 case SS_2_NEED_AUTH:
314 ret = check_credentials(buf, n);
315 send_auth_response(t->client.fd, 1, ret);
316 if(ret != EC_SUCCESS)
317 goto breakloop;
318 t->state = SS_3_AUTHED;
319 if(auth_ips && !pthread_rwlock_wrlock(&auth_ips_lock)) {
320 if(!is_in_authed_list(&t->client.addr))
321 add_auth_ip(&t->client.addr);
322 pthread_rwlock_unlock(&auth_ips_lock);
324 break;
325 case SS_3_AUTHED:
326 ret = connect_socks_target(buf, n, &t->client);
327 if(ret < 0) {
328 send_error(t->client.fd, ret*-1);
329 goto breakloop;
331 remotefd = ret;
332 send_error(t->client.fd, EC_SUCCESS);
333 copyloop(t->client.fd, remotefd);
334 goto breakloop;
338 breakloop:
340 if(remotefd != -1)
341 close(remotefd);
343 close(t->client.fd);
344 t->done = 1;
346 return 0;
349 static void collect(sblist *threads) {
350 size_t i;
351 for(i=0;i<sblist_getsize(threads);) {
352 struct thread* thread = *((struct thread**)sblist_get(threads, i));
353 if(thread->done) {
354 pthread_join(thread->pt, 0);
355 sblist_delete(threads, i);
356 free(thread);
357 } else
358 i++;
362 static int usage(void) {
363 dprintf(2,
364 "MicroSocks SOCKS5 Server\n"
365 "------------------------\n"
366 "usage: microsocks -1 -i listenip -p port -u user -P password -b bindaddr\n"
367 "all arguments are optional.\n"
368 "by default listenip is 0.0.0.0 and port 1080.\n\n"
369 "option -b specifies which ip outgoing connections are bound to\n"
370 "option -1 activates auth_once mode: once a specific ip address\n"
371 "authed successfully with user/pass, it is added to a whitelist\n"
372 "and may use the proxy without auth.\n"
373 "this is handy for programs like firefox that don't support\n"
374 "user/pass auth. for it to work you'd basically make one connection\n"
375 "with another program that supports it, and then you can use firefox too.\n"
377 return 1;
380 /* prevent username and password from showing up in top. */
381 static void zero_arg(char *s) {
382 size_t i, l = strlen(s);
383 for(i=0;i<l;i++) s[i] = 0;
386 int main(int argc, char** argv) {
387 int ch;
388 const char *listenip = "0.0.0.0";
389 unsigned port = 1080;
390 while((ch = getopt(argc, argv, ":1b:i:p:u:P:")) != -1) {
391 switch(ch) {
392 case '1':
393 auth_ips = sblist_new(sizeof(union sockaddr_union), 8);
394 break;
395 case 'b':
396 resolve_sa(optarg, 0, &bind_addr);
397 break;
398 case 'u':
399 auth_user = strdup(optarg);
400 zero_arg(optarg);
401 break;
402 case 'P':
403 auth_pass = strdup(optarg);
404 zero_arg(optarg);
405 break;
406 case 'i':
407 listenip = optarg;
408 break;
409 case 'p':
410 port = atoi(optarg);
411 break;
412 case ':':
413 dprintf(2, "error: option -%c requires an operand\n", optopt);
414 /* fall through */
415 case '?':
416 return usage();
419 if((auth_user && !auth_pass) || (!auth_user && auth_pass)) {
420 dprintf(2, "error: user and pass must be used together\n");
421 return 1;
423 if(auth_ips && !auth_pass) {
424 dprintf(2, "error: auth-once option must be used together with user/pass\n");
425 return 1;
427 signal(SIGPIPE, SIG_IGN);
428 struct server s;
429 sblist *threads = sblist_new(sizeof (struct thread*), 8);
430 if(server_setup(&s, listenip, port)) {
431 perror("server_setup");
432 return 1;
434 server = &s;
436 while(1) {
437 collect(threads);
438 struct client c;
439 struct thread *curr = malloc(sizeof (struct thread));
440 if(!curr) goto oom;
441 curr->done = 0;
442 if(server_waitclient(&s, &c)) continue;
443 curr->client = c;
444 if(!sblist_add(threads, &curr)) {
445 close(curr->client.fd);
446 free(curr);
447 oom:
448 dolog("rejecting connection due to OOM\n");
449 usleep(16); /* prevent 100% CPU usage in OOM situation */
450 continue;
452 pthread_attr_t *a = 0, attr;
453 if(pthread_attr_init(&attr) == 0) {
454 a = &attr;
455 pthread_attr_setstacksize(a, THREAD_STACK_SIZE);
457 if(pthread_create(&curr->pt, a, clientthread, curr) != 0)
458 dolog("pthread_create failed. OOM?\n");
459 if(a) pthread_attr_destroy(&attr);