make getpeername() return the original socket address which before it was intercepted
[hband-tools.git] / preload / autossl / autossl.c
blob62b6ca136a3f98fa2e7e8c16dc88f8caed77cbba
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <sys/types.h>
5 #include <sys/socket.h>
6 #include <netinet/in.h>
7 #include <arpa/inet.h>
8 #include <dlfcn.h>
9 #include <errno.h>
10 #include <unistd.h>
11 #include <bsd/unistd.h>
12 #include <string.h>
13 #include <pthread.h>
16 // we are gonna maintaining an array of fd <-> sockaddr mapping
17 // to able to answer getpeername() calls with the expected socket address.
18 struct sockfd_peername {
19 int fd;
20 struct sockaddr_in sockaddr;
21 socklen_t addrlen;
24 typedef struct sockfd_peername * sockfd_peername_ptr;
27 static sockfd_peername_ptr captured_sockets = NULL;
28 static unsigned int captured_sockets_len = 0;
29 pthread_mutex_t captured_sockets_lock;
32 sockfd_peername_ptr _get_captured_socket(int fd)
34 // find the sockfd_peername structure which hold the sockaddr for fd.
35 // even if it's no longer valid.
36 for(unsigned int idx = 0; idx < captured_sockets_len; idx++)
38 if(captured_sockets[idx].fd == fd) return &captured_sockets[idx];
40 return NULL;
43 void _note_sockfd_peername(int fd, const struct sockaddr_in * sockaddr, socklen_t addrlen)
45 // record the sockaddr for the given fd.
46 pthread_mutex_lock(&captured_sockets_lock);
47 sockfd_peername_ptr sockinfo = _get_captured_socket(fd);
48 if(sockinfo == NULL)
50 // fd is not recorded so far.
51 if(captured_sockets == NULL)
53 // captured_sockets is not allocated yet at all.
54 captured_sockets = malloc(sizeof(struct sockfd_peername));
55 if(captured_sockets == NULL) abort();
56 captured_sockets_len = 1;
57 sockinfo = captured_sockets;
59 else
61 // allocate one more slot.
62 captured_sockets = realloc(captured_sockets, sizeof(struct sockfd_peername) * (captured_sockets_len+1));
63 if(captured_sockets == NULL) abort();
64 captured_sockets_len += 1;
65 sockinfo = &captured_sockets[captured_sockets_len-1];
68 // copy fd number, socket address, and address length for later recall.
69 // write to a freshly allocated sockfd_peername structure or an old one, overwriting old data in that case.
70 // captured_sockets is a grow-only array. we don't free up space because it's mainly for programs making a few intercepted connect() calls in their lifetime.
71 sockinfo->fd = fd;
72 memcpy(&sockinfo->sockaddr, sockaddr, addrlen);
73 sockinfo->addrlen = addrlen;
74 pthread_mutex_unlock(&captured_sockets_lock);
77 void _autossl_ip_parse_error(const char* s, const size_t len)
79 fprintf(stderr, "autossl: failed to parse ip address '%.*s'\n", len, s);
82 int connect(int sockfd, const struct sockaddr_in *orig_sockaddr, socklen_t addrlen)
84 char *upgrade_ip_str;
85 struct in_addr upgrade_ip;
86 char *upgrade_port_str;
87 in_port_t upgrade_port = 0;
89 char *tls_cmd;
90 #define IP_STR_LEN 39+1
91 #define PORT_STR_LEN 5+1
92 char connect_ip[IP_STR_LEN];
93 char connect_port[PORT_STR_LEN];
95 char *next_separator;
96 char *autossl_errno_str;
98 struct sockaddr_in to_sockaddr;
99 int (*orig_connect)(int, const struct sockaddr_in *, socklen_t) = dlsym(RTLD_NEXT, "connect");
103 // TODO suport ipv6
104 if(orig_sockaddr->sin_family != AF_INET) goto stdlib;
106 upgrade_port_str = getenv("AUTOSSL_UPGRADE_PORTS");
107 if(upgrade_port_str == NULL) goto stdlib;
110 // determine if need to intercept this connect() according to AUTOSSL_UPGRADE_PORTS
111 int upgrade_port_matched = 0;
112 next_separator = NULL;
115 next_separator = strchrnul(upgrade_port_str, ' ');
116 upgrade_port = atoi(upgrade_port_str);
117 if(upgrade_port == 0) { fprintf(stderr, "autossl: failed to parse port number(s): %s\n", upgrade_port_str); goto error_case; }
118 if(upgrade_port == ntohs(orig_sockaddr->sin_port)) upgrade_port_matched = 1;
119 upgrade_port_str = (char*)(next_separator + 1);
121 while(!upgrade_port_matched && *next_separator != '\0');
123 if(!upgrade_port_matched) goto stdlib;
125 // determine if need to intercept this connect() according to AUTOSSL_UPGRADE_IPS
126 upgrade_ip_str = getenv("AUTOSSL_UPGRADE_IPS");
127 if(upgrade_ip_str != NULL)
129 int upgrade_ip_matched = 0;
130 char *next_separator = NULL;
133 next_separator = strchrnul(upgrade_ip_str, ' ');
134 if(inet_aton(upgrade_ip_str, &upgrade_ip) == 0)
136 _autossl_ip_parse_error(upgrade_ip_str, next_separator - upgrade_ip_str);
137 goto error_case;
139 if(ntohl(orig_sockaddr->sin_addr.s_addr) == ntohl(upgrade_ip.s_addr)) {
140 upgrade_ip_matched = 1;
142 upgrade_ip_str = (char*)(next_separator + 1);
144 while(!upgrade_ip_matched && *next_separator != '\0');
146 if(!upgrade_ip_matched) goto stdlib;
149 tls_cmd = getenv("AUTOSSL_TLS_CMD");
150 if(tls_cmd == NULL) goto stdlib;
153 int sockpair[2];
154 if(socketpair(AF_UNIX, SOCK_STREAM, 0, sockpair) == -1)
156 perror("autossl: sockpair");
157 goto error_case;
160 pid_t childpid = fork();
161 if(childpid < 0)
163 perror("autossl: fork");
164 goto error_case;
166 if(childpid == 0)
168 /* save the ip and port we wanted to connect to as strings */
169 snprintf(connect_ip, IP_STR_LEN, "%s", inet_ntoa(orig_sockaddr->sin_addr));
170 snprintf(connect_port, PORT_STR_LEN, "%d", ntohs(orig_sockaddr->sin_port));
171 /* wire STDIO to the newly created socket */
172 dup2(sockpair[1], 0);
173 dup2(sockpair[1], 1);
174 /* leave stderr open */
175 /* close dangling files */
176 closefrom(3);
177 execlp(tls_cmd, tls_cmd, connect_ip, connect_port, NULL);
178 _exit(127);
181 close(sockpair[1]);
182 if(dup2(sockpair[0], sockfd) == -1)
184 perror("autossl: dup2");
185 close(sockpair[0]);
186 goto error_case;
189 if(!getenv("AUTOSSL_SILENT"))
190 fprintf(stderr, "autossl: redirecting %s:%d -> fd#%d\n", inet_ntoa(orig_sockaddr->sin_addr), ntohs(orig_sockaddr->sin_port), sockpair[0]);
192 _note_sockfd_peername(sockfd, orig_sockaddr, addrlen);
194 /* the caller closes sockfd only, not sockpair[0], so unused open
195 files may pile up eventually in long running programs. */
197 /* childpid process won't be reaped, so don't be scared on the
198 zombie processes, they will be disappear as the main program exits. */
200 return 0;
202 error_case:
203 autossl_errno_str = getenv("AUTOSSL_ERRNO");
204 if(autossl_errno_str)
206 errno = atoi(autossl_errno_str);
207 return -1;
210 stdlib:
211 return orig_connect(sockfd, orig_sockaddr, addrlen);
214 int shutdown(int sockfd, int how)
216 int (*orig_shutdown)(int, int) = dlsym(RTLD_NEXT, "shutdown");
218 // shutdown() is intercepted only to mark which fd <-> sockaddr mapping is not valid anymore.
220 pthread_mutex_lock(&captured_sockets_lock);
221 int err = orig_shutdown(sockfd, how);
222 if(err == 0)
224 // shutdown() succeeded on the fd, so invalidate its slot in captured_sockets array - if it's there.
225 sockfd_peername_ptr sockinfo = _get_captured_socket(sockfd);
226 if(sockinfo != NULL)
228 sockinfo->addrlen = 0; // mark this slot not being used anymore
231 pthread_mutex_unlock(&captured_sockets_lock);
232 return err;
235 int getpeername(int sockfd, struct sockaddr *restrict addr, socklen_t *restrict addrlen)
237 int (*orig_getpeername)(int, struct sockaddr *, socklen_t *) = dlsym(RTLD_NEXT, "getpeername");
239 // check if this fd is in captured_sockets array,
240 // if so, then answer with the recorded socket address.
242 pthread_mutex_lock(&captured_sockets_lock);
243 sockfd_peername_ptr sockinfo = _get_captured_socket(sockfd);
244 if(sockinfo != NULL && sockinfo->addrlen != 0)
246 memcpy(addr, &sockinfo->sockaddr, sockinfo->addrlen);
247 *addrlen = sockinfo->addrlen;
248 pthread_mutex_unlock(&captured_sockets_lock);
249 return 0;
251 pthread_mutex_unlock(&captured_sockets_lock);
253 return orig_getpeername(sockfd, addr, addrlen);