1 //===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
10 #define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
13 #include "llvm/Support/thread.h"
28 /// TypeNameSequence is a utility for rendering sequences of types to a string
29 /// by rendering each type, separated by ", ".
30 template <typename
... ArgTs
> class RPCTypeNameSequence
{};
32 /// Render an empty TypeNameSequence to an ostream.
33 template <typename OStream
>
34 OStream
&operator<<(OStream
&OS
, const RPCTypeNameSequence
<> &V
) {
38 /// Render a TypeNameSequence of a single type to an ostream.
39 template <typename OStream
, typename ArgT
>
40 OStream
&operator<<(OStream
&OS
, const RPCTypeNameSequence
<ArgT
> &V
) {
41 OS
<< RPCTypeName
<ArgT
>::getName();
45 /// Render a TypeNameSequence of more than one type to an ostream.
46 template <typename OStream
, typename ArgT1
, typename ArgT2
, typename
... ArgTs
>
48 operator<<(OStream
&OS
, const RPCTypeNameSequence
<ArgT1
, ArgT2
, ArgTs
...> &V
) {
49 OS
<< RPCTypeName
<ArgT1
>::getName() << ", "
50 << RPCTypeNameSequence
<ArgT2
, ArgTs
...>();
55 class RPCTypeName
<void> {
57 static const char* getName() { return "void"; }
61 class RPCTypeName
<int8_t> {
63 static const char* getName() { return "int8_t"; }
67 class RPCTypeName
<uint8_t> {
69 static const char* getName() { return "uint8_t"; }
73 class RPCTypeName
<int16_t> {
75 static const char* getName() { return "int16_t"; }
79 class RPCTypeName
<uint16_t> {
81 static const char* getName() { return "uint16_t"; }
85 class RPCTypeName
<int32_t> {
87 static const char* getName() { return "int32_t"; }
91 class RPCTypeName
<uint32_t> {
93 static const char* getName() { return "uint32_t"; }
97 class RPCTypeName
<int64_t> {
99 static const char* getName() { return "int64_t"; }
103 class RPCTypeName
<uint64_t> {
105 static const char* getName() { return "uint64_t"; }
109 class RPCTypeName
<bool> {
111 static const char* getName() { return "bool"; }
115 class RPCTypeName
<std::string
> {
117 static const char* getName() { return "std::string"; }
121 class RPCTypeName
<Error
> {
123 static const char* getName() { return "Error"; }
126 template <typename T
>
127 class RPCTypeName
<Expected
<T
>> {
129 static const char* getName() {
130 static std::string Name
= [] {
132 raw_string_ostream(Name
) << "Expected<"
133 << RPCTypeNameSequence
<T
>()
141 template <typename T1
, typename T2
>
142 class RPCTypeName
<std::pair
<T1
, T2
>> {
144 static const char* getName() {
145 static std::string Name
= [] {
147 raw_string_ostream(Name
) << "std::pair<" << RPCTypeNameSequence
<T1
, T2
>()
155 template <typename
... ArgTs
>
156 class RPCTypeName
<std::tuple
<ArgTs
...>> {
158 static const char* getName() {
159 static std::string Name
= [] {
161 raw_string_ostream(Name
) << "std::tuple<"
162 << RPCTypeNameSequence
<ArgTs
...>() << ">";
169 template <typename T
>
170 class RPCTypeName
<std::vector
<T
>> {
172 static const char*getName() {
173 static std::string Name
= [] {
175 raw_string_ostream(Name
) << "std::vector<" << RPCTypeName
<T
>::getName()
183 template <typename T
> class RPCTypeName
<std::set
<T
>> {
185 static const char *getName() {
186 static std::string Name
= [] {
188 raw_string_ostream(Name
)
189 << "std::set<" << RPCTypeName
<T
>::getName() << ">";
196 template <typename K
, typename V
> class RPCTypeName
<std::map
<K
, V
>> {
198 static const char *getName() {
199 static std::string Name
= [] {
201 raw_string_ostream(Name
)
202 << "std::map<" << RPCTypeNameSequence
<K
, V
>() << ">";
209 /// The SerializationTraits<ChannelT, T> class describes how to serialize and
210 /// deserialize an instance of type T to/from an abstract channel of type
211 /// ChannelT. It also provides a representation of the type's name via the
214 /// Specializations of this class should provide the following functions:
218 /// static const char* getName();
219 /// static Error serialize(ChannelT&, const T&);
220 /// static Error deserialize(ChannelT&, T&);
224 /// The third argument of SerializationTraits is intended to support SFINAE.
229 /// class MyVirtualChannel { ... };
231 /// template <DerivedChannelT>
232 /// class SerializationTraits<DerivedChannelT, bool,
233 /// typename std::enable_if<
234 /// std::is_base_of<VirtChannel, DerivedChannel>::value
237 /// static const char* getName() { ... };
241 template <typename ChannelT
, typename WireType
,
242 typename ConcreteType
= WireType
, typename
= void>
243 class SerializationTraits
;
245 template <typename ChannelT
>
246 class SequenceTraits
{
248 static Error
emitSeparator(ChannelT
&C
) { return Error::success(); }
249 static Error
consumeSeparator(ChannelT
&C
) { return Error::success(); }
252 /// Utility class for serializing sequences of values of varying types.
253 /// Specializations of this class contain 'serialize' and 'deserialize' methods
254 /// for the given channel. The ArgTs... list will determine the "over-the-wire"
255 /// types to be serialized. The serialize and deserialize methods take a list
256 /// CArgTs... ("caller arg types") which must be the same length as ArgTs...,
257 /// but may be different types from ArgTs, provided that for each CArgT there
258 /// is a SerializationTraits specialization
259 /// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the
260 /// caller argument to over-the-wire value.
261 template <typename ChannelT
, typename
... ArgTs
>
262 class SequenceSerialization
;
264 template <typename ChannelT
>
265 class SequenceSerialization
<ChannelT
> {
267 static Error
serialize(ChannelT
&C
) { return Error::success(); }
268 static Error
deserialize(ChannelT
&C
) { return Error::success(); }
271 template <typename ChannelT
, typename ArgT
>
272 class SequenceSerialization
<ChannelT
, ArgT
> {
275 template <typename CArgT
>
276 static Error
serialize(ChannelT
&C
, CArgT
&&CArg
) {
277 return SerializationTraits
<ChannelT
, ArgT
,
278 typename
std::decay
<CArgT
>::type
>::
279 serialize(C
, std::forward
<CArgT
>(CArg
));
282 template <typename CArgT
>
283 static Error
deserialize(ChannelT
&C
, CArgT
&CArg
) {
284 return SerializationTraits
<ChannelT
, ArgT
, CArgT
>::deserialize(C
, CArg
);
288 template <typename ChannelT
, typename ArgT
, typename
... ArgTs
>
289 class SequenceSerialization
<ChannelT
, ArgT
, ArgTs
...> {
292 template <typename CArgT
, typename
... CArgTs
>
293 static Error
serialize(ChannelT
&C
, CArgT
&&CArg
,
294 CArgTs
&&... CArgs
) {
296 SerializationTraits
<ChannelT
, ArgT
, typename
std::decay
<CArgT
>::type
>::
297 serialize(C
, std::forward
<CArgT
>(CArg
)))
299 if (auto Err
= SequenceTraits
<ChannelT
>::emitSeparator(C
))
301 return SequenceSerialization
<ChannelT
, ArgTs
...>::
302 serialize(C
, std::forward
<CArgTs
>(CArgs
)...);
305 template <typename CArgT
, typename
... CArgTs
>
306 static Error
deserialize(ChannelT
&C
, CArgT
&CArg
,
309 SerializationTraits
<ChannelT
, ArgT
, CArgT
>::deserialize(C
, CArg
))
311 if (auto Err
= SequenceTraits
<ChannelT
>::consumeSeparator(C
))
313 return SequenceSerialization
<ChannelT
, ArgTs
...>::deserialize(C
, CArgs
...);
317 template <typename ChannelT
, typename
... ArgTs
>
318 Error
serializeSeq(ChannelT
&C
, ArgTs
&&... Args
) {
319 return SequenceSerialization
<ChannelT
, typename
std::decay
<ArgTs
>::type
...>::
320 serialize(C
, std::forward
<ArgTs
>(Args
)...);
323 template <typename ChannelT
, typename
... ArgTs
>
324 Error
deserializeSeq(ChannelT
&C
, ArgTs
&... Args
) {
325 return SequenceSerialization
<ChannelT
, ArgTs
...>::deserialize(C
, Args
...);
328 template <typename ChannelT
>
329 class SerializationTraits
<ChannelT
, Error
> {
332 using WrappedErrorSerializer
=
333 std::function
<Error(ChannelT
&C
, const ErrorInfoBase
&)>;
335 using WrappedErrorDeserializer
=
336 std::function
<Error(ChannelT
&C
, Error
&Err
)>;
338 template <typename ErrorInfoT
, typename SerializeFtor
,
339 typename DeserializeFtor
>
340 static void registerErrorType(std::string Name
, SerializeFtor Serialize
,
341 DeserializeFtor Deserialize
) {
342 assert(!Name
.empty() &&
343 "The empty string is reserved for the Success value");
345 const std::string
*KeyName
= nullptr;
347 // We're abusing the stability of std::map here: We take a reference to the
348 // key of the deserializers map to save us from duplicating the string in
349 // the serializer. This should be changed to use a stringpool if we switch
350 // to a map type that may move keys in memory.
351 std::lock_guard
<std::recursive_mutex
> Lock(DeserializersMutex
);
353 Deserializers
.insert(Deserializers
.begin(),
354 std::make_pair(std::move(Name
),
355 std::move(Deserialize
)));
360 assert(KeyName
!= nullptr && "No keyname pointer");
361 std::lock_guard
<std::recursive_mutex
> Lock(SerializersMutex
);
362 Serializers
[ErrorInfoT::classID()] =
363 [KeyName
, Serialize
= std::move(Serialize
)](
364 ChannelT
&C
, const ErrorInfoBase
&EIB
) -> Error
{
365 assert(EIB
.dynamicClassID() == ErrorInfoT::classID() &&
366 "Serializer called for wrong error type");
367 if (auto Err
= serializeSeq(C
, *KeyName
))
369 return Serialize(C
, static_cast<const ErrorInfoT
&>(EIB
));
374 static Error
serialize(ChannelT
&C
, Error
&&Err
) {
375 std::lock_guard
<std::recursive_mutex
> Lock(SerializersMutex
);
378 return serializeSeq(C
, std::string());
380 return handleErrors(std::move(Err
),
381 [&C
](const ErrorInfoBase
&EIB
) {
382 auto SI
= Serializers
.find(EIB
.dynamicClassID());
383 if (SI
== Serializers
.end())
384 return serializeAsStringError(C
, EIB
);
385 return (SI
->second
)(C
, EIB
);
389 static Error
deserialize(ChannelT
&C
, Error
&Err
) {
390 std::lock_guard
<std::recursive_mutex
> Lock(DeserializersMutex
);
393 if (auto Err
= deserializeSeq(C
, Key
))
397 ErrorAsOutParameter
EAO(&Err
);
398 Err
= Error::success();
399 return Error::success();
402 auto DI
= Deserializers
.find(Key
);
403 assert(DI
!= Deserializers
.end() && "No deserializer for error type");
404 return (DI
->second
)(C
, Err
);
409 static Error
serializeAsStringError(ChannelT
&C
, const ErrorInfoBase
&EIB
) {
412 raw_string_ostream
ErrMsgStream(ErrMsg
);
413 EIB
.log(ErrMsgStream
);
415 return serialize(C
, make_error
<StringError
>(std::move(ErrMsg
),
416 inconvertibleErrorCode()));
419 static std::recursive_mutex SerializersMutex
;
420 static std::recursive_mutex DeserializersMutex
;
421 static std::map
<const void*, WrappedErrorSerializer
> Serializers
;
422 static std::map
<std::string
, WrappedErrorDeserializer
> Deserializers
;
425 template <typename ChannelT
>
426 std::recursive_mutex SerializationTraits
<ChannelT
, Error
>::SerializersMutex
;
428 template <typename ChannelT
>
429 std::recursive_mutex SerializationTraits
<ChannelT
, Error
>::DeserializersMutex
;
431 template <typename ChannelT
>
432 std::map
<const void*,
433 typename SerializationTraits
<ChannelT
, Error
>::WrappedErrorSerializer
>
434 SerializationTraits
<ChannelT
, Error
>::Serializers
;
436 template <typename ChannelT
>
437 std::map
<std::string
,
438 typename SerializationTraits
<ChannelT
, Error
>::WrappedErrorDeserializer
>
439 SerializationTraits
<ChannelT
, Error
>::Deserializers
;
441 /// Registers a serializer and deserializer for the given error type on the
442 /// given channel type.
443 template <typename ChannelT
, typename ErrorInfoT
, typename SerializeFtor
,
444 typename DeserializeFtor
>
445 void registerErrorSerialization(std::string Name
, SerializeFtor
&&Serialize
,
446 DeserializeFtor
&&Deserialize
) {
447 SerializationTraits
<ChannelT
, Error
>::template registerErrorType
<ErrorInfoT
>(
449 std::forward
<SerializeFtor
>(Serialize
),
450 std::forward
<DeserializeFtor
>(Deserialize
));
453 /// Registers serialization/deserialization for StringError.
454 template <typename ChannelT
>
455 void registerStringError() {
456 static bool AlreadyRegistered
= false;
457 if (!AlreadyRegistered
) {
458 registerErrorSerialization
<ChannelT
, StringError
>(
460 [](ChannelT
&C
, const StringError
&SE
) {
461 return serializeSeq(C
, SE
.getMessage());
463 [](ChannelT
&C
, Error
&Err
) -> Error
{
464 ErrorAsOutParameter
EAO(&Err
);
466 if (auto E2
= deserializeSeq(C
, Msg
))
469 make_error
<StringError
>(std::move(Msg
),
471 OrcErrorCode::UnknownErrorCodeFromRemote
));
472 return Error::success();
474 AlreadyRegistered
= true;
478 /// SerializationTraits for Expected<T1> from an Expected<T2>.
479 template <typename ChannelT
, typename T1
, typename T2
>
480 class SerializationTraits
<ChannelT
, Expected
<T1
>, Expected
<T2
>> {
483 static Error
serialize(ChannelT
&C
, Expected
<T2
> &&ValOrErr
) {
485 if (auto Err
= serializeSeq(C
, true))
487 return SerializationTraits
<ChannelT
, T1
, T2
>::serialize(C
, *ValOrErr
);
489 if (auto Err
= serializeSeq(C
, false))
491 return serializeSeq(C
, ValOrErr
.takeError());
494 static Error
deserialize(ChannelT
&C
, Expected
<T2
> &ValOrErr
) {
495 ExpectedAsOutParameter
<T2
> EAO(&ValOrErr
);
497 if (auto Err
= deserializeSeq(C
, HasValue
))
500 return SerializationTraits
<ChannelT
, T1
, T2
>::deserialize(C
, *ValOrErr
);
501 Error Err
= Error::success();
502 if (auto E2
= deserializeSeq(C
, Err
))
504 ValOrErr
= std::move(Err
);
505 return Error::success();
509 /// SerializationTraits for Expected<T1> from a T2.
510 template <typename ChannelT
, typename T1
, typename T2
>
511 class SerializationTraits
<ChannelT
, Expected
<T1
>, T2
> {
514 static Error
serialize(ChannelT
&C
, T2
&&Val
) {
515 return serializeSeq(C
, Expected
<T2
>(std::forward
<T2
>(Val
)));
519 /// SerializationTraits for Expected<T1> from an Error.
520 template <typename ChannelT
, typename T
>
521 class SerializationTraits
<ChannelT
, Expected
<T
>, Error
> {
524 static Error
serialize(ChannelT
&C
, Error
&&Err
) {
525 return serializeSeq(C
, Expected
<T
>(std::move(Err
)));
529 /// SerializationTraits default specialization for std::pair.
530 template <typename ChannelT
, typename T1
, typename T2
, typename T3
, typename T4
>
531 class SerializationTraits
<ChannelT
, std::pair
<T1
, T2
>, std::pair
<T3
, T4
>> {
533 static Error
serialize(ChannelT
&C
, const std::pair
<T3
, T4
> &V
) {
534 if (auto Err
= SerializationTraits
<ChannelT
, T1
, T3
>::serialize(C
, V
.first
))
536 return SerializationTraits
<ChannelT
, T2
, T4
>::serialize(C
, V
.second
);
539 static Error
deserialize(ChannelT
&C
, std::pair
<T3
, T4
> &V
) {
541 SerializationTraits
<ChannelT
, T1
, T3
>::deserialize(C
, V
.first
))
543 return SerializationTraits
<ChannelT
, T2
, T4
>::deserialize(C
, V
.second
);
547 /// SerializationTraits default specialization for std::tuple.
548 template <typename ChannelT
, typename
... ArgTs
>
549 class SerializationTraits
<ChannelT
, std::tuple
<ArgTs
...>> {
552 /// RPC channel serialization for std::tuple.
553 static Error
serialize(ChannelT
&C
, const std::tuple
<ArgTs
...> &V
) {
554 return serializeTupleHelper(C
, V
, std::index_sequence_for
<ArgTs
...>());
557 /// RPC channel deserialization for std::tuple.
558 static Error
deserialize(ChannelT
&C
, std::tuple
<ArgTs
...> &V
) {
559 return deserializeTupleHelper(C
, V
, std::index_sequence_for
<ArgTs
...>());
563 // Serialization helper for std::tuple.
564 template <size_t... Is
>
565 static Error
serializeTupleHelper(ChannelT
&C
, const std::tuple
<ArgTs
...> &V
,
566 std::index_sequence
<Is
...> _
) {
567 return serializeSeq(C
, std::get
<Is
>(V
)...);
570 // Serialization helper for std::tuple.
571 template <size_t... Is
>
572 static Error
deserializeTupleHelper(ChannelT
&C
, std::tuple
<ArgTs
...> &V
,
573 std::index_sequence
<Is
...> _
) {
574 return deserializeSeq(C
, std::get
<Is
>(V
)...);
578 /// SerializationTraits default specialization for std::vector.
579 template <typename ChannelT
, typename T
>
580 class SerializationTraits
<ChannelT
, std::vector
<T
>> {
583 /// Serialize a std::vector<T> from std::vector<T>.
584 static Error
serialize(ChannelT
&C
, const std::vector
<T
> &V
) {
585 if (auto Err
= serializeSeq(C
, static_cast<uint64_t>(V
.size())))
588 for (const auto &E
: V
)
589 if (auto Err
= serializeSeq(C
, E
))
592 return Error::success();
595 /// Deserialize a std::vector<T> to a std::vector<T>.
596 static Error
deserialize(ChannelT
&C
, std::vector
<T
> &V
) {
598 "Expected default-constructed vector to deserialize into");
601 if (auto Err
= deserializeSeq(C
, Count
))
606 if (auto Err
= deserializeSeq(C
, E
))
609 return Error::success();
613 template <typename ChannelT
, typename T
, typename T2
>
614 class SerializationTraits
<ChannelT
, std::set
<T
>, std::set
<T2
>> {
616 /// Serialize a std::set<T> from std::set<T2>.
617 static Error
serialize(ChannelT
&C
, const std::set
<T2
> &S
) {
618 if (auto Err
= serializeSeq(C
, static_cast<uint64_t>(S
.size())))
621 for (const auto &E
: S
)
622 if (auto Err
= SerializationTraits
<ChannelT
, T
, T2
>::serialize(C
, E
))
625 return Error::success();
628 /// Deserialize a std::set<T> to a std::set<T>.
629 static Error
deserialize(ChannelT
&C
, std::set
<T2
> &S
) {
630 assert(S
.empty() && "Expected default-constructed set to deserialize into");
633 if (auto Err
= deserializeSeq(C
, Count
))
636 while (Count
-- != 0) {
638 if (auto Err
= SerializationTraits
<ChannelT
, T
, T2
>::deserialize(C
, Val
))
641 auto Added
= S
.insert(Val
).second
;
643 return make_error
<StringError
>("Duplicate element in deserialized set",
644 orcError(OrcErrorCode::UnknownORCError
));
647 return Error::success();
651 template <typename ChannelT
, typename K
, typename V
, typename K2
, typename V2
>
652 class SerializationTraits
<ChannelT
, std::map
<K
, V
>, std::map
<K2
, V2
>> {
654 /// Serialize a std::map<K, V> from std::map<K2, V2>.
655 static Error
serialize(ChannelT
&C
, const std::map
<K2
, V2
> &M
) {
656 if (auto Err
= serializeSeq(C
, static_cast<uint64_t>(M
.size())))
659 for (const auto &E
: M
) {
661 SerializationTraits
<ChannelT
, K
, K2
>::serialize(C
, E
.first
))
664 SerializationTraits
<ChannelT
, V
, V2
>::serialize(C
, E
.second
))
668 return Error::success();
671 /// Deserialize a std::map<K, V> to a std::map<K, V>.
672 static Error
deserialize(ChannelT
&C
, std::map
<K2
, V2
> &M
) {
673 assert(M
.empty() && "Expected default-constructed map to deserialize into");
676 if (auto Err
= deserializeSeq(C
, Count
))
679 while (Count
-- != 0) {
680 std::pair
<K2
, V2
> Val
;
682 SerializationTraits
<ChannelT
, K
, K2
>::deserialize(C
, Val
.first
))
686 SerializationTraits
<ChannelT
, V
, V2
>::deserialize(C
, Val
.second
))
689 auto Added
= M
.insert(Val
).second
;
691 return make_error
<StringError
>("Duplicate element in deserialized map",
692 orcError(OrcErrorCode::UnknownORCError
));
695 return Error::success();
699 } // end namespace rpc
700 } // end namespace orc
701 } // end namespace llvm
703 #endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H