[libc][NFC] Move aligned access implementations to separate header
[llvm-project.git] / libc / src / __support / FPUtil / generic / FMA.h
blob86cc40c808cd5d558b3b7838d9552c28859d1c1c
1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H
12 #include "src/__support/CPP/type_traits.h"
13 #include "src/__support/FPUtil/FEnvImpl.h"
14 #include "src/__support/FPUtil/FPBits.h"
15 #include "src/__support/FPUtil/FloatProperties.h"
16 #include "src/__support/FPUtil/rounding_mode.h"
17 #include "src/__support/UInt128.h"
18 #include "src/__support/builtin_wrappers.h"
19 #include "src/__support/macros/attributes.h" // LIBC_INLINE
20 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
22 namespace __llvm_libc {
23 namespace fputil {
24 namespace generic {
26 template <typename T> LIBC_INLINE T fma(T x, T y, T z);
28 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
29 // The implementation below only is only correct for the default rounding mode,
30 // round-to-nearest tie-to-even.
31 template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
32 // Product is exact.
33 double prod = static_cast<double>(x) * static_cast<double>(y);
34 double z_d = static_cast<double>(z);
35 double sum = prod + z_d;
36 fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
38 if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) {
39 // Since the sum is computed in double precision, rounding might happen
40 // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
41 // bit_prod.exponent > bitz.exponent + 40). In that case, when we round
42 // the sum back to float, double rounding error might occur.
43 // A concrete example of this phenomenon is as follows:
44 // x = y = 1 + 2^(-12), z = 2^(-53)
45 // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
46 // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
47 // On the other hand, with the default rounding mode,
48 // double(x*y + z) = 1 + 2^(-11) + 2^(-24)
49 // and casting again to float gives us:
50 // float(double(x*y + z)) = 1 + 2^(-11).
52 // In order to correct this possible double rounding error, first we use
53 // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
54 // assuming the (default) rounding mode is round-to-the-nearest,
55 // tie-to-even. Moreover, t satisfies the condition that t < eps(sum),
56 // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
57 // occurs when computing the sum, we just need to use t to adjust (any) last
58 // bit of sum, so that the sticky bits used when rounding sum to float are
59 // correct (when it matters).
60 fputil::FPBits<double> t(
61 (bit_prod.get_unbiased_exponent() >= bitz.get_unbiased_exponent())
62 ? ((double(bit_sum) - double(bit_prod)) - double(bitz))
63 : ((double(bit_sum) - double(bitz)) - double(bit_prod)));
65 // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
66 // zero.
67 if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) {
68 if (bit_sum.get_sign() != t.get_sign()) {
69 bit_sum.set_mantissa(bit_sum.get_mantissa() + 1);
70 } else if (bit_sum.get_mantissa()) {
71 bit_sum.set_mantissa(bit_sum.get_mantissa() - 1);
76 return static_cast<float>(static_cast<double>(bit_sum));
79 namespace internal {
81 // Extract the sticky bits and shift the `mantissa` to the right by
82 // `shift_length`.
83 LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) {
84 if (shift_length >= 128) {
85 mant = 0;
86 return true; // prod_mant is non-zero.
88 UInt128 mask = (UInt128(1) << shift_length) - 1;
89 bool sticky_bits = (mant & mask) != 0;
90 mant >>= shift_length;
91 return sticky_bits;
94 } // namespace internal
96 template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
97 using FPBits = fputil::FPBits<double>;
98 using FloatProp = fputil::FloatProperties<double>;
100 if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) {
101 return x * y + z;
104 int x_exp = 0;
105 int y_exp = 0;
106 int z_exp = 0;
108 // Normalize denormal inputs.
109 if (LIBC_UNLIKELY(FPBits(x).get_unbiased_exponent() == 0)) {
110 x_exp -= 52;
111 x *= 0x1.0p+52;
113 if (LIBC_UNLIKELY(FPBits(y).get_unbiased_exponent() == 0)) {
114 y_exp -= 52;
115 y *= 0x1.0p+52;
117 if (LIBC_UNLIKELY(FPBits(z).get_unbiased_exponent() == 0)) {
118 z_exp -= 52;
119 z *= 0x1.0p+52;
122 FPBits x_bits(x), y_bits(y), z_bits(z);
123 bool x_sign = x_bits.get_sign();
124 bool y_sign = y_bits.get_sign();
125 bool z_sign = z_bits.get_sign();
126 bool prod_sign = x_sign != y_sign;
127 x_exp += x_bits.get_unbiased_exponent();
128 y_exp += y_bits.get_unbiased_exponent();
129 z_exp += z_bits.get_unbiased_exponent();
131 if (LIBC_UNLIKELY(x_exp == FPBits::MAX_EXPONENT ||
132 y_exp == FPBits::MAX_EXPONENT ||
133 z_exp == FPBits::MAX_EXPONENT))
134 return x * y + z;
136 // Extract mantissa and append hidden leading bits.
137 UInt128 x_mant = x_bits.get_mantissa() | FPBits::MIN_NORMAL;
138 UInt128 y_mant = y_bits.get_mantissa() | FPBits::MIN_NORMAL;
139 UInt128 z_mant = z_bits.get_mantissa() | FPBits::MIN_NORMAL;
141 // If the exponent of the product x*y > the exponent of z, then no extra
142 // precision beside the entire product x*y is needed. On the other hand, when
143 // the exponent of z >= the exponent of the product x*y, the worst-case that
144 // we need extra precision is when there is cancellation and the most
145 // significant bit of the product is aligned exactly with the second most
146 // significant bit of z:
147 // z : 10aa...a
148 // - prod : 1bb...bb....b
149 // In that case, in order to store the exact result, we need at least
150 // (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54.
151 // Overall, before aligning the mantissas and exponents, we can simply left-
152 // shift the mantissa of z by at least 54, and left-shift the product of x*y
153 // by (that amount - 52). After that, it is enough to align the least
154 // significant bit, given that we keep track of the round and sticky bits
155 // after the least significant bit.
156 // We pick shifting z_mant by 64 bits so that technically we can simply use
157 // the original mantissa as high part when constructing 128-bit z_mant. So the
158 // mantissa of prod will be left-shifted by 64 - 54 = 10 initially.
160 UInt128 prod_mant = x_mant * y_mant << 10;
161 int prod_lsb_exp =
162 x_exp + y_exp -
163 (FPBits::EXPONENT_BIAS + 2 * MantissaWidth<double>::VALUE + 10);
165 z_mant <<= 64;
166 int z_lsb_exp = z_exp - (MantissaWidth<double>::VALUE + 64);
167 bool round_bit = false;
168 bool sticky_bits = false;
169 bool z_shifted = false;
171 // Align exponents.
172 if (prod_lsb_exp < z_lsb_exp) {
173 sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant);
174 prod_lsb_exp = z_lsb_exp;
175 } else if (z_lsb_exp < prod_lsb_exp) {
176 z_shifted = true;
177 sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant);
180 // Perform the addition:
181 // (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
182 // The final result will be stored in prod_sign and prod_mant.
183 if (prod_sign == z_sign) {
184 // Effectively an addition.
185 prod_mant += z_mant;
186 } else {
187 // Subtraction cases.
188 if (prod_mant >= z_mant) {
189 if (z_shifted && sticky_bits) {
190 // Add 1 more to the subtrahend so that the sticky bits remain
191 // positive. This would simplify the rounding logic.
192 ++z_mant;
194 prod_mant -= z_mant;
195 } else {
196 if (!z_shifted && sticky_bits) {
197 // Add 1 more to the subtrahend so that the sticky bits remain
198 // positive. This would simplify the rounding logic.
199 ++prod_mant;
201 prod_mant = z_mant - prod_mant;
202 prod_sign = z_sign;
206 uint64_t result = 0;
207 int r_exp = 0; // Unbiased exponent of the result
209 // Normalize the result.
210 if (prod_mant != 0) {
211 uint64_t prod_hi = static_cast<uint64_t>(prod_mant >> 64);
212 int lead_zeros = prod_hi
213 ? unsafe_clz(prod_hi)
214 : 64 + unsafe_clz(static_cast<uint64_t>(prod_mant));
215 // Move the leading 1 to the most significant bit.
216 prod_mant <<= lead_zeros;
217 // The lower 64 bits are always sticky bits after moving the leading 1 to
218 // the most significant bit.
219 sticky_bits |= (static_cast<uint64_t>(prod_mant) != 0);
220 result = static_cast<uint64_t>(prod_mant >> 64);
221 // Change prod_lsb_exp the be the exponent of the least significant bit of
222 // the result.
223 prod_lsb_exp += 64 - lead_zeros;
224 r_exp = prod_lsb_exp + 63;
226 if (r_exp > 0) {
227 // The result is normal. We will shift the mantissa to the right by
228 // 63 - 52 = 11 bits (from the locations of the most significant bit).
229 // Then the rounding bit will correspond the the 11th bit, and the lowest
230 // 10 bits are merged into sticky bits.
231 round_bit = (result & 0x0400ULL) != 0;
232 sticky_bits |= (result & 0x03ffULL) != 0;
233 result >>= 11;
234 } else {
235 if (r_exp < -52) {
236 // The result is smaller than 1/2 of the smallest denormal number.
237 sticky_bits = true; // since the result is non-zero.
238 result = 0;
239 } else {
240 // The result is denormal.
241 uint64_t mask = 1ULL << (11 - r_exp);
242 round_bit = (result & mask) != 0;
243 sticky_bits |= (result & (mask - 1)) != 0;
244 if (r_exp > -52)
245 result >>= 12 - r_exp;
246 else
247 result = 0;
250 r_exp = 0;
252 } else {
253 // Return +0.0 when there is exact cancellation, i.e., x*y == -z exactly.
254 prod_sign = false;
257 // Finalize the result.
258 int round_mode = fputil::quick_get_round();
259 if (LIBC_UNLIKELY(r_exp >= FPBits::MAX_EXPONENT)) {
260 if ((round_mode == FE_TOWARDZERO) ||
261 (round_mode == FE_UPWARD && prod_sign) ||
262 (round_mode == FE_DOWNWARD && !prod_sign)) {
263 result = FPBits::MAX_NORMAL;
264 return prod_sign ? -cpp::bit_cast<double>(result)
265 : cpp::bit_cast<double>(result);
267 return prod_sign ? static_cast<double>(FPBits::neg_inf())
268 : static_cast<double>(FPBits::inf());
271 // Remove hidden bit and append the exponent field and sign bit.
272 result = (result & FloatProp::MANTISSA_MASK) |
273 (static_cast<uint64_t>(r_exp) << FloatProp::MANTISSA_WIDTH);
274 if (prod_sign) {
275 result |= FloatProp::SIGN_MASK;
278 // Rounding.
279 if (round_mode == FE_TONEAREST) {
280 if (round_bit && (sticky_bits || ((result & 1) != 0)))
281 ++result;
282 } else if ((round_mode == FE_UPWARD && !prod_sign) ||
283 (round_mode == FE_DOWNWARD && prod_sign)) {
284 if (round_bit || sticky_bits)
285 ++result;
288 return cpp::bit_cast<double>(result);
291 } // namespace generic
292 } // namespace fputil
293 } // namespace __llvm_libc
295 #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_GENERIC_FMA_H