2 from functools
import singledispatch
5 from datetime
import datetime
13 from urllib
.parse
import urlparse
16 from itertools
import islice
18 def rmse_3d(preds
, y_test
):
20 Calculate RMSE for ndarrays structured as (batch_size, timesteps, features).
21 The first dimension, batch_size, could denote distinct locations. The second, timesteps, is length of sequence
23 squared_diff
= np
.square(preds
- y_test
)
25 # Mean squared error along the timesteps and dimensions (axis 1 and 2)
26 mse
= np
.mean(squared_diff
, axis
=(1, 2))
28 # Root mean squared error (RMSE) for each timeseries
36 A dictionary that allows member access to its keys.
40 def __init__(self
, d
):
42 Updates itself with d.
46 def __getattr__(self
, item
):
49 def __setattr__(self
, item
, value
):
52 def __getitem__(self
, item
):
54 return super().__getitem
__(item
)
57 if isinstance(key
,(range,tuple)) and item
in key
:
58 return super().__getitem
__(key
)
62 if any([isinstance(key
,(range,tuple)) for key
in self
]):
65 if isinstance(key
,(range,tuple)):
75 with
open(file,'wb') as output
:
76 dill
.dump(obj
, output
)
79 with
open(file,'rb') as input:
80 returnitem
= dill
.load(input)
83 # Utility to retrieve files from URL
84 def retrieve_url(url
, dest_path
, force_download
=False):
86 Downloads a file from a specified URL to a destination path.
91 The URL from which to download the file.
93 The destination path where the file should be saved.
94 force_download : bool, optional
95 If True, forces the download even if the file already exists at the destination path.
100 Prints a warning if the file extension of the URL does not match the destination file extension.
105 If the download fails and the file does not exist at the destination path.
109 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
110 installed and accessible from the system's PATH.
114 A message indicating whether the file was downloaded or if it already exists at the
117 if not osp
.exists(dest_path
) or force_download
:
118 print(f
"Attempting to downloaded {url} to {dest_path}")
119 target_extension
= osp
.splitext(dest_path
)[1]
120 url_extension
= osp
.splitext(urlparse(url
).path
)[1]
121 if target_extension
!= url_extension
:
122 print("Warning: file extension from url does not match destination file extension")
123 subprocess
.call(f
"wget -O {dest_path} {url}", shell
=True)
124 assert osp
.exists(dest_path
)
125 print(f
"Successfully downloaded {url} to {dest_path}")
127 print(f
"Target data already exists at {dest_path}")
129 # Function to check if lists are nested, or all elements in given list are in target list
130 def all_items_exist(source_list
, target_list
):
132 Checks if all items from the source list exist in the target list.
137 The list of items to check for existence in the target list.
139 The list in which to check for the presence of items from the source list.
144 True if all items in the source list are present in the target list, False otherwise.
148 >>> source_list = [1, 2, 3]
149 >>> target_list = [1, 2, 3, 4, 5]
150 >>> all_items_exist(source_list, target_list)
153 >>> source_list = [1, 2, 6]
154 >>> all_items_exist(source_list, target_list)
157 return all(item
in target_list
for item
in source_list
)
159 # Generic helper function to read yaml files
160 def read_yml(yaml_path
, subkey
=None):
162 Reads a YAML file and optionally retrieves a specific subkey.
167 The path to the YAML file to be read.
168 subkey : str, optional
169 A specific key within the YAML file to retrieve. If provided, only the value associated
170 with this key will be returned. If not provided, the entire YAML file is returned as a
171 dictionary. Default is None.
176 The contents of the YAML file as a dictionary, or the value associated with the specified
180 with
open(yaml_path
, 'r') as file:
181 d
= yaml
.safe_load(file)
182 if subkey
is not None:
186 # Use to load nested fmda dictionary of cases
187 def load_and_fix_data(filename
):
188 # Given path to FMDA training dictionary, read and return cleaned dictionary
190 # filename: (str) path to file with .pickle extension
192 # FMDA dictionary with NA values "fixed"
193 print(f
"loading file {filename}")
194 with
open(filename
, 'rb') as handle
:
195 test_dict
= pickle
.load(handle
)
196 for case
in test_dict
:
197 test_dict
[case
]['case'] = case
198 test_dict
[case
]['filename'] = filename
199 for key
in test_dict
[case
].keys():
200 var
= test_dict
[case
][key
] # pointer to test_dict[case][key]
201 if isinstance(var
,np
.ndarray
) and (var
.dtype
.kind
== 'f'):
202 nans
= np
.sum(np
.isnan(var
))
204 print('WARNING: case',case
,'variable',key
,'shape',var
.shape
,'has',nans
,'nan values, fixing')
206 nans
= np
.sum(np
.isnan(test_dict
[case
][key
]))
207 print('After fixing, remained',nans
,'nan values')
208 if not 'title' in test_dict
[case
].keys():
209 test_dict
[case
]['title']=case
210 if not 'descr' in test_dict
[case
].keys():
211 test_dict
[case
]['descr']=f
"{case} FMDA dictionary"
214 # Generic helper function to read pickle files
215 def read_pkl(file_path
):
217 Reads a pickle file and returns its contents.
222 The path to the pickle file to be read.
227 The object stored in the pickle file.
231 A message indicating the file path being loaded.
235 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
236 that the pickle file was created in a safe and trusted environment to avoid security risks
237 associated with loading arbitrary code.
240 with
open(file_path
, 'rb') as file:
241 print(f
"loading file {file_path}")
242 d
= pickle
.load(file)
248 format
='%(asctime)s - %(levelname)s - %(message)s',
252 numeric_kinds
= {'i', 'u', 'f', 'c'}
254 def is_numeric_ndarray(array
):
255 if isinstance(array
, np
.ndarray
):
256 return array
.dtype
.kind
in numeric_kinds
263 frame
= inspect
.currentframe()
264 if 'verbose' in frame
.f_back
.f_locals
:
265 verbose
= frame
.f_back
.f_locals
['verbose']
270 for s
in args
[:(len(args
)-1)]:
275 ## Function for Hashing numpy arrays
276 def hash_ndarray(arr
: np
.ndarray
) -> str:
278 Generates a unique hash string for a NumPy ndarray.
283 The NumPy array to be hashed.
288 A hexadecimal string representing the MD5 hash of the array.
292 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
293 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
297 >>> arr = np.array([1, 2, 3])
298 >>> hash_value = hash_ndarray(arr)
299 >>> print(hash_value)
300 '2a1dd1e1e59d0a384c26951e316cd7e6'
302 # If input is list, attempt to concatenate and then hash
303 if type(arr
) == list:
305 arr_bytes
= arr
.tobytes()
307 # Convert the array to a bytes string
308 arr_bytes
= arr
.tobytes()
309 # Use hashlib to generate a unique hash
310 hash_obj
= hashlib
.md5(arr_bytes
)
311 return hash_obj
.hexdigest()
313 ## Function for Hashing tensorflow models
314 def hash_weights(model
):
316 Generates a unique hash string for a the weights of a given Keras model.
320 model : A keras model
321 The Keras model to be hashed.
326 A hexadecimal string representing the MD5 hash of the model weights.
329 # Extract all weights and biases
330 weights
= model
.get_weights()
332 # Convert each weight array to a string
333 weight_str
= ''.join([np
.array2string(w
, separator
=',') for w
in weights
])
335 # Generate a SHA-256 hash of the combined string
336 weight_hash
= hashlib
.md5(weight_str
.encode('utf-8')).hexdigest()
340 ## Generic function to hash dictionary of various types
343 ## Top level hash function with built-in hash function for str, float, int, etc
347 @hash2.register(np
.ndarray
)
348 ## Hash numpy array, hash array with pandas and return integer sum
350 # return hash(x.tobytes())
351 return np
.sum(pd
.util
.hash_array(x
))
353 @hash2.register(list)
354 ## Hash list, convert to tuple
356 return hash2(tuple(x
))
358 @hash2.register(tuple)
361 for i
in range(len(x
)):
365 @hash2.register(dict)
366 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
367 def _(x
, keys
= None, verbose
= False):
368 r
= 0 # return value integer
369 if keys
is None: # allow user input of keys to hash, otherwise hash them all
373 if (verbose
): print('Hashing', key
)
377 def print_args(func
, *args
, **kwargs
):
378 # wrapper to trace function call and arguments
379 print(f
"Called: {func.__name__}")
383 for key
, value
in kwargs
.items():
384 print(f
" {key}={value}")
385 return func(*args
, **kwargs
)
387 def print_args_test():
388 def my_function(a
, b
):
391 print_args(my_function
, a
=1, b
=2)
394 def get_item(dict,var
,**kwargs
):
397 elif 'default' in kwargs
:
398 value
= kwargs
['default']
400 logging
.error('Variable %s not in the dictionary and no default',var
)
402 logging
.info('%s = %s',var
,value
)
405 def print_first(item_list
,num
=3,indent
=0,id=None):
407 Print the first num items of the list followed by '...'
409 :param item_list: List of items to be printed
410 :param num: number of items to list
412 indent_str
= ' ' * indent
414 print(indent_str
, id)
415 if len(item_list
) > 0:
416 print(indent_str
,type(item_list
[0]))
417 for i
in range(min(num
,len(item_list
))):
418 print(indent_str
,item_list
[i
])
419 if len(item_list
) > num
:
420 print(indent_str
,'...')
422 def print_dict_summary(d
,indent
=0,first
=[],first_num
=3):
424 Prints a summary for each array in the dictionary, showing the key and the size of the array.
427 d (dict): The dictionary to summarize.
428 first_items (list): Print the first items for any arrays with these names
431 indent_str
= ' ' * indent
432 for key
, value
in d
.items():
433 # Check if the value is list-like using a simple method check
434 if isinstance(value
, dict):
435 print(f
"{indent_str}{key}")
436 print_dict_summary(value
,first
=first
,indent
=indent
+5,first_num
=first_num
)
437 elif isinstance(value
,np
.ndarray
):
438 if np
.issubdtype(value
.dtype
, np
.number
):
439 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
441 # Handle non-numeric arrays differently
442 print(f
"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
443 elif hasattr(value
, "__iter__") and not isinstance(value
, str): # Check for iterable that is not a string
444 print(f
"{indent_str}{key}: Array of {len(value)} items")
446 print(indent_str
,key
,":",value
)
448 print_first(value
,num
=first_num
,indent
=indent
+5)
451 from datetime
import datetime
455 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
457 if isinstance(input, str):
458 return datetime
.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
459 elif isinstance(input, list):
460 return [str2time(s
) for s
in input]
462 raise ValueError("Input must be a string or a list of strings")
465 # interpolate linearly over nans
467 def filter_nan_values(t1
, v1
):
468 # Filter out NaN values from v1 and corresponding times in t1
469 valid_indices
= ~np
.isnan(v1
) # Indices where v1 is not NaN
470 t1_filtered
= np
.array(t1
)[valid_indices
]
471 v1_filtered
= np
.array(v1
)[valid_indices
]
472 return t1_filtered
, v1_filtered
474 def time_intp(t1
, v1
, t2
):
475 # Check if t1 v1 t2 are 1D arrays
477 logging
.error("Error: t1 is not a 1D array. Dimension: %s", t1
.ndim
)
480 logging
.error("Error: v1 is not a 1D array. Dimension %s:", v1
.ndim
)
483 logging
.errorr("Error: t2 is not a 1D array. Dimension: %s", t2
.ndim
)
485 # Check if t1 and v1 have the same length
486 if len(t1
) != len(v1
):
487 logging
.error("Error: t1 and v1 have different lengths: %s %s",len(t1
),len(v1
))
489 t1_no_nan
, v1_no_nan
= filter_nan_values(t1
, v1
)
490 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
491 # Convert datetime objects to timestamps
492 t1_stamps
= np
.array([t
.timestamp() for t
in t1_no_nan
])
493 t2_stamps
= np
.array([t
.timestamp() for t
in t2
])
495 # Interpolate using the filtered data
496 v2_interpolated
= np
.interp(t2_stamps
, t1_stamps
, v1_no_nan
)
497 if np
.isnan(v2_interpolated
).any():
498 logging
.error('time_intp: interpolated output contains NaN')
500 return v2_interpolated
502 def str2time(strlist
):
503 # Convert array of strings to array of datetime objects
504 return np
.array([datetime
.strptime(dt_str
, '%Y-%m-%dT%H:%M:%SZ') for dt_str
in strlist
])
506 def check_increment(datetime_array
,id=''):
507 # Calculate time differences between consecutive datetime values
508 diffs
= [b
- a
for a
, b
in zip(datetime_array
[:-1], datetime_array
[1:])]
509 diffs_hours
= np
.array([diff
.total_seconds()/3600 for diff
in diffs
])
510 # Check if all time differences are exactlyu 1 hour
511 if all(diffs_hours
== diffs_hours
[0]):
512 logging
.info('%s time array increments are %s hours',id,diffs_hours
[0])
513 if diffs_hours
[0] <= 0 :
514 logging
.error('%s time array increements are not positive',id)
515 return diffs_hours
[0]
517 logging
.info('%s time array increments are min %s max %s',id,
518 np
.min(diffs_hours
),np
.max(diffs_hours
))