1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h" // countl_zero
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/common.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/uint128.h"
23 #include "hdr/fenv_macros.h"
25 namespace LIBC_NAMESPACE_DECL
{
30 template <typename T
> struct SpecialLongDouble
{
31 static constexpr bool VALUE
= false;
34 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
35 template <> struct SpecialLongDouble
<long double> {
36 static constexpr bool VALUE
= true;
38 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
41 LIBC_INLINE
void normalize(int &exponent
,
42 typename FPBits
<T
>::StorageType
&mantissa
) {
44 cpp::countl_zero(mantissa
) -
45 (8 * static_cast<int>(sizeof(mantissa
)) - 1 - FPBits
<T
>::FRACTION_LEN
);
50 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
52 LIBC_INLINE
void normalize
<long double>(int &exponent
, uint64_t &mantissa
) {
53 normalize
<double>(exponent
, mantissa
);
55 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
57 LIBC_INLINE
void normalize
<long double>(int &exponent
, UInt128
&mantissa
) {
58 const uint64_t hi_bits
= static_cast<uint64_t>(mantissa
>> 64);
60 hi_bits
? (cpp::countl_zero(hi_bits
) - 15)
61 : (cpp::countl_zero(static_cast<uint64_t>(mantissa
)) + 49);
67 } // namespace internal
69 // Correctly rounded IEEE 754 SQRT for all rounding modes.
70 // Shift-and-add algorithm.
71 template <typename OutType
, typename InType
>
72 LIBC_INLINE
cpp::enable_if_t
<cpp::is_floating_point_v
<OutType
> &&
73 cpp::is_floating_point_v
<InType
> &&
74 sizeof(OutType
) <= sizeof(InType
),
77 if constexpr (internal::SpecialLongDouble
<OutType
>::VALUE
&&
78 internal::SpecialLongDouble
<InType
>::VALUE
) {
79 // Special 80-bit long double.
82 // IEEE floating points formats.
83 using OutFPBits
= FPBits
<OutType
>;
84 using InFPBits
= FPBits
<InType
>;
85 using InStorageType
= typename
InFPBits::StorageType
;
87 DyadicFloat
<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN
))>;
89 constexpr InStorageType ONE
= InStorageType(1) << InFPBits::FRACTION_LEN
;
90 constexpr auto FLT_NAN
= OutFPBits::quiet_nan().get_val();
94 if (bits
== InFPBits::inf(Sign::POS
) || bits
.is_zero() || bits
.is_nan()) {
100 return cast
<OutType
>(x
);
101 } else if (bits
.is_neg()) {
106 int x_exp
= bits
.get_exponent();
107 InStorageType x_mant
= bits
.get_mantissa();
109 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
110 if (bits
.is_subnormal()) {
111 ++x_exp
; // let x_exp be the correct exponent of ONE bit.
112 internal::normalize
<InType
>(x_exp
, x_mant
);
117 // Step 1b: Make sure the exponent is even.
123 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
124 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
125 // Notice that the output of sqrt is always in the normal range.
126 // To perform shift-and-add algorithm to find y, let denote:
127 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
128 // r(n) = 2^n ( x_mant - y(n)^2 ).
129 // That leads to the following recurrence formula:
130 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
131 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
132 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
133 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
135 InStorageType y
= ONE
;
136 InStorageType r
= x_mant
- ONE
;
138 // TODO: Reduce iteration count to OutFPBits::FRACTION_LEN + 2 or + 3.
139 for (InStorageType current_bit
= ONE
>> 1; current_bit
;
142 // 2*y(n - 1) + 2^(-n-1)
143 InStorageType tmp
= static_cast<InStorageType
>((y
<< 1) + current_bit
);
150 // We compute one more iteration in order to round correctly.
153 InStorageType tmp
= y
+ 1;
160 y
|= static_cast<unsigned int>(r
!= 0);
162 DyadicFloat
yd(Sign::POS
, (x_exp
>> 1) - 2 - InFPBits::FRACTION_LEN
, y
);
163 return yd
.template as
<OutType
, /*ShouldSignalExceptions=*/true>();
168 } // namespace fputil
169 } // namespace LIBC_NAMESPACE_DECL
171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H