[MemProf] Templatize CallStackRadixTreeBuilder (NFC) (#117014)
[llvm-project.git] / libc / src / __support / FPUtil / generic / FMA.h
blobbec312e44b1b108786cf3426bbf1b660e9e0d4f0
1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
12 #include "src/__support/CPP/bit.h"
13 #include "src/__support/CPP/limits.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/BasicOperations.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/FPUtil/rounding_mode.h"
20 #include "src/__support/big_int.h"
21 #include "src/__support/macros/attributes.h" // LIBC_INLINE
22 #include "src/__support/macros/config.h"
23 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
25 #include "hdr/fenv_macros.h"
27 namespace LIBC_NAMESPACE_DECL {
28 namespace fputil {
29 namespace generic {
31 template <typename OutType, typename InType>
32 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
33 cpp::is_floating_point_v<InType> &&
34 sizeof(OutType) <= sizeof(InType),
35 OutType>
36 fma(InType x, InType y, InType z);
38 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
39 // The implementation below only is only correct for the default rounding mode,
40 // round-to-nearest tie-to-even.
41 template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
42 // Product is exact.
43 double prod = static_cast<double>(x) * static_cast<double>(y);
44 double z_d = static_cast<double>(z);
45 double sum = prod + z_d;
46 fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
48 if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) {
49 // Since the sum is computed in double precision, rounding might happen
50 // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
51 // bit_prod.exponent > bitz.exponent + 40). In that case, when we round
52 // the sum back to float, double rounding error might occur.
53 // A concrete example of this phenomenon is as follows:
54 // x = y = 1 + 2^(-12), z = 2^(-53)
55 // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
56 // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
57 // On the other hand, with the default rounding mode,
58 // double(x*y + z) = 1 + 2^(-11) + 2^(-24)
59 // and casting again to float gives us:
60 // float(double(x*y + z)) = 1 + 2^(-11).
62 // In order to correct this possible double rounding error, first we use
63 // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
64 // assuming the (default) rounding mode is round-to-the-nearest,
65 // tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
66 // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
67 // occurs when computing the sum, we just need to use t to adjust (any) last
68 // bit of sum, so that the sticky bits used when rounding sum to float are
69 // correct (when it matters).
70 fputil::FPBits<double> t(
71 (bit_prod.get_biased_exponent() >= bitz.get_biased_exponent())
72 ? ((bit_sum.get_val() - bit_prod.get_val()) - bitz.get_val())
73 : ((bit_sum.get_val() - bitz.get_val()) - bit_prod.get_val()));
75 // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
76 // zero.
77 if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) {
78 if (bit_sum.sign() != t.sign())
79 bit_sum.set_mantissa(bit_sum.get_mantissa() + 1);
80 else if (bit_sum.get_mantissa())
81 bit_sum.set_mantissa(bit_sum.get_mantissa() - 1);
85 return static_cast<float>(bit_sum.get_val());
88 namespace internal {
90 // Extract the sticky bits and shift the `mantissa` to the right by
91 // `shift_length`.
92 template <typename T>
93 LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool>
94 shift_mantissa(int shift_length, T &mant) {
95 if (shift_length >= cpp::numeric_limits<T>::digits) {
96 mant = 0;
97 return true; // prod_mant is non-zero.
99 T mask = (T(1) << shift_length) - 1;
100 bool sticky_bits = (mant & mask) != 0;
101 mant >>= shift_length;
102 return sticky_bits;
105 } // namespace internal
107 template <typename OutType, typename InType>
108 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
109 cpp::is_floating_point_v<InType> &&
110 sizeof(OutType) <= sizeof(InType),
111 OutType>
112 fma(InType x, InType y, InType z) {
113 using OutFPBits = FPBits<OutType>;
114 using OutStorageType = typename OutFPBits::StorageType;
115 using InFPBits = FPBits<InType>;
116 using InStorageType = typename InFPBits::StorageType;
118 constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1;
119 constexpr size_t PROD_LEN = 2 * IN_EXPLICIT_MANT_LEN;
120 constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1);
121 using TmpResultType = UInt<TMP_RESULT_LEN>;
122 using DyadicFloat = DyadicFloat<TMP_RESULT_LEN>;
124 InFPBits x_bits(x), y_bits(y), z_bits(z);
126 if (LIBC_UNLIKELY(x_bits.is_nan() || y_bits.is_nan() || z_bits.is_nan())) {
127 if (x_bits.is_nan() || y_bits.is_nan()) {
128 if (x_bits.is_signaling_nan() || y_bits.is_signaling_nan() ||
129 z_bits.is_signaling_nan())
130 raise_except_if_required(FE_INVALID);
132 if (x_bits.is_quiet_nan()) {
133 InStorageType x_payload = x_bits.get_mantissa();
134 x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
135 return OutFPBits::quiet_nan(x_bits.sign(),
136 static_cast<OutStorageType>(x_payload))
137 .get_val();
140 if (y_bits.is_quiet_nan()) {
141 InStorageType y_payload = y_bits.get_mantissa();
142 y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
143 return OutFPBits::quiet_nan(y_bits.sign(),
144 static_cast<OutStorageType>(y_payload))
145 .get_val();
148 if (z_bits.is_quiet_nan()) {
149 InStorageType z_payload = z_bits.get_mantissa();
150 z_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
151 return OutFPBits::quiet_nan(z_bits.sign(),
152 static_cast<OutStorageType>(z_payload))
153 .get_val();
156 return OutFPBits::quiet_nan().get_val();
160 if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0))
161 return cast<OutType>(x * y + z);
163 int x_exp = 0;
164 int y_exp = 0;
165 int z_exp = 0;
167 // Denormal scaling = 2^(fraction length).
168 constexpr InStorageType IMPLICIT_MASK =
169 InFPBits::SIG_MASK - InFPBits::FRACTION_MASK;
171 constexpr InType DENORMAL_SCALING =
172 InFPBits::create_value(
173 Sign::POS, InFPBits::FRACTION_LEN + InFPBits::EXP_BIAS, IMPLICIT_MASK)
174 .get_val();
176 // Normalize denormal inputs.
177 if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
178 x_exp -= InFPBits::FRACTION_LEN;
179 x *= DENORMAL_SCALING;
181 if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
182 y_exp -= InFPBits::FRACTION_LEN;
183 y *= DENORMAL_SCALING;
185 if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
186 z_exp -= InFPBits::FRACTION_LEN;
187 z *= DENORMAL_SCALING;
190 x_bits = InFPBits(x);
191 y_bits = InFPBits(y);
192 z_bits = InFPBits(z);
193 const Sign z_sign = z_bits.sign();
194 Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
195 x_exp += x_bits.get_biased_exponent();
196 y_exp += y_bits.get_biased_exponent();
197 z_exp += z_bits.get_biased_exponent();
199 if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT ||
200 y_exp == InFPBits::MAX_BIASED_EXPONENT ||
201 z_exp == InFPBits::MAX_BIASED_EXPONENT))
202 return cast<OutType>(x * y + z);
204 // Extract mantissa and append hidden leading bits.
205 InStorageType x_mant = x_bits.get_explicit_mantissa();
206 InStorageType y_mant = y_bits.get_explicit_mantissa();
207 TmpResultType z_mant = z_bits.get_explicit_mantissa();
209 // If the exponent of the product x*y > the exponent of z, then no extra
210 // precision beside the entire product x*y is needed. On the other hand, when
211 // the exponent of z >= the exponent of the product x*y, the worst-case that
212 // we need extra precision is when there is cancellation and the most
213 // significant bit of the product is aligned exactly with the second most
214 // significant bit of z:
215 // z : 10aa...a
216 // - prod : 1bb...bb....b
217 // In that case, in order to store the exact result, we need at least
218 // (Length of prod) - (Fraction length of z)
219 // = 2*(Length of input explicit mantissa) - (Fraction length of z) bits.
220 // Overall, before aligning the mantissas and exponents, we can simply left-
221 // shift the mantissa of z by that amount. After that, it is enough to align
222 // the least significant bit, given that we keep track of the round and sticky
223 // bits after the least significant bit.
225 TmpResultType prod_mant = TmpResultType(x_mant) * y_mant;
226 int prod_lsb_exp =
227 x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN);
229 constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
230 z_mant <<= RESULT_MIN_LEN;
231 int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
232 bool sticky_bits = false;
233 bool z_shifted = false;
235 // Align exponents.
236 if (prod_lsb_exp < z_lsb_exp) {
237 sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant);
238 prod_lsb_exp = z_lsb_exp;
239 } else if (z_lsb_exp < prod_lsb_exp) {
240 z_shifted = true;
241 sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant);
244 // Perform the addition:
245 // (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
246 // The final result will be stored in prod_sign and prod_mant.
247 if (prod_sign == z_sign) {
248 // Effectively an addition.
249 prod_mant += z_mant;
250 } else {
251 // Subtraction cases.
252 if (prod_mant >= z_mant) {
253 if (z_shifted && sticky_bits) {
254 // Add 1 more to the subtrahend so that the sticky bits remain
255 // positive. This would simplify the rounding logic.
256 ++z_mant;
258 prod_mant -= z_mant;
259 } else {
260 if (!z_shifted && sticky_bits) {
261 // Add 1 more to the subtrahend so that the sticky bits remain
262 // positive. This would simplify the rounding logic.
263 ++prod_mant;
265 prod_mant = z_mant - prod_mant;
266 prod_sign = z_sign;
270 if (prod_mant == 0) {
271 // When there is exact cancellation, i.e., x*y == -z exactly, return -0.0 if
272 // rounding downward and +0.0 for other rounding modes.
273 if (quick_get_round() == FE_DOWNWARD)
274 prod_sign = Sign::NEG;
275 else
276 prod_sign = Sign::POS;
279 DyadicFloat result(prod_sign, prod_lsb_exp - InFPBits::EXP_BIAS, prod_mant);
280 result.mantissa |= static_cast<unsigned int>(sticky_bits);
281 return result.template as<OutType, /*ShouldSignalExceptions=*/true>();
284 } // namespace generic
285 } // namespace fputil
286 } // namespace LIBC_NAMESPACE_DECL
288 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H