Bump version to 19.1.0-rc3
[llvm-project.git] / offload / DeviceRTL / src / Workshare.cpp
blob7e087a07e44201b9039fdc77671923b63d425703
1 //===----- Workshare.cpp - OpenMP workshare implementation ------ C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the KMPC interface
10 // for the loop construct plus other worksharing constructs that use the same
11 // interface as loops.
13 //===----------------------------------------------------------------------===//
15 #include "Workshare.h"
16 #include "Debug.h"
17 #include "Interface.h"
18 #include "Mapping.h"
19 #include "State.h"
20 #include "Synchronization.h"
21 #include "Types.h"
22 #include "Utils.h"
24 using namespace ompx;
26 // TODO:
27 struct DynamicScheduleTracker {
28 int64_t Chunk;
29 int64_t LoopUpperBound;
30 int64_t NextLowerBound;
31 int64_t Stride;
32 kmp_sched_t ScheduleType;
33 DynamicScheduleTracker *NextDST;
36 #define ASSERT0(...)
38 // used by the library for the interface with the app
39 #define DISPATCH_FINISHED 0
40 #define DISPATCH_NOTFINISHED 1
42 // used by dynamic scheduling
43 #define FINISHED 0
44 #define NOT_FINISHED 1
45 #define LAST_CHUNK 2
47 #pragma omp begin declare target device_type(nohost)
49 // TODO: This variable is a hack inherited from the old runtime.
50 static uint64_t SHARED(Cnt);
52 template <typename T, typename ST> struct omptarget_nvptx_LoopSupport {
53 ////////////////////////////////////////////////////////////////////////////////
54 // Loop with static scheduling with chunk
56 // Generic implementation of OMP loop scheduling with static policy
57 /*! \brief Calculate initial bounds for static loop and stride
58 * @param[in] loc location in code of the call (not used here)
59 * @param[in] global_tid global thread id
60 * @param[in] schetype type of scheduling (see omptarget-nvptx.h)
61 * @param[in] plastiter pointer to last iteration
62 * @param[in,out] pointer to loop lower bound. it will contain value of
63 * lower bound of first chunk
64 * @param[in,out] pointer to loop upper bound. It will contain value of
65 * upper bound of first chunk
66 * @param[in,out] pointer to loop stride. It will contain value of stride
67 * between two successive chunks executed by the same thread
68 * @param[in] loop increment bump
69 * @param[in] chunk size
72 // helper function for static chunk
73 static void ForStaticChunk(int &last, T &lb, T &ub, ST &stride, ST chunk,
74 T entityId, T numberOfEntities) {
75 // each thread executes multiple chunks all of the same size, except
76 // the last one
77 // distance between two successive chunks
78 stride = numberOfEntities * chunk;
79 lb = lb + entityId * chunk;
80 T inputUb = ub;
81 ub = lb + chunk - 1; // Clang uses i <= ub
82 // Say ub' is the begining of the last chunk. Then who ever has a
83 // lower bound plus a multiple of the increment equal to ub' is
84 // the last one.
85 T beginingLastChunk = inputUb - (inputUb % chunk);
86 last = ((beginingLastChunk - lb) % stride) == 0;
89 ////////////////////////////////////////////////////////////////////////////////
90 // Loop with static scheduling without chunk
92 // helper function for static no chunk
93 static void ForStaticNoChunk(int &last, T &lb, T &ub, ST &stride, ST &chunk,
94 T entityId, T numberOfEntities) {
95 // No chunk size specified. Each thread or warp gets at most one
96 // chunk; chunks are all almost of equal size
97 T loopSize = ub - lb + 1;
99 chunk = loopSize / numberOfEntities;
100 T leftOver = loopSize - chunk * numberOfEntities;
102 if (entityId < leftOver) {
103 chunk++;
104 lb = lb + entityId * chunk;
105 } else {
106 lb = lb + entityId * chunk + leftOver;
109 T inputUb = ub;
110 ub = lb + chunk - 1; // Clang uses i <= ub
111 last = lb <= inputUb && inputUb <= ub;
112 stride = loopSize; // make sure we only do 1 chunk per warp
115 ////////////////////////////////////////////////////////////////////////////////
116 // Support for Static Init
118 static void for_static_init(int32_t, int32_t schedtype, int32_t *plastiter,
119 T *plower, T *pupper, ST *pstride, ST chunk,
120 bool IsSPMDExecutionMode) {
121 int32_t gtid = omp_get_thread_num();
122 int numberOfActiveOMPThreads = omp_get_num_threads();
124 // All warps that are in excess of the maximum requested, do
125 // not execute the loop
126 ASSERT0(LT_FUSSY, gtid < numberOfActiveOMPThreads,
127 "current thread is not needed here; error");
129 // copy
130 int lastiter = 0;
131 T lb = *plower;
132 T ub = *pupper;
133 ST stride = *pstride;
135 // init
136 switch (SCHEDULE_WITHOUT_MODIFIERS(schedtype)) {
137 case kmp_sched_static_chunk: {
138 if (chunk > 0) {
139 ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
140 numberOfActiveOMPThreads);
141 break;
143 [[fallthrough]];
144 } // note: if chunk <=0, use nochunk
145 case kmp_sched_static_balanced_chunk: {
146 if (chunk > 0) {
147 // round up to make sure the chunk is enough to cover all iterations
148 T tripCount = ub - lb + 1; // +1 because ub is inclusive
149 T span = (tripCount + numberOfActiveOMPThreads - 1) /
150 numberOfActiveOMPThreads;
151 // perform chunk adjustment
152 chunk = (span + chunk - 1) & ~(chunk - 1);
154 ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
155 T oldUb = ub;
156 ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
157 numberOfActiveOMPThreads);
158 if (ub > oldUb)
159 ub = oldUb;
160 break;
162 [[fallthrough]];
163 } // note: if chunk <=0, use nochunk
164 case kmp_sched_static_nochunk: {
165 ForStaticNoChunk(lastiter, lb, ub, stride, chunk, gtid,
166 numberOfActiveOMPThreads);
167 break;
169 case kmp_sched_distr_static_chunk: {
170 if (chunk > 0) {
171 ForStaticChunk(lastiter, lb, ub, stride, chunk, omp_get_team_num(),
172 omp_get_num_teams());
173 break;
175 [[fallthrough]];
176 } // note: if chunk <=0, use nochunk
177 case kmp_sched_distr_static_nochunk: {
178 ForStaticNoChunk(lastiter, lb, ub, stride, chunk, omp_get_team_num(),
179 omp_get_num_teams());
180 break;
182 case kmp_sched_distr_static_chunk_sched_static_chunkone: {
183 ForStaticChunk(lastiter, lb, ub, stride, chunk,
184 numberOfActiveOMPThreads * omp_get_team_num() + gtid,
185 omp_get_num_teams() * numberOfActiveOMPThreads);
186 break;
188 default: {
189 // ASSERT(LT_FUSSY, 0, "unknown schedtype %d", (int)schedtype);
190 ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
191 numberOfActiveOMPThreads);
192 break;
195 // copy back
196 *plastiter = lastiter;
197 *plower = lb;
198 *pupper = ub;
199 *pstride = stride;
202 ////////////////////////////////////////////////////////////////////////////////
203 // Support for dispatch Init
205 static int OrderedSchedule(kmp_sched_t schedule) {
206 return schedule >= kmp_sched_ordered_first &&
207 schedule <= kmp_sched_ordered_last;
210 static void dispatch_init(IdentTy *loc, int32_t threadId,
211 kmp_sched_t schedule, T lb, T ub, ST st, ST chunk,
212 DynamicScheduleTracker *DST) {
213 int tid = mapping::getThreadIdInBlock();
214 T tnum = omp_get_num_threads();
215 T tripCount = ub - lb + 1; // +1 because ub is inclusive
216 ASSERT0(LT_FUSSY, threadId < tnum,
217 "current thread is not needed here; error");
219 /* Currently just ignore the monotonic and non-monotonic modifiers
220 * (the compiler isn't producing them * yet anyway).
221 * When it is we'll want to look at them somewhere here and use that
222 * information to add to our schedule choice. We shouldn't need to pass
223 * them on, they merely affect which schedule we can legally choose for
224 * various dynamic cases. (In particular, whether or not a stealing scheme
225 * is legal).
227 schedule = SCHEDULE_WITHOUT_MODIFIERS(schedule);
229 // Process schedule.
230 if (tnum == 1 || tripCount <= 1 || OrderedSchedule(schedule)) {
231 if (OrderedSchedule(schedule))
232 __kmpc_barrier(loc, threadId);
233 schedule = kmp_sched_static_chunk;
234 chunk = tripCount; // one thread gets the whole loop
235 } else if (schedule == kmp_sched_runtime) {
236 // process runtime
237 omp_sched_t rtSched;
238 int ChunkInt;
239 omp_get_schedule(&rtSched, &ChunkInt);
240 chunk = ChunkInt;
241 switch (rtSched) {
242 case omp_sched_static: {
243 if (chunk > 0)
244 schedule = kmp_sched_static_chunk;
245 else
246 schedule = kmp_sched_static_nochunk;
247 break;
249 case omp_sched_auto: {
250 schedule = kmp_sched_static_chunk;
251 chunk = 1;
252 break;
254 case omp_sched_dynamic:
255 case omp_sched_guided: {
256 schedule = kmp_sched_dynamic;
257 break;
260 } else if (schedule == kmp_sched_auto) {
261 schedule = kmp_sched_static_chunk;
262 chunk = 1;
263 } else {
264 // ASSERT(LT_FUSSY,
265 // schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
266 // "unknown schedule %d & chunk %lld\n", (int)schedule,
267 // (long long)chunk);
270 // init schedules
271 if (schedule == kmp_sched_static_chunk) {
272 ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
273 // save sched state
274 DST->ScheduleType = schedule;
275 // save ub
276 DST->LoopUpperBound = ub;
277 // compute static chunk
278 ST stride;
279 int lastiter = 0;
280 ForStaticChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
281 // save computed params
282 DST->Chunk = chunk;
283 DST->NextLowerBound = lb;
284 DST->Stride = stride;
285 } else if (schedule == kmp_sched_static_balanced_chunk) {
286 ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
287 // save sched state
288 DST->ScheduleType = schedule;
289 // save ub
290 DST->LoopUpperBound = ub;
291 // compute static chunk
292 ST stride;
293 int lastiter = 0;
294 // round up to make sure the chunk is enough to cover all iterations
295 T span = (tripCount + tnum - 1) / tnum;
296 // perform chunk adjustment
297 chunk = (span + chunk - 1) & ~(chunk - 1);
299 T oldUb = ub;
300 ForStaticChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
301 ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
302 if (ub > oldUb)
303 ub = oldUb;
304 // save computed params
305 DST->Chunk = chunk;
306 DST->NextLowerBound = lb;
307 DST->Stride = stride;
308 } else if (schedule == kmp_sched_static_nochunk) {
309 ASSERT0(LT_FUSSY, chunk == 0, "bad chunk value");
310 // save sched state
311 DST->ScheduleType = schedule;
312 // save ub
313 DST->LoopUpperBound = ub;
314 // compute static chunk
315 ST stride;
316 int lastiter = 0;
317 ForStaticNoChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
318 // save computed params
319 DST->Chunk = chunk;
320 DST->NextLowerBound = lb;
321 DST->Stride = stride;
322 } else if (schedule == kmp_sched_dynamic || schedule == kmp_sched_guided) {
323 // save data
324 DST->ScheduleType = schedule;
325 if (chunk < 1)
326 chunk = 1;
327 DST->Chunk = chunk;
328 DST->LoopUpperBound = ub;
329 DST->NextLowerBound = lb;
330 __kmpc_barrier(loc, threadId);
331 if (tid == 0) {
332 Cnt = 0;
333 fence::team(atomic::seq_cst);
335 __kmpc_barrier(loc, threadId);
339 ////////////////////////////////////////////////////////////////////////////////
340 // Support for dispatch next
342 static uint64_t NextIter() {
343 __kmpc_impl_lanemask_t active = mapping::activemask();
344 uint32_t leader = utils::ffs(active) - 1;
345 uint32_t change = utils::popc(active);
346 __kmpc_impl_lanemask_t lane_mask_lt = mapping::lanemaskLT();
347 unsigned int rank = utils::popc(active & lane_mask_lt);
348 uint64_t warp_res = 0;
349 if (rank == 0) {
350 warp_res = atomic::add(&Cnt, change, atomic::seq_cst);
352 warp_res = utils::shuffle(active, warp_res, leader);
353 return warp_res + rank;
356 static int DynamicNextChunk(T &lb, T &ub, T chunkSize, T loopLowerBound,
357 T loopUpperBound) {
358 T N = NextIter();
359 lb = loopLowerBound + N * chunkSize;
360 ub = lb + chunkSize - 1; // Clang uses i <= ub
362 // 3 result cases:
363 // a. lb and ub < loopUpperBound --> NOT_FINISHED
364 // b. lb < loopUpperBound and ub >= loopUpperBound: last chunk -->
365 // NOT_FINISHED
366 // c. lb and ub >= loopUpperBound: empty chunk --> FINISHED
367 // a.
368 if (lb <= loopUpperBound && ub < loopUpperBound) {
369 return NOT_FINISHED;
371 // b.
372 if (lb <= loopUpperBound) {
373 ub = loopUpperBound;
374 return LAST_CHUNK;
376 // c. if we are here, we are in case 'c'
377 lb = loopUpperBound + 2;
378 ub = loopUpperBound + 1;
379 return FINISHED;
382 static int dispatch_next(IdentTy *loc, int32_t gtid, int32_t *plast,
383 T *plower, T *pupper, ST *pstride,
384 DynamicScheduleTracker *DST) {
385 // ID of a thread in its own warp
387 // automatically selects thread or warp ID based on selected implementation
388 ASSERT0(LT_FUSSY, gtid < omp_get_num_threads(),
389 "current thread is not needed here; error");
390 // retrieve schedule
391 kmp_sched_t schedule = DST->ScheduleType;
393 // xxx reduce to one
394 if (schedule == kmp_sched_static_chunk ||
395 schedule == kmp_sched_static_nochunk) {
396 T myLb = DST->NextLowerBound;
397 T ub = DST->LoopUpperBound;
398 // finished?
399 if (myLb > ub) {
400 return DISPATCH_FINISHED;
402 // not finished, save current bounds
403 ST chunk = DST->Chunk;
404 *plower = myLb;
405 T myUb = myLb + chunk - 1; // Clang uses i <= ub
406 if (myUb > ub)
407 myUb = ub;
408 *pupper = myUb;
409 *plast = (int32_t)(myUb == ub);
411 // increment next lower bound by the stride
412 ST stride = DST->Stride;
413 DST->NextLowerBound = myLb + stride;
414 return DISPATCH_NOTFINISHED;
416 ASSERT0(LT_FUSSY,
417 schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
418 "bad sched");
419 T myLb, myUb;
420 int finished = DynamicNextChunk(myLb, myUb, DST->Chunk, DST->NextLowerBound,
421 DST->LoopUpperBound);
423 if (finished == FINISHED)
424 return DISPATCH_FINISHED;
426 // not finished (either not finished or last chunk)
427 *plast = (int32_t)(finished == LAST_CHUNK);
428 *plower = myLb;
429 *pupper = myUb;
430 *pstride = 1;
432 return DISPATCH_NOTFINISHED;
435 static void dispatch_fini() {
436 // nothing
439 ////////////////////////////////////////////////////////////////////////////////
440 // end of template class that encapsulate all the helper functions
441 ////////////////////////////////////////////////////////////////////////////////
444 ////////////////////////////////////////////////////////////////////////////////
445 // KMP interface implementation (dyn loops)
446 ////////////////////////////////////////////////////////////////////////////////
448 // TODO: Expand the dispatch API to take a DST pointer which can then be
449 // allocated properly without malloc.
450 // For now, each team will contain an LDS pointer (ThreadDST) to a global array
451 // of references to the DST structs allocated (in global memory) for each thread
452 // in the team. The global memory array is allocated during the init phase if it
453 // was not allocated already and will be deallocated when the dispatch phase
454 // ends:
456 // __kmpc_dispatch_init
458 // ** Dispatch loop **
460 // __kmpc_dispatch_deinit
462 static DynamicScheduleTracker **SHARED(ThreadDST);
464 // Create a new DST, link the current one, and define the new as current.
465 static DynamicScheduleTracker *pushDST() {
466 int32_t ThreadIndex = mapping::getThreadIdInBlock();
467 // Each block will allocate an array of pointers to DST structs. The array is
468 // equal in length to the number of threads in that block.
469 if (!ThreadDST) {
470 // Allocate global memory array of pointers to DST structs:
471 if (mapping::isMainThreadInGenericMode() || ThreadIndex == 0)
472 ThreadDST = static_cast<DynamicScheduleTracker **>(
473 memory::allocGlobal(mapping::getNumberOfThreadsInBlock() *
474 sizeof(DynamicScheduleTracker *),
475 "new ThreadDST array"));
476 synchronize::threads(atomic::seq_cst);
478 // Initialize the array pointers:
479 ThreadDST[ThreadIndex] = nullptr;
482 // Create a DST struct for the current thread:
483 DynamicScheduleTracker *NewDST = static_cast<DynamicScheduleTracker *>(
484 memory::allocGlobal(sizeof(DynamicScheduleTracker), "new DST"));
485 *NewDST = DynamicScheduleTracker({0});
487 // Add the new DST struct to the array of DST structs:
488 NewDST->NextDST = ThreadDST[ThreadIndex];
489 ThreadDST[ThreadIndex] = NewDST;
490 return NewDST;
493 // Return the current DST.
494 static DynamicScheduleTracker *peekDST() {
495 return ThreadDST[mapping::getThreadIdInBlock()];
498 // Pop the current DST and restore the last one.
499 static void popDST() {
500 int32_t ThreadIndex = mapping::getThreadIdInBlock();
501 DynamicScheduleTracker *CurrentDST = ThreadDST[ThreadIndex];
502 DynamicScheduleTracker *OldDST = CurrentDST->NextDST;
503 memory::freeGlobal(CurrentDST, "remove DST");
504 ThreadDST[ThreadIndex] = OldDST;
506 // Check if we need to deallocate the global array. Ensure all threads
507 // in the block have finished deallocating the individual DSTs.
508 synchronize::threads(atomic::seq_cst);
509 if (!ThreadDST[ThreadIndex] && !ThreadIndex) {
510 memory::freeGlobal(ThreadDST, "remove ThreadDST array");
511 ThreadDST = nullptr;
513 synchronize::threads(atomic::seq_cst);
516 void workshare::init(bool IsSPMD) {
517 if (mapping::isInitialThreadInLevel0(IsSPMD))
518 ThreadDST = nullptr;
521 extern "C" {
523 // init
524 void __kmpc_dispatch_init_4(IdentTy *loc, int32_t tid, int32_t schedule,
525 int32_t lb, int32_t ub, int32_t st, int32_t chunk) {
526 DynamicScheduleTracker *DST = pushDST();
527 omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_init(
528 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
531 void __kmpc_dispatch_init_4u(IdentTy *loc, int32_t tid, int32_t schedule,
532 uint32_t lb, uint32_t ub, int32_t st,
533 int32_t chunk) {
534 DynamicScheduleTracker *DST = pushDST();
535 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_init(
536 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
539 void __kmpc_dispatch_init_8(IdentTy *loc, int32_t tid, int32_t schedule,
540 int64_t lb, int64_t ub, int64_t st, int64_t chunk) {
541 DynamicScheduleTracker *DST = pushDST();
542 omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_init(
543 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
546 void __kmpc_dispatch_init_8u(IdentTy *loc, int32_t tid, int32_t schedule,
547 uint64_t lb, uint64_t ub, int64_t st,
548 int64_t chunk) {
549 DynamicScheduleTracker *DST = pushDST();
550 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_init(
551 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
554 // next
555 int __kmpc_dispatch_next_4(IdentTy *loc, int32_t tid, int32_t *p_last,
556 int32_t *p_lb, int32_t *p_ub, int32_t *p_st) {
557 DynamicScheduleTracker *DST = peekDST();
558 return omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_next(
559 loc, tid, p_last, p_lb, p_ub, p_st, DST);
562 int __kmpc_dispatch_next_4u(IdentTy *loc, int32_t tid, int32_t *p_last,
563 uint32_t *p_lb, uint32_t *p_ub, int32_t *p_st) {
564 DynamicScheduleTracker *DST = peekDST();
565 return omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_next(
566 loc, tid, p_last, p_lb, p_ub, p_st, DST);
569 int __kmpc_dispatch_next_8(IdentTy *loc, int32_t tid, int32_t *p_last,
570 int64_t *p_lb, int64_t *p_ub, int64_t *p_st) {
571 DynamicScheduleTracker *DST = peekDST();
572 return omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_next(
573 loc, tid, p_last, p_lb, p_ub, p_st, DST);
576 int __kmpc_dispatch_next_8u(IdentTy *loc, int32_t tid, int32_t *p_last,
577 uint64_t *p_lb, uint64_t *p_ub, int64_t *p_st) {
578 DynamicScheduleTracker *DST = peekDST();
579 return omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_next(
580 loc, tid, p_last, p_lb, p_ub, p_st, DST);
583 // fini
584 void __kmpc_dispatch_fini_4(IdentTy *loc, int32_t tid) {
585 omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_fini();
588 void __kmpc_dispatch_fini_4u(IdentTy *loc, int32_t tid) {
589 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_fini();
592 void __kmpc_dispatch_fini_8(IdentTy *loc, int32_t tid) {
593 omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_fini();
596 void __kmpc_dispatch_fini_8u(IdentTy *loc, int32_t tid) {
597 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_fini();
600 // deinit
601 void __kmpc_dispatch_deinit(IdentTy *loc, int32_t tid) { popDST(); }
603 ////////////////////////////////////////////////////////////////////////////////
604 // KMP interface implementation (static loops)
605 ////////////////////////////////////////////////////////////////////////////////
607 void __kmpc_for_static_init_4(IdentTy *loc, int32_t global_tid,
608 int32_t schedtype, int32_t *plastiter,
609 int32_t *plower, int32_t *pupper,
610 int32_t *pstride, int32_t incr, int32_t chunk) {
611 omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
612 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
613 mapping::isSPMDMode());
616 void __kmpc_for_static_init_4u(IdentTy *loc, int32_t global_tid,
617 int32_t schedtype, int32_t *plastiter,
618 uint32_t *plower, uint32_t *pupper,
619 int32_t *pstride, int32_t incr, int32_t chunk) {
620 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
621 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
622 mapping::isSPMDMode());
625 void __kmpc_for_static_init_8(IdentTy *loc, int32_t global_tid,
626 int32_t schedtype, int32_t *plastiter,
627 int64_t *plower, int64_t *pupper,
628 int64_t *pstride, int64_t incr, int64_t chunk) {
629 omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
630 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
631 mapping::isSPMDMode());
634 void __kmpc_for_static_init_8u(IdentTy *loc, int32_t global_tid,
635 int32_t schedtype, int32_t *plastiter,
636 uint64_t *plower, uint64_t *pupper,
637 int64_t *pstride, int64_t incr, int64_t chunk) {
638 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
639 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
640 mapping::isSPMDMode());
643 void __kmpc_distribute_static_init_4(IdentTy *loc, int32_t global_tid,
644 int32_t schedtype, int32_t *plastiter,
645 int32_t *plower, int32_t *pupper,
646 int32_t *pstride, int32_t incr,
647 int32_t chunk) {
648 omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
649 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
650 mapping::isSPMDMode());
653 void __kmpc_distribute_static_init_4u(IdentTy *loc, int32_t global_tid,
654 int32_t schedtype, int32_t *plastiter,
655 uint32_t *plower, uint32_t *pupper,
656 int32_t *pstride, int32_t incr,
657 int32_t chunk) {
658 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
659 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
660 mapping::isSPMDMode());
663 void __kmpc_distribute_static_init_8(IdentTy *loc, int32_t global_tid,
664 int32_t schedtype, int32_t *plastiter,
665 int64_t *plower, int64_t *pupper,
666 int64_t *pstride, int64_t incr,
667 int64_t chunk) {
668 omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
669 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
670 mapping::isSPMDMode());
673 void __kmpc_distribute_static_init_8u(IdentTy *loc, int32_t global_tid,
674 int32_t schedtype, int32_t *plastiter,
675 uint64_t *plower, uint64_t *pupper,
676 int64_t *pstride, int64_t incr,
677 int64_t chunk) {
678 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
679 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
680 mapping::isSPMDMode());
683 void __kmpc_for_static_fini(IdentTy *loc, int32_t global_tid) {}
685 void __kmpc_distribute_static_fini(IdentTy *loc, int32_t global_tid) {}
688 namespace ompx {
690 /// Helper class to hide the generic loop nest and provide the template argument
691 /// throughout.
692 template <typename Ty> class StaticLoopChunker {
694 /// Generic loop nest that handles block and/or thread distribution in the
695 /// absence of user specified chunk sizes. This implicitly picks a block chunk
696 /// size equal to the number of threads in the block and a thread chunk size
697 /// equal to one. In contrast to the chunked version we can get away with a
698 /// single loop in this case
699 static void NormalizedLoopNestNoChunk(void (*LoopBody)(Ty, void *), void *Arg,
700 Ty NumBlocks, Ty BId, Ty NumThreads,
701 Ty TId, Ty NumIters,
702 bool OneIterationPerThread) {
703 Ty KernelIteration = NumBlocks * NumThreads;
705 // Start index in the normalized space.
706 Ty IV = BId * NumThreads + TId;
707 ASSERT(IV >= 0, "Bad index");
709 // Cover the entire iteration space, assumptions in the caller might allow
710 // to simplify this loop to a conditional.
711 if (IV < NumIters) {
712 do {
714 // Execute the loop body.
715 LoopBody(IV, Arg);
717 // Every thread executed one block and thread chunk now.
718 IV += KernelIteration;
720 if (OneIterationPerThread)
721 return;
723 } while (IV < NumIters);
727 /// Generic loop nest that handles block and/or thread distribution in the
728 /// presence of user specified chunk sizes (for at least one of them).
729 static void NormalizedLoopNestChunked(void (*LoopBody)(Ty, void *), void *Arg,
730 Ty BlockChunk, Ty NumBlocks, Ty BId,
731 Ty ThreadChunk, Ty NumThreads, Ty TId,
732 Ty NumIters,
733 bool OneIterationPerThread) {
734 Ty KernelIteration = NumBlocks * BlockChunk;
736 // Start index in the chunked space.
737 Ty IV = BId * BlockChunk + TId;
738 ASSERT(IV >= 0, "Bad index");
740 // Cover the entire iteration space, assumptions in the caller might allow
741 // to simplify this loop to a conditional.
742 do {
744 Ty BlockChunkLeft =
745 BlockChunk >= TId * ThreadChunk ? BlockChunk - TId * ThreadChunk : 0;
746 Ty ThreadChunkLeft =
747 ThreadChunk <= BlockChunkLeft ? ThreadChunk : BlockChunkLeft;
749 while (ThreadChunkLeft--) {
751 // Given the blocking it's hard to keep track of what to execute.
752 if (IV >= NumIters)
753 return;
755 // Execute the loop body.
756 LoopBody(IV, Arg);
758 if (OneIterationPerThread)
759 return;
761 ++IV;
764 IV += KernelIteration;
766 } while (IV < NumIters);
769 public:
770 /// Worksharing `for`-loop.
771 static void For(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
772 Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
773 ASSERT(NumIters >= 0, "Bad iteration count");
774 ASSERT(ThreadChunk >= 0, "Bad thread count");
776 // All threads need to participate but we don't know if we are in a
777 // parallel at all or if the user might have used a `num_threads` clause
778 // on the parallel and reduced the number compared to the block size.
779 // Since nested parallels are possible too we need to get the thread id
780 // from the `omp` getter and not the mapping directly.
781 Ty TId = omp_get_thread_num();
783 // There are no blocks involved here.
784 Ty BlockChunk = 0;
785 Ty NumBlocks = 1;
786 Ty BId = 0;
788 // If the thread chunk is not specified we pick a default now.
789 if (ThreadChunk == 0)
790 ThreadChunk = 1;
792 // If we know we have more threads than iterations we can indicate that to
793 // avoid an outer loop.
794 bool OneIterationPerThread = false;
795 if (config::getAssumeThreadsOversubscription()) {
796 ASSERT(NumThreads >= NumIters, "Broken assumption");
797 OneIterationPerThread = true;
800 if (ThreadChunk != 1)
801 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
802 ThreadChunk, NumThreads, TId, NumIters,
803 OneIterationPerThread);
804 else
805 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
806 NumIters, OneIterationPerThread);
809 /// Worksharing `distrbute`-loop.
810 static void Distribute(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
811 Ty NumIters, Ty BlockChunk) {
812 ASSERT(icv::Level == 0, "Bad distribute");
813 ASSERT(icv::ActiveLevel == 0, "Bad distribute");
814 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
815 ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
817 ASSERT(NumIters >= 0, "Bad iteration count");
818 ASSERT(BlockChunk >= 0, "Bad block count");
820 // There are no threads involved here.
821 Ty ThreadChunk = 0;
822 Ty NumThreads = 1;
823 Ty TId = 0;
824 ASSERT(TId == mapping::getThreadIdInBlock(), "Bad thread id");
826 // All teams need to participate.
827 Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
828 Ty BId = mapping::getBlockIdInKernel();
830 // If the block chunk is not specified we pick a default now.
831 if (BlockChunk == 0)
832 BlockChunk = NumThreads;
834 // If we know we have more blocks than iterations we can indicate that to
835 // avoid an outer loop.
836 bool OneIterationPerThread = false;
837 if (config::getAssumeTeamsOversubscription()) {
838 ASSERT(NumBlocks >= NumIters, "Broken assumption");
839 OneIterationPerThread = true;
842 if (BlockChunk != NumThreads)
843 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
844 ThreadChunk, NumThreads, TId, NumIters,
845 OneIterationPerThread);
846 else
847 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
848 NumIters, OneIterationPerThread);
850 ASSERT(icv::Level == 0, "Bad distribute");
851 ASSERT(icv::ActiveLevel == 0, "Bad distribute");
852 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
853 ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
856 /// Worksharing `distrbute parallel for`-loop.
857 static void DistributeFor(IdentTy *Loc, void (*LoopBody)(Ty, void *),
858 void *Arg, Ty NumIters, Ty NumThreads,
859 Ty BlockChunk, Ty ThreadChunk) {
860 ASSERT(icv::Level == 1, "Bad distribute");
861 ASSERT(icv::ActiveLevel == 1, "Bad distribute");
862 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
864 ASSERT(NumIters >= 0, "Bad iteration count");
865 ASSERT(BlockChunk >= 0, "Bad block count");
866 ASSERT(ThreadChunk >= 0, "Bad thread count");
868 // All threads need to participate but the user might have used a
869 // `num_threads` clause on the parallel and reduced the number compared to
870 // the block size.
871 Ty TId = mapping::getThreadIdInBlock();
873 // All teams need to participate.
874 Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
875 Ty BId = mapping::getBlockIdInKernel();
877 // If the block chunk is not specified we pick a default now.
878 if (BlockChunk == 0)
879 BlockChunk = NumThreads;
881 // If the thread chunk is not specified we pick a default now.
882 if (ThreadChunk == 0)
883 ThreadChunk = 1;
885 // If we know we have more threads (across all blocks) than iterations we
886 // can indicate that to avoid an outer loop.
887 bool OneIterationPerThread = false;
888 if (config::getAssumeTeamsOversubscription() &
889 config::getAssumeThreadsOversubscription()) {
890 OneIterationPerThread = true;
891 ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");
894 if (BlockChunk != NumThreads || ThreadChunk != 1)
895 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
896 ThreadChunk, NumThreads, TId, NumIters,
897 OneIterationPerThread);
898 else
899 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
900 NumIters, OneIterationPerThread);
902 ASSERT(icv::Level == 1, "Bad distribute");
903 ASSERT(icv::ActiveLevel == 1, "Bad distribute");
904 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
908 } // namespace ompx
910 #define OMP_LOOP_ENTRY(BW, TY) \
911 [[gnu::flatten, clang::always_inline]] void \
912 __kmpc_distribute_for_static_loop##BW( \
913 IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
914 TY num_threads, TY block_chunk, TY thread_chunk) { \
915 ompx::StaticLoopChunker<TY>::DistributeFor( \
916 loc, fn, arg, num_iters + 1, num_threads, block_chunk, thread_chunk); \
918 [[gnu::flatten, clang::always_inline]] void \
919 __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
920 void *arg, TY num_iters, \
921 TY block_chunk) { \
922 ompx::StaticLoopChunker<TY>::Distribute(loc, fn, arg, num_iters + 1, \
923 block_chunk); \
925 [[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW( \
926 IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
927 TY num_threads, TY thread_chunk) { \
928 ompx::StaticLoopChunker<TY>::For(loc, fn, arg, num_iters + 1, num_threads, \
929 thread_chunk); \
932 extern "C" {
933 OMP_LOOP_ENTRY(_4, int32_t)
934 OMP_LOOP_ENTRY(_4u, uint32_t)
935 OMP_LOOP_ENTRY(_8, int64_t)
936 OMP_LOOP_ENTRY(_8u, uint64_t)
939 #pragma omp end declare target