1 # (1) Cython does not support __new__
3 # (2) what to do if we want
6 # cdef virt_func(Base a, Base b):
7 # # here we ensure that a & b are of the same type
10 # cdef class Child(Base):
11 # cdef virt_func(Child a, Child b):
16 # currently we have to do:
19 # cdef cirt_func(Child a, _Basic _b):
20 # cdef Child b = <Child>_b
23 # (3) @staticmethod for cdef methods?
25 # (4) nested cdef like in here:
38 cdef int hash_seq
(args
):
40 Hash of a sequence, that *depends* on the order of elements.
42 # make this more robust:
45 m
= hash(m
+ 1001 ^
hash(x
))
54 cdef tuple _args
# XXX tuple -> list?
64 self.hash = self._hash
()
69 return hash_seq
(self._args
)
77 def _set_rawargs
(self, args
):
87 cpdef as_coeff_rest
(self):
88 return (Integer
(1), self)
90 cpdef as_base_exp
(self):
91 return (self, Integer
(1))
93 cpdef _Basic expand
(self):
96 # NOTE: there is no __rxxx__ methods in Cython/Pyrex
108 return Mul
((x
, Pow((y
, Integer
(-1)))))
110 # FIXME we should get rid of z?
111 def __pow__
(x
, y
, z
):
115 return Mul
((Integer
(-1), x
))
120 # in subclasses, you can be sure that _equal(a, b) is called with exactly
121 # the same type, e.g.
123 # when _Add._equal(a, b) is called a and b are of ._type=ADD for sure
124 cdef int _equal
(_Basic
self, _Basic o
):
125 # by default we compare ._args
126 return self._args
== o
._args
128 cdef bint equal
(_Basic
self, _Basic o
):
129 if self._type
!= o
._type
:
132 # now we know self and o are of the same type, lets dispatch to their
134 return self._equal
(o
)
138 def __richcmp__
(_Basic x
, y
, int op
):
139 #print '__richcmp__ %s %s %i' % (x,y,op)
148 return not x
.equal
(y
)
160 cdef class _Integer
(_Basic
):
161 cdef object i
# XXX object -> pyint?
163 def __cinit__
(self, i
):
167 cdef int _hash
(self):
170 cdef int _equal
(_Integer
self, _Basic o
):
171 cdef _Integer other
= <_Integer
>o
172 return self.i
== other
.i
175 def __str__
(_Integer
self):
178 def __repr__
(_Integer
self):
179 return 'Integer(%i)' % self.i
181 # there is no __radd__ in pyrex
183 cdef _Basic a
= sympify
(_a
)
184 cdef _Basic b
= sympify
(_b
)
185 if a
._type
== INTEGER
and b
._type
== INTEGER
:
186 return Integer
( (<_Integer
>a
).i
+ (<_Integer
>b
).i
)
188 return _Basic
.__add__
(a
, b
)
190 # there is no __rmul__ in pyrex
192 cdef _Basic a
= sympify
(_a
)
193 cdef _Basic b
= sympify
(_b
)
194 if a
._type
== INTEGER
and b
._type
== INTEGER
:
195 return Integer
( (<_Integer
>a
).i
* (<_Integer
>b
).i
)
196 return _Basic
.__mul__
(a
, b
)
201 cpdef _Basic Symbol
(name
):
206 cdef class _Symbol
(_Basic
):
207 cdef object name
# XXX object -> str
209 def __cinit__
(self, name
):
213 cdef int _hash
(self):
214 return hash(self.name
)
216 cdef int _equal
(_Symbol
self, _Basic o
):
217 cdef _Symbol other
= <_Symbol
>o
218 #print 'Symbol._equal %s %s' % (self.name, other.name)
219 return self.name
== other
.name
221 def __str__
(_Symbol
self):
224 def __repr__
(_Symbol
self):
225 return 'Symbol(%s)' % self.name
230 cpdef _Basic Add
(args
):
231 args
= [sympify
(x
) for x
in args
]
232 return _Add_canonicalize
(args
)
236 cdef _Basic _Add_canonicalize
(args
):
239 # from csympy import HashTable
248 cdef _Integer num
= Integer
(0)
254 if a
._type
== INTEGER
:
258 if b
._type
== INTEGER
:
261 coeff
, key
= b
.as_coeff_rest
()
267 coeff
, key
= a
.as_coeff_rest
()
275 for a
, b
in d
.iteritems
():
276 args
.append
(Mul
((a
, b
)))
285 cdef class _Add
(_Basic
):
286 cdef object _args_set
# XXX object -> frozenset
288 def __cinit__
(_Add
self, args
):
290 self._args
= tuple(args
)
294 def freeze_args
(self):
295 #print "add is freezing"
296 if self._args_set
is None
:
297 self._args_set
= frozenset
(self._args
)
300 cdef int _equal
(_Add
self, _Basic o
):
301 cdef _Add other
= <_Add
>o
305 return self._args_set
== other
._args_set
308 def __str__
(_Basic
self):
309 cdef _Basic a
= self._args
[0]
313 for a
in self._args
[1:]:
314 s
= "%s + %s" % (s
, str(a
))
319 cdef int _hash
(self):
320 # XXX: it is surprising, but this is *not* faster:
322 #h = hash(self._args_set)
329 cpdef _Basic expand
(self):
331 for term
in self._args
:
332 r
.append
( term
.expand
() )
336 cpdef _Basic Mul
(args
):
337 args
= [sympify
(x
) for x
in args
]
338 return _Mul_canonicalize
(args
)
341 cdef _Basic _Mul_canonicalize
(args
):
344 # from csympy import HashTable
353 cdef _Integer num
= Integer
(1)
356 if a
._type
== INTEGER
:
360 if b
._type
== INTEGER
:
363 key
, coeff
= b
.as_base_exp
()
369 key
, coeff
= a
.as_base_exp
()
374 if num
.i
== 0 or len(d
)==0:
377 for a
, b
in d
.iteritems
():
378 args
.append
(Pow((a
, b
)))
389 cdef _Basic _Mul_expand_two
(_Basic a
, _Basic b
):
391 Both a and b are assumed to be expanded.
397 if a
._type
== ADD
and b
._type
== ADD
:
415 cdef class _Mul
(_Basic
):
416 cdef object _args_set
# XXX object -> frozenset
418 def __cinit__
(self, args
):
420 self._args
= tuple(args
)
421 self._args_set
= None
424 cdef int _hash
(self):
425 # in contrast to Add, here it is faster:
427 return hash(self._args_set
)
429 #a = list(self._args[:])
434 def freeze_args
(self):
435 #print "mul is freezing"
436 if self._args_set
is None
:
437 self._args_set
= frozenset
(self._args
)
441 cdef int _equal
(_Mul
self, _Basic o
):
442 cdef _Mul other
= <_Mul
>o
445 return self._args_set
== other
._args_set
448 cpdef as_coeff_rest
(self):
449 cdef _Basic a
= self._args
[0]
451 if a
._type
== INTEGER
:
452 return self.as_two_terms
()
453 return (Integer
(1), self)
456 cpdef as_two_terms
(_Mul
self):
457 cdef _Basic a0
= self._args
[0]
458 if len(self._args
) == 2:
459 return a0
, self._args
[1]
462 # XXX _Mul is ok here (like ._new_rawargs)
463 return (self._args
[0], _Mul
(self._args
[1:]))
468 cdef _Basic a
= self._args
[0]
470 if a
._type
in [ADD
, MUL
]:
472 for a
in self._args
[1:]:
473 if a
._type
in [ADD
, MUL
]:
474 s
= "%s * (%s)" % (s
, str(a
))
476 s
= "%s*%s" % (s
, str(a
))
480 cpdef _Basic expand
(self):
484 a
, b
= self.as_two_terms
()
485 r
= _Mul_expand_two
(a
, b
)
489 return _Mul_expand_two
(a
, b
)
494 cpdef _Basic
Pow(args
):
495 args
= [sympify
(x
) for x
in args
]
496 return _Pow_canonicalize
(args
)
499 cdef _Basic _Pow_canonicalize
(args
):
504 cdef _Integer b
= <_Integer
>base
505 cdef _Integer e
= <_Integer
>exp
507 if base
._type
== INTEGER
:
509 return b
# Integer(0)
511 return b
# Integer(1)
512 if exp
._type
== INTEGER
:
517 if base
._type
== POW:
518 return Pow((base
._args
[0], base
._args
[1]*exp
))
523 cdef class _Pow
(_Basic
):
525 def __cinit__
(self, args
):
527 self._args
= tuple(args
)
529 def __str__
(_Pow
self):
530 cdef _Basic b
= self._args
[0]
531 cdef _Basic e
= self._args
[1]
537 s
= "%s^(%s)" % (s
, str(e
))
539 s
= "%s^%s" % (s
, str(e
))
543 cpdef as_base_exp
(_Pow
self):
546 cpdef _Basic expand
(_Pow
self):
547 cdef _Basic _base
= self._args
[0]
548 cdef _Basic _exp
= self._args
[1]
550 # XXX please careful here - use it only after appropriate check
551 cdef _Add base
= <_Add
>_base
552 cdef _Integer exp
= <_Integer
>_exp
561 if _base
._type
== ADD
and _exp
._type
== INTEGER
:
565 d
= multinomial_coefficients
(m
, n
)
568 for powers
, coeff
in d
.iteritems
():
573 for x
, p
in zip
(base
._args
, powers
):
575 t
.append
(Pow((x
, p
)))
576 #t.append(_Pow((x, Integer(p)))) # XXX _Pow -> Pow
590 cpdef _Basic sympify
(x
):
591 if isinstance(x
, int):
595 def binomial_coefficients
(n
):
596 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
597 C_kn are binomial coefficients and n=k1+k2."""
598 d
= {(0, n
):1, (n
, 0):1}
600 for k
in xrange(1, n
//2+1):
602 d
[k
, n
-k
] = d
[n
-k
, k
] = a
605 def binomial_coefficients_list
(n
):
606 """ Return a list of binomial coefficients as rows of the Pascal's
611 for k
in xrange(1, n
//2+1):
616 def multinomial_coefficients
(m
, n
, _tuple
=tuple, _zip
=zip
):
617 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
618 where ``C_kn`` are multinomial coefficients such that
623 >>> print multinomial_coefficients(2,5)
624 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
626 The algorithm is based on the following result:
628 Consider a polynomial and it's ``m``-th exponent::
630 P(x) = sum_{i=0}^m p_i x^k
631 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
633 The coefficients ``a(n,k)`` can be computed using the
634 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
635 Algorithms, The art of Computer Programming v.2, Addison
636 Wesley, Reading, 1981;]::
638 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
640 where ``a(n,0) = p_0^n``.
644 return binomial_coefficients
(n
)
645 symbols
= [(0,)*i
+ (1,) + (0,)*(m
-i
-1) for i
in range(m
)]
647 p0
= [_tuple
([aa
-bb
for aa
,bb
in _zip
(s
,s0
)]) for s
in symbols
]
648 r
= {_tuple
([aa
*n
for aa
in s0
]):1}
651 l
= [0] * (n
*(m
-1)+1)
653 for k
in xrange(1, n
*(m
-1)+1):
656 for i
in xrange(1, min(m
,k
+1)):
661 for t2
, c2
in l
[k
-i
]:
662 tt
= _tuple
([aa
+bb
for aa
,bb
in _zip
(t2
,t
)])
673 r1
= [(t
, c
//k
) for (t
, c
) in d
.iteritems
()]