da.cc: fix typos in comments
[ppn.git] / eqv.cc
blob9cdd6694e65235a35a2c7e2b0b307daa9894b1da
1 #include <limits.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <algorithm>
6 #include <map>
7 #include <set>
8 #include <vector>
9 #include <string>
10 #include <sstream>
12 #include <isa/yaml.h>
13 #include <isa/pdg.h>
15 #include <isl/set.h>
16 #include <isl/map.h>
17 #include <isl/constraint.h>
18 #include <isl/union_set.h>
19 #include "eqv_options.h"
20 #include "dependence_graph.h"
22 struct ops {
23 std::vector<const char *> associative;
24 std::vector<const char *> commutative;
26 int ops_init(void *user)
28 struct ops **ops = (struct ops **)user;
29 *ops = new struct ops;
30 return 0;
32 void ops_clear(void *user)
34 struct ops **ops = (struct ops **)user;
35 delete *ops;
38 /* For each basic map in relation, add an edge to comp with
39 * given source and pos.
41 static void add_split_edge(computation *comp, computation *source,
42 int pos, isl_map *relation, std::vector<computation **> *missing,
43 enum edge::type type = edge::normal)
45 edge *e = new edge;
46 e->source = source;
47 if (!source) {
48 assert(missing);
49 missing->push_back(&e->source);
51 e->pos = pos;
52 e->relation = relation;
53 e->type = type;
54 comp->edges.push_back(e);
57 static unsigned update_out_dim(pdg::PDG *pdg, unsigned out_dim)
59 for (int i = 0; i < pdg->arrays.size(); ++i) {
60 pdg::array *array = pdg->arrays[i];
61 if (array->type != pdg::array::output)
62 continue;
63 if (array->dims.size() + 1 > out_dim)
64 out_dim = array->dims.size() + 1;
67 return out_dim;
70 static bool is_associative(const std::vector<const char *> &associative,
71 const char *op)
73 std::vector<const char *>::const_iterator iter;
75 for (iter = associative.begin(); iter != associative.end(); ++iter)
76 if (!strcmp(*iter, op))
77 return 1;
78 return 0;
81 /* Return a computation that has an associative operation
82 * that takes a different computation with the same operation
83 * as one of its arguments. Also return the edge from the
84 * first to the second computation in *e.
86 * If no such computation exists, then return NULL.
88 computation *dependence_graph::associative_node(edge **e,
89 const std::vector<const char *> &associative)
91 for (int i = 0; i < vertices.size(); ++i) {
92 computation *comp = vertices[i];
93 if (!is_associative(associative, comp->operation))
94 continue;
95 for (int j = 0; j < comp->edges.size(); ++j) {
96 computation *other = comp->edges[j]->source;
97 if (comp->has_same_source(other))
98 continue;
99 if (strcmp(comp->operation, other->operation))
100 continue;
101 *e = comp->edges[j];
102 return comp;
105 return NULL;
108 /* Splice the source of edge edge into the position of edge e,
109 * removing all other edges with the same positions and bumping
110 * up edges with a greater position.
112 void dependence_graph::splice(computation *comp, edge *e)
114 computation *source = e->source;
115 int pos = e->pos;
116 isl_map *map = isl_map_copy(e->relation);
117 std::vector<struct edge *> new_edges;
119 for (int i = 0; i < comp->edges.size(); ++i) {
120 if (comp->edges[i]->pos == pos) {
121 delete comp->edges[i];
122 continue;
124 if (comp->edges[i]->pos > pos)
125 comp->edges[i]->pos += source->arity - 1;
126 new_edges.push_back(comp->edges[i]);
128 for (int i = 0; i < source->edges.size(); ++i) {
129 edge *old_e = source->edges[i];
130 edge *e = new edge;
131 e->source = old_e->source;
132 e->pos = old_e->pos + pos;
133 e->relation = isl_map_apply_range(
134 isl_map_copy(map),
135 isl_map_copy(old_e->relation));
136 new_edges.push_back(e);
138 isl_map_free(map);
139 comp->edges = new_edges;
140 comp->arity += source->arity - 1;
143 /* Split computation comp into a part that always takes the edge e
144 * and one that never takes that edge.
145 * The second is returned and the initial computation is modified
146 * to match the first.
148 * We need to be careful not to destroy e, as it is still used
149 * by the calling method.
151 computation *dependence_graph::split_comp(computation *comp, edge *e)
153 computation *dup = new computation;
154 dup->original = comp->original ? comp->original : comp;
155 dup->operation = strdup(comp->operation);
156 dup->arity = comp->arity;
157 dup->location = comp->location;
159 isl_map *map = isl_map_copy(e->relation);
160 map = isl_map_reverse(map);
161 isl_set *dom = isl_set_apply(isl_set_copy(e->source->domain), map);
162 dup->domain = isl_set_subtract(isl_set_copy(comp->domain),
163 isl_set_copy(dom));
164 comp->domain = isl_set_intersect(comp->domain, dom);
166 std::vector<struct edge *> old_edges = comp->edges;
167 comp->edges.clear();
169 for (int i = 0; i < old_edges.size(); ++i) {
170 if (old_edges[i] == e) {
171 comp->edges.push_back(e);
172 continue;
175 edge *e = old_edges[i];
176 isl_map *map, *map_dup;
177 map = isl_map_copy(e->relation);
178 map_dup = isl_map_copy(map);
180 map = isl_map_intersect_domain(map, isl_set_copy(comp->domain));
181 map_dup = isl_map_intersect_domain(map_dup,
182 isl_set_copy(dup->domain));
183 add_split_edge(comp, e->source, e->pos, map, NULL);
184 add_split_edge(dup, e->source, e->pos, map_dup, NULL);
185 delete e;
188 vertices.push_back(dup);
189 return dup;
192 /* If any edge from comp points to comp_orig, then split it
193 * into two edges, one still pointing to comp_orig and the
194 * other pointing to comp_dup.
195 * comp_orig and comp_dup are assumed to have disjoint domains
196 * and the edge relations are adjusted according to these domains.
198 void dependence_graph::split_edges(computation *comp,
199 computation *comp_orig, computation *comp_dup)
201 std::vector<struct edge *> old_edges = comp->edges;
202 comp->edges.clear();
204 for (int i = 0; i < old_edges.size(); ++i) {
205 edge *e = old_edges[i];
207 if (e->source != comp_orig) {
208 comp->edges.push_back(e);
209 continue;
212 isl_map *map_orig, *map_dup;
213 map_orig = isl_map_copy(e->relation);
214 map_dup = isl_map_copy(map_orig);
215 map_orig = isl_map_intersect_range(map_orig,
216 isl_set_copy(comp_orig->domain));
217 map_dup = isl_map_intersect_range(map_dup,
218 isl_set_copy(comp_dup->domain));
219 add_split_edge(comp, comp_orig, e->pos, map_orig, NULL);
220 add_split_edge(comp, comp_dup, e->pos, map_dup, NULL);
221 delete e;
225 void dependence_graph::split_edges(computation *comp_orig,
226 computation *comp_dup)
228 split_edges(out, comp_orig, comp_dup);
229 for (int i = 0; i < vertices.size(); ++i)
230 split_edges(vertices[i], comp_orig, comp_dup);
233 /* Replace all nested calls of an associative operator,
234 * by a call with the nested call spliced into the first call.
236 * For each nested call we find, we first check if there are
237 * any other edges with the same position. If not, then
238 * the nested call is performed for each iteration of the computation
239 * and we can simplify splice the nested call.
241 * Otherwise, we first create a duplicate computation for the iterations
242 * that do not take the found edge and adjust the edges of both computations
243 * to their domains. This meand that the edge corresponding to the
244 * nested call will no longer appear in the duplicated computation
245 * and the other edges with the same position will no longer appear
246 * in the original computation.
247 * Then we splice the nested call in the original computation.
248 * Finally, we split all edges that pointed to the original computation
249 * into two edges, one going to the original computation and one
250 * going to the duplicated computation.
252 void dependence_graph::flatten_associative_operators(
253 const std::vector<const char *> &associative)
255 edge *e;
256 computation *comp;
258 while ((comp = associative_node(&e, associative)) != NULL) {
259 computation *comp_dup = NULL;
260 int j;
261 for (j = 0; j < comp->edges.size(); ++j) {
262 if (comp->edges[j] == e)
263 continue;
264 if (comp->edges[j]->pos == e->pos)
265 break;
267 if (j != comp->edges.size())
268 comp_dup = split_comp(comp, e);
269 splice(comp, e);
270 if (comp_dup)
271 split_edges(comp, comp_dup);
275 struct eq_node {
276 private:
277 isl_map *got;
278 isl_map *lost;
279 public:
280 computation *comp[2];
281 isl_map *want;
282 isl_map *need;
283 std::set<eq_node *> assumed;
284 unsigned closed : 1;
285 unsigned narrowing : 1;
286 unsigned widening : 1;
287 unsigned invalidated : 1;
288 unsigned reset : 1;
289 std::string trace;
290 isl_map *lost_sample;
292 eq_node(computation *c1, computation *c2,
293 isl_map *w) : want(w), closed(0),
294 widening(0), narrowing(0), invalidated(0), reset(0),
295 need(NULL), got(NULL), lost(NULL), lost_sample(NULL) {
296 comp[0] = c1;
297 comp[1] = c2;
299 bool is_still_valid();
300 void collect_open_assumed(std::set<eq_node *> &c);
301 ~eq_node() {
302 isl_map_free(want);
303 isl_map_free(need);
304 isl_map_free(got);
305 isl_map_free(lost);
306 isl_map_free(lost_sample);
308 void compute_got() {
309 assert(lost);
310 got = isl_map_copy(want);
311 got = isl_map_subtract(got, isl_map_copy(lost));
313 isl_map *get_got() {
314 if (!got)
315 compute_got();
316 return isl_map_copy(got);
318 isl_map *peek_got() {
319 if (!got)
320 compute_got();
321 return got;
323 void compute_lost() {
324 assert(got);
325 lost = isl_map_copy(want);
326 lost = isl_map_subtract(lost, isl_map_copy(got));
328 isl_map *get_lost() {
329 if (!lost)
330 compute_lost();
331 return isl_map_copy(lost);
333 isl_map *peek_lost() {
334 if (!lost)
335 compute_lost();
336 return lost;
338 void set_got(isl_map *got) {
339 isl_map_free(this->lost);
340 isl_map_free(this->got);
341 this->lost = NULL;
342 this->got = got;
344 void set_lost(isl_map *lost) {
345 isl_map_free(this->lost);
346 isl_map_free(this->got);
347 this->got = NULL;
348 this->lost = lost;
350 std::string to_string();
353 std::string eq_node::to_string()
355 std::ostringstream strm;
356 strm << comp[0]->location << "," << comp[0]->operation
357 << "/" << comp[0]->arity;
358 strm << " <-> ";
359 strm << comp[1]->location << "," << comp[1]->operation
360 << "/" << comp[1]->arity;
361 strm << std::endl;
362 return strm.str();
365 bool eq_node::is_still_valid()
367 if (invalidated)
368 return 0;
370 std::set<eq_node *>::iterator i;
371 for (i = assumed.begin(); i != assumed.end(); ++i) {
372 assert(*i != this);
373 if ((*i)->reset ||
374 ((*i)->closed && !(*i)->is_still_valid())) {
375 invalidated = 1;
376 return 0;
379 return 1;
382 void eq_node::collect_open_assumed(std::set<eq_node *> &c)
384 std::set<eq_node *>::iterator i;
385 for (i = assumed.begin(); i != assumed.end(); ++i) {
386 if ((*i)->closed)
387 (*i)->collect_open_assumed(c);
388 else
389 c.insert(*i);
393 /* A comp_pair contains all the edges that have the same pair
394 * of computations.
396 struct comp_pair {
397 computation *comp[2];
399 std::vector<eq_node *> nodes;
401 eq_node *tabled(eq_node *node);
402 eq_node *last_ancestor(eq_node *node);
403 ~comp_pair();
406 comp_pair::~comp_pair()
408 std::vector<eq_node *>::iterator i;
409 for (i = nodes.begin(); i != nodes.end(); ++i)
410 delete *i;
413 eq_node *comp_pair::tabled(eq_node *node)
415 std::vector<eq_node *>::iterator i;
417 for (i = nodes.begin(); i != nodes.end(); ++i) {
418 if (*i == node)
419 continue;
420 if (!(*i)->closed)
421 continue;
422 if (!(*i)->is_still_valid())
423 continue;
424 int is_subset;
425 is_subset = isl_map_is_subset(node->want, (*i)->want);
426 assert(is_subset >= 0);
427 if (!is_subset)
428 continue;
429 return *i;
431 return NULL;
434 eq_node *comp_pair::last_ancestor(eq_node *node)
436 std::vector<eq_node *>::reverse_iterator i;
437 for (i = nodes.rbegin(); i != nodes.rend(); ++i) {
438 if (*i == node)
439 continue;
440 if ((*i)->closed)
441 continue;
442 return *i;
444 return NULL;
447 typedef std::pair<computation *, computation *> computation_pair;
448 typedef std::map<computation_pair, comp_pair *> c2p_t;
450 struct equivalence_checker {
451 c2p_t c2p;
452 struct options *options;
453 int widenings;
454 int narrowings;
456 equivalence_checker(struct options *o) :
457 options(o), widenings(0), narrowings(0) {}
458 comp_pair *get_comp_pair(eq_node *node);
459 void handle(eq_node *node);
460 void dismiss(eq_node *node);
461 void handle_propagation(eq_node *node);
462 void handle_copy(eq_node *node, int i);
463 void handle_widening(eq_node *node, eq_node *a);
464 void handle_with_ancestor(eq_node *node, eq_node *a);
465 isl_map *lost_at_pos(eq_node *node, int pos1, int pos2);
466 void handle_propagation_same(eq_node *node);
467 void handle_propagation_comm(eq_node *node);
468 void handle_narrowing(eq_node *node);
470 isl_map *lost_from_propagation(eq_node *parent,
471 eq_node *node, edge *e1, edge *e2);
473 void init_trace(eq_node *node);
474 void extend_trace_propagation(eq_node *node, eq_node *child,
475 edge *e1, edge *e2);
476 void extend_trace_widening(eq_node *node, eq_node *child);
478 bool is_commutative(const char *op);
480 ~equivalence_checker();
483 equivalence_checker::~equivalence_checker()
485 c2p_t::iterator i;
487 for (i = c2p.begin(); i != c2p.end(); ++i)
488 delete (*i).second;
491 bool equivalence_checker::is_commutative(const char *op)
493 std::vector<const char *>::iterator iter;
495 for (iter = options->ops->commutative.begin();
496 iter != options->ops->commutative.end(); ++iter)
497 if (!strcmp(*iter, op))
498 return 1;
499 return 0;
502 comp_pair *equivalence_checker::get_comp_pair(eq_node *node)
504 c2p_t::iterator i;
505 comp_pair *cp;
507 i = c2p.find(computation_pair(node->comp[0], node->comp[1]));
508 if (i == c2p.end()) {
509 cp = new comp_pair;
510 c2p[computation_pair(node->comp[0], node->comp[1])] = cp;
511 } else
512 cp = (*i).second;
513 return cp;
516 void equivalence_checker::dismiss(eq_node *node)
518 /* unclosed nodes weren't added to c2p->nodes */
519 if (node && !node->closed)
520 delete node;
523 void equivalence_checker::init_trace(eq_node *node)
525 isl_ctx *ctx;
526 isl_printer *prn;
528 if (!options->trace_error)
529 return;
530 if (isl_map_is_empty(node->peek_lost()))
531 return;
532 node->trace = node->to_string();
533 std::cerr << node->trace;
534 isl_map_free(node->lost_sample);
535 node->lost_sample = isl_map_from_basic_map(
536 isl_map_sample(node->get_lost()));
537 ctx = isl_map_get_ctx(node->lost_sample);
538 prn = isl_printer_to_file(ctx, stderr);
539 prn = isl_printer_print_map(prn, node->lost_sample);
540 prn = isl_printer_end_line(prn);
541 isl_printer_free(prn);
544 void equivalence_checker::extend_trace_propagation(eq_node *node,
545 eq_node *child, edge *e1, edge *e2)
547 isl_ctx *ctx;
548 isl_printer *prn;
550 if (!options->trace_error)
551 return;
552 if (!child->lost_sample)
553 return;
554 if (node->lost_sample)
555 return;
556 node->trace = child->trace + node->to_string();
557 std::cerr << node->trace;
558 node->lost_sample = isl_map_copy(child->lost_sample);
559 if (e1)
560 node->lost_sample = isl_map_apply_domain(node->lost_sample,
561 isl_map_reverse(
562 isl_map_copy(e1->relation)));
563 if (e2)
564 node->lost_sample = isl_map_apply_range(node->lost_sample,
565 isl_map_reverse(
566 isl_map_copy(e2->relation)));
567 node->lost_sample = isl_map_intersect(node->lost_sample,
568 isl_map_copy(node->want));
569 node->lost_sample = isl_map_from_basic_map(
570 isl_map_sample(node->lost_sample));
571 ctx = isl_map_get_ctx(node->lost_sample);
572 prn = isl_printer_to_file(ctx, stderr);
573 prn = isl_printer_print_map(prn, node->lost_sample);
574 prn = isl_printer_end_line(prn);
575 isl_printer_free(prn);
578 void equivalence_checker::extend_trace_widening(eq_node *node, eq_node *child)
580 isl_ctx *ctx;
581 isl_printer *prn;
583 if (!options->trace_error)
584 return;
585 if (!child->lost_sample)
586 return;
587 if (isl_map_is_empty(node->peek_lost()))
588 return;
589 node->trace = child->trace + node->to_string();
590 std::cerr << node->trace;
591 node->lost_sample = isl_map_from_basic_map(
592 isl_map_sample(node->get_lost()));
593 ctx = isl_map_get_ctx(node->lost_sample);
594 prn = isl_printer_to_file(ctx, stderr);
595 prn = isl_printer_print_map(prn, node->lost_sample);
596 prn = isl_printer_end_line(prn);
597 isl_printer_free(prn);
600 /* Check for which subset (got) of the want relation equivelance
601 * holds for the pair of computations in the equivalence node.
603 * We first handle the easy cases: empty want, input computations
604 * and tabled nodes.
606 * If the current node is not a narrowing or a widening node
607 * and we can find an ancestor with the same pair of computations,
608 * then we will try to apply induction or widening in handle_with_ancestor.
609 * However, if the requested relation (want) of the ancestor is a strict
610 * subset of that of the current node, then we have already applied
611 * widening on an intermediate node (with a different pair of computations)
612 * so we shouldn't apply widening again (and we can't apply induction
613 * because the relation of the ancestor is a strict subset).
615 * In all other cases we try to apply propagation in handle_propagation.
617 void equivalence_checker::handle(eq_node *node)
619 computation *comp1 = node->comp[0];
620 computation *comp2 = node->comp[1];
622 if (!comp1->is_copy() && !comp2->is_copy() &&
623 (strcmp(comp1->operation, comp2->operation) ||
624 comp1->arity != comp2->arity)) {
625 node->set_lost(isl_map_copy(node->want));
626 init_trace(node);
627 return;
630 eq_node *s = NULL;
631 eq_node *a = NULL;
633 isl_map *want = node->want;
634 node->need = isl_map_empty_like(want);
635 int empty = isl_map_is_empty(want);
636 assert(empty >= 0);
637 if (empty) {
638 node->set_lost(isl_map_empty_like(want));
639 return;
641 if (node->comp[0]->is_input() && node->comp[1]->is_input()) {
642 if (strcmp(node->comp[0]->operation, node->comp[1]->operation)) {
643 node->set_lost(isl_map_copy(want));
644 return;
646 node->set_lost(isl_map_subtract(
647 isl_map_copy(node->want), isl_map_identity_like(want)));
648 return;
651 comp_pair *cp = get_comp_pair(node);
652 if ((s = cp->tabled(node)) != NULL) {
653 node->set_lost(isl_map_intersect(s->get_lost(),
654 isl_map_copy(node->want)));
655 s->collect_open_assumed(node->assumed);
656 return;
659 cp->nodes.push_back(node);
661 if (!node->narrowing && !node->widening &&
662 (a = cp->last_ancestor(node)) != NULL) {
663 int is_subset;
664 is_subset = isl_map_is_strict_subset(a->want, node->want);
665 assert(is_subset >= 0);
666 if (is_subset)
667 handle_propagation(node);
668 else
669 handle_with_ancestor(node, a);
670 } else
671 handle_propagation(node);
672 node->closed = 1;
675 /* Check if we can apply propagation to prove equivalence of the given node.
676 * First check if we need to apply copy propagation and if not
677 * check if the operations are the same and apply "regular" propagation.
679 * After we get back, we need to check that any induction hypotheses
680 * we have used in the process of proving the node hold.
681 * If not, we replace the obtained relation (got) by that
682 * of a narrowing node in handle_narrowin.
684 void equivalence_checker::handle_propagation(eq_node *node)
686 if (node->comp[0]->is_copy())
687 handle_copy(node, 0);
688 else if (node->comp[1]->is_copy())
689 handle_copy(node, 1);
690 else
691 handle_propagation_same(node);
692 node->assumed.erase(node);
693 isl_map *lost = node->get_lost();
694 lost = isl_map_intersect(lost, isl_map_copy(node->need));
695 int is_empty = isl_map_is_empty(lost);
696 isl_map_free(lost);
697 assert(is_empty >= 0);
698 if (!is_empty)
699 handle_narrowing(node);
702 /* When applying the mappings on expansion edges to both sides of
703 * an equivalence relation, each element in the original equivalence
704 * relation is mapped to many elements on both sides. For example,
705 * if the original equivalence relation has i R i' for i = i', then the new
706 * equivalence relation may have (i,j) R (i',j') for i = i' and
707 * 0 <= j,j' <= 10, expressing that all elements of the row read by
708 * iteration i should be equal to all elements of the row read by i'.
709 * Instead, we want to express that each individual element read by i'
710 * should be equal to the corresponding element read by i'.
711 * In the example, we need to introduce the equality j = j'.
713 * We first check that the number of dimensions added by the expansions
714 * is the same in both programs. Then we construct an identity relation
715 * between this number of dimensions, lift it to the space of the
716 * new equivalence relation and take the intersection.
718 * We have to be careful, though, that we don't loose any array elements
719 * by taking the intersection. In particular, the expansion maps
720 * a given read operation to iterations of an extended domain that
721 * reads all the individual array elements. The read operation is
722 * only equivalent to some other read operation if all the reads
723 * of the individual array elements are equivalent.
724 * We therefore need to make sure that adding the equalities does
725 * not remove any of the reads in the extended domain.
726 * We do this by projecting out the extra dimensions on one side
727 * from both "want" and "want \cap eq". The resulting maps should
728 * be the same. We do this check for both sides separately.
730 * If either of the above tests fails, then we simply return
731 * the original over-ambitious want.
733 static isl_map *expansion_want(isl_map *want, edge *e1, edge *e2)
735 unsigned s_dim_1 = isl_map_n_out(e1->relation) -
736 isl_map_n_in(e1->relation);
737 unsigned s_dim_2 = isl_map_n_out(e2->relation) -
738 isl_map_n_in(e2->relation);
739 if (s_dim_1 != s_dim_2)
740 return want;
742 isl_space *dim = isl_map_get_space(e1->relation);
743 dim = isl_space_drop_dims(dim, isl_dim_in,
744 0, isl_space_dim(dim, isl_dim_in));
745 dim = isl_space_drop_dims(dim, isl_dim_out,
746 0, isl_space_dim(dim, isl_dim_out));
747 dim = isl_space_add_dims(dim, isl_dim_in, s_dim_1);
748 dim = isl_space_add_dims(dim, isl_dim_out, s_dim_1);
749 isl_basic_map *id = isl_basic_map_identity(dim);
751 dim = isl_space_range(isl_map_get_space(e1->relation));
752 isl_basic_map *s_1 = isl_basic_map_identity(isl_space_map_from_set(dim));
753 s_1 = isl_basic_map_remove_dims(s_1, isl_dim_in, 0,
754 isl_basic_map_n_in(s_1) - s_dim_1);
755 id = isl_basic_map_apply_domain(id, s_1);
757 dim = isl_space_range(isl_map_get_space(e2->relation));
758 isl_basic_map *s_2 = isl_basic_map_identity(isl_space_map_from_set(dim));
759 s_2 = isl_basic_map_remove_dims(s_2, isl_dim_in, 0,
760 isl_basic_map_n_in(s_2) - s_dim_2);
761 id = isl_basic_map_apply_range(id, s_2);
763 bool unmatched = false;
764 isl_map *matched_want;
765 isl_map *proj_want, *proj_matched;
766 matched_want = isl_map_intersect(isl_map_copy(want),
767 isl_map_from_basic_map(id));
769 proj_want = isl_map_remove_dims(isl_map_copy(want),
770 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
771 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
772 isl_dim_in, isl_map_n_in(want) - s_dim_1, s_dim_1);
773 if (!isl_map_is_equal(proj_want, proj_matched))
774 unmatched = true;
775 isl_map_free(proj_want);
776 isl_map_free(proj_matched);
778 proj_want = isl_map_remove_dims(isl_map_copy(want),
779 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
780 proj_matched = isl_map_remove_dims(isl_map_copy(matched_want),
781 isl_dim_out, isl_map_n_out(want) - s_dim_2, s_dim_2);
782 if (!isl_map_is_equal(proj_want, proj_matched))
783 unmatched = true;
784 isl_map_free(proj_want);
785 isl_map_free(proj_matched);
787 if (unmatched) {
788 isl_map_free(matched_want);
789 return matched_want;
791 isl_map_free(want);
792 return matched_want;
795 static isl_map *propagation_want(eq_node *node, edge *e1, edge *e2)
797 isl_map *new_want;
799 new_want = isl_map_copy(node->want);
800 if (e1)
801 new_want = isl_map_apply_domain(new_want,
802 isl_map_copy(e1->relation));
803 if (e2)
804 new_want = isl_map_apply_range(new_want,
805 isl_map_copy(e2->relation));
807 if (e1 && e2 && e1->type == edge::expansion)
808 new_want = expansion_want(new_want, e1, e2);
810 return isl_map_detect_equalities(new_want);
813 static eq_node *propagation_node(eq_node *node, edge *e1, edge *e2)
815 isl_map *new_want;
816 computation *comp1 = e1 ? e1->source : node->comp[0];
817 computation *comp2 = e2 ? e2->source : node->comp[1];
819 new_want = propagation_want(node, e1, e2);
820 eq_node *child = new eq_node(comp1, comp2, new_want);
821 return child;
824 static isl_map *set_to_empty(isl_map *map)
826 isl_map *empty;
827 empty = isl_map_empty_like(map);
828 isl_map_free(map);
829 return empty;
832 /* Propagate "lost" from child back to parent.
834 isl_map *equivalence_checker::lost_from_propagation(eq_node *parent,
835 eq_node *node, edge *e1, edge *e2)
837 isl_map *new_lost;
838 new_lost = node->get_lost();
840 if (e1)
841 new_lost = isl_map_apply_domain(new_lost,
842 isl_map_reverse(isl_map_copy(e1->relation)));
843 if (e2)
844 new_lost = isl_map_apply_range(new_lost,
845 isl_map_reverse(isl_map_copy(e2->relation)));
847 new_lost = isl_map_intersect(new_lost, isl_map_copy(parent->want));
849 extend_trace_propagation(parent, node, e1, e2);
851 return new_lost;
854 void equivalence_checker::handle_copy(eq_node *node, int i)
856 std::vector<edge *>::iterator iter;
857 computation *copy = node->comp[i];
858 isl_map *want = node->want;
859 isl_map *lost;
861 lost = isl_map_empty_like(want);
863 for (iter = copy->edges.begin(); iter != copy->edges.end(); ++iter) {
864 edge *e = *iter;
865 eq_node *child;
866 if (i == 0)
867 child = propagation_node(node, e, NULL);
868 else
869 child = propagation_node(node, NULL, e);
871 if (!child)
872 continue;
874 handle(child);
876 isl_map *new_lost;
877 if (i == 0)
878 new_lost = lost_from_propagation(node, child, e, NULL);
879 else
880 new_lost = lost_from_propagation(node, child, NULL, e);
881 dismiss(child);
883 lost = isl_map_union_disjoint(lost, new_lost);
885 node->set_lost(lost);
888 /* Compute and return the part of "want" that is lost when propagating
889 * over all edges of argument position pos1 in the first program
890 * and argument position pos2 in the second program.
891 * Since the domains of the dependence mappings on edges with the same
892 * argument position partition the domain of the computation (and are
893 * therefore disjoint), we simply need to take the disjoint union
894 * of all losts over all pairs of edges.
896 isl_map *equivalence_checker::lost_at_pos(eq_node *node, int pos1, int pos2)
898 isl_map *lost;
899 lost = isl_map_empty_like(node->want);
901 for (int i = 0; i < node->comp[0]->edges.size(); ++i) {
902 edge *e1 = node->comp[0]->edges[i];
903 if (e1->pos != pos1)
904 continue;
905 for (int j = 0; j < node->comp[1]->edges.size(); ++j) {
906 edge *e2 = node->comp[1]->edges[j];
907 if (e2->pos != pos2)
908 continue;
910 eq_node *child;
911 isl_map *new_lost = NULL;
913 child = propagation_node(node, e1, e2);
914 handle(child);
915 new_lost = lost_from_propagation(node, child, e1, e2);
917 lost = isl_map_union_disjoint(lost, new_lost);
919 std::set<eq_node *>::iterator k;
920 for (k = child->assumed.begin();
921 k != child->assumed.end(); ++k)
922 node->assumed.insert(*k);
923 dismiss(child);
926 return lost;
929 /* Compute the lost that results from propagation on a pair of
930 * computations with a commutative operation.
932 * We first compute the losts for each pair of argument positions
933 * and store the result in lost[pos1][pos2].
934 * Then we perform a backtracking search over all permutations
935 * of the arguments. For each permutation, we compute the lost
936 * relation as the union of the losts over all arguments.
937 * The final lost is the intersection of all these losts over all
938 * permutations.
940 void equivalence_checker::handle_propagation_comm(eq_node *node)
942 int trace_error = options->trace_error;
943 trace_error = 0;
945 unsigned r = node->comp[0]->arity;
947 std::vector<std::vector<isl_map *> > lost;
949 for (int i = 0; i < r; ++i) {
950 std::vector<isl_map *> row;
951 for (int j = 0; j < r; ++j) {
952 isl_map *pos_lost;
953 pos_lost = lost_at_pos(node, i, j);
954 row.push_back(pos_lost);
956 lost.push_back(row);
959 int level;
960 std::vector<int> perm;
961 std::vector<isl_map *> lost_at;
962 for (level = 0; level < r; ++level) {
963 perm.push_back(0);
964 lost_at.push_back(NULL);
967 isl_map *total_lost;
968 total_lost = isl_map_copy(node->want);
970 level = 0;
971 while (level >= 0) {
972 if (perm[level] == r) {
973 perm[level] = 0;
974 --level;
975 if (level >= 0)
976 ++perm[level];
977 continue;
979 int l;
980 for (l = 0; l < level; ++l)
981 if (perm[l] == perm[level])
982 break;
983 if (l != level) {
984 ++perm[level];
985 continue;
988 isl_map_free(lost_at[level]);
989 lost_at[level] = isl_map_copy(lost[level][perm[level]]);
990 if (level != 0)
991 lost_at[level] = isl_map_union(lost_at[level],
992 isl_map_copy(lost_at[level - 1]));
994 if (level < r - 1) {
995 ++level;
996 continue;
999 lost_at[level] = isl_map_coalesce(lost_at[level]);
1000 total_lost = isl_map_intersect(total_lost,
1001 isl_map_copy(lost_at[level]));
1002 ++perm[level];
1005 for (int i = 0; i < r; ++i)
1006 isl_map_free(lost_at[i]);
1008 for (int i = 0; i < r; ++i)
1009 for (int j = 0; j < r; ++j)
1010 isl_map_free(lost[i][j]);
1012 node->set_lost(total_lost);
1014 options->trace_error = trace_error;
1015 init_trace(node);
1018 /* Compute the lost that results from propagation on a pair of
1019 * computations.
1021 * First, for functions of zero arity (i.e., constants), equivalence
1022 * always holds and the lost relation is empty.
1023 * For commutative operations, the computation is delegated
1024 * to handle_propagation_comm.
1025 * Otherwise, we simply take the union of the losts over each
1026 * argument position (always taking the same argument position
1027 * in both programs).
1029 void equivalence_checker::handle_propagation_same(eq_node *node)
1031 unsigned r = node->comp[0]->arity;
1032 if (r == 0) {
1033 node->set_lost(isl_map_empty_like(node->want));
1034 return;
1036 if (is_commutative(node->comp[0]->operation)) {
1037 handle_propagation_comm(node);
1038 return;
1041 isl_map *lost;
1042 lost = isl_map_empty_like(node->want);
1043 for (int i = 0; i < r; ++i) {
1044 isl_map *pos_lost;
1045 pos_lost = lost_at_pos(node, i, i);
1046 lost = isl_map_union(lost, pos_lost);
1048 lost = isl_map_coalesce(lost);
1049 node->set_lost(lost);
1052 void equivalence_checker::handle_widening(eq_node *node, eq_node *a)
1054 isl_map *wants;
1055 isl_map *aff;
1057 wants = isl_map_union(isl_map_copy(node->want), isl_map_copy(a->want));
1058 aff = isl_map_from_basic_map(isl_map_affine_hull(wants));
1059 aff = isl_map_intersect_domain(aff,
1060 isl_set_copy(node->comp[0]->domain));
1061 aff = isl_map_intersect_range(aff,
1062 isl_set_copy(node->comp[1]->domain));
1064 eq_node *child = new eq_node(node->comp[0], node->comp[1], aff);
1065 child->widening = 1;
1066 widenings++;
1067 handle(child);
1068 node->set_lost(isl_map_intersect(child->get_lost(),
1069 isl_map_copy(node->want)));
1070 extend_trace_widening(node, child);
1071 std::set<eq_node *>::iterator i;
1072 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1073 node->assumed.insert(*i);
1074 dismiss(child);
1077 /* Perform induction, if possible, and widening if we have to.
1079 void equivalence_checker::handle_with_ancestor(eq_node *node, eq_node *a)
1081 if (a->narrowing || isl_map_is_subset(node->want, a->want)) {
1082 isl_map *need;
1083 need = isl_map_intersect(isl_map_copy(node->want),
1084 isl_map_copy(a->want));
1085 node->set_lost(isl_map_subtract(isl_map_copy(node->want),
1086 isl_map_copy(a->want)));
1087 node->assumed.insert(a);
1088 a->need = isl_map_union(a->need, need);
1089 } else {
1090 handle_widening(node, a);
1094 struct narrowing_data {
1095 eq_node *node;
1096 equivalence_checker *ec;
1097 isl_map *new_got;
1100 static int basic_handle_narrowing(__isl_take isl_basic_map *bmap, void *user)
1102 narrowing_data *data = (narrowing_data *)user;
1104 eq_node *child;
1105 child = new eq_node(data->node->comp[0], data->node->comp[1],
1106 isl_map_from_basic_map(bmap));
1107 child->narrowing = 1;
1108 data->ec->narrowings++;
1109 data->ec->handle(child);
1110 data->new_got = isl_map_union(data->new_got, child->get_got());
1112 std::set<eq_node *>::iterator i;
1113 for (i = child->assumed.begin(); i != child->assumed.end(); ++i)
1114 data->node->assumed.insert(*i);
1115 data->ec->dismiss(child);
1117 return 0;
1120 /* Construct and handle narrowing nodes for the given node.
1122 * If the node itself was already a narrowing node, then we
1123 * simply return the empty relation.
1124 * Otherwise, we consider each basic relation in the obtained
1125 * relation, construct a new node with that basic relation as
1126 * requested relation and take the union of all obtained relations.
1128 void equivalence_checker::handle_narrowing(eq_node *node)
1130 node->reset = 1;
1131 node->assumed.clear();
1133 isl_map *want = node->want;
1134 isl_map *new_got = isl_map_empty_like(want);
1135 if (!options->narrowing || node->narrowing) {
1136 node->set_got(new_got);
1137 return;
1140 narrowing_data data = { node, this, new_got };
1141 isl_map_foreach_basic_map(node->peek_got(), &basic_handle_narrowing,
1142 &data);
1143 node->set_got(data.new_got);
1146 static __isl_give isl_union_set *update_result(__isl_take isl_union_set *res,
1147 const char *array_name, isl_map *map, int first, int n)
1149 if (isl_map_is_empty(map)) {
1150 isl_map_free(map);
1151 return res;
1154 isl_set *range = isl_map_range(map);
1155 range = isl_set_remove_dims(range, isl_dim_set, first, n);
1156 range = isl_set_coalesce(range);
1157 range = isl_set_set_tuple_name(range, array_name);
1158 res = isl_union_set_add_set(res, range);
1160 return res;
1163 struct check_equivalence_data {
1164 dependence_graph *dg1;
1165 dependence_graph *dg2;
1166 equivalence_checker *ec;
1167 isl_map *got;
1168 isl_map *lost;
1171 static int basic_check(__isl_take isl_basic_map *bmap, void *user)
1173 check_equivalence_data *data = (check_equivalence_data *)user;
1175 eq_node *root = new eq_node(data->dg1->out, data->dg2->out,
1176 isl_map_from_basic_map(bmap));
1177 data->ec->handle(root);
1178 data->got = isl_map_union_disjoint(data->got, root->get_got());
1179 data->lost = isl_map_union_disjoint(data->lost, root->get_lost());
1180 data->ec->dismiss(root);
1182 return 0;
1185 static int check_equivalence_array(isl_ctx *ctx, equivalence_checker *ec,
1186 dependence_graph *dg1, dependence_graph *dg2, int array1, int array2,
1187 isl_union_set **proved, isl_union_set **not_proved)
1189 const char *array_name = dg1->output_arrays[array1];
1190 isl_set *out1 = isl_set_copy(dg1->out->domain);
1191 isl_set *out2 = isl_set_copy(dg2->out->domain);
1192 unsigned dim1 = isl_set_n_dim(out1);
1193 unsigned dim2 = isl_set_n_dim(out2);
1194 unsigned array_dim = dg1->output_array_dims[array1];
1195 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1196 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1197 out1 = isl_set_remove_dims(out1, isl_dim_set, dim1 - 1, 1);
1198 out2 = isl_set_remove_dims(out2, isl_dim_set, dim2 - 1, 1);
1199 int equal = isl_set_is_equal(out1, out2);
1200 isl_set_free(out1);
1201 isl_set_free(out2);
1202 assert(equal >= 0);
1203 if (!equal) {
1204 fprintf(stderr, "different output domains for array %s\n",
1205 array_name);
1206 return -1;
1209 out1 = isl_set_copy(dg1->out->domain);
1210 out2 = isl_set_copy(dg2->out->domain);
1211 out1 = isl_set_fix_dim_si(out1, dim1-1, array1);
1212 out2 = isl_set_fix_dim_si(out2, dim2-1, array2);
1213 out1 = isl_set_coalesce(out1);
1214 out2 = isl_set_coalesce(out2);
1215 isl_space *dim = isl_space_map_from_set(isl_set_get_space(out2));
1216 isl_map *id = isl_map_identity(dim);
1217 id = isl_map_remove_dims(id, isl_dim_in, isl_map_n_in(id) - 1, 1);
1218 id = isl_map_apply_domain(id, isl_map_copy(id));
1219 id = isl_map_intersect_domain(id, out1);
1220 id = isl_map_intersect_range(id, out2);
1222 isl_map *got = isl_map_empty_like(id);
1223 isl_map *lost = isl_map_empty_like(id);
1224 check_equivalence_data data = { dg1, dg2, ec, got, lost };
1225 isl_map_foreach_basic_map(id, &basic_check, &data);
1226 isl_map_free(id);
1227 *proved = update_result(*proved,
1228 array_name, data.got, array_dim, dim2 - array_dim);
1229 *not_proved = update_result(*not_proved,
1230 array_name, data.lost, array_dim, dim2 - array_dim);
1231 return 0;
1234 /* The input arrays of the two programs are supposed to be the same,
1235 * so they should at least have the same dimension. Make sure
1236 * this is true, because we depend on it later on.
1238 static int check_input_arrays(dependence_graph *dg1, dependence_graph *dg2)
1240 for (int i = 0; i < dg1->input_computations.size(); ++i)
1241 for (int j = 0; j < dg2->input_computations.size(); ++j) {
1242 if (strcmp(dg1->input_computations[i]->operation,
1243 dg2->input_computations[j]->operation))
1244 continue;
1245 if (dg1->input_computations[i]->dim ==
1246 dg2->input_computations[j]->dim)
1247 continue;
1248 fprintf(stderr,
1249 "input arrays \"%s\" do not have the same dimension\n",
1250 dg1->input_computations[i]->operation);
1251 return -1;
1254 return 0;
1257 static __isl_give isl_union_set *add_array(__isl_take isl_union_set *only,
1258 dependence_graph *dg, int array)
1260 isl_space *dim;
1261 isl_set *set;
1263 dim = isl_union_set_get_space(only);
1264 dim = isl_space_add_dims(dim, isl_dim_set, dg->output_array_dims[array]);
1265 dim = isl_space_set_tuple_name(dim, isl_dim_set, dg->output_arrays[array]);
1266 set = isl_set_universe(dim);
1267 only = isl_union_set_add_set(only, set);
1269 return only;
1272 static void print_results(const char *str, __isl_keep isl_union_set *only)
1274 isl_printer *prn;
1276 if (isl_union_set_is_empty(only))
1277 return;
1279 fprintf(stdout, "%s: '", str);
1280 prn = isl_printer_to_file(isl_union_set_get_ctx(only), stdout);
1281 prn = isl_printer_print_union_set(prn, only);
1282 isl_printer_free(prn);
1283 fprintf(stdout, "'\n");
1286 static int check_equivalence(isl_ctx *ctx,
1287 dependence_graph *dg1, dependence_graph *dg2, options *options)
1289 isl_space *dim;
1290 isl_set *context;
1291 isl_union_set *only1, *only2, *proved, *not_proved;
1292 dg1->flatten_associative_operators(options->ops->associative);
1293 dg2->flatten_associative_operators(options->ops->associative);
1294 equivalence_checker ec(options);
1295 int i1 = 0, i2 = 0;
1297 if (check_input_arrays(dg1, dg2))
1298 return -1;
1300 dim = isl_space_set_alloc(ctx, 0, 0);
1301 proved = isl_union_set_empty(isl_space_copy(dim));
1302 not_proved = isl_union_set_empty(isl_space_copy(dim));
1303 only1 = isl_union_set_empty(isl_space_copy(dim));
1304 only2 = isl_union_set_empty(dim);
1306 while (i1 < dg1->output_arrays.size() || i2 < dg2->output_arrays.size()) {
1307 int cmp;
1308 cmp = i1 == dg1->output_arrays.size() ? 1 :
1309 i2 == dg2->output_arrays.size() ? -1 :
1310 strcmp(dg1->output_arrays[i1], dg2->output_arrays[i2]);
1311 if (cmp < 0) {
1312 only1 = add_array(only1, dg1, i1);
1313 ++i1;
1314 } else if (cmp > 0) {
1315 only2 = add_array(only2, dg2, i2);
1316 ++i2;
1317 } else {
1318 check_equivalence_array(ctx, &ec, dg1, dg2, i1, i2,
1319 &proved, &not_proved);
1320 ++i1;
1321 ++i2;
1325 context = isl_set_union(isl_set_copy(dg1->context),
1326 isl_set_copy(dg2->context));
1327 proved = isl_union_set_gist_params(proved, isl_set_copy(context));
1328 not_proved = isl_union_set_gist_params(not_proved, context);
1330 print_results("Equivalence proved", proved);
1331 print_results("Equivalence NOT proved", not_proved);
1332 print_results("Only in program 1", only1);
1333 print_results("Only in program 2", only2);
1335 isl_union_set_free(proved);
1336 isl_union_set_free(not_proved);
1337 isl_union_set_free(only1);
1338 isl_union_set_free(only2);
1340 if (options->print_stats) {
1341 fprintf(stderr, "widenings: %d\n", ec.widenings);
1342 if (options->narrowing)
1343 fprintf(stderr, "narrowings: %d\n", ec.narrowings);
1346 return 0;
1349 static void dump_vertex(FILE *out, computation *comp)
1351 fprintf(out, "ND_%p [label = \"%d,%s/%d\"];\n",
1352 comp, comp->location, comp->operation, comp->arity);
1353 for (int i = 0; i < comp->edges.size(); ++i)
1354 fprintf(out, "ND_%p -> ND_%p%s;\n",
1355 comp, comp->edges[i]->source,
1356 comp->edges[i]->type == edge::expansion ?
1357 " [color=\"blue\"]" : "");
1360 static void dump_graph(FILE *out, dependence_graph *dg)
1362 fprintf(out, "digraph dummy {\n");
1363 dump_vertex(out, dg->out);
1364 for (int i = 0; i < dg->vertices.size(); ++i)
1365 dump_vertex(out, dg->vertices[i]);
1366 fprintf(out, "}\n");
1369 static void dump_graphs(dependence_graph **dg, struct options *options)
1371 int i;
1372 char path[PATH_MAX];
1374 if (!options->dump_graphs)
1375 return;
1377 for (i = 0; i < 2; ++i) {
1378 FILE *out;
1379 int s;
1380 s = snprintf(path, sizeof(path), "%s.dot", options->program[i]);
1381 assert(s < sizeof(path));
1382 out = fopen(path, "w");
1383 assert(out);
1384 dump_graph(out, dg[i]);
1385 fclose(out);
1389 void parse_ops(struct options *options) {
1390 char *tok;
1392 if (options->associative) {
1393 tok = strtok(options->associative, ",");
1394 options->ops->associative.push_back(strdup(tok));
1395 while ((tok = strtok(NULL, ",")) != NULL)
1396 options->ops->associative.push_back(strdup(tok));
1399 if (options->commutative) {
1400 tok = strtok(options->commutative, ",");
1401 options->ops->commutative.push_back(strdup(tok));
1402 while ((tok = strtok(NULL, ",")) != NULL)
1403 options->ops->commutative.push_back(strdup(tok));
1407 int main(int argc, char *argv[])
1409 struct options *options = options_new_with_defaults();
1410 struct isl_ctx *ctx;
1411 isl_set *context = NULL;
1412 dependence_graph *dg[2];
1414 argc = options_parse(options, argc, argv, ISL_ARG_ALL);
1415 parse_ops(options);
1417 ctx = isl_ctx_alloc_with_options(&options_args, options);
1418 if (!ctx) {
1419 fprintf(stderr, "Unable to allocate ctx\n");
1420 return -1;
1423 if (options->context)
1424 context = isl_set_read_from_str(ctx, options->context);
1426 pdg::PDG *pdg[2];
1427 unsigned out_dim = 0;
1428 for (int i = 0; i < 2; ++i) {
1429 FILE *in;
1430 in = fopen(options->program[i], "r");
1431 if (!in) {
1432 fprintf(stderr, "Unable to open %s\n", options->program[i]);
1433 return -1;
1435 pdg[i] = yaml::Load<pdg::PDG>(in, ctx);
1436 fclose(in);
1437 if (!pdg[i]) {
1438 fprintf(stderr, "Unable to read %s\n", options->program[i]);
1439 return -1;
1441 out_dim = update_out_dim(pdg[i], out_dim);
1443 if (context &&
1444 pdg[i]->params.size() != isl_set_dim(context, isl_dim_param)) {
1445 fprintf(stdout,
1446 "Parameter dimension mismatch; context ignored\n");
1447 isl_set_free(context);
1448 context = NULL;
1452 for (int i = 0; i < 2; ++i) {
1453 dg[i] = pdg_to_dg(pdg[i], out_dim, isl_set_copy(context));
1454 pdg[i]->free();
1455 delete pdg[i];
1458 dump_graphs(dg, options);
1460 int res = check_equivalence(ctx, dg[0], dg[1], options);
1462 for (int i = 0; i < 2; ++i)
1463 delete dg[i];
1464 isl_set_free(context);
1465 isl_ctx_free(ctx);
1467 return res;