Implemented crisscross algorithm for solving LP problems.
[sympycore.git] / sympycore / heads / exp_coeff_dict.py
blobb7531b21ca51da92ecec0b6b0e4da973f78d8d8f
2 __all__ = ['EXP_COEFF_DICT']
4 from .base import ArithmeticHead
6 from ..core import init_module, Pair, Expr
7 init_module.import_heads()
8 init_module.import_numbers()
9 init_module.import_lowlevel_operations()
11 @init_module
12 def _init(module):
13 from ..arithmetic.number_theory import multinomial_coefficients
14 module.multinomial_coefficients = multinomial_coefficients
16 class ExpCoeffDict(ArithmeticHead):
17 """
18 """
20 def is_data_ok(self, cls, data):
21 if type(data) is not Pair:
22 return 'data must be Pair instance but got %r' % (type(data).__name__)
23 variables, exps_coeff_dict = data.pair
24 if type(variables) is not tuple:
25 return 'data[0] must be tuple but got %r' % (type(variables).__name__)
26 if type(exps_coeff_dict) is not dict:
27 return 'data[1] must be dict but got %r' % (type(exps_coeff_dict).__name__)
28 for exps, coeff in exps_coeff_dict.items():
29 if not coeff:
30 return 'data[1] contains exp-zero-coeff pair: (%s, %s)' % (exps, coeff)
31 if type(exps) is not IntegerList:
32 del exps_coeff_dict[exps]
33 exps = IntegerList(exps)
34 exps_coeff_dict[exps] = coeff
35 return 'data[1] keys must be IntegerList instances but got %r' % (type(exps).__name__)
37 def __repr__(self): return 'EXP_COEFF_DICT'
39 def data_to_str_and_precedence(self, cls, data):
40 variables, exp_coeff_dict = data
41 terms = []
42 for exps in sorted(exp_coeff_dict, reverse=True):
43 coeff = exp_coeff_dict[exps]
44 if type(exps) is not IntegerList:
45 # temporary hook for SPARSE_POLY head, remove if block when SPARSE_POLY is gone
46 exps = IntegerList(exps)
47 factors = []
48 for var,exp in zip(variables, exps):
49 if not isinstance(var, cls):
50 var = cls(SYMBOL, var)
51 factors.append(cls(POW,(var, exp)))
52 terms.append(cls(TERM_COEFF, (cls(MUL,factors), coeff)))
53 return ADD.data_to_str_and_precedence(cls, terms)
55 def to_lowlevel(self, cls, data, pair):
56 variables, exp_coeff_dict = data.pair
57 n = len(exp_coeff_dict)
58 if n==0:
59 return 0
60 if n==1:
61 exps, coeff = dict_get_item(exp_coeff_dict)
62 if type(exps) is not IntegerList:
63 # temporary hook for SPARSE_POLY head, remove if block when SPARSE_POLY is gone
64 exps = IntegerList(exps)
65 factors = []
66 for var, exp in zip(variables, exps.data):
67 if exp==0:
68 continue
69 if not isinstance(var, cls):
70 var = cls(SYMBOL, var)
71 if exp==1:
72 factors.append(var)
73 else:
74 factors.append(cls(POW, (var, exp)))
75 if not factors:
76 return coeff
77 term = MUL.new(cls, factors)
78 return term_coeff_new(cls, (term, coeff))
79 return pair
81 def combine_variables(self, *seq):
82 """
83 Return a tuple of sorted variables combining given variables.
84 """
85 variables = set([])
86 seq = list(seq)
87 while seq:
88 s = seq.pop(0)
89 if isinstance(s, Expr):
90 if s.head is SYMBOL:
91 variables.add(s.data)
92 else:
93 variables.add(s)
94 elif isinstance(s, str):
95 variables.add(s)
96 elif isinstance(s, (tuple, list)):
97 seq.extend(s)
98 elif s is None:
99 pass
100 else:
101 raise TypeError('expected an expression or a sequence of expressions but got %n' % (type(s)))
102 return tuple(sorted(variables))
104 def make_exponent(self, expr, variables):
106 Return exponent list such that expr == variables ** exp_list.
108 if expr is 0:
109 # shortcut to return exponent of a constant expr
110 return IntegerList([0]*len(variables))
111 i = list(variables).index(expr)
112 if i==-1:
113 exp = [0] * len(variables)
114 else:
115 exp = [0] * (i) + [1] + [0] * (len(variables)-i-1)
116 return IntegerList(exp)
118 def to_TERM_COEFF_DICT(self, Algebra, data, expr):
119 variables, exps_coeff_dict = data
120 d = {}
121 for exps, coeff in exps_coeff_dict.iteritems():
122 d1 = {}
123 for exp, var in zip(exps, variables):
124 if exp:
125 var = Algebra(var)
126 d1[var] = exp
127 term = base_exp_dict_new(Algebra, d1)
128 term_coeff_dict_add_item(Algebra, d, term, 1)
129 return term_coeff_dict_new(Algebra, d)
131 def to_EXP_COEFF_DICT(self, cls, data, expr, variables = None):
132 if variables is None:
133 return expr
134 evars, edata = data.pair
135 if evars==variables:
136 return expr
137 variables = self.combine_variables(evars, variables)
138 levars = list(evars)
139 l = []
140 for v in variables:
141 try:
142 i = levars.index(v)
143 except ValueError:
144 i = None
145 l.append(i)
146 d = {}
147 for exps, coeff in edata.iteritems():
148 new_exps = IntegerList([(exps[i] if i is not None else 0) for i in l])
149 d[new_exps] = coeff
150 return cls(self, Pair(variables, d))
152 def neg(self, cls, expr):
153 return self.commutative_mul_number(cls, expr, -1)
155 def add_number(self, cls, lhs, rhs, inplace=False):
156 if not rhs:
157 return lhs
158 lvars, ldict = lhs.data.pair
159 zero_exp = self.make_exponent(0, lvars)
160 if inplace and lhs.is_writable:
161 dict_add_item(cls, ldict, zero_exp, rhs)
162 return lhs
163 d = ldict.copy()
164 dict_add_item(cls, d, zero_exp, rhs)
165 return cls(self, Pair(lvars, d))
167 def add(self, cls, lhs, rhs, inplace=False):
168 lvars, ldict = lhs.data.pair
169 rhead, rdata = rhs.pair
170 if rhead is not EXP_COEFF_DICT:
171 rhs = rhead.to_EXP_COEFF_DICT(cls, rdata, rhs, lvars)
172 rhead, rdata = rhs.pair
173 rvars, rdict = rdata.pair
174 if lvars == rvars:
175 if inplace and lhs.is_writable:
176 dict_add_dict(cls, ldict, rdict)
177 return lhs
178 d = ldict.copy()
179 dict_add_dict(cls, d, rdict)
180 return cls(self, Pair(lvars, d))
181 variables = tuple(sorted(set(lvars + rvars)))
182 lhs = self.to_EXP_COEFF_DICT(cls, lhs.data, lhs, variables)
183 rhs = self.to_EXP_COEFF_DICT(cls, rhs.data, rhs, variables)
184 d = lhs.data.data.copy()
185 dict_add_dict(cls, d, rhs.data.data)
186 return cls(self, Pair(variables, d))
188 def sub(self, cls, lhs, rhs):
189 return self.add(cls, lhs, -rhs)
191 def commutative_mul(self, cls, lhs, rhs):
192 rhead, rdata = rhs.pair
193 if rhead is NUMBER:
194 return self.commutative_mul_number(cls, lhs, rdata)
195 lvars, ldict = lhs.data.pair
196 if rhead is not EXP_COEFF_DICT:
197 rhs = rhead.to_EXP_COEFF_DICT(cls, rdata, rhs, lvars)
198 rhead, rdata = rhs.pair
199 if rhead is EXP_COEFF_DICT:
200 rvars, rdict = rdata.pair
201 d = {}
202 if lvars == rvars:
203 exp_coeff_dict_mul_dict(cls, d, ldict, rdict)
204 return cls(self, Pair(lvars, d))
205 variables = tuple(sorted(set(lvars + rvars)))
206 lhs = self.to_EXP_COEFF_DICT(cls, lhs.data, lhs, variables)
207 rhs = self.to_EXP_COEFF_DICT(cls, rhs.data, rhs, variables)
208 exp_coeff_dict_mul_dict(cls, d, lhs.data[1], rhs.data[1])
209 return cls(self, Pair(variables, d))
210 raise NotImplementedError(`self, rhs.head`)
212 def commutative_mul_number(self, cls, lhs, rhs):
213 lvars, ldict = lhs.data.pair
214 if rhs==0:
215 return cls(self, Pair(lvars, {}))
216 if rhs==1:
217 return lhs
218 d = {}
219 for exps, coeff in ldict.iteritems():
220 d[exps] = coeff * rhs
221 return cls(self, Pair(lvars, d))
223 non_commutative_mul_number = commutative_mul_number
225 def combine_ncmul_exponents(self, lexps, rexps, variables):
227 Return exponents of non-commutative multiplication:
228 variables ** ([lexps] + [rexps]) -> variables ** exps
229 TODO: move the algorithm to expr.py and implement its C version.
231 exps = list(lexps) + list(rexps)
232 n = len(exps)
233 i0 = 0
234 while 1:
235 i = i0
236 while i < n and not exps[i]:
237 i += 1
238 if i==n:
239 break
240 j = i+1
241 while j < n and not exps[j]:
242 j += 1
243 if j==n:
244 break
245 if variables[i] == variables[j]:
246 exps[i] += exps[j]
247 exps[j] = 0
248 i0 = 0
249 else:
250 i0 = j
251 return IntegerList(exps)
253 def eliminate_trivial_exponents(self, Algebra, exps_coeff_dict, variables):
255 Eliminate common trivial exponents (=0) from the set of exponents
256 and return the corresponding non-trivial list of variables.
257 Note: exps_coeff_dict will be changed in place.
259 has_non_zeros = [0] * len(variables)
260 for exps in exps_coeff_dict:
261 n = 0
262 for i, exp in enumerate(exps):
263 if has_non_zeros[i]:
264 n += 1
265 continue
266 if exp:
267 n += 1
268 has_non_zeros[i] = 1
269 if n==len(variables): # no common trivial exponents
270 return variables
271 non_trivial_indices = [i for i in range(len(variables)) if has_non_zeros[i]]
272 assert len(non_trivial_indices) < len(variables),`non_trivial_indices, variables`
273 for exps in exps_coeff_dict.keys():
274 coeff = exps_coeff_dict[exps]
275 del exps_coeff_dict[exps]
276 exps = IntegerList([exps[i] for i in non_trivial_indices])
277 dict_add_item(Algebra, exps_coeff_dict, exps, coeff)
278 return tuple([variables[i] for i in non_trivial_indices])
280 def non_commutative_mul(self, cls, lhs, rhs):
281 rhead, rdata = rhs.pair
282 if rhead is NUMBER:
283 return self.non_commutative_mul_number(cls, lhs, rdata)
284 lvars, ldict = lhs.data.pair
285 if rhead is not EXP_COEFF_DICT:
286 rhs = rhead.to_EXP_COEFF_DICT(cls, rdata, rhs, lvars)
287 rhead, rdata = rhs.pair
288 if rhead is EXP_COEFF_DICT:
289 rvars, rdict = rdata.pair
290 variables = lvars + rvars
291 lhs = self.to_EXP_COEFF_DICT(cls, lhs.data, lhs, variables)
292 rhs = self.to_EXP_COEFF_DICT(cls, rhs.data, rhs, variables)
293 d = {}
294 # TODO: move the following double-loop to expr.py and implement it in C.
295 for lexp, lcoeff in ldict.iteritems():
296 for rexp, rcoeff in rdict.iteritems():
297 exp = self.combine_ncmul_exponents(lexp, rexp, variables)
298 dict_add_item(cls, d, exp, lcoeff * rcoeff)
299 variables = self.eliminate_trivial_exponents(cls, d, variables)
300 return cls(self, Pair(variables, d))
301 raise NotImplementedError(`self, rhs.head`)
303 def commutative_div(self, cls, lhs, rhs):
304 return lhs * rhs ** -1
306 def pow(self, cls, base, exp):
307 variables, exp_coeff_dict = base.data.pair
308 if isinstance(exp, Expr) and exp.head is NUMBER and isinstance(exp.data, inttypes):
309 exp = exp.data
310 if isinstance(exp, inttypes):
311 return self.pow_number(cls, base, exp)
312 expr = cls(POW, (base, exp))
313 variables = self.combine_variables(variables, expr)
314 exps = self.make_exponent(expr, variables)
315 return cls(self, Pair(variables, {exps:1}))
317 def pow_number(self, cls, base, exp):
318 variables, exp_coeff_dict = base.data.pair
319 if isinstance(exp, inttypes):
320 if exp==0:
321 return cls(self, Pair(variables, {(0,)*len(variables):1}))
322 if exp==1:
323 return base
324 if exp>1:
325 exps_coeff_list = base.data.data.items()
326 m = len(variables)
327 mdata = multinomial_coefficients(len(exps_coeff_list), exp)
328 d = {}
329 for e,c in mdata.iteritems():
330 new_exps = IntegerList([0]*m)
331 new_coeff = c
332 for e_i, (exps,coeff) in zip(e, exps_coeff_list):
333 new_exps += exps * e_i
334 new_coeff *= coeff ** e_i
335 dict_add_item(cls, d, new_exps, new_coeff)
336 return cls(self, Pair(variables, d))
337 # exp is negative integer
338 if len(exp_coeff_dict)==1:
339 exps, coeff = dict_get_item(exp_coeff_dict)
340 inv_coeff = number_div(cls, 1, coeff)
341 return cls(self, Pair(variables, {-exps: inv_coeff}))
342 expr = cls(POW, (base, exp))
343 variables = self.combine_variables(variables, expr)
344 exps = self.make_exponent(expr, variables)
345 return cls(self, Pair(variables, {exps:1}))
347 EXP_COEFF_DICT = ExpCoeffDict()