[clang-repl] [codegen] Reduce the state in TBAA. NFC for static compilation. (#98138)
[llvm-project.git] / polly / lib / External / isl / isl_test_python.py
blob05bb0c8246421fc8ba90e4dbe76f7993d948c40d
1 # Copyright 2016-2017 Tobias Grosser
3 # Use of this software is governed by the MIT license
5 # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
7 import sys
8 import isl
10 # Test that isl objects can be constructed.
12 # This tests:
13 # - construction from a string
14 # - construction from an integer
15 # - static constructor without a parameter
16 # - conversion construction
17 # - construction of empty union set
19 # The tests to construct from integers and strings cover functionality that
20 # is also tested in the parameter type tests, but here the presence of
21 # multiple overloaded constructors and overload resolution is tested.
23 def test_constructors():
24 zero1 = isl.val("0")
25 assert zero1.is_zero()
27 zero2 = isl.val(0)
28 assert zero2.is_zero()
30 zero3 = isl.val.zero()
31 assert zero3.is_zero()
33 bs = isl.basic_set("{ [1] }")
34 result = isl.set("{ [1] }")
35 s = isl.set(bs)
36 assert s.is_equal(result)
38 us = isl.union_set("{ A[1]; B[2, 3] }")
39 empty = isl.union_set.empty()
40 assert us.is_equal(us.union(empty))
43 # Test integer function parameters for a particular integer value.
45 def test_int(i):
46 val_int = isl.val(i)
47 val_str = isl.val(str(i))
48 assert val_int.eq(val_str)
51 # Test integer function parameters.
53 # Verify that extreme values and zero work.
55 def test_parameters_int():
56 test_int(sys.maxsize)
57 test_int(-sys.maxsize - 1)
58 test_int(0)
61 # Test isl objects parameters.
63 # Verify that isl objects can be passed as lvalue and rvalue parameters.
64 # Also verify that isl object parameters are automatically type converted if
65 # there is an inheritance relation. Finally, test function calls without
66 # any additional parameters, apart from the isl object on which
67 # the method is called.
69 def test_parameters_obj():
70 a = isl.set("{ [0] }")
71 b = isl.set("{ [1] }")
72 c = isl.set("{ [2] }")
73 expected = isl.set("{ [i] : 0 <= i <= 2 }")
75 tmp = a.union(b)
76 res_lvalue_param = tmp.union(c)
77 assert res_lvalue_param.is_equal(expected)
79 res_rvalue_param = a.union(b).union(c)
80 assert res_rvalue_param.is_equal(expected)
82 a2 = isl.basic_set("{ [0] }")
83 assert a.is_equal(a2)
85 two = isl.val(2)
86 half = isl.val("1/2")
87 res_only_this_param = two.inv()
88 assert res_only_this_param.eq(half)
91 # Test different kinds of parameters to be passed to functions.
93 # This includes integer and isl object parameters.
95 def test_parameters():
96 test_parameters_int()
97 test_parameters_obj()
100 # Test that isl objects are returned correctly.
102 # This only tests that after combining two objects, the result is successfully
103 # returned.
105 def test_return_obj():
106 one = isl.val("1")
107 two = isl.val("2")
108 three = isl.val("3")
110 res = one.add(two)
112 assert res.eq(three)
115 # Test that integer values are returned correctly.
117 def test_return_int():
118 one = isl.val("1")
119 neg_one = isl.val("-1")
120 zero = isl.val("0")
122 assert one.sgn() > 0
123 assert neg_one.sgn() < 0
124 assert zero.sgn() == 0
127 # Test that isl_bool values are returned correctly.
129 # In particular, check the conversion to bool in case of true and false.
131 def test_return_bool():
132 empty = isl.set("{ : false }")
133 univ = isl.set("{ : }")
135 b_true = empty.is_empty()
136 b_false = univ.is_empty()
138 assert b_true
139 assert not b_false
142 # Test that strings are returned correctly.
143 # Do so by calling overloaded isl.ast_build.from_expr methods.
145 def test_return_string():
146 context = isl.set("[n] -> { : }")
147 build = isl.ast_build.from_context(context)
148 pw_aff = isl.pw_aff("[n] -> { [n] }")
149 set = isl.set("[n] -> { : n >= 0 }")
151 expr = build.expr_from(pw_aff)
152 expected_string = "n"
153 assert expected_string == expr.to_C_str()
155 expr = build.expr_from(set)
156 expected_string = "n >= 0"
157 assert expected_string == expr.to_C_str()
160 # Test that return values are handled correctly.
162 # Test that isl objects, integers, boolean values, and strings are
163 # returned correctly.
165 def test_return():
166 test_return_obj()
167 test_return_int()
168 test_return_bool()
169 test_return_string()
172 # A class that is used to test isl.id.user.
174 class S:
175 def __init__(self):
176 self.value = 42
179 # Test isl.id.user.
181 # In particular, check that the object attached to an identifier
182 # can be retrieved again.
184 def test_user():
185 id = isl.id("test", 5)
186 id2 = isl.id("test2")
187 id3 = isl.id("S", S())
188 assert id.user() == 5, f"unexpected user object {id.user()}"
189 assert id2.user() is None, f"unexpected user object {id2.user()}"
190 s = id3.user()
191 assert isinstance(s, S), f"unexpected user object {s}"
192 assert s.value == 42, f"unexpected user object {s}"
195 # Test that foreach functions are modeled correctly.
197 # Verify that closures are correctly called as callback of a 'foreach'
198 # function and that variables captured by the closure work correctly. Also
199 # check that the foreach function handles exceptions thrown from
200 # the closure and that it propagates the exception.
202 def test_foreach():
203 s = isl.set("{ [0]; [1]; [2] }")
205 list = []
207 def add(bs):
208 list.append(bs)
210 s.foreach_basic_set(add)
212 assert len(list) == 3
213 assert list[0].is_subset(s)
214 assert list[1].is_subset(s)
215 assert list[2].is_subset(s)
216 assert not list[0].is_equal(list[1])
217 assert not list[0].is_equal(list[2])
218 assert not list[1].is_equal(list[2])
220 def fail(bs):
221 raise Exception("fail")
223 caught = False
224 try:
225 s.foreach_basic_set(fail)
226 except:
227 caught = True
228 assert caught
231 # Test the functionality of "foreach_scc" functions.
233 # In particular, test it on a list of elements that can be completely sorted
234 # but where two of the elements ("a" and "b") are incomparable.
236 def test_foreach_scc():
237 list = isl.id_list(3)
238 sorted = [isl.id_list(3)]
239 data = {
240 "a": isl.map("{ [0] -> [1] }"),
241 "b": isl.map("{ [1] -> [0] }"),
242 "c": isl.map("{ [i = 0:1] -> [i] }"),
244 for k, v in data.items():
245 list = list.add(k)
246 id = data["a"].space().domain().identity_multi_pw_aff_on_domain()
248 def follows(a, b):
249 map = data[b.name()].apply_domain(data[a.name()])
250 return not map.lex_ge_at(id).is_empty()
252 def add_single(scc):
253 assert scc.size() == 1
254 sorted[0] = sorted[0].concat(scc)
256 list.foreach_scc(follows, add_single)
257 assert sorted[0].size() == 3
258 assert sorted[0].at(0).name() == "b"
259 assert sorted[0].at(1).name() == "c"
260 assert sorted[0].at(2).name() == "a"
263 # Test the functionality of "every" functions.
265 # In particular, test the generic functionality and
266 # test that exceptions are properly propagated.
268 def test_every():
269 us = isl.union_set("{ A[i]; B[j] }")
271 def is_empty(s):
272 return s.is_empty()
274 assert not us.every_set(is_empty)
276 def is_non_empty(s):
277 return not s.is_empty()
279 assert us.every_set(is_non_empty)
281 def in_A(s):
282 return s.is_subset(isl.set("{ A[x] }"))
284 assert not us.every_set(in_A)
286 def not_in_A(s):
287 return not s.is_subset(isl.set("{ A[x] }"))
289 assert not us.every_set(not_in_A)
291 def fail(s):
292 raise Exception("fail")
294 caught = False
295 try:
296 us.ever_set(fail)
297 except:
298 caught = True
299 assert caught
302 # Check basic construction of spaces.
304 def test_space():
305 unit = isl.space.unit()
306 set_space = unit.add_named_tuple("A", 3)
307 map_space = set_space.add_named_tuple("B", 2)
309 set = isl.set.universe(set_space)
310 map = isl.map.universe(map_space)
311 assert set.is_equal(isl.set("{ A[*,*,*] }"))
312 assert map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))
315 # Construct a simple schedule tree with an outer sequence node and
316 # a single-dimensional band node in each branch, with one of them
317 # marked coincident.
319 def construct_schedule_tree():
320 A = isl.union_set("{ A[i] : 0 <= i < 10 }")
321 B = isl.union_set("{ B[i] : 0 <= i < 20 }")
323 node = isl.schedule_node.from_domain(A.union(B))
324 node = node.child(0)
326 filters = isl.union_set_list(A).add(B)
327 node = node.insert_sequence(filters)
329 f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
330 node = node.child(0)
331 node = node.child(0)
332 node = node.insert_partial_schedule(f_A)
333 node = node.member_set_coincident(0, True)
334 node = node.ancestor(2)
336 f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
337 node = node.child(1)
338 node = node.child(0)
339 node = node.insert_partial_schedule(f_B)
340 node = node.ancestor(2)
342 return node.schedule()
345 # Test basic schedule tree functionality.
347 # In particular, create a simple schedule tree and
348 # - check that the root node is a domain node
349 # - test map_descendant_bottom_up
350 # - test foreach_descendant_top_down
351 # - test every_descendant
353 def test_schedule_tree():
354 schedule = construct_schedule_tree()
355 root = schedule.root()
357 assert type(root) == isl.schedule_node_domain
359 count = [0]
361 def inc_count(node):
362 count[0] += 1
363 return node
365 root = root.map_descendant_bottom_up(inc_count)
366 assert count[0] == 8
368 def fail_map(node):
369 raise Exception("fail")
370 return node
372 caught = False
373 try:
374 root.map_descendant_bottom_up(fail_map)
375 except:
376 caught = True
377 assert caught
379 count = [0]
381 def inc_count(node):
382 count[0] += 1
383 return True
385 root.foreach_descendant_top_down(inc_count)
386 assert count[0] == 8
388 count = [0]
390 def inc_count(node):
391 count[0] += 1
392 return False
394 root.foreach_descendant_top_down(inc_count)
395 assert count[0] == 1
397 def is_not_domain(node):
398 return type(node) != isl.schedule_node_domain
400 assert root.child(0).every_descendant(is_not_domain)
401 assert not root.every_descendant(is_not_domain)
403 def fail(node):
404 raise Exception("fail")
406 caught = False
407 try:
408 root.every_descendant(fail)
409 except:
410 caught = True
411 assert caught
413 domain = root.domain()
414 filters = [isl.union_set("{}")]
416 def collect_filters(node):
417 if type(node) == isl.schedule_node_filter:
418 filters[0] = filters[0].union(node.filter())
419 return True
421 root.every_descendant(collect_filters)
422 assert domain.is_equal(filters[0])
425 # Test marking band members for unrolling.
426 # "schedule" is the schedule created by construct_schedule_tree.
427 # It schedules two statements, with 10 and 20 instances, respectively.
428 # Unrolling all band members therefore results in 30 at-domain calls
429 # by the AST generator.
431 def test_ast_build_unroll(schedule):
432 root = schedule.root()
434 def mark_unroll(node):
435 if type(node) == isl.schedule_node_band:
436 node = node.member_set_ast_loop_unroll(0)
437 return node
439 root = root.map_descendant_bottom_up(mark_unroll)
440 schedule = root.schedule()
442 count_ast = [0]
444 def inc_count_ast(node, build):
445 count_ast[0] += 1
446 return node
448 build = isl.ast_build()
449 build = build.set_at_each_domain(inc_count_ast)
450 ast = build.node_from(schedule)
451 assert count_ast[0] == 30
454 # Test basic AST generation from a schedule tree.
456 # In particular, create a simple schedule tree and
457 # - generate an AST from the schedule tree
458 # - test at_each_domain
459 # - test unrolling
461 def test_ast_build():
462 schedule = construct_schedule_tree()
464 count_ast = [0]
466 def inc_count_ast(node, build):
467 count_ast[0] += 1
468 return node
470 build = isl.ast_build()
471 build_copy = build.set_at_each_domain(inc_count_ast)
472 ast = build.node_from(schedule)
473 assert count_ast[0] == 0
474 count_ast[0] = 0
475 ast = build_copy.node_from(schedule)
476 assert count_ast[0] == 2
477 build = build_copy
478 count_ast[0] = 0
479 ast = build.node_from(schedule)
480 assert count_ast[0] == 2
482 do_fail = True
483 count_ast_fail = [0]
485 def fail_inc_count_ast(node, build):
486 count_ast_fail[0] += 1
487 if do_fail:
488 raise Exception("fail")
489 return node
491 build = isl.ast_build()
492 build = build.set_at_each_domain(fail_inc_count_ast)
493 caught = False
494 try:
495 ast = build.node_from(schedule)
496 except:
497 caught = True
498 assert caught
499 assert count_ast_fail[0] > 0
500 build_copy = build
501 build_copy = build_copy.set_at_each_domain(inc_count_ast)
502 count_ast[0] = 0
503 ast = build_copy.node_from(schedule)
504 assert count_ast[0] == 2
505 count_ast_fail[0] = 0
506 do_fail = False
507 ast = build.node_from(schedule)
508 assert count_ast_fail[0] == 2
510 test_ast_build_unroll(schedule)
513 # Test basic AST expression generation from an affine expression.
515 def test_ast_build_expr():
516 pa = isl.pw_aff("[n] -> { [n + 1] }")
517 build = isl.ast_build.from_context(pa.domain())
519 op = build.expr_from(pa)
520 assert type(op) == isl.ast_expr_op_add
521 assert op.n_arg() == 2
524 # Test the isl Python interface
526 # This includes:
527 # - Object construction
528 # - Different parameter types
529 # - Different return types
530 # - isl.id.user
531 # - Foreach functions
532 # - Foreach SCC function
533 # - Every functions
534 # - Spaces
535 # - Schedule trees
536 # - AST generation
537 # - AST expression generation
539 test_constructors()
540 test_parameters()
541 test_return()
542 test_user()
543 test_foreach()
544 test_foreach_scc()
545 test_every()
546 test_space()
547 test_schedule_tree()
548 test_ast_build()
549 test_ast_build_expr()