CLEANUP - Delete old pickle files
[notebooks.git] / fmda / utils.py
blob779eb246b955ac9c333f7052a578a2cc970abad5
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
14 # Utility to retrieve files from URL
15 def retrieve_url(url, dest_path, force_download=False):
16 if not osp.exists(dest_path) or force_download:
17 target_extension = osp.splitext(dest_path)[1]
18 url_extension = osp.splitext(urlparse(url).path)[1]
19 if target_extension != url_extension:
20 print("Warning: file extension from url does not match destination file extension")
21 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
22 assert osp.exists(dest_path)
23 print(f"Successfully downloaded {url} to {dest_path}")
24 else:
25 print(f"Target data already exists at {dest_path}")
27 # Function to check if lists are nested, or all elements in given list are in target list
28 def all_items_exist(source_list, target_list):
29 return all(item in target_list for item in source_list)
31 # Generic helper function to read yaml files
32 def read_yml(yaml_path, subkey=None):
33 with open(yaml_path, 'r') as file:
34 d = yaml.safe_load(file)
35 if subkey is not None:
36 d = d[subkey]
37 return d
39 # Use to load nested fmda dictionary of cases
40 def load_and_fix_data(filename):
41 # Given path to FMDA training dictionary, read and return cleaned dictionary
42 # Inputs:
43 # filename: (str) path to file with .pickle extension
44 # Returns:
45 # FMDA dictionary with NA values "fixed"
46 print(f"loading file {filename}")
47 with open(filename, 'rb') as handle:
48 test_dict = pickle.load(handle)
49 for case in test_dict:
50 test_dict[case]['case'] = case
51 test_dict[case]['filename'] = filename
52 for key in test_dict[case].keys():
53 var = test_dict[case][key] # pointer to test_dict[case][key]
54 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
55 nans = np.sum(np.isnan(var))
56 if nans:
57 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
58 fixnan(var)
59 nans = np.sum(np.isnan(test_dict[case][key]))
60 print('After fixing, remained',nans,'nan values')
61 if not 'title' in test_dict[case].keys():
62 test_dict[case]['title']=case
63 if not 'descr' in test_dict[case].keys():
64 test_dict[case]['descr']=f"{case} FMDA dictionary"
65 return test_dict
67 # Generic helper function to read pickle files
68 def read_pkl(file_path):
69 with open(file_path, 'rb') as file:
70 print(f"loading file {file_path}")
71 d = pickle.load(file)
72 return d
74 def logging_setup():
75 logging.basicConfig(
76 level=logging.INFO,
77 format='%(asctime)s - %(levelname)s - %(message)s',
78 stream=sys.stdout
81 numeric_kinds = {'i', 'u', 'f', 'c'}
83 def is_numeric_ndarray(array):
84 if isinstance(array, np.ndarray):
85 return array.dtype.kind in numeric_kinds
86 else:
87 return False
89 def vprint(*args):
90 import inspect
92 frame = inspect.currentframe()
93 if 'verbose' in frame.f_back.f_locals:
94 verbose = frame.f_back.f_locals['verbose']
95 else:
96 verbose = False
98 if verbose:
99 for s in args[:(len(args)-1)]:
100 print(s, end=' ')
101 print(args[-1])
104 ## Function for Hashing numpy arrays
105 def hash_ndarray(arr: np.ndarray) -> str:
106 # Convert the array to a bytes string
107 arr_bytes = arr.tobytes()
108 # Use hashlib to generate a unique hash
109 hash_obj = hashlib.md5(arr_bytes)
110 return hash_obj.hexdigest()
112 ## Function for Hashing tensorflow models
113 def hash_weights(model):
114 # Extract all weights and biases
115 weights = model.get_weights()
117 # Convert each weight array to a string
118 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
120 # Generate a SHA-256 hash of the combined string
121 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
123 return weight_hash
125 ## Generic function to hash dictionary of various types
127 @singledispatch
128 ## Top level hash function with built-in hash function for str, float, int, etc
129 def hash2(x):
130 return hash(x)
132 @hash2.register(np.ndarray)
133 ## Hash numpy array, hash array with pandas and return integer sum
134 def _(x):
135 # return hash(x.tobytes())
136 return np.sum(pd.util.hash_array(x))
138 @hash2.register(list)
139 ## Hash list, convert to tuple
140 def _(x):
141 return hash2(tuple(x))
143 @hash2.register(tuple)
144 def _(x):
145 r = 0
146 for i in range(len(x)):
147 r+=hash2(x[i])
148 return r
150 @hash2.register(dict)
151 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
152 def _(x, keys = None, verbose = False):
153 r = 0 # return value integer
154 if keys is None: # allow user input of keys to hash, otherwise hash them all
155 keys = [*x.keys()]
156 keys.sort()
157 for key in keys:
158 if (verbose): print('Hashing', key)
159 r += hash2(x[key])
160 return hash(r)
162 def print_args(func, *args, **kwargs):
163 # wrapper to trace function call and arguments
164 print(f"Called: {func.__name__}")
165 print("Arguments:")
166 for arg in args:
167 print(f" {arg}")
168 for key, value in kwargs.items():
169 print(f" {key}={value}")
170 return func(*args, **kwargs)
172 def print_args_test():
173 def my_function(a, b):
174 # some code here
175 return a + b
176 print_args(my_function, a=1, b=2)
178 import inspect
179 def get_item(dict,var,**kwargs):
180 if var in dict:
181 value = dict[var]
182 elif 'default' in kwargs:
183 value = kwargs['default']
184 else:
185 logging.error('Variable %s not in the dictionary and no default',var)
186 raise NameError()
187 logging.info('%s = %s',var,value)
188 return value
190 def print_first(item_list,num=3,indent=0,id=None):
192 Print the first num items of the list followed by '...'
194 :param item_list: List of items to be printed
195 :param num: number of items to list
197 indent_str = ' ' * indent
198 if id is not None:
199 print(indent_str, id)
200 if len(item_list) > 0:
201 print(indent_str,type(item_list[0]))
202 for i in range(min(num,len(item_list))):
203 print(indent_str,item_list[i])
204 if len(item_list) > num:
205 print(indent_str,'...')
207 def print_dict_summary(d,indent=0,first=[],first_num=3):
209 Prints a summary for each array in the dictionary, showing the key and the size of the array.
211 Arguments:
212 d (dict): The dictionary to summarize.
213 first_items (list): Print the first items for any arrays with these names
216 indent_str = ' ' * indent
217 for key, value in d.items():
218 # Check if the value is list-like using a simple method check
219 if isinstance(value, dict):
220 print(f"{indent_str}{key}")
221 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
222 elif isinstance(value,np.ndarray):
223 if np.issubdtype(value.dtype, np.number):
224 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
225 else:
226 # Handle non-numeric arrays differently
227 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
228 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
229 print(f"{indent_str}{key}: Array of {len(value)} items")
230 else:
231 print(indent_str,key,":",value)
232 if key in first:
233 print_first(value,num=first_num,indent=indent+5)
236 from datetime import datetime
238 def str2time(input):
240 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
242 if isinstance(input, str):
243 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
244 elif isinstance(input, list):
245 return [str2time(s) for s in input]
246 else:
247 raise ValueError("Input must be a string or a list of strings")
250 # interpolate linearly over nans
252 def filter_nan_values(t1, v1):
253 # Filter out NaN values from v1 and corresponding times in t1
254 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
255 t1_filtered = np.array(t1)[valid_indices]
256 v1_filtered = np.array(v1)[valid_indices]
257 return t1_filtered, v1_filtered
259 def time_intp(t1, v1, t2):
260 # Check if t1 v1 t2 are 1D arrays
261 if t1.ndim != 1:
262 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
263 return None
264 if v1.ndim != 1:
265 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
266 return None
267 if t2.ndim != 1:
268 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
269 return None
270 # Check if t1 and v1 have the same length
271 if len(t1) != len(v1):
272 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
273 return None
274 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
275 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
276 # Convert datetime objects to timestamps
277 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
278 t2_stamps = np.array([t.timestamp() for t in t2])
280 # Interpolate using the filtered data
281 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
282 if np.isnan(v2_interpolated).any():
283 logging.error('time_intp: interpolated output contains NaN')
285 return v2_interpolated
287 def str2time(strlist):
288 # Convert array of strings to array of datetime objects
289 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
291 def check_increment(datetime_array,id=''):
292 # Calculate time differences between consecutive datetime values
293 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
294 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
295 # Check if all time differences are exactlyu 1 hour
296 if all(diffs_hours == diffs_hours[0]):
297 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
298 if diffs_hours[0] <= 0 :
299 logging.error('%s time array increements are not positive',id)
300 return diffs_hours[0]
301 else:
302 logging.info('%s time array increments are min %s max %s',id,
303 np.min(diffs_hours),np.max(diffs_hours))
304 return -1