[MemProf] Templatize CallStackRadixTreeBuilder (NFC) (#117014)
[llvm-project.git] / libc / src / __support / FPUtil / generic / sqrt.h
blob497ebd145c6b425d8925c764294539e68eebd3f3
1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h" // countl_zero
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/common.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/uint128.h"
23 #include "hdr/fenv_macros.h"
25 namespace LIBC_NAMESPACE_DECL {
26 namespace fputil {
28 namespace internal {
30 template <typename T> struct SpecialLongDouble {
31 static constexpr bool VALUE = false;
34 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
35 template <> struct SpecialLongDouble<long double> {
36 static constexpr bool VALUE = true;
38 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
40 template <typename T>
41 LIBC_INLINE void normalize(int &exponent,
42 typename FPBits<T>::StorageType &mantissa) {
43 const int shift =
44 cpp::countl_zero(mantissa) -
45 (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
46 exponent -= shift;
47 mantissa <<= shift;
50 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
51 template <>
52 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
53 normalize<double>(exponent, mantissa);
55 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
56 template <>
57 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
58 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
59 const int shift =
60 hi_bits ? (cpp::countl_zero(hi_bits) - 15)
61 : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
62 exponent -= shift;
63 mantissa <<= shift;
65 #endif
67 } // namespace internal
69 // Correctly rounded IEEE 754 SQRT for all rounding modes.
70 // Shift-and-add algorithm.
71 template <typename OutType, typename InType>
72 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
73 cpp::is_floating_point_v<InType> &&
74 sizeof(OutType) <= sizeof(InType),
75 OutType>
76 sqrt(InType x) {
77 if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
78 internal::SpecialLongDouble<InType>::VALUE) {
79 // Special 80-bit long double.
80 return x86::sqrt(x);
81 } else {
82 // IEEE floating points formats.
83 using OutFPBits = FPBits<OutType>;
84 using InFPBits = FPBits<InType>;
85 using InStorageType = typename InFPBits::StorageType;
86 using DyadicFloat =
87 DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>;
89 constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
90 constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
92 InFPBits bits(x);
94 if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
95 // sqrt(+Inf) = +Inf
96 // sqrt(+0) = +0
97 // sqrt(-0) = -0
98 // sqrt(NaN) = NaN
99 // sqrt(-NaN) = -NaN
100 return cast<OutType>(x);
101 } else if (bits.is_neg()) {
102 // sqrt(-Inf) = NaN
103 // sqrt(-x) = NaN
104 return FLT_NAN;
105 } else {
106 int x_exp = bits.get_exponent();
107 InStorageType x_mant = bits.get_mantissa();
109 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
110 if (bits.is_subnormal()) {
111 ++x_exp; // let x_exp be the correct exponent of ONE bit.
112 internal::normalize<InType>(x_exp, x_mant);
113 } else {
114 x_mant |= ONE;
117 // Step 1b: Make sure the exponent is even.
118 if (x_exp & 1) {
119 --x_exp;
120 x_mant <<= 1;
123 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
124 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
125 // Notice that the output of sqrt is always in the normal range.
126 // To perform shift-and-add algorithm to find y, let denote:
127 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
128 // r(n) = 2^n ( x_mant - y(n)^2 ).
129 // That leads to the following recurrence formula:
130 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
131 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
132 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
133 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
134 // 0 otherwise.
135 InStorageType y = ONE;
136 InStorageType r = x_mant - ONE;
138 // TODO: Reduce iteration count to OutFPBits::FRACTION_LEN + 2 or + 3.
139 for (InStorageType current_bit = ONE >> 1; current_bit;
140 current_bit >>= 1) {
141 r <<= 1;
142 // 2*y(n - 1) + 2^(-n-1)
143 InStorageType tmp = static_cast<InStorageType>((y << 1) + current_bit);
144 if (r >= tmp) {
145 r -= tmp;
146 y += current_bit;
150 // We compute one more iteration in order to round correctly.
151 r <<= 2;
152 y <<= 2;
153 InStorageType tmp = y + 1;
154 if (r >= tmp) {
155 r -= tmp;
156 // Rounding bit.
157 y |= 2;
159 // Sticky bit.
160 y |= static_cast<unsigned int>(r != 0);
162 DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
163 return yd.template as<OutType, /*ShouldSignalExceptions=*/true>();
168 } // namespace fputil
169 } // namespace LIBC_NAMESPACE_DECL
171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H