1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/PlatformDefs.h"
18 #include "src/__support/FPUtil/rounding_mode.h"
19 #include "src/__support/UInt128.h"
20 #include "src/__support/builtin_wrappers.h"
21 #include "src/__support/common.h"
23 namespace __llvm_libc
{
28 template <typename T
> struct SpecialLongDouble
{
29 static constexpr bool VALUE
= false;
32 #if defined(SPECIAL_X86_LONG_DOUBLE)
33 template <> struct SpecialLongDouble
<long double> {
34 static constexpr bool VALUE
= true;
36 #endif // SPECIAL_X86_LONG_DOUBLE
39 LIBC_INLINE
void normalize(int &exponent
,
40 typename FPBits
<T
>::UIntType
&mantissa
) {
41 const int shift
= unsafe_clz(mantissa
) -
42 (8 * sizeof(mantissa
) - 1 - MantissaWidth
<T
>::VALUE
);
47 #ifdef LONG_DOUBLE_IS_DOUBLE
49 LIBC_INLINE
void normalize
<long double>(int &exponent
, uint64_t &mantissa
) {
50 normalize
<double>(exponent
, mantissa
);
52 #elif !defined(SPECIAL_X86_LONG_DOUBLE)
54 LIBC_INLINE
void normalize
<long double>(int &exponent
, UInt128
&mantissa
) {
55 const uint64_t hi_bits
= static_cast<uint64_t>(mantissa
>> 64);
56 const int shift
= hi_bits
57 ? (unsafe_clz(hi_bits
) - 15)
58 : (unsafe_clz(static_cast<uint64_t>(mantissa
)) + 49);
64 } // namespace internal
66 // Correctly rounded IEEE 754 SQRT for all rounding modes.
67 // Shift-and-add algorithm.
69 LIBC_INLINE
cpp::enable_if_t
<cpp::is_floating_point_v
<T
>, T
> sqrt(T x
) {
71 if constexpr (internal::SpecialLongDouble
<T
>::VALUE
) {
72 // Special 80-bit long double.
75 // IEEE floating points formats.
76 using UIntType
= typename FPBits
<T
>::UIntType
;
77 constexpr UIntType ONE
= UIntType(1) << MantissaWidth
<T
>::VALUE
;
81 if (bits
.is_inf_or_nan()) {
82 if (bits
.get_sign() && (bits
.get_mantissa() == 0)) {
84 return FPBits
<T
>::build_quiet_nan(ONE
>> 1);
90 } else if (bits
.is_zero()) {
94 } else if (bits
.get_sign()) {
95 // sqrt( negative numbers ) = NaN
96 return FPBits
<T
>::build_quiet_nan(ONE
>> 1);
98 int x_exp
= bits
.get_exponent();
99 UIntType x_mant
= bits
.get_mantissa();
101 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
102 if (bits
.get_unbiased_exponent() == 0) {
103 ++x_exp
; // let x_exp be the correct exponent of ONE bit.
104 internal::normalize
<T
>(x_exp
, x_mant
);
109 // Step 1b: Make sure the exponent is even.
115 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
116 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
117 // Notice that the output of sqrt is always in the normal range.
118 // To perform shift-and-add algorithm to find y, let denote:
119 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
120 // r(n) = 2^n ( x_mant - y(n)^2 ).
121 // That leads to the following recurrence formula:
122 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
123 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
124 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
125 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
128 UIntType r
= x_mant
- ONE
;
130 for (UIntType current_bit
= ONE
>> 1; current_bit
; current_bit
>>= 1) {
132 UIntType tmp
= (y
<< 1) + current_bit
; // 2*y(n - 1) + 2^(-n-1)
139 // We compute one more iteration in order to round correctly.
140 bool lsb
= static_cast<bool>(y
& 1); // Least significant bit
141 bool rb
= false; // Round bit
143 UIntType tmp
= (y
<< 2) + 1;
149 // Remove hidden bit and append the exponent field.
150 x_exp
= ((x_exp
>> 1) + FPBits
<T
>::EXPONENT_BIAS
);
152 y
= (y
- ONE
) | (static_cast<UIntType
>(x_exp
) << MantissaWidth
<T
>::VALUE
);
154 switch (quick_get_round()) {
156 // Round to nearest, ties to even
157 if (rb
&& (lsb
|| (r
!= 0)))
166 return cpp::bit_cast
<T
>(y
);
171 } // namespace fputil
172 } // namespace __llvm_libc
174 #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H