Update fmda_rnn_train_and_save.ipynb
[notebooks.git] / fmda / data_funcs.py
blob9d95bcc281a22c4a26b43bd96ec9d936e2c8d69f
1 ## Set of Functions to process and format fuel moisture model inputs
2 ## These functions are specific to the particulars of the input data, and may not be generally applicable
3 ## Generally applicable functions should be in utils.py
4 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6 import numpy as np, random
7 from numpy.random import rand
8 import tensorflow as tf
9 import pickle, os
10 from sklearn.metrics import mean_squared_error
11 import matplotlib.pyplot as plt
12 from moisture_models import model_decay, model_moisture
13 from datetime import datetime, timedelta
14 from utils import is_numeric_ndarray, hash2
15 import json
16 import copy
17 import subprocess
18 import os.path as osp
19 from utils import Dict, str2time, check_increment, time_intp, read_pkl
20 import warnings
23 # New Dict Functions as of 2024-10-2, needs more testing
24 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26 feature_types = {
27 # Static features are based on physical location, e.g. location of RAWS site
28 'static': ['elev', 'lon', 'lat'],
29 # Atmospheric weather features come from either RAWS subdict or HRRR
30 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew'],
31 # Features that require calculation. NOTE: rain only calculated in HRRR, not RAWS
32 'engineered': ['doy', 'hod', 'rain']
35 def build_train_dict(input_file_paths, params_data, spatial=True, atm_source="HRRR", forecast_step=3, verbose=True, features_subset=None, drop_na=True):
36 # TODO: process involves multiple loops over dictionary keys. Inefficient, but functionality are isolated in separate functions that call for loops so will be tedious to fix
38 # Define Forecast Step for atmospheric data
39 # If atm_source is RAWS, forecast hour not used AND if spatial data must have fetures_subset input to avoid incompatibility with building training data
40 # For HRRR data, hard coding previous step as f00, meaning rain will be calculated as diff between forecast hour and 0th hour
41 if atm_source == "RAWS":
42 print("Atmospheric data source is RAWS, so forecast_step is not used")
43 forecast_step = 0 # set to 0 for future compatibility
44 fstep=fprev=None
45 if spatial:
46 assert features_subset is not None, "For RAWS atmospheric data as source for spatial training set, argument features_subset cannot be None. Provide a list of features to subset the RAWS locations otherwise there will be errors when trying to build models with certain features."
47 elif atm_source == "HRRR":
48 fstep = int2fstep(forecast_step)
49 if forecast_step > 0:
50 fprev = int2fstep(forecast_step-1)
51 else:
52 fprev = "f00"
54 # Extract hours value from data params since it might need to change based on forecast hour time shift
55 hours = params_data['hours']
56 if forecast_step > 0 and drop_na and hours is not None:
57 hours = int(hours - forecast_step)
59 # Loop over input dictionary cases, extract and calculate features, then run data filters
60 new_dict = {}
61 for input_file_path in input_file_paths:
62 if verbose:
63 print("~"*75)
64 print(f"Extracting data from input file {input_file_path}")
65 dict0 = read_pkl(input_file_path)
66 for key in dict0:
68 # Extract features from subdicts
69 X, names = build_features_single(dict0[key], atm=atm_source, fstep=fstep, fprev = fprev)
71 # Get time from HRRR (bc always regular intervals) and interpolate RAWS data to those times
72 time = str2time(dict0[key]['HRRR']['time'])
73 hrrr_increment = check_increment(time,id=key+f' {"HRRR"}.time')
74 if hrrr_increment < 1:
75 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
76 raise(ValueError)
77 time_raws = str2time(dict0[key]['RAWS']['time_raws'])
78 check_increment(time_raws,id=dict0[key]['loc']['STID']+' RAWS.time_raws')
80 # Extract outcome data
81 y = get_fm(dict0[key], time)
83 # Shift atmospheric data in time if using a forecast step
84 if forecast_step > 0:
85 # Get indices to shift
86 atm_names = feature_types['atm'] + ['rain']
87 indices = [names.index(item) for item in atm_names if item in names]
88 # Shift indices to future in time to line up forecast with desired time
89 X = shift_time(X, indices, forecast_step)
90 if drop_na:
91 print(f"Shifted time based on forecast step {forecast_step}. Dropping NA at beginning of feature data and corresponding times of output data")
92 X = X[forecast_step:, :]
93 y = y[forecast_step:]
94 time = time[forecast_step:]
96 new_dict[key] = {
97 'id': key,
98 'case': key,
99 'filename': input_file_path,
100 'loc': dict0[key]['loc'],
101 'time': time,
102 'X': X,
103 'y': y,
104 'features_list': names,
105 'atm_source': atm_source,
106 'forecast_step': forecast_step
109 # Run Data Filters
110 # Subset timeseries into shorter stretches, discard short ones
111 if hours is not None:
112 if verbose:
113 print("~"*75)
114 print(f"Splitting Timeseries into smaller portions to aid with data filtering. Input data param for max timeseries hours: {hours}")
115 new_dict = split_timeseries(new_dict, hours=hours, verbose=verbose)
116 new_dict = discard_keys_with_short_y(new_dict, hours=hours, verbose=False)
118 # Check for suspect data
119 flags = flag_dict_keys(new_dict, params_data['zero_lag_threshold'], params_data['max_intp_time'], max_y = params_data['max_fm'], min_y = params_data['min_fm'], verbose=verbose)
121 # Remove flagged cases
122 cases = list([*new_dict.keys()])
123 flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
124 remove_key_list(new_dict, flagged_cases, verbose=verbose)
126 if spatial:
127 if atm_source == "HRRR":
128 new_dict = combine_nested(new_dict)
129 elif atm_source == "RAWS":
130 new_dict = subset_by_features(new_dict, features_subset)
131 new_dict = combine_nested(new_dict)
133 return Dict(new_dict)
135 def int2fstep(forecast_step, max_hour=5):
137 Converts an integer forecast step into a formatted string with a prefix 'f'
138 and zero-padded to two digits. Format of HRRR data forecast hours
140 Parameters:
141 - forecast_step (int): The forecast step to convert. Must be an integer
142 between 0 and max_hour (inclusive).
143 - max_hour (int, optional): The maximum allowable forecast step.
144 Depends on how many forecast hours were collected for input data. Default is 5.
146 Returns:
147 - str: A formatted string representing the forecast step, prefixed with 'f'
148 and zero-padded to two digits (e.g., 'f01', 'f02').
150 Raises:
151 - TypeError: If forecast_step is not an integer.
152 - ValueError: If forecast_step is not between 0 and max_hour (inclusive).
154 Example:
155 >>> int2fstep(3)
156 'f03'
158 if not isinstance(forecast_step, int):
159 raise TypeError(f"forecast_step must be an integer.")
160 if not (0 <= forecast_step <= max_hour):
161 raise ValueError(f"forecast_step must be between 0 and {max_hour}, the largest forecast step in input data.")
163 fstep='f'+str(forecast_step).zfill(2)
164 return fstep
166 def check_feat(feat, d):
167 if feat not in d:
168 raise ValueError(f"Feature {feat} not found")
170 def get_time(d, atm="HRRR"):
171 check_feat('time', d[atm])
172 time = str2time(d[atm]['time'])
173 return time
175 def get_static(d, hours):
176 cols = []
177 # Use all static vars, don't allow for missing
178 names = feature_types['static']
179 for feat in names:
180 check_feat(feat, d['loc'])
181 cols.append(np.full(hours,d['loc'][feat]))
182 return cols, names
184 def get_hrrr_atm(d, fstep):
185 cols = []
186 # Use all names, don't allow for missing
187 names = feature_types['atm'].copy()
188 for feat in names:
189 check_feat(feat, d["HRRR"][fstep])
190 v = d["HRRR"][fstep][feat]
191 cols.append(v)
192 return cols, names
194 def calc_time_features(time):
195 names = ['doy', 'hod']
196 doy = np.array([dt.timetuple().tm_yday - 1 for dt in time])
197 hod = time.astype('datetime64[h]').astype(int) % 24
198 cols = [doy, hod]
199 return cols, names
201 def calc_hrrr_rain(d, fstep, fprev):
202 # NOTE: if fstep and fprev are both f00, it will return all zeros which is what the f00 HRRR always is. If fprev is not hard coded as zero this might return nonsense, but not adding any checks yet for simplicity
203 rain = d["HRRR"][fstep]['precip_accum']- d["HRRR"][fprev]['precip_accum']
204 return rain
206 def get_raws_atm(d, time, check = True):
207 # may not be the same as requested time vector, used to interpolate to input time
208 time_raws=str2time(d['RAWS']['time_raws'])
210 cols = []
211 names = []
213 # Loop through all features, including rain
214 for feat in feature_types['atm']+['rain']:
215 if feat in d['RAWS']:
216 v = d['RAWS'][feat]
217 v = time_intp(time_raws, v, time)
218 assert len(v)==len(time), f"RAWS feature {feat} not equal length to input time: {len(v)} vs {len(time)}"
219 cols.append(v)
220 names.append(feat)
221 return cols, names
223 def build_features_single(subdict, atm ="HRRR", fstep=None, fprev=None):
224 # cols = []
225 # names = []
226 # Get Time variable
227 time = get_time(subdict)
228 # Calculate derived time variables
229 tvars, tnames = calc_time_features(time)
230 # Get Static Features, extends to hours
231 static_vars, static_names = get_static(subdict, hours = len(time))
232 # Get atmospheric variables based on data source. HRRR requires rain calculation
233 if atm == "HRRR":
234 atm_vars, atm_names = get_hrrr_atm(subdict, fstep)
235 rain = calc_hrrr_rain(subdict, fstep, fprev)
236 atm_vars.append(rain)
237 atm_names.append("rain")
238 elif atm == "RAWS":
239 atm_vars, atm_names = get_raws_atm(subdict, time)
240 else:
241 raise ValueError(f"Unrecognized atmospheric data source: {atm}")
242 # Put everything together and stack
243 cols = tvars + static_vars + atm_vars
244 X = np.column_stack(cols)
245 names = tnames + static_names + atm_names
247 return X, names
249 def get_fm(d, time):
250 fm = d['RAWS']['fm']
251 time_raws = str2time(d['RAWS']['time_raws'])
252 return time_intp(time_raws,fm,time)
255 def shift_time(X_array, inds, forecast_step):
257 Shifts specified columns of a numpy ndarray forward by a given number of steps.
259 Parameters:
260 ----------
261 X_array : numpy.ndarray
262 The input 2D array where specific columns will be shifted.
263 inds : list of int
264 Indices of the columns within X_array to be shifted.
265 forecast_step : int
266 The number of positions to shift the specified columns forward.
268 Returns:
269 -------
270 numpy.ndarray
271 A modified copy of the input array with specified columns shifted forward
272 by `forecast_step` and the leading positions in those columns filled with NaN.
274 Example:
275 -------
276 >>> X_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
277 >>> inds = [0, 2]
278 >>> shift_time(X_array, inds, 1)
279 array([[nan, 2., nan],
280 [ 1., 5., 3.],
281 [ 4., 8., 6.],
282 [ 7., 11., 9.]])
284 if not isinstance(forecast_step, int) or forecast_step <= 0:
285 raise ValueError("forecast_step must be an integer greater than 0.")
287 X_shift = X_array.astype(float).copy()
288 X_shift[:forecast_step, inds] = np.nan
289 X_shift[forecast_step:, inds] = X_array[:-forecast_step, inds]
290 return X_shift
294 def subset_by_features(nested_dict, input_features, verbose=True):
296 Subsets a nested dictionary to only include keys where all strings in the input_features
297 are present in the dictionary's 'features_list' subkey. Primarily used for RAWS dictionaries where desired features might not be present at all ground stations.
299 Parameters:
300 nested_dict (dict): The nested dictionary with a 'features_list' subkey.
301 input_features (list): The list of features to be checked.
303 Returns:
304 dict: A subset of the input dictionary with only the matching keys.
306 if verbose:
307 print(f"Subsetting to cases with features: {input_features}")
309 # Create a new dictionary to store the result
310 result = {}
312 # Iterate through the keys in the nested dictionary
313 for key, value in nested_dict.items():
314 # Check if 'features_list' key exists and all input_features are in the list
315 if 'features_list' in value and all(feature in value['features_list'] for feature in input_features):
316 # Add to the result if all features are present
317 result[key] = value
318 else:
319 print(f"Removing {key} due to missing features")
321 return result
325 def remove_key_list(d, ls, verbose=False):
326 for key in ls:
327 if key in d:
328 if verbose:
329 print(f"Removing key {key} due to data flags")
330 del d[key]
332 def split_timeseries(dict0, hours, naming_convention = "_set_", verbose=False):
334 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
336 cases = list([*dict0.keys()])
337 dict1={}
338 for key, data in dict0.items():
339 if verbose:
340 print(f"Processing case: {key}")
341 print(f"Length of y vector: {len(data['y'])}")
342 if type(data['time'][0]) == str:
343 time=str2time(data['time'])
344 else:
345 time=data['time']
346 X_array = data['X']
347 y_array = data['y']
349 # Determine the start and end time for the 720-hour portions
350 start_time = time[0]
351 end_time = time[-1]
352 current_time = start_time
353 portion_index = 1
354 while current_time < end_time:
355 next_time = current_time + timedelta(hours=hours)
357 # Create a mask for the time range
358 mask = (time >= current_time) & (time < next_time)
360 # Apply the mask to extract the portions
361 new_time = time[mask]
362 new_X = X_array[mask]
363 new_y = y_array[mask]
365 # Save the portions in the new dictionary with naming convention if second or more portion
366 if portion_index > 1:
367 new_key = f"{key}{naming_convention}{portion_index}"
368 else:
369 new_key = key
370 dict1[new_key] = {'time': new_time, 'X': new_X, 'y': new_y}
372 # Add other keys that aren't subset
373 for key2 in dict0[key].keys():
374 if key2 not in ['time', 'X', 'y']:
375 dict1[new_key][key2] = dict0[key][key2]
376 # Update Case name and Id (same for now, overloaded terminology)
377 dict1[new_key]['case'] = new_key
378 dict1[new_key]['id'] = new_key
379 dict1[new_key]['hours'] = len(dict1[new_key]['y'])
382 # Move to the next portion
383 current_time = next_time
384 portion_index += 1
385 if verbose:
386 print(f"Partitions of length {hours} from case {key}: {portion_index-1}")
387 return dict1
389 def flag_lag_stretches(x, threshold, lag = 1):
391 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
393 lags = np.round(np.diff(x, n=lag), 8)
394 zero_lag_indices = np.where(lags == 0)[0]
395 current_run_length = 1
396 for i in range(1, len(zero_lag_indices)):
397 if zero_lag_indices[i] == zero_lag_indices[i-1] + 1:
398 current_run_length += 1
399 if current_run_length > threshold:
400 return True
401 else:
402 current_run_length = 1
403 else:
404 return False
407 def flag_dict_keys(dict0, lag_1_threshold, lag_2_threshold, max_y, min_y, verbose=False):
409 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
411 cases = list([*dict0.keys()])
412 flags = np.zeros(len(cases))
413 for i, case in enumerate(cases):
414 if verbose:
415 print("~"*50)
416 print(f"Case: {case}")
417 y = dict0[case]['y']
418 if flag_lag_stretches(y, threshold=lag_1_threshold, lag=1):
419 if verbose:
420 print(f"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
421 flags[i]=1
422 if flag_lag_stretches(y, threshold=lag_2_threshold, lag=2):
423 if verbose:
424 print(f"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
425 flags[i]=1
426 if np.any(y>=max_y) or np.any(y<=min_y):
427 if verbose:
428 print(f"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
429 flags[i]=1
431 return flags
433 def discard_keys_with_short_y(input_dict, hours, verbose=False):
435 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
438 max_y = max(case['y'].shape[0] for case in input_dict.values())
439 if max_y < hours:
440 raise ValueError(f"Given input hours of {hours}, all cases in input dictionary would be filtered out as too short. Max timeseries y length: {max_y}. Try using a larger value for `hours` in data params")
442 discarded_keys = [key for key, value in input_dict.items() if len(value['y']) < hours]
444 if verbose:
445 print(f"Discarded keys due to y length less than {hours}: {discarded_keys}")
447 filtered_dict = {key: value for key, value in input_dict.items() if key not in discarded_keys}
449 return filtered_dict
453 # Utility to combine nested fmda dictionaries
454 def combine_nested(nested_input_dict, verbose=True):
456 Combines input data dictionaries.
458 Parameters:
459 -----------
460 verbose : bool, optional
461 If True, prints status messages. Default is True.
462 """
463 # Setup return dictionary
464 d = {}
465 # Use the helper function to populate the keys
466 d['id'] = _combine_key(nested_input_dict, 'id')
467 d['case'] = _combine_key(nested_input_dict, 'case')
468 d['filename'] = _combine_key(nested_input_dict, 'filename')
469 d['time'] = _combine_key(nested_input_dict, 'time')
470 d['X'] = _combine_key(nested_input_dict, 'X')
471 d['y'] = _combine_key(nested_input_dict, 'y')
472 d['atm_source'] = _combine_key(nested_input_dict, 'atm_source')
473 d['forecast_step'] = _combine_key(nested_input_dict, 'forecast_step')
475 # Build the loc subdictionary using _combine_key for each loc key
476 d['loc'] = {
477 'STID': _combine_key(nested_input_dict, 'loc', 'STID'),
478 'lat': _combine_key(nested_input_dict, 'loc', 'lat'),
479 'lon': _combine_key(nested_input_dict, 'loc', 'lon'),
480 'elev': _combine_key(nested_input_dict, 'loc', 'elev'),
481 'pixel_x': _combine_key(nested_input_dict, 'loc', 'pixel_x'),
482 'pixel_y': _combine_key(nested_input_dict, 'loc', 'pixel_y')
485 # Handle features_list separately with validation
486 features_list = _combine_key(nested_input_dict, 'features_list')
487 if features_list:
488 first_features_list = features_list[0]
489 for fl in features_list:
490 if fl != first_features_list:
491 warnings.warn("Different features_list found in the nested input dictionaries.")
492 d['features_list'] = first_features_list
494 return d
496 def _combine_key(nested_input_dict, key, subkey=None):
497 combined_list = []
498 for input_dict in nested_input_dict.values():
499 if isinstance(input_dict, dict):
500 try:
501 if subkey:
502 combined_list.append(input_dict[key][subkey])
503 else:
504 combined_list.append(input_dict[key])
505 except KeyError:
506 warning_message = f"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries. Setting value to None."
507 warnings.warn(warning_message)
508 combined_list.append(None)
509 else:
510 raise ValueError(f"Expected a dictionary, but got {type(input_dict)}")
511 return combined_list
514 def compare_dicts(dict1, dict2, keys):
515 for key in keys:
516 if dict1.get(key) != dict2.get(key):
517 return False
518 return True
520 items = '_items_' # dictionary key to keep list of items in
521 def check_data_array(dat,hours,a,s):
522 if a in dat[items]:
523 dat[items].remove(a)
524 if a in dat:
525 ar = dat[a]
526 print("array %s %s length %i min %s max %s hash %s %s" %
527 (a,s,len(ar),min(ar),max(ar),hash2(ar),type(ar)))
528 if hours is not None:
529 if len(ar) < hours:
530 print('len(%a) = %i does not equal to hours = %i' % (a,len(ar),hours))
531 exit(1)
532 else:
533 print(a + ' not present')
535 def check_data_scalar(dat,a):
536 if a in dat[items]:
537 dat[items].remove(a)
538 if a in dat:
539 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
540 else:
541 print(a + ' not present' )
543 def check_data(dat,case=True,name=None):
544 dat[items] = list(dat.keys()) # add list of items to the dictionary
545 if name is not None:
546 print(name)
547 if case:
548 check_data_scalar(dat,'filename')
549 check_data_scalar(dat,'title')
550 check_data_scalar(dat,'note')
551 check_data_scalar(dat,'hours')
552 check_data_scalar(dat,'h2')
553 check_data_scalar(dat,'case')
554 if 'hours' in dat:
555 hours = dat['hours']
556 else:
557 hours = None
558 check_data_array(dat,hours,'E','drying equilibrium (%)')
559 check_data_array(dat,hours,'Ed','drying equilibrium (%)')
560 check_data_array(dat,hours,'Ew','wetting equilibrium (%)')
561 check_data_array(dat,hours,'Ec','equilibrium equilibrium (%)')
562 check_data_array(dat,hours,'rain','rain intensity (mm/h)')
563 check_data_array(dat,hours,'fm','RAWS fuel moisture data (%)')
564 check_data_array(dat,hours,'m','fuel moisture estimate (%)')
565 if dat[items]:
566 print('items:',dat[items])
567 for a in dat[items].copy():
568 ar=dat[a]
569 if dat[a] is None or np.isscalar(dat[a]):
570 check_data_scalar(dat,a)
571 elif is_numeric_ndarray(ar):
572 print(type(ar))
573 print("array", a, "shape",ar.shape,"min",np.min(ar),
574 "max",np.max(ar),"hash",hash2(ar),"type",type(ar))
575 elif isinstance(ar, tf.Tensor):
576 print("array", a, "shape",ar.shape,"min",np.min(ar),
577 "max",np.max(ar),"type",type(ar))
578 else:
579 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
580 del dat[items] # clean up
582 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
583 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
584 def to_json(dic,filename):
585 # Write given dictionary as json file.
586 # This utility is used because the typical method fails on numpy.ndarray
587 # Inputs:
588 # dic: dictionary
589 # filename: (str) output json filename, expect a ".json" file extension
590 # Return: none
592 print('writing ',filename)
593 # check_data(dic)
594 new={}
595 for i in dic:
596 if type(dic[i]) is np.ndarray:
597 new[i]=dic[i].tolist() # because numpy.ndarray is not serializable
598 else:
599 new[i]=dic[i]
600 # print('i',type(new[i]))
601 new['filename']=filename
602 print('Hash: ', hash2(new))
603 json.dump(new,open(filename,'w'),indent=4)
605 def from_json(filename):
606 # Read json file given a filename
607 # Inputs: filename (str) expect a ".json" string
609 print('reading ',filename)
610 dic=json.load(open(filename,'r'))
611 new={}
612 for i in dic:
613 if type(dic[i]) is list:
614 new[i]=np.array(dic[i]) # because ndarray is not serializable
615 else:
616 new[i]=dic[i]
617 check_data(new)
618 print('Hash: ', hash2(new))
619 return new
621 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
623 # Function to simulate moisture data and equilibrium for model testing
624 def create_synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,DeltaE=0.0):
625 hours = days*24
626 h2 = int(hours/2)
627 hour = np.array(range(hours))
628 day = np.array(range(hours))/24.
630 # artificial equilibrium data
631 E = 100.0*np.power(np.sin(np.pi*day),4) # diurnal curve
632 E = 0.05+0.25*E
633 # FMC free run
634 m_f = np.zeros(hours)
635 m_f[0] = 0.1 # initial FMC
636 process_noise=0.
637 for t in range(hours-1):
638 m_f[t+1] = max(0.,model_decay(m_f[t],E[t]) + random.gauss(0,process_noise) )
639 data = m_f + np.random.normal(loc=0,scale=data_noise,size=hours)
640 E = E + DeltaE
641 return E,m_f,data,hour,h2,DeltaE
643 # the following input or output dictionary with all model data and variables
645 def synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,
646 DeltaE=0.0,Emin=5,Emax=30,p_rain=0.01,max_rain=10.0):
647 hours = days*24
648 h2 = int(hours/2)
649 hour = np.array(range(hours))
650 day = np.array(range(hours))/24.
651 # artificial equilibrium data
652 E = np.power(np.sin(np.pi*day),power) # diurnal curve betwen 0 and 1
653 E = Emin+(Emax - Emin)*E
654 E = E + DeltaE
655 Ed=E+0.5
656 Ew=np.maximum(E-0.5,0)
657 rain = np.multiply(rand(hours) < p_rain, rand(hours)*max_rain)
658 # FMC free run
659 fm = np.zeros(hours)
660 fm[0] = 0.1 # initial FMC
661 # process_noise=0.
662 for t in range(hours-1):
663 fm[t+1] = max(0.,model_moisture(fm[t],Ed[t-1],Ew[t-1],rain[t-1]) + random.gauss(0,process_noise))
664 fm = fm + np.random.normal(loc=0,scale=data_noise,size=hours)
665 dat = {'E':E,'Ew':Ew,'Ed':Ed,'fm':fm,'hours':hours,'h2':h2,'DeltaE':DeltaE,'rain':rain,'title':'Synthetic data'}
667 return dat
669 def plot_one(hmin,hmax,dat,name,linestyle,c,label, alpha=1,type='plot'):
670 # helper for plot_data
671 if name in dat:
672 h = len(dat[name])
673 if hmin is None:
674 hmin=0
675 if hmax is None:
676 hmax=len(dat[name])
677 hour = np.array(range(hmin,hmax))
678 if type=='plot':
679 plt.plot(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
680 elif type=='scatter':
681 plt.scatter(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
683 # Lookup table for plotting features
684 plot_styles = {
685 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
686 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
687 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
689 def plot_feature(x, y, feature_name):
690 style = plot_styles.get(feature_name, {})
691 plt.plot(x, y, **style)
693 def plot_features(hmin,hmax,dat,linestyle,c,label,alpha=1):
694 hour = np.array(range(hmin,hmax))
695 for feat in dat.features_list:
696 i = dat.all_features_list.index(feat) # index of main data
697 if feat in plot_styles.keys():
698 plot_feature(x=hour, y=dat['X'][:,i][hmin:hmax], feature_name=feat)
700 def plot_data(dat, plot_period='all', create_figure=False,title=None,title2=None,hmin=0,hmax=None,xlabel=None,ylabel=None):
701 # Plot fmda dictionary of data and model if present
702 # Inputs:
703 # dat: FMDA dictionary
704 # inverse_scale: logical, whether to inverse scale data
705 # Returns: none
707 # dat = copy.deepcopy(dat0)
709 if 'hours' in dat:
710 if hmax is None:
711 hmax = dat['hours']
712 else:
713 hmax = min(hmax, dat['hours'])
714 if plot_period == "all":
715 pass
716 elif plot_period == "predict":
717 assert "test_ind" in dat.keys()
718 hmin = dat['test_ind']
720 else:
721 raise ValueError(f"unrecognized time period for plotting plot_period: {plot_period}")
724 if create_figure:
725 plt.figure(figsize=(16,4))
727 plot_one(hmin,hmax,dat,'y',linestyle='-',c='#468a29',label='FM Observed')
728 plot_one(hmin,hmax,dat,'m',linestyle='-',c='k',label='FM Model')
729 plot_features(hmin,hmax,dat,linestyle='-',c='k',label='FM Model')
732 if 'test_ind' in dat.keys():
733 test_ind = dat["test_ind"]
734 else:
735 test_ind = None
736 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
737 # Note: the code within the tildes here makes a more complex, annotated plot
738 if (test_ind is not None) and ('m' in dat.keys()):
739 plt.axvline(test_ind, linestyle=':', c='k', alpha=.8)
740 yy = plt.ylim() # used to format annotations
741 plot_y0 = np.max([hmin, test_ind]) # Used to format annotations
742 plot_y1 = np.min([hmin, test_ind])
743 plt.annotate('', xy=(hmin, yy[0]),xytext=(plot_y0,yy[0]),
744 arrowprops=dict(arrowstyle='<-', linewidth=2),
745 annotation_clip=False)
746 plt.annotate('(Training)',xy=((hmin+plot_y0)/2,yy[1]),xytext=((hmin+plot_y0)/2,yy[1]+1), ha = 'right',
747 annotation_clip=False, alpha=.8)
748 plt.annotate('', xy=(plot_y0, yy[0]),xytext=(hmax,yy[0]),
749 arrowprops=dict(arrowstyle='<-', linewidth=2),
750 annotation_clip=False)
751 plt.annotate('(Forecast)',xy=(hmax-(hmax-test_ind)/2,yy[1]),
752 xytext=(hmax-(hmax-test_ind)/2,yy[1]+1),
753 annotation_clip=False, alpha=.8)
754 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
757 if title is not None:
758 t = title
759 elif 'title' in dat:
760 t=dat['title']
761 # print('title',type(t),t)
762 else:
763 t=''
764 if title2 is not None:
765 t = t + ' ' + title2
766 t = t + ' (' + rmse_data_str(dat)+')'
767 if plot_period == "predict":
768 t = t + " - Forecast Period"
769 plt.title(t, y=1.1)
771 if xlabel is None:
772 plt.xlabel('Time (hours)')
773 else:
774 plt.xlabel(xlabel)
775 if 'rain' in dat:
776 plt.ylabel('FM (%) / Rain (mm/h)')
777 elif ylabel is None:
778 plt.ylabel('Fuel moisture content (%)')
779 else:
780 plt.ylabel(ylabel)
781 plt.legend(loc="upper left")
783 def rmse(a, b):
784 return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))
786 def rmse_skip_nan(x, y):
787 mask = ~np.isnan(x) & ~np.isnan(y)
788 if np.count_nonzero(mask):
789 return np.sqrt(np.mean((x[mask] - y[mask]) ** 2))
790 else:
791 return np.nan
793 def rmse_str(a,b):
794 rmse = rmse_skip_nan(a,b)
795 return "RMSE " + "{:.3f}".format(rmse)
797 def rmse_data_str(dat, predict=True, hours = None, test_ind = None):
798 # Return RMSE for model object in formatted string. Used within plotting
799 # Inputs:
800 # dat: (dict) fmda dictionary
801 # predict: (bool) Whether to return prediction period RMSE. Default True
802 # hours: (int) total number of modeled time periods
803 # test_ind: (int) start of test period
804 # Return: (str) RMSE value
806 if hours is None:
807 if 'hours' in dat:
808 hours = dat['hours']
809 if test_ind is None:
810 if 'test_ind' in dat:
811 test_ind = dat['test_ind']
813 if 'm' in dat and 'y' in dat:
814 if predict and hours is not None and test_ind is not None:
815 return rmse_str(dat['m'].flatten()[test_ind:hours],dat['y'].flatten()[test_ind:hours])
816 else:
817 return rmse_str(dat['m'].flatten(),dat['y'].flatten())
818 else:
819 return ''
822 # Calculate mean absolute error
823 def mape(a, b):
824 return ((a - b).__abs__()).mean()
826 def rmse_data(dat, hours = None, h2 = None, simulation='m', measurements='fm'):
827 if hours is None:
828 hours = dat['hours']
829 if h2 is None:
830 h2 = dat['h2']
832 m = dat[simulation]
833 fm = dat[measurements]
834 case = dat['case']
836 train =rmse(m[:h2], fm[:h2])
837 predict = rmse(m[h2:hours], fm[h2:hours])
838 all = rmse(m[:hours], fm[:hours])
839 print(case,'Training 1 to',h2,'hours RMSE: ' + str(np.round(train, 4)))
840 print(case,'Prediction',h2+1,'to',hours,'hours RMSE: ' + str(np.round(predict, 4)))
841 print(f"All predictions hash: {hash2(m)}")
843 return {'train':train, 'predict':predict, 'all':all}
847 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
851 def get_file(filename, data_dir='data'):
852 # Check for file locally, retrieve with wget if not
853 if osp.exists(osp.join(data_dir, filename)):
854 print(f"File {osp.join(data_dir, filename)} exists locally")
855 elif not osp.exists(filename):
856 import subprocess
857 base_url = "https://demo.openwfm.org/web/data/fmda/dicts/"
858 print(f"Retrieving data {osp.join(base_url, filename)}")
859 subprocess.call(f"wget -P {data_dir} {osp.join(base_url, filename)}", shell=True)