1 //===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "rpc_server.h"
11 #include "src/__support/RPC/rpc.h"
12 #include "src/stdio/gpu/file.h"
18 #include <unordered_map>
22 using namespace __llvm_libc
;
24 static_assert(sizeof(rpc_buffer_t
) == sizeof(rpc::Buffer
),
25 "Buffer size mismatch");
27 static_assert(RPC_MAXIMUM_PORT_COUNT
== rpc::MAX_PORT_COUNT
,
28 "Incorrect maximum port count");
30 // The client needs to support different lane sizes for the SIMT model. Because
31 // of this we need to select between the possible sizes that the client can use.
33 template <uint32_t lane_size
>
34 Server(std::unique_ptr
<rpc::Server
<lane_size
>> &&server
)
35 : server(std::move(server
)) {}
37 rpc_status_t
handle_server(
38 const std::unordered_map
<rpc_opcode_t
, rpc_opcode_callback_ty
> &callbacks
,
39 const std::unordered_map
<rpc_opcode_t
, void *> &callback_data
) {
40 rpc_status_t ret
= RPC_STATUS_SUCCESS
;
43 ret
= handle_server(*server
, callbacks
, callback_data
);
50 template <uint32_t lane_size
>
51 rpc_status_t
handle_server(
52 rpc::Server
<lane_size
> &server
,
53 const std::unordered_map
<rpc_opcode_t
, rpc_opcode_callback_ty
> &callbacks
,
54 const std::unordered_map
<rpc_opcode_t
, void *> &callback_data
) {
55 auto port
= server
.try_open();
57 return RPC_STATUS_SUCCESS
;
59 switch (port
->get_opcode()) {
60 case RPC_WRITE_TO_STREAM
:
61 case RPC_WRITE_TO_STDERR
:
62 case RPC_WRITE_TO_STDOUT
: {
63 uint64_t sizes
[lane_size
] = {0};
64 void *strs
[lane_size
] = {nullptr};
65 FILE *files
[lane_size
] = {nullptr};
66 if (port
->get_opcode() == RPC_WRITE_TO_STREAM
)
67 port
->recv([&](rpc::Buffer
*buffer
, uint32_t id
) {
68 files
[id
] = reinterpret_cast<FILE *>(buffer
->data
[0]);
70 port
->recv_n(strs
, sizes
, [&](uint64_t size
) { return new char[size
]; });
71 port
->send([&](rpc::Buffer
*buffer
, uint32_t id
) {
73 port
->get_opcode() == RPC_WRITE_TO_STDOUT
75 : (port
->get_opcode() == RPC_WRITE_TO_STDERR
? stderr
77 uint64_t ret
= fwrite(strs
[id
], 1, sizes
[id
], file
);
78 std::memcpy(buffer
->data
, &ret
, sizeof(uint64_t));
79 delete[] reinterpret_cast<uint8_t *>(strs
[id
]);
83 case RPC_READ_FROM_STREAM
: {
84 uint64_t sizes
[lane_size
] = {0};
85 void *data
[lane_size
] = {nullptr};
86 port
->recv([&](rpc::Buffer
*buffer
, uint32_t id
) {
87 data
[id
] = new char[buffer
->data
[0]];
88 sizes
[id
] = fread(data
[id
], 1, buffer
->data
[0],
89 file::to_stream(buffer
->data
[1]));
91 port
->send_n(data
, sizes
);
92 port
->send([&](rpc::Buffer
*buffer
, uint32_t id
) {
93 delete[] reinterpret_cast<uint8_t *>(data
[id
]);
94 std::memcpy(buffer
->data
, &sizes
[id
], sizeof(uint64_t));
99 uint64_t sizes
[lane_size
] = {0};
100 void *paths
[lane_size
] = {nullptr};
101 port
->recv_n(paths
, sizes
, [&](uint64_t size
) { return new char[size
]; });
102 port
->recv_and_send([&](rpc::Buffer
*buffer
, uint32_t id
) {
103 FILE *file
= fopen(reinterpret_cast<char *>(paths
[id
]),
104 reinterpret_cast<char *>(buffer
->data
));
105 buffer
->data
[0] = reinterpret_cast<uintptr_t>(file
);
109 case RPC_CLOSE_FILE
: {
110 port
->recv_and_send([&](rpc::Buffer
*buffer
, uint32_t id
) {
111 FILE *file
= reinterpret_cast<FILE *>(buffer
->data
[0]);
112 buffer
->data
[0] = fclose(file
);
117 // Send a response to the client to signal that we are ready to exit.
118 port
->recv_and_send([](rpc::Buffer
*) {});
119 port
->recv([](rpc::Buffer
*buffer
) {
121 std::memcpy(&status
, buffer
->data
, sizeof(int));
127 // Send a response to the client to signal that we are ready to abort.
128 port
->recv_and_send([](rpc::Buffer
*) {});
129 port
->recv([](rpc::Buffer
*) {});
133 case RPC_HOST_CALL
: {
134 uint64_t sizes
[lane_size
] = {0};
135 void *args
[lane_size
] = {nullptr};
136 port
->recv_n(args
, sizes
, [&](uint64_t size
) { return new char[size
]; });
137 port
->recv([&](rpc::Buffer
*buffer
, uint32_t id
) {
138 reinterpret_cast<void (*)(void *)>(buffer
->data
[0])(args
[id
]);
140 port
->send([&](rpc::Buffer
*, uint32_t id
) {
141 delete[] reinterpret_cast<uint8_t *>(args
[id
]);
146 port
->recv_and_send([](rpc::Buffer
*buffer
) {
147 buffer
->data
[0] = feof(file::to_stream(buffer
->data
[0]));
152 port
->recv_and_send([](rpc::Buffer
*buffer
) {
153 buffer
->data
[0] = ferror(file::to_stream(buffer
->data
[0]));
158 port
->recv_and_send([](rpc::Buffer
*buffer
) {
159 clearerr(file::to_stream(buffer
->data
[0]));
164 port
->recv([](rpc::Buffer
*) {});
169 callbacks
.find(static_cast<rpc_opcode_t
>(port
->get_opcode()));
171 // We error out on an unhandled opcode.
172 if (handler
== callbacks
.end())
173 return RPC_STATUS_UNHANDLED_OPCODE
;
175 // Invoke the registered callback with a reference to the port.
177 callback_data
.at(static_cast<rpc_opcode_t
>(port
->get_opcode()));
178 rpc_port_t port_ref
{reinterpret_cast<uint64_t>(&*port
), lane_size
};
179 (handler
->second
)(port_ref
, data
);
183 return RPC_STATUS_CONTINUE
;
186 std::variant
<std::unique_ptr
<rpc::Server
<1>>,
187 std::unique_ptr
<rpc::Server
<32>>,
188 std::unique_ptr
<rpc::Server
<64>>>
193 template <typename T
>
194 Device(uint32_t num_ports
, void *buffer
, std::unique_ptr
<T
> &&server
)
195 : buffer(buffer
), server(std::move(server
)), client(num_ports
, buffer
) {}
199 std::unordered_map
<rpc_opcode_t
, rpc_opcode_callback_ty
> callbacks
;
200 std::unordered_map
<rpc_opcode_t
, void *> callback_data
;
203 // A struct containing all the runtime state required to run the RPC server.
205 State(uint32_t num_devices
)
206 : num_devices(num_devices
), devices(num_devices
), reference_count(0u) {}
207 uint32_t num_devices
;
208 std::vector
<std::unique_ptr
<Device
>> devices
;
209 std::atomic_uint32_t reference_count
;
212 static std::mutex startup_mutex
;
216 rpc_status_t
rpc_init(uint32_t num_devices
) {
217 std::scoped_lock
<decltype(startup_mutex
)> lock(startup_mutex
);
219 state
= new State(num_devices
);
221 if (state
->reference_count
== std::numeric_limits
<uint32_t>::max())
222 return RPC_STATUS_ERROR
;
224 state
->reference_count
++;
226 return RPC_STATUS_SUCCESS
;
229 rpc_status_t
rpc_shutdown(void) {
230 if (state
&& state
->reference_count
-- == 1)
233 return RPC_STATUS_SUCCESS
;
236 template <uint32_t lane_size
>
237 rpc_status_t
server_init_impl(uint32_t device_id
, uint64_t num_ports
,
238 rpc_alloc_ty alloc
, void *data
) {
239 uint64_t size
= rpc::Server
<lane_size
>::allocation_size(num_ports
);
240 void *buffer
= alloc(size
, data
);
243 return RPC_STATUS_ERROR
;
245 state
->devices
[device_id
] = std::make_unique
<Device
>(
247 std::make_unique
<rpc::Server
<lane_size
>>(num_ports
, buffer
));
248 if (!state
->devices
[device_id
])
249 return RPC_STATUS_ERROR
;
251 return RPC_STATUS_SUCCESS
;
254 rpc_status_t
rpc_server_init(uint32_t device_id
, uint64_t num_ports
,
255 uint32_t lane_size
, rpc_alloc_ty alloc
,
258 return RPC_STATUS_NOT_INITIALIZED
;
259 if (device_id
>= state
->num_devices
)
260 return RPC_STATUS_OUT_OF_RANGE
;
262 if (!state
->devices
[device_id
]) {
265 if (rpc_status_t err
=
266 server_init_impl
<1>(device_id
, num_ports
, alloc
, data
))
270 if (rpc_status_t err
=
271 server_init_impl
<32>(device_id
, num_ports
, alloc
, data
))
276 if (rpc_status_t err
=
277 server_init_impl
<64>(device_id
, num_ports
, alloc
, data
))
281 return RPC_STATUS_INVALID_LANE_SIZE
;
285 return RPC_STATUS_SUCCESS
;
288 rpc_status_t
rpc_server_shutdown(uint32_t device_id
, rpc_free_ty dealloc
,
291 return RPC_STATUS_NOT_INITIALIZED
;
292 if (device_id
>= state
->num_devices
)
293 return RPC_STATUS_OUT_OF_RANGE
;
294 if (!state
->devices
[device_id
])
295 return RPC_STATUS_ERROR
;
297 dealloc(state
->devices
[device_id
]->buffer
, data
);
298 if (state
->devices
[device_id
])
299 state
->devices
[device_id
].release();
301 return RPC_STATUS_SUCCESS
;
304 rpc_status_t
rpc_handle_server(uint32_t device_id
) {
306 return RPC_STATUS_NOT_INITIALIZED
;
307 if (device_id
>= state
->num_devices
)
308 return RPC_STATUS_OUT_OF_RANGE
;
309 if (!state
->devices
[device_id
])
310 return RPC_STATUS_ERROR
;
313 auto &device
= *state
->devices
[device_id
];
314 rpc_status_t status
=
315 device
.server
.handle_server(device
.callbacks
, device
.callback_data
);
316 if (status
!= RPC_STATUS_CONTINUE
)
321 rpc_status_t
rpc_register_callback(uint32_t device_id
, rpc_opcode_t opcode
,
322 rpc_opcode_callback_ty callback
,
325 return RPC_STATUS_NOT_INITIALIZED
;
326 if (device_id
>= state
->num_devices
)
327 return RPC_STATUS_OUT_OF_RANGE
;
328 if (!state
->devices
[device_id
])
329 return RPC_STATUS_ERROR
;
331 state
->devices
[device_id
]->callbacks
[opcode
] = callback
;
332 state
->devices
[device_id
]->callback_data
[opcode
] = data
;
333 return RPC_STATUS_SUCCESS
;
336 const void *rpc_get_client_buffer(uint32_t device_id
) {
337 if (!state
|| device_id
>= state
->num_devices
|| !state
->devices
[device_id
])
339 return &state
->devices
[device_id
]->client
;
342 uint64_t rpc_get_client_size() { return sizeof(rpc::Client
); }
344 using ServerPort
= std::variant
<rpc::Server
<1>::Port
*, rpc::Server
<32>::Port
*,
345 rpc::Server
<64>::Port
*>;
347 ServerPort
get_port(rpc_port_t ref
) {
348 if (ref
.lane_size
== 1)
349 return reinterpret_cast<rpc::Server
<1>::Port
*>(ref
.handle
);
350 else if (ref
.lane_size
== 32)
351 return reinterpret_cast<rpc::Server
<32>::Port
*>(ref
.handle
);
352 else if (ref
.lane_size
== 64)
353 return reinterpret_cast<rpc::Server
<64>::Port
*>(ref
.handle
);
355 __builtin_unreachable();
358 void rpc_send(rpc_port_t ref
, rpc_port_callback_ty callback
, void *data
) {
359 auto port
= get_port(ref
);
362 port
->send([=](rpc::Buffer
*buffer
) {
363 callback(reinterpret_cast<rpc_buffer_t
*>(buffer
), data
);
369 void rpc_send_n(rpc_port_t ref
, const void *const *src
, uint64_t *size
) {
370 auto port
= get_port(ref
);
371 std::visit([=](auto &port
) { port
->send_n(src
, size
); }, port
);
374 void rpc_recv(rpc_port_t ref
, rpc_port_callback_ty callback
, void *data
) {
375 auto port
= get_port(ref
);
378 port
->recv([=](rpc::Buffer
*buffer
) {
379 callback(reinterpret_cast<rpc_buffer_t
*>(buffer
), data
);
385 void rpc_recv_n(rpc_port_t ref
, void **dst
, uint64_t *size
, rpc_alloc_ty alloc
,
387 auto port
= get_port(ref
);
388 auto alloc_fn
= [=](uint64_t size
) { return alloc(size
, data
); };
389 std::visit([=](auto &port
) { port
->recv_n(dst
, size
, alloc_fn
); }, port
);
392 void rpc_recv_and_send(rpc_port_t ref
, rpc_port_callback_ty callback
,
394 auto port
= get_port(ref
);
397 port
->recv_and_send([=](rpc::Buffer
*buffer
) {
398 callback(reinterpret_cast<rpc_buffer_t
*>(buffer
), data
);