1 //===------------------ Client.h - Client Implementation ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // gRPC Client for the remote plugin.
11 //===----------------------------------------------------------------------===//
13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H
14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H
17 #include "omptarget.h"
18 #include <google/protobuf/arena.h>
19 #include <grpcpp/grpcpp.h>
20 #include <grpcpp/security/credentials.h>
21 #include <grpcpp/support/channel_arguments.h>
27 using openmp::libomptarget::remote::RemoteOffload
;
28 using namespace RemoteOffloading
;
30 using namespace google
;
32 class RemoteOffloadClient
{
35 const uint64_t MaxSize
;
36 const int64_t BlockSize
;
38 std::unique_ptr
<RemoteOffload::Stub
> Stub
;
39 std::unique_ptr
<protobuf::Arena
> Arena
;
41 std::unique_ptr
<std::mutex
> ArenaAllocatorLock
;
43 std::map
<int32_t, std::unordered_map
<void *, void *>> RemoteEntries
;
44 std::map
<int32_t, std::unique_ptr
<__tgt_target_table
>> DevicesToTables
;
46 template <typename Fn1
, typename Fn2
, typename TReturn
>
47 auto remoteCall(Fn1 Preprocessor
, Fn2 Postprocessor
, TReturn ErrorValue
,
48 bool CanTimeOut
= true);
51 RemoteOffloadClient(std::shared_ptr
<Channel
> Channel
, int Timeout
,
52 uint64_t MaxSize
, int64_t BlockSize
)
53 : Timeout(Timeout
), MaxSize(MaxSize
), BlockSize(BlockSize
),
54 Stub(RemoteOffload::NewStub(Channel
)) {
55 DebugLevel
= getDebugLevel();
56 Arena
= std::make_unique
<protobuf::Arena
>();
57 ArenaAllocatorLock
= std::make_unique
<std::mutex
>();
60 RemoteOffloadClient(RemoteOffloadClient
&&C
) = default;
62 ~RemoteOffloadClient() {
63 for (auto &TableIt
: DevicesToTables
)
64 freeTargetTable(TableIt
.second
.get());
67 int32_t shutdown(void);
69 int32_t registerLib(__tgt_bin_desc
*Desc
);
70 int32_t unregisterLib(__tgt_bin_desc
*Desc
);
72 int32_t isValidBinary(__tgt_device_image
*Image
);
73 int32_t getNumberOfDevices();
75 int32_t initDevice(int32_t DeviceId
);
76 int32_t initRequires(int64_t RequiresFlags
);
78 __tgt_target_table
*loadBinary(int32_t DeviceId
, __tgt_device_image
*Image
);
80 void *dataAlloc(int32_t DeviceId
, int64_t Size
, void *HstPtr
);
81 int32_t dataDelete(int32_t DeviceId
, void *TgtPtr
);
83 int32_t dataSubmit(int32_t DeviceId
, void *TgtPtr
, void *HstPtr
,
85 int32_t dataRetrieve(int32_t DeviceId
, void *HstPtr
, void *TgtPtr
,
88 int32_t isDataExchangeable(int32_t SrcDevId
, int32_t DstDevId
);
89 int32_t dataExchange(int32_t SrcDevId
, void *SrcPtr
, int32_t DstDevId
,
90 void *DstPtr
, int64_t Size
);
92 int32_t runTargetRegion(int32_t DeviceId
, void *TgtEntryPtr
, void **TgtArgs
,
93 ptrdiff_t *TgtOffsets
, int32_t ArgNum
);
94 int32_t runTargetTeamRegion(int32_t DeviceId
, void *TgtEntryPtr
,
95 void **TgtArgs
, ptrdiff_t *TgtOffsets
,
96 int32_t ArgNum
, int32_t TeamNum
,
97 int32_t ThreadLimit
, uint64_t LoopTripCount
);
100 class RemoteClientManager
{
102 std::vector
<RemoteOffloadClient
> Clients
;
103 std::vector
<int> Devices
;
105 std::pair
<int32_t, int32_t> mapDeviceId(int32_t DeviceId
);
109 RemoteClientManager() {
110 ClientManagerConfigTy Config
;
112 grpc::ChannelArguments ChArgs
;
113 ChArgs
.SetMaxReceiveMessageSize(-1);
114 DebugLevel
= getDebugLevel();
115 for (auto Address
: Config
.ServerAddresses
) {
116 Clients
.push_back(RemoteOffloadClient(
117 grpc::CreateChannel(Address
, grpc::InsecureChannelCredentials()),
118 Config
.Timeout
, Config
.MaxSize
, Config
.BlockSize
));
122 int32_t shutdown(void);
124 int32_t registerLib(__tgt_bin_desc
*Desc
);
125 int32_t unregisterLib(__tgt_bin_desc
*Desc
);
127 int32_t isValidBinary(__tgt_device_image
*Image
);
128 int32_t getNumberOfDevices();
130 int32_t initDevice(int32_t DeviceId
);
131 int32_t initRequires(int64_t RequiresFlags
);
133 __tgt_target_table
*loadBinary(int32_t DeviceId
, __tgt_device_image
*Image
);
135 void *dataAlloc(int32_t DeviceId
, int64_t Size
, void *HstPtr
);
136 int32_t dataDelete(int32_t DeviceId
, void *TgtPtr
);
138 int32_t dataSubmit(int32_t DeviceId
, void *TgtPtr
, void *HstPtr
,
140 int32_t dataRetrieve(int32_t DeviceId
, void *HstPtr
, void *TgtPtr
,
143 int32_t isDataExchangeable(int32_t SrcDevId
, int32_t DstDevId
);
144 int32_t dataExchange(int32_t SrcDevId
, void *SrcPtr
, int32_t DstDevId
,
145 void *DstPtr
, int64_t Size
);
147 int32_t runTargetRegion(int32_t DeviceId
, void *TgtEntryPtr
, void **TgtArgs
,
148 ptrdiff_t *TgtOffsets
, int32_t ArgNum
);
149 int32_t runTargetTeamRegion(int32_t DeviceId
, void *TgtEntryPtr
,
150 void **TgtArgs
, ptrdiff_t *TgtOffsets
,
151 int32_t ArgNum
, int32_t TeamNum
,
152 int32_t ThreadLimit
, uint64_t LoopTripCount
);