Update instructions in containers.rst
[gromacs.git] / src / gromacs / mdtypes / tests / checkpointdata.cpp
blobf3a899d3a97ce8973ca02146ead07ef3ee312688
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
35 #include "gmxpre.h"
37 #include "gromacs/mdtypes/checkpointdata.h"
39 #include <algorithm>
40 #include <random>
42 #include <gmock/gmock.h>
43 #include <gtest/gtest.h>
45 #include "gromacs/fileio/gmxfio.h"
46 #include "gromacs/fileio/gmxfio_xdr.h"
47 #include "gromacs/utility/fatalerror.h"
48 #include "gromacs/utility/inmemoryserializer.h"
50 #include "testutils/testfilemanager.h"
52 namespace gmx::test
54 namespace
57 /*! \internal
58 * \ingroup module_modularsimulator
59 * \brief Struct allowing to check if type is vector of serializable data
61 //! \{
62 template<class T>
63 struct IsVectorOfSerializableType
65 static bool const value = false;
67 template<class T>
68 struct IsVectorOfSerializableType<std::vector<T>>
70 static bool const value = IsSerializableType<T>::value;
72 //! \}
74 /*! \internal
75 * \brief Unified looping over test data
77 * This class allows to write a loop over test data as
78 * for (const auto& value : TestValues::testValueGenerator<type>())
79 * where type can be any of std::string, int, int64_t, bool, float, double,
80 * std::vector<[std::string, int, int64_6, float, double]>, or tensor.
82 class TestValues
84 public:
85 /*! \internal
86 * \brief Helper class allowing to loop over test values
87 * \tparam T type of value
89 template<typename T>
90 class TestValueGenerator
92 public:
93 //! Custom iterator
94 class Iterator
96 public:
97 explicit Iterator(const T* ptr) : ptr_(ptr) {}
98 Iterator operator++();
99 bool operator!=(const Iterator& other) const { return ptr_ != other.ptr_; }
100 const T& operator*() const { return *ptr_; }
102 private:
103 const T* ptr_;
106 Iterator begin() const;
107 Iterator end() const;
110 /*! \internal
111 * \brief Static function returning a TestValueGenerator of type T
112 * \tparam T type of values generated
113 * \return TestValueGenerator<T>
115 template<typename T>
116 static TestValueGenerator<T> testValueGenerator()
118 static const TestValueGenerator<T> testValueGenerator;
119 return testValueGenerator;
122 private:
123 template<typename T>
124 static const std::vector<T>& getTestVector();
126 template<typename T>
127 static std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, const T*>
128 getBeginPointer();
129 template<typename T>
130 static std::enable_if_t<IsVectorOfSerializableType<T>::value, const T*> getBeginPointer();
131 template<typename T>
132 static std::enable_if_t<std::is_same<T, bool>::value, const T*> getBeginPointer();
133 template<typename T>
134 static std::enable_if_t<std::is_same<T, tensor>::value, const T*> getBeginPointer();
136 template<typename T>
137 static std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, const T*>
138 getEndPointer();
139 template<typename T>
140 static std::enable_if_t<IsVectorOfSerializableType<T>::value, const T*> getEndPointer();
141 template<typename T>
142 static std::enable_if_t<std::is_same<T, bool>::value, const T*> getEndPointer();
143 template<typename T>
144 static std::enable_if_t<std::is_same<T, tensor>::value, const T*> getEndPointer();
146 template<typename T>
147 static std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, void>
148 increment(const T** ptr);
149 template<typename T>
150 static std::enable_if_t<IsVectorOfSerializableType<T>::value, void> increment(const T** ptr);
151 template<typename T>
152 static std::enable_if_t<std::is_same<T, bool>::value, void> increment(const T** ptr);
153 template<typename T>
154 static std::enable_if_t<std::is_same<T, tensor>::value, void> increment(const T** ptr);
156 constexpr static bool testTrue = true;
157 constexpr static bool testFalse = false;
158 constexpr static tensor testTensor1 = { { 1.6234, 2.4632, 3.1112 },
159 { 4.66234, 5.9678, 6.088 },
160 { 7.00001, 8.43535, 9.11233 } };
161 #if GMX_DOUBLE
162 constexpr static tensor testTensor2 = { { 1, GMX_DOUBLE_EPS, 3 },
163 { GMX_DOUBLE_MIN, 5, 6 },
164 { 7, 8, GMX_DOUBLE_MAX } };
165 #else
166 constexpr static tensor testTensor2 = { { 1, GMX_FLOAT_EPS, 3 },
167 { GMX_FLOAT_MIN, 5, 6 },
168 { 7, 8, GMX_FLOAT_MAX } };
169 #endif
172 // Begin implementations of TestValues methods
173 template<>
174 const std::vector<std::string>& TestValues::getTestVector()
176 static const std::vector<std::string> testStrings({ "Test string\nwith newlines\n", "" });
177 return testStrings;
179 template<>
180 const std::vector<int>& TestValues::getTestVector()
182 static const std::vector<int> testInts({ { 3, INT_MAX, INT_MIN } });
183 return testInts;
185 template<>
186 const std::vector<int64_t>& TestValues::getTestVector()
188 static const std::vector<int64_t> testInt64s({ -7, LLONG_MAX, LLONG_MIN });
189 return testInt64s;
191 template<>
192 const std::vector<float>& TestValues::getTestVector()
194 static const std::vector<float> testFloats({ 33.9, GMX_FLOAT_MAX, GMX_FLOAT_MIN, GMX_FLOAT_EPS });
195 return testFloats;
197 template<>
198 const std::vector<double>& TestValues::getTestVector()
200 static const std::vector<double> testDoubles({ -123.45, GMX_DOUBLE_MAX, GMX_DOUBLE_MIN, GMX_DOUBLE_EPS });
201 return testDoubles;
204 template<typename T>
205 std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, const T*> TestValues::getBeginPointer()
207 return getTestVector<T>().data();
209 template<typename T>
210 std::enable_if_t<IsVectorOfSerializableType<T>::value, const T*> TestValues::getBeginPointer()
212 return &getTestVector<typename T::value_type>();
214 template<typename T>
215 std::enable_if_t<std::is_same<T, bool>::value, const T*> TestValues::getBeginPointer()
217 return &testTrue;
219 template<typename T>
220 std::enable_if_t<std::is_same<T, tensor>::value, const T*> TestValues::getBeginPointer()
222 return &testTensor1;
225 template<typename T>
226 std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, const T*> TestValues::getEndPointer()
228 return getTestVector<T>().data() + getTestVector<T>().size();
230 template<typename T>
231 std::enable_if_t<IsVectorOfSerializableType<T>::value, const T*> TestValues::getEndPointer()
233 return &getTestVector<typename T::value_type>() + 1;
235 template<typename T>
236 std::enable_if_t<std::is_same<T, bool>::value, const T*> TestValues::getEndPointer()
238 return nullptr;
240 template<typename T>
241 std::enable_if_t<std::is_same<T, tensor>::value, const T*> TestValues::getEndPointer()
243 return nullptr;
246 template<typename T>
247 std::enable_if_t<IsSerializableType<T>::value && !std::is_same<T, bool>::value, void>
248 TestValues::increment(const T** ptr)
250 ++(*ptr);
252 template<typename T>
253 std::enable_if_t<IsVectorOfSerializableType<T>::value, void> TestValues::increment(const T** ptr)
255 ++(*ptr);
257 template<typename T>
258 std::enable_if_t<std::is_same<T, bool>::value, void> TestValues::increment(const T** ptr)
260 *ptr = (*ptr == &testTrue) ? &testFalse : nullptr;
262 template<typename T>
263 std::enable_if_t<std::is_same<T, tensor>::value, void> TestValues::increment(const T** ptr)
265 *ptr = (*ptr == &testTensor1) ? &testTensor2 : nullptr;
268 template<typename T>
269 typename TestValues::TestValueGenerator<T>::Iterator TestValues::TestValueGenerator<T>::begin() const
271 return TestValues::TestValueGenerator<T>::Iterator(getBeginPointer<T>());
274 template<typename T>
275 typename TestValues::TestValueGenerator<T>::Iterator TestValues::TestValueGenerator<T>::end() const
277 return TestValues::TestValueGenerator<T>::Iterator(getEndPointer<T>());
280 template<typename T>
281 typename TestValues::TestValueGenerator<T>::Iterator TestValues::TestValueGenerator<T>::Iterator::operator++()
283 TestValues::increment(&ptr_);
284 return *this;
286 // End implementations of TestValues methods
288 //! Write scalar input to CheckpointData
289 template<typename T>
290 typename std::enable_if_t<IsSerializableType<T>::value, void>
291 writeInput(const std::string& key, const T& inputValue, WriteCheckpointData* checkpointData)
293 checkpointData->scalar(key, &inputValue);
295 //! Read scalar from CheckpointData and test if equal to input
296 template<typename T>
297 typename std::enable_if_t<IsSerializableType<T>::value, void>
298 testOutput(const std::string& key, const T& inputValue, ReadCheckpointData* checkpointData)
300 T outputValue;
301 checkpointData->scalar(key, &outputValue);
302 EXPECT_EQ(inputValue, outputValue);
304 //! Write vector input to CheckpointData
305 template<typename T>
306 void writeInput(const std::string& key, const std::vector<T>& inputVector, WriteCheckpointData* checkpointData)
308 checkpointData->arrayRef(key, makeConstArrayRef(inputVector));
310 //! Read vector from CheckpointData and test if equal to input
311 template<typename T>
312 void testOutput(const std::string& key, const std::vector<T>& inputVector, ReadCheckpointData* checkpointData)
314 std::vector<T> outputVector;
315 outputVector.resize(inputVector.size());
316 checkpointData->arrayRef(key, makeArrayRef(outputVector));
317 EXPECT_THAT(outputVector, ::testing::ContainerEq(inputVector));
319 //! Write tensor input to CheckpointData
320 void writeInput(const std::string& key, const tensor inputTensor, WriteCheckpointData* checkpointData)
322 checkpointData->tensor(key, inputTensor);
324 //! Read tensor from CheckpointData and test if equal to input
325 void testOutput(const std::string& key, const tensor inputTensor, ReadCheckpointData* checkpointData)
327 tensor outputTensor = { { 0 } };
328 checkpointData->tensor(key, outputTensor);
329 std::array<std::array<real, 3>, 3> inputTensorArray = {
330 { { inputTensor[XX][XX], inputTensor[XX][YY], inputTensor[XX][ZZ] },
331 { inputTensor[YY][XX], inputTensor[YY][YY], inputTensor[YY][ZZ] },
332 { inputTensor[ZZ][XX], inputTensor[ZZ][YY], inputTensor[ZZ][ZZ] } }
334 std::array<std::array<real, 3>, 3> outputTensorArray = {
335 { { outputTensor[XX][XX], outputTensor[XX][YY], outputTensor[XX][ZZ] },
336 { outputTensor[YY][XX], outputTensor[YY][YY], outputTensor[YY][ZZ] },
337 { outputTensor[ZZ][XX], outputTensor[ZZ][YY], outputTensor[ZZ][ZZ] } }
339 EXPECT_THAT(outputTensorArray, ::testing::ContainerEq(inputTensorArray));
342 /*! \internal
343 * \brief CheckpointData test fixture
345 * Test whether input is equal to output, either with a single data type
346 * or with a combination of three data types.
348 class CheckpointDataTest : public ::testing::Test
350 public:
351 using WriteFunction = std::function<void(WriteCheckpointData*)>;
352 using TestFunction = std::function<void(ReadCheckpointData*)>;
354 // List of functions to write values to checkpoint
355 std::vector<WriteFunction> writeFunctions_;
356 // List of functions to test read checkpoint object
357 std::vector<TestFunction> testFunctions_;
359 // Add values to write / test lists
360 template<typename T>
361 void addTestValues()
363 for (const auto& inputValue : TestValues::testValueGenerator<T>())
365 std::string key = "value" + std::to_string(writeFunctions_.size());
366 writeFunctions_.emplace_back([key, inputValue](WriteCheckpointData* checkpointData) {
367 writeInput(key, inputValue, checkpointData);
369 testFunctions_.emplace_back([key, inputValue](ReadCheckpointData* checkpointData) {
370 testOutput(key, inputValue, checkpointData);
375 /* This shuffles the write and test functions (as writing and reading can happen
376 * if different orders), then writes all data to a CheckpointData object.
377 * The CheckpointData object is serialized to file, and then re-read into
378 * a new CheckpointData object. The test functions are then used to assert
379 * that all data is present in the new object.
381 void test()
383 /* Randomize order of writing and testing - checkpoint data can be
384 * accessed in any order!
385 * Having the same order makes this reproducible, so at least for now we're
386 * ok seeding the rng with default value and silencing clang-tidy */
387 // NOLINTNEXTLINE(cert-msc51-cpp)
388 auto rng = std::default_random_engine{};
389 std::shuffle(std::begin(writeFunctions_), std::end(writeFunctions_), rng);
390 std::shuffle(std::begin(testFunctions_), std::end(testFunctions_), rng);
392 // Write value to CheckpointData & serialize
394 WriteCheckpointDataHolder writeCheckpointDataHolder;
395 auto writeCheckpointData = writeCheckpointDataHolder.checkpointData("test");
396 for (const auto& writeFunction : writeFunctions_)
398 writeFunction(&writeCheckpointData);
401 auto* file = gmx_fio_open(filename_.c_str(), "w");
402 FileIOXdrSerializer serializer(file);
403 writeCheckpointDataHolder.serialize(&serializer);
404 gmx_fio_close(file);
407 // Deserialize values and test against reference
409 auto* file = gmx_fio_open(filename_.c_str(), "r");
410 FileIOXdrSerializer deserializer(file);
412 ReadCheckpointDataHolder readCheckpointDataHolder;
413 readCheckpointDataHolder.deserialize(&deserializer);
414 gmx_fio_close(file);
416 auto readCheckpointData = readCheckpointDataHolder.checkpointData("test");
417 for (const auto& testFunction : testFunctions_)
419 testFunction(&readCheckpointData);
423 // Object can be reused
424 writeFunctions_.clear();
425 testFunctions_.clear();
428 // The different functions to add test values - in a list to simplify looping over them
429 std::vector<std::function<void()>> addTestValueFunctions_ = {
430 [this]() { addTestValues<std::string>(); },
431 [this]() { addTestValues<int>(); },
432 [this]() { addTestValues<int64_t>(); },
433 [this]() { addTestValues<bool>(); },
434 [this]() { addTestValues<float>(); },
435 [this]() { addTestValues<double>(); },
436 [this]() { addTestValues<std::vector<std::string>>(); },
437 [this]() { addTestValues<std::vector<int>>(); },
438 [this]() { addTestValues<std::vector<int64_t>>(); },
439 [this]() { addTestValues<std::vector<float>>(); },
440 [this]() { addTestValues<std::vector<double>>(); },
441 [this]() { addTestValues<tensor>(); }
444 // The types we're testing - for scoped trace output only
445 std::vector<std::string> testingTypes_ = { "std::string",
446 "int",
447 "int64_t",
448 "bool",
449 "float",
450 "double",
451 "std::vector<std::string>",
452 "std::vector<int>",
453 "std::vector<int64_t>",
454 "std::vector<float>",
455 "std::vector<double>",
456 "tensor" };
458 // We'll need a temporary file to write / read our dummy checkpoint to
459 TestFileManager fileManager_;
460 std::string filename_ = fileManager_.getTemporaryFilePath("test.cpt");
463 TEST_F(CheckpointDataTest, SingleDataTest)
465 // Test data types separately
466 const int numTypes = addTestValueFunctions_.size();
467 for (int type = 0; type < numTypes; ++type)
469 SCOPED_TRACE("Using test values of type " + testingTypes_[type]);
470 addTestValueFunctions_[type]();
471 test();
475 TEST_F(CheckpointDataTest, MultiDataTest)
477 // All combinations of 2 different data types
478 const int numTypes = addTestValueFunctions_.size();
479 for (int type1 = 0; type1 < numTypes; ++type1)
481 for (int type2 = type1; type2 < numTypes; ++type2)
483 SCOPED_TRACE("Using test values of type " + testingTypes_[type1] + " and "
484 + testingTypes_[type2]);
485 addTestValueFunctions_[type1]();
486 addTestValueFunctions_[type2]();
487 test();
492 } // namespace
493 } // namespace gmx::test