17 #include <isl/space.h>
20 #include <isl/constraint.h>
21 #include <isl/union_set.h>
22 #include <isl/printer.h>
23 #include "eqv_options.h"
24 #include "dependence_graph.h"
27 std::vector
<const char *> associative
;
28 std::vector
<const char *> commutative
;
30 int ops_init(void *user
)
32 struct ops
**ops
= (struct ops
**)user
;
33 *ops
= new struct ops
;
36 void ops_clear(void *user
)
38 struct ops
**ops
= (struct ops
**)user
;
42 /* For each basic map in relation, add an edge to comp with
43 * given source and pos.
45 static void add_split_edge(computation
*comp
, computation
*source
,
46 int pos
, isl_map
*relation
, std::vector
<computation
**> *missing
,
47 enum edge::type type
= edge::normal
)
53 missing
->push_back(&e
->source
);
56 e
->relation
= relation
;
58 comp
->edges
.push_back(e
);
61 static unsigned update_out_dim(pdg::PDG
*pdg
, unsigned out_dim
)
63 for (int i
= 0; i
< pdg
->arrays
.size(); ++i
) {
64 pdg::array
*array
= pdg
->arrays
[i
];
65 if (array
->type
!= pdg::array::output
)
67 if (array
->dims
.size() + 1 > out_dim
)
68 out_dim
= array
->dims
.size() + 1;
74 static bool is_associative(const std::vector
<const char *> &associative
,
77 std::vector
<const char *>::const_iterator iter
;
79 for (iter
= associative
.begin(); iter
!= associative
.end(); ++iter
)
80 if (!strcmp(*iter
, op
))
85 /* Return a computation that has an associative operation
86 * that takes a different computation with the same operation
87 * as one of its arguments. Also return the edge from the
88 * first to the second computation in *e.
90 * If no such computation exists, then return NULL.
92 computation
*dependence_graph::associative_node(edge
**e
,
93 const std::vector
<const char *> &associative
)
95 for (int i
= 0; i
< vertices
.size(); ++i
) {
96 computation
*comp
= vertices
[i
];
97 if (!is_associative(associative
, comp
->operation
))
99 for (int j
= 0; j
< comp
->edges
.size(); ++j
) {
100 computation
*other
= comp
->edges
[j
]->source
;
101 if (comp
->has_same_source(other
))
103 if (strcmp(comp
->operation
, other
->operation
))
112 /* Splice the source of edge edge into the position of edge e,
113 * removing all other edges with the same positions and bumping
114 * up edges with a greater position.
116 void dependence_graph::splice(computation
*comp
, edge
*e
)
118 computation
*source
= e
->source
;
120 isl_map
*map
= isl_map_copy(e
->relation
);
121 std::vector
<struct edge
*> new_edges
;
123 for (int i
= 0; i
< comp
->edges
.size(); ++i
) {
124 if (comp
->edges
[i
]->pos
== pos
) {
125 delete comp
->edges
[i
];
128 if (comp
->edges
[i
]->pos
> pos
)
129 comp
->edges
[i
]->pos
+= source
->arity
- 1;
130 new_edges
.push_back(comp
->edges
[i
]);
132 for (int i
= 0; i
< source
->edges
.size(); ++i
) {
133 edge
*old_e
= source
->edges
[i
];
135 e
->source
= old_e
->source
;
136 e
->pos
= old_e
->pos
+ pos
;
137 e
->relation
= isl_map_apply_range(
139 isl_map_copy(old_e
->relation
));
140 new_edges
.push_back(e
);
143 comp
->edges
= new_edges
;
144 comp
->arity
+= source
->arity
- 1;
147 /* Split computation comp into a part that always takes the edge e
148 * and one that never takes that edge.
149 * The second is returned and the initial computation is modified
150 * to match the first.
152 * We need to be careful not to destroy e, as it is still used
153 * by the calling method.
155 computation
*dependence_graph::split_comp(computation
*comp
, edge
*e
)
157 computation
*dup
= new computation
;
158 dup
->original
= comp
->original
? comp
->original
: comp
;
159 dup
->operation
= strdup(comp
->operation
);
160 dup
->arity
= comp
->arity
;
161 dup
->location
= comp
->location
;
163 isl_map
*map
= isl_map_copy(e
->relation
);
164 map
= isl_map_reverse(map
);
165 isl_set
*dom
= isl_set_apply(isl_set_copy(e
->source
->domain
), map
);
166 dup
->domain
= isl_set_subtract(isl_set_copy(comp
->domain
),
168 comp
->domain
= isl_set_intersect(comp
->domain
, dom
);
170 std::vector
<struct edge
*> old_edges
= comp
->edges
;
173 for (int i
= 0; i
< old_edges
.size(); ++i
) {
174 if (old_edges
[i
] == e
) {
175 comp
->edges
.push_back(e
);
179 edge
*e
= old_edges
[i
];
180 isl_map
*map
, *map_dup
;
181 map
= isl_map_copy(e
->relation
);
182 map_dup
= isl_map_copy(map
);
184 map
= isl_map_intersect_domain(map
, isl_set_copy(comp
->domain
));
185 map_dup
= isl_map_intersect_domain(map_dup
,
186 isl_set_copy(dup
->domain
));
187 add_split_edge(comp
, e
->source
, e
->pos
, map
, NULL
);
188 add_split_edge(dup
, e
->source
, e
->pos
, map_dup
, NULL
);
192 vertices
.push_back(dup
);
196 /* If any edge from comp points to comp_orig, then split it
197 * into two edges, one still pointing to comp_orig and the
198 * other pointing to comp_dup.
199 * comp_orig and comp_dup are assumed to have disjoint domains
200 * and the edge relations are adjusted according to these domains.
202 void dependence_graph::split_edges(computation
*comp
,
203 computation
*comp_orig
, computation
*comp_dup
)
205 std::vector
<struct edge
*> old_edges
= comp
->edges
;
208 for (int i
= 0; i
< old_edges
.size(); ++i
) {
209 edge
*e
= old_edges
[i
];
211 if (e
->source
!= comp_orig
) {
212 comp
->edges
.push_back(e
);
216 isl_map
*map_orig
, *map_dup
;
217 map_orig
= isl_map_copy(e
->relation
);
218 map_dup
= isl_map_copy(map_orig
);
219 map_orig
= isl_map_intersect_range(map_orig
,
220 isl_set_copy(comp_orig
->domain
));
221 map_dup
= isl_map_intersect_range(map_dup
,
222 isl_set_copy(comp_dup
->domain
));
223 add_split_edge(comp
, comp_orig
, e
->pos
, map_orig
, NULL
);
224 add_split_edge(comp
, comp_dup
, e
->pos
, map_dup
, NULL
);
229 void dependence_graph::split_edges(computation
*comp_orig
,
230 computation
*comp_dup
)
232 split_edges(out
, comp_orig
, comp_dup
);
233 for (int i
= 0; i
< vertices
.size(); ++i
)
234 split_edges(vertices
[i
], comp_orig
, comp_dup
);
237 /* Replace all nested calls of an associative operator,
238 * by a call with the nested call spliced into the first call.
240 * For each nested call we find, we first check if there are
241 * any other edges with the same position. If not, then
242 * the nested call is performed for each iteration of the computation
243 * and we can simplify splice the nested call.
245 * Otherwise, we first create a duplicate computation for the iterations
246 * that do not take the found edge and adjust the edges of both computations
247 * to their domains. This meand that the edge corresponding to the
248 * nested call will no longer appear in the duplicated computation
249 * and the other edges with the same position will no longer appear
250 * in the original computation.
251 * Then we splice the nested call in the original computation.
252 * Finally, we split all edges that pointed to the original computation
253 * into two edges, one going to the original computation and one
254 * going to the duplicated computation.
256 void dependence_graph::flatten_associative_operators(
257 const std::vector
<const char *> &associative
)
262 while ((comp
= associative_node(&e
, associative
)) != NULL
) {
263 computation
*comp_dup
= NULL
;
265 for (j
= 0; j
< comp
->edges
.size(); ++j
) {
266 if (comp
->edges
[j
] == e
)
268 if (comp
->edges
[j
]->pos
== e
->pos
)
271 if (j
!= comp
->edges
.size())
272 comp_dup
= split_comp(comp
, e
);
275 split_edges(comp
, comp_dup
);
284 computation
*comp
[2];
287 std::set
<eq_node
*> assumed
;
289 unsigned narrowing
: 1;
290 unsigned widening
: 1;
291 unsigned invalidated
: 1;
294 isl_map
*lost_sample
;
296 eq_node(computation
*c1
, computation
*c2
,
297 isl_map
*w
) : want(w
), closed(0),
298 widening(0), narrowing(0), invalidated(0), reset(0),
299 need(NULL
), got(NULL
), lost(NULL
), lost_sample(NULL
) {
303 bool is_still_valid();
304 void collect_open_assumed(std::set
<eq_node
*> &c
);
310 isl_map_free(lost_sample
);
314 got
= isl_map_copy(want
);
315 got
= isl_map_subtract(got
, isl_map_copy(lost
));
320 return isl_map_copy(got
);
322 isl_map
*peek_got() {
327 void compute_lost() {
329 lost
= isl_map_copy(want
);
330 lost
= isl_map_subtract(lost
, isl_map_copy(got
));
332 isl_map
*get_lost() {
335 return isl_map_copy(lost
);
337 isl_map
*peek_lost() {
342 void set_got(isl_map
*got
) {
343 isl_map_free(this->lost
);
344 isl_map_free(this->got
);
348 void set_lost(isl_map
*lost
) {
349 isl_map_free(this->lost
);
350 isl_map_free(this->got
);
354 std::string
to_string();
357 std::string
eq_node::to_string()
359 std::ostringstream strm
;
360 strm
<< comp
[0]->location
<< "," << comp
[0]->operation
361 << "/" << comp
[0]->arity
;
363 strm
<< comp
[1]->location
<< "," << comp
[1]->operation
364 << "/" << comp
[1]->arity
;
369 bool eq_node::is_still_valid()
374 std::set
<eq_node
*>::iterator i
;
375 for (i
= assumed
.begin(); i
!= assumed
.end(); ++i
) {
378 ((*i
)->closed
&& !(*i
)->is_still_valid())) {
386 void eq_node::collect_open_assumed(std::set
<eq_node
*> &c
)
388 std::set
<eq_node
*>::iterator i
;
389 for (i
= assumed
.begin(); i
!= assumed
.end(); ++i
) {
391 (*i
)->collect_open_assumed(c
);
397 /* A comp_pair contains all the edges that have the same pair
401 computation
*comp
[2];
403 std::vector
<eq_node
*> nodes
;
405 eq_node
*tabled(eq_node
*node
);
406 eq_node
*last_ancestor(eq_node
*node
);
410 comp_pair::~comp_pair()
412 std::vector
<eq_node
*>::iterator i
;
413 for (i
= nodes
.begin(); i
!= nodes
.end(); ++i
)
417 eq_node
*comp_pair::tabled(eq_node
*node
)
419 std::vector
<eq_node
*>::iterator i
;
421 for (i
= nodes
.begin(); i
!= nodes
.end(); ++i
) {
426 if (!(*i
)->is_still_valid())
429 is_subset
= isl_map_is_subset(node
->want
, (*i
)->want
);
430 assert(is_subset
>= 0);
438 eq_node
*comp_pair::last_ancestor(eq_node
*node
)
440 std::vector
<eq_node
*>::reverse_iterator i
;
441 for (i
= nodes
.rbegin(); i
!= nodes
.rend(); ++i
) {
451 typedef std::pair
<computation
*, computation
*> computation_pair
;
452 typedef std::map
<computation_pair
, comp_pair
*> c2p_t
;
454 struct equivalence_checker
{
456 struct options
*options
;
460 equivalence_checker(struct options
*o
) :
461 options(o
), widenings(0), narrowings(0) {}
462 comp_pair
*get_comp_pair(eq_node
*node
);
463 void handle(eq_node
*node
);
464 void dismiss(eq_node
*node
);
465 void handle_propagation(eq_node
*node
);
466 void handle_copy(eq_node
*node
, int i
);
467 void handle_widening(eq_node
*node
, eq_node
*a
);
468 void handle_with_ancestor(eq_node
*node
, eq_node
*a
);
469 isl_map
*lost_at_pos(eq_node
*node
, int pos1
, int pos2
);
470 void handle_propagation_same(eq_node
*node
);
471 void handle_propagation_comm(eq_node
*node
);
472 void handle_narrowing(eq_node
*node
);
474 isl_map
*lost_from_propagation(eq_node
*parent
,
475 eq_node
*node
, edge
*e1
, edge
*e2
);
477 void init_trace(eq_node
*node
);
478 void extend_trace_propagation(eq_node
*node
, eq_node
*child
,
480 void extend_trace_widening(eq_node
*node
, eq_node
*child
);
482 bool is_commutative(const char *op
);
484 ~equivalence_checker();
487 equivalence_checker::~equivalence_checker()
491 for (i
= c2p
.begin(); i
!= c2p
.end(); ++i
)
495 bool equivalence_checker::is_commutative(const char *op
)
497 std::vector
<const char *>::iterator iter
;
499 for (iter
= options
->ops
->commutative
.begin();
500 iter
!= options
->ops
->commutative
.end(); ++iter
)
501 if (!strcmp(*iter
, op
))
506 comp_pair
*equivalence_checker::get_comp_pair(eq_node
*node
)
511 i
= c2p
.find(computation_pair(node
->comp
[0], node
->comp
[1]));
512 if (i
== c2p
.end()) {
514 c2p
[computation_pair(node
->comp
[0], node
->comp
[1])] = cp
;
520 void equivalence_checker::dismiss(eq_node
*node
)
522 /* unclosed nodes weren't added to c2p->nodes */
523 if (node
&& !node
->closed
)
527 void equivalence_checker::init_trace(eq_node
*node
)
532 if (!options
->trace_error
)
534 if (isl_map_is_empty(node
->peek_lost()))
536 node
->trace
= node
->to_string();
537 std::cerr
<< node
->trace
;
538 isl_map_free(node
->lost_sample
);
539 node
->lost_sample
= isl_map_from_basic_map(
540 isl_map_sample(node
->get_lost()));
541 ctx
= isl_map_get_ctx(node
->lost_sample
);
542 prn
= isl_printer_to_file(ctx
, stderr
);
543 prn
= isl_printer_print_map(prn
, node
->lost_sample
);
544 prn
= isl_printer_end_line(prn
);
545 isl_printer_free(prn
);
548 void equivalence_checker::extend_trace_propagation(eq_node
*node
,
549 eq_node
*child
, edge
*e1
, edge
*e2
)
554 if (!options
->trace_error
)
556 if (!child
->lost_sample
)
558 if (node
->lost_sample
)
560 node
->trace
= child
->trace
+ node
->to_string();
561 std::cerr
<< node
->trace
;
562 node
->lost_sample
= isl_map_copy(child
->lost_sample
);
564 node
->lost_sample
= isl_map_apply_domain(node
->lost_sample
,
566 isl_map_copy(e1
->relation
)));
568 node
->lost_sample
= isl_map_apply_range(node
->lost_sample
,
570 isl_map_copy(e2
->relation
)));
571 node
->lost_sample
= isl_map_intersect(node
->lost_sample
,
572 isl_map_copy(node
->want
));
573 node
->lost_sample
= isl_map_from_basic_map(
574 isl_map_sample(node
->lost_sample
));
575 ctx
= isl_map_get_ctx(node
->lost_sample
);
576 prn
= isl_printer_to_file(ctx
, stderr
);
577 prn
= isl_printer_print_map(prn
, node
->lost_sample
);
578 prn
= isl_printer_end_line(prn
);
579 isl_printer_free(prn
);
582 void equivalence_checker::extend_trace_widening(eq_node
*node
, eq_node
*child
)
587 if (!options
->trace_error
)
589 if (!child
->lost_sample
)
591 if (isl_map_is_empty(node
->peek_lost()))
593 node
->trace
= child
->trace
+ node
->to_string();
594 std::cerr
<< node
->trace
;
595 node
->lost_sample
= isl_map_from_basic_map(
596 isl_map_sample(node
->get_lost()));
597 ctx
= isl_map_get_ctx(node
->lost_sample
);
598 prn
= isl_printer_to_file(ctx
, stderr
);
599 prn
= isl_printer_print_map(prn
, node
->lost_sample
);
600 prn
= isl_printer_end_line(prn
);
601 isl_printer_free(prn
);
604 /* Check for which subset (got) of the want relation equivelance
605 * holds for the pair of computations in the equivalence node.
607 * We first handle the easy cases: empty want, input computations
610 * If the current node is not a narrowing or a widening node
611 * and we can find an ancestor with the same pair of computations,
612 * then we will try to apply induction or widening in handle_with_ancestor.
613 * However, if the requested relation (want) of the ancestor is a strict
614 * subset of that of the current node, then we have already applied
615 * widening on an intermediate node (with a different pair of computations)
616 * so we shouldn't apply widening again (and we can't apply induction
617 * because the relation of the ancestor is a strict subset).
619 * In all other cases we try to apply propagation in handle_propagation.
621 void equivalence_checker::handle(eq_node
*node
)
623 computation
*comp1
= node
->comp
[0];
624 computation
*comp2
= node
->comp
[1];
626 if (!comp1
->is_copy() && !comp2
->is_copy() &&
627 (strcmp(comp1
->operation
, comp2
->operation
) ||
628 comp1
->arity
!= comp2
->arity
)) {
629 node
->set_lost(isl_map_copy(node
->want
));
637 isl_map
*want
= node
->want
;
638 node
->need
= isl_map_empty(isl_map_get_space(want
));
639 int empty
= isl_map_is_empty(want
);
642 node
->set_lost(isl_map_empty(isl_map_get_space(want
)));
645 if (node
->comp
[0]->is_input() && node
->comp
[1]->is_input()) {
646 if (strcmp(node
->comp
[0]->operation
, node
->comp
[1]->operation
)) {
647 node
->set_lost(isl_map_copy(want
));
650 node
->set_lost(isl_map_subtract(
651 isl_map_copy(node
->want
),
652 isl_map_identity(isl_map_get_space(want
))));
656 comp_pair
*cp
= get_comp_pair(node
);
657 if ((s
= cp
->tabled(node
)) != NULL
) {
658 node
->set_lost(isl_map_intersect(s
->get_lost(),
659 isl_map_copy(node
->want
)));
660 s
->collect_open_assumed(node
->assumed
);
664 cp
->nodes
.push_back(node
);
666 if (!node
->narrowing
&& !node
->widening
&&
667 (a
= cp
->last_ancestor(node
)) != NULL
) {
669 is_subset
= isl_map_is_strict_subset(a
->want
, node
->want
);
670 assert(is_subset
>= 0);
672 handle_propagation(node
);
674 handle_with_ancestor(node
, a
);
676 handle_propagation(node
);
680 /* Check if we can apply propagation to prove equivalence of the given node.
681 * First check if we need to apply copy propagation and if not
682 * check if the operations are the same and apply "regular" propagation.
684 * After we get back, we need to check that any induction hypotheses
685 * we have used in the process of proving the node hold.
686 * If not, we replace the obtained relation (got) by that
687 * of a narrowing node in handle_narrowin.
689 void equivalence_checker::handle_propagation(eq_node
*node
)
691 if (node
->comp
[0]->is_copy())
692 handle_copy(node
, 0);
693 else if (node
->comp
[1]->is_copy())
694 handle_copy(node
, 1);
696 handle_propagation_same(node
);
697 node
->assumed
.erase(node
);
698 isl_map
*lost
= node
->get_lost();
699 lost
= isl_map_intersect(lost
, isl_map_copy(node
->need
));
700 int is_empty
= isl_map_is_empty(lost
);
702 assert(is_empty
>= 0);
704 handle_narrowing(node
);
707 /* When applying the mappings on expansion edges to both sides of
708 * an equivalence relation, each element in the original equivalence
709 * relation is mapped to many elements on both sides. For example,
710 * if the original equivalence relation has i R i' for i = i', then the new
711 * equivalence relation may have (i,j) R (i',j') for i = i' and
712 * 0 <= j,j' <= 10, expressing that all elements of the row read by
713 * iteration i should be equal to all elements of the row read by i'.
714 * Instead, we want to express that each individual element read by i'
715 * should be equal to the corresponding element read by i'.
716 * In the example, we need to introduce the equality j = j'.
718 * We first check that the number of dimensions added by the expansions
719 * is the same in both programs. Then we construct an identity relation
720 * between this number of dimensions, lift it to the space of the
721 * new equivalence relation and take the intersection.
723 * We have to be careful, though, that we don't loose any array elements
724 * by taking the intersection. In particular, the expansion maps
725 * a given read operation to iterations of an extended domain that
726 * reads all the individual array elements. The read operation is
727 * only equivalent to some other read operation if all the reads
728 * of the individual array elements are equivalent.
729 * We therefore need to make sure that adding the equalities does
730 * not remove any of the reads in the extended domain.
731 * We do this by projecting out the extra dimensions on one side
732 * from both "want" and "want \cap eq". The resulting maps should
733 * be the same. We do this check for both sides separately.
735 * If either of the above tests fails, then we simply return
736 * the original over-ambitious want.
738 static isl_map
*expansion_want(isl_map
*want
, edge
*e1
, edge
*e2
)
740 unsigned n_in
, n_out
;
741 unsigned s_dim_1
= isl_map_dim(e1
->relation
, isl_dim_out
) -
742 isl_map_dim(e1
->relation
, isl_dim_in
);
743 unsigned s_dim_2
= isl_map_dim(e2
->relation
, isl_dim_out
) -
744 isl_map_dim(e2
->relation
, isl_dim_in
);
745 if (s_dim_1
!= s_dim_2
)
748 isl_space
*dim
= isl_map_get_space(e1
->relation
);
749 dim
= isl_space_drop_dims(dim
, isl_dim_in
,
750 0, isl_space_dim(dim
, isl_dim_in
));
751 dim
= isl_space_drop_dims(dim
, isl_dim_out
,
752 0, isl_space_dim(dim
, isl_dim_out
));
753 dim
= isl_space_add_dims(dim
, isl_dim_in
, s_dim_1
);
754 dim
= isl_space_add_dims(dim
, isl_dim_out
, s_dim_1
);
755 isl_basic_map
*id
= isl_basic_map_identity(dim
);
757 dim
= isl_space_range(isl_map_get_space(e1
->relation
));
758 isl_basic_map
*s_1
= isl_basic_map_identity(isl_space_map_from_set(dim
));
759 s_1
= isl_basic_map_remove_dims(s_1
, isl_dim_in
, 0,
760 isl_basic_map_dim(s_1
, isl_dim_in
) - s_dim_1
);
761 id
= isl_basic_map_apply_domain(id
, s_1
);
763 dim
= isl_space_range(isl_map_get_space(e2
->relation
));
764 isl_basic_map
*s_2
= isl_basic_map_identity(isl_space_map_from_set(dim
));
765 s_2
= isl_basic_map_remove_dims(s_2
, isl_dim_in
, 0,
766 isl_basic_map_dim(s_2
, isl_dim_in
) - s_dim_2
);
767 id
= isl_basic_map_apply_range(id
, s_2
);
769 bool unmatched
= false;
770 isl_map
*matched_want
;
771 isl_map
*proj_want
, *proj_matched
;
772 matched_want
= isl_map_intersect(isl_map_copy(want
),
773 isl_map_from_basic_map(id
));
775 n_in
= isl_map_dim(want
, isl_dim_in
);
776 proj_want
= isl_map_remove_dims(isl_map_copy(want
),
777 isl_dim_in
, n_in
- s_dim_1
, s_dim_1
);
778 proj_matched
= isl_map_remove_dims(isl_map_copy(matched_want
),
779 isl_dim_in
, n_in
- s_dim_1
, s_dim_1
);
780 if (!isl_map_is_equal(proj_want
, proj_matched
))
782 isl_map_free(proj_want
);
783 isl_map_free(proj_matched
);
785 n_out
= isl_map_dim(want
, isl_dim_out
);
786 proj_want
= isl_map_remove_dims(isl_map_copy(want
),
787 isl_dim_out
, n_out
- s_dim_2
, s_dim_2
);
788 proj_matched
= isl_map_remove_dims(isl_map_copy(matched_want
),
789 isl_dim_out
, n_out
- s_dim_2
, s_dim_2
);
790 if (!isl_map_is_equal(proj_want
, proj_matched
))
792 isl_map_free(proj_want
);
793 isl_map_free(proj_matched
);
796 isl_map_free(matched_want
);
803 static isl_map
*propagation_want(eq_node
*node
, edge
*e1
, edge
*e2
)
807 new_want
= isl_map_copy(node
->want
);
809 new_want
= isl_map_apply_domain(new_want
,
810 isl_map_copy(e1
->relation
));
812 new_want
= isl_map_apply_range(new_want
,
813 isl_map_copy(e2
->relation
));
815 if (e1
&& e2
&& e1
->type
== edge::expansion
)
816 new_want
= expansion_want(new_want
, e1
, e2
);
818 return isl_map_detect_equalities(new_want
);
821 static eq_node
*propagation_node(eq_node
*node
, edge
*e1
, edge
*e2
)
824 computation
*comp1
= e1
? e1
->source
: node
->comp
[0];
825 computation
*comp2
= e2
? e2
->source
: node
->comp
[1];
827 new_want
= propagation_want(node
, e1
, e2
);
828 eq_node
*child
= new eq_node(comp1
, comp2
, new_want
);
832 static isl_map
*set_to_empty(isl_map
*map
)
835 space
= isl_map_get_space(map
);
837 return isl_map_empty(space
);
840 /* Propagate "lost" from child back to parent.
842 isl_map
*equivalence_checker::lost_from_propagation(eq_node
*parent
,
843 eq_node
*node
, edge
*e1
, edge
*e2
)
846 new_lost
= node
->get_lost();
849 new_lost
= isl_map_apply_domain(new_lost
,
850 isl_map_reverse(isl_map_copy(e1
->relation
)));
852 new_lost
= isl_map_apply_range(new_lost
,
853 isl_map_reverse(isl_map_copy(e2
->relation
)));
855 new_lost
= isl_map_intersect(new_lost
, isl_map_copy(parent
->want
));
857 extend_trace_propagation(parent
, node
, e1
, e2
);
862 void equivalence_checker::handle_copy(eq_node
*node
, int i
)
864 std::vector
<edge
*>::iterator iter
;
865 computation
*copy
= node
->comp
[i
];
866 isl_map
*want
= node
->want
;
869 lost
= isl_map_empty(isl_map_get_space(want
));
871 for (iter
= copy
->edges
.begin(); iter
!= copy
->edges
.end(); ++iter
) {
875 child
= propagation_node(node
, e
, NULL
);
877 child
= propagation_node(node
, NULL
, e
);
886 new_lost
= lost_from_propagation(node
, child
, e
, NULL
);
888 new_lost
= lost_from_propagation(node
, child
, NULL
, e
);
891 lost
= isl_map_union_disjoint(lost
, new_lost
);
893 node
->set_lost(lost
);
896 /* Compute and return the part of "want" that is lost when propagating
897 * over all edges of argument position pos1 in the first program
898 * and argument position pos2 in the second program.
899 * Since the domains of the dependence mappings on edges with the same
900 * argument position partition the domain of the computation (and are
901 * therefore disjoint), we simply need to take the disjoint union
902 * of all losts over all pairs of edges.
904 isl_map
*equivalence_checker::lost_at_pos(eq_node
*node
, int pos1
, int pos2
)
907 lost
= isl_map_empty(isl_map_get_space(node
->want
));
909 for (int i
= 0; i
< node
->comp
[0]->edges
.size(); ++i
) {
910 edge
*e1
= node
->comp
[0]->edges
[i
];
913 for (int j
= 0; j
< node
->comp
[1]->edges
.size(); ++j
) {
914 edge
*e2
= node
->comp
[1]->edges
[j
];
919 isl_map
*new_lost
= NULL
;
921 child
= propagation_node(node
, e1
, e2
);
923 new_lost
= lost_from_propagation(node
, child
, e1
, e2
);
925 lost
= isl_map_union_disjoint(lost
, new_lost
);
927 std::set
<eq_node
*>::iterator k
;
928 for (k
= child
->assumed
.begin();
929 k
!= child
->assumed
.end(); ++k
)
930 node
->assumed
.insert(*k
);
937 /* Compute the lost that results from propagation on a pair of
938 * computations with a commutative operation.
940 * We first compute the losts for each pair of argument positions
941 * and store the result in lost[pos1][pos2].
942 * Then we perform a backtracking search over all permutations
943 * of the arguments. For each permutation, we compute the lost
944 * relation as the union of the losts over all arguments.
945 * The final lost is the intersection of all these losts over all
948 void equivalence_checker::handle_propagation_comm(eq_node
*node
)
950 int trace_error
= options
->trace_error
;
953 unsigned r
= node
->comp
[0]->arity
;
955 std::vector
<std::vector
<isl_map
*> > lost
;
957 for (int i
= 0; i
< r
; ++i
) {
958 std::vector
<isl_map
*> row
;
959 for (int j
= 0; j
< r
; ++j
) {
961 pos_lost
= lost_at_pos(node
, i
, j
);
962 row
.push_back(pos_lost
);
968 std::vector
<int> perm
;
969 std::vector
<isl_map
*> lost_at
;
970 for (level
= 0; level
< r
; ++level
) {
972 lost_at
.push_back(NULL
);
976 total_lost
= isl_map_copy(node
->want
);
980 if (perm
[level
] == r
) {
988 for (l
= 0; l
< level
; ++l
)
989 if (perm
[l
] == perm
[level
])
996 isl_map_free(lost_at
[level
]);
997 lost_at
[level
] = isl_map_copy(lost
[level
][perm
[level
]]);
999 lost_at
[level
] = isl_map_union(lost_at
[level
],
1000 isl_map_copy(lost_at
[level
- 1]));
1002 if (level
< r
- 1) {
1007 lost_at
[level
] = isl_map_coalesce(lost_at
[level
]);
1008 total_lost
= isl_map_intersect(total_lost
,
1009 isl_map_copy(lost_at
[level
]));
1013 for (int i
= 0; i
< r
; ++i
)
1014 isl_map_free(lost_at
[i
]);
1016 for (int i
= 0; i
< r
; ++i
)
1017 for (int j
= 0; j
< r
; ++j
)
1018 isl_map_free(lost
[i
][j
]);
1020 node
->set_lost(total_lost
);
1022 options
->trace_error
= trace_error
;
1026 /* Compute the lost that results from propagation on a pair of
1029 * First, for functions of zero arity (i.e., constants), equivalence
1030 * always holds and the lost relation is empty.
1031 * For commutative operations, the computation is delegated
1032 * to handle_propagation_comm.
1033 * Otherwise, we simply take the union of the losts over each
1034 * argument position (always taking the same argument position
1035 * in both programs).
1037 void equivalence_checker::handle_propagation_same(eq_node
*node
)
1039 unsigned r
= node
->comp
[0]->arity
;
1041 node
->set_lost(isl_map_empty(isl_map_get_space(node
->want
)));
1044 if (is_commutative(node
->comp
[0]->operation
)) {
1045 handle_propagation_comm(node
);
1050 lost
= isl_map_empty(isl_map_get_space(node
->want
));
1051 for (int i
= 0; i
< r
; ++i
) {
1053 pos_lost
= lost_at_pos(node
, i
, i
);
1054 lost
= isl_map_union(lost
, pos_lost
);
1056 lost
= isl_map_coalesce(lost
);
1057 node
->set_lost(lost
);
1060 void equivalence_checker::handle_widening(eq_node
*node
, eq_node
*a
)
1065 wants
= isl_map_union(isl_map_copy(node
->want
), isl_map_copy(a
->want
));
1066 aff
= isl_map_from_basic_map(isl_map_affine_hull(wants
));
1067 aff
= isl_map_intersect_domain(aff
,
1068 isl_set_copy(node
->comp
[0]->domain
));
1069 aff
= isl_map_intersect_range(aff
,
1070 isl_set_copy(node
->comp
[1]->domain
));
1072 eq_node
*child
= new eq_node(node
->comp
[0], node
->comp
[1], aff
);
1073 child
->widening
= 1;
1076 node
->set_lost(isl_map_intersect(child
->get_lost(),
1077 isl_map_copy(node
->want
)));
1078 extend_trace_widening(node
, child
);
1079 std::set
<eq_node
*>::iterator i
;
1080 for (i
= child
->assumed
.begin(); i
!= child
->assumed
.end(); ++i
)
1081 node
->assumed
.insert(*i
);
1085 /* Perform induction, if possible, and widening if we have to.
1087 void equivalence_checker::handle_with_ancestor(eq_node
*node
, eq_node
*a
)
1089 if (a
->narrowing
|| isl_map_is_subset(node
->want
, a
->want
)) {
1091 need
= isl_map_intersect(isl_map_copy(node
->want
),
1092 isl_map_copy(a
->want
));
1093 node
->set_lost(isl_map_subtract(isl_map_copy(node
->want
),
1094 isl_map_copy(a
->want
)));
1095 node
->assumed
.insert(a
);
1096 a
->need
= isl_map_union(a
->need
, need
);
1098 handle_widening(node
, a
);
1102 struct narrowing_data
{
1104 equivalence_checker
*ec
;
1108 static isl_stat
basic_handle_narrowing(__isl_take isl_basic_map
*bmap
,
1111 narrowing_data
*data
= (narrowing_data
*)user
;
1114 child
= new eq_node(data
->node
->comp
[0], data
->node
->comp
[1],
1115 isl_map_from_basic_map(bmap
));
1116 child
->narrowing
= 1;
1117 data
->ec
->narrowings
++;
1118 data
->ec
->handle(child
);
1119 data
->new_got
= isl_map_union(data
->new_got
, child
->get_got());
1121 std::set
<eq_node
*>::iterator i
;
1122 for (i
= child
->assumed
.begin(); i
!= child
->assumed
.end(); ++i
)
1123 data
->node
->assumed
.insert(*i
);
1124 data
->ec
->dismiss(child
);
1129 /* Construct and handle narrowing nodes for the given node.
1131 * If the node itself was already a narrowing node, then we
1132 * simply return the empty relation.
1133 * Otherwise, we consider each basic relation in the obtained
1134 * relation, construct a new node with that basic relation as
1135 * requested relation and take the union of all obtained relations.
1137 void equivalence_checker::handle_narrowing(eq_node
*node
)
1140 node
->assumed
.clear();
1142 isl_map
*want
= node
->want
;
1143 isl_map
*new_got
= isl_map_empty(isl_map_get_space(want
));
1144 if (!options
->narrowing
|| node
->narrowing
) {
1145 node
->set_got(new_got
);
1149 narrowing_data data
= { node
, this, new_got
};
1150 isl_map_foreach_basic_map(node
->peek_got(), &basic_handle_narrowing
,
1152 node
->set_got(data
.new_got
);
1155 static __isl_give isl_union_set
*update_result(__isl_take isl_union_set
*res
,
1156 const char *array_name
, isl_map
*map
, int first
, int n
)
1158 if (isl_map_is_empty(map
)) {
1163 isl_set
*range
= isl_map_range(map
);
1164 range
= isl_set_remove_dims(range
, isl_dim_set
, first
, n
);
1165 range
= isl_set_coalesce(range
);
1166 range
= isl_set_set_tuple_name(range
, array_name
);
1167 res
= isl_union_set_add_set(res
, range
);
1172 struct check_equivalence_data
{
1173 dependence_graph
*dg1
;
1174 dependence_graph
*dg2
;
1175 equivalence_checker
*ec
;
1180 static isl_stat
basic_check(__isl_take isl_basic_map
*bmap
, void *user
)
1182 check_equivalence_data
*data
= (check_equivalence_data
*)user
;
1184 eq_node
*root
= new eq_node(data
->dg1
->out
, data
->dg2
->out
,
1185 isl_map_from_basic_map(bmap
));
1186 data
->ec
->handle(root
);
1187 data
->got
= isl_map_union_disjoint(data
->got
, root
->get_got());
1188 data
->lost
= isl_map_union_disjoint(data
->lost
, root
->get_lost());
1189 data
->ec
->dismiss(root
);
1194 static int check_equivalence_array(isl_ctx
*ctx
, equivalence_checker
*ec
,
1195 dependence_graph
*dg1
, dependence_graph
*dg2
, int array1
, int array2
,
1196 isl_union_set
**proved
, isl_union_set
**not_proved
)
1199 const char *array_name
= dg1
->output_arrays
[array1
];
1200 isl_set
*out1
= isl_set_copy(dg1
->out
->domain
);
1201 isl_set
*out2
= isl_set_copy(dg2
->out
->domain
);
1202 unsigned dim1
= isl_set_n_dim(out1
);
1203 unsigned dim2
= isl_set_n_dim(out2
);
1204 unsigned array_dim
= dg1
->output_array_dims
[array1
];
1205 out1
= isl_set_fix_dim_si(out1
, dim1
-1, array1
);
1206 out2
= isl_set_fix_dim_si(out2
, dim2
-1, array2
);
1207 out1
= isl_set_remove_dims(out1
, isl_dim_set
, dim1
- 1, 1);
1208 out2
= isl_set_remove_dims(out2
, isl_dim_set
, dim2
- 1, 1);
1209 int equal
= isl_set_is_equal(out1
, out2
);
1214 fprintf(stderr
, "different output domains for array %s\n",
1219 out1
= isl_set_copy(dg1
->out
->domain
);
1220 out2
= isl_set_copy(dg2
->out
->domain
);
1221 out1
= isl_set_fix_dim_si(out1
, dim1
-1, array1
);
1222 out2
= isl_set_fix_dim_si(out2
, dim2
-1, array2
);
1223 out1
= isl_set_coalesce(out1
);
1224 out2
= isl_set_coalesce(out2
);
1225 isl_space
*dim
= isl_space_map_from_set(isl_set_get_space(out2
));
1226 isl_map
*id
= isl_map_identity(dim
);
1227 n_in
= isl_map_dim(id
, isl_dim_in
);
1228 id
= isl_map_remove_dims(id
, isl_dim_in
, n_in
- 1, 1);
1229 id
= isl_map_apply_domain(id
, isl_map_copy(id
));
1230 id
= isl_map_intersect_domain(id
, out1
);
1231 id
= isl_map_intersect_range(id
, out2
);
1233 isl_map
*got
= isl_map_empty(isl_map_get_space(id
));
1234 isl_map
*lost
= isl_map_copy(got
);
1235 check_equivalence_data data
= { dg1
, dg2
, ec
, got
, lost
};
1236 isl_map_foreach_basic_map(id
, &basic_check
, &data
);
1238 *proved
= update_result(*proved
,
1239 array_name
, data
.got
, array_dim
, dim2
- array_dim
);
1240 *not_proved
= update_result(*not_proved
,
1241 array_name
, data
.lost
, array_dim
, dim2
- array_dim
);
1245 /* The input arrays of the two programs are supposed to be the same,
1246 * so they should at least have the same dimension. Make sure
1247 * this is true, because we depend on it later on.
1249 static int check_input_arrays(dependence_graph
*dg1
, dependence_graph
*dg2
)
1251 for (int i
= 0; i
< dg1
->input_computations
.size(); ++i
)
1252 for (int j
= 0; j
< dg2
->input_computations
.size(); ++j
) {
1253 if (strcmp(dg1
->input_computations
[i
]->operation
,
1254 dg2
->input_computations
[j
]->operation
))
1256 if (dg1
->input_computations
[i
]->dim
==
1257 dg2
->input_computations
[j
]->dim
)
1260 "input arrays \"%s\" do not have the same dimension\n",
1261 dg1
->input_computations
[i
]->operation
);
1268 static __isl_give isl_union_set
*add_array(__isl_take isl_union_set
*only
,
1269 dependence_graph
*dg
, int array
)
1274 dim
= isl_union_set_get_space(only
);
1275 dim
= isl_space_add_dims(dim
, isl_dim_set
, dg
->output_array_dims
[array
]);
1276 dim
= isl_space_set_tuple_name(dim
, isl_dim_set
, dg
->output_arrays
[array
]);
1277 set
= isl_set_universe(dim
);
1278 only
= isl_union_set_add_set(only
, set
);
1283 static void print_results(const char *str
, __isl_keep isl_union_set
*only
)
1287 if (isl_union_set_is_empty(only
))
1290 fprintf(stdout
, "%s: '", str
);
1291 prn
= isl_printer_to_file(isl_union_set_get_ctx(only
), stdout
);
1292 prn
= isl_printer_print_union_set(prn
, only
);
1293 isl_printer_free(prn
);
1294 fprintf(stdout
, "'\n");
1297 static int check_equivalence(isl_ctx
*ctx
,
1298 dependence_graph
*dg1
, dependence_graph
*dg2
, options
*options
)
1302 isl_union_set
*only1
, *only2
, *proved
, *not_proved
;
1303 dg1
->flatten_associative_operators(options
->ops
->associative
);
1304 dg2
->flatten_associative_operators(options
->ops
->associative
);
1305 equivalence_checker
ec(options
);
1308 if (check_input_arrays(dg1
, dg2
))
1311 dim
= isl_space_set_alloc(ctx
, 0, 0);
1312 proved
= isl_union_set_empty(isl_space_copy(dim
));
1313 not_proved
= isl_union_set_empty(isl_space_copy(dim
));
1314 only1
= isl_union_set_empty(isl_space_copy(dim
));
1315 only2
= isl_union_set_empty(dim
);
1317 while (i1
< dg1
->output_arrays
.size() || i2
< dg2
->output_arrays
.size()) {
1319 cmp
= i1
== dg1
->output_arrays
.size() ? 1 :
1320 i2
== dg2
->output_arrays
.size() ? -1 :
1321 strcmp(dg1
->output_arrays
[i1
], dg2
->output_arrays
[i2
]);
1323 only1
= add_array(only1
, dg1
, i1
);
1325 } else if (cmp
> 0) {
1326 only2
= add_array(only2
, dg2
, i2
);
1329 check_equivalence_array(ctx
, &ec
, dg1
, dg2
, i1
, i2
,
1330 &proved
, ¬_proved
);
1336 context
= isl_set_union(isl_set_copy(dg1
->context
),
1337 isl_set_copy(dg2
->context
));
1338 proved
= isl_union_set_gist_params(proved
, isl_set_copy(context
));
1339 not_proved
= isl_union_set_gist_params(not_proved
, context
);
1341 print_results("Equivalence proved", proved
);
1342 print_results("Equivalence NOT proved", not_proved
);
1343 print_results("Only in program 1", only1
);
1344 print_results("Only in program 2", only2
);
1346 isl_union_set_free(proved
);
1347 isl_union_set_free(not_proved
);
1348 isl_union_set_free(only1
);
1349 isl_union_set_free(only2
);
1351 if (options
->print_stats
) {
1352 fprintf(stderr
, "widenings: %d\n", ec
.widenings
);
1353 if (options
->narrowing
)
1354 fprintf(stderr
, "narrowings: %d\n", ec
.narrowings
);
1360 static void dump_vertex(FILE *out
, computation
*comp
)
1362 fprintf(out
, "ND_%p [label = \"%d,%s/%d\"];\n",
1363 comp
, comp
->location
, comp
->operation
, comp
->arity
);
1364 for (int i
= 0; i
< comp
->edges
.size(); ++i
)
1365 fprintf(out
, "ND_%p -> ND_%p%s;\n",
1366 comp
, comp
->edges
[i
]->source
,
1367 comp
->edges
[i
]->type
== edge::expansion
?
1368 " [color=\"blue\"]" : "");
1371 static void dump_graph(FILE *out
, dependence_graph
*dg
)
1373 fprintf(out
, "digraph dummy {\n");
1374 dump_vertex(out
, dg
->out
);
1375 for (int i
= 0; i
< dg
->vertices
.size(); ++i
)
1376 dump_vertex(out
, dg
->vertices
[i
]);
1377 fprintf(out
, "}\n");
1380 static void dump_graphs(dependence_graph
**dg
, struct options
*options
)
1383 char path
[PATH_MAX
];
1385 if (!options
->dump_graphs
)
1388 for (i
= 0; i
< 2; ++i
) {
1391 s
= snprintf(path
, sizeof(path
), "%s.dot", options
->program
[i
]);
1392 assert(s
< sizeof(path
));
1393 out
= fopen(path
, "w");
1395 dump_graph(out
, dg
[i
]);
1400 void parse_ops(struct options
*options
) {
1403 if (options
->associative
) {
1404 tok
= strtok(options
->associative
, ",");
1405 options
->ops
->associative
.push_back(strdup(tok
));
1406 while ((tok
= strtok(NULL
, ",")) != NULL
)
1407 options
->ops
->associative
.push_back(strdup(tok
));
1410 if (options
->commutative
) {
1411 tok
= strtok(options
->commutative
, ",");
1412 options
->ops
->commutative
.push_back(strdup(tok
));
1413 while ((tok
= strtok(NULL
, ",")) != NULL
)
1414 options
->ops
->commutative
.push_back(strdup(tok
));
1418 int main(int argc
, char *argv
[])
1420 struct options
*options
= options_new_with_defaults();
1421 struct isl_ctx
*ctx
;
1422 isl_set
*context
= NULL
;
1423 dependence_graph
*dg
[2];
1425 argc
= options_parse(options
, argc
, argv
, ISL_ARG_ALL
);
1428 ctx
= isl_ctx_alloc_with_options(&options_args
, options
);
1430 fprintf(stderr
, "Unable to allocate ctx\n");
1434 if (options
->context
)
1435 context
= isl_set_read_from_str(ctx
, options
->context
);
1438 unsigned out_dim
= 0;
1439 for (int i
= 0; i
< 2; ++i
) {
1441 in
= fopen(options
->program
[i
], "r");
1443 fprintf(stderr
, "Unable to open %s\n", options
->program
[i
]);
1446 pdg
[i
] = yaml::Load
<pdg::PDG
>(in
, ctx
);
1449 fprintf(stderr
, "Unable to read %s\n", options
->program
[i
]);
1452 out_dim
= update_out_dim(pdg
[i
], out_dim
);
1455 pdg
[i
]->params
.size() != isl_set_dim(context
, isl_dim_param
)) {
1457 "Parameter dimension mismatch; context ignored\n");
1458 isl_set_free(context
);
1463 for (int i
= 0; i
< 2; ++i
) {
1464 dg
[i
] = pdg_to_dg(pdg
[i
], out_dim
, isl_set_copy(context
));
1469 dump_graphs(dg
, options
);
1471 int res
= check_equivalence(ctx
, dg
[0], dg
[1], options
);
1473 for (int i
= 0; i
< 2; ++i
)
1475 isl_set_free(context
);