couple tests fixed
[sympyx.git] / sympy.py
blob7eee10e6d25ada3305c7c10bd6f27247bc815152
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 # make this more robust:
10 m = 2
11 for x in args:
12 m = hash(m + 1001 ^ hash(x))
13 return m
15 class Basic(object):
17 def __new__(cls, type, args):
18 obj = object.__new__(cls)
19 obj.type = type
20 obj._args = args
21 return obj
23 def __repr__(self):
24 return str(self)
26 def __hash__(self):
27 return hash_seq(self.args)
29 @property
30 def args(self):
31 return self._args
33 def as_coeff_rest(self):
34 return (Integer(1), self)
36 def as_base_exp(self):
37 return (self, Integer(1))
39 def __add__(x, y):
40 return Add((x, y))
42 def __radd__(x, y):
43 return x.__add__(y)
45 def __sub__(x, y):
46 return Add((x, -y))
48 def __rsub__(x, y):
49 return Add((y, -x))
51 def __mul__(x, y):
52 return Mul((x, y))
54 def __rmul__(x, y):
55 return Mul((y, x))
57 def __div__(x, y):
58 return Mul((x, Pow((y, Integer(-1)))))
60 def __rdiv__(x, y):
61 return Mul((y, Pow((x, Integer(-1)))))
63 def __pow__(x, y):
64 return Pow((x, y))
66 def __rpow__(x, y):
67 return Pow((y, x))
69 def __neg__(x):
70 return Mul((Integer(-1), x))
72 def __pos__(x):
73 return x
76 class Integer(Basic):
78 def __new__(cls, i):
79 obj = Basic.__new__(cls, INTEGER, [])
80 obj.i = i
81 return obj
83 def __str__(self):
84 return str(self.i)
86 def __add__(self, o):
87 o = sympify(o)
88 if o.type == INTEGER:
89 return Integer(self.i+o.i)
90 return Basic.__add__(self, o)
92 def __mul__(self, o):
93 o = sympify(o)
94 if o.type == INTEGER:
95 return Integer(self.i*o.i)
96 return Basic.__mul__(self, o)
99 class Symbol(Basic):
101 def __new__(cls, name):
102 obj = Basic.__new__(cls, SYMBOL, [])
103 obj.name = name
104 return obj
106 def __hash__(self):
107 return hash(self.name)
109 def __str__(self):
110 return self.name
113 class Add(Basic):
115 def __new__(cls, args, canonicalize=True):
116 if canonicalize == False:
117 obj = Basic.__new__(cls, ADD, args)
118 return obj
119 args = [sympify(x) for x in args]
120 return Add.canonicalize(args)
122 @classmethod
123 def canonicalize(cls, args):
124 d = {}
125 for a in args:
126 if a.type == ADD:
127 for b in a.args:
128 coeff, key = b.as_coeff_rest()
129 if key in d:
130 d[key] += coeff
131 else:
132 d[key] = coeff
133 else:
134 coeff, key = a.as_coeff_rest()
135 if key in d:
136 d[key] += coeff
137 else:
138 d[key] = coeff
139 args = []
140 for a, b in d.iteritems():
141 args.append(Mul((a, b)))
143 return Add(args, False)
145 def __str__(self):
146 s = str(self.args[0])
147 if self.args[0].type == ADD:
148 s = "(%s)" % str(s)
149 for x in self.args[1:]:
150 s = "%s + %s" % (s, str(x))
151 if x.type == ADD:
152 s = "(%s)" % s
153 return s
155 class Mul(Basic):
157 def __new__(cls, args, canonicalize=True):
158 if canonicalize == False:
159 obj = Basic.__new__(cls, MUL, args)
160 return obj
161 args = [sympify(x) for x in args]
162 return Mul.canonicalize(args)
164 @classmethod
165 def canonicalize(cls, args):
166 d = {}
167 num = Integer(1)
168 for a in args:
169 if a.type == INTEGER:
170 num *= a
171 elif a.type == MUL:
172 for b in a.args:
173 coeff, key = b.as_base_exp()
174 if key in d:
175 d[key] += coeff
176 else:
177 d[key] = coeff
178 else:
179 coeff, key = a.as_base_exp()
180 if key in d:
181 d[key] += coeff
182 else:
183 d[key] = coeff
184 if num.i == 0:
185 return num
186 args = []
187 for a, b in d.iteritems():
188 args.append(Pow((b, a)))
189 if num.i != 1:
190 args.insert(0, num)
191 if len(args) == 1:
192 return args[0]
193 else:
194 return Mul(args, False)
196 def __hash__(self):
197 a = self.args[:]
198 a.sort(key=hash)
199 return hash_seq(a)
201 def __eq__(self, o):
202 if o.type == MUL:
203 a = self.args[:]
204 a.sort(key=hash)
205 b = o.args[:]
206 b.sort(key=hash)
207 return a == b
208 else:
209 return False
212 def as_coeff_rest(self):
213 if self.args[0].type == INTEGER:
214 return (self.args[0], Mul(self.args[1:]))
215 return (Integer(1), self)
217 def __str__(self):
218 s = str(self.args[0])
219 if self.args[0].type == MUL:
220 s = "(%s)" % str(s)
221 for x in self.args[1:]:
222 s = "%s*%s" % (s, str(x))
223 if x.type == MUL:
224 s = "(%s)" % s
225 return s
227 class Pow(Basic):
229 def __new__(cls, args, canonicalize=True):
230 if canonicalize == False:
231 obj = Basic.__new__(cls, MUL, args)
232 return obj
233 args = [sympify(x) for x in args]
234 return Pow.canonicalize(args)
236 @classmethod
237 def canonicalize(cls, args):
238 base, exp = args
239 if exp.type == INTEGER:
240 if exp.i == 0:
241 return Integer(1)
242 if exp.i == 1:
243 return base
244 return Pow(args, False)
246 def __str__(self):
247 s = str(self.args[0])
248 if self.args[0].type == ADD:
249 s = "(%s)" % s
250 if self.args[1].type == ADD:
251 s = "%s^(%s)" % (s, str(self.args[1]))
252 else:
253 s = "%s^%s" % (s, str(self.args[1]))
255 return s
257 def sympify(x):
258 if isinstance(x, int):
259 return Integer(x)
260 return x