16 typedef struct kt_for_t
{
20 void (*func
)(void*,long,int);
24 static inline long steal_work(kt_for_t
*t
)
27 long k
, min
= LONG_MAX
;
28 for (i
= 0; i
< t
->n_threads
; ++i
)
29 if (min
> t
->w
[i
].i
) min
= t
->w
[i
].i
, min_i
= i
;
30 k
= __sync_fetch_and_add(&t
->w
[min_i
].i
, t
->n_threads
);
31 return k
>= t
->n
? -1 : k
;
34 static void *ktf_worker(void *data
)
36 ktf_worker_t
*w
= (ktf_worker_t
*)data
;
39 i
= __sync_fetch_and_add(&w
->i
, w
->t
->n_threads
);
40 if (i
>= w
->t
->n
) break;
41 w
->t
->func(w
->t
->data
, i
, w
- w
->t
->w
);
43 while ((i
= steal_work(w
->t
)) >= 0)
44 w
->t
->func(w
->t
->data
, i
, w
- w
->t
->w
);
48 void kt_for(int n_threads
, void (*func
)(void*,long,int), void *data
, long n
)
54 t
.func
= func
, t
.data
= data
, t
.n_threads
= n_threads
, t
.n
= n
;
55 t
.w
= (ktf_worker_t
*)alloca(n_threads
* sizeof(ktf_worker_t
));
56 tid
= (pthread_t
*)alloca(n_threads
* sizeof(pthread_t
));
57 for (i
= 0; i
< n_threads
; ++i
)
58 t
.w
[i
].t
= &t
, t
.w
[i
].i
= i
;
59 for (i
= 0; i
< n_threads
; ++i
) pthread_create(&tid
[i
], 0, ktf_worker
, &t
.w
[i
]);
60 for (i
= 0; i
< n_threads
; ++i
) pthread_join(tid
[i
], 0);
63 for (j
= 0; j
< n
; ++j
) func(data
, j
, 0);
67 /***************************
68 * kt_for with thread pool *
69 ***************************/
74 struct kt_forpool_t
*t
;
79 typedef struct kt_forpool_t
{
80 int n_threads
, n_pending
;
84 void (*func
)(void*,long,int);
86 pthread_mutex_t mutex
;
87 pthread_cond_t cv_m
, cv_s
;
90 static inline long kt_fp_steal_work(kt_forpool_t
*t
)
93 long k
, min
= LONG_MAX
;
94 for (i
= 0; i
< t
->n_threads
; ++i
)
95 if (min
> t
->w
[i
].i
) min
= t
->w
[i
].i
, min_i
= i
;
96 k
= __sync_fetch_and_add(&t
->w
[min_i
].i
, t
->n_threads
);
97 return k
>= t
->n
? -1 : k
;
100 static void *kt_fp_worker(void *data
)
102 kto_worker_t
*w
= (kto_worker_t
*)data
;
103 kt_forpool_t
*fp
= w
->t
;
107 pthread_mutex_lock(&fp
->mutex
);
108 if (--fp
->n_pending
== 0)
109 pthread_cond_signal(&fp
->cv_m
);
111 while (w
->action
== 0) pthread_cond_wait(&fp
->cv_s
, &fp
->mutex
);
113 pthread_mutex_unlock(&fp
->mutex
);
114 if (action
< 0) break;
115 for (;;) { // process jobs allocated to this worker
116 i
= __sync_fetch_and_add(&w
->i
, fp
->n_threads
);
117 if (i
>= fp
->n
) break;
118 fp
->func(fp
->data
, i
, w
- fp
->w
);
120 while ((i
= kt_fp_steal_work(fp
)) >= 0) // steal jobs allocated to other workers
121 fp
->func(fp
->data
, i
, w
- fp
->w
);
126 void *kt_forpool_init(int n_threads
)
130 fp
= (kt_forpool_t
*)calloc(1, sizeof(kt_forpool_t
));
131 fp
->n_threads
= fp
->n_pending
= n_threads
;
132 fp
->tid
= (pthread_t
*)calloc(fp
->n_threads
, sizeof(pthread_t
));
133 fp
->w
= (kto_worker_t
*)calloc(fp
->n_threads
, sizeof(kto_worker_t
));
134 for (i
= 0; i
< fp
->n_threads
; ++i
) fp
->w
[i
].t
= fp
;
135 pthread_mutex_init(&fp
->mutex
, 0);
136 pthread_cond_init(&fp
->cv_m
, 0);
137 pthread_cond_init(&fp
->cv_s
, 0);
138 for (i
= 0; i
< fp
->n_threads
; ++i
) pthread_create(&fp
->tid
[i
], 0, kt_fp_worker
, &fp
->w
[i
]);
139 pthread_mutex_lock(&fp
->mutex
);
140 while (fp
->n_pending
) pthread_cond_wait(&fp
->cv_m
, &fp
->mutex
);
141 pthread_mutex_unlock(&fp
->mutex
);
145 void kt_forpool_destroy(void *_fp
)
147 kt_forpool_t
*fp
= (kt_forpool_t
*)_fp
;
149 for (i
= 0; i
< fp
->n_threads
; ++i
) fp
->w
[i
].action
= -1;
150 pthread_cond_broadcast(&fp
->cv_s
);
151 for (i
= 0; i
< fp
->n_threads
; ++i
) pthread_join(fp
->tid
[i
], 0);
152 pthread_cond_destroy(&fp
->cv_s
);
153 pthread_cond_destroy(&fp
->cv_m
);
154 pthread_mutex_destroy(&fp
->mutex
);
155 free(fp
->w
); free(fp
->tid
); free(fp
);
158 void kt_forpool(void *_fp
, void (*func
)(void*,long,int), void *data
, long n
)
160 kt_forpool_t
*fp
= (kt_forpool_t
*)_fp
;
162 if (fp
&& fp
->n_threads
> 1) {
163 fp
->n
= n
, fp
->func
= func
, fp
->data
= data
, fp
->n_pending
= fp
->n_threads
;
164 for (i
= 0; i
< fp
->n_threads
; ++i
) fp
->w
[i
].i
= i
, fp
->w
[i
].action
= 1;
165 pthread_mutex_lock(&fp
->mutex
);
166 pthread_cond_broadcast(&fp
->cv_s
);
167 while (fp
->n_pending
) pthread_cond_wait(&fp
->cv_m
, &fp
->mutex
);
168 pthread_mutex_unlock(&fp
->mutex
);
169 } else for (i
= 0; i
< n
; ++i
) func(data
, i
, 0);
185 typedef struct ktp_t
{
187 void *(*func
)(void*, int, void*);
189 int n_workers
, n_steps
;
190 ktp_worker_t
*workers
;
191 pthread_mutex_t mutex
;
195 static void *ktp_worker(void *data
)
197 ktp_worker_t
*w
= (ktp_worker_t
*)data
;
199 while (w
->step
< p
->n_steps
) {
200 // test whether we can kick off the job with this worker
201 pthread_mutex_lock(&p
->mutex
);
204 // test whether another worker is doing the same step
205 for (i
= 0; i
< p
->n_workers
; ++i
) {
206 if (w
== &p
->workers
[i
]) continue; // ignore itself
207 if (p
->workers
[i
].step
<= w
->step
&& p
->workers
[i
].index
< w
->index
)
210 if (i
== p
->n_workers
) break; // no workers with smaller indices are doing w->step or the previous steps
211 pthread_cond_wait(&p
->cv
, &p
->mutex
);
213 pthread_mutex_unlock(&p
->mutex
);
215 // working on w->step
216 w
->data
= p
->func(p
->shared
, w
->step
, w
->step
? w
->data
: 0); // for the first step, input is NULL
218 // update step and let other workers know
219 pthread_mutex_lock(&p
->mutex
);
220 w
->step
= w
->step
== p
->n_steps
- 1 || w
->data
? (w
->step
+ 1) % p
->n_steps
: p
->n_steps
;
221 if (w
->step
== 0) w
->index
= p
->index
++;
222 pthread_cond_broadcast(&p
->cv
);
223 pthread_mutex_unlock(&p
->mutex
);
228 void kt_pipeline(int n_threads
, void *(*func
)(void*, int, void*), void *shared_data
, int n_steps
)
234 if (n_threads
< 1) n_threads
= 1;
235 aux
.n_workers
= n_threads
;
236 aux
.n_steps
= n_steps
;
238 aux
.shared
= shared_data
;
240 pthread_mutex_init(&aux
.mutex
, 0);
241 pthread_cond_init(&aux
.cv
, 0);
243 aux
.workers
= (ktp_worker_t
*)alloca(n_threads
* sizeof(ktp_worker_t
));
244 for (i
= 0; i
< n_threads
; ++i
) {
245 ktp_worker_t
*w
= &aux
.workers
[i
];
246 w
->step
= 0; w
->pl
= &aux
; w
->data
= 0;
247 w
->index
= aux
.index
++;
250 tid
= (pthread_t
*)alloca(n_threads
* sizeof(pthread_t
));
251 for (i
= 0; i
< n_threads
; ++i
) pthread_create(&tid
[i
], 0, ktp_worker
, &aux
.workers
[i
]);
252 for (i
= 0; i
< n_threads
; ++i
) pthread_join(tid
[i
], 0);
254 pthread_mutex_destroy(&aux
.mutex
);
255 pthread_cond_destroy(&aux
.cv
);