1 //===-- MPCUtils.h ----------------------------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
10 #define LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/complex_type.h"
14 #include "src/__support/macros/config.h"
15 #include "src/__support/macros/properties/complex_types.h"
16 #include "src/__support/macros/properties/types.h"
17 #include "test/UnitTest/RoundingModeUtils.h"
18 #include "test/UnitTest/Test.h"
22 namespace LIBC_NAMESPACE_DECL
{
26 enum class Operation
{
27 // Operations which take a single complex floating point number as input
28 // and produce a single floating point number as output which has the same
29 // floating point type as the real/imaginary part of the input.
30 BeginUnaryOperationsSingleOutputDifferentOutputType
,
33 EndUnaryOperationsSingleOutputDifferentOutputType
,
35 // Operations which take a single complex floating point number as input
36 // and produce a single complex floating point number of the same kind
38 BeginUnaryOperationsSingleOutputSameOutputType
,
55 EndUnaryOperationsSingleOutputSameOutputType
,
57 // Operations which take two complex floating point numbers as input
58 // and produce a single complex floating point number of the same kind
60 BeginBinaryOperationsSingleOutput
,
62 EndBinaryOperationsSingleOutput
,
65 using LIBC_NAMESPACE::fputil::testing::RoundingMode
;
67 template <typename T
> struct BinaryInput
{
68 static_assert(LIBC_NAMESPACE::cpp::is_complex_v
<T
>,
69 "Template parameter of BinaryInput must be a complex floating "
78 template <typename InputType
, typename OutputType
>
79 bool compare_unary_operation_single_output_same_type(Operation op
,
81 OutputType libc_output
,
83 RoundingMode rounding
);
85 template <typename InputType
, typename OutputType
>
86 bool compare_unary_operation_single_output_different_type(
87 Operation op
, InputType input
, OutputType libc_output
, double ulp_tolerance
,
88 RoundingMode rounding
);
90 template <typename InputType
, typename OutputType
>
91 bool compare_binary_operation_one_output(Operation op
,
92 const BinaryInput
<InputType
> &input
,
93 OutputType libc_output
,
95 RoundingMode rounding
);
97 template <typename InputType
, typename OutputType
>
98 void explain_unary_operation_single_output_same_type_error(
99 Operation op
, InputType input
, OutputType match_value
, double ulp_tolerance
,
100 RoundingMode rounding
);
102 template <typename InputType
, typename OutputType
>
103 void explain_unary_operation_single_output_different_type_error(
104 Operation op
, InputType input
, OutputType match_value
, double ulp_tolerance
,
105 RoundingMode rounding
);
107 template <typename InputType
, typename OutputType
>
108 void explain_binary_operation_one_output_error(
109 Operation op
, const BinaryInput
<InputType
> &input
, OutputType match_value
,
110 double ulp_tolerance
, RoundingMode rounding
);
112 template <Operation op
, typename InputType
, typename OutputType
>
113 class MPCMatcher
: public testing::Matcher
<OutputType
> {
116 OutputType match_value
;
117 double ulp_tolerance
;
118 RoundingMode rounding
;
121 MPCMatcher(InputType testInput
, double ulp_tolerance
, RoundingMode rounding
)
122 : input(testInput
), ulp_tolerance(ulp_tolerance
), rounding(rounding
) {}
124 bool match(OutputType libcResult
) {
125 match_value
= libcResult
;
126 return match(input
, match_value
);
129 void explainError() override
{ // NOLINT
130 explain_error(input
, match_value
);
134 template <typename InType
, typename OutType
>
135 bool match(InType in
, OutType out
) {
136 if (cpp::is_same_v
<InType
, OutType
>) {
137 return compare_unary_operation_single_output_same_type(
138 op
, in
, out
, ulp_tolerance
, rounding
);
140 return compare_unary_operation_single_output_different_type(
141 op
, in
, out
, ulp_tolerance
, rounding
);
145 template <typename T
, typename U
>
146 bool match(const BinaryInput
<T
> &in
, U out
) {
147 return compare_binary_operation_one_output(op
, in
, out
, ulp_tolerance
,
151 template <typename InType
, typename OutType
>
152 void explain_error(InType in
, OutType out
) {
153 if (cpp::is_same_v
<InType
, OutType
>) {
154 explain_unary_operation_single_output_same_type_error(
155 op
, in
, out
, ulp_tolerance
, rounding
);
157 explain_unary_operation_single_output_different_type_error(
158 op
, in
, out
, ulp_tolerance
, rounding
);
162 template <typename T
, typename U
>
163 void explain_error(const BinaryInput
<T
> &in
, U out
) {
164 explain_binary_operation_one_output_error(op
, in
, out
, ulp_tolerance
,
169 } // namespace internal
171 // Return true if the input and ouput types for the operation op are valid
173 template <Operation op
, typename InputType
, typename OutputType
>
174 constexpr bool is_valid_operation() {
175 return (Operation::BeginBinaryOperationsSingleOutput
< op
&&
176 op
< Operation::EndBinaryOperationsSingleOutput
&&
177 cpp::is_complex_type_same
<InputType
, OutputType
> &&
178 cpp::is_complex_v
<InputType
>) ||
179 (Operation::BeginUnaryOperationsSingleOutputSameOutputType
< op
&&
180 op
< Operation::EndUnaryOperationsSingleOutputSameOutputType
&&
181 cpp::is_complex_type_same
<InputType
, OutputType
> &&
182 cpp::is_complex_v
<InputType
>) ||
183 (Operation::BeginUnaryOperationsSingleOutputDifferentOutputType
< op
&&
184 op
< Operation::EndUnaryOperationsSingleOutputDifferentOutputType
&&
185 cpp::is_same_v
<make_real_t
<InputType
>, OutputType
> &&
186 cpp::is_complex_v
<InputType
>);
189 template <Operation op
, typename InputType
, typename OutputType
>
190 cpp::enable_if_t
<is_valid_operation
<op
, InputType
, OutputType
>(),
191 internal::MPCMatcher
<op
, InputType
, OutputType
>>
192 get_mpc_matcher(InputType input
, [[maybe_unused
]] OutputType output
,
193 double ulp_tolerance
, RoundingMode rounding
) {
194 return internal::MPCMatcher
<op
, InputType
, OutputType
>(input
, ulp_tolerance
,
199 } // namespace testing
200 } // namespace LIBC_NAMESPACE_DECL
202 #define EXPECT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
203 EXPECT_THAT(match_value, \
204 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
205 input, match_value, ulp_tolerance, \
206 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
208 #define EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
210 EXPECT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
211 input, match_value, ulp_tolerance, rounding))
213 #define EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
214 ulp_tolerance, rounding) \
216 MPCRND::ForceRoundingMode __r(rounding); \
218 EXPECT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
223 #define EXPECT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
225 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
226 for (int i = 0; i < 4; i++) { \
227 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
228 EXPECT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
229 ulp_tolerance, r_mode); \
233 #define TEST_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
235 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>(input, match_value, \
236 ulp_tolerance, rounding) \
239 #define ASSERT_MPC_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
240 ASSERT_THAT(match_value, \
241 LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
242 input, match_value, ulp_tolerance, \
243 LIBC_NAMESPACE::fputil::testing::RoundingMode::Nearest))
245 #define ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
247 ASSERT_THAT(match_value, LIBC_NAMESPACE::testing::mpc::get_mpc_matcher<op>( \
248 input, match_value, ulp_tolerance, rounding))
250 #define ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
251 ulp_tolerance, rounding) \
253 MPCRND::ForceRoundingMode __r(rounding); \
255 ASSERT_MPC_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
260 #define ASSERT_MPC_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
262 namespace MPCRND = LIBC_NAMESPACE::fputil::testing; \
263 for (int i = 0; i < 4; i++) { \
264 MPCRND::RoundingMode r_mode = static_cast<MPCRND::RoundingMode>(i); \
265 ASSERT_MPC_MATCH_ALL_ROUNDING_HELPER(op, input, match_value, \
266 ulp_tolerance, r_mode); \
270 #endif // LLVM_LIBC_UTILS_MPCWRAPPER_MPCUTILS_H