explain why we put subprocess in the dict
[0tDNS.git] / src / hourly.py
blob5af9e4d9ee0bdcdd44eea9c63dbc44467a359b0b
1 #!/bin/python3
3 from sys import argv
4 import subprocess
5 from os import path, waitpid, unlink
6 from time import gmtime, strftime, sleep
7 import re
9 # our own module used by several scripts in the project
10 from ztdnslib import start_db_connection, \
11 get_default_host_address, get_ztdns_config
13 wrapper = '/var/lib/0tdns/vpn_wrapper.sh'
14 perform_queries = '/var/lib/0tdns/perform_queries.py'
15 lockfile = '/var/lib/0tdns/lockfile'
17 def sync_ovpn_config(cursor, vpn_id, config_path, config_hash):
18 cursor.execute('''
19 select ovpn_config
20 from user_side_vpn
21 where id = %s and ovpn_config_sha256 = %s
22 ''', (vpn_id, config_hash))
24 (config_contents,) = cursor.fetchone()
26 with open(config_path, "wb") as config_file:
27 config_file.write(config_contents.tobytes())
29 def get_vpn_connections(cursor, hour):
30 # return (
31 # # vpn_id | config_path
32 # (14, "./vpngate_178.254.251.12_udp_1195.ovpn"),
33 # (13, "./vpngate_public-vpn-229.opengw.net_tcp_443.ovpn")
34 # )
35 cursor.execute('''
36 SELECT DISTINCT v.id, v.ovpn_config_sha256
37 FROM user_side_queries AS q JOIN user_side_vpn AS v
38 ON v.id = q.vpn_id;
39 ''')
40 return cursor.fetchall()
42 # return True on success and False if lock exists
43 def lock_on_file():
44 try:
45 with open(lockfile, 'x'):
46 return True
47 except FileExistsError:
48 return False
50 # return True on success and False if lock got removed in the meantime
51 def unlock_on_file():
52 try:
53 unlink(lockfile)
54 return True
55 except FileNotFoundError:
56 return False
58 address_range_regex = re.compile(r'''
59 ([\d]+\.[\d]+\.[\d]+\.[\d]+) # first IPv4 address in the range
61 [\s]*-[\s]* # dash (with optional whitespace around)
63 ([\d]+\.[\d]+\.[\d]+\.[\d]+) # last IPv4 address in the range
64 ''', re.VERBOSE)
66 address_regex = re.compile(r'([\d]+)\.([\d]+)\.([\d]+)\.([\d]+)')
68 def ip_address_to_number(address):
69 match = address_regex.match(address)
70 if not match:
71 return None
72 number = 0
73 for byte in match.groups():
74 byteval = int(byte)
75 if byteval > 256:
76 return None
77 number = number * 256 + byteval
78 return number
80 def number_to_ip_address(number):
81 byte1 = number % 256
82 number = number // 256
83 byte2 = number % 256
84 number = number // 256
85 byte3 = number % 256
86 number = number // 256
87 byte4 = number % 256
88 return "{}.{}.{}.{}".format(byte4, byte3, byte2, byte1)
90 # this functions accepts list of IPv4 address ranges like:
91 # ['10.25.25.0 - 10.25.25.59', '10.25.25.120 - 10.25.25.135']
92 # and returns a set of /30 subnetworks; each subnetwork is represented
93 # by a tuple of 2 usable addresses within that subnetwork.
94 # E.g. for subnetwork 10.25.25.16/30 it would be ('10.25.25.17', '10.25.25.18');
95 # Addressess ending with .16 (subnet address)
96 # and .19 (broadcast in the subnet) are considered unusable in this case.
97 # The returned set will contain up to count elements.
98 def get_available_subnetworks(count, address_ranges, logfile):
99 available_subnetworks = set()
101 for address_range in address_ranges:
102 match = address_range_regex.match(address_range)
103 ok_flag = True
105 if not match:
106 ok_flag = False
108 if ok_flag:
109 start_addr_number = ip_address_to_number(match.groups()[0])
110 end_addr_number = ip_address_to_number(match.groups()[1])
111 if not start_addr_number or not end_addr_number:
112 ok_flag = False
114 if ok_flag:
115 # round so that start_addr is first ip address in a /30 network
116 # and end_addr is last ip address in a /30 network
117 while start_addr_number % 4 != 0:
118 start_addr_number += 1
119 while end_addr_number % 4 != 3:
120 end_addr_number -= 1
122 if start_addr_number >= end_addr_number:
123 logfile.write("address range '{}' doesn't contain any"
124 " /30 subnetworks\n".format(address_range))
125 else:
126 while len(available_subnetworks) < count and \
127 start_addr_number < end_addr_number:
128 usable_addr1 = number_to_ip_address(start_addr_number + 1)
129 usable_addr2 = number_to_ip_address(start_addr_number + 2)
130 available_subnetworks.add((usable_addr1, usable_addr2))
131 start_addr_number += 4
132 else:
133 logfile.write("'{}' is not a valid address range\n"\
134 .format(address_range))
136 return available_subnetworks
138 def do_hourly_work(hour, logfile):
139 ztdns_config = get_ztdns_config()
140 if ztdns_config['enabled'] != 'yes':
141 logfile.write("0tdns not enabled in the config - exiting\n")
142 return
144 connection = start_db_connection(ztdns_config)
145 cursor = connection.cursor()
147 vpns = get_vpn_connections(cursor, hour)
149 handled_vpns = ztdns_config.get('handled_vpns')
150 if handled_vpns:
151 logfile.write("Only handling vpns of ids {}\n".format(handled_vpns))
152 vpns = [vpn for vpn in vpns if vpn[0] in handled_vpns]
153 else:
154 # if not specfied in the config, all vpns are handled
155 handled_vpns = [vpn[0] for vpn in vpns]
157 parallel_vpns = ztdns_config['parallel_vpns'] # we need this many subnets
158 subnets = get_available_subnetworks(parallel_vpns,
159 ztdns_config['private_addresses'],
160 logfile)
162 if not subnets:
163 logfile.write("couldn't get ANY /30 subnet of private addresses from"
164 " the 0tdns config file - exiting\n");
165 return # TODO close cursor and connection here
167 if len(subnets) < parallel_vpns:
168 logfile.write("configuration allows running {0} parallel vpn"
169 " connections, but provided private ip addresses give"
170 " only {1} /30 subnets, which limits parallel connections"
171 " to {1}\n".format(parallel_vpns, len(subnets)))
172 parallel_vpns = len(subnets)
174 for vpn_id, config_hash in vpns:
175 config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
176 if not path.isfile(config_path):
177 logfile.write("Syncing config for vpn {} with hash {}\n"\
178 .format(vpn_id, config_hash))
179 sync_ovpn_config(cursor, vpn_id, config_path, config_hash)
181 # map of each wrapper pid to tuple containing id of the vpn it connects to
182 # and subnet (represented as tuple of addresses) it uses for veth device
183 pids_wrappers = {}
185 def wait_for_wrapper_process():
186 while True:
187 pid, exit_status = waitpid(0, 0)
188 # make sure it's one of our wrapper processes
189 vpn_id, subnet, _ = pids_wrappers.get(pid, (None, None, None))
190 if subnet:
191 break
193 if exit_status == 2:
194 # this means our perform_queries.py crashed... not good
195 logfile.write('performing queries through vpn {} failed\n'\
196 .format(vpn_id))
197 elif exit_status != 0:
198 # vpn server is probably not responding
199 logfile.write('connection to vpn {} failed\n'\
200 .format(vpn_id))
201 pids_wrappers.pop(pid)
202 subnets.add(subnet)
204 for vpn_id, config_hash in vpns:
205 if len(pids_wrappers) == parallel_vpns:
206 wait_for_wrapper_process()
208 config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
209 physical_ip = get_default_host_address(ztdns_config['host'])
210 subnet = subnets.pop()
211 veth_addr1, veth_addr2 = subnet
212 route_through_veth = ztdns_config['host'] + "/32"
213 command_in_namespace = [perform_queries, hour, str(vpn_id)]
214 logfile.write("Running connection for vpn {}\n".format(vpn_id))
216 # see into vpn_wrapper.sh for explaination of its arguments
217 p = subprocess.Popen([wrapper, config_path, physical_ip, veth_addr1,
218 veth_addr2, route_through_veth, str(vpn_id)] +
219 command_in_namespace)
221 # we're not actually using the subprocess object anywhere, but we
222 # put it in the dict regardless to keep a reference to it - otherwise
223 # python would reap the child for us and waitpid(0, 0) would raise
224 # '[Errno 10] No child processes' :c
225 pids_wrappers[p.pid] = (vpn_id, subnet, p)
227 while len(pids_wrappers) > 0:
228 wait_for_wrapper_process()
230 cursor.execute('''
231 INSERT INTO user_side_responses(date, result, dns_id, service_id, vpn_id)
232 (SELECT TIMESTAMP WITH TIME ZONE %s,
233 'internal failure: vpn_connection_failure',
234 q.dns_id, q.service_id, q.vpn_id
235 FROM user_side_responses AS r RIGHT JOIN user_side_queries AS q
236 ON q.service_id = r.service_id AND
237 q.dns_id = r.dns_id AND
238 q.vpn_id = r.vpn_id AND
239 date = %s
240 WHERE r.id IS NULL AND q.vpn_id = ANY(%s));
241 ''', (hour, hour, handled_vpns))
243 cursor.close()
244 connection.close()
247 with open("/var/log/0tdns.log", "a") as logfile:
248 # round down to an hour - this datetime format is one
249 # of the formats accepted by postgres
250 hour = strftime('%Y-%m-%d %H:00%z', gmtime())
251 if not lock_on_file():
252 logfile.write("Failed trying to run for {}; {} exists\n"\
253 .format(hour, lockfile))
254 else:
255 try:
256 logfile.write("Running for {}\n".format(hour))
257 do_hourly_work(hour, logfile)
258 finally:
259 if not unlock_on_file():
260 logfile.write("Can't remove lock - {} already deleted!\n"\
261 .format(lockfile))