1 ## Set of Functions to process and format fuel moisture model inputs
2 ## These functions are specific to the particulars of the input data, and may not be generally applicable
3 ## Generally applicable functions should be in utils.py
4 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6 import numpy
as np
, random
7 from numpy
.random
import rand
8 import tensorflow
as tf
10 from sklearn
.metrics
import mean_squared_error
11 import matplotlib
.pyplot
as plt
12 from moisture_models
import model_decay
, model_moisture
13 from datetime
import datetime
, timedelta
14 from utils
import is_numeric_ndarray
, hash2
19 from utils
import Dict
, str2time
, check_increment
, time_intp
, read_pkl
23 # New Dict Functions as of 2024-10-2, needs more testing
24 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27 # Static features are based on physical location, e.g. location of RAWS site
28 'static': ['elev', 'lon', 'lat'],
29 # Atmospheric weather features come from either RAWS subdict or HRRR
30 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew'],
31 # Features that require calculation. NOTE: rain only calculated in HRRR, not RAWS
32 'engineered': ['doy', 'hod', 'rain']
35 def build_train_dict(input_file_paths
, params_data
, spatial
=True, atm_source
="HRRR", forecast_step
=3, verbose
=True, features_subset
=None, drop_na
=True):
36 # TODO: process involves multiple loops over dictionary keys. Inefficient, but functionality are isolated in separate functions that call for loops so will be tedious to fix
38 # Define Forecast Step for atmospheric data
39 # If atm_source is RAWS, forecast hour not used AND if spatial data must have fetures_subset input to avoid incompatibility with building training data
40 # For HRRR data, hard coding previous step as f00, meaning rain will be calculated as diff between forecast hour and 0th hour
41 if atm_source
== "RAWS":
42 print("Atmospheric data source is RAWS, so forecast_step is not used")
43 forecast_step
= 0 # set to 0 for future compatibility
46 assert features_subset
is not None, "For RAWS atmospheric data as source for spatial training set, argument features_subset cannot be None. Provide a list of features to subset the RAWS locations otherwise there will be errors when trying to build models with certain features."
47 elif atm_source
== "HRRR":
48 fstep
= int2fstep(forecast_step
)
50 fprev
= int2fstep(forecast_step
-1)
54 # Extract hours value from data params since it might need to change based on forecast hour time shift
55 hours
= params_data
['hours']
56 if forecast_step
> 0 and drop_na
and hours
is not None:
57 hours
= int(hours
- forecast_step
)
59 # Loop over input dictionary cases, extract and calculate features, then run data filters
61 for input_file_path
in input_file_paths
:
64 print(f
"Extracting data from input file {input_file_path}")
65 dict0
= read_pkl(input_file_path
)
68 # Extract features from subdicts
69 X
, names
= build_features_single(dict0
[key
], atm
=atm_source
, fstep
=fstep
, fprev
= fprev
)
71 # Get time from HRRR (bc always regular intervals) and interpolate RAWS data to those times
72 time
= str2time(dict0
[key
]['HRRR']['time'])
73 hrrr_increment
= check_increment(time
,id=key
+f
' {"HRRR"}.time')
74 if hrrr_increment
< 1:
75 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
77 time_raws
= str2time(dict0
[key
]['RAWS']['time_raws'])
78 check_increment(time_raws
,id=dict0
[key
]['loc']['STID']+' RAWS.time_raws')
80 # Extract outcome data
81 y
= get_fm(dict0
[key
], time
)
83 # Shift atmospheric data in time if using a forecast step
85 # Get indices to shift
86 atm_names
= feature_types
['atm'] + ['rain']
87 indices
= [names
.index(item
) for item
in atm_names
if item
in names
]
88 # Shift indices to future in time to line up forecast with desired time
89 X
= shift_time(X
, indices
, forecast_step
)
91 print(f
"Shifted time based on forecast step {forecast_step}. Dropping NA at beginning of feature data and corresponding times of output data")
92 X
= X
[forecast_step
:, :]
94 time
= time
[forecast_step
:]
99 'filename': input_file_path
,
100 'loc': dict0
[key
]['loc'],
104 'features_list': names
,
105 'atm_source': atm_source
,
106 'forecast_step': forecast_step
110 # Subset timeseries into shorter stretches, discard short ones
111 if hours
is not None:
114 print(f
"Splitting Timeseries into smaller portions to aid with data filtering. Input data param for max timeseries hours: {hours}")
115 new_dict
= split_timeseries(new_dict
, hours
=hours
, verbose
=verbose
)
116 new_dict
= discard_keys_with_short_y(new_dict
, hours
=hours
, verbose
=False)
118 # Check for suspect data
119 flags
= flag_dict_keys(new_dict
, params_data
['zero_lag_threshold'], params_data
['max_intp_time'], max_y
= params_data
['max_fm'], min_y
= params_data
['min_fm'], verbose
=verbose
)
121 # Remove flagged cases
122 cases
= list([*new_dict
.keys()])
123 flagged_cases
= [element
for element
, flag
in zip(cases
, flags
) if flag
== 1]
124 remove_key_list(new_dict
, flagged_cases
, verbose
=verbose
)
127 if atm_source
== "HRRR":
128 new_dict
= combine_nested(new_dict
)
129 elif atm_source
== "RAWS":
130 new_dict
= subset_by_features(new_dict
, features_subset
)
131 new_dict
= combine_nested(new_dict
)
133 return Dict(new_dict
)
135 def int2fstep(forecast_step
, max_hour
=5):
137 Converts an integer forecast step into a formatted string with a prefix 'f'
138 and zero-padded to two digits. Format of HRRR data forecast hours
141 - forecast_step (int): The forecast step to convert. Must be an integer
142 between 0 and max_hour (inclusive).
143 - max_hour (int, optional): The maximum allowable forecast step.
144 Depends on how many forecast hours were collected for input data. Default is 5.
147 - str: A formatted string representing the forecast step, prefixed with 'f'
148 and zero-padded to two digits (e.g., 'f01', 'f02').
151 - TypeError: If forecast_step is not an integer.
152 - ValueError: If forecast_step is not between 0 and max_hour (inclusive).
158 if not isinstance(forecast_step
, int):
159 raise TypeError(f
"forecast_step must be an integer.")
160 if not (0 <= forecast_step
<= max_hour
):
161 raise ValueError(f
"forecast_step must be between 0 and {max_hour}, the largest forecast step in input data.")
163 fstep
='f'+str(forecast_step
).zfill(2)
166 def check_feat(feat
, d
):
168 raise ValueError(f
"Feature {feat} not found")
170 def get_time(d
, atm
="HRRR"):
171 check_feat('time', d
[atm
])
172 time
= str2time(d
[atm
]['time'])
175 def get_static(d
, hours
):
177 # Use all static vars, don't allow for missing
178 names
= feature_types
['static']
180 check_feat(feat
, d
['loc'])
181 cols
.append(np
.full(hours
,d
['loc'][feat
]))
184 def get_hrrr_atm(d
, fstep
):
186 # Use all names, don't allow for missing
187 names
= feature_types
['atm'].copy()
189 check_feat(feat
, d
["HRRR"][fstep
])
190 v
= d
["HRRR"][fstep
][feat
]
194 def calc_time_features(time
):
195 names
= ['doy', 'hod']
196 doy
= np
.array([dt
.timetuple().tm_yday
- 1 for dt
in time
])
197 hod
= time
.astype('datetime64[h]').astype(int) % 24
201 def calc_hrrr_rain(d
, fstep
, fprev
):
202 # NOTE: if fstep and fprev are both f00, it will return all zeros which is what the f00 HRRR always is. If fprev is not hard coded as zero this might return nonsense, but not adding any checks yet for simplicity
203 rain
= d
["HRRR"][fstep
]['precip_accum']- d
["HRRR"][fprev
]['precip_accum']
206 def get_raws_atm(d
, time
, check
= True):
207 # may not be the same as requested time vector, used to interpolate to input time
208 time_raws
=str2time(d
['RAWS']['time_raws'])
213 # Loop through all features, including rain
214 for feat
in feature_types
['atm']+['rain']:
215 if feat
in d
['RAWS']:
217 v
= time_intp(time_raws
, v
, time
)
218 assert len(v
)==len(time
), f
"RAWS feature {feat} not equal length to input time: {len(v)} vs {len(time)}"
223 def build_features_single(subdict
, atm
="HRRR", fstep
=None, fprev
=None):
227 time
= get_time(subdict
)
228 # Calculate derived time variables
229 tvars
, tnames
= calc_time_features(time
)
230 # Get Static Features, extends to hours
231 static_vars
, static_names
= get_static(subdict
, hours
= len(time
))
232 # Get atmospheric variables based on data source. HRRR requires rain calculation
234 atm_vars
, atm_names
= get_hrrr_atm(subdict
, fstep
)
235 rain
= calc_hrrr_rain(subdict
, fstep
, fprev
)
236 atm_vars
.append(rain
)
237 atm_names
.append("rain")
239 atm_vars
, atm_names
= get_raws_atm(subdict
, time
)
241 raise ValueError(f
"Unrecognized atmospheric data source: {atm}")
242 # Put everything together and stack
243 cols
= tvars
+ static_vars
+ atm_vars
244 X
= np
.column_stack(cols
)
245 names
= tnames
+ static_names
+ atm_names
251 time_raws
= str2time(d
['RAWS']['time_raws'])
252 return time_intp(time_raws
,fm
,time
)
255 def shift_time(X_array
, inds
, forecast_step
):
257 Shifts specified columns of a numpy ndarray forward by a given number of steps.
261 X_array : numpy.ndarray
262 The input 2D array where specific columns will be shifted.
264 Indices of the columns within X_array to be shifted.
266 The number of positions to shift the specified columns forward.
271 A modified copy of the input array with specified columns shifted forward
272 by `forecast_step` and the leading positions in those columns filled with NaN.
276 >>> X_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
278 >>> shift_time(X_array, inds, 1)
279 array([[nan, 2., nan],
284 if not isinstance(forecast_step
, int) or forecast_step
<= 0:
285 raise ValueError("forecast_step must be an integer greater than 0.")
287 X_shift
= X_array
.astype(float).copy()
288 X_shift
[:forecast_step
, inds
] = np
.nan
289 X_shift
[forecast_step
:, inds
] = X_array
[:-forecast_step
, inds
]
294 def subset_by_features(nested_dict
, input_features
, verbose
=True):
296 Subsets a nested dictionary to only include keys where all strings in the input_features
297 are present in the dictionary's 'features_list' subkey. Primarily used for RAWS dictionaries where desired features might not be present at all ground stations.
300 nested_dict (dict): The nested dictionary with a 'features_list' subkey.
301 input_features (list): The list of features to be checked.
304 dict: A subset of the input dictionary with only the matching keys.
307 print(f
"Subsetting to cases with features: {input_features}")
309 # Create a new dictionary to store the result
312 # Iterate through the keys in the nested dictionary
313 for key
, value
in nested_dict
.items():
314 # Check if 'features_list' key exists and all input_features are in the list
315 if 'features_list' in value
and all(feature
in value
['features_list'] for feature
in input_features
):
316 # Add to the result if all features are present
319 print(f
"Removing {key} due to missing features")
325 def remove_key_list(d
, ls
, verbose
=False):
329 print(f
"Removing key {key} due to data flags")
332 def split_timeseries(dict0
, hours
, naming_convention
= "_set_", verbose
=False):
334 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
336 cases
= list([*dict0
.keys()])
338 for key
, data
in dict0
.items():
340 print(f
"Processing case: {key}")
341 print(f
"Length of y vector: {len(data['y'])}")
342 if type(data
['time'][0]) == str:
343 time
=str2time(data
['time'])
349 # Determine the start and end time for the 720-hour portions
352 current_time
= start_time
354 while current_time
< end_time
:
355 next_time
= current_time
+ timedelta(hours
=hours
)
357 # Create a mask for the time range
358 mask
= (time
>= current_time
) & (time
< next_time
)
360 # Apply the mask to extract the portions
361 new_time
= time
[mask
]
362 new_X
= X_array
[mask
]
363 new_y
= y_array
[mask
]
365 # Save the portions in the new dictionary with naming convention if second or more portion
366 if portion_index
> 1:
367 new_key
= f
"{key}{naming_convention}{portion_index}"
370 dict1
[new_key
] = {'time': new_time
, 'X': new_X
, 'y': new_y
}
372 # Add other keys that aren't subset
373 for key2
in dict0
[key
].keys():
374 if key2
not in ['time', 'X', 'y']:
375 dict1
[new_key
][key2
] = dict0
[key
][key2
]
376 # Update Case name and Id (same for now, overloaded terminology)
377 dict1
[new_key
]['case'] = new_key
378 dict1
[new_key
]['id'] = new_key
379 dict1
[new_key
]['hours'] = len(dict1
[new_key
]['y'])
382 # Move to the next portion
383 current_time
= next_time
386 print(f
"Partitions of length {hours} from case {key}: {portion_index-1}")
389 def flag_lag_stretches(x
, threshold
, lag
= 1):
391 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
393 lags
= np
.round(np
.diff(x
, n
=lag
), 8)
394 zero_lag_indices
= np
.where(lags
== 0)[0]
395 current_run_length
= 1
396 for i
in range(1, len(zero_lag_indices
)):
397 if zero_lag_indices
[i
] == zero_lag_indices
[i
-1] + 1:
398 current_run_length
+= 1
399 if current_run_length
> threshold
:
402 current_run_length
= 1
407 def flag_dict_keys(dict0
, lag_1_threshold
, lag_2_threshold
, max_y
, min_y
, verbose
=False):
409 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
411 cases
= list([*dict0
.keys()])
412 flags
= np
.zeros(len(cases
))
413 for i
, case
in enumerate(cases
):
416 print(f
"Case: {case}")
418 if flag_lag_stretches(y
, threshold
=lag_1_threshold
, lag
=1):
420 print(f
"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
422 if flag_lag_stretches(y
, threshold
=lag_2_threshold
, lag
=2):
424 print(f
"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
426 if np
.any(y
>=max_y
) or np
.any(y
<=min_y
):
428 print(f
"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
433 def discard_keys_with_short_y(input_dict
, hours
, verbose
=False):
435 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
438 max_y
= max(case
['y'].shape
[0] for case
in input_dict
.values())
440 raise ValueError(f
"Given input hours of {hours}, all cases in input dictionary would be filtered out as too short. Max timeseries y length: {max_y}. Try using a larger value for `hours` in data params")
442 discarded_keys
= [key
for key
, value
in input_dict
.items() if len(value
['y']) < hours
]
445 print(f
"Discarded keys due to y length less than {hours}: {discarded_keys}")
447 filtered_dict
= {key
: value
for key
, value
in input_dict
.items() if key
not in discarded_keys
}
453 # Utility to combine nested fmda dictionaries
454 def combine_nested(nested_input_dict
, verbose
=True):
456 Combines input data dictionaries.
460 verbose : bool, optional
461 If True, prints status messages. Default is True.
463 # Setup return dictionary
465 # Use the helper function to populate the keys
466 d
['id'] = _combine_key(nested_input_dict
, 'id')
467 d
['case'] = _combine_key(nested_input_dict
, 'case')
468 d
['filename'] = _combine_key(nested_input_dict
, 'filename')
469 d
['time'] = _combine_key(nested_input_dict
, 'time')
470 d
['X'] = _combine_key(nested_input_dict
, 'X')
471 d
['y'] = _combine_key(nested_input_dict
, 'y')
472 d
['atm_source'] = _combine_key(nested_input_dict
, 'atm_source')
473 d
['forecast_step'] = _combine_key(nested_input_dict
, 'forecast_step')
475 # Build the loc subdictionary using _combine_key for each loc key
477 'STID': _combine_key(nested_input_dict
, 'loc', 'STID'),
478 'lat': _combine_key(nested_input_dict
, 'loc', 'lat'),
479 'lon': _combine_key(nested_input_dict
, 'loc', 'lon'),
480 'elev': _combine_key(nested_input_dict
, 'loc', 'elev'),
481 'pixel_x': _combine_key(nested_input_dict
, 'loc', 'pixel_x'),
482 'pixel_y': _combine_key(nested_input_dict
, 'loc', 'pixel_y')
485 # Handle features_list separately with validation
486 features_list
= _combine_key(nested_input_dict
, 'features_list')
488 first_features_list
= features_list
[0]
489 for fl
in features_list
:
490 if fl
!= first_features_list
:
491 warnings
.warn("Different features_list found in the nested input dictionaries.")
492 d
['features_list'] = first_features_list
496 def _combine_key(nested_input_dict
, key
, subkey
=None):
498 for input_dict
in nested_input_dict
.values():
499 if isinstance(input_dict
, dict):
502 combined_list
.append(input_dict
[key
][subkey
])
504 combined_list
.append(input_dict
[key
])
506 warning_message
= f
"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries. Setting value to None."
507 warnings
.warn(warning_message
)
508 combined_list
.append(None)
510 raise ValueError(f
"Expected a dictionary, but got {type(input_dict)}")
514 def compare_dicts(dict1
, dict2
, keys
):
516 if dict1
.get(key
) != dict2
.get(key
):
520 items
= '_items_' # dictionary key to keep list of items in
521 def check_data_array(dat
,hours
,a
,s
):
526 print("array %s %s length %i min %s max %s hash %s %s" %
527 (a
,s
,len(ar
),min(ar
),max(ar
),hash2(ar
),type(ar
)))
528 if hours
is not None:
530 print('len(%a) = %i does not equal to hours = %i' % (a
,len(ar
),hours
))
533 print(a
+ ' not present')
535 def check_data_scalar(dat
,a
):
539 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
541 print(a
+ ' not present' )
543 def check_data(dat
,case
=True,name
=None):
544 dat
[items
] = list(dat
.keys()) # add list of items to the dictionary
548 check_data_scalar(dat
,'filename')
549 check_data_scalar(dat
,'title')
550 check_data_scalar(dat
,'note')
551 check_data_scalar(dat
,'hours')
552 check_data_scalar(dat
,'h2')
553 check_data_scalar(dat
,'case')
558 check_data_array(dat
,hours
,'E','drying equilibrium (%)')
559 check_data_array(dat
,hours
,'Ed','drying equilibrium (%)')
560 check_data_array(dat
,hours
,'Ew','wetting equilibrium (%)')
561 check_data_array(dat
,hours
,'Ec','equilibrium equilibrium (%)')
562 check_data_array(dat
,hours
,'rain','rain intensity (mm/h)')
563 check_data_array(dat
,hours
,'fm','RAWS fuel moisture data (%)')
564 check_data_array(dat
,hours
,'m','fuel moisture estimate (%)')
566 print('items:',dat
[items
])
567 for a
in dat
[items
].copy():
569 if dat
[a
] is None or np
.isscalar(dat
[a
]):
570 check_data_scalar(dat
,a
)
571 elif is_numeric_ndarray(ar
):
573 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
574 "max",np
.max(ar
),"hash",hash2(ar
),"type",type(ar
))
575 elif isinstance(ar
, tf
.Tensor
):
576 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
577 "max",np
.max(ar
),"type",type(ar
))
579 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
580 del dat
[items
] # clean up
582 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
583 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
584 def to_json(dic
,filename
):
585 # Write given dictionary as json file.
586 # This utility is used because the typical method fails on numpy.ndarray
589 # filename: (str) output json filename, expect a ".json" file extension
592 print('writing ',filename
)
596 if type(dic
[i
]) is np
.ndarray
:
597 new
[i
]=dic
[i
].tolist() # because numpy.ndarray is not serializable
600 # print('i',type(new[i]))
601 new
['filename']=filename
602 print('Hash: ', hash2(new
))
603 json
.dump(new
,open(filename
,'w'),indent
=4)
605 def from_json(filename
):
606 # Read json file given a filename
607 # Inputs: filename (str) expect a ".json" string
609 print('reading ',filename
)
610 dic
=json
.load(open(filename
,'r'))
613 if type(dic
[i
]) is list:
614 new
[i
]=np
.array(dic
[i
]) # because ndarray is not serializable
618 print('Hash: ', hash2(new
))
621 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
623 # Function to simulate moisture data and equilibrium for model testing
624 def create_synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,DeltaE
=0.0):
627 hour
= np
.array(range(hours
))
628 day
= np
.array(range(hours
))/24.
630 # artificial equilibrium data
631 E
= 100.0*np
.power(np
.sin(np
.pi
*day
),4) # diurnal curve
634 m_f
= np
.zeros(hours
)
635 m_f
[0] = 0.1 # initial FMC
637 for t
in range(hours
-1):
638 m_f
[t
+1] = max(0.,model_decay(m_f
[t
],E
[t
]) + random
.gauss(0,process_noise
) )
639 data
= m_f
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
641 return E
,m_f
,data
,hour
,h2
,DeltaE
643 # the following input or output dictionary with all model data and variables
645 def synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,
646 DeltaE
=0.0,Emin
=5,Emax
=30,p_rain
=0.01,max_rain
=10.0):
649 hour
= np
.array(range(hours
))
650 day
= np
.array(range(hours
))/24.
651 # artificial equilibrium data
652 E
= np
.power(np
.sin(np
.pi
*day
),power
) # diurnal curve betwen 0 and 1
653 E
= Emin
+(Emax
- Emin
)*E
656 Ew
=np
.maximum(E
-0.5,0)
657 rain
= np
.multiply(rand(hours
) < p_rain
, rand(hours
)*max_rain
)
660 fm
[0] = 0.1 # initial FMC
662 for t
in range(hours
-1):
663 fm
[t
+1] = max(0.,model_moisture(fm
[t
],Ed
[t
-1],Ew
[t
-1],rain
[t
-1]) + random
.gauss(0,process_noise
))
664 fm
= fm
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
665 dat
= {'E':E
,'Ew':Ew
,'Ed':Ed
,'fm':fm
,'hours':hours
,'h2':h2
,'DeltaE':DeltaE
,'rain':rain
,'title':'Synthetic data'}
669 def plot_one(hmin
,hmax
,dat
,name
,linestyle
,c
,label
, alpha
=1,type='plot'):
670 # helper for plot_data
677 hour
= np
.array(range(hmin
,hmax
))
679 plt
.plot(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
680 elif type=='scatter':
681 plt
.scatter(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
683 # Lookup table for plotting features
685 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
686 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
687 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
689 def plot_feature(x
, y
, feature_name
):
690 style
= plot_styles
.get(feature_name
, {})
691 plt
.plot(x
, y
, **style
)
693 def plot_features(hmin
,hmax
,dat
,linestyle
,c
,label
,alpha
=1):
694 hour
= np
.array(range(hmin
,hmax
))
695 for feat
in dat
.features_list
:
696 i
= dat
.all_features_list
.index(feat
) # index of main data
697 if feat
in plot_styles
.keys():
698 plot_feature(x
=hour
, y
=dat
['X'][:,i
][hmin
:hmax
], feature_name
=feat
)
700 def plot_data(dat
, plot_period
='all', create_figure
=False,title
=None,title2
=None,hmin
=0,hmax
=None,xlabel
=None,ylabel
=None):
701 # Plot fmda dictionary of data and model if present
703 # dat: FMDA dictionary
704 # inverse_scale: logical, whether to inverse scale data
707 # dat = copy.deepcopy(dat0)
713 hmax
= min(hmax
, dat
['hours'])
714 if plot_period
== "all":
716 elif plot_period
== "predict":
717 assert "test_ind" in dat
.keys()
718 hmin
= dat
['test_ind']
721 raise ValueError(f
"unrecognized time period for plotting plot_period: {plot_period}")
725 plt
.figure(figsize
=(16,4))
727 plot_one(hmin
,hmax
,dat
,'y',linestyle
='-',c
='#468a29',label
='FM Observed')
728 plot_one(hmin
,hmax
,dat
,'m',linestyle
='-',c
='k',label
='FM Model')
729 plot_features(hmin
,hmax
,dat
,linestyle
='-',c
='k',label
='FM Model')
732 if 'test_ind' in dat
.keys():
733 test_ind
= dat
["test_ind"]
736 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
737 # Note: the code within the tildes here makes a more complex, annotated plot
738 if (test_ind
is not None) and ('m' in dat
.keys()):
739 plt
.axvline(test_ind
, linestyle
=':', c
='k', alpha
=.8)
740 yy
= plt
.ylim() # used to format annotations
741 plot_y0
= np
.max([hmin
, test_ind
]) # Used to format annotations
742 plot_y1
= np
.min([hmin
, test_ind
])
743 plt
.annotate('', xy
=(hmin
, yy
[0]),xytext
=(plot_y0
,yy
[0]),
744 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
745 annotation_clip
=False)
746 plt
.annotate('(Training)',xy
=((hmin
+plot_y0
)/2,yy
[1]),xytext
=((hmin
+plot_y0
)/2,yy
[1]+1), ha
= 'right',
747 annotation_clip
=False, alpha
=.8)
748 plt
.annotate('', xy
=(plot_y0
, yy
[0]),xytext
=(hmax
,yy
[0]),
749 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
750 annotation_clip
=False)
751 plt
.annotate('(Forecast)',xy
=(hmax
-(hmax
-test_ind
)/2,yy
[1]),
752 xytext
=(hmax
-(hmax
-test_ind
)/2,yy
[1]+1),
753 annotation_clip
=False, alpha
=.8)
754 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
757 if title
is not None:
761 # print('title',type(t),t)
764 if title2
is not None:
766 t
= t
+ ' (' + rmse_data_str(dat
)+')'
767 if plot_period
== "predict":
768 t
= t
+ " - Forecast Period"
772 plt
.xlabel('Time (hours)')
776 plt
.ylabel('FM (%) / Rain (mm/h)')
778 plt
.ylabel('Fuel moisture content (%)')
781 plt
.legend(loc
="upper left")
784 return np
.sqrt(mean_squared_error(a
.flatten(), b
.flatten()))
786 def rmse_skip_nan(x
, y
):
787 mask
= ~np
.isnan(x
) & ~np
.isnan(y
)
788 if np
.count_nonzero(mask
):
789 return np
.sqrt(np
.mean((x
[mask
] - y
[mask
]) ** 2))
794 rmse
= rmse_skip_nan(a
,b
)
795 return "RMSE " + "{:.3f}".format(rmse
)
797 def rmse_data_str(dat
, predict
=True, hours
= None, test_ind
= None):
798 # Return RMSE for model object in formatted string. Used within plotting
800 # dat: (dict) fmda dictionary
801 # predict: (bool) Whether to return prediction period RMSE. Default True
802 # hours: (int) total number of modeled time periods
803 # test_ind: (int) start of test period
804 # Return: (str) RMSE value
810 if 'test_ind' in dat
:
811 test_ind
= dat
['test_ind']
813 if 'm' in dat
and 'y' in dat
:
814 if predict
and hours
is not None and test_ind
is not None:
815 return rmse_str(dat
['m'].flatten()[test_ind
:hours
],dat
['y'].flatten()[test_ind
:hours
])
817 return rmse_str(dat
['m'].flatten(),dat
['y'].flatten())
822 # Calculate mean absolute error
824 return ((a
- b
).__abs
__()).mean()
826 def rmse_data(dat
, hours
= None, h2
= None, simulation
='m', measurements
='fm'):
833 fm
= dat
[measurements
]
836 train
=rmse(m
[:h2
], fm
[:h2
])
837 predict
= rmse(m
[h2
:hours
], fm
[h2
:hours
])
838 all
= rmse(m
[:hours
], fm
[:hours
])
839 print(case
,'Training 1 to',h2
,'hours RMSE: ' + str(np
.round(train
, 4)))
840 print(case
,'Prediction',h2
+1,'to',hours
,'hours RMSE: ' + str(np
.round(predict
, 4)))
841 print(f
"All predictions hash: {hash2(m)}")
843 return {'train':train
, 'predict':predict
, 'all':all
}
847 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
851 def get_file(filename
, data_dir
='data'):
852 # Check for file locally, retrieve with wget if not
853 if osp
.exists(osp
.join(data_dir
, filename
)):
854 print(f
"File {osp.join(data_dir, filename)} exists locally")
855 elif not osp
.exists(filename
):
857 base_url
= "https://demo.openwfm.org/web/data/fmda/dicts/"
858 print(f
"Retrieving data {osp.join(base_url, filename)}")
859 subprocess
.call(f
"wget -P {data_dir} {osp.join(base_url, filename)}", shell
=True)