[libc] Add platform independent floating point rounding mode checks.
[llvm-project.git] / libc / src / __support / FPUtil / generic / sqrt.h
blob9e9896ed185fe6f4236640a04932c79341f020ee
1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/PlatformDefs.h"
18 #include "src/__support/FPUtil/rounding_mode.h"
19 #include "src/__support/UInt128.h"
20 #include "src/__support/builtin_wrappers.h"
21 #include "src/__support/common.h"
23 namespace __llvm_libc {
24 namespace fputil {
26 namespace internal {
28 template <typename T> struct SpecialLongDouble {
29 static constexpr bool VALUE = false;
32 #if defined(SPECIAL_X86_LONG_DOUBLE)
33 template <> struct SpecialLongDouble<long double> {
34 static constexpr bool VALUE = true;
36 #endif // SPECIAL_X86_LONG_DOUBLE
38 template <typename T>
39 LIBC_INLINE void normalize(int &exponent,
40 typename FPBits<T>::UIntType &mantissa) {
41 const int shift = unsafe_clz(mantissa) -
42 (8 * sizeof(mantissa) - 1 - MantissaWidth<T>::VALUE);
43 exponent -= shift;
44 mantissa <<= shift;
47 #ifdef LONG_DOUBLE_IS_DOUBLE
48 template <>
49 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
50 normalize<double>(exponent, mantissa);
52 #elif !defined(SPECIAL_X86_LONG_DOUBLE)
53 template <>
54 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
55 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
56 const int shift = hi_bits
57 ? (unsafe_clz(hi_bits) - 15)
58 : (unsafe_clz(static_cast<uint64_t>(mantissa)) + 49);
59 exponent -= shift;
60 mantissa <<= shift;
62 #endif
64 } // namespace internal
66 // Correctly rounded IEEE 754 SQRT for all rounding modes.
67 // Shift-and-add algorithm.
68 template <typename T>
69 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
71 if constexpr (internal::SpecialLongDouble<T>::VALUE) {
72 // Special 80-bit long double.
73 return x86::sqrt(x);
74 } else {
75 // IEEE floating points formats.
76 using UIntType = typename FPBits<T>::UIntType;
77 constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE;
79 FPBits<T> bits(x);
81 if (bits.is_inf_or_nan()) {
82 if (bits.get_sign() && (bits.get_mantissa() == 0)) {
83 // sqrt(-Inf) = NaN
84 return FPBits<T>::build_quiet_nan(ONE >> 1);
85 } else {
86 // sqrt(NaN) = NaN
87 // sqrt(+Inf) = +Inf
88 return x;
90 } else if (bits.is_zero()) {
91 // sqrt(+0) = +0
92 // sqrt(-0) = -0
93 return x;
94 } else if (bits.get_sign()) {
95 // sqrt( negative numbers ) = NaN
96 return FPBits<T>::build_quiet_nan(ONE >> 1);
97 } else {
98 int x_exp = bits.get_exponent();
99 UIntType x_mant = bits.get_mantissa();
101 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
102 if (bits.get_unbiased_exponent() == 0) {
103 ++x_exp; // let x_exp be the correct exponent of ONE bit.
104 internal::normalize<T>(x_exp, x_mant);
105 } else {
106 x_mant |= ONE;
109 // Step 1b: Make sure the exponent is even.
110 if (x_exp & 1) {
111 --x_exp;
112 x_mant <<= 1;
115 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
116 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
117 // Notice that the output of sqrt is always in the normal range.
118 // To perform shift-and-add algorithm to find y, let denote:
119 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
120 // r(n) = 2^n ( x_mant - y(n)^2 ).
121 // That leads to the following recurrence formula:
122 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
123 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
124 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
125 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
126 // 0 otherwise.
127 UIntType y = ONE;
128 UIntType r = x_mant - ONE;
130 for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
131 r <<= 1;
132 UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
133 if (r >= tmp) {
134 r -= tmp;
135 y += current_bit;
139 // We compute one more iteration in order to round correctly.
140 bool lsb = y & 1; // Least significant bit
141 bool rb = false; // Round bit
142 r <<= 2;
143 UIntType tmp = (y << 2) + 1;
144 if (r >= tmp) {
145 r -= tmp;
146 rb = true;
149 // Remove hidden bit and append the exponent field.
150 x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS);
152 y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE);
154 switch (quick_get_round()) {
155 case FE_TONEAREST:
156 // Round to nearest, ties to even
157 if (rb && (lsb || (r != 0)))
158 ++y;
159 break;
160 case FE_UPWARD:
161 if (rb || (r != 0))
162 ++y;
163 break;
166 return cpp::bit_cast<T>(y);
171 } // namespace fputil
172 } // namespace __llvm_libc
174 #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_SQRT_H