Double precision AVX-256 kernels
[gromacs.git] / include / gmx_x86_avx_256.h
blob6a561fa04b5711dbd0d2a6e996327030c620bb3d
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
3 *
4 * This file is part of GROMACS.
5 * Copyright (c) 2012-
7 * Written by the Gromacs development team under coordination of
8 * David van der Spoel, Berk Hess, and Erik Lindahl.
10 * This library is free software; you can redistribute it and/or
11 * modify it under the terms of the GNU Lesser General Public License
12 * as published by the Free Software Foundation; either version 2
13 * of the License, or (at your option) any later version.
15 * To help us fund GROMACS development, we humbly ask that you cite
16 * the research papers on the package. Check out http://www.gromacs.org
18 * And Hey:
19 * Gnomes, ROck Monsters And Chili Sauce
21 #ifndef _gmx_x86_avx_256_h_
22 #define _gmx_x86_avx_256_h_
25 #include <immintrin.h>
26 #ifdef HAVE_X86INTRIN_H
27 #include <x86intrin.h> /* FMA */
28 #endif
31 #include <stdio.h>
33 #include "types/simple.h"
36 #define gmx_mm_extract_epi32(x, imm) _mm_cvtsi128_si32(_mm_srli_si128((x), 4 * (imm)))
38 #define _GMX_MM_BLEND256D(b3,b2,b1,b0) (((b3) << 3) | ((b2) << 2) | ((b1) << 1) | ((b0)))
39 #define _GMX_MM_PERMUTE(fp3,fp2,fp1,fp0) (((fp3) << 6) | ((fp2) << 4) | ((fp1) << 2) | ((fp0)))
40 #define _GMX_MM_PERMUTE256D(fp3,fp2,fp1,fp0) (((fp3) << 3) | ((fp2) << 2) | ((fp1) << 1) | ((fp0)))
41 #define _GMX_MM_PERMUTE128D(fp1,fp0) (((fp1) << 1) | ((fp0)))
44 #define GMX_MM_TRANSPOSE2_PD(row0, row1) { \
45 __m128d __gmx_t1 = row0; \
46 row0 = _mm_unpacklo_pd(row0,row1); \
47 row1 = _mm_unpackhi_pd(__gmx_t1,row1); \
50 #define GMX_MM256_FULLTRANSPOSE4_PD(row0,row1,row2,row3) \
51 { \
52 __m256d _t0, _t1, _t2, _t3; \
53 _t0 = _mm256_unpacklo_pd((row0), (row1)); \
54 _t1 = _mm256_unpackhi_pd((row0), (row1)); \
55 _t2 = _mm256_unpacklo_pd((row2), (row3)); \
56 _t3 = _mm256_unpackhi_pd((row2), (row3)); \
57 row0 = _mm256_permute2f128_pd(_t0, _t2, 0x20); \
58 row1 = _mm256_permute2f128_pd(_t1, _t3, 0x20); \
59 row2 = _mm256_permute2f128_pd(_t0, _t2, 0x31); \
60 row3 = _mm256_permute2f128_pd(_t1, _t3, 0x31); \
63 #if (defined (_MSC_VER) || defined(__INTEL_COMPILER))
64 # define gmx_mm_castsi128_ps(a) _mm_castsi128_ps(a)
65 # define gmx_mm_castps_si128(a) _mm_castps_si128(a)
66 # define gmx_mm_castps_ps128(a) (a)
67 # define gmx_mm_castsi128_pd(a) _mm_castsi128_pd(a)
68 # define gmx_mm_castpd_si128(a) _mm_castpd_si128(a)
69 #elif defined(__GNUC__)
70 # define gmx_mm_castsi128_ps(a) ((__m128)(a))
71 # define gmx_mm_castps_si128(a) ((__m128i)(a))
72 # define gmx_mm_castps_ps128(a) ((__m128)(a))
73 # define gmx_mm_castsi128_pd(a) ((__m128d)(a))
74 # define gmx_mm_castpd_si128(a) ((__m128i)(a))
75 #else
76 static __m128 gmx_mm_castsi128_ps(__m128i a)
78 return *(__m128 *) &a;
80 static __m128i gmx_mm_castps_si128(__m128 a)
82 return *(__m128i *) &a;
84 static __m128 gmx_mm_castps_ps128(__m128 a)
86 return *(__m128 *) &a;
88 static __m128d gmx_mm_castsi128_pd(__m128i a)
90 return *(__m128d *) &a;
92 static __m128i gmx_mm_castpd_si128(__m128d a)
94 return *(__m128i *) &a;
96 #endif
98 static gmx_inline __m256
99 gmx_mm256_unpack128lo_ps(__m256 xmm1, __m256 xmm2)
101 return _mm256_permute2f128_ps(xmm1,xmm2,0x20);
104 static gmx_inline __m256
105 gmx_mm256_unpack128hi_ps(__m256 xmm1, __m256 xmm2)
107 return _mm256_permute2f128_ps(xmm1,xmm2,0x31);
110 static gmx_inline __m256
111 gmx_mm256_set_m128(__m128 hi, __m128 lo)
113 return _mm256_insertf128_ps(_mm256_castps128_ps256(lo), hi, 0x1);
117 static __m256d
118 gmx_mm256_unpack128lo_pd(__m256d xmm1, __m256d xmm2)
120 return _mm256_permute2f128_pd(xmm1,xmm2,0x20);
123 static __m256d
124 gmx_mm256_unpack128hi_pd(__m256d xmm1, __m256d xmm2)
126 return _mm256_permute2f128_pd(xmm1,xmm2,0x31);
129 static __m256d
130 gmx_mm256_set_m128d(__m128d hi, __m128d lo)
132 return _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 0x1);
138 static void
139 gmx_mm_printxmm_ps(const char *s,__m128 xmm)
141 float f[4];
143 _mm_storeu_ps(f,xmm);
144 printf("%s: %15.10e %15.10e %15.10e %15.10e\n",s,f[0],f[1],f[2],f[3]);
148 static void
149 gmx_mm_printxmmsum_ps(const char *s,__m128 xmm)
151 float f[4];
153 _mm_storeu_ps(f,xmm);
154 printf("%s (sum): %15.10g\n",s,f[0]+f[1]+f[2]+f[3]);
158 static void
159 gmx_mm_printxmm_pd(const char *s,__m128d xmm)
161 double f[2];
163 _mm_storeu_pd(f,xmm);
164 printf("%s: %30.20e %30.20e\n",s,f[0],f[1]);
167 static void
168 gmx_mm_printxmmsum_pd(const char *s,__m128d xmm)
170 double f[2];
172 _mm_storeu_pd(f,xmm);
173 printf("%s (sum): %15.10g\n",s,f[0]+f[1]);
177 static void
178 gmx_mm_printxmm_epi32(const char *s,__m128i xmmi)
180 int i[4];
182 _mm_storeu_si128((__m128i *)i,xmmi);
183 printf("%10s: %2d %2d %2d %2d\n",s,i[0],i[1],i[2],i[3]);
186 static void
187 gmx_mm256_printymm_ps(const char *s,__m256 ymm)
189 float f[8];
191 _mm256_storeu_ps(f,ymm);
192 printf("%s: %12.7f %12.7f %12.7f %12.7f %12.7f %12.7f %12.7f %12.7f\n",s,f[0],f[1],f[2],f[3],f[4],f[5],f[6],f[7]);
195 static void
196 gmx_mm256_printymmsum_ps(const char *s,__m256 ymm)
198 float f[8];
200 _mm256_storeu_ps(f,ymm);
201 printf("%s (sum): %15.10g\n",s,f[0]+f[1]+f[2]+f[3]+f[4]+f[5]+f[6]+f[7]);
205 static void
206 gmx_mm256_printymm_pd(const char *s,__m256d ymm)
208 double f[4];
210 _mm256_storeu_pd(f,ymm);
211 printf("%s: %16.12f %16.12f %16.12f %16.12f\n",s,f[0],f[1],f[2],f[3]);
214 static void
215 gmx_mm256_printymmsum_pd(const char *s,__m256d ymm)
217 double f[4];
219 _mm256_storeu_pd(f,ymm);
220 printf("%s (sum): %15.10g\n",s,f[0]+f[1]+f[2]+f[3]);
225 static void
226 gmx_mm256_printymm_epi32(const char *s,__m256i ymmi)
228 int i[8];
230 _mm256_storeu_si256((__m256i *)i,ymmi);
231 printf("%10s: %2d %2d %2d %2d %2d %2d %2d %2d\n",s,i[0],i[1],i[2],i[3],i[4],i[5],i[6],i[7]);
236 static int gmx_mm_check_and_reset_overflow(void)
238 int MXCSR;
239 int sse_overflow;
241 MXCSR = _mm_getcsr();
242 /* The overflow flag is bit 3 in the register */
243 if (MXCSR & 0x0008)
245 sse_overflow = 1;
246 /* Set the overflow flag to zero */
247 MXCSR = MXCSR & 0xFFF7;
248 _mm_setcsr(MXCSR);
250 else
252 sse_overflow = 0;
255 return sse_overflow;
260 #endif /* _gmx_x86_avx_256_h_ */