1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy
as np
, random
5 from numpy
.random
import rand
6 import tensorflow
as tf
8 from sklearn
.metrics
import mean_squared_error
9 import matplotlib
.pyplot
as plt
10 from moisture_models
import model_decay
, model_moisture
11 from datetime
import datetime
, timedelta
12 from utils
import is_numeric_ndarray
, hash2
17 from utils
import Dict
, str2time
, check_increment
, time_intp
19 def process_train_dict(input_file_paths
, params_data
, atm_dict
= "HRRR", verbose
=False):
20 if type(input_file_paths
) is not list:
21 raise ValueError(f
"Argument `input_file_paths` must be list, received {type(input_file_paths)}")
23 for file_path
in input_file_paths
:
24 # Extract target and features
25 di
= build_train_dict(file_path
, atm
=atm_dict
, features_all
=params_data
['features_all'], verbose
=verbose
)
26 # Subset timeseries into shorter stretches
27 di
= split_timeseries(di
, hours
=params_data
['hours'], verbose
=verbose
)
28 di
= discard_keys_with_short_y(di
, hours
=params_data
['hours'], verbose
=False)
29 # Check for suspect data
30 flags
= flag_dict_keys(di
, params_data
['zero_lag_threshold'], params_data
['max_intp_time'], max_y
= params_data
['max_fm'], min_y
= params_data
['min_fm'], verbose
=verbose
)
31 # Remove flagged cases
32 cases
= list([*di
.keys()])
33 flagged_cases
= [element
for element
, flag
in zip(cases
, flags
) if flag
== 1]
34 remove_key_list(di
, flagged_cases
, verbose
=verbose
)
40 # Static features are based on physical location, e.g. location of RAWS site
41 'static': ['elev', 'lon', 'lat'],
42 # Atmospheric weather features come from either RAWS subdict or HRRR
43 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew']
46 def build_train_dict(input_file_path
,
47 forecast_step
=1, atm
="HRRR",features_all
=['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain'], verbose
=False):
49 # file_path list of strings - files as in read_test_pkl
50 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
51 # atm str - name of subdict where atmospheric vars are located
52 # features_list list of strings - names of keys in subdicts to collect into features matrix. Default is everything collected
54 # train dictionary with structure
55 # {key : {'key' : key, # copied subdict key
56 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
57 # 'time' : time, # datetime vector, spacing tres
58 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
59 # 'Y' : feat # features from atmosphere and location
65 if 'rain' in features_all
and (not features_all
[-1]=='rain'):
66 raise ValueError(f
"Make rain in features list last element since (working on fix as of 24-6-24), given features list: {features_list}")
68 if forecast_step
> 0 and forecast_step
< 100 and forecast_step
== int(forecast_step
):
69 fstep
='f'+str(forecast_step
).zfill(2)
70 fprev
='f'+str(forecast_step
-1).zfill(2)
71 # logging.info('Using data from step %s',fstep)
72 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
74 # logging.critical('forecast_step must be integer between 1 and 99')
75 raise ValueError('bad forecast_step')
78 with
open(input_file_path
, 'rb') as file:
79 # logging.info("loading file %s", file_path)
83 features_list
= features_all
84 # logging.info('Processing subdictionary %s',key)
87 # logging.warning('skipping duplicate key %s',key)
89 subdict
=d
[key
] # subdictionary for this case
92 'id': key
, # store the key inside the dictionary, subdictionary will be used separatedly
94 'filename': input_file_path
,
99 train
[desc
]=subdict
[desc
]
100 time_hrrr
=str2time(subdict
[atm_dict
]['time'])
102 hours
=len(d
[key
][atm_dict
]['time'])
103 train
[key
]['hours']=hours
104 # train[key]['h2'] =hours # not doing prediction yet
105 hrrr_increment
= check_increment(time_hrrr
,id=key
+f
' {atm_dict}.time')
106 # logging.info(f'{atm_dict} increment is %s h',hrrr_increment)
107 if hrrr_increment
< 1:
108 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
111 # build matrix of features - assuming all the same length, if not column_stack will fail
112 train
[key
]['time']=time_hrrr
113 # logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
114 time_raws
=str2time(subdict
['RAWS']['time_raws']) # may not be the same as HRRR
115 # logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
116 check_increment(time_raws
,id=key
+' RAWS.time_raws')
117 # print_first(time_raws,num=5,id='RAWS.time_raws')
121 missing_features
= []
122 for feat
in features_list
:
123 # For atmospheric features,
124 if feat
in feature_types
['atm']:
125 if atm_dict
== "HRRR":
126 vec
= subdict
['HRRR'][fstep
][feat
]
128 elif atm_dict
== "RAWS":
129 if feat
in subdict
['RAWS'].keys():
130 vec
= time_intp(time_raws
, subdict
['RAWS'][feat
], time_hrrr
)
133 missing_features
.append(feat
)
135 # For static features, repeat to fit number of time observations
136 elif feat
in feature_types
['static']:
137 columns
.append(np
.full(hours
,loc
[feat
]))
138 # compute rain as difference of accumulated precipitation
139 if 'rain' in features_list
:
140 if atm_dict
== "HRRR":
141 rain
= subdict
[atm_dict
][fstep
]['precip_accum']- subdict
[atm_dict
][fprev
]['precip_accum']
142 # logging.info('%s rain as difference %s minus %s: min %s max %s',
143 # key,fstep,fprev,np.min(rain),np.max(rain))
144 elif atm_dict
== "RAWS":
145 if 'rain' in subdict
[atm_dict
]:
146 rain
= time_intp(time_raws
,subdict
[atm_dict
]['rain'],time_hrrr
)
149 # logging.info('No rain data found in RAWS subdictionary %s', key)
150 columns
.append( rain
) # add rain feature
152 missing_features
.append('rain')
154 train
[key
]['X'] = np
.column_stack(columns
)
155 train
[key
]['features_list'] = [item
for item
in features_list
if item
not in missing_features
]
157 fm
=subdict
['RAWS']['fm']
158 # logging.info('%s RAWS.fm length is %s',key,len(fm))
159 # interpolate RAWS sensors to HRRR time and over NaNs
160 train
[key
]['y'] = time_intp(time_raws
,fm
,time_hrrr
)
161 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
163 if train
[key
]['y'] is None:
165 # logging.error('Cannot create target matrix for %s, using None',key)
168 # logging.info(f"Created target matrix train[{key}]['y'] shape {train[key]['y'].shape}")
170 # logging.info('Created a "train" dictionary with %s items',len(train))
176 if train
[key
]['X'] is None or train
[key
]['y'] is None:
177 # logging.warning('Deleting training item %s because features X or target Y are None', key)
178 keys_to_delete
.append(key
)
180 # Delete the items from the dictionary
181 if len(keys_to_delete
)>0:
182 for key
in keys_to_delete
:
184 # logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
185 # len(keys_to_delete),len(train))
189 # if output_file_path is not None:
190 # with open(output_file_path, 'wb') as file:
191 # logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
192 # pickle.dump(train, file)
194 # logging.info('pkl2train done')
199 def remove_key_list(d
, ls
, verbose
=False):
203 print(f
"Removing key {key} due to data flags")
206 def split_timeseries(dict0
, hours
, naming_convention
= "_set_", verbose
=False):
208 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
210 cases
= list([*dict0
.keys()])
212 for key
, data
in dict0
.items():
214 print(f
"Processing case: {key}")
215 print(f
"Length of y vector: {len(data['y'])}")
216 if type(data
['time'][0]) == str:
217 time
=str2time(data
['time'])
223 # Determine the start and end time for the 720-hour portions
226 current_time
= start_time
228 while current_time
< end_time
:
229 next_time
= current_time
+ timedelta(hours
=hours
)
231 # Create a mask for the time range
232 mask
= (time
>= current_time
) & (time
< next_time
)
234 # Apply the mask to extract the portions
235 new_time
= time
[mask
]
236 new_X
= X_array
[mask
]
237 new_y
= y_array
[mask
]
239 # Save the portions in the new dictionary with naming convention if second or more portion
240 if portion_index
> 1:
241 new_key
= f
"{key}{naming_convention}{portion_index}"
244 dict1
[new_key
] = {'time': new_time
, 'X': new_X
, 'y': new_y
}
246 # Add other keys that aren't subset
247 for key2
in dict0
[key
].keys():
248 if key2
not in ['time', 'X', 'y']:
249 dict1
[new_key
][key2
] = dict0
[key
][key2
]
250 # Update Case name and Id (same for now, overloaded terminology)
251 dict1
[new_key
]['case'] = new_key
252 dict1
[new_key
]['id'] = new_key
253 dict1
[new_key
]['hours'] = len(dict1
[new_key
]['y'])
256 # Move to the next portion
257 current_time
= next_time
260 print(f
"Partitions of length {hours} from case {key}: {portion_index-1}")
263 def flag_lag_stretches(x
, threshold
, lag
= 1):
265 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
267 lags
= np
.round(np
.diff(x
, n
=lag
), 8)
268 zero_lag_indices
= np
.where(lags
== 0)[0]
269 current_run_length
= 1
270 for i
in range(1, len(zero_lag_indices
)):
271 if zero_lag_indices
[i
] == zero_lag_indices
[i
-1] + 1:
272 current_run_length
+= 1
273 if current_run_length
> threshold
:
276 current_run_length
= 1
281 def flag_dict_keys(dict0
, lag_1_threshold
, lag_2_threshold
, max_y
, min_y
, verbose
=False):
283 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
285 cases
= list([*dict0
.keys()])
286 flags
= np
.zeros(len(cases
))
287 for i
, case
in enumerate(cases
):
290 print(f
"Case: {case}")
292 if flag_lag_stretches(y
, threshold
=lag_1_threshold
, lag
=1):
294 print(f
"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
296 if flag_lag_stretches(y
, threshold
=lag_2_threshold
, lag
=2):
298 print(f
"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
300 if np
.any(y
>=max_y
) or np
.any(y
<=min_y
):
302 print(f
"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
307 def discard_keys_with_short_y(input_dict
, hours
, verbose
=False):
309 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
311 discarded_keys
= [key
for key
, value
in input_dict
.items() if len(value
['y']) < hours
]
314 print(f
"Discarded keys due to y length less than {hours}: {discarded_keys}")
316 filtered_dict
= {key
: value
for key
, value
in input_dict
.items() if key
not in discarded_keys
}
322 # Utility to combine nested fmda dictionaries
323 def combine_nested(nested_input_dict
, verbose
=True):
325 Combines input data dictionaries.
329 verbose : bool, optional
330 If True, prints status messages. Default is True.
332 # Setup return dictionary
334 # Use the helper function to populate the keys
335 d
['id'] = _combine_key(nested_input_dict
, 'id')
336 d
['case'] = _combine_key(nested_input_dict
, 'case')
337 d
['filename'] = _combine_key(nested_input_dict
, 'filename')
338 d
['time'] = _combine_key(nested_input_dict
, 'time')
339 d
['X'] = _combine_key(nested_input_dict
, 'X')
340 d
['y'] = _combine_key(nested_input_dict
, 'y')
342 # Build the loc subdictionary using _combine_key for each loc key
344 'STID': _combine_key(nested_input_dict
, 'loc', 'STID'),
345 'lat': _combine_key(nested_input_dict
, 'loc', 'lat'),
346 'lon': _combine_key(nested_input_dict
, 'loc', 'lon'),
347 'elev': _combine_key(nested_input_dict
, 'loc', 'elev'),
348 'pixel_x': _combine_key(nested_input_dict
, 'loc', 'pixel_x'),
349 'pixel_y': _combine_key(nested_input_dict
, 'loc', 'pixel_y')
352 # Handle features_list separately with validation
353 features_list
= _combine_key(nested_input_dict
, 'features_list')
355 first_features_list
= features_list
[0]
356 for fl
in features_list
:
357 if fl
!= first_features_list
:
358 warnings
.warn("Different features_list found in the nested input dictionaries.")
359 d
['features_list'] = first_features_list
363 def _combine_key(nested_input_dict
, key
, subkey
=None):
365 for input_dict
in nested_input_dict
.values():
366 if isinstance(input_dict
, dict):
369 combined_list
.append(input_dict
[key
][subkey
])
371 combined_list
.append(input_dict
[key
])
373 raise ValueError(f
"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries")
375 raise ValueError(f
"Expected a dictionary, but got {type(input_dict)}")
379 def compare_dicts(dict1
, dict2
, keys
):
381 if dict1
.get(key
) != dict2
.get(key
):
385 items
= '_items_' # dictionary key to keep list of items in
386 def check_data_array(dat
,hours
,a
,s
):
391 print("array %s %s length %i min %s max %s hash %s %s" %
392 (a
,s
,len(ar
),min(ar
),max(ar
),hash2(ar
),type(ar
)))
393 if hours
is not None:
395 print('len(%a) = %i does not equal to hours = %i' % (a
,len(ar
),hours
))
398 print(a
+ ' not present')
400 def check_data_scalar(dat
,a
):
404 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
406 print(a
+ ' not present' )
408 def check_data(dat
,case
=True,name
=None):
409 dat
[items
] = list(dat
.keys()) # add list of items to the dictionary
413 check_data_scalar(dat
,'filename')
414 check_data_scalar(dat
,'title')
415 check_data_scalar(dat
,'note')
416 check_data_scalar(dat
,'hours')
417 check_data_scalar(dat
,'h2')
418 check_data_scalar(dat
,'case')
423 check_data_array(dat
,hours
,'E','drying equilibrium (%)')
424 check_data_array(dat
,hours
,'Ed','drying equilibrium (%)')
425 check_data_array(dat
,hours
,'Ew','wetting equilibrium (%)')
426 check_data_array(dat
,hours
,'Ec','equilibrium equilibrium (%)')
427 check_data_array(dat
,hours
,'rain','rain intensity (mm/h)')
428 check_data_array(dat
,hours
,'fm','RAWS fuel moisture data (%)')
429 check_data_array(dat
,hours
,'m','fuel moisture estimate (%)')
431 print('items:',dat
[items
])
432 for a
in dat
[items
].copy():
434 if dat
[a
] is None or np
.isscalar(dat
[a
]):
435 check_data_scalar(dat
,a
)
436 elif is_numeric_ndarray(ar
):
438 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
439 "max",np
.max(ar
),"hash",hash2(ar
),"type",type(ar
))
440 elif isinstance(ar
, tf
.Tensor
):
441 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
442 "max",np
.max(ar
),"type",type(ar
))
444 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
445 del dat
[items
] # clean up
447 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
448 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
449 def to_json(dic
,filename
):
450 # Write given dictionary as json file.
451 # This utility is used because the typical method fails on numpy.ndarray
454 # filename: (str) output json filename, expect a ".json" file extension
457 print('writing ',filename
)
461 if type(dic
[i
]) is np
.ndarray
:
462 new
[i
]=dic
[i
].tolist() # because numpy.ndarray is not serializable
465 # print('i',type(new[i]))
466 new
['filename']=filename
467 print('Hash: ', hash2(new
))
468 json
.dump(new
,open(filename
,'w'),indent
=4)
470 def from_json(filename
):
471 # Read json file given a filename
472 # Inputs: filename (str) expect a ".json" string
474 print('reading ',filename
)
475 dic
=json
.load(open(filename
,'r'))
478 if type(dic
[i
]) is list:
479 new
[i
]=np
.array(dic
[i
]) # because ndarray is not serializable
483 print('Hash: ', hash2(new
))
486 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488 # Function to simulate moisture data and equilibrium for model testing
489 def create_synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,DeltaE
=0.0):
492 hour
= np
.array(range(hours
))
493 day
= np
.array(range(hours
))/24.
495 # artificial equilibrium data
496 E
= 100.0*np
.power(np
.sin(np
.pi
*day
),4) # diurnal curve
499 m_f
= np
.zeros(hours
)
500 m_f
[0] = 0.1 # initial FMC
502 for t
in range(hours
-1):
503 m_f
[t
+1] = max(0.,model_decay(m_f
[t
],E
[t
]) + random
.gauss(0,process_noise
) )
504 data
= m_f
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
506 return E
,m_f
,data
,hour
,h2
,DeltaE
508 # the following input or output dictionary with all model data and variables
510 def synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,
511 DeltaE
=0.0,Emin
=5,Emax
=30,p_rain
=0.01,max_rain
=10.0):
514 hour
= np
.array(range(hours
))
515 day
= np
.array(range(hours
))/24.
516 # artificial equilibrium data
517 E
= np
.power(np
.sin(np
.pi
*day
),power
) # diurnal curve betwen 0 and 1
518 E
= Emin
+(Emax
- Emin
)*E
521 Ew
=np
.maximum(E
-0.5,0)
522 rain
= np
.multiply(rand(hours
) < p_rain
, rand(hours
)*max_rain
)
525 fm
[0] = 0.1 # initial FMC
527 for t
in range(hours
-1):
528 fm
[t
+1] = max(0.,model_moisture(fm
[t
],Ed
[t
-1],Ew
[t
-1],rain
[t
-1]) + random
.gauss(0,process_noise
))
529 fm
= fm
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
530 dat
= {'E':E
,'Ew':Ew
,'Ed':Ed
,'fm':fm
,'hours':hours
,'h2':h2
,'DeltaE':DeltaE
,'rain':rain
,'title':'Synthetic data'}
534 def plot_one(hmin
,hmax
,dat
,name
,linestyle
,c
,label
, alpha
=1,type='plot'):
535 # helper for plot_data
542 hour
= np
.array(range(hmin
,hmax
))
544 plt
.plot(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
545 elif type=='scatter':
546 plt
.scatter(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
548 # Lookup table for plotting features
550 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
551 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
552 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
554 def plot_feature(x
, y
, feature_name
):
555 style
= plot_styles
.get(feature_name
, {})
556 plt
.plot(x
, y
, **style
)
558 def plot_features(hmin
,hmax
,dat
,linestyle
,c
,label
,alpha
=1):
559 hour
= np
.array(range(hmin
,hmax
))
560 for feat
in dat
.features_list
:
561 i
= dat
.all_features_list
.index(feat
) # index of main data
562 if feat
in plot_styles
.keys():
563 plot_feature(x
=hour
, y
=dat
['X'][:,i
][hmin
:hmax
], feature_name
=feat
)
565 def plot_data(dat
, plot_period
='all', create_figure
=False,title
=None,title2
=None,hmin
=0,hmax
=None,xlabel
=None,ylabel
=None):
566 # Plot fmda dictionary of data and model if present
568 # dat: FMDA dictionary
569 # inverse_scale: logical, whether to inverse scale data
572 # dat = copy.deepcopy(dat0)
578 hmax
= min(hmax
, dat
['hours'])
579 if plot_period
== "all":
581 elif plot_period
== "predict":
582 assert "test_ind" in dat
.keys()
583 hmin
= dat
['test_ind']
586 raise ValueError(f
"unrecognized time period for plotting plot_period: {plot_period}")
590 plt
.figure(figsize
=(16,4))
592 plot_one(hmin
,hmax
,dat
,'y',linestyle
='-',c
='#468a29',label
='FM Observed')
593 plot_one(hmin
,hmax
,dat
,'m',linestyle
='-',c
='k',label
='FM Model')
594 plot_features(hmin
,hmax
,dat
,linestyle
='-',c
='k',label
='FM Model')
597 if 'test_ind' in dat
.keys():
598 test_ind
= dat
["test_ind"]
601 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
602 # Note: the code within the tildes here makes a more complex, annotated plot
603 if (test_ind
is not None) and ('m' in dat
.keys()):
604 plt
.axvline(test_ind
, linestyle
=':', c
='k', alpha
=.8)
605 yy
= plt
.ylim() # used to format annotations
606 plot_y0
= np
.max([hmin
, test_ind
]) # Used to format annotations
607 plot_y1
= np
.min([hmin
, test_ind
])
608 plt
.annotate('', xy
=(hmin
, yy
[0]),xytext
=(plot_y0
,yy
[0]),
609 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
610 annotation_clip
=False)
611 plt
.annotate('(Training)',xy
=((hmin
+plot_y0
)/2,yy
[1]),xytext
=((hmin
+plot_y0
)/2,yy
[1]+1), ha
= 'right',
612 annotation_clip
=False, alpha
=.8)
613 plt
.annotate('', xy
=(plot_y0
, yy
[0]),xytext
=(hmax
,yy
[0]),
614 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
615 annotation_clip
=False)
616 plt
.annotate('(Forecast)',xy
=(hmax
-(hmax
-test_ind
)/2,yy
[1]),
617 xytext
=(hmax
-(hmax
-test_ind
)/2,yy
[1]+1),
618 annotation_clip
=False, alpha
=.8)
619 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
622 if title
is not None:
626 # print('title',type(t),t)
629 if title2
is not None:
631 t
= t
+ ' (' + rmse_data_str(dat
)+')'
632 if plot_period
== "predict":
633 t
= t
+ " - Forecast Period"
637 plt
.xlabel('Time (hours)')
641 plt
.ylabel('FM (%) / Rain (mm/h)')
643 plt
.ylabel('Fuel moisture content (%)')
646 plt
.legend(loc
="upper left")
649 return np
.sqrt(mean_squared_error(a
.flatten(), b
.flatten()))
651 def rmse_skip_nan(x
, y
):
652 mask
= ~np
.isnan(x
) & ~np
.isnan(y
)
653 if np
.count_nonzero(mask
):
654 return np
.sqrt(np
.mean((x
[mask
] - y
[mask
]) ** 2))
659 rmse
= rmse_skip_nan(a
,b
)
660 return "RMSE " + "{:.3f}".format(rmse
)
662 def rmse_data_str(dat
, predict
=True, hours
= None, test_ind
= None):
663 # Return RMSE for model object in formatted string. Used within plotting
665 # dat: (dict) fmda dictionary
666 # predict: (bool) Whether to return prediction period RMSE. Default True
667 # hours: (int) total number of modeled time periods
668 # test_ind: (int) start of test period
669 # Return: (str) RMSE value
675 if 'test_ind' in dat
:
676 test_ind
= dat
['test_ind']
678 if 'm' in dat
and 'y' in dat
:
679 if predict
and hours
is not None and test_ind
is not None:
680 return rmse_str(dat
['m'][test_ind
:hours
],dat
['y'].flatten()[test_ind
:hours
])
682 return rmse_str(dat
['m'],dat
['y'].flatten())
687 # Calculate mean absolute error
689 return ((a
- b
).__abs
__()).mean()
691 def rmse_data(dat
, hours
= None, h2
= None, simulation
='m', measurements
='fm'):
698 fm
= dat
[measurements
]
701 train
=rmse(m
[:h2
], fm
[:h2
])
702 predict
= rmse(m
[h2
:hours
], fm
[h2
:hours
])
703 all
= rmse(m
[:hours
], fm
[:hours
])
704 print(case
,'Training 1 to',h2
,'hours RMSE: ' + str(np
.round(train
, 4)))
705 print(case
,'Prediction',h2
+1,'to',hours
,'hours RMSE: ' + str(np
.round(predict
, 4)))
706 print(f
"All predictions hash: {hash2(m)}")
708 return {'train':train
, 'predict':predict
, 'all':all
}
712 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
716 def get_file(filename
, data_dir
='data'):
717 # Check for file locally, retrieve with wget if not
718 if osp
.exists(osp
.join(data_dir
, filename
)):
719 print(f
"File {osp.join(data_dir, filename)} exists locally")
720 elif not osp
.exists(filename
):
722 base_url
= "https://demo.openwfm.org/web/data/fmda/dicts/"
723 print(f
"Retrieving data {osp.join(base_url, filename)}")
724 subprocess
.call(f
"wget -P {data_dir} {osp.join(base_url, filename)}", shell
=True)