ctdb-scripts: Improve update and listing code
[samba4-gss.git] / python / samba / tests / krb5 / spn_tests.py
blob5bcc0bde3699c0732334a670cdf89be6110dfdb2
1 #!/usr/bin/env python3
2 # Unix SMB/CIFS implementation.
3 # Copyright (C) Stefan Metzmacher 2020
4 # Copyright (C) 2020 Catalyst.Net Ltd
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
20 import sys
21 import os
23 sys.path.insert(0, "bin/python")
24 os.environ["PYTHONUNBUFFERED"] = "1"
26 from samba.tests import DynamicTestCase
28 import ldb
30 from samba.tests.krb5.kdc_base_test import KDCBaseTest
31 from samba.tests.krb5.raw_testcase import KerberosCredentials
32 from samba.tests.krb5.rfc4120_constants import (
33 AES256_CTS_HMAC_SHA1_96,
34 ARCFOUR_HMAC_MD5,
35 KDC_ERR_S_PRINCIPAL_UNKNOWN,
36 NT_PRINCIPAL,
39 global_asn1_print = False
40 global_hexdump = False
43 @DynamicTestCase
44 class SpnTests(KDCBaseTest):
45 test_account_types = {
46 'computer': KDCBaseTest.AccountType.COMPUTER,
47 'server': KDCBaseTest.AccountType.SERVER,
48 'rodc': KDCBaseTest.AccountType.RODC
50 test_spns = {
51 '2_part': 'ldap/{{account}}',
52 '3_part_our_domain': 'ldap/{{account}}/{netbios_domain_name}',
53 '3_part_our_realm': 'ldap/{{account}}/{dns_domain_name}',
54 '3_part_not_our_realm': 'ldap/{{account}}/test',
55 '3_part_instance': 'ldap/{{account}}:test/{dns_domain_name}'
58 @classmethod
59 def setUpClass(cls):
60 super().setUpClass()
62 cls._mock_rodc_creds = None
64 @classmethod
65 def setUpDynamicTestCases(cls):
66 for account_type_name, account_type in cls.test_account_types.items():
67 for spn_name, spn in cls.test_spns.items():
68 tname = f'{spn_name}_spn_{account_type_name}'
69 targs = (account_type, spn)
70 cls.generate_dynamic_test('test_spn', tname, *targs)
72 def _test_spn_with_args(self, account_type, spn):
73 target_creds = self._get_creds(account_type)
74 spn = self._format_spn(spn, target_creds)
76 sname = self.PrincipalName_create(name_type=NT_PRINCIPAL,
77 names=spn.split('/'))
79 client_creds = self.get_client_creds()
80 tgt = self.get_tgt(client_creds)
82 samdb = self.get_samdb()
83 netbios_domain_name = samdb.domain_netbios_name()
84 dns_domain_name = samdb.domain_dns_name()
86 subkey = self.RandomKey(tgt.session_key.etype)
88 etypes = (AES256_CTS_HMAC_SHA1_96, ARCFOUR_HMAC_MD5,)
90 if account_type is self.AccountType.SERVER:
91 ticket_etype = AES256_CTS_HMAC_SHA1_96
92 else:
93 ticket_etype = None
94 decryption_key = self.TicketDecryptionKey_from_creds(
95 target_creds, etype=ticket_etype)
97 if (spn.count('/') > 1
98 and (spn.endswith(netbios_domain_name)
99 or spn.endswith(dns_domain_name))
100 and account_type is not self.AccountType.SERVER
101 and account_type is not self.AccountType.RODC):
102 expected_error_mode = KDC_ERR_S_PRINCIPAL_UNKNOWN
103 check_error_fn = self.generic_check_kdc_error
104 check_rep_fn = None
105 else:
106 expected_error_mode = 0
107 check_error_fn = None
108 check_rep_fn = self.generic_check_kdc_rep
110 kdc_exchange_dict = self.tgs_exchange_dict(
111 expected_crealm=tgt.crealm,
112 expected_cname=tgt.cname,
113 expected_srealm=tgt.srealm,
114 expected_sname=sname,
115 ticket_decryption_key=decryption_key,
116 check_rep_fn=check_rep_fn,
117 check_error_fn=check_error_fn,
118 check_kdc_private_fn=self.generic_check_kdc_private,
119 expected_error_mode=expected_error_mode,
120 tgt=tgt,
121 authenticator_subkey=subkey,
122 kdc_options='0',
123 expect_edata=False)
125 self._generic_kdc_exchange(kdc_exchange_dict,
126 cname=None,
127 realm=tgt.srealm,
128 sname=sname,
129 etypes=etypes)
131 def setUp(self):
132 super().setUp()
133 self.do_asn1_print = global_asn1_print
134 self.do_hexdump = global_hexdump
136 def _format_spns(self, spns, creds=None):
137 return map(lambda spn: self._format_spn(spn, creds), spns)
139 def _format_spn(self, spn, creds=None):
140 samdb = self.get_samdb()
142 spn = spn.format(netbios_domain_name=samdb.domain_netbios_name(),
143 dns_domain_name=samdb.domain_dns_name())
145 if creds is not None:
146 account_name = creds.get_username()
147 spn = spn.format(account=account_name)
149 return spn
151 def _get_creds(self, account_type):
152 spns = self._format_spns(self.test_spns.values())
154 if account_type is self.AccountType.RODC:
155 creds = self._mock_rodc_creds
156 if creds is None:
157 creds = self._get_mock_rodc_creds(spns)
158 type(self)._mock_rodc_creds = creds
159 else:
160 creds = self.get_cached_creds(
161 account_type=account_type,
162 opts={
163 'spn': spns
166 return creds
168 def _get_mock_rodc_creds(self, spns):
169 rodc_ctx = self.get_mock_rodc_ctx()
171 for spn in spns:
172 spn = spn.format(account=rodc_ctx.myname)
173 if spn not in rodc_ctx.SPNs:
174 rodc_ctx.SPNs.append(spn)
176 samdb = self.get_samdb()
177 rodc_dn = ldb.Dn(samdb, rodc_ctx.acct_dn)
179 msg = ldb.Message(rodc_dn)
180 msg['servicePrincipalName'] = ldb.MessageElement(
181 rodc_ctx.SPNs,
182 ldb.FLAG_MOD_REPLACE,
183 'servicePrincipalName')
184 samdb.modify(msg)
186 creds = KerberosCredentials()
187 creds.guess(self.get_lp())
188 creds.set_realm(rodc_ctx.realm.upper())
189 creds.set_domain(rodc_ctx.domain_name)
190 creds.set_password(rodc_ctx.acct_pass)
191 creds.set_username(rodc_ctx.myname)
192 creds.set_workstation(rodc_ctx.samname)
193 creds.set_dn(rodc_dn)
194 creds.set_spn(rodc_ctx.SPNs)
196 res = samdb.search(base=rodc_dn,
197 scope=ldb.SCOPE_BASE,
198 attrs=['msDS-KeyVersionNumber'])
199 kvno = int(res[0].get('msDS-KeyVersionNumber', idx=0))
200 creds.set_kvno(kvno)
202 keys = self.get_keys(creds)
203 self.creds_set_keys(creds, keys)
205 return creds
208 if __name__ == "__main__":
209 global_asn1_print = False
210 global_hexdump = False
211 import unittest
212 unittest.main()