first test fixed
[sympyx.git] / sympy.py
blob79ea6ab4ea8606fbc9cb680cd59f4f593de5e71a
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 # make this more robust:
10 m = 2
11 for x in args:
12 m = hash(m + 1001 ^ hash(x))
13 return m
15 class Basic(object):
17 def __new__(cls, type, args):
18 obj = object.__new__(cls)
19 obj.type = type
20 obj._args = tuple(args)
21 return obj
23 def __repr__(self):
24 return str(self)
26 def __hash__(self):
27 return hash_seq(self.args)
29 @property
30 def args(self):
31 return self._args
33 def as_coeff_rest(self):
34 return (Integer(1), self)
36 def as_base_exp(self):
37 return (self, Integer(1))
39 def __add__(x, y):
40 return Add((x, y))
42 def __radd__(x, y):
43 return x.__add__(y)
45 def __sub__(x, y):
46 return Add((x, -y))
48 def __rsub__(x, y):
49 return Add((y, -x))
51 def __mul__(x, y):
52 return Mul((x, y))
54 def __rmul__(x, y):
55 return Mul((y, x))
57 def __div__(x, y):
58 return Mul((x, Pow((y, Integer(-1)))))
60 def __rdiv__(x, y):
61 return Mul((y, Pow((x, Integer(-1)))))
63 def __pow__(x, y):
64 return Pow((x, y))
66 def __rpow__(x, y):
67 return Pow((y, x))
69 def __neg__(x):
70 return Mul((Integer(-1), x))
72 def __pos__(x):
73 return x
75 def __ne__(self, x):
76 return not self.__eq__(x)
78 def __eq__(self, o):
79 o = sympify(o)
80 if o.type == self.type:
81 return self.args == o.args
82 else:
83 return False
86 class Integer(Basic):
88 def __new__(cls, i):
89 obj = Basic.__new__(cls, INTEGER, [])
90 obj.i = i
91 return obj
93 def __eq__(self, o):
94 o = sympify(o)
95 if o.type == INTEGER:
96 return self.i == o.i
97 else:
98 return False
100 def __str__(self):
101 return str(self.i)
103 def __add__(self, o):
104 o = sympify(o)
105 if o.type == INTEGER:
106 return Integer(self.i+o.i)
107 return Basic.__add__(self, o)
109 def __mul__(self, o):
110 o = sympify(o)
111 if o.type == INTEGER:
112 return Integer(self.i*o.i)
113 return Basic.__mul__(self, o)
116 class Symbol(Basic):
118 def __new__(cls, name):
119 obj = Basic.__new__(cls, SYMBOL, [])
120 obj.name = name
121 return obj
123 def __hash__(self):
124 return hash(self.name)
126 def __eq__(self, o):
127 o = sympify(o)
128 if o.type == SYMBOL:
129 return self.name == o.name
130 return False
132 def __str__(self):
133 return self.name
136 class Add(Basic):
138 def __new__(cls, args, canonicalize=True):
139 if canonicalize == False:
140 obj = Basic.__new__(cls, ADD, args)
141 return obj
142 args = [sympify(x) for x in args]
143 return Add.canonicalize(args)
145 @classmethod
146 def canonicalize(cls, args):
147 d = {}
148 for a in args:
149 if a.type == ADD:
150 for b in a.args:
151 coeff, key = b.as_coeff_rest()
152 if key in d:
153 d[key] += coeff
154 else:
155 d[key] = coeff
156 else:
157 coeff, key = a.as_coeff_rest()
158 if key in d:
159 d[key] += coeff
160 else:
161 d[key] = coeff
162 args = []
163 for a, b in d.iteritems():
164 args.append(Mul((a, b)))
166 return Add(args, False)
168 def __eq__(self, o):
169 o = sympify(o)
170 if o.type == ADD:
171 a = list(self.args[:])
172 a.sort(key=hash)
173 b = list(o.args[:])
174 b.sort(key=hash)
175 return a == b
176 else:
177 return False
179 def __str__(self):
180 s = str(self.args[0])
181 if self.args[0].type == ADD:
182 s = "(%s)" % str(s)
183 for x in self.args[1:]:
184 s = "%s + %s" % (s, str(x))
185 if x.type == ADD:
186 s = "(%s)" % s
187 return s
189 class Mul(Basic):
191 def __new__(cls, args, canonicalize=True):
192 if canonicalize == False:
193 obj = Basic.__new__(cls, MUL, args)
194 return obj
195 args = [sympify(x) for x in args]
196 return Mul.canonicalize(args)
198 @classmethod
199 def canonicalize(cls, args):
200 d = {}
201 num = Integer(1)
202 for a in args:
203 if a.type == INTEGER:
204 num *= a
205 elif a.type == MUL:
206 for b in a.args:
207 coeff, key = b.as_base_exp()
208 if key in d:
209 d[key] += coeff
210 else:
211 d[key] = coeff
212 else:
213 coeff, key = a.as_base_exp()
214 if key in d:
215 d[key] += coeff
216 else:
217 d[key] = coeff
218 if num.i == 0 or len(d)==0:
219 return num
220 args = []
221 for a, b in d.iteritems():
222 args.append(Pow((b, a)))
223 if num.i != 1:
224 args.insert(0, num)
225 if len(args) == 1:
226 return args[0]
227 else:
228 return Mul(args, False)
230 def __hash__(self):
231 a = list(self.args[:])
232 a.sort(key=hash)
233 return hash_seq(a)
235 def __eq__(self, o):
236 o = sympify(o)
237 if o.type == MUL:
238 a = list(self.args[:])
239 a.sort(key=hash)
240 b = list(o.args[:])
241 b.sort(key=hash)
242 return a == b
243 else:
244 return False
247 def as_coeff_rest(self):
248 if self.args[0].type == INTEGER:
249 return (self.args[0], Mul(self.args[1:]))
250 return (Integer(1), self)
252 def __str__(self):
253 s = str(self.args[0])
254 if self.args[0].type == MUL:
255 s = "(%s)" % str(s)
256 for x in self.args[1:]:
257 s = "%s*%s" % (s, str(x))
258 if x.type == MUL:
259 s = "(%s)" % s
260 return s
262 class Pow(Basic):
264 def __new__(cls, args, canonicalize=True):
265 if canonicalize == False:
266 obj = Basic.__new__(cls, POW, args)
267 return obj
268 args = [sympify(x) for x in args]
269 return Pow.canonicalize(args)
271 @classmethod
272 def canonicalize(cls, args):
273 base, exp = args
274 if exp.type == INTEGER:
275 if exp.i == 0:
276 return Integer(1)
277 if exp.i == 1:
278 return base
279 return Pow(args, False)
281 def __str__(self):
282 s = str(self.args[0])
283 if self.args[0].type == ADD:
284 s = "(%s)" % s
285 if self.args[1].type == ADD:
286 s = "%s^(%s)" % (s, str(self.args[1]))
287 else:
288 s = "%s^%s" % (s, str(self.args[1]))
290 return s
292 def sympify(x):
293 if isinstance(x, int):
294 return Integer(x)
295 return x