update isl for support for recent clangs
[ppn.git] / eqv.cc
blob44b74cd2b07e2b398c2424006ce2b831e6f5710d
1 #include <limits.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <algorithm>
6 #include <map>
7 #include <set>
8 #include <vector>
9 #include <string>
10 #include <iostream>
11 #include <sstream>
13 #include <isa/yaml.h>
14 #include <isa/pdg.h>
16 #include <isl/set.h>
17 #include <isl/map.h>
18 #include <isl/constraint.h>
19 #include <isl/union_set.h>
20 #include "eqv_options.h"
21 #include "dependence_graph.h"
23 struct ops {
24 std::vector<const char *> associative;
25 std::vector<const char *> commutative;
27 int ops_init(void *user)
29 struct ops **ops = (struct ops **)user;
30 *ops = new struct ops;
31 return 0;
33 void ops_clear(void *user)
35 struct ops **ops = (struct ops **)user;
36 delete *ops;
39 /* For each basic map in relation, add an edge to comp with
40 * given source and pos.
42 static void add_split_edge(computation *comp, computation *source,
43 int pos, isl_map *relation, std::vector<computation **> *missing,
44 enum edge::type type = edge::normal)
46 edge *e = new edge;
47 e->source = source;
48 if (!source) {
49 assert(missing);
50 missing->push_back(&e->source);
52 e->pos = pos;
53 e->relation = relation;
54 e->type = type;
55 comp->edges.push_back(e);
58 static unsigned update_out_dim(pdg::PDG *pdg, unsigned out_dim)
60 for (int i = 0; i < pdg->arrays.size(); ++i) {
61 pdg::array *array = pdg->arrays[i];
62 if (array->type != pdg::array::output)
63 continue;
64 if (array->dims.size() + 1 > out_dim)
65 out_dim = array->dims.size() + 1;
68 return out_dim;
71 static bool is_associative(const std::vector<const char *> &associative,
72 const char *op)
74 std::vector<const char *>::const_iterator iter;
76 for (iter = associative.begin(); iter != associative.end(); ++iter)
77 if (!strcmp(*iter, op))
78 return 1;
79 return 0;
82 /* Return a computation that has an associative operation
83 * that takes a different computation with the same operation
84 * as one of its arguments. Also return the edge from the
85 * first to the second computation in *e.
87 * If no such computation exists, then return NULL.
89 computation *dependence_graph::associative_node(edge **e,
90 const std::vector<const char *> &associative)
92 for (int i = 0; i < vertices.size(); ++i) {
93 computation *comp = vertices[i];
94 if (!is_associative(associative, comp->operation))
95 continue;
96 for (int j = 0; j < comp->edges.size(); ++j) {
97 computation *other = comp->edges[j]->source;
98 if (comp->has_same_source(other))
99 continue;
100 if (strcmp(comp->operation, other->operation))
101 continue;
102 *e = comp->edges[j];
103 return comp;
106 return NULL;
109 /* Splice the source of edge edge into the position of edge e,
110 * removing all other edges with the same positions and bumping
111 * up edges with a greater position.
113 void dependence_graph::splice(computation *comp, edge *e)
115 computation *source = e->source;
116 int pos = e->pos;
117 isl_map *map = isl_map_copy(e->relation);
118 std::vector<struct edge *> new_edges;
120 for (int i = 0; i < comp->edges.size(); ++i) {
121 if (comp->edges[i]->pos == pos) {
122 delete comp->edges[i];
123 continue;
125 if (comp->edges[i]->pos > pos)
126 comp->edges[i]->pos += source->arity - 1;
127 new_edges.push_back(comp->edges[i]);
129 for (int i = 0; i < source->edges.size(); ++i) {
130 edge *old_e = source->edges[i];
131 edge *e = new edge;
132 e->source = old_e->source;
133 e->pos = old_e->pos + pos;
134 e->relation = isl_map_apply_range(
135 isl_map_copy(map),
136 isl_map_copy(old_e->relation));
137 new_edges.push_back(e);
139 isl_map_free(map);
140 comp->edges = new_edges;
141 comp->arity += source->arity - 1;
144 /* Split computation comp into a part that always takes the edge e
145 * and one that never takes that edge.
146 * The second is returned and the initial computation is modified
147 * to match the first.
149 * We need to be careful not to destroy e, as it is still used
150 * by the calling method.
152 computation *dependence_graph::split_comp(computation *comp, edge *e)
154 computation *dup = new computation;
155 dup->original = comp->original ? comp->original : comp;
156 dup->operation = strdup(comp->operation);
157 dup->arity = comp->arity;
158 dup->location = comp->location;
160 isl_map *map = isl_map_copy(e->relation);
161 map = isl_map_reverse(map);
162 isl_set *dom = isl_set_apply(isl_set_copy(e->source->domain), map);
163 dup->domain = isl_set_subtract(isl_set_copy(comp->domain),
164 isl_set_copy(dom));
165 comp->domain = isl_set_intersect(comp->domain, dom);
167 std::vector<struct edge *> old_edges = comp->edges;
168 comp->edges.clear();
170 for (int i = 0; i < old_edges.size(); ++i) {
171 if (old_edges[i] == e) {
172 comp->edges.push_back(e);
173 continue;
176 edge *e = old_edges[i];
177 isl_map *map, *map_dup;
178 map = isl_map_copy(e->relation);
179 map_dup = isl_map_copy(map);
181 map = isl_map_intersect_domain(map, isl_set_copy(comp->domain));
182 map_dup = isl_map_intersect_domain(map_dup,
183 isl_set_copy(dup->domain));
184 add_split_edge(comp, e->source, e->pos, map, NULL);
185 add_split_edge(dup, e->source, e->pos, map_dup, NULL);
186 delete e;
189 vertices.push_back(dup);
190 return dup;
193 /* If any edge from comp points to comp_orig, then split it
194 * into two edges, one still pointing to comp_orig and the
195 * other pointing to comp_dup.
196 * comp_orig and comp_dup are assumed to have disjoint domains
197 * and the edge relations are adjusted according to these domains.
199 void dependence_graph::split_edges(computation *comp,
200 computation *comp_orig, computation *comp_dup)
202 std::vector<struct edge *> old_edges = comp->edges;
203 comp->edges.clear();
205 for (int i = 0; i < old_edges.size(); ++i) {
206 edge *e = old_edges[i];
208 if (e->source != comp_orig) {
209 comp->edges.push_back(e);
210 continue;
213 isl_map *map_orig, *map_dup;
214 map_orig = isl_map_copy(e->relation);
215 map_dup = isl_map_copy(map_orig);
216 map_orig = isl_map_intersect_range(map_orig,
217 isl_set_copy(comp_orig->domain));
218 map_dup = isl_map_intersect_range(map_dup,
219 isl_set_copy(comp_dup->domain));
220 add_split_edge(comp, comp_orig, e->pos, map_orig, NULL);
221 add_split_edge(comp, comp_dup, e->pos, map_dup, NULL);
222 delete e;
226 void dependence_graph::split_edges(computation *comp_orig,
227 computation *comp_dup)
229 split_edges(out, comp_orig, comp_dup);
230 for (int i = 0; i < vertices.size(); ++i)
231 split_edges(vertices[i], comp_orig, comp_dup);
234 /* Replace all nested calls of an associative operator,
235 * by a call with the nested call spliced into the first call.
237 * For each nested call we find, we first check if there are
238 * any other edges with the same position. If not, then
239 * the nested call is performed for each iteration of the computation
240 * and we can simplify splice the nested call.
242 * Otherwise, we first create a duplicate computation for the iterations
243 * that do not take the found edge and adjust the edges of both computations
244 * to their domains. This meand that the edge corresponding to the
245 * nested call will no longer appear in the duplicated computation
246 * and the other edges with the same position will no longer appear
247 * in the original computation.
248 * Then we splice the nested call in the original computation.
249 * Finally, we split all edges that pointed to the original computation
250 * into two edges, one going to the original computation and one
251 * going to the duplicated computation.
253 void dependence_graph::flatten_associative_operators(
254 const std::vector<const char *> &associative)
256 edge *e;
257 computation *comp;
259 while ((comp = associative_node(&e, associative)) != NULL) {
260 computation *comp_dup = NULL;
261 int j;
262 for (j = 0; j < comp->edges.size(); ++j) {
263 if (comp->edges[j] == e)
264 continue;
265 if (comp->edges[j]->pos == e->pos)
266 break;
268 if (j != comp->edges.size())
269 comp_dup = split_comp(comp, e);
270 splice(comp, e);
271 if (comp_dup)
272 split_edges(comp, comp_dup);
276 struct eq_node {
277 private:
278 isl_map *got;
279 isl_map *lost;
280 public:
281 computation *comp[2];
282 isl_map *want;
283 isl_map *need;
284 std::set<eq_node *> assumed;
285 unsigned closed : 1;
286 unsigned narrowing : 1;
287 unsigned widening : 1;
288 unsigned invalidated : 1;
289 unsigned reset : 1;
290 std::string trace;
291 isl_map *lost_sample;
293 eq_node(computation *c1, computation *c2,
294 isl_map *w) : want(w), closed(0),
295 widening(0), narrowing(0), invalidated(0), reset(0),
296 need(NULL), got(NULL), lost(NULL), lost_sample(NULL) {
297 comp[0] = c1;
298 comp[1] = c2;
300 bool is_still_valid();
301 void collect_open_assumed(std::set<eq_node *> &c);
302 ~eq_node() {
303 isl_map_free(want);
304 isl_map_free(need);
305 isl_map_free(got);
306 isl_map_free(lost);
307 isl_map_free(lost_sample);
309 void compute_got() {
310 assert(lost);
311 got = isl_map_copy(want);
312 got = isl_map_subtract(got, isl_map_copy(lost));
314 isl_map *get_got() {
315 if (!got)
316 compute_got();
317 return isl_map_copy(got);
319 isl_map *peek_got() {
320 if (!got)
321 compute_got();
322 return got;
324 void compute_lost() {
325 assert(got);
326 lost = isl_map_copy(want);
327 lost = isl_map_subtract(lost, isl_map_copy(got));
329 isl_map *get_lost() {
330 if (!lost)
331 compute_lost();
332 return isl_map_copy(lost);
334 isl_map *peek_lost() {
335 if (!lost)
336 compute_lost();
337 return lost;
339 void set_got(isl_map *got) {
340 isl_map_free(this->lost);
341 isl_map_free(this->got);
342 this->lost = NULL;
343 this->got = got;
345 void set_lost(isl_map *lost) {
346 isl_map_free(this->lost);
347 isl_map_free(this->got);
348 this->got = NULL;
349 this->lost = lost;
351 std::string to_string();
354 std::string eq_node::to_string()
356 std::ostringstream strm;
357 strm << comp[0]->location << "," << comp[0]->operation
358 << "/" << comp[0]->arity;
359 strm << " <-> ";
360 strm << comp[1]->location << "," << comp[1]->operation
361 << "/" << comp[1]->arity;
362 strm << std::endl;
363 return strm.str();
366 bool eq_node::is_still_valid()
368 if (invalidated)
369 return 0;
371 std::set<eq_node *>::iterator i;
372 for (i = assumed.begin(); i != assumed.end(); ++i) {
373 assert(*i != this);
374 if ((*i)->reset ||
375 ((*i)->closed && !(*i)->is_still_valid())) {
376 invalidated = 1;
377 return 0;
380 return 1;
383 void eq_node::collect_open_assumed(std::set<eq_node *> &c)
385 std::set<eq_node *>::iterator i;
386 for (i = assumed.begin(); i != assumed.end(); ++i) {
387 if ((*i)->closed)
388 (*i)->collect_open_assumed(c);
389 else
390 c.insert(*i);
394 /* A comp_pair contains all the edges that have the same pair
395 * of computations.
397 struct comp_pair {
398 computation *comp[2];
400 std::vector<eq_node *> nodes;
402 eq_node *tabled(eq_node *node);
403 eq_node *last_ancestor(eq_node *node);
404 ~comp_pair();
407 comp_pair::~comp_pair()
409 std::vector<eq_node *>::iterator i;
410 for (i = nodes.begin(); i != nodes.end(); ++i)
411 delete *i;
414 eq_node *comp_pair::tabled(eq_node *node)
416 std::vector<eq_node *>::iterator i;
418 for (i = nodes.begin(); i != nodes.end(); ++i) {
419 if (*i == node)
420 continue;
421 if (!(*i)->closed)
422 continue;
423 if (!(*i)->is_still_valid())
424 continue;
425 int is_subset;
426 is_subset = isl_map_is_subset(node->want, (*i)->want);
427 assert(is_subset >= 0);
428 if (!is_subset)
429 continue;
430 return *i;
432 return NULL;
435 eq_node *comp_pair::last_ancestor(eq_node *node)
437 std::vector<eq_node *>::reverse_iterator i;
438 for (i = nodes.rbegin(); i != nodes.rend(); ++i) {
439 if (*i == node)
440 continue;
441 if ((*i)->closed)
442 continue;
443 return *i;
445 return NULL;
448 typedef std::pair<computation *, computation *> computation_pair;
449 typedef std::map<computation_pair, comp_pair *> c2p_t;
451 struct equivalence_checker {
452 c2p_t c2p;
453 struct options *options;
454 int widenings;
455 int narrowings;
457 equivalence_checker(struct options *o) :
458 options(o), widenings(0), narrowings(0) {}
459 comp_pair *get_comp_pair(eq_node *node);
460 void handle(eq_node *node);
461 void dismiss(eq_node *node);
462 void handle_propagation(eq_node *node);
463 void handle_copy(eq_node *node, int i);
464 void handle_widening(eq_node *node, eq_node *a);
465 void handle_with_ancestor(eq_node *node, eq_node *a);
466 isl_map *lost_at_pos(eq_node *node, int pos1, int pos2);
467 void handle_propagation_same(eq_node *node);
468 void handle_propagation_comm(eq_node *node);
469 void handle_narrowing(eq_node *node);
471 isl_map *lost_from_propagation(eq_node *parent,
472 eq_node *node, edge *e1, edge *e2);
474 void init_trace(eq_node *node);
475 void extend_trace_propagation(eq_node *node, eq_node *child,
476 edge *e1, edge *e2);
477 void extend_trace_widening(eq_node *node, eq_node *child);
479 bool is_commutative(const char *op);
481 ~equivalence_checker();
484 equivalence_checker::~equivalence_checker()
486 c2p_t::iterator i;
488 for (i = c2p.begin(); i != c2p.end(); ++i)
489 delete (*i).second;
492 bool equivalence_checker::is_commutative(const char *op)
494 std::vector<const char *>::iterator iter;
496 for (iter = options->ops->commutative.begin();
497 iter != options->ops->commutative.end(); ++iter)
498 if (!strcmp(*iter, op))
499 return 1;
500 return 0;
503 comp_pair *equivalence_checker::get_comp_pair(eq_node *node)
505 c2p_t::iterator i;
506 comp_pair *cp;
508 i = c2p.find(computation_pair(node->comp[0], node->comp[1]));
509 if (i == c2p.end()) {
510 cp = new comp_pair;
511 c2p[computation_pair(node->comp[0], node->comp[1])] = cp;
512 } else
513 cp = (*i).second;
514 return cp;
517 void equivalence_checker::dismiss(eq_node *node)
519 /* unclosed nodes weren't added to c2p->nodes */
520 if (node && !node->closed)
521 delete node;
524 void equivalence_checker::init_trace(eq_node *node)
526 isl_ctx *ctx;
527 isl_printer *prn;
529 if (!options->trace_error)
530 return;
531 if (isl_map_is_empty(node->peek_lost()))
532 return;
533 node->trace = node->to_string();
534 std::cerr << node->trace;
535 isl_map_free(node->lost_sample);
536 node->lost_sample = isl_map_from_basic_map(
537 isl_map_sample(node->get_lost()));
538 ctx = isl_map_get_ctx(node->lost_sample);
539 prn = isl_printer_to_file(ctx, stderr);
540 prn = isl_printer_print_map(prn, node->lost_sample);
541 prn = isl_printer_end_line(prn);
542 isl_printer_free(prn);
545 void equivalence_checker::extend_trace_propagation(eq_node *node,
546 eq_node *child, edge *e1, edge *e2)
548 isl_ctx *ctx;
549 isl_printer *prn;
551 if (!options->trace_error)
552 return;
553 if (!child->lost_sample)
554 return;
555 if (node->lost_sample)
556 return;
557 node->trace = child->trace + node->to_string();
558 std::cerr << node->trace;
559 node->lost_sample = isl_map_copy(child->lost_sample);
560 if (e1)
561 node->lost_sample = isl_map_apply_domain(node->lost_sample,
562 isl_map_reverse(
563 isl_map_copy(e1->relation)));
564 if (e2)
565 node->lost_sample = isl_map_apply_range(node->lost_sample,
566 isl_map_reverse(
567 isl_map_copy(e2->relation)));
568 node->lost_sample = isl_map_intersect(node->lost_sample,
569 isl_map_copy(node->want));
570 node->lost_sample = isl_map_from_basic_map(
571 isl_map_sample(node->lost_sample));
572 ctx = isl_map_get_ctx(node->lost_sample);
573 prn = isl_printer_to_file(ctx, stderr);
574 prn = isl_printer_print_map(prn, node->lost_sample);
575 prn = isl_printer_end_line(prn);
576 isl_printer_free(prn);
579 void equivalence_checker::extend_trace_widening(eq_node *node, eq_node *child)
581 isl_ctx *ctx;
582 isl_printer *prn;
584 if (!options->trace_error)
585 return;
586 if (!child->lost_sample)
587 return;
588 if (isl_map_is_empty(node->peek_lost()))
589 return;
590 node->trace = child->trace + node->to_string();
591 std::cerr << node->trace;
592 node->lost_sample = isl_map_from_basic_map(
593 isl_map_sample(node->get_lost()));
594 ctx = isl_map_get_ctx(node->lost_sample);
595 prn = isl_printer_to_file(ctx, stderr);
596 prn = isl_printer_print_map(prn, node->lost_sample);
597 prn = isl_printer_end_line(prn);
598 isl_printer_free(prn);
601 /* Check for which subset (got) of the want relation equivelance
602 * holds for the pair of computations in the equivalence node.
604 * We first handle the easy cases: empty want, input computations
605 * and tabled nodes.
607 * If the current node is not a narrowing or a widening node
608 * and we can find an ancestor with the same pair of computations,
609 * then we will try to apply induction or widening in handle_with_ancestor.
610 * However, if the requested relation (want) of the ancestor is a strict
611 * subset of that of the current node, then we have already applied
612 * widening on an intermediate node (with a different pair of computations)
613 * so we shouldn't apply widening again (and we can't apply induction
614 * because the relation of the ancestor is a strict subset).
616 * In all other cases we try to apply propagation in handle_propagation.
618 void equivalence_checker::handle(eq_node *node)
620 computation *comp1 = node->comp[0];
621 computation *comp2 = node->comp[1];
623 if (!comp1->is_copy() && !comp2->is_copy() &&
624 (strcmp(comp1->operation, comp2->operation) ||
625 comp1->arity != comp2->arity)) {
626 node->set_lost(isl_map_copy(node->want));
627 init_trace(node);
628 return;
631 eq_node *s = NULL;
632 eq_node *a = NULL;
634 isl_map *want = node->want;
635 node->need = isl_map_empty(isl_map_get_space(want));
636 int empty = isl_map_is_empty(want);
637 assert(empty >= 0);
638 if (empty) {
639 node->set_lost(isl_map_empty(isl_map_get_space(want)));
640 return;
642 if (node->comp[0]->is_input() && node->comp[1]->is_input()) {
643 if (strcmp(node->comp[0]->operation, node->comp[1]->operation)) {
644 node->set_lost(isl_map_copy(want));
645 return;
647 node->set_lost(isl_map_subtract(
648 isl_map_copy(node->want),
649 isl_map_identity(isl_map_get_space(want))));
650 return;
653 comp_pair *cp = get_comp_pair(node);
654 if ((s = cp->tabled(node)) != NULL) {
655 node->set_lost(isl_map_intersect(s->get_lost(),
656 isl_map_copy(node->want)));
657 s->collect_open_assumed(node->assumed);
658 return;
661 cp->nodes.push_back(node);
663 if (!node->narrowing && !node->widening &&
664 (a = cp->last_ancestor(node)) != NULL) {
665 int is_subset;
666 is_subset = isl_map_is_strict_subset(a->want, node->want);
667 assert(is_subset >= 0);
668 if (is_subset)
669 handle_propagation(node);
670 else
671 handle_with_ancestor(node, a);
672 } else
673 handle_propagation(node);
674 node->closed = 1;
677 /* Check if we can apply propagation to prove equivalence of the given node.
678 * First check if we need to apply copy propagation and if not
679 * check if the operations are the same and apply "regular" propagation.
681 * After we get back, we need to check that any induction hypotheses
682 * we have used in the process of proving the node hold.
683 * If not, we replace the obtained relation (got) by that
684 * of a narrowing node in handle_narrowin.
686 void equivalence_checker::handle_propagation(eq_node *node)
688 if (node->comp[0]->is_copy())
689 handle_copy(node, 0);
690 else if (node->comp[1]->is_copy())
691 handle_copy(node, 1);
692 else
693 handle_propagation_same(node);
694 node->assumed.erase(node);
695 isl_map *lost = node->get_lost();
696 lost = isl_map_intersect(lost, isl_map_copy(node->need));
697 int is_empty = isl_map_is_empty(lost);
698 isl_map_free(lost);
699 assert(is_empty >= 0);
700 if (!is_empty)
701 handle_narrowing(node);
704 /* When applying the mappings on expansion edges to both sides of
705 * an equivalence relation, each element in the original equivalence
706 * relation is mapped to many elements on both sides. For example,
707 * if the original equivalence relation has i R i' for i = i', then the new
708 * equivalence relation may have (i,j) R (i',j') for i = i' and
709 * 0 <= j,j' <= 10, expressing that all elements of the row read by
710 * iteration i should be equal to all elements of the row read by i'.
711 * Instead, we want to express that each individual element read by i'
712 * should be equal to the corresponding element read by i'.
713 * In the example, we need to introduce the equality j = j'.
715 * We first check that the number of dimensions added by the expansions
716 * is the same in both programs. Then we construct an identity relation
717 * between this number of dimensions, lift it to the space of the
718 * new equivalence relation and take the intersection.
720 * We have to be careful, though, that we don't loose any array elements
721 * by taking the intersection. In particular, the expansion maps
722 * a given read operation to iterations of an extended domain that
723 * reads all the individual array elements. The read operation is
724 * only equivalent to some other read operation if all the reads
725 * of the individual array elements are equivalent.
726 * We therefore need to make sure that adding the equalities does
727 * not remove any of the reads in the extended domain.
728 * We do this by projecting out the extra dimensions on one side
729 * from both "want" and "want \cap eq". The resulting maps should
730 * be the same. We do this check for both sides separately.
732 * If either of the above tests fails, then we simply return
733 * the original over-ambitious want.
735 static isl_map *expansion_want(isl_map *want, edge *e1, edge *e2)
737 unsigned s_dim_1 = isl_map_n_out(e1->relation) -
738 isl_map_n_in(e1->relation);
739 unsigned s_dim_2 = isl_map_n_out(e2->relation) -
740 isl_map_n_in(e2->relation);
741 if (s_dim_1 != s_dim_2)
742 return want;
744 isl_space *dim = isl_map_get_space(e1->relation);
745 dim = isl_space_drop_dims(dim, isl_dim_in,
746 0, isl_space_dim(dim, isl_dim_in));
747 dim = isl_space_drop_dims(dim, isl_dim_out,
748 0, isl_space_dim(dim, isl_dim_out));
749 dim = isl_space_add_dims(dim, isl_dim_in, s_dim_1);
750 dim = isl_space_add_dims(dim, isl_dim_out, s_dim_1);
751 isl_basic_map *id = isl_basic_map_identity(dim);
753 dim = isl_space_range(isl_map_get_space(e1->relation));
754 isl_basic_map *s_1 = isl_basic_map_identity(isl_space_map_from_set(dim));
755 s_1 = isl_basic_map_remove_dims(s_1, isl_dim_in, 0,
756 isl_basic_map_n_in(s_1) - s_dim_1);
757 id = isl_basic_map_apply_domain(id, s_1);
759 dim = isl_space_range(isl_map_get_space(e2->relation));
760 isl_basic_map *s_2 = isl_basic_map_identity(isl_space_map_from_set(dim));
761 s_2 = isl_basic_map_remove_dims(s_2, isl_dim_in, 0,
762 isl_basic_map_n_in(s_2) - s_dim_2);
763 id = isl_basic_map_apply_range(id, s_2);
765 bool unmatched = false;
766 isl_map *matched_want;
767 isl_map *proj_want, *proj_matched;
768 matched_want = isl_map_intersect(isl_map_copy(want),
769 isl_map_from_basic_map(id));
771 proj_want = isl_map_remove_dims(isl_map_copy(want),
772 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
773 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
774 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
775 if (!isl_map_is_equal(proj_want, proj_matched))
776 unmatched = true;
777 isl_map_free(proj_want);
778 isl_map_free(proj_matched);
780 proj_want = isl_map_remove_dims(isl_map_copy(want),
781 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
782 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
783 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
784 if (!isl_map_is_equal(proj_want, proj_matched))
785 unmatched = true;
786 isl_map_free(proj_want);
787 isl_map_free(proj_matched);
789 if (unmatched) {
790 isl_map_free(matched_want);
791 return matched_want;
793 isl_map_free(want);
794 return matched_want;
797 static isl_map *propagation_want(eq_node *node, edge *e1, edge *e2)
799 isl_map *new_want;
801 new_want = isl_map_copy(node->want);
802 if (e1)
803 new_want = isl_map_apply_domain(new_want,
804 isl_map_copy(e1->relation));
805 if (e2)
806 new_want = isl_map_apply_range(new_want,
807 isl_map_copy(e2->relation));
809 if (e1 && e2 && e1->type == edge::expansion)
810 new_want = expansion_want(new_want, e1, e2);
812 return isl_map_detect_equalities(new_want);
815 static eq_node *propagation_node(eq_node *node, edge *e1, edge *e2)
817 isl_map *new_want;
818 computation *comp1 = e1 ? e1->source : node->comp[0];
819 computation *comp2 = e2 ? e2->source : node->comp[1];
821 new_want = propagation_want(node, e1, e2);
822 eq_node *child = new eq_node(comp1, comp2, new_want);
823 return child;
826 static isl_map *set_to_empty(isl_map *map)
828 isl_space *space;
829 space = isl_map_get_space(map);
830 isl_map_free(map);
831 return isl_map_empty(space);
834 /* Propagate "lost" from child back to parent.
836 isl_map *equivalence_checker::lost_from_propagation(eq_node *parent,
837 eq_node *node, edge *e1, edge *e2)
839 isl_map *new_lost;
840 new_lost = node->get_lost();
842 if (e1)
843 new_lost = isl_map_apply_domain(new_lost,
844 isl_map_reverse(isl_map_copy(e1->relation)));
845 if (e2)
846 new_lost = isl_map_apply_range(new_lost,
847 isl_map_reverse(isl_map_copy(e2->relation)));
849 new_lost = isl_map_intersect(new_lost, isl_map_copy(parent->want));
851 extend_trace_propagation(parent, node, e1, e2);
853 return new_lost;
856 void equivalence_checker::handle_copy(eq_node *node, int i)
858 std::vector<edge *>::iterator iter;
859 computation *copy = node->comp[i];
860 isl_map *want = node->want;
861 isl_map *lost;
863 lost = isl_map_empty(isl_map_get_space(want));
865 for (iter = copy->edges.begin(); iter != copy->edges.end(); ++iter) {
866 edge *e = *iter;
867 eq_node *child;
868 if (i == 0)
869 child = propagation_node(node, e, NULL);
870 else
871 child = propagation_node(node, NULL, e);
873 if (!child)
874 continue;
876 handle(child);
878 isl_map *new_lost;
879 if (i == 0)
880 new_lost = lost_from_propagation(node, child, e, NULL);
881 else
882 new_lost = lost_from_propagation(node, child, NULL, e);
883 dismiss(child);
885 lost = isl_map_union_disjoint(lost, new_lost);
887 node->set_lost(lost);
890 /* Compute and return the part of "want" that is lost when propagating
891 * over all edges of argument position pos1 in the first program
892 * and argument position pos2 in the second program.
893 * Since the domains of the dependence mappings on edges with the same
894 * argument position partition the domain of the computation (and are
895 * therefore disjoint), we simply need to take the disjoint union
896 * of all losts over all pairs of edges.
898 isl_map *equivalence_checker::lost_at_pos(eq_node *node, int pos1, int pos2)
900 isl_map *lost;
901 lost = isl_map_empty(isl_map_get_space(node->want));
903 for (int i = 0; i < node->comp[0]->edges.size(); ++i) {
904 edge *e1 = node->comp[0]->edges[i];
905 if (e1->pos != pos1)
906 continue;
907 for (int j = 0; j < node->comp[1]->edges.size(); ++j) {
908 edge *e2 = node->comp[1]->edges[j];
909 if (e2->pos != pos2)
910 continue;
912 eq_node *child;
913 isl_map *new_lost = NULL;
915 child = propagation_node(node, e1, e2);
916 handle(child);
917 new_lost = lost_from_propagation(node, child, e1, e2);
919 lost = isl_map_union_disjoint(lost, new_lost);
921 std::set<eq_node *>::iterator k;
922 for (k = child->assumed.begin();
923 k != child->assumed.end(); ++k)
924 node->assumed.insert(*k);
925 dismiss(child);
928 return lost;
931 /* Compute the lost that results from propagation on a pair of
932 * computations with a commutative operation.
934 * We first compute the losts for each pair of argument positions
935 * and store the result in lost[pos1][pos2].
936 * Then we perform a backtracking search over all permutations
937 * of the arguments. For each permutation, we compute the lost
938 * relation as the union of the losts over all arguments.
939 * The final lost is the intersection of all these losts over all
940 * permutations.
942 void equivalence_checker::handle_propagation_comm(eq_node *node)
944 int trace_error = options->trace_error;
945 trace_error = 0;
947 unsigned r = node->comp[0]->arity;
949 std::vector<std::vector<isl_map *> > lost;
951 for (int i = 0; i < r; ++i) {
952 std::vector<isl_map *> row;
953 for (int j = 0; j < r; ++j) {
954 isl_map *pos_lost;
955 pos_lost = lost_at_pos(node, i, j);
956 row.push_back(pos_lost);
958 lost.push_back(row);
961 int level;
962 std::vector<int> perm;
963 std::vector<isl_map *> lost_at;
964 for (level = 0; level < r; ++level) {
965 perm.push_back(0);
966 lost_at.push_back(NULL);
969 isl_map *total_lost;
970 total_lost = isl_map_copy(node->want);
972 level = 0;
973 while (level >= 0) {
974 if (perm[level] == r) {
975 perm[level] = 0;
976 --level;
977 if (level >= 0)
978 ++perm[level];
979 continue;
981 int l;
982 for (l = 0; l < level; ++l)
983 if (perm[l] == perm[level])
984 break;
985 if (l != level) {
986 ++perm[level];
987 continue;
990 isl_map_free(lost_at[level]);
991 lost_at[level] = isl_map_copy(lost[level][perm[level]]);
992 if (level != 0)
993 lost_at[level] = isl_map_union(lost_at[level],
994 isl_map_copy(lost_at[level - 1]));
996 if (level < r - 1) {
997 ++level;
998 continue;
1001 lost_at[level] = isl_map_coalesce(lost_at[level]);
1002 total_lost = isl_map_intersect(total_lost,
1003 isl_map_copy(lost_at[level]));
1004 ++perm[level];
1007 for (int i = 0; i < r; ++i)
1008 isl_map_free(lost_at[i]);
1010 for (int i = 0; i < r; ++i)
1011 for (int j = 0; j < r; ++j)
1012 isl_map_free(lost[i][j]);
1014 node->set_lost(total_lost);
1016 options->trace_error = trace_error;
1017 init_trace(node);
1020 /* Compute the lost that results from propagation on a pair of
1021 * computations.
1023 * First, for functions of zero arity (i.e., constants), equivalence
1024 * always holds and the lost relation is empty.
1025 * For commutative operations, the computation is delegated
1026 * to handle_propagation_comm.
1027 * Otherwise, we simply take the union of the losts over each
1028 * argument position (always taking the same argument position
1029 * in both programs).
1031 void equivalence_checker::handle_propagation_same(eq_node *node)
1033 unsigned r = node->comp[0]->arity;
1034 if (r == 0) {
1035 node->set_lost(isl_map_empty(isl_map_get_space(node->want)));
1036 return;
1038 if (is_commutative(node->comp[0]->operation)) {
1039 handle_propagation_comm(node);
1040 return;
1043 isl_map *lost;
1044 lost = isl_map_empty(isl_map_get_space(node->want));
1045 for (int i = 0; i < r; ++i) {
1046 isl_map *pos_lost;
1047 pos_lost = lost_at_pos(node, i, i);
1048 lost = isl_map_union(lost, pos_lost);
1050 lost = isl_map_coalesce(lost);
1051 node->set_lost(lost);
1054 void equivalence_checker::handle_widening(eq_node *node, eq_node *a)
1056 isl_map *wants;
1057 isl_map *aff;
1059 wants = isl_map_union(isl_map_copy(node->want), isl_map_copy(a->want));
1060 aff = isl_map_from_basic_map(isl_map_affine_hull(wants));
1061 aff = isl_map_intersect_domain(aff,
1062 isl_set_copy(node->comp[0]->domain));
1063 aff = isl_map_intersect_range(aff,
1064 isl_set_copy(node->comp[1]->domain));
1066 eq_node *child = new eq_node(node->comp[0], node->comp[1], aff);
1067 child->widening = 1;
1068 widenings++;
1069 handle(child);
1070 node->set_lost(isl_map_intersect(child->get_lost(),
1071 isl_map_copy(node->want)));
1072 extend_trace_widening(node, child);
1073 std::set<eq_node *>::iterator i;
1074 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1075 node->assumed.insert(*i);
1076 dismiss(child);
1079 /* Perform induction, if possible, and widening if we have to.
1081 void equivalence_checker::handle_with_ancestor(eq_node *node, eq_node *a)
1083 if (a->narrowing || isl_map_is_subset(node->want, a->want)) {
1084 isl_map *need;
1085 need = isl_map_intersect(isl_map_copy(node->want),
1086 isl_map_copy(a->want));
1087 node->set_lost(isl_map_subtract(isl_map_copy(node->want),
1088 isl_map_copy(a->want)));
1089 node->assumed.insert(a);
1090 a->need = isl_map_union(a->need, need);
1091 } else {
1092 handle_widening(node, a);
1096 struct narrowing_data {
1097 eq_node *node;
1098 equivalence_checker *ec;
1099 isl_map *new_got;
1102 static isl_stat basic_handle_narrowing(__isl_take isl_basic_map *bmap,
1103 void *user)
1105 narrowing_data *data = (narrowing_data *)user;
1107 eq_node *child;
1108 child = new eq_node(data->node->comp[0], data->node->comp[1],
1109 isl_map_from_basic_map(bmap));
1110 child->narrowing = 1;
1111 data->ec->narrowings++;
1112 data->ec->handle(child);
1113 data->new_got = isl_map_union(data->new_got, child->get_got());
1115 std::set<eq_node *>::iterator i;
1116 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1117 data->node->assumed.insert(*i);
1118 data->ec->dismiss(child);
1120 return isl_stat_ok;
1123 /* Construct and handle narrowing nodes for the given node.
1125 * If the node itself was already a narrowing node, then we
1126 * simply return the empty relation.
1127 * Otherwise, we consider each basic relation in the obtained
1128 * relation, construct a new node with that basic relation as
1129 * requested relation and take the union of all obtained relations.
1131 void equivalence_checker::handle_narrowing(eq_node *node)
1133 node->reset = 1;
1134 node->assumed.clear();
1136 isl_map *want = node->want;
1137 isl_map *new_got = isl_map_empty(isl_map_get_space(want));
1138 if (!options->narrowing || node->narrowing) {
1139 node->set_got(new_got);
1140 return;
1143 narrowing_data data = { node, this, new_got };
1144 isl_map_foreach_basic_map(node->peek_got(), &basic_handle_narrowing,
1145 &data);
1146 node->set_got(data.new_got);
1149 static __isl_give isl_union_set *update_result(__isl_take isl_union_set *res,
1150 const char *array_name, isl_map *map, int first, int n)
1152 if (isl_map_is_empty(map)) {
1153 isl_map_free(map);
1154 return res;
1157 isl_set *range = isl_map_range(map);
1158 range = isl_set_remove_dims(range, isl_dim_set, first, n);
1159 range = isl_set_coalesce(range);
1160 range = isl_set_set_tuple_name(range, array_name);
1161 res = isl_union_set_add_set(res, range);
1163 return res;
1166 struct check_equivalence_data {
1167 dependence_graph *dg1;
1168 dependence_graph *dg2;
1169 equivalence_checker *ec;
1170 isl_map *got;
1171 isl_map *lost;
1174 static isl_stat basic_check(__isl_take isl_basic_map *bmap, void *user)
1176 check_equivalence_data *data = (check_equivalence_data *)user;
1178 eq_node *root = new eq_node(data->dg1->out, data->dg2->out,
1179 isl_map_from_basic_map(bmap));
1180 data->ec->handle(root);
1181 data->got = isl_map_union_disjoint(data->got, root->get_got());
1182 data->lost = isl_map_union_disjoint(data->lost, root->get_lost());
1183 data->ec->dismiss(root);
1185 return isl_stat_ok;
1188 static int check_equivalence_array(isl_ctx *ctx, equivalence_checker *ec,
1189 dependence_graph *dg1, dependence_graph *dg2, int array1, int array2,
1190 isl_union_set **proved, isl_union_set **not_proved)
1192 const char *array_name = dg1->output_arrays[array1];
1193 isl_set *out1 = isl_set_copy(dg1->out->domain);
1194 isl_set *out2 = isl_set_copy(dg2->out->domain);
1195 unsigned dim1 = isl_set_n_dim(out1);
1196 unsigned dim2 = isl_set_n_dim(out2);
1197 unsigned array_dim = dg1->output_array_dims[array1];
1198 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1199 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1200 out1 = isl_set_remove_dims(out1, isl_dim_set, dim1 - 1, 1);
1201 out2 = isl_set_remove_dims(out2, isl_dim_set, dim2 - 1, 1);
1202 int equal = isl_set_is_equal(out1, out2);
1203 isl_set_free(out1);
1204 isl_set_free(out2);
1205 assert(equal >= 0);
1206 if (!equal) {
1207 fprintf(stderr, "different output domains for array %s\n",
1208 array_name);
1209 return -1;
1212 out1 = isl_set_copy(dg1->out->domain);
1213 out2 = isl_set_copy(dg2->out->domain);
1214 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1215 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1216 out1 = isl_set_coalesce(out1);
1217 out2 = isl_set_coalesce(out2);
1218 isl_space *dim = isl_space_map_from_set(isl_set_get_space(out2));
1219 isl_map *id = isl_map_identity(dim);
1220 id = isl_map_remove_dims(id, isl_dim_in, isl_map_n_in(id) - 1, 1);
1221 id = isl_map_apply_domain(id, isl_map_copy(id));
1222 id = isl_map_intersect_domain(id, out1);
1223 id = isl_map_intersect_range(id, out2);
1225 isl_map *got = isl_map_empty(isl_map_get_space(id));
1226 isl_map *lost = isl_map_copy(got);
1227 check_equivalence_data data = { dg1, dg2, ec, got, lost };
1228 isl_map_foreach_basic_map(id, &basic_check, &data);
1229 isl_map_free(id);
1230 *proved = update_result(*proved,
1231 array_name, data.got, array_dim, dim2 - array_dim);
1232 *not_proved = update_result(*not_proved,
1233 array_name, data.lost, array_dim, dim2 - array_dim);
1234 return 0;
1237 /* The input arrays of the two programs are supposed to be the same,
1238 * so they should at least have the same dimension. Make sure
1239 * this is true, because we depend on it later on.
1241 static int check_input_arrays(dependence_graph *dg1, dependence_graph *dg2)
1243 for (int i = 0; i < dg1->input_computations.size(); ++i)
1244 for (int j = 0; j < dg2->input_computations.size(); ++j) {
1245 if (strcmp(dg1->input_computations[i]->operation,
1246 dg2->input_computations[j]->operation))
1247 continue;
1248 if (dg1->input_computations[i]->dim ==
1249 dg2->input_computations[j]->dim)
1250 continue;
1251 fprintf(stderr,
1252 "input arrays \"%s\" do not have the same dimension\n",
1253 dg1->input_computations[i]->operation);
1254 return -1;
1257 return 0;
1260 static __isl_give isl_union_set *add_array(__isl_take isl_union_set *only,
1261 dependence_graph *dg, int array)
1263 isl_space *dim;
1264 isl_set *set;
1266 dim = isl_union_set_get_space(only);
1267 dim = isl_space_add_dims(dim, isl_dim_set, dg->output_array_dims[array]);
1268 dim = isl_space_set_tuple_name(dim, isl_dim_set, dg->output_arrays[array]);
1269 set = isl_set_universe(dim);
1270 only = isl_union_set_add_set(only, set);
1272 return only;
1275 static void print_results(const char *str, __isl_keep isl_union_set *only)
1277 isl_printer *prn;
1279 if (isl_union_set_is_empty(only))
1280 return;
1282 fprintf(stdout, "%s: '", str);
1283 prn = isl_printer_to_file(isl_union_set_get_ctx(only), stdout);
1284 prn = isl_printer_print_union_set(prn, only);
1285 isl_printer_free(prn);
1286 fprintf(stdout, "'\n");
1289 static int check_equivalence(isl_ctx *ctx,
1290 dependence_graph *dg1, dependence_graph *dg2, options *options)
1292 isl_space *dim;
1293 isl_set *context;
1294 isl_union_set *only1, *only2, *proved, *not_proved;
1295 dg1->flatten_associative_operators(options->ops->associative);
1296 dg2->flatten_associative_operators(options->ops->associative);
1297 equivalence_checker ec(options);
1298 int i1 = 0, i2 = 0;
1300 if (check_input_arrays(dg1, dg2))
1301 return -1;
1303 dim = isl_space_set_alloc(ctx, 0, 0);
1304 proved = isl_union_set_empty(isl_space_copy(dim));
1305 not_proved = isl_union_set_empty(isl_space_copy(dim));
1306 only1 = isl_union_set_empty(isl_space_copy(dim));
1307 only2 = isl_union_set_empty(dim);
1309 while (i1 < dg1->output_arrays.size() || i2 < dg2->output_arrays.size()) {
1310 int cmp;
1311 cmp = i1 == dg1->output_arrays.size() ? 1 :
1312 i2 == dg2->output_arrays.size() ? -1 :
1313 strcmp(dg1->output_arrays[i1], dg2->output_arrays[i2]);
1314 if (cmp < 0) {
1315 only1 = add_array(only1, dg1, i1);
1316 ++i1;
1317 } else if (cmp > 0) {
1318 only2 = add_array(only2, dg2, i2);
1319 ++i2;
1320 } else {
1321 check_equivalence_array(ctx, &ec, dg1, dg2, i1, i2,
1322 &proved, &not_proved);
1323 ++i1;
1324 ++i2;
1328 context = isl_set_union(isl_set_copy(dg1->context),
1329 isl_set_copy(dg2->context));
1330 proved = isl_union_set_gist_params(proved, isl_set_copy(context));
1331 not_proved = isl_union_set_gist_params(not_proved, context);
1333 print_results("Equivalence proved", proved);
1334 print_results("Equivalence NOT proved", not_proved);
1335 print_results("Only in program 1", only1);
1336 print_results("Only in program 2", only2);
1338 isl_union_set_free(proved);
1339 isl_union_set_free(not_proved);
1340 isl_union_set_free(only1);
1341 isl_union_set_free(only2);
1343 if (options->print_stats) {
1344 fprintf(stderr, "widenings: %d\n", ec.widenings);
1345 if (options->narrowing)
1346 fprintf(stderr, "narrowings: %d\n", ec.narrowings);
1349 return 0;
1352 static void dump_vertex(FILE *out, computation *comp)
1354 fprintf(out, "ND_%p [label = \"%d,%s/%d\"];\n",
1355 comp, comp->location, comp->operation, comp->arity);
1356 for (int i = 0; i < comp->edges.size(); ++i)
1357 fprintf(out, "ND_%p -> ND_%p%s;\n",
1358 comp, comp->edges[i]->source,
1359 comp->edges[i]->type == edge::expansion ?
1360 " [color=\"blue\"]" : "");
1363 static void dump_graph(FILE *out, dependence_graph *dg)
1365 fprintf(out, "digraph dummy {\n");
1366 dump_vertex(out, dg->out);
1367 for (int i = 0; i < dg->vertices.size(); ++i)
1368 dump_vertex(out, dg->vertices[i]);
1369 fprintf(out, "}\n");
1372 static void dump_graphs(dependence_graph **dg, struct options *options)
1374 int i;
1375 char path[PATH_MAX];
1377 if (!options->dump_graphs)
1378 return;
1380 for (i = 0; i < 2; ++i) {
1381 FILE *out;
1382 int s;
1383 s = snprintf(path, sizeof(path), "%s.dot", options->program[i]);
1384 assert(s < sizeof(path));
1385 out = fopen(path, "w");
1386 assert(out);
1387 dump_graph(out, dg[i]);
1388 fclose(out);
1392 void parse_ops(struct options *options) {
1393 char *tok;
1395 if (options->associative) {
1396 tok = strtok(options->associative, ",");
1397 options->ops->associative.push_back(strdup(tok));
1398 while ((tok = strtok(NULL, ",")) != NULL)
1399 options->ops->associative.push_back(strdup(tok));
1402 if (options->commutative) {
1403 tok = strtok(options->commutative, ",");
1404 options->ops->commutative.push_back(strdup(tok));
1405 while ((tok = strtok(NULL, ",")) != NULL)
1406 options->ops->commutative.push_back(strdup(tok));
1410 int main(int argc, char *argv[])
1412 struct options *options = options_new_with_defaults();
1413 struct isl_ctx *ctx;
1414 isl_set *context = NULL;
1415 dependence_graph *dg[2];
1417 argc = options_parse(options, argc, argv, ISL_ARG_ALL);
1418 parse_ops(options);
1420 ctx = isl_ctx_alloc_with_options(&options_args, options);
1421 if (!ctx) {
1422 fprintf(stderr, "Unable to allocate ctx\n");
1423 return -1;
1426 if (options->context)
1427 context = isl_set_read_from_str(ctx, options->context);
1429 pdg::PDG *pdg[2];
1430 unsigned out_dim = 0;
1431 for (int i = 0; i < 2; ++i) {
1432 FILE *in;
1433 in = fopen(options->program[i], "r");
1434 if (!in) {
1435 fprintf(stderr, "Unable to open %s\n", options->program[i]);
1436 return -1;
1438 pdg[i] = yaml::Load<pdg::PDG>(in, ctx);
1439 fclose(in);
1440 if (!pdg[i]) {
1441 fprintf(stderr, "Unable to read %s\n", options->program[i]);
1442 return -1;
1444 out_dim = update_out_dim(pdg[i], out_dim);
1446 if (context &&
1447 pdg[i]->params.size() != isl_set_dim(context, isl_dim_param)) {
1448 fprintf(stdout,
1449 "Parameter dimension mismatch; context ignored\n");
1450 isl_set_free(context);
1451 context = NULL;
1455 for (int i = 0; i < 2; ++i) {
1456 dg[i] = pdg_to_dg(pdg[i], out_dim, isl_set_copy(context));
1457 pdg[i]->free();
1458 delete pdg[i];
1461 dump_graphs(dg, options);
1463 int res = check_equivalence(ctx, dg[0], dg[1], options);
1465 for (int i = 0; i < 2; ++i)
1466 delete dg[i];
1467 isl_set_free(context);
1468 isl_ctx_free(ctx);
1470 return res;