[X86] Better handling of impossibly large stack frames (#124217)
[llvm-project.git] / libclc / generic / lib / math / clc_rootn.cl
blob70ae02ac2370c90efaf4703c46f1187a723d5b45
1 /*
2 * Copyright (c) 2014 Advanced Micro Devices, Inc.
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <clc/clc.h>
24 #include <clc/clcmacro.h>
25 #include <clc/math/clc_fabs.h>
26 #include <clc/math/clc_mad.h>
28 #include "config.h"
29 #include "math.h"
30 #include "tables.h"
32 // compute pow using log and exp
33 // x^y = exp(y * log(x))
35 // we take care not to lose precision in the intermediate steps
37 // When computing log, calculate it in splits,
39 // r = f * (p_invead + p_inv_tail)
40 // r = rh + rt
42 // calculate log polynomial using r, in end addition, do
43 // poly = poly + ((rh-r) + rt)
45 // lth = -r
46 // ltt = ((xexp * log2_t) - poly) + logT
47 // lt = lth + ltt
49 // lh = (xexp * log2_h) + logH
50 // l = lh + lt
52 // Calculate final log answer as gh and gt,
53 // gh = l & higher-half bits
54 // gt = (((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh))
56 // yh = y & higher-half bits
57 // yt = y - yh
59 // Before entering computation of exp,
60 // vs = ((yt*gt + yt*gh) + yh*gt)
61 // v = vs + yh*gh
62 // vt = ((yh*gh - v) + vs)
64 // In calculation of exp, add vt to r that is used for poly
65 // At the end of exp, do
66 // ((((expT * poly) + expT) + expH*poly) + expH)
68 _CLC_DEF _CLC_OVERLOAD float __clc_rootn(float x, int ny) {
69 float y = MATH_RECIP((float)ny);
71 int ix = as_int(x);
72 int ax = ix & EXSIGNBIT_SP32;
73 int xpos = ix == ax;
75 int iy = as_int(y);
76 int ay = iy & EXSIGNBIT_SP32;
77 int ypos = iy == ay;
79 // Extra precise log calculation
80 // First handle case that x is close to 1
81 float r = 1.0f - as_float(ax);
82 int near1 = __clc_fabs(r) < 0x1.0p-4f;
83 float r2 = r * r;
85 // Coefficients are just 1/3, 1/4, 1/5 and 1/6
86 float poly = __clc_mad(
88 __clc_mad(r,
89 __clc_mad(r, __clc_mad(r, 0x1.24924ap-3f, 0x1.555556p-3f),
90 0x1.99999ap-3f),
91 0x1.000000p-2f),
92 0x1.555556p-2f);
94 poly *= r2 * r;
96 float lth_near1 = -r2 * 0.5f;
97 float ltt_near1 = -poly;
98 float lt_near1 = lth_near1 + ltt_near1;
99 float lh_near1 = -r;
100 float l_near1 = lh_near1 + lt_near1;
102 // Computations for x not near 1
103 int m = (int)(ax >> EXPSHIFTBITS_SP32) - EXPBIAS_SP32;
104 float mf = (float)m;
105 int ixs = as_int(as_float(ax | 0x3f800000) - 1.0f);
106 float mfs = (float)((ixs >> EXPSHIFTBITS_SP32) - 253);
107 int c = m == -127;
108 int ixn = c ? ixs : ax;
109 float mfn = c ? mfs : mf;
111 int indx = (ixn & 0x007f0000) + ((ixn & 0x00008000) << 1);
113 // F - Y
114 float f = as_float(0x3f000000 | indx) -
115 as_float(0x3f000000 | (ixn & MANTBITS_SP32));
117 indx = indx >> 16;
118 float2 tv = USE_TABLE(log_inv_tbl_ep, indx);
119 float rh = f * tv.s0;
120 float rt = f * tv.s1;
121 r = rh + rt;
123 poly = __clc_mad(r, __clc_mad(r, 0x1.0p-2f, 0x1.555556p-2f), 0x1.0p-1f) *
124 (r * r);
125 poly += (rh - r) + rt;
127 const float LOG2_HEAD = 0x1.62e000p-1f; // 0.693115234
128 const float LOG2_TAIL = 0x1.0bfbe8p-15f; // 0.0000319461833
129 tv = USE_TABLE(loge_tbl, indx);
130 float lth = -r;
131 float ltt = __clc_mad(mfn, LOG2_TAIL, -poly) + tv.s1;
132 float lt = lth + ltt;
133 float lh = __clc_mad(mfn, LOG2_HEAD, tv.s0);
134 float l = lh + lt;
136 // Select near 1 or not
137 lth = near1 ? lth_near1 : lth;
138 ltt = near1 ? ltt_near1 : ltt;
139 lt = near1 ? lt_near1 : lt;
140 lh = near1 ? lh_near1 : lh;
141 l = near1 ? l_near1 : l;
143 float gh = as_float(as_int(l) & 0xfffff000);
144 float gt = ((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh);
146 float yh = as_float(iy & 0xfffff000);
148 float fny = (float)ny;
149 float fnyh = as_float(as_int(fny) & 0xfffff000);
150 float fnyt = (float)(ny - (int)fnyh);
151 float yt = MATH_DIVIDE(__clc_mad(-fnyt, yh, __clc_mad(-fnyh, yh, 1.0f)), fny);
153 float ylogx_s = __clc_mad(gt, yh, __clc_mad(gh, yt, yt * gt));
154 float ylogx = __clc_mad(yh, gh, ylogx_s);
155 float ylogx_t = __clc_mad(yh, gh, -ylogx) + ylogx_s;
157 // Extra precise exp of ylogx
158 const float R_64_BY_LOG2 = 0x1.715476p+6f; // 64/log2 : 92.332482616893657
159 int n = convert_int(ylogx * R_64_BY_LOG2);
160 float nf = (float)n;
162 int j = n & 0x3f;
163 m = n >> 6;
164 int m2 = m << EXPSHIFTBITS_SP32;
166 // log2/64 lead: 0.0108032227
167 const float R_LOG2_BY_64_LD = 0x1.620000p-7f;
168 // log2/64 tail: 0.0000272020388
169 const float R_LOG2_BY_64_TL = 0x1.c85fdep-16f;
170 r = __clc_mad(nf, -R_LOG2_BY_64_TL, __clc_mad(nf, -R_LOG2_BY_64_LD, ylogx)) +
171 ylogx_t;
173 // Truncated Taylor series for e^r
174 poly = __clc_mad(__clc_mad(__clc_mad(r, 0x1.555556p-5f, 0x1.555556p-3f), r,
175 0x1.000000p-1f),
176 r * r, r);
178 tv = USE_TABLE(exp_tbl_ep, j);
180 float expylogx =
181 __clc_mad(tv.s0, poly, __clc_mad(tv.s1, poly, tv.s1)) + tv.s0;
182 float sexpylogx = __clc_fp32_subnormals_supported()
183 ? expylogx * as_float(0x1 << (m + 149))
184 : 0.0f;
186 float texpylogx = as_float(as_int(expylogx) + m2);
187 expylogx = m < -125 ? sexpylogx : texpylogx;
189 // Result is +-Inf if (ylogx + ylogx_t) > 128*log2
190 expylogx = ((ylogx > 0x1.62e430p+6f) |
191 (ylogx == 0x1.62e430p+6f & ylogx_t > -0x1.05c610p-22f))
192 ? as_float(PINFBITPATT_SP32)
193 : expylogx;
195 // Result is 0 if ylogx < -149*log2
196 expylogx = ylogx < -0x1.9d1da0p+6f ? 0.0f : expylogx;
198 // Classify y:
199 // inty = 0 means not an integer.
200 // inty = 1 means odd integer.
201 // inty = 2 means even integer.
203 int inty = 2 - (ny & 1);
205 float signval = as_float((as_uint(expylogx) ^ SIGNBIT_SP32));
206 expylogx = ((inty == 1) & !xpos) ? signval : expylogx;
207 int ret = as_int(expylogx);
209 // Corner case handling
210 ret = (!xpos & (inty == 2)) ? QNANBITPATT_SP32 : ret;
211 int xinf = xpos ? PINFBITPATT_SP32 : NINFBITPATT_SP32;
212 ret = ((ax == 0) & !ypos & (inty == 1)) ? xinf : ret;
213 ret = ((ax == 0) & !ypos & (inty == 2)) ? PINFBITPATT_SP32 : ret;
214 ret = ((ax == 0) & ypos & (inty == 2)) ? 0 : ret;
215 int xzero = xpos ? 0 : 0x80000000;
216 ret = ((ax == 0) & ypos & (inty == 1)) ? xzero : ret;
217 ret =
218 ((ix == NINFBITPATT_SP32) & ypos & (inty == 1)) ? NINFBITPATT_SP32 : ret;
219 ret = ((ix == NINFBITPATT_SP32) & !ypos & (inty == 1)) ? 0x80000000 : ret;
220 ret = ((ix == PINFBITPATT_SP32) & !ypos) ? 0 : ret;
221 ret = ((ix == PINFBITPATT_SP32) & ypos) ? PINFBITPATT_SP32 : ret;
222 ret = ax > PINFBITPATT_SP32 ? ix : ret;
223 ret = ny == 0 ? QNANBITPATT_SP32 : ret;
225 return as_float(ret);
227 _CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, __clc_rootn, float, int)
229 #ifdef cl_khr_fp64
230 _CLC_DEF _CLC_OVERLOAD double __clc_rootn(double x, int ny) {
231 const double real_log2_tail = 5.76999904754328540596e-08;
232 const double real_log2_lead = 6.93147122859954833984e-01;
234 double dny = (double)ny;
235 double y = 1.0 / dny;
237 long ux = as_long(x);
238 long ax = ux & (~SIGNBIT_DP64);
239 int xpos = ax == ux;
241 long uy = as_long(y);
242 long ay = uy & (~SIGNBIT_DP64);
243 int ypos = ay == uy;
245 // Extended precision log
246 double v, vt;
248 int exp = (int)(ax >> 52) - 1023;
249 int mask_exp_1023 = exp == -1023;
250 double xexp = (double)exp;
251 long mantissa = ax & 0x000FFFFFFFFFFFFFL;
253 long temp_ux = as_long(as_double(0x3ff0000000000000L | mantissa) - 1.0);
254 exp = ((temp_ux & 0x7FF0000000000000L) >> 52) - 2045;
255 double xexp1 = (double)exp;
256 long mantissa1 = temp_ux & 0x000FFFFFFFFFFFFFL;
258 xexp = mask_exp_1023 ? xexp1 : xexp;
259 mantissa = mask_exp_1023 ? mantissa1 : mantissa;
261 long rax = (mantissa & 0x000ff00000000000) +
262 ((mantissa & 0x0000080000000000) << 1);
263 int index = rax >> 44;
265 double F = as_double(rax | 0x3FE0000000000000L);
266 double Y = as_double(mantissa | 0x3FE0000000000000L);
267 double f = F - Y;
268 double2 tv = USE_TABLE(log_f_inv_tbl, index);
269 double log_h = tv.s0;
270 double log_t = tv.s1;
271 double f_inv = (log_h + log_t) * f;
272 double r1 = as_double(as_long(f_inv) & 0xfffffffff8000000L);
273 double r2 = fma(-F, r1, f) * (log_h + log_t);
274 double r = r1 + r2;
276 double poly = fma(
277 r, fma(r, fma(r, fma(r, 1.0 / 7.0, 1.0 / 6.0), 1.0 / 5.0), 1.0 / 4.0),
278 1.0 / 3.0);
279 poly = poly * r * r * r;
281 double hr1r1 = 0.5 * r1 * r1;
282 double poly0h = r1 + hr1r1;
283 double poly0t = r1 - poly0h + hr1r1;
284 poly = fma(r1, r2, fma(0.5 * r2, r2, poly)) + r2 + poly0t;
286 tv = USE_TABLE(powlog_tbl, index);
287 log_h = tv.s0;
288 log_t = tv.s1;
290 double resT_t = fma(xexp, real_log2_tail, +log_t) - poly;
291 double resT = resT_t - poly0h;
292 double resH = fma(xexp, real_log2_lead, log_h);
293 double resT_h = poly0h;
295 double H = resT + resH;
296 double H_h = as_double(as_long(H) & 0xfffffffff8000000L);
297 double T = (resH - H + resT) + (resT_t - (resT + resT_h)) + (H - H_h);
298 H = H_h;
300 double y_head = as_double(uy & 0xfffffffff8000000L);
301 double y_tail = y - y_head;
303 double fnyh = as_double(as_long(dny) & 0xfffffffffff00000);
304 double fnyt = (double)(ny - (int)fnyh);
305 y_tail = fma(-fnyt, y_head, fma(-fnyh, y_head, 1.0)) / dny;
307 double temp = fma(y_tail, H, fma(y_head, T, y_tail * T));
308 v = fma(y_head, H, temp);
309 vt = fma(y_head, H, -v) + temp;
312 // Now calculate exp of (v,vt)
314 double expv;
316 const double max_exp_arg = 709.782712893384;
317 const double min_exp_arg = -745.1332191019411;
318 const double sixtyfour_by_lnof2 = 92.33248261689366;
319 const double lnof2_by_64_head = 0.010830424260348081;
320 const double lnof2_by_64_tail = -4.359010638708991e-10;
322 double temp = v * sixtyfour_by_lnof2;
323 int n = (int)temp;
324 double dn = (double)n;
325 int j = n & 0x0000003f;
326 int m = n >> 6;
328 double2 tv = USE_TABLE(two_to_jby64_ep_tbl, j);
329 double f1 = tv.s0;
330 double f2 = tv.s1;
331 double f = f1 + f2;
333 double r1 = fma(dn, -lnof2_by_64_head, v);
334 double r2 = dn * lnof2_by_64_tail;
335 double r = (r1 + r2) + vt;
337 double q = fma(
339 fma(r,
340 fma(r,
341 fma(r, 1.38889490863777199667e-03, 8.33336798434219616221e-03),
342 4.16666666662260795726e-02),
343 1.66666666665260878863e-01),
344 5.00000000000000008883e-01);
345 q = fma(r * r, q, r);
347 expv = fma(f, q, f2) + f1;
348 expv = ldexp(expv, m);
350 expv = v > max_exp_arg ? as_double(0x7FF0000000000000L) : expv;
351 expv = v < min_exp_arg ? 0.0 : expv;
354 // See whether y is an integer.
355 // inty = 0 means not an integer.
356 // inty = 1 means odd integer.
357 // inty = 2 means even integer.
359 int inty = 2 - (ny & 1);
361 expv *= ((inty == 1) & !xpos) ? -1.0 : 1.0;
363 long ret = as_long(expv);
365 // Now all the edge cases
366 ret = (!xpos & (inty == 2)) ? QNANBITPATT_DP64 : ret;
367 long xinf = xpos ? PINFBITPATT_DP64 : NINFBITPATT_DP64;
368 ret = ((ax == 0L) & !ypos & (inty == 1)) ? xinf : ret;
369 ret = ((ax == 0L) & !ypos & (inty == 2)) ? PINFBITPATT_DP64 : ret;
370 ret = ((ax == 0L) & ypos & (inty == 2)) ? 0L : ret;
371 long xzero = xpos ? 0L : 0x8000000000000000L;
372 ret = ((ax == 0L) & ypos & (inty == 1)) ? xzero : ret;
373 ret =
374 ((ux == NINFBITPATT_DP64) & ypos & (inty == 1)) ? NINFBITPATT_DP64 : ret;
375 ret = ((ux == NINFBITPATT_DP64) & !ypos & (inty == 1)) ? 0x8000000000000000L
376 : ret;
377 ret = ((ux == PINFBITPATT_DP64) & !ypos) ? 0L : ret;
378 ret = ((ux == PINFBITPATT_DP64) & ypos) ? PINFBITPATT_DP64 : ret;
379 ret = ax > PINFBITPATT_DP64 ? ux : ret;
380 ret = ny == 0 ? QNANBITPATT_DP64 : ret;
381 return as_double(ret);
383 _CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, double, __clc_rootn, double, int)
384 #endif
386 #ifdef cl_khr_fp16
388 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
390 _CLC_OVERLOAD _CLC_DEF half __clc_rootn(half x, int y) {
391 return (half)__clc_rootn((float)x, y);
394 _CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, half, __clc_rootn, half, int);
396 #endif