Implemented crisscross algorithm for solving LP problems.
[sympycore.git] / sympycore / calculus / relational.py
blob3a313fae9309d9aadf8a995139b2fc308112f4b2
2 # Created in February 2008 by Fredrik Johansson.
4 """ Provides some basic implementation of assumptions support.
5 """
7 __docformat__ = "restructuredtext"
8 __all__ = ['Assumptions', 'is_positive']
10 from ..core import init_module
11 init_module.import_heads()
12 init_module.import_numbers()
14 from .algebra import Calculus
15 from ..logic import Logic, Lt, Le
17 @init_module
18 def init(module):
19 module.no_assumptions = Assumptions([])
20 module.globalctx = GlobalContext()
21 module.globalctx.assumptions = module.no_assumptions
22 from .functions import pi, E
23 module.pi = pi
24 module.E = E
27 def is_number(x):
28 return isinstance(x, numbertypes) or (isinstance(x, Calculus) \
29 and x.head is NUMBER)
31 class Assumptions:
33 def __init__(self, statements=[]):
34 self.pos_values = []
35 self.nonneg_values = []
36 for stmt in statements:
37 if stmt is True:
38 continue
39 if stmt is False:
40 raise ValueError("got False as assumption")
41 if isinstance(stmt, Logic):
42 head, (lhs, rhs) = stmt.pair
43 if head is LT:
44 self.pos_values.append(rhs - lhs)
45 elif head is LE:
46 self.nonneg_values.append(rhs - lhs)
47 elif head is GT:
48 self.pos_values.append(lhs - rhs)
49 elif head is GE:
50 self.nonneg_values.append(lhs - rhs)
51 else:
52 raise ValueError("unknown assumption: " + repr(stmt))
53 else:
54 raise ValueError("unknown assumption type: " + repr(stmt))
56 def __repr__(self):
57 ps = [repr(Lt(0, a)) for a in self.pos_values]
58 ns = [repr(Le(0, a)) for a in self.nonneg_values]
59 return "Assumptions([%s])" % ", ".join(ps + ns)
61 def __enter__(self):
62 globalctx.assumptions = self
64 def __exit__(self, *args):
65 globalctx.assumptions = no_assumptions
67 def check(self, cond):
68 if isinstance(cond, (bool, type(None))):
69 return cond
70 if isinstance(cond, Logic):
71 head, (lhs, rhs) = cond.pair
72 if head is LT:
73 return self.positive(rhs - lhs)
74 if head is LE:
75 return self.nonnegative(rhs - lhs)
76 if head is GT:
77 return self.positive(lhs - rhs)
78 if head is GE:
79 return self.nonnegative(lhs - rhs)
80 raise ValueError(`cond`)
82 def eq(s, a, b): return s.zero(a-b)
83 def ne(s, a, b): return s.nonzero(a-b)
84 def lt(s, a, b): return s.positive(b-a)
85 def le(s, a, b): return s.nonnegative(b-a)
86 def gt(s, a, b): return s.positive(a-b)
87 def ge(s, a, b): return s.nonnegative(a-b)
89 def negative(s, x):
90 t = s.positive(x)
91 if t is None:
92 return t
93 return not t
95 def nonpositive(s, x):
96 t = s.nonnegative(x)
97 if t is None:
98 return t
99 return not t
101 def zero(s, x):
102 if is_number(x):
103 return x == 0
104 if s.positive(x) or s.negative(x):
105 return False
106 return None
108 def nonzero(s, x):
109 if is_number(x):
110 return x != 0
111 if s.positive(x) or s.negative(x):
112 return True
113 return None
115 def positive(s, x):
116 x = Calculus.convert(x)
117 if x.head is NUMBER:
118 val = x.data
119 if isinstance(x.data, realtypes):
120 return val > 0
121 elif x.head is ADD:
122 args = x.data
123 if any(s.positive(a) for a in args) and all(s.nonnegative(a) for a in args): return True
124 if any(s.negative(a) for a in args) and all(s.nonpositive(a) for a in args): return False
125 elif x.head is TERM_COEFF:
126 return s.positive(Calculus(MUL, map(Calculus, x.data)))
127 elif x.head is TERM_COEFF_DICT:
128 l = []
129 for t,c in x.data.iteritems():
130 l.append(t * c)
131 return s.positive(Calculus(ADD, l))
132 elif x.head is BASE_EXP_DICT:
133 l = []
134 for b, e in x.data.iteritems():
135 l.append(b * e)
136 return s.positive(Calculus(MUL, l))
137 elif x.head is MUL:
138 args = x.data
139 if any(not s.nonzero(a) for a in args):
140 return None
141 neg = sum(s.negative(a) for a in args)
142 return (neg % 2) == 0
143 elif x.head is POW:
144 b, e = x.data
145 if s.positive(b) and s.positive(e):
146 return True
147 elif x == pi or x == E:
148 return True
149 if s.pos_values:
150 # NOTE: this should check both x-p and x+p, i.e. bounds from both directions
151 t1 = any(no_assumptions.nonnegative(x-p) for p in s.pos_values)
152 t2 = any(no_assumptions.nonpositive(x-p) for p in s.pos_values)
153 if t1 and not t2: return True
154 return None
156 def nonnegative(s, x):
157 x = Calculus.convert(x)
158 if x.head is NUMBER:
159 val = x.data
160 if isinstance(x.data, realtypes):
161 return val >= 0
162 elif x.head is ADD:
163 args = x.data
164 if all(s.nonnegative(a) for a in args): return True
165 if all(s.negative(a) for a in args): return False
166 elif x.head is MUL:
167 args = x.data
168 if all(s.nonnegative(a) for a in args): return True
169 elif x.head is TERM_COEFF:
170 return s.nonnegative(Calculus(MUL, map(Calculus, x.data)))
171 elif x.head is POW:
172 b, e = x.data
173 if s.nonnegative(b) and s.nonnegative(e):
174 return True
175 elif x.head is TERM_COEFF_DICT:
176 l = []
177 for t,c in x.data.iteritems():
178 l.append(t * c)
179 return s.nonnegative(Calculus(ADD, l))
180 elif x.head is BASE_EXP_DICT:
181 l = []
182 for b, e in x.data.iteritems():
183 l.append(b * e)
184 return s.nonnegative(Calculus(MUL, l))
185 elif x == pi or x == E:
186 return True
187 if s.nonneg_values:
188 # NOTE: this should check both x-p and x+p, i.e. bounds from both directions
189 t1 = any(no_assumptions.nonnegative(x-p) for p in s.nonneg_values)
190 t2 = any(no_assumptions.negative(x-p) for p in s.nonneg_values)
191 if t1 and not t2: return True
192 return None
195 class GlobalContext(object):
196 pass
198 def is_positive(e):
199 return globalctx.assumptions.positive(e)