128-bit AVX2 SIMD for AMD Ryzen
[gromacs.git] / src / gromacs / ewald / pme-solve.cpp
blobd5c0ec3b5be6b6374f76a35375445180a37fbf0a
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5 * Copyright (c) 2001-2004, The GROMACS development team.
6 * Copyright (c) 2013,2014,2015,2016,2017, by the GROMACS development team, led by
7 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8 * and including many others, as listed in the AUTHORS file in the
9 * top-level source directory and at http://www.gromacs.org.
11 * GROMACS is free software; you can redistribute it and/or
12 * modify it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation; either version 2.1
14 * of the License, or (at your option) any later version.
16 * GROMACS is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 * Lesser General Public License for more details.
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with GROMACS; if not, see
23 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26 * If you want to redistribute modifications to GROMACS, please
27 * consider that scientific software is very special. Version
28 * control is crucial - bugs must be traceable. We will be happy to
29 * consider code for inclusion in the official distribution, but
30 * derived work must not be called official GROMACS. Details are found
31 * in the README & COPYING files - if they are missing, get the
32 * official version at http://www.gromacs.org.
34 * To help us fund GROMACS development, we humbly ask that you cite
35 * the research papers on the package. Check out http://www.gromacs.org.
38 #include "gmxpre.h"
40 #include "pme-solve.h"
42 #include <cmath>
44 #include "gromacs/fft/parallel_3dfft.h"
45 #include "gromacs/math/units.h"
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/simd/simd.h"
49 #include "gromacs/simd/simd_math.h"
50 #include "gromacs/utility/exceptions.h"
51 #include "gromacs/utility/smalloc.h"
53 #include "pme-internal.h"
55 #if GMX_SIMD_HAVE_REAL
56 /* Turn on arbitrary width SIMD intrinsics for PME solve */
57 # define PME_SIMD_SOLVE
58 #endif
60 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
62 struct pme_solve_work_t
64 /* work data for solve_pme */
65 int nalloc;
66 real * mhx;
67 real * mhy;
68 real * mhz;
69 real * m2;
70 real * denom;
71 real * tmp1_alloc;
72 real * tmp1;
73 real * tmp2;
74 real * eterm;
75 real * m2inv;
77 real energy_q;
78 matrix vir_q;
79 real energy_lj;
80 matrix vir_lj;
83 static void realloc_work(struct pme_solve_work_t *work, int nkx)
85 if (nkx > work->nalloc)
87 int simd_width, i;
89 work->nalloc = nkx;
90 srenew(work->mhx, work->nalloc);
91 srenew(work->mhy, work->nalloc);
92 srenew(work->mhz, work->nalloc);
93 srenew(work->m2, work->nalloc);
94 /* Allocate an aligned pointer for SIMD operations, including extra
95 * elements at the end for padding.
97 #ifdef PME_SIMD_SOLVE
98 simd_width = GMX_SIMD_REAL_WIDTH;
99 #else
100 /* We can use any alignment, apart from 0, so we use 4 */
101 simd_width = 4;
102 #endif
103 sfree_aligned(work->denom);
104 sfree_aligned(work->tmp1);
105 sfree_aligned(work->tmp2);
106 sfree_aligned(work->eterm);
107 snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
108 snew_aligned(work->tmp1, work->nalloc+simd_width, simd_width*sizeof(real));
109 snew_aligned(work->tmp2, work->nalloc+simd_width, simd_width*sizeof(real));
110 snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
111 srenew(work->m2inv, work->nalloc);
113 /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
114 * of simd padded elements.
116 for (i = 0; i < work->nalloc+simd_width; i++)
118 work->denom[i] = 1;
123 void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
125 int thread;
126 /* Use fft5d, order after FFT is y major, z, x minor */
128 snew(*work, nthread);
129 /* Allocate the work arrays thread local to optimize memory access */
130 #pragma omp parallel for num_threads(nthread) schedule(static)
131 for (thread = 0; thread < nthread; thread++)
135 realloc_work(&((*work)[thread]), nkx);
137 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
141 static void free_work(struct pme_solve_work_t *work)
143 if (work)
145 sfree(work->mhx);
146 sfree(work->mhy);
147 sfree(work->mhz);
148 sfree(work->m2);
149 sfree_aligned(work->denom);
150 sfree_aligned(work->tmp1);
151 sfree_aligned(work->tmp2);
152 sfree_aligned(work->eterm);
153 sfree(work->m2inv);
157 void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
159 if (*work)
161 for (int thread = 0; thread < nthread; thread++)
163 free_work(&(*work)[thread]);
166 sfree(*work);
167 *work = nullptr;
170 void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
171 real *mesh_energy, matrix vir)
173 /* This function sums output over threads and should therefore
174 * only be called after thread synchronization.
176 int thread;
178 *mesh_energy = work[0].energy_q;
179 copy_mat(work[0].vir_q, vir);
181 for (thread = 1; thread < nthread; thread++)
183 *mesh_energy += work[thread].energy_q;
184 m_add(vir, work[thread].vir_q, vir);
188 void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
189 real *mesh_energy, matrix vir)
191 /* This function sums output over threads and should therefore
192 * only be called after thread synchronization.
194 int thread;
196 *mesh_energy = work[0].energy_lj;
197 copy_mat(work[0].vir_lj, vir);
199 for (thread = 1; thread < nthread; thread++)
201 *mesh_energy += work[thread].energy_lj;
202 m_add(vir, work[thread].vir_lj, vir);
206 #if defined PME_SIMD_SOLVE
207 /* Calculate exponentials through SIMD */
208 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
211 SimdReal f_simd(f);
212 SimdReal tmp_d1, tmp_r, tmp_e;
213 int kx;
215 /* We only need to calculate from start. But since start is 0 or 1
216 * and we want to use aligned loads/stores, we always start from 0.
218 for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
220 tmp_d1 = load(d_aligned+kx);
221 tmp_r = load(r_aligned+kx);
222 tmp_r = gmx::exp(tmp_r);
223 tmp_e = f_simd / tmp_d1;
224 tmp_e = tmp_e * tmp_r;
225 store(e_aligned+kx, tmp_e);
229 #else
230 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
232 int kx;
233 for (kx = start; kx < end; kx++)
235 d[kx] = 1.0/d[kx];
237 for (kx = start; kx < end; kx++)
239 r[kx] = std::exp(r[kx]);
241 for (kx = start; kx < end; kx++)
243 e[kx] = f*r[kx]*d[kx];
246 #endif
248 #if defined PME_SIMD_SOLVE
249 /* Calculate exponentials through SIMD */
250 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
252 SimdReal tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
253 const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
254 int kx;
255 for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
257 /* We only need to calculate from start. But since start is 0 or 1
258 * and we want to use aligned loads/stores, we always start from 0.
260 tmp_d = load(d_aligned+kx);
261 d_inv = SimdReal(1.0) / tmp_d;
262 store(d_aligned+kx, d_inv);
263 tmp_r = load(r_aligned+kx);
264 tmp_r = gmx::exp(tmp_r);
265 store(r_aligned+kx, tmp_r);
266 tmp_mk = load(factor_aligned+kx);
267 tmp_fac = sqr_PI * tmp_mk * erfc(tmp_mk);
268 store(factor_aligned+kx, tmp_fac);
271 #else
272 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
274 int kx;
275 real mk;
276 for (kx = start; kx < end; kx++)
278 d[kx] = 1.0/d[kx];
281 for (kx = start; kx < end; kx++)
283 r[kx] = std::exp(r[kx]);
286 for (kx = start; kx < end; kx++)
288 mk = tmp2[kx];
289 tmp2[kx] = sqrt(M_PI)*mk*std::erfc(mk);
292 #endif
294 int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid, real vol,
295 gmx_bool bEnerVir,
296 int nthread, int thread)
298 /* do recip sum over local cells in grid */
299 /* y major, z middle, x minor or continuous */
300 t_complex *p0;
301 int kx, ky, kz, maxkx, maxky;
302 int nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
303 real mx, my, mz;
304 real ewaldcoeff = pme->ewaldcoeff_q;
305 real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
306 real ets2, struct2, vfactor, ets2vf;
307 real d1, d2, energy = 0;
308 real by, bz;
309 real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
310 real rxx, ryx, ryy, rzx, rzy, rzz;
311 struct pme_solve_work_t *work;
312 real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
313 real mhxk, mhyk, mhzk, m2k;
314 real corner_fac;
315 ivec complex_order;
316 ivec local_ndata, local_offset, local_size;
317 real elfac;
319 elfac = ONE_4PI_EPS0/pme->epsilon_r;
321 nx = pme->nkx;
322 ny = pme->nky;
323 nz = pme->nkz;
325 /* Dimensions should be identical for A/B grid, so we just use A here */
326 gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
327 complex_order,
328 local_ndata,
329 local_offset,
330 local_size);
332 rxx = pme->recipbox[XX][XX];
333 ryx = pme->recipbox[YY][XX];
334 ryy = pme->recipbox[YY][YY];
335 rzx = pme->recipbox[ZZ][XX];
336 rzy = pme->recipbox[ZZ][YY];
337 rzz = pme->recipbox[ZZ][ZZ];
339 maxkx = (nx+1)/2;
340 maxky = (ny+1)/2;
342 work = &pme->solve_work[thread];
343 mhx = work->mhx;
344 mhy = work->mhy;
345 mhz = work->mhz;
346 m2 = work->m2;
347 denom = work->denom;
348 tmp1 = work->tmp1;
349 eterm = work->eterm;
350 m2inv = work->m2inv;
352 iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread /nthread;
353 iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
355 for (iyz = iyz0; iyz < iyz1; iyz++)
357 iy = iyz/local_ndata[ZZ];
358 iz = iyz - iy*local_ndata[ZZ];
360 ky = iy + local_offset[YY];
362 if (ky < maxky)
364 my = ky;
366 else
368 my = (ky - ny);
371 by = M_PI*vol*pme->bsp_mod[YY][ky];
373 kz = iz + local_offset[ZZ];
375 mz = kz;
377 bz = pme->bsp_mod[ZZ][kz];
379 /* 0.5 correction for corner points */
380 corner_fac = 1;
381 if (kz == 0 || kz == (nz+1)/2)
383 corner_fac = 0.5;
386 p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
388 /* We should skip the k-space point (0,0,0) */
389 /* Note that since here x is the minor index, local_offset[XX]=0 */
390 if (local_offset[XX] > 0 || ky > 0 || kz > 0)
392 kxstart = local_offset[XX];
394 else
396 kxstart = local_offset[XX] + 1;
397 p0++;
399 kxend = local_offset[XX] + local_ndata[XX];
401 if (bEnerVir)
403 /* More expensive inner loop, especially because of the storage
404 * of the mh elements in array's.
405 * Because x is the minor grid index, all mh elements
406 * depend on kx for triclinic unit cells.
409 /* Two explicit loops to avoid a conditional inside the loop */
410 for (kx = kxstart; kx < maxkx; kx++)
412 mx = kx;
414 mhxk = mx * rxx;
415 mhyk = mx * ryx + my * ryy;
416 mhzk = mx * rzx + my * rzy + mz * rzz;
417 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
418 mhx[kx] = mhxk;
419 mhy[kx] = mhyk;
420 mhz[kx] = mhzk;
421 m2[kx] = m2k;
422 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
423 tmp1[kx] = -factor*m2k;
426 for (kx = maxkx; kx < kxend; kx++)
428 mx = (kx - nx);
430 mhxk = mx * rxx;
431 mhyk = mx * ryx + my * ryy;
432 mhzk = mx * rzx + my * rzy + mz * rzz;
433 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
434 mhx[kx] = mhxk;
435 mhy[kx] = mhyk;
436 mhz[kx] = mhzk;
437 m2[kx] = m2k;
438 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
439 tmp1[kx] = -factor*m2k;
442 for (kx = kxstart; kx < kxend; kx++)
444 m2inv[kx] = 1.0/m2[kx];
447 calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
449 for (kx = kxstart; kx < kxend; kx++, p0++)
451 d1 = p0->re;
452 d2 = p0->im;
454 p0->re = d1*eterm[kx];
455 p0->im = d2*eterm[kx];
457 struct2 = 2.0*(d1*d1+d2*d2);
459 tmp1[kx] = eterm[kx]*struct2;
462 for (kx = kxstart; kx < kxend; kx++)
464 ets2 = corner_fac*tmp1[kx];
465 vfactor = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
466 energy += ets2;
468 ets2vf = ets2*vfactor;
469 virxx += ets2vf*mhx[kx]*mhx[kx] - ets2;
470 virxy += ets2vf*mhx[kx]*mhy[kx];
471 virxz += ets2vf*mhx[kx]*mhz[kx];
472 viryy += ets2vf*mhy[kx]*mhy[kx] - ets2;
473 viryz += ets2vf*mhy[kx]*mhz[kx];
474 virzz += ets2vf*mhz[kx]*mhz[kx] - ets2;
477 else
479 /* We don't need to calculate the energy and the virial.
480 * In this case the triclinic overhead is small.
483 /* Two explicit loops to avoid a conditional inside the loop */
485 for (kx = kxstart; kx < maxkx; kx++)
487 mx = kx;
489 mhxk = mx * rxx;
490 mhyk = mx * ryx + my * ryy;
491 mhzk = mx * rzx + my * rzy + mz * rzz;
492 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
493 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
494 tmp1[kx] = -factor*m2k;
497 for (kx = maxkx; kx < kxend; kx++)
499 mx = (kx - nx);
501 mhxk = mx * rxx;
502 mhyk = mx * ryx + my * ryy;
503 mhzk = mx * rzx + my * rzy + mz * rzz;
504 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
505 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
506 tmp1[kx] = -factor*m2k;
509 calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
511 for (kx = kxstart; kx < kxend; kx++, p0++)
513 d1 = p0->re;
514 d2 = p0->im;
516 p0->re = d1*eterm[kx];
517 p0->im = d2*eterm[kx];
522 if (bEnerVir)
524 /* Update virial with local values.
525 * The virial is symmetric by definition.
526 * this virial seems ok for isotropic scaling, but I'm
527 * experiencing problems on semiisotropic membranes.
528 * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
530 work->vir_q[XX][XX] = 0.25*virxx;
531 work->vir_q[YY][YY] = 0.25*viryy;
532 work->vir_q[ZZ][ZZ] = 0.25*virzz;
533 work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
534 work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
535 work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
537 /* This energy should be corrected for a charged system */
538 work->energy_q = 0.5*energy;
541 /* Return the loop count */
542 return local_ndata[YY]*local_ndata[XX];
545 int solve_pme_lj_yzx(struct gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real vol,
546 gmx_bool bEnerVir, int nthread, int thread)
548 /* do recip sum over local cells in grid */
549 /* y major, z middle, x minor or continuous */
550 int ig, gcount;
551 int kx, ky, kz, maxkx, maxky;
552 int nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
553 real mx, my, mz;
554 real ewaldcoeff = pme->ewaldcoeff_lj;
555 real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
556 real ets2, ets2vf;
557 real eterm, vterm, d1, d2, energy = 0;
558 real by, bz;
559 real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
560 real rxx, ryx, ryy, rzx, rzy, rzz;
561 real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
562 real mhxk, mhyk, mhzk, m2k;
563 struct pme_solve_work_t *work;
564 real corner_fac;
565 ivec complex_order;
566 ivec local_ndata, local_offset, local_size;
567 nx = pme->nkx;
568 ny = pme->nky;
569 nz = pme->nkz;
571 /* Dimensions should be identical for A/B grid, so we just use A here */
572 gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
573 complex_order,
574 local_ndata,
575 local_offset,
576 local_size);
577 rxx = pme->recipbox[XX][XX];
578 ryx = pme->recipbox[YY][XX];
579 ryy = pme->recipbox[YY][YY];
580 rzx = pme->recipbox[ZZ][XX];
581 rzy = pme->recipbox[ZZ][YY];
582 rzz = pme->recipbox[ZZ][ZZ];
584 maxkx = (nx+1)/2;
585 maxky = (ny+1)/2;
587 work = &pme->solve_work[thread];
588 mhx = work->mhx;
589 mhy = work->mhy;
590 mhz = work->mhz;
591 m2 = work->m2;
592 denom = work->denom;
593 tmp1 = work->tmp1;
594 tmp2 = work->tmp2;
596 iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread /nthread;
597 iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
599 for (iyz = iyz0; iyz < iyz1; iyz++)
601 iy = iyz/local_ndata[ZZ];
602 iz = iyz - iy*local_ndata[ZZ];
604 ky = iy + local_offset[YY];
606 if (ky < maxky)
608 my = ky;
610 else
612 my = (ky - ny);
615 by = 3.0*vol*pme->bsp_mod[YY][ky]
616 / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
618 kz = iz + local_offset[ZZ];
620 mz = kz;
622 bz = pme->bsp_mod[ZZ][kz];
624 /* 0.5 correction for corner points */
625 corner_fac = 1;
626 if (kz == 0 || kz == (nz+1)/2)
628 corner_fac = 0.5;
631 kxstart = local_offset[XX];
632 kxend = local_offset[XX] + local_ndata[XX];
633 if (bEnerVir)
635 /* More expensive inner loop, especially because of the
636 * storage of the mh elements in array's. Because x is the
637 * minor grid index, all mh elements depend on kx for
638 * triclinic unit cells.
641 /* Two explicit loops to avoid a conditional inside the loop */
642 for (kx = kxstart; kx < maxkx; kx++)
644 mx = kx;
646 mhxk = mx * rxx;
647 mhyk = mx * ryx + my * ryy;
648 mhzk = mx * rzx + my * rzy + mz * rzz;
649 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
650 mhx[kx] = mhxk;
651 mhy[kx] = mhyk;
652 mhz[kx] = mhzk;
653 m2[kx] = m2k;
654 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
655 tmp1[kx] = -factor*m2k;
656 tmp2[kx] = sqrt(factor*m2k);
659 for (kx = maxkx; kx < kxend; kx++)
661 mx = (kx - nx);
663 mhxk = mx * rxx;
664 mhyk = mx * ryx + my * ryy;
665 mhzk = mx * rzx + my * rzy + mz * rzz;
666 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
667 mhx[kx] = mhxk;
668 mhy[kx] = mhyk;
669 mhz[kx] = mhzk;
670 m2[kx] = m2k;
671 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
672 tmp1[kx] = -factor*m2k;
673 tmp2[kx] = sqrt(factor*m2k);
676 calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
678 for (kx = kxstart; kx < kxend; kx++)
680 m2k = factor*m2[kx];
681 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
682 + 2.0*m2k*tmp2[kx]);
683 vterm = 3.0*(-tmp1[kx] + tmp2[kx]);
684 tmp1[kx] = eterm*denom[kx];
685 tmp2[kx] = vterm*denom[kx];
688 if (!bLB)
690 t_complex *p0;
691 real struct2;
693 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
694 for (kx = kxstart; kx < kxend; kx++, p0++)
696 d1 = p0->re;
697 d2 = p0->im;
699 eterm = tmp1[kx];
700 vterm = tmp2[kx];
701 p0->re = d1*eterm;
702 p0->im = d2*eterm;
704 struct2 = 2.0*(d1*d1+d2*d2);
706 tmp1[kx] = eterm*struct2;
707 tmp2[kx] = vterm*struct2;
710 else
712 real *struct2 = denom;
713 real str2;
715 for (kx = kxstart; kx < kxend; kx++)
717 struct2[kx] = 0.0;
719 /* Due to symmetry we only need to calculate 4 of the 7 terms */
720 for (ig = 0; ig <= 3; ++ig)
722 t_complex *p0, *p1;
723 real scale;
725 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
726 p1 = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
727 scale = 2.0*lb_scale_factor_symm[ig];
728 for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
730 struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
734 for (ig = 0; ig <= 6; ++ig)
736 t_complex *p0;
738 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
739 for (kx = kxstart; kx < kxend; kx++, p0++)
741 d1 = p0->re;
742 d2 = p0->im;
744 eterm = tmp1[kx];
745 p0->re = d1*eterm;
746 p0->im = d2*eterm;
749 for (kx = kxstart; kx < kxend; kx++)
751 eterm = tmp1[kx];
752 vterm = tmp2[kx];
753 str2 = struct2[kx];
754 tmp1[kx] = eterm*str2;
755 tmp2[kx] = vterm*str2;
759 for (kx = kxstart; kx < kxend; kx++)
761 ets2 = corner_fac*tmp1[kx];
762 vterm = 2.0*factor*tmp2[kx];
763 energy += ets2;
764 ets2vf = corner_fac*vterm;
765 virxx += ets2vf*mhx[kx]*mhx[kx] - ets2;
766 virxy += ets2vf*mhx[kx]*mhy[kx];
767 virxz += ets2vf*mhx[kx]*mhz[kx];
768 viryy += ets2vf*mhy[kx]*mhy[kx] - ets2;
769 viryz += ets2vf*mhy[kx]*mhz[kx];
770 virzz += ets2vf*mhz[kx]*mhz[kx] - ets2;
773 else
775 /* We don't need to calculate the energy and the virial.
776 * In this case the triclinic overhead is small.
779 /* Two explicit loops to avoid a conditional inside the loop */
781 for (kx = kxstart; kx < maxkx; kx++)
783 mx = kx;
785 mhxk = mx * rxx;
786 mhyk = mx * ryx + my * ryy;
787 mhzk = mx * rzx + my * rzy + mz * rzz;
788 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
789 m2[kx] = m2k;
790 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
791 tmp1[kx] = -factor*m2k;
792 tmp2[kx] = sqrt(factor*m2k);
795 for (kx = maxkx; kx < kxend; kx++)
797 mx = (kx - nx);
799 mhxk = mx * rxx;
800 mhyk = mx * ryx + my * ryy;
801 mhzk = mx * rzx + my * rzy + mz * rzz;
802 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
803 m2[kx] = m2k;
804 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
805 tmp1[kx] = -factor*m2k;
806 tmp2[kx] = sqrt(factor*m2k);
809 calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
811 for (kx = kxstart; kx < kxend; kx++)
813 m2k = factor*m2[kx];
814 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
815 + 2.0*m2k*tmp2[kx]);
816 tmp1[kx] = eterm*denom[kx];
818 gcount = (bLB ? 7 : 1);
819 for (ig = 0; ig < gcount; ++ig)
821 t_complex *p0;
823 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
824 for (kx = kxstart; kx < kxend; kx++, p0++)
826 d1 = p0->re;
827 d2 = p0->im;
829 eterm = tmp1[kx];
831 p0->re = d1*eterm;
832 p0->im = d2*eterm;
837 if (bEnerVir)
839 work->vir_lj[XX][XX] = 0.25*virxx;
840 work->vir_lj[YY][YY] = 0.25*viryy;
841 work->vir_lj[ZZ][ZZ] = 0.25*virzz;
842 work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
843 work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
844 work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
846 /* This energy should be corrected for a charged system */
847 work->energy_lj = 0.5*energy;
849 /* Return the loop count */
850 return local_ndata[YY]*local_ndata[XX];