update isl for change in lexicographic optimization
[ppn.git] / eqv.cc
blob106e9e10a1ef6fb6d945b15c7d23d357232885de
1 #include <limits.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <algorithm>
6 #include <map>
7 #include <set>
8 #include <vector>
9 #include <string>
10 #include <iostream>
11 #include <sstream>
13 #include <isa/yaml.h>
14 #include <isa/pdg.h>
16 #include <isl/ctx.h>
17 #include <isl/space.h>
18 #include <isl/set.h>
19 #include <isl/map.h>
20 #include <isl/constraint.h>
21 #include <isl/union_set.h>
22 #include <isl/printer.h>
23 #include "eqv_options.h"
24 #include "dependence_graph.h"
26 struct ops {
27 std::vector<const char *> associative;
28 std::vector<const char *> commutative;
30 int ops_init(void *user)
32 struct ops **ops = (struct ops **)user;
33 *ops = new struct ops;
34 return 0;
36 void ops_clear(void *user)
38 struct ops **ops = (struct ops **)user;
39 delete *ops;
42 /* For each basic map in relation, add an edge to comp with
43 * given source and pos.
45 static void add_split_edge(computation *comp, computation *source,
46 int pos, isl_map *relation, std::vector<computation **> *missing,
47 enum edge::type type = edge::normal)
49 edge *e = new edge;
50 e->source = source;
51 if (!source) {
52 assert(missing);
53 missing->push_back(&e->source);
55 e->pos = pos;
56 e->relation = relation;
57 e->type = type;
58 comp->edges.push_back(e);
61 static unsigned update_out_dim(pdg::PDG *pdg, unsigned out_dim)
63 for (int i = 0; i < pdg->arrays.size(); ++i) {
64 pdg::array *array = pdg->arrays[i];
65 if (array->type != pdg::array::output)
66 continue;
67 if (array->dims.size() + 1 > out_dim)
68 out_dim = array->dims.size() + 1;
71 return out_dim;
74 static bool is_associative(const std::vector<const char *> &associative,
75 const char *op)
77 std::vector<const char *>::const_iterator iter;
79 for (iter = associative.begin(); iter != associative.end(); ++iter)
80 if (!strcmp(*iter, op))
81 return 1;
82 return 0;
85 /* Return a computation that has an associative operation
86 * that takes a different computation with the same operation
87 * as one of its arguments. Also return the edge from the
88 * first to the second computation in *e.
90 * If no such computation exists, then return NULL.
92 computation *dependence_graph::associative_node(edge **e,
93 const std::vector<const char *> &associative)
95 for (int i = 0; i < vertices.size(); ++i) {
96 computation *comp = vertices[i];
97 if (!is_associative(associative, comp->operation))
98 continue;
99 for (int j = 0; j < comp->edges.size(); ++j) {
100 computation *other = comp->edges[j]->source;
101 if (comp->has_same_source(other))
102 continue;
103 if (strcmp(comp->operation, other->operation))
104 continue;
105 *e = comp->edges[j];
106 return comp;
109 return NULL;
112 /* Splice the source of edge edge into the position of edge e,
113 * removing all other edges with the same positions and bumping
114 * up edges with a greater position.
116 void dependence_graph::splice(computation *comp, edge *e)
118 computation *source = e->source;
119 int pos = e->pos;
120 isl_map *map = isl_map_copy(e->relation);
121 std::vector<struct edge *> new_edges;
123 for (int i = 0; i < comp->edges.size(); ++i) {
124 if (comp->edges[i]->pos == pos) {
125 delete comp->edges[i];
126 continue;
128 if (comp->edges[i]->pos > pos)
129 comp->edges[i]->pos += source->arity - 1;
130 new_edges.push_back(comp->edges[i]);
132 for (int i = 0; i < source->edges.size(); ++i) {
133 edge *old_e = source->edges[i];
134 edge *e = new edge;
135 e->source = old_e->source;
136 e->pos = old_e->pos + pos;
137 e->relation = isl_map_apply_range(
138 isl_map_copy(map),
139 isl_map_copy(old_e->relation));
140 new_edges.push_back(e);
142 isl_map_free(map);
143 comp->edges = new_edges;
144 comp->arity += source->arity - 1;
147 /* Split computation comp into a part that always takes the edge e
148 * and one that never takes that edge.
149 * The second is returned and the initial computation is modified
150 * to match the first.
152 * We need to be careful not to destroy e, as it is still used
153 * by the calling method.
155 computation *dependence_graph::split_comp(computation *comp, edge *e)
157 computation *dup = new computation;
158 dup->original = comp->original ? comp->original : comp;
159 dup->operation = strdup(comp->operation);
160 dup->arity = comp->arity;
161 dup->location = comp->location;
163 isl_map *map = isl_map_copy(e->relation);
164 map = isl_map_reverse(map);
165 isl_set *dom = isl_set_apply(isl_set_copy(e->source->domain), map);
166 dup->domain = isl_set_subtract(isl_set_copy(comp->domain),
167 isl_set_copy(dom));
168 comp->domain = isl_set_intersect(comp->domain, dom);
170 std::vector<struct edge *> old_edges = comp->edges;
171 comp->edges.clear();
173 for (int i = 0; i < old_edges.size(); ++i) {
174 if (old_edges[i] == e) {
175 comp->edges.push_back(e);
176 continue;
179 edge *e = old_edges[i];
180 isl_map *map, *map_dup;
181 map = isl_map_copy(e->relation);
182 map_dup = isl_map_copy(map);
184 map = isl_map_intersect_domain(map, isl_set_copy(comp->domain));
185 map_dup = isl_map_intersect_domain(map_dup,
186 isl_set_copy(dup->domain));
187 add_split_edge(comp, e->source, e->pos, map, NULL);
188 add_split_edge(dup, e->source, e->pos, map_dup, NULL);
189 delete e;
192 vertices.push_back(dup);
193 return dup;
196 /* If any edge from comp points to comp_orig, then split it
197 * into two edges, one still pointing to comp_orig and the
198 * other pointing to comp_dup.
199 * comp_orig and comp_dup are assumed to have disjoint domains
200 * and the edge relations are adjusted according to these domains.
202 void dependence_graph::split_edges(computation *comp,
203 computation *comp_orig, computation *comp_dup)
205 std::vector<struct edge *> old_edges = comp->edges;
206 comp->edges.clear();
208 for (int i = 0; i < old_edges.size(); ++i) {
209 edge *e = old_edges[i];
211 if (e->source != comp_orig) {
212 comp->edges.push_back(e);
213 continue;
216 isl_map *map_orig, *map_dup;
217 map_orig = isl_map_copy(e->relation);
218 map_dup = isl_map_copy(map_orig);
219 map_orig = isl_map_intersect_range(map_orig,
220 isl_set_copy(comp_orig->domain));
221 map_dup = isl_map_intersect_range(map_dup,
222 isl_set_copy(comp_dup->domain));
223 add_split_edge(comp, comp_orig, e->pos, map_orig, NULL);
224 add_split_edge(comp, comp_dup, e->pos, map_dup, NULL);
225 delete e;
229 void dependence_graph::split_edges(computation *comp_orig,
230 computation *comp_dup)
232 split_edges(out, comp_orig, comp_dup);
233 for (int i = 0; i < vertices.size(); ++i)
234 split_edges(vertices[i], comp_orig, comp_dup);
237 /* Replace all nested calls of an associative operator,
238 * by a call with the nested call spliced into the first call.
240 * For each nested call we find, we first check if there are
241 * any other edges with the same position. If not, then
242 * the nested call is performed for each iteration of the computation
243 * and we can simplify splice the nested call.
245 * Otherwise, we first create a duplicate computation for the iterations
246 * that do not take the found edge and adjust the edges of both computations
247 * to their domains. This meand that the edge corresponding to the
248 * nested call will no longer appear in the duplicated computation
249 * and the other edges with the same position will no longer appear
250 * in the original computation.
251 * Then we splice the nested call in the original computation.
252 * Finally, we split all edges that pointed to the original computation
253 * into two edges, one going to the original computation and one
254 * going to the duplicated computation.
256 void dependence_graph::flatten_associative_operators(
257 const std::vector<const char *> &associative)
259 edge *e;
260 computation *comp;
262 while ((comp = associative_node(&e, associative)) != NULL) {
263 computation *comp_dup = NULL;
264 int j;
265 for (j = 0; j < comp->edges.size(); ++j) {
266 if (comp->edges[j] == e)
267 continue;
268 if (comp->edges[j]->pos == e->pos)
269 break;
271 if (j != comp->edges.size())
272 comp_dup = split_comp(comp, e);
273 splice(comp, e);
274 if (comp_dup)
275 split_edges(comp, comp_dup);
279 struct eq_node {
280 private:
281 isl_map *got;
282 isl_map *lost;
283 public:
284 computation *comp[2];
285 isl_map *want;
286 isl_map *need;
287 std::set<eq_node *> assumed;
288 unsigned closed : 1;
289 unsigned narrowing : 1;
290 unsigned widening : 1;
291 unsigned invalidated : 1;
292 unsigned reset : 1;
293 std::string trace;
294 isl_map *lost_sample;
296 eq_node(computation *c1, computation *c2,
297 isl_map *w) : want(w), closed(0),
298 widening(0), narrowing(0), invalidated(0), reset(0),
299 need(NULL), got(NULL), lost(NULL), lost_sample(NULL) {
300 comp[0] = c1;
301 comp[1] = c2;
303 bool is_still_valid();
304 void collect_open_assumed(std::set<eq_node *> &c);
305 ~eq_node() {
306 isl_map_free(want);
307 isl_map_free(need);
308 isl_map_free(got);
309 isl_map_free(lost);
310 isl_map_free(lost_sample);
312 void compute_got() {
313 assert(lost);
314 got = isl_map_copy(want);
315 got = isl_map_subtract(got, isl_map_copy(lost));
317 isl_map *get_got() {
318 if (!got)
319 compute_got();
320 return isl_map_copy(got);
322 isl_map *peek_got() {
323 if (!got)
324 compute_got();
325 return got;
327 void compute_lost() {
328 assert(got);
329 lost = isl_map_copy(want);
330 lost = isl_map_subtract(lost, isl_map_copy(got));
332 isl_map *get_lost() {
333 if (!lost)
334 compute_lost();
335 return isl_map_copy(lost);
337 isl_map *peek_lost() {
338 if (!lost)
339 compute_lost();
340 return lost;
342 void set_got(isl_map *got) {
343 isl_map_free(this->lost);
344 isl_map_free(this->got);
345 this->lost = NULL;
346 this->got = got;
348 void set_lost(isl_map *lost) {
349 isl_map_free(this->lost);
350 isl_map_free(this->got);
351 this->got = NULL;
352 this->lost = lost;
354 std::string to_string();
357 std::string eq_node::to_string()
359 std::ostringstream strm;
360 strm << comp[0]->location << "," << comp[0]->operation
361 << "/" << comp[0]->arity;
362 strm << " <-> ";
363 strm << comp[1]->location << "," << comp[1]->operation
364 << "/" << comp[1]->arity;
365 strm << std::endl;
366 return strm.str();
369 bool eq_node::is_still_valid()
371 if (invalidated)
372 return 0;
374 std::set<eq_node *>::iterator i;
375 for (i = assumed.begin(); i != assumed.end(); ++i) {
376 assert(*i != this);
377 if ((*i)->reset ||
378 ((*i)->closed && !(*i)->is_still_valid())) {
379 invalidated = 1;
380 return 0;
383 return 1;
386 void eq_node::collect_open_assumed(std::set<eq_node *> &c)
388 std::set<eq_node *>::iterator i;
389 for (i = assumed.begin(); i != assumed.end(); ++i) {
390 if ((*i)->closed)
391 (*i)->collect_open_assumed(c);
392 else
393 c.insert(*i);
397 /* A comp_pair contains all the edges that have the same pair
398 * of computations.
400 struct comp_pair {
401 computation *comp[2];
403 std::vector<eq_node *> nodes;
405 eq_node *tabled(eq_node *node);
406 eq_node *last_ancestor(eq_node *node);
407 ~comp_pair();
410 comp_pair::~comp_pair()
412 std::vector<eq_node *>::iterator i;
413 for (i = nodes.begin(); i != nodes.end(); ++i)
414 delete *i;
417 eq_node *comp_pair::tabled(eq_node *node)
419 std::vector<eq_node *>::iterator i;
421 for (i = nodes.begin(); i != nodes.end(); ++i) {
422 if (*i == node)
423 continue;
424 if (!(*i)->closed)
425 continue;
426 if (!(*i)->is_still_valid())
427 continue;
428 int is_subset;
429 is_subset = isl_map_is_subset(node->want, (*i)->want);
430 assert(is_subset >= 0);
431 if (!is_subset)
432 continue;
433 return *i;
435 return NULL;
438 eq_node *comp_pair::last_ancestor(eq_node *node)
440 std::vector<eq_node *>::reverse_iterator i;
441 for (i = nodes.rbegin(); i != nodes.rend(); ++i) {
442 if (*i == node)
443 continue;
444 if ((*i)->closed)
445 continue;
446 return *i;
448 return NULL;
451 typedef std::pair<computation *, computation *> computation_pair;
452 typedef std::map<computation_pair, comp_pair *> c2p_t;
454 struct equivalence_checker {
455 c2p_t c2p;
456 struct options *options;
457 int widenings;
458 int narrowings;
460 equivalence_checker(struct options *o) :
461 options(o), widenings(0), narrowings(0) {}
462 comp_pair *get_comp_pair(eq_node *node);
463 void handle(eq_node *node);
464 void dismiss(eq_node *node);
465 void handle_propagation(eq_node *node);
466 void handle_copy(eq_node *node, int i);
467 void handle_widening(eq_node *node, eq_node *a);
468 void handle_with_ancestor(eq_node *node, eq_node *a);
469 isl_map *lost_at_pos(eq_node *node, int pos1, int pos2);
470 void handle_propagation_same(eq_node *node);
471 void handle_propagation_comm(eq_node *node);
472 void handle_narrowing(eq_node *node);
474 isl_map *lost_from_propagation(eq_node *parent,
475 eq_node *node, edge *e1, edge *e2);
477 void init_trace(eq_node *node);
478 void extend_trace_propagation(eq_node *node, eq_node *child,
479 edge *e1, edge *e2);
480 void extend_trace_widening(eq_node *node, eq_node *child);
482 bool is_commutative(const char *op);
484 ~equivalence_checker();
487 equivalence_checker::~equivalence_checker()
489 c2p_t::iterator i;
491 for (i = c2p.begin(); i != c2p.end(); ++i)
492 delete (*i).second;
495 bool equivalence_checker::is_commutative(const char *op)
497 std::vector<const char *>::iterator iter;
499 for (iter = options->ops->commutative.begin();
500 iter != options->ops->commutative.end(); ++iter)
501 if (!strcmp(*iter, op))
502 return 1;
503 return 0;
506 comp_pair *equivalence_checker::get_comp_pair(eq_node *node)
508 c2p_t::iterator i;
509 comp_pair *cp;
511 i = c2p.find(computation_pair(node->comp[0], node->comp[1]));
512 if (i == c2p.end()) {
513 cp = new comp_pair;
514 c2p[computation_pair(node->comp[0], node->comp[1])] = cp;
515 } else
516 cp = (*i).second;
517 return cp;
520 void equivalence_checker::dismiss(eq_node *node)
522 /* unclosed nodes weren't added to c2p->nodes */
523 if (node && !node->closed)
524 delete node;
527 void equivalence_checker::init_trace(eq_node *node)
529 isl_ctx *ctx;
530 isl_printer *prn;
532 if (!options->trace_error)
533 return;
534 if (isl_map_is_empty(node->peek_lost()))
535 return;
536 node->trace = node->to_string();
537 std::cerr << node->trace;
538 isl_map_free(node->lost_sample);
539 node->lost_sample = isl_map_from_basic_map(
540 isl_map_sample(node->get_lost()));
541 ctx = isl_map_get_ctx(node->lost_sample);
542 prn = isl_printer_to_file(ctx, stderr);
543 prn = isl_printer_print_map(prn, node->lost_sample);
544 prn = isl_printer_end_line(prn);
545 isl_printer_free(prn);
548 void equivalence_checker::extend_trace_propagation(eq_node *node,
549 eq_node *child, edge *e1, edge *e2)
551 isl_ctx *ctx;
552 isl_printer *prn;
554 if (!options->trace_error)
555 return;
556 if (!child->lost_sample)
557 return;
558 if (node->lost_sample)
559 return;
560 node->trace = child->trace + node->to_string();
561 std::cerr << node->trace;
562 node->lost_sample = isl_map_copy(child->lost_sample);
563 if (e1)
564 node->lost_sample = isl_map_apply_domain(node->lost_sample,
565 isl_map_reverse(
566 isl_map_copy(e1->relation)));
567 if (e2)
568 node->lost_sample = isl_map_apply_range(node->lost_sample,
569 isl_map_reverse(
570 isl_map_copy(e2->relation)));
571 node->lost_sample = isl_map_intersect(node->lost_sample,
572 isl_map_copy(node->want));
573 node->lost_sample = isl_map_from_basic_map(
574 isl_map_sample(node->lost_sample));
575 ctx = isl_map_get_ctx(node->lost_sample);
576 prn = isl_printer_to_file(ctx, stderr);
577 prn = isl_printer_print_map(prn, node->lost_sample);
578 prn = isl_printer_end_line(prn);
579 isl_printer_free(prn);
582 void equivalence_checker::extend_trace_widening(eq_node *node, eq_node *child)
584 isl_ctx *ctx;
585 isl_printer *prn;
587 if (!options->trace_error)
588 return;
589 if (!child->lost_sample)
590 return;
591 if (isl_map_is_empty(node->peek_lost()))
592 return;
593 node->trace = child->trace + node->to_string();
594 std::cerr << node->trace;
595 node->lost_sample = isl_map_from_basic_map(
596 isl_map_sample(node->get_lost()));
597 ctx = isl_map_get_ctx(node->lost_sample);
598 prn = isl_printer_to_file(ctx, stderr);
599 prn = isl_printer_print_map(prn, node->lost_sample);
600 prn = isl_printer_end_line(prn);
601 isl_printer_free(prn);
604 /* Check for which subset (got) of the want relation equivelance
605 * holds for the pair of computations in the equivalence node.
607 * We first handle the easy cases: empty want, input computations
608 * and tabled nodes.
610 * If the current node is not a narrowing or a widening node
611 * and we can find an ancestor with the same pair of computations,
612 * then we will try to apply induction or widening in handle_with_ancestor.
613 * However, if the requested relation (want) of the ancestor is a strict
614 * subset of that of the current node, then we have already applied
615 * widening on an intermediate node (with a different pair of computations)
616 * so we shouldn't apply widening again (and we can't apply induction
617 * because the relation of the ancestor is a strict subset).
619 * In all other cases we try to apply propagation in handle_propagation.
621 void equivalence_checker::handle(eq_node *node)
623 computation *comp1 = node->comp[0];
624 computation *comp2 = node->comp[1];
626 if (!comp1->is_copy() && !comp2->is_copy() &&
627 (strcmp(comp1->operation, comp2->operation) ||
628 comp1->arity != comp2->arity)) {
629 node->set_lost(isl_map_copy(node->want));
630 init_trace(node);
631 return;
634 eq_node *s = NULL;
635 eq_node *a = NULL;
637 isl_map *want = node->want;
638 node->need = isl_map_empty(isl_map_get_space(want));
639 int empty = isl_map_is_empty(want);
640 assert(empty >= 0);
641 if (empty) {
642 node->set_lost(isl_map_empty(isl_map_get_space(want)));
643 return;
645 if (node->comp[0]->is_input() && node->comp[1]->is_input()) {
646 if (strcmp(node->comp[0]->operation, node->comp[1]->operation)) {
647 node->set_lost(isl_map_copy(want));
648 return;
650 node->set_lost(isl_map_subtract(
651 isl_map_copy(node->want),
652 isl_map_identity(isl_map_get_space(want))));
653 return;
656 comp_pair *cp = get_comp_pair(node);
657 if ((s = cp->tabled(node)) != NULL) {
658 node->set_lost(isl_map_intersect(s->get_lost(),
659 isl_map_copy(node->want)));
660 s->collect_open_assumed(node->assumed);
661 return;
664 cp->nodes.push_back(node);
666 if (!node->narrowing && !node->widening &&
667 (a = cp->last_ancestor(node)) != NULL) {
668 int is_subset;
669 is_subset = isl_map_is_strict_subset(a->want, node->want);
670 assert(is_subset >= 0);
671 if (is_subset)
672 handle_propagation(node);
673 else
674 handle_with_ancestor(node, a);
675 } else
676 handle_propagation(node);
677 node->closed = 1;
680 /* Check if we can apply propagation to prove equivalence of the given node.
681 * First check if we need to apply copy propagation and if not
682 * check if the operations are the same and apply "regular" propagation.
684 * After we get back, we need to check that any induction hypotheses
685 * we have used in the process of proving the node hold.
686 * If not, we replace the obtained relation (got) by that
687 * of a narrowing node in handle_narrowin.
689 void equivalence_checker::handle_propagation(eq_node *node)
691 if (node->comp[0]->is_copy())
692 handle_copy(node, 0);
693 else if (node->comp[1]->is_copy())
694 handle_copy(node, 1);
695 else
696 handle_propagation_same(node);
697 node->assumed.erase(node);
698 isl_map *lost = node->get_lost();
699 lost = isl_map_intersect(lost, isl_map_copy(node->need));
700 int is_empty = isl_map_is_empty(lost);
701 isl_map_free(lost);
702 assert(is_empty >= 0);
703 if (!is_empty)
704 handle_narrowing(node);
707 /* When applying the mappings on expansion edges to both sides of
708 * an equivalence relation, each element in the original equivalence
709 * relation is mapped to many elements on both sides. For example,
710 * if the original equivalence relation has i R i' for i = i', then the new
711 * equivalence relation may have (i,j) R (i',j') for i = i' and
712 * 0 <= j,j' <= 10, expressing that all elements of the row read by
713 * iteration i should be equal to all elements of the row read by i'.
714 * Instead, we want to express that each individual element read by i'
715 * should be equal to the corresponding element read by i'.
716 * In the example, we need to introduce the equality j = j'.
718 * We first check that the number of dimensions added by the expansions
719 * is the same in both programs. Then we construct an identity relation
720 * between this number of dimensions, lift it to the space of the
721 * new equivalence relation and take the intersection.
723 * We have to be careful, though, that we don't loose any array elements
724 * by taking the intersection. In particular, the expansion maps
725 * a given read operation to iterations of an extended domain that
726 * reads all the individual array elements. The read operation is
727 * only equivalent to some other read operation if all the reads
728 * of the individual array elements are equivalent.
729 * We therefore need to make sure that adding the equalities does
730 * not remove any of the reads in the extended domain.
731 * We do this by projecting out the extra dimensions on one side
732 * from both "want" and "want \cap eq". The resulting maps should
733 * be the same. We do this check for both sides separately.
735 * If either of the above tests fails, then we simply return
736 * the original over-ambitious want.
738 static isl_map *expansion_want(isl_map *want, edge *e1, edge *e2)
740 unsigned n_in, n_out;
741 unsigned s_dim_1 = isl_map_dim(e1->relation, isl_dim_out) -
742 isl_map_dim(e1->relation, isl_dim_in);
743 unsigned s_dim_2 = isl_map_dim(e2->relation, isl_dim_out) -
744 isl_map_dim(e2->relation, isl_dim_in);
745 if (s_dim_1 != s_dim_2)
746 return want;
748 isl_space *dim = isl_map_get_space(e1->relation);
749 dim = isl_space_drop_dims(dim, isl_dim_in,
750 0, isl_space_dim(dim, isl_dim_in));
751 dim = isl_space_drop_dims(dim, isl_dim_out,
752 0, isl_space_dim(dim, isl_dim_out));
753 dim = isl_space_add_dims(dim, isl_dim_in, s_dim_1);
754 dim = isl_space_add_dims(dim, isl_dim_out, s_dim_1);
755 isl_basic_map *id = isl_basic_map_identity(dim);
757 dim = isl_space_range(isl_map_get_space(e1->relation));
758 isl_basic_map *s_1 = isl_basic_map_identity(isl_space_map_from_set(dim));
759 s_1 = isl_basic_map_remove_dims(s_1, isl_dim_in, 0,
760 isl_basic_map_dim(s_1, isl_dim_in) - s_dim_1);
761 id = isl_basic_map_apply_domain(id, s_1);
763 dim = isl_space_range(isl_map_get_space(e2->relation));
764 isl_basic_map *s_2 = isl_basic_map_identity(isl_space_map_from_set(dim));
765 s_2 = isl_basic_map_remove_dims(s_2, isl_dim_in, 0,
766 isl_basic_map_dim(s_2, isl_dim_in) - s_dim_2);
767 id = isl_basic_map_apply_range(id, s_2);
769 bool unmatched = false;
770 isl_map *matched_want;
771 isl_map *proj_want, *proj_matched;
772 matched_want = isl_map_intersect(isl_map_copy(want),
773 isl_map_from_basic_map(id));
775 n_in = isl_map_dim(want, isl_dim_in);
776 proj_want = isl_map_remove_dims(isl_map_copy(want),
777 isl_dim_in, n_in - s_dim_1, s_dim_1);
778 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
779 isl_dim_in, n_in - s_dim_1, s_dim_1);
780 if (!isl_map_is_equal(proj_want, proj_matched))
781 unmatched = true;
782 isl_map_free(proj_want);
783 isl_map_free(proj_matched);
785 n_out = isl_map_dim(want, isl_dim_out);
786 proj_want = isl_map_remove_dims(isl_map_copy(want),
787 isl_dim_out, n_out - s_dim_2, s_dim_2);
788 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
789 isl_dim_out, n_out - s_dim_2, s_dim_2);
790 if (!isl_map_is_equal(proj_want, proj_matched))
791 unmatched = true;
792 isl_map_free(proj_want);
793 isl_map_free(proj_matched);
795 if (unmatched) {
796 isl_map_free(matched_want);
797 return matched_want;
799 isl_map_free(want);
800 return matched_want;
803 static isl_map *propagation_want(eq_node *node, edge *e1, edge *e2)
805 isl_map *new_want;
807 new_want = isl_map_copy(node->want);
808 if (e1)
809 new_want = isl_map_apply_domain(new_want,
810 isl_map_copy(e1->relation));
811 if (e2)
812 new_want = isl_map_apply_range(new_want,
813 isl_map_copy(e2->relation));
815 if (e1 && e2 && e1->type == edge::expansion)
816 new_want = expansion_want(new_want, e1, e2);
818 return isl_map_detect_equalities(new_want);
821 static eq_node *propagation_node(eq_node *node, edge *e1, edge *e2)
823 isl_map *new_want;
824 computation *comp1 = e1 ? e1->source : node->comp[0];
825 computation *comp2 = e2 ? e2->source : node->comp[1];
827 new_want = propagation_want(node, e1, e2);
828 eq_node *child = new eq_node(comp1, comp2, new_want);
829 return child;
832 static isl_map *set_to_empty(isl_map *map)
834 isl_space *space;
835 space = isl_map_get_space(map);
836 isl_map_free(map);
837 return isl_map_empty(space);
840 /* Propagate "lost" from child back to parent.
842 isl_map *equivalence_checker::lost_from_propagation(eq_node *parent,
843 eq_node *node, edge *e1, edge *e2)
845 isl_map *new_lost;
846 new_lost = node->get_lost();
848 if (e1)
849 new_lost = isl_map_apply_domain(new_lost,
850 isl_map_reverse(isl_map_copy(e1->relation)));
851 if (e2)
852 new_lost = isl_map_apply_range(new_lost,
853 isl_map_reverse(isl_map_copy(e2->relation)));
855 new_lost = isl_map_intersect(new_lost, isl_map_copy(parent->want));
857 extend_trace_propagation(parent, node, e1, e2);
859 return new_lost;
862 void equivalence_checker::handle_copy(eq_node *node, int i)
864 std::vector<edge *>::iterator iter;
865 computation *copy = node->comp[i];
866 isl_map *want = node->want;
867 isl_map *lost;
869 lost = isl_map_empty(isl_map_get_space(want));
871 for (iter = copy->edges.begin(); iter != copy->edges.end(); ++iter) {
872 edge *e = *iter;
873 eq_node *child;
874 if (i == 0)
875 child = propagation_node(node, e, NULL);
876 else
877 child = propagation_node(node, NULL, e);
879 if (!child)
880 continue;
882 handle(child);
884 isl_map *new_lost;
885 if (i == 0)
886 new_lost = lost_from_propagation(node, child, e, NULL);
887 else
888 new_lost = lost_from_propagation(node, child, NULL, e);
889 dismiss(child);
891 lost = isl_map_union_disjoint(lost, new_lost);
893 node->set_lost(lost);
896 /* Compute and return the part of "want" that is lost when propagating
897 * over all edges of argument position pos1 in the first program
898 * and argument position pos2 in the second program.
899 * Since the domains of the dependence mappings on edges with the same
900 * argument position partition the domain of the computation (and are
901 * therefore disjoint), we simply need to take the disjoint union
902 * of all losts over all pairs of edges.
904 isl_map *equivalence_checker::lost_at_pos(eq_node *node, int pos1, int pos2)
906 isl_map *lost;
907 lost = isl_map_empty(isl_map_get_space(node->want));
909 for (int i = 0; i < node->comp[0]->edges.size(); ++i) {
910 edge *e1 = node->comp[0]->edges[i];
911 if (e1->pos != pos1)
912 continue;
913 for (int j = 0; j < node->comp[1]->edges.size(); ++j) {
914 edge *e2 = node->comp[1]->edges[j];
915 if (e2->pos != pos2)
916 continue;
918 eq_node *child;
919 isl_map *new_lost = NULL;
921 child = propagation_node(node, e1, e2);
922 handle(child);
923 new_lost = lost_from_propagation(node, child, e1, e2);
925 lost = isl_map_union_disjoint(lost, new_lost);
927 std::set<eq_node *>::iterator k;
928 for (k = child->assumed.begin();
929 k != child->assumed.end(); ++k)
930 node->assumed.insert(*k);
931 dismiss(child);
934 return lost;
937 /* Compute the lost that results from propagation on a pair of
938 * computations with a commutative operation.
940 * We first compute the losts for each pair of argument positions
941 * and store the result in lost[pos1][pos2].
942 * Then we perform a backtracking search over all permutations
943 * of the arguments. For each permutation, we compute the lost
944 * relation as the union of the losts over all arguments.
945 * The final lost is the intersection of all these losts over all
946 * permutations.
948 void equivalence_checker::handle_propagation_comm(eq_node *node)
950 int trace_error = options->trace_error;
951 trace_error = 0;
953 unsigned r = node->comp[0]->arity;
955 std::vector<std::vector<isl_map *> > lost;
957 for (int i = 0; i < r; ++i) {
958 std::vector<isl_map *> row;
959 for (int j = 0; j < r; ++j) {
960 isl_map *pos_lost;
961 pos_lost = lost_at_pos(node, i, j);
962 row.push_back(pos_lost);
964 lost.push_back(row);
967 int level;
968 std::vector<int> perm;
969 std::vector<isl_map *> lost_at;
970 for (level = 0; level < r; ++level) {
971 perm.push_back(0);
972 lost_at.push_back(NULL);
975 isl_map *total_lost;
976 total_lost = isl_map_copy(node->want);
978 level = 0;
979 while (level >= 0) {
980 if (perm[level] == r) {
981 perm[level] = 0;
982 --level;
983 if (level >= 0)
984 ++perm[level];
985 continue;
987 int l;
988 for (l = 0; l < level; ++l)
989 if (perm[l] == perm[level])
990 break;
991 if (l != level) {
992 ++perm[level];
993 continue;
996 isl_map_free(lost_at[level]);
997 lost_at[level] = isl_map_copy(lost[level][perm[level]]);
998 if (level != 0)
999 lost_at[level] = isl_map_union(lost_at[level],
1000 isl_map_copy(lost_at[level - 1]));
1002 if (level < r - 1) {
1003 ++level;
1004 continue;
1007 lost_at[level] = isl_map_coalesce(lost_at[level]);
1008 total_lost = isl_map_intersect(total_lost,
1009 isl_map_copy(lost_at[level]));
1010 ++perm[level];
1013 for (int i = 0; i < r; ++i)
1014 isl_map_free(lost_at[i]);
1016 for (int i = 0; i < r; ++i)
1017 for (int j = 0; j < r; ++j)
1018 isl_map_free(lost[i][j]);
1020 node->set_lost(total_lost);
1022 options->trace_error = trace_error;
1023 init_trace(node);
1026 /* Compute the lost that results from propagation on a pair of
1027 * computations.
1029 * First, for functions of zero arity (i.e., constants), equivalence
1030 * always holds and the lost relation is empty.
1031 * For commutative operations, the computation is delegated
1032 * to handle_propagation_comm.
1033 * Otherwise, we simply take the union of the losts over each
1034 * argument position (always taking the same argument position
1035 * in both programs).
1037 void equivalence_checker::handle_propagation_same(eq_node *node)
1039 unsigned r = node->comp[0]->arity;
1040 if (r == 0) {
1041 node->set_lost(isl_map_empty(isl_map_get_space(node->want)));
1042 return;
1044 if (is_commutative(node->comp[0]->operation)) {
1045 handle_propagation_comm(node);
1046 return;
1049 isl_map *lost;
1050 lost = isl_map_empty(isl_map_get_space(node->want));
1051 for (int i = 0; i < r; ++i) {
1052 isl_map *pos_lost;
1053 pos_lost = lost_at_pos(node, i, i);
1054 lost = isl_map_union(lost, pos_lost);
1056 lost = isl_map_coalesce(lost);
1057 node->set_lost(lost);
1060 void equivalence_checker::handle_widening(eq_node *node, eq_node *a)
1062 isl_map *wants;
1063 isl_map *aff;
1065 wants = isl_map_union(isl_map_copy(node->want), isl_map_copy(a->want));
1066 aff = isl_map_from_basic_map(isl_map_affine_hull(wants));
1067 aff = isl_map_intersect_domain(aff,
1068 isl_set_copy(node->comp[0]->domain));
1069 aff = isl_map_intersect_range(aff,
1070 isl_set_copy(node->comp[1]->domain));
1072 eq_node *child = new eq_node(node->comp[0], node->comp[1], aff);
1073 child->widening = 1;
1074 widenings++;
1075 handle(child);
1076 node->set_lost(isl_map_intersect(child->get_lost(),
1077 isl_map_copy(node->want)));
1078 extend_trace_widening(node, child);
1079 std::set<eq_node *>::iterator i;
1080 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1081 node->assumed.insert(*i);
1082 dismiss(child);
1085 /* Perform induction, if possible, and widening if we have to.
1087 void equivalence_checker::handle_with_ancestor(eq_node *node, eq_node *a)
1089 if (a->narrowing || isl_map_is_subset(node->want, a->want)) {
1090 isl_map *need;
1091 need = isl_map_intersect(isl_map_copy(node->want),
1092 isl_map_copy(a->want));
1093 node->set_lost(isl_map_subtract(isl_map_copy(node->want),
1094 isl_map_copy(a->want)));
1095 node->assumed.insert(a);
1096 a->need = isl_map_union(a->need, need);
1097 } else {
1098 handle_widening(node, a);
1102 struct narrowing_data {
1103 eq_node *node;
1104 equivalence_checker *ec;
1105 isl_map *new_got;
1108 static isl_stat basic_handle_narrowing(__isl_take isl_basic_map *bmap,
1109 void *user)
1111 narrowing_data *data = (narrowing_data *)user;
1113 eq_node *child;
1114 child = new eq_node(data->node->comp[0], data->node->comp[1],
1115 isl_map_from_basic_map(bmap));
1116 child->narrowing = 1;
1117 data->ec->narrowings++;
1118 data->ec->handle(child);
1119 data->new_got = isl_map_union(data->new_got, child->get_got());
1121 std::set<eq_node *>::iterator i;
1122 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1123 data->node->assumed.insert(*i);
1124 data->ec->dismiss(child);
1126 return isl_stat_ok;
1129 /* Construct and handle narrowing nodes for the given node.
1131 * If the node itself was already a narrowing node, then we
1132 * simply return the empty relation.
1133 * Otherwise, we consider each basic relation in the obtained
1134 * relation, construct a new node with that basic relation as
1135 * requested relation and take the union of all obtained relations.
1137 void equivalence_checker::handle_narrowing(eq_node *node)
1139 node->reset = 1;
1140 node->assumed.clear();
1142 isl_map *want = node->want;
1143 isl_map *new_got = isl_map_empty(isl_map_get_space(want));
1144 if (!options->narrowing || node->narrowing) {
1145 node->set_got(new_got);
1146 return;
1149 narrowing_data data = { node, this, new_got };
1150 isl_map_foreach_basic_map(node->peek_got(), &basic_handle_narrowing,
1151 &data);
1152 node->set_got(data.new_got);
1155 static __isl_give isl_union_set *update_result(__isl_take isl_union_set *res,
1156 const char *array_name, isl_map *map, int first, int n)
1158 if (isl_map_is_empty(map)) {
1159 isl_map_free(map);
1160 return res;
1163 isl_set *range = isl_map_range(map);
1164 range = isl_set_remove_dims(range, isl_dim_set, first, n);
1165 range = isl_set_coalesce(range);
1166 range = isl_set_set_tuple_name(range, array_name);
1167 res = isl_union_set_add_set(res, range);
1169 return res;
1172 struct check_equivalence_data {
1173 dependence_graph *dg1;
1174 dependence_graph *dg2;
1175 equivalence_checker *ec;
1176 isl_map *got;
1177 isl_map *lost;
1180 static isl_stat basic_check(__isl_take isl_basic_map *bmap, void *user)
1182 check_equivalence_data *data = (check_equivalence_data *)user;
1184 eq_node *root = new eq_node(data->dg1->out, data->dg2->out,
1185 isl_map_from_basic_map(bmap));
1186 data->ec->handle(root);
1187 data->got = isl_map_union_disjoint(data->got, root->get_got());
1188 data->lost = isl_map_union_disjoint(data->lost, root->get_lost());
1189 data->ec->dismiss(root);
1191 return isl_stat_ok;
1194 static int check_equivalence_array(isl_ctx *ctx, equivalence_checker *ec,
1195 dependence_graph *dg1, dependence_graph *dg2, int array1, int array2,
1196 isl_union_set **proved, isl_union_set **not_proved)
1198 unsigned n_in;
1199 const char *array_name = dg1->output_arrays[array1];
1200 isl_set *out1 = isl_set_copy(dg1->out->domain);
1201 isl_set *out2 = isl_set_copy(dg2->out->domain);
1202 unsigned dim1 = isl_set_n_dim(out1);
1203 unsigned dim2 = isl_set_n_dim(out2);
1204 unsigned array_dim = dg1->output_array_dims[array1];
1205 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1206 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1207 out1 = isl_set_remove_dims(out1, isl_dim_set, dim1 - 1, 1);
1208 out2 = isl_set_remove_dims(out2, isl_dim_set, dim2 - 1, 1);
1209 int equal = isl_set_is_equal(out1, out2);
1210 isl_set_free(out1);
1211 isl_set_free(out2);
1212 assert(equal >= 0);
1213 if (!equal) {
1214 fprintf(stderr, "different output domains for array %s\n",
1215 array_name);
1216 return -1;
1219 out1 = isl_set_copy(dg1->out->domain);
1220 out2 = isl_set_copy(dg2->out->domain);
1221 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1222 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1223 out1 = isl_set_coalesce(out1);
1224 out2 = isl_set_coalesce(out2);
1225 isl_space *dim = isl_space_map_from_set(isl_set_get_space(out2));
1226 isl_map *id = isl_map_identity(dim);
1227 n_in = isl_map_dim(id, isl_dim_in);
1228 id = isl_map_remove_dims(id, isl_dim_in, n_in - 1, 1);
1229 id = isl_map_apply_domain(id, isl_map_copy(id));
1230 id = isl_map_intersect_domain(id, out1);
1231 id = isl_map_intersect_range(id, out2);
1233 isl_map *got = isl_map_empty(isl_map_get_space(id));
1234 isl_map *lost = isl_map_copy(got);
1235 check_equivalence_data data = { dg1, dg2, ec, got, lost };
1236 isl_map_foreach_basic_map(id, &basic_check, &data);
1237 isl_map_free(id);
1238 *proved = update_result(*proved,
1239 array_name, data.got, array_dim, dim2 - array_dim);
1240 *not_proved = update_result(*not_proved,
1241 array_name, data.lost, array_dim, dim2 - array_dim);
1242 return 0;
1245 /* The input arrays of the two programs are supposed to be the same,
1246 * so they should at least have the same dimension. Make sure
1247 * this is true, because we depend on it later on.
1249 static int check_input_arrays(dependence_graph *dg1, dependence_graph *dg2)
1251 for (int i = 0; i < dg1->input_computations.size(); ++i)
1252 for (int j = 0; j < dg2->input_computations.size(); ++j) {
1253 if (strcmp(dg1->input_computations[i]->operation,
1254 dg2->input_computations[j]->operation))
1255 continue;
1256 if (dg1->input_computations[i]->dim ==
1257 dg2->input_computations[j]->dim)
1258 continue;
1259 fprintf(stderr,
1260 "input arrays \"%s\" do not have the same dimension\n",
1261 dg1->input_computations[i]->operation);
1262 return -1;
1265 return 0;
1268 static __isl_give isl_union_set *add_array(__isl_take isl_union_set *only,
1269 dependence_graph *dg, int array)
1271 isl_space *dim;
1272 isl_set *set;
1274 dim = isl_union_set_get_space(only);
1275 dim = isl_space_add_dims(dim, isl_dim_set, dg->output_array_dims[array]);
1276 dim = isl_space_set_tuple_name(dim, isl_dim_set, dg->output_arrays[array]);
1277 set = isl_set_universe(dim);
1278 only = isl_union_set_add_set(only, set);
1280 return only;
1283 static void print_results(const char *str, __isl_keep isl_union_set *only)
1285 isl_printer *prn;
1287 if (isl_union_set_is_empty(only))
1288 return;
1290 fprintf(stdout, "%s: '", str);
1291 prn = isl_printer_to_file(isl_union_set_get_ctx(only), stdout);
1292 prn = isl_printer_print_union_set(prn, only);
1293 isl_printer_free(prn);
1294 fprintf(stdout, "'\n");
1297 static int check_equivalence(isl_ctx *ctx,
1298 dependence_graph *dg1, dependence_graph *dg2, options *options)
1300 isl_space *dim;
1301 isl_set *context;
1302 isl_union_set *only1, *only2, *proved, *not_proved;
1303 dg1->flatten_associative_operators(options->ops->associative);
1304 dg2->flatten_associative_operators(options->ops->associative);
1305 equivalence_checker ec(options);
1306 int i1 = 0, i2 = 0;
1308 if (check_input_arrays(dg1, dg2))
1309 return -1;
1311 dim = isl_space_set_alloc(ctx, 0, 0);
1312 proved = isl_union_set_empty(isl_space_copy(dim));
1313 not_proved = isl_union_set_empty(isl_space_copy(dim));
1314 only1 = isl_union_set_empty(isl_space_copy(dim));
1315 only2 = isl_union_set_empty(dim);
1317 while (i1 < dg1->output_arrays.size() || i2 < dg2->output_arrays.size()) {
1318 int cmp;
1319 cmp = i1 == dg1->output_arrays.size() ? 1 :
1320 i2 == dg2->output_arrays.size() ? -1 :
1321 strcmp(dg1->output_arrays[i1], dg2->output_arrays[i2]);
1322 if (cmp < 0) {
1323 only1 = add_array(only1, dg1, i1);
1324 ++i1;
1325 } else if (cmp > 0) {
1326 only2 = add_array(only2, dg2, i2);
1327 ++i2;
1328 } else {
1329 check_equivalence_array(ctx, &ec, dg1, dg2, i1, i2,
1330 &proved, &not_proved);
1331 ++i1;
1332 ++i2;
1336 context = isl_set_union(isl_set_copy(dg1->context),
1337 isl_set_copy(dg2->context));
1338 proved = isl_union_set_gist_params(proved, isl_set_copy(context));
1339 not_proved = isl_union_set_gist_params(not_proved, context);
1341 print_results("Equivalence proved", proved);
1342 print_results("Equivalence NOT proved", not_proved);
1343 print_results("Only in program 1", only1);
1344 print_results("Only in program 2", only2);
1346 isl_union_set_free(proved);
1347 isl_union_set_free(not_proved);
1348 isl_union_set_free(only1);
1349 isl_union_set_free(only2);
1351 if (options->print_stats) {
1352 fprintf(stderr, "widenings: %d\n", ec.widenings);
1353 if (options->narrowing)
1354 fprintf(stderr, "narrowings: %d\n", ec.narrowings);
1357 return 0;
1360 static void dump_vertex(FILE *out, computation *comp)
1362 fprintf(out, "ND_%p [label = \"%d,%s/%d\"];\n",
1363 comp, comp->location, comp->operation, comp->arity);
1364 for (int i = 0; i < comp->edges.size(); ++i)
1365 fprintf(out, "ND_%p -> ND_%p%s;\n",
1366 comp, comp->edges[i]->source,
1367 comp->edges[i]->type == edge::expansion ?
1368 " [color=\"blue\"]" : "");
1371 static void dump_graph(FILE *out, dependence_graph *dg)
1373 fprintf(out, "digraph dummy {\n");
1374 dump_vertex(out, dg->out);
1375 for (int i = 0; i < dg->vertices.size(); ++i)
1376 dump_vertex(out, dg->vertices[i]);
1377 fprintf(out, "}\n");
1380 static void dump_graphs(dependence_graph **dg, struct options *options)
1382 int i;
1383 char path[PATH_MAX];
1385 if (!options->dump_graphs)
1386 return;
1388 for (i = 0; i < 2; ++i) {
1389 FILE *out;
1390 int s;
1391 s = snprintf(path, sizeof(path), "%s.dot", options->program[i]);
1392 assert(s < sizeof(path));
1393 out = fopen(path, "w");
1394 assert(out);
1395 dump_graph(out, dg[i]);
1396 fclose(out);
1400 void parse_ops(struct options *options) {
1401 char *tok;
1403 if (options->associative) {
1404 tok = strtok(options->associative, ",");
1405 options->ops->associative.push_back(strdup(tok));
1406 while ((tok = strtok(NULL, ",")) != NULL)
1407 options->ops->associative.push_back(strdup(tok));
1410 if (options->commutative) {
1411 tok = strtok(options->commutative, ",");
1412 options->ops->commutative.push_back(strdup(tok));
1413 while ((tok = strtok(NULL, ",")) != NULL)
1414 options->ops->commutative.push_back(strdup(tok));
1418 int main(int argc, char *argv[])
1420 struct options *options = options_new_with_defaults();
1421 struct isl_ctx *ctx;
1422 isl_set *context = NULL;
1423 dependence_graph *dg[2];
1425 argc = options_parse(options, argc, argv, ISL_ARG_ALL);
1426 parse_ops(options);
1428 ctx = isl_ctx_alloc_with_options(&options_args, options);
1429 if (!ctx) {
1430 fprintf(stderr, "Unable to allocate ctx\n");
1431 return -1;
1434 if (options->context)
1435 context = isl_set_read_from_str(ctx, options->context);
1437 pdg::PDG *pdg[2];
1438 unsigned out_dim = 0;
1439 for (int i = 0; i < 2; ++i) {
1440 FILE *in;
1441 in = fopen(options->program[i], "r");
1442 if (!in) {
1443 fprintf(stderr, "Unable to open %s\n", options->program[i]);
1444 return -1;
1446 pdg[i] = yaml::Load<pdg::PDG>(in, ctx);
1447 fclose(in);
1448 if (!pdg[i]) {
1449 fprintf(stderr, "Unable to read %s\n", options->program[i]);
1450 return -1;
1452 out_dim = update_out_dim(pdg[i], out_dim);
1454 if (context &&
1455 pdg[i]->params.size() != isl_set_dim(context, isl_dim_param)) {
1456 fprintf(stdout,
1457 "Parameter dimension mismatch; context ignored\n");
1458 isl_set_free(context);
1459 context = NULL;
1463 for (int i = 0; i < 2; ++i) {
1464 dg[i] = pdg_to_dg(pdg[i], out_dim, isl_set_copy(context));
1465 pdg[i]->free();
1466 delete pdg[i];
1469 dump_graphs(dg, options);
1471 int res = check_equivalence(ctx, dg[0], dg[1], options);
1473 for (int i = 0; i < 2; ++i)
1474 delete dg[i];
1475 isl_set_free(context);
1476 isl_ctx_free(ctx);
1478 return res;