Fix ARM NEON simd debug builds
[gromacs.git] / src / gromacs / simd / impl_arm_neon / impl_arm_neon_simd_float.h
blob7247b6c871c8a0cffcd00a5e1655f434dfb41606
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2014,2015,2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #ifndef GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
36 #define GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H
38 #include "config.h"
40 #include <cassert>
41 #include <cstddef>
42 #include <cstdint>
44 #include <arm_neon.h>
46 namespace gmx
49 class SimdFloat
51 public:
52 SimdFloat() {}
54 SimdFloat(float f) : simdInternal_(vdupq_n_f32(f)) {}
56 // Internal utility constructor to simplify return statements
57 SimdFloat(float32x4_t simd) : simdInternal_(simd) {}
59 float32x4_t simdInternal_;
62 class SimdFInt32
64 public:
65 SimdFInt32() {}
67 SimdFInt32(std::int32_t i) : simdInternal_(vdupq_n_s32(i)) {}
69 // Internal utility constructor to simplify return statements
70 SimdFInt32(int32x4_t simd) : simdInternal_(simd) {}
72 int32x4_t simdInternal_;
75 class SimdFBool
77 public:
78 SimdFBool() {}
80 SimdFBool(bool b) : simdInternal_(vdupq_n_u32( b ? 0xFFFFFFFF : 0)) {}
82 // Internal utility constructor to simplify return statements
83 SimdFBool(uint32x4_t simd) : simdInternal_(simd) {}
85 uint32x4_t simdInternal_;
88 class SimdFIBool
90 public:
91 SimdFIBool() {}
93 SimdFIBool(bool b) : simdInternal_(vdupq_n_u32( b ? 0xFFFFFFFF : 0)) {}
95 // Internal utility constructor to simplify return statements
96 SimdFIBool(uint32x4_t simd) : simdInternal_(simd) {}
98 uint32x4_t simdInternal_;
101 static inline SimdFloat gmx_simdcall
102 simdLoad(const float *m)
104 assert(std::size_t(m) % 16 == 0);
105 return {
106 vld1q_f32(m)
110 static inline void gmx_simdcall
111 store(float *m, SimdFloat a)
113 assert(std::size_t(m) % 16 == 0);
114 vst1q_f32(m, a.simdInternal_);
117 static inline SimdFloat gmx_simdcall
118 simdLoadU(const float *m)
120 return {
121 vld1q_f32(m)
125 static inline void gmx_simdcall
126 storeU(float *m, SimdFloat a)
128 vst1q_f32(m, a.simdInternal_);
131 static inline SimdFloat gmx_simdcall
132 setZeroF()
134 return {
135 vdupq_n_f32(0.0f)
139 static inline SimdFInt32 gmx_simdcall
140 simdLoadFI(const std::int32_t * m)
142 assert(std::size_t(m) % 16 == 0);
143 return {
144 vld1q_s32(m)
148 static inline void gmx_simdcall
149 store(std::int32_t * m, SimdFInt32 a)
151 assert(std::size_t(m) % 16 == 0);
152 vst1q_s32(m, a.simdInternal_);
155 static inline SimdFInt32 gmx_simdcall
156 simdLoadUFI(const std::int32_t *m)
158 return {
159 vld1q_s32(m)
163 static inline void gmx_simdcall
164 storeU(std::int32_t * m, SimdFInt32 a)
166 vst1q_s32(m, a.simdInternal_);
169 static inline SimdFInt32 gmx_simdcall
170 setZeroFI()
172 return {
173 vdupq_n_s32(0)
177 template<int index> gmx_simdcall
178 static inline std::int32_t
179 extract(SimdFInt32 a)
181 return vgetq_lane_s32(a.simdInternal_, index);
184 static inline SimdFloat gmx_simdcall
185 operator&(SimdFloat a, SimdFloat b)
187 return {
188 vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(a.simdInternal_),
189 vreinterpretq_s32_f32(b.simdInternal_)))
193 static inline SimdFloat gmx_simdcall
194 andNot(SimdFloat a, SimdFloat b)
196 return {
197 vreinterpretq_f32_s32(vbicq_s32(vreinterpretq_s32_f32(b.simdInternal_),
198 vreinterpretq_s32_f32(a.simdInternal_)))
202 static inline SimdFloat gmx_simdcall
203 operator|(SimdFloat a, SimdFloat b)
205 return {
206 vreinterpretq_f32_s32(vorrq_s32(vreinterpretq_s32_f32(a.simdInternal_),
207 vreinterpretq_s32_f32(b.simdInternal_)))
211 static inline SimdFloat gmx_simdcall
212 operator^(SimdFloat a, SimdFloat b)
214 return {
215 vreinterpretq_f32_s32(veorq_s32(vreinterpretq_s32_f32(a.simdInternal_),
216 vreinterpretq_s32_f32(b.simdInternal_)))
220 static inline SimdFloat gmx_simdcall
221 operator+(SimdFloat a, SimdFloat b)
223 return {
224 vaddq_f32(a.simdInternal_, b.simdInternal_)
228 static inline SimdFloat gmx_simdcall
229 operator-(SimdFloat a, SimdFloat b)
231 return {
232 vsubq_f32(a.simdInternal_, b.simdInternal_)
236 static inline SimdFloat gmx_simdcall
237 operator-(SimdFloat x)
239 return {
240 vnegq_f32(x.simdInternal_)
244 static inline SimdFloat gmx_simdcall
245 operator*(SimdFloat a, SimdFloat b)
247 return {
248 vmulq_f32(a.simdInternal_, b.simdInternal_)
252 // Override for Neon-Asimd
253 #if GMX_SIMD_ARM_NEON
254 static inline SimdFloat gmx_simdcall
255 fma(SimdFloat a, SimdFloat b, SimdFloat c)
257 return {
258 #ifdef __ARM_FEATURE_FMA
259 vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
260 #else
261 vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
262 #endif
266 static inline SimdFloat gmx_simdcall
267 fms(SimdFloat a, SimdFloat b, SimdFloat c)
269 return {
270 #ifdef __ARM_FEATURE_FMA
271 vnegq_f32(vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
272 #else
273 vnegq_f32(vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
274 #endif
278 static inline SimdFloat gmx_simdcall
279 fnma(SimdFloat a, SimdFloat b, SimdFloat c)
281 return {
282 #ifdef __ARM_FEATURE_FMA
283 vfmsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
284 #else
285 vmlsq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_)
286 #endif
290 static inline SimdFloat gmx_simdcall
291 fnms(SimdFloat a, SimdFloat b, SimdFloat c)
293 return {
294 #ifdef __ARM_FEATURE_FMA
295 vnegq_f32(vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
296 #else
297 vnegq_f32(vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_))
298 #endif
301 #endif
303 static inline SimdFloat gmx_simdcall
304 rsqrt(SimdFloat x)
306 return {
307 vrsqrteq_f32(x.simdInternal_)
311 static inline SimdFloat gmx_simdcall
312 rsqrtIter(SimdFloat lu, SimdFloat x)
314 return {
315 vmulq_f32(lu.simdInternal_, vrsqrtsq_f32(vmulq_f32(lu.simdInternal_, lu.simdInternal_), x.simdInternal_))
319 static inline SimdFloat gmx_simdcall
320 rcp(SimdFloat x)
322 return {
323 vrecpeq_f32(x.simdInternal_)
327 static inline SimdFloat gmx_simdcall
328 rcpIter(SimdFloat lu, SimdFloat x)
330 return {
331 vmulq_f32(lu.simdInternal_, vrecpsq_f32(lu.simdInternal_, x.simdInternal_))
335 static inline SimdFloat gmx_simdcall
336 maskAdd(SimdFloat a, SimdFloat b, SimdFBool m)
338 b.simdInternal_ = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(b.simdInternal_),
339 m.simdInternal_));
341 return {
342 vaddq_f32(a.simdInternal_, b.simdInternal_)
346 static inline SimdFloat gmx_simdcall
347 maskzMul(SimdFloat a, SimdFloat b, SimdFBool m)
349 SimdFloat tmp = a * b;
351 return {
352 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp.simdInternal_),
353 m.simdInternal_))
357 static inline SimdFloat gmx_simdcall
358 maskzFma(SimdFloat a, SimdFloat b, SimdFloat c, SimdFBool m)
360 #ifdef __ARM_FEATURE_FMA
361 float32x4_t tmp = vfmaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
362 #else
363 float32x4_t tmp = vmlaq_f32(c.simdInternal_, b.simdInternal_, a.simdInternal_);
364 #endif
366 return {
367 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(tmp),
368 m.simdInternal_))
372 static inline SimdFloat gmx_simdcall
373 maskzRsqrt(SimdFloat x, SimdFBool m)
375 #ifndef NDEBUG
376 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0f));
377 #endif
378 return {
379 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrsqrteq_f32(x.simdInternal_)),
380 m.simdInternal_))
384 static inline SimdFloat gmx_simdcall
385 maskzRcp(SimdFloat x, SimdFBool m)
387 #ifndef NDEBUG
388 x.simdInternal_ = vbslq_f32(m.simdInternal_, x.simdInternal_, vdupq_n_f32(1.0f));
389 #endif
390 return {
391 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(vrecpeq_f32(x.simdInternal_)),
392 m.simdInternal_))
396 static inline SimdFloat gmx_simdcall
397 abs(SimdFloat x)
399 return {
400 vabsq_f32( x.simdInternal_ )
404 static inline SimdFloat gmx_simdcall
405 max(SimdFloat a, SimdFloat b)
407 return {
408 vmaxq_f32(a.simdInternal_, b.simdInternal_)
412 static inline SimdFloat gmx_simdcall
413 min(SimdFloat a, SimdFloat b)
415 return {
416 vminq_f32(a.simdInternal_, b.simdInternal_)
420 // Round and trunc operations are defined at the end of this file, since they
421 // need to use float-to-integer and integer-to-float conversions.
423 static inline SimdFloat gmx_simdcall
424 frexp(SimdFloat value, SimdFInt32 * exponent)
426 const int32x4_t exponentMask = vdupq_n_s32(0x7F800000);
427 const int32x4_t mantissaMask = vdupq_n_s32(0x807FFFFF);
428 const int32x4_t exponentBias = vdupq_n_s32(126); // add 1 to make our definition identical to frexp()
429 const float32x4_t half = vdupq_n_f32(0.5f);
430 int32x4_t iExponent;
432 iExponent = vandq_s32(vreinterpretq_s32_f32(value.simdInternal_), exponentMask);
433 iExponent = vsubq_s32(vshrq_n_s32(iExponent, 23), exponentBias);
434 exponent->simdInternal_ = iExponent;
436 return {
437 vreinterpretq_f32_s32(vorrq_s32(vandq_s32(vreinterpretq_s32_f32(value.simdInternal_),
438 mantissaMask),
439 vreinterpretq_s32_f32(half)))
443 static inline SimdFloat gmx_simdcall
444 ldexp(SimdFloat value, SimdFInt32 exponent)
446 const int32x4_t exponentBias = vdupq_n_s32(127);
447 int32x4_t iExponent;
449 iExponent = vshlq_n_s32( vaddq_s32(exponent.simdInternal_, exponentBias), 23);
451 return {
452 vmulq_f32(value.simdInternal_, vreinterpretq_f32_s32(iExponent))
456 // Override for Neon-Asimd
457 #if GMX_SIMD_ARM_NEON
458 static inline float gmx_simdcall
459 reduce(SimdFloat a)
461 float32x4_t x = a.simdInternal_;
462 float32x4_t y = vextq_f32(x, x, 2);
464 x = vaddq_f32(x, y);
465 y = vextq_f32(x, x, 1);
466 x = vaddq_f32(x, y);
467 return vgetq_lane_f32(x, 0);
469 #endif
471 static inline SimdFBool gmx_simdcall
472 operator==(SimdFloat a, SimdFloat b)
474 return {
475 vceqq_f32(a.simdInternal_, b.simdInternal_)
479 static inline SimdFBool gmx_simdcall
480 operator!=(SimdFloat a, SimdFloat b)
482 return {
483 vmvnq_u32(vceqq_f32(a.simdInternal_, b.simdInternal_))
487 static inline SimdFBool gmx_simdcall
488 operator<(SimdFloat a, SimdFloat b)
490 return {
491 vcltq_f32(a.simdInternal_, b.simdInternal_)
495 static inline SimdFBool gmx_simdcall
496 operator<=(SimdFloat a, SimdFloat b)
498 return {
499 vcleq_f32(a.simdInternal_, b.simdInternal_)
503 static inline SimdFBool gmx_simdcall
504 testBits(SimdFloat a)
506 uint32x4_t tmp = vreinterpretq_u32_f32(a.simdInternal_);
508 return {
509 vtstq_u32(tmp, tmp)
513 static inline SimdFBool gmx_simdcall
514 operator&&(SimdFBool a, SimdFBool b)
517 return {
518 vandq_u32(a.simdInternal_, b.simdInternal_)
522 static inline SimdFBool gmx_simdcall
523 operator||(SimdFBool a, SimdFBool b)
525 return {
526 vorrq_u32(a.simdInternal_, b.simdInternal_)
530 // Override for Neon-Asimd
531 #if GMX_SIMD_ARM_NEON
532 static inline bool gmx_simdcall
533 anyTrue(SimdFBool a)
535 uint32x4_t x = a.simdInternal_;
536 uint32x4_t y = vextq_u32(x, x, 2);
538 x = vorrq_u32(x, y);
539 y = vextq_u32(x, x, 1);
540 x = vorrq_u32(x, y);
541 return (vgetq_lane_u32(x, 0) != 0);
543 #endif
545 static inline SimdFloat gmx_simdcall
546 selectByMask(SimdFloat a, SimdFBool m)
548 return {
549 vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.simdInternal_),
550 m.simdInternal_))
554 static inline SimdFloat gmx_simdcall
555 selectByNotMask(SimdFloat a, SimdFBool m)
557 return {
558 vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.simdInternal_),
559 m.simdInternal_))
563 static inline SimdFloat gmx_simdcall
564 blend(SimdFloat a, SimdFloat b, SimdFBool sel)
566 return {
567 vbslq_f32(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
571 static inline SimdFInt32 gmx_simdcall
572 operator<<(SimdFInt32 a, int n)
574 return {
575 vshlq_n_s32(a.simdInternal_, n)
579 static inline SimdFInt32 gmx_simdcall
580 operator>>(SimdFInt32 a, int n)
582 return {
583 vshrq_n_s32(a.simdInternal_, n)
587 static inline SimdFInt32 gmx_simdcall
588 operator&(SimdFInt32 a, SimdFInt32 b)
590 return {
591 vandq_s32(a.simdInternal_, b.simdInternal_)
595 static inline SimdFInt32 gmx_simdcall
596 andNot(SimdFInt32 a, SimdFInt32 b)
598 return {
599 vbicq_s32(b.simdInternal_, a.simdInternal_)
603 static inline SimdFInt32 gmx_simdcall
604 operator|(SimdFInt32 a, SimdFInt32 b)
606 return {
607 vorrq_s32(a.simdInternal_, b.simdInternal_)
611 static inline SimdFInt32 gmx_simdcall
612 operator^(SimdFInt32 a, SimdFInt32 b)
614 return {
615 veorq_s32(a.simdInternal_, b.simdInternal_)
619 static inline SimdFInt32 gmx_simdcall
620 operator+(SimdFInt32 a, SimdFInt32 b)
622 return {
623 vaddq_s32(a.simdInternal_, b.simdInternal_)
627 static inline SimdFInt32 gmx_simdcall
628 operator-(SimdFInt32 a, SimdFInt32 b)
630 return {
631 vsubq_s32(a.simdInternal_, b.simdInternal_)
635 static inline SimdFInt32 gmx_simdcall
636 operator*(SimdFInt32 a, SimdFInt32 b)
638 return {
639 vmulq_s32(a.simdInternal_, b.simdInternal_)
643 static inline SimdFIBool gmx_simdcall
644 operator==(SimdFInt32 a, SimdFInt32 b)
646 return {
647 vceqq_s32(a.simdInternal_, b.simdInternal_)
651 static inline SimdFIBool gmx_simdcall
652 testBits(SimdFInt32 a)
654 return {
655 vtstq_s32(a.simdInternal_, a.simdInternal_)
659 static inline SimdFIBool gmx_simdcall
660 operator<(SimdFInt32 a, SimdFInt32 b)
662 return {
663 vcltq_s32(a.simdInternal_, b.simdInternal_)
667 static inline SimdFIBool gmx_simdcall
668 operator&&(SimdFIBool a, SimdFIBool b)
670 return {
671 vandq_u32(a.simdInternal_, b.simdInternal_)
675 static inline SimdFIBool gmx_simdcall
676 operator||(SimdFIBool a, SimdFIBool b)
678 return {
679 vorrq_u32(a.simdInternal_, b.simdInternal_)
683 // Override for Neon-Asimd
684 #if GMX_SIMD_ARM_NEON
685 static inline bool gmx_simdcall
686 anyTrue(SimdFIBool a)
688 uint32x4_t x = a.simdInternal_;
689 uint32x4_t y = vextq_u32(x, x, 2);
691 x = vorrq_u32(x, y);
692 y = vextq_u32(x, x, 1);
693 x = vorrq_u32(x, y);
694 return (vgetq_lane_u32(x, 0) != 0);
696 #endif
698 static inline SimdFInt32 gmx_simdcall
699 selectByMask(SimdFInt32 a, SimdFIBool m)
701 return {
702 vandq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_))
706 static inline SimdFInt32 gmx_simdcall
707 selectByNotMask(SimdFInt32 a, SimdFIBool m)
709 return {
710 vbicq_s32(a.simdInternal_, vreinterpretq_s32_u32(m.simdInternal_))
714 static inline SimdFInt32 gmx_simdcall
715 blend(SimdFInt32 a, SimdFInt32 b, SimdFIBool sel)
717 return {
718 vbslq_s32(sel.simdInternal_, b.simdInternal_, a.simdInternal_)
722 // Override for Neon-Asimd
723 #if GMX_SIMD_ARM_NEON
724 static inline SimdFInt32 gmx_simdcall
725 cvtR2I(SimdFloat a)
727 float32x4_t signBitOfA = vreinterpretq_f32_u32(vandq_u32(vdupq_n_u32(0x80000000), vreinterpretq_u32_f32(a.simdInternal_)));
728 float32x4_t half = vdupq_n_f32(0.5f);
729 float32x4_t corr = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(half), vreinterpretq_u32_f32(signBitOfA)));
731 return {
732 vcvtq_s32_f32(vaddq_f32(a.simdInternal_, corr))
735 #endif
737 static inline SimdFInt32 gmx_simdcall
738 cvttR2I(SimdFloat a)
740 return {
741 vcvtq_s32_f32(a.simdInternal_)
745 static inline SimdFloat gmx_simdcall
746 cvtI2R(SimdFInt32 a)
748 return {
749 vcvtq_f32_s32(a.simdInternal_)
753 static inline SimdFIBool gmx_simdcall
754 cvtB2IB(SimdFBool a)
756 return {
757 a.simdInternal_
761 static inline SimdFBool gmx_simdcall
762 cvtIB2B(SimdFIBool a)
764 return {
765 a.simdInternal_
769 // Override for Neon-Asimd
770 #if GMX_SIMD_ARM_NEON
771 static inline SimdFloat gmx_simdcall
772 round(SimdFloat x)
774 return cvtI2R(cvtR2I(x));
777 static inline SimdFloat gmx_simdcall
778 trunc(SimdFloat x)
780 return cvtI2R(cvttR2I(x));
782 #endif
784 } // namespace gmx
786 #endif // GMX_SIMD_IMPL_ARM_NEON_SIMD_FLOAT_H