update isl for support for recent versions of clang
[ppn.git] / check_channel_sizes.cc
blobf0f618893666e8372c91fc11e182d825629c4305
1 #include <iostream>
2 #include <isl/ctx.h>
3 #include <isl/id.h>
4 #include <isl/space.h>
5 #include <isl/aff.h>
6 #include <isl/set.h>
7 #include <isl/map.h>
8 #include <isl/union_map.h>
9 #include <isl/ast.h>
10 #include <isl/ast_build.h>
11 #include <isl/printer.h>
13 #include <isa/yaml.h>
14 #include <isa/pdg.h>
16 using namespace std;
17 using namespace pdg;
19 struct pos_dep {
20 int pos;
21 pdg::dependence *dep;
24 static __isl_give isl_printer *print_access_pair(__isl_take isl_printer *p,
25 int pos, pdg::node *from, int from_access, pdg::node *to, int to_access)
27 p = isl_printer_print_int(p, pos);
28 p = isl_printer_print_str(p, "_");
29 p = isl_printer_print_int(p, from->nr);
30 p = isl_printer_print_str(p, "_");
31 p = isl_printer_print_int(p, from_access);
32 p = isl_printer_print_str(p, "_");
33 p = isl_printer_print_int(p, to->nr);
34 p = isl_printer_print_str(p, "_");
35 p = isl_printer_print_int(p, to_access);
36 return p;
39 static __isl_give isl_printer *print_core_domain(__isl_take isl_printer *p,
40 __isl_take isl_ast_print_options *print_options,
41 __isl_keep isl_ast_node *node, void *user)
43 pos_dep *pd;
44 int pos;
45 pdg::dependence *dep;
46 pdg::node *from;
47 pdg::node *to;
48 int from_access;
49 int to_access;
50 isl_id *id;
51 isl_ast_expr *expr, *arg;
52 const char *name;
54 expr = isl_ast_node_user_get_expr(node);
55 arg = isl_ast_expr_get_op_arg(expr, 0);
56 id = isl_ast_expr_get_id(arg);
57 name = isl_id_get_name(id);
58 pd = (pos_dep *) isl_id_get_user(id);
59 dep = pd->dep;
60 pos = pd->pos;
61 isl_id_free(id);
62 isl_ast_expr_free(arg);
63 isl_ast_expr_free(expr);
65 from = dep->from;
66 to = dep->to;
67 from_access = dep->from_access ? dep->from_access->nr
68 : from->statement->accesses.size();
69 to_access = dep->to_access ? dep->to_access->nr
70 : to->statement->accesses.size();
72 if (!strcmp(name, "write")) {
73 p = isl_printer_start_line(p);
74 p = isl_printer_print_str(p, "if (++i_");
75 p = print_access_pair(p, pos, from, from_access, to, to_access);
76 p = isl_printer_print_str(p, " > s_");
77 p = print_access_pair(p, pos, from, from_access, to, to_access);
78 p = isl_printer_print_str(p, ")");
79 p = isl_printer_end_line(p);
80 p = isl_printer_indent(p, 4);
81 p = isl_printer_start_line(p);
82 p = isl_printer_print_str(p, "s_");
83 p = print_access_pair(p, pos, from, from_access, to, to_access);
84 p = isl_printer_print_str(p, " = i_");
85 p = print_access_pair(p, pos, from, from_access, to, to_access);
86 p = isl_printer_print_str(p, ";");
87 p = isl_printer_end_line(p);
88 p = isl_printer_indent(p, -4);
89 } else {
90 p = isl_printer_start_line(p);
91 p = isl_printer_print_str(p, "--i_");
92 p = print_access_pair(p, pos, from, from_access, to, to_access);
93 p = isl_printer_print_str(p, ";");
94 p = isl_printer_end_line(p);
97 isl_ast_print_options_free(print_options);
98 return p;
101 static __isl_give isl_union_map *add_schedule(__isl_take isl_union_map *sched,
102 __isl_take isl_set *domain, int maxdim, const char *type, pos_dep *pd)
104 isl_id *id;
105 isl_map *sched_i;
106 int n;
108 sched_i = isl_set_identity(domain);
109 n = isl_map_dim(sched_i, isl_dim_out);
110 sched_i = isl_map_add_dims(sched_i, isl_dim_out, maxdim - n);
111 for (int i = n; i < maxdim; ++i)
112 sched_i = isl_map_fix_si(sched_i, isl_dim_out, i, 0);
113 id = isl_id_alloc(isl_union_map_get_ctx(sched), type, pd);
114 sched_i = isl_map_set_tuple_id(sched_i, isl_dim_in, id);
115 sched = isl_union_map_add_map(sched, sched_i);
117 return sched;
120 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
121 __isl_take isl_ast_print_options *print_options,
122 __isl_keep isl_ast_node *node, void *user)
124 PDG *pdg = (PDG *) user;
125 isl_ctx *ctx = pdg->get_isl_ctx();
126 int maxdim = 0;
127 isl_ast_build *build;
128 isl_ast_node *tree;
129 isl_union_map *sched;
130 pos_dep *data;
132 p = isl_printer_start_line(p);
133 p = isl_printer_print_str(p, "{");
134 p = isl_printer_end_line(p);
135 p = isl_printer_indent(p, 2);
137 int nparam = pdg->params.size();
139 for (int i = 0; i < pdg->nodes.size(); ++i) {
140 pdg::node *node = pdg->nodes[i];
141 if (node->prefix.size() > maxdim)
142 maxdim = node->prefix.size();
145 sched = isl_union_map_empty(isl_space_params_alloc(ctx, 0));
147 data = new pos_dep[pdg->dependences.size()];
148 std::map<pdg::dependence *,int> pos;
150 for (int i = 0; i < pdg->dependences.size(); ++i) {
151 pdg::dependence *dep = pdg->dependences[i];
152 if (dep->type == pdg::dependence::uninitialized)
153 continue;
154 if (dep->type == pdg::dependence::pn_part)
155 continue;
156 pdg::node* from = dep->from;
157 pdg::node* to = dep->to;
158 int from_access = dep->from_access ? dep->from_access->nr
159 : from->statement->accesses.size();
160 int to_access = dep->to_access ? dep->to_access->nr
161 : to->statement->accesses.size();
162 if (!dep->size)
163 continue;
164 pos[dep] = i;
166 p = isl_printer_start_line(p);
167 p = isl_printer_print_str(p, "int i_");
168 p = print_access_pair(p, i, from, from_access, to, to_access);
169 p = isl_printer_print_str(p, " = 0;");
170 p = isl_printer_end_line(p);
171 p = isl_printer_start_line(p);
172 p = isl_printer_print_str(p, "int s_");
173 p = print_access_pair(p, i, from, from_access, to, to_access);
174 p = isl_printer_print_str(p, " = 0;");
175 p = isl_printer_end_line(p);
178 for (int i = 0; i < pdg->dependences.size(); ++i) {
179 isl_set *domain;
180 pdg::dependence *dep = pdg->dependences[i];
181 pdg::dependence *container = dep;
182 if (dep->type == pdg::dependence::uninitialized)
183 continue;
184 if (dep->type == pdg::dependence::pn_union)
185 continue;
186 if (dep->type == pdg::dependence::pn_part)
187 container = dep->container;
188 if (!container->size)
189 continue;
190 data[i].pos = pos.find(container)->second;
191 data[i].dep = container;
193 isl_map *map = scatter_dependence(pdg, dep);
194 domain = isl_map_range(isl_map_copy(map));
195 sched = add_schedule(sched, domain, maxdim, "read", &data[i]);
196 domain = isl_map_domain(map);
197 sched = add_schedule(sched, domain, maxdim, "write", &data[i]);
200 build = isl_ast_build_from_context(pdg->get_context_isl_set());
201 tree = isl_ast_build_ast_from_schedule(build, sched);
202 isl_ast_build_free(build);
204 print_options = isl_ast_print_options_set_print_user(print_options,
205 &print_core_domain, NULL);
206 p = isl_ast_node_print(tree, p, print_options);
208 isl_ast_node_free(tree);
210 for (int i = 0; i < pdg->dependences.size(); ++i) {
211 pdg::dependence *dep = pdg->dependences[i];
212 if (dep->type == pdg::dependence::uninitialized)
213 continue;
214 if (dep->type == pdg::dependence::pn_part)
215 continue;
216 pdg::node* from = dep->from;
217 pdg::node* to = dep->to;
218 int from_access = dep->from_access ? dep->from_access->nr
219 : from->statement->accesses.size();
220 int to_access = dep->to_access ? dep->to_access->nr
221 : to->statement->accesses.size();
222 if (!dep->size)
223 continue;
225 p = isl_printer_start_line(p);
226 p = isl_printer_print_str(p, "assert(i_");
227 p = print_access_pair(p, i, from, from_access, to, to_access);
228 p = isl_printer_print_str(p, " == 0);");
229 p = isl_printer_end_line(p);
231 p = isl_printer_start_line(p);
232 p = isl_printer_print_str(p, "printf(\"s_");
233 p = print_access_pair(p, i, from, from_access, to, to_access);
234 p = isl_printer_print_str(p, ": %d; %d\\n\", s_");
235 p = print_access_pair(p, i, from, from_access, to, to_access);
236 p = isl_printer_print_str(p, ", (int)(");
237 p = isl_printer_print_str(p, dep->size->s->c_str());
238 p = isl_printer_print_str(p, "));");
239 p = isl_printer_end_line(p);
241 p = isl_printer_start_line(p);
242 p = isl_printer_print_str(p, "assert(s_");
243 p = print_access_pair(p, i, from, from_access, to, to_access);
244 p = isl_printer_print_str(p, " <= (int)(");
245 p = isl_printer_print_str(p, dep->size->s->c_str());
246 p = isl_printer_print_str(p, "));");
247 p = isl_printer_end_line(p);
250 p = isl_printer_indent(p, -2);
251 p = isl_printer_start_line(p);
252 p = isl_printer_print_str(p, "}");
253 p = isl_printer_end_line(p);
255 delete [] data;
257 return p;
260 static isl_stat print_macro(enum isl_ast_op_type type, void *user)
262 isl_printer **p = (isl_printer **) user;
264 if (type == isl_ast_op_fdiv_q)
265 return isl_stat_ok;
267 *p = isl_ast_op_type_print_macro(type, *p);
269 return isl_stat_ok;
272 /* Print the required macros for "node", including one for floord.
273 * We always print a macro for floord as it may also appear in the statements.
275 static __isl_give isl_printer *print_macros(
276 __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
278 p = isl_ast_op_type_print_macro(isl_ast_op_fdiv_q, p);
279 if (isl_ast_node_foreach_ast_op_type(node, &print_macro, &p) < 0) {
280 isl_printer_free(p);
281 return NULL;
283 return p;
286 static void print(__isl_keep isl_ast_node *tree, PDG *pdg)
288 isl_ctx *ctx;
289 FILE *out = stdout;
290 isl_printer *p;
291 isl_ast_print_options *print_options;
293 ctx = isl_ast_node_get_ctx(tree);
295 p = isl_printer_to_file(ctx, out);
296 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
298 p = isl_printer_start_line(p);
299 p = isl_printer_print_str(p, "#include <assert.h>");
300 p = isl_printer_end_line(p);
302 p = isl_printer_start_line(p);
303 p = isl_printer_print_str(p, "#include <stdio.h>");
304 p = isl_printer_end_line(p);
306 p = print_macros(tree, p);
308 p = isl_printer_start_line(p);
309 p = isl_printer_print_str(p, "int main()");
310 p = isl_printer_end_line(p);
311 p = isl_printer_start_line(p);
312 p = isl_printer_print_str(p, "{");
313 p = isl_printer_end_line(p);
315 p = isl_printer_indent(p, 4);
317 print_options = isl_ast_print_options_alloc(ctx);
318 print_options = isl_ast_print_options_set_print_user(print_options,
319 &print_user, pdg);
320 p = isl_ast_node_print(tree, p, print_options);
322 p = isl_printer_start_line(p);
323 p = isl_printer_print_str(p, "return 0;");
324 p = isl_printer_end_line(p);
326 p = isl_printer_indent(p, -4);
328 p = isl_printer_start_line(p);
329 p = isl_printer_print_str(p, "}");
330 p = isl_printer_end_line(p);
332 isl_printer_free(p);
335 void print_program(PDG *pdg)
337 isl_ctx *ctx = pdg->get_isl_ctx();
338 isl_set *context;
339 isl_union_map *sched;
340 isl_ast_build *build;
341 isl_ast_node *tree;
342 isl_id_list *iterators;
343 unsigned nparam;
345 context = pdg->get_context_isl_set();
346 context = isl_set_from_params(context);
347 nparam = isl_set_dim(context, isl_dim_param);
348 context = isl_set_move_dims(context, isl_dim_set, 0,
349 isl_dim_param, 0, nparam);
350 sched = isl_union_map_from_map(isl_set_identity(context));
352 context = isl_set_universe(isl_union_map_get_space(sched));
354 iterators = isl_id_list_alloc(ctx, nparam);
355 for (int i = 0; i < nparam; ++i) {
356 isl_id *id = isl_id_alloc(ctx, pdg->params[i]->name->s.c_str(), NULL);
357 iterators = isl_id_list_add(iterators, id);
360 build = isl_ast_build_from_context(context);
361 build = isl_ast_build_set_iterators(build, iterators);
362 tree = isl_ast_build_ast_from_schedule(build, sched);
363 isl_ast_build_free(build);
365 print(tree, pdg);
367 isl_ast_node_free(tree);
370 int main(int argc, char * argv[])
372 PDG *pdg;
373 isl_ctx *ctx = isl_ctx_alloc();
374 pdg = PDG::Load(stdin, ctx);
376 assert(pdg->context);
378 print_program(pdg);
380 pdg->free();
381 delete pdg;
382 isl_ctx_free(ctx);
384 return 0;