clever expansion implemented
[sympyx.git] / test_basic.py
blob538433e1b5ac50d58faae12fd536478f126bfae0
1 from sympy import Symbol, Add, Mul, Pow, Integer, SYMBOL, ADD, MUL, POW, \
2 INTEGER
4 def test_eq():
6 x = Symbol("x")
7 y = Symbol("y")
8 z = Symbol("z")
9 a = Symbol("x")
11 assert x == x
12 assert not (x != x)
13 assert x == a
14 assert not (x != a)
15 assert x != y
17 assert x + y == x + y
18 assert a + y == x + y
19 assert x + y == y + x
20 assert x + y != y + z
22 assert x * y == x * y
23 assert a * y == x * y
24 assert x * y == y * x
25 assert x * y != y * z
27 assert x ** y == x ** y
28 assert a ** y == x ** y
29 assert x ** y != y ** x
30 assert x ** y != y ** z
31 assert a ** y != x ** z
33 assert Integer(3) == Integer(3)
34 assert Integer(3) != Integer(4)
36 def test_add():
37 x = Symbol("x")
38 y = Symbol("y")
39 z = Symbol("z")
41 assert (x + y) + z == x + (y + z)
42 assert (z + x) + y == x + (y + z)
43 assert (z + x) + x != x + (y + z)
45 assert x + x == Integer(2) * x
46 assert ((x + y) + z) + x == (Integer(2)*x + y) + z
48 e1 = 1+z+x+y*x+5
49 e2 = 1+x*y+5+x+z
50 assert e1 == e2
51 assert e2 == 6+x+z+x*y
53 def test_mul():
54 x = Symbol("x")
55 y = Symbol("y")
56 z = Symbol("z")
58 assert (x * y) * z == x * (y * z)
59 assert (z * x) * y == x * (y * z)
60 assert (z * x) * x != x * (y * z)
62 assert x * x == x ** Integer(2)
63 assert ((x * y) * z) * x == ((x ** Integer(2)) * y) * z
65 e1 = 2*z*x*y**x*5
66 e2 = 2*y**x*5*x*z
67 assert e1 == e2
68 assert e2 == 10*x*z*y**x
70 def test_arit():
71 x = Symbol("x")
72 y = Symbol("y")
73 z = Symbol("z")
75 assert x+y == x + y
76 assert x+y+z == (x + y) + z
78 assert x-y == x + (Integer(-1) * y)
79 assert y-x == (Integer(-1) * x) + y
81 assert x*y == x * y
82 assert x*y*z == z * (x * y)
84 assert x/y == x * (y ** Integer(-1))
85 assert y/x == (x ** Integer(-1)) * y
87 assert x**Integer(2) == x ** Integer(2)
89 assert -x == Integer(-1) * x
90 assert +x == x
92 def test_int_conversion():
93 x = Symbol("x")
94 assert x+1 == x + 1
95 assert x*1 == x
96 assert x**1 == x
97 assert x/2 == x * (Integer(2) ** -1)
99 def test_expand1():
100 x = Symbol("x")
101 y = Symbol("y")
102 z = Symbol("z")
104 assert ( (x+y)**2 ).expand() == x**2 + 2*x*y + y**2
105 assert ( (x+y)**3 ).expand() == x**3 + 3*x**2*y +3*x*y**2 + y**3
107 assert ( (x+y+z)**2 ).expand() == x**2 + y**2 + z**2 + 2*x*y + 2*x*z + 2*y*z
109 def test_expand2():
110 x = Symbol("x")
111 y = Symbol("y")
112 z = Symbol("z")
114 assert ( 2*x*y ).expand() == 2*x*y
115 assert ( (x+y) * (x+z) ).expand() == x**2 + x*y + x*z + y*z
116 assert ( x*(x+y)**2 ).expand() == x**3 + 2*x**2*y + x*y**2
117 assert ( x*(x+y)**2 + z*(x+y)**2 ).expand() == \
118 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
120 assert ( 2*x * (y*x + y*z) ).expand() == 2*x**2*y + 2*x*y*z
121 assert ( (x+y)**2 * (x+z) ).expand() == \
122 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
124 def test_canonicalization():
125 x = Symbol("x")
126 y = Symbol("y")
127 z = Symbol("z")
129 assert x-x == 0
130 assert x*1 == x
131 assert x+0 == x
132 assert x-0 == x
133 assert x**1 == x
134 assert 1**x == 1
135 assert 0**x == 0
137 def test_pow():
138 x = Symbol("x")
139 y = Symbol("y")
140 z = Symbol("z")
142 assert (x**2)**3 == x**6
143 assert (x**y)**3 == x**(3*y)
144 # this is maybe not mathematically correct:
145 assert (x**y)**z == x**(y*z)
147 assert x*x == x**2
148 assert x*x*x == x**3
150 def test_args_type():
151 x = Symbol("x")
152 y = Symbol("y")
153 z = Symbol("z")
155 assert (x+y).type == ADD
156 assert set((x+y).args) == set((x, y))
157 assert set((x+y).args) != set((x, z))
159 assert (x*y*z).type == MUL
160 assert set((x*y*z).args) == set((x, y, z))
162 assert (x**y).type == POW
163 assert (x**y).args == (x, y)
164 assert x.type == SYMBOL
165 assert x.args == ()
166 assert Integer(5).type == INTEGER
167 assert Integer(5).args == ()
169 assert ( x-y ).type == ADD
170 assert set(( x-y ).args) == set((x, -y))
172 def test_hash():
173 x = Symbol("x")
174 y = Symbol("y")
175 z = Symbol("z")
176 a = Symbol("x")
178 assert hash(x) != hash(y)
179 assert hash(x) != hash(z)
180 assert hash(x) == hash(a)
182 assert hash(Integer(3)) == hash(Integer(3))
183 assert hash(Integer(3)) != hash(Integer(4))
185 assert hash(x*y) == hash(y*x)
186 assert hash(x*y) == hash(y*a)
187 assert hash(x*y) != hash(y*z)
188 assert hash(x*y*z) == hash(y*z*x)
189 assert hash(x*y*z) == hash(y*z*a)
191 def test_hash2():
192 x = Symbol("x")
193 y = Symbol("y")
194 z = Symbol("z")
195 a = Symbol("x")
197 assert x*y+y*x == 2*x*y
198 assert x*y-y*x == 0
199 assert x*y+y*a == 2*x*y
200 assert x*y-y*a == 0
202 def test_hash3():
203 x = Symbol("x")
204 y = Symbol("y")
205 z = Symbol("z")
206 e1 = z+x+y*x
207 e2 = x*y+x+z
209 d = {}
210 d[e1] = 1
211 assert e1 in d
212 assert e2 in d