ctdb-scripts: Improve update and listing code
[samba4-gss.git] / python / samba / tests / dns_base.py
blob5fd33ff54dc9b62804d885c2874a374b60a86cb9
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from samba.tests import TestCaseInTempDir
20 from samba.dcerpc import dns, dnsp
21 from samba import gensec, tests
22 from samba import credentials
23 from samba import NTSTATUSError
24 import struct
25 import samba.ndr as ndr
26 import random
27 import socket
28 import uuid
29 import time
32 class DNSTest(TestCaseInTempDir):
34 def setUp(self):
35 super().setUp()
36 self.timeout = None
38 def errstr(self, errcode):
39 "Return a readable error code"
40 string_codes = [
41 "OK",
42 "FORMERR",
43 "SERVFAIL",
44 "NXDOMAIN",
45 "NOTIMP",
46 "REFUSED",
47 "YXDOMAIN",
48 "YXRRSET",
49 "NXRRSET",
50 "NOTAUTH",
51 "NOTZONE",
52 "0x0B",
53 "0x0C",
54 "0x0D",
55 "0x0E",
56 "0x0F",
57 "BADSIG",
58 "BADKEY"
61 return string_codes[errcode]
63 def assert_rcode_equals(self, rcode, expected):
64 "Helper function to check return code"
65 self.assertEqual(rcode, expected, "Expected RCODE %s, got %s" %
66 (self.errstr(expected), self.errstr(rcode)))
68 def assert_dns_rcode_equals(self, packet, rcode):
69 "Helper function to check return code"
70 p_errcode = packet.operation & dns.DNS_RCODE
71 self.assertEqual(p_errcode, rcode, "Expected RCODE %s, got %s" %
72 (self.errstr(rcode), self.errstr(p_errcode)))
74 def assert_dns_opcode_equals(self, packet, opcode):
75 "Helper function to check opcode"
76 p_opcode = packet.operation & dns.DNS_OPCODE
77 self.assertEqual(p_opcode, opcode, "Expected OPCODE %s, got %s" %
78 (opcode, p_opcode))
80 def assert_dns_flags_equals(self, packet, flags):
81 "Helper function to check opcode"
82 p_flags = packet.operation & (~(dns.DNS_OPCODE|dns.DNS_RCODE))
83 self.assertEqual(p_flags, flags, "Expected FLAGS %02x, got %02x" %
84 (flags, p_flags))
86 def assert_echoed_dns_error(self, request, response, response_p, rcode):
88 request_p = ndr.ndr_pack(request)
90 self.assertEqual(response.id, request.id)
91 self.assert_dns_rcode_equals(response, rcode)
92 self.assert_dns_opcode_equals(response, request.operation & dns.DNS_OPCODE)
93 self.assert_dns_flags_equals(response,
94 (request.operation | dns.DNS_FLAG_REPLY) & (~(dns.DNS_OPCODE|dns.DNS_RCODE)))
95 self.assertEqual(len(response_p), len(request_p))
96 self.assertEqual(response_p[4:], request_p[4:])
98 def make_name_packet(self, opcode, qid=None):
99 "Helper creating a dns.name_packet"
100 p = dns.name_packet()
101 if qid is None:
102 p.id = random.randint(0x0, 0xff00)
103 p.operation = opcode
104 p.questions = []
105 p.additional = []
106 return p
108 def finish_name_packet(self, packet, questions):
109 "Helper to finalize a dns.name_packet"
110 packet.qdcount = len(questions)
111 packet.questions = questions
113 def make_name_question(self, name, qtype, qclass):
114 "Helper creating a dns.name_question"
115 q = dns.name_question()
116 q.name = name
117 q.question_type = qtype
118 q.question_class = qclass
119 return q
121 def make_txt_record(self, records):
122 rdata_txt = dns.txt_record()
123 s_list = dnsp.string_list()
124 s_list.count = len(records)
125 s_list.str = records
126 rdata_txt.txt = s_list
127 return rdata_txt
129 def get_dns_domain(self):
130 "Helper to get dns domain"
131 return self.creds.get_realm().lower()
133 def dns_transaction_udp(self, packet, host,
134 allow_remaining=False,
135 allow_truncated=False,
136 dump=False, timeout=None):
137 "send a DNS query and read the reply"
138 s = None
139 if timeout is None:
140 timeout = self.timeout
141 try:
142 send_packet = ndr.ndr_pack(packet)
143 if dump:
144 print(self.hexdump(send_packet))
145 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
146 s.settimeout(timeout)
147 s.connect((host, 53))
148 s.sendall(send_packet, 0)
149 recv_packet = s.recv(2048, 0)
150 if dump:
151 print(self.hexdump(recv_packet))
152 if allow_truncated:
153 # with allow_remaining
154 # we add some zero bytes
155 # in order to also parse truncated
156 # responses
157 recv_packet_p = recv_packet + 32*b"\x00"
158 allow_remaining = True
159 else:
160 recv_packet_p = recv_packet
161 response = ndr.ndr_unpack(dns.name_packet, recv_packet_p,
162 allow_remaining=allow_remaining)
163 return (response, recv_packet)
164 except RuntimeError as re:
165 if s is not None:
166 s.close()
167 raise AssertionError(re)
168 finally:
169 if s is not None:
170 s.close()
172 def dns_transaction_tcp(self, packet, host,
173 dump=False, timeout=None):
174 "send a DNS query and read the reply, also return the raw packet"
175 s = None
176 if timeout is None:
177 timeout = self.timeout
178 try:
179 send_packet = ndr.ndr_pack(packet)
180 if dump:
181 print(self.hexdump(send_packet))
182 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
183 s.settimeout(timeout)
184 s.connect((host, 53))
185 tcp_packet = struct.pack('!H', len(send_packet))
186 tcp_packet += send_packet
187 s.sendall(tcp_packet)
189 recv_packet = b''
190 length = None
191 for i in range(0, 2 + 0xffff):
192 if len(recv_packet) >= 2:
193 length, = struct.unpack('!H', recv_packet[0:2])
194 remaining = 2 + length
195 else:
196 remaining = 2 + 12
197 remaining -= len(recv_packet)
198 if remaining == 0:
199 break
200 recv_packet += s.recv(remaining, 0)
201 if dump:
202 print(self.hexdump(recv_packet))
203 response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
205 except RuntimeError as re:
206 if s is not None:
207 s.close()
208 raise AssertionError(re)
209 finally:
210 if s is not None:
211 s.close()
213 # unpacking and packing again should produce same bytestream
214 my_packet = ndr.ndr_pack(response)
215 self.assertEqual(my_packet, recv_packet[2:])
216 return (response, recv_packet[2:])
218 def make_txt_update(self, prefix, txt_array, zone=None, ttl=900):
219 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
220 updates = []
222 name = zone or self.get_dns_domain()
223 u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
224 updates.append(u)
225 self.finish_name_packet(p, updates)
227 updates = []
228 r = dns.res_rec()
229 r.name = "%s.%s" % (prefix, name)
230 r.rr_type = dns.DNS_QTYPE_TXT
231 r.rr_class = dns.DNS_QCLASS_IN
232 r.ttl = ttl
233 r.length = 0xffff
234 rdata = self.make_txt_record(txt_array)
235 r.rdata = rdata
236 updates.append(r)
237 p.nscount = len(updates)
238 p.nsrecs = updates
240 return p
242 def check_query_txt(self, prefix, txt_array, zone=None):
243 name = "%s.%s" % (prefix, zone or self.get_dns_domain())
244 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
245 questions = []
247 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
248 questions.append(q)
250 self.finish_name_packet(p, questions)
251 (response, response_packet) =\
252 self.dns_transaction_udp(p, host=self.server_ip)
253 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
254 self.assertEqual(response.ancount, 1)
255 self.assertEqual(response.answers[0].rdata.txt.str, txt_array)
258 class DNSTKeyTest(DNSTest):
259 def setUp(self):
260 super().setUp()
261 self.settings = {}
262 self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
263 self.settings["target_hostname"] = self.server
265 self.creds = credentials.Credentials()
266 self.creds.guess(self.lp_ctx)
267 self.creds.set_username(tests.env_get_var_value('USERNAME'))
268 self.creds.set_password(tests.env_get_var_value('PASSWORD'))
269 self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
271 self.unpriv_creds = None
273 self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
275 def get_unpriv_creds(self):
276 if self.unpriv_creds is not None:
277 return self.unpriv_creds
279 self.unpriv_creds = credentials.Credentials()
280 self.unpriv_creds.guess(self.lp_ctx)
281 self.unpriv_creds.set_username(tests.env_get_var_value('USERNAME_UNPRIV'))
282 self.unpriv_creds.set_password(tests.env_get_var_value('PASSWORD_UNPRIV'))
283 self.unpriv_creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
285 return self.unpriv_creds
287 def tkey_trans(self, creds=None, algorithm_name="gss-tsig",
288 tkey_req_in_answers=False,
289 expected_rcode=dns.DNS_RCODE_OK):
290 "Do a TKEY transaction and establish a gensec context"
292 if creds is None:
293 creds = self.creds
295 mech = 'spnego'
297 tkey = {}
298 tkey['name'] = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
299 tkey['creds'] = creds
300 tkey['mech'] = mech
301 tkey['algorithm'] = algorithm_name
303 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
304 q = self.make_name_question(tkey['name'],
305 dns.DNS_QTYPE_TKEY,
306 dns.DNS_QCLASS_IN)
307 questions = []
308 questions.append(q)
309 self.finish_name_packet(p, questions)
311 r = dns.res_rec()
312 r.name = tkey['name']
313 r.rr_type = dns.DNS_QTYPE_TKEY
314 r.rr_class = dns.DNS_QCLASS_IN
315 r.ttl = 0
316 r.length = 0xffff
317 rdata = dns.tkey_record()
318 rdata.algorithm = algorithm_name
319 rdata.inception = int(time.time())
320 rdata.expiration = int(time.time()) + 60 * 60
321 rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
322 rdata.error = 0
323 rdata.other_size = 0
325 tkey['gensec'] = gensec.Security.start_client(self.settings)
326 tkey['gensec'].set_credentials(creds)
327 tkey['gensec'].set_target_service("dns")
328 tkey['gensec'].set_target_hostname(self.server)
329 tkey['gensec'].want_feature(gensec.FEATURE_SIGN)
330 tkey['gensec'].start_mech_by_name(tkey['mech'])
332 finished = False
333 client_to_server = b""
335 (finished, server_to_client) = tkey['gensec'].update(client_to_server)
336 self.assertFalse(finished)
338 data = list(server_to_client)
339 rdata.key_data = data
340 rdata.key_size = len(data)
341 r.rdata = rdata
343 additional = [r]
344 if tkey_req_in_answers:
345 p.ancount = 1
346 p.answers = additional
347 else:
348 p.arcount = 1
349 p.additional = additional
351 (response, response_packet) =\
352 self.dns_transaction_tcp(p, self.server_ip)
353 if expected_rcode != dns.DNS_RCODE_OK:
354 self.assert_echoed_dns_error(p, response, response_packet, expected_rcode)
355 return
356 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
358 tkey_record = response.answers[0].rdata
359 server_to_client = bytes(tkey_record.key_data)
360 (finished, client_to_server) = tkey['gensec'].update(server_to_client)
361 self.assertTrue(finished)
363 self.tkey = tkey
365 self.verify_packet(response, response_packet)
367 def verify_packet(self, response, response_packet, request_mac=b""):
368 self.assertEqual(response.arcount, 1)
369 self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
371 if self.tkey['algorithm'] == "gss-tsig":
372 gss_tsig = True
373 else:
374 gss_tsig = False
376 request_mac_len = b""
377 if len(request_mac) > 0 and gss_tsig:
378 request_mac_len = struct.pack('!H', len(request_mac))
380 tsig_record = response.additional[0].rdata
381 mac = bytes(tsig_record.mac)
383 self.assertEqual(tsig_record.original_id, response.id)
384 self.assertEqual(tsig_record.mac_size, len(mac))
386 # Cut off tsig record from dns response packet for MAC verification
387 # and reset additional record count.
388 response_copy = ndr.ndr_deepcopy(response)
389 response_copy.arcount = 0
390 response_packet_wo_tsig = ndr.ndr_pack(response_copy)
392 fake_tsig = dns.fake_tsig_rec()
393 fake_tsig.name = self.tkey['name']
394 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
395 fake_tsig.ttl = 0
396 fake_tsig.time_prefix = tsig_record.time_prefix
397 fake_tsig.time = tsig_record.time
398 fake_tsig.algorithm_name = tsig_record.algorithm_name
399 fake_tsig.fudge = tsig_record.fudge
400 fake_tsig.error = tsig_record.error
401 fake_tsig.other_size = tsig_record.other_size
402 fake_tsig.other_data = tsig_record.other_data
403 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
405 data = request_mac_len + request_mac + response_packet_wo_tsig + fake_tsig_packet
406 try:
407 self.tkey['gensec'].check_packet(data, data, mac)
408 except NTSTATUSError as nt:
409 raise AssertionError(nt)
411 def sign_packet(self, packet, key_name,
412 algorithm_name="gss-tsig",
413 bad_sig=False):
414 "Sign a packet, calculate a MAC and add TSIG record"
415 packet_data = ndr.ndr_pack(packet)
417 fake_tsig = dns.fake_tsig_rec()
418 fake_tsig.name = key_name
419 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
420 fake_tsig.ttl = 0
421 fake_tsig.time_prefix = 0
422 fake_tsig.time = int(time.time())
423 fake_tsig.algorithm_name = algorithm_name
424 fake_tsig.fudge = 300
425 fake_tsig.error = 0
426 fake_tsig.other_size = 0
427 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
429 data = packet_data + fake_tsig_packet
430 mac = self.tkey['gensec'].sign_packet(data, data)
431 mac_list = list(mac)
432 if bad_sig:
433 if len(mac) > 8:
434 mac_list[-8] = mac_list[-8] ^ 0xff
435 if len(mac) > 7:
436 mac_list[-7] = ord('b')
437 if len(mac) > 6:
438 mac_list[-6] = ord('a')
439 if len(mac) > 5:
440 mac_list[-5] = ord('d')
441 if len(mac) > 4:
442 mac_list[-4] = ord('m')
443 if len(mac) > 3:
444 mac_list[-3] = ord('a')
445 if len(mac) > 2:
446 mac_list[-2] = ord('c')
447 if len(mac) > 1:
448 mac_list[-1] = mac_list[-1] ^ 0xff
450 rdata = dns.tsig_record()
451 rdata.algorithm_name = algorithm_name
452 rdata.time_prefix = 0
453 rdata.time = fake_tsig.time
454 rdata.fudge = 300
455 rdata.original_id = packet.id
456 rdata.error = 0
457 rdata.other_size = 0
458 rdata.mac = mac_list
459 rdata.mac_size = len(mac_list)
461 r = dns.res_rec()
462 r.name = key_name
463 r.rr_type = dns.DNS_QTYPE_TSIG
464 r.rr_class = dns.DNS_QCLASS_ANY
465 r.ttl = 0
466 r.length = 0xffff
467 r.rdata = rdata
469 additional = [r]
470 packet.additional = additional
471 packet.arcount = 1
473 return mac
475 def bad_sign_packet(self, packet, key_name):
476 """Add bad signature for a packet by
477 bitflipping and hardcoding bytes at the end of the MAC"""
479 return self.sign_packet(packet, key_name, bad_sig=True)
481 def search_record(self, name):
482 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
483 questions = []
485 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
486 questions.append(q)
488 self.finish_name_packet(p, questions)
489 (response, response_packet) =\
490 self.dns_transaction_udp(p, self.server_ip)
491 return response.operation & 0x000F
493 def make_update_request(self, delete=False):
494 "Create a DNS update request"
496 rr_class = dns.DNS_QCLASS_IN
497 ttl = 900
499 if delete:
500 rr_class = dns.DNS_QCLASS_NONE
501 ttl = 0
503 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
504 q = self.make_name_question(self.get_dns_domain(),
505 dns.DNS_QTYPE_SOA,
506 dns.DNS_QCLASS_IN)
507 questions = []
508 questions.append(q)
509 self.finish_name_packet(p, questions)
511 updates = []
512 r = dns.res_rec()
513 r.name = self.newrecname
514 r.rr_type = dns.DNS_QTYPE_TXT
515 r.rr_class = rr_class
516 r.ttl = ttl
517 r.length = 0xffff
518 rdata = self.make_txt_record(['"This is a test"'])
519 r.rdata = rdata
520 updates.append(r)
521 p.nscount = len(updates)
522 p.nsrecs = updates
524 return p