Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / data_funcs.py
blobf4919af4ac1888bff9167701ff384877d35a1aea
1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy as np, random
5 from numpy.random import rand
6 import tensorflow as tf
7 import pickle, os
8 from sklearn.metrics import mean_squared_error
9 import matplotlib.pyplot as plt
10 from moisture_models import model_decay, model_moisture
11 from datetime import datetime, timedelta
12 from utils import is_numeric_ndarray, hash2
13 import json
14 import copy
15 import subprocess
16 import os.path as osp
19 def compare_dicts(dict1, dict2, keys):
20 for key in keys:
21 if dict1.get(key) != dict2.get(key):
22 return False
23 return True
25 items = '_items_' # dictionary key to keep list of items in
26 def check_data_array(dat,hours,a,s):
27 if a in dat[items]:
28 dat[items].remove(a)
29 if a in dat:
30 ar = dat[a]
31 print("array %s %s length %i min %s max %s hash %s %s" %
32 (a,s,len(ar),min(ar),max(ar),hash2(ar),type(ar)))
33 if hours is not None:
34 if len(ar) < hours:
35 print('len(%a) = %i does not equal to hours = %i' % (a,len(ar),hours))
36 exit(1)
37 else:
38 print(a + ' not present')
40 def check_data_scalar(dat,a):
41 if a in dat[items]:
42 dat[items].remove(a)
43 if a in dat:
44 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
45 else:
46 print(a + ' not present' )
48 def check_data(dat,case=True,name=None):
49 dat[items] = list(dat.keys()) # add list of items to the dictionary
50 if name is not None:
51 print(name)
52 if case:
53 check_data_scalar(dat,'filename')
54 check_data_scalar(dat,'title')
55 check_data_scalar(dat,'note')
56 check_data_scalar(dat,'hours')
57 check_data_scalar(dat,'h2')
58 check_data_scalar(dat,'case')
59 if 'hours' in dat:
60 hours = dat['hours']
61 else:
62 hours = None
63 check_data_array(dat,hours,'E','drying equilibrium (%)')
64 check_data_array(dat,hours,'Ed','drying equilibrium (%)')
65 check_data_array(dat,hours,'Ew','wetting equilibrium (%)')
66 check_data_array(dat,hours,'Ec','equilibrium equilibrium (%)')
67 check_data_array(dat,hours,'rain','rain intensity (mm/h)')
68 check_data_array(dat,hours,'fm','RAWS fuel moisture data (%)')
69 check_data_array(dat,hours,'m','fuel moisture estimate (%)')
70 if dat[items]:
71 print('items:',dat[items])
72 for a in dat[items].copy():
73 ar=dat[a]
74 if dat[a] is None or np.isscalar(dat[a]):
75 check_data_scalar(dat,a)
76 elif is_numeric_ndarray(ar):
77 print(type(ar))
78 print("array", a, "shape",ar.shape,"min",np.min(ar),
79 "max",np.max(ar),"hash",hash2(ar),"type",type(ar))
80 elif isinstance(ar, tf.Tensor):
81 print("array", a, "shape",ar.shape,"min",np.min(ar),
82 "max",np.max(ar),"type",type(ar))
83 else:
84 print('%s = %s' % (a,dat[a]),' ',type(dat[a]))
85 del dat[items] # clean up
87 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
88 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89 def to_json(dic,filename):
90 # Write given dictionary as json file.
91 # This utility is used because the typical method fails on numpy.ndarray
92 # Inputs:
93 # dic: dictionary
94 # filename: (str) output json filename, expect a ".json" file extension
95 # Return: none
97 print('writing ',filename)
98 # check_data(dic)
99 new={}
100 for i in dic:
101 if type(dic[i]) is np.ndarray:
102 new[i]=dic[i].tolist() # because numpy.ndarray is not serializable
103 else:
104 new[i]=dic[i]
105 # print('i',type(new[i]))
106 new['filename']=filename
107 print('Hash: ', hash2(new))
108 json.dump(new,open(filename,'w'),indent=4)
110 def from_json(filename):
111 # Read json file given a filename
112 # Inputs: filename (str) expect a ".json" string
114 print('reading ',filename)
115 dic=json.load(open(filename,'r'))
116 new={}
117 for i in dic:
118 if type(dic[i]) is list:
119 new[i]=np.array(dic[i]) # because ndarray is not serializable
120 else:
121 new[i]=dic[i]
122 check_data(new)
123 print('Hash: ', hash2(new))
124 return new
126 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128 # Function to simulate moisture data and equilibrium for model testing
129 def create_synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,DeltaE=0.0):
130 hours = days*24
131 h2 = int(hours/2)
132 hour = np.array(range(hours))
133 day = np.array(range(hours))/24.
135 # artificial equilibrium data
136 E = 100.0*np.power(np.sin(np.pi*day),4) # diurnal curve
137 E = 0.05+0.25*E
138 # FMC free run
139 m_f = np.zeros(hours)
140 m_f[0] = 0.1 # initial FMC
141 process_noise=0.
142 for t in range(hours-1):
143 m_f[t+1] = max(0.,model_decay(m_f[t],E[t]) + random.gauss(0,process_noise) )
144 data = m_f + np.random.normal(loc=0,scale=data_noise,size=hours)
145 E = E + DeltaE
146 return E,m_f,data,hour,h2,DeltaE
148 # the following input or output dictionary with all model data and variables
150 def synthetic_data(days=20,power=4,data_noise=0.02,process_noise=0.0,
151 DeltaE=0.0,Emin=5,Emax=30,p_rain=0.01,max_rain=10.0):
152 hours = days*24
153 h2 = int(hours/2)
154 hour = np.array(range(hours))
155 day = np.array(range(hours))/24.
156 # artificial equilibrium data
157 E = np.power(np.sin(np.pi*day),power) # diurnal curve betwen 0 and 1
158 E = Emin+(Emax - Emin)*E
159 E = E + DeltaE
160 Ed=E+0.5
161 Ew=np.maximum(E-0.5,0)
162 rain = np.multiply(rand(hours) < p_rain, rand(hours)*max_rain)
163 # FMC free run
164 fm = np.zeros(hours)
165 fm[0] = 0.1 # initial FMC
166 # process_noise=0.
167 for t in range(hours-1):
168 fm[t+1] = max(0.,model_moisture(fm[t],Ed[t-1],Ew[t-1],rain[t-1]) + random.gauss(0,process_noise))
169 fm = fm + np.random.normal(loc=0,scale=data_noise,size=hours)
170 dat = {'E':E,'Ew':Ew,'Ed':Ed,'fm':fm,'hours':hours,'h2':h2,'DeltaE':DeltaE,'rain':rain,'title':'Synthetic data'}
172 return dat
174 def plot_one(hmin,hmax,dat,name,linestyle,c,label, alpha=1,type='plot'):
175 # helper for plot_data
176 if name in dat:
177 h = len(dat[name])
178 if hmin is None:
179 hmin=0
180 if hmax is None:
181 hmax=len(dat[name])
182 hour = np.array(range(hmin,hmax))
183 if type=='plot':
184 plt.plot(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
185 elif type=='scatter':
186 plt.scatter(hour,dat[name][hmin:hmax],linestyle=linestyle,c=c,label=label, alpha=alpha)
188 # Lookup table for plotting features
189 plot_styles = {
190 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
191 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
192 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
194 def plot_feature(x, y, feature_name):
195 style = plot_styles.get(feature_name, {})
196 plt.plot(x, y, **style)
198 def plot_features(hmin,hmax,dat,linestyle,c,label,alpha=1):
199 hour = np.array(range(hmin,hmax))
200 for feat in dat.features_list:
201 i = dat.all_features_list.index(feat) # index of main data
202 if feat in plot_styles.keys():
203 plot_feature(x=hour, y=dat['X'][:,i][hmin:hmax], feature_name=feat)
205 def plot_data(dat, plot_period='all', create_figure=False,title=None,title2=None,hmin=0,hmax=None,xlabel=None,ylabel=None):
206 # Plot fmda dictionary of data and model if present
207 # Inputs:
208 # dat: FMDA dictionary
209 # inverse_scale: logical, whether to inverse scale data
210 # Returns: none
212 # dat = copy.deepcopy(dat0)
214 if 'hours' in dat:
215 if hmax is None:
216 hmax = dat['hours']
217 else:
218 hmax = min(hmax, dat['hours'])
219 if plot_period == "all":
220 pass
221 elif plot_period == "predict":
222 assert "test_ind" in dat.keys()
223 hmin = dat['test_ind']
225 else:
226 raise ValueError(f"unrecognized time period for plotting plot_period: {plot_period}")
229 if create_figure:
230 plt.figure(figsize=(16,4))
232 plot_one(hmin,hmax,dat,'y',linestyle='-',c='#468a29',label='FM Observed')
233 plot_one(hmin,hmax,dat,'m',linestyle='-',c='k',label='FM Model')
234 plot_features(hmin,hmax,dat,linestyle='-',c='k',label='FM Model')
237 if 'test_ind' in dat.keys():
238 test_ind = dat["test_ind"]
239 else:
240 test_ind = None
241 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
242 # Note: the code within the tildes here makes a more complex, annotated plot
243 if (test_ind is not None) and ('m' in dat.keys()):
244 plt.axvline(test_ind, linestyle=':', c='k', alpha=.8)
245 yy = plt.ylim() # used to format annotations
246 plot_y0 = np.max([hmin, test_ind]) # Used to format annotations
247 plot_y1 = np.min([hmin, test_ind])
248 plt.annotate('', xy=(hmin, yy[0]),xytext=(plot_y0,yy[0]),
249 arrowprops=dict(arrowstyle='<-', linewidth=2),
250 annotation_clip=False)
251 plt.annotate('(Training)',xy=((hmin+plot_y0)/2,yy[1]),xytext=((hmin+plot_y0)/2,yy[1]+1), ha = 'right',
252 annotation_clip=False, alpha=.8)
253 plt.annotate('', xy=(plot_y0, yy[0]),xytext=(hmax,yy[0]),
254 arrowprops=dict(arrowstyle='<-', linewidth=2),
255 annotation_clip=False)
256 plt.annotate('(Forecast)',xy=(hmax-(hmax-test_ind)/2,yy[1]),
257 xytext=(hmax-(hmax-test_ind)/2,yy[1]+1),
258 annotation_clip=False, alpha=.8)
259 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
262 if title is not None:
263 t = title
264 elif 'title' in dat:
265 t=dat['title']
266 # print('title',type(t),t)
267 else:
268 t=''
269 if title2 is not None:
270 t = t + ' ' + title2
271 t = t + ' (' + rmse_data_str(dat)+')'
272 if plot_period == "predict":
273 t = t + " - Forecast Period"
274 plt.title(t, y=1.1)
276 if xlabel is None:
277 plt.xlabel('Time (hours)')
278 else:
279 plt.xlabel(xlabel)
280 if 'rain' in dat:
281 plt.ylabel('FM (%) / Rain (mm/h)')
282 elif ylabel is None:
283 plt.ylabel('Fuel moisture content (%)')
284 else:
285 plt.ylabel(ylabel)
286 plt.legend(loc="upper left")
288 def rmse(a, b):
289 return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))
291 def rmse_skip_nan(x, y):
292 mask = ~np.isnan(x) & ~np.isnan(y)
293 if np.count_nonzero(mask):
294 return np.sqrt(np.mean((x[mask] - y[mask]) ** 2))
295 else:
296 return np.nan
298 def rmse_str(a,b):
299 rmse = rmse_skip_nan(a,b)
300 return "RMSE " + "{:.3f}".format(rmse)
302 def rmse_data_str(dat, predict=True, hours = None, test_ind = None):
303 # Return RMSE for model object in formatted string. Used within plotting
304 # Inputs:
305 # dat: (dict) fmda dictionary
306 # predict: (bool) Whether to return prediction period RMSE. Default True
307 # hours: (int) total number of modeled time periods
308 # test_ind: (int) start of test period
309 # Return: (str) RMSE value
311 if hours is None:
312 if 'hours' in dat:
313 hours = dat['hours']
314 if test_ind is None:
315 if 'test_ind' in dat:
316 test_ind = dat['test_ind']
318 if 'm' in dat and 'y' in dat:
319 if predict and hours is not None and test_ind is not None:
320 return rmse_str(dat['m'][test_ind:hours],dat['y'].flatten()[test_ind:hours])
321 else:
322 return rmse_str(dat['m'],dat['y'].flatten())
323 else:
324 return ''
327 # Calculate mean absolute error
328 def mape(a, b):
329 return ((a - b).__abs__()).mean()
331 def rmse_data(dat, hours = None, h2 = None, simulation='m', measurements='fm'):
332 if hours is None:
333 hours = dat['hours']
334 if h2 is None:
335 h2 = dat['h2']
337 m = dat[simulation]
338 fm = dat[measurements]
339 case = dat['case']
341 train =rmse(m[:h2], fm[:h2])
342 predict = rmse(m[h2:hours], fm[h2:hours])
343 all = rmse(m[:hours], fm[:hours])
344 print(case,'Training 1 to',h2,'hours RMSE: ' + str(np.round(train, 4)))
345 print(case,'Prediction',h2+1,'to',hours,'hours RMSE: ' + str(np.round(predict, 4)))
346 print(f"All predictions hash: {hash2(m)}")
348 return {'train':train, 'predict':predict, 'all':all}
352 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
356 def get_file(filename, data_dir='data'):
357 # Check for file locally, retrieve with wget if not
358 if osp.exists(osp.join(data_dir, filename)):
359 print(f"File {osp.join(data_dir, filename)} exists locally")
360 elif not osp.exists(filename):
361 import subprocess
362 base_url = "https://demo.openwfm.org/web/data/fmda/dicts/"
363 print(f"Retrieving data {osp.join(base_url, filename)}")
364 subprocess.call(f"wget -P {data_dir} {osp.join(base_url, filename)}", shell=True)