Update instructions in containers.rst
[gromacs.git] / src / gromacs / statistics / statistics.cpp
blob5e1315cfeb1e2869763e98bfc0fa88518b92b220
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5 * Copyright (c) 2001-2004, The GROMACS development team.
6 * Copyright (c) 2012,2014,2015,2017,2018 by the GROMACS development team.
7 * Copyright (c) 2019,2020, by the GROMACS development team, led by
8 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9 * and including many others, as listed in the AUTHORS file in the
10 * top-level source directory and at http://www.gromacs.org.
12 * GROMACS is free software; you can redistribute it and/or
13 * modify it under the terms of the GNU Lesser General Public License
14 * as published by the Free Software Foundation; either version 2.1
15 * of the License, or (at your option) any later version.
17 * GROMACS is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty of
19 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20 * Lesser General Public License for more details.
22 * You should have received a copy of the GNU Lesser General Public
23 * License along with GROMACS; if not, see
24 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
27 * If you want to redistribute modifications to GROMACS, please
28 * consider that scientific software is very special. Version
29 * control is crucial - bugs must be traceable. We will be happy to
30 * consider code for inclusion in the official distribution, but
31 * derived work must not be called official GROMACS. Details are found
32 * in the README & COPYING files - if they are missing, get the
33 * official version at http://www.gromacs.org.
35 * To help us fund GROMACS development, we humbly ask that you cite
36 * the research papers on the package. Check out http://www.gromacs.org.
38 #include "gmxpre.h"
40 #include "statistics.h"
42 #include <cmath>
44 #include "gromacs/math/functions.h"
45 #include "gromacs/math/vec.h"
46 #include "gromacs/utility/fatalerror.h"
47 #include "gromacs/utility/real.h"
48 #include "gromacs/utility/smalloc.h"
50 static int gmx_dnint(double x)
52 return gmx::roundToInt(x);
55 typedef struct gmx_stats
57 double aa, a, b, sigma_aa, sigma_a, sigma_b, aver, sigma_aver, error;
58 double rmsd, Rdata, Rfit, Rfitaa, chi2, chi2aa;
59 double *x, *y, *dx, *dy;
60 int computed;
61 int np, np_c, nalloc;
62 } gmx_stats;
64 gmx_stats_t gmx_stats_init()
66 gmx_stats* stats;
68 snew(stats, 1);
70 return static_cast<gmx_stats_t>(stats);
73 int gmx_stats_get_npoints(gmx_stats_t gstats, int* N)
75 gmx_stats* stats = static_cast<gmx_stats*>(gstats);
77 *N = stats->np;
79 return estatsOK;
82 void gmx_stats_free(gmx_stats_t gstats)
84 gmx_stats* stats = static_cast<gmx_stats*>(gstats);
86 sfree(stats->x);
87 sfree(stats->y);
88 sfree(stats->dx);
89 sfree(stats->dy);
90 sfree(stats);
93 int gmx_stats_add_point(gmx_stats_t gstats, double x, double y, double dx, double dy)
95 gmx_stats* stats = gstats;
97 if (stats->np + 1 >= stats->nalloc)
99 if (stats->nalloc == 0)
101 stats->nalloc = 1024;
103 else
105 stats->nalloc *= 2;
107 srenew(stats->x, stats->nalloc);
108 srenew(stats->y, stats->nalloc);
109 srenew(stats->dx, stats->nalloc);
110 srenew(stats->dy, stats->nalloc);
111 for (int i = stats->np; (i < stats->nalloc); i++)
113 stats->x[i] = 0;
114 stats->y[i] = 0;
115 stats->dx[i] = 0;
116 stats->dy[i] = 0;
119 stats->x[stats->np] = x;
120 stats->y[stats->np] = y;
121 stats->dx[stats->np] = dx;
122 stats->dy[stats->np] = dy;
123 stats->np++;
124 stats->computed = 0;
126 return estatsOK;
129 int gmx_stats_get_point(gmx_stats_t gstats, real* x, real* y, real* dx, real* dy, real level)
131 gmx_stats* stats = gstats;
132 int ok, outlier;
133 real rmsd, r;
135 if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
137 return ok;
139 outlier = 0;
140 while ((outlier == 0) && (stats->np_c < stats->np))
142 r = std::abs(stats->x[stats->np_c] - stats->y[stats->np_c]);
143 outlier = static_cast<int>(r > rmsd * level);
144 if (outlier)
146 if (nullptr != x)
148 *x = stats->x[stats->np_c];
150 if (nullptr != y)
152 *y = stats->y[stats->np_c];
154 if (nullptr != dx)
156 *dx = stats->dx[stats->np_c];
158 if (nullptr != dy)
160 *dy = stats->dy[stats->np_c];
163 stats->np_c++;
165 if (outlier)
167 return estatsOK;
171 stats->np_c = 0;
173 return estatsNO_POINTS;
176 int gmx_stats_add_points(gmx_stats_t gstats, int n, real* x, real* y, real* dx, real* dy)
178 for (int i = 0; (i < n); i++)
180 int ok;
181 if ((ok = gmx_stats_add_point(gstats, x[i], y[i], (nullptr != dx) ? dx[i] : 0,
182 (nullptr != dy) ? dy[i] : 0))
183 != estatsOK)
185 return ok;
188 return estatsOK;
191 static int gmx_stats_compute(gmx_stats* stats, int weight)
193 double yy, yx, xx, sx, sy, dy, chi2, chi2aa, d2;
194 double ssxx, ssyy, ssxy;
195 double w, wtot, yx_nw, sy_nw, sx_nw, yy_nw, xx_nw, dx2, dy2;
197 int N = stats->np;
199 if (stats->computed == 0)
201 if (N < 1)
203 return estatsNO_POINTS;
206 xx = xx_nw = 0;
207 yy = yy_nw = 0;
208 yx = yx_nw = 0;
209 sx = sx_nw = 0;
210 sy = sy_nw = 0;
211 wtot = 0;
212 d2 = 0;
213 for (int i = 0; (i < N); i++)
215 d2 += gmx::square(stats->x[i] - stats->y[i]);
216 if (((stats->dy[i]) != 0.0) && (weight == elsqWEIGHT_Y))
218 w = 1 / gmx::square(stats->dy[i]);
220 else
222 w = 1;
225 wtot += w;
227 xx += w * gmx::square(stats->x[i]);
228 xx_nw += gmx::square(stats->x[i]);
230 yy += w * gmx::square(stats->y[i]);
231 yy_nw += gmx::square(stats->y[i]);
233 yx += w * stats->y[i] * stats->x[i];
234 yx_nw += stats->y[i] * stats->x[i];
236 sx += w * stats->x[i];
237 sx_nw += stats->x[i];
239 sy += w * stats->y[i];
240 sy_nw += stats->y[i];
243 /* Compute average, sigma and error */
244 stats->aver = sy_nw / N;
245 stats->sigma_aver = std::sqrt(yy_nw / N - gmx::square(sy_nw / N));
246 stats->error = stats->sigma_aver / std::sqrt(static_cast<double>(N));
248 /* Compute RMSD between x and y */
249 stats->rmsd = std::sqrt(d2 / N);
251 /* Correlation coefficient for data */
252 yx_nw /= N;
253 xx_nw /= N;
254 yy_nw /= N;
255 sx_nw /= N;
256 sy_nw /= N;
257 ssxx = N * (xx_nw - gmx::square(sx_nw));
258 ssyy = N * (yy_nw - gmx::square(sy_nw));
259 ssxy = N * (yx_nw - (sx_nw * sy_nw));
260 stats->Rdata = std::sqrt(gmx::square(ssxy) / (ssxx * ssyy));
262 /* Compute straight line through datapoints, either with intercept
263 zero (result in aa) or with intercept variable (results in a
264 and b) */
265 yx = yx / wtot;
266 xx = xx / wtot;
267 sx = sx / wtot;
268 sy = sy / wtot;
270 stats->aa = (yx / xx);
271 stats->a = (yx - sx * sy) / (xx - sx * sx);
272 stats->b = (sy) - (stats->a) * (sx);
274 /* Compute chi2, deviation from a line y = ax+b. Also compute
275 chi2aa which returns the deviation from a line y = ax. */
276 chi2 = 0;
277 chi2aa = 0;
278 for (int i = 0; (i < N); i++)
280 if (stats->dy[i] > 0)
282 dy = stats->dy[i];
284 else
286 dy = 1;
288 chi2aa += gmx::square((stats->y[i] - (stats->aa * stats->x[i])) / dy);
289 chi2 += gmx::square((stats->y[i] - (stats->a * stats->x[i] + stats->b)) / dy);
291 if (N > 2)
293 stats->chi2 = std::sqrt(chi2 / (N - 2));
294 stats->chi2aa = std::sqrt(chi2aa / (N - 2));
296 /* Look up equations! */
297 dx2 = (xx - sx * sx);
298 dy2 = (yy - sy * sy);
299 stats->sigma_a = std::sqrt(stats->chi2 / ((N - 2) * dx2));
300 stats->sigma_b = stats->sigma_a * std::sqrt(xx);
301 stats->Rfit = std::abs(ssxy) / std::sqrt(ssxx * ssyy);
302 stats->Rfitaa = stats->aa * std::sqrt(dx2 / dy2);
304 else
306 stats->chi2 = 0;
307 stats->chi2aa = 0;
308 stats->sigma_a = 0;
309 stats->sigma_b = 0;
310 stats->Rfit = 0;
311 stats->Rfitaa = 0;
314 stats->computed = 1;
317 return estatsOK;
320 int gmx_stats_get_ab(gmx_stats_t gstats, int weight, real* a, real* b, real* da, real* db, real* chi2, real* Rfit)
322 gmx_stats* stats = gstats;
323 int ok;
325 if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
327 return ok;
329 if (nullptr != a)
331 *a = stats->a;
333 if (nullptr != b)
335 *b = stats->b;
337 if (nullptr != da)
339 *da = stats->sigma_a;
341 if (nullptr != db)
343 *db = stats->sigma_b;
345 if (nullptr != chi2)
347 *chi2 = stats->chi2;
349 if (nullptr != Rfit)
351 *Rfit = stats->Rfit;
354 return estatsOK;
357 int gmx_stats_get_a(gmx_stats_t gstats, int weight, real* a, real* da, real* chi2, real* Rfit)
359 gmx_stats* stats = gstats;
360 int ok;
362 if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
364 return ok;
366 if (nullptr != a)
368 *a = stats->aa;
370 if (nullptr != da)
372 *da = stats->sigma_aa;
374 if (nullptr != chi2)
376 *chi2 = stats->chi2aa;
378 if (nullptr != Rfit)
380 *Rfit = stats->Rfitaa;
383 return estatsOK;
386 int gmx_stats_get_average(gmx_stats_t gstats, real* aver)
388 gmx_stats* stats = gstats;
389 int ok;
391 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
393 return ok;
396 *aver = stats->aver;
398 return estatsOK;
401 int gmx_stats_get_ase(gmx_stats_t gstats, real* aver, real* sigma, real* error)
403 gmx_stats* stats = gstats;
404 int ok;
406 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
408 return ok;
411 if (nullptr != aver)
413 *aver = stats->aver;
415 if (nullptr != sigma)
417 *sigma = stats->sigma_aver;
419 if (nullptr != error)
421 *error = stats->error;
424 return estatsOK;
427 int gmx_stats_get_sigma(gmx_stats_t gstats, real* sigma)
429 gmx_stats* stats = gstats;
430 int ok;
432 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
434 return ok;
437 *sigma = stats->sigma_aver;
439 return estatsOK;
442 int gmx_stats_get_error(gmx_stats_t gstats, real* error)
444 gmx_stats* stats = gstats;
445 int ok;
447 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
449 return ok;
452 *error = stats->error;
454 return estatsOK;
457 int gmx_stats_get_corr_coeff(gmx_stats_t gstats, real* R)
459 gmx_stats* stats = gstats;
460 int ok;
462 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
464 return ok;
467 *R = stats->Rdata;
469 return estatsOK;
472 int gmx_stats_get_rmsd(gmx_stats_t gstats, real* rmsd)
474 gmx_stats* stats = gstats;
475 int ok;
477 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
479 return ok;
482 *rmsd = stats->rmsd;
484 return estatsOK;
487 int gmx_stats_dump_xy(gmx_stats_t gstats, FILE* fp)
489 gmx_stats* stats = gstats;
491 for (int i = 0; (i < stats->np); i++)
493 fprintf(fp, "%12g %12g %12g %12g\n", stats->x[i], stats->y[i], stats->dx[i], stats->dy[i]);
496 return estatsOK;
499 int gmx_stats_remove_outliers(gmx_stats_t gstats, double level)
501 gmx_stats* stats = gstats;
502 int iter = 1, done = 0, ok;
503 real rmsd, r;
505 while ((stats->np >= 10) && !done)
507 if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
509 return ok;
511 done = 1;
512 for (int i = 0; (i < stats->np);)
514 r = std::abs(stats->x[i] - stats->y[i]);
515 if (r > level * rmsd)
517 fprintf(stderr, "Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n", iter,
518 rmsd, stats->x[i], stats->y[i]);
519 if (i < stats->np - 1)
521 stats->x[i] = stats->x[stats->np - 1];
522 stats->y[i] = stats->y[stats->np - 1];
523 stats->dx[i] = stats->dx[stats->np - 1];
524 stats->dy[i] = stats->dy[stats->np - 1];
526 stats->np--;
527 done = 0;
529 else
531 i++;
534 iter++;
537 return estatsOK;
540 int gmx_stats_make_histogram(gmx_stats_t gstats, real binwidth, int* nb, int ehisto, int normalized, real** x, real** y)
542 gmx_stats* stats = gstats;
543 int index = 0, nbins = *nb, *nindex;
544 double minx, maxx, maxy, miny, delta, dd, minh;
546 if (((binwidth <= 0) && (nbins <= 0)) || ((binwidth > 0) && (nbins > 0)))
548 return estatsINVALID_INPUT;
550 if (stats->np <= 2)
552 return estatsNO_POINTS;
554 minx = maxx = stats->x[0];
555 miny = maxy = stats->y[0];
556 for (int i = 1; (i < stats->np); i++)
558 miny = (stats->y[i] < miny) ? stats->y[i] : miny;
559 maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
560 minx = (stats->x[i] < minx) ? stats->x[i] : minx;
561 maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
563 if (ehisto == ehistoX)
565 delta = maxx - minx;
566 minh = minx;
568 else if (ehisto == ehistoY)
570 delta = maxy - miny;
571 minh = miny;
573 else
575 return estatsINVALID_INPUT;
578 if (binwidth == 0)
580 binwidth = (delta) / nbins;
582 else
584 nbins = gmx_dnint((delta) / binwidth + 0.5);
586 snew(*x, nbins);
587 snew(nindex, nbins);
588 for (int i = 0; (i < nbins); i++)
590 (*x)[i] = minh + binwidth * (i + 0.5);
592 if (normalized == 0)
594 dd = 1;
596 else
598 dd = 1.0 / (binwidth * stats->np);
601 snew(*y, nbins);
602 for (int i = 0; (i < stats->np); i++)
604 if (ehisto == ehistoY)
606 index = static_cast<int>((stats->y[i] - miny) / binwidth);
608 else if (ehisto == ehistoX)
610 index = static_cast<int>((stats->x[i] - minx) / binwidth);
612 if (index < 0)
614 index = 0;
616 if (index > nbins - 1)
618 index = nbins - 1;
620 (*y)[index] += dd;
621 nindex[index]++;
623 if (*nb == 0)
625 *nb = nbins;
627 for (int i = 0; (i < nbins); i++)
629 if (nindex[i] > 0)
631 (*y)[i] /= nindex[i];
635 sfree(nindex);
637 return estatsOK;
640 static const char* stats_error[estatsNR] = { "All well in STATS land", "No points",
641 "Not enough memory", "Invalid histogram input",
642 "Unknown error", "Not implemented yet" };
644 const char* gmx_stats_message(int estats)
646 if ((estats >= 0) && (estats < estatsNR))
648 return stats_error[estats];
650 else
652 return stats_error[estatsERROR];
656 /* Old convenience functions, should be merged with the core
657 statistics above. */
658 int lsq_y_ax(int n, real x[], real y[], real* a)
660 gmx_stats_t lsq = gmx_stats_init();
661 int ok;
662 real da, chi2, Rfit;
664 gmx_stats_add_points(lsq, n, x, y, nullptr, nullptr);
665 ok = gmx_stats_get_a(lsq, elsqWEIGHT_NONE, a, &da, &chi2, &Rfit);
666 gmx_stats_free(lsq);
668 return ok;
671 static int low_lsq_y_ax_b(int n, const real* xr, const double* xd, real yr[], real* a, real* b, real* r, real* chi2)
673 gmx_stats_t lsq = gmx_stats_init();
674 int ok;
676 for (int i = 0; (i < n); i++)
678 double pt;
680 if (xd != nullptr)
682 pt = xd[i];
684 else if (xr != nullptr)
686 pt = xr[i];
688 else
690 gmx_incons("Either xd or xr has to be non-NULL in low_lsq_y_ax_b()");
693 if ((ok = gmx_stats_add_point(lsq, pt, yr[i], 0, 0)) != estatsOK)
695 gmx_stats_free(lsq);
696 return ok;
699 ok = gmx_stats_get_ab(lsq, elsqWEIGHT_NONE, a, b, nullptr, nullptr, chi2, r);
700 gmx_stats_free(lsq);
702 return ok;
705 int lsq_y_ax_b(int n, real x[], real y[], real* a, real* b, real* r, real* chi2)
707 return low_lsq_y_ax_b(n, x, nullptr, y, a, b, r, chi2);
710 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real* a, real* b, real* r, real* chi2)
712 return low_lsq_y_ax_b(n, nullptr, x, y, a, b, r, chi2);
715 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[], real* a, real* b, real* da, real* db, real* r, real* chi2)
717 gmx_stats_t lsq = gmx_stats_init();
718 int ok;
720 for (int i = 0; (i < n); i++)
722 ok = gmx_stats_add_point(lsq, x[i], y[i], 0, dy[i]);
723 if (ok != estatsOK)
725 gmx_stats_free(lsq);
726 return ok;
729 ok = gmx_stats_get_ab(lsq, elsqWEIGHT_Y, a, b, da, db, chi2, r);
730 gmx_stats_free(lsq);
732 return ok;