Support for osl statement_extensions
[clay.git] / source / util.c
blob6e25d82ae30b3a9ebd82d9ac79db96708bc93118
2 /*--------------------------------------------------------------------+
3 | Clay |
4 |--------------------------------------------------------------------|
5 | util.c |
6 |--------------------------------------------------------------------|
7 | First version: 03/04/2012 |
8 +--------------------------------------------------------------------+
10 +--------------------------------------------------------------------------+
11 | / __)( ) /__\ ( \/ ) |
12 | ( (__ )(__ /(__)\ \ / Chunky Loop Alteration wizardrY |
13 | \___)(____)(__)(__)(__) |
14 +--------------------------------------------------------------------------+
15 | Copyright (C) 2012 University of Paris-Sud |
16 | |
17 | This library is free software; you can redistribute it and/or modify it |
18 | under the terms of the GNU Lesser General Public License as published by |
19 | the Free Software Foundation; either version 2.1 of the License, or |
20 | (at your option) any later version. |
21 | |
22 | This library is distributed in the hope that it will be useful but |
23 | WITHOUT ANY WARRANTY; without even the implied warranty of |
24 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser |
25 | General Public License for more details. |
26 | |
27 | You should have received a copy of the GNU Lesser General Public License |
28 | along with this software; if not, write to the Free Software Foundation, |
29 | Inc., 51 Franklin Street, Fifth Floor, |
30 | Boston, MA 02110-1301 USA |
31 | |
32 | Clay, the Chunky Loop Alteration wizardrY |
33 | Written by Joel Poudroux, joel.poudroux@u-psud.fr |
34 +--------------------------------------------------------------------------*/
36 #include <stdlib.h>
37 #include <string.h>
38 #include <clay/array.h>
39 #include <clay/beta.h>
40 #include <clay/macros.h>
41 #include <clay/util.h>
42 #include <clay/errors.h>
44 #include <osl/statement.h>
45 #include <osl/body.h>
46 #include <osl/extensions/extbody.h>
47 #include <osl/extensions/scatnames.h>
48 #include <osl/scop.h>
49 #include <osl/generic.h>
50 #include <osl/util.h>
51 #include <osl/relation.h>
52 #include <osl/relation_list.h>
53 #include <osl/macros.h>
55 /**
56 * clay_util_statement_insert_inequation function:
57 * Insert a new inequation at the end of the scattering
58 * \param[in,out] statement
59 * \param[in] inequ [iter1, iter2, ..., param1, param2, ..., const]
60 * \param[in] nb_input_dims Nb input dims in the array
61 * \param[in] nb_params Nb params in the array
63 /*void clay_util_statement_insert_inequation(osl_statement_p statement,
64 clay_array_p inequ, int nb_input_dims, int nb_params) {
66 osl_relation_p scattering = statement->scattering;
67 int row = scattering->nb_rows;
68 int i, j;
69 int precision = scattering->precision;
71 // insert the inequation spliting (local dims are not in the inequation)
72 // (at the end)
73 osl_relation_insert_blank_row(scattering, row);
74 osl_int_set_si(precision, scattering->m[row], 0, 1); // type inequation
76 // affects input_dims
77 i = scattering->nb_output_dims+1;
78 for (j = 0 ; j < nb_input_dims ; j++) {
79 osl_int_set_si(precision,
80 scattering->m[row], i,
81 inequ->data[j]);
82 i++;
84 // affects parameters
85 i = 1 + scattering->nb_output_dims + scattering->nb_input_dims +
86 scattering->nb_local_dims;
87 for (; j < nb_params + nb_input_dims ; j++) {
88 osl_int_set_si(precision,
89 scattering->m[row], i,
90 inequ->data[j]);
91 i++;
93 // set the constant
94 osl_int_set_si(precision,
95 scattering->m[row], scattering->nb_columns-1,
96 inequ->data[inequ->size-1]);
102 /**
103 * clay_util_array_output_dims_pad_zero function:
104 * Pad zeros for alpha columns
105 * For example if we have [i, j], the result will be [0, i, 0, j, 0]
106 * \param[in,out] array
108 void clay_util_array_output_dims_pad_zero(clay_array_p a) {
109 int i;
110 int end = a->size*2+1;
112 for (i = a->size ; i < end ; i++)
113 clay_array_add(a, 0);
115 for (i = end-1 ; i >= 0 ; i--) {
116 a->data[i*2+1] = a->data[i];
117 a->data[i] = 0;
122 /**
123 * clay_util_statement_insert_inequation function:
124 * Insert a new inequation at the end of the scattering
125 * The list must have less or equal than 3 arrays
126 * Warning: here the output dims are complete
127 * example: if the output dims are : 0 i 0 j 0, you have
128 * to give all these numbers
129 * \param[in,out] statement
130 * \param[in] inequ {(([output, ...],) [param, ..],) [const]}
132 void clay_util_statement_set_inequation(
133 osl_statement_p statement,
134 clay_list_p inequ) {
136 osl_relation_p scattering = statement->scattering;
137 clay_array_p arr_dims = NULL, arr_params = NULL, arr_const = NULL;
138 int row = scattering->nb_rows;
139 int i, j;
140 int precision = scattering->precision;
142 // insert the inequation spliting (local dims are not in the inequation)
143 // (at the end)
144 osl_relation_insert_blank_row(scattering, row);
145 osl_int_set_si(precision, &scattering->m[row][0], 1); // type inequation
147 if (inequ->size > 3) {
148 CLAY_error("list with more than 3 arrays not supported");
149 } else if (inequ->size == 3) {
150 arr_dims = inequ->data[0];
151 arr_params = inequ->data[1];
152 arr_const = inequ->data[2];
153 } else if (inequ->size == 2) {
154 arr_params = inequ->data[0];
155 arr_const = inequ->data[1];
156 } else {
157 arr_const = inequ->data[0];
160 // affects output dims
161 if (inequ->size == 3) {
162 i = 1;
163 for (j = 0 ; j < arr_dims->size ; j++) {
164 osl_int_set_si(precision,
165 &scattering->m[row][i],
166 arr_dims->data[j]);
167 i++;
171 // affects parameters
172 if (inequ->size >= 2) {
173 i = 1 + scattering->nb_output_dims + scattering->nb_input_dims +
174 scattering->nb_local_dims;
175 for (j = 0; j < arr_params->size ; j++) {
176 osl_int_set_si(precision,
177 &scattering->m[row][i],
178 arr_params->data[j]);
179 i++;
183 // set the constant
184 if (inequ->size >= 1 && arr_const->size == 1) {
185 osl_int_set_si(precision,
186 &scattering->m[row][scattering->nb_columns-1],
187 arr_const->data[0]);
192 /**
193 * clay_util_statement_set_vector function:
194 * Set the equation on each line where the column of the output dim is
195 * different of zero
196 * \param[in,out] statement
197 * \param[in] vector {(([output, ...],) [param, ..],) [const]}
198 * \param[in] column column on the output dim
200 void clay_util_statement_set_vector(
201 osl_statement_p statement,
202 clay_list_p vector, int column) {
204 osl_relation_p scattering = statement->scattering;
205 clay_array_p arr_dims = NULL, arr_params = NULL, arr_const = NULL;
206 int i, j, k;
207 int precision = scattering->precision;
208 osl_int_p tmp;
210 tmp = osl_int_malloc(precision);
212 if (vector->size > 3) {
213 CLAY_error("list with more than 3 arrays not supported");
214 } else if (vector->size == 3) {
215 arr_dims = vector->data[0];
216 arr_params = vector->data[1];
217 arr_const = vector->data[2];
218 } else if (vector->size == 2) {
219 arr_params = vector->data[0];
220 arr_const = vector->data[1];
221 } else {
222 arr_const = vector->data[0];
225 // for each line where there is a number different from zero on the
226 // column
227 for (k = 0 ; k < scattering->nb_rows ; k++) {
228 if (!osl_int_zero(precision, scattering->m[k][1+column])) {
230 // scattering = coeff_outputdim * shifting
232 // affect output dims
233 if (vector->size >= 3) {
234 i = 1;
235 for (j = 0 ; j < arr_dims->size ; j++) {
236 osl_int_mul_si(precision,
237 &scattering->m[k][i],
238 scattering->m[k][1+column],
239 arr_dims->data[j]);
240 i++;
244 // here we add we the last value
245 // scattering += coeff_outputdim * shifting
247 // affects parameters
248 if (vector->size >= 2) {
249 i = 1 + scattering->nb_output_dims + scattering->nb_input_dims +
250 scattering->nb_local_dims;
251 for (j = 0 ; j < arr_params->size ; j++) {
252 osl_int_mul_si(precision,
253 tmp,
254 scattering->m[k][1+column],
255 arr_params->data[j]);
256 osl_int_add(precision,
257 &scattering->m[k][i],
258 scattering->m[k][i],
259 *tmp);
260 i++;
264 // set the constant
265 if (vector->size >= 1 && arr_const->size == 1) {
266 osl_int_mul_si(precision,
267 tmp,
268 scattering->m[k][1+column],
269 arr_const->data[0]);
270 osl_int_add(precision,
271 &scattering->m[k][scattering->nb_columns-1],
272 scattering->m[k][scattering->nb_columns-1],
273 *tmp);
278 osl_int_free(precision, tmp);
283 * clay_util_relation_negate_row function:
284 * Negate the line at `row' (doesn't affect the e/i column)
285 * \param[in,out] statement
286 * \param[in] row row to negate
288 void clay_util_relation_negate_row(osl_relation_p r, int row) {
289 int i;
290 int precision = r->precision;
291 for (i = 1 ; i < r->nb_columns ; i++) {
292 osl_int_oppose(precision,
293 &r->m[row][i],
294 r->m[row][i]);
296 osl_int_decrement(precision,
297 &r->m[row][r->nb_columns-1],
298 r->m[row][r->nb_columns-1]);
303 * clay_util_statement_insert function:
304 * Insert `newstatement' before `statement', and set his beta value
305 * \param[in,out] statement
306 * \param[in,out] newstatement
307 * \param[in] column column on the output dim (where we want to split)
308 * this is a `alpha column' of the 2*d+1
309 * \param[in] order new beta value
310 * \return return statement->next (so newtstatement)
312 osl_statement_p clay_util_statement_insert(osl_statement_p statement,
313 osl_statement_p newstatement,
314 int column,
315 int order) {
316 osl_relation_p scattering = statement->scattering;
318 // the current statement is after the new statement
319 int row = clay_util_relation_get_line(scattering, column);
320 osl_int_set_si(scattering->precision,
321 &scattering->m[row][scattering->nb_columns-1],
322 order);
324 // the order is not important in the statements list
325 newstatement->next = statement->next;
326 statement->next = newstatement;
327 statement = statement->next;
329 return statement;
334 * clay_util_string_replace function:
335 * Search and replace a string with another string , in a string
336 * Minor modifications from :
337 * http://www.binarytides.com/blog/str_replace-for-c/
338 * \param[in] search
339 * \param[in] replace
340 * \param[in] subject
342 char* clay_util_string_replace(char *search, char *replace, char *string) {
343 char *ptr = NULL , *old = NULL , *new_string = NULL ;
344 int count = 0 , search_size;
346 search_size = strlen(search);
348 // Count how many occurences
349 for(ptr = strstr(string , search) ; ptr != NULL ;
350 ptr = strstr(ptr + search_size , search)) {
351 count++;
354 // Final size
355 count = (strlen(replace) - search_size)*count + strlen(string) + 1;
356 new_string = calloc(count, 1);
358 // The start position
359 old = string;
361 for(ptr = strstr(string, search) ; ptr != NULL ;
362 ptr = strstr(ptr + search_size, search)) {
363 // move ahead and copy some text from original subject , from a
364 // certain position
365 strncpy(new_string + strlen(new_string), old , ptr - old);
367 // move ahead and copy the replacement text
368 strcpy(new_string + strlen(new_string) , replace);
370 // The new start position after this search match
371 old = ptr + search_size;
374 // Copy the part after the last search match
375 strcpy(new_string + strlen(new_string) , old);
377 return new_string;
382 * clay_util_scatnames_exists_iterator_iterator function:
383 * Return true if the iterator name is already in the scattering.
384 * \param[in] scattering
385 * \return
387 bool clay_util_scatnames_exists(osl_scatnames_p scatnames, char *iter) {
388 osl_strings_p names = scatnames->names;
389 if (names == NULL || names->string[0] == NULL)
390 return 0;
392 char **ptr = names->string;
394 while (*ptr != NULL) {
395 if (strcmp(*ptr, iter) == 0)
396 return 1;
397 ptr++;
400 return 0;
405 * clay_util_statement_find_iterator function:
406 * Return the index if iter is found in the original iterators list.
407 * \param[in] scop
408 * \param[in] iter name of the original iterator we want to search
409 * \return
411 int clay_util_statement_find_iterator(osl_statement_p statement, char *iter) {
412 osl_body_p body;
413 osl_extbody_p extbody = NULL;
415 extbody = osl_generic_lookup(statement->extension, OSL_URI_EXTBODY);
416 if (extbody)
417 body = extbody->body;
418 else
419 body = osl_generic_lookup(statement->extension, OSL_URI_BODY);
421 char **ptr = body->iterators->string;
422 int i = 0;
424 while (*ptr != NULL) {
425 if (strcmp(*ptr, iter) == 0)
426 return i;
427 ptr++;
428 i++;
431 return -1;
436 * clay_util_scop_export_body function:
437 * Convert each extbody to a body structure
438 * \param[in] scop
440 void clay_util_scop_export_body(osl_scop_p scop) {
441 if (scop == NULL)
442 return;
444 osl_statement_p stmt = scop->statement;
445 osl_extbody_p ebody = NULL;
446 osl_body_p body = NULL;
447 osl_generic_p gen = NULL;
449 while (stmt) {
450 ebody = osl_generic_lookup(stmt->extension, OSL_URI_EXTBODY);
451 if (ebody!=NULL) {
453 body = osl_generic_lookup(stmt->extension, OSL_URI_BODY);
454 if (body) {
455 osl_generic_remove(&stmt->extension, OSL_URI_BODY);
457 body = osl_body_clone(ebody->body);
458 gen = osl_generic_shell(body, osl_body_interface());
459 osl_generic_add(&stmt->extension, gen);
460 osl_generic_remove(&stmt->extension, OSL_URI_EXTBODY);
461 ebody=NULL;
462 body=NULL;
464 stmt = stmt->next;
469 void static clay_util_name_sprint(char **dst, int *hwm,
470 int *print_plus, int val, char *name) {
471 if (*print_plus)
472 osl_util_safe_strcat(dst, " + ", hwm);
473 else
474 *print_plus = 1;
476 char buffer[32];
478 if (name == NULL) {
479 snprintf(buffer, 32, "%d", val);
480 osl_util_safe_strcat(dst, buffer, hwm);
481 } else {
482 if (val == 1) {
483 osl_util_safe_strcat(dst, name, hwm);
484 } else if (val == -1) {
485 osl_util_safe_strcat(dst, "-", hwm);
486 osl_util_safe_strcat(dst, name, hwm);
487 } else {
488 snprintf(buffer, 32, "%d*", val);
489 osl_util_safe_strcat(dst, buffer, hwm);
490 osl_util_safe_strcat(dst, name, hwm);
497 * clay_util_body_regenerate_access function:
498 * Read the access array and re-generate the code in the body
499 * \param[in] ebody An extbody structure
500 * \param[in] access The relation to regenerate the code
501 * \param[in] index nth access (needed to access to the array start and
502 * length of the extbody structure)
504 void clay_util_body_regenerate_access(osl_extbody_p ebody,
505 osl_relation_p access,
506 int index,
507 osl_arrays_p arrays,
508 osl_scatnames_p scatnames,
509 osl_strings_p params) {
511 if (!arrays || !scatnames || !params || access->nb_output_dims == 0 ||
512 index >= ebody->nb_access)
513 return;
515 const int precision = access->precision;
516 int i, j, k, row, val, print_plus;
518 // check if there are no inequ
519 for (i = 0 ; i < access->nb_rows ; i++) {
520 if (!osl_int_zero(precision, access->m[i][0]))
521 CLAY_error("I don't know how to regenerate access with inequalities");
524 // check identity matrix in output dims
525 int n;
526 for (j = 0 ; j < access->nb_output_dims ; j++) {
527 n = 0;
528 for (i = 0 ; i < access->nb_rows ; i++)
529 if (!osl_int_zero(precision, access->m[i][j+1])) {
530 if (n >= 1)
531 CLAY_error("I don't know how to regenerate access with "
532 "dependences in output dims");
533 n++;
537 char *body = ebody->body->expression->string[0];
538 int body_len = strlen(body);
539 int start = ebody->start[index];
540 int len = ebody->length[index];
541 int is_zero; // if the line contains only zeros
543 if (start >= body_len || start + len >= body_len || (start == -1 && len == -1))
544 return;
546 char *new_body;
547 char end_body[OSL_MAX_STRING];
548 int hwm = OSL_MAX_STRING;
550 CLAY_malloc(new_body, char *, OSL_MAX_STRING * sizeof(char));
552 // copy the beginning of the body
553 if (start+1 >= OSL_MAX_STRING)
554 CLAY_error("memcpy: please recompile osl with a higher OSL_MAX_STRING");
555 memcpy(new_body, body, start);
556 new_body[start] = '\0';
558 // save the end in a buffer
559 int sz = body_len - start - len;
560 if (sz + 1 >= OSL_MAX_STRING)
561 CLAY_error("memcpy: please recompile osl with a higher OSL_MAX_STRING");
562 memcpy(end_body, body + start + len, sz);
563 end_body[sz] = '\0';
566 // copy access name string
567 val = osl_relation_get_array_id(access);
568 val = clay_util_arrays_search(arrays, val); // get the index in th array
569 osl_util_safe_strcat(&new_body, arrays->names[val], &hwm);
572 // generate each dims
573 for (i = 1 ; i < access->nb_output_dims ; i++) {
574 row = clay_util_relation_get_line(access, i);
575 if (row == -1)
576 continue;
578 osl_util_safe_strcat(&new_body, "[", &hwm);
580 is_zero = 1;
581 print_plus = 0;
582 k = 1 + access->nb_output_dims;
584 // iterators
585 for (j = 0 ; j < access->nb_input_dims ; j++, k++) {
586 val = osl_int_get_si(precision, access->m[row][k]);
587 if (val != 0) {
588 clay_util_name_sprint(&new_body,
589 &hwm,
590 &print_plus,
591 val,
592 scatnames->names->string[j*2+1]);
593 is_zero = 0;
597 // params
598 for (j = 0 ; j < access->nb_parameters ; j++, k++) {
599 val = osl_int_get_si(precision, access->m[row][k]);
600 if (val != 0) {
601 clay_util_name_sprint(&new_body,
602 &hwm,
603 &print_plus,
604 val,
605 params->string[j]);
606 is_zero = 0;
610 // const
611 val = osl_int_get_si(precision, access->m[row][k]);
612 if (val != 0 || is_zero)
613 clay_util_name_sprint(&new_body,
614 &hwm,
615 &print_plus,
616 val,
617 NULL);
619 osl_util_safe_strcat(&new_body, "]", &hwm);
622 // length of the generated access
623 ebody->length[index] = strlen(new_body) - start;
625 // concat the end
626 osl_util_safe_strcat(&new_body, end_body, &hwm);
628 // update ebody
629 free(ebody->body->expression->string[0]);
630 ebody->body->expression->string[0] = new_body;
632 // shift the start
633 int diff = ebody->length[index] - len;
634 for (i = index+1 ; i < ebody->nb_access ; i++)
635 if (ebody->start[i] != -1)
636 ebody->start[i] += diff;
641 * clay_util_arrays_search function:
642 * Search the string which corresponds to id
643 * arrays is an extension of osl
644 * \param[in] arrays An arrays osl structure
645 * \param[in] id The id to search
646 * \return Return the index in the arrays
648 int clay_util_arrays_search(osl_arrays_p arrays, unsigned int id) {
649 int i;
650 for (i = 0 ; i < arrays->nb_names ; i++) {
651 if (arrays->id[i] == id)
652 return i;
654 return -1;
659 * clay_util_foreach_access function:
660 * Execute func on each access which corresponds to access_name
661 * \param[in,out] scop
662 * \param[in] beta
663 * \param[in] access_name The id to search
664 * \param[in] func The function to execute for each access
665 * The function takes an osl_relation_list_p in
666 * parameter (the elt can be modified) and must
667 * return a define error or CLAY_SUCCESS
668 * \param[in] args args of `func'
669 * \param[in] regenerate_body If 1: after each call to func,
670 * clay_util_body_regenerate_access is also called
671 * \return Return a define error or CLAY_SUCCESS
673 int clay_util_foreach_access(osl_scop_p scop,
674 clay_array_p beta,
675 unsigned int access_name,
676 int (*func)(osl_relation_list_p, void*),
677 void *args,
678 int regenerate_body) {
680 osl_statement_p stmt = scop->statement;
681 osl_relation_list_p access;
682 osl_relation_p a;
683 osl_extbody_p ebody = NULL;
684 osl_body_p body = NULL;
685 osl_generic_p gen= NULL;
686 int count_access;
687 int found = 0;
688 int ret;
690 // TODO : global vars ?
691 osl_arrays_p arrays;
692 osl_scatnames_p scatnames;
693 osl_strings_p params;
694 arrays = osl_generic_lookup(scop->extension, OSL_URI_ARRAYS);
695 scatnames = osl_generic_lookup(scop->extension, OSL_URI_SCATNAMES);
696 params = osl_generic_lookup(scop->parameters, OSL_URI_STRINGS);
698 if (!arrays || !scatnames || !params)
699 CLAY_warning("no arrays or scatnames extension");
701 stmt = clay_beta_find(scop->statement, beta);
702 if (!stmt)
703 return CLAY_ERROR_BETA_NOT_FOUND;
705 // for each access in the beta, we search the access_name
706 while (stmt != NULL) {
707 if (clay_beta_check(stmt, beta)) {
708 access = stmt->access;
709 count_access = 0;
711 while (access) {
712 a = access->elt;
714 if (osl_relation_get_array_id(a) == access_name) {
715 found = 1;
717 ebody = osl_generic_lookup(stmt->extension, OSL_URI_EXTBODY);
718 if (ebody==NULL) {
719 CLAY_error("extbody uri not found on this statement");
720 fprintf(stderr, "%s\n",
721 ebody->body->expression->string[0]);
724 // call the function
725 ret = (*func)(access, args);
726 if (ret != CLAY_SUCCESS) {
727 fprintf(stderr, "%s\n",
728 ebody->body->expression->string[0]);
729 return ret;
732 // re-generate the body
733 if (regenerate_body) {
734 clay_util_body_regenerate_access(
735 ebody,
736 access->elt,
737 count_access,
738 arrays,
739 scatnames,
740 params);
743 //synchronize extbody with body
744 body = osl_generic_lookup(stmt->extension, OSL_URI_BODY);
745 if (body) {
746 osl_generic_remove(&stmt->extension, OSL_URI_BODY);
747 body = osl_body_clone(ebody->body);
748 gen = osl_generic_shell(body, osl_body_interface());
749 osl_generic_add(&stmt->extension, gen);
754 ebody = NULL;
755 body = NULL;
756 access = access->next;
757 count_access++;
760 stmt = stmt->next;
763 if (!found)
764 fprintf(stderr,"[Clay] Warning: access number %d not found\n", access_name);
766 return CLAY_SUCCESS;
771 * clay_util_relation_get_line function:
772 * Because the lines in the scattering matrix may have not ordered, we have to
773 * search the corresponding line. It returns the first line where the value is
774 * different from zero in the `column'. `column' is between 0 and
775 * nb_output_dims-1
776 * \param[in] relation
777 * \param[in] column Line to search
778 * \return Return the real line
780 int clay_util_relation_get_line(osl_relation_p relation, int column) {
781 if (column < 0 || column > relation->nb_output_dims)
782 return -1;
783 int i;
784 int precision = relation->precision;
785 for (i = 0 ; i < relation->nb_rows ; i++) {
786 if (!osl_int_zero(precision, relation->m[i][column+1])) {
787 break;
790 return (i == relation->nb_rows ? -1 : i );