[flang] Use object before converts in fir.dispatch (#68589)
[llvm-project.git] / polly / utils / pyscop / isl.py
blob5eaf7798e20b97bed06dcdc3f9b52664a41fcbe6
1 from ctypes import *
3 isl = cdll.LoadLibrary("libisl.so")
6 class Context:
7 defaultInstance = None
8 instances = {}
10 def __init__(self):
11 ptr = isl.isl_ctx_alloc()
12 self.ptr = ptr
13 Context.instances[ptr] = self
15 def __del__(self):
16 isl.isl_ctx_free(self)
18 def from_param(self):
19 return self.ptr
21 @staticmethod
22 def from_ptr(ptr):
23 return Context.instances[ptr]
25 @staticmethod
26 def getDefaultInstance():
27 if Context.defaultInstance == None:
28 Context.defaultInstance = Context()
30 return Context.defaultInstance
33 class IslObject:
34 def __init__(self, string="", ctx=None, ptr=None):
35 self.initialize_isl_methods()
36 if ptr != None:
37 self.ptr = ptr
38 self.ctx = self.get_isl_method("get_ctx")(self)
39 return
41 if ctx == None:
42 ctx = Context.getDefaultInstance()
44 self.ctx = ctx
45 self.ptr = self.get_isl_method("read_from_str")(ctx, string, -1)
47 def __del__(self):
48 self.get_isl_method("free")(self)
50 def from_param(self):
51 return self.ptr
53 @property
54 def context(self):
55 return self.ctx
57 def __repr__(self):
58 p = Printer(self.ctx)
59 self.to_printer(p)
60 return p.getString()
62 def __str__(self):
63 p = Printer(self.ctx)
64 self.to_printer(p)
65 return p.getString()
67 @staticmethod
68 def isl_name():
69 return "No isl name available"
71 def initialize_isl_methods(self):
72 if hasattr(self.__class__, "initialized"):
73 return
75 self.__class__.initalized = True
76 self.get_isl_method("read_from_str").argtypes = [Context, c_char_p, c_int]
77 self.get_isl_method("copy").argtypes = [self.__class__]
78 self.get_isl_method("copy").restype = c_int
79 self.get_isl_method("free").argtypes = [self.__class__]
80 self.get_isl_method("get_ctx").argtypes = [self.__class__]
81 self.get_isl_method("get_ctx").restype = Context.from_ptr
82 getattr(isl, "isl_printer_print_" + self.isl_name()).argtypes = [
83 Printer,
84 self.__class__,
87 def get_isl_method(self, name):
88 return getattr(isl, "isl_" + self.isl_name() + "_" + name)
90 def to_printer(self, printer):
91 getattr(isl, "isl_printer_print_" + self.isl_name())(printer, self)
94 class BSet(IslObject):
95 @staticmethod
96 def from_ptr(ptr):
97 if not ptr:
98 return
99 return BSet(ptr=ptr)
101 @staticmethod
102 def isl_name():
103 return "basic_set"
106 class Set(IslObject):
107 @staticmethod
108 def from_ptr(ptr):
109 if not ptr:
110 return
111 return Set(ptr=ptr)
113 @staticmethod
114 def isl_name():
115 return "set"
118 class USet(IslObject):
119 @staticmethod
120 def from_ptr(ptr):
121 if not ptr:
122 return
123 return USet(ptr=ptr)
125 @staticmethod
126 def isl_name():
127 return "union_set"
130 class BMap(IslObject):
131 @staticmethod
132 def from_ptr(ptr):
133 if not ptr:
134 return
135 return BMap(ptr=ptr)
137 def __mul__(self, set):
138 return self.intersect_domain(set)
140 @staticmethod
141 def isl_name():
142 return "basic_map"
145 class Map(IslObject):
146 @staticmethod
147 def from_ptr(ptr):
148 if not ptr:
149 return
150 return Map(ptr=ptr)
152 def __mul__(self, set):
153 return self.intersect_domain(set)
155 @staticmethod
156 def isl_name():
157 return "map"
159 @staticmethod
160 def lex_lt(dim):
161 dim = isl.isl_dim_copy(dim)
162 return isl.isl_map_lex_lt(dim)
164 @staticmethod
165 def lex_le(dim):
166 dim = isl.isl_dim_copy(dim)
167 return isl.isl_map_lex_le(dim)
169 @staticmethod
170 def lex_gt(dim):
171 dim = isl.isl_dim_copy(dim)
172 return isl.isl_map_lex_gt(dim)
174 @staticmethod
175 def lex_ge(dim):
176 dim = isl.isl_dim_copy(dim)
177 return isl.isl_map_lex_ge(dim)
180 class UMap(IslObject):
181 @staticmethod
182 def from_ptr(ptr):
183 if not ptr:
184 return
185 return UMap(ptr=ptr)
187 @staticmethod
188 def isl_name():
189 return "union_map"
192 class Dim(IslObject):
193 @staticmethod
194 def from_ptr(ptr):
195 if not ptr:
196 return
197 return Dim(ptr=ptr)
199 @staticmethod
200 def isl_name():
201 return "dim"
203 def initialize_isl_methods(self):
204 if hasattr(self.__class__, "initialized"):
205 return
207 self.__class__.initalized = True
208 self.get_isl_method("copy").argtypes = [self.__class__]
209 self.get_isl_method("copy").restype = c_int
210 self.get_isl_method("free").argtypes = [self.__class__]
211 self.get_isl_method("get_ctx").argtypes = [self.__class__]
212 self.get_isl_method("get_ctx").restype = Context.from_ptr
214 def __repr__(self):
215 return str(self)
217 def __str__(self):
219 dimParam = isl.isl_dim_size(self, 1)
220 dimIn = isl.isl_dim_size(self, 2)
221 dimOut = isl.isl_dim_size(self, 3)
223 if dimIn:
224 return "<dim In:%s, Out:%s, Param:%s>" % (dimIn, dimOut, dimParam)
226 return "<dim Set:%s, Param:%s>" % (dimOut, dimParam)
229 class Printer:
230 FORMAT_ISL = 0
231 FORMAT_POLYLIB = 1
232 FORMAT_POLYLIB_CONSTRAINTS = 2
233 FORMAT_OMEGA = 3
234 FORMAT_C = 4
235 FORMAT_LATEX = 5
236 FORMAT_EXT_POLYLIB = 6
238 def __init__(self, ctx=None):
239 if ctx == None:
240 ctx = Context.getDefaultInstance()
242 self.ctx = ctx
243 self.ptr = isl.isl_printer_to_str(ctx)
245 def setFormat(self, format):
246 self.ptr = isl.isl_printer_set_output_format(self, format)
248 def from_param(self):
249 return self.ptr
251 def __del__(self):
252 isl.isl_printer_free(self)
254 def getString(self):
255 return isl.isl_printer_get_str(self)
258 functions = [
259 # Unary properties
260 ("is_empty", BSet, [BSet], c_int),
261 ("is_empty", Set, [Set], c_int),
262 ("is_empty", USet, [USet], c_int),
263 ("is_empty", BMap, [BMap], c_int),
264 ("is_empty", Map, [Map], c_int),
265 ("is_empty", UMap, [UMap], c_int),
266 # ("is_universe", Set, [Set], c_int),
267 # ("is_universe", Map, [Map], c_int),
268 ("is_single_valued", Map, [Map], c_int),
269 ("is_bijective", Map, [Map], c_int),
270 ("is_wrapping", BSet, [BSet], c_int),
271 ("is_wrapping", Set, [Set], c_int),
272 # Binary properties
273 ("is_equal", BSet, [BSet, BSet], c_int),
274 ("is_equal", Set, [Set, Set], c_int),
275 ("is_equal", USet, [USet, USet], c_int),
276 ("is_equal", BMap, [BMap, BMap], c_int),
277 ("is_equal", Map, [Map, Map], c_int),
278 ("is_equal", UMap, [UMap, UMap], c_int),
279 # is_disjoint missing
280 # ("is_subset", BSet, [BSet, BSet], c_int),
281 ("is_subset", Set, [Set, Set], c_int),
282 ("is_subset", USet, [USet, USet], c_int),
283 ("is_subset", BMap, [BMap, BMap], c_int),
284 ("is_subset", Map, [Map, Map], c_int),
285 ("is_subset", UMap, [UMap, UMap], c_int),
286 # ("is_strict_subset", BSet, [BSet, BSet], c_int),
287 ("is_strict_subset", Set, [Set, Set], c_int),
288 ("is_strict_subset", USet, [USet, USet], c_int),
289 ("is_strict_subset", BMap, [BMap, BMap], c_int),
290 ("is_strict_subset", Map, [Map, Map], c_int),
291 ("is_strict_subset", UMap, [UMap, UMap], c_int),
292 # Unary Operations
293 ("complement", Set, [Set], Set),
294 ("reverse", BMap, [BMap], BMap),
295 ("reverse", Map, [Map], Map),
296 ("reverse", UMap, [UMap], UMap),
297 # Projection missing
298 ("range", BMap, [BMap], BSet),
299 ("range", Map, [Map], Set),
300 ("range", UMap, [UMap], USet),
301 ("domain", BMap, [BMap], BSet),
302 ("domain", Map, [Map], Set),
303 ("domain", UMap, [UMap], USet),
304 ("identity", Set, [Set], Map),
305 ("identity", USet, [USet], UMap),
306 ("deltas", BMap, [BMap], BSet),
307 ("deltas", Map, [Map], Set),
308 ("deltas", UMap, [UMap], USet),
309 ("coalesce", Set, [Set], Set),
310 ("coalesce", USet, [USet], USet),
311 ("coalesce", Map, [Map], Map),
312 ("coalesce", UMap, [UMap], UMap),
313 ("detect_equalities", BSet, [BSet], BSet),
314 ("detect_equalities", Set, [Set], Set),
315 ("detect_equalities", USet, [USet], USet),
316 ("detect_equalities", BMap, [BMap], BMap),
317 ("detect_equalities", Map, [Map], Map),
318 ("detect_equalities", UMap, [UMap], UMap),
319 ("convex_hull", Set, [Set], Set),
320 ("convex_hull", Map, [Map], Map),
321 ("simple_hull", Set, [Set], Set),
322 ("simple_hull", Map, [Map], Map),
323 ("affine_hull", BSet, [BSet], BSet),
324 ("affine_hull", Set, [Set], BSet),
325 ("affine_hull", USet, [USet], USet),
326 ("affine_hull", BMap, [BMap], BMap),
327 ("affine_hull", Map, [Map], BMap),
328 ("affine_hull", UMap, [UMap], UMap),
329 ("polyhedral_hull", Set, [Set], Set),
330 ("polyhedral_hull", USet, [USet], USet),
331 ("polyhedral_hull", Map, [Map], Map),
332 ("polyhedral_hull", UMap, [UMap], UMap),
333 # Power missing
334 # Transitive closure missing
335 # Reaching path lengths missing
336 ("wrap", BMap, [BMap], BSet),
337 ("wrap", Map, [Map], Set),
338 ("wrap", UMap, [UMap], USet),
339 ("unwrap", BSet, [BMap], BMap),
340 ("unwrap", Set, [Map], Map),
341 ("unwrap", USet, [UMap], UMap),
342 ("flatten", Set, [Set], Set),
343 ("flatten", Map, [Map], Map),
344 ("flatten_map", Set, [Set], Map),
345 # Dimension manipulation missing
346 # Binary Operations
347 ("intersect", BSet, [BSet, BSet], BSet),
348 ("intersect", Set, [Set, Set], Set),
349 ("intersect", USet, [USet, USet], USet),
350 ("intersect", BMap, [BMap, BMap], BMap),
351 ("intersect", Map, [Map, Map], Map),
352 ("intersect", UMap, [UMap, UMap], UMap),
353 ("intersect_domain", BMap, [BMap, BSet], BMap),
354 ("intersect_domain", Map, [Map, Set], Map),
355 ("intersect_domain", UMap, [UMap, USet], UMap),
356 ("intersect_range", BMap, [BMap, BSet], BMap),
357 ("intersect_range", Map, [Map, Set], Map),
358 ("intersect_range", UMap, [UMap, USet], UMap),
359 ("union", BSet, [BSet, BSet], Set),
360 ("union", Set, [Set, Set], Set),
361 ("union", USet, [USet, USet], USet),
362 ("union", BMap, [BMap, BMap], Map),
363 ("union", Map, [Map, Map], Map),
364 ("union", UMap, [UMap, UMap], UMap),
365 ("subtract", Set, [Set, Set], Set),
366 ("subtract", Map, [Map, Map], Map),
367 ("subtract", USet, [USet, USet], USet),
368 ("subtract", UMap, [UMap, UMap], UMap),
369 ("apply", BSet, [BSet, BMap], BSet),
370 ("apply", Set, [Set, Map], Set),
371 ("apply", USet, [USet, UMap], USet),
372 ("apply_domain", BMap, [BMap, BMap], BMap),
373 ("apply_domain", Map, [Map, Map], Map),
374 ("apply_domain", UMap, [UMap, UMap], UMap),
375 ("apply_range", BMap, [BMap, BMap], BMap),
376 ("apply_range", Map, [Map, Map], Map),
377 ("apply_range", UMap, [UMap, UMap], UMap),
378 ("gist", BSet, [BSet, BSet], BSet),
379 ("gist", Set, [Set, Set], Set),
380 ("gist", USet, [USet, USet], USet),
381 ("gist", BMap, [BMap, BMap], BMap),
382 ("gist", Map, [Map, Map], Map),
383 ("gist", UMap, [UMap, UMap], UMap),
384 # Lexicographic Optimizations
385 # partial_lexmin missing
386 ("lexmin", BSet, [BSet], BSet),
387 ("lexmin", Set, [Set], Set),
388 ("lexmin", USet, [USet], USet),
389 ("lexmin", BMap, [BMap], BMap),
390 ("lexmin", Map, [Map], Map),
391 ("lexmin", UMap, [UMap], UMap),
392 ("lexmax", BSet, [BSet], BSet),
393 ("lexmax", Set, [Set], Set),
394 ("lexmax", USet, [USet], USet),
395 ("lexmax", BMap, [BMap], BMap),
396 ("lexmax", Map, [Map], Map),
397 ("lexmax", UMap, [UMap], UMap),
398 # Undocumented
399 ("lex_lt_union_set", USet, [USet, USet], UMap),
400 ("lex_le_union_set", USet, [USet, USet], UMap),
401 ("lex_gt_union_set", USet, [USet, USet], UMap),
402 ("lex_ge_union_set", USet, [USet, USet], UMap),
404 keep_functions = [
405 # Unary properties
406 ("get_dim", BSet, [BSet], Dim),
407 ("get_dim", Set, [Set], Dim),
408 ("get_dim", USet, [USet], Dim),
409 ("get_dim", BMap, [BMap], Dim),
410 ("get_dim", Map, [Map], Dim),
411 ("get_dim", UMap, [UMap], Dim),
415 def addIslFunction(object, name):
416 functionName = "isl_" + object.isl_name() + "_" + name
417 islFunction = getattr(isl, functionName)
418 if len(islFunction.argtypes) == 1:
419 f = lambda a: islFunctionOneOp(islFunction, a)
420 elif len(islFunction.argtypes) == 2:
421 f = lambda a, b: islFunctionTwoOp(islFunction, a, b)
422 object.__dict__[name] = f
425 def islFunctionOneOp(islFunction, ops):
426 ops = getattr(isl, "isl_" + ops.isl_name() + "_copy")(ops)
427 return islFunction(ops)
430 def islFunctionTwoOp(islFunction, opOne, opTwo):
431 opOne = getattr(isl, "isl_" + opOne.isl_name() + "_copy")(opOne)
432 opTwo = getattr(isl, "isl_" + opTwo.isl_name() + "_copy")(opTwo)
433 return islFunction(opOne, opTwo)
436 for (operation, base, operands, ret) in functions:
437 functionName = "isl_" + base.isl_name() + "_" + operation
438 islFunction = getattr(isl, functionName)
439 if len(operands) == 1:
440 islFunction.argtypes = [c_int]
441 elif len(operands) == 2:
442 islFunction.argtypes = [c_int, c_int]
444 if ret == c_int:
445 islFunction.restype = ret
446 else:
447 islFunction.restype = ret.from_ptr
449 addIslFunction(base, operation)
452 def addIslFunctionKeep(object, name):
453 functionName = "isl_" + object.isl_name() + "_" + name
454 islFunction = getattr(isl, functionName)
455 if len(islFunction.argtypes) == 1:
456 f = lambda a: islFunctionOneOpKeep(islFunction, a)
457 elif len(islFunction.argtypes) == 2:
458 f = lambda a, b: islFunctionTwoOpKeep(islFunction, a, b)
459 object.__dict__[name] = f
462 def islFunctionOneOpKeep(islFunction, ops):
463 return islFunction(ops)
466 def islFunctionTwoOpKeep(islFunction, opOne, opTwo):
467 return islFunction(opOne, opTwo)
470 for (operation, base, operands, ret) in keep_functions:
471 functionName = "isl_" + base.isl_name() + "_" + operation
472 islFunction = getattr(isl, functionName)
473 if len(operands) == 1:
474 islFunction.argtypes = [c_int]
475 elif len(operands) == 2:
476 islFunction.argtypes = [c_int, c_int]
478 if ret == c_int:
479 islFunction.restype = ret
480 else:
481 islFunction.restype = ret.from_ptr
483 addIslFunctionKeep(base, operation)
485 isl.isl_ctx_free.argtypes = [Context]
486 isl.isl_basic_set_read_from_str.argtypes = [Context, c_char_p, c_int]
487 isl.isl_set_read_from_str.argtypes = [Context, c_char_p, c_int]
488 isl.isl_basic_set_copy.argtypes = [BSet]
489 isl.isl_basic_set_copy.restype = c_int
490 isl.isl_set_copy.argtypes = [Set]
491 isl.isl_set_copy.restype = c_int
492 isl.isl_set_copy.argtypes = [Set]
493 isl.isl_set_copy.restype = c_int
494 isl.isl_set_free.argtypes = [Set]
495 isl.isl_basic_set_get_ctx.argtypes = [BSet]
496 isl.isl_basic_set_get_ctx.restype = Context.from_ptr
497 isl.isl_set_get_ctx.argtypes = [Set]
498 isl.isl_set_get_ctx.restype = Context.from_ptr
499 isl.isl_basic_set_get_dim.argtypes = [BSet]
500 isl.isl_basic_set_get_dim.restype = Dim.from_ptr
501 isl.isl_set_get_dim.argtypes = [Set]
502 isl.isl_set_get_dim.restype = Dim.from_ptr
503 isl.isl_union_set_get_dim.argtypes = [USet]
504 isl.isl_union_set_get_dim.restype = Dim.from_ptr
506 isl.isl_basic_map_read_from_str.argtypes = [Context, c_char_p, c_int]
507 isl.isl_map_read_from_str.argtypes = [Context, c_char_p, c_int]
508 isl.isl_basic_map_free.argtypes = [BMap]
509 isl.isl_map_free.argtypes = [Map]
510 isl.isl_basic_map_copy.argtypes = [BMap]
511 isl.isl_basic_map_copy.restype = c_int
512 isl.isl_map_copy.argtypes = [Map]
513 isl.isl_map_copy.restype = c_int
514 isl.isl_map_get_ctx.argtypes = [Map]
515 isl.isl_basic_map_get_ctx.argtypes = [BMap]
516 isl.isl_basic_map_get_ctx.restype = Context.from_ptr
517 isl.isl_map_get_ctx.argtypes = [Map]
518 isl.isl_map_get_ctx.restype = Context.from_ptr
519 isl.isl_basic_map_get_dim.argtypes = [BMap]
520 isl.isl_basic_map_get_dim.restype = Dim.from_ptr
521 isl.isl_map_get_dim.argtypes = [Map]
522 isl.isl_map_get_dim.restype = Dim.from_ptr
523 isl.isl_union_map_get_dim.argtypes = [UMap]
524 isl.isl_union_map_get_dim.restype = Dim.from_ptr
525 isl.isl_printer_free.argtypes = [Printer]
526 isl.isl_printer_to_str.argtypes = [Context]
527 isl.isl_printer_print_basic_set.argtypes = [Printer, BSet]
528 isl.isl_printer_print_set.argtypes = [Printer, Set]
529 isl.isl_printer_print_basic_map.argtypes = [Printer, BMap]
530 isl.isl_printer_print_map.argtypes = [Printer, Map]
531 isl.isl_printer_get_str.argtypes = [Printer]
532 isl.isl_printer_get_str.restype = c_char_p
533 isl.isl_printer_set_output_format.argtypes = [Printer, c_int]
534 isl.isl_printer_set_output_format.restype = c_int
535 isl.isl_dim_size.argtypes = [Dim, c_int]
536 isl.isl_dim_size.restype = c_int
538 isl.isl_map_lex_lt.argtypes = [c_int]
539 isl.isl_map_lex_lt.restype = Map.from_ptr
540 isl.isl_map_lex_le.argtypes = [c_int]
541 isl.isl_map_lex_le.restype = Map.from_ptr
542 isl.isl_map_lex_gt.argtypes = [c_int]
543 isl.isl_map_lex_gt.restype = Map.from_ptr
544 isl.isl_map_lex_ge.argtypes = [c_int]
545 isl.isl_map_lex_ge.restype = Map.from_ptr
547 isl.isl_union_map_compute_flow.argtypes = [
548 c_int,
549 c_int,
550 c_int,
551 c_int,
552 c_void_p,
553 c_void_p,
554 c_void_p,
555 c_void_p,
559 def dependences(sink, must_source, may_source, schedule):
560 sink = getattr(isl, "isl_" + sink.isl_name() + "_copy")(sink)
561 must_source = getattr(isl, "isl_" + must_source.isl_name() + "_copy")(must_source)
562 may_source = getattr(isl, "isl_" + may_source.isl_name() + "_copy")(may_source)
563 schedule = getattr(isl, "isl_" + schedule.isl_name() + "_copy")(schedule)
564 must_dep = c_int()
565 may_dep = c_int()
566 must_no_source = c_int()
567 may_no_source = c_int()
568 isl.isl_union_map_compute_flow(
569 sink,
570 must_source,
571 may_source,
572 schedule,
573 byref(must_dep),
574 byref(may_dep),
575 byref(must_no_source),
576 byref(may_no_source),
579 return (
580 UMap.from_ptr(must_dep),
581 UMap.from_ptr(may_dep),
582 USet.from_ptr(must_no_source),
583 USet.from_ptr(may_no_source),
587 __all__ = ["Set", "Map", "Printer", "Context"]