1 # -*- coding: utf-8 -*-
2 # Copyright 2013 Google Inc. All Rights Reserved.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 """Utility classes for the parallelism framework."""
17 from __future__
import absolute_import
19 import multiprocessing
23 class BasicIncrementDict(object):
24 """Dictionary meant for storing values for which increment is defined.
26 This handles any values for which the "+" operation is defined (e.g., floats,
27 lists, etc.). This class is neither thread- nor process-safe.
33 def Get(self
, key
, default_value
=None):
34 return self
.dict.get(key
, default_value
)
36 def Put(self
, key
, value
):
37 self
.dict[key
] = value
39 def Update(self
, key
, inc
, default_value
=0):
40 """Update the stored value associated with the given key.
42 Performs the equivalent of
43 self.put(key, self.get(key, default_value) + inc).
46 key: lookup key for the value of the first operand of the "+" operation.
47 inc: Second operand of the "+" operation.
48 default_value: Default value if there is no existing value for the key.
53 val
= self
.dict.get(key
, default_value
) + inc
58 class AtomicIncrementDict(BasicIncrementDict
):
59 """Dictionary meant for storing values for which increment is defined.
61 This handles any values for which the "+" operation is defined (e.g., floats,
62 lists, etc.) in a thread- and process-safe way that allows for atomic get,
66 def __init__(self
, manager
): # pylint: disable=super-init-not-called
67 self
.dict = ThreadAndProcessSafeDict(manager
)
68 self
.lock
= multiprocessing
.Lock()
70 def Update(self
, key
, inc
, default_value
=0):
71 """Atomically update the stored value associated with the given key.
73 Performs the atomic equivalent of
74 self.put(key, self.get(key, default_value) + inc).
77 key: lookup key for the value of the first operand of the "+" operation.
78 inc: Second operand of the "+" operation.
79 default_value: Default value if there is no existing value for the key.
85 return super(AtomicIncrementDict
, self
).Update(key
, inc
, default_value
)
88 class ThreadSafeDict(object):
89 """Provides a thread-safe dictionary (protected by a lock)."""
92 """Initializes the thread-safe dict."""
93 self
.lock
= threading
.Lock()
96 def __getitem__(self
, key
):
100 def __setitem__(self
, key
, value
):
102 self
.dict[key
] = value
104 # pylint: disable=invalid-name
105 def get(self
, key
, default_value
=None):
107 return self
.dict.get(key
, default_value
)
109 def delete(self
, key
):
114 class ThreadAndProcessSafeDict(ThreadSafeDict
):
115 """Wraps a multiprocessing.Manager's proxy objects for thread-safety.
117 The proxy objects returned by a manager are process-safe but not necessarily
118 thread-safe, so this class simply wraps their access with a lock for ease of
119 use. Since the objects are process-safe, we can use the more efficient
123 def __init__(self
, manager
):
124 """Initializes the thread and process safe dict.
127 manager: Multiprocessing.manager object.
129 super(ThreadAndProcessSafeDict
, self
).__init
__()
130 self
.dict = manager
.dict()