[RISCV] Change func to funct in RISCVInstrInfoXqci.td. NFC (#119669)
[llvm-project.git] / libc / src / __support / FPUtil / dyadic_float.h
blob289fd01680547d40c73659e8711333e396f32106
1 //===-- A class to store high precision floating point numbers --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
12 #include "FEnvImpl.h"
13 #include "FPBits.h"
14 #include "hdr/errno_macros.h"
15 #include "hdr/fenv_macros.h"
16 #include "multiply_add.h"
17 #include "rounding_mode.h"
18 #include "src/__support/CPP/type_traits.h"
19 #include "src/__support/big_int.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
22 #include "src/__support/macros/properties/types.h"
24 #include <stddef.h>
26 namespace LIBC_NAMESPACE_DECL {
27 namespace fputil {
29 // A generic class to perform computations of high precision floating points.
30 // We store the value in dyadic format, including 3 fields:
31 // sign : boolean value - false means positive, true means negative
32 // exponent: the exponent value of the least significant bit of the mantissa.
33 // mantissa: unsigned integer of length `Bits`.
34 // So the real value that is stored is:
35 // real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
36 // The stored data is normal if for non-zero mantissa, the leading bit is 1.
37 // The outputs of the constructors and most functions will be normalized.
38 // To simplify and improve the efficiency, many functions will assume that the
39 // inputs are normal.
40 template <size_t Bits> struct DyadicFloat {
41 using MantissaType = LIBC_NAMESPACE::UInt<Bits>;
43 Sign sign = Sign::POS;
44 int exponent = 0;
45 MantissaType mantissa = MantissaType(0);
47 LIBC_INLINE constexpr DyadicFloat() = default;
49 template <typename T, cpp::enable_if_t<cpp::is_floating_point_v<T>, int> = 0>
50 LIBC_INLINE constexpr DyadicFloat(T x) {
51 static_assert(FPBits<T>::FRACTION_LEN < Bits);
52 FPBits<T> x_bits(x);
53 sign = x_bits.sign();
54 exponent = x_bits.get_explicit_exponent() - FPBits<T>::FRACTION_LEN;
55 mantissa = MantissaType(x_bits.get_explicit_mantissa());
56 normalize();
59 LIBC_INLINE constexpr DyadicFloat(Sign s, int e, MantissaType m)
60 : sign(s), exponent(e), mantissa(m) {
61 normalize();
64 // Normalizing the mantissa, bringing the leading 1 bit to the most
65 // significant bit.
66 LIBC_INLINE constexpr DyadicFloat &normalize() {
67 if (!mantissa.is_zero()) {
68 int shift_length = cpp::countl_zero(mantissa);
69 exponent -= shift_length;
70 mantissa <<= static_cast<size_t>(shift_length);
72 return *this;
75 // Used for aligning exponents. Output might not be normalized.
76 LIBC_INLINE constexpr DyadicFloat &shift_left(unsigned shift_length) {
77 if (shift_length < Bits) {
78 exponent -= static_cast<int>(shift_length);
79 mantissa <<= shift_length;
80 } else {
81 exponent = 0;
82 mantissa = MantissaType(0);
84 return *this;
87 // Used for aligning exponents. Output might not be normalized.
88 LIBC_INLINE constexpr DyadicFloat &shift_right(unsigned shift_length) {
89 if (shift_length < Bits) {
90 exponent += static_cast<int>(shift_length);
91 mantissa >>= shift_length;
92 } else {
93 exponent = 0;
94 mantissa = MantissaType(0);
96 return *this;
99 // Assume that it is already normalized. Output the unbiased exponent.
100 LIBC_INLINE constexpr int get_unbiased_exponent() const {
101 return exponent + (Bits - 1);
104 #ifdef LIBC_TYPES_HAS_FLOAT16
105 template <typename T, bool ShouldSignalExceptions>
106 LIBC_INLINE constexpr cpp::enable_if_t<
107 cpp::is_floating_point_v<T> && (FPBits<T>::FRACTION_LEN < Bits), T>
108 generic_as() const {
109 using FPBits = FPBits<float16>;
110 using StorageType = typename FPBits::StorageType;
112 constexpr int EXTRA_FRACTION_LEN = Bits - 1 - FPBits::FRACTION_LEN;
114 if (mantissa == 0)
115 return FPBits::zero(sign).get_val();
117 int unbiased_exp = get_unbiased_exponent();
119 if (unbiased_exp + FPBits::EXP_BIAS >= FPBits::MAX_BIASED_EXPONENT) {
120 if constexpr (ShouldSignalExceptions) {
121 set_errno_if_required(ERANGE);
122 raise_except_if_required(FE_OVERFLOW | FE_INEXACT);
125 switch (quick_get_round()) {
126 case FE_TONEAREST:
127 return FPBits::inf(sign).get_val();
128 case FE_TOWARDZERO:
129 return FPBits::max_normal(sign).get_val();
130 case FE_DOWNWARD:
131 if (sign.is_pos())
132 return FPBits::max_normal(Sign::POS).get_val();
133 return FPBits::inf(Sign::NEG).get_val();
134 case FE_UPWARD:
135 if (sign.is_neg())
136 return FPBits::max_normal(Sign::NEG).get_val();
137 return FPBits::inf(Sign::POS).get_val();
138 default:
139 __builtin_unreachable();
143 StorageType out_biased_exp = 0;
144 StorageType out_mantissa = 0;
145 bool round = false;
146 bool sticky = false;
147 bool underflow = false;
149 if (unbiased_exp < -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
150 sticky = true;
151 underflow = true;
152 } else if (unbiased_exp == -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
153 round = true;
154 MantissaType sticky_mask = (MantissaType(1) << (Bits - 1)) - 1;
155 sticky = (mantissa & sticky_mask) != 0;
156 } else {
157 int extra_fraction_len = EXTRA_FRACTION_LEN;
159 if (unbiased_exp < 1 - FPBits::EXP_BIAS) {
160 underflow = true;
161 extra_fraction_len += 1 - FPBits::EXP_BIAS - unbiased_exp;
162 } else {
163 out_biased_exp =
164 static_cast<StorageType>(unbiased_exp + FPBits::EXP_BIAS);
167 MantissaType round_mask = MantissaType(1) << (extra_fraction_len - 1);
168 round = (mantissa & round_mask) != 0;
169 MantissaType sticky_mask = round_mask - 1;
170 sticky = (mantissa & sticky_mask) != 0;
172 out_mantissa = static_cast<StorageType>(mantissa >> extra_fraction_len);
175 bool lsb = (out_mantissa & 1) != 0;
177 StorageType result =
178 FPBits::create_value(sign, out_biased_exp, out_mantissa).uintval();
180 switch (quick_get_round()) {
181 case FE_TONEAREST:
182 if (round && (lsb || sticky))
183 ++result;
184 break;
185 case FE_DOWNWARD:
186 if (sign.is_neg() && (round || sticky))
187 ++result;
188 break;
189 case FE_UPWARD:
190 if (sign.is_pos() && (round || sticky))
191 ++result;
192 break;
193 default:
194 break;
197 if (ShouldSignalExceptions && (round || sticky)) {
198 int excepts = FE_INEXACT;
199 if (FPBits(result).is_inf()) {
200 set_errno_if_required(ERANGE);
201 excepts |= FE_OVERFLOW;
202 } else if (underflow) {
203 set_errno_if_required(ERANGE);
204 excepts |= FE_UNDERFLOW;
206 raise_except_if_required(excepts);
209 return FPBits(result).get_val();
211 #endif // LIBC_TYPES_HAS_FLOAT16
213 template <typename T, bool ShouldSignalExceptions,
214 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
215 (FPBits<T>::FRACTION_LEN < Bits),
216 void>>
217 LIBC_INLINE constexpr T fast_as() const {
218 if (LIBC_UNLIKELY(mantissa.is_zero()))
219 return FPBits<T>::zero(sign).get_val();
221 // Assume that it is normalized, and output is also normal.
222 constexpr uint32_t PRECISION = FPBits<T>::FRACTION_LEN + 1;
223 using output_bits_t = typename FPBits<T>::StorageType;
224 constexpr output_bits_t IMPLICIT_MASK =
225 FPBits<T>::SIG_MASK - FPBits<T>::FRACTION_MASK;
227 int exp_hi = exponent + static_cast<int>((Bits - 1) + FPBits<T>::EXP_BIAS);
229 if (LIBC_UNLIKELY(exp_hi > 2 * FPBits<T>::EXP_BIAS)) {
230 // Results overflow.
231 T d_hi =
232 FPBits<T>::create_value(sign, 2 * FPBits<T>::EXP_BIAS, IMPLICIT_MASK)
233 .get_val();
234 // volatile prevents constant propagation that would result in infinity
235 // always being returned no matter the current rounding mode.
236 volatile T two = static_cast<T>(2.0);
237 T r = two * d_hi;
239 // TODO: Whether rounding down the absolute value to max_normal should
240 // also raise FE_OVERFLOW and set ERANGE is debatable.
241 if (ShouldSignalExceptions && FPBits<T>(r).is_inf())
242 set_errno_if_required(ERANGE);
244 return r;
247 bool denorm = false;
248 uint32_t shift = Bits - PRECISION;
249 if (LIBC_UNLIKELY(exp_hi <= 0)) {
250 // Output is denormal.
251 denorm = true;
252 shift = (Bits - PRECISION) + static_cast<uint32_t>(1 - exp_hi);
254 exp_hi = FPBits<T>::EXP_BIAS;
257 int exp_lo = exp_hi - static_cast<int>(PRECISION) - 1;
259 MantissaType m_hi =
260 shift >= MantissaType::BITS ? MantissaType(0) : mantissa >> shift;
262 T d_hi = FPBits<T>::create_value(
263 sign, static_cast<output_bits_t>(exp_hi),
264 (static_cast<output_bits_t>(m_hi) & FPBits<T>::SIG_MASK) |
265 IMPLICIT_MASK)
266 .get_val();
268 MantissaType round_mask =
269 shift > MantissaType::BITS ? 0 : MantissaType(1) << (shift - 1);
270 MantissaType sticky_mask = round_mask - MantissaType(1);
272 bool round_bit = !(mantissa & round_mask).is_zero();
273 bool sticky_bit = !(mantissa & sticky_mask).is_zero();
274 int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
276 T d_lo;
278 if (LIBC_UNLIKELY(exp_lo <= 0)) {
279 // d_lo is denormal, but the output is normal.
280 int scale_up_exponent = 1 - exp_lo;
281 T scale_up_factor =
282 FPBits<T>::create_value(Sign::POS,
283 static_cast<output_bits_t>(
284 FPBits<T>::EXP_BIAS + scale_up_exponent),
285 IMPLICIT_MASK)
286 .get_val();
287 T scale_down_factor =
288 FPBits<T>::create_value(Sign::POS,
289 static_cast<output_bits_t>(
290 FPBits<T>::EXP_BIAS - scale_up_exponent),
291 IMPLICIT_MASK)
292 .get_val();
294 d_lo = FPBits<T>::create_value(
295 sign, static_cast<output_bits_t>(exp_lo + scale_up_exponent),
296 IMPLICIT_MASK)
297 .get_val();
299 return multiply_add(d_lo, T(round_and_sticky), d_hi * scale_up_factor) *
300 scale_down_factor;
303 d_lo = FPBits<T>::create_value(sign, static_cast<output_bits_t>(exp_lo),
304 IMPLICIT_MASK)
305 .get_val();
307 // Still correct without FMA instructions if `d_lo` is not underflow.
308 T r = multiply_add(d_lo, T(round_and_sticky), d_hi);
310 if (LIBC_UNLIKELY(denorm)) {
311 // Exponent before rounding is in denormal range, simply clear the
312 // exponent field.
313 output_bits_t clear_exp = static_cast<output_bits_t>(
314 output_bits_t(exp_hi) << FPBits<T>::SIG_LEN);
315 output_bits_t r_bits = FPBits<T>(r).uintval() - clear_exp;
317 if (!(r_bits & FPBits<T>::EXP_MASK)) {
318 // Output is denormal after rounding, clear the implicit bit for 80-bit
319 // long double.
320 r_bits -= IMPLICIT_MASK;
322 // TODO: IEEE Std 754-2019 lets implementers choose whether to check for
323 // "tininess" before or after rounding for base-2 formats, as long as
324 // the same choice is made for all operations. Our choice to check after
325 // rounding might not be the same as the hardware's.
326 if (ShouldSignalExceptions && round_and_sticky) {
327 set_errno_if_required(ERANGE);
328 raise_except_if_required(FE_UNDERFLOW);
332 return FPBits<T>(r_bits).get_val();
335 return r;
338 // Assume that it is already normalized.
339 // Output is rounded correctly with respect to the current rounding mode.
340 template <typename T, bool ShouldSignalExceptions,
341 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
342 (FPBits<T>::FRACTION_LEN < Bits),
343 void>>
344 LIBC_INLINE constexpr T as() const {
345 #if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
346 if constexpr (cpp::is_same_v<T, float16>)
347 return generic_as<T, ShouldSignalExceptions>();
348 #endif
349 return fast_as<T, ShouldSignalExceptions>();
352 template <typename T,
353 typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
354 (FPBits<T>::FRACTION_LEN < Bits),
355 void>>
356 LIBC_INLINE explicit constexpr operator T() const {
357 return as<T, /*ShouldSignalExceptions=*/false>();
360 LIBC_INLINE constexpr MantissaType as_mantissa_type() const {
361 if (mantissa.is_zero())
362 return 0;
364 MantissaType new_mant = mantissa;
365 if (exponent > 0) {
366 new_mant <<= exponent;
367 } else {
368 new_mant >>= (-exponent);
371 if (sign.is_neg()) {
372 new_mant = (~new_mant) + 1;
375 return new_mant;
379 // Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the
380 // output:
381 // - Align the exponents so that:
382 // new a.exponent = new b.exponent = max(a.exponent, b.exponent)
383 // - Add or subtract the mantissas depending on the signs.
384 // - Normalize the result.
385 // The absolute errors compared to the mathematical sum is bounded by:
386 // | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
387 // i.e., errors are up to 2 ULPs.
388 // Assume inputs are normalized (by constructors or other functions) so that we
389 // don't need to normalize the inputs again in this function. If the inputs are
390 // not normalized, the results might lose precision significantly.
391 template <size_t Bits>
392 LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
393 DyadicFloat<Bits> b) {
394 if (LIBC_UNLIKELY(a.mantissa.is_zero()))
395 return b;
396 if (LIBC_UNLIKELY(b.mantissa.is_zero()))
397 return a;
399 // Align exponents
400 if (a.exponent > b.exponent)
401 b.shift_right(static_cast<unsigned>(a.exponent - b.exponent));
402 else if (b.exponent > a.exponent)
403 a.shift_right(static_cast<unsigned>(b.exponent - a.exponent));
405 DyadicFloat<Bits> result;
407 if (a.sign == b.sign) {
408 // Addition
409 result.sign = a.sign;
410 result.exponent = a.exponent;
411 result.mantissa = a.mantissa;
412 if (result.mantissa.add_overflow(b.mantissa)) {
413 // Mantissa addition overflow.
414 result.shift_right(1);
415 result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
416 (uint64_t(1) << 63);
418 // Result is already normalized.
419 return result;
422 // Subtraction
423 if (a.mantissa >= b.mantissa) {
424 result.sign = a.sign;
425 result.exponent = a.exponent;
426 result.mantissa = a.mantissa - b.mantissa;
427 } else {
428 result.sign = b.sign;
429 result.exponent = b.exponent;
430 result.mantissa = b.mantissa - a.mantissa;
433 return result.normalize();
436 // Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
437 // floats with rounding toward 0 and then normalize the output:
438 // result.exponent = a.exponent + b.exponent + Bits,
439 // result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
440 // ~ (full product a.mantissa * b.mantissa) >> Bits.
441 // The errors compared to the mathematical product is bounded by:
442 // 2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
443 // Assume inputs are normalized (by constructors or other functions) so that we
444 // don't need to normalize the inputs again in this function. If the inputs are
445 // not normalized, the results might lose precision significantly.
446 template <size_t Bits>
447 LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(const DyadicFloat<Bits> &a,
448 const DyadicFloat<Bits> &b) {
449 DyadicFloat<Bits> result;
450 result.sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
451 result.exponent = a.exponent + b.exponent + static_cast<int>(Bits);
453 if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
454 result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
455 // Check the leading bit directly, should be faster than using clz in
456 // normalize().
457 if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
458 63 ==
460 result.shift_left(1);
461 } else {
462 result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
464 return result;
467 // Simple polynomial approximation.
468 template <size_t Bits>
469 LIBC_INLINE constexpr DyadicFloat<Bits>
470 multiply_add(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b,
471 const DyadicFloat<Bits> &c) {
472 return quick_add(c, quick_mul(a, b));
475 // Simple exponentiation implementation for printf. Only handles positive
476 // exponents, since division isn't implemented.
477 template <size_t Bits>
478 LIBC_INLINE constexpr DyadicFloat<Bits> pow_n(const DyadicFloat<Bits> &a,
479 uint32_t power) {
480 DyadicFloat<Bits> result = 1.0;
481 DyadicFloat<Bits> cur_power = a;
483 while (power > 0) {
484 if ((power % 2) > 0) {
485 result = quick_mul(result, cur_power);
487 power = power >> 1;
488 cur_power = quick_mul(cur_power, cur_power);
490 return result;
493 template <size_t Bits>
494 LIBC_INLINE constexpr DyadicFloat<Bits> mul_pow_2(const DyadicFloat<Bits> &a,
495 int32_t pow_2) {
496 DyadicFloat<Bits> result = a;
497 result.exponent += pow_2;
498 return result;
501 } // namespace fputil
502 } // namespace LIBC_NAMESPACE_DECL
504 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H