1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
12 #include "src/__support/CPP/bit.h"
13 #include "src/__support/CPP/limits.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/BasicOperations.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/FPUtil/rounding_mode.h"
20 #include "src/__support/big_int.h"
21 #include "src/__support/macros/attributes.h" // LIBC_INLINE
22 #include "src/__support/macros/config.h"
23 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
25 #include "hdr/fenv_macros.h"
27 namespace LIBC_NAMESPACE_DECL
{
31 template <typename OutType
, typename InType
>
32 LIBC_INLINE
cpp::enable_if_t
<cpp::is_floating_point_v
<OutType
> &&
33 cpp::is_floating_point_v
<InType
> &&
34 sizeof(OutType
) <= sizeof(InType
),
36 fma(InType x
, InType y
, InType z
);
38 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
39 // The implementation below only is only correct for the default rounding mode,
40 // round-to-nearest tie-to-even.
41 template <> LIBC_INLINE
float fma
<float>(float x
, float y
, float z
) {
43 double prod
= static_cast<double>(x
) * static_cast<double>(y
);
44 double z_d
= static_cast<double>(z
);
45 double sum
= prod
+ z_d
;
46 fputil::FPBits
<double> bit_prod(prod
), bitz(z_d
), bit_sum(sum
);
48 if (!(bit_sum
.is_inf_or_nan() || bit_sum
.is_zero())) {
49 // Since the sum is computed in double precision, rounding might happen
50 // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
51 // bit_prod.exponent > bitz.exponent + 40). In that case, when we round
52 // the sum back to float, double rounding error might occur.
53 // A concrete example of this phenomenon is as follows:
54 // x = y = 1 + 2^(-12), z = 2^(-53)
55 // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
56 // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
57 // On the other hand, with the default rounding mode,
58 // double(x*y + z) = 1 + 2^(-11) + 2^(-24)
59 // and casting again to float gives us:
60 // float(double(x*y + z)) = 1 + 2^(-11).
62 // In order to correct this possible double rounding error, first we use
63 // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
64 // assuming the (default) rounding mode is round-to-the-nearest,
65 // tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
66 // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
67 // occurs when computing the sum, we just need to use t to adjust (any) last
68 // bit of sum, so that the sticky bits used when rounding sum to float are
69 // correct (when it matters).
70 fputil::FPBits
<double> t(
71 (bit_prod
.get_biased_exponent() >= bitz
.get_biased_exponent())
72 ? ((bit_sum
.get_val() - bit_prod
.get_val()) - bitz
.get_val())
73 : ((bit_sum
.get_val() - bitz
.get_val()) - bit_prod
.get_val()));
75 // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
77 if (!t
.is_zero() && ((bit_sum
.get_mantissa() & 0xfff'ffffULL
) == 0)) {
78 if (bit_sum
.sign() != t
.sign())
79 bit_sum
.set_mantissa(bit_sum
.get_mantissa() + 1);
80 else if (bit_sum
.get_mantissa())
81 bit_sum
.set_mantissa(bit_sum
.get_mantissa() - 1);
85 return static_cast<float>(bit_sum
.get_val());
90 // Extract the sticky bits and shift the `mantissa` to the right by
93 LIBC_INLINE
cpp::enable_if_t
<is_unsigned_integral_or_big_int_v
<T
>, bool>
94 shift_mantissa(int shift_length
, T
&mant
) {
95 if (shift_length
>= cpp::numeric_limits
<T
>::digits
) {
97 return true; // prod_mant is non-zero.
99 T mask
= (T(1) << shift_length
) - 1;
100 bool sticky_bits
= (mant
& mask
) != 0;
101 mant
>>= shift_length
;
105 } // namespace internal
107 template <typename OutType
, typename InType
>
108 LIBC_INLINE
cpp::enable_if_t
<cpp::is_floating_point_v
<OutType
> &&
109 cpp::is_floating_point_v
<InType
> &&
110 sizeof(OutType
) <= sizeof(InType
),
112 fma(InType x
, InType y
, InType z
) {
113 using OutFPBits
= FPBits
<OutType
>;
114 using OutStorageType
= typename
OutFPBits::StorageType
;
115 using InFPBits
= FPBits
<InType
>;
116 using InStorageType
= typename
InFPBits::StorageType
;
118 constexpr int IN_EXPLICIT_MANT_LEN
= InFPBits::FRACTION_LEN
+ 1;
119 constexpr size_t PROD_LEN
= 2 * IN_EXPLICIT_MANT_LEN
;
120 constexpr size_t TMP_RESULT_LEN
= cpp::bit_ceil(PROD_LEN
+ 1);
121 using TmpResultType
= UInt
<TMP_RESULT_LEN
>;
122 using DyadicFloat
= DyadicFloat
<TMP_RESULT_LEN
>;
124 InFPBits
x_bits(x
), y_bits(y
), z_bits(z
);
126 if (LIBC_UNLIKELY(x_bits
.is_nan() || y_bits
.is_nan() || z_bits
.is_nan())) {
127 if (x_bits
.is_nan() || y_bits
.is_nan()) {
128 if (x_bits
.is_signaling_nan() || y_bits
.is_signaling_nan() ||
129 z_bits
.is_signaling_nan())
130 raise_except_if_required(FE_INVALID
);
132 if (x_bits
.is_quiet_nan()) {
133 InStorageType x_payload
= x_bits
.get_mantissa();
134 x_payload
>>= InFPBits::FRACTION_LEN
- OutFPBits::FRACTION_LEN
;
135 return OutFPBits::quiet_nan(x_bits
.sign(),
136 static_cast<OutStorageType
>(x_payload
))
140 if (y_bits
.is_quiet_nan()) {
141 InStorageType y_payload
= y_bits
.get_mantissa();
142 y_payload
>>= InFPBits::FRACTION_LEN
- OutFPBits::FRACTION_LEN
;
143 return OutFPBits::quiet_nan(y_bits
.sign(),
144 static_cast<OutStorageType
>(y_payload
))
148 if (z_bits
.is_quiet_nan()) {
149 InStorageType z_payload
= z_bits
.get_mantissa();
150 z_payload
>>= InFPBits::FRACTION_LEN
- OutFPBits::FRACTION_LEN
;
151 return OutFPBits::quiet_nan(z_bits
.sign(),
152 static_cast<OutStorageType
>(z_payload
))
156 return OutFPBits::quiet_nan().get_val();
160 if (LIBC_UNLIKELY(x
== 0 || y
== 0 || z
== 0))
161 return cast
<OutType
>(x
* y
+ z
);
167 // Denormal scaling = 2^(fraction length).
168 constexpr InStorageType IMPLICIT_MASK
=
169 InFPBits::SIG_MASK
- InFPBits::FRACTION_MASK
;
171 constexpr InType DENORMAL_SCALING
=
172 InFPBits::create_value(
173 Sign::POS
, InFPBits::FRACTION_LEN
+ InFPBits::EXP_BIAS
, IMPLICIT_MASK
)
176 // Normalize denormal inputs.
177 if (LIBC_UNLIKELY(InFPBits(x
).is_subnormal())) {
178 x_exp
-= InFPBits::FRACTION_LEN
;
179 x
*= DENORMAL_SCALING
;
181 if (LIBC_UNLIKELY(InFPBits(y
).is_subnormal())) {
182 y_exp
-= InFPBits::FRACTION_LEN
;
183 y
*= DENORMAL_SCALING
;
185 if (LIBC_UNLIKELY(InFPBits(z
).is_subnormal())) {
186 z_exp
-= InFPBits::FRACTION_LEN
;
187 z
*= DENORMAL_SCALING
;
190 x_bits
= InFPBits(x
);
191 y_bits
= InFPBits(y
);
192 z_bits
= InFPBits(z
);
193 const Sign z_sign
= z_bits
.sign();
194 Sign prod_sign
= (x_bits
.sign() == y_bits
.sign()) ? Sign::POS
: Sign::NEG
;
195 x_exp
+= x_bits
.get_biased_exponent();
196 y_exp
+= y_bits
.get_biased_exponent();
197 z_exp
+= z_bits
.get_biased_exponent();
199 if (LIBC_UNLIKELY(x_exp
== InFPBits::MAX_BIASED_EXPONENT
||
200 y_exp
== InFPBits::MAX_BIASED_EXPONENT
||
201 z_exp
== InFPBits::MAX_BIASED_EXPONENT
))
202 return cast
<OutType
>(x
* y
+ z
);
204 // Extract mantissa and append hidden leading bits.
205 InStorageType x_mant
= x_bits
.get_explicit_mantissa();
206 InStorageType y_mant
= y_bits
.get_explicit_mantissa();
207 TmpResultType z_mant
= z_bits
.get_explicit_mantissa();
209 // If the exponent of the product x*y > the exponent of z, then no extra
210 // precision beside the entire product x*y is needed. On the other hand, when
211 // the exponent of z >= the exponent of the product x*y, the worst-case that
212 // we need extra precision is when there is cancellation and the most
213 // significant bit of the product is aligned exactly with the second most
214 // significant bit of z:
216 // - prod : 1bb...bb....b
217 // In that case, in order to store the exact result, we need at least
218 // (Length of prod) - (Fraction length of z)
219 // = 2*(Length of input explicit mantissa) - (Fraction length of z) bits.
220 // Overall, before aligning the mantissas and exponents, we can simply left-
221 // shift the mantissa of z by that amount. After that, it is enough to align
222 // the least significant bit, given that we keep track of the round and sticky
223 // bits after the least significant bit.
225 TmpResultType prod_mant
= TmpResultType(x_mant
) * y_mant
;
227 x_exp
+ y_exp
- (InFPBits::EXP_BIAS
+ 2 * InFPBits::FRACTION_LEN
);
229 constexpr int RESULT_MIN_LEN
= PROD_LEN
- InFPBits::FRACTION_LEN
;
230 z_mant
<<= RESULT_MIN_LEN
;
231 int z_lsb_exp
= z_exp
- (InFPBits::FRACTION_LEN
+ RESULT_MIN_LEN
);
232 bool sticky_bits
= false;
233 bool z_shifted
= false;
236 if (prod_lsb_exp
< z_lsb_exp
) {
237 sticky_bits
= internal::shift_mantissa(z_lsb_exp
- prod_lsb_exp
, prod_mant
);
238 prod_lsb_exp
= z_lsb_exp
;
239 } else if (z_lsb_exp
< prod_lsb_exp
) {
241 sticky_bits
= internal::shift_mantissa(prod_lsb_exp
- z_lsb_exp
, z_mant
);
244 // Perform the addition:
245 // (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
246 // The final result will be stored in prod_sign and prod_mant.
247 if (prod_sign
== z_sign
) {
248 // Effectively an addition.
251 // Subtraction cases.
252 if (prod_mant
>= z_mant
) {
253 if (z_shifted
&& sticky_bits
) {
254 // Add 1 more to the subtrahend so that the sticky bits remain
255 // positive. This would simplify the rounding logic.
260 if (!z_shifted
&& sticky_bits
) {
261 // Add 1 more to the subtrahend so that the sticky bits remain
262 // positive. This would simplify the rounding logic.
265 prod_mant
= z_mant
- prod_mant
;
270 if (prod_mant
== 0) {
271 // When there is exact cancellation, i.e., x*y == -z exactly, return -0.0 if
272 // rounding downward and +0.0 for other rounding modes.
273 if (quick_get_round() == FE_DOWNWARD
)
274 prod_sign
= Sign::NEG
;
276 prod_sign
= Sign::POS
;
279 DyadicFloat
result(prod_sign
, prod_lsb_exp
- InFPBits::EXP_BIAS
, prod_mant
);
280 result
.mantissa
|= static_cast<unsigned int>(sticky_bits
);
281 return result
.template as
<OutType
, /*ShouldSignalExceptions=*/true>();
284 } // namespace generic
285 } // namespace fputil
286 } // namespace LIBC_NAMESPACE_DECL
288 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H