2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 #include "gromacs/mdtypes/checkpointdata.h"
42 #include <gmock/gmock.h>
43 #include <gtest/gtest.h>
45 #include "gromacs/fileio/gmxfio.h"
46 #include "gromacs/fileio/gmxfio_xdr.h"
47 #include "gromacs/utility/fatalerror.h"
48 #include "gromacs/utility/inmemoryserializer.h"
50 #include "testutils/testfilemanager.h"
58 * \ingroup module_modularsimulator
59 * \brief Struct allowing to check if type is vector of serializable data
63 struct IsVectorOfSerializableType
65 static bool const value
= false;
68 struct IsVectorOfSerializableType
<std::vector
<T
>>
70 static bool const value
= IsSerializableType
<T
>::value
;
75 * \brief Unified looping over test data
77 * This class allows to write a loop over test data as
78 * for (const auto& value : TestValues::testValueGenerator<type>())
79 * where type can be any of std::string, int, int64_t, bool, float, double,
80 * std::vector<[std::string, int, int64_6, float, double]>, or tensor.
86 * \brief Helper class allowing to loop over test values
87 * \tparam T type of value
90 class TestValueGenerator
97 explicit Iterator(const T
* ptr
) : ptr_(ptr
) {}
98 Iterator
operator++();
99 bool operator!=(const Iterator
& other
) const { return ptr_
!= other
.ptr_
; }
100 const T
& operator*() const { return *ptr_
; }
106 Iterator
begin() const;
107 Iterator
end() const;
111 * \brief Static function returning a TestValueGenerator of type T
112 * \tparam T type of values generated
113 * \return TestValueGenerator<T>
116 static TestValueGenerator
<T
> testValueGenerator()
118 static const TestValueGenerator
<T
> testValueGenerator
;
119 return testValueGenerator
;
124 static const std::vector
<T
>& getTestVector();
127 static std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, const T
*>
130 static std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, const T
*> getBeginPointer();
132 static std::enable_if_t
<std::is_same
<T
, bool>::value
, const T
*> getBeginPointer();
134 static std::enable_if_t
<std::is_same
<T
, tensor
>::value
, const T
*> getBeginPointer();
137 static std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, const T
*>
140 static std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, const T
*> getEndPointer();
142 static std::enable_if_t
<std::is_same
<T
, bool>::value
, const T
*> getEndPointer();
144 static std::enable_if_t
<std::is_same
<T
, tensor
>::value
, const T
*> getEndPointer();
147 static std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, void>
148 increment(const T
** ptr
);
150 static std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, void> increment(const T
** ptr
);
152 static std::enable_if_t
<std::is_same
<T
, bool>::value
, void> increment(const T
** ptr
);
154 static std::enable_if_t
<std::is_same
<T
, tensor
>::value
, void> increment(const T
** ptr
);
156 constexpr static bool testTrue
= true;
157 constexpr static bool testFalse
= false;
158 constexpr static tensor testTensor1
= { { 1.6234, 2.4632, 3.1112 },
159 { 4.66234, 5.9678, 6.088 },
160 { 7.00001, 8.43535, 9.11233 } };
162 constexpr static tensor testTensor2
= { { 1, GMX_DOUBLE_EPS
, 3 },
163 { GMX_DOUBLE_MIN
, 5, 6 },
164 { 7, 8, GMX_DOUBLE_MAX
} };
166 constexpr static tensor testTensor2
= { { 1, GMX_FLOAT_EPS
, 3 },
167 { GMX_FLOAT_MIN
, 5, 6 },
168 { 7, 8, GMX_FLOAT_MAX
} };
172 // Begin implementations of TestValues methods
174 const std::vector
<std::string
>& TestValues::getTestVector()
176 static const std::vector
<std::string
> testStrings({ "Test string\nwith newlines\n", "" });
180 const std::vector
<int>& TestValues::getTestVector()
182 static const std::vector
<int> testInts({ { 3, INT_MAX
, INT_MIN
} });
186 const std::vector
<int64_t>& TestValues::getTestVector()
188 static const std::vector
<int64_t> testInt64s({ -7, LLONG_MAX
, LLONG_MIN
});
192 const std::vector
<float>& TestValues::getTestVector()
194 static const std::vector
<float> testFloats({ 33.9, GMX_FLOAT_MAX
, GMX_FLOAT_MIN
, GMX_FLOAT_EPS
});
198 const std::vector
<double>& TestValues::getTestVector()
200 static const std::vector
<double> testDoubles({ -123.45, GMX_DOUBLE_MAX
, GMX_DOUBLE_MIN
, GMX_DOUBLE_EPS
});
205 std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, const T
*> TestValues::getBeginPointer()
207 return getTestVector
<T
>().data();
210 std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, const T
*> TestValues::getBeginPointer()
212 return &getTestVector
<typename
T::value_type
>();
215 std::enable_if_t
<std::is_same
<T
, bool>::value
, const T
*> TestValues::getBeginPointer()
220 std::enable_if_t
<std::is_same
<T
, tensor
>::value
, const T
*> TestValues::getBeginPointer()
226 std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, const T
*> TestValues::getEndPointer()
228 return getTestVector
<T
>().data() + getTestVector
<T
>().size();
231 std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, const T
*> TestValues::getEndPointer()
233 return &getTestVector
<typename
T::value_type
>() + 1;
236 std::enable_if_t
<std::is_same
<T
, bool>::value
, const T
*> TestValues::getEndPointer()
241 std::enable_if_t
<std::is_same
<T
, tensor
>::value
, const T
*> TestValues::getEndPointer()
247 std::enable_if_t
<IsSerializableType
<T
>::value
&& !std::is_same
<T
, bool>::value
, void>
248 TestValues::increment(const T
** ptr
)
253 std::enable_if_t
<IsVectorOfSerializableType
<T
>::value
, void> TestValues::increment(const T
** ptr
)
258 std::enable_if_t
<std::is_same
<T
, bool>::value
, void> TestValues::increment(const T
** ptr
)
260 *ptr
= (*ptr
== &testTrue
) ? &testFalse
: nullptr;
263 std::enable_if_t
<std::is_same
<T
, tensor
>::value
, void> TestValues::increment(const T
** ptr
)
265 *ptr
= (*ptr
== &testTensor1
) ? &testTensor2
: nullptr;
269 typename
TestValues::TestValueGenerator
<T
>::Iterator
TestValues::TestValueGenerator
<T
>::begin() const
271 return TestValues::TestValueGenerator
<T
>::Iterator(getBeginPointer
<T
>());
275 typename
TestValues::TestValueGenerator
<T
>::Iterator
TestValues::TestValueGenerator
<T
>::end() const
277 return TestValues::TestValueGenerator
<T
>::Iterator(getEndPointer
<T
>());
281 typename
TestValues::TestValueGenerator
<T
>::Iterator
TestValues::TestValueGenerator
<T
>::Iterator::operator++()
283 TestValues::increment(&ptr_
);
286 // End implementations of TestValues methods
288 //! Write scalar input to CheckpointData
290 typename
std::enable_if_t
<IsSerializableType
<T
>::value
, void>
291 writeInput(const std::string
& key
, const T
& inputValue
, WriteCheckpointData
* checkpointData
)
293 checkpointData
->scalar(key
, &inputValue
);
295 //! Read scalar from CheckpointData and test if equal to input
297 typename
std::enable_if_t
<IsSerializableType
<T
>::value
, void>
298 testOutput(const std::string
& key
, const T
& inputValue
, ReadCheckpointData
* checkpointData
)
301 checkpointData
->scalar(key
, &outputValue
);
302 EXPECT_EQ(inputValue
, outputValue
);
304 //! Write vector input to CheckpointData
306 void writeInput(const std::string
& key
, const std::vector
<T
>& inputVector
, WriteCheckpointData
* checkpointData
)
308 checkpointData
->arrayRef(key
, makeConstArrayRef(inputVector
));
310 //! Read vector from CheckpointData and test if equal to input
312 void testOutput(const std::string
& key
, const std::vector
<T
>& inputVector
, ReadCheckpointData
* checkpointData
)
314 std::vector
<T
> outputVector
;
315 outputVector
.resize(inputVector
.size());
316 checkpointData
->arrayRef(key
, makeArrayRef(outputVector
));
317 EXPECT_THAT(outputVector
, ::testing::ContainerEq(inputVector
));
319 //! Write tensor input to CheckpointData
320 void writeInput(const std::string
& key
, const tensor inputTensor
, WriteCheckpointData
* checkpointData
)
322 checkpointData
->tensor(key
, inputTensor
);
324 //! Read tensor from CheckpointData and test if equal to input
325 void testOutput(const std::string
& key
, const tensor inputTensor
, ReadCheckpointData
* checkpointData
)
327 tensor outputTensor
= { { 0 } };
328 checkpointData
->tensor(key
, outputTensor
);
329 std::array
<std::array
<real
, 3>, 3> inputTensorArray
= {
330 { { inputTensor
[XX
][XX
], inputTensor
[XX
][YY
], inputTensor
[XX
][ZZ
] },
331 { inputTensor
[YY
][XX
], inputTensor
[YY
][YY
], inputTensor
[YY
][ZZ
] },
332 { inputTensor
[ZZ
][XX
], inputTensor
[ZZ
][YY
], inputTensor
[ZZ
][ZZ
] } }
334 std::array
<std::array
<real
, 3>, 3> outputTensorArray
= {
335 { { outputTensor
[XX
][XX
], outputTensor
[XX
][YY
], outputTensor
[XX
][ZZ
] },
336 { outputTensor
[YY
][XX
], outputTensor
[YY
][YY
], outputTensor
[YY
][ZZ
] },
337 { outputTensor
[ZZ
][XX
], outputTensor
[ZZ
][YY
], outputTensor
[ZZ
][ZZ
] } }
339 EXPECT_THAT(outputTensorArray
, ::testing::ContainerEq(inputTensorArray
));
343 * \brief CheckpointData test fixture
345 * Test whether input is equal to output, either with a single data type
346 * or with a combination of three data types.
348 class CheckpointDataTest
: public ::testing::Test
351 using WriteFunction
= std::function
<void(WriteCheckpointData
*)>;
352 using TestFunction
= std::function
<void(ReadCheckpointData
*)>;
354 // List of functions to write values to checkpoint
355 std::vector
<WriteFunction
> writeFunctions_
;
356 // List of functions to test read checkpoint object
357 std::vector
<TestFunction
> testFunctions_
;
359 // Add values to write / test lists
363 for (const auto& inputValue
: TestValues::testValueGenerator
<T
>())
365 std::string key
= "value" + std::to_string(writeFunctions_
.size());
366 writeFunctions_
.emplace_back([key
, inputValue
](WriteCheckpointData
* checkpointData
) {
367 writeInput(key
, inputValue
, checkpointData
);
369 testFunctions_
.emplace_back([key
, inputValue
](ReadCheckpointData
* checkpointData
) {
370 testOutput(key
, inputValue
, checkpointData
);
375 /* This shuffles the write and test functions (as writing and reading can happen
376 * if different orders), then writes all data to a CheckpointData object.
377 * The CheckpointData object is serialized to file, and then re-read into
378 * a new CheckpointData object. The test functions are then used to assert
379 * that all data is present in the new object.
383 /* Randomize order of writing and testing - checkpoint data can be
384 * accessed in any order!
385 * Having the same order makes this reproducible, so at least for now we're
386 * ok seeding the rng with default value and silencing clang-tidy */
387 // NOLINTNEXTLINE(cert-msc51-cpp)
388 auto rng
= std::default_random_engine
{};
389 std::shuffle(std::begin(writeFunctions_
), std::end(writeFunctions_
), rng
);
390 std::shuffle(std::begin(testFunctions_
), std::end(testFunctions_
), rng
);
392 // Write value to CheckpointData & serialize
394 WriteCheckpointDataHolder writeCheckpointDataHolder
;
395 auto writeCheckpointData
= writeCheckpointDataHolder
.checkpointData("test");
396 for (const auto& writeFunction
: writeFunctions_
)
398 writeFunction(&writeCheckpointData
);
401 auto* file
= gmx_fio_open(filename_
.c_str(), "w");
402 FileIOXdrSerializer
serializer(file
);
403 writeCheckpointDataHolder
.serialize(&serializer
);
407 // Deserialize values and test against reference
409 auto* file
= gmx_fio_open(filename_
.c_str(), "r");
410 FileIOXdrSerializer
deserializer(file
);
412 ReadCheckpointDataHolder readCheckpointDataHolder
;
413 readCheckpointDataHolder
.deserialize(&deserializer
);
416 auto readCheckpointData
= readCheckpointDataHolder
.checkpointData("test");
417 for (const auto& testFunction
: testFunctions_
)
419 testFunction(&readCheckpointData
);
423 // Object can be reused
424 writeFunctions_
.clear();
425 testFunctions_
.clear();
428 // The different functions to add test values - in a list to simplify looping over them
429 std::vector
<std::function
<void()>> addTestValueFunctions_
= {
430 [this]() { addTestValues
<std::string
>(); },
431 [this]() { addTestValues
<int>(); },
432 [this]() { addTestValues
<int64_t>(); },
433 [this]() { addTestValues
<bool>(); },
434 [this]() { addTestValues
<float>(); },
435 [this]() { addTestValues
<double>(); },
436 [this]() { addTestValues
<std::vector
<std::string
>>(); },
437 [this]() { addTestValues
<std::vector
<int>>(); },
438 [this]() { addTestValues
<std::vector
<int64_t>>(); },
439 [this]() { addTestValues
<std::vector
<float>>(); },
440 [this]() { addTestValues
<std::vector
<double>>(); },
441 [this]() { addTestValues
<tensor
>(); }
444 // The types we're testing - for scoped trace output only
445 std::vector
<std::string
> testingTypes_
= { "std::string",
451 "std::vector<std::string>",
453 "std::vector<int64_t>",
454 "std::vector<float>",
455 "std::vector<double>",
458 // We'll need a temporary file to write / read our dummy checkpoint to
459 TestFileManager fileManager_
;
460 std::string filename_
= fileManager_
.getTemporaryFilePath("test.cpt");
463 TEST_F(CheckpointDataTest
, SingleDataTest
)
465 // Test data types separately
466 const int numTypes
= addTestValueFunctions_
.size();
467 for (int type
= 0; type
< numTypes
; ++type
)
469 SCOPED_TRACE("Using test values of type " + testingTypes_
[type
]);
470 addTestValueFunctions_
[type
]();
475 TEST_F(CheckpointDataTest
, MultiDataTest
)
477 // All combinations of 2 different data types
478 const int numTypes
= addTestValueFunctions_
.size();
479 for (int type1
= 0; type1
< numTypes
; ++type1
)
481 for (int type2
= type1
; type2
< numTypes
; ++type2
)
483 SCOPED_TRACE("Using test values of type " + testingTypes_
[type1
] + " and "
484 + testingTypes_
[type2
]);
485 addTestValueFunctions_
[type1
]();
486 addTestValueFunctions_
[type2
]();
493 } // namespace gmx::test