2 from functools
import singledispatch
5 from datetime
import datetime
13 from urllib
.parse
import urlparse
16 from itertools
import islice
20 A dictionary that allows member access to its keys.
24 def __init__(self
, d
):
26 Updates itself with d.
30 def __getattr__(self
, item
):
33 def __setattr__(self
, item
, value
):
36 def __getitem__(self
, item
):
38 return super().__getitem
__(item
)
41 if isinstance(key
,(range,tuple)) and item
in key
:
42 return super().__getitem
__(key
)
46 if any([isinstance(key
,(range,tuple)) for key
in self
]):
49 if isinstance(key
,(range,tuple)):
59 with
open(file,'wb') as output
:
60 dill
.dump(obj
, output
)
63 with
open(file,'rb') as input:
64 returnitem
= dill
.load(input)
67 # Utility to retrieve files from URL
68 def retrieve_url(url
, dest_path
, force_download
=False):
70 Downloads a file from a specified URL to a destination path.
75 The URL from which to download the file.
77 The destination path where the file should be saved.
78 force_download : bool, optional
79 If True, forces the download even if the file already exists at the destination path.
84 Prints a warning if the file extension of the URL does not match the destination file extension.
89 If the download fails and the file does not exist at the destination path.
93 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
94 installed and accessible from the system's PATH.
98 A message indicating whether the file was downloaded or if it already exists at the
101 if not osp
.exists(dest_path
) or force_download
:
102 print(f
"Attempting to downloaded {url} to {dest_path}")
103 target_extension
= osp
.splitext(dest_path
)[1]
104 url_extension
= osp
.splitext(urlparse(url
).path
)[1]
105 if target_extension
!= url_extension
:
106 print("Warning: file extension from url does not match destination file extension")
107 subprocess
.call(f
"wget -O {dest_path} {url}", shell
=True)
108 assert osp
.exists(dest_path
)
109 print(f
"Successfully downloaded {url} to {dest_path}")
111 print(f
"Target data already exists at {dest_path}")
113 # Function to check if lists are nested, or all elements in given list are in target list
114 def all_items_exist(source_list
, target_list
):
116 Checks if all items from the source list exist in the target list.
121 The list of items to check for existence in the target list.
123 The list in which to check for the presence of items from the source list.
128 True if all items in the source list are present in the target list, False otherwise.
132 >>> source_list = [1, 2, 3]
133 >>> target_list = [1, 2, 3, 4, 5]
134 >>> all_items_exist(source_list, target_list)
137 >>> source_list = [1, 2, 6]
138 >>> all_items_exist(source_list, target_list)
141 return all(item
in target_list
for item
in source_list
)
143 # Generic helper function to read yaml files
144 def read_yml(yaml_path
, subkey
=None):
146 Reads a YAML file and optionally retrieves a specific subkey.
151 The path to the YAML file to be read.
152 subkey : str, optional
153 A specific key within the YAML file to retrieve. If provided, only the value associated
154 with this key will be returned. If not provided, the entire YAML file is returned as a
155 dictionary. Default is None.
160 The contents of the YAML file as a dictionary, or the value associated with the specified
164 with
open(yaml_path
, 'r') as file:
165 d
= yaml
.safe_load(file)
166 if subkey
is not None:
170 # Use to load nested fmda dictionary of cases
171 def load_and_fix_data(filename
):
172 # Given path to FMDA training dictionary, read and return cleaned dictionary
174 # filename: (str) path to file with .pickle extension
176 # FMDA dictionary with NA values "fixed"
177 print(f
"loading file {filename}")
178 with
open(filename
, 'rb') as handle
:
179 test_dict
= pickle
.load(handle
)
180 for case
in test_dict
:
181 test_dict
[case
]['case'] = case
182 test_dict
[case
]['filename'] = filename
183 for key
in test_dict
[case
].keys():
184 var
= test_dict
[case
][key
] # pointer to test_dict[case][key]
185 if isinstance(var
,np
.ndarray
) and (var
.dtype
.kind
== 'f'):
186 nans
= np
.sum(np
.isnan(var
))
188 print('WARNING: case',case
,'variable',key
,'shape',var
.shape
,'has',nans
,'nan values, fixing')
190 nans
= np
.sum(np
.isnan(test_dict
[case
][key
]))
191 print('After fixing, remained',nans
,'nan values')
192 if not 'title' in test_dict
[case
].keys():
193 test_dict
[case
]['title']=case
194 if not 'descr' in test_dict
[case
].keys():
195 test_dict
[case
]['descr']=f
"{case} FMDA dictionary"
198 # Generic helper function to read pickle files
199 def read_pkl(file_path
):
201 Reads a pickle file and returns its contents.
206 The path to the pickle file to be read.
211 The object stored in the pickle file.
215 A message indicating the file path being loaded.
219 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
220 that the pickle file was created in a safe and trusted environment to avoid security risks
221 associated with loading arbitrary code.
224 with
open(file_path
, 'rb') as file:
225 print(f
"loading file {file_path}")
226 d
= pickle
.load(file)
232 format
='%(asctime)s - %(levelname)s - %(message)s',
236 numeric_kinds
= {'i', 'u', 'f', 'c'}
238 def is_numeric_ndarray(array
):
239 if isinstance(array
, np
.ndarray
):
240 return array
.dtype
.kind
in numeric_kinds
247 frame
= inspect
.currentframe()
248 if 'verbose' in frame
.f_back
.f_locals
:
249 verbose
= frame
.f_back
.f_locals
['verbose']
254 for s
in args
[:(len(args
)-1)]:
259 ## Function for Hashing numpy arrays
260 def hash_ndarray(arr
: np
.ndarray
) -> str:
262 Generates a unique hash string for a NumPy ndarray.
267 The NumPy array to be hashed.
272 A hexadecimal string representing the MD5 hash of the array.
276 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
277 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
281 >>> arr = np.array([1, 2, 3])
282 >>> hash_value = hash_ndarray(arr)
283 >>> print(hash_value)
284 '2a1dd1e1e59d0a384c26951e316cd7e6'
286 # If input is list, attempt to concatenate and then hash
287 if type(arr
) == list:
289 arr_bytes
= arr
.tobytes()
291 # Convert the array to a bytes string
292 arr_bytes
= arr
.tobytes()
293 # Use hashlib to generate a unique hash
294 hash_obj
= hashlib
.md5(arr_bytes
)
295 return hash_obj
.hexdigest()
297 ## Function for Hashing tensorflow models
298 def hash_weights(model
):
300 Generates a unique hash string for a the weights of a given Keras model.
304 model : A keras model
305 The Keras model to be hashed.
310 A hexadecimal string representing the MD5 hash of the model weights.
313 # Extract all weights and biases
314 weights
= model
.get_weights()
316 # Convert each weight array to a string
317 weight_str
= ''.join([np
.array2string(w
, separator
=',') for w
in weights
])
319 # Generate a SHA-256 hash of the combined string
320 weight_hash
= hashlib
.md5(weight_str
.encode('utf-8')).hexdigest()
324 ## Generic function to hash dictionary of various types
327 ## Top level hash function with built-in hash function for str, float, int, etc
331 @hash2.register(np
.ndarray
)
332 ## Hash numpy array, hash array with pandas and return integer sum
334 # return hash(x.tobytes())
335 return np
.sum(pd
.util
.hash_array(x
))
337 @hash2.register(list)
338 ## Hash list, convert to tuple
340 return hash2(tuple(x
))
342 @hash2.register(tuple)
345 for i
in range(len(x
)):
349 @hash2.register(dict)
350 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
351 def _(x
, keys
= None, verbose
= False):
352 r
= 0 # return value integer
353 if keys
is None: # allow user input of keys to hash, otherwise hash them all
357 if (verbose
): print('Hashing', key
)
361 def print_args(func
, *args
, **kwargs
):
362 # wrapper to trace function call and arguments
363 print(f
"Called: {func.__name__}")
367 for key
, value
in kwargs
.items():
368 print(f
" {key}={value}")
369 return func(*args
, **kwargs
)
371 def print_args_test():
372 def my_function(a
, b
):
375 print_args(my_function
, a
=1, b
=2)
378 def get_item(dict,var
,**kwargs
):
381 elif 'default' in kwargs
:
382 value
= kwargs
['default']
384 logging
.error('Variable %s not in the dictionary and no default',var
)
386 logging
.info('%s = %s',var
,value
)
389 def print_first(item_list
,num
=3,indent
=0,id=None):
391 Print the first num items of the list followed by '...'
393 :param item_list: List of items to be printed
394 :param num: number of items to list
396 indent_str
= ' ' * indent
398 print(indent_str
, id)
399 if len(item_list
) > 0:
400 print(indent_str
,type(item_list
[0]))
401 for i
in range(min(num
,len(item_list
))):
402 print(indent_str
,item_list
[i
])
403 if len(item_list
) > num
:
404 print(indent_str
,'...')
406 def print_dict_summary(d
,indent
=0,first
=[],first_num
=3):
408 Prints a summary for each array in the dictionary, showing the key and the size of the array.
411 d (dict): The dictionary to summarize.
412 first_items (list): Print the first items for any arrays with these names
415 indent_str
= ' ' * indent
416 for key
, value
in d
.items():
417 # Check if the value is list-like using a simple method check
418 if isinstance(value
, dict):
419 print(f
"{indent_str}{key}")
420 print_dict_summary(value
,first
=first
,indent
=indent
+5,first_num
=first_num
)
421 elif isinstance(value
,np
.ndarray
):
422 if np
.issubdtype(value
.dtype
, np
.number
):
423 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
425 # Handle non-numeric arrays differently
426 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
427 elif hasattr(value
, "__iter__") and not isinstance(value
, str): # Check for iterable that is not a string
428 print(f
"{indent_str}{key}: Array of {len(value)} items")
430 print(indent_str
,key
,":",value
)
432 print_first(value
,num
=first_num
,indent
=indent
+5)
435 from datetime
import datetime
439 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
441 if isinstance(input, str):
442 return datetime
.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
443 elif isinstance(input, list):
444 return [str2time(s
) for s
in input]
446 raise ValueError("Input must be a string or a list of strings")
449 # interpolate linearly over nans
451 def filter_nan_values(t1
, v1
):
452 # Filter out NaN values from v1 and corresponding times in t1
453 valid_indices
= ~np
.isnan(v1
) # Indices where v1 is not NaN
454 t1_filtered
= np
.array(t1
)[valid_indices
]
455 v1_filtered
= np
.array(v1
)[valid_indices
]
456 return t1_filtered
, v1_filtered
458 def time_intp(t1
, v1
, t2
):
459 # Check if t1 v1 t2 are 1D arrays
461 logging
.error("Error: t1 is not a 1D array. Dimension: %s", t1
.ndim
)
464 logging
.error("Error: v1 is not a 1D array. Dimension %s:", v1
.ndim
)
467 logging
.errorr("Error: t2 is not a 1D array. Dimension: %s", t2
.ndim
)
469 # Check if t1 and v1 have the same length
470 if len(t1
) != len(v1
):
471 logging
.error("Error: t1 and v1 have different lengths: %s %s",len(t1
),len(v1
))
473 t1_no_nan
, v1_no_nan
= filter_nan_values(t1
, v1
)
474 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
475 # Convert datetime objects to timestamps
476 t1_stamps
= np
.array([t
.timestamp() for t
in t1_no_nan
])
477 t2_stamps
= np
.array([t
.timestamp() for t
in t2
])
479 # Interpolate using the filtered data
480 v2_interpolated
= np
.interp(t2_stamps
, t1_stamps
, v1_no_nan
)
481 if np
.isnan(v2_interpolated
).any():
482 logging
.error('time_intp: interpolated output contains NaN')
484 return v2_interpolated
486 def str2time(strlist
):
487 # Convert array of strings to array of datetime objects
488 return np
.array([datetime
.strptime(dt_str
, '%Y-%m-%dT%H:%M:%SZ') for dt_str
in strlist
])
490 def check_increment(datetime_array
,id=''):
491 # Calculate time differences between consecutive datetime values
492 diffs
= [b
- a
for a
, b
in zip(datetime_array
[:-1], datetime_array
[1:])]
493 diffs_hours
= np
.array([diff
.total_seconds()/3600 for diff
in diffs
])
494 # Check if all time differences are exactlyu 1 hour
495 if all(diffs_hours
== diffs_hours
[0]):
496 logging
.info('%s time array increments are %s hours',id,diffs_hours
[0])
497 if diffs_hours
[0] <= 0 :
498 logging
.error('%s time array increements are not positive',id)
499 return diffs_hours
[0]
501 logging
.info('%s time array increments are min %s max %s',id,
502 np
.min(diffs_hours
),np
.max(diffs_hours
))