1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/FPUtil/FEnvImpl.h"
14 #include "src/__support/FPUtil/FPBits.h"
15 #include "src/__support/FPUtil/FloatProperties.h"
16 #include "src/__support/FPUtil/rounding_mode.h"
17 #include "src/__support/UInt128.h"
18 #include "src/__support/builtin_wrappers.h"
19 #include "src/__support/macros/attributes.h" // LIBC_INLINE
20 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
22 namespace __llvm_libc
{
26 template <typename T
> LIBC_INLINE T
fma(T x
, T y
, T z
);
28 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
29 // The implementation below only is only correct for the default rounding mode,
30 // round-to-nearest tie-to-even.
31 template <> LIBC_INLINE
float fma
<float>(float x
, float y
, float z
) {
33 double prod
= static_cast<double>(x
) * static_cast<double>(y
);
34 double z_d
= static_cast<double>(z
);
35 double sum
= prod
+ z_d
;
36 fputil::FPBits
<double> bit_prod(prod
), bitz(z_d
), bit_sum(sum
);
38 if (!(bit_sum
.is_inf_or_nan() || bit_sum
.is_zero())) {
39 // Since the sum is computed in double precision, rounding might happen
40 // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
41 // bit_prod.exponent > bitz.exponent + 40). In that case, when we round
42 // the sum back to float, double rounding error might occur.
43 // A concrete example of this phenomenon is as follows:
44 // x = y = 1 + 2^(-12), z = 2^(-53)
45 // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
46 // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
47 // On the other hand, with the default rounding mode,
48 // double(x*y + z) = 1 + 2^(-11) + 2^(-24)
49 // and casting again to float gives us:
50 // float(double(x*y + z)) = 1 + 2^(-11).
52 // In order to correct this possible double rounding error, first we use
53 // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
54 // assuming the (default) rounding mode is round-to-the-nearest,
55 // tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
56 // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
57 // occurs when computing the sum, we just need to use t to adjust (any) last
58 // bit of sum, so that the sticky bits used when rounding sum to float are
59 // correct (when it matters).
60 fputil::FPBits
<double> t(
61 (bit_prod
.get_unbiased_exponent() >= bitz
.get_unbiased_exponent())
62 ? ((double(bit_sum
) - double(bit_prod
)) - double(bitz
))
63 : ((double(bit_sum
) - double(bitz
)) - double(bit_prod
)));
65 // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
67 if (!t
.is_zero() && ((bit_sum
.get_mantissa() & 0xfff'ffffULL
) == 0)) {
68 if (bit_sum
.get_sign() != t
.get_sign()) {
69 bit_sum
.set_mantissa(bit_sum
.get_mantissa() + 1);
70 } else if (bit_sum
.get_mantissa()) {
71 bit_sum
.set_mantissa(bit_sum
.get_mantissa() - 1);
76 return static_cast<float>(static_cast<double>(bit_sum
));
81 // Extract the sticky bits and shift the `mantissa` to the right by
83 LIBC_INLINE
bool shift_mantissa(int shift_length
, UInt128
&mant
) {
84 if (shift_length
>= 128) {
86 return true; // prod_mant is non-zero.
88 UInt128 mask
= (UInt128(1) << shift_length
) - 1;
89 bool sticky_bits
= (mant
& mask
) != 0;
90 mant
>>= shift_length
;
94 } // namespace internal
96 template <> LIBC_INLINE
double fma
<double>(double x
, double y
, double z
) {
97 using FPBits
= fputil::FPBits
<double>;
98 using FloatProp
= fputil::FloatProperties
<double>;
100 if (LIBC_UNLIKELY(x
== 0 || y
== 0 || z
== 0)) {
108 // Normalize denormal inputs.
109 if (LIBC_UNLIKELY(FPBits(x
).get_unbiased_exponent() == 0)) {
113 if (LIBC_UNLIKELY(FPBits(y
).get_unbiased_exponent() == 0)) {
117 if (LIBC_UNLIKELY(FPBits(z
).get_unbiased_exponent() == 0)) {
122 FPBits
x_bits(x
), y_bits(y
), z_bits(z
);
123 bool x_sign
= x_bits
.get_sign();
124 bool y_sign
= y_bits
.get_sign();
125 bool z_sign
= z_bits
.get_sign();
126 bool prod_sign
= x_sign
!= y_sign
;
127 x_exp
+= x_bits
.get_unbiased_exponent();
128 y_exp
+= y_bits
.get_unbiased_exponent();
129 z_exp
+= z_bits
.get_unbiased_exponent();
131 if (LIBC_UNLIKELY(x_exp
== FPBits::MAX_EXPONENT
||
132 y_exp
== FPBits::MAX_EXPONENT
||
133 z_exp
== FPBits::MAX_EXPONENT
))
136 // Extract mantissa and append hidden leading bits.
137 UInt128 x_mant
= x_bits
.get_mantissa() | FPBits::MIN_NORMAL
;
138 UInt128 y_mant
= y_bits
.get_mantissa() | FPBits::MIN_NORMAL
;
139 UInt128 z_mant
= z_bits
.get_mantissa() | FPBits::MIN_NORMAL
;
141 // If the exponent of the product x*y > the exponent of z, then no extra
142 // precision beside the entire product x*y is needed. On the other hand, when
143 // the exponent of z >= the exponent of the product x*y, the worst-case that
144 // we need extra precision is when there is cancellation and the most
145 // significant bit of the product is aligned exactly with the second most
146 // significant bit of z:
148 // - prod : 1bb...bb....b
149 // In that case, in order to store the exact result, we need at least
150 // (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54.
151 // Overall, before aligning the mantissas and exponents, we can simply left-
152 // shift the mantissa of z by at least 54, and left-shift the product of x*y
153 // by (that amount - 52). After that, it is enough to align the least
154 // significant bit, given that we keep track of the round and sticky bits
155 // after the least significant bit.
156 // We pick shifting z_mant by 64 bits so that technically we can simply use
157 // the original mantissa as high part when constructing 128-bit z_mant. So the
158 // mantissa of prod will be left-shifted by 64 - 54 = 10 initially.
160 UInt128 prod_mant
= x_mant
* y_mant
<< 10;
163 (FPBits::EXPONENT_BIAS
+ 2 * MantissaWidth
<double>::VALUE
+ 10);
166 int z_lsb_exp
= z_exp
- (MantissaWidth
<double>::VALUE
+ 64);
167 bool round_bit
= false;
168 bool sticky_bits
= false;
169 bool z_shifted
= false;
172 if (prod_lsb_exp
< z_lsb_exp
) {
173 sticky_bits
= internal::shift_mantissa(z_lsb_exp
- prod_lsb_exp
, prod_mant
);
174 prod_lsb_exp
= z_lsb_exp
;
175 } else if (z_lsb_exp
< prod_lsb_exp
) {
177 sticky_bits
= internal::shift_mantissa(prod_lsb_exp
- z_lsb_exp
, z_mant
);
180 // Perform the addition:
181 // (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
182 // The final result will be stored in prod_sign and prod_mant.
183 if (prod_sign
== z_sign
) {
184 // Effectively an addition.
187 // Subtraction cases.
188 if (prod_mant
>= z_mant
) {
189 if (z_shifted
&& sticky_bits
) {
190 // Add 1 more to the subtrahend so that the sticky bits remain
191 // positive. This would simplify the rounding logic.
196 if (!z_shifted
&& sticky_bits
) {
197 // Add 1 more to the subtrahend so that the sticky bits remain
198 // positive. This would simplify the rounding logic.
201 prod_mant
= z_mant
- prod_mant
;
207 int r_exp
= 0; // Unbiased exponent of the result
209 // Normalize the result.
210 if (prod_mant
!= 0) {
211 uint64_t prod_hi
= static_cast<uint64_t>(prod_mant
>> 64);
212 int lead_zeros
= prod_hi
213 ? unsafe_clz(prod_hi
)
214 : 64 + unsafe_clz(static_cast<uint64_t>(prod_mant
));
215 // Move the leading 1 to the most significant bit.
216 prod_mant
<<= lead_zeros
;
217 // The lower 64 bits are always sticky bits after moving the leading 1 to
218 // the most significant bit.
219 sticky_bits
|= (static_cast<uint64_t>(prod_mant
) != 0);
220 result
= static_cast<uint64_t>(prod_mant
>> 64);
221 // Change prod_lsb_exp the be the exponent of the least significant bit of
223 prod_lsb_exp
+= 64 - lead_zeros
;
224 r_exp
= prod_lsb_exp
+ 63;
227 // The result is normal. We will shift the mantissa to the right by
228 // 63 - 52 = 11 bits (from the locations of the most significant bit).
229 // Then the rounding bit will correspond the the 11th bit, and the lowest
230 // 10 bits are merged into sticky bits.
231 round_bit
= (result
& 0x0400ULL
) != 0;
232 sticky_bits
|= (result
& 0x03ffULL
) != 0;
236 // The result is smaller than 1/2 of the smallest denormal number.
237 sticky_bits
= true; // since the result is non-zero.
240 // The result is denormal.
241 uint64_t mask
= 1ULL << (11 - r_exp
);
242 round_bit
= (result
& mask
) != 0;
243 sticky_bits
|= (result
& (mask
- 1)) != 0;
245 result
>>= 12 - r_exp
;
253 // Return +0.0 when there is exact cancellation, i.e., x*y == -z exactly.
257 // Finalize the result.
258 int round_mode
= fputil::quick_get_round();
259 if (LIBC_UNLIKELY(r_exp
>= FPBits::MAX_EXPONENT
)) {
260 if ((round_mode
== FE_TOWARDZERO
) ||
261 (round_mode
== FE_UPWARD
&& prod_sign
) ||
262 (round_mode
== FE_DOWNWARD
&& !prod_sign
)) {
263 result
= FPBits::MAX_NORMAL
;
264 return prod_sign
? -cpp::bit_cast
<double>(result
)
265 : cpp::bit_cast
<double>(result
);
267 return prod_sign
? static_cast<double>(FPBits::neg_inf())
268 : static_cast<double>(FPBits::inf());
271 // Remove hidden bit and append the exponent field and sign bit.
272 result
= (result
& FloatProp::MANTISSA_MASK
) |
273 (static_cast<uint64_t>(r_exp
) << FloatProp::MANTISSA_WIDTH
);
275 result
|= FloatProp::SIGN_MASK
;
279 if (round_mode
== FE_TONEAREST
) {
280 if (round_bit
&& (sticky_bits
|| ((result
& 1) != 0)))
282 } else if ((round_mode
== FE_UPWARD
&& !prod_sign
) ||
283 (round_mode
== FE_DOWNWARD
&& prod_sign
)) {
284 if (round_bit
|| sticky_bits
)
288 return cpp::bit_cast
<double>(result
);
291 } // namespace generic
292 } // namespace fputil
293 } // namespace __llvm_libc
295 #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H