remove_all_equalities: also remove parameter equalities after main compression
[barvinok.git] / bfcounter.cc
blob56c79ff3e635a8a484e9ba3657d6b5349a1d17a1
1 #include <string.h>
2 #include <vector>
3 #include <ostream>
4 #include <iostream>
5 #include <NTL/vec_ZZ.h>
6 #include <NTL/mat_ZZ.h>
7 #include <barvinok/polylib.h>
8 #include "bfcounter.h"
9 #include "lattice_point.h"
11 using std::vector;
12 using std::cerr;
13 using std::endl;
15 static int lex_cmp(vec_ZZ& a, vec_ZZ& b)
17 assert(a.length() == b.length());
19 for (int j = 0; j < a.length(); ++j)
20 if (a[j] != b[j])
21 return a[j] < b[j] ? -1 : 1;
22 return 0;
25 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num_orig, vec_ZZ& extra_num)
27 vec_ZZ num;
28 int d = num_orig.length();
29 num.SetLength(d-1);
30 for (int l = 0; l < d-1; ++l)
31 num[l] = num_orig[l+1] + extra_num[l];
33 add_term(t, num);
36 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num)
38 int len = t->terms.NumRows();
39 int i, r;
40 for (i = 0; i < len; ++i) {
41 r = lex_cmp(t->terms[i], num);
42 if (r >= 0)
43 break;
45 if (i == len || r > 0) {
46 t->terms.SetDims(len+1, num.length());
47 insert_term(t, i);
48 t->terms[i] = num;
49 } else {
50 // i < len && r == 0
51 update_term(t, i);
55 bfc_term_base* bf_base::find_bfc_term(bfc_vec& v, int *powers, int len)
57 bfc_vec::iterator i;
58 for (i = v.begin(); i != v.end(); ++i) {
59 int j;
60 for (j = 0; j < len; ++j)
61 if ((*i)->powers[j] != powers[j])
62 break;
63 if (j == len)
64 return (*i);
65 if ((*i)->powers[j] > powers[j])
66 break;
69 bfc_term_base* t = new_bf_term(len);
70 v.insert(i, t);
71 memcpy(t->powers, powers, len * sizeof(int));
73 return t;
76 void bf_base::reduce(mat_ZZ& factors, bfc_vec& v, barvinok_options *options)
78 assert(v.size() > 0);
79 unsigned d = factors.NumCols();
81 if (d == lower)
82 return base(factors, v);
84 bf_reducer bfr(factors, v, this);
86 bfr.reduce(options);
88 if (bfr.vn.size() > 0)
89 reduce(bfr.nfactors, bfr.vn, options);
92 int bf_base::setup_factors(const mat_ZZ& rays, mat_ZZ& factors,
93 bfc_term_base* t, int s)
95 factors.SetDims(dim, dim);
97 int r;
99 for (r = 0; r < dim; ++r)
100 t->powers[r] = 1;
102 for (r = 0; r < dim; ++r) {
103 factors[r] = rays[r];
104 int k;
105 for (k = 0; k < dim; ++k)
106 if (factors[r][k] != 0)
107 break;
108 if (factors[r][k] < 0) {
109 factors[r] = -factors[r];
110 for (int i = 0; i < t->terms.NumRows(); ++i)
111 t->terms[i] += factors[r];
112 s = -s;
116 return s;
119 void bf_base::handle(const mat_ZZ& rays, Value *vertex, const QQ& c,
120 unsigned long det, barvinok_options *options)
122 bfc_term* t = new bfc_term(dim);
123 vector< bfc_term_base * > v;
124 v.push_back(t);
126 Matrix *points = Matrix_Alloc(det, dim);
127 Matrix* Rays = zz2matrix(rays);
128 lattice_points_fixed(vertex, vertex, Rays, Rays, points, det);
129 Matrix_Free(Rays);
130 matrix2zz(points, t->terms, points->NbRows, points->NbColumns);
131 Matrix_Free(points);
133 // the elements of factors are always lexpositive
134 mat_ZZ factors;
135 int s = setup_factors(rays, factors, t, 1);
137 t->c.SetLength(t->terms.NumRows());
139 for (int i = 0; i < t->c.length(); ++i) {
140 t->c[i].n = s * c.n;
141 t->c[i].d = c.d;
144 reduce(factors, v, options);
147 bfc_term_base* bfcounter_base::new_bf_term(int len)
149 bfc_term* t = new bfc_term(len);
150 t->c.SetLength(0);
151 return t;
154 void bfcounter_base::set_factor(bfc_term_base *t, int k, int change)
156 bfc_term* bfct = static_cast<bfc_term *>(t);
157 c = bfct->c[k];
158 if (change)
159 c.n = -c.n;
162 void bfcounter_base::set_factor(bfc_term_base *t, int k, mpq_t &f, int change)
164 bfc_term* bfct = static_cast<bfc_term *>(t);
165 value2zz(mpq_numref(f), c.n);
166 value2zz(mpq_denref(f), c.d);
167 c *= bfct->c[k];
168 if (change)
169 c.n = -c.n;
172 void bfcounter_base::set_factor(bfc_term_base *t, int k, const QQ& c_factor,
173 int change)
175 bfc_term* bfct = static_cast<bfc_term *>(t);
176 c = bfct->c[k];
177 c *= c_factor;
178 if (change)
179 c.n = -c.n;
182 void bfcounter_base::insert_term(bfc_term_base *t, int i)
184 bfc_term* bfct = static_cast<bfc_term *>(t);
185 int len = t->terms.NumRows()-1; // already increased by one
187 bfct->c.SetLength(len+1);
188 for (int j = len; j > i; --j) {
189 bfct->c[j] = bfct->c[j-1];
190 t->terms[j] = t->terms[j-1];
192 bfct->c[i] = c;
195 void bfcounter_base::update_term(bfc_term_base *t, int i)
197 bfc_term* bfct = static_cast<bfc_term *>(t);
199 bfct->c[i] += c;
202 void bf_reducer::compute_extra_num(int i)
204 clear(extra_num);
205 changes = 0;
206 no_param = 0; // r from text
207 only_param = 0; // k-r-s from text
208 total_power = 0; // k from text
210 for (int j = 0; j < nf; ++j) {
211 if (v[i]->powers[j] == 0)
212 continue;
214 total_power += v[i]->powers[j];
215 if (factors[j][0] == 0) {
216 only_param += v[i]->powers[j];
217 continue;
220 if (old2new[j] == -1)
221 no_param += v[i]->powers[j];
222 else
223 extra_num += -sign[j] * v[i]->powers[j] * nfactors[old2new[j]];
224 changes += v[i]->powers[j];
228 void bf_reducer::update_powers(const std::vector<int>& powers)
230 for (int l = 0; l < nnf; ++l)
231 npowers[l] = bpowers[l];
233 l_extra_num = extra_num;
234 l_changes = changes;
236 for (int l = 0; l < powers.size(); ++l) {
237 int n = powers[l];
238 if (n == 0)
239 continue;
240 assert(old2new[l] != -1);
242 npowers[old2new[l]] += n;
243 // interpretation of sign has been inverted
244 // since we inverted the power for specialization
245 if (sign[l] == 1) {
246 l_extra_num += n * nfactors[old2new[l]];
247 l_changes += n;
253 void bf_reducer::compute_reduced_factors()
255 unsigned nf = factors.NumRows();
256 unsigned d = factors.NumCols();
257 nnf = 0;
258 nfactors.SetDims(nnf, d-1);
260 for (int i = 0; i < nf; ++i) {
261 int j;
262 int s = 1;
263 for (j = 0; j < nnf; ++j) {
264 int k;
265 for (k = 1; k < d; ++k)
266 if (factors[i][k] != 0 || nfactors[j][k-1] != 0)
267 break;
268 if (k < d && factors[i][k] == -nfactors[j][k-1])
269 s = -1;
270 for (; k < d; ++k)
271 if (factors[i][k] != s * nfactors[j][k-1])
272 break;
273 if (k == d)
274 break;
276 old2new[i] = j;
277 if (j == nnf) {
278 int k;
279 for (k = 1; k < d; ++k)
280 if (factors[i][k] != 0)
281 break;
282 if (k < d) {
283 if (factors[i][k] < 0)
284 s = -1;
285 nfactors.SetDims(++nnf, d-1);
286 for (int k = 1; k < d; ++k)
287 nfactors[j][k-1] = s * factors[i][k];
288 } else
289 old2new[i] = -1;
291 sign[i] = s;
293 npowers = new int[nnf];
294 bpowers = new int[nnf];
297 void bf_reducer::reduce(barvinok_options *options)
299 compute_reduced_factors();
301 Value tmp;
302 value_init(tmp);
303 for (int i = 0; i < v.size(); ++i) {
304 compute_extra_num(i);
306 if (no_param == 0) {
307 vec_ZZ extra_num;
308 extra_num.SetLength(d-1);
309 int changes = 0;
310 int *npowers = new int[nnf];
311 for (int k = 0; k < nnf; ++k)
312 npowers[k] = 0;
313 for (int k = 0; k < nf; ++k) {
314 assert(old2new[k] != -1);
315 npowers[old2new[k]] += v[i]->powers[k];
316 if (sign[k] == -1) {
317 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
318 changes += v[i]->powers[k];
322 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
323 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
324 bf->set_factor(v[i], k, changes % 2);
325 bf->add_term(t, v[i]->terms[k], extra_num);
327 delete [] npowers;
328 } else {
329 // powers of "constant" part
330 for (int k = 0; k < nnf; ++k)
331 bpowers[k] = 0;
332 for (int k = 0; k < nf; ++k) {
333 if (factors[k][0] != 0)
334 continue;
335 assert(old2new[k] != -1);
336 bpowers[old2new[k]] += v[i]->powers[k];
337 if (sign[k] == -1) {
338 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
339 changes += v[i]->powers[k];
343 int j;
344 for (j = 0; j < nf; ++j)
345 if (old2new[j] == -1 && v[i]->powers[j] > 0)
346 break;
348 zz2value(factors[j][0], tmp);
349 dpoly D(no_param, tmp, 1);
350 for (int k = 1; k < v[i]->powers[j]; ++k) {
351 dpoly fact(no_param, tmp, 1);
352 D *= fact;
354 for ( ; ++j < nf; )
355 if (old2new[j] == -1) {
356 zz2value(factors[j][0], tmp);
357 for (int k = 0; k < v[i]->powers[j]; ++k) {
358 dpoly fact(no_param, tmp, 1);
359 D *= fact;
363 if (no_param + only_param == total_power &&
364 bf->constant_vertex(d)) {
365 bfc_term_base * t = NULL;
366 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
367 zz2value(v[i]->terms[k][0], tmp);
368 dpoly n(no_param, tmp);
369 mpq_set_si(bf->tcount, 0, 1);
370 n.div(D, bf->tcount, 1);
372 if (value_zero_p(mpq_numref(bf->tcount)))
373 continue;
375 if (!t)
376 t = bf->find_bfc_term(vn, bpowers, nnf);
377 bf->set_factor(v[i], k, bf->tcount, changes % 2);
378 bf->add_term(t, v[i]->terms[k], extra_num);
380 } else {
381 for (int j = 0; j < v[i]->terms.NumRows(); ++j) {
382 zz2value(v[i]->terms[j][0], tmp);
383 dpoly n(no_param, tmp);
385 dpoly_r * r = 0;
386 if (no_param + only_param == total_power)
387 r = new dpoly_r(n, nf);
388 else
389 for (int k = 0; k < nf; ++k) {
390 if (v[i]->powers[k] == 0)
391 continue;
392 if (factors[k][0] == 0 || old2new[k] == -1)
393 continue;
395 zz2value(factors[k][0], tmp);
396 dpoly pd(no_param-1, tmp, 1);
398 for (int l = 0; l < v[i]->powers[k]; ++l) {
399 int q;
400 for (q = 0; q < k; ++q)
401 if (old2new[q] == old2new[k] &&
402 sign[q] == sign[k])
403 break;
405 if (r == 0)
406 r = new dpoly_r(n, pd, q, nf);
407 else {
408 dpoly_r *nr = new dpoly_r(r, pd, q, nf);
409 delete r;
410 r = nr;
415 dpoly_r *rc = r->div(D);
416 delete r;
417 QQ factor;
418 factor.d = rc->denom;
420 if (bf->constant_vertex(d)) {
421 dpoly_r_term_list& final = rc->c[rc->len-1];
423 dpoly_r_term_list::iterator k;
424 for (k = final.begin(); k != final.end(); ++k) {
425 if ((*k)->coeff == 0)
426 continue;
428 update_powers((*k)->powers);
430 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
431 factor.n = (*k)->coeff;
432 bf->set_factor(v[i], j, factor, l_changes % 2);
433 bf->add_term(t, v[i]->terms[j], l_extra_num);
435 } else
436 bf->cum(this, v[i], j, rc, options);
438 delete rc;
442 delete v[i];
444 value_clear(tmp);