Update moisture_rnn.py
[notebooks.git] / fmda / utils.py
blobbae477fdf93bcdcca7b4e767445b757fa8ec7252
1 import numpy as np
2 from functools import singledispatch
3 import pandas as pd
4 import numbers
5 from datetime import datetime
6 import logging
7 import sys
8 import inspect
9 import yaml
10 import hashlib
11 import pickle
12 import os.path as osp
13 from urllib.parse import urlparse
14 import subprocess
16 from itertools import islice
18 class Dict(dict):
19 """
20 A dictionary that allows member access to its keys.
21 A convenience class.
22 """
24 def __init__(self, d):
25 """
26 Updates itself with d.
27 """
28 self.update(d)
30 def __getattr__(self, item):
31 return self[item]
33 def __setattr__(self, item, value):
34 self[item] = value
36 def __getitem__(self, item):
37 if item in self:
38 return super().__getitem__(item)
39 else:
40 for key in self:
41 if isinstance(key,(range,tuple)) and item in key:
42 return super().__getitem__(key)
43 raise KeyError(item)
45 def keys(self):
46 if any([isinstance(key,(range,tuple)) for key in self]):
47 keys = []
48 for key in self:
49 if isinstance(key,(range,tuple)):
50 for k in key:
51 keys.append(k)
52 else:
53 keys.append(key)
54 return keys
55 else:
56 return super().keys()
58 def save(obj, file):
59 with open(file,'wb') as output:
60 dill.dump(obj, output )
62 def load(file):
63 with open(file,'rb') as input:
64 returnitem = dill.load(input)
65 return returnitem
67 # Utility to retrieve files from URL
68 def retrieve_url(url, dest_path, force_download=False):
69 """
70 Downloads a file from a specified URL to a destination path.
72 Parameters:
73 -----------
74 url : str
75 The URL from which to download the file.
76 dest_path : str
77 The destination path where the file should be saved.
78 force_download : bool, optional
79 If True, forces the download even if the file already exists at the destination path.
80 Default is False.
82 Warnings:
83 ---------
84 Prints a warning if the file extension of the URL does not match the destination file extension.
86 Raises:
87 -------
88 AssertionError:
89 If the download fails and the file does not exist at the destination path.
91 Notes:
92 ------
93 This function uses the `wget` command-line tool to download the file. Ensure that `wget` is
94 installed and accessible from the system's PATH.
96 Prints:
97 -------
98 A message indicating whether the file was downloaded or if it already exists at the
99 destination path.
100 """
101 if not osp.exists(dest_path) or force_download:
102 target_extension = osp.splitext(dest_path)[1]
103 url_extension = osp.splitext(urlparse(url).path)[1]
104 if target_extension != url_extension:
105 print("Warning: file extension from url does not match destination file extension")
106 subprocess.call(f"wget -O {dest_path} {url}", shell=True)
107 assert osp.exists(dest_path)
108 print(f"Successfully downloaded {url} to {dest_path}")
109 else:
110 print(f"Target data already exists at {dest_path}")
112 # Function to check if lists are nested, or all elements in given list are in target list
113 def all_items_exist(source_list, target_list):
115 Checks if all items from the source list exist in the target list.
117 Parameters:
118 -----------
119 source_list : list
120 The list of items to check for existence in the target list.
121 target_list : list
122 The list in which to check for the presence of items from the source list.
124 Returns:
125 --------
126 bool
127 True if all items in the source list are present in the target list, False otherwise.
129 Example:
130 --------
131 >>> source_list = [1, 2, 3]
132 >>> target_list = [1, 2, 3, 4, 5]
133 >>> all_items_exist(source_list, target_list)
134 True
136 >>> source_list = [1, 2, 6]
137 >>> all_items_exist(source_list, target_list)
138 False
139 """
140 return all(item in target_list for item in source_list)
142 # Generic helper function to read yaml files
143 def read_yml(yaml_path, subkey=None):
145 Reads a YAML file and optionally retrieves a specific subkey.
147 Parameters:
148 -----------
149 yaml_path : str
150 The path to the YAML file to be read.
151 subkey : str, optional
152 A specific key within the YAML file to retrieve. If provided, only the value associated
153 with this key will be returned. If not provided, the entire YAML file is returned as a
154 dictionary. Default is None.
156 Returns:
157 --------
158 dict or any
159 The contents of the YAML file as a dictionary, or the value associated with the specified
160 subkey if provided.
162 """
163 with open(yaml_path, 'r') as file:
164 d = yaml.safe_load(file)
165 if subkey is not None:
166 d = d[subkey]
167 return d
169 # Use to load nested fmda dictionary of cases
170 def load_and_fix_data(filename):
171 # Given path to FMDA training dictionary, read and return cleaned dictionary
172 # Inputs:
173 # filename: (str) path to file with .pickle extension
174 # Returns:
175 # FMDA dictionary with NA values "fixed"
176 print(f"loading file {filename}")
177 with open(filename, 'rb') as handle:
178 test_dict = pickle.load(handle)
179 for case in test_dict:
180 test_dict[case]['case'] = case
181 test_dict[case]['filename'] = filename
182 for key in test_dict[case].keys():
183 var = test_dict[case][key] # pointer to test_dict[case][key]
184 if isinstance(var,np.ndarray) and (var.dtype.kind == 'f'):
185 nans = np.sum(np.isnan(var))
186 if nans:
187 print('WARNING: case',case,'variable',key,'shape',var.shape,'has',nans,'nan values, fixing')
188 fixnan(var)
189 nans = np.sum(np.isnan(test_dict[case][key]))
190 print('After fixing, remained',nans,'nan values')
191 if not 'title' in test_dict[case].keys():
192 test_dict[case]['title']=case
193 if not 'descr' in test_dict[case].keys():
194 test_dict[case]['descr']=f"{case} FMDA dictionary"
195 return test_dict
197 # Generic helper function to read pickle files
198 def read_pkl(file_path):
200 Reads a pickle file and returns its contents.
202 Parameters:
203 -----------
204 file_path : str
205 The path to the pickle file to be read.
207 Returns:
208 --------
210 The object stored in the pickle file.
212 Prints:
213 -------
214 A message indicating the file path being loaded.
216 Notes:
217 ------
218 This function uses Python's `pickle` module to deserialize the contents of the file. Ensure
219 that the pickle file was created in a safe and trusted environment to avoid security risks
220 associated with loading arbitrary code.
222 """
223 with open(file_path, 'rb') as file:
224 print(f"loading file {file_path}")
225 d = pickle.load(file)
226 return d
228 def logging_setup():
229 logging.basicConfig(
230 level=logging.INFO,
231 format='%(asctime)s - %(levelname)s - %(message)s',
232 stream=sys.stdout
235 numeric_kinds = {'i', 'u', 'f', 'c'}
237 def is_numeric_ndarray(array):
238 if isinstance(array, np.ndarray):
239 return array.dtype.kind in numeric_kinds
240 else:
241 return False
243 def vprint(*args):
244 import inspect
246 frame = inspect.currentframe()
247 if 'verbose' in frame.f_back.f_locals:
248 verbose = frame.f_back.f_locals['verbose']
249 else:
250 verbose = False
252 if verbose:
253 for s in args[:(len(args)-1)]:
254 print(s, end=' ')
255 print(args[-1])
258 ## Function for Hashing numpy arrays
259 def hash_ndarray(arr: np.ndarray) -> str:
261 Generates a unique hash string for a NumPy ndarray.
263 Parameters:
264 -----------
265 arr : np.ndarray
266 The NumPy array to be hashed.
268 Returns:
269 --------
271 A hexadecimal string representing the MD5 hash of the array.
273 Notes:
274 ------
275 This function first converts the NumPy array to a bytes string using the `tobytes()` method,
276 and then computes the MD5 hash of this bytes string. Performance might be bad for very large arrays.
278 Example:
279 --------
280 >>> arr = np.array([1, 2, 3])
281 >>> hash_value = hash_ndarray(arr)
282 >>> print(hash_value)
283 '2a1dd1e1e59d0a384c26951e316cd7e6'
284 """
285 # If input is list, attempt to concatenate and then hash
286 if type(arr) == list:
287 arr = np.vstack(arr)
288 arr_bytes = arr.tobytes()
289 else:
290 # Convert the array to a bytes string
291 arr_bytes = arr.tobytes()
292 # Use hashlib to generate a unique hash
293 hash_obj = hashlib.md5(arr_bytes)
294 return hash_obj.hexdigest()
296 ## Function for Hashing tensorflow models
297 def hash_weights(model):
299 Generates a unique hash string for a the weights of a given Keras model.
301 Parameters:
302 -----------
303 model : A keras model
304 The Keras model to be hashed.
306 Returns:
307 --------
309 A hexadecimal string representing the MD5 hash of the model weights.
312 # Extract all weights and biases
313 weights = model.get_weights()
315 # Convert each weight array to a string
316 weight_str = ''.join([np.array2string(w, separator=',') for w in weights])
318 # Generate a SHA-256 hash of the combined string
319 weight_hash = hashlib.md5(weight_str.encode('utf-8')).hexdigest()
321 return weight_hash
323 ## Generic function to hash dictionary of various types
325 @singledispatch
326 ## Top level hash function with built-in hash function for str, float, int, etc
327 def hash2(x):
328 return hash(x)
330 @hash2.register(np.ndarray)
331 ## Hash numpy array, hash array with pandas and return integer sum
332 def _(x):
333 # return hash(x.tobytes())
334 return np.sum(pd.util.hash_array(x))
336 @hash2.register(list)
337 ## Hash list, convert to tuple
338 def _(x):
339 return hash2(tuple(x))
341 @hash2.register(tuple)
342 def _(x):
343 r = 0
344 for i in range(len(x)):
345 r+=hash2(x[i])
346 return r
348 @hash2.register(dict)
349 ## Hash dict, loop through keys and hash each element via dispatch. Return hashed integer sum of hashes
350 def _(x, keys = None, verbose = False):
351 r = 0 # return value integer
352 if keys is None: # allow user input of keys to hash, otherwise hash them all
353 keys = [*x.keys()]
354 keys.sort()
355 for key in keys:
356 if (verbose): print('Hashing', key)
357 r += hash2(x[key])
358 return hash(r)
360 def print_args(func, *args, **kwargs):
361 # wrapper to trace function call and arguments
362 print(f"Called: {func.__name__}")
363 print("Arguments:")
364 for arg in args:
365 print(f" {arg}")
366 for key, value in kwargs.items():
367 print(f" {key}={value}")
368 return func(*args, **kwargs)
370 def print_args_test():
371 def my_function(a, b):
372 # some code here
373 return a + b
374 print_args(my_function, a=1, b=2)
376 import inspect
377 def get_item(dict,var,**kwargs):
378 if var in dict:
379 value = dict[var]
380 elif 'default' in kwargs:
381 value = kwargs['default']
382 else:
383 logging.error('Variable %s not in the dictionary and no default',var)
384 raise NameError()
385 logging.info('%s = %s',var,value)
386 return value
388 def print_first(item_list,num=3,indent=0,id=None):
390 Print the first num items of the list followed by '...'
392 :param item_list: List of items to be printed
393 :param num: number of items to list
395 indent_str = ' ' * indent
396 if id is not None:
397 print(indent_str, id)
398 if len(item_list) > 0:
399 print(indent_str,type(item_list[0]))
400 for i in range(min(num,len(item_list))):
401 print(indent_str,item_list[i])
402 if len(item_list) > num:
403 print(indent_str,'...')
405 def print_dict_summary(d,indent=0,first=[],first_num=3):
407 Prints a summary for each array in the dictionary, showing the key and the size of the array.
409 Arguments:
410 d (dict): The dictionary to summarize.
411 first_items (list): Print the first items for any arrays with these names
414 indent_str = ' ' * indent
415 for key, value in d.items():
416 # Check if the value is list-like using a simple method check
417 if isinstance(value, dict):
418 print(f"{indent_str}{key}")
419 print_dict_summary(value,first=first,indent=indent+5,first_num=first_num)
420 elif isinstance(value,np.ndarray):
421 if np.issubdtype(value.dtype, np.number):
422 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, min: {value.min()}, max: {value.max()}")
423 else:
424 # Handle non-numeric arrays differently
425 print(f"{indent_str}{key}: NumPy array of shape {value.shape}, type {value.dtype}")
426 elif hasattr(value, "__iter__") and not isinstance(value, str): # Check for iterable that is not a string
427 print(f"{indent_str}{key}: Array of {len(value)} items")
428 else:
429 print(indent_str,key,":",value)
430 if key in first:
431 print_first(value,num=first_num,indent=indent+5)
434 from datetime import datetime
436 def str2time(input):
438 Convert a single string timestamp or a list of string timestamps to corresponding datetime object(s).
440 if isinstance(input, str):
441 return datetime.strptime(input.replace('Z', '+00:00'), '%Y-%m-%dT%H:%M:%S%z')
442 elif isinstance(input, list):
443 return [str2time(s) for s in input]
444 else:
445 raise ValueError("Input must be a string or a list of strings")
448 # interpolate linearly over nans
450 def filter_nan_values(t1, v1):
451 # Filter out NaN values from v1 and corresponding times in t1
452 valid_indices = ~np.isnan(v1) # Indices where v1 is not NaN
453 t1_filtered = np.array(t1)[valid_indices]
454 v1_filtered = np.array(v1)[valid_indices]
455 return t1_filtered, v1_filtered
457 def time_intp(t1, v1, t2):
458 # Check if t1 v1 t2 are 1D arrays
459 if t1.ndim != 1:
460 logging.error("Error: t1 is not a 1D array. Dimension: %s", t1.ndim)
461 return None
462 if v1.ndim != 1:
463 logging.error("Error: v1 is not a 1D array. Dimension %s:", v1.ndim)
464 return None
465 if t2.ndim != 1:
466 logging.errorr("Error: t2 is not a 1D array. Dimension: %s", t2.ndim)
467 return None
468 # Check if t1 and v1 have the same length
469 if len(t1) != len(v1):
470 logging.error("Error: t1 and v1 have different lengths: %s %s",len(t1),len(v1))
471 return None
472 t1_no_nan, v1_no_nan = filter_nan_values(t1, v1)
473 # print('t1_no_nan.dtype=',t1_no_nan.dtype)
474 # Convert datetime objects to timestamps
475 t1_stamps = np.array([t.timestamp() for t in t1_no_nan])
476 t2_stamps = np.array([t.timestamp() for t in t2])
478 # Interpolate using the filtered data
479 v2_interpolated = np.interp(t2_stamps, t1_stamps, v1_no_nan)
480 if np.isnan(v2_interpolated).any():
481 logging.error('time_intp: interpolated output contains NaN')
483 return v2_interpolated
485 def str2time(strlist):
486 # Convert array of strings to array of datetime objects
487 return np.array([datetime.strptime(dt_str, '%Y-%m-%dT%H:%M:%SZ') for dt_str in strlist])
489 def check_increment(datetime_array,id=''):
490 # Calculate time differences between consecutive datetime values
491 diffs = [b - a for a, b in zip(datetime_array[:-1], datetime_array[1:])]
492 diffs_hours = np.array([diff.total_seconds()/3600 for diff in diffs])
493 # Check if all time differences are exactlyu 1 hour
494 if all(diffs_hours == diffs_hours[0]):
495 logging.info('%s time array increments are %s hours',id,diffs_hours[0])
496 if diffs_hours[0] <= 0 :
497 logging.error('%s time array increements are not positive',id)
498 return diffs_hours[0]
499 else:
500 logging.info('%s time array increments are min %s max %s',id,
501 np.min(diffs_hours),np.max(diffs_hours))
502 return -1