git-svn-id: svn://svn.icms.temple.edu/lammps-ro/trunk@16053 f3b2605a-c512-4ea7-a41b...
[lammps.git] / tools / i-pi / ipi / utils / depend.py
blob1ca361f781109db2754ec28c8638c491c2edfcbc
1 """Contains the classes that are used to define the dependency network.
3 Copyright (C) 2013, Joshua More and Michele Ceriotti
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http.//www.gnu.org/licenses/>.
19 The classes defined in this module overload the standard __get__ and __set__
20 routines of the numpy ndarray class and standard library object class so that
21 they automatically keep track of whether anything they depend on has been
22 altered, and so only recalculate their value when necessary.
24 Basic quantities that depend on nothing else can be manually altered in the
25 usual way, all other quantities are updated automatically and cannot be changed
26 directly.
28 The exceptions to this are synchronized properties, which are in effect
29 multiple basic quantities all related to each other, for example the bead and
30 normal mode representations of the positions and momenta. In this case any of
31 the representations can be set manually, and all the other representations
32 must keep in step.
34 For a more detailed discussion, see the reference manual.
36 Classes:
37 depend_base: Base depend class with the generic methods and attributes.
38 depend_value: Depend class for scalar objects.
39 depend_array: Depend class for arrays.
40 synchronizer: Class that holds the different objects that are related to each
41 other and keeps track of which property has been set manually.
42 dobject: An extension of the standard library object that overloads
43 __getattribute__ and __setattribute__, so that we can use the
44 standard syntax for setting and getting the depend object,
45 i.e. foo = value, not foo.set(value).
47 Functions:
48 dget: Gets the dependencies of a depend object.
49 dset: Sets the dependencies of a depend object.
50 depstrip: Used on a depend_array object, to access its value without
51 needing the depend machinery, and so much more quickly. Must not be used
52 if the value of the array is to be changed.
53 depcopy: Copies the dependencies from one object to another
54 deppipe: Used to make two objects be synchronized to the same value.
55 """
57 __all__ = ['depend_base', 'depend_value', 'depend_array', 'synchronizer',
58 'dobject', 'dget', 'dset', 'depstrip', 'depcopy', 'deppipe']
60 import numpy as np
61 from ipi.utils.messages import verbosity, warning
63 class synchronizer(object):
64 """Class to implement synched objects.
66 Holds the objects used to keep two or more objects in step with each other.
67 This is shared between all the synched objects.
69 Attributes:
70 synched: A dictionary containing all the synched objects, of the form
71 {"name": depend object}.
72 manual: A string containing the name of the object being manually changed.
73 """
75 def __init__(self, deps=None):
76 """Initialises synchronizer.
78 Args:
79 deps: Optional dictionary giving the synched objects of the form
80 {"name": depend object}.
81 """
83 if deps is None:
84 self.synced = dict()
85 else:
86 self.synced = deps
88 self.manual = None
91 #TODO put some error checks in the init to make sure that the object is initialized from consistent synchro and func states
92 class depend_base(object):
93 """Base class for dependency handling.
95 Builds the majority of the machinery required for the different depend
96 objects. Contains functions to add and remove dependencies, the tainting
97 mechanism by which information about which objects have been updated is
98 passed around the dependency network, and the manual and automatic update
99 functions to check that depend objects with functions are not manually
100 updated and that synchronized objects are kept in step with the one manually
101 changed.
103 Attributes:
104 _tainted: An array containing one boolean, which is True if one of the
105 dependencies has been changed since the last time the value was
106 cached.
107 _func: A function name giving the method of calculating the value,
108 if required. None otherwise.
109 _name: The name of the depend base object.
110 _synchro: A synchronizer object to deal with synched objects, if
111 required. None otherwise.
112 _dependants: A list containing all objects dependent on the self.
115 def __init__(self, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None):
116 """Initialises depend_base.
118 An unusual initialisation routine, as it has to be able to deal with the
119 depend array mechanism for returning slices as new depend arrays.
121 This is the reason for the penultimate if statement; it automatically
122 taints objects created from scratch but does nothing to slices which are
123 not tainted.
125 Also, the last if statement makes sure that if a synchronized property is
126 sliced, this initialization routine does not automatically set it to the
127 manually updated property.
129 Args:
130 name: A string giving the name of self.
131 tainted: An optional array containing one boolean which is True if one
132 of the dependencies has been changed.
133 func: An optional argument that can be specified either by a function
134 name, or for synchronized values a dictionary of the form
135 {"name": function name}; where "name" is one of the other
136 synched objects and function name is the name of a function to
137 get the object "name" from self.
138 synchro: An optional synchronizer object.
139 dependants: An optional list containing objects that depend on self.
140 dependencies: An optional list containing objects that self
141 depends upon.
144 self._dependants = []
145 if tainted is None:
146 tainted = np.array([True],bool)
147 if dependants is None:
148 dependants = []
149 if dependencies is None:
150 dependencies = []
151 self._tainted = tainted
152 self._func = func
153 self._name = name
155 self.add_synchro(synchro)
157 for item in dependencies:
158 item.add_dependant(self, tainted)
160 self._dependants = dependants
162 # Don't taint self if the object is a primitive one. However, do propagate tainting to dependants if required.
163 if (tainted):
164 if self._func is None:
165 self.taint(taintme=False)
166 else:
167 self.taint(taintme=tainted)
170 def add_synchro(self, synchro=None):
171 """ Links depend object to a synchronizer. """
173 self._synchro = synchro
174 if not self._synchro is None and not self._name in self._synchro.synced:
175 self._synchro.synced[self._name] = self
176 self._synchro.manual = self._name
179 def add_dependant(self, newdep, tainted=True):
180 """Adds a dependant property.
182 Args:
183 newdep: The depend object to be added to the dependency list.
184 tainted: A boolean that decides whether newdep should be tainted.
185 True by default.
188 self._dependants.append(newdep)
189 if tainted:
190 newdep.taint(taintme=True)
192 def add_dependency(self, newdep, tainted=True):
193 """Adds a dependency.
195 Args:
196 newdep: The depend object self now depends upon.
197 tainted: A boolean that decides whether self should
198 be tainted. True by default.
201 newdep._dependants.append(self)
202 if tainted:
203 self.taint(taintme=True)
205 def taint(self,taintme=True):
206 """Recursively sets tainted flag on dependent objects.
208 The main function dealing with the dependencies. Taints all objects
209 further down the dependency tree until either all objects have been
210 tainted, or it reaches only objects that have already been tainted. Note
211 that in the case of a dependency loop the initial setting of _tainted to
212 True prevents an infinite loop occuring.
214 Also, in the case of a synchro object, the manually set quantity is not
215 tainted, as it is assumed that synchro objects only depend on each other.
217 Args:
218 taintme: A boolean giving whether self should be tainted at the end.
219 True by default.
222 self._tainted[:] = True
223 for item in self._dependants:
224 if (not item._tainted[0]):
225 item.taint()
226 if not self._synchro is None:
227 for v in self._synchro.synced.values():
228 if (not v._tainted[0]) and (not v is self):
229 v.taint(taintme=True)
230 self._tainted[:] = (taintme and (not self._name == self._synchro.manual))
231 else:
232 self._tainted[:] = taintme
234 def tainted(self):
235 """Returns tainted flag."""
237 return self._tainted[0]
239 def update_auto(self):
240 """Automatic update routine.
242 Updates the value when get has been called and self has been tainted.
245 if not self._synchro is None:
246 if (not self._name == self._synchro.manual):
247 self.set(self._func[self._synchro.manual](), manual=False)
248 else:
249 warning(self._name + " probably shouldn't be tainted (synchro)", verbosity.low)
250 elif not self._func is None:
251 self.set(self._func(), manual=False)
252 else:
253 warning(self._name + " probably shouldn't be tainted (value)", verbosity.low)
255 def update_man(self):
256 """Manual update routine.
258 Updates the value when the value has been manually set. Also raises an
259 exception if a calculated quantity has been manually set. Also starts the
260 tainting routine.
262 Raises:
263 NameError: If a calculated quantity has been manually set.
266 if not self._synchro is None:
267 self._synchro.manual = self._name
268 for v in self._synchro.synced.values():
269 v.taint(taintme=True)
270 self._tainted[:] = False
271 elif not self._func is None:
272 raise NameError("Cannot set manually the value of the automatically-computed property <" + self._name + ">")
273 else:
274 self.taint(taintme=False)
276 def set(self, value, manual=False):
277 """Dummy setting routine."""
279 pass
281 def get(self):
282 """Dummy getting routine."""
284 pass
286 class depend_value(depend_base):
287 """Scalar class for dependency handling.
289 Attributes:
290 _value: The value associated with self.
293 def __init__(self, name, value=None, synchro=None, func=None, dependants=None, dependencies=None, tainted=None):
294 """Initialises depend_value.
296 Args:
297 name: A string giving the name of self.
298 value: The value of the object. Optional.
299 tainted: An optional array giving the tainted flag. Default is [True].
300 func: An optional argument that can be specified either by a function
301 name, or for synchronized values a dictionary of the form
302 {"name": function name}; where "name" is one of the other
303 synched objects and function name is the name of a function to
304 get the object "name" from self.
305 synchro: An optional synchronizer object.
306 dependants: An optional list containing objects that depend on self.
307 dependencies: An optional list containing objects that self
308 depends upon.
311 self._value = value
312 super(depend_value,self).__init__(name, synchro, func, dependants, dependencies, tainted)
314 def get(self):
315 """Returns value, after recalculating if necessary.
317 Overwrites the standard method of getting value, so that value
318 is recalculated if tainted.
321 if self._tainted[0]:
322 self.update_auto()
323 self.taint(taintme=False)
325 return self._value
327 def __get__(self, instance, owner):
328 """Overwrites standard get function."""
330 return self.get()
332 def set(self, value, manual=True):
333 """Alters value and taints dependencies.
335 Overwrites the standard method of setting value, so that dependent
336 quantities are tainted, and so we check that computed quantities are not
337 manually updated.
340 self._value = value
341 self.taint(taintme=False)
342 if (manual):
343 self.update_man()
345 def __set__(self, instance, value):
346 """Overwrites standard set function."""
348 self.set(value)
351 class depend_array(np.ndarray, depend_base):
352 """Array class for dependency handling.
354 Differs from depend_value as arrays handle getting items in a different
355 way to scalar quantities, and as there needs to be support for slicing an
356 array. Initialisation is also done in a different way for ndarrays.
358 Attributes:
359 _bval: The base deparray storage space. Equal to depstrip(self) unless
360 self is a slice.
363 def __new__(cls, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None):
364 """Creates a new array from a template.
366 Called whenever a new instance of depend_array is created. Casts the
367 array base into an appropriate form before passing it to
368 __array_finalize__().
370 Args:
371 See __init__().
374 obj = np.asarray(value).view(cls)
375 return obj
377 def __init__(self, value, name, synchro=None, func=None, dependants=None, dependencies=None, tainted=None, base=None):
378 """Initialises depend_array.
380 Note that this is only called when a new array is created by an
381 explicit constructor.
383 Args:
384 name: A string giving the name of self.
385 value: The (numpy) array to serve as the memory base.
386 tainted: An optional array giving the tainted flag. Default is [True].
387 func: An optional argument that can be specified either by a function
388 name, or for synchronized values a dictionary of the form
389 {"name": function name}; where "name" is one of the other
390 synched objects and function name is the name of a function to
391 get the object "name" from self.
392 synchro: An optional synchronizer object.
393 dependants: An optional list containing objects that depend on self.
394 dependencies: An optional list containing objects that self
395 depends upon.
398 super(depend_array,self).__init__(name, synchro, func, dependants, dependencies, tainted)
400 if base is None:
401 self._bval = value
402 else:
403 self._bval = base
405 def copy(self, order='C', maskna=None):
406 """Wrapper for numpy copy mechanism."""
408 # Sets a flag and hands control to the numpy copy
409 self._fcopy = True
410 return super(depend_array,self).copy(order)
412 def __array_finalize__(self, obj):
413 """Deals with properly creating some arrays.
415 In the case where a function acting on a depend array returns a ndarray,
416 this casts it into the correct form and gives it the
417 depend machinery for other methods to be able to act upon it. New
418 depend_arrays will next be passed to __init__ ()to be properly
419 initialized, but some ways of creating arrays do not call __new__() or
420 __init__(), so need to be initialized.
423 depend_base.__init__(self, name="")
425 if type(obj) is depend_array:
426 # We are in a view cast or in new from template. Unfortunately
427 # there is no sure way to tell (or so it seems). Hence we need to
428 # handle special cases, and hope we are in a view cast otherwise.
429 if hasattr(obj,"_fcopy"):
430 del(obj._fcopy) # removes the "copy flag"
431 self._bval = depstrip(self)
432 else:
433 # Assumes we are in view cast, so copy over the attributes from the
434 # parent object. Typical case: when transpose is performed as a
435 # view.
436 super(depend_array,self).__init__(obj._name, obj._synchro, obj._func, obj._dependants, None, obj._tainted)
437 self._bval = obj._bval
438 else:
439 # Most likely we came here on the way to init.
440 # Just sets a defaults for safety
441 self._bval = depstrip(self)
444 def __array_prepare__(self, arr, context=None):
445 """Prepare output array for ufunc.
447 Depending on the context we try to understand if we are doing an
448 in-place operation (in which case we want to keep the return value a
449 deparray) or we are generating a new array as a result of the ufunc.
450 In this case there is no way to know if dependencies should be copied,
451 so we strip and return a ndarray.
454 if context is None or len(context) < 2 or not type(context[0]) is np.ufunc:
455 # It is not clear what we should do. If in doubt, strip dependencies.
456 return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context)
457 elif len(context[1]) > context[0].nin and context[0].nout > 0:
458 # We are being called by a ufunc with a output argument, which is being
459 # actually used. Most likely, something like an increment,
460 # so we pass on a deparray
461 return super(depend_array,self).__array_prepare__(arr,context)
462 else:
463 # Apparently we are generating a new array.
464 # We have no way of knowing its
465 # dependencies, so we'd better return a ndarray view!
466 return np.ndarray.__array_prepare__(self.view(np.ndarray),arr.view(np.ndarray),context)
468 def __array_wrap__(self, arr, context=None):
469 """ Wraps up output array from ufunc.
471 See docstring of __array_prepare__().
474 if context is None or len(context) < 2 or not type(context[0]) is np.ufunc:
475 return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context)
476 elif len(context[1]) > context[0].nin and context[0].nout > 0:
477 return super(depend_array,self).__array_wrap__(arr,context)
478 else:
479 return np.ndarray.__array_wrap__(self.view(np.ndarray),arr.view(np.ndarray),context)
481 # whenever possible in compound operations just return a regular ndarray
482 __array_priority__ = -1.0
484 def reshape(self, newshape):
485 """Changes the shape of the base array.
487 Args:
488 newshape: A tuple giving the desired shape of the new array.
490 Returns:
491 A depend_array with the dimensions given by newshape.
494 return depend_array(depstrip(self).reshape(newshape), name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval)
496 def flatten(self):
497 """Makes the base array one dimensional.
499 Returns:
500 A flattened array.
503 return self.reshape(self.size)
505 @staticmethod
506 def __scalarindex(index, depth=1):
507 """Checks if an index points at a scalar value.
509 Used so that looking up one item in an array returns a scalar, whereas
510 looking up a slice of the array returns a new array with the same
511 dependencies as the original, so that changing the slice also taints
512 the global array.
514 Arguments:
515 index: the index to be checked.
516 depth: the rank of the array which is being accessed. Default value
517 is 1.
519 Returns:
520 A logical stating whether a __get__ instruction based
521 on index would return a scalar.
524 if (np.isscalar(index) and depth <= 1):
525 return True
526 elif (isinstance(index, tuple) and len(index)==depth):
527 #if the index is a tuple check it does not contain slices
528 for i in index:
529 if not np.isscalar(i): return False
530 return True
531 return False
533 def __getitem__(self,index):
534 """Returns value[index], after recalculating if necessary.
536 Overwrites the standard method of getting value, so that value
537 is recalculated if tainted. Scalar slices are returned as an ndarray,
538 so without depend machinery. If you need a "scalar depend" which
539 behaves as a slice, just create a 1x1 matrix, e.g b=a(7,1:2)
541 Args:
542 index: A slice variable giving the appropriate slice to be read.
545 if self._tainted[0]:
546 self.update_auto()
547 self.taint(taintme=False)
549 if (self.__scalarindex(index, self.ndim)):
550 return depstrip(self)[index]
551 else:
552 return depend_array(depstrip(self)[index], name=self._name, synchro=self._synchro, func=self._func, dependants=self._dependants, tainted=self._tainted, base=self._bval)
555 def __getslice__(self,i,j):
556 """Overwrites standard get function."""
558 return self.__getitem__(slice(i,j,None))
560 def get(self):
561 """Alternative to standard get function."""
563 return self.__get__(slice(None,None,None))
565 def __get__(self, instance, owner):
566 """Overwrites standard get function."""
568 # It is worth duplicating this code that is also used in __getitem__ as this
569 # is called most of the time, and we avoid creating a load of copies pointing to the same depend_array
571 if self._tainted[0]:
572 self.update_auto()
573 self.taint(taintme=False)
575 return self
577 def __setitem__(self,index,value,manual=True):
578 """Alters value[index] and taints dependencies.
580 Overwrites the standard method of setting value, so that dependent
581 quantities are tainted, and so we check that computed quantities are not
582 manually updated.
584 Args:
585 index: A slice variable giving the appropriate slice to be read.
586 value: The new value of the slice.
587 manual: Optional boolean giving whether the value has been changed
588 manually. True by default.
591 self.taint(taintme=False)
592 if manual:
593 depstrip(self)[index] = value
594 self.update_man()
595 elif index == slice(None,None,None):
596 self._bval[index] = value
597 else:
598 raise IndexError("Automatically computed arrays should span the whole parent")
600 def __setslice__(self,i,j,value):
601 """Overwrites standard set function."""
603 return self.__setitem__(slice(i,j),value)
605 def set(self, value, manual=True):
606 """Alterative to standard set function.
608 Args:
609 See __setitem__().
612 self.__setitem__(slice(None,None),value=value,manual=manual)
614 def __set__(self, instance, value):
615 """Overwrites standard set function."""
617 self.__setitem__(slice(None,None),value=value)
620 # np.dot and other numpy.linalg functions have the nasty habit to
621 # view cast to generate the output. Since we don't want to pass on
622 # dependencies to the result of these functions, and we can't use
623 # the ufunc mechanism to demote the class type to ndarray, we must
624 # overwrite np.dot and other similar functions.
625 # BEGINS NUMPY FUNCTIONS OVERRIDE
626 # ** np.dot
627 __dp_dot = np.dot
629 def dep_dot(da, db):
630 a=depstrip(da)
631 b=depstrip(db)
633 return __dp_dot(da,db)
635 np.dot = dep_dot
636 # ENDS NUMPY FUNCTIONS OVERRIDE
638 def dget(obj,member):
639 """Takes an object and retrieves one of its attributes.
641 Note that this is necessary as calling it in the standard way calls the
642 __get__() function of member.
644 Args:
645 obj: A user defined class.
646 member: A string giving the name of an attribute of obj.
648 Exceptions:
649 KeyError: If member is not an attribute of obj.
651 Returns:
652 obj.member.
655 return obj.__dict__[member]
657 def dset(obj,member,value,name=None):
658 """Takes an object and sets one of its attributes.
660 Necessary for editing any depend object, and should be used for
661 initialising them as well, as often initialization occurs more than once,
662 with the second time effectively being an edit.
664 Args:
665 obj: A user defined class.
666 member: A string giving the name of an attribute of obj.
667 value: The new value of member.
668 name: New name of member.
670 Exceptions:
671 KeyError: If member is not an attribute of obj.
674 obj.__dict__[member] = value
675 if not name is None:
676 obj.__dict__[member]._name = name
678 def depstrip(da):
679 """Removes dependencies from a depend_array.
681 Takes a depend_array and returns its value as a ndarray, effectively
682 stripping the dependencies from the ndarray. This speeds up a lot of
683 calculations involving these arrays. Must only be used if the value of the
684 array is not going to be changed.
686 Args:
687 deparray: A depend_array.
689 Returns:
690 A ndarray with the same value as deparray.
693 if isinstance(da, depend_array): # only bother to strip dependencies if the array actually IS a depend_array
694 #if da._tainted[0]:
695 # print "!!! WARNING depstrip called on tainted array WARNING !!!!!" # I think we can safely assume that when we call depstrip the array has been cleared already but I am not 100% sure so better check - and in case raise the update
696 return da.view(np.ndarray)
697 else:
698 return da
700 def deppipe(objfrom,memberfrom,objto,memberto):
701 """Synchronizes two depend objects.
703 Takes two depend objects, and makes one of them depend on the other in such
704 a way that both keep the same value. Used for attributes such as temperature
705 that are used in many different modules, and so need different depend objects
706 in each, but which should all have the same value.
708 Args:
709 objfrom: An object containing memberfrom.
710 memberfrom: The base depend object.
711 objto: An object containing memberto.
712 memberto: The depend object that should be equal to memberfrom.
715 dfrom = dget(objfrom,memberfrom)
716 dto = dget(objto,memberto)
717 dto._func = lambda : dfrom.get()
718 dto.add_dependency(dfrom)
720 def depcopy(objfrom,memberfrom,objto,memberto):
721 """Copies the dependencies of one depend object to another.
723 Args:
724 See deppipe.
726 dfrom = dget(objfrom,memberfrom)
727 dto = dget(objto,memberto)
728 dto._dependants = dfrom._dependants
729 dto._synchro = dfrom._synchro
730 dto.add_synchro(dfrom._synchro)
731 dto._tainted = dfrom._tainted
732 dto._func = dfrom._func
733 if hasattr(dfrom,"_bval"):
734 dto._bval = dfrom._bval
737 class dobject(object):
738 """Class that allows standard notation to be used for depend objects."""
740 def __getattribute__(self, name):
741 """Overwrites standard __getattribute__().
743 This changes the standard __getattribute__() function of any class that
744 subclasses dobject such that depend objects are called with their own
745 __get__() function rather than the standard one.
748 value = object.__getattribute__(self, name)
749 if hasattr(value, '__get__'):
750 value = value.__get__(self, self.__class__)
751 return value
753 def __setattr__(self, name, value):
754 """Overwrites standard __setattribute__().
756 This changes the standard __setattribute__() function of any class that
757 subclasses dobject such that depend objects are called with their own
758 __set__() function rather than the standard one.
761 try:
762 obj = object.__getattribute__(self, name)
763 except AttributeError:
764 pass
765 else:
766 if hasattr(obj, '__set__'):
767 return obj.__set__(self, value)
768 return object.__setattr__(self, name, value)