1 //===-- Utility class to test sqrt[f|l] -------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "src/__support/CPP/bit.h"
10 #include "test/UnitTest/FPMatcher.h"
11 #include "test/UnitTest/Test.h"
12 #include "utils/MPFRWrapper/MPFRUtils.h"
16 namespace mpfr
= __llvm_libc::testing::mpfr
;
18 template <typename T
> class SqrtTest
: public __llvm_libc::testing::Test
{
20 DECLARE_SPECIAL_CONSTANTS(T
)
22 static constexpr UIntType HIDDEN_BIT
=
23 UIntType(1) << __llvm_libc::fputil::MantissaWidth
<T
>::VALUE
;
26 typedef T (*SqrtFunc
)(T
);
28 void test_special_numbers(SqrtFunc func
) {
29 ASSERT_FP_EQ(aNaN
, func(aNaN
));
30 ASSERT_FP_EQ(inf
, func(inf
));
31 ASSERT_FP_EQ(aNaN
, func(neg_inf
));
32 ASSERT_FP_EQ(0.0, func(0.0));
33 ASSERT_FP_EQ(-0.0, func(-0.0));
34 ASSERT_FP_EQ(aNaN
, func(T(-1.0)));
35 ASSERT_FP_EQ(T(1.0), func(T(1.0)));
36 ASSERT_FP_EQ(T(2.0), func(T(4.0)));
37 ASSERT_FP_EQ(T(3.0), func(T(9.0)));
40 void test_denormal_values(SqrtFunc func
) {
41 for (UIntType mant
= 1; mant
< HIDDEN_BIT
; mant
<<= 1) {
42 FPBits
denormal(T(0.0));
43 denormal
.set_mantissa(mant
);
45 test_all_rounding_modes(func
, T(denormal
));
48 constexpr UIntType COUNT
= 1'000'001;
49 constexpr UIntType STEP
= HIDDEN_BIT
/ COUNT
;
50 for (UIntType i
= 0, v
= 0; i
<= COUNT
; ++i
, v
+= STEP
) {
51 T x
= __llvm_libc::cpp::bit_cast
<T
>(v
);
52 test_all_rounding_modes(func
, x
);
56 void test_normal_range(SqrtFunc func
) {
57 constexpr UIntType COUNT
= 10'000'001;
58 constexpr UIntType STEP
= UIntType(-1) / COUNT
;
59 for (UIntType i
= 0, v
= 0; i
<= COUNT
; ++i
, v
+= STEP
) {
60 T x
= __llvm_libc::cpp::bit_cast
<T
>(v
);
61 if (isnan(x
) || (x
< 0)) {
64 test_all_rounding_modes(func
, x
);
68 void test_all_rounding_modes(SqrtFunc func
, T x
) {
69 mpfr::ForceRoundingMode
r1(mpfr::RoundingMode::Nearest
);
70 EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt
, x
, func(x
), 0.5,
71 mpfr::RoundingMode::Nearest
);
73 mpfr::ForceRoundingMode
r2(mpfr::RoundingMode::Upward
);
74 EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt
, x
, func(x
), 0.5,
75 mpfr::RoundingMode::Upward
);
77 mpfr::ForceRoundingMode
r3(mpfr::RoundingMode::Downward
);
78 EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt
, x
, func(x
), 0.5,
79 mpfr::RoundingMode::Downward
);
81 mpfr::ForceRoundingMode
r4(mpfr::RoundingMode::TowardZero
);
82 EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt
, x
, func(x
), 0.5,
83 mpfr::RoundingMode::TowardZero
);
87 #define LIST_SQRT_TESTS(T, func) \
88 using LlvmLibcSqrtTest = SqrtTest<T>; \
89 TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); } \
90 TEST_F(LlvmLibcSqrtTest, DenormalValues) { test_denormal_values(&func); } \
91 TEST_F(LlvmLibcSqrtTest, NormalRange) { test_normal_range(&func); }