Remove old repro files
[notebooks.git] / fmda / utils.py
blobda957d7fbef37f635141dd69613ed8e0024a8780
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
13 from urllib.parse import urlparse
14 import subprocess
16 from itertools import islice
18 class Dict(dict):
19 """
20 A dictionary that allows member access to its keys.
21 A convenience class.
22 """
24 def __init__(self, d):
25 """
26 Updates itself with d.
27 """
28 self.update(d)
30 def __getattr__(self, item):
31 return self[item]
33 def __setattr__(self, item, value):
34 self[item] = value
36 def __getitem__(self, item):
37 if item in self:
38 return super().__getitem__(item)
39 else:
40 for key in self:
41 if isinstance(key,(range,tuple)) and item in key:
42 return super().__getitem__(key)
43 raise KeyError(item)
45 def keys(self):
46 if any([isinstance(key,(range,tuple)) for key in self]):
47 keys = []
48 for key in self:
49 if isinstance(key,(range,tuple)):
50 for k in key:
51 keys.append(k)
52 else:
53 keys.append(key)
54 return keys
55 else:
56 return super().keys()
58 def save(obj, file):
59 with open(file,'wb') as output:
60 dill.dump(obj, output )
62 def load(file):
63 with open(file,'rb') as input:
64 returnitem = dill.load(input)
65 return returnitem
67 # Utility to retrieve files from URL
68 def retrieve_url(url, dest_path, force_download=False):
69 """
70 Downloads a file from a specified URL to a destination path.
72 Parameters:
73 -----------
74 url : str
75 The URL from which to download the file.
76 dest_path : str
77 The destination path where the file should be saved.
78 force_download : bool, optional
79 If True, forces the download even if the file already exists at the destination path.
80 Default is False.
82 Warnings:
83 ---------
84 Prints a warning if the file extension of the URL does not match the destination file extension.
86 Raises:
87 -------
88 AssertionError:
89 If the download fails and the file does not exist at the destination path.
91 Notes:
92 ------
93 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
94 installed and accessible from the system's PATH.
96 Prints:
97 -------
98 A message indicating whether the file was downloaded or if it already exists at the
99 destination path.
100 """
101 if not osp.exists(dest_path) or force_download:
102 print(f"Attempting to downloaded {url} to {dest_path}")
103 target_extension = osp.splitext(dest_path)[1]
104 url_extension = osp.splitext(urlparse(url).path)[1]
105 if target_extension != url_extension:
106 print("Warning: file extension from url does not match destination file extension")
107 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
108 assert osp.exists(dest_path)
109 print(f"Successfully downloaded {url} to {dest_path}")
110 else:
111 print(f"Target data already exists at {dest_path}")
113 # Function to check if lists are nested, or all elements in given list are in target list
114 def all_items_exist(source_list, target_list):
116 Checks if all items from the source list exist in the target list.
118 Parameters:
119 -----------
120 source_list : list
121 The list of items to check for existence in the target list.
122 target_list : list
123 The list in which to check for the presence of items from the source list.
125 Returns:
126 --------
127 bool
128 True if all items in the source list are present in the target list, False otherwise.
130 Example:
131 --------
132 >>> source_list = [1, 2, 3]
133 >>> target_list = [1, 2, 3, 4, 5]
134 >>> all_items_exist(source_list, target_list)
135 True
137 >>> source_list = [1, 2, 6]
138 >>> all_items_exist(source_list, target_list)
139 False
140 """
141 return all(item in target_list for item in source_list)
143 # Generic helper function to read yaml files
144 def read_yml(yaml_path, subkey=None):
146 Reads a YAML file and optionally retrieves a specific subkey.
148 Parameters:
149 -----------
150 yaml_path : str
151 The path to the YAML file to be read.
152 subkey : str, optional
153 A specific key within the YAML file to retrieve. If provided, only the value associated
154 with this key will be returned. If not provided, the entire YAML file is returned as a
155 dictionary. Default is None.
157 Returns:
158 --------
159 dict or any
160 The contents of the YAML file as a dictionary, or the value associated with the specified
161 subkey if provided.
163 """
164 with open(yaml_path, 'r') as file:
165 d = yaml.safe_load(file)
166 if subkey is not None:
167 d = d[subkey]
168 return d
170 # Use to load nested fmda dictionary of cases
171 def load_and_fix_data(filename):
172 # Given path to FMDA training dictionary, read and return cleaned dictionary
173 # Inputs:
174 # filename: (str) path to file with .pickle extension
175 # Returns:
176 # FMDA dictionary with NA values "fixed"
177 print(f"loading file {filename}")
178 with open(filename, 'rb') as handle:
179 test_dict = pickle.load(handle)
180 for case in test_dict:
181 test_dict[case]['case'] = case
182 test_dict[case]['filename'] = filename
183 for key in test_dict[case].keys():
184 var = test_dict[case][key] # pointer to test_dict[case][key]
185 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
186 nans = np.sum(np.isnan(var))
187 if nans:
188 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
189 fixnan(var)
190 nans = np.sum(np.isnan(test_dict[case][key]))
191 print('After fixing, remained',nans,'nan values')
192 if not 'title' in test_dict[case].keys():
193 test_dict[case]['title']=case
194 if not 'descr' in test_dict[case].keys():
195 test_dict[case]['descr']=f"{case} FMDA dictionary"
196 return test_dict
198 # Generic helper function to read pickle files
199 def read_pkl(file_path):
201 Reads a pickle file and returns its contents.
203 Parameters:
204 -----------
205 file_path : str
206 The path to the pickle file to be read.
208 Returns:
209 --------
211 The object stored in the pickle file.
213 Prints:
214 -------
215 A message indicating the file path being loaded.
217 Notes:
218 ------
219 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
220 that the pickle file was created in a safe and trusted environment to avoid security risks
221 associated with loading arbitrary code.
223 """
224 with open(file_path, 'rb') as file:
225 print(f"loading file {file_path}")
226 d = pickle.load(file)
227 return d
229 def logging_setup():
230 logging.basicConfig(
231 level=logging.INFO,
232 format='%(asctime)s - %(levelname)s - %(message)s',
233 stream=sys.stdout
236 numeric_kinds = {'i', 'u', 'f', 'c'}
238 def is_numeric_ndarray(array):
239 if isinstance(array, np.ndarray):
240 return array.dtype.kind in numeric_kinds
241 else:
242 return False
244 def vprint(*args):
245 import inspect
247 frame = inspect.currentframe()
248 if 'verbose' in frame.f_back.f_locals:
249 verbose = frame.f_back.f_locals['verbose']
250 else:
251 verbose = False
253 if verbose:
254 for s in args[:(len(args)-1)]:
255 print(s, end=' ')
256 print(args[-1])
259 ## Function for Hashing numpy arrays
260 def hash_ndarray(arr: np.ndarray) -> str:
262 Generates a unique hash string for a NumPy ndarray.
264 Parameters:
265 -----------
266 arr : np.ndarray
267 The NumPy array to be hashed.
269 Returns:
270 --------
272 A hexadecimal string representing the MD5 hash of the array.
274 Notes:
275 ------
276 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
277 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
279 Example:
280 --------
281 >>> arr = np.array([1, 2, 3])
282 >>> hash_value = hash_ndarray(arr)
283 >>> print(hash_value)
284 '2a1dd1e1e59d0a384c26951e316cd7e6'
285 """
286 # If input is list, attempt to concatenate and then hash
287 if type(arr) == list:
288 arr = np.vstack(arr)
289 arr_bytes = arr.tobytes()
290 else:
291 # Convert the array to a bytes string
292 arr_bytes = arr.tobytes()
293 # Use hashlib to generate a unique hash
294 hash_obj = hashlib.md5(arr_bytes)
295 return hash_obj.hexdigest()
297 ## Function for Hashing tensorflow models
298 def hash_weights(model):
300 Generates a unique hash string for a the weights of a given Keras model.
302 Parameters:
303 -----------
304 model : A keras model
305 The Keras model to be hashed.
307 Returns:
308 --------
310 A hexadecimal string representing the MD5 hash of the model weights.
313 # Extract all weights and biases
314 weights = model.get_weights()
316 # Convert each weight array to a string
317 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
319 # Generate a SHA-256 hash of the combined string
320 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
322 return weight_hash
324 ## Generic function to hash dictionary of various types
326 @singledispatch
327 ## Top level hash function with built-in hash function for str, float, int, etc
328 def hash2(x):
329 return hash(x)
331 @hash2.register(np.ndarray)
332 ## Hash numpy array, hash array with pandas and return integer sum
333 def _(x):
334 # return hash(x.tobytes())
335 return np.sum(pd.util.hash_array(x))
337 @hash2.register(list)
338 ## Hash list, convert to tuple
339 def _(x):
340 return hash2(tuple(x))
342 @hash2.register(tuple)
343 def _(x):
344 r = 0
345 for i in range(len(x)):
346 r+=hash2(x[i])
347 return r
349 @hash2.register(dict)
350 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
351 def _(x, keys = None, verbose = False):
352 r = 0 # return value integer
353 if keys is None: # allow user input of keys to hash, otherwise hash them all
354 keys = [*x.keys()]
355 keys.sort()
356 for key in keys:
357 if (verbose): print('Hashing', key)
358 r += hash2(x[key])
359 return hash(r)
361 def print_args(func, *args, **kwargs):
362 # wrapper to trace function call and arguments
363 print(f"Called: {func.__name__}")
364 print("Arguments:")
365 for arg in args:
366 print(f" {arg}")
367 for key, value in kwargs.items():
368 print(f" {key}={value}")
369 return func(*args, **kwargs)
371 def print_args_test():
372 def my_function(a, b):
373 # some code here
374 return a + b
375 print_args(my_function, a=1, b=2)
377 import inspect
378 def get_item(dict,var,**kwargs):
379 if var in dict:
380 value = dict[var]
381 elif 'default' in kwargs:
382 value = kwargs['default']
383 else:
384 logging.error('Variable %s not in the dictionary and no default',var)
385 raise NameError()
386 logging.info('%s = %s',var,value)
387 return value
389 def print_first(item_list,num=3,indent=0,id=None):
391 Print the first num items of the list followed by '...'
393 :param item_list: List of items to be printed
394 :param num: number of items to list
396 indent_str = ' ' * indent
397 if id is not None:
398 print(indent_str, id)
399 if len(item_list) > 0:
400 print(indent_str,type(item_list[0]))
401 for i in range(min(num,len(item_list))):
402 print(indent_str,item_list[i])
403 if len(item_list) > num:
404 print(indent_str,'...')
406 def print_dict_summary(d,indent=0,first=[],first_num=3):
408 Prints a summary for each array in the dictionary, showing the key and the size of the array.
410 Arguments:
411 d (dict): The dictionary to summarize.
412 first_items (list): Print the first items for any arrays with these names
415 indent_str = ' ' * indent
416 for key, value in d.items():
417 # Check if the value is list-like using a simple method check
418 if isinstance(value, dict):
419 print(f"{indent_str}{key}")
420 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
421 elif isinstance(value,np.ndarray):
422 if np.issubdtype(value.dtype, np.number):
423 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
424 else:
425 # Handle non-numeric arrays differently
426 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
427 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
428 print(f"{indent_str}{key}: Array of {len(value)} items")
429 else:
430 print(indent_str,key,":",value)
431 if key in first:
432 print_first(value,num=first_num,indent=indent+5)
435 from datetime import datetime
437 def str2time(input):
439 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
441 if isinstance(input, str):
442 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
443 elif isinstance(input, list):
444 return [str2time(s) for s in input]
445 else:
446 raise ValueError("Input must be a string or a list of strings")
449 # interpolate linearly over nans
451 def filter_nan_values(t1, v1):
452 # Filter out NaN values from v1 and corresponding times in t1
453 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
454 t1_filtered = np.array(t1)[valid_indices]
455 v1_filtered = np.array(v1)[valid_indices]
456 return t1_filtered, v1_filtered
458 def time_intp(t1, v1, t2):
459 # Check if t1 v1 t2 are 1D arrays
460 if t1.ndim != 1:
461 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
462 return None
463 if v1.ndim != 1:
464 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
465 return None
466 if t2.ndim != 1:
467 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
468 return None
469 # Check if t1 and v1 have the same length
470 if len(t1) != len(v1):
471 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
472 return None
473 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
474 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
475 # Convert datetime objects to timestamps
476 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
477 t2_stamps = np.array([t.timestamp() for t in t2])
479 # Interpolate using the filtered data
480 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
481 if np.isnan(v2_interpolated).any():
482 logging.error('time_intp: interpolated output contains NaN')
484 return v2_interpolated
486 def str2time(strlist):
487 # Convert array of strings to array of datetime objects
488 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
490 def check_increment(datetime_array,id=''):
491 # Calculate time differences between consecutive datetime values
492 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
493 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
494 # Check if all time differences are exactlyu 1 hour
495 if all(diffs_hours == diffs_hours[0]):
496 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
497 if diffs_hours[0] <= 0 :
498 logging.error('%s time array increements are not positive',id)
499 return diffs_hours[0]
500 else:
501 logging.info('%s time array increments are min %s max %s',id,
502 np.min(diffs_hours),np.max(diffs_hours))
503 return -1