Make Pow able not to sympify.
[sympyx.git] / sympy_py.py
blobc923389ff3c94b08a48aa7b981abd4b21e122bc3
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 """
10 Hash of a sequence, that *depends* on the order of elements.
11 """
12 # make this more robust:
13 m = 2
14 for x in args:
15 m = hash(m + 1001 ^ hash(x))
16 return m
18 class Basic(object):
20 def __new__(cls, type, args):
21 obj = object.__new__(cls)
22 obj.type = type
23 obj._args = tuple(args)
24 obj.mhash = None
25 return obj
27 def __repr__(self):
28 return str(self)
30 def __hash__(self):
31 if self.mhash is None:
32 h = hash_seq(self.args)
33 self.mhash = h
34 return h
35 else:
36 return self.mhash
38 @property
39 def args(self):
40 return self._args
42 def as_coeff_rest(self):
43 return (Integer(1), self)
45 def as_base_exp(self):
46 return (self, Integer(1))
48 def expand(self):
49 return self
51 def __add__(x, y):
52 return Add((x, y))
54 def __radd__(x, y):
55 return x.__add__(y)
57 def __sub__(x, y):
58 return Add((x, -y))
60 def __rsub__(x, y):
61 return Add((y, -x))
63 def __mul__(x, y):
64 return Mul((x, y))
66 def __rmul__(x, y):
67 return Mul((y, x))
69 def __div__(x, y):
70 return Mul((x, Pow((y, Integer(-1)))))
72 def __rdiv__(x, y):
73 return Mul((y, Pow((x, Integer(-1)))))
75 def __pow__(x, y):
76 return Pow((x, y))
78 def __rpow__(x, y):
79 return Pow((y, x))
81 def __neg__(x):
82 return Mul((Integer(-1), x))
84 def __pos__(x):
85 return x
87 def __ne__(self, x):
88 return not self.__eq__(x)
90 def __eq__(self, o):
91 o = sympify(o)
92 if o.type == self.type:
93 return self.args == o.args
94 else:
95 return False
98 class Integer(Basic):
100 def __new__(cls, i):
101 obj = Basic.__new__(cls, INTEGER, [])
102 obj.i = i
103 return obj
105 def __hash__(self):
106 if self.mhash is None:
107 h = hash(self.i)
108 self.mhash = h
109 return h
110 else:
111 return self.mhash
113 def __eq__(self, o):
114 o = sympify(o)
115 if o.type == INTEGER:
116 return self.i == o.i
117 else:
118 return False
120 def __str__(self):
121 return str(self.i)
123 def __add__(self, o):
124 o = sympify(o)
125 if o.type == INTEGER:
126 return Integer(self.i+o.i)
127 return Basic.__add__(self, o)
129 def __mul__(self, o):
130 o = sympify(o)
131 if o.type == INTEGER:
132 return Integer(self.i*o.i)
133 return Basic.__mul__(self, o)
136 class Symbol(Basic):
138 def __new__(cls, name):
139 obj = Basic.__new__(cls, SYMBOL, [])
140 obj.name = name
141 return obj
143 def __hash__(self):
144 if self.mhash is None:
145 h = hash(self.name)
146 self.mhash = h
147 return h
148 else:
149 return self.mhash
151 def __eq__(self, o):
152 o = sympify(o)
153 if o.type == SYMBOL:
154 return self.name == o.name
155 return False
157 def __str__(self):
158 return self.name
161 class Add(Basic):
163 def __new__(cls, args, canonicalize=True):
164 if canonicalize == False:
165 obj = Basic.__new__(cls, ADD, args)
166 obj._args_set = None
167 return obj
168 args = [sympify(x) for x in args]
169 return Add.canonicalize(args)
171 @classmethod
172 def canonicalize(cls, args):
173 use_glib = 0
174 if use_glib:
175 from csympy import HashTable
176 d = HashTable()
177 else:
178 d = {}
179 num = Integer(0)
180 for a in args:
181 if a.type == INTEGER:
182 num += a
183 elif a.type == ADD:
184 for b in a.args:
185 if b.type == INTEGER:
186 num += b
187 else:
188 coeff, key = b.as_coeff_rest()
189 if key in d:
190 d[key] += coeff
191 else:
192 d[key] = coeff
193 else:
194 coeff, key = a.as_coeff_rest()
195 if key in d:
196 d[key] += coeff
197 else:
198 d[key] = coeff
199 if len(d)==0:
200 return num
201 args = []
202 for a, b in d.iteritems():
203 args.append(Mul((a, b)))
204 if num.i != 0:
205 args.insert(0, num)
206 if len(args) == 1:
207 return args[0]
208 else:
209 return Add(args, False)
211 def freeze_args(self):
212 #print "add is freezing"
213 if self._args_set is None:
214 self._args_set = frozenset(self.args)
215 #print "done"
217 def __eq__(self, o):
218 o = sympify(o)
219 if o.type == ADD:
220 self.freeze_args()
221 o.freeze_args()
222 return self._args_set == o._args_set
223 else:
224 return False
226 def __str__(self):
227 s = str(self.args[0])
228 if self.args[0].type == ADD:
229 s = "(%s)" % str(s)
230 for x in self.args[1:]:
231 s = "%s + %s" % (s, str(x))
232 if x.type == ADD:
233 s = "(%s)" % s
234 return s
236 def __hash__(self):
237 if self.mhash is None:
238 # XXX: it is surprising, but this is *not* faster:
239 #self.freeze_args()
240 #h = hash(self._args_set)
242 # this is faster:
243 a = list(self.args[:])
244 a.sort(key=hash)
245 h = hash_seq(a)
246 self.mhash = h
247 return h
248 else:
249 return self.mhash
251 def expand(self):
252 r = []
253 for term in self.args:
254 r.append( term.expand() )
255 return Add(r)
257 class Mul(Basic):
259 def __new__(cls, args, canonicalize=True):
260 if canonicalize == False:
261 obj = Basic.__new__(cls, MUL, args)
262 obj._args_set = None
263 return obj
264 args = [sympify(x) for x in args]
265 return Mul.canonicalize(args)
267 @classmethod
268 def canonicalize(cls, args):
269 use_glib = 0
270 if use_glib:
271 from csympy import HashTable
272 d = HashTable()
273 else:
274 d = {}
275 if len(args) == 2 and args[0].type == MUL and args[1].type == INTEGER:
276 a, b = args
277 assert a.type == MUL
278 assert b.type == INTEGER
279 if b.i == 1:
280 return a
281 if b.i == 0:
282 return b
283 if a.args[0].type == INTEGER:
284 if a.args[0].i == 1:
285 args = (b,) + a.args[1:]
286 else:
287 args = (b*a.args[0],) + a.args[1:]
288 else:
289 args = (b,)+a.args
290 return Mul(args, False)
292 num = Integer(1)
293 for a in args:
294 if a.type == INTEGER:
295 num *= a
296 elif a.type == MUL:
297 for b in a.args:
298 if b.type == INTEGER:
299 num *= b
300 else:
301 key, coeff = b.as_base_exp()
302 if key in d:
303 d[key] += coeff
304 else:
305 d[key] = coeff
306 else:
307 key, coeff = a.as_base_exp()
308 if key in d:
309 d[key] += coeff
310 else:
311 d[key] = coeff
312 if num.i == 0 or len(d)==0:
313 return num
314 args = []
315 for a, b in d.iteritems():
316 args.append(Pow((a, b)))
317 if num.i != 1:
318 args.insert(0, num)
319 if len(args) == 1:
320 return args[0]
321 else:
322 return Mul(args, False)
324 def __hash__(self):
325 if self.mhash is None:
326 # in contrast to Add, here it is faster:
327 self.freeze_args()
328 h = hash(self._args_set)
329 # this is slower:
330 #a = list(self.args[:])
331 #a.sort(key=hash)
332 #h = hash_seq(a)
333 self.mhash = h
334 return h
335 else:
336 return self.mhash
338 def freeze_args(self):
339 #print "mul is freezing"
340 if self._args_set is None:
341 self._args_set = frozenset(self.args)
342 #print "done"
344 def __eq__(self, o):
345 o = sympify(o)
346 if o.type == MUL:
347 self.freeze_args()
348 o.freeze_args()
349 return self._args_set == o._args_set
350 else:
351 return False
354 def as_coeff_rest(self):
355 if self.args[0].type == INTEGER:
356 return self.as_two_terms()
357 return (Integer(1), self)
359 def as_two_terms(self):
360 args = self.args
361 a0 = args[0]
363 if len(args) == 2:
364 return a0, args[1]
365 else:
366 return (a0, Mul(args[1:], False))
369 def __str__(self):
370 s = str(self.args[0])
371 if self.args[0].type in [ADD, MUL]:
372 s = "(%s)" % str(s)
373 for x in self.args[1:]:
374 if x.type in [ADD, MUL]:
375 s = "%s * (%s)" % (s, str(x))
376 else:
377 s = "%s*%s" % (s, str(x))
378 return s
380 @classmethod
381 def expand_two(self, a, b):
383 Both a and b are assumed to be expanded.
385 if a.type == ADD and b.type == ADD:
386 terms = []
387 for x in a.args:
388 for y in b.args:
389 terms.append(x*y)
390 return Add(terms)
391 if a.type == ADD:
392 terms = []
393 for x in a.args:
394 terms.append(x*b)
395 return Add(terms)
396 if b.type == ADD:
397 terms = []
398 for y in b.args:
399 terms.append(a*y)
400 return Add(terms)
401 return a*b
403 def expand(self):
404 a, b = self.as_two_terms()
405 r = Mul.expand_two(a, b)
406 if r == self:
407 a = a.expand()
408 b = b.expand()
409 return Mul.expand_two(a, b)
410 else:
411 return r.expand()
413 class Pow(Basic):
415 def __new__(cls, args, canonicalize=True, do_sympify=True):
416 if canonicalize == False:
417 obj = Basic.__new__(cls, POW, args)
418 return obj
419 if do_sympify:
420 args = [sympify(x) for x in args]
421 return Pow.canonicalize(args)
423 @classmethod
424 def canonicalize(cls, args):
425 base, exp = args
426 if base.type == INTEGER:
427 if base.i == 0:
428 return base
429 if base.i == 1:
430 return base
431 if exp.type == INTEGER:
432 if exp.i == 0:
433 return Integer(1)
434 if exp.i == 1:
435 return base
436 if base.type == POW:
437 return Pow((base.args[0], base.args[1]*exp))
438 return Pow(args, False)
440 def __str__(self):
441 s = str(self.args[0])
442 if self.args[0].type == ADD:
443 s = "(%s)" % s
444 if self.args[1].type == ADD:
445 s = "%s^(%s)" % (s, str(self.args[1]))
446 else:
447 s = "%s^%s" % (s, str(self.args[1]))
448 return s
450 def as_base_exp(self):
451 return self.args
453 def expand(self):
454 base, exp = self.args
455 if base.type == ADD and exp.type == INTEGER:
456 n = exp.i
457 m = len(base.args)
458 #print "multi"
459 d = multinomial_coefficients(m, n)
460 #print "assembly"
461 r = []
462 for powers, coeff in d.iteritems():
463 if coeff == 1:
464 t = []
465 else:
466 t = [Integer(coeff)]
467 for x, p in zip(base.args, powers):
468 if p != 0:
469 if p == 1:
470 tt = x
471 else:
472 tt = Pow((x, Integer(p)), do_sympify=False)
473 t.append(tt)
474 assert len(t) != 0
475 if len(t) == 1:
476 t = t[0]
477 else:
478 t = Mul(t, False)
479 r.append(t)
480 r = Add(r, False)
481 #print "done"
482 return r
483 return self
485 def sympify(x):
486 if isinstance(x, int):
487 return Integer(x)
488 return x
490 def var(s):
492 Create a symbolic variable with the name *s*.
494 INPUT:
495 s -- a string, either a single variable name, or
496 a space separated list of variable names, or
497 a list of variable names.
499 NOTE: The new variable is both returned and automatically injected into
500 the parent's *global* namespace. It's recommended not to use "var" in
501 library code, it is better to use symbols() instead.
503 EXAMPLES:
504 We define some symbolic variables:
505 >>> var('m')
507 >>> var('n xx yy zz')
508 (n, xx, yy, zz)
509 >>> n
513 import re
514 import inspect
515 frame = inspect.currentframe().f_back
517 try:
518 if not isinstance(s, list):
519 s = re.split('\s|,', s)
521 res = []
523 for t in s:
524 # skip empty strings
525 if not t:
526 continue
527 sym = Symbol(t)
528 frame.f_globals[t] = sym
529 res.append(sym)
531 res = tuple(res)
532 if len(res) == 0: # var('')
533 res = None
534 elif len(res) == 1: # var('x')
535 res = res[0]
536 # otherwise var('a b ...')
537 return res
539 finally:
540 # we should explicitly break cyclic dependencies as stated in inspect
541 # doc
542 del frame
544 def binomial_coefficients(n):
545 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
546 C_kn are binomial coefficients and n=k1+k2."""
547 d = {(0, n):1, (n, 0):1}
548 a = 1
549 for k in xrange(1, n//2+1):
550 a = (a * (n-k+1))//k
551 d[k, n-k] = d[n-k, k] = a
552 return d
554 def binomial_coefficients_list(n):
555 """ Return a list of binomial coefficients as rows of the Pascal's
556 triangle.
558 d = [1] * (n+1)
559 a = 1
560 for k in xrange(1, n//2+1):
561 a = (a * (n-k+1))//k
562 d[k] = d[n-k] = a
563 return d
565 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
566 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
567 where ``C_kn`` are multinomial coefficients such that
568 ``n=k1+k2+..+km``.
570 For example:
572 >>> print multinomial_coefficients(2,5)
573 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
575 The algorithm is based on the following result:
577 Consider a polynomial and it's ``m``-th exponent::
579 P(x) = sum_{i=0}^m p_i x^k
580 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
582 The coefficients ``a(n,k)`` can be computed using the
583 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
584 Algorithms, The art of Computer Programming v.2, Addison
585 Wesley, Reading, 1981;]::
587 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
589 where ``a(n,0) = p_0^n``.
592 if m==2:
593 return binomial_coefficients(n)
594 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
595 s0 = symbols[0]
596 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
597 r = {_tuple(aa*n for aa in s0):1}
598 r_get = r.get
599 r_update = r.update
600 l = [0] * (n*(m-1)+1)
601 l[0] = r.items()
602 for k in xrange(1, n*(m-1)+1):
603 d = {}
604 d_get = d.get
605 for i in xrange(1, min(m,k+1)):
606 nn = (n+1)*i-k
607 if not nn:
608 continue
609 t = p0[i]
610 for t2, c2 in l[k-i]:
611 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
612 cc = nn * c2
613 b = d_get(tt)
614 if b is None:
615 d[tt] = cc
616 else:
617 cc = b + cc
618 if cc:
619 d[tt] = cc
620 else:
621 del d[tt]
622 r1 = [(t, c//k) for (t, c) in d.iteritems()]
623 l[k] = r1
624 r_update(r1)
625 return r