Debugging prints added
[sympyx.git] / sympy_py.py
blobd6ac8268158ced5c2ceb11a1b60ee09d7bf59f01
1 #from timeit import default_timer as clock
3 BASIC = 0
4 SYMBOL = 1
5 ADD = 2
6 MUL = 3
7 POW = 4
8 INTEGER = 5
10 def hash_seq(args):
11 """
12 Hash of a sequence, that *depends* on the order of elements.
13 """
14 # make this more robust:
15 m = 2
16 for x in args:
17 m = hash(m + 1001 ^ hash(x))
18 return m
20 class Basic(object):
22 def __new__(cls, type, args):
23 obj = object.__new__(cls)
24 obj.type = type
25 obj._args = tuple(args)
26 obj.mhash = None
27 return obj
29 def __repr__(self):
30 return str(self)
32 def __hash__(self):
33 if self.mhash is None:
34 h = hash_seq(self.args)
35 self.mhash = h
36 return h
37 else:
38 return self.mhash
40 @property
41 def args(self):
42 return self._args
44 def as_coeff_rest(self):
45 return (Integer(1), self)
47 def as_base_exp(self):
48 return (self, Integer(1))
50 def expand(self):
51 return self
53 def __add__(x, y):
54 return Add((x, y))
56 def __radd__(x, y):
57 return x.__add__(y)
59 def __sub__(x, y):
60 return Add((x, -y))
62 def __rsub__(x, y):
63 return Add((y, -x))
65 def __mul__(x, y):
66 return Mul((x, y))
68 def __rmul__(x, y):
69 return Mul((y, x))
71 def __div__(x, y):
72 return Mul((x, Pow((y, Integer(-1)))))
74 def __rdiv__(x, y):
75 return Mul((y, Pow((x, Integer(-1)))))
77 def __pow__(x, y):
78 return Pow((x, y))
80 def __rpow__(x, y):
81 return Pow((y, x))
83 def __neg__(x):
84 return Mul((Integer(-1), x))
86 def __pos__(x):
87 return x
89 def __ne__(self, x):
90 return not self.__eq__(x)
92 def __eq__(self, o):
93 o = sympify(o)
94 if o.type == self.type:
95 return self.args == o.args
96 else:
97 return False
100 class Integer(Basic):
102 def __new__(cls, i):
103 obj = Basic.__new__(cls, INTEGER, [])
104 obj.i = i
105 return obj
107 def __hash__(self):
108 if self.mhash is None:
109 h = hash(self.i)
110 self.mhash = h
111 return h
112 else:
113 return self.mhash
115 def __eq__(self, o):
116 o = sympify(o)
117 if o.type == INTEGER:
118 return self.i == o.i
119 else:
120 return False
122 def __str__(self):
123 return str(self.i)
125 def __add__(self, o):
126 o = sympify(o)
127 if o.type == INTEGER:
128 return Integer(self.i+o.i)
129 return Basic.__add__(self, o)
131 def __mul__(self, o):
132 o = sympify(o)
133 if o.type == INTEGER:
134 return Integer(self.i*o.i)
135 return Basic.__mul__(self, o)
138 class Symbol(Basic):
140 def __new__(cls, name):
141 obj = Basic.__new__(cls, SYMBOL, [])
142 obj.name = name
143 return obj
145 def __hash__(self):
146 if self.mhash is None:
147 h = hash(self.name)
148 self.mhash = h
149 return h
150 else:
151 return self.mhash
153 def __eq__(self, o):
154 o = sympify(o)
155 if o.type == SYMBOL:
156 return self.name == o.name
157 return False
159 def __str__(self):
160 return self.name
163 class Add(Basic):
165 def __new__(cls, args, canonicalize=True):
166 if canonicalize == False:
167 obj = Basic.__new__(cls, ADD, args)
168 obj._args_set = None
169 return obj
170 args = [sympify(x) for x in args]
171 return Add.canonicalize(args)
173 @classmethod
174 def canonicalize(cls, args):
175 #t = clock()
176 #print "Add.canon-start"
177 use_glib = 0
178 if use_glib:
179 from csympy import HashTable
180 d = HashTable()
181 else:
182 d = {}
183 num = Integer(0)
184 for a in args:
185 if a.type == INTEGER:
186 num += a
187 elif a.type == ADD:
188 for b in a.args:
189 if b.type == INTEGER:
190 num += b
191 else:
192 coeff, key = b.as_coeff_rest()
193 if key in d:
194 d[key] += coeff
195 else:
196 d[key] = coeff
197 else:
198 coeff, key = a.as_coeff_rest()
199 if key in d:
200 d[key] += coeff
201 else:
202 d[key] = coeff
203 if len(d)==0:
204 return num
205 args = []
206 #print "Add.canon-end:", len(d)
207 #t2 = clock()
208 #print t2-t
209 for a, b in d.iteritems():
210 args.append(Mul((a, b)))
211 #print "."
212 #print clock()-t2
213 if num.i != 0:
214 args.insert(0, num)
215 if len(args) == 1:
216 return args[0]
217 else:
218 return Add(args, False)
220 def freeze_args(self):
221 #print "add is freezing"
222 if self._args_set is None:
223 self._args_set = frozenset(self.args)
224 #print "done"
226 def __eq__(self, o):
227 o = sympify(o)
228 if o.type == ADD:
229 self.freeze_args()
230 o.freeze_args()
231 return self._args_set == o._args_set
232 else:
233 return False
235 def __str__(self):
236 s = str(self.args[0])
237 if self.args[0].type == ADD:
238 s = "(%s)" % str(s)
239 for x in self.args[1:]:
240 s = "%s + %s" % (s, str(x))
241 if x.type == ADD:
242 s = "(%s)" % s
243 return s
245 def __hash__(self):
246 if self.mhash is None:
247 # XXX: it is surprising, but this is *not* faster:
248 #self.freeze_args()
249 #h = hash(self._args_set)
251 # this is faster:
252 a = list(self.args[:])
253 a.sort(key=hash)
254 h = hash_seq(a)
255 self.mhash = h
256 return h
257 else:
258 return self.mhash
260 def expand(self):
261 r = []
262 for term in self.args:
263 r.append( term.expand() )
264 return Add(r)
266 class Mul(Basic):
268 def __new__(cls, args, canonicalize=True):
269 if canonicalize == False:
270 obj = Basic.__new__(cls, MUL, args)
271 obj._args_set = None
272 return obj
273 args = [sympify(x) for x in args]
274 return Mul.canonicalize(args)
276 @classmethod
277 def canonicalize(cls, args):
278 use_glib = 0
279 if use_glib:
280 from csympy import HashTable
281 d = HashTable()
282 else:
283 d = {}
284 if len(args) == 2 and args[0].type == MUL and args[1].type == INTEGER:
285 a, b = args
286 assert a.type == MUL
287 assert b.type == INTEGER
288 if b.i == 1:
289 return a
290 if b.i == 0:
291 return b
292 if a.args[0].type == INTEGER:
293 if a.args[0].i == 1:
294 args = (b,) + a.args[1:]
295 else:
296 args = (b*a.args[0],) + a.args[1:]
297 else:
298 args = (b,)+a.args
299 return Mul(args, False)
301 num = Integer(1)
302 for a in args:
303 if a.type == INTEGER:
304 num *= a
305 elif a.type == MUL:
306 for b in a.args:
307 if b.type == INTEGER:
308 num *= b
309 else:
310 key, coeff = b.as_base_exp()
311 if key in d:
312 d[key] += coeff
313 else:
314 d[key] = coeff
315 else:
316 key, coeff = a.as_base_exp()
317 if key in d:
318 d[key] += coeff
319 else:
320 d[key] = coeff
321 if num.i == 0 or len(d)==0:
322 return num
323 args = []
324 for a, b in d.iteritems():
325 args.append(Pow((a, b)))
326 if num.i != 1:
327 args.insert(0, num)
328 if len(args) == 1:
329 return args[0]
330 else:
331 return Mul(args, False)
333 def __hash__(self):
334 if self.mhash is None:
335 # in contrast to Add, here it is faster:
336 self.freeze_args()
337 h = hash(self._args_set)
338 # this is slower:
339 #a = list(self.args[:])
340 #a.sort(key=hash)
341 #h = hash_seq(a)
342 self.mhash = h
343 return h
344 else:
345 return self.mhash
347 def freeze_args(self):
348 #print "mul is freezing"
349 if self._args_set is None:
350 self._args_set = frozenset(self.args)
351 #print "done"
353 def __eq__(self, o):
354 o = sympify(o)
355 if o.type == MUL:
356 self.freeze_args()
357 o.freeze_args()
358 return self._args_set == o._args_set
359 else:
360 return False
363 def as_coeff_rest(self):
364 if self.args[0].type == INTEGER:
365 return self.as_two_terms()
366 return (Integer(1), self)
368 def as_two_terms(self):
369 args = self.args
370 a0 = args[0]
372 if len(args) == 2:
373 return a0, args[1]
374 else:
375 return (a0, Mul(args[1:], False))
378 def __str__(self):
379 s = str(self.args[0])
380 if self.args[0].type in [ADD, MUL]:
381 s = "(%s)" % str(s)
382 for x in self.args[1:]:
383 if x.type in [ADD, MUL]:
384 s = "%s * (%s)" % (s, str(x))
385 else:
386 s = "%s*%s" % (s, str(x))
387 return s
389 @classmethod
390 def expand_two(self, a, b):
392 Both a and b are assumed to be expanded.
394 if a.type == ADD and b.type == ADD:
395 terms = []
396 for x in a.args:
397 for y in b.args:
398 terms.append(x*y)
399 return Add(terms)
400 if a.type == ADD:
401 terms = []
402 for x in a.args:
403 terms.append(x*b)
404 return Add(terms)
405 if b.type == ADD:
406 terms = []
407 for y in b.args:
408 terms.append(a*y)
409 return Add(terms)
410 return a*b
412 def expand(self):
413 a, b = self.as_two_terms()
414 r = Mul.expand_two(a, b)
415 if r == self:
416 a = a.expand()
417 b = b.expand()
418 return Mul.expand_two(a, b)
419 else:
420 return r.expand()
422 class Pow(Basic):
424 def __new__(cls, args, canonicalize=True, do_sympify=True):
425 if canonicalize == False:
426 obj = Basic.__new__(cls, POW, args)
427 return obj
428 if do_sympify:
429 args = [sympify(x) for x in args]
430 return Pow.canonicalize(args)
432 @classmethod
433 def canonicalize(cls, args):
434 base, exp = args
435 if base.type == INTEGER:
436 if base.i == 0:
437 return base
438 if base.i == 1:
439 return base
440 if exp.type == INTEGER:
441 if exp.i == 0:
442 return Integer(1)
443 if exp.i == 1:
444 return base
445 if base.type == POW:
446 return Pow((base.args[0], base.args[1]*exp))
447 return Pow(args, False)
449 def __str__(self):
450 s = str(self.args[0])
451 if self.args[0].type == ADD:
452 s = "(%s)" % s
453 if self.args[1].type == ADD:
454 s = "%s^(%s)" % (s, str(self.args[1]))
455 else:
456 s = "%s^%s" % (s, str(self.args[1]))
457 return s
459 def as_base_exp(self):
460 return self.args
462 def expand(self):
463 base, exp = self.args
464 if base.type == ADD and exp.type == INTEGER:
465 n = exp.i
466 m = len(base.args)
467 #print "multi"
468 d = multinomial_coefficients(m, n)
469 #print "assembly"
470 r = []
471 for powers, coeff in d.iteritems():
472 if coeff == 1:
473 t = []
474 else:
475 t = [Integer(coeff)]
476 for x, p in zip(base.args, powers):
477 if p != 0:
478 if p == 1:
479 tt = x
480 else:
481 tt = Pow((x, Integer(p)), do_sympify=False)
482 t.append(tt)
483 assert len(t) != 0
484 if len(t) == 1:
485 t = t[0]
486 else:
487 t = Mul(t, False)
488 r.append(t)
489 r = Add(r, False)
490 #time2 = clock()-time
491 #print "done", time2, t2
492 return r
493 return self
495 def sympify(x):
496 if isinstance(x, int):
497 return Integer(x)
498 return x
500 def var(s):
502 Create a symbolic variable with the name *s*.
504 INPUT:
505 s -- a string, either a single variable name, or
506 a space separated list of variable names, or
507 a list of variable names.
509 NOTE: The new variable is both returned and automatically injected into
510 the parent's *global* namespace. It's recommended not to use "var" in
511 library code, it is better to use symbols() instead.
513 EXAMPLES:
514 We define some symbolic variables:
515 >>> var('m')
517 >>> var('n xx yy zz')
518 (n, xx, yy, zz)
519 >>> n
523 import re
524 import inspect
525 frame = inspect.currentframe().f_back
527 try:
528 if not isinstance(s, list):
529 s = re.split('\s|,', s)
531 res = []
533 for t in s:
534 # skip empty strings
535 if not t:
536 continue
537 sym = Symbol(t)
538 frame.f_globals[t] = sym
539 res.append(sym)
541 res = tuple(res)
542 if len(res) == 0: # var('')
543 res = None
544 elif len(res) == 1: # var('x')
545 res = res[0]
546 # otherwise var('a b ...')
547 return res
549 finally:
550 # we should explicitly break cyclic dependencies as stated in inspect
551 # doc
552 del frame
554 def binomial_coefficients(n):
555 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
556 C_kn are binomial coefficients and n=k1+k2."""
557 d = {(0, n):1, (n, 0):1}
558 a = 1
559 for k in xrange(1, n//2+1):
560 a = (a * (n-k+1))//k
561 d[k, n-k] = d[n-k, k] = a
562 return d
564 def binomial_coefficients_list(n):
565 """ Return a list of binomial coefficients as rows of the Pascal's
566 triangle.
568 d = [1] * (n+1)
569 a = 1
570 for k in xrange(1, n//2+1):
571 a = (a * (n-k+1))//k
572 d[k] = d[n-k] = a
573 return d
575 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
576 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
577 where ``C_kn`` are multinomial coefficients such that
578 ``n=k1+k2+..+km``.
580 For example:
582 >>> print multinomial_coefficients(2,5)
583 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
585 The algorithm is based on the following result:
587 Consider a polynomial and it's ``m``-th exponent::
589 P(x) = sum_{i=0}^m p_i x^k
590 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
592 The coefficients ``a(n,k)`` can be computed using the
593 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
594 Algorithms, The art of Computer Programming v.2, Addison
595 Wesley, Reading, 1981;]::
597 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
599 where ``a(n,0) = p_0^n``.
602 if m==2:
603 return binomial_coefficients(n)
604 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
605 s0 = symbols[0]
606 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
607 r = {_tuple(aa*n for aa in s0):1}
608 r_get = r.get
609 r_update = r.update
610 l = [0] * (n*(m-1)+1)
611 l[0] = r.items()
612 for k in xrange(1, n*(m-1)+1):
613 d = {}
614 d_get = d.get
615 for i in xrange(1, min(m,k+1)):
616 nn = (n+1)*i-k
617 if not nn:
618 continue
619 t = p0[i]
620 for t2, c2 in l[k-i]:
621 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
622 cc = nn * c2
623 b = d_get(tt)
624 if b is None:
625 d[tt] = cc
626 else:
627 cc = b + cc
628 if cc:
629 d[tt] = cc
630 else:
631 del d[tt]
632 r1 = [(t, c//k) for (t, c) in d.iteritems()]
633 l[k] = r1
634 r_update(r1)
635 return r