New classes for spline interpolation tables
[gromacs/AngularHB.git] / src / gromacs / tables / tests / splinetable.cpp
blob44aaf2b42a35ab8fe0d8f2e489580cea4be67439
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2015,2016, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 /*! \internal \file
36 * \brief
37 * Tests for simple math functions.eval
39 * \author Erik Lindahl <erik.lindahl@gmail.com>
40 * \ingroup module_tables
42 #include "gmxpre.h"
44 #include <cmath>
46 #include <algorithm>
47 #include <functional>
48 #include <utility>
50 #include <gtest/gtest.h>
52 #include "gromacs/math/utilities.h"
53 #include "gromacs/options/basicoptions.h"
54 #include "gromacs/options/ioptionscontainer.h"
55 #include "gromacs/simd/simd.h"
56 #include "gromacs/tables/cubicsplinetable.h"
57 #include "gromacs/tables/quadraticsplinetable.h"
59 #include "testutils/testasserts.h"
60 #include "testutils/testoptions.h"
63 namespace gmx
66 namespace test
69 namespace
72 class SplineTableTestBase : public ::testing::Test
74 public:
75 static int s_testPoints_; //!< Number of points to use. Public so we can set it as option
78 int
79 SplineTableTestBase::s_testPoints_ = 100;
81 /*! \cond */
82 /*! \brief Command-line option to adjust the number of points used to test SIMD math functions. */
83 GMX_TEST_OPTIONS(SplineTableTestOptions, options)
85 options->addOption(::gmx::IntegerOption("npoints")
86 .store(&SplineTableTestBase::s_testPoints_)
87 .description("Number of points to test for spline table functions"));
89 /*! \endcond */
94 /*! \brief Test fixture for table comparision with analytical/numerical functions */
95 template <typename T>
96 class SplineTableTest : public SplineTableTestBase
98 public:
99 SplineTableTest() : tolerance_(T::defaultTolerance) {}
101 /*! \brief Set a new tolerance to be used in table function comparison
103 * \param tol New tolerance to use
105 void
106 setTolerance(real tol) { tolerance_ = tol; }
108 //! \cond internal
109 /*! \internal \brief
110 * Assertion predicate formatter for comparing table with function/derivative
112 template<int numFuncInTable = 1, int funcIndex = 0>
113 void
114 testSplineTableAgainstFunctions(const std::string &desc,
115 const std::function<double(double)> &refFunc,
116 const std::function<double(double)> &refDer,
117 const T &table,
118 const std::pair<real, real> &testRange);
119 //! \endcond
121 private:
122 real tolerance_; //!< Tolerance to use
125 template <class T>
126 template<int numFuncInTable, int funcIndex>
127 void
128 SplineTableTest<T>::testSplineTableAgainstFunctions(const std::string &desc,
129 const std::function<double(double)> &refFunc,
130 const std::function<double(double)> &refDer,
131 const T &table,
132 const std::pair<real, real> &testRange)
134 real dx = (testRange.second - testRange.first) / s_testPoints_;
136 FloatingPointTolerance funcTolerance(relativeToleranceAsFloatingPoint(0.0, tolerance_));
138 for (real x = testRange.first; x < testRange.second; x += dx)
140 real h = std::sqrt(GMX_REAL_EPS);
141 real secondDerivative = (refDer(x+h)-refDer(x))/h;
143 real testFuncValue;
144 real testDerValue;
146 table.template evaluateFunctionAndDerivative<numFuncInTable, funcIndex>(x, &testFuncValue, &testDerValue);
148 // Check that we get the same values from function/derivative-only methods
149 real tmpFunc, tmpDer;
151 table.template evaluateFunction<numFuncInTable, funcIndex>(x, &tmpFunc);
153 table.template evaluateDerivative<numFuncInTable, funcIndex>(x, &tmpDer);
155 if (testFuncValue != tmpFunc)
157 ADD_FAILURE()
158 << "Interpolation inconsistency for table " << desc << std::endl
159 << numFuncInTable << " function(s) in table, testing index " << funcIndex << std::endl
160 << "First failure at x = " << x << std::endl
161 << "Function value when evaluating function & derivative: " << testFuncValue << std::endl
162 << "Function value when evaluating only function: " << tmpFunc << std::endl;
163 return;
165 if (testDerValue != tmpDer)
167 ADD_FAILURE()
168 << "Interpolation inconsistency for table " << desc << std::endl
169 << numFuncInTable << " function(s) in table, testing index " << funcIndex << std::endl
170 << "First failure at x = " << x << std::endl
171 << "Derivative value when evaluating function & derivative: " << testDerValue << std::endl
172 << "Derivative value when evaluating only derivative: " << tmpDer << std::endl;
173 return;
176 // There are two sources of errors that we need to account for when checking the values,
177 // and we only fail the test if both of these tolerances are violated:
179 // 1) First, we have the normal relative error of the test vs. reference value. For this
180 // we use the normal testutils relative tolerance checking.
181 // However, there is an additional source of error: When we calculate the forces we
182 // use average higher derivatives over the interval to improve accuracy, but this
183 // also means we won't reproduce values at table points exactly. This is usually not
184 // an issue since the tolerances we have are much larger, but when the reference derivative
185 // value is exactly zero the relative error will be infinite. To account for this, we
186 // extract the spacing from the table and evaluate the reference derivative at a point
187 // this much larger too, and use the largest of the two values as the reference
188 // magnitude for the derivative when setting the relative tolerance.
189 // Note that according to the table function definitions, we should be allowed to evaluate
190 // it one table point beyond the range (this is done already for construction).
192 // 2) Second, due to the loss-of-accuracy when calculating the index through rtable
193 // there is an internal absolute tolerance that we can calculate.
194 // The source of this error is the subtraction eps=rtab-[rtab], which leaves an
195 // error proportional to eps_machine*rtab=eps_machine*x*tableScale.
196 // To lowest order, the term in the function and derivative values (respectively) that
197 // are proportional to eps will be the next-higher derivative multiplied by the spacing.
198 // This means the truncation error in the value is derivative*x*eps_machine, and in the
199 // derivative the error is 2nd_derivative*x*eps_machine.
201 real refFuncValue = refFunc(x);
202 real refDerValue = refDer(x);
203 real nextRefDerValue = refDer(x + table.tableSpacing());
205 real derMagnitude = std::max( std::abs(refDerValue), std::abs(nextRefDerValue));
207 // Since the reference magnitude will change over each interval we need to re-evaluate
208 // the derivative tolerance inside the loop.
209 FloatingPointTolerance derTolerance(relativeToleranceAsFloatingPoint(derMagnitude, tolerance_));
211 FloatingPointDifference funcDiff(refFuncValue, testFuncValue);
212 FloatingPointDifference derDiff(refDerValue, testDerValue);
214 real allowedAbsFuncErr = std::abs(refDerValue) * x * GMX_REAL_EPS;
215 real allowedAbsDerErr = std::abs(secondDerivative) * x * GMX_REAL_EPS;
217 if ((!funcTolerance.isWithin(funcDiff) && funcDiff.asAbsolute() > allowedAbsFuncErr) ||
218 (!derTolerance.isWithin(derDiff) && derDiff.asAbsolute() > allowedAbsDerErr))
220 ADD_FAILURE()
221 << "Failing comparison with function for table " << desc << std::endl
222 << numFuncInTable << " function(s) in table, testing index " << funcIndex << std::endl
223 << "Test range is ( " << testRange.first << " , " << testRange.second << " ) " << std::endl
224 << "Tolerance = " << tolerance_ << std::endl
225 << "First failure at x = " << x << std::endl
226 << "Reference function = " << refFuncValue << std::endl
227 << "Test table function = " << testFuncValue << std::endl
228 << "Reference derivative = " << refDerValue << std::endl
229 << "Test table derivative = " << testDerValue << std::endl;
230 return;
236 /*! \brief Function similar to coulomb electrostatics
238 * \param r argument
239 * \return r^-1
241 double
242 coulombFunction(double r)
244 return 1.0/r;
247 /*! \brief Derivative (not force) of coulomb electrostatics
249 * \param r argument
250 * \return -r^-2
252 double
253 coulombDerivative(double r)
255 return -1.0/(r*r);
258 /*! \brief Function similar to power-6 Lennard-Jones dispersion
260 * \param r argument
261 * \return r^-6
263 double
264 lj6Function(double r)
266 return std::pow(r, -6.0);
269 /*! \brief Derivative (not force) of the power-6 Lennard-Jones dispersion
271 * \param r argument
272 * \return -6.0*r^-7
274 double
275 lj6Derivative(double r)
277 return -6.0*std::pow(r, -7.0);
280 /*! \brief Function similar to power-12 Lennard-Jones repulsion
282 * \param r argument
283 * \return r^-12
285 double
286 lj12Function(double r)
288 return std::pow(r, -12.0);
291 /*! \brief Derivative (not force) of the power-12 Lennard-Jones repulsion
293 * \param r argument
294 * \return -12.0*r^-13
296 double
297 lj12Derivative(double r)
299 return -12.0*std::pow(r, -13.0);
302 /*! \brief The sinc function, sin(r)/r
304 * \param r argument
305 * \return sin(r)/r
307 double
308 sincFunction(double r)
310 return std::sin(r)/r;
313 /*! \brief Derivative of the sinc function
315 * \param r argument
316 * \return derivative of sinc, (r*cos(r)-sin(r))/r^2
318 double
319 sincDerivative(double r)
321 return (r*std::cos(r)-std::sin(r))/(r*r);
324 /*! \brief Function for the direct-space PME correction to 1/r
326 * \param r argument
327 * \return PME correction function, erf(r)/r
329 double
330 pmeCorrFunction(double r)
332 if (r == 0)
334 return 2.0/std::sqrt(M_PI);
336 else
338 return std::erf(r)/r;
342 /*! \brief Derivative of the direct-space PME correction to 1/r
344 * \param r argument
345 * \return Derivative of the PME correction function.
347 double
348 pmeCorrDerivative(double r)
350 if (r == 0)
352 return 0;
354 else
356 return (2.0*std::exp(-r*r)/std::sqrt(3.14159265358979323846)*r-erf(r))/(r*r);
360 /*! \brief Typed-test list. We test QuadraticSplineTable and CubicSplineTable
362 typedef ::testing::Types<QuadraticSplineTable, CubicSplineTable> SplineTableTypes;
363 TYPED_TEST_CASE(SplineTableTest, SplineTableTypes);
366 TYPED_TEST(SplineTableTest, HandlesIncorrectInput)
368 // negative range
369 EXPECT_THROW_GMX(TypeParam( {{"LJ12", lj12Function, lj12Derivative}}, {-1.0, 0.0}), gmx::InvalidInputError);
371 // Too small range
372 EXPECT_THROW_GMX(TypeParam( {{"LJ12", lj12Function, lj12Derivative}}, {1.0, 1.00001}), gmx::InvalidInputError);
374 // bad tolerance
375 EXPECT_THROW_GMX(TypeParam( {{"LJ12", lj12Function, lj12Derivative}}, {1.0, 2.0}, 1e-20), gmx::ToleranceError);
377 // Range is so close to 0.0 that table would require >1e6 points
378 EXPECT_THROW_GMX(TypeParam( {{"LJ12", lj12Function, lj12Derivative}}, {1e-4, 2.0}), gmx::ToleranceError);
380 // mismatching function/derivative
381 EXPECT_THROW_GMX(TypeParam( { {"BadLJ12", lj12Derivative, lj12Function}}, {1.0, 2.0}), gmx::InconsistentInputError);
385 #ifndef NDEBUG
386 TYPED_TEST(SplineTableTest, CatchesOutOfRangeValues)
388 TypeParam table( {{"LJ12", lj12Function, lj12Derivative}}, {0.2, 1.0});
389 real x, func, der;
391 x = -GMX_REAL_EPS;
392 EXPECT_THROW_GMX(table.evaluateFunctionAndDerivative(x, &func, &der), gmx::RangeError);
394 x = 1.0;
395 EXPECT_THROW_GMX(table.evaluateFunctionAndDerivative(x, &func, &der), gmx::RangeError);
397 #endif
400 TYPED_TEST(SplineTableTest, Sinc)
402 std::pair<real, real> range(0.1, 10);
404 TypeParam sincTable( {{"Sinc", sincFunction, sincDerivative}}, range);
406 TestFixture::testSplineTableAgainstFunctions("Sinc", sincFunction, sincDerivative, sincTable, range);
410 TYPED_TEST(SplineTableTest, LJ12)
412 std::pair<real, real> range(0.2, 2.0);
414 TypeParam lj12Table( {{"LJ12", lj12Function, lj12Derivative}}, range);
416 TestFixture::testSplineTableAgainstFunctions("LJ12", lj12Function, lj12Derivative, lj12Table, range);
420 TYPED_TEST(SplineTableTest, PmeCorrection)
422 std::pair<real, real> range(0.0, 4.0);
423 real tolerance = 1e-5;
425 TypeParam pmeCorrTable( {{"PMECorr", pmeCorrFunction, pmeCorrDerivative}}, range, tolerance);
427 TestFixture::setTolerance(tolerance);
428 TestFixture::testSplineTableAgainstFunctions("PMECorr", pmeCorrFunction, pmeCorrDerivative, pmeCorrTable, range);
433 TYPED_TEST(SplineTableTest, HandlesIncorrectNumericalInput)
435 // Lengths do not match
436 std::vector<double> functionValues(10);
437 std::vector<double> derivativeValues(20);
438 EXPECT_THROW_GMX(TypeParam( {{"EmptyVectors", functionValues, derivativeValues, 0.001}},
439 {1.0, 2.0}), gmx::InconsistentInputError);
441 // Upper range is 2.0, spacing 0.1. This requires at least 21 points. Make sure we get an error for 20.
442 functionValues.resize(20);
443 derivativeValues.resize(20);
444 EXPECT_THROW_GMX(TypeParam( {{"EmptyVectors", functionValues, derivativeValues, 0.1}},
445 {1.0, 2.0}), gmx::InconsistentInputError);
447 // Create some test data
448 functionValues.clear();
449 derivativeValues.clear();
451 std::vector<double> badDerivativeValues;
452 double spacing = 1e-3;
454 for (std::size_t i = 0; i < 1001; i++)
456 double x = i * spacing;
457 double func = (x >= 0.1) ? lj12Function(x) : 0.0;
458 double der = (x >= 0.1) ? lj12Derivative(x) : 0.0;
460 functionValues.push_back(func);
461 derivativeValues.push_back(der);
462 badDerivativeValues.push_back(-der);
465 // Derivatives not consistent with function
466 EXPECT_THROW_GMX(TypeParam( {{"NumericalBadLJ12", functionValues, badDerivativeValues, spacing}},
467 {0.2, 1.0}), gmx::InconsistentInputError);
469 // Spacing 1e-3 is not sufficient for r^-12 in range [0.1,1.0]
470 // Make sure we get a tolerance error
471 EXPECT_THROW_GMX(TypeParam( {{"NumericalLJ12", functionValues, derivativeValues, spacing}},
472 {0.2, 1.0}), gmx::ToleranceError);
476 TYPED_TEST(SplineTableTest, NumericalInputPmeCorr)
478 std::pair<real, real> range(0.0, 4.0);
479 std::vector<double> functionValues;
480 std::vector<double> derivativeValues;
482 double inputSpacing = 1e-3;
483 real tolerance = 1e-5;
485 // We only need data up to the argument 4.0, but add 1% margin
486 for (std::size_t i = 0; i < range.second*1.01/inputSpacing; i++)
488 double x = i * inputSpacing;
490 functionValues.push_back(pmeCorrFunction(x));
491 derivativeValues.push_back(pmeCorrDerivative(x));
494 TypeParam pmeCorrTable( {{"NumericalPMECorr", functionValues, derivativeValues, inputSpacing}},
495 range, tolerance);
497 TestFixture::setTolerance(tolerance);
498 TestFixture::testSplineTableAgainstFunctions("NumericalPMECorr", pmeCorrFunction, pmeCorrDerivative, pmeCorrTable, range);
501 TYPED_TEST(SplineTableTest, TwoFunctions)
503 std::pair<real, real> range(0.2, 2.0);
505 TypeParam table( {{"LJ6", lj6Function, lj6Derivative}, {"LJ12", lj12Function, lj12Derivative}}, range);
507 // Test entire range for each function. This will use the method that interpolates a single function
508 TestFixture::template testSplineTableAgainstFunctions<2, 0>("LJ6", lj6Function, lj6Derivative, table, range);
509 TestFixture::template testSplineTableAgainstFunctions<2, 1>("LJ12", lj12Function, lj12Derivative, table, range);
511 // Test the methods that evaluated both functions for one value
512 real x = 0.5 * (range.first + range.second);
513 real refFunc0 = lj6Function(x);
514 real refDer0 = lj6Derivative(x);
515 real refFunc1 = lj12Function(x);
516 real refDer1 = lj12Derivative(x);
518 real tstFunc0, tstDer0, tstFunc1, tstDer1;
519 real tmpFunc0, tmpFunc1, tmpDer0, tmpDer1;
521 // test that we reproduce the reference functions
522 table.evaluateFunctionAndDerivative(x, &tstFunc0, &tstDer0, &tstFunc1, &tstDer1);
524 real funcErr0 = std::abs(tstFunc0-refFunc0) / std::abs(refFunc0);
525 real funcErr1 = std::abs(tstFunc1-refFunc1) / std::abs(refFunc1);
526 real derErr0 = std::abs(tstDer0-refDer0) / std::abs(refDer0);
527 real derErr1 = std::abs(tstDer1-refDer1) / std::abs(refDer1);
529 // Use asserts, since the following ones compare to these values.
530 ASSERT_LT(funcErr0, TypeParam::defaultTolerance);
531 ASSERT_LT(derErr0, TypeParam::defaultTolerance);
532 ASSERT_LT(funcErr1, TypeParam::defaultTolerance);
533 ASSERT_LT(derErr1, TypeParam::defaultTolerance);
535 // Test that function/derivative-only interpolation methods work
536 table.evaluateFunction(x, &tmpFunc0, &tmpFunc1);
537 table.evaluateDerivative(x, &tmpDer0, &tmpDer1);
538 EXPECT_EQ(tstFunc0, tmpFunc0);
539 EXPECT_EQ(tstFunc1, tmpFunc1);
540 EXPECT_EQ(tstDer0, tmpDer0);
542 // Test that scrambled order interpolation methods work
543 table.template evaluateFunctionAndDerivative<2, 1, 0>(x, &tstFunc1, &tstDer1, &tstFunc0, &tstDer0);
544 EXPECT_EQ(tstFunc0, tmpFunc0);
545 EXPECT_EQ(tstFunc1, tmpFunc1);
546 EXPECT_EQ(tstDer0, tmpDer0);
547 EXPECT_EQ(tstDer1, tmpDer1);
549 // Test scrambled order for function/derivative-only methods
550 table.template evaluateFunction<2, 1, 0>(x, &tmpFunc1, &tmpFunc0);
551 table.template evaluateDerivative<2, 1, 0>(x, &tmpDer1, &tmpDer0);
552 EXPECT_EQ(tstFunc0, tmpFunc0);
553 EXPECT_EQ(tstFunc1, tmpFunc1);
554 EXPECT_EQ(tstDer0, tmpDer0);
555 EXPECT_EQ(tstDer1, tmpDer1);
558 TYPED_TEST(SplineTableTest, ThreeFunctions)
560 std::pair<real, real> range(0.2, 2.0);
562 TypeParam table( {{"Coulomb", coulombFunction, coulombDerivative}, {"LJ6", lj6Function, lj6Derivative}, {"LJ12", lj12Function, lj12Derivative}}, range);
564 // Test entire range for each function
565 TestFixture::template testSplineTableAgainstFunctions<3, 0>("Coulomb", coulombFunction, coulombDerivative, table, range);
566 TestFixture::template testSplineTableAgainstFunctions<3, 1>("LJ6", lj6Function, lj6Derivative, table, range);
567 TestFixture::template testSplineTableAgainstFunctions<3, 2>("LJ12", lj12Function, lj12Derivative, table, range);
569 // Test the methods that evaluated both functions for one value
570 real x = 0.5 * (range.first + range.second);
571 real refFunc0 = coulombFunction(x);
572 real refDer0 = coulombDerivative(x);
573 real refFunc1 = lj6Function(x);
574 real refDer1 = lj6Derivative(x);
575 real refFunc2 = lj12Function(x);
576 real refDer2 = lj12Derivative(x);
578 real tstFunc0, tstDer0, tstFunc1, tstDer1, tstFunc2, tstDer2;
579 real tmpFunc0, tmpFunc1, tmpFunc2, tmpDer0, tmpDer1, tmpDer2;
581 // test that we reproduce the reference functions
582 table.evaluateFunctionAndDerivative(x, &tstFunc0, &tstDer0, &tstFunc1, &tstDer1, &tstFunc2, &tstDer2);
584 real funcErr0 = std::abs(tstFunc0-refFunc0) / std::abs(refFunc0);
585 real derErr0 = std::abs(tstDer0-refDer0) / std::abs(refDer0);
586 real funcErr1 = std::abs(tstFunc1-refFunc1) / std::abs(refFunc1);
587 real derErr1 = std::abs(tstDer1-refDer1) / std::abs(refDer1);
588 real funcErr2 = std::abs(tstFunc2-refFunc2) / std::abs(refFunc2);
589 real derErr2 = std::abs(tstDer2-refDer2) / std::abs(refDer2);
591 // Use asserts, since the following ones compare to these values.
592 ASSERT_LT(funcErr0, TypeParam::defaultTolerance);
593 ASSERT_LT(derErr0, TypeParam::defaultTolerance);
594 ASSERT_LT(funcErr1, TypeParam::defaultTolerance);
595 ASSERT_LT(derErr1, TypeParam::defaultTolerance);
596 ASSERT_LT(funcErr2, TypeParam::defaultTolerance);
597 ASSERT_LT(derErr2, TypeParam::defaultTolerance);
599 // Test that function/derivative-only interpolation methods work
600 table.evaluateFunction(x, &tmpFunc0, &tmpFunc1, &tmpFunc2);
601 table.evaluateDerivative(x, &tmpDer0, &tmpDer1, &tmpDer2);
602 EXPECT_EQ(tstFunc0, tmpFunc0);
603 EXPECT_EQ(tstFunc1, tmpFunc1);
604 EXPECT_EQ(tstFunc2, tmpFunc2);
605 EXPECT_EQ(tstDer0, tmpDer0);
606 EXPECT_EQ(tstDer1, tmpDer1);
607 EXPECT_EQ(tstDer2, tmpDer2);
609 // Test two functions out of three
610 table.template evaluateFunctionAndDerivative<3, 0, 1>(x, &tmpFunc0, &tmpDer0, &tmpFunc1, &tmpDer1);
611 EXPECT_EQ(tstFunc0, tmpFunc0);
612 EXPECT_EQ(tstFunc1, tmpFunc1);
613 EXPECT_EQ(tstDer0, tmpDer0);
614 EXPECT_EQ(tstDer1, tmpDer1);
616 // two out of three, function/derivative-only
617 table.template evaluateFunction<3, 0, 1>(x, &tmpFunc0, &tmpFunc1);
618 table.template evaluateDerivative<3, 0, 1>(x, &tmpDer0, &tmpDer1);
619 EXPECT_EQ(tstFunc0, tmpFunc0);
620 EXPECT_EQ(tstFunc1, tmpFunc1);
621 EXPECT_EQ(tstDer0, tmpDer0);
622 EXPECT_EQ(tstDer1, tmpDer1);
624 // Test that scrambled order interpolation methods work
625 table.template evaluateFunctionAndDerivative<3, 2, 1, 0>(x, &tstFunc2, &tstDer2, &tstFunc1, &tstDer1, &tstFunc0, &tstDer0);
626 EXPECT_EQ(tstFunc0, tmpFunc0);
627 EXPECT_EQ(tstFunc1, tmpFunc1);
628 EXPECT_EQ(tstFunc2, tmpFunc2);
629 EXPECT_EQ(tstDer0, tmpDer0);
630 EXPECT_EQ(tstDer1, tmpDer1);
631 EXPECT_EQ(tstDer2, tmpDer2);
633 // Test scrambled order for function/derivative-only methods
634 table.template evaluateFunction<3, 2, 1, 0>(x, &tmpFunc2, &tmpFunc1, &tmpFunc0);
635 table.template evaluateDerivative<3, 2, 1, 0>(x, &tmpDer2, &tmpDer1, &tmpDer0);
636 EXPECT_EQ(tstFunc0, tmpFunc0);
637 EXPECT_EQ(tstFunc1, tmpFunc1);
638 EXPECT_EQ(tstFunc2, tmpFunc2);
639 EXPECT_EQ(tstDer0, tmpDer0);
640 EXPECT_EQ(tstDer1, tmpDer1);
641 EXPECT_EQ(tstDer2, tmpDer2);
644 #if GMX_SIMD_HAVE_REAL
645 TYPED_TEST(SplineTableTest, Simd)
647 std::pair<real, real> range(0.2, 1.0);
648 TypeParam table( {{"LJ12", lj12Function, lj12Derivative}}, range);
650 // We already test that the SIMD operations handle the different elements
651 // correctly in the SIMD module, so here we only test that interpolation
652 // works for a single value in the middle of the interval
654 real x = 0.5 * (range.first + range.second);
655 real refFunc = lj12Function(x);
656 real refDer = lj12Derivative(x);
657 SimdReal tstFunc, tstDer;
658 real funcErr, derErr;
659 GMX_ALIGNED(real, GMX_SIMD_REAL_WIDTH) alignedMem[GMX_SIMD_REAL_WIDTH];
661 table.evaluateFunctionAndDerivative(SimdReal(x), &tstFunc, &tstDer);
663 store(alignedMem, tstFunc);
664 funcErr = std::abs(alignedMem[0]-refFunc) / std::abs(refFunc);
666 store(alignedMem, tstDer);
667 derErr = std::abs(alignedMem[0]-refDer ) / std::abs(refDer );
669 EXPECT_LT(funcErr, TypeParam::defaultTolerance);
670 EXPECT_LT(derErr, TypeParam::defaultTolerance);
673 TYPED_TEST(SplineTableTest, SimdTwoFunctions)
675 std::pair<real, real> range(0.2, 2.0);
677 TypeParam table( {{"LJ6", lj6Function, lj6Derivative}, {"LJ12", lj12Function, lj12Derivative}}, range);
679 // We already test that the SIMD operations handle the different elements
680 // correctly in the SIMD module, so here we only test that interpolation
681 // works for a single value in the middle of the interval
683 real x = 0.5 * (range.first + range.second);
684 real refFunc0 = lj6Function(x);
685 real refDer0 = lj6Derivative(x);
686 real refFunc1 = lj12Function(x);
687 real refDer1 = lj12Derivative(x);
688 SimdReal tstFunc0, tstDer0;
689 SimdReal tstFunc1, tstDer1;
690 real funcErr0, derErr0;
691 real funcErr1, derErr1;
692 GMX_ALIGNED(real, GMX_SIMD_REAL_WIDTH) alignedMem[GMX_SIMD_REAL_WIDTH];
694 table.evaluateFunctionAndDerivative(SimdReal(x), &tstFunc0, &tstDer0, &tstFunc1, &tstDer1);
696 store(alignedMem, tstFunc0);
697 funcErr0 = std::abs(alignedMem[0]-refFunc0) / std::abs(refFunc0);
699 store(alignedMem, tstDer0);
700 derErr0 = std::abs(alignedMem[0]-refDer0 ) / std::abs(refDer0 );
702 store(alignedMem, tstFunc1);
703 funcErr1 = std::abs(alignedMem[0]-refFunc1) / std::abs(refFunc1);
705 store(alignedMem, tstDer1);
706 derErr1 = std::abs(alignedMem[0]-refDer1 ) / std::abs(refDer1 );
708 EXPECT_LT(funcErr0, TypeParam::defaultTolerance);
709 EXPECT_LT(derErr0, TypeParam::defaultTolerance);
710 EXPECT_LT(funcErr1, TypeParam::defaultTolerance);
711 EXPECT_LT(derErr1, TypeParam::defaultTolerance);
713 #endif
715 #if GMX_SIMD_HAVE_REAL && !defined NDEBUG
716 TYPED_TEST(SplineTableTest, CatchesOutOfRangeValuesSimd)
718 std::pair<real, real> range(0.2, 1.0);
719 TypeParam table( {{"LJ12", lj12Function, lj12Derivative}}, range);
720 SimdReal x, func, der;
722 GMX_ALIGNED(real, GMX_SIMD_REAL_WIDTH) alignedMem[GMX_SIMD_REAL_WIDTH];
724 // Make position 1 incorrect if width>=2, otherwise position 0
725 alignedMem[ (GMX_SIMD_REAL_WIDTH >= 2) ? 1 : 0] = -GMX_REAL_EPS;
726 x = load(alignedMem);
728 EXPECT_THROW_GMX(table.evaluateFunctionAndDerivative(x, &func, &der), gmx::RangeError);
730 for (std::size_t i = 0; i < GMX_SIMD_REAL_WIDTH; i++)
732 alignedMem[i] = range.second*(1.0-GMX_REAL_EPS);
734 // Make position 1 incorrect if width>=2, otherwise position 0
735 alignedMem[ (GMX_SIMD_REAL_WIDTH >= 2) ? 1 : 0] = range.second;
736 x = load(alignedMem);
738 EXPECT_THROW_GMX(table.evaluateFunctionAndDerivative(x, &func, &der), gmx::RangeError);
740 #endif
742 } // namespace
744 } // namespace test
746 } // namespace gmx