1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy
as np
, random
5 from numpy
.random
import rand
6 import tensorflow
as tf
8 from sklearn
.metrics
import mean_squared_error
9 import matplotlib
.pyplot
as plt
10 from moisture_models
import model_decay
, model_moisture
11 from datetime
import datetime
, timedelta
12 from utils
import is_numeric_ndarray
, hash2
17 from utils
import Dict
, str2time
, check_increment
, time_intp
, read_pkl
21 # New Dict Functions as of 2024-10-2, needs more testing
22 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
25 # Static features are based on physical location, e.g. location of RAWS site
26 'static': ['elev', 'lon', 'lat'],
27 # Atmospheric weather features come from either RAWS subdict or HRRR
28 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew'],
29 # Features that require calculation. NOTE: rain only calculated in HRRR, not RAWS
30 'engineered': ['doy', 'hod', 'rain']
33 def check_feat(feat
, d
):
35 raise ValueError(f
"Feature {feat} not found")
37 def get_time(d
, atm
="HRRR"):
38 check_feat('time', d
[atm
])
39 time
= str2time(d
[atm
]['time'])
42 def get_static(d
, hours
):
44 # Use all static vars, don't allow for missing
45 names
= feature_types
['static']
47 check_feat(feat
, d
['loc'])
48 cols
.append(np
.full(hours
,d
['loc'][feat
]))
51 def get_hrrr_atm(d
, fstep
):
53 # Use all names, don't allow for missing
54 names
= feature_types
['atm'].copy()
56 check_feat(feat
, d
["HRRR"][fstep
])
57 v
= d
["HRRR"][fstep
][feat
]
61 def calc_time_features(time
):
62 names
= ['doy', 'hod']
63 doy
= np
.array([dt
.timetuple().tm_yday
- 1 for dt
in time
])
64 hod
= time
.astype('datetime64[h]').astype(int) % 24
68 def calc_hrrr_rain(d
, fstep
, fprev
):
69 rain
= d
["HRRR"][fstep
]['precip_accum']- d
["HRRR"][fprev
]['precip_accum']
72 def get_raws_atm(d
, time
):
73 # may not be the same as requested time vector, used to interpolate to input time
74 time_raws
=str2time(d
['RAWS']['time_raws'])
79 # Loop through all features, including rain
80 for feat
in feature_types
['atm']+['rain']:
83 v
= time_intp(time_raws
, v
, time
)
84 assert len(v
)==len(time
), f
"RAWS feature {feat} not equal length to input time: {len(v)} vs {len(time)}"
89 def build_features_single(subdict
, atm
="HRRR", fstep
=None, fprev
=None):
93 time
= get_time(subdict
)
94 # Calculate derived time variables
95 tvars
, tnames
= calc_time_features(time
)
96 # Get Static Features, extends to hours
97 static_vars
, static_names
= get_static(subdict
, hours
= len(time
))
98 # Get atmospheric variables based on data source. HRRR requires rain calculation
100 atm_vars
, atm_names
= get_hrrr_atm(subdict
, fstep
)
101 rain
= calc_hrrr_rain(subdict
, fstep
, fprev
)
102 atm_vars
.append(rain
)
103 atm_names
.append("rain")
105 atm_vars
, atm_names
= get_raws_atm(subdict
, time
)
107 raise ValueError(f
"Unrecognized atmospheric data source: {atm}")
108 # Put everything together and stack
109 cols
= tvars
+ static_vars
+ atm_vars
110 X
= np
.column_stack(cols
)
111 names
= tnames
+ static_names
+ atm_names
117 time_raws
= str2time(d
['RAWS']['time_raws'])
118 return time_intp(time_raws
,fm
,time
)
120 def build_train_dict2(input_file_paths
, params_data
, spatial
=True, atm_source
="HRRR", forecast_step
=1, verbose
=True):
121 # TODO: process involves multiple loops over dictionary keys. Inefficient, but functionality are isolated in separate functions that call for loops so will be tedious to fix
124 # Define Forecast Step: NOTE: if atm_source == RAWS, this should be zero
125 if atm_source
== "RAWS":
126 print("Atmospheric data source is RAWS, so forecast_step set to zero")
128 elif forecast_step
> 0 and forecast_step
< 100 and forecast_step
== int(forecast_step
):
129 fstep
='f'+str(forecast_step
).zfill(2)
130 fprev
='f'+str(forecast_step
-1).zfill(2)
131 # logging.info('Using data from step %s',fstep)
132 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
134 # logging.critical('forecast_step must be integer between 1 and 99')
135 raise ValueError('bad forecast_step')
138 # Loop over input dictionary cases, extract and calculate features, then run data filters
140 for input_file_path
in input_file_paths
:
143 print(f
"Extracting data from input file {input_file_path}")
144 dict0
= read_pkl(input_file_path
)
148 # print(f"Extracting data for case {key}")
149 X
, names
= build_features_single(dict0
[key
], atm
=atm_source
, fstep
=fstep
, fprev
= fprev
)
150 time
= str2time(dict0
[key
]['HRRR']['time'])
151 hrrr_increment
= check_increment(time
,id=key
+f
' {"HRRR"}.time')
152 if hrrr_increment
< 1:
153 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
158 'filename': input_file_path
,
159 'loc': dict0
[key
]['loc'],
162 'y': get_fm(dict0
[key
], time
),
163 'features_list': names
,
164 'atm_source': atm_source
168 # Subset timeseries into shorter stretches, discard short ones
171 print(f
"Splitting Timeseries into smaller portions to aid with data filtering. Input data param for max timeseries hours: {params_data['hours']}")
172 new_dict
= split_timeseries(new_dict
, hours
=params_data
['hours'], verbose
=verbose
)
173 new_dict
= discard_keys_with_short_y(new_dict
, hours
=params_data
['hours'], verbose
=False)
175 # Check for suspect data
176 flags
= flag_dict_keys(new_dict
, params_data
['zero_lag_threshold'], params_data
['max_intp_time'], max_y
= params_data
['max_fm'], min_y
= params_data
['min_fm'], verbose
=verbose
)
178 # Remove flagged cases
179 cases
= list([*new_dict
.keys()])
180 flagged_cases
= [element
for element
, flag
in zip(cases
, flags
) if flag
== 1]
181 remove_key_list(new_dict
, flagged_cases
, verbose
=verbose
)
184 new_dict
= combine_nested(new_dict
)
186 return Dict(new_dict
)
203 # Wrapper Functions to Put it all together
204 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206 # TODO: ENGINEERED TIME FEATURES:
207 # hod = rnn_dat.time.astype('datetime64[h]').astype(int) % 24
208 # doy = np.array([dt.timetuple().tm_yday - 1 for dt in rnn_dat.time])
210 def create_spatial_train(input_file_paths
, params_data
, atm_dict
= "HRRR", verbose
=False):
211 train
= process_train_dict(input_file_paths
, params_data
= params_data
, verbose
=verbose
)
212 train_sp
= Dict(combine_nested(train
))
215 def process_train_dict(input_file_paths
, params_data
, atm_dict
= "HRRR", spatial
=False, verbose
=False):
216 if type(input_file_paths
) is not list:
217 raise ValueError(f
"Argument `input_file_paths` must be list, received {type(input_file_paths)}")
219 for file_path
in input_file_paths
:
220 # Extract target and features
221 di
= build_train_dict(file_path
, atm
=atm_dict
, features_all
=params_data
['features_all'], verbose
=verbose
)
222 # Subset timeseries into shorter stretches
223 di
= split_timeseries(di
, hours
=params_data
['hours'], verbose
=verbose
)
224 di
= discard_keys_with_short_y(di
, hours
=params_data
['hours'], verbose
=False)
225 # Check for suspect data
226 flags
= flag_dict_keys(di
, params_data
['zero_lag_threshold'], params_data
['max_intp_time'], max_y
= params_data
['max_fm'], min_y
= params_data
['min_fm'], verbose
=verbose
)
227 # Remove flagged cases
228 cases
= list([*di
.keys()])
229 flagged_cases
= [element
for element
, flag
in zip(cases
, flags
) if flag
== 1]
230 remove_key_list(di
, flagged_cases
, verbose
=verbose
)
233 train
= combine_nested(train
)
238 def subset_by_features(nested_dict
, input_features
, verbose
=True):
240 Subsets a nested dictionary to only include keys where all strings in the input_features
241 are present in the dictionary's 'features_list' subkey. Primarily used for RAWS dictionaries where desired features might not be present at all ground stations.
244 nested_dict (dict): The nested dictionary with a 'features_list' subkey.
245 input_features (list): The list of features to be checked.
248 dict: A subset of the input dictionary with only the matching keys.
251 print(f
"Subsetting to cases with features: {input_features}")
253 # Create a new dictionary to store the result
256 # Iterate through the keys in the nested dictionary
257 for key
, value
in nested_dict
.items():
258 # Check if 'features_list' key exists and all input_features are in the list
259 if 'features_list' in value
and all(feature
in value
['features_list'] for feature
in input_features
):
260 # Add to the result if all features are present
266 # Static features are based on physical location, e.g. location of RAWS site
267 'static': ['elev', 'lon', 'lat'],
268 # Atmospheric weather features come from either RAWS subdict or HRRR
269 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew']
272 def build_train_dict(input_file_path
,
273 forecast_step
=1, atm
="HRRR",features_all
=['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'doy', 'hod', 'rain'], verbose
=False):
275 # file_path list of strings - files as in read_test_pkl
276 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
277 # atm str - name of subdict where atmospheric vars are located
278 # features_list list of strings - names of keys in subdicts to collect into features matrix. Default is everything collected
280 # train dictionary with structure
281 # {key : {'key' : key, # copied subdict key
282 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
283 # 'time' : time, # datetime vector, spacing tres
284 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
285 # 'Y' : feat # features from atmosphere and location
291 if 'rain' in features_all
and (not features_all
[-1]=='rain'):
292 raise ValueError(f
"Make rain in features list last element since (working on fix as of 24-6-24), given features list: {features_list}")
294 if forecast_step
> 0 and forecast_step
< 100 and forecast_step
== int(forecast_step
):
295 fstep
='f'+str(forecast_step
).zfill(2)
296 fprev
='f'+str(forecast_step
-1).zfill(2)
297 # logging.info('Using data from step %s',fstep)
298 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
300 # logging.critical('forecast_step must be integer between 1 and 99')
301 raise ValueError('bad forecast_step')
304 with
open(input_file_path
, 'rb') as file:
305 # logging.info("loading file %s", file_path)
306 d
= pickle
.load(file)
309 features_list
= features_all
310 # logging.info('Processing subdictionary %s',key)
313 # logging.warning('skipping duplicate key %s',key)
315 subdict
=d
[key
] # subdictionary for this case
318 'id': key
, # store the key inside the dictionary, subdictionary will be used separatedly
320 'filename': input_file_path
,
325 train
[desc
]=subdict
[desc
]
326 time_hrrr
=str2time(subdict
[atm_dict
]['time'])
328 hours
=len(d
[key
][atm_dict
]['time'])
329 train
[key
]['hours']=hours
330 # train[key]['h2'] =hours # not doing prediction yet
331 hrrr_increment
= check_increment(time_hrrr
,id=key
+f
' {atm_dict}.time')
332 # logging.info(f'{atm_dict} increment is %s h',hrrr_increment)
333 if hrrr_increment
< 1:
334 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
337 # build matrix of features - assuming all the same length, if not column_stack will fail
338 train
[key
]['time']=time_hrrr
339 # logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
340 time_raws
=str2time(subdict
['RAWS']['time_raws']) # may not be the same as HRRR
341 # logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
342 check_increment(time_raws
,id=key
+' RAWS.time_raws')
343 # print_first(time_raws,num=5,id='RAWS.time_raws')
347 missing_features
= []
348 for feat
in features_list
:
349 # For atmospheric features,
350 if feat
in feature_types
['atm']:
351 if atm_dict
== "HRRR":
352 vec
= subdict
['HRRR'][fstep
][feat
]
354 elif atm_dict
== "RAWS":
355 if feat
in subdict
['RAWS'].keys():
356 vec
= time_intp(time_raws
, subdict
['RAWS'][feat
], time_hrrr
)
359 missing_features
.append(feat
)
361 # For static features, repeat to fit number of time observations
362 elif feat
in feature_types
['static']:
363 columns
.append(np
.full(hours
,loc
[feat
]))
364 # Add Engineered Time features, doy and hod
365 # hod = time_hrrr.astype('datetime64[h]').astype(int) % 24
366 # doy = np.array([dt.timetuple().tm_yday - 1 for dt in time_hrrr])
367 # columns.extend([doy, hod])
369 # compute rain as difference of accumulated precipitation
370 if 'rain' in features_list
:
371 if atm_dict
== "HRRR":
372 rain
= subdict
[atm_dict
][fstep
]['precip_accum']- subdict
[atm_dict
][fprev
]['precip_accum']
373 # logging.info('%s rain as difference %s minus %s: min %s max %s',
374 # key,fstep,fprev,np.min(rain),np.max(rain))
375 elif atm_dict
== "RAWS":
376 if 'rain' in subdict
[atm_dict
]:
377 rain
= time_intp(time_raws
,subdict
[atm_dict
]['rain'],time_hrrr
)
380 # logging.info('No rain data found in RAWS subdictionary %s', key)
381 columns
.append( rain
) # add rain feature
383 missing_features
.append('rain')
385 train
[key
]['X'] = np
.column_stack(columns
)
386 train
[key
]['features_list'] = [item
for item
in features_list
if item
not in missing_features
]
388 fm
=subdict
['RAWS']['fm']
389 # logging.info('%s RAWS.fm length is %s',key,len(fm))
390 # interpolate RAWS sensors to HRRR time and over NaNs
391 train
[key
]['y'] = time_intp(time_raws
,fm
,time_hrrr
)
392 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
394 if train
[key
]['y'] is None:
396 # logging.error('Cannot create target matrix for %s, using None',key)
399 # logging.info(f"Created target matrix train[{key}]['y'] shape {train[key]['y'].shape}")
401 # logging.info('Created a "train" dictionary with %s items',len(train))
407 if train
[key
]['X'] is None or train
[key
]['y'] is None:
408 # logging.warning('Deleting training item %s because features X or target Y are None', key)
409 keys_to_delete
.append(key
)
411 # Delete the items from the dictionary
412 if len(keys_to_delete
)>0:
413 for key
in keys_to_delete
:
415 # logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
416 # len(keys_to_delete),len(train))
420 # if output_file_path is not None:
421 # with open(output_file_path, 'wb') as file:
422 # logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
423 # pickle.dump(train, file)
425 # logging.info('pkl2train done')
431 def remove_key_list(d
, ls
, verbose
=False):
435 print(f
"Removing key {key} due to data flags")
438 def split_timeseries(dict0
, hours
, naming_convention
= "_set_", verbose
=False):
440 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
442 cases
= list([*dict0
.keys()])
444 for key
, data
in dict0
.items():
446 print(f
"Processing case: {key}")
447 print(f
"Length of y vector: {len(data['y'])}")
448 if type(data
['time'][0]) == str:
449 time
=str2time(data
['time'])
455 # Determine the start and end time for the 720-hour portions
458 current_time
= start_time
460 while current_time
< end_time
:
461 next_time
= current_time
+ timedelta(hours
=hours
)
463 # Create a mask for the time range
464 mask
= (time
>= current_time
) & (time
< next_time
)
466 # Apply the mask to extract the portions
467 new_time
= time
[mask
]
468 new_X
= X_array
[mask
]
469 new_y
= y_array
[mask
]
471 # Save the portions in the new dictionary with naming convention if second or more portion
472 if portion_index
> 1:
473 new_key
= f
"{key}{naming_convention}{portion_index}"
476 dict1
[new_key
] = {'time': new_time
, 'X': new_X
, 'y': new_y
}
478 # Add other keys that aren't subset
479 for key2
in dict0
[key
].keys():
480 if key2
not in ['time', 'X', 'y']:
481 dict1
[new_key
][key2
] = dict0
[key
][key2
]
482 # Update Case name and Id (same for now, overloaded terminology)
483 dict1
[new_key
]['case'] = new_key
484 dict1
[new_key
]['id'] = new_key
485 dict1
[new_key
]['hours'] = len(dict1
[new_key
]['y'])
488 # Move to the next portion
489 current_time
= next_time
492 print(f
"Partitions of length {hours} from case {key}: {portion_index-1}")
495 def flag_lag_stretches(x
, threshold
, lag
= 1):
497 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
499 lags
= np
.round(np
.diff(x
, n
=lag
), 8)
500 zero_lag_indices
= np
.where(lags
== 0)[0]
501 current_run_length
= 1
502 for i
in range(1, len(zero_lag_indices
)):
503 if zero_lag_indices
[i
] == zero_lag_indices
[i
-1] + 1:
504 current_run_length
+= 1
505 if current_run_length
> threshold
:
508 current_run_length
= 1
513 def flag_dict_keys(dict0
, lag_1_threshold
, lag_2_threshold
, max_y
, min_y
, verbose
=False):
515 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
517 cases
= list([*dict0
.keys()])
518 flags
= np
.zeros(len(cases
))
519 for i
, case
in enumerate(cases
):
522 print(f
"Case: {case}")
524 if flag_lag_stretches(y
, threshold
=lag_1_threshold
, lag
=1):
526 print(f
"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
528 if flag_lag_stretches(y
, threshold
=lag_2_threshold
, lag
=2):
530 print(f
"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
532 if np
.any(y
>=max_y
) or np
.any(y
<=min_y
):
534 print(f
"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
539 def discard_keys_with_short_y(input_dict
, hours
, verbose
=False):
541 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
544 max_y
= max(case
['y'].shape
[0] for case
in input_dict
.values())
546 raise ValueError(f
"Given input hours of {hours}, all cases in input dictionary would be filtered out as too short. Max timeseries y length: {max_y}. Try using a larger value for `hours` in data params")
548 discarded_keys
= [key
for key
, value
in input_dict
.items() if len(value
['y']) < hours
]
551 print(f
"Discarded keys due to y length less than {hours}: {discarded_keys}")
553 filtered_dict
= {key
: value
for key
, value
in input_dict
.items() if key
not in discarded_keys
}
559 # Utility to combine nested fmda dictionaries
560 def combine_nested(nested_input_dict
, verbose
=True):
562 Combines input data dictionaries.
566 verbose : bool, optional
567 If True, prints status messages. Default is True.
569 # Setup return dictionary
571 # Use the helper function to populate the keys
572 d
['id'] = _combine_key(nested_input_dict
, 'id')
573 d
['case'] = _combine_key(nested_input_dict
, 'case')
574 d
['filename'] = _combine_key(nested_input_dict
, 'filename')
575 d
['time'] = _combine_key(nested_input_dict
, 'time')
576 d
['X'] = _combine_key(nested_input_dict
, 'X')
577 d
['y'] = _combine_key(nested_input_dict
, 'y')
578 d
['atm_source'] = _combine_key(nested_input_dict
, 'atm_source')
580 # Build the loc subdictionary using _combine_key for each loc key
582 'STID': _combine_key(nested_input_dict
, 'loc', 'STID'),
583 'lat': _combine_key(nested_input_dict
, 'loc', 'lat'),
584 'lon': _combine_key(nested_input_dict
, 'loc', 'lon'),
585 'elev': _combine_key(nested_input_dict
, 'loc', 'elev'),
586 'pixel_x': _combine_key(nested_input_dict
, 'loc', 'pixel_x'),
587 'pixel_y': _combine_key(nested_input_dict
, 'loc', 'pixel_y')
590 # Handle features_list separately with validation
591 features_list
= _combine_key(nested_input_dict
, 'features_list')
593 first_features_list
= features_list
[0]
594 for fl
in features_list
:
595 if fl
!= first_features_list
:
596 warnings
.warn("Different features_list found in the nested input dictionaries.")
597 d
['features_list'] = first_features_list
601 def _combine_key(nested_input_dict
, key
, subkey
=None):
603 for input_dict
in nested_input_dict
.values():
604 if isinstance(input_dict
, dict):
607 combined_list
.append(input_dict
[key
][subkey
])
609 combined_list
.append(input_dict
[key
])
611 warning_message
= f
"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries. Setting value to None."
612 warnings
.warn(warning_message
)
613 combined_list
.append(None)
615 raise ValueError(f
"Expected a dictionary, but got {type(input_dict)}")
619 def compare_dicts(dict1
, dict2
, keys
):
621 if dict1
.get(key
) != dict2
.get(key
):
625 items
= '_items_' # dictionary key to keep list of items in
626 def check_data_array(dat
,hours
,a
,s
):
631 print("array %s %s length %i min %s max %s hash %s %s" %
632 (a
,s
,len(ar
),min(ar
),max(ar
),hash2(ar
),type(ar
)))
633 if hours
is not None:
635 print('len(%a) = %i does not equal to hours = %i' % (a
,len(ar
),hours
))
638 print(a
+ ' not present')
640 def check_data_scalar(dat
,a
):
644 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
646 print(a
+ ' not present' )
648 def check_data(dat
,case
=True,name
=None):
649 dat
[items
] = list(dat
.keys()) # add list of items to the dictionary
653 check_data_scalar(dat
,'filename')
654 check_data_scalar(dat
,'title')
655 check_data_scalar(dat
,'note')
656 check_data_scalar(dat
,'hours')
657 check_data_scalar(dat
,'h2')
658 check_data_scalar(dat
,'case')
663 check_data_array(dat
,hours
,'E','drying equilibrium (%)')
664 check_data_array(dat
,hours
,'Ed','drying equilibrium (%)')
665 check_data_array(dat
,hours
,'Ew','wetting equilibrium (%)')
666 check_data_array(dat
,hours
,'Ec','equilibrium equilibrium (%)')
667 check_data_array(dat
,hours
,'rain','rain intensity (mm/h)')
668 check_data_array(dat
,hours
,'fm','RAWS fuel moisture data (%)')
669 check_data_array(dat
,hours
,'m','fuel moisture estimate (%)')
671 print('items:',dat
[items
])
672 for a
in dat
[items
].copy():
674 if dat
[a
] is None or np
.isscalar(dat
[a
]):
675 check_data_scalar(dat
,a
)
676 elif is_numeric_ndarray(ar
):
678 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
679 "max",np
.max(ar
),"hash",hash2(ar
),"type",type(ar
))
680 elif isinstance(ar
, tf
.Tensor
):
681 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
682 "max",np
.max(ar
),"type",type(ar
))
684 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
685 del dat
[items
] # clean up
687 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
688 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
689 def to_json(dic
,filename
):
690 # Write given dictionary as json file.
691 # This utility is used because the typical method fails on numpy.ndarray
694 # filename: (str) output json filename, expect a ".json" file extension
697 print('writing ',filename
)
701 if type(dic
[i
]) is np
.ndarray
:
702 new
[i
]=dic
[i
].tolist() # because numpy.ndarray is not serializable
705 # print('i',type(new[i]))
706 new
['filename']=filename
707 print('Hash: ', hash2(new
))
708 json
.dump(new
,open(filename
,'w'),indent
=4)
710 def from_json(filename
):
711 # Read json file given a filename
712 # Inputs: filename (str) expect a ".json" string
714 print('reading ',filename
)
715 dic
=json
.load(open(filename
,'r'))
718 if type(dic
[i
]) is list:
719 new
[i
]=np
.array(dic
[i
]) # because ndarray is not serializable
723 print('Hash: ', hash2(new
))
726 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
728 # Function to simulate moisture data and equilibrium for model testing
729 def create_synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,DeltaE
=0.0):
732 hour
= np
.array(range(hours
))
733 day
= np
.array(range(hours
))/24.
735 # artificial equilibrium data
736 E
= 100.0*np
.power(np
.sin(np
.pi
*day
),4) # diurnal curve
739 m_f
= np
.zeros(hours
)
740 m_f
[0] = 0.1 # initial FMC
742 for t
in range(hours
-1):
743 m_f
[t
+1] = max(0.,model_decay(m_f
[t
],E
[t
]) + random
.gauss(0,process_noise
) )
744 data
= m_f
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
746 return E
,m_f
,data
,hour
,h2
,DeltaE
748 # the following input or output dictionary with all model data and variables
750 def synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,
751 DeltaE
=0.0,Emin
=5,Emax
=30,p_rain
=0.01,max_rain
=10.0):
754 hour
= np
.array(range(hours
))
755 day
= np
.array(range(hours
))/24.
756 # artificial equilibrium data
757 E
= np
.power(np
.sin(np
.pi
*day
),power
) # diurnal curve betwen 0 and 1
758 E
= Emin
+(Emax
- Emin
)*E
761 Ew
=np
.maximum(E
-0.5,0)
762 rain
= np
.multiply(rand(hours
) < p_rain
, rand(hours
)*max_rain
)
765 fm
[0] = 0.1 # initial FMC
767 for t
in range(hours
-1):
768 fm
[t
+1] = max(0.,model_moisture(fm
[t
],Ed
[t
-1],Ew
[t
-1],rain
[t
-1]) + random
.gauss(0,process_noise
))
769 fm
= fm
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
770 dat
= {'E':E
,'Ew':Ew
,'Ed':Ed
,'fm':fm
,'hours':hours
,'h2':h2
,'DeltaE':DeltaE
,'rain':rain
,'title':'Synthetic data'}
774 def plot_one(hmin
,hmax
,dat
,name
,linestyle
,c
,label
, alpha
=1,type='plot'):
775 # helper for plot_data
782 hour
= np
.array(range(hmin
,hmax
))
784 plt
.plot(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
785 elif type=='scatter':
786 plt
.scatter(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
788 # Lookup table for plotting features
790 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
791 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
792 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
794 def plot_feature(x
, y
, feature_name
):
795 style
= plot_styles
.get(feature_name
, {})
796 plt
.plot(x
, y
, **style
)
798 def plot_features(hmin
,hmax
,dat
,linestyle
,c
,label
,alpha
=1):
799 hour
= np
.array(range(hmin
,hmax
))
800 for feat
in dat
.features_list
:
801 i
= dat
.all_features_list
.index(feat
) # index of main data
802 if feat
in plot_styles
.keys():
803 plot_feature(x
=hour
, y
=dat
['X'][:,i
][hmin
:hmax
], feature_name
=feat
)
805 def plot_data(dat
, plot_period
='all', create_figure
=False,title
=None,title2
=None,hmin
=0,hmax
=None,xlabel
=None,ylabel
=None):
806 # Plot fmda dictionary of data and model if present
808 # dat: FMDA dictionary
809 # inverse_scale: logical, whether to inverse scale data
812 # dat = copy.deepcopy(dat0)
818 hmax
= min(hmax
, dat
['hours'])
819 if plot_period
== "all":
821 elif plot_period
== "predict":
822 assert "test_ind" in dat
.keys()
823 hmin
= dat
['test_ind']
826 raise ValueError(f
"unrecognized time period for plotting plot_period: {plot_period}")
830 plt
.figure(figsize
=(16,4))
832 plot_one(hmin
,hmax
,dat
,'y',linestyle
='-',c
='#468a29',label
='FM Observed')
833 plot_one(hmin
,hmax
,dat
,'m',linestyle
='-',c
='k',label
='FM Model')
834 plot_features(hmin
,hmax
,dat
,linestyle
='-',c
='k',label
='FM Model')
837 if 'test_ind' in dat
.keys():
838 test_ind
= dat
["test_ind"]
841 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
842 # Note: the code within the tildes here makes a more complex, annotated plot
843 if (test_ind
is not None) and ('m' in dat
.keys()):
844 plt
.axvline(test_ind
, linestyle
=':', c
='k', alpha
=.8)
845 yy
= plt
.ylim() # used to format annotations
846 plot_y0
= np
.max([hmin
, test_ind
]) # Used to format annotations
847 plot_y1
= np
.min([hmin
, test_ind
])
848 plt
.annotate('', xy
=(hmin
, yy
[0]),xytext
=(plot_y0
,yy
[0]),
849 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
850 annotation_clip
=False)
851 plt
.annotate('(Training)',xy
=((hmin
+plot_y0
)/2,yy
[1]),xytext
=((hmin
+plot_y0
)/2,yy
[1]+1), ha
= 'right',
852 annotation_clip
=False, alpha
=.8)
853 plt
.annotate('', xy
=(plot_y0
, yy
[0]),xytext
=(hmax
,yy
[0]),
854 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
855 annotation_clip
=False)
856 plt
.annotate('(Forecast)',xy
=(hmax
-(hmax
-test_ind
)/2,yy
[1]),
857 xytext
=(hmax
-(hmax
-test_ind
)/2,yy
[1]+1),
858 annotation_clip
=False, alpha
=.8)
859 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
862 if title
is not None:
866 # print('title',type(t),t)
869 if title2
is not None:
871 t
= t
+ ' (' + rmse_data_str(dat
)+')'
872 if plot_period
== "predict":
873 t
= t
+ " - Forecast Period"
877 plt
.xlabel('Time (hours)')
881 plt
.ylabel('FM (%) / Rain (mm/h)')
883 plt
.ylabel('Fuel moisture content (%)')
886 plt
.legend(loc
="upper left")
889 return np
.sqrt(mean_squared_error(a
.flatten(), b
.flatten()))
891 def rmse_skip_nan(x
, y
):
892 mask
= ~np
.isnan(x
) & ~np
.isnan(y
)
893 if np
.count_nonzero(mask
):
894 return np
.sqrt(np
.mean((x
[mask
] - y
[mask
]) ** 2))
899 rmse
= rmse_skip_nan(a
,b
)
900 return "RMSE " + "{:.3f}".format(rmse
)
902 def rmse_data_str(dat
, predict
=True, hours
= None, test_ind
= None):
903 # Return RMSE for model object in formatted string. Used within plotting
905 # dat: (dict) fmda dictionary
906 # predict: (bool) Whether to return prediction period RMSE. Default True
907 # hours: (int) total number of modeled time periods
908 # test_ind: (int) start of test period
909 # Return: (str) RMSE value
915 if 'test_ind' in dat
:
916 test_ind
= dat
['test_ind']
918 if 'm' in dat
and 'y' in dat
:
919 if predict
and hours
is not None and test_ind
is not None:
920 return rmse_str(dat
['m'][test_ind
:hours
],dat
['y'].flatten()[test_ind
:hours
])
922 return rmse_str(dat
['m'],dat
['y'].flatten())
927 # Calculate mean absolute error
929 return ((a
- b
).__abs
__()).mean()
931 def rmse_data(dat
, hours
= None, h2
= None, simulation
='m', measurements
='fm'):
938 fm
= dat
[measurements
]
941 train
=rmse(m
[:h2
], fm
[:h2
])
942 predict
= rmse(m
[h2
:hours
], fm
[h2
:hours
])
943 all
= rmse(m
[:hours
], fm
[:hours
])
944 print(case
,'Training 1 to',h2
,'hours RMSE: ' + str(np
.round(train
, 4)))
945 print(case
,'Prediction',h2
+1,'to',hours
,'hours RMSE: ' + str(np
.round(predict
, 4)))
946 print(f
"All predictions hash: {hash2(m)}")
948 return {'train':train
, 'predict':predict
, 'all':all
}
952 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
956 def get_file(filename
, data_dir
='data'):
957 # Check for file locally, retrieve with wget if not
958 if osp
.exists(osp
.join(data_dir
, filename
)):
959 print(f
"File {osp.join(data_dir, filename)} exists locally")
960 elif not osp
.exists(filename
):
962 base_url
= "https://demo.openwfm.org/web/data/fmda/dicts/"
963 print(f
"Retrieving data {osp.join(base_url, filename)}")
964 subprocess
.call(f
"wget -P {data_dir} {osp.join(base_url, filename)}", shell
=True)