2 from functools
import singledispatch
5 from datetime
import datetime
13 from urllib
.parse
import urlparse
16 from itertools
import islice
20 A dictionary that allows member access to its keys.
24 def __init__(self
, d
):
26 Updates itself with d.
30 def __getattr__(self
, item
):
33 def __setattr__(self
, item
, value
):
36 def __getitem__(self
, item
):
38 return super().__getitem
__(item
)
41 if isinstance(key
,(range,tuple)) and item
in key
:
42 return super().__getitem
__(key
)
46 if any([isinstance(key
,(range,tuple)) for key
in self
]):
49 if isinstance(key
,(range,tuple)):
59 with
open(file,'wb') as output
:
60 dill
.dump(obj
, output
)
63 with
open(file,'rb') as input:
64 returnitem
= dill
.load(input)
67 # Utility to retrieve files from URL
68 def retrieve_url(url
, dest_path
, force_download
=False):
70 Downloads a file from a specified URL to a destination path.
75 The URL from which to download the file.
77 The destination path where the file should be saved.
78 force_download : bool, optional
79 If True, forces the download even if the file already exists at the destination path.
84 Prints a warning if the file extension of the URL does not match the destination file extension.
89 If the download fails and the file does not exist at the destination path.
93 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
94 installed and accessible from the system's PATH.
98 A message indicating whether the file was downloaded or if it already exists at the
101 if not osp
.exists(dest_path
) or force_download
:
102 target_extension
= osp
.splitext(dest_path
)[1]
103 url_extension
= osp
.splitext(urlparse(url
).path
)[1]
104 if target_extension
!= url_extension
:
105 print("Warning: file extension from url does not match destination file extension")
106 subprocess
.call(f
"wget -O {dest_path} {url}", shell
=True)
107 assert osp
.exists(dest_path
)
108 print(f
"Successfully downloaded {url} to {dest_path}")
110 print(f
"Target data already exists at {dest_path}")
112 # Function to check if lists are nested, or all elements in given list are in target list
113 def all_items_exist(source_list
, target_list
):
115 Checks if all items from the source list exist in the target list.
120 The list of items to check for existence in the target list.
122 The list in which to check for the presence of items from the source list.
127 True if all items in the source list are present in the target list, False otherwise.
131 >>> source_list = [1, 2, 3]
132 >>> target_list = [1, 2, 3, 4, 5]
133 >>> all_items_exist(source_list, target_list)
136 >>> source_list = [1, 2, 6]
137 >>> all_items_exist(source_list, target_list)
140 return all(item
in target_list
for item
in source_list
)
142 # Generic helper function to read yaml files
143 def read_yml(yaml_path
, subkey
=None):
145 Reads a YAML file and optionally retrieves a specific subkey.
150 The path to the YAML file to be read.
151 subkey : str, optional
152 A specific key within the YAML file to retrieve. If provided, only the value associated
153 with this key will be returned. If not provided, the entire YAML file is returned as a
154 dictionary. Default is None.
159 The contents of the YAML file as a dictionary, or the value associated with the specified
163 with
open(yaml_path
, 'r') as file:
164 d
= yaml
.safe_load(file)
165 if subkey
is not None:
169 # Use to load nested fmda dictionary of cases
170 def load_and_fix_data(filename
):
171 # Given path to FMDA training dictionary, read and return cleaned dictionary
173 # filename: (str) path to file with .pickle extension
175 # FMDA dictionary with NA values "fixed"
176 print(f
"loading file {filename}")
177 with
open(filename
, 'rb') as handle
:
178 test_dict
= pickle
.load(handle
)
179 for case
in test_dict
:
180 test_dict
[case
]['case'] = case
181 test_dict
[case
]['filename'] = filename
182 for key
in test_dict
[case
].keys():
183 var
= test_dict
[case
][key
] # pointer to test_dict[case][key]
184 if isinstance(var
,np
.ndarray
) and (var
.dtype
.kind
== 'f'):
185 nans
= np
.sum(np
.isnan(var
))
187 print('WARNING: case',case
,'variable',key
,'shape',var
.shape
,'has',nans
,'nan values, fixing')
189 nans
= np
.sum(np
.isnan(test_dict
[case
][key
]))
190 print('After fixing, remained',nans
,'nan values')
191 if not 'title' in test_dict
[case
].keys():
192 test_dict
[case
]['title']=case
193 if not 'descr' in test_dict
[case
].keys():
194 test_dict
[case
]['descr']=f
"{case} FMDA dictionary"
197 # Generic helper function to read pickle files
198 def read_pkl(file_path
):
200 Reads a pickle file and returns its contents.
205 The path to the pickle file to be read.
210 The object stored in the pickle file.
214 A message indicating the file path being loaded.
218 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
219 that the pickle file was created in a safe and trusted environment to avoid security risks
220 associated with loading arbitrary code.
223 with
open(file_path
, 'rb') as file:
224 print(f
"loading file {file_path}")
225 d
= pickle
.load(file)
231 format
='%(asctime)s - %(levelname)s - %(message)s',
235 numeric_kinds
= {'i', 'u', 'f', 'c'}
237 def is_numeric_ndarray(array
):
238 if isinstance(array
, np
.ndarray
):
239 return array
.dtype
.kind
in numeric_kinds
246 frame
= inspect
.currentframe()
247 if 'verbose' in frame
.f_back
.f_locals
:
248 verbose
= frame
.f_back
.f_locals
['verbose']
253 for s
in args
[:(len(args
)-1)]:
258 ## Function for Hashing numpy arrays
259 def hash_ndarray(arr
: np
.ndarray
) -> str:
261 Generates a unique hash string for a NumPy ndarray.
266 The NumPy array to be hashed.
271 A hexadecimal string representing the MD5 hash of the array.
275 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
276 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
280 >>> arr = np.array([1, 2, 3])
281 >>> hash_value = hash_ndarray(arr)
282 >>> print(hash_value)
283 '2a1dd1e1e59d0a384c26951e316cd7e6'
285 # If input is list, attempt to concatenate and then hash
286 if type(arr
) == list:
288 arr_bytes
= arr
.tobytes()
290 # Convert the array to a bytes string
291 arr_bytes
= arr
.tobytes()
292 # Use hashlib to generate a unique hash
293 hash_obj
= hashlib
.md5(arr_bytes
)
294 return hash_obj
.hexdigest()
296 ## Function for Hashing tensorflow models
297 def hash_weights(model
):
299 Generates a unique hash string for a the weights of a given Keras model.
303 model : A keras model
304 The Keras model to be hashed.
309 A hexadecimal string representing the MD5 hash of the model weights.
312 # Extract all weights and biases
313 weights
= model
.get_weights()
315 # Convert each weight array to a string
316 weight_str
= ''.join([np
.array2string(w
, separator
=',') for w
in weights
])
318 # Generate a SHA-256 hash of the combined string
319 weight_hash
= hashlib
.md5(weight_str
.encode('utf-8')).hexdigest()
323 ## Generic function to hash dictionary of various types
326 ## Top level hash function with built-in hash function for str, float, int, etc
330 @hash2.register(np
.ndarray
)
331 ## Hash numpy array, hash array with pandas and return integer sum
333 # return hash(x.tobytes())
334 return np
.sum(pd
.util
.hash_array(x
))
336 @hash2.register(list)
337 ## Hash list, convert to tuple
339 return hash2(tuple(x
))
341 @hash2.register(tuple)
344 for i
in range(len(x
)):
348 @hash2.register(dict)
349 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
350 def _(x
, keys
= None, verbose
= False):
351 r
= 0 # return value integer
352 if keys
is None: # allow user input of keys to hash, otherwise hash them all
356 if (verbose
): print('Hashing', key
)
360 def print_args(func
, *args
, **kwargs
):
361 # wrapper to trace function call and arguments
362 print(f
"Called: {func.__name__}")
366 for key
, value
in kwargs
.items():
367 print(f
" {key}={value}")
368 return func(*args
, **kwargs
)
370 def print_args_test():
371 def my_function(a
, b
):
374 print_args(my_function
, a
=1, b
=2)
377 def get_item(dict,var
,**kwargs
):
380 elif 'default' in kwargs
:
381 value
= kwargs
['default']
383 logging
.error('Variable %s not in the dictionary and no default',var
)
385 logging
.info('%s = %s',var
,value
)
388 def print_first(item_list
,num
=3,indent
=0,id=None):
390 Print the first num items of the list followed by '...'
392 :param item_list: List of items to be printed
393 :param num: number of items to list
395 indent_str
= ' ' * indent
397 print(indent_str
, id)
398 if len(item_list
) > 0:
399 print(indent_str
,type(item_list
[0]))
400 for i
in range(min(num
,len(item_list
))):
401 print(indent_str
,item_list
[i
])
402 if len(item_list
) > num
:
403 print(indent_str
,'...')
405 def print_dict_summary(d
,indent
=0,first
=[],first_num
=3):
407 Prints a summary for each array in the dictionary, showing the key and the size of the array.
410 d (dict): The dictionary to summarize.
411 first_items (list): Print the first items for any arrays with these names
414 indent_str
= ' ' * indent
415 for key
, value
in d
.items():
416 # Check if the value is list-like using a simple method check
417 if isinstance(value
, dict):
418 print(f
"{indent_str}{key}")
419 print_dict_summary(value
,first
=first
,indent
=indent
+5,first_num
=first_num
)
420 elif isinstance(value
,np
.ndarray
):
421 if np
.issubdtype(value
.dtype
, np
.number
):
422 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
424 # Handle non-numeric arrays differently
425 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
426 elif hasattr(value
, "__iter__") and not isinstance(value
, str): # Check for iterable that is not a string
427 print(f
"{indent_str}{key}: Array of {len(value)} items")
429 print(indent_str
,key
,":",value
)
431 print_first(value
,num
=first_num
,indent
=indent
+5)
434 from datetime
import datetime
438 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
440 if isinstance(input, str):
441 return datetime
.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
442 elif isinstance(input, list):
443 return [str2time(s
) for s
in input]
445 raise ValueError("Input must be a string or a list of strings")
448 # interpolate linearly over nans
450 def filter_nan_values(t1
, v1
):
451 # Filter out NaN values from v1 and corresponding times in t1
452 valid_indices
= ~np
.isnan(v1
) # Indices where v1 is not NaN
453 t1_filtered
= np
.array(t1
)[valid_indices
]
454 v1_filtered
= np
.array(v1
)[valid_indices
]
455 return t1_filtered
, v1_filtered
457 def time_intp(t1
, v1
, t2
):
458 # Check if t1 v1 t2 are 1D arrays
460 logging
.error("Error: t1 is not a 1D array. Dimension: %s", t1
.ndim
)
463 logging
.error("Error: v1 is not a 1D array. Dimension %s:", v1
.ndim
)
466 logging
.errorr("Error: t2 is not a 1D array. Dimension: %s", t2
.ndim
)
468 # Check if t1 and v1 have the same length
469 if len(t1
) != len(v1
):
470 logging
.error("Error: t1 and v1 have different lengths: %s %s",len(t1
),len(v1
))
472 t1_no_nan
, v1_no_nan
= filter_nan_values(t1
, v1
)
473 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
474 # Convert datetime objects to timestamps
475 t1_stamps
= np
.array([t
.timestamp() for t
in t1_no_nan
])
476 t2_stamps
= np
.array([t
.timestamp() for t
in t2
])
478 # Interpolate using the filtered data
479 v2_interpolated
= np
.interp(t2_stamps
, t1_stamps
, v1_no_nan
)
480 if np
.isnan(v2_interpolated
).any():
481 logging
.error('time_intp: interpolated output contains NaN')
483 return v2_interpolated
485 def str2time(strlist
):
486 # Convert array of strings to array of datetime objects
487 return np
.array([datetime
.strptime(dt_str
, '%Y-%m-%dT%H:%M:%SZ') for dt_str
in strlist
])
489 def check_increment(datetime_array
,id=''):
490 # Calculate time differences between consecutive datetime values
491 diffs
= [b
- a
for a
, b
in zip(datetime_array
[:-1], datetime_array
[1:])]
492 diffs_hours
= np
.array([diff
.total_seconds()/3600 for diff
in diffs
])
493 # Check if all time differences are exactlyu 1 hour
494 if all(diffs_hours
== diffs_hours
[0]):
495 logging
.info('%s time array increments are %s hours',id,diffs_hours
[0])
496 if diffs_hours
[0] <= 0 :
497 logging
.error('%s time array increements are not positive',id)
498 return diffs_hours
[0]
500 logging
.info('%s time array increments are min %s max %s',id,
501 np
.min(diffs_hours
),np
.max(diffs_hours
))