[libc] Switch to using the generic `<gpuintrin.h>` implementations (#121810)
[llvm-project.git] / libc / utils / MPFRWrapper / MPFRUtils.h
blobc7a57819f68b79c0524632fa40832b52a90cb3ec
1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
10 #define LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/macros/config.h"
14 #include "test/UnitTest/RoundingModeUtils.h"
15 #include "test/UnitTest/Test.h"
17 #include <stdint.h>
19 namespace LIBC_NAMESPACE_DECL {
20 namespace testing {
21 namespace mpfr {
23 enum class Operation : int {
24 // Operations with take a single floating point number as input
25 // and produce a single floating point number as output. The input
26 // and output floating point numbers are of the same kind.
27 BeginUnaryOperationsSingleOutput,
28 Abs,
29 Acos,
30 Acosh,
31 Asin,
32 Asinh,
33 Atan,
34 Atanh,
35 Cbrt,
36 Ceil,
37 Cos,
38 Cosh,
39 Cospi,
40 Erf,
41 Exp,
42 Exp2,
43 Exp2m1,
44 Exp10,
45 Exp10m1,
46 Expm1,
47 Floor,
48 Log,
49 Log2,
50 Log10,
51 Log1p,
52 Mod2PI,
53 ModPIOver2,
54 ModPIOver4,
55 Round,
56 RoundEven,
57 Sin,
58 Sinpi,
59 Sinh,
60 Sqrt,
61 Tan,
62 Tanh,
63 Tanpi,
64 Trunc,
65 EndUnaryOperationsSingleOutput,
67 // Operations which take a single floating point nubmer as input
68 // but produce two outputs. The first ouput is a floating point
69 // number of the same type as the input. The second output is of type
70 // 'int'.
71 BeginUnaryOperationsTwoOutputs,
72 Frexp, // Floating point output, the first output, is the fractional part.
73 EndUnaryOperationsTwoOutputs,
75 // Operations wich take two floating point nubmers of the same type as
76 // input and produce a single floating point number of the same type as
77 // output.
78 BeginBinaryOperationsSingleOutput,
79 Add,
80 Atan2,
81 Div,
82 Fmod,
83 Hypot,
84 Mul,
85 Pow,
86 Sub,
87 EndBinaryOperationsSingleOutput,
89 // Operations which take two floating point numbers of the same type as
90 // input and produce two outputs. The first output is a floating nubmer of
91 // the same type as the inputs. The second output is af type 'int'.
92 BeginBinaryOperationsTwoOutputs,
93 RemQuo, // The first output, the floating point output, is the remainder.
94 EndBinaryOperationsTwoOutputs,
96 // Operations which take three floating point nubmers of the same type as
97 // input and produce a single floating point number of the same type as
98 // output.
99 BeginTernaryOperationsSingleOuput,
100 Fma,
101 EndTernaryOperationsSingleOutput,
104 using LIBC_NAMESPACE::fputil::testing::ForceRoundingMode;
105 using LIBC_NAMESPACE::fputil::testing::RoundingMode;
107 template <typename T> struct BinaryInput {
108 static_assert(
109 LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
110 "Template parameter of BinaryInput must be a floating point type.");
112 using Type = T;
113 T x, y;
116 template <typename T> struct TernaryInput {
117 static_assert(
118 LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
119 "Template parameter of TernaryInput must be a floating point type.");
121 using Type = T;
122 T x, y, z;
125 template <typename T> struct BinaryOutput {
126 T f;
127 int i;
130 namespace internal {
132 template <typename T1, typename T2>
133 struct AreMatchingBinaryInputAndBinaryOutput {
134 static constexpr bool VALUE = false;
137 template <typename T>
138 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
139 static constexpr bool VALUE = cpp::is_floating_point_v<T>;
142 template <typename T> struct IsBinaryInput {
143 static constexpr bool VALUE = false;
146 template <typename T> struct IsBinaryInput<BinaryInput<T>> {
147 static constexpr bool VALUE = true;
150 template <typename T> struct IsTernaryInput {
151 static constexpr bool VALUE = false;
154 template <typename T> struct IsTernaryInput<TernaryInput<T>> {
155 static constexpr bool VALUE = true;
158 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
160 template <typename T>
161 struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
163 template <typename T>
164 struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
166 template <typename InputType, typename OutputType>
167 bool compare_unary_operation_single_output(Operation op, InputType input,
168 OutputType libc_output,
169 double ulp_tolerance,
170 RoundingMode rounding);
171 template <typename T>
172 bool compare_unary_operation_two_outputs(Operation op, T input,
173 const BinaryOutput<T> &libc_output,
174 double ulp_tolerance,
175 RoundingMode rounding);
176 template <typename T>
177 bool compare_binary_operation_two_outputs(Operation op,
178 const BinaryInput<T> &input,
179 const BinaryOutput<T> &libc_output,
180 double ulp_tolerance,
181 RoundingMode rounding);
183 template <typename InputType, typename OutputType>
184 bool compare_binary_operation_one_output(Operation op,
185 const BinaryInput<InputType> &input,
186 OutputType libc_output,
187 double ulp_tolerance,
188 RoundingMode rounding);
190 template <typename InputType, typename OutputType>
191 bool compare_ternary_operation_one_output(Operation op,
192 const TernaryInput<InputType> &input,
193 OutputType libc_output,
194 double ulp_tolerance,
195 RoundingMode rounding);
197 template <typename InputType, typename OutputType>
198 void explain_unary_operation_single_output_error(Operation op, InputType input,
199 OutputType match_value,
200 double ulp_tolerance,
201 RoundingMode rounding);
202 template <typename T>
203 void explain_unary_operation_two_outputs_error(
204 Operation op, T input, const BinaryOutput<T> &match_value,
205 double ulp_tolerance, RoundingMode rounding);
206 template <typename T>
207 void explain_binary_operation_two_outputs_error(
208 Operation op, const BinaryInput<T> &input,
209 const BinaryOutput<T> &match_value, double ulp_tolerance,
210 RoundingMode rounding);
212 template <typename InputType, typename OutputType>
213 void explain_binary_operation_one_output_error(
214 Operation op, const BinaryInput<InputType> &input, OutputType match_value,
215 double ulp_tolerance, RoundingMode rounding);
217 template <typename InputType, typename OutputType>
218 void explain_ternary_operation_one_output_error(
219 Operation op, const TernaryInput<InputType> &input, OutputType match_value,
220 double ulp_tolerance, RoundingMode rounding);
222 template <Operation op, bool silent, typename InputType, typename OutputType>
223 class MPFRMatcher : public testing::Matcher<OutputType> {
224 InputType input;
225 OutputType match_value;
226 double ulp_tolerance;
227 RoundingMode rounding;
229 public:
230 MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
231 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}
233 bool match(OutputType libcResult) {
234 match_value = libcResult;
235 return match(input, match_value);
238 // This method is marked with NOLINT because the name `explainError` does not
239 // conform to the coding style.
240 void explainError() override { // NOLINT
241 explain_error(input, match_value);
244 // Whether the `explainError` step is skipped or not.
245 bool is_silent() const override { return silent; }
247 private:
248 template <typename InType, typename OutType>
249 bool match(InType in, OutType out) {
250 return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
251 rounding);
254 template <typename T> bool match(T in, const BinaryOutput<T> &out) {
255 return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
256 rounding);
259 template <typename T, typename U>
260 bool match(const BinaryInput<T> &in, U out) {
261 return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
262 rounding);
265 template <typename T>
266 bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
267 return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
268 rounding);
271 template <typename InType, typename OutType>
272 bool match(const TernaryInput<InType> &in, OutType out) {
273 return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
274 rounding);
277 template <typename InType, typename OutType>
278 void explain_error(InType in, OutType out) {
279 explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
280 rounding);
283 template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
284 explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
285 rounding);
288 template <typename T>
289 void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
290 explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
291 rounding);
294 template <typename T, typename U>
295 void explain_error(const BinaryInput<T> &in, U out) {
296 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
297 rounding);
300 template <typename InType, typename OutType>
301 void explain_error(const TernaryInput<InType> &in, OutType out) {
302 explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
303 rounding);
307 } // namespace internal
309 // Return true if the input and ouput types for the operation op are valid
310 // types.
311 template <Operation op, typename InputType, typename OutputType>
312 constexpr bool is_valid_operation() {
313 constexpr bool IS_NARROWING_OP =
314 (op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
315 cpp::is_floating_point_v<OutputType> &&
316 sizeof(OutputType) <= sizeof(InputType)) ||
317 (Operation::BeginBinaryOperationsSingleOutput < op &&
318 op < Operation::EndBinaryOperationsSingleOutput &&
319 internal::IsBinaryInput<InputType>::VALUE &&
320 cpp::is_floating_point_v<
321 typename internal::MakeScalarInput<InputType>::type> &&
322 cpp::is_floating_point_v<OutputType>) ||
323 (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
324 cpp::is_floating_point_v<
325 typename internal::MakeScalarInput<InputType>::type> &&
326 cpp::is_floating_point_v<OutputType>);
327 if (IS_NARROWING_OP)
328 return true;
329 return (Operation::BeginUnaryOperationsSingleOutput < op &&
330 op < Operation::EndUnaryOperationsSingleOutput &&
331 cpp::is_same_v<InputType, OutputType> &&
332 cpp::is_floating_point_v<InputType>) ||
333 (Operation::BeginUnaryOperationsTwoOutputs < op &&
334 op < Operation::EndUnaryOperationsTwoOutputs &&
335 cpp::is_floating_point_v<InputType> &&
336 cpp::is_same_v<OutputType, BinaryOutput<InputType>>) ||
337 (Operation::BeginBinaryOperationsSingleOutput < op &&
338 op < Operation::EndBinaryOperationsSingleOutput &&
339 cpp::is_floating_point_v<OutputType> &&
340 cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
341 (Operation::BeginBinaryOperationsTwoOutputs < op &&
342 op < Operation::EndBinaryOperationsTwoOutputs &&
343 internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
344 OutputType>::VALUE) ||
345 (Operation::BeginTernaryOperationsSingleOuput < op &&
346 op < Operation::EndTernaryOperationsSingleOutput &&
347 cpp::is_floating_point_v<OutputType> &&
348 cpp::is_same_v<InputType, TernaryInput<OutputType>>);
351 template <Operation op, typename InputType, typename OutputType>
352 __attribute__((no_sanitize("address"))) cpp::enable_if_t<
353 is_valid_operation<op, InputType, OutputType>(),
354 internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>>
355 get_mpfr_matcher(InputType input, OutputType output_unused,
356 double ulp_tolerance, RoundingMode rounding) {
357 return internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>(
358 input, ulp_tolerance, rounding);
361 template <Operation op, typename InputType, typename OutputType>
362 __attribute__((no_sanitize("address"))) cpp::enable_if_t<
363 is_valid_operation<op, InputType, OutputType>(),
364 internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>>
365 get_silent_mpfr_matcher(InputType input, OutputType output_unused,
366 double ulp_tolerance, RoundingMode rounding) {
367 return internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>(
368 input, ulp_tolerance, rounding);
371 template <typename T> T round(T x, RoundingMode mode);
373 template <typename T> bool round_to_long(T x, long &result);
374 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
376 } // namespace mpfr
377 } // namespace testing
378 } // namespace LIBC_NAMESPACE_DECL
380 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
381 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
382 #define GET_MPFR_DUMMY_ARG(...) 0
384 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME
386 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
387 EXPECT_THAT(match_value, \
388 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
389 input, match_value, ulp_tolerance, \
390 LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
392 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
393 rounding) \
394 EXPECT_THAT(match_value, \
395 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
396 input, match_value, ulp_tolerance, rounding))
398 #define EXPECT_MPFR_MATCH(...) \
399 GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \
400 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
401 (__VA_ARGS__)
403 #define TEST_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
404 rounding) \
405 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(input, match_value, \
406 ulp_tolerance, rounding) \
407 .match(match_value)
409 #define TEST_MPFR_MATCH(...) \
410 GET_MPFR_MACRO(__VA_ARGS__, TEST_MPFR_MATCH_ROUNDING, \
411 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
412 (__VA_ARGS__)
414 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
416 namespace mpfr = LIBC_NAMESPACE::testing::mpfr; \
417 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \
418 if (__r1.success) { \
419 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
420 mpfr::RoundingMode::Nearest); \
422 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \
423 if (__r2.success) { \
424 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
425 mpfr::RoundingMode::Upward); \
427 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \
428 if (__r3.success) { \
429 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
430 mpfr::RoundingMode::Downward); \
432 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \
433 if (__r4.success) { \
434 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
435 mpfr::RoundingMode::TowardZero); \
439 #define TEST_MPFR_MATCH_ROUNDING_SILENTLY(op, input, match_value, \
440 ulp_tolerance, rounding) \
441 LIBC_NAMESPACE::testing::mpfr::get_silent_mpfr_matcher<op>( \
442 input, match_value, ulp_tolerance, rounding) \
443 .match(match_value)
445 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
446 ASSERT_THAT(match_value, \
447 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
448 input, match_value, ulp_tolerance, \
449 LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))
451 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
452 rounding) \
453 ASSERT_THAT(match_value, \
454 LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>( \
455 input, match_value, ulp_tolerance, rounding))
457 #define ASSERT_MPFR_MATCH(...) \
458 GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \
459 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
460 (__VA_ARGS__)
462 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \
464 namespace mpfr = LIBC_NAMESPACE::testing::mpfr; \
465 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \
466 if (__r1.success) { \
467 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
468 mpfr::RoundingMode::Nearest); \
470 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \
471 if (__r2.success) { \
472 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
473 mpfr::RoundingMode::Upward); \
475 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \
476 if (__r3.success) { \
477 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
478 mpfr::RoundingMode::Downward); \
480 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \
481 if (__r4.success) { \
482 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \
483 mpfr::RoundingMode::TowardZero); \
487 #endif // LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H