[libc][NFC] Move aligned access implementations to separate header
[llvm-project.git] / libc / src / __support / FPUtil / dyadic_float.h
blobeb51c17abb80b828b7acec7d1c24557b0e0afe65
1 //===-- A class to store high precision floating point numbers --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H
10 #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H
12 #include "FPBits.h"
13 #include "FloatProperties.h"
14 #include "multiply_add.h"
15 #include "src/__support/CPP/type_traits.h"
16 #include "src/__support/UInt.h"
17 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
19 #include <stddef.h>
21 namespace __llvm_libc::fputil {
23 // A generic class to perform comuptations of high precision floating points.
24 // We store the value in dyadic format, including 3 fields:
25 // sign : boolean value - false means positive, true means negative
26 // exponent: the exponent value of the least significant bit of the mantissa.
27 // mantissa: unsigned integer of length `Bits`.
28 // So the real value that is stored is:
29 // real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
30 // The stored data is normal if for non-zero mantissa, the leading bit is 1.
31 // The outputs of the constructors and most functions will be normalized.
32 // To simplify and improve the efficiency, many functions will assume that the
33 // inputs are normal.
34 template <size_t Bits> struct DyadicFloat {
35 using MantissaType = __llvm_libc::cpp::UInt<Bits>;
37 bool sign = false;
38 int exponent = 0;
39 MantissaType mantissa = MantissaType(0);
41 DyadicFloat() = default;
43 template <typename T,
44 cpp::enable_if_t<cpp::is_floating_point_v<T> &&
45 (FloatProperties<T>::MANTISSA_WIDTH < Bits),
46 int> = 0>
47 DyadicFloat(T x) {
48 FPBits<T> x_bits(x);
49 sign = x_bits.get_sign();
50 exponent = x_bits.get_exponent() - FloatProperties<T>::MANTISSA_WIDTH;
51 mantissa = MantissaType(x_bits.get_explicit_mantissa());
52 normalize();
55 constexpr DyadicFloat(bool s, int e, MantissaType m)
56 : sign(s), exponent(e), mantissa(m) {
57 normalize();
60 // Normalizing the mantissa, bringing the leading 1 bit to the most
61 // significant bit.
62 constexpr DyadicFloat &normalize() {
63 if (!mantissa.is_zero()) {
64 int shift_length = static_cast<int>(mantissa.clz());
65 exponent -= shift_length;
66 mantissa.shift_left(static_cast<size_t>(shift_length));
68 return *this;
71 // Used for aligning exponents. Output might not be normalized.
72 DyadicFloat &shift_left(int shift_length) {
73 exponent -= shift_length;
74 mantissa <<= static_cast<size_t>(shift_length);
75 return *this;
78 // Used for aligning exponents. Output might not be normalized.
79 DyadicFloat &shift_right(int shift_length) {
80 exponent += shift_length;
81 mantissa >>= static_cast<size_t>(shift_length);
82 return *this;
85 // Assume that it is already normalized and output is also normal.
86 // Output is rounded correctly with respect to the current rounding mode.
87 // TODO(lntue): Test or add support for denormal output.
88 // TODO(lntue): Test or add specialization for x86 long double.
89 template <typename T, typename = cpp::enable_if_t<
90 cpp::is_floating_point_v<T> &&
91 (FloatProperties<T>::MANTISSA_WIDTH < Bits),
92 void>>
93 explicit operator T() const {
94 // TODO(lntue): Do we need to treat signed zeros properly?
95 if (mantissa.is_zero())
96 return 0.0;
98 // Assume that it is normalized, and output is also normal.
99 constexpr size_t PRECISION = FloatProperties<T>::MANTISSA_WIDTH + 1;
100 using output_bits_t = typename FPBits<T>::UIntType;
102 MantissaType m_hi(mantissa >> (Bits - PRECISION));
103 auto d_hi = FPBits<T>::create_value(
104 sign, exponent + (Bits - 1) + FloatProperties<T>::EXPONENT_BIAS,
105 output_bits_t(m_hi) & FloatProperties<T>::MANTISSA_MASK);
107 const MantissaType round_mask = MantissaType(1) << (Bits - PRECISION - 1);
108 const MantissaType sticky_mask = round_mask - MantissaType(1);
110 bool round_bit = !(mantissa & round_mask).is_zero();
111 bool sticky_bit = !(mantissa & sticky_mask).is_zero();
112 int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
113 auto d_lo = FPBits<T>::create_value(sign,
114 exponent + (Bits - PRECISION - 2) +
115 FloatProperties<T>::EXPONENT_BIAS,
116 output_bits_t(0));
118 // Still correct without FMA instructions if `d_lo` is not underflow.
119 return multiply_add(d_lo.get_val(), T(round_and_sticky), d_hi.get_val());
122 explicit operator MantissaType() const {
123 if (mantissa.is_zero())
124 return 0;
126 MantissaType new_mant = mantissa;
127 if (exponent > 0) {
128 new_mant <<= exponent;
129 } else {
130 new_mant >>= (-exponent);
133 if (sign) {
134 new_mant = (~new_mant) + 1;
137 return new_mant;
141 // Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the
142 // output:
143 // - Align the exponents so that:
144 // new a.exponent = new b.exponent = max(a.exponent, b.exponent)
145 // - Add or subtract the mantissas depending on the signs.
146 // - Normalize the result.
147 // The absolute errors compared to the mathematical sum is bounded by:
148 // | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
149 // i.e., errors are up to 2 ULPs.
150 // Assume inputs are normalized (by constructors or other functions) so that we
151 // don't need to normalize the inputs again in this function. If the inputs are
152 // not normalized, the results might lose precision significantly.
153 template <size_t Bits>
154 constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
155 DyadicFloat<Bits> b) {
156 if (LIBC_UNLIKELY(a.mantissa.is_zero()))
157 return b;
158 if (LIBC_UNLIKELY(b.mantissa.is_zero()))
159 return a;
161 // Align exponents
162 if (a.exponent > b.exponent)
163 b.shift_right(a.exponent - b.exponent);
164 else if (b.exponent > a.exponent)
165 a.shift_right(b.exponent - a.exponent);
167 DyadicFloat<Bits> result;
169 if (a.sign == b.sign) {
170 // Addition
171 result.sign = a.sign;
172 result.exponent = a.exponent;
173 result.mantissa = a.mantissa;
174 if (result.mantissa.add(b.mantissa)) {
175 // Mantissa addition overflow.
176 result.shift_right(1);
177 result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] |=
178 (uint64_t(1) << 63);
180 // Result is already normalized.
181 return result;
184 // Subtraction
185 if (a.mantissa >= b.mantissa) {
186 result.sign = a.sign;
187 result.exponent = a.exponent;
188 result.mantissa = a.mantissa - b.mantissa;
189 } else {
190 result.sign = b.sign;
191 result.exponent = b.exponent;
192 result.mantissa = b.mantissa - a.mantissa;
195 return result.normalize();
198 // Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
199 // floats with rounding toward 0 and then normalize the output:
200 // result.exponent = a.exponent + b.exponent + Bits,
201 // result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
202 // ~ (full product a.mantissa * b.mantissa) >> Bits.
203 // The errors compared to the mathematical product is bounded by:
204 // 2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORDCOUNT - 1) in ULPs.
205 // Assume inputs are normalized (by constructors or other functions) so that we
206 // don't need to normalize the inputs again in this function. If the inputs are
207 // not normalized, the results might lose precision significantly.
208 template <size_t Bits>
209 constexpr DyadicFloat<Bits> quick_mul(DyadicFloat<Bits> a,
210 DyadicFloat<Bits> b) {
211 DyadicFloat<Bits> result;
212 result.sign = (a.sign != b.sign);
213 result.exponent = a.exponent + b.exponent + int(Bits);
215 if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
216 result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
217 // Check the leading bit directly, should be faster than using clz in
218 // normalize().
219 if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] >>
220 63 ==
222 result.shift_left(1);
223 } else {
224 result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
226 return result;
229 // Simple exponentiation implementation for printf. Only handles positive
230 // exponents, since division isn't implemented.
231 template <size_t Bits>
232 constexpr DyadicFloat<Bits> pow_n(DyadicFloat<Bits> a, uint32_t power) {
233 DyadicFloat<Bits> result = 1.0;
234 DyadicFloat<Bits> cur_power = a;
236 while (power > 0) {
237 if ((power % 2) > 0) {
238 result = quick_mul(result, cur_power);
240 power = power >> 1;
241 cur_power = quick_mul(cur_power, cur_power);
243 return result;
246 template <size_t Bits>
247 constexpr DyadicFloat<Bits> mul_pow_2(DyadicFloat<Bits> a, int32_t pow_2) {
248 DyadicFloat<Bits> result = a;
249 result.exponent += pow_2;
250 return result;
253 } // namespace __llvm_libc::fputil
255 #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H