Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / clang / utils / ABITest / Enumeration.py
blob005b104a337de3779522ab0a2ce608c5041e12f4
1 """Utilities for enumeration of finite and countably infinite sets.
2 """
3 from __future__ import absolute_import, division, print_function
5 ###
6 # Countable iteration
8 # Simplifies some calculations
9 class Aleph0(int):
10 _singleton = None
12 def __new__(type):
13 if type._singleton is None:
14 type._singleton = int.__new__(type)
15 return type._singleton
17 def __repr__(self):
18 return "<aleph0>"
20 def __str__(self):
21 return "inf"
23 def __cmp__(self, b):
24 return 1
26 def __sub__(self, b):
27 raise ValueError("Cannot subtract aleph0")
29 __rsub__ = __sub__
31 def __add__(self, b):
32 return self
34 __radd__ = __add__
36 def __mul__(self, b):
37 if b == 0:
38 return b
39 return self
41 __rmul__ = __mul__
43 def __floordiv__(self, b):
44 if b == 0:
45 raise ZeroDivisionError
46 return self
48 __rfloordiv__ = __floordiv__
49 __truediv__ = __floordiv__
50 __rtuediv__ = __floordiv__
51 __div__ = __floordiv__
52 __rdiv__ = __floordiv__
54 def __pow__(self, b):
55 if b == 0:
56 return 1
57 return self
60 aleph0 = Aleph0()
63 def base(line):
64 return line * (line + 1) // 2
67 def pairToN(pair):
68 x, y = pair
69 line, index = x + y, y
70 return base(line) + index
73 def getNthPairInfo(N):
74 # Avoid various singularities
75 if N == 0:
76 return (0, 0)
78 # Gallop to find bounds for line
79 line = 1
80 next = 2
81 while base(next) <= N:
82 line = next
83 next = line << 1
85 # Binary search for starting line
86 lo = line
87 hi = line << 1
88 while lo + 1 != hi:
89 # assert base(lo) <= N < base(hi)
90 mid = (lo + hi) >> 1
91 if base(mid) <= N:
92 lo = mid
93 else:
94 hi = mid
96 line = lo
97 return line, N - base(line)
100 def getNthPair(N):
101 line, index = getNthPairInfo(N)
102 return (line - index, index)
105 def getNthPairBounded(N, W=aleph0, H=aleph0, useDivmod=False):
106 """getNthPairBounded(N, W, H) -> (x, y)
108 Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
110 if W <= 0 or H <= 0:
111 raise ValueError("Invalid bounds")
112 elif N >= W * H:
113 raise ValueError("Invalid input (out of bounds)")
115 # Simple case...
116 if W is aleph0 and H is aleph0:
117 return getNthPair(N)
119 # Otherwise simplify by assuming W < H
120 if H < W:
121 x, y = getNthPairBounded(N, H, W, useDivmod=useDivmod)
122 return y, x
124 if useDivmod:
125 return N % W, N // W
126 else:
127 # Conceptually we want to slide a diagonal line across a
128 # rectangle. This gives more interesting results for large
129 # bounds than using divmod.
131 # If in lower left, just return as usual
132 cornerSize = base(W)
133 if N < cornerSize:
134 return getNthPair(N)
136 # Otherwise if in upper right, subtract from corner
137 if H is not aleph0:
138 M = W * H - N - 1
139 if M < cornerSize:
140 x, y = getNthPair(M)
141 return (W - 1 - x, H - 1 - y)
143 # Otherwise, compile line and index from number of times we
144 # wrap.
145 N = N - cornerSize
146 index, offset = N % W, N // W
147 # p = (W-1, 1+offset) + (-1,1)*index
148 return (W - 1 - index, 1 + offset + index)
151 def getNthPairBoundedChecked(
152 N, W=aleph0, H=aleph0, useDivmod=False, GNP=getNthPairBounded
154 x, y = GNP(N, W, H, useDivmod)
155 assert 0 <= x < W and 0 <= y < H
156 return x, y
159 def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
160 """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
162 Return the N-th W-tuple, where for 0 <= x_i < H."""
164 if useLeftToRight:
165 elts = [None] * W
166 for i in range(W):
167 elts[i], N = getNthPairBounded(N, H)
168 return tuple(elts)
169 else:
170 if W == 0:
171 return ()
172 elif W == 1:
173 return (N,)
174 elif W == 2:
175 return getNthPairBounded(N, H, H)
176 else:
177 LW, RW = W // 2, W - (W // 2)
178 L, R = getNthPairBounded(N, H**LW, H**RW)
179 return getNthNTuple(
180 L, LW, H=H, useLeftToRight=useLeftToRight
181 ) + getNthNTuple(R, RW, H=H, useLeftToRight=useLeftToRight)
184 def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
185 t = GNT(N, W, H, useLeftToRight)
186 assert len(t) == W
187 for i in t:
188 assert i < H
189 return t
192 def getNthTuple(
193 N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False
195 """getNthTuple(N, maxSize, maxElement) -> x
197 Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
198 y < maxElement."""
200 # All zero sized tuples are isomorphic, don't ya know.
201 if N == 0:
202 return ()
203 N -= 1
204 if maxElement is not aleph0:
205 if maxSize is aleph0:
206 raise NotImplementedError("Max element size without max size unhandled")
207 bounds = [maxElement**i for i in range(1, maxSize + 1)]
208 S, M = getNthPairVariableBounds(N, bounds)
209 else:
210 S, M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
211 return getNthNTuple(M, S + 1, maxElement, useLeftToRight=useLeftToRight)
214 def getNthTupleChecked(
216 maxSize=aleph0,
217 maxElement=aleph0,
218 useDivmod=False,
219 useLeftToRight=False,
220 GNT=getNthTuple,
222 # FIXME: maxsize is inclusive
223 t = GNT(N, maxSize, maxElement, useDivmod, useLeftToRight)
224 assert len(t) <= maxSize
225 for i in t:
226 assert i < maxElement
227 return t
230 def getNthPairVariableBounds(N, bounds):
231 """getNthPairVariableBounds(N, bounds) -> (x, y)
233 Given a finite list of bounds (which may be finite or aleph0),
234 return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
235 bounds[x]."""
237 if not bounds:
238 raise ValueError("Invalid bounds")
239 if not (0 <= N < sum(bounds)):
240 raise ValueError("Invalid input (out of bounds)")
242 level = 0
243 active = list(range(len(bounds)))
244 active.sort(key=lambda i: bounds[i])
245 prevLevel = 0
246 for i, index in enumerate(active):
247 level = bounds[index]
248 W = len(active) - i
249 if level is aleph0:
250 H = aleph0
251 else:
252 H = level - prevLevel
253 levelSize = W * H
254 if N < levelSize: # Found the level
255 idelta, delta = getNthPairBounded(N, W, H)
256 return active[i + idelta], prevLevel + delta
257 else:
258 N -= levelSize
259 prevLevel = level
260 else:
261 raise RuntimError("Unexpected loop completion")
264 def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
265 x, y = GNVP(N, bounds)
266 assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
267 return (x, y)
273 def testPairs():
274 W = 3
275 H = 6
276 a = [[" " for x in range(10)] for y in range(10)]
277 b = [[" " for x in range(10)] for y in range(10)]
278 for i in range(min(W * H, 40)):
279 x, y = getNthPairBounded(i, W, H)
280 x2, y2 = getNthPairBounded(i, W, H, useDivmod=True)
281 print(i, (x, y), (x2, y2))
282 a[y][x] = "%2d" % i
283 b[y2][x2] = "%2d" % i
285 print("-- a --")
286 for ln in a[::-1]:
287 if "".join(ln).strip():
288 print(" ".join(ln))
289 print("-- b --")
290 for ln in b[::-1]:
291 if "".join(ln).strip():
292 print(" ".join(ln))
295 def testPairsVB():
296 bounds = [2, 2, 4, aleph0, 5, aleph0]
297 a = [[" " for x in range(15)] for y in range(15)]
298 b = [[" " for x in range(15)] for y in range(15)]
299 for i in range(min(sum(bounds), 40)):
300 x, y = getNthPairVariableBounds(i, bounds)
301 print(i, (x, y))
302 a[y][x] = "%2d" % i
304 print("-- a --")
305 for ln in a[::-1]:
306 if "".join(ln).strip():
307 print(" ".join(ln))
312 # Toggle to use checked versions of enumeration routines.
313 if False:
314 getNthPairVariableBounds = getNthPairVariableBoundsChecked
315 getNthPairBounded = getNthPairBoundedChecked
316 getNthNTuple = getNthNTupleChecked
317 getNthTuple = getNthTupleChecked
319 if __name__ == "__main__":
320 testPairs()
322 testPairsVB()