2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017,2019,2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 #ifndef GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
37 #define GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H
44 #include <immintrin.h>
46 #include "gromacs/math/utilities.h"
56 SimdFloat(float f
) : simdInternal_(_mm512_set1_ps(f
)) {}
58 // Internal utility constructor to simplify return statements
59 SimdFloat(__m512 simd
) : simdInternal_(simd
) {}
69 SimdFInt32(std::int32_t i
) : simdInternal_(_mm512_set1_epi32(i
)) {}
71 // Internal utility constructor to simplify return statements
72 SimdFInt32(__m512i simd
) : simdInternal_(simd
) {}
74 __m512i simdInternal_
;
82 SimdFBool(bool b
) : simdInternal_(_mm512_int2mask(b
? 0xFFFF : 0)) {}
84 // Internal utility constructor to simplify return statements
85 SimdFBool(__mmask16 simd
) : simdInternal_(simd
) {}
87 __mmask16 simdInternal_
;
95 SimdFIBool(bool b
) : simdInternal_(_mm512_int2mask(b
? 0xFFFF : 0)) {}
97 // Internal utility constructor to simplify return statements
98 SimdFIBool(__mmask16 simd
) : simdInternal_(simd
) {}
100 __mmask16 simdInternal_
;
103 static inline SimdFloat gmx_simdcall
simdLoad(const float* m
, SimdFloatTag
= {})
105 assert(std::size_t(m
) % 64 == 0);
106 return { _mm512_load_ps(m
) };
109 static inline void gmx_simdcall
store(float* m
, SimdFloat a
)
111 assert(std::size_t(m
) % 64 == 0);
112 _mm512_store_ps(m
, a
.simdInternal_
);
115 static inline SimdFloat gmx_simdcall
simdLoadU(const float* m
, SimdFloatTag
= {})
117 return { _mm512_loadunpackhi_ps(_mm512_loadunpacklo_ps(_mm512_undefined_ps(), m
), m
+ 16) };
120 static inline void gmx_simdcall
storeU(float* m
, SimdFloat a
)
122 _mm512_packstorelo_ps(m
, a
.simdInternal_
);
123 _mm512_packstorehi_ps(m
+ 16, a
.simdInternal_
);
126 static inline SimdFloat gmx_simdcall
setZeroF()
128 return { _mm512_setzero_ps() };
131 static inline SimdFInt32 gmx_simdcall
simdLoad(const std::int32_t* m
, SimdFInt32Tag
)
133 assert(std::size_t(m
) % 64 == 0);
134 return { _mm512_load_epi32(m
) };
137 static inline void gmx_simdcall
store(std::int32_t* m
, SimdFInt32 a
)
139 assert(std::size_t(m
) % 64 == 0);
140 _mm512_store_epi32(m
, a
.simdInternal_
);
143 static inline SimdFInt32 gmx_simdcall
simdLoadU(const std::int32_t* m
, SimdFInt32Tag
)
145 return { _mm512_loadunpackhi_epi32(_mm512_loadunpacklo_epi32(_mm512_undefined_epi32(), m
), m
+ 16) };
148 static inline void gmx_simdcall
storeU(std::int32_t* m
, SimdFInt32 a
)
150 _mm512_packstorelo_epi32(m
, a
.simdInternal_
);
151 _mm512_packstorehi_epi32(m
+ 16, a
.simdInternal_
);
154 static inline SimdFInt32 gmx_simdcall
setZeroFI()
156 return { _mm512_setzero_si512() };
161 static inline std::int32_t gmx_simdcall
extract(SimdFInt32 a
)
164 _mm512_mask_packstorelo_epi32(&r
, _mm512_mask2int(1 << index
), a
.simdInternal_
);
168 static inline SimdFloat gmx_simdcall
operator&(SimdFloat a
, SimdFloat b
)
170 return { _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a
.simdInternal_
),
171 _mm512_castps_si512(b
.simdInternal_
))) };
174 static inline SimdFloat gmx_simdcall
andNot(SimdFloat a
, SimdFloat b
)
176 return { _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(a
.simdInternal_
),
177 _mm512_castps_si512(b
.simdInternal_
))) };
180 static inline SimdFloat gmx_simdcall
operator|(SimdFloat a
, SimdFloat b
)
182 return { _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(a
.simdInternal_
),
183 _mm512_castps_si512(b
.simdInternal_
))) };
186 static inline SimdFloat gmx_simdcall
operator^(SimdFloat a
, SimdFloat b
)
188 return { _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a
.simdInternal_
),
189 _mm512_castps_si512(b
.simdInternal_
))) };
192 static inline SimdFloat gmx_simdcall
operator+(SimdFloat a
, SimdFloat b
)
194 return { _mm512_add_ps(a
.simdInternal_
, b
.simdInternal_
) };
197 static inline SimdFloat gmx_simdcall
operator-(SimdFloat a
, SimdFloat b
)
199 return { _mm512_sub_ps(a
.simdInternal_
, b
.simdInternal_
) };
202 static inline SimdFloat gmx_simdcall
operator-(SimdFloat x
)
204 return { _mm512_addn_ps(x
.simdInternal_
, _mm512_setzero_ps()) };
207 static inline SimdFloat gmx_simdcall
operator*(SimdFloat a
, SimdFloat b
)
209 return { _mm512_mul_ps(a
.simdInternal_
, b
.simdInternal_
) };
212 static inline SimdFloat gmx_simdcall
fma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
214 return { _mm512_fmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
) };
217 static inline SimdFloat gmx_simdcall
fms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
219 return { _mm512_fmsub_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
) };
222 static inline SimdFloat gmx_simdcall
fnma(SimdFloat a
, SimdFloat b
, SimdFloat c
)
224 return { _mm512_fnmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
) };
227 static inline SimdFloat gmx_simdcall
fnms(SimdFloat a
, SimdFloat b
, SimdFloat c
)
229 return { _mm512_fnmsub_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
) };
232 static inline SimdFloat gmx_simdcall
rsqrt(SimdFloat x
)
234 return { _mm512_rsqrt23_ps(x
.simdInternal_
) };
237 static inline SimdFloat gmx_simdcall
rcp(SimdFloat x
)
239 return { _mm512_rcp23_ps(x
.simdInternal_
) };
242 static inline SimdFloat gmx_simdcall
maskAdd(SimdFloat a
, SimdFloat b
, SimdFBool m
)
244 return { _mm512_mask_add_ps(a
.simdInternal_
, m
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
) };
247 static inline SimdFloat gmx_simdcall
maskzMul(SimdFloat a
, SimdFloat b
, SimdFBool m
)
249 return { _mm512_mask_mul_ps(_mm512_setzero_ps(), m
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
) };
252 static inline SimdFloat gmx_simdcall
maskzFma(SimdFloat a
, SimdFloat b
, SimdFloat c
, SimdFBool m
)
254 return { _mm512_mask_mov_ps(_mm512_setzero_ps(), m
.simdInternal_
,
255 _mm512_fmadd_ps(a
.simdInternal_
, b
.simdInternal_
, c
.simdInternal_
)) };
258 static inline SimdFloat gmx_simdcall
maskzRsqrt(SimdFloat x
, SimdFBool m
)
260 return { _mm512_mask_rsqrt23_ps(_mm512_setzero_ps(), m
.simdInternal_
, x
.simdInternal_
) };
263 static inline SimdFloat gmx_simdcall
maskzRcp(SimdFloat x
, SimdFBool m
)
265 return { _mm512_mask_rcp23_ps(_mm512_setzero_ps(), m
.simdInternal_
, x
.simdInternal_
) };
268 static inline SimdFloat gmx_simdcall
abs(SimdFloat x
)
270 return { _mm512_castsi512_ps(_mm512_andnot_epi32(_mm512_castps_si512(_mm512_set1_ps(GMX_FLOAT_NEGZERO
)),
271 _mm512_castps_si512(x
.simdInternal_
))) };
274 static inline SimdFloat gmx_simdcall
max(SimdFloat a
, SimdFloat b
)
276 return { _mm512_gmax_ps(a
.simdInternal_
, b
.simdInternal_
) };
279 static inline SimdFloat gmx_simdcall
min(SimdFloat a
, SimdFloat b
)
281 return { _mm512_gmin_ps(a
.simdInternal_
, b
.simdInternal_
) };
284 static inline SimdFloat gmx_simdcall
round(SimdFloat x
)
286 return { _mm512_round_ps(x
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
, _MM_EXPADJ_NONE
) };
289 static inline SimdFloat gmx_simdcall
trunc(SimdFloat x
)
291 return { _mm512_round_ps(x
.simdInternal_
, _MM_FROUND_TO_ZERO
, _MM_EXPADJ_NONE
) };
294 template<MathOptimization opt
= MathOptimization::Safe
>
295 static inline SimdFloat gmx_simdcall
frexp(SimdFloat value
, SimdFInt32
* exponent
)
301 if (opt
== MathOptimization::Safe
)
303 // For the safe branch, we use the masked operations to only assign results if the
304 // input value was nonzero, and otherwise set exponent to 0, and the fraction to the input (+-0).
305 __mmask16 valueIsNonZero
=
306 _mm512_cmp_ps_mask(_mm512_setzero_ps(), value
.simdInternal_
, _CMP_NEQ_OQ
);
307 rExponent
= _mm512_mask_getexp_ps(_mm512_setzero_ps(), valueIsNonZero
, value
.simdInternal_
);
308 iExponent
= _mm512_cvtfxpnt_round_adjustps_epi32(rExponent
, _MM_FROUND_TO_NEAREST_INT
,
310 iExponent
= _mm512_mask_add_epi32(iExponent
, valueIsNonZero
, iExponent
, _mm512_set1_epi32(1));
312 // Set result to input value when the latter is +-0
313 result
= _mm512_mask_getmant_ps(value
.simdInternal_
, valueIsNonZero
, value
.simdInternal_
,
314 _MM_MANT_NORM_p5_1
, _MM_MANT_SIGN_src
);
318 rExponent
= _mm512_getexp_ps(value
.simdInternal_
);
319 iExponent
= _mm512_cvtfxpnt_round_adjustps_epi32(rExponent
, _MM_FROUND_TO_NEAREST_INT
,
321 iExponent
= _mm512_add_epi32(iExponent
, _mm512_set1_epi32(1));
322 result
= _mm512_getmant_ps(value
.simdInternal_
, _MM_MANT_NORM_p5_1
, _MM_MANT_SIGN_src
);
325 exponent
->simdInternal_
= iExponent
;
330 template<MathOptimization opt
= MathOptimization::Safe
>
331 static inline SimdFloat gmx_simdcall
ldexp(SimdFloat value
, SimdFInt32 exponent
)
333 const __m512i exponentBias
= _mm512_set1_epi32(127);
334 __m512i iExponent
= _mm512_add_epi32(exponent
.simdInternal_
, exponentBias
);
336 if (opt
== MathOptimization::Safe
)
338 // Make sure biased argument is not negative
339 iExponent
= _mm512_max_epi32(iExponent
, _mm512_setzero_epi32());
342 iExponent
= _mm512_slli_epi32(iExponent
, 23);
344 return { _mm512_mul_ps(value
.simdInternal_
, _mm512_castsi512_ps(iExponent
)) };
347 static inline float gmx_simdcall
reduce(SimdFloat a
)
349 return _mm512_reduce_add_ps(a
.simdInternal_
);
352 // Picky, picky, picky:
353 // icc-16 complains about "Illegal value of immediate argument to intrinsic"
355 // 1) Ordered-quiet for ==
356 // 2) Unordered-quiet for !=
357 // 3) Ordered-signaling for < and <=
359 static inline SimdFBool gmx_simdcall
operator==(SimdFloat a
, SimdFloat b
)
361 return { _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_EQ_OQ
) };
364 static inline SimdFBool gmx_simdcall
operator!=(SimdFloat a
, SimdFloat b
)
366 return { _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_NEQ_UQ
) };
369 static inline SimdFBool gmx_simdcall
operator<(SimdFloat a
, SimdFloat b
)
371 return { _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_LT_OS
) };
374 static inline SimdFBool gmx_simdcall
operator<=(SimdFloat a
, SimdFloat b
)
376 return { _mm512_cmp_ps_mask(a
.simdInternal_
, b
.simdInternal_
, _CMP_LE_OS
) };
379 static inline SimdFBool gmx_simdcall
testBits(SimdFloat a
)
381 return { _mm512_test_epi32_mask(_mm512_castps_si512(a
.simdInternal_
),
382 _mm512_castps_si512(a
.simdInternal_
)) };
385 static inline SimdFBool gmx_simdcall
operator&&(SimdFBool a
, SimdFBool b
)
387 return { _mm512_kand(a
.simdInternal_
, b
.simdInternal_
) };
390 static inline SimdFBool gmx_simdcall
operator||(SimdFBool a
, SimdFBool b
)
392 return { _mm512_kor(a
.simdInternal_
, b
.simdInternal_
) };
395 static inline bool gmx_simdcall
anyTrue(SimdFBool a
)
397 return _mm512_mask2int(a
.simdInternal_
) != 0;
400 static inline SimdFloat gmx_simdcall
selectByMask(SimdFloat a
, SimdFBool m
)
402 return { _mm512_mask_mov_ps(_mm512_setzero_ps(), m
.simdInternal_
, a
.simdInternal_
) };
405 static inline SimdFloat gmx_simdcall
selectByNotMask(SimdFloat a
, SimdFBool m
)
407 return { _mm512_mask_mov_ps(a
.simdInternal_
, m
.simdInternal_
, _mm512_setzero_ps()) };
410 static inline SimdFloat gmx_simdcall
blend(SimdFloat a
, SimdFloat b
, SimdFBool sel
)
412 return { _mm512_mask_blend_ps(sel
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
) };
415 static inline SimdFInt32 gmx_simdcall
operator&(SimdFInt32 a
, SimdFInt32 b
)
417 return { _mm512_and_epi32(a
.simdInternal_
, b
.simdInternal_
) };
420 static inline SimdFInt32 gmx_simdcall
andNot(SimdFInt32 a
, SimdFInt32 b
)
422 return { _mm512_andnot_epi32(a
.simdInternal_
, b
.simdInternal_
) };
425 static inline SimdFInt32 gmx_simdcall
operator|(SimdFInt32 a
, SimdFInt32 b
)
427 return { _mm512_or_epi32(a
.simdInternal_
, b
.simdInternal_
) };
430 static inline SimdFInt32 gmx_simdcall
operator^(SimdFInt32 a
, SimdFInt32 b
)
432 return { _mm512_xor_epi32(a
.simdInternal_
, b
.simdInternal_
) };
435 static inline SimdFInt32 gmx_simdcall
operator+(SimdFInt32 a
, SimdFInt32 b
)
437 return { _mm512_add_epi32(a
.simdInternal_
, b
.simdInternal_
) };
440 static inline SimdFInt32 gmx_simdcall
operator-(SimdFInt32 a
, SimdFInt32 b
)
442 return { _mm512_sub_epi32(a
.simdInternal_
, b
.simdInternal_
) };
445 static inline SimdFInt32 gmx_simdcall
operator*(SimdFInt32 a
, SimdFInt32 b
)
447 return { _mm512_mullo_epi32(a
.simdInternal_
, b
.simdInternal_
) };
450 static inline SimdFIBool gmx_simdcall
operator==(SimdFInt32 a
, SimdFInt32 b
)
452 return { _mm512_cmp_epi32_mask(a
.simdInternal_
, b
.simdInternal_
, _MM_CMPINT_EQ
) };
455 static inline SimdFIBool gmx_simdcall
testBits(SimdFInt32 a
)
457 return { _mm512_test_epi32_mask(a
.simdInternal_
, a
.simdInternal_
) };
460 static inline SimdFIBool gmx_simdcall
operator<(SimdFInt32 a
, SimdFInt32 b
)
462 return { _mm512_cmp_epi32_mask(a
.simdInternal_
, b
.simdInternal_
, _MM_CMPINT_LT
) };
465 static inline SimdFIBool gmx_simdcall
operator&&(SimdFIBool a
, SimdFIBool b
)
467 return { _mm512_kand(a
.simdInternal_
, b
.simdInternal_
) };
470 static inline SimdFIBool gmx_simdcall
operator||(SimdFIBool a
, SimdFIBool b
)
472 return { _mm512_kor(a
.simdInternal_
, b
.simdInternal_
) };
475 static inline bool gmx_simdcall
anyTrue(SimdFIBool a
)
477 return _mm512_mask2int(a
.simdInternal_
) != 0;
480 static inline SimdFInt32 gmx_simdcall
selectByMask(SimdFInt32 a
, SimdFIBool m
)
482 return { _mm512_mask_mov_epi32(_mm512_setzero_epi32(), m
.simdInternal_
, a
.simdInternal_
) };
485 static inline SimdFInt32 gmx_simdcall
selectByNotMask(SimdFInt32 a
, SimdFIBool m
)
487 return { _mm512_mask_mov_epi32(a
.simdInternal_
, m
.simdInternal_
, _mm512_setzero_epi32()) };
490 static inline SimdFInt32 gmx_simdcall
blend(SimdFInt32 a
, SimdFInt32 b
, SimdFIBool sel
)
492 return { _mm512_mask_blend_epi32(sel
.simdInternal_
, a
.simdInternal_
, b
.simdInternal_
) };
495 static inline SimdFInt32 gmx_simdcall
cvtR2I(SimdFloat a
)
497 return { _mm512_cvtfxpnt_round_adjustps_epi32(a
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
,
501 static inline SimdFInt32 gmx_simdcall
cvttR2I(SimdFloat a
)
503 return { _mm512_cvtfxpnt_round_adjustps_epi32(a
.simdInternal_
, _MM_FROUND_TO_ZERO
, _MM_EXPADJ_NONE
) };
506 static inline SimdFloat gmx_simdcall
cvtI2R(SimdFInt32 a
)
508 return { _mm512_cvtfxpnt_round_adjustepi32_ps(a
.simdInternal_
, _MM_FROUND_TO_NEAREST_INT
,
512 static inline SimdFIBool gmx_simdcall
cvtB2IB(SimdFBool a
)
514 return { a
.simdInternal_
};
517 static inline SimdFBool gmx_simdcall
cvtIB2B(SimdFIBool a
)
519 return { a
.simdInternal_
};
523 template<MathOptimization opt
= MathOptimization::Safe
>
524 static inline SimdFloat gmx_simdcall
exp2(SimdFloat x
)
526 return { _mm512_exp223_ps(_mm512_cvtfxpnt_round_adjustps_epi32(
527 x
.simdInternal_
, _MM_ROUND_MODE_NEAREST
, _MM_EXPADJ_24
)) };
530 template<MathOptimization opt
= MathOptimization::Safe
>
531 static inline SimdFloat gmx_simdcall
exp(SimdFloat x
)
533 const __m512 argscale
= _mm512_set1_ps(1.44269504088896341F
);
534 const __m512 invargscale
= _mm512_set1_ps(-0.69314718055994528623F
);
536 if (opt
== MathOptimization::Safe
)
538 // Set the limit to gurantee flush to zero
539 const SimdFloat
smallArgLimit(-88.f
);
540 // Since we multiply the argument by 1.44, for the safe version we need to make
541 // sure this doesn't result in overflow
542 x
= max(x
, smallArgLimit
);
545 __m512 xscaled
= _mm512_mul_ps(x
.simdInternal_
, argscale
);
546 __m512 r
= _mm512_exp223_ps(
547 _mm512_cvtfxpnt_round_adjustps_epi32(xscaled
, _MM_ROUND_MODE_NEAREST
, _MM_EXPADJ_24
));
549 // exp2a23_ps provides 23 bits of accuracy, but we ruin some of that with our argument
550 // scaling. To correct this, we find the difference between the scaled argument and
551 // the true one (extended precision arithmetics does not appear to be necessary to
552 // fulfill our accuracy requirements) and then multiply by the exponent of this
553 // correction since exp(a+b)=exp(a)*exp(b).
554 // Note that this only adds two instructions (and maybe some constant loads).
556 // find the difference
557 x
= _mm512_fmadd_ps(invargscale
, xscaled
, x
.simdInternal_
);
558 // x will now be a _very_ small number, so approximate exp(x)=1+x.
559 // We should thus apply the correction as r'=r*(1+x)=r+r*x
560 r
= _mm512_fmadd_ps(r
, x
.simdInternal_
, r
);
564 static inline SimdFloat gmx_simdcall
log(SimdFloat x
)
566 return { _mm512_mul_ps(_mm512_set1_ps(0.693147180559945286226764F
),
567 _mm512_log2ae23_ps(x
.simdInternal_
)) };
572 #endif // GMX_SIMD_IMPL_X86_MIC_SIMD_FLOAT_H