1 # -*- coding: utf-8 -*-
3 # Copyright Andrew Bartlett 2018
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 import samba
.getopt
as options
25 from samba
.auth
import system_session
26 from samba
.tests
import TestCase
29 ERRCODE_ENTRY_EXISTS
= 68
30 ERRCODE_OPERATIONS_ERROR
= 1
31 ERRCODE_INVALID_VALUE
= 21
32 ERRCODE_CLASS_VIOLATION
= 65
34 parser
= optparse
.OptionParser("{0} <host>".format(sys
.argv
[0]))
35 sambaopts
= options
.SambaOptions(parser
)
37 # use command line creds if available
38 credopts
= options
.CredentialsOptions(parser
)
39 parser
.add_option_group(credopts
)
40 parser
.add_option("-v", action
="store_true", dest
="verbose",
41 help="print successful expression outputs")
42 opts
, args
= parser
.parse_args()
48 lp
= sambaopts
.get_loadparm()
49 creds
= credopts
.get_credentials(lp
)
51 # Set properly at end of file.
58 class ComplexExpressionTests(TestCase
):
59 # Using setUpClass instead of setup because we're not modifying any
60 # records in the tests
64 cls
.samdb
= samba
.samdb
.SamDB(host
, lp
=lp
,
65 session_info
=system_session(),
68 ou_name
= "ComplexExprTest"
69 cls
.base_dn
= "OU={0},{1}".format(ou_name
, cls
.samdb
.domain_dn())
72 cls
.samdb
.delete(cls
.base_dn
, ["tree_delete:1"])
77 cls
.samdb
.create_ou(cls
.base_dn
)
78 except ldb
.LdbError
as e
:
79 if e
.args
[0] == ERRCODE_ENTRY_EXISTS
:
80 print(('test ou {ou} already exists. Delete with '
81 '"samba-tool group delete OU={ou} '
82 '--force-subtree-delete"').format(ou
=ou_name
))
85 cls
.name_template
= "testuser{0}"
88 # These fields are carefully hand-picked from the schema. They have
89 # syntax and handling appropriate for our test structure.
90 cls
.largeint_f
= "accountExpires"
91 cls
.str_f
= "accountNameHistory"
93 cls
.enum_f
= "preferredDeliveryMethod"
94 cls
.time_f
= "msTSExpireDate"
95 cls
.ranged_int_f
= "countryCode"
98 def tearDownClass(cls
):
99 cls
.samdb
.delete(cls
.base_dn
, ["tree_delete:1"])
101 # Make test OU containing users with field=val for each val
102 def make_test_objects(self
, field
, vals
):
105 ou_dn
= "OU=testou{0},{1}".format(ou_count
, self
.base_dn
)
106 self
.samdb
.create_ou(ou_dn
)
108 ldap_objects
= [{"dn": "CN=testuser{0},{1}".format(n
, ou_dn
),
109 "name": self
.name_template
.format(n
),
110 "objectClass": "user",
114 for ldap_object
in ldap_objects
:
115 # It's useful to keep appropriate python types in the ldap_object
116 # dict but samdb's 'add' function expects strings.
117 stringed_ldap_object
= {k
: str(v
)
118 for (k
, v
) in ldap_object
.items()}
120 self
.samdb
.add(stringed_ldap_object
)
121 except ldb
.LdbError
as e
:
122 print("failed to add %s" % (stringed_ldap_object
))
125 return ou_dn
, ldap_objects
127 # Run search expr and print out time. This function should be used for
128 # almost all searching.
129 def time_ldap_search(self
, expr
, dn
):
132 start_time
= time
.time()
133 res
= self
.samdb
.search(base
=dn
,
134 scope
=ldb
.SCOPE_SUBTREE
,
136 time_taken
= time
.time() - start_time
137 except Exception as e
:
138 print("failed expr " + expr
)
140 print("{0} took {1}s".format(expr
, time_taken
))
141 return res
, time_taken
143 # Take an ldap expression and an equivalent python expression.
144 # Run and time the ldap expression and compare the result to the python
145 # expression run over a list of ldap_object dicts.
146 def assertLDAPQuery(self
, ldap_expr
, ou_dn
, py_expr
, ldap_objects
):
148 # run (and time) the LDAP search expression over the DB
149 res
, time_taken
= self
.time_ldap_search(ldap_expr
, ou_dn
)
150 results
= {str(row
.get('name')[0]) for row
in res
}
152 # build the set of expected results by evaluating the python-equivalent
153 # of the search expression over the same set of objects
154 expected_results
= set()
155 for ldap_object
in ldap_objects
:
157 final_expr
= py_expr
.format(**ldap_object
)
159 # If the format on the py_expr hits a key error, then
160 # ldap_object doesn't have the field, so it shouldn't match.
164 expected_results
.add(str(ldap_object
['name']))
166 self
.assertEqual(results
, expected_results
)
169 ldap_object_names
= {l
['name'] for l
in ldap_objects
}
170 excluded
= ldap_object_names
- results
171 excluded
= "\n ".join(excluded
) or "[NOTHING]"
172 returned
= "\n ".join(expected_results
) or "[NOTHING]"
174 print("PASS: Expression {0} took {1}s and returned:"
176 "Excluded:\n {3}\n".format(ldap_expr
,
181 # Basic integer range test
182 def test_int_range(self
, field
=None):
184 field
= field
or self
.int_f
185 ou_dn
, ldap_objects
= self
.make_test_objects(field
, range(n
))
187 expr
= "(&(%s>=%s)(%s<=%s))" % (field
, n
-1, field
, n
+1)
188 py_expr
= "%d <= {%s} <= %d" % (n
-1, field
, n
+1)
189 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
193 expr
= "(%s<=%s)" % (field
, half_n
)
194 py_expr
= "{%s} <= %d" % (field
, half_n
)
195 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
197 expr
= "(%s>=%s)" % (field
, half_n
)
198 py_expr
= "{%s} >= %d" % (field
, half_n
)
199 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
201 # Same test again for largeint and enum
202 def test_largeint_range(self
):
203 self
.test_int_range(self
.largeint_f
)
205 def test_enum_range(self
):
206 self
.test_int_range(self
.enum_f
)
208 # Special range test for integer field with upper and lower bounds defined.
209 # The bounds are checked on insertion, not search, so we should be able
210 # to compare to a constant that is outside bounds.
211 def test_ranged_int_range(self
):
212 field
= self
.ranged_int_f
216 vals
= list(range(ubound
-width
, ubound
))
217 ou_dn
, ldap_objects
= self
.make_test_objects(field
, vals
)
219 # Check <= value above overflow returns all vals
220 expr
= "(%s<=%d)" % (field
, ubound
+5)
221 py_expr
= "{%s} <= %d" % (field
, ubound
+5)
222 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
224 # Test range also works for time fields
225 def test_time_range(self
):
231 base_time
= 20050116175514
232 time_range
= [base_time
+ t
for t
in range(-width
, width
)]
233 time_range
= [str(t
) + ".0Z" for t
in time_range
]
234 ou_dn
, ldap_objects
= self
.make_test_objects(field
, time_range
)
236 expr
= "(%s<=%s)" % (field
, str(base_time
) + ".0Z")
237 py_expr
= 'int("{%s}"[:-3]) <= %d' % (field
, base_time
)
238 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
240 expr
= "(&(%s>=%s)(%s<=%s))" % (field
, str(base_time
-1) + ".0Z",
241 field
, str(base_time
+1) + ".0Z")
242 py_expr
= '%d <= int("{%s}"[:-3]) <= %d' % (base_time
-1,
245 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
247 # Run each comparison op on a simple test set. Time taken will be printed.
248 def test_int_single_cmp_op_speeds(self
, field
=None):
250 field
= field
or self
.int_f
251 ou_dn
, ldap_objects
= self
.make_test_objects(field
, range(n
))
253 comp_ops
= ['=', '<=', '>=']
254 py_comp_ops
= ['==', '<=', '>=']
255 exprs
= ["(%s%s%d)" % (field
, c
, n
) for c
in comp_ops
]
256 py_exprs
= ["{%s}%s%d" % (field
, c
, n
) for c
in py_comp_ops
]
258 for expr
, py_expr
in zip(exprs
, py_exprs
):
259 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
261 def test_largeint_single_cmp_op_speeds(self
):
262 self
.test_int_single_cmp_op_speeds(self
.largeint_f
)
264 def test_enum_single_cmp_op_speeds(self
):
265 self
.test_int_single_cmp_op_speeds(self
.enum_f
)
267 # Check strings are ordered using a naive ordering.
268 def test_str_ordering(self
):
272 str_range
= ['abc{0}d'.format(chr(c
)) for c
in range(a_ord
, a_ord
+n
)]
273 ou_dn
, ldap_objects
= self
.make_test_objects(field
, str_range
)
274 half_n
= int(a_ord
+ n
/2)
276 # Basic <= and >= statements
277 expr
= "(%s>=abc%s)" % (field
, chr(half_n
))
278 py_expr
= "'{%s}' >= 'abc%s'" % (field
, chr(half_n
))
279 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
281 expr
= "(%s<=abc%s)" % (field
, chr(half_n
))
282 py_expr
= "'{%s}' <= 'abc%s'" % (field
, chr(half_n
))
283 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
286 expr
= "(&(%s>=abc%s)(%s<=abc%s))" % (field
, chr(half_n
-2),
287 field
, chr(half_n
+2))
288 py_expr
= "'abc%s' <= '{%s}' <= 'abc%s'" % (chr(half_n
-2),
291 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
293 # Integers treated as string
294 expr
= "(%s>=1)" % (field
)
295 py_expr
= "'{%s}' >= '1'" % (field
)
296 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
298 # Windows returns nothing for invalid expressions. Expected fail on samba.
299 def test_invalid_expressions(self
, field
=None):
300 field
= field
or self
.int_f
302 ou_dn
, ldap_objects
= self
.make_test_objects(field
, list(range(n
)))
303 int_expressions
= ["(%s>=abc)",
307 for expr
in int_expressions
:
308 expr
= expr
% (field
)
309 self
.assertLDAPQuery(expr
, ou_dn
, "False", ldap_objects
)
311 def test_largeint_invalid_expressions(self
):
312 self
.test_invalid_expressions(self
.largeint_f
)
314 def test_enum_invalid_expressions(self
):
315 self
.test_invalid_expressions(self
.enum_f
)
317 def test_case_insensitive(self
):
318 str_range
= ["äbc"+str(n
) for n
in range(10)]
319 ou_dn
, ldap_objects
= self
.make_test_objects(self
.str_f
, str_range
)
321 expr
= "(%s=äbc1)" % (self
.str_f
)
322 pyexpr
= '"{%s}"=="äbc1"' % (self
.str_f
)
323 self
.assertLDAPQuery(expr
, ou_dn
, pyexpr
, ldap_objects
)
325 expr
= "(%s=ÄbC1)" % (self
.str_f
)
326 self
.assertLDAPQuery(expr
, ou_dn
, pyexpr
, ldap_objects
)
328 # Check negative numbers can be entered and compared
329 def test_negative_cmp(self
, field
=None):
330 field
= field
or self
.int_f
332 around_zero
= list(range(-width
, width
))
333 ou_dn
, ldap_objects
= self
.make_test_objects(field
, around_zero
)
335 expr
= "(%s>=-3)" % (field
)
336 py_expr
= "{%s} >= -3" % (field
)
337 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
339 def test_negative_cmp_largeint(self
):
340 self
.test_negative_cmp(self
.largeint_f
)
342 def test_negative_cmp_enum(self
):
343 self
.test_negative_cmp(self
.enum_f
)
345 # Check behaviour on insertion and comparison of zero-prefixed numbers.
346 # Samba errors on insertion, Windows strips the leading zeroes.
347 def test_zero_prefix(self
, field
=None):
348 field
= field
or self
.int_f
350 # Test comparison with 0-prefixed constants.
352 ou_dn
, ldap_objects
= self
.make_test_objects(field
, list(range(n
)))
354 expr
= "(%s>=00%d)" % (field
, n
/2)
355 py_expr
= "{%s} >= %d" % (field
, n
/2)
356 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
358 # Delete the test OU so we don't mix it up with the next one.
359 self
.samdb
.delete(ou_dn
, ["tree_delete:1"])
361 # Try inserting 0-prefixed numbers, check it fails.
362 zero_pref_nums
= ['00'+str(num
) for num
in range(n
)]
364 ou_dn
, ldap_objects
= self
.make_test_objects(field
, zero_pref_nums
)
365 except ldb
.LdbError
as e
:
366 if e
.args
[0] != ERRCODE_INVALID_VALUE
:
370 # Samba doesn't get this far - the exception is raised. Windows allows
371 # the insertion and removes the leading 0s as tested below.
372 # Either behaviour is fine.
373 print("LDAP allowed insertion of 0-prefixed nums for field " + field
)
375 res
= self
.samdb
.search(base
=ou_dn
,
376 scope
=ldb
.SCOPE_SUBTREE
,
377 expression
="(objectClass=user)")
378 returned_nums
= [str(r
.get(field
)[0]) for r
in res
]
379 expect
= [str(n
) for n
in range(n
)]
380 self
.assertEqual(set(returned_nums
), set(expect
))
382 expr
= "(%s>=%d)" % (field
, n
/2)
383 py_expr
= "{%s} >= %d" % (field
, n
/2)
384 for ldap_object
in ldap_objects
:
385 ldap_object
[field
] = int(ldap_object
[field
])
387 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
389 def test_zero_prefix_largeint(self
):
390 self
.test_zero_prefix(self
.largeint_f
)
392 def test_zero_prefix_enum(self
):
393 self
.test_zero_prefix(self
.enum_f
)
395 # Check integer overflow is handled as best it can be.
396 def test_int_overflow(self
, field
=None, of
=None):
397 field
= field
or self
.int_f
401 vals
= list(range(of
-width
, of
+width
))
402 ou_dn
, ldap_objects
= self
.make_test_objects(field
, vals
)
404 # Check ">=overflow" doesn't return vals past overflow
405 expr
= "(%s>=%d)" % (field
, of
-3)
406 py_expr
= "%d <= {%s} <= %d" % (of
-3, field
, of
)
407 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
409 # "<=overflow" returns everything
410 expr
= "(%s<=%d)" % (field
, of
)
412 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
414 # Values past overflow should be negative
415 expr
= "(&(%s<=%d)(%s>=0))" % (field
, of
, field
)
416 py_expr
= "{%s} <= %d" % (field
, of
)
417 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
418 expr
= "(%s<=0)" % (field
)
419 py_expr
= "{%s} >= %d" % (field
, of
+1)
420 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
422 # Get the values back out and check vals past overflow are negative.
423 res
= self
.samdb
.search(base
=ou_dn
,
424 scope
=ldb
.SCOPE_SUBTREE
,
425 expression
="(objectClass=user)")
426 returned_nums
= [str(r
.get(field
)[0]) for r
in res
]
428 # Note: range(a,b) == [a..b-1] (confusing)
429 up_to_overflow
= list(range(of
-width
, of
+1))
430 negatives
= list(range(-of
-1, -of
+width
-2))
432 expect
= [str(n
) for n
in up_to_overflow
+ negatives
]
433 self
.assertEqual(set(returned_nums
), set(expect
))
435 def test_enum_overflow(self
):
436 self
.test_int_overflow(self
.enum_f
, 2**31-1)
438 # Check cmp works on uSNChanged. We can't insert uSNChanged vals, they get
439 # added automatically so we'll just insert some objects and go with what
441 def test_usnchanged(self
):
444 # Note we can't actually set uSNChanged via LDAP (LDB ignores it),
445 # so the input val range doesn't matter here
446 ou_dn
, _
= self
.make_test_objects(field
, list(range(n
)))
448 # Get the assigned uSNChanged values
449 res
= self
.samdb
.search(base
=ou_dn
,
450 scope
=ldb
.SCOPE_SUBTREE
,
451 expression
="(objectClass=user)")
453 # Our vals got ignored so make ldap_objects from search result
454 ldap_objects
= [{'name': str(r
['name'][0]),
455 field
: int(r
[field
][0])}
458 # Get the median val and use as the number in the test search expr.
459 nums
= [l
[field
] for l
in ldap_objects
]
460 nums
= list(sorted(nums
))
461 search_num
= nums
[int(len(nums
)/2)]
463 expr
= "(&(%s<=%d)(objectClass=user))" % (field
, search_num
)
464 py_expr
= "{%s} <= %d" % (field
, search_num
)
465 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
467 expr
= "(&(%s>=%d)(objectClass=user))" % (field
, search_num
)
468 py_expr
= "{%s} >= %d" % (field
, search_num
)
469 self
.assertLDAPQuery(expr
, ou_dn
, py_expr
, ldap_objects
)
472 # If we're called independently then import subunit, get host from first
473 # arg and run. Otherwise, subunit ran us so just set host from env.
474 # We always try to run over LDAP rather than direct file, so that
475 # search timings are not impacted by opening and closing the tdb file.
476 if __name__
== "__main__":
477 from samba
.tests
.subunitrun
import TestProgram
480 if "://" not in host
:
481 if os
.path
.isfile(host
):
482 host
= "tdb://%s" % host
484 host
= "ldap://%s" % host
485 TestProgram(module
=__name__
)
487 host
= "ldap://" + os
.getenv("SERVER")