BUG FIX: add needed imports to utils
[notebooks.git] / fmda / utils.py
blob887c027f3ce77b397713b52fbf2581cc11b371e4
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
13 from urllib.parse import urlparse
14 import subprocess
16 # Utility to retrieve files from URL
17 def retrieve_url(url, dest_path, force_download=False):
18 if not osp.exists(dest_path) or force_download:
19 target_extension = osp.splitext(dest_path)[1]
20 url_extension = osp.splitext(urlparse(url).path)[1]
21 if target_extension != url_extension:
22 print("Warning: file extension from url does not match destination file extension")
23 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
24 assert osp.exists(dest_path)
25 print(f"Successfully downloaded {url} to {dest_path}")
26 else:
27 print(f"Target data already exists at {dest_path}")
29 # Function to check if lists are nested, or all elements in given list are in target list
30 def all_items_exist(source_list, target_list):
31 return all(item in target_list for item in source_list)
33 # Generic helper function to read yaml files
34 def read_yml(yaml_path, subkey=None):
35 with open(yaml_path, 'r') as file:
36 d = yaml.safe_load(file)
37 if subkey is not None:
38 d = d[subkey]
39 return d
41 # Use to load nested fmda dictionary of cases
42 def load_and_fix_data(filename):
43 # Given path to FMDA training dictionary, read and return cleaned dictionary
44 # Inputs:
45 # filename: (str) path to file with .pickle extension
46 # Returns:
47 # FMDA dictionary with NA values "fixed"
48 print(f"loading file {filename}")
49 with open(filename, 'rb') as handle:
50 test_dict = pickle.load(handle)
51 for case in test_dict:
52 test_dict[case]['case'] = case
53 test_dict[case]['filename'] = filename
54 for key in test_dict[case].keys():
55 var = test_dict[case][key] # pointer to test_dict[case][key]
56 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
57 nans = np.sum(np.isnan(var))
58 if nans:
59 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
60 fixnan(var)
61 nans = np.sum(np.isnan(test_dict[case][key]))
62 print('After fixing, remained',nans,'nan values')
63 if not 'title' in test_dict[case].keys():
64 test_dict[case]['title']=case
65 if not 'descr' in test_dict[case].keys():
66 test_dict[case]['descr']=f"{case} FMDA dictionary"
67 return test_dict
69 # Generic helper function to read pickle files
70 def read_pkl(file_path):
71 with open(file_path, 'rb') as file:
72 print(f"loading file {file_path}")
73 d = pickle.load(file)
74 return d
76 def logging_setup():
77 logging.basicConfig(
78 level=logging.INFO,
79 format='%(asctime)s - %(levelname)s - %(message)s',
80 stream=sys.stdout
83 numeric_kinds = {'i', 'u', 'f', 'c'}
85 def is_numeric_ndarray(array):
86 if isinstance(array, np.ndarray):
87 return array.dtype.kind in numeric_kinds
88 else:
89 return False
91 def vprint(*args):
92 import inspect
94 frame = inspect.currentframe()
95 if 'verbose' in frame.f_back.f_locals:
96 verbose = frame.f_back.f_locals['verbose']
97 else:
98 verbose = False
100 if verbose:
101 for s in args[:(len(args)-1)]:
102 print(s, end=' ')
103 print(args[-1])
106 ## Function for Hashing numpy arrays
107 def hash_ndarray(arr: np.ndarray) -> str:
108 # Convert the array to a bytes string
109 arr_bytes = arr.tobytes()
110 # Use hashlib to generate a unique hash
111 hash_obj = hashlib.md5(arr_bytes)
112 return hash_obj.hexdigest()
114 ## Function for Hashing tensorflow models
115 def hash_weights(model):
116 # Extract all weights and biases
117 weights = model.get_weights()
119 # Convert each weight array to a string
120 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
122 # Generate a SHA-256 hash of the combined string
123 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
125 return weight_hash
127 ## Generic function to hash dictionary of various types
129 @singledispatch
130 ## Top level hash function with built-in hash function for str, float, int, etc
131 def hash2(x):
132 return hash(x)
134 @hash2.register(np.ndarray)
135 ## Hash numpy array, hash array with pandas and return integer sum
136 def _(x):
137 # return hash(x.tobytes())
138 return np.sum(pd.util.hash_array(x))
140 @hash2.register(list)
141 ## Hash list, convert to tuple
142 def _(x):
143 return hash2(tuple(x))
145 @hash2.register(tuple)
146 def _(x):
147 r = 0
148 for i in range(len(x)):
149 r+=hash2(x[i])
150 return r
152 @hash2.register(dict)
153 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
154 def _(x, keys = None, verbose = False):
155 r = 0 # return value integer
156 if keys is None: # allow user input of keys to hash, otherwise hash them all
157 keys = [*x.keys()]
158 keys.sort()
159 for key in keys:
160 if (verbose): print('Hashing', key)
161 r += hash2(x[key])
162 return hash(r)
164 def print_args(func, *args, **kwargs):
165 # wrapper to trace function call and arguments
166 print(f"Called: {func.__name__}")
167 print("Arguments:")
168 for arg in args:
169 print(f" {arg}")
170 for key, value in kwargs.items():
171 print(f" {key}={value}")
172 return func(*args, **kwargs)
174 def print_args_test():
175 def my_function(a, b):
176 # some code here
177 return a + b
178 print_args(my_function, a=1, b=2)
180 import inspect
181 def get_item(dict,var,**kwargs):
182 if var in dict:
183 value = dict[var]
184 elif 'default' in kwargs:
185 value = kwargs['default']
186 else:
187 logging.error('Variable %s not in the dictionary and no default',var)
188 raise NameError()
189 logging.info('%s = %s',var,value)
190 return value
192 def print_first(item_list,num=3,indent=0,id=None):
194 Print the first num items of the list followed by '...'
196 :param item_list: List of items to be printed
197 :param num: number of items to list
199 indent_str = ' ' * indent
200 if id is not None:
201 print(indent_str, id)
202 if len(item_list) > 0:
203 print(indent_str,type(item_list[0]))
204 for i in range(min(num,len(item_list))):
205 print(indent_str,item_list[i])
206 if len(item_list) > num:
207 print(indent_str,'...')
209 def print_dict_summary(d,indent=0,first=[],first_num=3):
211 Prints a summary for each array in the dictionary, showing the key and the size of the array.
213 Arguments:
214 d (dict): The dictionary to summarize.
215 first_items (list): Print the first items for any arrays with these names
218 indent_str = ' ' * indent
219 for key, value in d.items():
220 # Check if the value is list-like using a simple method check
221 if isinstance(value, dict):
222 print(f"{indent_str}{key}")
223 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
224 elif isinstance(value,np.ndarray):
225 if np.issubdtype(value.dtype, np.number):
226 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
227 else:
228 # Handle non-numeric arrays differently
229 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
230 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
231 print(f"{indent_str}{key}: Array of {len(value)} items")
232 else:
233 print(indent_str,key,":",value)
234 if key in first:
235 print_first(value,num=first_num,indent=indent+5)
238 from datetime import datetime
240 def str2time(input):
242 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
244 if isinstance(input, str):
245 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
246 elif isinstance(input, list):
247 return [str2time(s) for s in input]
248 else:
249 raise ValueError("Input must be a string or a list of strings")
252 # interpolate linearly over nans
254 def filter_nan_values(t1, v1):
255 # Filter out NaN values from v1 and corresponding times in t1
256 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
257 t1_filtered = np.array(t1)[valid_indices]
258 v1_filtered = np.array(v1)[valid_indices]
259 return t1_filtered, v1_filtered
261 def time_intp(t1, v1, t2):
262 # Check if t1 v1 t2 are 1D arrays
263 if t1.ndim != 1:
264 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
265 return None
266 if v1.ndim != 1:
267 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
268 return None
269 if t2.ndim != 1:
270 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
271 return None
272 # Check if t1 and v1 have the same length
273 if len(t1) != len(v1):
274 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
275 return None
276 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
277 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
278 # Convert datetime objects to timestamps
279 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
280 t2_stamps = np.array([t.timestamp() for t in t2])
282 # Interpolate using the filtered data
283 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
284 if np.isnan(v2_interpolated).any():
285 logging.error('time_intp: interpolated output contains NaN')
287 return v2_interpolated
289 def str2time(strlist):
290 # Convert array of strings to array of datetime objects
291 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
293 def check_increment(datetime_array,id=''):
294 # Calculate time differences between consecutive datetime values
295 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
296 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
297 # Check if all time differences are exactlyu 1 hour
298 if all(diffs_hours == diffs_hours[0]):
299 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
300 if diffs_hours[0] <= 0 :
301 logging.error('%s time array increements are not positive',id)
302 return diffs_hours[0]
303 else:
304 logging.info('%s time array increments are min %s max %s',id,
305 np.min(diffs_hours),np.max(diffs_hours))
306 return -1