1 //===------- Mapping.cpp - OpenMP device runtime mapping helpers -- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 //===----------------------------------------------------------------------===//
13 #include "DeviceTypes.h"
14 #include "DeviceUtils.h"
15 #include "Interface.h"
18 #pragma omp begin declare target device_type(nohost)
20 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
27 // Forward declarations defined to be defined for AMDGCN and NVPTX.
28 LaneMaskTy
activemask();
29 LaneMaskTy
lanemaskLT();
30 LaneMaskTy
lanemaskGT();
31 uint32_t getThreadIdInWarp();
32 uint32_t getThreadIdInBlock(int32_t Dim
);
33 uint32_t getNumberOfThreadsInBlock(int32_t Dim
);
34 uint32_t getNumberOfThreadsInKernel();
35 uint32_t getBlockIdInKernel(int32_t Dim
);
36 uint32_t getNumberOfBlocksInKernel(int32_t Dim
);
37 uint32_t getWarpIdInBlock();
38 uint32_t getNumberOfWarpsInBlock();
39 uint32_t getWarpSize();
41 /// AMDGCN Implementation
44 #pragma omp begin declare variant match(device = {arch(amdgcn)})
46 uint32_t getWarpSize() { return __builtin_amdgcn_wavefrontsize(); }
48 uint32_t getNumberOfThreadsInBlock(int32_t Dim
) {
51 return __builtin_amdgcn_workgroup_size_x();
53 return __builtin_amdgcn_workgroup_size_y();
55 return __builtin_amdgcn_workgroup_size_z();
57 UNREACHABLE("Dim outside range!");
60 LaneMaskTy
activemask() { return __builtin_amdgcn_read_exec(); }
62 LaneMaskTy
lanemaskLT() {
63 uint32_t Lane
= mapping::getThreadIdInWarp();
64 int64_t Ballot
= mapping::activemask();
65 uint64_t Mask
= ((uint64_t)1 << Lane
) - (uint64_t)1;
69 LaneMaskTy
lanemaskGT() {
70 uint32_t Lane
= mapping::getThreadIdInWarp();
71 if (Lane
== (mapping::getWarpSize() - 1))
73 int64_t Ballot
= mapping::activemask();
74 uint64_t Mask
= (~((uint64_t)0)) << (Lane
+ 1);
78 uint32_t getThreadIdInWarp() {
79 return __builtin_amdgcn_mbcnt_hi(~0u, __builtin_amdgcn_mbcnt_lo(~0u, 0u));
82 uint32_t getThreadIdInBlock(int32_t Dim
) {
85 return __builtin_amdgcn_workitem_id_x();
87 return __builtin_amdgcn_workitem_id_y();
89 return __builtin_amdgcn_workitem_id_z();
91 UNREACHABLE("Dim outside range!");
94 uint32_t getNumberOfThreadsInKernel() {
95 return __builtin_amdgcn_grid_size_x() * __builtin_amdgcn_grid_size_y() *
96 __builtin_amdgcn_grid_size_z();
99 uint32_t getBlockIdInKernel(int32_t Dim
) {
102 return __builtin_amdgcn_workgroup_id_x();
104 return __builtin_amdgcn_workgroup_id_y();
106 return __builtin_amdgcn_workgroup_id_z();
108 UNREACHABLE("Dim outside range!");
111 uint32_t getNumberOfBlocksInKernel(int32_t Dim
) {
114 return __builtin_amdgcn_grid_size_x() / __builtin_amdgcn_workgroup_size_x();
116 return __builtin_amdgcn_grid_size_y() / __builtin_amdgcn_workgroup_size_y();
118 return __builtin_amdgcn_grid_size_z() / __builtin_amdgcn_workgroup_size_z();
120 UNREACHABLE("Dim outside range!");
123 uint32_t getWarpIdInBlock() {
124 return impl::getThreadIdInBlock(mapping::DIM_X
) / mapping::getWarpSize();
127 uint32_t getNumberOfWarpsInBlock() {
128 return mapping::getNumberOfThreadsInBlock() / mapping::getWarpSize();
131 #pragma omp end declare variant
134 /// NVPTX Implementation
137 #pragma omp begin declare variant match( \
138 device = {arch(nvptx, nvptx64)}, \
139 implementation = {extension(match_any)})
141 uint32_t getNumberOfThreadsInBlock(int32_t Dim
) {
144 return __nvvm_read_ptx_sreg_ntid_x();
146 return __nvvm_read_ptx_sreg_ntid_y();
148 return __nvvm_read_ptx_sreg_ntid_z();
150 UNREACHABLE("Dim outside range!");
153 uint32_t getWarpSize() { return __nvvm_read_ptx_sreg_warpsize(); }
155 LaneMaskTy
activemask() { return __nvvm_activemask(); }
157 LaneMaskTy
lanemaskLT() { return __nvvm_read_ptx_sreg_lanemask_lt(); }
159 LaneMaskTy
lanemaskGT() { return __nvvm_read_ptx_sreg_lanemask_gt(); }
161 uint32_t getThreadIdInBlock(int32_t Dim
) {
164 return __nvvm_read_ptx_sreg_tid_x();
166 return __nvvm_read_ptx_sreg_tid_y();
168 return __nvvm_read_ptx_sreg_tid_z();
170 UNREACHABLE("Dim outside range!");
173 uint32_t getThreadIdInWarp() { return __nvvm_read_ptx_sreg_laneid(); }
175 uint32_t getBlockIdInKernel(int32_t Dim
) {
178 return __nvvm_read_ptx_sreg_ctaid_x();
180 return __nvvm_read_ptx_sreg_ctaid_y();
182 return __nvvm_read_ptx_sreg_ctaid_z();
184 UNREACHABLE("Dim outside range!");
187 uint32_t getNumberOfBlocksInKernel(int32_t Dim
) {
190 return __nvvm_read_ptx_sreg_nctaid_x();
192 return __nvvm_read_ptx_sreg_nctaid_y();
194 return __nvvm_read_ptx_sreg_nctaid_z();
196 UNREACHABLE("Dim outside range!");
199 uint32_t getNumberOfThreadsInKernel() {
200 return impl::getNumberOfThreadsInBlock(0) *
201 impl::getNumberOfBlocksInKernel(0) *
202 impl::getNumberOfThreadsInBlock(1) *
203 impl::getNumberOfBlocksInKernel(1) *
204 impl::getNumberOfThreadsInBlock(2) *
205 impl::getNumberOfBlocksInKernel(2);
208 uint32_t getWarpIdInBlock() {
209 return impl::getThreadIdInBlock(mapping::DIM_X
) / mapping::getWarpSize();
212 uint32_t getNumberOfWarpsInBlock() {
213 return (mapping::getNumberOfThreadsInBlock() + mapping::getWarpSize() - 1) /
214 mapping::getWarpSize();
217 #pragma omp end declare variant
223 /// We have to be deliberate about the distinction of `mapping::` and `impl::`
224 /// below to avoid repeating assumptions or including irrelevant ones.
227 static bool isInLastWarp() {
228 uint32_t MainTId
= (mapping::getNumberOfThreadsInBlock() - 1) &
229 ~(mapping::getWarpSize() - 1);
230 return mapping::getThreadIdInBlock() == MainTId
;
233 bool mapping::isMainThreadInGenericMode(bool IsSPMD
) {
234 if (IsSPMD
|| icv::Level
)
237 // Check if this is the last warp in the block.
238 return isInLastWarp();
241 bool mapping::isMainThreadInGenericMode() {
242 return mapping::isMainThreadInGenericMode(mapping::isSPMDMode());
245 bool mapping::isInitialThreadInLevel0(bool IsSPMD
) {
247 return mapping::getThreadIdInBlock() == 0;
248 return isInLastWarp();
251 bool mapping::isLeaderInWarp() {
252 __kmpc_impl_lanemask_t Active
= mapping::activemask();
253 __kmpc_impl_lanemask_t LaneMaskLT
= mapping::lanemaskLT();
254 return utils::popc(Active
& LaneMaskLT
) == 0;
257 LaneMaskTy
mapping::activemask() { return impl::activemask(); }
259 LaneMaskTy
mapping::lanemaskLT() { return impl::lanemaskLT(); }
261 LaneMaskTy
mapping::lanemaskGT() { return impl::lanemaskGT(); }
263 uint32_t mapping::getThreadIdInWarp() {
264 uint32_t ThreadIdInWarp
= impl::getThreadIdInWarp();
265 ASSERT(ThreadIdInWarp
< impl::getWarpSize(), nullptr);
266 return ThreadIdInWarp
;
269 uint32_t mapping::getThreadIdInBlock(int32_t Dim
) {
270 uint32_t ThreadIdInBlock
= impl::getThreadIdInBlock(Dim
);
271 return ThreadIdInBlock
;
274 uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
276 uint32_t mapping::getMaxTeamThreads(bool IsSPMD
) {
277 uint32_t BlockSize
= mapping::getNumberOfThreadsInBlock();
278 // If we are in SPMD mode, remove one warp.
279 return BlockSize
- (!IsSPMD
* impl::getWarpSize());
281 uint32_t mapping::getMaxTeamThreads() {
282 return mapping::getMaxTeamThreads(mapping::isSPMDMode());
285 uint32_t mapping::getNumberOfThreadsInBlock(int32_t Dim
) {
286 return impl::getNumberOfThreadsInBlock(Dim
);
289 uint32_t mapping::getNumberOfThreadsInKernel() {
290 return impl::getNumberOfThreadsInKernel();
293 uint32_t mapping::getWarpIdInBlock() {
294 uint32_t WarpID
= impl::getWarpIdInBlock();
295 ASSERT(WarpID
< impl::getNumberOfWarpsInBlock(), nullptr);
299 uint32_t mapping::getBlockIdInKernel(int32_t Dim
) {
300 uint32_t BlockId
= impl::getBlockIdInKernel(Dim
);
301 ASSERT(BlockId
< impl::getNumberOfBlocksInKernel(Dim
), nullptr);
305 uint32_t mapping::getNumberOfWarpsInBlock() {
306 uint32_t NumberOfWarpsInBlocks
= impl::getNumberOfWarpsInBlock();
307 ASSERT(impl::getWarpIdInBlock() < NumberOfWarpsInBlocks
, nullptr);
308 return NumberOfWarpsInBlocks
;
311 uint32_t mapping::getNumberOfBlocksInKernel(int32_t Dim
) {
312 uint32_t NumberOfBlocks
= impl::getNumberOfBlocksInKernel(Dim
);
313 ASSERT(impl::getBlockIdInKernel(Dim
) < NumberOfBlocks
, nullptr);
314 return NumberOfBlocks
;
317 uint32_t mapping::getNumberOfProcessorElements() {
318 return static_cast<uint32_t>(config::getHardwareParallelism());
327 // TODO: This is a workaround for initialization coming from kernels outside of
328 // the TU. We will need to solve this more correctly in the future.
329 [[gnu::weak
]] int SHARED(IsSPMDMode
);
331 void mapping::init(bool IsSPMD
) {
332 if (mapping::isInitialThreadInLevel0(IsSPMD
))
336 bool mapping::isSPMDMode() { return IsSPMDMode
; }
338 bool mapping::isGenericMode() { return !isSPMDMode(); }
342 [[gnu::noinline
]] uint32_t __kmpc_get_hardware_thread_id_in_block() {
343 return mapping::getThreadIdInBlock();
346 [[gnu::noinline
]] uint32_t __kmpc_get_hardware_num_threads_in_block() {
347 return impl::getNumberOfThreadsInBlock(mapping::DIM_X
);
350 [[gnu::noinline
]] uint32_t __kmpc_get_warp_size() {
351 return impl::getWarpSize();
355 #define _TGT_KERNEL_LANGUAGE(NAME, MAPPER_NAME) \
356 extern "C" int ompx_##NAME(int Dim) { return mapping::MAPPER_NAME(Dim); }
358 _TGT_KERNEL_LANGUAGE(thread_id
, getThreadIdInBlock
)
359 _TGT_KERNEL_LANGUAGE(block_id
, getBlockIdInKernel
)
360 _TGT_KERNEL_LANGUAGE(block_dim
, getNumberOfThreadsInBlock
)
361 _TGT_KERNEL_LANGUAGE(grid_dim
, getNumberOfBlocksInKernel
)
364 uint64_t ompx_ballot_sync(uint64_t mask
, int pred
) {
365 return utils::ballotSync(mask
, pred
);
368 int ompx_shfl_down_sync_i(uint64_t mask
, int var
, unsigned delta
, int width
) {
369 return utils::shuffleDown(mask
, var
, delta
, width
);
372 float ompx_shfl_down_sync_f(uint64_t mask
, float var
, unsigned delta
,
374 return utils::convertViaPun
<float>(utils::shuffleDown(
375 mask
, utils::convertViaPun
<int32_t>(var
), delta
, width
));
378 long ompx_shfl_down_sync_l(uint64_t mask
, long var
, unsigned delta
, int width
) {
379 return utils::shuffleDown(mask
, var
, delta
, width
);
382 double ompx_shfl_down_sync_d(uint64_t mask
, double var
, unsigned delta
,
384 return utils::convertViaPun
<double>(utils::shuffleDown(
385 mask
, utils::convertViaPun
<int64_t>(var
), delta
, width
));
389 #pragma omp end declare target