update isl for fix in isl_schedule interface
[barvinok.git] / bfcounter.cc
blob8d6a9e4986d79adbae645d7b3f3af5c86b712a22
1 #include <vector>
2 #include "bfcounter.h"
3 #include "lattice_point.h"
5 using std::vector;
6 using std::cerr;
7 using std::endl;
9 static int lex_cmp(vec_ZZ& a, vec_ZZ& b)
11 assert(a.length() == b.length());
13 for (int j = 0; j < a.length(); ++j)
14 if (a[j] != b[j])
15 return a[j] < b[j] ? -1 : 1;
16 return 0;
19 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num_orig, vec_ZZ& extra_num)
21 vec_ZZ num;
22 int d = num_orig.length();
23 num.SetLength(d-1);
24 for (int l = 0; l < d-1; ++l)
25 num[l] = num_orig[l+1] + extra_num[l];
27 add_term(t, num);
30 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num)
32 int len = t->terms.NumRows();
33 int i, r;
34 for (i = 0; i < len; ++i) {
35 r = lex_cmp(t->terms[i], num);
36 if (r >= 0)
37 break;
39 if (i == len || r > 0) {
40 t->terms.SetDims(len+1, num.length());
41 insert_term(t, i);
42 t->terms[i] = num;
43 } else {
44 // i < len && r == 0
45 update_term(t, i);
49 bfc_term_base* bf_base::find_bfc_term(bfc_vec& v, int *powers, int len)
51 bfc_vec::iterator i;
52 for (i = v.begin(); i != v.end(); ++i) {
53 int j;
54 for (j = 0; j < len; ++j)
55 if ((*i)->powers[j] != powers[j])
56 break;
57 if (j == len)
58 return (*i);
59 if ((*i)->powers[j] > powers[j])
60 break;
63 bfc_term_base* t = new_bf_term(len);
64 v.insert(i, t);
65 memcpy(t->powers, powers, len * sizeof(int));
67 return t;
70 void bf_base::reduce(mat_ZZ& factors, bfc_vec& v, barvinok_options *options)
72 assert(v.size() > 0);
73 unsigned nf = factors.NumRows();
74 unsigned d = factors.NumCols();
76 if (d == lower)
77 return base(factors, v);
79 bf_reducer bfr(factors, v, this);
81 bfr.reduce(options);
83 if (bfr.vn.size() > 0)
84 reduce(bfr.nfactors, bfr.vn, options);
87 int bf_base::setup_factors(const mat_ZZ& rays, mat_ZZ& factors,
88 bfc_term_base* t, int s)
90 factors.SetDims(dim, dim);
92 int r;
94 for (r = 0; r < dim; ++r)
95 t->powers[r] = 1;
97 for (r = 0; r < dim; ++r) {
98 factors[r] = rays[r];
99 int k;
100 for (k = 0; k < dim; ++k)
101 if (factors[r][k] != 0)
102 break;
103 if (factors[r][k] < 0) {
104 factors[r] = -factors[r];
105 for (int i = 0; i < t->terms.NumRows(); ++i)
106 t->terms[i] += factors[r];
107 s = -s;
111 return s;
114 void bf_base::handle(const mat_ZZ& rays, Value *vertex, const QQ& c,
115 unsigned long det, barvinok_options *options)
117 bfc_term* t = new bfc_term(dim);
118 vector< bfc_term_base * > v;
119 v.push_back(t);
121 Matrix *points = Matrix_Alloc(det, dim);
122 Matrix* Rays = zz2matrix(rays);
123 lattice_points_fixed(vertex, vertex, Rays, Rays, points, det);
124 Matrix_Free(Rays);
125 matrix2zz(points, t->terms, points->NbRows, points->NbColumns);
126 Matrix_Free(points);
128 // the elements of factors are always lexpositive
129 mat_ZZ factors;
130 int s = setup_factors(rays, factors, t, 1);
132 t->c.SetLength(t->terms.NumRows());
134 for (int i = 0; i < t->c.length(); ++i) {
135 t->c[i].n = s * c.n;
136 t->c[i].d = c.d;
139 reduce(factors, v, options);
142 bfc_term_base* bfcounter_base::new_bf_term(int len)
144 bfc_term* t = new bfc_term(len);
145 t->c.SetLength(0);
146 return t;
149 void bfcounter_base::set_factor(bfc_term_base *t, int k, int change)
151 bfc_term* bfct = static_cast<bfc_term *>(t);
152 c = bfct->c[k];
153 if (change)
154 c.n = -c.n;
157 void bfcounter_base::set_factor(bfc_term_base *t, int k, mpq_t &f, int change)
159 bfc_term* bfct = static_cast<bfc_term *>(t);
160 value2zz(mpq_numref(f), c.n);
161 value2zz(mpq_denref(f), c.d);
162 c *= bfct->c[k];
163 if (change)
164 c.n = -c.n;
167 void bfcounter_base::set_factor(bfc_term_base *t, int k, const QQ& c_factor,
168 int change)
170 bfc_term* bfct = static_cast<bfc_term *>(t);
171 c = bfct->c[k];
172 c *= c_factor;
173 if (change)
174 c.n = -c.n;
177 void bfcounter_base::insert_term(bfc_term_base *t, int i)
179 bfc_term* bfct = static_cast<bfc_term *>(t);
180 int len = t->terms.NumRows()-1; // already increased by one
182 bfct->c.SetLength(len+1);
183 for (int j = len; j > i; --j) {
184 bfct->c[j] = bfct->c[j-1];
185 t->terms[j] = t->terms[j-1];
187 bfct->c[i] = c;
190 void bfcounter_base::update_term(bfc_term_base *t, int i)
192 bfc_term* bfct = static_cast<bfc_term *>(t);
194 bfct->c[i] += c;
197 void bf_reducer::compute_extra_num(int i)
199 clear(extra_num);
200 changes = 0;
201 no_param = 0; // r from text
202 only_param = 0; // k-r-s from text
203 total_power = 0; // k from text
205 for (int j = 0; j < nf; ++j) {
206 if (v[i]->powers[j] == 0)
207 continue;
209 total_power += v[i]->powers[j];
210 if (factors[j][0] == 0) {
211 only_param += v[i]->powers[j];
212 continue;
215 if (old2new[j] == -1)
216 no_param += v[i]->powers[j];
217 else
218 extra_num += -sign[j] * v[i]->powers[j] * nfactors[old2new[j]];
219 changes += v[i]->powers[j];
223 void bf_reducer::update_powers(const std::vector<int>& powers)
225 for (int l = 0; l < nnf; ++l)
226 npowers[l] = bpowers[l];
228 l_extra_num = extra_num;
229 l_changes = changes;
231 for (int l = 0; l < powers.size(); ++l) {
232 int n = powers[l];
233 if (n == 0)
234 continue;
235 assert(old2new[l] != -1);
237 npowers[old2new[l]] += n;
238 // interpretation of sign has been inverted
239 // since we inverted the power for specialization
240 if (sign[l] == 1) {
241 l_extra_num += n * nfactors[old2new[l]];
242 l_changes += n;
248 void bf_reducer::compute_reduced_factors()
250 unsigned nf = factors.NumRows();
251 unsigned d = factors.NumCols();
252 nnf = 0;
253 nfactors.SetDims(nnf, d-1);
255 for (int i = 0; i < nf; ++i) {
256 int j;
257 int s = 1;
258 for (j = 0; j < nnf; ++j) {
259 int k;
260 for (k = 1; k < d; ++k)
261 if (factors[i][k] != 0 || nfactors[j][k-1] != 0)
262 break;
263 if (k < d && factors[i][k] == -nfactors[j][k-1])
264 s = -1;
265 for (; k < d; ++k)
266 if (factors[i][k] != s * nfactors[j][k-1])
267 break;
268 if (k == d)
269 break;
271 old2new[i] = j;
272 if (j == nnf) {
273 int k;
274 for (k = 1; k < d; ++k)
275 if (factors[i][k] != 0)
276 break;
277 if (k < d) {
278 if (factors[i][k] < 0)
279 s = -1;
280 nfactors.SetDims(++nnf, d-1);
281 for (int k = 1; k < d; ++k)
282 nfactors[j][k-1] = s * factors[i][k];
283 } else
284 old2new[i] = -1;
286 sign[i] = s;
288 npowers = new int[nnf];
289 bpowers = new int[nnf];
292 void bf_reducer::reduce(barvinok_options *options)
294 compute_reduced_factors();
296 Value tmp;
297 value_init(tmp);
298 for (int i = 0; i < v.size(); ++i) {
299 compute_extra_num(i);
301 if (no_param == 0) {
302 vec_ZZ extra_num;
303 extra_num.SetLength(d-1);
304 int changes = 0;
305 int *npowers = new int[nnf];
306 for (int k = 0; k < nnf; ++k)
307 npowers[k] = 0;
308 for (int k = 0; k < nf; ++k) {
309 assert(old2new[k] != -1);
310 npowers[old2new[k]] += v[i]->powers[k];
311 if (sign[k] == -1) {
312 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
313 changes += v[i]->powers[k];
317 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
318 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
319 bf->set_factor(v[i], k, changes % 2);
320 bf->add_term(t, v[i]->terms[k], extra_num);
322 delete [] npowers;
323 } else {
324 // powers of "constant" part
325 for (int k = 0; k < nnf; ++k)
326 bpowers[k] = 0;
327 for (int k = 0; k < nf; ++k) {
328 if (factors[k][0] != 0)
329 continue;
330 assert(old2new[k] != -1);
331 bpowers[old2new[k]] += v[i]->powers[k];
332 if (sign[k] == -1) {
333 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
334 changes += v[i]->powers[k];
338 int j;
339 for (j = 0; j < nf; ++j)
340 if (old2new[j] == -1 && v[i]->powers[j] > 0)
341 break;
343 zz2value(factors[j][0], tmp);
344 dpoly D(no_param, tmp, 1);
345 for (int k = 1; k < v[i]->powers[j]; ++k) {
346 dpoly fact(no_param, tmp, 1);
347 D *= fact;
349 for ( ; ++j < nf; )
350 if (old2new[j] == -1) {
351 zz2value(factors[j][0], tmp);
352 for (int k = 0; k < v[i]->powers[j]; ++k) {
353 dpoly fact(no_param, tmp, 1);
354 D *= fact;
358 if (no_param + only_param == total_power &&
359 bf->constant_vertex(d)) {
360 bfc_term_base * t = NULL;
361 vec_ZZ num;
362 num.SetLength(d-1);
363 ZZ cn;
364 ZZ cd;
365 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
366 zz2value(v[i]->terms[k][0], tmp);
367 dpoly n(no_param, tmp);
368 mpq_set_si(bf->tcount, 0, 1);
369 n.div(D, bf->tcount, 1);
371 if (value_zero_p(mpq_numref(bf->tcount)))
372 continue;
374 if (!t)
375 t = bf->find_bfc_term(vn, bpowers, nnf);
376 bf->set_factor(v[i], k, bf->tcount, changes % 2);
377 bf->add_term(t, v[i]->terms[k], extra_num);
379 } else {
380 for (int j = 0; j < v[i]->terms.NumRows(); ++j) {
381 zz2value(v[i]->terms[j][0], tmp);
382 dpoly n(no_param, tmp);
384 dpoly_r * r = 0;
385 if (no_param + only_param == total_power)
386 r = new dpoly_r(n, nf);
387 else
388 for (int k = 0; k < nf; ++k) {
389 if (v[i]->powers[k] == 0)
390 continue;
391 if (factors[k][0] == 0 || old2new[k] == -1)
392 continue;
394 zz2value(factors[k][0], tmp);
395 dpoly pd(no_param-1, tmp, 1);
397 for (int l = 0; l < v[i]->powers[k]; ++l) {
398 int q;
399 for (q = 0; q < k; ++q)
400 if (old2new[q] == old2new[k] &&
401 sign[q] == sign[k])
402 break;
404 if (r == 0)
405 r = new dpoly_r(n, pd, q, nf);
406 else {
407 dpoly_r *nr = new dpoly_r(r, pd, q, nf);
408 delete r;
409 r = nr;
414 dpoly_r *rc = r->div(D);
415 delete r;
416 QQ factor;
417 factor.d = rc->denom;
419 if (bf->constant_vertex(d)) {
420 dpoly_r_term_list& final = rc->c[rc->len-1];
422 dpoly_r_term_list::iterator k;
423 for (k = final.begin(); k != final.end(); ++k) {
424 if ((*k)->coeff == 0)
425 continue;
427 update_powers((*k)->powers);
429 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
430 factor.n = (*k)->coeff;
431 bf->set_factor(v[i], j, factor, l_changes % 2);
432 bf->add_term(t, v[i]->terms[j], l_extra_num);
434 } else
435 bf->cum(this, v[i], j, rc, options);
437 delete rc;
441 delete v[i];
443 value_clear(tmp);