1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from samba
.tests
import TestCaseInTempDir
20 from samba
.dcerpc
import dns
, dnsp
21 from samba
import gensec
, tests
22 from samba
import credentials
23 from samba
import NTSTATUSError
25 import samba
.ndr
as ndr
32 class DNSTest(TestCaseInTempDir
):
38 def errstr(self
, errcode
):
39 "Return a readable error code"
61 return string_codes
[errcode
]
63 def assert_rcode_equals(self
, rcode
, expected
):
64 "Helper function to check return code"
65 self
.assertEqual(rcode
, expected
, "Expected RCODE %s, got %s" %
66 (self
.errstr(expected
), self
.errstr(rcode
)))
68 def assert_dns_rcode_equals(self
, packet
, rcode
):
69 "Helper function to check return code"
70 p_errcode
= packet
.operation
& dns
.DNS_RCODE
71 self
.assertEqual(p_errcode
, rcode
, "Expected RCODE %s, got %s" %
72 (self
.errstr(rcode
), self
.errstr(p_errcode
)))
74 def assert_dns_opcode_equals(self
, packet
, opcode
):
75 "Helper function to check opcode"
76 p_opcode
= packet
.operation
& dns
.DNS_OPCODE
77 self
.assertEqual(p_opcode
, opcode
, "Expected OPCODE %s, got %s" %
80 def assert_dns_flags_equals(self
, packet
, flags
):
81 "Helper function to check opcode"
82 p_flags
= packet
.operation
& (~
(dns
.DNS_OPCODE|dns
.DNS_RCODE
))
83 self
.assertEqual(p_flags
, flags
, "Expected FLAGS %02x, got %02x" %
86 def assert_echoed_dns_error(self
, request
, response
, response_p
, rcode
):
88 request_p
= ndr
.ndr_pack(request
)
90 self
.assertEqual(response
.id, request
.id)
91 self
.assert_dns_rcode_equals(response
, rcode
)
92 self
.assert_dns_opcode_equals(response
, request
.operation
& dns
.DNS_OPCODE
)
93 self
.assert_dns_flags_equals(response
,
94 (request
.operation | dns
.DNS_FLAG_REPLY
) & (~
(dns
.DNS_OPCODE|dns
.DNS_RCODE
)))
95 self
.assertEqual(len(response_p
), len(request_p
))
96 self
.assertEqual(response_p
[4:], request_p
[4:])
98 def make_name_packet(self
, opcode
, qid
=None):
99 "Helper creating a dns.name_packet"
100 p
= dns
.name_packet()
102 p
.id = random
.randint(0x0, 0xff00)
108 def finish_name_packet(self
, packet
, questions
):
109 "Helper to finalize a dns.name_packet"
110 packet
.qdcount
= len(questions
)
111 packet
.questions
= questions
113 def make_name_question(self
, name
, qtype
, qclass
):
114 "Helper creating a dns.name_question"
115 q
= dns
.name_question()
117 q
.question_type
= qtype
118 q
.question_class
= qclass
121 def make_txt_record(self
, records
):
122 rdata_txt
= dns
.txt_record()
123 s_list
= dnsp
.string_list()
124 s_list
.count
= len(records
)
126 rdata_txt
.txt
= s_list
129 def get_dns_domain(self
):
130 "Helper to get dns domain"
131 return self
.creds
.get_realm().lower()
133 def dns_transaction_udp(self
, packet
, host
,
134 allow_remaining
=False,
135 allow_truncated
=False,
136 dump
=False, timeout
=None):
137 "send a DNS query and read the reply"
140 timeout
= self
.timeout
142 send_packet
= ndr
.ndr_pack(packet
)
144 print(self
.hexdump(send_packet
))
145 s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
, 0)
146 s
.settimeout(timeout
)
147 s
.connect((host
, 53))
148 s
.sendall(send_packet
, 0)
149 recv_packet
= s
.recv(2048, 0)
151 print(self
.hexdump(recv_packet
))
153 # with allow_remaining
154 # we add some zero bytes
155 # in order to also parse truncated
157 recv_packet_p
= recv_packet
+ 32*b
"\x00"
158 allow_remaining
= True
160 recv_packet_p
= recv_packet
161 response
= ndr
.ndr_unpack(dns
.name_packet
, recv_packet_p
,
162 allow_remaining
=allow_remaining
)
163 return (response
, recv_packet
)
164 except RuntimeError as re
:
167 raise AssertionError(re
)
172 def dns_transaction_tcp(self
, packet
, host
,
173 dump
=False, timeout
=None):
174 "send a DNS query and read the reply, also return the raw packet"
177 timeout
= self
.timeout
179 send_packet
= ndr
.ndr_pack(packet
)
181 print(self
.hexdump(send_packet
))
182 s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
, 0)
183 s
.settimeout(timeout
)
184 s
.connect((host
, 53))
185 tcp_packet
= struct
.pack('!H', len(send_packet
))
186 tcp_packet
+= send_packet
187 s
.sendall(tcp_packet
)
191 for i
in range(0, 2 + 0xffff):
192 if len(recv_packet
) >= 2:
193 length
, = struct
.unpack('!H', recv_packet
[0:2])
194 remaining
= 2 + length
197 remaining
-= len(recv_packet
)
200 recv_packet
+= s
.recv(remaining
, 0)
202 print(self
.hexdump(recv_packet
))
203 response
= ndr
.ndr_unpack(dns
.name_packet
, recv_packet
[2:])
205 except RuntimeError as re
:
208 raise AssertionError(re
)
213 # unpacking and packing again should produce same bytestream
214 my_packet
= ndr
.ndr_pack(response
)
215 self
.assertEqual(my_packet
, recv_packet
[2:])
216 return (response
, recv_packet
[2:])
218 def make_txt_update(self
, prefix
, txt_array
, zone
=None, ttl
=900):
219 p
= self
.make_name_packet(dns
.DNS_OPCODE_UPDATE
)
222 name
= zone
or self
.get_dns_domain()
223 u
= self
.make_name_question(name
, dns
.DNS_QTYPE_SOA
, dns
.DNS_QCLASS_IN
)
225 self
.finish_name_packet(p
, updates
)
229 r
.name
= "%s.%s" % (prefix
, name
)
230 r
.rr_type
= dns
.DNS_QTYPE_TXT
231 r
.rr_class
= dns
.DNS_QCLASS_IN
234 rdata
= self
.make_txt_record(txt_array
)
237 p
.nscount
= len(updates
)
242 def check_query_txt(self
, prefix
, txt_array
, zone
=None):
243 name
= "%s.%s" % (prefix
, zone
or self
.get_dns_domain())
244 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
247 q
= self
.make_name_question(name
, dns
.DNS_QTYPE_TXT
, dns
.DNS_QCLASS_IN
)
250 self
.finish_name_packet(p
, questions
)
251 (response
, response_packet
) =\
252 self
.dns_transaction_udp(p
, host
=self
.server_ip
)
253 self
.assert_dns_rcode_equals(response
, dns
.DNS_RCODE_OK
)
254 self
.assertEqual(response
.ancount
, 1)
255 self
.assertEqual(response
.answers
[0].rdata
.txt
.str, txt_array
)
258 class DNSTKeyTest(DNSTest
):
262 self
.settings
["lp_ctx"] = self
.lp_ctx
= tests
.env_loadparm()
263 self
.settings
["target_hostname"] = self
.server
265 self
.creds
= credentials
.Credentials()
266 self
.creds
.guess(self
.lp_ctx
)
267 self
.creds
.set_username(tests
.env_get_var_value('USERNAME'))
268 self
.creds
.set_password(tests
.env_get_var_value('PASSWORD'))
269 self
.creds
.set_kerberos_state(credentials
.MUST_USE_KERBEROS
)
271 self
.unpriv_creds
= None
273 self
.newrecname
= "tkeytsig.%s" % self
.get_dns_domain()
275 def get_unpriv_creds(self
):
276 if self
.unpriv_creds
is not None:
277 return self
.unpriv_creds
279 self
.unpriv_creds
= credentials
.Credentials()
280 self
.unpriv_creds
.guess(self
.lp_ctx
)
281 self
.unpriv_creds
.set_username(tests
.env_get_var_value('USERNAME_UNPRIV'))
282 self
.unpriv_creds
.set_password(tests
.env_get_var_value('PASSWORD_UNPRIV'))
283 self
.unpriv_creds
.set_kerberos_state(credentials
.MUST_USE_KERBEROS
)
285 return self
.unpriv_creds
287 def tkey_trans(self
, creds
=None, algorithm_name
="gss-tsig",
288 tkey_req_in_answers
=False,
289 expected_rcode
=dns
.DNS_RCODE_OK
):
290 "Do a TKEY transaction and establish a gensec context"
298 tkey
['name'] = "%s.%s" % (uuid
.uuid4(), self
.get_dns_domain())
299 tkey
['creds'] = creds
301 tkey
['algorithm'] = algorithm_name
303 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
304 q
= self
.make_name_question(tkey
['name'],
309 self
.finish_name_packet(p
, questions
)
312 r
.name
= tkey
['name']
313 r
.rr_type
= dns
.DNS_QTYPE_TKEY
314 r
.rr_class
= dns
.DNS_QCLASS_IN
317 rdata
= dns
.tkey_record()
318 rdata
.algorithm
= algorithm_name
319 rdata
.inception
= int(time
.time())
320 rdata
.expiration
= int(time
.time()) + 60 * 60
321 rdata
.mode
= dns
.DNS_TKEY_MODE_GSSAPI
325 tkey
['gensec'] = gensec
.Security
.start_client(self
.settings
)
326 tkey
['gensec'].set_credentials(creds
)
327 tkey
['gensec'].set_target_service("dns")
328 tkey
['gensec'].set_target_hostname(self
.server
)
329 tkey
['gensec'].want_feature(gensec
.FEATURE_SIGN
)
330 tkey
['gensec'].start_mech_by_name(tkey
['mech'])
333 client_to_server
= b
""
335 (finished
, server_to_client
) = tkey
['gensec'].update(client_to_server
)
336 self
.assertFalse(finished
)
338 data
= list(server_to_client
)
339 rdata
.key_data
= data
340 rdata
.key_size
= len(data
)
344 if tkey_req_in_answers
:
346 p
.answers
= additional
349 p
.additional
= additional
351 (response
, response_packet
) =\
352 self
.dns_transaction_tcp(p
, self
.server_ip
)
353 if expected_rcode
!= dns
.DNS_RCODE_OK
:
354 self
.assert_echoed_dns_error(p
, response
, response_packet
, expected_rcode
)
356 self
.assert_dns_rcode_equals(response
, dns
.DNS_RCODE_OK
)
358 tkey_record
= response
.answers
[0].rdata
359 server_to_client
= bytes(tkey_record
.key_data
)
360 (finished
, client_to_server
) = tkey
['gensec'].update(server_to_client
)
361 self
.assertTrue(finished
)
365 self
.verify_packet(response
, response_packet
)
367 def verify_packet(self
, response
, response_packet
, request_mac
=b
""):
368 self
.assertEqual(response
.arcount
, 1)
369 self
.assertEqual(response
.additional
[0].rr_type
, dns
.DNS_QTYPE_TSIG
)
371 if self
.tkey
['algorithm'] == "gss-tsig":
376 request_mac_len
= b
""
377 if len(request_mac
) > 0 and gss_tsig
:
378 request_mac_len
= struct
.pack('!H', len(request_mac
))
380 tsig_record
= response
.additional
[0].rdata
381 mac
= bytes(tsig_record
.mac
)
383 self
.assertEqual(tsig_record
.original_id
, response
.id)
384 self
.assertEqual(tsig_record
.mac_size
, len(mac
))
386 # Cut off tsig record from dns response packet for MAC verification
387 # and reset additional record count.
388 response_copy
= ndr
.ndr_deepcopy(response
)
389 response_copy
.arcount
= 0
390 response_packet_wo_tsig
= ndr
.ndr_pack(response_copy
)
392 fake_tsig
= dns
.fake_tsig_rec()
393 fake_tsig
.name
= self
.tkey
['name']
394 fake_tsig
.rr_class
= dns
.DNS_QCLASS_ANY
396 fake_tsig
.time_prefix
= tsig_record
.time_prefix
397 fake_tsig
.time
= tsig_record
.time
398 fake_tsig
.algorithm_name
= tsig_record
.algorithm_name
399 fake_tsig
.fudge
= tsig_record
.fudge
400 fake_tsig
.error
= tsig_record
.error
401 fake_tsig
.other_size
= tsig_record
.other_size
402 fake_tsig
.other_data
= tsig_record
.other_data
403 fake_tsig_packet
= ndr
.ndr_pack(fake_tsig
)
405 data
= request_mac_len
+ request_mac
+ response_packet_wo_tsig
+ fake_tsig_packet
407 self
.tkey
['gensec'].check_packet(data
, data
, mac
)
408 except NTSTATUSError
as nt
:
409 raise AssertionError(nt
)
411 def sign_packet(self
, packet
, key_name
,
412 algorithm_name
="gss-tsig",
414 "Sign a packet, calculate a MAC and add TSIG record"
415 packet_data
= ndr
.ndr_pack(packet
)
417 fake_tsig
= dns
.fake_tsig_rec()
418 fake_tsig
.name
= key_name
419 fake_tsig
.rr_class
= dns
.DNS_QCLASS_ANY
421 fake_tsig
.time_prefix
= 0
422 fake_tsig
.time
= int(time
.time())
423 fake_tsig
.algorithm_name
= algorithm_name
424 fake_tsig
.fudge
= 300
426 fake_tsig
.other_size
= 0
427 fake_tsig_packet
= ndr
.ndr_pack(fake_tsig
)
429 data
= packet_data
+ fake_tsig_packet
430 mac
= self
.tkey
['gensec'].sign_packet(data
, data
)
434 mac_list
[-8] = mac_list
[-8] ^
0xff
436 mac_list
[-7] = ord('b')
438 mac_list
[-6] = ord('a')
440 mac_list
[-5] = ord('d')
442 mac_list
[-4] = ord('m')
444 mac_list
[-3] = ord('a')
446 mac_list
[-2] = ord('c')
448 mac_list
[-1] = mac_list
[-1] ^
0xff
450 rdata
= dns
.tsig_record()
451 rdata
.algorithm_name
= algorithm_name
452 rdata
.time_prefix
= 0
453 rdata
.time
= fake_tsig
.time
455 rdata
.original_id
= packet
.id
459 rdata
.mac_size
= len(mac_list
)
463 r
.rr_type
= dns
.DNS_QTYPE_TSIG
464 r
.rr_class
= dns
.DNS_QCLASS_ANY
470 packet
.additional
= additional
475 def bad_sign_packet(self
, packet
, key_name
):
476 """Add bad signature for a packet by
477 bitflipping and hardcoding bytes at the end of the MAC"""
479 return self
.sign_packet(packet
, key_name
, bad_sig
=True)
481 def search_record(self
, name
):
482 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
485 q
= self
.make_name_question(name
, dns
.DNS_QTYPE_TXT
, dns
.DNS_QCLASS_IN
)
488 self
.finish_name_packet(p
, questions
)
489 (response
, response_packet
) =\
490 self
.dns_transaction_udp(p
, self
.server_ip
)
491 return response
.operation
& 0x000F
493 def make_update_request(self
, delete
=False):
494 "Create a DNS update request"
496 rr_class
= dns
.DNS_QCLASS_IN
500 rr_class
= dns
.DNS_QCLASS_NONE
503 p
= self
.make_name_packet(dns
.DNS_OPCODE_UPDATE
)
504 q
= self
.make_name_question(self
.get_dns_domain(),
509 self
.finish_name_packet(p
, questions
)
513 r
.name
= self
.newrecname
514 r
.rr_type
= dns
.DNS_QTYPE_TXT
515 r
.rr_class
= rr_class
518 rdata
= self
.make_txt_record(['"This is a test"'])
521 p
.nscount
= len(updates
)