1 # Copyright 2016-2017 Tobias Grosser
3 # Use of this software is governed by the MIT license
5 # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
10 # Test that isl objects can be constructed.
13 # - construction from a string
14 # - construction from an integer
15 # - static constructor without a parameter
16 # - conversion construction
17 # - construction of empty union set
19 # The tests to construct from integers and strings cover functionality that
20 # is also tested in the parameter type tests, but here the presence of
21 # multiple overloaded constructors and overload resolution is tested.
23 def test_constructors():
25 assert zero1
.is_zero()
28 assert zero2
.is_zero()
30 zero3
= isl
.val
.zero()
31 assert zero3
.is_zero()
33 bs
= isl
.basic_set("{ [1] }")
34 result
= isl
.set("{ [1] }")
36 assert s
.is_equal(result
)
38 us
= isl
.union_set("{ A[1]; B[2, 3] }")
39 empty
= isl
.union_set
.empty()
40 assert us
.is_equal(us
.union(empty
))
43 # Test integer function parameters for a particular integer value.
47 val_str
= isl
.val(str(i
))
48 assert val_int
.eq(val_str
)
51 # Test integer function parameters.
53 # Verify that extreme values and zero work.
55 def test_parameters_int():
57 test_int(-sys
.maxsize
- 1)
61 # Test isl objects parameters.
63 # Verify that isl objects can be passed as lvalue and rvalue parameters.
64 # Also verify that isl object parameters are automatically type converted if
65 # there is an inheritance relation. Finally, test function calls without
66 # any additional parameters, apart from the isl object on which
67 # the method is called.
69 def test_parameters_obj():
70 a
= isl
.set("{ [0] }")
71 b
= isl
.set("{ [1] }")
72 c
= isl
.set("{ [2] }")
73 expected
= isl
.set("{ [i] : 0 <= i <= 2 }")
76 res_lvalue_param
= tmp
.union(c
)
77 assert res_lvalue_param
.is_equal(expected
)
79 res_rvalue_param
= a
.union(b
).union(c
)
80 assert res_rvalue_param
.is_equal(expected
)
82 a2
= isl
.basic_set("{ [0] }")
87 res_only_this_param
= two
.inv()
88 assert res_only_this_param
.eq(half
)
91 # Test different kinds of parameters to be passed to functions.
93 # This includes integer and isl object parameters.
95 def test_parameters():
100 # Test that isl objects are returned correctly.
102 # This only tests that after combining two objects, the result is successfully
105 def test_return_obj():
115 # Test that integer values are returned correctly.
117 def test_return_int():
119 neg_one
= isl
.val("-1")
123 assert neg_one
.sgn() < 0
124 assert zero
.sgn() == 0
127 # Test that isl_bool values are returned correctly.
129 # In particular, check the conversion to bool in case of true and false.
131 def test_return_bool():
132 empty
= isl
.set("{ : false }")
133 univ
= isl
.set("{ : }")
135 b_true
= empty
.is_empty()
136 b_false
= univ
.is_empty()
142 # Test that strings are returned correctly.
143 # Do so by calling overloaded isl.ast_build.from_expr methods.
145 def test_return_string():
146 context
= isl
.set("[n] -> { : }")
147 build
= isl
.ast_build
.from_context(context
)
148 pw_aff
= isl
.pw_aff("[n] -> { [n] }")
149 set = isl
.set("[n] -> { : n >= 0 }")
151 expr
= build
.expr_from(pw_aff
)
152 expected_string
= "n"
153 assert expected_string
== expr
.to_C_str()
155 expr
= build
.expr_from(set)
156 expected_string
= "n >= 0"
157 assert expected_string
== expr
.to_C_str()
160 # Test that return values are handled correctly.
162 # Test that isl objects, integers, boolean values, and strings are
163 # returned correctly.
172 # A class that is used to test isl.id.user.
181 # In particular, check that the object attached to an identifier
182 # can be retrieved again.
185 id = isl
.id("test", 5)
186 id2
= isl
.id("test2")
187 id3
= isl
.id("S", S())
188 assert id.user() == 5, f
"unexpected user object {id.user()}"
189 assert id2
.user() is None, f
"unexpected user object {id2.user()}"
191 assert isinstance(s
, S
), f
"unexpected user object {s}"
192 assert s
.value
== 42, f
"unexpected user object {s}"
195 # Test that foreach functions are modeled correctly.
197 # Verify that closures are correctly called as callback of a 'foreach'
198 # function and that variables captured by the closure work correctly. Also
199 # check that the foreach function handles exceptions thrown from
200 # the closure and that it propagates the exception.
203 s
= isl
.set("{ [0]; [1]; [2] }")
210 s
.foreach_basic_set(add
)
212 assert len(list) == 3
213 assert list[0].is_subset(s
)
214 assert list[1].is_subset(s
)
215 assert list[2].is_subset(s
)
216 assert not list[0].is_equal(list[1])
217 assert not list[0].is_equal(list[2])
218 assert not list[1].is_equal(list[2])
221 raise Exception("fail")
225 s
.foreach_basic_set(fail
)
231 # Test the functionality of "foreach_scc" functions.
233 # In particular, test it on a list of elements that can be completely sorted
234 # but where two of the elements ("a" and "b") are incomparable.
236 def test_foreach_scc():
237 list = isl
.id_list(3)
238 sorted = [isl
.id_list(3)]
240 "a": isl
.map("{ [0] -> [1] }"),
241 "b": isl
.map("{ [1] -> [0] }"),
242 "c": isl
.map("{ [i = 0:1] -> [i] }"),
244 for k
, v
in data
.items():
246 id = data
["a"].space().domain().identity_multi_pw_aff_on_domain()
249 map = data
[b
.name()].apply_domain(data
[a
.name()])
250 return not map.lex_ge_at(id).is_empty()
253 assert scc
.size() == 1
254 sorted[0] = sorted[0].concat(scc
)
256 list.foreach_scc(follows
, add_single
)
257 assert sorted[0].size() == 3
258 assert sorted[0].at(0).name() == "b"
259 assert sorted[0].at(1).name() == "c"
260 assert sorted[0].at(2).name() == "a"
263 # Test the functionality of "every" functions.
265 # In particular, test the generic functionality and
266 # test that exceptions are properly propagated.
269 us
= isl
.union_set("{ A[i]; B[j] }")
274 assert not us
.every_set(is_empty
)
277 return not s
.is_empty()
279 assert us
.every_set(is_non_empty
)
282 return s
.is_subset(isl
.set("{ A[x] }"))
284 assert not us
.every_set(in_A
)
287 return not s
.is_subset(isl
.set("{ A[x] }"))
289 assert not us
.every_set(not_in_A
)
292 raise Exception("fail")
302 # Check basic construction of spaces.
305 unit
= isl
.space
.unit()
306 set_space
= unit
.add_named_tuple("A", 3)
307 map_space
= set_space
.add_named_tuple("B", 2)
309 set = isl
.set.universe(set_space
)
310 map = isl
.map.universe(map_space
)
311 assert set.is_equal(isl
.set("{ A[*,*,*] }"))
312 assert map.is_equal(isl
.map("{ A[*,*,*] -> B[*,*] }"))
315 # Construct a simple schedule tree with an outer sequence node and
316 # a single-dimensional band node in each branch, with one of them
319 def construct_schedule_tree():
320 A
= isl
.union_set("{ A[i] : 0 <= i < 10 }")
321 B
= isl
.union_set("{ B[i] : 0 <= i < 20 }")
323 node
= isl
.schedule_node
.from_domain(A
.union(B
))
326 filters
= isl
.union_set_list(A
).add(B
)
327 node
= node
.insert_sequence(filters
)
329 f_A
= isl
.multi_union_pw_aff("[ { A[i] -> [i] } ]")
332 node
= node
.insert_partial_schedule(f_A
)
333 node
= node
.member_set_coincident(0, True)
334 node
= node
.ancestor(2)
336 f_B
= isl
.multi_union_pw_aff("[ { B[i] -> [i] } ]")
339 node
= node
.insert_partial_schedule(f_B
)
340 node
= node
.ancestor(2)
342 return node
.schedule()
345 # Test basic schedule tree functionality.
347 # In particular, create a simple schedule tree and
348 # - check that the root node is a domain node
349 # - test map_descendant_bottom_up
350 # - test foreach_descendant_top_down
351 # - test every_descendant
353 def test_schedule_tree():
354 schedule
= construct_schedule_tree()
355 root
= schedule
.root()
357 assert type(root
) == isl
.schedule_node_domain
365 root
= root
.map_descendant_bottom_up(inc_count
)
369 raise Exception("fail")
374 root
.map_descendant_bottom_up(fail_map
)
385 root
.foreach_descendant_top_down(inc_count
)
394 root
.foreach_descendant_top_down(inc_count
)
397 def is_not_domain(node
):
398 return type(node
) != isl
.schedule_node_domain
400 assert root
.child(0).every_descendant(is_not_domain
)
401 assert not root
.every_descendant(is_not_domain
)
404 raise Exception("fail")
408 root
.every_descendant(fail
)
413 domain
= root
.domain()
414 filters
= [isl
.union_set("{}")]
416 def collect_filters(node
):
417 if type(node
) == isl
.schedule_node_filter
:
418 filters
[0] = filters
[0].union(node
.filter())
421 root
.every_descendant(collect_filters
)
422 assert domain
.is_equal(filters
[0])
425 # Test marking band members for unrolling.
426 # "schedule" is the schedule created by construct_schedule_tree.
427 # It schedules two statements, with 10 and 20 instances, respectively.
428 # Unrolling all band members therefore results in 30 at-domain calls
429 # by the AST generator.
431 def test_ast_build_unroll(schedule
):
432 root
= schedule
.root()
434 def mark_unroll(node
):
435 if type(node
) == isl
.schedule_node_band
:
436 node
= node
.member_set_ast_loop_unroll(0)
439 root
= root
.map_descendant_bottom_up(mark_unroll
)
440 schedule
= root
.schedule()
444 def inc_count_ast(node
, build
):
448 build
= isl
.ast_build()
449 build
= build
.set_at_each_domain(inc_count_ast
)
450 ast
= build
.node_from(schedule
)
451 assert count_ast
[0] == 30
454 # Test basic AST generation from a schedule tree.
456 # In particular, create a simple schedule tree and
457 # - generate an AST from the schedule tree
458 # - test at_each_domain
461 def test_ast_build():
462 schedule
= construct_schedule_tree()
466 def inc_count_ast(node
, build
):
470 build
= isl
.ast_build()
471 build_copy
= build
.set_at_each_domain(inc_count_ast
)
472 ast
= build
.node_from(schedule
)
473 assert count_ast
[0] == 0
475 ast
= build_copy
.node_from(schedule
)
476 assert count_ast
[0] == 2
479 ast
= build
.node_from(schedule
)
480 assert count_ast
[0] == 2
485 def fail_inc_count_ast(node
, build
):
486 count_ast_fail
[0] += 1
488 raise Exception("fail")
491 build
= isl
.ast_build()
492 build
= build
.set_at_each_domain(fail_inc_count_ast
)
495 ast
= build
.node_from(schedule
)
499 assert count_ast_fail
[0] > 0
501 build_copy
= build_copy
.set_at_each_domain(inc_count_ast
)
503 ast
= build_copy
.node_from(schedule
)
504 assert count_ast
[0] == 2
505 count_ast_fail
[0] = 0
507 ast
= build
.node_from(schedule
)
508 assert count_ast_fail
[0] == 2
510 test_ast_build_unroll(schedule
)
513 # Test basic AST expression generation from an affine expression.
515 def test_ast_build_expr():
516 pa
= isl
.pw_aff("[n] -> { [n + 1] }")
517 build
= isl
.ast_build
.from_context(pa
.domain())
519 op
= build
.expr_from(pa
)
520 assert type(op
) == isl
.ast_expr_op_add
521 assert op
.n_arg() == 2
524 # Test the isl Python interface
527 # - Object construction
528 # - Different parameter types
529 # - Different return types
531 # - Foreach functions
532 # - Foreach SCC function
537 # - AST expression generation
549 test_ast_build_expr()