Update moisture_rnn.py
[notebooks.git] / fmda / data_funcs.py
blobbd3cc1701bcd784f1866f3d3d3563003a04d90c4
1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy as np, random
5 from numpy.random import rand
6 import tensorflow as tf
7 import pickle, os
8 from sklearn.metrics import mean_squared_error
9 import matplotlib.pyplot as plt
10 from moisture_models import model_decay, model_moisture
11 from datetime import datetime, timedelta
12 from utils import is_numeric_ndarray, hash2
13 import json
14 import copy
15 import subprocess
16 import os.path as osp
18 # Utility to combine nested fmda dictionaries
19 def combine_nested(nested_input_dict, verbose=True):
20 """
21 Combines input data dictionaries.
23 Parameters:
24 -----------
25 verbose : bool, optional
26 If True, prints status messages. Default is True.
27 """
28 # Setup return dictionary
29 d = {}
30 # Use the helper function to populate the keys
31 d['id'] = _combine_key(nested_input_dict, 'id')
32 d['case'] = _combine_key(nested_input_dict, 'case')
33 d['filename'] = _combine_key(nested_input_dict, 'filename')
34 d['time'] = _combine_key(nested_input_dict, 'time')
35 d['X'] = _combine_key(nested_input_dict, 'X')
36 d['y'] = _combine_key(nested_input_dict, 'y')
38 # Build the loc subdictionary using _combine_key for each loc key
39 d['loc'] = {
40 'STID': _combine_key(nested_input_dict, 'loc', 'STID'),
41 'lat': _combine_key(nested_input_dict, 'loc', 'lat'),
42 'lon': _combine_key(nested_input_dict, 'loc', 'lon'),
43 'elev': _combine_key(nested_input_dict, 'loc', 'elev'),
44 'pixel_x': _combine_key(nested_input_dict, 'loc', 'pixel_x'),
45 'pixel_y': _combine_key(nested_input_dict, 'loc', 'pixel_y')
48 # Handle features_list separately with validation
49 features_list = _combine_key(nested_input_dict, 'features_list')
50 if features_list:
51 first_features_list = features_list[0]
52 for fl in features_list:
53 if fl != first_features_list:
54 warnings.warn("Different features_list found in the nested input dictionaries.")
55 d['features_list'] = first_features_list
57 return d
59 def _combine_key(nested_input_dict, key, subkey=None):
60 combined_list = []
61 for input_dict in nested_input_dict.values():
62 if isinstance(input_dict, dict):
63 try:
64 if subkey:
65 combined_list.append(input_dict[key][subkey])
66 else:
67 combined_list.append(input_dict[key])
68 except KeyError:
69 raise ValueError(f"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries")
70 else:
71 raise ValueError(f"Expected a dictionary, but got {type(input_dict)}")
72 return combined_list
75 def compare_dicts(dict1, dict2, keys):
76 for key in keys:
77 if dict1.get(key) != dict2.get(key):
78 return False
79 return True
81 items = '_items_' # dictionary key to keep list of items in
82 def check_data_array(dat,hours,a,s):
83 if a in dat[items]:
84 dat[items].remove(a)
85 if a in dat:
86 ar = dat[a]
87 print("array %s %s length %i min %s max %s hash %s %s" %
88 (a,s,len(ar),min(ar),max(ar),hash2(ar),type(ar)))
89 if hours is not None:
90 if len(ar) < hours:
91 print('len(%a) = %i does not equal to hours = %i' % (a,len(ar),hours))
92 exit(1)
93 else:
94 print(a + ' not present')
96 def check_data_scalar(dat,a):
97 if a in dat[items]:
98 dat[items].remove(a)
99 if a in dat:
100 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
101 else:
102 print(a + ' not present' )
104 def check_data(dat,case=True,name=None):
105 dat[items] = list(dat.keys()) # add list of items to the dictionary
106 if name is not None:
107 print(name)
108 if case:
109 check_data_scalar(dat,'filename')
110 check_data_scalar(dat,'title')
111 check_data_scalar(dat,'note')
112 check_data_scalar(dat,'hours')
113 check_data_scalar(dat,'h2')
114 check_data_scalar(dat,'case')
115 if 'hours' in dat:
116 hours = dat['hours']
117 else:
118 hours = None
119 check_data_array(dat,hours,'E','drying equilibrium (%)')
120 check_data_array(dat,hours,'Ed','drying equilibrium (%)')
121 check_data_array(dat,hours,'Ew','wetting equilibrium (%)')
122 check_data_array(dat,hours,'Ec','equilibrium equilibrium (%)')
123 check_data_array(dat,hours,'rain','rain intensity (mm/h)')
124 check_data_array(dat,hours,'fm','RAWS fuel moisture data (%)')
125 check_data_array(dat,hours,'m','fuel moisture estimate (%)')
126 if dat[items]:
127 print('items:',dat[items])
128 for a in dat[items].copy():
129 ar=dat[a]
130 if dat[a] is None or np.isscalar(dat[a]):
131 check_data_scalar(dat,a)
132 elif is_numeric_ndarray(ar):
133 print(type(ar))
134 print("array", a, "shape",ar.shape,"min",np.min(ar),
135 "max",np.max(ar),"hash",hash2(ar),"type",type(ar))
136 elif isinstance(ar, tf.Tensor):
137 print("array", a, "shape",ar.shape,"min",np.min(ar),
138 "max",np.max(ar),"type",type(ar))
139 else:
140 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
141 del dat[items] # clean up
143 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
144 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145 def to_json(dic,filename):
146 # Write given dictionary as json file.
147 # This utility is used because the typical method fails on numpy.ndarray
148 # Inputs:
149 # dic: dictionary
150 # filename: (str) output json filename, expect a ".json" file extension
151 # Return: none
153 print('writing ',filename)
154 # check_data(dic)
155 new={}
156 for i in dic:
157 if type(dic[i]) is np.ndarray:
158 new[i]=dic[i].tolist() # because numpy.ndarray is not serializable
159 else:
160 new[i]=dic[i]
161 # print('i',type(new[i]))
162 new['filename']=filename
163 print('Hash: ', hash2(new))
164 json.dump(new,open(filename,'w'),indent=4)
166 def from_json(filename):
167 # Read json file given a filename
168 # Inputs: filename (str) expect a ".json" string
170 print('reading ',filename)
171 dic=json.load(open(filename,'r'))
172 new={}
173 for i in dic:
174 if type(dic[i]) is list:
175 new[i]=np.array(dic[i]) # because ndarray is not serializable
176 else:
177 new[i]=dic[i]
178 check_data(new)
179 print('Hash: ', hash2(new))
180 return new
182 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184 # Function to simulate moisture data and equilibrium for model testing
185 def create_synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,DeltaE=0.0):
186 hours = days*24
187 h2 = int(hours/2)
188 hour = np.array(range(hours))
189 day = np.array(range(hours))/24.
191 # artificial equilibrium data
192 E = 100.0*np.power(np.sin(np.pi*day),4) # diurnal curve
193 E = 0.05+0.25*E
194 # FMC free run
195 m_f = np.zeros(hours)
196 m_f[0] = 0.1 # initial FMC
197 process_noise=0.
198 for t in range(hours-1):
199 m_f[t+1] = max(0.,model_decay(m_f[t],E[t]) + random.gauss(0,process_noise) )
200 data = m_f + np.random.normal(loc=0,scale=data_noise,size=hours)
201 E = E + DeltaE
202 return E,m_f,data,hour,h2,DeltaE
204 # the following input or output dictionary with all model data and variables
206 def synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,
207 DeltaE=0.0,Emin=5,Emax=30,p_rain=0.01,max_rain=10.0):
208 hours = days*24
209 h2 = int(hours/2)
210 hour = np.array(range(hours))
211 day = np.array(range(hours))/24.
212 # artificial equilibrium data
213 E = np.power(np.sin(np.pi*day),power) # diurnal curve betwen 0 and 1
214 E = Emin+(Emax - Emin)*E
215 E = E + DeltaE
216 Ed=E+0.5
217 Ew=np.maximum(E-0.5,0)
218 rain = np.multiply(rand(hours) < p_rain, rand(hours)*max_rain)
219 # FMC free run
220 fm = np.zeros(hours)
221 fm[0] = 0.1 # initial FMC
222 # process_noise=0.
223 for t in range(hours-1):
224 fm[t+1] = max(0.,model_moisture(fm[t],Ed[t-1],Ew[t-1],rain[t-1]) + random.gauss(0,process_noise))
225 fm = fm + np.random.normal(loc=0,scale=data_noise,size=hours)
226 dat = {'E':E,'Ew':Ew,'Ed':Ed,'fm':fm,'hours':hours,'h2':h2,'DeltaE':DeltaE,'rain':rain,'title':'Synthetic data'}
228 return dat
230 def plot_one(hmin,hmax,dat,name,linestyle,c,label, alpha=1,type='plot'):
231 # helper for plot_data
232 if name in dat:
233 h = len(dat[name])
234 if hmin is None:
235 hmin=0
236 if hmax is None:
237 hmax=len(dat[name])
238 hour = np.array(range(hmin,hmax))
239 if type=='plot':
240 plt.plot(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
241 elif type=='scatter':
242 plt.scatter(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
244 # Lookup table for plotting features
245 plot_styles = {
246 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
247 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
248 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
250 def plot_feature(x, y, feature_name):
251 style = plot_styles.get(feature_name, {})
252 plt.plot(x, y, **style)
254 def plot_features(hmin,hmax,dat,linestyle,c,label,alpha=1):
255 hour = np.array(range(hmin,hmax))
256 for feat in dat.features_list:
257 i = dat.all_features_list.index(feat) # index of main data
258 if feat in plot_styles.keys():
259 plot_feature(x=hour, y=dat['X'][:,i][hmin:hmax], feature_name=feat)
261 def plot_data(dat, plot_period='all', create_figure=False,title=None,title2=None,hmin=0,hmax=None,xlabel=None,ylabel=None):
262 # Plot fmda dictionary of data and model if present
263 # Inputs:
264 # dat: FMDA dictionary
265 # inverse_scale: logical, whether to inverse scale data
266 # Returns: none
268 # dat = copy.deepcopy(dat0)
270 if 'hours' in dat:
271 if hmax is None:
272 hmax = dat['hours']
273 else:
274 hmax = min(hmax, dat['hours'])
275 if plot_period == "all":
276 pass
277 elif plot_period == "predict":
278 assert "test_ind" in dat.keys()
279 hmin = dat['test_ind']
281 else:
282 raise ValueError(f"unrecognized time period for plotting plot_period: {plot_period}")
285 if create_figure:
286 plt.figure(figsize=(16,4))
288 plot_one(hmin,hmax,dat,'y',linestyle='-',c='#468a29',label='FM Observed')
289 plot_one(hmin,hmax,dat,'m',linestyle='-',c='k',label='FM Model')
290 plot_features(hmin,hmax,dat,linestyle='-',c='k',label='FM Model')
293 if 'test_ind' in dat.keys():
294 test_ind = dat["test_ind"]
295 else:
296 test_ind = None
297 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
298 # Note: the code within the tildes here makes a more complex, annotated plot
299 if (test_ind is not None) and ('m' in dat.keys()):
300 plt.axvline(test_ind, linestyle=':', c='k', alpha=.8)
301 yy = plt.ylim() # used to format annotations
302 plot_y0 = np.max([hmin, test_ind]) # Used to format annotations
303 plot_y1 = np.min([hmin, test_ind])
304 plt.annotate('', xy=(hmin, yy[0]),xytext=(plot_y0,yy[0]),
305 arrowprops=dict(arrowstyle='<-', linewidth=2),
306 annotation_clip=False)
307 plt.annotate('(Training)',xy=((hmin+plot_y0)/2,yy[1]),xytext=((hmin+plot_y0)/2,yy[1]+1), ha = 'right',
308 annotation_clip=False, alpha=.8)
309 plt.annotate('', xy=(plot_y0, yy[0]),xytext=(hmax,yy[0]),
310 arrowprops=dict(arrowstyle='<-', linewidth=2),
311 annotation_clip=False)
312 plt.annotate('(Forecast)',xy=(hmax-(hmax-test_ind)/2,yy[1]),
313 xytext=(hmax-(hmax-test_ind)/2,yy[1]+1),
314 annotation_clip=False, alpha=.8)
315 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318 if title is not None:
319 t = title
320 elif 'title' in dat:
321 t=dat['title']
322 # print('title',type(t),t)
323 else:
324 t=''
325 if title2 is not None:
326 t = t + ' ' + title2
327 t = t + ' (' + rmse_data_str(dat)+')'
328 if plot_period == "predict":
329 t = t + " - Forecast Period"
330 plt.title(t, y=1.1)
332 if xlabel is None:
333 plt.xlabel('Time (hours)')
334 else:
335 plt.xlabel(xlabel)
336 if 'rain' in dat:
337 plt.ylabel('FM (%) / Rain (mm/h)')
338 elif ylabel is None:
339 plt.ylabel('Fuel moisture content (%)')
340 else:
341 plt.ylabel(ylabel)
342 plt.legend(loc="upper left")
344 def rmse(a, b):
345 return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))
347 def rmse_skip_nan(x, y):
348 mask = ~np.isnan(x) & ~np.isnan(y)
349 if np.count_nonzero(mask):
350 return np.sqrt(np.mean((x[mask] - y[mask]) ** 2))
351 else:
352 return np.nan
354 def rmse_str(a,b):
355 rmse = rmse_skip_nan(a,b)
356 return "RMSE " + "{:.3f}".format(rmse)
358 def rmse_data_str(dat, predict=True, hours = None, test_ind = None):
359 # Return RMSE for model object in formatted string. Used within plotting
360 # Inputs:
361 # dat: (dict) fmda dictionary
362 # predict: (bool) Whether to return prediction period RMSE. Default True
363 # hours: (int) total number of modeled time periods
364 # test_ind: (int) start of test period
365 # Return: (str) RMSE value
367 if hours is None:
368 if 'hours' in dat:
369 hours = dat['hours']
370 if test_ind is None:
371 if 'test_ind' in dat:
372 test_ind = dat['test_ind']
374 if 'm' in dat and 'y' in dat:
375 if predict and hours is not None and test_ind is not None:
376 return rmse_str(dat['m'][test_ind:hours],dat['y'].flatten()[test_ind:hours])
377 else:
378 return rmse_str(dat['m'],dat['y'].flatten())
379 else:
380 return ''
383 # Calculate mean absolute error
384 def mape(a, b):
385 return ((a - b).__abs__()).mean()
387 def rmse_data(dat, hours = None, h2 = None, simulation='m', measurements='fm'):
388 if hours is None:
389 hours = dat['hours']
390 if h2 is None:
391 h2 = dat['h2']
393 m = dat[simulation]
394 fm = dat[measurements]
395 case = dat['case']
397 train =rmse(m[:h2], fm[:h2])
398 predict = rmse(m[h2:hours], fm[h2:hours])
399 all = rmse(m[:hours], fm[:hours])
400 print(case,'Training 1 to',h2,'hours RMSE: ' + str(np.round(train, 4)))
401 print(case,'Prediction',h2+1,'to',hours,'hours RMSE: ' + str(np.round(predict, 4)))
402 print(f"All predictions hash: {hash2(m)}")
404 return {'train':train, 'predict':predict, 'all':all}
408 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
412 def get_file(filename, data_dir='data'):
413 # Check for file locally, retrieve with wget if not
414 if osp.exists(osp.join(data_dir, filename)):
415 print(f"File {osp.join(data_dir, filename)} exists locally")
416 elif not osp.exists(filename):
417 import subprocess
418 base_url = "https://demo.openwfm.org/web/data/fmda/dicts/"
419 print(f"Retrieving data {osp.join(base_url, filename)}")
420 subprocess.call(f"wget -P {data_dir} {osp.join(base_url, filename)}", shell=True)