Update data_funcs.py
[notebooks.git] / fmda / data_funcs.py
blobca355f9009b15cc3146c0af743f574f8118e701f
1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy as np, random
5 from numpy.random import rand
6 import tensorflow as tf
7 import pickle, os
8 from sklearn.metrics import mean_squared_error
9 import matplotlib.pyplot as plt
10 from moisture_models import model_decay, model_moisture
11 from datetime import datetime, timedelta
12 from utils import is_numeric_ndarray, hash2
13 import json
14 import copy
15 import subprocess
16 import os.path as osp
17 from utils import Dict, str2time, check_increment, time_intp, read_pkl
18 import warnings
21 # New Dict Functions as of 2024-10-2, needs more testing
22 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24 feature_types = {
25 # Static features are based on physical location, e.g. location of RAWS site
26 'static': ['elev', 'lon', 'lat'],
27 # Atmospheric weather features come from either RAWS subdict or HRRR
28 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew'],
29 # Features that require calculation. NOTE: rain only calculated in HRRR, not RAWS
30 'engineered': ['doy', 'hod', 'rain']
33 def check_feat(feat, d):
34 if feat not in d:
35 raise ValueError(f"Feature {feat} not found")
37 def get_time(d, atm="HRRR"):
38 check_feat('time', d[atm])
39 time = str2time(d[atm]['time'])
40 return time
42 def get_static(d, hours):
43 cols = []
44 # Use all static vars, don't allow for missing
45 names = feature_types['static']
46 for feat in names:
47 check_feat(feat, d['loc'])
48 cols.append(np.full(hours,d['loc'][feat]))
49 return cols, names
51 def get_hrrr_atm(d, fstep):
52 cols = []
53 # Use all names, don't allow for missing
54 names = feature_types['atm'].copy()
55 for feat in names:
56 check_feat(feat, d["HRRR"][fstep])
57 v = d["HRRR"][fstep][feat]
58 cols.append(v)
59 return cols, names
61 def calc_time_features(time):
62 names = ['doy', 'hod']
63 doy = np.array([dt.timetuple().tm_yday - 1 for dt in time])
64 hod = time.astype('datetime64[h]').astype(int) % 24
65 cols = [doy, hod]
66 return cols, names
68 def calc_hrrr_rain(d, fstep, fprev):
69 rain = d["HRRR"][fstep]['precip_accum']- d["HRRR"][fprev]['precip_accum']
70 return rain
72 def get_raws_atm(d, time):
73 # may not be the same as requested time vector, used to interpolate to input time
74 time_raws=str2time(d['RAWS']['time_raws'])
76 cols = []
77 names = []
79 # Loop through all features, including rain
80 for feat in feature_types['atm']+['rain']:
81 if feat in d['RAWS']:
82 v = d['RAWS'][feat]
83 v = time_intp(time_raws, v, time)
84 assert len(v)==len(time), f"RAWS feature {feat} not equal length to input time: {len(v)} vs {len(time)}"
85 cols.append(v)
86 names.append(feat)
87 return cols, names
89 def build_features_single(subdict, atm ="HRRR", fstep=None, fprev=None):
90 # cols = []
91 # names = []
92 # Get Time variable
93 time = get_time(subdict)
94 # Calculate derived time variables
95 tvars, tnames = calc_time_features(time)
96 # Get Static Features, extends to hours
97 static_vars, static_names = get_static(subdict, hours = len(time))
98 # Get atmospheric variables based on data source. HRRR requires rain calculation
99 if atm == "HRRR":
100 atm_vars, atm_names = get_hrrr_atm(subdict, fstep)
101 rain = calc_hrrr_rain(subdict, fstep, fprev)
102 atm_vars.append(rain)
103 atm_names.append("rain")
104 elif atm == "RAWS":
105 atm_vars, atm_names = get_raws_atm(subdict, time)
106 else:
107 raise ValueError(f"Unrecognized atmospheric data source: {atm}")
108 # Put everything together and stack
109 cols = tvars + static_vars + atm_vars
110 X = np.column_stack(cols)
111 names = tnames + static_names + atm_names
113 return X, names
115 def get_fm(d, time):
116 fm = d['RAWS']['fm']
117 time_raws = str2time(d['RAWS']['time_raws'])
118 return time_intp(time_raws,fm,time)
120 def build_train_dict2(input_file_paths, params_data, spatial=True, atm_source="HRRR", forecast_step=1, verbose=True):
121 # TODO: process involves multiple loops over dictionary keys. Inefficient, but functionality are isolated in separate functions that call for loops so will be tedious to fix
124 # Define Forecast Step: NOTE: if atm_source == RAWS, this should be zero
125 if atm_source == "RAWS":
126 print("Atmospheric data source is RAWS, so forecast_step set to zero")
127 forecast_step=0
128 elif forecast_step > 0 and forecast_step < 100 and forecast_step == int(forecast_step):
129 fstep='f'+str(forecast_step).zfill(2)
130 fprev='f'+str(forecast_step-1).zfill(2)
131 # logging.info('Using data from step %s',fstep)
132 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
133 else:
134 # logging.critical('forecast_step must be integer between 1 and 99')
135 raise ValueError('bad forecast_step')
138 # Loop over input dictionary cases, extract and calculate features, then run data filters
139 new_dict = {}
140 for input_file_path in input_file_paths:
141 if verbose:
142 print("~"*75)
143 print(f"Extracting data from input file {input_file_path}")
144 dict0 = read_pkl(input_file_path)
145 for key in dict0:
146 # if verbose:
147 # print("~"*50)
148 # print(f"Extracting data for case {key}")
149 X, names = build_features_single(dict0[key], atm=atm_source, fstep=fstep, fprev = fprev)
150 time = str2time(dict0[key]['HRRR']['time'])
151 hrrr_increment = check_increment(time,id=key+f' {"HRRR"}.time')
152 if hrrr_increment < 1:
153 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
154 raise(ValueError)
155 new_dict[key] = {
156 'id': key,
157 'case': key,
158 'filename': input_file_path,
159 'loc': dict0[key]['loc'],
160 'time': time,
161 'X': X,
162 'y': get_fm(dict0[key], time),
163 'features_list': names,
164 'atm_source': atm_source
167 # Run Data Filters
168 # Subset timeseries into shorter stretches, discard short ones
169 if verbose:
170 print("~"*75)
171 print(f"Splitting Timeseries into smaller portions to aid with data filtering. Input data param for max timeseries hours: {params_data['hours']}")
172 new_dict = split_timeseries(new_dict, hours=params_data['hours'], verbose=verbose)
173 new_dict = discard_keys_with_short_y(new_dict, hours=params_data['hours'], verbose=False)
175 # Check for suspect data
176 flags = flag_dict_keys(new_dict, params_data['zero_lag_threshold'], params_data['max_intp_time'], max_y = params_data['max_fm'], min_y = params_data['min_fm'], verbose=verbose)
178 # Remove flagged cases
179 cases = list([*new_dict.keys()])
180 flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
181 remove_key_list(new_dict, flagged_cases, verbose=verbose)
183 if spatial:
184 new_dict = combine_nested(new_dict)
186 return Dict(new_dict)
203 # Wrapper Functions to Put it all together
204 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206 # TODO: ENGINEERED TIME FEATURES:
207 # hod = rnn_dat.time.astype('datetime64[h]').astype(int) % 24
208 # doy = np.array([dt.timetuple().tm_yday - 1 for dt in rnn_dat.time])
210 def create_spatial_train(input_file_paths, params_data, atm_dict = "HRRR", verbose=False):
211 train = process_train_dict(input_file_paths, params_data = params_data, verbose=verbose)
212 train_sp = Dict(combine_nested(train))
213 return train_sp
215 def process_train_dict(input_file_paths, params_data, atm_dict = "HRRR", spatial=False, verbose=False):
216 if type(input_file_paths) is not list:
217 raise ValueError(f"Argument `input_file_paths` must be list, received {type(input_file_paths)}")
218 train = {}
219 for file_path in input_file_paths:
220 # Extract target and features
221 di = build_train_dict(file_path, atm=atm_dict, features_all=params_data['features_all'], verbose=verbose)
222 # Subset timeseries into shorter stretches
223 di = split_timeseries(di, hours=params_data['hours'], verbose=verbose)
224 di = discard_keys_with_short_y(di, hours=params_data['hours'], verbose=False)
225 # Check for suspect data
226 flags = flag_dict_keys(di, params_data['zero_lag_threshold'], params_data['max_intp_time'], max_y = params_data['max_fm'], min_y = params_data['min_fm'], verbose=verbose)
227 # Remove flagged cases
228 cases = list([*di.keys()])
229 flagged_cases = [element for element, flag in zip(cases, flags) if flag == 1]
230 remove_key_list(di, flagged_cases, verbose=verbose)
231 train.update(di)
232 if spatial:
233 train = combine_nested(train)
235 return Dict(train)
238 def subset_by_features(nested_dict, input_features, verbose=True):
240 Subsets a nested dictionary to only include keys where all strings in the input_features
241 are present in the dictionary's 'features_list' subkey. Primarily used for RAWS dictionaries where desired features might not be present at all ground stations.
243 Parameters:
244 nested_dict (dict): The nested dictionary with a 'features_list' subkey.
245 input_features (list): The list of features to be checked.
247 Returns:
248 dict: A subset of the input dictionary with only the matching keys.
250 if verbose:
251 print(f"Subsetting to cases with features: {input_features}")
253 # Create a new dictionary to store the result
254 result = {}
256 # Iterate through the keys in the nested dictionary
257 for key, value in nested_dict.items():
258 # Check if 'features_list' key exists and all input_features are in the list
259 if 'features_list' in value and all(feature in value['features_list'] for feature in input_features):
260 # Add to the result if all features are present
261 result[key] = value
263 return result
265 feature_types = {
266 # Static features are based on physical location, e.g. location of RAWS site
267 'static': ['elev', 'lon', 'lat'],
268 # Atmospheric weather features come from either RAWS subdict or HRRR
269 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew']
272 def build_train_dict(input_file_path,
273 forecast_step=1, atm="HRRR",features_all=['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'doy', 'hod', 'rain'], verbose=False):
274 # in:
275 # file_path list of strings - files as in read_test_pkl
276 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
277 # atm str - name of subdict where atmospheric vars are located
278 # features_list list of strings - names of keys in subdicts to collect into features matrix. Default is everything collected
279 # return:
280 # train dictionary with structure
281 # {key : {'key' : key, # copied subdict key
282 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
283 # 'time' : time, # datetime vector, spacing tres
284 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
285 # 'Y' : feat # features from atmosphere and location
290 # TODO: fix this
291 if 'rain' in features_all and (not features_all[-1]=='rain'):
292 raise ValueError(f"Make rain in features list last element since (working on fix as of 24-6-24), given features list: {features_list}")
294 if forecast_step > 0 and forecast_step < 100 and forecast_step == int(forecast_step):
295 fstep='f'+str(forecast_step).zfill(2)
296 fprev='f'+str(forecast_step-1).zfill(2)
297 # logging.info('Using data from step %s',fstep)
298 # logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
299 else:
300 # logging.critical('forecast_step must be integer between 1 and 99')
301 raise ValueError('bad forecast_step')
303 train = {}
304 with open(input_file_path, 'rb') as file:
305 # logging.info("loading file %s", file_path)
306 d = pickle.load(file)
307 for key in d:
308 atm_dict = atm
309 features_list = features_all
310 # logging.info('Processing subdictionary %s',key)
311 if key in train:
312 pass
313 # logging.warning('skipping duplicate key %s',key)
314 else:
315 subdict=d[key] # subdictionary for this case
316 loc=subdict['loc']
317 train[key] = {
318 'id': key, # store the key inside the dictionary, subdictionary will be used separatedly
319 'case':key,
320 'filename': input_file_path,
321 'loc': loc
323 desc='descr'
324 if desc in subdict:
325 train[desc]=subdict[desc]
326 time_hrrr=str2time(subdict[atm_dict]['time'])
327 # timekeeping
328 hours=len(d[key][atm_dict]['time'])
329 train[key]['hours']=hours
330 # train[key]['h2'] =hours # not doing prediction yet
331 hrrr_increment = check_increment(time_hrrr,id=key+f' {atm_dict}.time')
332 # logging.info(f'{atm_dict} increment is %s h',hrrr_increment)
333 if hrrr_increment < 1:
334 # logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
335 raise(ValueError)
337 # build matrix of features - assuming all the same length, if not column_stack will fail
338 train[key]['time']=time_hrrr
339 # logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
340 time_raws=str2time(subdict['RAWS']['time_raws']) # may not be the same as HRRR
341 # logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
342 check_increment(time_raws,id=key+' RAWS.time_raws')
343 # print_first(time_raws,num=5,id='RAWS.time_raws')
345 # Set up static vars
346 columns=[]
347 missing_features = []
348 for feat in features_list:
349 # For atmospheric features,
350 if feat in feature_types['atm']:
351 if atm_dict == "HRRR":
352 vec = subdict['HRRR'][fstep][feat]
353 columns.append(vec)
354 elif atm_dict == "RAWS":
355 if feat in subdict['RAWS'].keys():
356 vec = time_intp(time_raws, subdict['RAWS'][feat], time_hrrr)
357 columns.append(vec)
358 else:
359 missing_features.append(feat)
361 # For static features, repeat to fit number of time observations
362 elif feat in feature_types['static']:
363 columns.append(np.full(hours,loc[feat]))
364 # Add Engineered Time features, doy and hod
365 # hod = time_hrrr.astype('datetime64[h]').astype(int) % 24
366 # doy = np.array([dt.timetuple().tm_yday - 1 for dt in time_hrrr])
367 # columns.extend([doy, hod])
369 # compute rain as difference of accumulated precipitation
370 if 'rain' in features_list:
371 if atm_dict == "HRRR":
372 rain = subdict[atm_dict][fstep]['precip_accum']- subdict[atm_dict][fprev]['precip_accum']
373 # logging.info('%s rain as difference %s minus %s: min %s max %s',
374 # key,fstep,fprev,np.min(rain),np.max(rain))
375 elif atm_dict == "RAWS":
376 if 'rain' in subdict[atm_dict]:
377 rain = time_intp(time_raws,subdict[atm_dict]['rain'],time_hrrr)
378 else:
379 pass
380 # logging.info('No rain data found in RAWS subdictionary %s', key)
381 columns.append( rain ) # add rain feature
382 else:
383 missing_features.append('rain')
385 train[key]['X'] = np.column_stack(columns)
386 train[key]['features_list'] = [item for item in features_list if item not in missing_features]
388 fm=subdict['RAWS']['fm']
389 # logging.info('%s RAWS.fm length is %s',key,len(fm))
390 # interpolate RAWS sensors to HRRR time and over NaNs
391 train[key]['y'] = time_intp(time_raws,fm,time_hrrr)
392 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
394 if train[key]['y'] is None:
395 pass
396 # logging.error('Cannot create target matrix for %s, using None',key)
397 else:
398 pass
399 # logging.info(f"Created target matrix train[{key}]['y'] shape {train[key]['y'].shape}")
401 # logging.info('Created a "train" dictionary with %s items',len(train))
403 # clean up
405 keys_to_delete = []
406 for key in train:
407 if train[key]['X'] is None or train[key]['y'] is None:
408 # logging.warning('Deleting training item %s because features X or target Y are None', key)
409 keys_to_delete.append(key)
411 # Delete the items from the dictionary
412 if len(keys_to_delete)>0:
413 for key in keys_to_delete:
414 del train[key]
415 # logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
416 # len(keys_to_delete),len(train))
418 # output
420 # if output_file_path is not None:
421 # with open(output_file_path, 'wb') as file:
422 # logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
423 # pickle.dump(train, file)
425 # logging.info('pkl2train done')
427 return train
431 def remove_key_list(d, ls, verbose=False):
432 for key in ls:
433 if key in d:
434 if verbose:
435 print(f"Removing key {key} due to data flags")
436 del d[key]
438 def split_timeseries(dict0, hours, naming_convention = "_set_", verbose=False):
440 Given number of hours, splits nested fmda dictionary into smaller stretches. This is used primarily to aid in filtering out and removing stretches of missing data.
442 cases = list([*dict0.keys()])
443 dict1={}
444 for key, data in dict0.items():
445 if verbose:
446 print(f"Processing case: {key}")
447 print(f"Length of y vector: {len(data['y'])}")
448 if type(data['time'][0]) == str:
449 time=str2time(data['time'])
450 else:
451 time=data['time']
452 X_array = data['X']
453 y_array = data['y']
455 # Determine the start and end time for the 720-hour portions
456 start_time = time[0]
457 end_time = time[-1]
458 current_time = start_time
459 portion_index = 1
460 while current_time < end_time:
461 next_time = current_time + timedelta(hours=hours)
463 # Create a mask for the time range
464 mask = (time >= current_time) & (time < next_time)
466 # Apply the mask to extract the portions
467 new_time = time[mask]
468 new_X = X_array[mask]
469 new_y = y_array[mask]
471 # Save the portions in the new dictionary with naming convention if second or more portion
472 if portion_index > 1:
473 new_key = f"{key}{naming_convention}{portion_index}"
474 else:
475 new_key = key
476 dict1[new_key] = {'time': new_time, 'X': new_X, 'y': new_y}
478 # Add other keys that aren't subset
479 for key2 in dict0[key].keys():
480 if key2 not in ['time', 'X', 'y']:
481 dict1[new_key][key2] = dict0[key][key2]
482 # Update Case name and Id (same for now, overloaded terminology)
483 dict1[new_key]['case'] = new_key
484 dict1[new_key]['id'] = new_key
485 dict1[new_key]['hours'] = len(dict1[new_key]['y'])
488 # Move to the next portion
489 current_time = next_time
490 portion_index += 1
491 if verbose:
492 print(f"Partitions of length {hours} from case {key}: {portion_index-1}")
493 return dict1
495 def flag_lag_stretches(x, threshold, lag = 1):
497 Used to itentify stretches of data that have been interpolated a length greater than or equal to given threshold. Used to identify stretches of data that are not trustworthy due to extensive interpolation and thus should be removed from a ML training set.
499 lags = np.round(np.diff(x, n=lag), 8)
500 zero_lag_indices = np.where(lags == 0)[0]
501 current_run_length = 1
502 for i in range(1, len(zero_lag_indices)):
503 if zero_lag_indices[i] == zero_lag_indices[i-1] + 1:
504 current_run_length += 1
505 if current_run_length > threshold:
506 return True
507 else:
508 current_run_length = 1
509 else:
510 return False
513 def flag_dict_keys(dict0, lag_1_threshold, lag_2_threshold, max_y, min_y, verbose=False):
515 Loop through dictionary and generate list of flags for if the `y` variable within the dictionary has target lag patterns. The lag_1_threshold parameter sets upper limit for the number of constant, zero-lag stretches in y. The lag_2_threshold parameter sets upper limit for the number of constant, linear stretches of data. Used to identify cases of data that have been interpolated excessively long and thus not trustworthy for inclusion in a ML framework.
517 cases = list([*dict0.keys()])
518 flags = np.zeros(len(cases))
519 for i, case in enumerate(cases):
520 if verbose:
521 print("~"*50)
522 print(f"Case: {case}")
523 y = dict0[case]['y']
524 if flag_lag_stretches(y, threshold=lag_1_threshold, lag=1):
525 if verbose:
526 print(f"Flagging case {case} for zero lag stretches greater than param {lag_1_threshold}")
527 flags[i]=1
528 if flag_lag_stretches(y, threshold=lag_2_threshold, lag=2):
529 if verbose:
530 print(f"Flagging case {case} for constant linear stretches greater than param {lag_2_threshold}")
531 flags[i]=1
532 if np.any(y>=max_y) or np.any(y<=min_y):
533 if verbose:
534 print(f"Flagging case {case} for FMC outside param range {min_y,max_y}. FMC range for {case}: {y.min(),y.max()}")
535 flags[i]=1
537 return flags
539 def discard_keys_with_short_y(input_dict, hours, verbose=False):
541 Remove keys from a dictionary where the subkey `y` is less than given hours. Used to remove partial sequences at the end of timeseries after the longer timeseries has been subdivided.
544 max_y = max(case['y'].shape[0] for case in input_dict.values())
545 if max_y < hours:
546 raise ValueError(f"Given input hours of {hours}, all cases in input dictionary would be filtered out as too short. Max timeseries y length: {max_y}. Try using a larger value for `hours` in data params")
548 discarded_keys = [key for key, value in input_dict.items() if len(value['y']) < hours]
550 if verbose:
551 print(f"Discarded keys due to y length less than {hours}: {discarded_keys}")
553 filtered_dict = {key: value for key, value in input_dict.items() if key not in discarded_keys}
555 return filtered_dict
559 # Utility to combine nested fmda dictionaries
560 def combine_nested(nested_input_dict, verbose=True):
562 Combines input data dictionaries.
564 Parameters:
565 -----------
566 verbose : bool, optional
567 If True, prints status messages. Default is True.
568 """
569 # Setup return dictionary
570 d = {}
571 # Use the helper function to populate the keys
572 d['id'] = _combine_key(nested_input_dict, 'id')
573 d['case'] = _combine_key(nested_input_dict, 'case')
574 d['filename'] = _combine_key(nested_input_dict, 'filename')
575 d['time'] = _combine_key(nested_input_dict, 'time')
576 d['X'] = _combine_key(nested_input_dict, 'X')
577 d['y'] = _combine_key(nested_input_dict, 'y')
578 d['atm_source'] = _combine_key(nested_input_dict, 'atm_source')
580 # Build the loc subdictionary using _combine_key for each loc key
581 d['loc'] = {
582 'STID': _combine_key(nested_input_dict, 'loc', 'STID'),
583 'lat': _combine_key(nested_input_dict, 'loc', 'lat'),
584 'lon': _combine_key(nested_input_dict, 'loc', 'lon'),
585 'elev': _combine_key(nested_input_dict, 'loc', 'elev'),
586 'pixel_x': _combine_key(nested_input_dict, 'loc', 'pixel_x'),
587 'pixel_y': _combine_key(nested_input_dict, 'loc', 'pixel_y')
590 # Handle features_list separately with validation
591 features_list = _combine_key(nested_input_dict, 'features_list')
592 if features_list:
593 first_features_list = features_list[0]
594 for fl in features_list:
595 if fl != first_features_list:
596 warnings.warn("Different features_list found in the nested input dictionaries.")
597 d['features_list'] = first_features_list
599 return d
601 def _combine_key(nested_input_dict, key, subkey=None):
602 combined_list = []
603 for input_dict in nested_input_dict.values():
604 if isinstance(input_dict, dict):
605 try:
606 if subkey:
607 combined_list.append(input_dict[key][subkey])
608 else:
609 combined_list.append(input_dict[key])
610 except KeyError:
611 warning_message = f"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries. Setting value to None."
612 warnings.warn(warning_message)
613 combined_list.append(None)
614 else:
615 raise ValueError(f"Expected a dictionary, but got {type(input_dict)}")
616 return combined_list
619 def compare_dicts(dict1, dict2, keys):
620 for key in keys:
621 if dict1.get(key) != dict2.get(key):
622 return False
623 return True
625 items = '_items_' # dictionary key to keep list of items in
626 def check_data_array(dat,hours,a,s):
627 if a in dat[items]:
628 dat[items].remove(a)
629 if a in dat:
630 ar = dat[a]
631 print("array %s %s length %i min %s max %s hash %s %s" %
632 (a,s,len(ar),min(ar),max(ar),hash2(ar),type(ar)))
633 if hours is not None:
634 if len(ar) < hours:
635 print('len(%a) = %i does not equal to hours = %i' % (a,len(ar),hours))
636 exit(1)
637 else:
638 print(a + ' not present')
640 def check_data_scalar(dat,a):
641 if a in dat[items]:
642 dat[items].remove(a)
643 if a in dat:
644 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
645 else:
646 print(a + ' not present' )
648 def check_data(dat,case=True,name=None):
649 dat[items] = list(dat.keys()) # add list of items to the dictionary
650 if name is not None:
651 print(name)
652 if case:
653 check_data_scalar(dat,'filename')
654 check_data_scalar(dat,'title')
655 check_data_scalar(dat,'note')
656 check_data_scalar(dat,'hours')
657 check_data_scalar(dat,'h2')
658 check_data_scalar(dat,'case')
659 if 'hours' in dat:
660 hours = dat['hours']
661 else:
662 hours = None
663 check_data_array(dat,hours,'E','drying equilibrium (%)')
664 check_data_array(dat,hours,'Ed','drying equilibrium (%)')
665 check_data_array(dat,hours,'Ew','wetting equilibrium (%)')
666 check_data_array(dat,hours,'Ec','equilibrium equilibrium (%)')
667 check_data_array(dat,hours,'rain','rain intensity (mm/h)')
668 check_data_array(dat,hours,'fm','RAWS fuel moisture data (%)')
669 check_data_array(dat,hours,'m','fuel moisture estimate (%)')
670 if dat[items]:
671 print('items:',dat[items])
672 for a in dat[items].copy():
673 ar=dat[a]
674 if dat[a] is None or np.isscalar(dat[a]):
675 check_data_scalar(dat,a)
676 elif is_numeric_ndarray(ar):
677 print(type(ar))
678 print("array", a, "shape",ar.shape,"min",np.min(ar),
679 "max",np.max(ar),"hash",hash2(ar),"type",type(ar))
680 elif isinstance(ar, tf.Tensor):
681 print("array", a, "shape",ar.shape,"min",np.min(ar),
682 "max",np.max(ar),"type",type(ar))
683 else:
684 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
685 del dat[items] # clean up
687 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
688 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
689 def to_json(dic,filename):
690 # Write given dictionary as json file.
691 # This utility is used because the typical method fails on numpy.ndarray
692 # Inputs:
693 # dic: dictionary
694 # filename: (str) output json filename, expect a ".json" file extension
695 # Return: none
697 print('writing ',filename)
698 # check_data(dic)
699 new={}
700 for i in dic:
701 if type(dic[i]) is np.ndarray:
702 new[i]=dic[i].tolist() # because numpy.ndarray is not serializable
703 else:
704 new[i]=dic[i]
705 # print('i',type(new[i]))
706 new['filename']=filename
707 print('Hash: ', hash2(new))
708 json.dump(new,open(filename,'w'),indent=4)
710 def from_json(filename):
711 # Read json file given a filename
712 # Inputs: filename (str) expect a ".json" string
714 print('reading ',filename)
715 dic=json.load(open(filename,'r'))
716 new={}
717 for i in dic:
718 if type(dic[i]) is list:
719 new[i]=np.array(dic[i]) # because ndarray is not serializable
720 else:
721 new[i]=dic[i]
722 check_data(new)
723 print('Hash: ', hash2(new))
724 return new
726 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
728 # Function to simulate moisture data and equilibrium for model testing
729 def create_synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,DeltaE=0.0):
730 hours = days*24
731 h2 = int(hours/2)
732 hour = np.array(range(hours))
733 day = np.array(range(hours))/24.
735 # artificial equilibrium data
736 E = 100.0*np.power(np.sin(np.pi*day),4) # diurnal curve
737 E = 0.05+0.25*E
738 # FMC free run
739 m_f = np.zeros(hours)
740 m_f[0] = 0.1 # initial FMC
741 process_noise=0.
742 for t in range(hours-1):
743 m_f[t+1] = max(0.,model_decay(m_f[t],E[t]) + random.gauss(0,process_noise) )
744 data = m_f + np.random.normal(loc=0,scale=data_noise,size=hours)
745 E = E + DeltaE
746 return E,m_f,data,hour,h2,DeltaE
748 # the following input or output dictionary with all model data and variables
750 def synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,
751 DeltaE=0.0,Emin=5,Emax=30,p_rain=0.01,max_rain=10.0):
752 hours = days*24
753 h2 = int(hours/2)
754 hour = np.array(range(hours))
755 day = np.array(range(hours))/24.
756 # artificial equilibrium data
757 E = np.power(np.sin(np.pi*day),power) # diurnal curve betwen 0 and 1
758 E = Emin+(Emax - Emin)*E
759 E = E + DeltaE
760 Ed=E+0.5
761 Ew=np.maximum(E-0.5,0)
762 rain = np.multiply(rand(hours) < p_rain, rand(hours)*max_rain)
763 # FMC free run
764 fm = np.zeros(hours)
765 fm[0] = 0.1 # initial FMC
766 # process_noise=0.
767 for t in range(hours-1):
768 fm[t+1] = max(0.,model_moisture(fm[t],Ed[t-1],Ew[t-1],rain[t-1]) + random.gauss(0,process_noise))
769 fm = fm + np.random.normal(loc=0,scale=data_noise,size=hours)
770 dat = {'E':E,'Ew':Ew,'Ed':Ed,'fm':fm,'hours':hours,'h2':h2,'DeltaE':DeltaE,'rain':rain,'title':'Synthetic data'}
772 return dat
774 def plot_one(hmin,hmax,dat,name,linestyle,c,label, alpha=1,type='plot'):
775 # helper for plot_data
776 if name in dat:
777 h = len(dat[name])
778 if hmin is None:
779 hmin=0
780 if hmax is None:
781 hmax=len(dat[name])
782 hour = np.array(range(hmin,hmax))
783 if type=='plot':
784 plt.plot(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
785 elif type=='scatter':
786 plt.scatter(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
788 # Lookup table for plotting features
789 plot_styles = {
790 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
791 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
792 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
794 def plot_feature(x, y, feature_name):
795 style = plot_styles.get(feature_name, {})
796 plt.plot(x, y, **style)
798 def plot_features(hmin,hmax,dat,linestyle,c,label,alpha=1):
799 hour = np.array(range(hmin,hmax))
800 for feat in dat.features_list:
801 i = dat.all_features_list.index(feat) # index of main data
802 if feat in plot_styles.keys():
803 plot_feature(x=hour, y=dat['X'][:,i][hmin:hmax], feature_name=feat)
805 def plot_data(dat, plot_period='all', create_figure=False,title=None,title2=None,hmin=0,hmax=None,xlabel=None,ylabel=None):
806 # Plot fmda dictionary of data and model if present
807 # Inputs:
808 # dat: FMDA dictionary
809 # inverse_scale: logical, whether to inverse scale data
810 # Returns: none
812 # dat = copy.deepcopy(dat0)
814 if 'hours' in dat:
815 if hmax is None:
816 hmax = dat['hours']
817 else:
818 hmax = min(hmax, dat['hours'])
819 if plot_period == "all":
820 pass
821 elif plot_period == "predict":
822 assert "test_ind" in dat.keys()
823 hmin = dat['test_ind']
825 else:
826 raise ValueError(f"unrecognized time period for plotting plot_period: {plot_period}")
829 if create_figure:
830 plt.figure(figsize=(16,4))
832 plot_one(hmin,hmax,dat,'y',linestyle='-',c='#468a29',label='FM Observed')
833 plot_one(hmin,hmax,dat,'m',linestyle='-',c='k',label='FM Model')
834 plot_features(hmin,hmax,dat,linestyle='-',c='k',label='FM Model')
837 if 'test_ind' in dat.keys():
838 test_ind = dat["test_ind"]
839 else:
840 test_ind = None
841 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
842 # Note: the code within the tildes here makes a more complex, annotated plot
843 if (test_ind is not None) and ('m' in dat.keys()):
844 plt.axvline(test_ind, linestyle=':', c='k', alpha=.8)
845 yy = plt.ylim() # used to format annotations
846 plot_y0 = np.max([hmin, test_ind]) # Used to format annotations
847 plot_y1 = np.min([hmin, test_ind])
848 plt.annotate('', xy=(hmin, yy[0]),xytext=(plot_y0,yy[0]),
849 arrowprops=dict(arrowstyle='<-', linewidth=2),
850 annotation_clip=False)
851 plt.annotate('(Training)',xy=((hmin+plot_y0)/2,yy[1]),xytext=((hmin+plot_y0)/2,yy[1]+1), ha = 'right',
852 annotation_clip=False, alpha=.8)
853 plt.annotate('', xy=(plot_y0, yy[0]),xytext=(hmax,yy[0]),
854 arrowprops=dict(arrowstyle='<-', linewidth=2),
855 annotation_clip=False)
856 plt.annotate('(Forecast)',xy=(hmax-(hmax-test_ind)/2,yy[1]),
857 xytext=(hmax-(hmax-test_ind)/2,yy[1]+1),
858 annotation_clip=False, alpha=.8)
859 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
862 if title is not None:
863 t = title
864 elif 'title' in dat:
865 t=dat['title']
866 # print('title',type(t),t)
867 else:
868 t=''
869 if title2 is not None:
870 t = t + ' ' + title2
871 t = t + ' (' + rmse_data_str(dat)+')'
872 if plot_period == "predict":
873 t = t + " - Forecast Period"
874 plt.title(t, y=1.1)
876 if xlabel is None:
877 plt.xlabel('Time (hours)')
878 else:
879 plt.xlabel(xlabel)
880 if 'rain' in dat:
881 plt.ylabel('FM (%) / Rain (mm/h)')
882 elif ylabel is None:
883 plt.ylabel('Fuel moisture content (%)')
884 else:
885 plt.ylabel(ylabel)
886 plt.legend(loc="upper left")
888 def rmse(a, b):
889 return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))
891 def rmse_skip_nan(x, y):
892 mask = ~np.isnan(x) & ~np.isnan(y)
893 if np.count_nonzero(mask):
894 return np.sqrt(np.mean((x[mask] - y[mask]) ** 2))
895 else:
896 return np.nan
898 def rmse_str(a,b):
899 rmse = rmse_skip_nan(a,b)
900 return "RMSE " + "{:.3f}".format(rmse)
902 def rmse_data_str(dat, predict=True, hours = None, test_ind = None):
903 # Return RMSE for model object in formatted string. Used within plotting
904 # Inputs:
905 # dat: (dict) fmda dictionary
906 # predict: (bool) Whether to return prediction period RMSE. Default True
907 # hours: (int) total number of modeled time periods
908 # test_ind: (int) start of test period
909 # Return: (str) RMSE value
911 if hours is None:
912 if 'hours' in dat:
913 hours = dat['hours']
914 if test_ind is None:
915 if 'test_ind' in dat:
916 test_ind = dat['test_ind']
918 if 'm' in dat and 'y' in dat:
919 if predict and hours is not None and test_ind is not None:
920 return rmse_str(dat['m'][test_ind:hours],dat['y'].flatten()[test_ind:hours])
921 else:
922 return rmse_str(dat['m'],dat['y'].flatten())
923 else:
924 return ''
927 # Calculate mean absolute error
928 def mape(a, b):
929 return ((a - b).__abs__()).mean()
931 def rmse_data(dat, hours = None, h2 = None, simulation='m', measurements='fm'):
932 if hours is None:
933 hours = dat['hours']
934 if h2 is None:
935 h2 = dat['h2']
937 m = dat[simulation]
938 fm = dat[measurements]
939 case = dat['case']
941 train =rmse(m[:h2], fm[:h2])
942 predict = rmse(m[h2:hours], fm[h2:hours])
943 all = rmse(m[:hours], fm[:hours])
944 print(case,'Training 1 to',h2,'hours RMSE: ' + str(np.round(train, 4)))
945 print(case,'Prediction',h2+1,'to',hours,'hours RMSE: ' + str(np.round(predict, 4)))
946 print(f"All predictions hash: {hash2(m)}")
948 return {'train':train, 'predict':predict, 'all':all}
952 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
956 def get_file(filename, data_dir='data'):
957 # Check for file locally, retrieve with wget if not
958 if osp.exists(osp.join(data_dir, filename)):
959 print(f"File {osp.join(data_dir, filename)} exists locally")
960 elif not osp.exists(filename):
961 import subprocess
962 base_url = "https://demo.openwfm.org/web/data/fmda/dicts/"
963 print(f"Retrieving data {osp.join(base_url, filename)}")
964 subprocess.call(f"wget -P {data_dir} {osp.join(base_url, filename)}", shell=True)