Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / utils.py
blob6ea91e13b79959a4c6495100e31d8ee3f712712c
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
13 from urllib.parse import urlparse
14 import subprocess
16 # Utility to retrieve files from URL
17 def retrieve_url(url, dest_path, force_download=False):
18 """
19 Downloads a file from a specified URL to a destination path.
21 Parameters:
22 -----------
23 url : str
24 The URL from which to download the file.
25 dest_path : str
26 The destination path where the file should be saved.
27 force_download : bool, optional
28 If True, forces the download even if the file already exists at the destination path.
29 Default is False.
31 Warnings:
32 ---------
33 Prints a warning if the file extension of the URL does not match the destination file extension.
35 Raises:
36 -------
37 AssertionError:
38 If the download fails and the file does not exist at the destination path.
40 Notes:
41 ------
42 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
43 installed and accessible from the system's PATH.
45 Prints:
46 -------
47 A message indicating whether the file was downloaded or if it already exists at the
48 destination path.
49 """
50 if not osp.exists(dest_path) or force_download:
51 target_extension = osp.splitext(dest_path)[1]
52 url_extension = osp.splitext(urlparse(url).path)[1]
53 if target_extension != url_extension:
54 print("Warning: file extension from url does not match destination file extension")
55 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
56 assert osp.exists(dest_path)
57 print(f"Successfully downloaded {url} to {dest_path}")
58 else:
59 print(f"Target data already exists at {dest_path}")
61 # Function to check if lists are nested, or all elements in given list are in target list
62 def all_items_exist(source_list, target_list):
63 """
64 Checks if all items from the source list exist in the target list.
66 Parameters:
67 -----------
68 source_list : list
69 The list of items to check for existence in the target list.
70 target_list : list
71 The list in which to check for the presence of items from the source list.
73 Returns:
74 --------
75 bool
76 True if all items in the source list are present in the target list, False otherwise.
78 Example:
79 --------
80 >>> source_list = [1, 2, 3]
81 >>> target_list = [1, 2, 3, 4, 5]
82 >>> all_items_exist(source_list, target_list)
83 True
85 >>> source_list = [1, 2, 6]
86 >>> all_items_exist(source_list, target_list)
87 False
88 """
89 return all(item in target_list for item in source_list)
91 # Generic helper function to read yaml files
92 def read_yml(yaml_path, subkey=None):
93 """
94 Reads a YAML file and optionally retrieves a specific subkey.
96 Parameters:
97 -----------
98 yaml_path : str
99 The path to the YAML file to be read.
100 subkey : str, optional
101 A specific key within the YAML file to retrieve. If provided, only the value associated
102 with this key will be returned. If not provided, the entire YAML file is returned as a
103 dictionary. Default is None.
105 Returns:
106 --------
107 dict or any
108 The contents of the YAML file as a dictionary, or the value associated with the specified
109 subkey if provided.
111 """
112 with open(yaml_path, 'r') as file:
113 d = yaml.safe_load(file)
114 if subkey is not None:
115 d = d[subkey]
116 return d
118 # Use to load nested fmda dictionary of cases
119 def load_and_fix_data(filename):
120 # Given path to FMDA training dictionary, read and return cleaned dictionary
121 # Inputs:
122 # filename: (str) path to file with .pickle extension
123 # Returns:
124 # FMDA dictionary with NA values "fixed"
125 print(f"loading file {filename}")
126 with open(filename, 'rb') as handle:
127 test_dict = pickle.load(handle)
128 for case in test_dict:
129 test_dict[case]['case'] = case
130 test_dict[case]['filename'] = filename
131 for key in test_dict[case].keys():
132 var = test_dict[case][key] # pointer to test_dict[case][key]
133 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
134 nans = np.sum(np.isnan(var))
135 if nans:
136 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
137 fixnan(var)
138 nans = np.sum(np.isnan(test_dict[case][key]))
139 print('After fixing, remained',nans,'nan values')
140 if not 'title' in test_dict[case].keys():
141 test_dict[case]['title']=case
142 if not 'descr' in test_dict[case].keys():
143 test_dict[case]['descr']=f"{case} FMDA dictionary"
144 return test_dict
146 # Generic helper function to read pickle files
147 def read_pkl(file_path):
149 Reads a pickle file and returns its contents.
151 Parameters:
152 -----------
153 file_path : str
154 The path to the pickle file to be read.
156 Returns:
157 --------
159 The object stored in the pickle file.
161 Prints:
162 -------
163 A message indicating the file path being loaded.
165 Notes:
166 ------
167 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
168 that the pickle file was created in a safe and trusted environment to avoid security risks
169 associated with loading arbitrary code.
171 """
172 with open(file_path, 'rb') as file:
173 print(f"loading file {file_path}")
174 d = pickle.load(file)
175 return d
177 def logging_setup():
178 logging.basicConfig(
179 level=logging.INFO,
180 format='%(asctime)s - %(levelname)s - %(message)s',
181 stream=sys.stdout
184 numeric_kinds = {'i', 'u', 'f', 'c'}
186 def is_numeric_ndarray(array):
187 if isinstance(array, np.ndarray):
188 return array.dtype.kind in numeric_kinds
189 else:
190 return False
192 def vprint(*args):
193 import inspect
195 frame = inspect.currentframe()
196 if 'verbose' in frame.f_back.f_locals:
197 verbose = frame.f_back.f_locals['verbose']
198 else:
199 verbose = False
201 if verbose:
202 for s in args[:(len(args)-1)]:
203 print(s, end=' ')
204 print(args[-1])
207 ## Function for Hashing numpy arrays
208 def hash_ndarray(arr: np.ndarray) -> str:
210 Generates a unique hash string for a NumPy ndarray.
212 Parameters:
213 -----------
214 arr : np.ndarray
215 The NumPy array to be hashed.
217 Returns:
218 --------
220 A hexadecimal string representing the MD5 hash of the array.
222 Notes:
223 ------
224 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
225 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
227 Example:
228 --------
229 >>> arr = np.array([1, 2, 3])
230 >>> hash_value = hash_ndarray(arr)
231 >>> print(hash_value)
232 '2a1dd1e1e59d0a384c26951e316cd7e6'
233 """
234 # Convert the array to a bytes string
235 arr_bytes = arr.tobytes()
236 # Use hashlib to generate a unique hash
237 hash_obj = hashlib.md5(arr_bytes)
238 return hash_obj.hexdigest()
240 ## Function for Hashing tensorflow models
241 def hash_weights(model):
243 Generates a unique hash string for a the weights of a given Keras model.
245 Parameters:
246 -----------
247 model : A keras model
248 The Keras model to be hashed.
250 Returns:
251 --------
253 A hexadecimal string representing the MD5 hash of the model weights.
256 # Extract all weights and biases
257 weights = model.get_weights()
259 # Convert each weight array to a string
260 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
262 # Generate a SHA-256 hash of the combined string
263 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
265 return weight_hash
267 ## Generic function to hash dictionary of various types
269 @singledispatch
270 ## Top level hash function with built-in hash function for str, float, int, etc
271 def hash2(x):
272 return hash(x)
274 @hash2.register(np.ndarray)
275 ## Hash numpy array, hash array with pandas and return integer sum
276 def _(x):
277 # return hash(x.tobytes())
278 return np.sum(pd.util.hash_array(x))
280 @hash2.register(list)
281 ## Hash list, convert to tuple
282 def _(x):
283 return hash2(tuple(x))
285 @hash2.register(tuple)
286 def _(x):
287 r = 0
288 for i in range(len(x)):
289 r+=hash2(x[i])
290 return r
292 @hash2.register(dict)
293 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
294 def _(x, keys = None, verbose = False):
295 r = 0 # return value integer
296 if keys is None: # allow user input of keys to hash, otherwise hash them all
297 keys = [*x.keys()]
298 keys.sort()
299 for key in keys:
300 if (verbose): print('Hashing', key)
301 r += hash2(x[key])
302 return hash(r)
304 def print_args(func, *args, **kwargs):
305 # wrapper to trace function call and arguments
306 print(f"Called: {func.__name__}")
307 print("Arguments:")
308 for arg in args:
309 print(f" {arg}")
310 for key, value in kwargs.items():
311 print(f" {key}={value}")
312 return func(*args, **kwargs)
314 def print_args_test():
315 def my_function(a, b):
316 # some code here
317 return a + b
318 print_args(my_function, a=1, b=2)
320 import inspect
321 def get_item(dict,var,**kwargs):
322 if var in dict:
323 value = dict[var]
324 elif 'default' in kwargs:
325 value = kwargs['default']
326 else:
327 logging.error('Variable %s not in the dictionary and no default',var)
328 raise NameError()
329 logging.info('%s = %s',var,value)
330 return value
332 def print_first(item_list,num=3,indent=0,id=None):
334 Print the first num items of the list followed by '...'
336 :param item_list: List of items to be printed
337 :param num: number of items to list
339 indent_str = ' ' * indent
340 if id is not None:
341 print(indent_str, id)
342 if len(item_list) > 0:
343 print(indent_str,type(item_list[0]))
344 for i in range(min(num,len(item_list))):
345 print(indent_str,item_list[i])
346 if len(item_list) > num:
347 print(indent_str,'...')
349 def print_dict_summary(d,indent=0,first=[],first_num=3):
351 Prints a summary for each array in the dictionary, showing the key and the size of the array.
353 Arguments:
354 d (dict): The dictionary to summarize.
355 first_items (list): Print the first items for any arrays with these names
358 indent_str = ' ' * indent
359 for key, value in d.items():
360 # Check if the value is list-like using a simple method check
361 if isinstance(value, dict):
362 print(f"{indent_str}{key}")
363 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
364 elif isinstance(value,np.ndarray):
365 if np.issubdtype(value.dtype, np.number):
366 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
367 else:
368 # Handle non-numeric arrays differently
369 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
370 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
371 print(f"{indent_str}{key}: Array of {len(value)} items")
372 else:
373 print(indent_str,key,":",value)
374 if key in first:
375 print_first(value,num=first_num,indent=indent+5)
378 from datetime import datetime
380 def str2time(input):
382 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
384 if isinstance(input, str):
385 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
386 elif isinstance(input, list):
387 return [str2time(s) for s in input]
388 else:
389 raise ValueError("Input must be a string or a list of strings")
392 # interpolate linearly over nans
394 def filter_nan_values(t1, v1):
395 # Filter out NaN values from v1 and corresponding times in t1
396 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
397 t1_filtered = np.array(t1)[valid_indices]
398 v1_filtered = np.array(v1)[valid_indices]
399 return t1_filtered, v1_filtered
401 def time_intp(t1, v1, t2):
402 # Check if t1 v1 t2 are 1D arrays
403 if t1.ndim != 1:
404 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
405 return None
406 if v1.ndim != 1:
407 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
408 return None
409 if t2.ndim != 1:
410 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
411 return None
412 # Check if t1 and v1 have the same length
413 if len(t1) != len(v1):
414 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
415 return None
416 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
417 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
418 # Convert datetime objects to timestamps
419 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
420 t2_stamps = np.array([t.timestamp() for t in t2])
422 # Interpolate using the filtered data
423 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
424 if np.isnan(v2_interpolated).any():
425 logging.error('time_intp: interpolated output contains NaN')
427 return v2_interpolated
429 def str2time(strlist):
430 # Convert array of strings to array of datetime objects
431 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
433 def check_increment(datetime_array,id=''):
434 # Calculate time differences between consecutive datetime values
435 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
436 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
437 # Check if all time differences are exactlyu 1 hour
438 if all(diffs_hours == diffs_hours[0]):
439 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
440 if diffs_hours[0] <= 0 :
441 logging.error('%s time array increements are not positive',id)
442 return diffs_hours[0]
443 else:
444 logging.info('%s time array increments are min %s max %s',id,
445 np.min(diffs_hours),np.max(diffs_hours))
446 return -1