2 /*--------------------------------------------------------------------+
4 |--------------------------------------------------------------------|
6 |--------------------------------------------------------------------|
7 | First version: 03/04/2012 |
8 +--------------------------------------------------------------------+
10 +--------------------------------------------------------------------------+
11 | / __)( ) /__\ ( \/ ) |
12 | ( (__ )(__ /(__)\ \ / Chunky Loop Alteration wizardrY |
13 | \___)(____)(__)(__)(__) |
14 +--------------------------------------------------------------------------+
15 | Copyright (C) 2012 University of Paris-Sud |
17 | This library is free software; you can redistribute it and/or modify it |
18 | under the terms of the GNU Lesser General Public License as published by |
19 | the Free Software Foundation; either version 2.1 of the License, or |
20 | (at your option) any later version. |
22 | This library is distributed in the hope that it will be useful but |
23 | WITHOUT ANY WARRANTY; without even the implied warranty of |
24 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser |
25 | General Public License for more details. |
27 | You should have received a copy of the GNU Lesser General Public License |
28 | along with this software; if not, write to the Free Software Foundation, |
29 | Inc., 51 Franklin Street, Fifth Floor, |
30 | Boston, MA 02110-1301 USA |
32 | Clay, the Chunky Loop Alteration wizardrY |
33 | Written by Joel Poudroux, joel.poudroux@u-psud.fr |
34 +--------------------------------------------------------------------------*/
38 #include <clay/array.h>
39 #include <clay/beta.h>
40 #include <clay/macros.h>
41 #include <clay/util.h>
42 #include <clay/errors.h>
44 #include <osl/statement.h>
46 #include <osl/extensions/extbody.h>
47 #include <osl/extensions/scatnames.h>
49 #include <osl/generic.h>
51 #include <osl/relation.h>
52 #include <osl/relation_list.h>
53 #include <osl/macros.h>
56 * clay_util_statement_insert_inequation function:
57 * Insert a new inequation at the end of the scattering
58 * \param[in,out] statement
59 * \param[in] inequ [iter1, iter2, ..., param1, param2, ..., const]
60 * \param[in] nb_input_dims Nb input dims in the array
61 * \param[in] nb_params Nb params in the array
63 /*void clay_util_statement_insert_inequation(osl_statement_p statement,
64 clay_array_p inequ, int nb_input_dims, int nb_params) {
66 osl_relation_p scattering = statement->scattering;
67 int row = scattering->nb_rows;
69 int precision = scattering->precision;
71 // insert the inequation spliting (local dims are not in the inequation)
73 osl_relation_insert_blank_row(scattering, row);
74 osl_int_set_si(precision, scattering->m[row], 0, 1); // type inequation
77 i = scattering->nb_output_dims+1;
78 for (j = 0 ; j < nb_input_dims ; j++) {
79 osl_int_set_si(precision,
80 scattering->m[row], i,
85 i = 1 + scattering->nb_output_dims + scattering->nb_input_dims +
86 scattering->nb_local_dims;
87 for (; j < nb_params + nb_input_dims ; j++) {
88 osl_int_set_si(precision,
89 scattering->m[row], i,
94 osl_int_set_si(precision,
95 scattering->m[row], scattering->nb_columns-1,
96 inequ->data[inequ->size-1]);
103 * clay_util_array_output_dims_pad_zero function:
104 * Pad zeros for alpha columns
105 * For example if we have [i, j], the result will be [0, i, 0, j, 0]
106 * \param[in,out] array
108 void clay_util_array_output_dims_pad_zero(clay_array_p a
) {
110 int end
= a
->size
*2+1;
112 for (i
= a
->size
; i
< end
; i
++)
113 clay_array_add(a
, 0);
115 for (i
= end
-1 ; i
>= 0 ; i
--) {
116 a
->data
[i
*2+1] = a
->data
[i
];
123 * clay_util_statement_insert_inequation function:
124 * Insert a new inequation at the end of the scattering
125 * The list must have less or equal than 3 arrays
126 * Warning: here the output dims are complete
127 * example: if the output dims are : 0 i 0 j 0, you have
128 * to give all these numbers
129 * \param[in,out] statement
130 * \param[in] inequ {(([output, ...],) [param, ..],) [const]}
132 void clay_util_statement_set_inequation(
133 osl_statement_p statement
,
136 osl_relation_p scattering
= statement
->scattering
;
137 clay_array_p arr_dims
= NULL
, arr_params
= NULL
, arr_const
= NULL
;
138 int row
= scattering
->nb_rows
;
140 int precision
= scattering
->precision
;
142 // insert the inequation spliting (local dims are not in the inequation)
144 osl_relation_insert_blank_row(scattering
, row
);
145 osl_int_set_si(precision
, &scattering
->m
[row
][0], 1); // type inequation
147 if (inequ
->size
> 3) {
148 CLAY_error("list with more than 3 arrays not supported");
149 } else if (inequ
->size
== 3) {
150 arr_dims
= inequ
->data
[0];
151 arr_params
= inequ
->data
[1];
152 arr_const
= inequ
->data
[2];
153 } else if (inequ
->size
== 2) {
154 arr_params
= inequ
->data
[0];
155 arr_const
= inequ
->data
[1];
157 arr_const
= inequ
->data
[0];
160 // affects output dims
161 if (inequ
->size
== 3) {
163 for (j
= 0 ; j
< arr_dims
->size
; j
++) {
164 osl_int_set_si(precision
,
165 &scattering
->m
[row
][i
],
171 // affects parameters
172 if (inequ
->size
>= 2) {
173 i
= 1 + scattering
->nb_output_dims
+ scattering
->nb_input_dims
+
174 scattering
->nb_local_dims
;
175 for (j
= 0; j
< arr_params
->size
; j
++) {
176 osl_int_set_si(precision
,
177 &scattering
->m
[row
][i
],
178 arr_params
->data
[j
]);
184 if (inequ
->size
>= 1 && arr_const
->size
== 1) {
185 osl_int_set_si(precision
,
186 &scattering
->m
[row
][scattering
->nb_columns
-1],
193 * clay_util_statement_set_vector function:
194 * Set the equation on each line where the column of the output dim is
196 * \param[in,out] statement
197 * \param[in] vector {(([output, ...],) [param, ..],) [const]}
198 * \param[in] column column on the output dim
200 void clay_util_statement_set_vector(
201 osl_statement_p statement
,
202 clay_list_p vector
, int column
) {
204 osl_relation_p scattering
= statement
->scattering
;
205 clay_array_p arr_dims
= NULL
, arr_params
= NULL
, arr_const
= NULL
;
207 int precision
= scattering
->precision
;
210 tmp
= osl_int_malloc(precision
);
212 if (vector
->size
> 3) {
213 CLAY_error("list with more than 3 arrays not supported");
214 } else if (vector
->size
== 3) {
215 arr_dims
= vector
->data
[0];
216 arr_params
= vector
->data
[1];
217 arr_const
= vector
->data
[2];
218 } else if (vector
->size
== 2) {
219 arr_params
= vector
->data
[0];
220 arr_const
= vector
->data
[1];
222 arr_const
= vector
->data
[0];
225 // for each line where there is a number different from zero on the
227 for (k
= 0 ; k
< scattering
->nb_rows
; k
++) {
228 if (!osl_int_zero(precision
, scattering
->m
[k
][1+column
])) {
230 // scattering = coeff_outputdim * shifting
232 // affect output dims
233 if (vector
->size
>= 3) {
235 for (j
= 0 ; j
< arr_dims
->size
; j
++) {
236 osl_int_mul_si(precision
,
237 &scattering
->m
[k
][i
],
238 scattering
->m
[k
][1+column
],
244 // here we add we the last value
245 // scattering += coeff_outputdim * shifting
247 // affects parameters
248 if (vector
->size
>= 2) {
249 i
= 1 + scattering
->nb_output_dims
+ scattering
->nb_input_dims
+
250 scattering
->nb_local_dims
;
251 for (j
= 0 ; j
< arr_params
->size
; j
++) {
252 osl_int_mul_si(precision
,
254 scattering
->m
[k
][1+column
],
255 arr_params
->data
[j
]);
256 osl_int_add(precision
,
257 &scattering
->m
[k
][i
],
265 if (vector
->size
>= 1 && arr_const
->size
== 1) {
266 osl_int_mul_si(precision
,
268 scattering
->m
[k
][1+column
],
270 osl_int_add(precision
,
271 &scattering
->m
[k
][scattering
->nb_columns
-1],
272 scattering
->m
[k
][scattering
->nb_columns
-1],
278 osl_int_free(precision
, tmp
);
283 * clay_util_relation_negate_row function:
284 * Negate the line at `row' (doesn't affect the e/i column)
285 * \param[in,out] statement
286 * \param[in] row row to negate
288 void clay_util_relation_negate_row(osl_relation_p r
, int row
) {
290 int precision
= r
->precision
;
291 for (i
= 1 ; i
< r
->nb_columns
; i
++) {
292 osl_int_oppose(precision
,
296 osl_int_decrement(precision
,
297 &r
->m
[row
][r
->nb_columns
-1],
298 r
->m
[row
][r
->nb_columns
-1]);
303 * clay_util_statement_insert function:
304 * Insert `newstatement' before `statement', and set his beta value
305 * \param[in,out] statement
306 * \param[in,out] newstatement
307 * \param[in] column column on the output dim (where we want to split)
308 * this is a `alpha column' of the 2*d+1
309 * \param[in] order new beta value
310 * \return return statement->next (so newtstatement)
312 osl_statement_p
clay_util_statement_insert(osl_statement_p statement
,
313 osl_statement_p newstatement
,
316 osl_relation_p scattering
= statement
->scattering
;
318 // the current statement is after the new statement
319 int row
= clay_util_relation_get_line(scattering
, column
);
320 osl_int_set_si(scattering
->precision
,
321 &scattering
->m
[row
][scattering
->nb_columns
-1],
324 // the order is not important in the statements list
325 newstatement
->next
= statement
->next
;
326 statement
->next
= newstatement
;
327 statement
= statement
->next
;
334 * clay_util_string_replace function:
335 * Search and replace a string with another string , in a string
336 * Minor modifications from :
337 * http://www.binarytides.com/blog/str_replace-for-c/
342 char* clay_util_string_replace(char *search
, char *replace
, char *string
) {
343 char *ptr
= NULL
, *old
= NULL
, *new_string
= NULL
;
344 int count
= 0 , search_size
;
346 search_size
= strlen(search
);
348 // Count how many occurences
349 for(ptr
= strstr(string
, search
) ; ptr
!= NULL
;
350 ptr
= strstr(ptr
+ search_size
, search
)) {
355 count
= (strlen(replace
) - search_size
)*count
+ strlen(string
) + 1;
356 new_string
= calloc(count
, 1);
358 // The start position
361 for(ptr
= strstr(string
, search
) ; ptr
!= NULL
;
362 ptr
= strstr(ptr
+ search_size
, search
)) {
363 // move ahead and copy some text from original subject , from a
365 strncpy(new_string
+ strlen(new_string
), old
, ptr
- old
);
367 // move ahead and copy the replacement text
368 strcpy(new_string
+ strlen(new_string
) , replace
);
370 // The new start position after this search match
371 old
= ptr
+ search_size
;
374 // Copy the part after the last search match
375 strcpy(new_string
+ strlen(new_string
) , old
);
382 * clay_util_scatnames_exists_iterator_iterator function:
383 * Return true if the iterator name is already in the scattering.
384 * \param[in] scattering
387 bool clay_util_scatnames_exists(osl_scatnames_p scatnames
, char *iter
) {
388 osl_strings_p names
= scatnames
->names
;
389 if (names
== NULL
|| names
->string
[0] == NULL
)
392 char **ptr
= names
->string
;
394 while (*ptr
!= NULL
) {
395 if (strcmp(*ptr
, iter
) == 0)
405 * clay_util_statement_find_iterator function:
406 * Return the index if iter is found in the original iterators list.
408 * \param[in] iter name of the original iterator we want to search
411 int clay_util_statement_find_iterator(osl_statement_p statement
, char *iter
) {
413 osl_extbody_p extbody
= NULL
;
415 extbody
= osl_generic_lookup(statement
->extension
, OSL_URI_EXTBODY
);
417 body
= extbody
->body
;
419 body
= osl_generic_lookup(statement
->extension
, OSL_URI_BODY
);
421 char **ptr
= body
->iterators
->string
;
424 while (*ptr
!= NULL
) {
425 if (strcmp(*ptr
, iter
) == 0)
436 * clay_util_scop_export_body function:
437 * Convert each extbody to a body structure
440 void clay_util_scop_export_body(osl_scop_p scop
) {
444 osl_statement_p stmt
= scop
->statement
;
445 osl_extbody_p ebody
= NULL
;
446 osl_body_p body
= NULL
;
447 osl_generic_p gen
= NULL
;
450 ebody
= osl_generic_lookup(stmt
->extension
, OSL_URI_EXTBODY
);
453 body
= osl_generic_lookup(stmt
->extension
, OSL_URI_BODY
);
455 osl_generic_remove(&stmt
->extension
, OSL_URI_BODY
);
457 body
= osl_body_clone(ebody
->body
);
458 gen
= osl_generic_shell(body
, osl_body_interface());
459 osl_generic_add(&stmt
->extension
, gen
);
460 osl_generic_remove(&stmt
->extension
, OSL_URI_EXTBODY
);
469 void static clay_util_name_sprint(char **dst
, int *hwm
,
470 int *print_plus
, int val
, char *name
) {
472 osl_util_safe_strcat(dst
, " + ", hwm
);
479 snprintf(buffer
, 32, "%d", val
);
480 osl_util_safe_strcat(dst
, buffer
, hwm
);
483 osl_util_safe_strcat(dst
, name
, hwm
);
484 } else if (val
== -1) {
485 osl_util_safe_strcat(dst
, "-", hwm
);
486 osl_util_safe_strcat(dst
, name
, hwm
);
488 snprintf(buffer
, 32, "%d*", val
);
489 osl_util_safe_strcat(dst
, buffer
, hwm
);
490 osl_util_safe_strcat(dst
, name
, hwm
);
497 * clay_util_body_regenerate_access function:
498 * Read the access array and re-generate the code in the body
499 * \param[in] ebody An extbody structure
500 * \param[in] access The relation to regenerate the code
501 * \param[in] index nth access (needed to access to the array start and
502 * length of the extbody structure)
504 void clay_util_body_regenerate_access(osl_extbody_p ebody
,
505 osl_relation_p access
,
508 osl_scatnames_p scatnames
,
509 osl_strings_p params
) {
511 if (!arrays
|| !scatnames
|| !params
|| access
->nb_output_dims
== 0 ||
512 index
>= ebody
->nb_access
)
515 const int precision
= access
->precision
;
516 int i
, j
, k
, row
, val
, print_plus
;
518 // check if there are no inequ
519 for (i
= 0 ; i
< access
->nb_rows
; i
++) {
520 if (!osl_int_zero(precision
, access
->m
[i
][0]))
521 CLAY_error("I don't know how to regenerate access with inequalities");
524 // check identity matrix in output dims
526 for (j
= 0 ; j
< access
->nb_output_dims
; j
++) {
528 for (i
= 0 ; i
< access
->nb_rows
; i
++)
529 if (!osl_int_zero(precision
, access
->m
[i
][j
+1])) {
531 CLAY_error("I don't know how to regenerate access with "
532 "dependences in output dims");
537 char *body
= ebody
->body
->expression
->string
[0];
538 int body_len
= strlen(body
);
539 int start
= ebody
->start
[index
];
540 int len
= ebody
->length
[index
];
541 int is_zero
; // if the line contains only zeros
543 if (start
>= body_len
|| start
+ len
>= body_len
|| (start
== -1 && len
== -1))
547 char end_body
[OSL_MAX_STRING
];
548 int hwm
= OSL_MAX_STRING
;
550 CLAY_malloc(new_body
, char *, OSL_MAX_STRING
* sizeof(char));
552 // copy the beginning of the body
553 if (start
+1 >= OSL_MAX_STRING
)
554 CLAY_error("memcpy: please recompile osl with a higher OSL_MAX_STRING");
555 memcpy(new_body
, body
, start
);
556 new_body
[start
] = '\0';
558 // save the end in a buffer
559 int sz
= body_len
- start
- len
;
560 if (sz
+ 1 >= OSL_MAX_STRING
)
561 CLAY_error("memcpy: please recompile osl with a higher OSL_MAX_STRING");
562 memcpy(end_body
, body
+ start
+ len
, sz
);
566 // copy access name string
567 val
= osl_relation_get_array_id(access
);
568 val
= clay_util_arrays_search(arrays
, val
); // get the index in th array
569 osl_util_safe_strcat(&new_body
, arrays
->names
[val
], &hwm
);
572 // generate each dims
573 for (i
= 1 ; i
< access
->nb_output_dims
; i
++) {
574 row
= clay_util_relation_get_line(access
, i
);
578 osl_util_safe_strcat(&new_body
, "[", &hwm
);
582 k
= 1 + access
->nb_output_dims
;
585 for (j
= 0 ; j
< access
->nb_input_dims
; j
++, k
++) {
586 val
= osl_int_get_si(precision
, access
->m
[row
][k
]);
588 clay_util_name_sprint(&new_body
,
592 scatnames
->names
->string
[j
*2+1]);
598 for (j
= 0 ; j
< access
->nb_parameters
; j
++, k
++) {
599 val
= osl_int_get_si(precision
, access
->m
[row
][k
]);
601 clay_util_name_sprint(&new_body
,
611 val
= osl_int_get_si(precision
, access
->m
[row
][k
]);
612 if (val
!= 0 || is_zero
)
613 clay_util_name_sprint(&new_body
,
619 osl_util_safe_strcat(&new_body
, "]", &hwm
);
622 // length of the generated access
623 ebody
->length
[index
] = strlen(new_body
) - start
;
626 osl_util_safe_strcat(&new_body
, end_body
, &hwm
);
629 free(ebody
->body
->expression
->string
[0]);
630 ebody
->body
->expression
->string
[0] = new_body
;
633 int diff
= ebody
->length
[index
] - len
;
634 for (i
= index
+1 ; i
< ebody
->nb_access
; i
++)
635 if (ebody
->start
[i
] != -1)
636 ebody
->start
[i
] += diff
;
641 * clay_util_arrays_search function:
642 * Search the string which corresponds to id
643 * arrays is an extension of osl
644 * \param[in] arrays An arrays osl structure
645 * \param[in] id The id to search
646 * \return Return the index in the arrays
648 int clay_util_arrays_search(osl_arrays_p arrays
, unsigned int id
) {
650 for (i
= 0 ; i
< arrays
->nb_names
; i
++) {
651 if (arrays
->id
[i
] == id
)
659 * clay_util_foreach_access function:
660 * Execute func on each access which corresponds to access_name
661 * \param[in,out] scop
663 * \param[in] access_name The id to search
664 * \param[in] func The function to execute for each access
665 * The function takes an osl_relation_list_p in
666 * parameter (the elt can be modified) and must
667 * return a define error or CLAY_SUCCESS
668 * \param[in] args args of `func'
669 * \param[in] regenerate_body If 1: after each call to func,
670 * clay_util_body_regenerate_access is also called
671 * \return Return a define error or CLAY_SUCCESS
673 int clay_util_foreach_access(osl_scop_p scop
,
675 unsigned int access_name
,
676 int (*func
)(osl_relation_list_p
, void*),
678 int regenerate_body
) {
680 osl_statement_p stmt
= scop
->statement
;
681 osl_relation_list_p access
;
683 osl_extbody_p ebody
= NULL
;
684 osl_body_p body
= NULL
;
685 osl_generic_p gen
= NULL
;
690 // TODO : global vars ?
692 osl_scatnames_p scatnames
;
693 osl_strings_p params
;
694 arrays
= osl_generic_lookup(scop
->extension
, OSL_URI_ARRAYS
);
695 scatnames
= osl_generic_lookup(scop
->extension
, OSL_URI_SCATNAMES
);
696 params
= osl_generic_lookup(scop
->parameters
, OSL_URI_STRINGS
);
698 if (!arrays
|| !scatnames
|| !params
)
699 CLAY_warning("no arrays or scatnames extension");
701 stmt
= clay_beta_find(scop
->statement
, beta
);
703 return CLAY_ERROR_BETA_NOT_FOUND
;
705 // for each access in the beta, we search the access_name
706 while (stmt
!= NULL
) {
707 if (clay_beta_check(stmt
, beta
)) {
708 access
= stmt
->access
;
714 if (osl_relation_get_array_id(a
) == access_name
) {
717 ebody
= osl_generic_lookup(stmt
->extension
, OSL_URI_EXTBODY
);
719 CLAY_error("extbody uri not found on this statement");
720 fprintf(stderr
, "%s\n",
721 ebody
->body
->expression
->string
[0]);
725 ret
= (*func
)(access
, args
);
726 if (ret
!= CLAY_SUCCESS
) {
727 fprintf(stderr
, "%s\n",
728 ebody
->body
->expression
->string
[0]);
732 // re-generate the body
733 if (regenerate_body
) {
734 clay_util_body_regenerate_access(
743 //synchronize extbody with body
744 body
= osl_generic_lookup(stmt
->extension
, OSL_URI_BODY
);
746 osl_generic_remove(&stmt
->extension
, OSL_URI_BODY
);
747 body
= osl_body_clone(ebody
->body
);
748 gen
= osl_generic_shell(body
, osl_body_interface());
749 osl_generic_add(&stmt
->extension
, gen
);
756 access
= access
->next
;
764 fprintf(stderr
,"[Clay] Warning: access number %d not found\n", access_name
);
771 * clay_util_relation_get_line function:
772 * Because the lines in the scattering matrix may have not ordered, we have to
773 * search the corresponding line. It returns the first line where the value is
774 * different from zero in the `column'. `column' is between 0 and
776 * \param[in] relation
777 * \param[in] column Line to search
778 * \return Return the real line
780 int clay_util_relation_get_line(osl_relation_p relation
, int column
) {
781 if (column
< 0 || column
> relation
->nb_output_dims
)
784 int precision
= relation
->precision
;
785 for (i
= 0 ; i
< relation
->nb_rows
; i
++) {
786 if (!osl_int_zero(precision
, relation
->m
[i
][column
+1])) {
790 return (i
== relation
->nb_rows
? -1 : i
);