[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / compiler-rt / lib / orc / wrapper_function_utils.h
blobdcb6d0e6addbcda8bcbc4e7daf73a9fc17746c3e
1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
11 //===----------------------------------------------------------------------===//
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
16 #include "orc_rt/c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "executor_address.h"
20 #include "simple_packed_serialization.h"
21 #include <type_traits>
23 namespace __orc_rt {
25 /// C++ wrapper function result: Same as CWrapperFunctionResult but
26 /// auto-releases memory.
27 class WrapperFunctionResult {
28 public:
29 /// Create a default WrapperFunctionResult.
30 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }
32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33 /// instance takes ownership of the result object and will automatically
34 /// call dispose on the result upon destruction.
35 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
37 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
40 WrapperFunctionResult(WrapperFunctionResult &&Other) {
41 orc_rt_CWrapperFunctionResultInit(&R);
42 std::swap(R, Other.R);
45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46 orc_rt_CWrapperFunctionResult Tmp;
47 orc_rt_CWrapperFunctionResultInit(&Tmp);
48 std::swap(Tmp, Other.R);
49 std::swap(R, Tmp);
50 return *this;
53 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }
55 /// Relinquish ownership of and return the
56 /// orc_rt_CWrapperFunctionResult.
57 orc_rt_CWrapperFunctionResult release() {
58 orc_rt_CWrapperFunctionResult Tmp;
59 orc_rt_CWrapperFunctionResultInit(&Tmp);
60 std::swap(R, Tmp);
61 return Tmp;
64 /// Get a pointer to the data contained in this instance.
65 char *data() { return orc_rt_CWrapperFunctionResultData(&R); }
67 /// Returns the size of the data contained in this instance.
68 size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }
70 /// Returns true if this value is equivalent to a default-constructed
71 /// WrapperFunctionResult.
72 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }
74 /// Create a WrapperFunctionResult with the given size and return a pointer
75 /// to the underlying memory.
76 static WrapperFunctionResult allocate(size_t Size) {
77 WrapperFunctionResult R;
78 R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
79 return R;
82 /// Copy from the given char range.
83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84 return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
87 /// Copy from the given null-terminated string (includes the null-terminator).
88 static WrapperFunctionResult copyFrom(const char *Source) {
89 return orc_rt_CreateCWrapperFunctionResultFromString(Source);
92 /// Copy from the given std::string (includes the null terminator).
93 static WrapperFunctionResult copyFrom(const std::string &Source) {
94 return copyFrom(Source.c_str());
97 /// Create an out-of-band error by copying the given string.
98 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
102 /// Create an out-of-band error by copying the given string.
103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104 return createOutOfBandError(Msg.c_str());
107 template <typename SPSArgListT, typename... ArgTs>
108 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
109 auto Result = allocate(SPSArgListT::size(Args...));
110 SPSOutputBuffer OB(Result.data(), Result.size());
111 if (!SPSArgListT::serialize(OB, Args...))
112 return createOutOfBandError(
113 "Error serializing arguments to blob in call");
114 return Result;
117 /// If this value is an out-of-band error then this returns the error message,
118 /// otherwise returns nullptr.
119 const char *getOutOfBandError() const {
120 return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
123 private:
124 orc_rt_CWrapperFunctionResult R;
127 namespace detail {
129 template <typename RetT> class WrapperFunctionHandlerCaller {
130 public:
131 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
133 std::index_sequence<I...>) {
134 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
138 template <> class WrapperFunctionHandlerCaller<void> {
139 public:
140 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
142 std::index_sequence<I...>) {
143 std::forward<HandlerT>(H)(std::get<I>(Args)...);
144 return SPSEmpty();
148 template <typename WrapperFunctionImplT,
149 template <typename> class ResultSerializer, typename... SPSTagTs>
150 class WrapperFunctionHandlerHelper
151 : public WrapperFunctionHandlerHelper<
152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
153 ResultSerializer, SPSTagTs...> {};
155 template <typename RetT, typename... ArgTs,
156 template <typename> class ResultSerializer, typename... SPSTagTs>
157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
158 SPSTagTs...> {
159 public:
160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
163 template <typename HandlerT>
164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
165 size_t ArgSize) {
166 ArgTuple Args;
167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
168 return WrapperFunctionResult::createOutOfBandError(
169 "Could not deserialize arguments for wrapper function call");
171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
172 std::forward<HandlerT>(H), Args, ArgIndices{});
174 return ResultSerializer<decltype(HandlerResult)>::serialize(
175 std::move(HandlerResult));
178 private:
179 template <std::size_t... I>
180 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
181 std::index_sequence<I...>) {
182 SPSInputBuffer IB(ArgData, ArgSize);
183 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
187 // Map function pointers to function types.
188 template <typename RetT, typename... ArgTs,
189 template <typename> class ResultSerializer, typename... SPSTagTs>
190 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
191 SPSTagTs...>
192 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
193 SPSTagTs...> {};
195 // Map non-const member function types to function types.
196 template <typename ClassT, typename RetT, typename... ArgTs,
197 template <typename> class ResultSerializer, typename... SPSTagTs>
198 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
199 SPSTagTs...>
200 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
201 SPSTagTs...> {};
203 // Map const member function types to function types.
204 template <typename ClassT, typename RetT, typename... ArgTs,
205 template <typename> class ResultSerializer, typename... SPSTagTs>
206 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
207 ResultSerializer, SPSTagTs...>
208 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
209 SPSTagTs...> {};
211 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
212 public:
213 static WrapperFunctionResult serialize(RetT Result) {
214 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
218 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
219 public:
220 static WrapperFunctionResult serialize(Error Err) {
221 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
222 toSPSSerializable(std::move(Err)));
226 template <typename SPSRetTagT, typename T>
227 class ResultSerializer<SPSRetTagT, Expected<T>> {
228 public:
229 static WrapperFunctionResult serialize(Expected<T> E) {
230 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
231 toSPSSerializable(std::move(E)));
235 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
236 public:
237 static void makeSafe(RetT &Result) {}
239 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
240 SPSInputBuffer IB(ArgData, ArgSize);
241 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
242 return make_error<StringError>(
243 "Error deserializing return value from blob in call");
244 return Error::success();
248 template <> class ResultDeserializer<SPSError, Error> {
249 public:
250 static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
252 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
253 SPSInputBuffer IB(ArgData, ArgSize);
254 SPSSerializableError BSE;
255 if (!SPSArgList<SPSError>::deserialize(IB, BSE))
256 return make_error<StringError>(
257 "Error deserializing return value from blob in call");
258 Err = fromSPSSerializable(std::move(BSE));
259 return Error::success();
263 template <typename SPSTagT, typename T>
264 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
265 public:
266 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
268 static Error deserialize(Expected<T> &E, const char *ArgData,
269 size_t ArgSize) {
270 SPSInputBuffer IB(ArgData, ArgSize);
271 SPSSerializableExpected<T> BSE;
272 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
273 return make_error<StringError>(
274 "Error deserializing return value from blob in call");
275 E = fromSPSSerializable(std::move(BSE));
276 return Error::success();
280 } // end namespace detail
282 template <typename SPSSignature> class WrapperFunction;
284 template <typename SPSRetTagT, typename... SPSTagTs>
285 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
286 private:
287 template <typename RetT>
288 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
290 public:
291 template <typename RetT, typename... ArgTs>
292 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
294 // RetT might be an Error or Expected value. Set the checked flag now:
295 // we don't want the user to have to check the unused result if this
296 // operation fails.
297 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
299 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
300 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
301 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
302 return make_error<StringError>("__orc_rt_jit_dispatch not set");
304 auto ArgBuffer =
305 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
306 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
307 return make_error<StringError>(ErrMsg);
309 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
310 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
311 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
312 return make_error<StringError>(ErrMsg);
314 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
315 Result, ResultBuffer.data(), ResultBuffer.size());
318 template <typename HandlerT>
319 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
320 HandlerT &&Handler) {
321 using WFHH =
322 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
323 ResultSerializer, SPSTagTs...>;
324 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
327 private:
328 template <typename T> static const T &makeSerializable(const T &Value) {
329 return Value;
332 static detail::SPSSerializableError makeSerializable(Error Err) {
333 return detail::toSPSSerializable(std::move(Err));
336 template <typename T>
337 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
338 return detail::toSPSSerializable(std::move(E));
342 template <typename... SPSTagTs>
343 class WrapperFunction<void(SPSTagTs...)>
344 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
345 public:
346 template <typename... ArgTs>
347 static Error call(const void *FnTag, const ArgTs &...Args) {
348 SPSEmpty BE;
349 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
352 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
355 /// A function object that takes an ExecutorAddr as its first argument,
356 /// casts that address to a ClassT*, then calls the given method on that
357 /// pointer passing in the remaining function arguments. This utility
358 /// removes some of the boilerplate from writing wrappers for method calls.
360 /// @code{.cpp}
361 /// class MyClass {
362 /// public:
363 /// void myMethod(uint32_t, bool) { ... }
364 /// };
366 /// // SPS Method signature -- note MyClass object address as first argument.
367 /// using SPSMyMethodWrapperSignature =
368 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
370 /// WrapperFunctionResult
371 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
372 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
373 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
374 /// }
375 /// @endcode
377 template <typename RetT, typename ClassT, typename... ArgTs>
378 class MethodWrapperHandler {
379 public:
380 using MethodT = RetT (ClassT::*)(ArgTs...);
381 MethodWrapperHandler(MethodT M) : M(M) {}
382 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
383 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
386 private:
387 MethodT M;
390 /// Create a MethodWrapperHandler object from the given method pointer.
391 template <typename RetT, typename ClassT, typename... ArgTs>
392 MethodWrapperHandler<RetT, ClassT, ArgTs...>
393 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
394 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
397 /// Represents a call to a wrapper function.
398 class WrapperFunctionCall {
399 public:
400 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
401 // smallvector.
402 using ArgDataBufferType = std::vector<char>;
404 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
405 /// the arguments.
406 template <typename SPSSerializer, typename... ArgTs>
407 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
408 const ArgTs &...Args) {
409 ArgDataBufferType ArgData;
410 ArgData.resize(SPSSerializer::size(Args...));
411 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
412 ArgData.size());
413 if (SPSSerializer::serialize(OB, Args...))
414 return WrapperFunctionCall(FnAddr, std::move(ArgData));
415 return make_error<StringError>("Cannot serialize arguments for "
416 "AllocActionCall");
419 WrapperFunctionCall() = default;
421 /// Create a WrapperFunctionCall from a target function and arg buffer.
422 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
423 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
425 /// Returns the address to be called.
426 const ExecutorAddr &getCallee() const { return FnAddr; }
428 /// Returns the argument data.
429 const ArgDataBufferType &getArgData() const { return ArgData; }
431 /// WrapperFunctionCalls convert to true if the callee is non-null.
432 explicit operator bool() const { return !!FnAddr; }
434 /// Run call returning raw WrapperFunctionResult.
435 WrapperFunctionResult run() const {
436 using FnTy =
437 orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
438 return WrapperFunctionResult(
439 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
442 /// Run call and deserialize result using SPS.
443 template <typename SPSRetT, typename RetT>
444 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
445 runWithSPSRet(RetT &RetVal) const {
446 auto WFR = run();
447 if (const char *ErrMsg = WFR.getOutOfBandError())
448 return make_error<StringError>(ErrMsg);
449 SPSInputBuffer IB(WFR.data(), WFR.size());
450 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
451 return make_error<StringError>("Could not deserialize result from "
452 "serialized wrapper function call");
453 return Error::success();
456 /// Overload for SPS functions returning void.
457 template <typename SPSRetT>
458 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
459 runWithSPSRet() const {
460 SPSEmpty E;
461 return runWithSPSRet<SPSEmpty>(E);
464 /// Run call and deserialize an SPSError result. SPSError returns and
465 /// deserialization failures are merged into the returned error.
466 Error runWithSPSRetErrorMerged() const {
467 detail::SPSSerializableError RetErr;
468 if (auto Err = runWithSPSRet<SPSError>(RetErr))
469 return Err;
470 return detail::fromSPSSerializable(std::move(RetErr));
473 private:
474 ExecutorAddr FnAddr;
475 std::vector<char> ArgData;
478 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
480 template <>
481 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
482 public:
483 static size_t size(const WrapperFunctionCall &WFC) {
484 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
485 WFC.getCallee(), WFC.getArgData());
488 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
489 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
490 OB, WFC.getCallee(), WFC.getArgData());
493 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
494 ExecutorAddr FnAddr;
495 WrapperFunctionCall::ArgDataBufferType ArgData;
496 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
497 return false;
498 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
499 return true;
503 } // end namespace __orc_rt
505 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H