c2pdg.cc: extract_node: extract statement name from domain instead of schedule
[ppn.git] / eqv.cc
blob4fdfa095b0faebcbf25d49c7dd879ada250dcdf4
1 #include <limits.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <algorithm>
6 #include <map>
7 #include <set>
8 #include <vector>
9 #include <string>
10 #include <iostream>
11 #include <sstream>
13 #include <isa/yaml.h>
14 #include <isa/pdg.h>
16 #include <isl/set.h>
17 #include <isl/map.h>
18 #include <isl/constraint.h>
19 #include <isl/union_set.h>
20 #include "eqv_options.h"
21 #include "dependence_graph.h"
23 struct ops {
24 std::vector<const char *> associative;
25 std::vector<const char *> commutative;
27 int ops_init(void *user)
29 struct ops **ops = (struct ops **)user;
30 *ops = new struct ops;
31 return 0;
33 void ops_clear(void *user)
35 struct ops **ops = (struct ops **)user;
36 delete *ops;
39 /* For each basic map in relation, add an edge to comp with
40 * given source and pos.
42 static void add_split_edge(computation *comp, computation *source,
43 int pos, isl_map *relation, std::vector<computation **> *missing,
44 enum edge::type type = edge::normal)
46 edge *e = new edge;
47 e->source = source;
48 if (!source) {
49 assert(missing);
50 missing->push_back(&e->source);
52 e->pos = pos;
53 e->relation = relation;
54 e->type = type;
55 comp->edges.push_back(e);
58 static unsigned update_out_dim(pdg::PDG *pdg, unsigned out_dim)
60 for (int i = 0; i < pdg->arrays.size(); ++i) {
61 pdg::array *array = pdg->arrays[i];
62 if (array->type != pdg::array::output)
63 continue;
64 if (array->dims.size() + 1 > out_dim)
65 out_dim = array->dims.size() + 1;
68 return out_dim;
71 static bool is_associative(const std::vector<const char *> &associative,
72 const char *op)
74 std::vector<const char *>::const_iterator iter;
76 for (iter = associative.begin(); iter != associative.end(); ++iter)
77 if (!strcmp(*iter, op))
78 return 1;
79 return 0;
82 /* Return a computation that has an associative operation
83 * that takes a different computation with the same operation
84 * as one of its arguments. Also return the edge from the
85 * first to the second computation in *e.
87 * If no such computation exists, then return NULL.
89 computation *dependence_graph::associative_node(edge **e,
90 const std::vector<const char *> &associative)
92 for (int i = 0; i < vertices.size(); ++i) {
93 computation *comp = vertices[i];
94 if (!is_associative(associative, comp->operation))
95 continue;
96 for (int j = 0; j < comp->edges.size(); ++j) {
97 computation *other = comp->edges[j]->source;
98 if (comp->has_same_source(other))
99 continue;
100 if (strcmp(comp->operation, other->operation))
101 continue;
102 *e = comp->edges[j];
103 return comp;
106 return NULL;
109 /* Splice the source of edge edge into the position of edge e,
110 * removing all other edges with the same positions and bumping
111 * up edges with a greater position.
113 void dependence_graph::splice(computation *comp, edge *e)
115 computation *source = e->source;
116 int pos = e->pos;
117 isl_map *map = isl_map_copy(e->relation);
118 std::vector<struct edge *> new_edges;
120 for (int i = 0; i < comp->edges.size(); ++i) {
121 if (comp->edges[i]->pos == pos) {
122 delete comp->edges[i];
123 continue;
125 if (comp->edges[i]->pos > pos)
126 comp->edges[i]->pos += source->arity - 1;
127 new_edges.push_back(comp->edges[i]);
129 for (int i = 0; i < source->edges.size(); ++i) {
130 edge *old_e = source->edges[i];
131 edge *e = new edge;
132 e->source = old_e->source;
133 e->pos = old_e->pos + pos;
134 e->relation = isl_map_apply_range(
135 isl_map_copy(map),
136 isl_map_copy(old_e->relation));
137 new_edges.push_back(e);
139 isl_map_free(map);
140 comp->edges = new_edges;
141 comp->arity += source->arity - 1;
144 /* Split computation comp into a part that always takes the edge e
145 * and one that never takes that edge.
146 * The second is returned and the initial computation is modified
147 * to match the first.
149 * We need to be careful not to destroy e, as it is still used
150 * by the calling method.
152 computation *dependence_graph::split_comp(computation *comp, edge *e)
154 computation *dup = new computation;
155 dup->original = comp->original ? comp->original : comp;
156 dup->operation = strdup(comp->operation);
157 dup->arity = comp->arity;
158 dup->location = comp->location;
160 isl_map *map = isl_map_copy(e->relation);
161 map = isl_map_reverse(map);
162 isl_set *dom = isl_set_apply(isl_set_copy(e->source->domain), map);
163 dup->domain = isl_set_subtract(isl_set_copy(comp->domain),
164 isl_set_copy(dom));
165 comp->domain = isl_set_intersect(comp->domain, dom);
167 std::vector<struct edge *> old_edges = comp->edges;
168 comp->edges.clear();
170 for (int i = 0; i < old_edges.size(); ++i) {
171 if (old_edges[i] == e) {
172 comp->edges.push_back(e);
173 continue;
176 edge *e = old_edges[i];
177 isl_map *map, *map_dup;
178 map = isl_map_copy(e->relation);
179 map_dup = isl_map_copy(map);
181 map = isl_map_intersect_domain(map, isl_set_copy(comp->domain));
182 map_dup = isl_map_intersect_domain(map_dup,
183 isl_set_copy(dup->domain));
184 add_split_edge(comp, e->source, e->pos, map, NULL);
185 add_split_edge(dup, e->source, e->pos, map_dup, NULL);
186 delete e;
189 vertices.push_back(dup);
190 return dup;
193 /* If any edge from comp points to comp_orig, then split it
194 * into two edges, one still pointing to comp_orig and the
195 * other pointing to comp_dup.
196 * comp_orig and comp_dup are assumed to have disjoint domains
197 * and the edge relations are adjusted according to these domains.
199 void dependence_graph::split_edges(computation *comp,
200 computation *comp_orig, computation *comp_dup)
202 std::vector<struct edge *> old_edges = comp->edges;
203 comp->edges.clear();
205 for (int i = 0; i < old_edges.size(); ++i) {
206 edge *e = old_edges[i];
208 if (e->source != comp_orig) {
209 comp->edges.push_back(e);
210 continue;
213 isl_map *map_orig, *map_dup;
214 map_orig = isl_map_copy(e->relation);
215 map_dup = isl_map_copy(map_orig);
216 map_orig = isl_map_intersect_range(map_orig,
217 isl_set_copy(comp_orig->domain));
218 map_dup = isl_map_intersect_range(map_dup,
219 isl_set_copy(comp_dup->domain));
220 add_split_edge(comp, comp_orig, e->pos, map_orig, NULL);
221 add_split_edge(comp, comp_dup, e->pos, map_dup, NULL);
222 delete e;
226 void dependence_graph::split_edges(computation *comp_orig,
227 computation *comp_dup)
229 split_edges(out, comp_orig, comp_dup);
230 for (int i = 0; i < vertices.size(); ++i)
231 split_edges(vertices[i], comp_orig, comp_dup);
234 /* Replace all nested calls of an associative operator,
235 * by a call with the nested call spliced into the first call.
237 * For each nested call we find, we first check if there are
238 * any other edges with the same position. If not, then
239 * the nested call is performed for each iteration of the computation
240 * and we can simplify splice the nested call.
242 * Otherwise, we first create a duplicate computation for the iterations
243 * that do not take the found edge and adjust the edges of both computations
244 * to their domains. This meand that the edge corresponding to the
245 * nested call will no longer appear in the duplicated computation
246 * and the other edges with the same position will no longer appear
247 * in the original computation.
248 * Then we splice the nested call in the original computation.
249 * Finally, we split all edges that pointed to the original computation
250 * into two edges, one going to the original computation and one
251 * going to the duplicated computation.
253 void dependence_graph::flatten_associative_operators(
254 const std::vector<const char *> &associative)
256 edge *e;
257 computation *comp;
259 while ((comp = associative_node(&e, associative)) != NULL) {
260 computation *comp_dup = NULL;
261 int j;
262 for (j = 0; j < comp->edges.size(); ++j) {
263 if (comp->edges[j] == e)
264 continue;
265 if (comp->edges[j]->pos == e->pos)
266 break;
268 if (j != comp->edges.size())
269 comp_dup = split_comp(comp, e);
270 splice(comp, e);
271 if (comp_dup)
272 split_edges(comp, comp_dup);
276 struct eq_node {
277 private:
278 isl_map *got;
279 isl_map *lost;
280 public:
281 computation *comp[2];
282 isl_map *want;
283 isl_map *need;
284 std::set<eq_node *> assumed;
285 unsigned closed : 1;
286 unsigned narrowing : 1;
287 unsigned widening : 1;
288 unsigned invalidated : 1;
289 unsigned reset : 1;
290 std::string trace;
291 isl_map *lost_sample;
293 eq_node(computation *c1, computation *c2,
294 isl_map *w) : want(w), closed(0),
295 widening(0), narrowing(0), invalidated(0), reset(0),
296 need(NULL), got(NULL), lost(NULL), lost_sample(NULL) {
297 comp[0] = c1;
298 comp[1] = c2;
300 bool is_still_valid();
301 void collect_open_assumed(std::set<eq_node *> &c);
302 ~eq_node() {
303 isl_map_free(want);
304 isl_map_free(need);
305 isl_map_free(got);
306 isl_map_free(lost);
307 isl_map_free(lost_sample);
309 void compute_got() {
310 assert(lost);
311 got = isl_map_copy(want);
312 got = isl_map_subtract(got, isl_map_copy(lost));
314 isl_map *get_got() {
315 if (!got)
316 compute_got();
317 return isl_map_copy(got);
319 isl_map *peek_got() {
320 if (!got)
321 compute_got();
322 return got;
324 void compute_lost() {
325 assert(got);
326 lost = isl_map_copy(want);
327 lost = isl_map_subtract(lost, isl_map_copy(got));
329 isl_map *get_lost() {
330 if (!lost)
331 compute_lost();
332 return isl_map_copy(lost);
334 isl_map *peek_lost() {
335 if (!lost)
336 compute_lost();
337 return lost;
339 void set_got(isl_map *got) {
340 isl_map_free(this->lost);
341 isl_map_free(this->got);
342 this->lost = NULL;
343 this->got = got;
345 void set_lost(isl_map *lost) {
346 isl_map_free(this->lost);
347 isl_map_free(this->got);
348 this->got = NULL;
349 this->lost = lost;
351 std::string to_string();
354 std::string eq_node::to_string()
356 std::ostringstream strm;
357 strm << comp[0]->location << "," << comp[0]->operation
358 << "/" << comp[0]->arity;
359 strm << " <-> ";
360 strm << comp[1]->location << "," << comp[1]->operation
361 << "/" << comp[1]->arity;
362 strm << std::endl;
363 return strm.str();
366 bool eq_node::is_still_valid()
368 if (invalidated)
369 return 0;
371 std::set<eq_node *>::iterator i;
372 for (i = assumed.begin(); i != assumed.end(); ++i) {
373 assert(*i != this);
374 if ((*i)->reset ||
375 ((*i)->closed && !(*i)->is_still_valid())) {
376 invalidated = 1;
377 return 0;
380 return 1;
383 void eq_node::collect_open_assumed(std::set<eq_node *> &c)
385 std::set<eq_node *>::iterator i;
386 for (i = assumed.begin(); i != assumed.end(); ++i) {
387 if ((*i)->closed)
388 (*i)->collect_open_assumed(c);
389 else
390 c.insert(*i);
394 /* A comp_pair contains all the edges that have the same pair
395 * of computations.
397 struct comp_pair {
398 computation *comp[2];
400 std::vector<eq_node *> nodes;
402 eq_node *tabled(eq_node *node);
403 eq_node *last_ancestor(eq_node *node);
404 ~comp_pair();
407 comp_pair::~comp_pair()
409 std::vector<eq_node *>::iterator i;
410 for (i = nodes.begin(); i != nodes.end(); ++i)
411 delete *i;
414 eq_node *comp_pair::tabled(eq_node *node)
416 std::vector<eq_node *>::iterator i;
418 for (i = nodes.begin(); i != nodes.end(); ++i) {
419 if (*i == node)
420 continue;
421 if (!(*i)->closed)
422 continue;
423 if (!(*i)->is_still_valid())
424 continue;
425 int is_subset;
426 is_subset = isl_map_is_subset(node->want, (*i)->want);
427 assert(is_subset >= 0);
428 if (!is_subset)
429 continue;
430 return *i;
432 return NULL;
435 eq_node *comp_pair::last_ancestor(eq_node *node)
437 std::vector<eq_node *>::reverse_iterator i;
438 for (i = nodes.rbegin(); i != nodes.rend(); ++i) {
439 if (*i == node)
440 continue;
441 if ((*i)->closed)
442 continue;
443 return *i;
445 return NULL;
448 typedef std::pair<computation *, computation *> computation_pair;
449 typedef std::map<computation_pair, comp_pair *> c2p_t;
451 struct equivalence_checker {
452 c2p_t c2p;
453 struct options *options;
454 int widenings;
455 int narrowings;
457 equivalence_checker(struct options *o) :
458 options(o), widenings(0), narrowings(0) {}
459 comp_pair *get_comp_pair(eq_node *node);
460 void handle(eq_node *node);
461 void dismiss(eq_node *node);
462 void handle_propagation(eq_node *node);
463 void handle_copy(eq_node *node, int i);
464 void handle_widening(eq_node *node, eq_node *a);
465 void handle_with_ancestor(eq_node *node, eq_node *a);
466 isl_map *lost_at_pos(eq_node *node, int pos1, int pos2);
467 void handle_propagation_same(eq_node *node);
468 void handle_propagation_comm(eq_node *node);
469 void handle_narrowing(eq_node *node);
471 isl_map *lost_from_propagation(eq_node *parent,
472 eq_node *node, edge *e1, edge *e2);
474 void init_trace(eq_node *node);
475 void extend_trace_propagation(eq_node *node, eq_node *child,
476 edge *e1, edge *e2);
477 void extend_trace_widening(eq_node *node, eq_node *child);
479 bool is_commutative(const char *op);
481 ~equivalence_checker();
484 equivalence_checker::~equivalence_checker()
486 c2p_t::iterator i;
488 for (i = c2p.begin(); i != c2p.end(); ++i)
489 delete (*i).second;
492 bool equivalence_checker::is_commutative(const char *op)
494 std::vector<const char *>::iterator iter;
496 for (iter = options->ops->commutative.begin();
497 iter != options->ops->commutative.end(); ++iter)
498 if (!strcmp(*iter, op))
499 return 1;
500 return 0;
503 comp_pair *equivalence_checker::get_comp_pair(eq_node *node)
505 c2p_t::iterator i;
506 comp_pair *cp;
508 i = c2p.find(computation_pair(node->comp[0], node->comp[1]));
509 if (i == c2p.end()) {
510 cp = new comp_pair;
511 c2p[computation_pair(node->comp[0], node->comp[1])] = cp;
512 } else
513 cp = (*i).second;
514 return cp;
517 void equivalence_checker::dismiss(eq_node *node)
519 /* unclosed nodes weren't added to c2p->nodes */
520 if (node && !node->closed)
521 delete node;
524 void equivalence_checker::init_trace(eq_node *node)
526 isl_ctx *ctx;
527 isl_printer *prn;
529 if (!options->trace_error)
530 return;
531 if (isl_map_is_empty(node->peek_lost()))
532 return;
533 node->trace = node->to_string();
534 std::cerr << node->trace;
535 isl_map_free(node->lost_sample);
536 node->lost_sample = isl_map_from_basic_map(
537 isl_map_sample(node->get_lost()));
538 ctx = isl_map_get_ctx(node->lost_sample);
539 prn = isl_printer_to_file(ctx, stderr);
540 prn = isl_printer_print_map(prn, node->lost_sample);
541 prn = isl_printer_end_line(prn);
542 isl_printer_free(prn);
545 void equivalence_checker::extend_trace_propagation(eq_node *node,
546 eq_node *child, edge *e1, edge *e2)
548 isl_ctx *ctx;
549 isl_printer *prn;
551 if (!options->trace_error)
552 return;
553 if (!child->lost_sample)
554 return;
555 if (node->lost_sample)
556 return;
557 node->trace = child->trace + node->to_string();
558 std::cerr << node->trace;
559 node->lost_sample = isl_map_copy(child->lost_sample);
560 if (e1)
561 node->lost_sample = isl_map_apply_domain(node->lost_sample,
562 isl_map_reverse(
563 isl_map_copy(e1->relation)));
564 if (e2)
565 node->lost_sample = isl_map_apply_range(node->lost_sample,
566 isl_map_reverse(
567 isl_map_copy(e2->relation)));
568 node->lost_sample = isl_map_intersect(node->lost_sample,
569 isl_map_copy(node->want));
570 node->lost_sample = isl_map_from_basic_map(
571 isl_map_sample(node->lost_sample));
572 ctx = isl_map_get_ctx(node->lost_sample);
573 prn = isl_printer_to_file(ctx, stderr);
574 prn = isl_printer_print_map(prn, node->lost_sample);
575 prn = isl_printer_end_line(prn);
576 isl_printer_free(prn);
579 void equivalence_checker::extend_trace_widening(eq_node *node, eq_node *child)
581 isl_ctx *ctx;
582 isl_printer *prn;
584 if (!options->trace_error)
585 return;
586 if (!child->lost_sample)
587 return;
588 if (isl_map_is_empty(node->peek_lost()))
589 return;
590 node->trace = child->trace + node->to_string();
591 std::cerr << node->trace;
592 node->lost_sample = isl_map_from_basic_map(
593 isl_map_sample(node->get_lost()));
594 ctx = isl_map_get_ctx(node->lost_sample);
595 prn = isl_printer_to_file(ctx, stderr);
596 prn = isl_printer_print_map(prn, node->lost_sample);
597 prn = isl_printer_end_line(prn);
598 isl_printer_free(prn);
601 /* Check for which subset (got) of the want relation equivelance
602 * holds for the pair of computations in the equivalence node.
604 * We first handle the easy cases: empty want, input computations
605 * and tabled nodes.
607 * If the current node is not a narrowing or a widening node
608 * and we can find an ancestor with the same pair of computations,
609 * then we will try to apply induction or widening in handle_with_ancestor.
610 * However, if the requested relation (want) of the ancestor is a strict
611 * subset of that of the current node, then we have already applied
612 * widening on an intermediate node (with a different pair of computations)
613 * so we shouldn't apply widening again (and we can't apply induction
614 * because the relation of the ancestor is a strict subset).
616 * In all other cases we try to apply propagation in handle_propagation.
618 void equivalence_checker::handle(eq_node *node)
620 computation *comp1 = node->comp[0];
621 computation *comp2 = node->comp[1];
623 if (!comp1->is_copy() && !comp2->is_copy() &&
624 (strcmp(comp1->operation, comp2->operation) ||
625 comp1->arity != comp2->arity)) {
626 node->set_lost(isl_map_copy(node->want));
627 init_trace(node);
628 return;
631 eq_node *s = NULL;
632 eq_node *a = NULL;
634 isl_map *want = node->want;
635 node->need = isl_map_empty_like(want);
636 int empty = isl_map_is_empty(want);
637 assert(empty >= 0);
638 if (empty) {
639 node->set_lost(isl_map_empty_like(want));
640 return;
642 if (node->comp[0]->is_input() && node->comp[1]->is_input()) {
643 if (strcmp(node->comp[0]->operation, node->comp[1]->operation)) {
644 node->set_lost(isl_map_copy(want));
645 return;
647 node->set_lost(isl_map_subtract(
648 isl_map_copy(node->want), isl_map_identity_like(want)));
649 return;
652 comp_pair *cp = get_comp_pair(node);
653 if ((s = cp->tabled(node)) != NULL) {
654 node->set_lost(isl_map_intersect(s->get_lost(),
655 isl_map_copy(node->want)));
656 s->collect_open_assumed(node->assumed);
657 return;
660 cp->nodes.push_back(node);
662 if (!node->narrowing && !node->widening &&
663 (a = cp->last_ancestor(node)) != NULL) {
664 int is_subset;
665 is_subset = isl_map_is_strict_subset(a->want, node->want);
666 assert(is_subset >= 0);
667 if (is_subset)
668 handle_propagation(node);
669 else
670 handle_with_ancestor(node, a);
671 } else
672 handle_propagation(node);
673 node->closed = 1;
676 /* Check if we can apply propagation to prove equivalence of the given node.
677 * First check if we need to apply copy propagation and if not
678 * check if the operations are the same and apply "regular" propagation.
680 * After we get back, we need to check that any induction hypotheses
681 * we have used in the process of proving the node hold.
682 * If not, we replace the obtained relation (got) by that
683 * of a narrowing node in handle_narrowin.
685 void equivalence_checker::handle_propagation(eq_node *node)
687 if (node->comp[0]->is_copy())
688 handle_copy(node, 0);
689 else if (node->comp[1]->is_copy())
690 handle_copy(node, 1);
691 else
692 handle_propagation_same(node);
693 node->assumed.erase(node);
694 isl_map *lost = node->get_lost();
695 lost = isl_map_intersect(lost, isl_map_copy(node->need));
696 int is_empty = isl_map_is_empty(lost);
697 isl_map_free(lost);
698 assert(is_empty >= 0);
699 if (!is_empty)
700 handle_narrowing(node);
703 /* When applying the mappings on expansion edges to both sides of
704 * an equivalence relation, each element in the original equivalence
705 * relation is mapped to many elements on both sides. For example,
706 * if the original equivalence relation has i R i' for i = i', then the new
707 * equivalence relation may have (i,j) R (i',j') for i = i' and
708 * 0 <= j,j' <= 10, expressing that all elements of the row read by
709 * iteration i should be equal to all elements of the row read by i'.
710 * Instead, we want to express that each individual element read by i'
711 * should be equal to the corresponding element read by i'.
712 * In the example, we need to introduce the equality j = j'.
714 * We first check that the number of dimensions added by the expansions
715 * is the same in both programs. Then we construct an identity relation
716 * between this number of dimensions, lift it to the space of the
717 * new equivalence relation and take the intersection.
719 * We have to be careful, though, that we don't loose any array elements
720 * by taking the intersection. In particular, the expansion maps
721 * a given read operation to iterations of an extended domain that
722 * reads all the individual array elements. The read operation is
723 * only equivalent to some other read operation if all the reads
724 * of the individual array elements are equivalent.
725 * We therefore need to make sure that adding the equalities does
726 * not remove any of the reads in the extended domain.
727 * We do this by projecting out the extra dimensions on one side
728 * from both "want" and "want \cap eq". The resulting maps should
729 * be the same. We do this check for both sides separately.
731 * If either of the above tests fails, then we simply return
732 * the original over-ambitious want.
734 static isl_map *expansion_want(isl_map *want, edge *e1, edge *e2)
736 unsigned s_dim_1 = isl_map_n_out(e1->relation) -
737 isl_map_n_in(e1->relation);
738 unsigned s_dim_2 = isl_map_n_out(e2->relation) -
739 isl_map_n_in(e2->relation);
740 if (s_dim_1 != s_dim_2)
741 return want;
743 isl_space *dim = isl_map_get_space(e1->relation);
744 dim = isl_space_drop_dims(dim, isl_dim_in,
745 0, isl_space_dim(dim, isl_dim_in));
746 dim = isl_space_drop_dims(dim, isl_dim_out,
747 0, isl_space_dim(dim, isl_dim_out));
748 dim = isl_space_add_dims(dim, isl_dim_in, s_dim_1);
749 dim = isl_space_add_dims(dim, isl_dim_out, s_dim_1);
750 isl_basic_map *id = isl_basic_map_identity(dim);
752 dim = isl_space_range(isl_map_get_space(e1->relation));
753 isl_basic_map *s_1 = isl_basic_map_identity(isl_space_map_from_set(dim));
754 s_1 = isl_basic_map_remove_dims(s_1, isl_dim_in, 0,
755 isl_basic_map_n_in(s_1) - s_dim_1);
756 id = isl_basic_map_apply_domain(id, s_1);
758 dim = isl_space_range(isl_map_get_space(e2->relation));
759 isl_basic_map *s_2 = isl_basic_map_identity(isl_space_map_from_set(dim));
760 s_2 = isl_basic_map_remove_dims(s_2, isl_dim_in, 0,
761 isl_basic_map_n_in(s_2) - s_dim_2);
762 id = isl_basic_map_apply_range(id, s_2);
764 bool unmatched = false;
765 isl_map *matched_want;
766 isl_map *proj_want, *proj_matched;
767 matched_want = isl_map_intersect(isl_map_copy(want),
768 isl_map_from_basic_map(id));
770 proj_want = isl_map_remove_dims(isl_map_copy(want),
771 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
772 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
773 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
774 if (!isl_map_is_equal(proj_want, proj_matched))
775 unmatched = true;
776 isl_map_free(proj_want);
777 isl_map_free(proj_matched);
779 proj_want = isl_map_remove_dims(isl_map_copy(want),
780 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
781 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
782 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
783 if (!isl_map_is_equal(proj_want, proj_matched))
784 unmatched = true;
785 isl_map_free(proj_want);
786 isl_map_free(proj_matched);
788 if (unmatched) {
789 isl_map_free(matched_want);
790 return matched_want;
792 isl_map_free(want);
793 return matched_want;
796 static isl_map *propagation_want(eq_node *node, edge *e1, edge *e2)
798 isl_map *new_want;
800 new_want = isl_map_copy(node->want);
801 if (e1)
802 new_want = isl_map_apply_domain(new_want,
803 isl_map_copy(e1->relation));
804 if (e2)
805 new_want = isl_map_apply_range(new_want,
806 isl_map_copy(e2->relation));
808 if (e1 && e2 && e1->type == edge::expansion)
809 new_want = expansion_want(new_want, e1, e2);
811 return isl_map_detect_equalities(new_want);
814 static eq_node *propagation_node(eq_node *node, edge *e1, edge *e2)
816 isl_map *new_want;
817 computation *comp1 = e1 ? e1->source : node->comp[0];
818 computation *comp2 = e2 ? e2->source : node->comp[1];
820 new_want = propagation_want(node, e1, e2);
821 eq_node *child = new eq_node(comp1, comp2, new_want);
822 return child;
825 static isl_map *set_to_empty(isl_map *map)
827 isl_map *empty;
828 empty = isl_map_empty_like(map);
829 isl_map_free(map);
830 return empty;
833 /* Propagate "lost" from child back to parent.
835 isl_map *equivalence_checker::lost_from_propagation(eq_node *parent,
836 eq_node *node, edge *e1, edge *e2)
838 isl_map *new_lost;
839 new_lost = node->get_lost();
841 if (e1)
842 new_lost = isl_map_apply_domain(new_lost,
843 isl_map_reverse(isl_map_copy(e1->relation)));
844 if (e2)
845 new_lost = isl_map_apply_range(new_lost,
846 isl_map_reverse(isl_map_copy(e2->relation)));
848 new_lost = isl_map_intersect(new_lost, isl_map_copy(parent->want));
850 extend_trace_propagation(parent, node, e1, e2);
852 return new_lost;
855 void equivalence_checker::handle_copy(eq_node *node, int i)
857 std::vector<edge *>::iterator iter;
858 computation *copy = node->comp[i];
859 isl_map *want = node->want;
860 isl_map *lost;
862 lost = isl_map_empty_like(want);
864 for (iter = copy->edges.begin(); iter != copy->edges.end(); ++iter) {
865 edge *e = *iter;
866 eq_node *child;
867 if (i == 0)
868 child = propagation_node(node, e, NULL);
869 else
870 child = propagation_node(node, NULL, e);
872 if (!child)
873 continue;
875 handle(child);
877 isl_map *new_lost;
878 if (i == 0)
879 new_lost = lost_from_propagation(node, child, e, NULL);
880 else
881 new_lost = lost_from_propagation(node, child, NULL, e);
882 dismiss(child);
884 lost = isl_map_union_disjoint(lost, new_lost);
886 node->set_lost(lost);
889 /* Compute and return the part of "want" that is lost when propagating
890 * over all edges of argument position pos1 in the first program
891 * and argument position pos2 in the second program.
892 * Since the domains of the dependence mappings on edges with the same
893 * argument position partition the domain of the computation (and are
894 * therefore disjoint), we simply need to take the disjoint union
895 * of all losts over all pairs of edges.
897 isl_map *equivalence_checker::lost_at_pos(eq_node *node, int pos1, int pos2)
899 isl_map *lost;
900 lost = isl_map_empty_like(node->want);
902 for (int i = 0; i < node->comp[0]->edges.size(); ++i) {
903 edge *e1 = node->comp[0]->edges[i];
904 if (e1->pos != pos1)
905 continue;
906 for (int j = 0; j < node->comp[1]->edges.size(); ++j) {
907 edge *e2 = node->comp[1]->edges[j];
908 if (e2->pos != pos2)
909 continue;
911 eq_node *child;
912 isl_map *new_lost = NULL;
914 child = propagation_node(node, e1, e2);
915 handle(child);
916 new_lost = lost_from_propagation(node, child, e1, e2);
918 lost = isl_map_union_disjoint(lost, new_lost);
920 std::set<eq_node *>::iterator k;
921 for (k = child->assumed.begin();
922 k != child->assumed.end(); ++k)
923 node->assumed.insert(*k);
924 dismiss(child);
927 return lost;
930 /* Compute the lost that results from propagation on a pair of
931 * computations with a commutative operation.
933 * We first compute the losts for each pair of argument positions
934 * and store the result in lost[pos1][pos2].
935 * Then we perform a backtracking search over all permutations
936 * of the arguments. For each permutation, we compute the lost
937 * relation as the union of the losts over all arguments.
938 * The final lost is the intersection of all these losts over all
939 * permutations.
941 void equivalence_checker::handle_propagation_comm(eq_node *node)
943 int trace_error = options->trace_error;
944 trace_error = 0;
946 unsigned r = node->comp[0]->arity;
948 std::vector<std::vector<isl_map *> > lost;
950 for (int i = 0; i < r; ++i) {
951 std::vector<isl_map *> row;
952 for (int j = 0; j < r; ++j) {
953 isl_map *pos_lost;
954 pos_lost = lost_at_pos(node, i, j);
955 row.push_back(pos_lost);
957 lost.push_back(row);
960 int level;
961 std::vector<int> perm;
962 std::vector<isl_map *> lost_at;
963 for (level = 0; level < r; ++level) {
964 perm.push_back(0);
965 lost_at.push_back(NULL);
968 isl_map *total_lost;
969 total_lost = isl_map_copy(node->want);
971 level = 0;
972 while (level >= 0) {
973 if (perm[level] == r) {
974 perm[level] = 0;
975 --level;
976 if (level >= 0)
977 ++perm[level];
978 continue;
980 int l;
981 for (l = 0; l < level; ++l)
982 if (perm[l] == perm[level])
983 break;
984 if (l != level) {
985 ++perm[level];
986 continue;
989 isl_map_free(lost_at[level]);
990 lost_at[level] = isl_map_copy(lost[level][perm[level]]);
991 if (level != 0)
992 lost_at[level] = isl_map_union(lost_at[level],
993 isl_map_copy(lost_at[level - 1]));
995 if (level < r - 1) {
996 ++level;
997 continue;
1000 lost_at[level] = isl_map_coalesce(lost_at[level]);
1001 total_lost = isl_map_intersect(total_lost,
1002 isl_map_copy(lost_at[level]));
1003 ++perm[level];
1006 for (int i = 0; i < r; ++i)
1007 isl_map_free(lost_at[i]);
1009 for (int i = 0; i < r; ++i)
1010 for (int j = 0; j < r; ++j)
1011 isl_map_free(lost[i][j]);
1013 node->set_lost(total_lost);
1015 options->trace_error = trace_error;
1016 init_trace(node);
1019 /* Compute the lost that results from propagation on a pair of
1020 * computations.
1022 * First, for functions of zero arity (i.e., constants), equivalence
1023 * always holds and the lost relation is empty.
1024 * For commutative operations, the computation is delegated
1025 * to handle_propagation_comm.
1026 * Otherwise, we simply take the union of the losts over each
1027 * argument position (always taking the same argument position
1028 * in both programs).
1030 void equivalence_checker::handle_propagation_same(eq_node *node)
1032 unsigned r = node->comp[0]->arity;
1033 if (r == 0) {
1034 node->set_lost(isl_map_empty_like(node->want));
1035 return;
1037 if (is_commutative(node->comp[0]->operation)) {
1038 handle_propagation_comm(node);
1039 return;
1042 isl_map *lost;
1043 lost = isl_map_empty_like(node->want);
1044 for (int i = 0; i < r; ++i) {
1045 isl_map *pos_lost;
1046 pos_lost = lost_at_pos(node, i, i);
1047 lost = isl_map_union(lost, pos_lost);
1049 lost = isl_map_coalesce(lost);
1050 node->set_lost(lost);
1053 void equivalence_checker::handle_widening(eq_node *node, eq_node *a)
1055 isl_map *wants;
1056 isl_map *aff;
1058 wants = isl_map_union(isl_map_copy(node->want), isl_map_copy(a->want));
1059 aff = isl_map_from_basic_map(isl_map_affine_hull(wants));
1060 aff = isl_map_intersect_domain(aff,
1061 isl_set_copy(node->comp[0]->domain));
1062 aff = isl_map_intersect_range(aff,
1063 isl_set_copy(node->comp[1]->domain));
1065 eq_node *child = new eq_node(node->comp[0], node->comp[1], aff);
1066 child->widening = 1;
1067 widenings++;
1068 handle(child);
1069 node->set_lost(isl_map_intersect(child->get_lost(),
1070 isl_map_copy(node->want)));
1071 extend_trace_widening(node, child);
1072 std::set<eq_node *>::iterator i;
1073 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1074 node->assumed.insert(*i);
1075 dismiss(child);
1078 /* Perform induction, if possible, and widening if we have to.
1080 void equivalence_checker::handle_with_ancestor(eq_node *node, eq_node *a)
1082 if (a->narrowing || isl_map_is_subset(node->want, a->want)) {
1083 isl_map *need;
1084 need = isl_map_intersect(isl_map_copy(node->want),
1085 isl_map_copy(a->want));
1086 node->set_lost(isl_map_subtract(isl_map_copy(node->want),
1087 isl_map_copy(a->want)));
1088 node->assumed.insert(a);
1089 a->need = isl_map_union(a->need, need);
1090 } else {
1091 handle_widening(node, a);
1095 struct narrowing_data {
1096 eq_node *node;
1097 equivalence_checker *ec;
1098 isl_map *new_got;
1101 static int basic_handle_narrowing(__isl_take isl_basic_map *bmap, void *user)
1103 narrowing_data *data = (narrowing_data *)user;
1105 eq_node *child;
1106 child = new eq_node(data->node->comp[0], data->node->comp[1],
1107 isl_map_from_basic_map(bmap));
1108 child->narrowing = 1;
1109 data->ec->narrowings++;
1110 data->ec->handle(child);
1111 data->new_got = isl_map_union(data->new_got, child->get_got());
1113 std::set<eq_node *>::iterator i;
1114 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1115 data->node->assumed.insert(*i);
1116 data->ec->dismiss(child);
1118 return 0;
1121 /* Construct and handle narrowing nodes for the given node.
1123 * If the node itself was already a narrowing node, then we
1124 * simply return the empty relation.
1125 * Otherwise, we consider each basic relation in the obtained
1126 * relation, construct a new node with that basic relation as
1127 * requested relation and take the union of all obtained relations.
1129 void equivalence_checker::handle_narrowing(eq_node *node)
1131 node->reset = 1;
1132 node->assumed.clear();
1134 isl_map *want = node->want;
1135 isl_map *new_got = isl_map_empty_like(want);
1136 if (!options->narrowing || node->narrowing) {
1137 node->set_got(new_got);
1138 return;
1141 narrowing_data data = { node, this, new_got };
1142 isl_map_foreach_basic_map(node->peek_got(), &basic_handle_narrowing,
1143 &data);
1144 node->set_got(data.new_got);
1147 static __isl_give isl_union_set *update_result(__isl_take isl_union_set *res,
1148 const char *array_name, isl_map *map, int first, int n)
1150 if (isl_map_is_empty(map)) {
1151 isl_map_free(map);
1152 return res;
1155 isl_set *range = isl_map_range(map);
1156 range = isl_set_remove_dims(range, isl_dim_set, first, n);
1157 range = isl_set_coalesce(range);
1158 range = isl_set_set_tuple_name(range, array_name);
1159 res = isl_union_set_add_set(res, range);
1161 return res;
1164 struct check_equivalence_data {
1165 dependence_graph *dg1;
1166 dependence_graph *dg2;
1167 equivalence_checker *ec;
1168 isl_map *got;
1169 isl_map *lost;
1172 static int basic_check(__isl_take isl_basic_map *bmap, void *user)
1174 check_equivalence_data *data = (check_equivalence_data *)user;
1176 eq_node *root = new eq_node(data->dg1->out, data->dg2->out,
1177 isl_map_from_basic_map(bmap));
1178 data->ec->handle(root);
1179 data->got = isl_map_union_disjoint(data->got, root->get_got());
1180 data->lost = isl_map_union_disjoint(data->lost, root->get_lost());
1181 data->ec->dismiss(root);
1183 return 0;
1186 static int check_equivalence_array(isl_ctx *ctx, equivalence_checker *ec,
1187 dependence_graph *dg1, dependence_graph *dg2, int array1, int array2,
1188 isl_union_set **proved, isl_union_set **not_proved)
1190 const char *array_name = dg1->output_arrays[array1];
1191 isl_set *out1 = isl_set_copy(dg1->out->domain);
1192 isl_set *out2 = isl_set_copy(dg2->out->domain);
1193 unsigned dim1 = isl_set_n_dim(out1);
1194 unsigned dim2 = isl_set_n_dim(out2);
1195 unsigned array_dim = dg1->output_array_dims[array1];
1196 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1197 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1198 out1 = isl_set_remove_dims(out1, isl_dim_set, dim1 - 1, 1);
1199 out2 = isl_set_remove_dims(out2, isl_dim_set, dim2 - 1, 1);
1200 int equal = isl_set_is_equal(out1, out2);
1201 isl_set_free(out1);
1202 isl_set_free(out2);
1203 assert(equal >= 0);
1204 if (!equal) {
1205 fprintf(stderr, "different output domains for array %s\n",
1206 array_name);
1207 return -1;
1210 out1 = isl_set_copy(dg1->out->domain);
1211 out2 = isl_set_copy(dg2->out->domain);
1212 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1213 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1214 out1 = isl_set_coalesce(out1);
1215 out2 = isl_set_coalesce(out2);
1216 isl_space *dim = isl_space_map_from_set(isl_set_get_space(out2));
1217 isl_map *id = isl_map_identity(dim);
1218 id = isl_map_remove_dims(id, isl_dim_in, isl_map_n_in(id) - 1, 1);
1219 id = isl_map_apply_domain(id, isl_map_copy(id));
1220 id = isl_map_intersect_domain(id, out1);
1221 id = isl_map_intersect_range(id, out2);
1223 isl_map *got = isl_map_empty_like(id);
1224 isl_map *lost = isl_map_empty_like(id);
1225 check_equivalence_data data = { dg1, dg2, ec, got, lost };
1226 isl_map_foreach_basic_map(id, &basic_check, &data);
1227 isl_map_free(id);
1228 *proved = update_result(*proved,
1229 array_name, data.got, array_dim, dim2 - array_dim);
1230 *not_proved = update_result(*not_proved,
1231 array_name, data.lost, array_dim, dim2 - array_dim);
1232 return 0;
1235 /* The input arrays of the two programs are supposed to be the same,
1236 * so they should at least have the same dimension. Make sure
1237 * this is true, because we depend on it later on.
1239 static int check_input_arrays(dependence_graph *dg1, dependence_graph *dg2)
1241 for (int i = 0; i < dg1->input_computations.size(); ++i)
1242 for (int j = 0; j < dg2->input_computations.size(); ++j) {
1243 if (strcmp(dg1->input_computations[i]->operation,
1244 dg2->input_computations[j]->operation))
1245 continue;
1246 if (dg1->input_computations[i]->dim ==
1247 dg2->input_computations[j]->dim)
1248 continue;
1249 fprintf(stderr,
1250 "input arrays \"%s\" do not have the same dimension\n",
1251 dg1->input_computations[i]->operation);
1252 return -1;
1255 return 0;
1258 static __isl_give isl_union_set *add_array(__isl_take isl_union_set *only,
1259 dependence_graph *dg, int array)
1261 isl_space *dim;
1262 isl_set *set;
1264 dim = isl_union_set_get_space(only);
1265 dim = isl_space_add_dims(dim, isl_dim_set, dg->output_array_dims[array]);
1266 dim = isl_space_set_tuple_name(dim, isl_dim_set, dg->output_arrays[array]);
1267 set = isl_set_universe(dim);
1268 only = isl_union_set_add_set(only, set);
1270 return only;
1273 static void print_results(const char *str, __isl_keep isl_union_set *only)
1275 isl_printer *prn;
1277 if (isl_union_set_is_empty(only))
1278 return;
1280 fprintf(stdout, "%s: '", str);
1281 prn = isl_printer_to_file(isl_union_set_get_ctx(only), stdout);
1282 prn = isl_printer_print_union_set(prn, only);
1283 isl_printer_free(prn);
1284 fprintf(stdout, "'\n");
1287 static int check_equivalence(isl_ctx *ctx,
1288 dependence_graph *dg1, dependence_graph *dg2, options *options)
1290 isl_space *dim;
1291 isl_set *context;
1292 isl_union_set *only1, *only2, *proved, *not_proved;
1293 dg1->flatten_associative_operators(options->ops->associative);
1294 dg2->flatten_associative_operators(options->ops->associative);
1295 equivalence_checker ec(options);
1296 int i1 = 0, i2 = 0;
1298 if (check_input_arrays(dg1, dg2))
1299 return -1;
1301 dim = isl_space_set_alloc(ctx, 0, 0);
1302 proved = isl_union_set_empty(isl_space_copy(dim));
1303 not_proved = isl_union_set_empty(isl_space_copy(dim));
1304 only1 = isl_union_set_empty(isl_space_copy(dim));
1305 only2 = isl_union_set_empty(dim);
1307 while (i1 < dg1->output_arrays.size() || i2 < dg2->output_arrays.size()) {
1308 int cmp;
1309 cmp = i1 == dg1->output_arrays.size() ? 1 :
1310 i2 == dg2->output_arrays.size() ? -1 :
1311 strcmp(dg1->output_arrays[i1], dg2->output_arrays[i2]);
1312 if (cmp < 0) {
1313 only1 = add_array(only1, dg1, i1);
1314 ++i1;
1315 } else if (cmp > 0) {
1316 only2 = add_array(only2, dg2, i2);
1317 ++i2;
1318 } else {
1319 check_equivalence_array(ctx, &ec, dg1, dg2, i1, i2,
1320 &proved, &not_proved);
1321 ++i1;
1322 ++i2;
1326 context = isl_set_union(isl_set_copy(dg1->context),
1327 isl_set_copy(dg2->context));
1328 proved = isl_union_set_gist_params(proved, isl_set_copy(context));
1329 not_proved = isl_union_set_gist_params(not_proved, context);
1331 print_results("Equivalence proved", proved);
1332 print_results("Equivalence NOT proved", not_proved);
1333 print_results("Only in program 1", only1);
1334 print_results("Only in program 2", only2);
1336 isl_union_set_free(proved);
1337 isl_union_set_free(not_proved);
1338 isl_union_set_free(only1);
1339 isl_union_set_free(only2);
1341 if (options->print_stats) {
1342 fprintf(stderr, "widenings: %d\n", ec.widenings);
1343 if (options->narrowing)
1344 fprintf(stderr, "narrowings: %d\n", ec.narrowings);
1347 return 0;
1350 static void dump_vertex(FILE *out, computation *comp)
1352 fprintf(out, "ND_%p [label = \"%d,%s/%d\"];\n",
1353 comp, comp->location, comp->operation, comp->arity);
1354 for (int i = 0; i < comp->edges.size(); ++i)
1355 fprintf(out, "ND_%p -> ND_%p%s;\n",
1356 comp, comp->edges[i]->source,
1357 comp->edges[i]->type == edge::expansion ?
1358 " [color=\"blue\"]" : "");
1361 static void dump_graph(FILE *out, dependence_graph *dg)
1363 fprintf(out, "digraph dummy {\n");
1364 dump_vertex(out, dg->out);
1365 for (int i = 0; i < dg->vertices.size(); ++i)
1366 dump_vertex(out, dg->vertices[i]);
1367 fprintf(out, "}\n");
1370 static void dump_graphs(dependence_graph **dg, struct options *options)
1372 int i;
1373 char path[PATH_MAX];
1375 if (!options->dump_graphs)
1376 return;
1378 for (i = 0; i < 2; ++i) {
1379 FILE *out;
1380 int s;
1381 s = snprintf(path, sizeof(path), "%s.dot", options->program[i]);
1382 assert(s < sizeof(path));
1383 out = fopen(path, "w");
1384 assert(out);
1385 dump_graph(out, dg[i]);
1386 fclose(out);
1390 void parse_ops(struct options *options) {
1391 char *tok;
1393 if (options->associative) {
1394 tok = strtok(options->associative, ",");
1395 options->ops->associative.push_back(strdup(tok));
1396 while ((tok = strtok(NULL, ",")) != NULL)
1397 options->ops->associative.push_back(strdup(tok));
1400 if (options->commutative) {
1401 tok = strtok(options->commutative, ",");
1402 options->ops->commutative.push_back(strdup(tok));
1403 while ((tok = strtok(NULL, ",")) != NULL)
1404 options->ops->commutative.push_back(strdup(tok));
1408 int main(int argc, char *argv[])
1410 struct options *options = options_new_with_defaults();
1411 struct isl_ctx *ctx;
1412 isl_set *context = NULL;
1413 dependence_graph *dg[2];
1415 argc = options_parse(options, argc, argv, ISL_ARG_ALL);
1416 parse_ops(options);
1418 ctx = isl_ctx_alloc_with_options(&options_args, options);
1419 if (!ctx) {
1420 fprintf(stderr, "Unable to allocate ctx\n");
1421 return -1;
1424 if (options->context)
1425 context = isl_set_read_from_str(ctx, options->context);
1427 pdg::PDG *pdg[2];
1428 unsigned out_dim = 0;
1429 for (int i = 0; i < 2; ++i) {
1430 FILE *in;
1431 in = fopen(options->program[i], "r");
1432 if (!in) {
1433 fprintf(stderr, "Unable to open %s\n", options->program[i]);
1434 return -1;
1436 pdg[i] = yaml::Load<pdg::PDG>(in, ctx);
1437 fclose(in);
1438 if (!pdg[i]) {
1439 fprintf(stderr, "Unable to read %s\n", options->program[i]);
1440 return -1;
1442 out_dim = update_out_dim(pdg[i], out_dim);
1444 if (context &&
1445 pdg[i]->params.size() != isl_set_dim(context, isl_dim_param)) {
1446 fprintf(stdout,
1447 "Parameter dimension mismatch; context ignored\n");
1448 isl_set_free(context);
1449 context = NULL;
1453 for (int i = 0; i < 2; ++i) {
1454 dg[i] = pdg_to_dg(pdg[i], out_dim, isl_set_copy(context));
1455 pdg[i]->free();
1456 delete pdg[i];
1459 dump_graphs(dg, options);
1461 int res = check_equivalence(ctx, dg[0], dg[1], options);
1463 for (int i = 0; i < 2; ++i)
1464 delete dg[i];
1465 isl_set_free(context);
1466 isl_ctx_free(ctx);
1468 return res;