Update data_funcs.py
[notebooks.git] / fmda / data_funcs.py
blobafc390355b026b90e7f9f8f752ad78c4047d4fe9
1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy as np, random
5 from numpy.random import rand
6 import tensorflow as tf
7 import pickle, os
8 from sklearn.metrics import mean_squared_error
9 import matplotlib.pyplot as plt
10 from moisture_models import model_decay, model_moisture
11 from datetime import datetime, timedelta
12 from utils import is_numeric_ndarray, hash2
13 import json
14 import copy
15 import subprocess
16 import os.path as osp
17 from utils import Dict, str2time, check_increment, time_intp
19 def process_train_dict(input_file_paths, params_data, atm_dict = "HRRR", verbose=False):
20 if type(input_file_paths) is not list:
21 raise ValueError(f"Argument `input_file_paths` must be list, received {type(input_file_paths)}")
22 train = {}
23 for file_path in input_file_paths:
24 # Extract target and features
25 di = build_train_dict(file_path, atm=atm_dict, features_all=params_data['features_all'], verbose=verbose)
26 # Subset timeseries into shorter stretches
27 di = split_timeseries(di, hours=params_data['hours'], verbose=verbose)
28 di = discard_keys_with_short_y(di, hours=params_data['hours'], verbose=False)
29 # Check for suspect data
30 flags = flag_dict_keys(di, params_data['zero_lag_threshold'], params_data['max_intp_time'], max_y = params_data['max_fm'], min_y = params_data['min_fm'], verbose=verbose)
31 # Remove flagged cases
32 cases = list([*di.keys()])
33 flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
34 remove_key_list(di, flagged_cases, verbose=verbose)
35 train.update(di)
36 return train
39 feature_types = {
40 # Static features are based on physical location, e.g. location of RAWS site
41 'static': ['elev', 'lon', 'lat'],
42 # Atmospheric weather features come from either RAWS subdict or HRRR
43 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew']
46 def build_train_dict(input_file_path,
47 forecast_step=1, atm="HRRR",features_all=['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain'], verbose=False):
48 # in:
49 # file_path list of strings - files as in read_test_pkl
50 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
51 # atm str - name of subdict where atmospheric vars are located
52 # features_list list of strings - names of keys in subdicts to collect into features matrix. Default is everything collected
53 # return:
54 # train dictionary with structure
55 # {key : {'key' : key, # copied subdict key
56 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
57 # 'time' : time, # datetime vector, spacing tres
58 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
59 # 'Y' : feat # features from atmosphere and location
64 # TODO: fix this
65 if 'rain' in features_all and (not features_all[-1]=='rain'):
66 raise ValueError(f"Make rain in features list last element since (working on fix as of 24-6-24), given features list: {features_list}")
68 if forecast_step > 0 and forecast_step < 100 and forecast_step == int(forecast_step):
69 fstep='f'+str(forecast_step).zfill(2)
70 fprev='f'+str(forecast_step-1).zfill(2)
71 # logging.info('Using data from step %s',fstep)
72 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
73 else:
74 # logging.critical('forecast_step must be integer between 1 and 99')
75 raise ValueError('bad forecast_step')
77 train = {}
78 with open(input_file_path, 'rb') as file:
79 # logging.info("loading file %s", file_path)
80 d = pickle.load(file)
81 for key in d:
82 atm_dict = atm
83 features_list = features_all
84 # logging.info('Processing subdictionary %s',key)
85 if key in train:
86 pass
87 # logging.warning('skipping duplicate key %s',key)
88 else:
89 subdict=d[key] # subdictionary for this case
90 loc=subdict['loc']
91 train[key] = {
92 'id': key, # store the key inside the dictionary, subdictionary will be used separatedly
93 'case':key,
94 'filename': input_file_path,
95 'loc': loc
97 desc='descr'
98 if desc in subdict:
99 train[desc]=subdict[desc]
100 time_hrrr=str2time(subdict[atm_dict]['time'])
101 # timekeeping
102 hours=len(d[key][atm_dict]['time'])
103 train[key]['hours']=hours
104 # train[key]['h2'] =hours # not doing prediction yet
105 hrrr_increment = check_increment(time_hrrr,id=key+f' {atm_dict}.time')
106 # logging.info(f'{atm_dict} increment is %s h',hrrr_increment)
107 if hrrr_increment < 1:
108 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
109 raise(ValueError)
111 # build matrix of features - assuming all the same length, if not column_stack will fail
112 train[key]['time']=time_hrrr
113 # logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
114 time_raws=str2time(subdict['RAWS']['time_raws']) # may not be the same as HRRR
115 # logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
116 check_increment(time_raws,id=key+' RAWS.time_raws')
117 # print_first(time_raws,num=5,id='RAWS.time_raws')
119 # Set up static vars
120 columns=[]
121 missing_features = []
122 for feat in features_list:
123 # For atmospheric features,
124 if feat in feature_types['atm']:
125 if atm_dict == "HRRR":
126 vec = subdict['HRRR'][fstep][feat]
127 columns.append(vec)
128 elif atm_dict == "RAWS":
129 if feat in subdict['RAWS'].keys():
130 vec = time_intp(time_raws, subdict['RAWS'][feat], time_hrrr)
131 columns.append(vec)
132 else:
133 missing_features.append(feat)
135 # For static features, repeat to fit number of time observations
136 elif feat in feature_types['static']:
137 columns.append(np.full(hours,loc[feat]))
138 # compute rain as difference of accumulated precipitation
139 if 'rain' in features_list:
140 if atm_dict == "HRRR":
141 rain = subdict[atm_dict][fstep]['precip_accum']- subdict[atm_dict][fprev]['precip_accum']
142 # logging.info('%s rain as difference %s minus %s: min %s max %s',
143 # key,fstep,fprev,np.min(rain),np.max(rain))
144 elif atm_dict == "RAWS":
145 if 'rain' in subdict[atm_dict]:
146 rain = time_intp(time_raws,subdict[atm_dict]['rain'],time_hrrr)
147 else:
148 pass
149 # logging.info('No rain data found in RAWS subdictionary %s', key)
150 columns.append( rain ) # add rain feature
151 else:
152 missing_features.append('rain')
154 train[key]['X'] = np.column_stack(columns)
155 train[key]['features_list'] = [item for item in features_list if item not in missing_features]
157 fm=subdict['RAWS']['fm']
158 # logging.info('%s RAWS.fm length is %s',key,len(fm))
159 # interpolate RAWS sensors to HRRR time and over NaNs
160 train[key]['y'] = time_intp(time_raws,fm,time_hrrr)
161 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
163 if train[key]['y'] is None:
164 pass
165 # logging.error('Cannot create target matrix for %s, using None',key)
166 else:
167 pass
168 # logging.info(f"Created target matrix train[{key}]['y'] shape {train[key]['y'].shape}")
170 # logging.info('Created a "train" dictionary with %s items',len(train))
172 # clean up
174 keys_to_delete = []
175 for key in train:
176 if train[key]['X'] is None or train[key]['y'] is None:
177 # logging.warning('Deleting training item %s because features X or target Y are None', key)
178 keys_to_delete.append(key)
180 # Delete the items from the dictionary
181 if len(keys_to_delete)>0:
182 for key in keys_to_delete:
183 del train[key]
184 # logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
185 # len(keys_to_delete),len(train))
187 # output
189 # if output_file_path is not None:
190 # with open(output_file_path, 'wb') as file:
191 # logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
192 # pickle.dump(train, file)
194 # logging.info('pkl2train done')
196 return train
199 def remove_key_list(d, ls, verbose=False):
200 for key in ls:
201 if key in d:
202 if verbose:
203 print(f"Removing key {key} due to data flags")
204 del d[key]
206 def split_timeseries(dict0, hours, naming_convention = "_set_", verbose=False):
208 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
210 cases = list([*dict0.keys()])
211 dict1={}
212 for key, data in dict0.items():
213 if verbose:
214 print(f"Processing case: {key}")
215 print(f"Length of y vector: {len(data['y'])}")
216 if type(data['time'][0]) == str:
217 time=str2time(data['time'])
218 else:
219 time=data['time']
220 X_array = data['X']
221 y_array = data['y']
223 # Determine the start and end time for the 720-hour portions
224 start_time = time[0]
225 end_time = time[-1]
226 current_time = start_time
227 portion_index = 1
228 while current_time < end_time:
229 next_time = current_time + timedelta(hours=hours)
231 # Create a mask for the time range
232 mask = (time >= current_time) & (time < next_time)
234 # Apply the mask to extract the portions
235 new_time = time[mask]
236 new_X = X_array[mask]
237 new_y = y_array[mask]
239 # Save the portions in the new dictionary with naming convention if second or more portion
240 if portion_index > 1:
241 new_key = f"{key}{naming_convention}{portion_index}"
242 else:
243 new_key = key
244 dict1[new_key] = {'time': new_time, 'X': new_X, 'y': new_y}
246 # Add other keys that aren't subset
247 for key2 in dict0[key].keys():
248 if key2 not in ['time', 'X', 'y']:
249 dict1[new_key][key2] = dict0[key][key2]
250 # Update Case name and Id (same for now, overloaded terminology)
251 dict1[new_key]['case'] = new_key
252 dict1[new_key]['id'] = new_key
253 dict1[new_key]['hours'] = len(dict1[new_key]['y'])
256 # Move to the next portion
257 current_time = next_time
258 portion_index += 1
259 if verbose:
260 print(f"Partitions of length {hours} from case {key}: {portion_index-1}")
261 return dict1
263 def flag_lag_stretches(x, threshold, lag = 1):
265 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
267 lags = np.round(np.diff(x, n=lag), 8)
268 zero_lag_indices = np.where(lags == 0)[0]
269 current_run_length = 1
270 for i in range(1, len(zero_lag_indices)):
271 if zero_lag_indices[i] == zero_lag_indices[i-1] + 1:
272 current_run_length += 1
273 if current_run_length > threshold:
274 return True
275 else:
276 current_run_length = 1
277 else:
278 return False
281 def flag_dict_keys(dict0, lag_1_threshold, lag_2_threshold, max_y, min_y, verbose=False):
283 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
285 cases = list([*dict0.keys()])
286 flags = np.zeros(len(cases))
287 for i, case in enumerate(cases):
288 if verbose:
289 print("~"*50)
290 print(f"Case: {case}")
291 y = dict0[case]['y']
292 if flag_lag_stretches(y, threshold=lag_1_threshold, lag=1):
293 if verbose:
294 print(f"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
295 flags[i]=1
296 if flag_lag_stretches(y, threshold=lag_2_threshold, lag=2):
297 if verbose:
298 print(f"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
299 flags[i]=1
300 if np.any(y>=max_y) or np.any(y<=min_y):
301 if verbose:
302 print(f"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
303 flags[i]=1
305 return flags
307 def discard_keys_with_short_y(input_dict, hours, verbose=False):
309 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
311 discarded_keys = [key for key, value in input_dict.items() if len(value['y']) < hours]
313 if verbose:
314 print(f"Discarded keys due to y length less than {hours}: {discarded_keys}")
316 filtered_dict = {key: value for key, value in input_dict.items() if key not in discarded_keys}
318 return filtered_dict
322 # Utility to combine nested fmda dictionaries
323 def combine_nested(nested_input_dict, verbose=True):
325 Combines input data dictionaries.
327 Parameters:
328 -----------
329 verbose : bool, optional
330 If True, prints status messages. Default is True.
331 """
332 # Setup return dictionary
333 d = {}
334 # Use the helper function to populate the keys
335 d['id'] = _combine_key(nested_input_dict, 'id')
336 d['case'] = _combine_key(nested_input_dict, 'case')
337 d['filename'] = _combine_key(nested_input_dict, 'filename')
338 d['time'] = _combine_key(nested_input_dict, 'time')
339 d['X'] = _combine_key(nested_input_dict, 'X')
340 d['y'] = _combine_key(nested_input_dict, 'y')
342 # Build the loc subdictionary using _combine_key for each loc key
343 d['loc'] = {
344 'STID': _combine_key(nested_input_dict, 'loc', 'STID'),
345 'lat': _combine_key(nested_input_dict, 'loc', 'lat'),
346 'lon': _combine_key(nested_input_dict, 'loc', 'lon'),
347 'elev': _combine_key(nested_input_dict, 'loc', 'elev'),
348 'pixel_x': _combine_key(nested_input_dict, 'loc', 'pixel_x'),
349 'pixel_y': _combine_key(nested_input_dict, 'loc', 'pixel_y')
352 # Handle features_list separately with validation
353 features_list = _combine_key(nested_input_dict, 'features_list')
354 if features_list:
355 first_features_list = features_list[0]
356 for fl in features_list:
357 if fl != first_features_list:
358 warnings.warn("Different features_list found in the nested input dictionaries.")
359 d['features_list'] = first_features_list
361 return d
363 def _combine_key(nested_input_dict, key, subkey=None):
364 combined_list = []
365 for input_dict in nested_input_dict.values():
366 if isinstance(input_dict, dict):
367 try:
368 if subkey:
369 combined_list.append(input_dict[key][subkey])
370 else:
371 combined_list.append(input_dict[key])
372 except KeyError:
373 raise ValueError(f"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries")
374 else:
375 raise ValueError(f"Expected a dictionary, but got {type(input_dict)}")
376 return combined_list
379 def compare_dicts(dict1, dict2, keys):
380 for key in keys:
381 if dict1.get(key) != dict2.get(key):
382 return False
383 return True
385 items = '_items_' # dictionary key to keep list of items in
386 def check_data_array(dat,hours,a,s):
387 if a in dat[items]:
388 dat[items].remove(a)
389 if a in dat:
390 ar = dat[a]
391 print("array %s %s length %i min %s max %s hash %s %s" %
392 (a,s,len(ar),min(ar),max(ar),hash2(ar),type(ar)))
393 if hours is not None:
394 if len(ar) < hours:
395 print('len(%a) = %i does not equal to hours = %i' % (a,len(ar),hours))
396 exit(1)
397 else:
398 print(a + ' not present')
400 def check_data_scalar(dat,a):
401 if a in dat[items]:
402 dat[items].remove(a)
403 if a in dat:
404 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
405 else:
406 print(a + ' not present' )
408 def check_data(dat,case=True,name=None):
409 dat[items] = list(dat.keys()) # add list of items to the dictionary
410 if name is not None:
411 print(name)
412 if case:
413 check_data_scalar(dat,'filename')
414 check_data_scalar(dat,'title')
415 check_data_scalar(dat,'note')
416 check_data_scalar(dat,'hours')
417 check_data_scalar(dat,'h2')
418 check_data_scalar(dat,'case')
419 if 'hours' in dat:
420 hours = dat['hours']
421 else:
422 hours = None
423 check_data_array(dat,hours,'E','drying equilibrium (%)')
424 check_data_array(dat,hours,'Ed','drying equilibrium (%)')
425 check_data_array(dat,hours,'Ew','wetting equilibrium (%)')
426 check_data_array(dat,hours,'Ec','equilibrium equilibrium (%)')
427 check_data_array(dat,hours,'rain','rain intensity (mm/h)')
428 check_data_array(dat,hours,'fm','RAWS fuel moisture data (%)')
429 check_data_array(dat,hours,'m','fuel moisture estimate (%)')
430 if dat[items]:
431 print('items:',dat[items])
432 for a in dat[items].copy():
433 ar=dat[a]
434 if dat[a] is None or np.isscalar(dat[a]):
435 check_data_scalar(dat,a)
436 elif is_numeric_ndarray(ar):
437 print(type(ar))
438 print("array", a, "shape",ar.shape,"min",np.min(ar),
439 "max",np.max(ar),"hash",hash2(ar),"type",type(ar))
440 elif isinstance(ar, tf.Tensor):
441 print("array", a, "shape",ar.shape,"min",np.min(ar),
442 "max",np.max(ar),"type",type(ar))
443 else:
444 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
445 del dat[items] # clean up
447 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
448 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
449 def to_json(dic,filename):
450 # Write given dictionary as json file.
451 # This utility is used because the typical method fails on numpy.ndarray
452 # Inputs:
453 # dic: dictionary
454 # filename: (str) output json filename, expect a ".json" file extension
455 # Return: none
457 print('writing ',filename)
458 # check_data(dic)
459 new={}
460 for i in dic:
461 if type(dic[i]) is np.ndarray:
462 new[i]=dic[i].tolist() # because numpy.ndarray is not serializable
463 else:
464 new[i]=dic[i]
465 # print('i',type(new[i]))
466 new['filename']=filename
467 print('Hash: ', hash2(new))
468 json.dump(new,open(filename,'w'),indent=4)
470 def from_json(filename):
471 # Read json file given a filename
472 # Inputs: filename (str) expect a ".json" string
474 print('reading ',filename)
475 dic=json.load(open(filename,'r'))
476 new={}
477 for i in dic:
478 if type(dic[i]) is list:
479 new[i]=np.array(dic[i]) # because ndarray is not serializable
480 else:
481 new[i]=dic[i]
482 check_data(new)
483 print('Hash: ', hash2(new))
484 return new
486 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488 # Function to simulate moisture data and equilibrium for model testing
489 def create_synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,DeltaE=0.0):
490 hours = days*24
491 h2 = int(hours/2)
492 hour = np.array(range(hours))
493 day = np.array(range(hours))/24.
495 # artificial equilibrium data
496 E = 100.0*np.power(np.sin(np.pi*day),4) # diurnal curve
497 E = 0.05+0.25*E
498 # FMC free run
499 m_f = np.zeros(hours)
500 m_f[0] = 0.1 # initial FMC
501 process_noise=0.
502 for t in range(hours-1):
503 m_f[t+1] = max(0.,model_decay(m_f[t],E[t]) + random.gauss(0,process_noise) )
504 data = m_f + np.random.normal(loc=0,scale=data_noise,size=hours)
505 E = E + DeltaE
506 return E,m_f,data,hour,h2,DeltaE
508 # the following input or output dictionary with all model data and variables
510 def synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,
511 DeltaE=0.0,Emin=5,Emax=30,p_rain=0.01,max_rain=10.0):
512 hours = days*24
513 h2 = int(hours/2)
514 hour = np.array(range(hours))
515 day = np.array(range(hours))/24.
516 # artificial equilibrium data
517 E = np.power(np.sin(np.pi*day),power) # diurnal curve betwen 0 and 1
518 E = Emin+(Emax - Emin)*E
519 E = E + DeltaE
520 Ed=E+0.5
521 Ew=np.maximum(E-0.5,0)
522 rain = np.multiply(rand(hours) < p_rain, rand(hours)*max_rain)
523 # FMC free run
524 fm = np.zeros(hours)
525 fm[0] = 0.1 # initial FMC
526 # process_noise=0.
527 for t in range(hours-1):
528 fm[t+1] = max(0.,model_moisture(fm[t],Ed[t-1],Ew[t-1],rain[t-1]) + random.gauss(0,process_noise))
529 fm = fm + np.random.normal(loc=0,scale=data_noise,size=hours)
530 dat = {'E':E,'Ew':Ew,'Ed':Ed,'fm':fm,'hours':hours,'h2':h2,'DeltaE':DeltaE,'rain':rain,'title':'Synthetic data'}
532 return dat
534 def plot_one(hmin,hmax,dat,name,linestyle,c,label, alpha=1,type='plot'):
535 # helper for plot_data
536 if name in dat:
537 h = len(dat[name])
538 if hmin is None:
539 hmin=0
540 if hmax is None:
541 hmax=len(dat[name])
542 hour = np.array(range(hmin,hmax))
543 if type=='plot':
544 plt.plot(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
545 elif type=='scatter':
546 plt.scatter(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
548 # Lookup table for plotting features
549 plot_styles = {
550 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
551 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
552 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
554 def plot_feature(x, y, feature_name):
555 style = plot_styles.get(feature_name, {})
556 plt.plot(x, y, **style)
558 def plot_features(hmin,hmax,dat,linestyle,c,label,alpha=1):
559 hour = np.array(range(hmin,hmax))
560 for feat in dat.features_list:
561 i = dat.all_features_list.index(feat) # index of main data
562 if feat in plot_styles.keys():
563 plot_feature(x=hour, y=dat['X'][:,i][hmin:hmax], feature_name=feat)
565 def plot_data(dat, plot_period='all', create_figure=False,title=None,title2=None,hmin=0,hmax=None,xlabel=None,ylabel=None):
566 # Plot fmda dictionary of data and model if present
567 # Inputs:
568 # dat: FMDA dictionary
569 # inverse_scale: logical, whether to inverse scale data
570 # Returns: none
572 # dat = copy.deepcopy(dat0)
574 if 'hours' in dat:
575 if hmax is None:
576 hmax = dat['hours']
577 else:
578 hmax = min(hmax, dat['hours'])
579 if plot_period == "all":
580 pass
581 elif plot_period == "predict":
582 assert "test_ind" in dat.keys()
583 hmin = dat['test_ind']
585 else:
586 raise ValueError(f"unrecognized time period for plotting plot_period: {plot_period}")
589 if create_figure:
590 plt.figure(figsize=(16,4))
592 plot_one(hmin,hmax,dat,'y',linestyle='-',c='#468a29',label='FM Observed')
593 plot_one(hmin,hmax,dat,'m',linestyle='-',c='k',label='FM Model')
594 plot_features(hmin,hmax,dat,linestyle='-',c='k',label='FM Model')
597 if 'test_ind' in dat.keys():
598 test_ind = dat["test_ind"]
599 else:
600 test_ind = None
601 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
602 # Note: the code within the tildes here makes a more complex, annotated plot
603 if (test_ind is not None) and ('m' in dat.keys()):
604 plt.axvline(test_ind, linestyle=':', c='k', alpha=.8)
605 yy = plt.ylim() # used to format annotations
606 plot_y0 = np.max([hmin, test_ind]) # Used to format annotations
607 plot_y1 = np.min([hmin, test_ind])
608 plt.annotate('', xy=(hmin, yy[0]),xytext=(plot_y0,yy[0]),
609 arrowprops=dict(arrowstyle='<-', linewidth=2),
610 annotation_clip=False)
611 plt.annotate('(Training)',xy=((hmin+plot_y0)/2,yy[1]),xytext=((hmin+plot_y0)/2,yy[1]+1), ha = 'right',
612 annotation_clip=False, alpha=.8)
613 plt.annotate('', xy=(plot_y0, yy[0]),xytext=(hmax,yy[0]),
614 arrowprops=dict(arrowstyle='<-', linewidth=2),
615 annotation_clip=False)
616 plt.annotate('(Forecast)',xy=(hmax-(hmax-test_ind)/2,yy[1]),
617 xytext=(hmax-(hmax-test_ind)/2,yy[1]+1),
618 annotation_clip=False, alpha=.8)
619 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
622 if title is not None:
623 t = title
624 elif 'title' in dat:
625 t=dat['title']
626 # print('title',type(t),t)
627 else:
628 t=''
629 if title2 is not None:
630 t = t + ' ' + title2
631 t = t + ' (' + rmse_data_str(dat)+')'
632 if plot_period == "predict":
633 t = t + " - Forecast Period"
634 plt.title(t, y=1.1)
636 if xlabel is None:
637 plt.xlabel('Time (hours)')
638 else:
639 plt.xlabel(xlabel)
640 if 'rain' in dat:
641 plt.ylabel('FM (%) / Rain (mm/h)')
642 elif ylabel is None:
643 plt.ylabel('Fuel moisture content (%)')
644 else:
645 plt.ylabel(ylabel)
646 plt.legend(loc="upper left")
648 def rmse(a, b):
649 return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))
651 def rmse_skip_nan(x, y):
652 mask = ~np.isnan(x) & ~np.isnan(y)
653 if np.count_nonzero(mask):
654 return np.sqrt(np.mean((x[mask] - y[mask]) ** 2))
655 else:
656 return np.nan
658 def rmse_str(a,b):
659 rmse = rmse_skip_nan(a,b)
660 return "RMSE " + "{:.3f}".format(rmse)
662 def rmse_data_str(dat, predict=True, hours = None, test_ind = None):
663 # Return RMSE for model object in formatted string. Used within plotting
664 # Inputs:
665 # dat: (dict) fmda dictionary
666 # predict: (bool) Whether to return prediction period RMSE. Default True
667 # hours: (int) total number of modeled time periods
668 # test_ind: (int) start of test period
669 # Return: (str) RMSE value
671 if hours is None:
672 if 'hours' in dat:
673 hours = dat['hours']
674 if test_ind is None:
675 if 'test_ind' in dat:
676 test_ind = dat['test_ind']
678 if 'm' in dat and 'y' in dat:
679 if predict and hours is not None and test_ind is not None:
680 return rmse_str(dat['m'][test_ind:hours],dat['y'].flatten()[test_ind:hours])
681 else:
682 return rmse_str(dat['m'],dat['y'].flatten())
683 else:
684 return ''
687 # Calculate mean absolute error
688 def mape(a, b):
689 return ((a - b).__abs__()).mean()
691 def rmse_data(dat, hours = None, h2 = None, simulation='m', measurements='fm'):
692 if hours is None:
693 hours = dat['hours']
694 if h2 is None:
695 h2 = dat['h2']
697 m = dat[simulation]
698 fm = dat[measurements]
699 case = dat['case']
701 train =rmse(m[:h2], fm[:h2])
702 predict = rmse(m[h2:hours], fm[h2:hours])
703 all = rmse(m[:hours], fm[:hours])
704 print(case,'Training 1 to',h2,'hours RMSE: ' + str(np.round(train, 4)))
705 print(case,'Prediction',h2+1,'to',hours,'hours RMSE: ' + str(np.round(predict, 4)))
706 print(f"All predictions hash: {hash2(m)}")
708 return {'train':train, 'predict':predict, 'all':all}
712 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
716 def get_file(filename, data_dir='data'):
717 # Check for file locally, retrieve with wget if not
718 if osp.exists(osp.join(data_dir, filename)):
719 print(f"File {osp.join(data_dir, filename)} exists locally")
720 elif not osp.exists(filename):
721 import subprocess
722 base_url = "https://demo.openwfm.org/web/data/fmda/dicts/"
723 print(f"Retrieving data {osp.join(base_url, filename)}")
724 subprocess.call(f"wget -P {data_dir} {osp.join(base_url, filename)}", shell=True)