[InstCombine] Signed saturation patterns
[llvm-core.git] / include / llvm / ExecutionEngine / Orc / RPCUtils.h
blobee9c2cc69c30e0023cc3f8a89687f25be70d912b
1 //===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities to support construction of simple RPC APIs.
11 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12 // programmers, high performance, low memory overhead, and efficient use of the
13 // communications channel.
15 //===----------------------------------------------------------------------===//
17 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
18 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20 #include <map>
21 #include <thread>
22 #include <vector>
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ExecutionEngine/Orc/OrcError.h"
26 #include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
27 #include "llvm/Support/MSVCErrorWorkarounds.h"
29 #include <future>
31 namespace llvm {
32 namespace orc {
33 namespace rpc {
35 /// Base class of all fatal RPC errors (those that necessarily result in the
36 /// termination of the RPC session).
37 class RPCFatalError : public ErrorInfo<RPCFatalError> {
38 public:
39 static char ID;
42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
43 /// has already been closed due to either an error or graceful disconnection.
44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45 public:
46 static char ID;
47 std::error_code convertToErrorCode() const override;
48 void log(raw_ostream &OS) const override;
51 /// BadFunctionCall is returned from handleOne when the remote makes a call with
52 /// an unrecognized function id.
53 ///
54 /// This error is fatal because Orc RPC needs to know how to parse a function
55 /// call to know where the next call starts, and if it doesn't recognize the
56 /// function id it cannot parse the call.
57 template <typename FnIdT, typename SeqNoT>
58 class BadFunctionCall
59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60 public:
61 static char ID;
63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
66 std::error_code convertToErrorCode() const override {
67 return orcError(OrcErrorCode::UnexpectedRPCCall);
70 void log(raw_ostream &OS) const override {
71 OS << "Call to invalid RPC function id '" << FnId << "' with "
72 "sequence number " << SeqNo;
75 private:
76 FnIdT FnId;
77 SeqNoT SeqNo;
80 template <typename FnIdT, typename SeqNoT>
81 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
83 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
84 /// call arrives with a sequence number that doesn't correspond to any in-flight
85 /// function call.
86 ///
87 /// This error is fatal because Orc RPC needs to know how to parse the rest of
88 /// the response call to know where the next call starts, and if it doesn't have
89 /// a result parser for this sequence number it can't do that.
90 template <typename SeqNoT>
91 class InvalidSequenceNumberForResponse
92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93 public:
94 static char ID;
96 InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97 : SeqNo(std::move(SeqNo)) {}
99 std::error_code convertToErrorCode() const override {
100 return orcError(OrcErrorCode::UnexpectedRPCCall);
103 void log(raw_ostream &OS) const override {
104 OS << "Response has unknown sequence number " << SeqNo;
106 private:
107 SeqNoT SeqNo;
110 template <typename SeqNoT>
111 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
113 /// This non-fatal error will be passed to asynchronous result handlers in place
114 /// of a result if the connection goes down before a result returns, or if the
115 /// function to be called cannot be negotiated with the remote.
116 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117 public:
118 static char ID;
120 std::error_code convertToErrorCode() const override;
121 void log(raw_ostream &OS) const override;
124 /// This error is returned if the remote does not have a handler installed for
125 /// the given RPC function.
126 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127 public:
128 static char ID;
130 CouldNotNegotiate(std::string Signature);
131 std::error_code convertToErrorCode() const override;
132 void log(raw_ostream &OS) const override;
133 const std::string &getSignature() const { return Signature; }
134 private:
135 std::string Signature;
138 template <typename DerivedFunc, typename FnT> class Function;
140 // RPC Function class.
141 // DerivedFunc should be a user defined class with a static 'getName()' method
142 // returning a const char* representing the function's name.
143 template <typename DerivedFunc, typename RetT, typename... ArgTs>
144 class Function<DerivedFunc, RetT(ArgTs...)> {
145 public:
146 /// User defined function type.
147 using Type = RetT(ArgTs...);
149 /// Return type.
150 using ReturnType = RetT;
152 /// Returns the full function prototype as a string.
153 static const char *getPrototype() {
154 static std::string Name = [] {
155 std::string Name;
156 raw_string_ostream(Name)
157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159 return Name;
160 }();
161 return Name.data();
165 /// Allocates RPC function ids during autonegotiation.
166 /// Specializations of this class must provide four members:
168 /// static T getInvalidId():
169 /// Should return a reserved id that will be used to represent missing
170 /// functions during autonegotiation.
172 /// static T getResponseId():
173 /// Should return a reserved id that will be used to send function responses
174 /// (return values).
176 /// static T getNegotiateId():
177 /// Should return a reserved id for the negotiate function, which will be used
178 /// to negotiate ids for user defined functions.
180 /// template <typename Func> T allocate():
181 /// Allocate a unique id for function Func.
182 template <typename T, typename = void> class RPCFunctionIdAllocator;
184 /// This specialization of RPCFunctionIdAllocator provides a default
185 /// implementation for integral types.
186 template <typename T>
187 class RPCFunctionIdAllocator<
188 T, typename std::enable_if<std::is_integral<T>::value>::type> {
189 public:
190 static T getInvalidId() { return T(0); }
191 static T getResponseId() { return T(1); }
192 static T getNegotiateId() { return T(2); }
194 template <typename Func> T allocate() { return NextId++; }
196 private:
197 T NextId = 3;
200 namespace detail {
202 /// Provides a typedef for a tuple containing the decayed argument types.
203 template <typename T> class FunctionArgsTuple;
205 template <typename RetT, typename... ArgTs>
206 class FunctionArgsTuple<RetT(ArgTs...)> {
207 public:
208 using Type = std::tuple<typename std::decay<
209 typename std::remove_reference<ArgTs>::type>::type...>;
212 // ResultTraits provides typedefs and utilities specific to the return type
213 // of functions.
214 template <typename RetT> class ResultTraits {
215 public:
216 // The return type wrapped in llvm::Expected.
217 using ErrorReturnType = Expected<RetT>;
219 #ifdef _MSC_VER
220 // The ErrorReturnType wrapped in a std::promise.
221 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
223 // The ErrorReturnType wrapped in a std::future.
224 using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
225 #else
226 // The ErrorReturnType wrapped in a std::promise.
227 using ReturnPromiseType = std::promise<ErrorReturnType>;
229 // The ErrorReturnType wrapped in a std::future.
230 using ReturnFutureType = std::future<ErrorReturnType>;
231 #endif
233 // Create a 'blank' value of the ErrorReturnType, ready and safe to
234 // overwrite.
235 static ErrorReturnType createBlankErrorReturnValue() {
236 return ErrorReturnType(RetT());
239 // Consume an abandoned ErrorReturnType.
240 static void consumeAbandoned(ErrorReturnType RetOrErr) {
241 consumeError(RetOrErr.takeError());
245 // ResultTraits specialization for void functions.
246 template <> class ResultTraits<void> {
247 public:
248 // For void functions, ErrorReturnType is llvm::Error.
249 using ErrorReturnType = Error;
251 #ifdef _MSC_VER
252 // The ErrorReturnType wrapped in a std::promise.
253 using ReturnPromiseType = std::promise<MSVCPError>;
255 // The ErrorReturnType wrapped in a std::future.
256 using ReturnFutureType = std::future<MSVCPError>;
257 #else
258 // The ErrorReturnType wrapped in a std::promise.
259 using ReturnPromiseType = std::promise<ErrorReturnType>;
261 // The ErrorReturnType wrapped in a std::future.
262 using ReturnFutureType = std::future<ErrorReturnType>;
263 #endif
265 // Create a 'blank' value of the ErrorReturnType, ready and safe to
266 // overwrite.
267 static ErrorReturnType createBlankErrorReturnValue() {
268 return ErrorReturnType::success();
271 // Consume an abandoned ErrorReturnType.
272 static void consumeAbandoned(ErrorReturnType Err) {
273 consumeError(std::move(Err));
277 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
278 // handlers for void RPC functions to return either void (in which case they
279 // implicitly succeed) or Error (in which case their error return is
280 // propagated). See usage in HandlerTraits::runHandlerHelper.
281 template <> class ResultTraits<Error> : public ResultTraits<void> {};
283 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
284 // handlers for RPC functions returning a T to return either a T (in which
285 // case they implicitly succeed) or Expected<T> (in which case their error
286 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
287 template <typename RetT>
288 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
290 // Determines whether an RPC function's defined error return type supports
291 // error return value.
292 template <typename T>
293 class SupportsErrorReturn {
294 public:
295 static const bool value = false;
298 template <>
299 class SupportsErrorReturn<Error> {
300 public:
301 static const bool value = true;
304 template <typename T>
305 class SupportsErrorReturn<Expected<T>> {
306 public:
307 static const bool value = true;
310 // RespondHelper packages return values based on whether or not the declared
311 // RPC function return type supports error returns.
312 template <bool FuncSupportsErrorReturn>
313 class RespondHelper;
315 // RespondHelper specialization for functions that support error returns.
316 template <>
317 class RespondHelper<true> {
318 public:
320 // Send Expected<T>.
321 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
322 typename FunctionIdT, typename SequenceNumberT>
323 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
324 SequenceNumberT SeqNo,
325 Expected<HandlerRetT> ResultOrErr) {
326 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
327 return ResultOrErr.takeError();
329 // Open the response message.
330 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
331 return Err;
333 // Serialize the result.
334 if (auto Err =
335 SerializationTraits<ChannelT, WireRetT,
336 Expected<HandlerRetT>>::serialize(
337 C, std::move(ResultOrErr)))
338 return Err;
340 // Close the response message.
341 if (auto Err = C.endSendMessage())
342 return Err;
343 return C.send();
346 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
347 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348 SequenceNumberT SeqNo, Error Err) {
349 if (Err && Err.isA<RPCFatalError>())
350 return Err;
351 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352 return Err2;
353 if (auto Err2 = serializeSeq(C, std::move(Err)))
354 return Err2;
355 if (auto Err2 = C.endSendMessage())
356 return Err2;
357 return C.send();
362 // RespondHelper specialization for functions that do not support error returns.
363 template <>
364 class RespondHelper<false> {
365 public:
367 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
368 typename FunctionIdT, typename SequenceNumberT>
369 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
370 SequenceNumberT SeqNo,
371 Expected<HandlerRetT> ResultOrErr) {
372 if (auto Err = ResultOrErr.takeError())
373 return Err;
375 // Open the response message.
376 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
377 return Err;
379 // Serialize the result.
380 if (auto Err =
381 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
382 C, *ResultOrErr))
383 return Err;
385 // End the response message.
386 if (auto Err = C.endSendMessage())
387 return Err;
389 return C.send();
392 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
393 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
394 SequenceNumberT SeqNo, Error Err) {
395 if (Err)
396 return Err;
397 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
398 return Err2;
399 if (auto Err2 = C.endSendMessage())
400 return Err2;
401 return C.send();
407 // Send a response of the given wire return type (WireRetT) over the
408 // channel, with the given sequence number.
409 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
410 typename FunctionIdT, typename SequenceNumberT>
411 Error respond(ChannelT &C, const FunctionIdT &ResponseId,
412 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
417 // Send an empty response message on the given channel to indicate that
418 // the handler ran.
419 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
420 typename SequenceNumberT>
421 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422 Error Err) {
423 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424 sendResult(C, ResponseId, SeqNo, std::move(Err));
427 // Converts a given type to the equivalent error return type.
428 template <typename T> class WrappedHandlerReturn {
429 public:
430 using Type = Expected<T>;
433 template <typename T> class WrappedHandlerReturn<Expected<T>> {
434 public:
435 using Type = Expected<T>;
438 template <> class WrappedHandlerReturn<void> {
439 public:
440 using Type = Error;
443 template <> class WrappedHandlerReturn<Error> {
444 public:
445 using Type = Error;
448 template <> class WrappedHandlerReturn<ErrorSuccess> {
449 public:
450 using Type = Error;
453 // Traits class that strips the response function from the list of handler
454 // arguments.
455 template <typename FnT> class AsyncHandlerTraits;
457 template <typename ResultT, typename... ArgTs>
458 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459 public:
460 using Type = Error(ArgTs...);
461 using ResultType = Expected<ResultT>;
464 template <typename... ArgTs>
465 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466 public:
467 using Type = Error(ArgTs...);
468 using ResultType = Error;
471 template <typename... ArgTs>
472 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473 public:
474 using Type = Error(ArgTs...);
475 using ResultType = Error;
478 template <typename... ArgTs>
479 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480 public:
481 using Type = Error(ArgTs...);
482 using ResultType = Error;
485 template <typename ResponseHandlerT, typename... ArgTs>
486 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488 ArgTs...)> {};
490 // This template class provides utilities related to RPC function handlers.
491 // The base case applies to non-function types (the template class is
492 // specialized for function types) and inherits from the appropriate
493 // speciilization for the given non-function type's call operator.
494 template <typename HandlerT>
495 class HandlerTraits : public HandlerTraits<decltype(
496 &std::remove_reference<HandlerT>::type::operator())> {
499 // Traits for handlers with a given function type.
500 template <typename RetT, typename... ArgTs>
501 class HandlerTraits<RetT(ArgTs...)> {
502 public:
503 // Function type of the handler.
504 using Type = RetT(ArgTs...);
506 // Return type of the handler.
507 using ReturnType = RetT;
509 // Call the given handler with the given arguments.
510 template <typename HandlerT, typename... TArgTs>
511 static typename WrappedHandlerReturn<RetT>::Type
512 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513 return unpackAndRunHelper(Handler, Args,
514 std::index_sequence_for<TArgTs...>());
517 // Call the given handler with the given arguments.
518 template <typename HandlerT, typename ResponderT, typename... TArgTs>
519 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520 std::tuple<TArgTs...> &Args) {
521 return unpackAndRunAsyncHelper(Handler, Responder, Args,
522 std::index_sequence_for<TArgTs...>());
525 // Call the given handler with the given arguments.
526 template <typename HandlerT>
527 static typename std::enable_if<
528 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529 Error>::type
530 run(HandlerT &Handler, ArgTs &&... Args) {
531 Handler(std::move(Args)...);
532 return Error::success();
535 template <typename HandlerT, typename... TArgTs>
536 static typename std::enable_if<
537 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538 typename HandlerTraits<HandlerT>::ReturnType>::type
539 run(HandlerT &Handler, TArgTs... Args) {
540 return Handler(std::move(Args)...);
543 // Serialize arguments to the channel.
544 template <typename ChannelT, typename... CArgTs>
545 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
549 // Deserialize arguments from the channel.
550 template <typename ChannelT, typename... CArgTs>
551 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552 return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
555 private:
556 template <typename ChannelT, typename... CArgTs, size_t... Indexes>
557 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
558 std::index_sequence<Indexes...> _) {
559 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
560 C, std::get<Indexes>(Args)...);
563 template <typename HandlerT, typename ArgTuple, size_t... Indexes>
564 static typename WrappedHandlerReturn<
565 typename HandlerTraits<HandlerT>::ReturnType>::Type
566 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
567 std::index_sequence<Indexes...>) {
568 return run(Handler, std::move(std::get<Indexes>(Args))...);
571 template <typename HandlerT, typename ResponderT, typename ArgTuple,
572 size_t... Indexes>
573 static typename WrappedHandlerReturn<
574 typename HandlerTraits<HandlerT>::ReturnType>::Type
575 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
576 ArgTuple &Args, std::index_sequence<Indexes...>) {
577 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
581 // Handler traits for free functions.
582 template <typename RetT, typename... ArgTs>
583 class HandlerTraits<RetT(*)(ArgTs...)>
584 : public HandlerTraits<RetT(ArgTs...)> {};
586 // Handler traits for class methods (especially call operators for lambdas).
587 template <typename Class, typename RetT, typename... ArgTs>
588 class HandlerTraits<RetT (Class::*)(ArgTs...)>
589 : public HandlerTraits<RetT(ArgTs...)> {};
591 // Handler traits for const class methods (especially call operators for
592 // lambdas).
593 template <typename Class, typename RetT, typename... ArgTs>
594 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
595 : public HandlerTraits<RetT(ArgTs...)> {};
597 // Utility to peel the Expected wrapper off a response handler error type.
598 template <typename HandlerT> class ResponseHandlerArg;
600 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
601 public:
602 using ArgType = Expected<ArgT>;
603 using UnwrappedArgType = ArgT;
606 template <typename ArgT>
607 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
608 public:
609 using ArgType = Expected<ArgT>;
610 using UnwrappedArgType = ArgT;
613 template <> class ResponseHandlerArg<Error(Error)> {
614 public:
615 using ArgType = Error;
618 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
619 public:
620 using ArgType = Error;
623 // ResponseHandler represents a handler for a not-yet-received function call
624 // result.
625 template <typename ChannelT> class ResponseHandler {
626 public:
627 virtual ~ResponseHandler() {}
629 // Reads the function result off the wire and acts on it. The meaning of
630 // "act" will depend on how this method is implemented in any given
631 // ResponseHandler subclass but could, for example, mean running a
632 // user-specified handler or setting a promise value.
633 virtual Error handleResponse(ChannelT &C) = 0;
635 // Abandons this outstanding result.
636 virtual void abandon() = 0;
638 // Create an error instance representing an abandoned response.
639 static Error createAbandonedResponseError() {
640 return make_error<ResponseAbandoned>();
644 // ResponseHandler subclass for RPC functions with non-void returns.
645 template <typename ChannelT, typename FuncRetT, typename HandlerT>
646 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
647 public:
648 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
650 // Handle the result by deserializing it from the channel then passing it
651 // to the user defined handler.
652 Error handleResponse(ChannelT &C) override {
653 using UnwrappedArgType = typename ResponseHandlerArg<
654 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
655 UnwrappedArgType Result;
656 if (auto Err =
657 SerializationTraits<ChannelT, FuncRetT,
658 UnwrappedArgType>::deserialize(C, Result))
659 return Err;
660 if (auto Err = C.endReceiveMessage())
661 return Err;
662 return Handler(std::move(Result));
665 // Abandon this response by calling the handler with an 'abandoned response'
666 // error.
667 void abandon() override {
668 if (auto Err = Handler(this->createAbandonedResponseError())) {
669 // Handlers should not fail when passed an abandoned response error.
670 report_fatal_error(std::move(Err));
674 private:
675 HandlerT Handler;
678 // ResponseHandler subclass for RPC functions with void returns.
679 template <typename ChannelT, typename HandlerT>
680 class ResponseHandlerImpl<ChannelT, void, HandlerT>
681 : public ResponseHandler<ChannelT> {
682 public:
683 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
685 // Handle the result (no actual value, just a notification that the function
686 // has completed on the remote end) by calling the user-defined handler with
687 // Error::success().
688 Error handleResponse(ChannelT &C) override {
689 if (auto Err = C.endReceiveMessage())
690 return Err;
691 return Handler(Error::success());
694 // Abandon this response by calling the handler with an 'abandoned response'
695 // error.
696 void abandon() override {
697 if (auto Err = Handler(this->createAbandonedResponseError())) {
698 // Handlers should not fail when passed an abandoned response error.
699 report_fatal_error(std::move(Err));
703 private:
704 HandlerT Handler;
707 template <typename ChannelT, typename FuncRetT, typename HandlerT>
708 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
709 : public ResponseHandler<ChannelT> {
710 public:
711 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
713 // Handle the result by deserializing it from the channel then passing it
714 // to the user defined handler.
715 Error handleResponse(ChannelT &C) override {
716 using HandlerArgType = typename ResponseHandlerArg<
717 typename HandlerTraits<HandlerT>::Type>::ArgType;
718 HandlerArgType Result((typename HandlerArgType::value_type()));
720 if (auto Err =
721 SerializationTraits<ChannelT, Expected<FuncRetT>,
722 HandlerArgType>::deserialize(C, Result))
723 return Err;
724 if (auto Err = C.endReceiveMessage())
725 return Err;
726 return Handler(std::move(Result));
729 // Abandon this response by calling the handler with an 'abandoned response'
730 // error.
731 void abandon() override {
732 if (auto Err = Handler(this->createAbandonedResponseError())) {
733 // Handlers should not fail when passed an abandoned response error.
734 report_fatal_error(std::move(Err));
738 private:
739 HandlerT Handler;
742 template <typename ChannelT, typename HandlerT>
743 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
744 : public ResponseHandler<ChannelT> {
745 public:
746 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
748 // Handle the result by deserializing it from the channel then passing it
749 // to the user defined handler.
750 Error handleResponse(ChannelT &C) override {
751 Error Result = Error::success();
752 if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
753 C, Result)) {
754 consumeError(std::move(Result));
755 return Err;
757 if (auto Err = C.endReceiveMessage()) {
758 consumeError(std::move(Result));
759 return Err;
761 return Handler(std::move(Result));
764 // Abandon this response by calling the handler with an 'abandoned response'
765 // error.
766 void abandon() override {
767 if (auto Err = Handler(this->createAbandonedResponseError())) {
768 // Handlers should not fail when passed an abandoned response error.
769 report_fatal_error(std::move(Err));
773 private:
774 HandlerT Handler;
777 // Create a ResponseHandler from a given user handler.
778 template <typename ChannelT, typename FuncRetT, typename HandlerT>
779 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
780 return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
781 std::move(H));
784 // Helper for wrapping member functions up as functors. This is useful for
785 // installing methods as result handlers.
786 template <typename ClassT, typename RetT, typename... ArgTs>
787 class MemberFnWrapper {
788 public:
789 using MethodT = RetT (ClassT::*)(ArgTs...);
790 MemberFnWrapper(ClassT &Instance, MethodT Method)
791 : Instance(Instance), Method(Method) {}
792 RetT operator()(ArgTs &&... Args) {
793 return (Instance.*Method)(std::move(Args)...);
796 private:
797 ClassT &Instance;
798 MethodT Method;
801 // Helper that provides a Functor for deserializing arguments.
802 template <typename... ArgTs> class ReadArgs {
803 public:
804 Error operator()() { return Error::success(); }
807 template <typename ArgT, typename... ArgTs>
808 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
809 public:
810 ReadArgs(ArgT &Arg, ArgTs &... Args)
811 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
813 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
814 this->Arg = std::move(ArgVal);
815 return ReadArgs<ArgTs...>::operator()(ArgVals...);
818 private:
819 ArgT &Arg;
822 // Manage sequence numbers.
823 template <typename SequenceNumberT> class SequenceNumberManager {
824 public:
825 // Reset, making all sequence numbers available.
826 void reset() {
827 std::lock_guard<std::mutex> Lock(SeqNoLock);
828 NextSequenceNumber = 0;
829 FreeSequenceNumbers.clear();
832 // Get the next available sequence number. Will re-use numbers that have
833 // been released.
834 SequenceNumberT getSequenceNumber() {
835 std::lock_guard<std::mutex> Lock(SeqNoLock);
836 if (FreeSequenceNumbers.empty())
837 return NextSequenceNumber++;
838 auto SequenceNumber = FreeSequenceNumbers.back();
839 FreeSequenceNumbers.pop_back();
840 return SequenceNumber;
843 // Release a sequence number, making it available for re-use.
844 void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
845 std::lock_guard<std::mutex> Lock(SeqNoLock);
846 FreeSequenceNumbers.push_back(SequenceNumber);
849 private:
850 std::mutex SeqNoLock;
851 SequenceNumberT NextSequenceNumber = 0;
852 std::vector<SequenceNumberT> FreeSequenceNumbers;
855 // Checks that predicate P holds for each corresponding pair of type arguments
856 // from T1 and T2 tuple.
857 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
858 class RPCArgTypeCheckHelper;
860 template <template <class, class> class P>
861 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
862 public:
863 static const bool value = true;
866 template <template <class, class> class P, typename T, typename... Ts,
867 typename U, typename... Us>
868 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
869 public:
870 static const bool value =
871 P<T, U>::value &&
872 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
875 template <template <class, class> class P, typename T1Sig, typename T2Sig>
876 class RPCArgTypeCheck {
877 public:
878 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
879 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
881 static_assert(std::tuple_size<T1Tuple>::value >=
882 std::tuple_size<T2Tuple>::value,
883 "Too many arguments to RPC call");
884 static_assert(std::tuple_size<T1Tuple>::value <=
885 std::tuple_size<T2Tuple>::value,
886 "Too few arguments to RPC call");
888 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
891 template <typename ChannelT, typename WireT, typename ConcreteT>
892 class CanSerialize {
893 private:
894 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
896 template <typename T>
897 static std::true_type
898 check(typename std::enable_if<
899 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
900 std::declval<const ConcreteT &>())),
901 Error>::value,
902 void *>::type);
904 template <typename> static std::false_type check(...);
906 public:
907 static const bool value = decltype(check<S>(0))::value;
910 template <typename ChannelT, typename WireT, typename ConcreteT>
911 class CanDeserialize {
912 private:
913 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
915 template <typename T>
916 static std::true_type
917 check(typename std::enable_if<
918 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
919 std::declval<ConcreteT &>())),
920 Error>::value,
921 void *>::type);
923 template <typename> static std::false_type check(...);
925 public:
926 static const bool value = decltype(check<S>(0))::value;
929 /// Contains primitive utilities for defining, calling and handling calls to
930 /// remote procedures. ChannelT is a bidirectional stream conforming to the
931 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
932 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
933 /// is an integral type that will be used to number in-flight function calls.
935 /// These utilities support the construction of very primitive RPC utilities.
936 /// Their intent is to ensure correct serialization and deserialization of
937 /// procedure arguments, and to keep the client and server's view of the API in
938 /// sync.
939 template <typename ImplT, typename ChannelT, typename FunctionIdT,
940 typename SequenceNumberT>
941 class RPCEndpointBase {
942 protected:
943 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
944 public:
945 static const char *getName() { return "__orc_rpc$invalid"; }
948 class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
949 public:
950 static const char *getName() { return "__orc_rpc$response"; }
953 class OrcRPCNegotiate
954 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
955 public:
956 static const char *getName() { return "__orc_rpc$negotiate"; }
959 // Helper predicate for testing for the presence of SerializeTraits
960 // serializers.
961 template <typename WireT, typename ConcreteT>
962 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
963 public:
964 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
966 static_assert(value, "Missing serializer for argument (Can't serialize the "
967 "first template type argument of CanSerializeCheck "
968 "from the second)");
971 // Helper predicate for testing for the presence of SerializeTraits
972 // deserializers.
973 template <typename WireT, typename ConcreteT>
974 class CanDeserializeCheck
975 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
976 public:
977 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
979 static_assert(value, "Missing deserializer for argument (Can't deserialize "
980 "the second template type argument of "
981 "CanDeserializeCheck from the first)");
984 public:
985 /// Construct an RPC instance on a channel.
986 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
987 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
988 // Hold ResponseId in a special variable, since we expect Response to be
989 // called relatively frequently, and want to avoid the map lookup.
990 ResponseId = FnIdAllocator.getResponseId();
991 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
993 // Register the negotiate function id and handler.
994 auto NegotiateId = FnIdAllocator.getNegotiateId();
995 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
996 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
997 [this](const std::string &Name) { return handleNegotiate(Name); });
1001 /// Negotiate a function id for Func with the other end of the channel.
1002 template <typename Func> Error negotiateFunction(bool Retry = false) {
1003 return getRemoteFunctionId<Func>(true, Retry).takeError();
1006 /// Append a call Func, does not call send on the channel.
1007 /// The first argument specifies a user-defined handler to be run when the
1008 /// function returns. The handler should take an Expected<Func::ReturnType>,
1009 /// or an Error (if Func::ReturnType is void). The handler will be called
1010 /// with an error if the return value is abandoned due to a channel error.
1011 template <typename Func, typename HandlerT, typename... ArgTs>
1012 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1014 static_assert(
1015 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1016 void(ArgTs...)>::value,
1017 "");
1019 // Look up the function ID.
1020 FunctionIdT FnId;
1021 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1022 FnId = *FnIdOrErr;
1023 else {
1024 // Negotiation failed. Notify the handler then return the negotiate-failed
1025 // error.
1026 cantFail(Handler(make_error<ResponseAbandoned>()));
1027 return FnIdOrErr.takeError();
1030 SequenceNumberT SeqNo; // initialized in locked scope below.
1032 // Lock the pending responses map and sequence number manager.
1033 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1035 // Allocate a sequence number.
1036 SeqNo = SequenceNumberMgr.getSequenceNumber();
1037 assert(!PendingResponses.count(SeqNo) &&
1038 "Sequence number already allocated");
1040 // Install the user handler.
1041 PendingResponses[SeqNo] =
1042 detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1043 std::move(Handler));
1046 // Open the function call message.
1047 if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1048 abandonPendingResponses();
1049 return Err;
1052 // Serialize the call arguments.
1053 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1054 C, Args...)) {
1055 abandonPendingResponses();
1056 return Err;
1059 // Close the function call messagee.
1060 if (auto Err = C.endSendMessage()) {
1061 abandonPendingResponses();
1062 return Err;
1065 return Error::success();
1068 Error sendAppendedCalls() { return C.send(); };
1070 template <typename Func, typename HandlerT, typename... ArgTs>
1071 Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1072 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1073 return Err;
1074 return C.send();
1077 /// Handle one incoming call.
1078 Error handleOne() {
1079 FunctionIdT FnId;
1080 SequenceNumberT SeqNo;
1081 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1082 abandonPendingResponses();
1083 return Err;
1085 if (FnId == ResponseId)
1086 return handleResponse(SeqNo);
1087 auto I = Handlers.find(FnId);
1088 if (I != Handlers.end())
1089 return I->second(C, SeqNo);
1091 // else: No handler found. Report error to client?
1092 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1093 SeqNo);
1096 /// Helper for handling setter procedures - this method returns a functor that
1097 /// sets the variables referred to by Args... to values deserialized from the
1098 /// channel.
1099 /// E.g.
1101 /// typedef Function<0, bool, int> Func1;
1103 /// ...
1104 /// bool B;
1105 /// int I;
1106 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1107 /// /* Handle Args */ ;
1109 template <typename... ArgTs>
1110 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1111 return detail::ReadArgs<ArgTs...>(Args...);
1114 /// Abandon all outstanding result handlers.
1116 /// This will call all currently registered result handlers to receive an
1117 /// "abandoned" error as their argument. This is used internally by the RPC
1118 /// in error situations, but can also be called directly by clients who are
1119 /// disconnecting from the remote and don't or can't expect responses to their
1120 /// outstanding calls. (Especially for outstanding blocking calls, calling
1121 /// this function may be necessary to avoid dead threads).
1122 void abandonPendingResponses() {
1123 // Lock the pending responses map and sequence number manager.
1124 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1126 for (auto &KV : PendingResponses)
1127 KV.second->abandon();
1128 PendingResponses.clear();
1129 SequenceNumberMgr.reset();
1132 /// Remove the handler for the given function.
1133 /// A handler must currently be registered for this function.
1134 template <typename Func>
1135 void removeHandler() {
1136 auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1137 assert(IdItr != LocalFunctionIds.end() &&
1138 "Function does not have a registered handler");
1139 auto HandlerItr = Handlers.find(IdItr->second);
1140 assert(HandlerItr != Handlers.end() &&
1141 "Function does not have a registered handler");
1142 Handlers.erase(HandlerItr);
1145 /// Clear all handlers.
1146 void clearHandlers() {
1147 Handlers.clear();
1150 protected:
1152 FunctionIdT getInvalidFunctionId() const {
1153 return FnIdAllocator.getInvalidId();
1156 /// Add the given handler to the handler map and make it available for
1157 /// autonegotiation and execution.
1158 template <typename Func, typename HandlerT>
1159 void addHandlerImpl(HandlerT Handler) {
1161 static_assert(detail::RPCArgTypeCheck<
1162 CanDeserializeCheck, typename Func::Type,
1163 typename detail::HandlerTraits<HandlerT>::Type>::value,
1164 "");
1166 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1167 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1168 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1171 template <typename Func, typename HandlerT>
1172 void addAsyncHandlerImpl(HandlerT Handler) {
1174 static_assert(detail::RPCArgTypeCheck<
1175 CanDeserializeCheck, typename Func::Type,
1176 typename detail::AsyncHandlerTraits<
1177 typename detail::HandlerTraits<HandlerT>::Type
1178 >::Type>::value,
1179 "");
1181 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1182 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1183 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1186 Error handleResponse(SequenceNumberT SeqNo) {
1187 using Handler = typename decltype(PendingResponses)::mapped_type;
1188 Handler PRHandler;
1191 // Lock the pending responses map and sequence number manager.
1192 std::unique_lock<std::mutex> Lock(ResponsesMutex);
1193 auto I = PendingResponses.find(SeqNo);
1195 if (I != PendingResponses.end()) {
1196 PRHandler = std::move(I->second);
1197 PendingResponses.erase(I);
1198 SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1199 } else {
1200 // Unlock the pending results map to prevent recursive lock.
1201 Lock.unlock();
1202 abandonPendingResponses();
1203 return make_error<
1204 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1208 assert(PRHandler &&
1209 "If we didn't find a response handler we should have bailed out");
1211 if (auto Err = PRHandler->handleResponse(C)) {
1212 abandonPendingResponses();
1213 return Err;
1216 return Error::success();
1219 FunctionIdT handleNegotiate(const std::string &Name) {
1220 auto I = LocalFunctionIds.find(Name);
1221 if (I == LocalFunctionIds.end())
1222 return getInvalidFunctionId();
1223 return I->second;
1226 // Find the remote FunctionId for the given function.
1227 template <typename Func>
1228 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1229 bool NegotiateIfInvalid) {
1230 bool DoNegotiate;
1232 // Check if we already have a function id...
1233 auto I = RemoteFunctionIds.find(Func::getPrototype());
1234 if (I != RemoteFunctionIds.end()) {
1235 // If it's valid there's nothing left to do.
1236 if (I->second != getInvalidFunctionId())
1237 return I->second;
1238 DoNegotiate = NegotiateIfInvalid;
1239 } else
1240 DoNegotiate = NegotiateIfNotInMap;
1242 // We don't have a function id for Func yet, but we're allowed to try to
1243 // negotiate one.
1244 if (DoNegotiate) {
1245 auto &Impl = static_cast<ImplT &>(*this);
1246 if (auto RemoteIdOrErr =
1247 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1248 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1249 if (*RemoteIdOrErr == getInvalidFunctionId())
1250 return make_error<CouldNotNegotiate>(Func::getPrototype());
1251 return *RemoteIdOrErr;
1252 } else
1253 return RemoteIdOrErr.takeError();
1256 // No key was available in the map and we weren't allowed to try to
1257 // negotiate one, so return an unknown function error.
1258 return make_error<CouldNotNegotiate>(Func::getPrototype());
1261 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1263 // Wrap the given user handler in the necessary argument-deserialization code,
1264 // result-serialization code, and call to the launch policy (if present).
1265 template <typename Func, typename HandlerT>
1266 WrappedHandlerFn wrapHandler(HandlerT Handler) {
1267 return [this, Handler](ChannelT &Channel,
1268 SequenceNumberT SeqNo) mutable -> Error {
1269 // Start by deserializing the arguments.
1270 using ArgsTuple =
1271 typename detail::FunctionArgsTuple<
1272 typename detail::HandlerTraits<HandlerT>::Type>::Type;
1273 auto Args = std::make_shared<ArgsTuple>();
1275 if (auto Err =
1276 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1277 Channel, *Args))
1278 return Err;
1280 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1281 // for RPCArgs. Void cast RPCArgs to work around this for now.
1282 // FIXME: Remove this workaround once we can assume a working GCC version.
1283 (void)Args;
1285 // End receieve message, unlocking the channel for reading.
1286 if (auto Err = Channel.endReceiveMessage())
1287 return Err;
1289 using HTraits = detail::HandlerTraits<HandlerT>;
1290 using FuncReturn = typename Func::ReturnType;
1291 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1292 HTraits::unpackAndRun(Handler, *Args));
1296 // Wrap the given user handler in the necessary argument-deserialization code,
1297 // result-serialization code, and call to the launch policy (if present).
1298 template <typename Func, typename HandlerT>
1299 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1300 return [this, Handler](ChannelT &Channel,
1301 SequenceNumberT SeqNo) mutable -> Error {
1302 // Start by deserializing the arguments.
1303 using AHTraits = detail::AsyncHandlerTraits<
1304 typename detail::HandlerTraits<HandlerT>::Type>;
1305 using ArgsTuple =
1306 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1307 auto Args = std::make_shared<ArgsTuple>();
1309 if (auto Err =
1310 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1311 Channel, *Args))
1312 return Err;
1314 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1315 // for RPCArgs. Void cast RPCArgs to work around this for now.
1316 // FIXME: Remove this workaround once we can assume a working GCC version.
1317 (void)Args;
1319 // End receieve message, unlocking the channel for reading.
1320 if (auto Err = Channel.endReceiveMessage())
1321 return Err;
1323 using HTraits = detail::HandlerTraits<HandlerT>;
1324 using FuncReturn = typename Func::ReturnType;
1325 auto Responder =
1326 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1327 return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1328 std::move(RetVal));
1331 return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1335 ChannelT &C;
1337 bool LazyAutoNegotiation;
1339 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1341 FunctionIdT ResponseId;
1342 std::map<std::string, FunctionIdT> LocalFunctionIds;
1343 std::map<const char *, FunctionIdT> RemoteFunctionIds;
1345 std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1347 std::mutex ResponsesMutex;
1348 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1349 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1350 PendingResponses;
1353 } // end namespace detail
1355 template <typename ChannelT, typename FunctionIdT = uint32_t,
1356 typename SequenceNumberT = uint32_t>
1357 class MultiThreadedRPCEndpoint
1358 : public detail::RPCEndpointBase<
1359 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1360 ChannelT, FunctionIdT, SequenceNumberT> {
1361 private:
1362 using BaseClass =
1363 detail::RPCEndpointBase<
1364 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1365 ChannelT, FunctionIdT, SequenceNumberT>;
1367 public:
1368 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1369 : BaseClass(C, LazyAutoNegotiation) {}
1371 /// Add a handler for the given RPC function.
1372 /// This installs the given handler functor for the given RPC Function, and
1373 /// makes the RPC function available for negotiation/calling from the remote.
1374 template <typename Func, typename HandlerT>
1375 void addHandler(HandlerT Handler) {
1376 return this->template addHandlerImpl<Func>(std::move(Handler));
1379 /// Add a class-method as a handler.
1380 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1381 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1382 addHandler<Func>(
1383 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1386 template <typename Func, typename HandlerT>
1387 void addAsyncHandler(HandlerT Handler) {
1388 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1391 /// Add a class-method as a handler.
1392 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1393 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1394 addAsyncHandler<Func>(
1395 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1398 /// Return type for non-blocking call primitives.
1399 template <typename Func>
1400 using NonBlockingCallResult = typename detail::ResultTraits<
1401 typename Func::ReturnType>::ReturnFutureType;
1403 /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1404 /// of a future result and the sequence number assigned to the result.
1406 /// This utility function is primarily used for single-threaded mode support,
1407 /// where the sequence number can be used to wait for the corresponding
1408 /// result. In multi-threaded mode the appendCallNB method, which does not
1409 /// return the sequence numeber, should be preferred.
1410 template <typename Func, typename... ArgTs>
1411 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1412 using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1413 using ErrorReturn = typename RTraits::ErrorReturnType;
1414 using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1416 ErrorReturnPromise Promise;
1417 auto FutureResult = Promise.get_future();
1419 if (auto Err = this->template appendCallAsync<Func>(
1420 [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1421 Promise.set_value(std::move(RetOrErr));
1422 return Error::success();
1424 Args...)) {
1425 RTraits::consumeAbandoned(FutureResult.get());
1426 return std::move(Err);
1428 return std::move(FutureResult);
1431 /// The same as appendCallNBWithSeq, except that it calls C.send() to
1432 /// flush the channel after serializing the call.
1433 template <typename Func, typename... ArgTs>
1434 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1435 auto Result = appendCallNB<Func>(Args...);
1436 if (!Result)
1437 return Result;
1438 if (auto Err = this->C.send()) {
1439 this->abandonPendingResponses();
1440 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1441 std::move(Result->get()));
1442 return std::move(Err);
1444 return Result;
1447 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1448 /// for void functions or an Expected<T> for functions returning a T.
1450 /// This function is for use in threaded code where another thread is
1451 /// handling responses and incoming calls.
1452 template <typename Func, typename... ArgTs,
1453 typename AltRetT = typename Func::ReturnType>
1454 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1455 callB(const ArgTs &... Args) {
1456 if (auto FutureResOrErr = callNB<Func>(Args...))
1457 return FutureResOrErr->get();
1458 else
1459 return FutureResOrErr.takeError();
1462 /// Handle incoming RPC calls.
1463 Error handlerLoop() {
1464 while (true)
1465 if (auto Err = this->handleOne())
1466 return Err;
1467 return Error::success();
1471 template <typename ChannelT, typename FunctionIdT = uint32_t,
1472 typename SequenceNumberT = uint32_t>
1473 class SingleThreadedRPCEndpoint
1474 : public detail::RPCEndpointBase<
1475 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1476 ChannelT, FunctionIdT, SequenceNumberT> {
1477 private:
1478 using BaseClass =
1479 detail::RPCEndpointBase<
1480 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1481 ChannelT, FunctionIdT, SequenceNumberT>;
1483 public:
1484 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1485 : BaseClass(C, LazyAutoNegotiation) {}
1487 template <typename Func, typename HandlerT>
1488 void addHandler(HandlerT Handler) {
1489 return this->template addHandlerImpl<Func>(std::move(Handler));
1492 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1493 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1494 addHandler<Func>(
1495 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1498 template <typename Func, typename HandlerT>
1499 void addAsyncHandler(HandlerT Handler) {
1500 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1503 /// Add a class-method as a handler.
1504 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1505 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1506 addAsyncHandler<Func>(
1507 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1510 template <typename Func, typename... ArgTs,
1511 typename AltRetT = typename Func::ReturnType>
1512 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1513 callB(const ArgTs &... Args) {
1514 bool ReceivedResponse = false;
1515 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1516 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1518 // We have to 'Check' result (which we know is in a success state at this
1519 // point) so that it can be overwritten in the async handler.
1520 (void)!!Result;
1522 if (auto Err = this->template appendCallAsync<Func>(
1523 [&](ResultType R) {
1524 Result = std::move(R);
1525 ReceivedResponse = true;
1526 return Error::success();
1528 Args...)) {
1529 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1530 std::move(Result));
1531 return std::move(Err);
1534 if (auto Err = this->C.send()) {
1535 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1536 std::move(Result));
1537 return std::move(Err);
1540 while (!ReceivedResponse) {
1541 if (auto Err = this->handleOne()) {
1542 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1543 std::move(Result));
1544 return std::move(Err);
1548 return Result;
1552 /// Asynchronous dispatch for a function on an RPC endpoint.
1553 template <typename RPCClass, typename Func>
1554 class RPCAsyncDispatch {
1555 public:
1556 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1558 template <typename HandlerT, typename... ArgTs>
1559 Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1560 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1563 private:
1564 RPCClass &Endpoint;
1567 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1568 template <typename Func, typename RPCEndpointT>
1569 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1570 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1573 /// Allows a set of asynchrounous calls to be dispatched, and then
1574 /// waited on as a group.
1575 class ParallelCallGroup {
1576 public:
1578 ParallelCallGroup() = default;
1579 ParallelCallGroup(const ParallelCallGroup &) = delete;
1580 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1582 /// Make as asynchronous call.
1583 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1584 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1585 const ArgTs &... Args) {
1586 // Increment the count of outstanding calls. This has to happen before
1587 // we invoke the call, as the handler may (depending on scheduling)
1588 // be run immediately on another thread, and we don't want the decrement
1589 // in the wrapped handler below to run before the increment.
1591 std::unique_lock<std::mutex> Lock(M);
1592 ++NumOutstandingCalls;
1595 // Wrap the user handler in a lambda that will decrement the
1596 // outstanding calls count, then poke the condition variable.
1597 using ArgType = typename detail::ResponseHandlerArg<
1598 typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1599 auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1600 auto Err = Handler(std::move(Arg));
1601 std::unique_lock<std::mutex> Lock(M);
1602 --NumOutstandingCalls;
1603 CV.notify_all();
1604 return Err;
1607 return AsyncDispatch(std::move(WrappedHandler), Args...);
1610 /// Blocks until all calls have been completed and their return value
1611 /// handlers run.
1612 void wait() {
1613 std::unique_lock<std::mutex> Lock(M);
1614 while (NumOutstandingCalls > 0)
1615 CV.wait(Lock);
1618 private:
1619 std::mutex M;
1620 std::condition_variable CV;
1621 uint32_t NumOutstandingCalls = 0;
1624 /// Convenience class for grouping RPC Functions into APIs that can be
1625 /// negotiated as a block.
1627 template <typename... Funcs>
1628 class APICalls {
1629 public:
1631 /// Test whether this API contains Function F.
1632 template <typename F>
1633 class Contains {
1634 public:
1635 static const bool value = false;
1638 /// Negotiate all functions in this API.
1639 template <typename RPCEndpoint>
1640 static Error negotiate(RPCEndpoint &R) {
1641 return Error::success();
1645 template <typename Func, typename... Funcs>
1646 class APICalls<Func, Funcs...> {
1647 public:
1649 template <typename F>
1650 class Contains {
1651 public:
1652 static const bool value = std::is_same<F, Func>::value |
1653 APICalls<Funcs...>::template Contains<F>::value;
1656 template <typename RPCEndpoint>
1657 static Error negotiate(RPCEndpoint &R) {
1658 if (auto Err = R.template negotiateFunction<Func>())
1659 return Err;
1660 return APICalls<Funcs...>::negotiate(R);
1665 template <typename... InnerFuncs, typename... Funcs>
1666 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1667 public:
1669 template <typename F>
1670 class Contains {
1671 public:
1672 static const bool value =
1673 APICalls<InnerFuncs...>::template Contains<F>::value |
1674 APICalls<Funcs...>::template Contains<F>::value;
1677 template <typename RPCEndpoint>
1678 static Error negotiate(RPCEndpoint &R) {
1679 if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1680 return Err;
1681 return APICalls<Funcs...>::negotiate(R);
1686 } // end namespace rpc
1687 } // end namespace orc
1688 } // end namespace llvm
1690 #endif