1 //===----------------- Server.cpp - Server Implementation -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Offloading gRPC server for remote host.
11 //===----------------------------------------------------------------------===//
17 #include "omptarget.h"
18 #include "openmp.grpc.pb.h"
19 #include "openmp.pb.h"
21 using grpc::WriteOptions
;
23 extern std::promise
<void> ShutdownPromise
;
25 Status
RemoteOffloadImpl::Shutdown(ServerContext
*Context
, const Null
*Request
,
27 SERVER_DBG("Shutting down the server")
30 ShutdownPromise
.set_value();
35 RemoteOffloadImpl::RegisterLib(ServerContext
*Context
,
36 const TargetBinaryDescription
*Description
,
38 auto Desc
= std::make_unique
<__tgt_bin_desc
>();
40 unloadTargetBinaryDescription(Description
, Desc
.get(),
41 HostToRemoteDeviceImage
);
42 PM
->RTLs
.RegisterLib(Desc
.get());
44 if (Descriptions
.find((void *)Description
->bin_ptr()) != Descriptions
.end())
45 freeTargetBinaryDescription(
46 Descriptions
[(void *)Description
->bin_ptr()].get());
48 Descriptions
[(void *)Description
->bin_ptr()] = std::move(Desc
);
50 SERVER_DBG("Registered library")
55 Status
RemoteOffloadImpl::UnregisterLib(ServerContext
*Context
,
56 const Pointer
*Request
, I32
*Reply
) {
57 if (Descriptions
.find((void *)Request
->number()) == Descriptions
.end()) {
62 PM
->RTLs
.UnregisterLib(Descriptions
[(void *)Request
->number()].get());
63 freeTargetBinaryDescription(Descriptions
[(void *)Request
->number()].get());
64 Descriptions
.erase((void *)Request
->number());
66 SERVER_DBG("Unregistered library")
71 Status
RemoteOffloadImpl::IsValidBinary(ServerContext
*Context
,
72 const TargetDeviceImagePtr
*DeviceImage
,
74 __tgt_device_image
*Image
=
75 HostToRemoteDeviceImage
[(void *)DeviceImage
->image_ptr()];
77 IsValid
->set_number(0);
79 for (auto &RTL
: PM
->RTLs
.AllRTLs
)
80 if (auto Ret
= RTL
.is_valid_binary(Image
)) {
81 IsValid
->set_number(Ret
);
85 SERVER_DBG("Checked if binary (%p) is valid",
86 (void *)(DeviceImage
->image_ptr()))
90 Status
RemoteOffloadImpl::GetNumberOfDevices(ServerContext
*Context
,
92 I32
*NumberOfDevices
) {
93 std::call_once(PM
->RTLs
.initFlag
, &RTLsTy::LoadRTLs
, &PM
->RTLs
);
97 for (auto &RTL
: PM
->RTLs
.AllRTLs
)
98 Devices
+= RTL
.NumberOfDevices
;
101 NumberOfDevices
->set_number(Devices
);
103 SERVER_DBG("Got number of devices")
107 Status
RemoteOffloadImpl::InitDevice(ServerContext
*Context
,
108 const I32
*DeviceNum
, I32
*Reply
) {
109 Reply
->set_number(PM
->Devices
[DeviceNum
->number()].RTL
->init_device(
110 mapHostRTLDeviceId(DeviceNum
->number())));
112 SERVER_DBG("Initialized device %d", DeviceNum
->number())
116 Status
RemoteOffloadImpl::InitRequires(ServerContext
*Context
,
117 const I64
*RequiresFlag
, I32
*Reply
) {
118 for (auto &Device
: PM
->Devices
)
119 if (Device
.RTL
->init_requires
)
120 Device
.RTL
->init_requires(RequiresFlag
->number());
121 Reply
->set_number(RequiresFlag
->number());
123 SERVER_DBG("Initialized requires for devices")
127 Status
RemoteOffloadImpl::LoadBinary(ServerContext
*Context
,
128 const Binary
*Binary
, TargetTable
*Reply
) {
129 __tgt_device_image
*Image
=
130 HostToRemoteDeviceImage
[(void *)Binary
->image_ptr()];
132 Table
= PM
->Devices
[Binary
->device_id()].RTL
->load_binary(
133 mapHostRTLDeviceId(Binary
->device_id()), Image
);
135 loadTargetTable(Table
, *Reply
, Image
);
137 SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary
->image_ptr(),
142 Status
RemoteOffloadImpl::IsDataExchangeable(ServerContext
*Context
,
143 const DevicePair
*Request
,
145 Reply
->set_number(-1);
146 if (PM
->Devices
[mapHostRTLDeviceId(Request
->src_dev_id())]
147 .RTL
->is_data_exchangable
)
148 Reply
->set_number(PM
->Devices
[mapHostRTLDeviceId(Request
->src_dev_id())]
149 .RTL
->is_data_exchangable(Request
->src_dev_id(),
150 Request
->dst_dev_id()));
152 SERVER_DBG("Checked if data exchangeable between device %d and device %d",
153 Request
->src_dev_id(), Request
->dst_dev_id())
157 Status
RemoteOffloadImpl::DataAlloc(ServerContext
*Context
,
158 const AllocData
*Request
, Pointer
*Reply
) {
159 uint64_t TgtPtr
= (uint64_t)PM
->Devices
[Request
->device_id()].RTL
->data_alloc(
160 mapHostRTLDeviceId(Request
->device_id()), Request
->size(),
161 (void *)Request
->hst_ptr(), TARGET_ALLOC_DEFAULT
);
162 Reply
->set_number(TgtPtr
);
164 SERVER_DBG("Allocated at " DPxMOD
"", DPxPTR((void *)TgtPtr
))
169 Status
RemoteOffloadImpl::DataSubmit(ServerContext
*Context
,
170 ServerReader
<SubmitData
> *Reader
,
173 uint8_t *HostCopy
= nullptr;
174 while (Reader
->Read(&Request
)) {
175 if (Request
.start() == 0 && Request
.size() == Request
.data().size()) {
176 Reader
->SendInitialMetadata();
178 Reply
->set_number(PM
->Devices
[Request
.device_id()].RTL
->data_submit(
179 mapHostRTLDeviceId(Request
.device_id()), (void *)Request
.tgt_ptr(),
180 (void *)Request
.data().data(), Request
.data().size()));
182 SERVER_DBG("Submitted %lu bytes async to (%p) on device %d",
183 Request
.data().size(), (void *)Request
.tgt_ptr(),
189 HostCopy
= new uint8_t[Request
.size()];
190 Reader
->SendInitialMetadata();
193 memcpy((void *)((char *)HostCopy
+ Request
.start()), Request
.data().data(),
194 Request
.data().size());
197 Reply
->set_number(PM
->Devices
[Request
.device_id()].RTL
->data_submit(
198 mapHostRTLDeviceId(Request
.device_id()), (void *)Request
.tgt_ptr(),
199 HostCopy
, Request
.size()));
203 SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request
.data().size(),
204 (void *)Request
.tgt_ptr(), Request
.device_id())
209 Status
RemoteOffloadImpl::DataRetrieve(ServerContext
*Context
,
210 const RetrieveData
*Request
,
211 ServerWriter
<Data
> *Writer
) {
212 auto HstPtr
= std::make_unique
<char[]>(Request
->size());
214 auto Ret
= PM
->Devices
[Request
->device_id()].RTL
->data_retrieve(
215 mapHostRTLDeviceId(Request
->device_id()), HstPtr
.get(),
216 (void *)Request
->tgt_ptr(), Request
->size());
218 if (Arena
->SpaceAllocated() >= MaxSize
)
221 if (Request
->size() > BlockSize
) {
222 uint64_t Start
= 0, End
= BlockSize
;
223 for (auto I
= 0; I
< ceil((float)Request
->size() / BlockSize
); I
++) {
224 auto *Reply
= protobuf::Arena::CreateMessage
<Data
>(Arena
.get());
226 Reply
->set_start(Start
);
227 Reply
->set_size(Request
->size());
228 Reply
->set_data((char *)HstPtr
.get() + Start
, End
- Start
);
231 if (!Writer
->Write(*Reply
)) {
232 CLIENT_DBG("Broken stream when submitting data")
235 SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start
,
236 End
, Request
->size(), (void *)Request
->tgt_ptr(),
237 mapHostRTLDeviceId(Request
->device_id()))
241 if (End
>= Request
->size())
242 End
= Request
->size();
245 auto *Reply
= protobuf::Arena::CreateMessage
<Data
>(Arena
.get());
248 Reply
->set_size(Request
->size());
249 Reply
->set_data((char *)HstPtr
.get(), Request
->size());
252 SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request
->size(),
253 (void *)Request
->tgt_ptr(),
254 mapHostRTLDeviceId(Request
->device_id()))
256 Writer
->WriteLast(*Reply
, WriteOptions());
262 Status
RemoteOffloadImpl::DataExchange(ServerContext
*Context
,
263 const ExchangeData
*Request
,
265 if (PM
->Devices
[Request
->src_dev_id()].RTL
->data_exchange
) {
266 int32_t Ret
= PM
->Devices
[Request
->src_dev_id()].RTL
->data_exchange(
267 mapHostRTLDeviceId(Request
->src_dev_id()), (void *)Request
->src_ptr(),
268 mapHostRTLDeviceId(Request
->dst_dev_id()), (void *)Request
->dst_ptr(),
270 Reply
->set_number(Ret
);
272 Reply
->set_number(-1);
275 "Exchanged data asynchronously from device %d (%p) to device %d (%p) of "
277 mapHostRTLDeviceId(Request
->src_dev_id()), (void *)Request
->src_ptr(),
278 mapHostRTLDeviceId(Request
->dst_dev_id()), (void *)Request
->dst_ptr(),
283 Status
RemoteOffloadImpl::DataDelete(ServerContext
*Context
,
284 const DeleteData
*Request
, I32
*Reply
) {
285 auto Ret
= PM
->Devices
[Request
->device_id()].RTL
->data_delete(
286 mapHostRTLDeviceId(Request
->device_id()), (void *)Request
->tgt_ptr());
287 Reply
->set_number(Ret
);
289 SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request
->tgt_ptr(),
290 mapHostRTLDeviceId(Request
->device_id()))
294 Status
RemoteOffloadImpl::RunTargetRegion(ServerContext
*Context
,
295 const TargetRegion
*Request
,
297 std::vector
<uint8_t> TgtArgs(Request
->arg_num());
298 for (auto I
= 0; I
< Request
->arg_num(); I
++)
299 TgtArgs
[I
] = (uint64_t)Request
->tgt_args()[I
];
301 std::vector
<ptrdiff_t> TgtOffsets(Request
->arg_num());
302 const auto *TgtOffsetItr
= Request
->tgt_offsets().begin();
303 for (auto I
= 0; I
< Request
->arg_num(); I
++, TgtOffsetItr
++)
304 TgtOffsets
[I
] = (ptrdiff_t)*TgtOffsetItr
;
306 void *TgtEntryPtr
= ((__tgt_offload_entry
*)Request
->tgt_entry_ptr())->addr
;
308 int32_t Ret
= PM
->Devices
[Request
->device_id()].RTL
->run_region(
309 mapHostRTLDeviceId(Request
->device_id()), TgtEntryPtr
,
310 (void **)TgtArgs
.data(), TgtOffsets
.data(), Request
->arg_num());
312 Reply
->set_number(Ret
);
314 SERVER_DBG("Ran TargetRegion on device %d with %d args",
315 mapHostRTLDeviceId(Request
->device_id()), Request
->arg_num())
319 Status
RemoteOffloadImpl::RunTargetTeamRegion(ServerContext
*Context
,
320 const TargetTeamRegion
*Request
,
322 std::vector
<uint64_t> TgtArgs(Request
->arg_num());
323 for (auto I
= 0; I
< Request
->arg_num(); I
++)
324 TgtArgs
[I
] = (uint64_t)Request
->tgt_args()[I
];
326 std::vector
<ptrdiff_t> TgtOffsets(Request
->arg_num());
327 const auto *TgtOffsetItr
= Request
->tgt_offsets().begin();
328 for (auto I
= 0; I
< Request
->arg_num(); I
++, TgtOffsetItr
++)
329 TgtOffsets
[I
] = (ptrdiff_t)*TgtOffsetItr
;
331 void *TgtEntryPtr
= ((__tgt_offload_entry
*)Request
->tgt_entry_ptr())->addr
;
333 int32_t Ret
= PM
->Devices
[Request
->device_id()].RTL
->run_team_region(
334 mapHostRTLDeviceId(Request
->device_id()), TgtEntryPtr
,
335 (void **)TgtArgs
.data(), TgtOffsets
.data(), Request
->arg_num(),
336 Request
->team_num(), Request
->thread_limit(), Request
->loop_tripcount());
338 Reply
->set_number(Ret
);
340 SERVER_DBG("Ran TargetTeamRegion on device %d with %d args",
341 mapHostRTLDeviceId(Request
->device_id()), Request
->arg_num())
345 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID
) {
346 for (auto &RTL
: PM
->RTLs
.UsedRTLs
) {
347 if (RTLDeviceID
- RTL
->NumberOfDevices
>= 0)
348 RTLDeviceID
-= RTL
->NumberOfDevices
;