1 //===------- Utils.cpp - OpenMP device runtime utility functions -- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 //===----------------------------------------------------------------------===//
15 #include "Interface.h"
18 #pragma omp begin declare target device_type(nohost)
24 bool isSharedMemPtr(const void *Ptr
) { return false; }
26 void Unpack(uint64_t Val
, uint32_t *LowBits
, uint32_t *HighBits
) {
27 static_assert(sizeof(unsigned long) == 8, "");
28 *LowBits
= static_cast<uint32_t>(Val
& 0x00000000FFFFFFFFUL
);
29 *HighBits
= static_cast<uint32_t>((Val
& 0xFFFFFFFF00000000UL
) >> 32);
32 uint64_t Pack(uint32_t LowBits
, uint32_t HighBits
) {
33 return (((uint64_t)HighBits
) << 32) | (uint64_t)LowBits
;
36 int32_t shuffle(uint64_t Mask
, int32_t Var
, int32_t SrcLane
);
37 int32_t shuffleDown(uint64_t Mask
, int32_t Var
, uint32_t LaneDelta
,
40 uint64_t ballotSync(uint64_t Mask
, int32_t Pred
);
42 /// AMDGCN Implementation
45 #pragma omp begin declare variant match(device = {arch(amdgcn)})
47 int32_t shuffle(uint64_t Mask
, int32_t Var
, int32_t SrcLane
) {
48 int Width
= mapping::getWarpSize();
49 int Self
= mapping::getThreadIdInWarp();
50 int Index
= SrcLane
+ (Self
& ~(Width
- 1));
51 return __builtin_amdgcn_ds_bpermute(Index
<< 2, Var
);
54 int32_t shuffleDown(uint64_t Mask
, int32_t Var
, uint32_t LaneDelta
,
56 int Self
= mapping::getThreadIdInWarp();
57 int Index
= Self
+ LaneDelta
;
58 Index
= (int)(LaneDelta
+ (Self
& (Width
- 1))) >= Width
? Self
: Index
;
59 return __builtin_amdgcn_ds_bpermute(Index
<< 2, Var
);
62 uint64_t ballotSync(uint64_t Mask
, int32_t Pred
) {
63 return Mask
& __builtin_amdgcn_ballot_w64(Pred
);
66 bool isSharedMemPtr(const void *Ptr
) {
67 return __builtin_amdgcn_is_shared(
68 (const __attribute__((address_space(0))) void *)Ptr
);
70 #pragma omp end declare variant
73 /// NVPTX Implementation
76 #pragma omp begin declare variant match( \
77 device = {arch(nvptx, nvptx64)}, \
78 implementation = {extension(match_any)})
80 int32_t shuffle(uint64_t Mask
, int32_t Var
, int32_t SrcLane
) {
81 return __nvvm_shfl_sync_idx_i32(Mask
, Var
, SrcLane
, 0x1f);
84 int32_t shuffleDown(uint64_t Mask
, int32_t Var
, uint32_t Delta
, int32_t Width
) {
85 int32_t T
= ((mapping::getWarpSize() - Width
) << 8) | 0x1f;
86 return __nvvm_shfl_sync_down_i32(Mask
, Var
, Delta
, T
);
89 uint64_t ballotSync(uint64_t Mask
, int32_t Pred
) {
90 return __nvvm_vote_ballot_sync(static_cast<uint32_t>(Mask
), Pred
);
93 bool isSharedMemPtr(const void *Ptr
) { return __nvvm_isspacep_shared(Ptr
); }
95 #pragma omp end declare variant
99 uint64_t utils::pack(uint32_t LowBits
, uint32_t HighBits
) {
100 return impl::Pack(LowBits
, HighBits
);
103 void utils::unpack(uint64_t Val
, uint32_t &LowBits
, uint32_t &HighBits
) {
104 impl::Unpack(Val
, &LowBits
, &HighBits
);
107 int32_t utils::shuffle(uint64_t Mask
, int32_t Var
, int32_t SrcLane
) {
108 return impl::shuffle(Mask
, Var
, SrcLane
);
111 int32_t utils::shuffleDown(uint64_t Mask
, int32_t Var
, uint32_t Delta
,
113 return impl::shuffleDown(Mask
, Var
, Delta
, Width
);
116 int64_t utils::shuffleDown(uint64_t Mask
, int64_t Var
, uint32_t Delta
,
119 utils::unpack(Var
, Lo
, Hi
);
120 Hi
= impl::shuffleDown(Mask
, Hi
, Delta
, Width
);
121 Lo
= impl::shuffleDown(Mask
, Lo
, Delta
, Width
);
122 return utils::pack(Lo
, Hi
);
125 uint64_t utils::ballotSync(uint64_t Mask
, int32_t Pred
) {
126 return impl::ballotSync(Mask
, Pred
);
129 bool utils::isSharedMemPtr(void *Ptr
) { return impl::isSharedMemPtr(Ptr
); }
132 int32_t __kmpc_shuffle_int32(int32_t Val
, int16_t Delta
, int16_t SrcLane
) {
133 return impl::shuffleDown(lanes::All
, Val
, Delta
, SrcLane
);
136 int64_t __kmpc_shuffle_int64(int64_t Val
, int16_t Delta
, int16_t Width
) {
137 return utils::shuffleDown(lanes::All
, Val
, Delta
, Width
);
141 #pragma omp end declare target