add license
[0tDNS.git] / src / hourly.py
blobca9fc2f8a45caac9b3cb2fb92adedcfede79e09d
1 #!/bin/python3
3 import subprocess
4 from os import path, waitpid, unlink, WEXITSTATUS, chown
5 from time import gmtime, strftime, sleep
6 import re
7 from pathlib import Path
8 from pwd import getpwnam
10 import psycopg2
12 # our own module used by several scripts in the project
13 from ztdnslib import start_db_connection, \
14 get_default_host_address, get_ztdns_config, log, set_loghour, logfile
16 wrapper = '/var/lib/0tdns/vpn_wrapper.sh'
17 perform_queries = '/var/lib/0tdns/perform_queries.py'
18 lockfile = '/var/lib/0tdns/lockfile'
20 def sync_ovpn_config(cursor, vpn_id, config_path, config_hash):
21 cursor.execute('''
22 select ovpn_config
23 from user_side_vpn
24 where id = %s and ovpn_config_sha256 = %s
25 ''', (vpn_id, config_hash))
27 (config_contents,) = cursor.fetchone()
29 with open(config_path, "wb") as config_file:
30 config_file.write(config_contents.tobytes())
32 def get_vpn_connections(cursor, hour):
33 # return (
34 # # vpn_id | config_path
35 # (14, "./vpngate_178.254.251.12_udp_1195.ovpn"),
36 # (13, "./vpngate_public-vpn-229.opengw.net_tcp_443.ovpn")
37 # )
38 cursor.execute('''
39 SELECT DISTINCT v.id, v.ovpn_config_sha256
40 FROM user_side_queries AS q JOIN user_side_vpn AS v
41 ON v.id = q.vpn_id;
42 ''')
43 return cursor.fetchall()
45 # return True on success and False if lock exists
46 def lock_on_file():
47 try:
48 with open(lockfile, 'x'):
49 return True
50 except FileExistsError:
51 return False
53 # return True on success and False if lock got removed in the meantime
54 def unlock_on_file():
55 try:
56 unlink(lockfile)
57 return True
58 except FileNotFoundError:
59 return False
61 address_range_regex = re.compile(r'''
62 ([\d]+\.[\d]+\.[\d]+\.[\d]+) # first IPv4 address in the range
64 [\s]*-[\s]* # dash (with optional whitespace around)
66 ([\d]+\.[\d]+\.[\d]+\.[\d]+) # last IPv4 address in the range
67 ''', re.VERBOSE)
69 address_regex = re.compile(r'([\d]+)\.([\d]+)\.([\d]+)\.([\d]+)')
71 def ip_address_to_number(address):
72 match = address_regex.match(address)
73 if not match:
74 return None
75 number = 0
76 for byte in match.groups():
77 byteval = int(byte)
78 if byteval > 256:
79 return None
80 number = number * 256 + byteval
81 return number
83 def number_to_ip_address(number):
84 byte1 = number % 256
85 number = number // 256
86 byte2 = number % 256
87 number = number // 256
88 byte3 = number % 256
89 number = number // 256
90 byte4 = number % 256
91 return "{}.{}.{}.{}".format(byte4, byte3, byte2, byte1)
93 # this functions accepts list of IPv4 address ranges like:
94 # ['10.25.25.0 - 10.25.25.59', '10.25.25.120 - 10.25.25.135']
95 # and returns a set of /30 subnetworks; each subnetwork is represented
96 # by a tuple of 2 usable addresses within that subnetwork.
97 # E.g. for subnetwork 10.25.25.16/30 it would be ('10.25.25.17', '10.25.25.18');
98 # Addressess ending with .16 (subnet address)
99 # and .19 (broadcast in the subnet) are considered unusable in this case.
100 # The returned set will contain up to count elements.
101 def get_available_subnetworks(count, address_ranges):
102 available_subnetworks = set()
104 for address_range in address_ranges:
105 match = address_range_regex.match(address_range)
106 ok_flag = True
108 if not match:
109 ok_flag = False
111 if ok_flag:
112 start_addr_number = ip_address_to_number(match.groups()[0])
113 end_addr_number = ip_address_to_number(match.groups()[1])
114 if not start_addr_number or not end_addr_number:
115 ok_flag = False
117 if ok_flag:
118 # round so that start_addr is first ip address in a /30 network
119 # and end_addr is last ip address in a /30 network
120 while start_addr_number % 4 != 0:
121 start_addr_number += 1
122 while end_addr_number % 4 != 3:
123 end_addr_number -= 1
125 if start_addr_number >= end_addr_number:
126 log("address range '{}' doesn't contain any /30 subnetworks"\
127 .format(address_range))
128 else:
129 while len(available_subnetworks) < count and \
130 start_addr_number < end_addr_number:
131 usable_addr1 = number_to_ip_address(start_addr_number + 1)
132 usable_addr2 = number_to_ip_address(start_addr_number + 2)
133 available_subnetworks.add((usable_addr1, usable_addr2))
134 start_addr_number += 4
135 else:
136 log("'{}' is not a valid address range".format(address_range))
138 return available_subnetworks
140 def do_hourly_work(hour):
141 ztdns_config = get_ztdns_config()
142 if ztdns_config['enabled'] not in ['yes', True]:
143 log('0tdns not enabled in the config - exiting')
144 return
146 connection = start_db_connection(ztdns_config)
147 cursor = connection.cursor()
149 vpns = get_vpn_connections(cursor, hour)
151 handled_vpns = ztdns_config.get('handled_vpns')
152 if handled_vpns:
153 log('Only handling vpns of ids {}'.format(handled_vpns))
154 vpns = [vpn for vpn in vpns if vpn[0] in handled_vpns]
155 else:
156 # if not specfied in the config, all vpns are handled
157 handled_vpns = [vpn[0] for vpn in vpns]
159 parallel_vpns = ztdns_config['parallel_vpns'] # we need this many subnets
160 subnets = get_available_subnetworks(parallel_vpns,
161 ztdns_config['private_addresses'])
163 if not subnets:
164 log("couldn't get ANY /30 subnet of private"
165 " addresses from the 0tdns config file - exiting");
166 cursor.close()
167 connection.close()
168 return
170 if len(subnets) < parallel_vpns:
171 log('configuration allows running {0} parallel vpn connections, but'
172 ' provided private ip addresses give only {1} /30 subnets, which'
173 ' limits parallel connections to {1}'\
174 .format(parallel_vpns, len(subnets)))
175 parallel_vpns = len(subnets)
177 for vpn_id, config_hash in vpns:
178 config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
179 if not path.isfile(config_path):
180 log('Syncing config for vpn {} with hash {}'\
181 .format(vpn_id, config_hash))
182 sync_ovpn_config(cursor, vpn_id, config_path, config_hash)
184 # map of each wrapper pid to tuple containing id of the vpn it connects to
185 # and subnet (represented as tuple of addresses) it uses for veth device
186 pids_wrappers = {}
188 def wait_for_wrapper_process():
189 while True:
190 pid, exit_status = waitpid(0, 0)
191 # make sure it's one of our wrapper processes
192 vpn_id, subnet, _ = pids_wrappers.get(pid, (None, None, None))
193 if subnet:
194 break
196 exit_status = WEXITSTATUS(exit_status) # read man waitpid if wondering
197 if exit_status != 0:
198 if exit_status == 2:
199 # this means our perform_queries.py crashed... not good
200 log('performing queries through vpn {} failed'.format(vpn_id))
201 result_info = 'internal failure: process_crash'
202 else:
203 # vpn server is probably not responding
204 log('connection to vpn {} failed'.format(vpn_id))
205 result_info = 'internal failure: vpn_connection_failure'
207 try:
208 cursor.execute('''
209 INSERT INTO user_side_responses
210 (date, result, dns_id, service_id, vpn_id)
211 (SELECT TIMESTAMP WITH TIME ZONE %s, %s,
212 dns_id, service_id, vpn_id
213 FROM user_side_queries
214 WHERE vpn_id = %s);
215 ''', (hour, result_info, vpn_id))
216 except psycopg2.IntegrityError:
217 log('results already exist for vpn {}'.format(vpn_id))
219 pids_wrappers.pop(pid)
220 subnets.add(subnet)
222 for vpn_id, config_hash in vpns:
223 if len(pids_wrappers) == parallel_vpns:
224 wait_for_wrapper_process()
226 config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
227 physical_ip = get_default_host_address(ztdns_config['host'])
228 subnet = subnets.pop()
229 veth_addr1, veth_addr2 = subnet
230 route_through_veth = ztdns_config['host'] + "/32"
231 command_in_namespace = [perform_queries, hour, str(vpn_id)]
232 log('Running connection for vpn {}'.format(vpn_id))
234 # see into vpn_wrapper.sh for explaination of its arguments
235 p = subprocess.Popen([wrapper, config_path, physical_ip, veth_addr1,
236 veth_addr2, route_through_veth, str(vpn_id)] +
237 command_in_namespace)
239 # we're not actually using the subprocess object anywhere, but we
240 # put it in the dict regardless to keep a reference to it - otherwise
241 # python would reap the child for us and waitpid(0, 0) would raise
242 # '[Errno 10] No child processes' :c
243 pids_wrappers[p.pid] = (vpn_id, subnet, p)
245 while len(pids_wrappers) > 0:
246 wait_for_wrapper_process()
248 cursor.close()
249 connection.close()
252 def prepare_logging(hour):
253 set_loghour(hour) # log() function will now prepend messages with hour
255 Path(logfile).touch() # ensure logfile exists
257 # enable 0tdns user to write to logfile
258 chown(logfile, getpwnam('0tdns').pw_uid, -1)
260 # round down to an hour - this datetime format is one
261 # of the formats accepted by postgres
262 hour = strftime('%Y-%m-%d %H:00%z', gmtime())
263 prepare_logging(hour)
265 if not lock_on_file():
266 log('Failed trying to run for {}; {} exists'.format(hour, lockfile))
267 else:
268 try:
269 log('Running for {}'.format(hour))
270 do_hourly_work(hour)
271 finally:
272 if not unlock_on_file():
273 log("Can't remove lock - {} already deleted!".format(lockfile))