Update rnn_train_versions.ipynb
[notebooks.git] / fmda / utils.py
blob444b66c39175ecbdcfc9cc9ebc2b4b93b30704bf
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
13 from urllib.parse import urlparse
14 import subprocess
16 from itertools import islice
18 def rmse_3d(preds, y_test):
19 """
20 Calculate RMSE for ndarrays structured as (batch_size, timesteps, features).
21 The first dimension, batch_size, could denote distinct locations. The second, timesteps, is length of sequence
22 """
23 squared_diff = np.square(preds - y_test)
25 # Mean squared error along the timesteps and dimensions (axis 1 and 2)
26 mse = np.mean(squared_diff, axis=(1, 2))
28 # Root mean squared error (RMSE) for each timeseries
29 rmses = np.sqrt(mse)
31 return rmses
34 class Dict(dict):
35 """
36 A dictionary that allows member access to its keys.
37 A convenience class.
38 """
40 def __init__(self, d):
41 """
42 Updates itself with d.
43 """
44 self.update(d)
46 def __getattr__(self, item):
47 return self[item]
49 def __setattr__(self, item, value):
50 self[item] = value
52 def __getitem__(self, item):
53 if item in self:
54 return super().__getitem__(item)
55 else:
56 for key in self:
57 if isinstance(key,(range,tuple)) and item in key:
58 return super().__getitem__(key)
59 raise KeyError(item)
61 def keys(self):
62 if any([isinstance(key,(range,tuple)) for key in self]):
63 keys = []
64 for key in self:
65 if isinstance(key,(range,tuple)):
66 for k in key:
67 keys.append(k)
68 else:
69 keys.append(key)
70 return keys
71 else:
72 return super().keys()
74 def save(obj, file):
75 with open(file,'wb') as output:
76 dill.dump(obj, output )
78 def load(file):
79 with open(file,'rb') as input:
80 returnitem = dill.load(input)
81 return returnitem
83 # Utility to retrieve files from URL
84 def retrieve_url(url, dest_path, force_download=False):
85 """
86 Downloads a file from a specified URL to a destination path.
88 Parameters:
89 -----------
90 url : str
91 The URL from which to download the file.
92 dest_path : str
93 The destination path where the file should be saved.
94 force_download : bool, optional
95 If True, forces the download even if the file already exists at the destination path.
96 Default is False.
98 Warnings:
99 ---------
100 Prints a warning if the file extension of the URL does not match the destination file extension.
102 Raises:
103 -------
104 AssertionError:
105 If the download fails and the file does not exist at the destination path.
107 Notes:
108 ------
109 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
110 installed and accessible from the system's PATH.
112 Prints:
113 -------
114 A message indicating whether the file was downloaded or if it already exists at the
115 destination path.
116 """
117 if not osp.exists(dest_path) or force_download:
118 print(f"Attempting to downloaded {url} to {dest_path}")
119 target_extension = osp.splitext(dest_path)[1]
120 url_extension = osp.splitext(urlparse(url).path)[1]
121 if target_extension != url_extension:
122 print("Warning: file extension from url does not match destination file extension")
123 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
124 assert osp.exists(dest_path)
125 print(f"Successfully downloaded {url} to {dest_path}")
126 else:
127 print(f"Target data already exists at {dest_path}")
129 # Function to check if lists are nested, or all elements in given list are in target list
130 def all_items_exist(source_list, target_list):
132 Checks if all items from the source list exist in the target list.
134 Parameters:
135 -----------
136 source_list : list
137 The list of items to check for existence in the target list.
138 target_list : list
139 The list in which to check for the presence of items from the source list.
141 Returns:
142 --------
143 bool
144 True if all items in the source list are present in the target list, False otherwise.
146 Example:
147 --------
148 >>> source_list = [1, 2, 3]
149 >>> target_list = [1, 2, 3, 4, 5]
150 >>> all_items_exist(source_list, target_list)
151 True
153 >>> source_list = [1, 2, 6]
154 >>> all_items_exist(source_list, target_list)
155 False
156 """
157 return all(item in target_list for item in source_list)
159 # Generic helper function to read yaml files
160 def read_yml(yaml_path, subkey=None):
162 Reads a YAML file and optionally retrieves a specific subkey.
164 Parameters:
165 -----------
166 yaml_path : str
167 The path to the YAML file to be read.
168 subkey : str, optional
169 A specific key within the YAML file to retrieve. If provided, only the value associated
170 with this key will be returned. If not provided, the entire YAML file is returned as a
171 dictionary. Default is None.
173 Returns:
174 --------
175 dict or any
176 The contents of the YAML file as a dictionary, or the value associated with the specified
177 subkey if provided.
179 """
180 with open(yaml_path, 'r') as file:
181 d = yaml.safe_load(file)
182 if subkey is not None:
183 d = d[subkey]
184 return d
186 # Use to load nested fmda dictionary of cases
187 def load_and_fix_data(filename):
188 # Given path to FMDA training dictionary, read and return cleaned dictionary
189 # Inputs:
190 # filename: (str) path to file with .pickle extension
191 # Returns:
192 # FMDA dictionary with NA values "fixed"
193 print(f"loading file {filename}")
194 with open(filename, 'rb') as handle:
195 test_dict = pickle.load(handle)
196 for case in test_dict:
197 test_dict[case]['case'] = case
198 test_dict[case]['filename'] = filename
199 for key in test_dict[case].keys():
200 var = test_dict[case][key] # pointer to test_dict[case][key]
201 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
202 nans = np.sum(np.isnan(var))
203 if nans:
204 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
205 fixnan(var)
206 nans = np.sum(np.isnan(test_dict[case][key]))
207 print('After fixing, remained',nans,'nan values')
208 if not 'title' in test_dict[case].keys():
209 test_dict[case]['title']=case
210 if not 'descr' in test_dict[case].keys():
211 test_dict[case]['descr']=f"{case} FMDA dictionary"
212 return test_dict
214 # Generic helper function to read pickle files
215 def read_pkl(file_path):
217 Reads a pickle file and returns its contents.
219 Parameters:
220 -----------
221 file_path : str
222 The path to the pickle file to be read.
224 Returns:
225 --------
227 The object stored in the pickle file.
229 Prints:
230 -------
231 A message indicating the file path being loaded.
233 Notes:
234 ------
235 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
236 that the pickle file was created in a safe and trusted environment to avoid security risks
237 associated with loading arbitrary code.
239 """
240 with open(file_path, 'rb') as file:
241 print(f"loading file {file_path}")
242 d = pickle.load(file)
243 return d
245 def logging_setup():
246 logging.basicConfig(
247 level=logging.INFO,
248 format='%(asctime)s - %(levelname)s - %(message)s',
249 stream=sys.stdout
252 numeric_kinds = {'i', 'u', 'f', 'c'}
254 def is_numeric_ndarray(array):
255 if isinstance(array, np.ndarray):
256 return array.dtype.kind in numeric_kinds
257 else:
258 return False
260 def vprint(*args):
261 import inspect
263 frame = inspect.currentframe()
264 if 'verbose' in frame.f_back.f_locals:
265 verbose = frame.f_back.f_locals['verbose']
266 else:
267 verbose = False
269 if verbose:
270 for s in args[:(len(args)-1)]:
271 print(s, end=' ')
272 print(args[-1])
275 ## Function for Hashing numpy arrays
276 def hash_ndarray(arr: np.ndarray) -> str:
278 Generates a unique hash string for a NumPy ndarray.
280 Parameters:
281 -----------
282 arr : np.ndarray
283 The NumPy array to be hashed.
285 Returns:
286 --------
288 A hexadecimal string representing the MD5 hash of the array.
290 Notes:
291 ------
292 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
293 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
295 Example:
296 --------
297 >>> arr = np.array([1, 2, 3])
298 >>> hash_value = hash_ndarray(arr)
299 >>> print(hash_value)
300 '2a1dd1e1e59d0a384c26951e316cd7e6'
301 """
302 # If input is list, attempt to concatenate and then hash
303 if type(arr) == list:
304 arr = np.vstack(arr)
305 arr_bytes = arr.tobytes()
306 else:
307 # Convert the array to a bytes string
308 arr_bytes = arr.tobytes()
309 # Use hashlib to generate a unique hash
310 hash_obj = hashlib.md5(arr_bytes)
311 return hash_obj.hexdigest()
313 ## Function for Hashing tensorflow models
314 def hash_weights(model):
316 Generates a unique hash string for a the weights of a given Keras model.
318 Parameters:
319 -----------
320 model : A keras model
321 The Keras model to be hashed.
323 Returns:
324 --------
326 A hexadecimal string representing the MD5 hash of the model weights.
329 # Extract all weights and biases
330 weights = model.get_weights()
332 # Convert each weight array to a string
333 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
335 # Generate a SHA-256 hash of the combined string
336 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
338 return weight_hash
340 ## Generic function to hash dictionary of various types
342 @singledispatch
343 ## Top level hash function with built-in hash function for str, float, int, etc
344 def hash2(x):
345 return hash(x)
347 @hash2.register(np.ndarray)
348 ## Hash numpy array, hash array with pandas and return integer sum
349 def _(x):
350 # return hash(x.tobytes())
351 return np.sum(pd.util.hash_array(x))
353 @hash2.register(list)
354 ## Hash list, convert to tuple
355 def _(x):
356 return hash2(tuple(x))
358 @hash2.register(tuple)
359 def _(x):
360 r = 0
361 for i in range(len(x)):
362 r+=hash2(x[i])
363 return r
365 @hash2.register(dict)
366 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
367 def _(x, keys = None, verbose = False):
368 r = 0 # return value integer
369 if keys is None: # allow user input of keys to hash, otherwise hash them all
370 keys = [*x.keys()]
371 keys.sort()
372 for key in keys:
373 if (verbose): print('Hashing', key)
374 r += hash2(x[key])
375 return hash(r)
377 def print_args(func, *args, **kwargs):
378 # wrapper to trace function call and arguments
379 print(f"Called: {func.__name__}")
380 print("Arguments:")
381 for arg in args:
382 print(f" {arg}")
383 for key, value in kwargs.items():
384 print(f" {key}={value}")
385 return func(*args, **kwargs)
387 def print_args_test():
388 def my_function(a, b):
389 # some code here
390 return a + b
391 print_args(my_function, a=1, b=2)
393 import inspect
394 def get_item(dict,var,**kwargs):
395 if var in dict:
396 value = dict[var]
397 elif 'default' in kwargs:
398 value = kwargs['default']
399 else:
400 logging.error('Variable %s not in the dictionary and no default',var)
401 raise NameError()
402 logging.info('%s = %s',var,value)
403 return value
405 def print_first(item_list,num=3,indent=0,id=None):
407 Print the first num items of the list followed by '...'
409 :param item_list: List of items to be printed
410 :param num: number of items to list
412 indent_str = ' ' * indent
413 if id is not None:
414 print(indent_str, id)
415 if len(item_list) > 0:
416 print(indent_str,type(item_list[0]))
417 for i in range(min(num,len(item_list))):
418 print(indent_str,item_list[i])
419 if len(item_list) > num:
420 print(indent_str,'...')
422 def print_dict_summary(d,indent=0,first=[],first_num=3):
424 Prints a summary for each array in the dictionary, showing the key and the size of the array.
426 Arguments:
427 d (dict): The dictionary to summarize.
428 first_items (list): Print the first items for any arrays with these names
431 indent_str = ' ' * indent
432 for key, value in d.items():
433 # Check if the value is list-like using a simple method check
434 if isinstance(value, dict):
435 print(f"{indent_str}{key}")
436 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
437 elif isinstance(value,np.ndarray):
438 if np.issubdtype(value.dtype, np.number):
439 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
440 else:
441 # Handle non-numeric arrays differently
442 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
443 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
444 print(f"{indent_str}{key}: Array of {len(value)} items")
445 else:
446 print(indent_str,key,":",value)
447 if key in first:
448 print_first(value,num=first_num,indent=indent+5)
451 from datetime import datetime
453 def str2time(input):
455 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
457 if isinstance(input, str):
458 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
459 elif isinstance(input, list):
460 return [str2time(s) for s in input]
461 else:
462 raise ValueError("Input must be a string or a list of strings")
465 # interpolate linearly over nans
467 def filter_nan_values(t1, v1):
468 # Filter out NaN values from v1 and corresponding times in t1
469 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
470 t1_filtered = np.array(t1)[valid_indices]
471 v1_filtered = np.array(v1)[valid_indices]
472 return t1_filtered, v1_filtered
474 def time_intp(t1, v1, t2):
475 # Check if t1 v1 t2 are 1D arrays
476 if t1.ndim != 1:
477 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
478 return None
479 if v1.ndim != 1:
480 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
481 return None
482 if t2.ndim != 1:
483 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
484 return None
485 # Check if t1 and v1 have the same length
486 if len(t1) != len(v1):
487 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
488 return None
489 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
490 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
491 # Convert datetime objects to timestamps
492 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
493 t2_stamps = np.array([t.timestamp() for t in t2])
495 # Interpolate using the filtered data
496 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
497 if np.isnan(v2_interpolated).any():
498 logging.error('time_intp: interpolated output contains NaN')
500 return v2_interpolated
502 def str2time(strlist):
503 # Convert array of strings to array of datetime objects
504 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
506 def check_increment(datetime_array,id=''):
507 # Calculate time differences between consecutive datetime values
508 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
509 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
510 # Check if all time differences are exactlyu 1 hour
511 if all(diffs_hours == diffs_hours[0]):
512 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
513 if diffs_hours[0] <= 0 :
514 logging.error('%s time array increements are not positive',id)
515 return diffs_hours[0]
516 else:
517 logging.info('%s time array increments are min %s max %s',id,
518 np.min(diffs_hours),np.max(diffs_hours))
519 return -1