1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file is a part of the ORC runtime support library.
11 //===----------------------------------------------------------------------===//
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
16 #include "orc_rt/c_api.h"
19 #include "executor_address.h"
20 #include "simple_packed_serialization.h"
21 #include <type_traits>
25 /// C++ wrapper function result: Same as CWrapperFunctionResult but
26 /// auto-releases memory.
27 class WrapperFunctionResult
{
29 /// Create a default WrapperFunctionResult.
30 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R
); }
32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33 /// instance takes ownership of the result object and will automatically
34 /// call dispose on the result upon destruction.
35 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R
) : R(R
) {}
37 WrapperFunctionResult(const WrapperFunctionResult
&) = delete;
38 WrapperFunctionResult
&operator=(const WrapperFunctionResult
&) = delete;
40 WrapperFunctionResult(WrapperFunctionResult
&&Other
) {
41 orc_rt_CWrapperFunctionResultInit(&R
);
42 std::swap(R
, Other
.R
);
45 WrapperFunctionResult
&operator=(WrapperFunctionResult
&&Other
) {
46 orc_rt_CWrapperFunctionResult Tmp
;
47 orc_rt_CWrapperFunctionResultInit(&Tmp
);
48 std::swap(Tmp
, Other
.R
);
53 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R
); }
55 /// Relinquish ownership of and return the
56 /// orc_rt_CWrapperFunctionResult.
57 orc_rt_CWrapperFunctionResult
release() {
58 orc_rt_CWrapperFunctionResult Tmp
;
59 orc_rt_CWrapperFunctionResultInit(&Tmp
);
64 /// Get a pointer to the data contained in this instance.
65 char *data() { return orc_rt_CWrapperFunctionResultData(&R
); }
67 /// Returns the size of the data contained in this instance.
68 size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R
); }
70 /// Returns true if this value is equivalent to a default-constructed
71 /// WrapperFunctionResult.
72 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R
); }
74 /// Create a WrapperFunctionResult with the given size and return a pointer
75 /// to the underlying memory.
76 static WrapperFunctionResult
allocate(size_t Size
) {
77 WrapperFunctionResult R
;
78 R
.R
= orc_rt_CWrapperFunctionResultAllocate(Size
);
82 /// Copy from the given char range.
83 static WrapperFunctionResult
copyFrom(const char *Source
, size_t Size
) {
84 return orc_rt_CreateCWrapperFunctionResultFromRange(Source
, Size
);
87 /// Copy from the given null-terminated string (includes the null-terminator).
88 static WrapperFunctionResult
copyFrom(const char *Source
) {
89 return orc_rt_CreateCWrapperFunctionResultFromString(Source
);
92 /// Copy from the given std::string (includes the null terminator).
93 static WrapperFunctionResult
copyFrom(const std::string
&Source
) {
94 return copyFrom(Source
.c_str());
97 /// Create an out-of-band error by copying the given string.
98 static WrapperFunctionResult
createOutOfBandError(const char *Msg
) {
99 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg
);
102 /// Create an out-of-band error by copying the given string.
103 static WrapperFunctionResult
createOutOfBandError(const std::string
&Msg
) {
104 return createOutOfBandError(Msg
.c_str());
107 template <typename SPSArgListT
, typename
... ArgTs
>
108 static WrapperFunctionResult
fromSPSArgs(const ArgTs
&...Args
) {
109 auto Result
= allocate(SPSArgListT::size(Args
...));
110 SPSOutputBuffer
OB(Result
.data(), Result
.size());
111 if (!SPSArgListT::serialize(OB
, Args
...))
112 return createOutOfBandError(
113 "Error serializing arguments to blob in call");
117 /// If this value is an out-of-band error then this returns the error message,
118 /// otherwise returns nullptr.
119 const char *getOutOfBandError() const {
120 return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R
);
124 orc_rt_CWrapperFunctionResult R
;
129 template <typename RetT
> class WrapperFunctionHandlerCaller
{
131 template <typename HandlerT
, typename ArgTupleT
, std::size_t... I
>
132 static decltype(auto) call(HandlerT
&&H
, ArgTupleT
&Args
,
133 std::index_sequence
<I
...>) {
134 return std::forward
<HandlerT
>(H
)(std::get
<I
>(Args
)...);
138 template <> class WrapperFunctionHandlerCaller
<void> {
140 template <typename HandlerT
, typename ArgTupleT
, std::size_t... I
>
141 static SPSEmpty
call(HandlerT
&&H
, ArgTupleT
&Args
,
142 std::index_sequence
<I
...>) {
143 std::forward
<HandlerT
>(H
)(std::get
<I
>(Args
)...);
148 template <typename WrapperFunctionImplT
,
149 template <typename
> class ResultSerializer
, typename
... SPSTagTs
>
150 class WrapperFunctionHandlerHelper
151 : public WrapperFunctionHandlerHelper
<
152 decltype(&std::remove_reference_t
<WrapperFunctionImplT
>::operator()),
153 ResultSerializer
, SPSTagTs
...> {};
155 template <typename RetT
, typename
... ArgTs
,
156 template <typename
> class ResultSerializer
, typename
... SPSTagTs
>
157 class WrapperFunctionHandlerHelper
<RetT(ArgTs
...), ResultSerializer
,
160 using ArgTuple
= std::tuple
<std::decay_t
<ArgTs
>...>;
161 using ArgIndices
= std::make_index_sequence
<std::tuple_size
<ArgTuple
>::value
>;
163 template <typename HandlerT
>
164 static WrapperFunctionResult
apply(HandlerT
&&H
, const char *ArgData
,
167 if (!deserialize(ArgData
, ArgSize
, Args
, ArgIndices
{}))
168 return WrapperFunctionResult::createOutOfBandError(
169 "Could not deserialize arguments for wrapper function call");
171 auto HandlerResult
= WrapperFunctionHandlerCaller
<RetT
>::call(
172 std::forward
<HandlerT
>(H
), Args
, ArgIndices
{});
174 return ResultSerializer
<decltype(HandlerResult
)>::serialize(
175 std::move(HandlerResult
));
179 template <std::size_t... I
>
180 static bool deserialize(const char *ArgData
, size_t ArgSize
, ArgTuple
&Args
,
181 std::index_sequence
<I
...>) {
182 SPSInputBuffer
IB(ArgData
, ArgSize
);
183 return SPSArgList
<SPSTagTs
...>::deserialize(IB
, std::get
<I
>(Args
)...);
187 // Map function pointers to function types.
188 template <typename RetT
, typename
... ArgTs
,
189 template <typename
> class ResultSerializer
, typename
... SPSTagTs
>
190 class WrapperFunctionHandlerHelper
<RetT (*)(ArgTs
...), ResultSerializer
,
192 : public WrapperFunctionHandlerHelper
<RetT(ArgTs
...), ResultSerializer
,
195 // Map non-const member function types to function types.
196 template <typename ClassT
, typename RetT
, typename
... ArgTs
,
197 template <typename
> class ResultSerializer
, typename
... SPSTagTs
>
198 class WrapperFunctionHandlerHelper
<RetT (ClassT::*)(ArgTs
...), ResultSerializer
,
200 : public WrapperFunctionHandlerHelper
<RetT(ArgTs
...), ResultSerializer
,
203 // Map const member function types to function types.
204 template <typename ClassT
, typename RetT
, typename
... ArgTs
,
205 template <typename
> class ResultSerializer
, typename
... SPSTagTs
>
206 class WrapperFunctionHandlerHelper
<RetT (ClassT::*)(ArgTs
...) const,
207 ResultSerializer
, SPSTagTs
...>
208 : public WrapperFunctionHandlerHelper
<RetT(ArgTs
...), ResultSerializer
,
211 template <typename SPSRetTagT
, typename RetT
> class ResultSerializer
{
213 static WrapperFunctionResult
serialize(RetT Result
) {
214 return WrapperFunctionResult::fromSPSArgs
<SPSArgList
<SPSRetTagT
>>(Result
);
218 template <typename SPSRetTagT
> class ResultSerializer
<SPSRetTagT
, Error
> {
220 static WrapperFunctionResult
serialize(Error Err
) {
221 return WrapperFunctionResult::fromSPSArgs
<SPSArgList
<SPSRetTagT
>>(
222 toSPSSerializable(std::move(Err
)));
226 template <typename SPSRetTagT
, typename T
>
227 class ResultSerializer
<SPSRetTagT
, Expected
<T
>> {
229 static WrapperFunctionResult
serialize(Expected
<T
> E
) {
230 return WrapperFunctionResult::fromSPSArgs
<SPSArgList
<SPSRetTagT
>>(
231 toSPSSerializable(std::move(E
)));
235 template <typename SPSRetTagT
, typename RetT
> class ResultDeserializer
{
237 static void makeSafe(RetT
&Result
) {}
239 static Error
deserialize(RetT
&Result
, const char *ArgData
, size_t ArgSize
) {
240 SPSInputBuffer
IB(ArgData
, ArgSize
);
241 if (!SPSArgList
<SPSRetTagT
>::deserialize(IB
, Result
))
242 return make_error
<StringError
>(
243 "Error deserializing return value from blob in call");
244 return Error::success();
248 template <> class ResultDeserializer
<SPSError
, Error
> {
250 static void makeSafe(Error
&Err
) { cantFail(std::move(Err
)); }
252 static Error
deserialize(Error
&Err
, const char *ArgData
, size_t ArgSize
) {
253 SPSInputBuffer
IB(ArgData
, ArgSize
);
254 SPSSerializableError BSE
;
255 if (!SPSArgList
<SPSError
>::deserialize(IB
, BSE
))
256 return make_error
<StringError
>(
257 "Error deserializing return value from blob in call");
258 Err
= fromSPSSerializable(std::move(BSE
));
259 return Error::success();
263 template <typename SPSTagT
, typename T
>
264 class ResultDeserializer
<SPSExpected
<SPSTagT
>, Expected
<T
>> {
266 static void makeSafe(Expected
<T
> &E
) { cantFail(E
.takeError()); }
268 static Error
deserialize(Expected
<T
> &E
, const char *ArgData
,
270 SPSInputBuffer
IB(ArgData
, ArgSize
);
271 SPSSerializableExpected
<T
> BSE
;
272 if (!SPSArgList
<SPSExpected
<SPSTagT
>>::deserialize(IB
, BSE
))
273 return make_error
<StringError
>(
274 "Error deserializing return value from blob in call");
275 E
= fromSPSSerializable(std::move(BSE
));
276 return Error::success();
280 } // end namespace detail
282 template <typename SPSSignature
> class WrapperFunction
;
284 template <typename SPSRetTagT
, typename
... SPSTagTs
>
285 class WrapperFunction
<SPSRetTagT(SPSTagTs
...)> {
287 template <typename RetT
>
288 using ResultSerializer
= detail::ResultSerializer
<SPSRetTagT
, RetT
>;
291 template <typename RetT
, typename
... ArgTs
>
292 static Error
call(const void *FnTag
, RetT
&Result
, const ArgTs
&...Args
) {
294 // RetT might be an Error or Expected value. Set the checked flag now:
295 // we don't want the user to have to check the unused result if this
297 detail::ResultDeserializer
<SPSRetTagT
, RetT
>::makeSafe(Result
);
299 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx
))
300 return make_error
<StringError
>("__orc_rt_jit_dispatch_ctx not set");
301 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch
))
302 return make_error
<StringError
>("__orc_rt_jit_dispatch not set");
305 WrapperFunctionResult::fromSPSArgs
<SPSArgList
<SPSTagTs
...>>(Args
...);
306 if (const char *ErrMsg
= ArgBuffer
.getOutOfBandError())
307 return make_error
<StringError
>(ErrMsg
);
309 WrapperFunctionResult ResultBuffer
= __orc_rt_jit_dispatch(
310 &__orc_rt_jit_dispatch_ctx
, FnTag
, ArgBuffer
.data(), ArgBuffer
.size());
311 if (auto ErrMsg
= ResultBuffer
.getOutOfBandError())
312 return make_error
<StringError
>(ErrMsg
);
314 return detail::ResultDeserializer
<SPSRetTagT
, RetT
>::deserialize(
315 Result
, ResultBuffer
.data(), ResultBuffer
.size());
318 template <typename HandlerT
>
319 static WrapperFunctionResult
handle(const char *ArgData
, size_t ArgSize
,
320 HandlerT
&&Handler
) {
322 detail::WrapperFunctionHandlerHelper
<std::remove_reference_t
<HandlerT
>,
323 ResultSerializer
, SPSTagTs
...>;
324 return WFHH::apply(std::forward
<HandlerT
>(Handler
), ArgData
, ArgSize
);
328 template <typename T
> static const T
&makeSerializable(const T
&Value
) {
332 static detail::SPSSerializableError
makeSerializable(Error Err
) {
333 return detail::toSPSSerializable(std::move(Err
));
336 template <typename T
>
337 static detail::SPSSerializableExpected
<T
> makeSerializable(Expected
<T
> E
) {
338 return detail::toSPSSerializable(std::move(E
));
342 template <typename
... SPSTagTs
>
343 class WrapperFunction
<void(SPSTagTs
...)>
344 : private WrapperFunction
<SPSEmpty(SPSTagTs
...)> {
346 template <typename
... ArgTs
>
347 static Error
call(const void *FnTag
, const ArgTs
&...Args
) {
349 return WrapperFunction
<SPSEmpty(SPSTagTs
...)>::call(FnTag
, BE
, Args
...);
352 using WrapperFunction
<SPSEmpty(SPSTagTs
...)>::handle
;
355 /// A function object that takes an ExecutorAddr as its first argument,
356 /// casts that address to a ClassT*, then calls the given method on that
357 /// pointer passing in the remaining function arguments. This utility
358 /// removes some of the boilerplate from writing wrappers for method calls.
363 /// void myMethod(uint32_t, bool) { ... }
366 /// // SPS Method signature -- note MyClass object address as first argument.
367 /// using SPSMyMethodWrapperSignature =
368 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
370 /// WrapperFunctionResult
371 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
372 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
373 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
377 template <typename RetT
, typename ClassT
, typename
... ArgTs
>
378 class MethodWrapperHandler
{
380 using MethodT
= RetT (ClassT::*)(ArgTs
...);
381 MethodWrapperHandler(MethodT M
) : M(M
) {}
382 RetT
operator()(ExecutorAddr ObjAddr
, ArgTs
&...Args
) {
383 return (ObjAddr
.toPtr
<ClassT
*>()->*M
)(std::forward
<ArgTs
>(Args
)...);
390 /// Create a MethodWrapperHandler object from the given method pointer.
391 template <typename RetT
, typename ClassT
, typename
... ArgTs
>
392 MethodWrapperHandler
<RetT
, ClassT
, ArgTs
...>
393 makeMethodWrapperHandler(RetT (ClassT::*Method
)(ArgTs
...)) {
394 return MethodWrapperHandler
<RetT
, ClassT
, ArgTs
...>(Method
);
397 /// Represents a call to a wrapper function.
398 class WrapperFunctionCall
{
400 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
402 using ArgDataBufferType
= std::vector
<char>;
404 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
406 template <typename SPSSerializer
, typename
... ArgTs
>
407 static Expected
<WrapperFunctionCall
> Create(ExecutorAddr FnAddr
,
408 const ArgTs
&...Args
) {
409 ArgDataBufferType ArgData
;
410 ArgData
.resize(SPSSerializer::size(Args
...));
411 SPSOutputBuffer
OB(ArgData
.empty() ? nullptr : ArgData
.data(),
413 if (SPSSerializer::serialize(OB
, Args
...))
414 return WrapperFunctionCall(FnAddr
, std::move(ArgData
));
415 return make_error
<StringError
>("Cannot serialize arguments for "
419 WrapperFunctionCall() = default;
421 /// Create a WrapperFunctionCall from a target function and arg buffer.
422 WrapperFunctionCall(ExecutorAddr FnAddr
, ArgDataBufferType ArgData
)
423 : FnAddr(FnAddr
), ArgData(std::move(ArgData
)) {}
425 /// Returns the address to be called.
426 const ExecutorAddr
&getCallee() const { return FnAddr
; }
428 /// Returns the argument data.
429 const ArgDataBufferType
&getArgData() const { return ArgData
; }
431 /// WrapperFunctionCalls convert to true if the callee is non-null.
432 explicit operator bool() const { return !!FnAddr
; }
434 /// Run call returning raw WrapperFunctionResult.
435 WrapperFunctionResult
run() const {
437 orc_rt_CWrapperFunctionResult(const char *ArgData
, size_t ArgSize
);
438 return WrapperFunctionResult(
439 FnAddr
.toPtr
<FnTy
*>()(ArgData
.data(), ArgData
.size()));
442 /// Run call and deserialize result using SPS.
443 template <typename SPSRetT
, typename RetT
>
444 std::enable_if_t
<!std::is_same
<SPSRetT
, void>::value
, Error
>
445 runWithSPSRet(RetT
&RetVal
) const {
447 if (const char *ErrMsg
= WFR
.getOutOfBandError())
448 return make_error
<StringError
>(ErrMsg
);
449 SPSInputBuffer
IB(WFR
.data(), WFR
.size());
450 if (!SPSSerializationTraits
<SPSRetT
, RetT
>::deserialize(IB
, RetVal
))
451 return make_error
<StringError
>("Could not deserialize result from "
452 "serialized wrapper function call");
453 return Error::success();
456 /// Overload for SPS functions returning void.
457 template <typename SPSRetT
>
458 std::enable_if_t
<std::is_same
<SPSRetT
, void>::value
, Error
>
459 runWithSPSRet() const {
461 return runWithSPSRet
<SPSEmpty
>(E
);
464 /// Run call and deserialize an SPSError result. SPSError returns and
465 /// deserialization failures are merged into the returned error.
466 Error
runWithSPSRetErrorMerged() const {
467 detail::SPSSerializableError RetErr
;
468 if (auto Err
= runWithSPSRet
<SPSError
>(RetErr
))
470 return detail::fromSPSSerializable(std::move(RetErr
));
475 std::vector
<char> ArgData
;
478 using SPSWrapperFunctionCall
= SPSTuple
<SPSExecutorAddr
, SPSSequence
<char>>;
481 class SPSSerializationTraits
<SPSWrapperFunctionCall
, WrapperFunctionCall
> {
483 static size_t size(const WrapperFunctionCall
&WFC
) {
484 return SPSArgList
<SPSExecutorAddr
, SPSSequence
<char>>::size(
485 WFC
.getCallee(), WFC
.getArgData());
488 static bool serialize(SPSOutputBuffer
&OB
, const WrapperFunctionCall
&WFC
) {
489 return SPSArgList
<SPSExecutorAddr
, SPSSequence
<char>>::serialize(
490 OB
, WFC
.getCallee(), WFC
.getArgData());
493 static bool deserialize(SPSInputBuffer
&IB
, WrapperFunctionCall
&WFC
) {
495 WrapperFunctionCall::ArgDataBufferType ArgData
;
496 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB
, FnAddr
, ArgData
))
498 WFC
= WrapperFunctionCall(FnAddr
, std::move(ArgData
));
503 } // end namespace __orc_rt
505 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H