1 //===- RPC.h - Interface for remote procedure calls from the GPU ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "PluginInterface.h"
13 // This header file may be present in-tree or from an LLVM installation. The
14 // installed version lives alongside the GPU headers so we do not want to
15 // include it directly.
16 #if __has_include(<gpu-none-llvm/rpc_server.h>)
17 #include <gpu-none-llvm/rpc_server.h>
18 #elif defined(LIBOMPTARGET_RPC_SUPPORT)
19 #include <rpc_server.h>
24 using namespace target
;
26 RPCServerTy::RPCServerTy(uint32_t NumDevices
) {
27 #ifdef LIBOMPTARGET_RPC_SUPPORT
28 // If this fails then something is catastrophically wrong, just exit.
29 if (rpc_status_t Err
= rpc_init(NumDevices
))
30 FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err
);
35 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy
&Device
,
36 plugin::GenericGlobalHandlerTy
&Handler
,
37 plugin::DeviceImageTy
&Image
) {
38 #ifdef LIBOMPTARGET_RPC_SUPPORT
40 plugin::GlobalTy
Global(rpc_client_symbol_name
, sizeof(void *), &ClientPtr
);
41 if (auto Err
= Handler
.readGlobalFromImage(Device
, Image
, Global
)) {
42 llvm::consumeError(std::move(Err
));
52 Error
RPCServerTy::initDevice(plugin::GenericDeviceTy
&Device
,
53 plugin::GenericGlobalHandlerTy
&Handler
,
54 plugin::DeviceImageTy
&Image
) {
55 #ifdef LIBOMPTARGET_RPC_SUPPORT
56 uint32_t DeviceId
= Device
.getDeviceId();
57 auto Alloc
= [](uint64_t Size
, void *Data
) {
58 plugin::GenericDeviceTy
&Device
=
59 *reinterpret_cast<plugin::GenericDeviceTy
*>(Data
);
60 return Device
.allocate(Size
, nullptr, TARGET_ALLOC_HOST
);
63 std::min(Device
.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT
);
64 if (rpc_status_t Err
= rpc_server_init(DeviceId
, NumPorts
,
65 Device
.getWarpSize(), Alloc
, &Device
))
66 return plugin::Plugin::error(
67 "Failed to initialize RPC server for device %d: %d", DeviceId
, Err
);
69 // Register a custom opcode handler to perform plugin specific allocation.
70 // FIXME: We need to make sure this uses asynchronous allocations on CUDA.
71 auto MallocHandler
= [](rpc_port_t Port
, void *Data
) {
74 [](rpc_buffer_t
*Buffer
, void *Data
) {
75 plugin::GenericDeviceTy
&Device
=
76 *reinterpret_cast<plugin::GenericDeviceTy
*>(Data
);
77 Buffer
->data
[0] = reinterpret_cast<uintptr_t>(
78 Device
.allocate(Buffer
->data
[0], nullptr, TARGET_ALLOC_DEVICE
));
82 if (rpc_status_t Err
=
83 rpc_register_callback(DeviceId
, RPC_MALLOC
, MallocHandler
, &Device
))
84 return plugin::Plugin::error(
85 "Failed to register RPC malloc handler for device %d: %d\n", DeviceId
,
88 // Register a custom opcode handler to perform plugin specific deallocation.
89 auto FreeHandler
= [](rpc_port_t Port
, void *Data
) {
92 [](rpc_buffer_t
*Buffer
, void *Data
) {
93 plugin::GenericDeviceTy
&Device
=
94 *reinterpret_cast<plugin::GenericDeviceTy
*>(Data
);
95 Device
.free(reinterpret_cast<void *>(Buffer
->data
[0]),
100 if (rpc_status_t Err
=
101 rpc_register_callback(DeviceId
, RPC_FREE
, FreeHandler
, &Device
))
102 return plugin::Plugin::error(
103 "Failed to register RPC free handler for device %d: %d\n", DeviceId
,
106 // Get the address of the RPC client from the device.
108 plugin::GlobalTy
ClientGlobal(rpc_client_symbol_name
, sizeof(void *));
110 Handler
.getGlobalMetadataFromDevice(Device
, Image
, ClientGlobal
))
113 if (auto Err
= Device
.dataRetrieve(&ClientPtr
, ClientGlobal
.getPtr(),
114 sizeof(void *), nullptr))
117 const void *ClientBuffer
= rpc_get_client_buffer(DeviceId
);
118 if (auto Err
= Device
.dataSubmit(ClientPtr
, ClientBuffer
,
119 rpc_get_client_size(), nullptr))
122 return Error::success();
125 Error
RPCServerTy::runServer(plugin::GenericDeviceTy
&Device
) {
126 #ifdef LIBOMPTARGET_RPC_SUPPORT
127 if (rpc_status_t Err
= rpc_handle_server(Device
.getDeviceId()))
128 return plugin::Plugin::error(
129 "Error while running RPC server on device %d: %d", Device
.getDeviceId(),
132 return Error::success();
135 Error
RPCServerTy::deinitDevice(plugin::GenericDeviceTy
&Device
) {
136 #ifdef LIBOMPTARGET_RPC_SUPPORT
137 auto Dealloc
= [](void *Ptr
, void *Data
) {
138 plugin::GenericDeviceTy
&Device
=
139 *reinterpret_cast<plugin::GenericDeviceTy
*>(Data
);
140 Device
.free(Ptr
, TARGET_ALLOC_HOST
);
142 if (rpc_status_t Err
=
143 rpc_server_shutdown(Device
.getDeviceId(), Dealloc
, &Device
))
144 return plugin::Plugin::error(
145 "Failed to shut down RPC server for device %d: %d",
146 Device
.getDeviceId(), Err
);
148 return Error::success();
151 RPCServerTy::~RPCServerTy() {
152 #ifdef LIBOMPTARGET_RPC_SUPPORT