Replace automatic rdtscp checks with boolean option
[gromacs.git] / python_packaging / sample_restraint / src / pythonmodule / export_plugin.cpp
bloba2f2e393321495d6ebb53fbed48a949dfdf7ebba
1 /*! \file
2 * \brief Provide Python bindings and helper functions for setting up restraint potentials.
4 * There is currently a lot of boilerplate here that will be generalized and removed in a future version.
5 * In the mean time, follow the example for EnsembleRestraint to create the proper helper functions
6 * and instantiate the necessary templates.
8 * \author M. Eric Irrgang <ericirrgang@gmail.com>
9 */
11 #include "export_plugin.h"
13 #include <cassert>
15 #include <memory>
17 #include "gmxapi/exceptions.h"
18 #include "gmxapi/md.h"
19 #include "gmxapi/md/mdmodule.h"
20 #include "gmxapi/gmxapi.h"
22 #include "ensemblepotential.h"
24 // Make a convenient alias to save some typing...
25 namespace py = pybind11;
27 ////////////////////////////////
28 // Begin PyRestraint static code
29 /*!
30 * \brief Templated wrapper to use in Python bindings.
32 * Boilerplate
34 * Mix-in from below. Adds a bind behavior, a getModule() method to get a gmxapi::MDModule adapter,
35 * and a create() method that assures a single shared_ptr record for an object that may sometimes
36 * be referred to by a raw pointer and/or have shared_from_this called.
37 * \tparam T class implementing gmx::IRestraintPotential
40 template<class T>
41 class PyRestraint : public T, public std::enable_shared_from_this<PyRestraint<T>>
43 public:
44 void bind(py::object object);
46 using T::name;
48 /*!
49 * \brief
51 * T must either derive from gmxapi::MDModule or provide a template specialization for
52 * PyRestraint<T>::getModule(). If T derives from gmxapi::MDModule, we can keep a weak pointer
53 * to ourself and generate a shared_ptr on request, but std::enable_shared_from_this already
54 * does that, so we use it when we can.
55 * \return
57 std::shared_ptr<gmxapi::MDModule> getModule();
59 /*!
60 * \brief Factory function to get a managed pointer to a new restraint.
62 * \tparam ArgsT
63 * \param args
64 * \return
66 template<typename ... ArgsT>
67 static std::shared_ptr<PyRestraint<T>> create(ArgsT... args)
69 auto newRestraint = std::make_shared<PyRestraint<T>>(args...);
70 return newRestraint;
73 template<typename ... ArgsT>
74 explicit PyRestraint(ArgsT... args) :
75 T{args...}
80 /*!
81 * \brief Implement the gmxapi binding protocol for restraints.
83 * All restraints will use this same code automatically.
85 * \tparam T restraint class exported below.
86 * \param object Python Capsule object to allow binding with a simple C API.
88 template<class T>
89 void PyRestraint<T>::bind(py::object object)
91 PyObject * capsule = object.ptr();
92 if (PyCapsule_IsValid(capsule,
93 gmxapi::MDHolder::api_name))
95 auto holder = static_cast<gmxapi::MDHolder*>(PyCapsule_GetPointer(capsule,
96 gmxapi::MDHolder::api_name));
97 auto workSpec = holder->getSpec();
98 std::cout << this->name() << " received " << holder->name();
99 std::cout << " containing spec of size ";
100 std::cout << workSpec->getModules().size();
101 std::cout << std::endl;
103 auto module = getModule();
104 workSpec->addModule(module);
106 else
108 throw gmxapi::ProtocolError("bind method requires a python capsule as input");
111 // end PyRestraint static code
112 //////////////////////////////
116 * \brief Interact with the restraint framework and gmxapi when launching a simulation.
118 * This should be generalized and removed from here. Unfortunately, some things need to be
119 * standardized first. If a potential follows the example of EnsembleRestraint or HarmonicRestraint,
120 * the template specializations below can be mimicked to give GROMACS access to the potential.
122 * \tparam T class implementing the gmxapi::MDModule interface.
123 * \return shared ownership of a T object via the gmxapi::MDModule interface.
125 // If T is derived from gmxapi::MDModule, create a default-constructed std::shared_ptr<T>
126 // \todo Need a better default that can call a shared_from_this()
127 template<class T>
128 std::shared_ptr<gmxapi::MDModule> PyRestraint<T>::getModule()
130 auto module = std::make_shared<typename std::enable_if<std::is_base_of<gmxapi::MDModule, T>::value, T>::type>();
131 return module;
135 template<>
136 std::shared_ptr<gmxapi::MDModule> PyRestraint<plugin::RestraintModule<plugin::EnsembleRestraint>>::getModule()
138 return shared_from_this();
140 //////////////////////////////////////////////////////////////////////////////////////////
141 // New restraints mimicking EnsembleRestraint should specialize getModule() here as above.
142 //////////////////////////////////////////////////////////////////////////////////////////
146 ////////////////////
147 // Begin MyRestraint
149 * \brief No-op restraint class for testing and demonstration.
151 class MyRestraint
153 public:
154 static const char* docstring;
156 static std::string name()
157 { return "MyRestraint"; };
160 template<>
161 std::shared_ptr<gmxapi::MDModule> PyRestraint<MyRestraint>::getModule()
163 auto module = std::make_shared<gmxapi::MDModule>();
164 return module;
168 // Raw string will have line breaks and indentation as written between the delimiters.
169 const char* MyRestraint::docstring =
170 R"rawdelimiter(Some sort of custom potential.
171 )rawdelimiter";
172 // end MyRestraint
173 //////////////////
176 class EnsembleRestraintBuilder
178 public:
179 explicit EnsembleRestraintBuilder(py::object element)
181 name_ = py::cast<std::string>(element.attr("name"));
182 assert(!name_.empty());
184 // It looks like we need some boilerplate exceptions for plugins so we have something to
185 // raise if the element is invalid.
186 assert(py::hasattr(element,
187 "params"));
189 // Params attribute should be a Python list
190 py::dict parameter_dict = element.attr("params");
191 // \todo Check for the presence of these dictionary keys to avoid hard-to-diagnose error.
193 // Get positional parameters.
194 py::list sites = parameter_dict["sites"];
195 for (auto&& site : sites)
197 siteIndices_.emplace_back(py::cast<int>(site));
200 auto nbins = py::cast<size_t>(parameter_dict["nbins"]);
201 auto binWidth = py::cast<double>(parameter_dict["binWidth"]);
202 auto minDist = py::cast<double>(parameter_dict["min_dist"]);
203 auto maxDist = pybind11::cast<double>(parameter_dict["max_dist"]);
204 auto experimental = pybind11::cast<std::vector<double>>(parameter_dict["experimental"]);
205 auto nSamples = pybind11::cast<unsigned int>(parameter_dict["nsamples"]);
206 auto samplePeriod = pybind11::cast<double>(parameter_dict["sample_period"]);
207 auto nWindows = pybind11::cast<unsigned int>(parameter_dict["nwindows"]);
208 auto k = pybind11::cast<double>(parameter_dict["k"]);
209 auto sigma = pybind11::cast<double>(parameter_dict["sigma"]);
211 auto params = plugin::makeEnsembleParams(nbins,
212 binWidth,
213 minDist,
214 maxDist,
215 experimental,
216 nSamples,
217 samplePeriod,
218 nWindows,
220 sigma);
221 params_ = std::move(*params);
223 // Note that if we want to grab a reference to the Context or its communicator, we can get it
224 // here through element.workspec._context. We need a more general API solution, but this code is
225 // in the Python bindings code, so we know we are in a Python Context.
226 assert(py::hasattr(element,
227 "workspec"));
228 auto workspec = element.attr("workspec");
229 assert(py::hasattr(workspec,
230 "_context"));
231 context_ = workspec.attr("_context");
235 * \brief Add node(s) to graph for the work element.
237 * \param graph networkx.DiGraph object still evolving in gmx.context.
239 * \todo This may not follow the latest graph building protocol as described.
241 void build(py::object graph)
243 if (!subscriber_)
245 return;
247 else
249 if (!py::hasattr(subscriber_, "potential")) throw gmxapi::ProtocolError("Invalid subscriber");
252 // Restraints do not currently add any new nodes to the graph, so we
253 // mark this standard 'graph' argument unused.
254 (void) graph;
256 // Temporarily subvert things to get quick-and-dirty solution for testing.
257 // Need to capture Python communicator and pybind syntax in closure so EnsembleResources
258 // can just call with matrix arguments.
260 // This can be replaced with a subscription and delayed until launch, if necessary.
261 if (!py::hasattr(context_, "ensemble_update"))
263 throw gmxapi::ProtocolError("context does not have 'ensemble_update'.");
265 // make a local copy of the Python object so we can capture it in the lambda
266 auto update = context_.attr("ensemble_update");
267 // Make a callable with standardizeable signature.
268 const std::string name{name_};
269 auto functor = [update, name](const plugin::Matrix<double>& send,
270 plugin::Matrix<double>* receive) {
271 update(send,
272 receive,
273 py::str(name));
276 // To use a reduce function on the Python side, we need to provide it with a Python buffer-like object,
277 // so we will create one here. Note: it looks like the SharedData element will be useful after all.
278 auto resources = std::make_shared<plugin::Resources>(std::move(functor));
280 auto potential = PyRestraint<plugin::RestraintModule<plugin::EnsembleRestraint>>::create(name_,
281 siteIndices_,
282 params_,
283 resources);
285 auto subscriber = subscriber_;
286 py::list potentialList = subscriber.attr("potential");
287 potentialList.append(potential);
292 * \brief Accept subscription of an MD task.
294 * \param subscriber Python object with a 'potential' attribute that is a Python list.
296 * During build, an object is added to the subscriber's self.potential, which is then bound with
297 * system.add_potential(potential) during the subscriber's launch()
299 void addSubscriber(py::object subscriber)
301 assert(py::hasattr(subscriber,
302 "potential"));
303 subscriber_ = subscriber;
306 py::object subscriber_;
307 py::object context_;
308 std::vector<int> siteIndices_;
310 plugin::ensemble_input_param_type params_;
312 std::string name_;
315 namespace {
318 * \brief Factory function to create a new builder for use during Session launch.
320 * \param element WorkElement provided through Context
321 * \return ownership of new builder object
323 std::unique_ptr<EnsembleRestraintBuilder> createEnsembleBuilder(const py::object& element)
325 using std::make_unique;
326 auto builder = make_unique<EnsembleRestraintBuilder>(element);
327 return builder;
333 ////////////////////////////////////////////////////////////////////////////////////////////
334 // New potentials modeled after EnsembleRestraint should define a Builder class and define a
335 // factory function here, following the previous two examples. The factory function should be
336 // exposed to Python following the examples near the end of the PYBIND11_MODULE block.
337 ////////////////////////////////////////////////////////////////////////////////////////////
340 //////////////////////////////////////////////////////////////////////////////////////////////////
341 // The PYBIND11_MODULE block uses the pybind11 framework (ref https://github.com/pybind/pybind11 )
342 // to generate Python bindings to the C++ code elsewhere in this repository. A copy of the pybind11
343 // source code is included with this repository. Use syntax from the examples below when exposing
344 // a new potential, along with its builder and parameters structure. In future releases, there will
345 // be less code to include elsewhere, but more syntax in the block below to define and export the
346 // interface to a plugin. pybind11 is not required to write a GROMACS extension module or for
347 // compatibility with the ``gmx`` module provided with gmxapi. It is sufficient to implement the
348 // various protocols, C API and Python function names, but we do not provide example code
349 // for other Python bindings frameworks.
350 //////////////////////////////////////////////////////////////////////////////////////////////////
352 // The first argument is the name of the module when importing to Python. This should be the same as the name specified
353 // as the OUTPUT_NAME for the shared object library in the CMakeLists.txt file. The second argument, 'm', can be anything
354 // but it might as well be short since we use it to refer to aspects of the module we are defining.
355 PYBIND11_MODULE(myplugin, m) {
356 m.doc() = "sample plugin"; // This will be the text of the module's docstring.
358 // Matrix utility class (temporary). Borrowed from http://pybind11.readthedocs.io/en/master/advanced/pycpp/numpy.html#arrays
359 py::class_<plugin::Matrix<double>, std::shared_ptr<plugin::Matrix<double>>>(m,
360 "Matrix",
361 py::buffer_protocol())
362 .def_buffer([](plugin::Matrix<double>& matrix) -> py::buffer_info {
363 return py::buffer_info(
364 matrix.data(), /* Pointer to buffer */
365 sizeof(double), /* Size of one scalar */
366 py::format_descriptor<double>::format(), /* Python struct-style format descriptor */
367 2, /* Number of dimensions */
368 {matrix.rows(), matrix.cols()}, /* Buffer dimensions */
369 {sizeof(double) * matrix.cols(), /* Strides (in bytes) for each index */
370 sizeof(double)}
374 //////////////////////////////////////////////////////////////////////////
375 // Begin EnsembleRestraint
377 // Define Builder to be returned from ensemble_restraint Python function defined further down.
378 pybind11::class_<EnsembleRestraintBuilder> ensembleBuilder(m,
379 "EnsembleBuilder");
380 ensembleBuilder.def("add_subscriber",
381 &EnsembleRestraintBuilder::addSubscriber);
382 ensembleBuilder.def("build",
383 &EnsembleRestraintBuilder::build);
385 // Get more concise name for the template instantiation...
386 using PyEnsemble = PyRestraint<plugin::RestraintModule<plugin::EnsembleRestraint>>;
388 // Export a Python class for our parameters struct
389 py::class_<plugin::EnsembleRestraint::input_param_type> ensembleParams(m, "EnsembleRestraintParams");
390 m.def("make_ensemble_params",
391 &plugin::makeEnsembleParams);
393 // API object to build.
394 py::class_<PyEnsemble, std::shared_ptr<PyEnsemble>> ensemble(m, "EnsembleRestraint");
395 // EnsembleRestraint can only be created via builder for now.
396 ensemble.def("bind",
397 &PyEnsemble::bind,
398 "Implement binding protocol");
400 * To implement gmxapi_workspec_1_0, the module needs a function that a Context can import that
401 * produces a builder that translates workspec elements for session launching. The object returned
402 * by our function needs to have an add_subscriber(other_builder) method and a build(graph) method.
403 * The build() method returns None or a launcher. A launcher has a signature like launch(rank) and
404 * returns None or a runner.
407 // Generate the name operation that will be used to specify elements of Work in gmxapi workflows.
408 // WorkElements will then have namespace: "myplugin" and operation: "ensemble_restraint"
409 m.def("ensemble_restraint",
410 [](const py::object element) { return createEnsembleBuilder(element); });
412 // End EnsembleRestraint
413 ///////////////////////////////////////////////////////////////////////////