[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / libc / utils / gpu / server / rpc_server.cpp
blob7493ed66ceecb8c3307e24250095f0631ca3c486
1 //===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "rpc_server.h"
11 #include "src/__support/RPC/rpc.h"
12 #include "src/stdio/gpu/file.h"
13 #include <atomic>
14 #include <cstdio>
15 #include <cstring>
16 #include <memory>
17 #include <mutex>
18 #include <unordered_map>
19 #include <variant>
20 #include <vector>
22 using namespace __llvm_libc;
24 static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
25 "Buffer size mismatch");
27 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
28 "Incorrect maximum port count");
30 // The client needs to support different lane sizes for the SIMT model. Because
31 // of this we need to select between the possible sizes that the client can use.
32 struct Server {
33 template <uint32_t lane_size>
34 Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
35 : server(std::move(server)) {}
37 rpc_status_t handle_server(
38 const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
39 const std::unordered_map<rpc_opcode_t, void *> &callback_data) {
40 rpc_status_t ret = RPC_STATUS_SUCCESS;
41 std::visit(
42 [&](auto &server) {
43 ret = handle_server(*server, callbacks, callback_data);
45 server);
46 return ret;
49 private:
50 template <uint32_t lane_size>
51 rpc_status_t handle_server(
52 rpc::Server<lane_size> &server,
53 const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
54 const std::unordered_map<rpc_opcode_t, void *> &callback_data) {
55 auto port = server.try_open();
56 if (!port)
57 return RPC_STATUS_SUCCESS;
59 switch (port->get_opcode()) {
60 case RPC_WRITE_TO_STREAM:
61 case RPC_WRITE_TO_STDERR:
62 case RPC_WRITE_TO_STDOUT: {
63 uint64_t sizes[lane_size] = {0};
64 void *strs[lane_size] = {nullptr};
65 FILE *files[lane_size] = {nullptr};
66 if (port->get_opcode() == RPC_WRITE_TO_STREAM)
67 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
68 files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
69 });
70 port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
71 port->send([&](rpc::Buffer *buffer, uint32_t id) {
72 FILE *file =
73 port->get_opcode() == RPC_WRITE_TO_STDOUT
74 ? stdout
75 : (port->get_opcode() == RPC_WRITE_TO_STDERR ? stderr
76 : files[id]);
77 uint64_t ret = fwrite(strs[id], 1, sizes[id], file);
78 std::memcpy(buffer->data, &ret, sizeof(uint64_t));
79 delete[] reinterpret_cast<uint8_t *>(strs[id]);
80 });
81 break;
83 case RPC_READ_FROM_STREAM: {
84 uint64_t sizes[lane_size] = {0};
85 void *data[lane_size] = {nullptr};
86 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
87 data[id] = new char[buffer->data[0]];
88 sizes[id] = fread(data[id], 1, buffer->data[0],
89 file::to_stream(buffer->data[1]));
90 });
91 port->send_n(data, sizes);
92 port->send([&](rpc::Buffer *buffer, uint32_t id) {
93 delete[] reinterpret_cast<uint8_t *>(data[id]);
94 std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
95 });
96 break;
98 case RPC_OPEN_FILE: {
99 uint64_t sizes[lane_size] = {0};
100 void *paths[lane_size] = {nullptr};
101 port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; });
102 port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
103 FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
104 reinterpret_cast<char *>(buffer->data));
105 buffer->data[0] = reinterpret_cast<uintptr_t>(file);
107 break;
109 case RPC_CLOSE_FILE: {
110 port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
111 FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
112 buffer->data[0] = fclose(file);
114 break;
116 case RPC_EXIT: {
117 // Send a response to the client to signal that we are ready to exit.
118 port->recv_and_send([](rpc::Buffer *) {});
119 port->recv([](rpc::Buffer *buffer) {
120 int status = 0;
121 std::memcpy(&status, buffer->data, sizeof(int));
122 exit(status);
124 break;
126 case RPC_ABORT: {
127 // Send a response to the client to signal that we are ready to abort.
128 port->recv_and_send([](rpc::Buffer *) {});
129 port->recv([](rpc::Buffer *) {});
130 abort();
131 break;
133 case RPC_HOST_CALL: {
134 uint64_t sizes[lane_size] = {0};
135 void *args[lane_size] = {nullptr};
136 port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
137 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
138 reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
140 port->send([&](rpc::Buffer *, uint32_t id) {
141 delete[] reinterpret_cast<uint8_t *>(args[id]);
143 break;
145 case RPC_FEOF: {
146 port->recv_and_send([](rpc::Buffer *buffer) {
147 buffer->data[0] = feof(file::to_stream(buffer->data[0]));
149 break;
151 case RPC_FERROR: {
152 port->recv_and_send([](rpc::Buffer *buffer) {
153 buffer->data[0] = ferror(file::to_stream(buffer->data[0]));
155 break;
157 case RPC_CLEARERR: {
158 port->recv_and_send([](rpc::Buffer *buffer) {
159 clearerr(file::to_stream(buffer->data[0]));
161 break;
163 case RPC_NOOP: {
164 port->recv([](rpc::Buffer *) {});
165 break;
167 default: {
168 auto handler =
169 callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
171 // We error out on an unhandled opcode.
172 if (handler == callbacks.end())
173 return RPC_STATUS_UNHANDLED_OPCODE;
175 // Invoke the registered callback with a reference to the port.
176 void *data =
177 callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
178 rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
179 (handler->second)(port_ref, data);
182 port->close();
183 return RPC_STATUS_CONTINUE;
186 std::variant<std::unique_ptr<rpc::Server<1>>,
187 std::unique_ptr<rpc::Server<32>>,
188 std::unique_ptr<rpc::Server<64>>>
189 server;
192 struct Device {
193 template <typename T>
194 Device(uint32_t num_ports, void *buffer, std::unique_ptr<T> &&server)
195 : buffer(buffer), server(std::move(server)), client(num_ports, buffer) {}
196 void *buffer;
197 Server server;
198 rpc::Client client;
199 std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
200 std::unordered_map<rpc_opcode_t, void *> callback_data;
203 // A struct containing all the runtime state required to run the RPC server.
204 struct State {
205 State(uint32_t num_devices)
206 : num_devices(num_devices), devices(num_devices), reference_count(0u) {}
207 uint32_t num_devices;
208 std::vector<std::unique_ptr<Device>> devices;
209 std::atomic_uint32_t reference_count;
212 static std::mutex startup_mutex;
214 static State *state;
216 rpc_status_t rpc_init(uint32_t num_devices) {
217 std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
218 if (!state)
219 state = new State(num_devices);
221 if (state->reference_count == std::numeric_limits<uint32_t>::max())
222 return RPC_STATUS_ERROR;
224 state->reference_count++;
226 return RPC_STATUS_SUCCESS;
229 rpc_status_t rpc_shutdown(void) {
230 if (state && state->reference_count-- == 1)
231 delete state;
233 return RPC_STATUS_SUCCESS;
236 template <uint32_t lane_size>
237 rpc_status_t server_init_impl(uint32_t device_id, uint64_t num_ports,
238 rpc_alloc_ty alloc, void *data) {
239 uint64_t size = rpc::Server<lane_size>::allocation_size(num_ports);
240 void *buffer = alloc(size, data);
242 if (!buffer)
243 return RPC_STATUS_ERROR;
245 state->devices[device_id] = std::make_unique<Device>(
246 num_ports, buffer,
247 std::make_unique<rpc::Server<lane_size>>(num_ports, buffer));
248 if (!state->devices[device_id])
249 return RPC_STATUS_ERROR;
251 return RPC_STATUS_SUCCESS;
254 rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
255 uint32_t lane_size, rpc_alloc_ty alloc,
256 void *data) {
257 if (!state)
258 return RPC_STATUS_NOT_INITIALIZED;
259 if (device_id >= state->num_devices)
260 return RPC_STATUS_OUT_OF_RANGE;
262 if (!state->devices[device_id]) {
263 switch (lane_size) {
264 case 1:
265 if (rpc_status_t err =
266 server_init_impl<1>(device_id, num_ports, alloc, data))
267 return err;
268 break;
269 case 32: {
270 if (rpc_status_t err =
271 server_init_impl<32>(device_id, num_ports, alloc, data))
272 return err;
273 break;
275 case 64:
276 if (rpc_status_t err =
277 server_init_impl<64>(device_id, num_ports, alloc, data))
278 return err;
279 break;
280 default:
281 return RPC_STATUS_INVALID_LANE_SIZE;
285 return RPC_STATUS_SUCCESS;
288 rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
289 void *data) {
290 if (!state)
291 return RPC_STATUS_NOT_INITIALIZED;
292 if (device_id >= state->num_devices)
293 return RPC_STATUS_OUT_OF_RANGE;
294 if (!state->devices[device_id])
295 return RPC_STATUS_ERROR;
297 dealloc(state->devices[device_id]->buffer, data);
298 if (state->devices[device_id])
299 state->devices[device_id].release();
301 return RPC_STATUS_SUCCESS;
304 rpc_status_t rpc_handle_server(uint32_t device_id) {
305 if (!state)
306 return RPC_STATUS_NOT_INITIALIZED;
307 if (device_id >= state->num_devices)
308 return RPC_STATUS_OUT_OF_RANGE;
309 if (!state->devices[device_id])
310 return RPC_STATUS_ERROR;
312 for (;;) {
313 auto &device = *state->devices[device_id];
314 rpc_status_t status =
315 device.server.handle_server(device.callbacks, device.callback_data);
316 if (status != RPC_STATUS_CONTINUE)
317 return status;
321 rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
322 rpc_opcode_callback_ty callback,
323 void *data) {
324 if (!state)
325 return RPC_STATUS_NOT_INITIALIZED;
326 if (device_id >= state->num_devices)
327 return RPC_STATUS_OUT_OF_RANGE;
328 if (!state->devices[device_id])
329 return RPC_STATUS_ERROR;
331 state->devices[device_id]->callbacks[opcode] = callback;
332 state->devices[device_id]->callback_data[opcode] = data;
333 return RPC_STATUS_SUCCESS;
336 const void *rpc_get_client_buffer(uint32_t device_id) {
337 if (!state || device_id >= state->num_devices || !state->devices[device_id])
338 return nullptr;
339 return &state->devices[device_id]->client;
342 uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
344 using ServerPort = std::variant<rpc::Server<1>::Port *, rpc::Server<32>::Port *,
345 rpc::Server<64>::Port *>;
347 ServerPort get_port(rpc_port_t ref) {
348 if (ref.lane_size == 1)
349 return reinterpret_cast<rpc::Server<1>::Port *>(ref.handle);
350 else if (ref.lane_size == 32)
351 return reinterpret_cast<rpc::Server<32>::Port *>(ref.handle);
352 else if (ref.lane_size == 64)
353 return reinterpret_cast<rpc::Server<64>::Port *>(ref.handle);
354 else
355 __builtin_unreachable();
358 void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
359 auto port = get_port(ref);
360 std::visit(
361 [=](auto &port) {
362 port->send([=](rpc::Buffer *buffer) {
363 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
366 port);
369 void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) {
370 auto port = get_port(ref);
371 std::visit([=](auto &port) { port->send_n(src, size); }, port);
374 void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
375 auto port = get_port(ref);
376 std::visit(
377 [=](auto &port) {
378 port->recv([=](rpc::Buffer *buffer) {
379 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
382 port);
385 void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc,
386 void *data) {
387 auto port = get_port(ref);
388 auto alloc_fn = [=](uint64_t size) { return alloc(size, data); };
389 std::visit([=](auto &port) { port->recv_n(dst, size, alloc_fn); }, port);
392 void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
393 void *data) {
394 auto port = get_port(ref);
395 std::visit(
396 [=](auto &port) {
397 port->recv_and_send([=](rpc::Buffer *buffer) {
398 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
401 port);