ctdb-scripts: Improve update and listing code
[samba4-gss.git] / selftest / target / dns_hub.py
blob5f8d52463035137dbef72b01e3a08d7dd19cf215
1 #!/usr/bin/env python3
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
23 import threading
24 import sys
25 import select
26 import socket
27 import collections
28 import time
29 from samba.dcerpc import dns
30 import samba.ndr as ndr
32 import socketserver
33 sserver = socketserver
35 DNS_REQUEST_TIMEOUT = 10
37 # make sure the script dies immediately when hitting control-C,
38 # rather than raising KeyboardInterrupt. As we do all database
39 # operations using transactions, this is safe.
40 import signal
41 signal.signal(signal.SIGINT, signal.SIG_DFL)
43 class DnsHandler(sserver.BaseRequestHandler):
44 dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
45 def dns_qtype_string(self, qtype):
46 "Return a readable qtype code"
47 return self.dns_qtype_strings[qtype]
49 dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
50 def dns_rcode_string(self, rcode):
51 "Return a readable error code"
52 return self.dns_rcode_strings[rcode]
54 def dns_transaction_udp(self, packet, host):
55 "send a DNS query and read the reply"
56 s = None
57 flags = socket.AddressInfo.AI_NUMERICHOST
58 flags |= socket.AddressInfo.AI_NUMERICSERV
59 flags |= socket.AddressInfo.AI_PASSIVE
60 addr_info = socket.getaddrinfo(host, int(53),
61 type=socket.SocketKind.SOCK_DGRAM,
62 flags=flags)
63 assert len(addr_info) == 1
64 try:
65 send_packet = ndr.ndr_pack(packet)
66 s = socket.socket(addr_info[0][0], addr_info[0][1], 0)
67 s.settimeout(DNS_REQUEST_TIMEOUT)
68 s.connect(addr_info[0][4])
69 s.sendall(send_packet, 0)
70 recv_packet = s.recv(2048, 0)
71 return ndr.ndr_unpack(dns.name_packet, recv_packet)
72 except socket.error as err:
73 print("Error sending to host %s for name %s: %s\n" %
74 (host, packet.questions[0].name, err.errno))
75 raise
76 finally:
77 if s is not None:
78 s.close()
80 def get_pdc_ipv4_addr(self, lookup_name):
81 """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
83 realm_to_ip_mappings = self.server.realm_to_ip_mappings
85 # sort the realms so we find the longest-match first
86 testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
87 testenv_realms.reverse()
89 for realm in testenv_realms:
90 if lookup_name.endswith(realm):
91 # return the corresponding IP address for this realm's PDC
92 return realm_to_ip_mappings[realm]
94 return None
96 def forwarder(self, name):
97 lname = name.lower()
99 # check for special cases used by tests (e.g. dns_forwarder.py)
100 if lname.endswith('an-address-that-will-not-resolve'):
101 return 'ignore'
102 if lname.endswith('dsfsdfs'):
103 return 'fail'
104 if lname.endswith("torture1", 0, len(lname)-2):
105 # CATCH TORTURE100, TORTURE101, ...
106 return 'torture'
107 if lname.endswith('_none_.example.com'):
108 return 'torture'
109 if lname.endswith('torturedom.samba.example.com'):
110 return 'torture'
112 # return the testenv PDC matching the realm being requested
113 return self.get_pdc_ipv4_addr(lname)
115 def handle(self):
116 start = time.monotonic()
117 data, sock = self.request
118 query = ndr.ndr_unpack(dns.name_packet, data)
119 name = query.questions[0].name
120 forwarder = self.forwarder(name)
121 response = None
123 if forwarder == 'ignore':
124 return
125 elif forwarder == 'fail':
126 pass
127 elif forwarder in ['torture', None]:
128 response = query
129 response.operation |= dns.DNS_FLAG_REPLY
130 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
131 response.operation |= dns.DNS_RCODE_NXDOMAIN
132 else:
133 try:
134 response = self.dns_transaction_udp(query, forwarder)
135 except OSError as err:
136 print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
137 (forwarder, name, err))
139 if response is None:
140 response = query
141 response.operation |= dns.DNS_FLAG_REPLY
142 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
143 response.operation |= dns.DNS_RCODE_SERVFAIL
145 send_packet = ndr.ndr_pack(response)
147 end = time.monotonic()
148 tdiff = end - start
149 errcode = response.operation & dns.DNS_RCODE
150 if tdiff > (DNS_REQUEST_TIMEOUT/5):
151 debug = True
152 else:
153 debug = False
154 if debug:
155 print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
156 (forwarder, self.client_address, name,
157 self.dns_qtype_string(query.questions[0].question_type),
158 self.dns_rcode_string(errcode), response.operation, tdiff))
160 try:
161 sock.sendto(send_packet, self.client_address)
162 except socket.error as err:
163 print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
164 (self.client_address, name, tdiff, err))
167 class server_thread(threading.Thread):
168 def __init__(self, server, name):
169 threading.Thread.__init__(self, name=name)
170 self.server = server
172 def run(self):
173 print("dns_hub[%s]: before serve_forever()" % self.name)
174 self.server.serve_forever()
175 print("dns_hub[%s]: after serve_forever()" % self.name)
177 def stop(self):
178 print("dns_hub[%s]: before shutdown()" % self.name)
179 self.server.shutdown()
180 print("dns_hub[%s]: after shutdown()" % self.name)
181 self.server.server_close()
183 class UDPV4Server(sserver.UDPServer):
184 address_family = socket.AF_INET
186 class UDPV6Server(sserver.UDPServer):
187 address_family = socket.AF_INET6
189 def main():
190 if len(sys.argv) < 4:
191 print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
192 sys.exit(1)
194 timeout = int(sys.argv[1]) * 1000
195 timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more
196 # we pass in the listen addresses as a comma-separated string.
197 listenaddresses = sys.argv[2].split(',')
198 # we pass in the realm-to-IP mappings as a comma-separated key=value
199 # string. Convert this back into a dictionary that the DnsHandler can use
200 realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
202 def prepare_server_thread(listenaddress, realm_mappings):
204 flags = socket.AddressInfo.AI_NUMERICHOST
205 flags |= socket.AddressInfo.AI_NUMERICSERV
206 flags |= socket.AddressInfo.AI_PASSIVE
207 addr_info = socket.getaddrinfo(listenaddress, int(53),
208 type=socket.SocketKind.SOCK_DGRAM,
209 flags=flags)
210 assert len(addr_info) == 1
211 if addr_info[0][0] == socket.AddressFamily.AF_INET6:
212 server = UDPV6Server(addr_info[0][4], DnsHandler)
213 else:
214 server = UDPV4Server(addr_info[0][4], DnsHandler)
216 # we pass in the realm-to-IP mappings as a comma-separated key=value
217 # string. Convert this back into a dictionary that the DnsHandler can use
218 server.realm_to_ip_mappings = realm_mappings
219 t = server_thread(server, name="UDP[%s]" % listenaddress)
220 return t
222 print("dns_hub will proxy DNS requests for the following realms:")
223 for realm, ip in realm_mappings.items():
224 print(" {0} ==> {1}".format(realm, ip))
226 print("dns_hub will listen on the following UDP addresses:")
227 threads = []
228 for listenaddress in listenaddresses:
229 print(" %s" % listenaddress)
230 t = prepare_server_thread(listenaddress, realm_mappings)
231 threads.append(t)
233 for t in threads:
234 t.start()
235 p = select.poll()
236 stdin = sys.stdin.fileno()
237 p.register(stdin, select.POLLIN)
238 p.poll(timeout)
239 print("dns_hub: after poll()")
240 for t in threads:
241 t.stop()
242 for t in threads:
243 t.join()
244 print("dns_hub: before exit()")
245 sys.exit(0)
247 main()