benchmarks
[sympyx.git] / test_basic.py
blobce69cfe0716f9a804db317305f0103623667bdb6
1 from sympy import Symbol, Add, Mul, Pow, Integer, SYMBOL, ADD, MUL, POW, \
2 INTEGER
4 def test_eq():
6 x = Symbol("x")
7 y = Symbol("y")
8 z = Symbol("z")
9 a = Symbol("x")
11 assert x == x
12 assert not (x != x)
13 assert x == a
14 assert not (x != a)
15 assert x != y
17 assert x + y == x + y
18 assert a + y == x + y
19 assert x + y == y + x
20 assert x + y != y + z
22 assert x * y == x * y
23 assert a * y == x * y
24 assert x * y == y * x
25 assert x * y != y * z
27 assert x ** y == x ** y
28 assert a ** y == x ** y
29 assert x ** y != y ** x
30 assert x ** y != y ** z
31 assert a ** y != x ** z
33 assert Integer(3) == Integer(3)
34 assert Integer(3) != Integer(4)
36 def test_add():
37 x = Symbol("x")
38 y = Symbol("y")
39 z = Symbol("z")
41 assert (x + y) + z == x + (y + z)
42 assert (z + x) + y == x + (y + z)
43 assert (z + x) + x != x + (y + z)
45 assert x + x == Integer(2) * x
46 assert ((x + y) + z) + x == (Integer(2)*x + y) + z
48 def test_mul():
49 x = Symbol("x")
50 y = Symbol("y")
51 z = Symbol("z")
53 assert (x * y) * z == x * (y * z)
54 assert (z * x) * y == x * (y * z)
55 assert (z * x) * x != x * (y * z)
57 assert x * x == x ** Integer(2)
58 assert ((x * y) * z) * x == ((x ** Integer(2)) * y) * z
60 def test_arit():
61 x = Symbol("x")
62 y = Symbol("y")
63 z = Symbol("z")
65 assert x+y == x + y
66 assert x+y+z == (x + y) + z
68 assert x-y == x + (Integer(-1) * y)
69 assert y-x == (Integer(-1) * x) + y
71 assert x*y == x * y
72 assert x*y*z == z * (x * y)
74 assert x/y == x * (y ** Integer(-1))
75 assert y/x == (x ** Integer(-1)) * y
77 assert x**Integer(2) == x ** Integer(2)
79 assert -x == Integer(-1) * x
80 assert +x == x
82 def test_int_conversion():
83 x = Symbol("x")
84 assert x+1 == x + 1
85 assert x*1 == x
86 assert x**1 == x
87 assert x/2 == x * (Integer(2) ** -1)
89 def test_expand1():
90 x = Symbol("x")
91 y = Symbol("y")
92 z = Symbol("z")
94 assert ( (x+y)**2 ).expand() == x**2 + 2*x*y + y**2
95 assert ( (x+y)**3 ).expand() == x**3 + 3*x**2*y +3*x*y**2 + y**3
97 assert ( (x+y+z)**2 ).expand() == x**2 + y**2 + z**2 + 2*x*y + 2*x*z + 2*y*z
99 def test_expand2():
100 x = Symbol("x")
101 y = Symbol("y")
102 z = Symbol("z")
104 assert ( 2*x*y ).expand() == 2*x*y
105 assert ( (x+y) * (x+z) ).expand() == x**2 + x*y + x*z + y*z
106 assert ( x*(x+y)**2 ).expand() == x**3 + 2*x**2*y + x*y**2
107 assert ( x*(x+y)**2 + z*(x+y)**2 ).expand() == \
108 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
110 assert ( 2*x * (y*x + y*z) ).expand() == 2*x**2*y + 2*x*y*z
111 assert ( (x+y)**2 * (x+z) ).expand() == \
112 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
114 def test_canonicalization():
115 x = Symbol("x")
116 y = Symbol("y")
117 z = Symbol("z")
119 assert x-x == 0
120 assert x*1 == x
121 assert x+0 == x
122 assert x-0 == x
123 assert x**1 == x
124 assert 1**x == 1
125 assert 0**x == 0
127 def test_pow():
128 x = Symbol("x")
129 y = Symbol("y")
130 z = Symbol("z")
132 assert (x**2)**3 == x**6
133 assert (x**y)**3 == x**(3*y)
134 # this is maybe not mathematically correct:
135 assert (x**y)**z == x**(y*z)
137 assert x*x == x**2
138 assert x*x*x == x**3
140 def test_args_type():
141 x = Symbol("x")
142 y = Symbol("y")
143 z = Symbol("z")
145 assert (x+y).type == ADD
146 assert set((x+y).args) == set((x, y))
147 assert set((x+y).args) != set((x, z))
149 assert (x*y*z).type == MUL
150 assert set((x*y*z).args) == set((x, y, z))
152 assert (x**y).type == POW
153 assert (x**y).args == (x, y)
154 assert x.type == SYMBOL
155 assert x.args == ()
156 assert Integer(5).type == INTEGER
157 assert Integer(5).args == ()
159 assert ( x-y ).type == ADD
160 assert set(( x-y ).args) == set((x, -y))
162 def test_hash():
163 x = Symbol("x")
164 y = Symbol("y")
165 z = Symbol("z")
166 a = Symbol("x")
168 assert hash(x) != hash(y)
169 assert hash(x) != hash(z)
170 assert hash(x) == hash(a)
172 assert hash(Integer(3)) == hash(Integer(3))
173 assert hash(Integer(3)) != hash(Integer(4))
175 assert hash(x*y) == hash(y*x)
176 assert hash(x*y) == hash(y*a)
177 assert hash(x*y) != hash(y*z)
178 assert hash(x*y*z) == hash(y*z*x)
179 assert hash(x*y*z) == hash(y*z*a)
181 def test_hash2():
182 x = Symbol("x")
183 y = Symbol("y")
184 z = Symbol("z")
185 a = Symbol("x")
187 assert x*y+y*x == 2*x*y
188 assert x*y-y*x == 0
189 assert x*y+y*a == 2*x*y
190 assert x*y-y*a == 0