2 * Copyright
(c) 2014 Advanced Micro Devices
, Inc.
4 * Permission is hereby granted
, free of charge
, to any person obtaining a copy
5 * of this software and associated documentation files
(the "Software"), to deal
6 * in the Software without restriction
, including without limitation the rights
7 * to use
, copy
, modify
, merge
, publish
, distribute
, sublicense
, and
/or sell
8 * copies of the Software
, and to permit persons to whom the Software is
9 * furnished to do so
, subject to the following conditions
:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED
"AS IS", WITHOUT WARRANTY OF ANY KIND
, EXPRESS OR
15 * IMPLIED
, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY
,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM
, DAMAGES OR OTHER
18 * LIABILITY
, WHETHER IN AN ACTION OF CONTRACT
, TORT OR OTHERWISE
, ARISING FROM
,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 // This version is derived from the generic fma software implementation
24 // (__clc_sw_fma), but avoids the use of ulong in favor of uint2. The logic has
25 // been updated as appropriate.
27 #include
"../../../generic/lib/math/math.h"
29 #include
<clc
/clcmacro.h
>
37 static uint2 u2_set
(uint hi
, uint lo
) {
44 static uint2 u2_set_u
(uint val
) { return u2_set
(0, val
); }
46 static uint2 u2_mul
(uint a
, uint b
) {
48 res.hi
= mul_hi
(a, b
);
53 static uint2 u2_sll
(uint2 val
, uint shift
) {
58 val.hi |
= val.lo
>> (32 - shift
);
61 val.hi
= val.lo
<< (shift -
32);
67 static uint2 u2_srl
(uint2 val
, uint shift
) {
72 val.lo |
= val.hi
<< (32 - shift
);
75 val.lo
= val.hi
>> (shift -
32);
81 static uint2 u2_or
(uint2 a
, uint b
) {
86 static uint2 u2_and
(uint2 a
, uint2 b
) {
92 static uint2 u2_add
(uint2 a
, uint2 b
) {
93 uint carry
= (hadd(a.lo
, b.lo
) >> 31) & 0x1;
99 static uint2 u2_add_u
(uint2 a
, uint b
) { return u2_add
(a, u2_set_u
(b)); }
101 static uint2 u2_inv
(uint2 a
) {
104 return u2_add_u
(a, 1);
107 static uint u2_clz
(uint2 a
) {
108 uint leading_zeroes
= clz
(a.hi
);
109 if
(leading_zeroes == 32) {
110 leading_zeroes
+= clz
(a.lo
);
112 return leading_zeroes
;
115 static bool u2_eq
(uint2 a
, uint2 b
) { return a.lo
== b.lo
&& a.hi
== b.hi
; }
117 static bool u2_zero
(uint2 a
) { return u2_eq
(a, u2_set_u
(0)); }
119 static bool u2_gt
(uint2 a
, uint2 b
) {
120 return a.hi
> b.hi ||
(a.hi
== b.hi
&& a.lo
> b.lo
);
123 _CLC_DEF _CLC_OVERLOAD float fma
(float a
, float b
, float c
) {
125 if
(isnan(a) || isnan
(b) || isnan
(c) || isinf
(a) || isinf
(b)) {
129 /* If only c is inf
, and both a
,b are regular numbers
, the result is c
*/
134 a
= __clc_flush_denormal_if_not_supported
(a);
135 b
= __clc_flush_denormal_if_not_supported
(b);
136 c
= __clc_flush_denormal_if_not_supported
(c);
138 if
(a == 0.0f || b
== 0.0f
) {
146 struct fp st_a
, st_b
, st_c
;
148 st_a.exponent
= a
== .0f ?
0 : ((as_uint(a) & 0x7f800000) >> 23) -
127;
149 st_b.exponent
= b
== .0f ?
0 : ((as_uint(b) & 0x7f800000) >> 23) -
127;
150 st_c.exponent
= c
== .0f ?
0 : ((as_uint(c) & 0x7f800000) >> 23) -
127;
152 st_a.mantissa
= u2_set_u
(a == .0f ?
0 : (as_uint(a) & 0x7fffff) |
0x800000);
153 st_b.mantissa
= u2_set_u
(b == .0f ?
0 : (as_uint(b) & 0x7fffff) |
0x800000);
154 st_c.mantissa
= u2_set_u
(c == .0f ?
0 : (as_uint(c) & 0x7fffff) |
0x800000);
156 st_a.sign
= as_uint
(a) & 0x80000000;
157 st_b.sign
= as_uint
(b) & 0x80000000;
158 st_c.sign
= as_uint
(c) & 0x80000000;
161 // Move the product to the highest bits to maximize precision
162 // mantissa is
24 bits
=> product is
48 bits
, 2bits non-fraction.
163 // Add one bit for future addition overflow
,
164 // add another bit to detect subtraction underflow
166 st_mul.sign
= st_a.sign ^ st_b.sign
;
167 st_mul.mantissa
= u2_sll
(u2_mul(st_a.mantissa.lo
, st_b.mantissa.lo
), 14);
169 !u2_zero
(st_mul.mantissa
) ? st_a.exponent
+ st_b.exponent
: 0;
171 // FIXME
: Detecting a
== 0 || b
== 0 above crashed GCN isel
172 if
(st_mul.exponent
== 0 && u2_zero
(st_mul.mantissa
))
175 // Mantissa is
23 fractional bits
, shift it the same way as product mantissa
176 #define C_ADJUST
37ul
178 // both exponents are bias adjusted
179 int exp_diff
= st_mul.exponent - st_c.exponent
;
181 st_c.mantissa
= u2_sll
(st_c.mantissa
, C_ADJUST
);
182 uint2 cutoff_bits
= u2_set_u
(0);
183 uint2 cutoff_mask
= u2_add
(u2_sll(u2_set_u(1), abs
(exp_diff)),
184 u2_set
(0xffffffff, 0xffffffff));
187 exp_diff
>= 64 ? st_c.mantissa
: u2_and
(st_c.mantissa
, cutoff_mask
);
189 exp_diff
>= 64 ? u2_set_u
(0) : u2_srl
(st_c.mantissa
, exp_diff
);
191 cutoff_bits
= -exp_diff
>= 64 ? st_mul.mantissa
192 : u2_and
(st_mul.mantissa
, cutoff_mask
);
194 -exp_diff
>= 64 ? u2_set_u
(0) : u2_srl
(st_mul.mantissa
, -exp_diff
);
198 st_fma.sign
= st_mul.sign
;
199 st_fma.exponent
= max
(st_mul.exponent
, st_c.exponent
);
200 if
(st_c.sign
== st_mul.sign
) {
201 st_fma.mantissa
= u2_add
(st_mul.mantissa
, st_c.mantissa
);
203 // cutoff bits borrow one
205 u2_add
(u2_add(st_mul.mantissa
, u2_inv
(st_c.mantissa
)),
206 (!u2_zero
(cutoff_bits) && (st_mul.exponent
> st_c.exponent
)
207 ? u2_set
(0xffffffff, 0xffffffff)
211 // underflow
: st_c.sign
!= st_mul.sign
, and magnitude switches the sign
212 if
(u2_gt(st_fma.mantissa
, u2_set
(0x7fffffff, 0xffffffff))) {
213 st_fma.mantissa
= u2_inv
(st_fma.mantissa
);
214 st_fma.sign
= st_mul.sign ^
0x80000000;
217 // detect overflow
/underflow
218 int overflow_bits
= 3 - u2_clz
(st_fma.mantissa
);
221 st_fma.exponent
+= overflow_bits
;
224 if
(overflow_bits < 0) {
225 st_fma.mantissa
= u2_sll
(st_fma.mantissa
, -overflow_bits
);
230 uint2 trunc_mask
= u2_add
(u2_sll(u2_set_u(1), C_ADJUST
+ overflow_bits
),
231 u2_set
(0xffffffff, 0xffffffff));
233 u2_or
(u2_and(st_fma.mantissa
, trunc_mask
), !u2_zero
(cutoff_bits));
235 u2_and
(st_fma.mantissa
, u2_sll
(u2_set_u(1), C_ADJUST
+ overflow_bits
));
236 uint2 grs_bits
= u2_sll
(u2_set_u(4), C_ADJUST -
3 + overflow_bits
);
238 // round to nearest even
239 if
(u2_gt(trunc_bits, grs_bits
) ||
240 (u2_eq(trunc_bits, grs_bits
) && !u2_zero
(last_bit))) {
242 u2_add
(st_fma.mantissa
, u2_sll
(u2_set_u(1), C_ADJUST
+ overflow_bits
));
245 // Shift mantissa back to bit
23
246 st_fma.mantissa
= u2_srl
(st_fma.mantissa
, C_ADJUST
+ overflow_bits
);
248 // Detect rounding overflow
249 if
(u2_gt(st_fma.mantissa
, u2_set_u
(0xffffff))) {
251 st_fma.mantissa
= u2_srl
(st_fma.mantissa
, 1);
254 if
(u2_zero(st_fma.mantissa
)) {
258 // Flating point range limit
259 if
(st_fma.exponent
> 127) {
260 return as_float
(as_uint(INFINITY) | st_fma.sign
);
264 if
(st_fma.exponent
<= -
127) {
265 return as_float
(st_fma.sign
);
268 return as_float
(st_fma.sign |
((st_fma.exponent
+ 127) << 23) |
269 ((uint)st_fma.mantissa.lo
& 0x7fffff));
271 _CLC_TERNARY_VECTORIZE
(_CLC_DEF _CLC_OVERLOAD
, float
, fma
, float
, float
, float
)
275 #pragma OPENCL EXTENSION cl_khr_fp16
: enable
277 _CLC_DEF _CLC_OVERLOAD half fma
(half a
, half b
, half c
) {
278 return
(half)mad
((float)a
, (float)b
, (float)c
);
280 _CLC_TERNARY_VECTORIZE
(_CLC_DEF _CLC_OVERLOAD
, half
, fma
, half
, half
, half
)