1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy
as np
, random
5 from numpy
.random
import rand
6 import tensorflow
as tf
8 from sklearn
.metrics
import mean_squared_error
9 import matplotlib
.pyplot
as plt
10 from moisture_models
import model_decay
, model_moisture
11 from datetime
import datetime
, timedelta
12 from utils
import is_numeric_ndarray
, hash2
19 def compare_dicts(dict1
, dict2
, keys
):
21 if dict1
.get(key
) != dict2
.get(key
):
25 items
= '_items_' # dictionary key to keep list of items in
26 def check_data_array(dat
,hours
,a
,s
):
31 print("array %s %s length %i min %s max %s hash %s %s" %
32 (a
,s
,len(ar
),min(ar
),max(ar
),hash2(ar
),type(ar
)))
35 print('len(%a) = %i does not equal to hours = %i' % (a
,len(ar
),hours
))
38 print(a
+ ' not present')
40 def check_data_scalar(dat
,a
):
44 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
46 print(a
+ ' not present' )
48 def check_data(dat
,case
=True,name
=None):
49 dat
[items
] = list(dat
.keys()) # add list of items to the dictionary
53 check_data_scalar(dat
,'filename')
54 check_data_scalar(dat
,'title')
55 check_data_scalar(dat
,'note')
56 check_data_scalar(dat
,'hours')
57 check_data_scalar(dat
,'h2')
58 check_data_scalar(dat
,'case')
63 check_data_array(dat
,hours
,'E','drying equilibrium (%)')
64 check_data_array(dat
,hours
,'Ed','drying equilibrium (%)')
65 check_data_array(dat
,hours
,'Ew','wetting equilibrium (%)')
66 check_data_array(dat
,hours
,'Ec','equilibrium equilibrium (%)')
67 check_data_array(dat
,hours
,'rain','rain intensity (mm/h)')
68 check_data_array(dat
,hours
,'fm','RAWS fuel moisture data (%)')
69 check_data_array(dat
,hours
,'m','fuel moisture estimate (%)')
71 print('items:',dat
[items
])
72 for a
in dat
[items
].copy():
74 if dat
[a
] is None or np
.isscalar(dat
[a
]):
75 check_data_scalar(dat
,a
)
76 elif is_numeric_ndarray(ar
):
78 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
79 "max",np
.max(ar
),"hash",hash2(ar
),"type",type(ar
))
80 elif isinstance(ar
, tf
.Tensor
):
81 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
82 "max",np
.max(ar
),"type",type(ar
))
84 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
85 del dat
[items
] # clean up
87 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
88 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89 def to_json(dic
,filename
):
90 # Write given dictionary as json file.
91 # This utility is used because the typical method fails on numpy.ndarray
94 # filename: (str) output json filename, expect a ".json" file extension
97 print('writing ',filename
)
101 if type(dic
[i
]) is np
.ndarray
:
102 new
[i
]=dic
[i
].tolist() # because numpy.ndarray is not serializable
105 # print('i',type(new[i]))
106 new
['filename']=filename
107 print('Hash: ', hash2(new
))
108 json
.dump(new
,open(filename
,'w'),indent
=4)
110 def from_json(filename
):
111 # Read json file given a filename
112 # Inputs: filename (str) expect a ".json" string
114 print('reading ',filename
)
115 dic
=json
.load(open(filename
,'r'))
118 if type(dic
[i
]) is list:
119 new
[i
]=np
.array(dic
[i
]) # because ndarray is not serializable
123 print('Hash: ', hash2(new
))
126 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128 # Function to simulate moisture data and equilibrium for model testing
129 def create_synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,DeltaE
=0.0):
132 hour
= np
.array(range(hours
))
133 day
= np
.array(range(hours
))/24.
135 # artificial equilibrium data
136 E
= 100.0*np
.power(np
.sin(np
.pi
*day
),4) # diurnal curve
139 m_f
= np
.zeros(hours
)
140 m_f
[0] = 0.1 # initial FMC
142 for t
in range(hours
-1):
143 m_f
[t
+1] = max(0.,model_decay(m_f
[t
],E
[t
]) + random
.gauss(0,process_noise
) )
144 data
= m_f
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
146 return E
,m_f
,data
,hour
,h2
,DeltaE
148 # the following input or output dictionary with all model data and variables
150 def synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,
151 DeltaE
=0.0,Emin
=5,Emax
=30,p_rain
=0.01,max_rain
=10.0):
154 hour
= np
.array(range(hours
))
155 day
= np
.array(range(hours
))/24.
156 # artificial equilibrium data
157 E
= np
.power(np
.sin(np
.pi
*day
),power
) # diurnal curve betwen 0 and 1
158 E
= Emin
+(Emax
- Emin
)*E
161 Ew
=np
.maximum(E
-0.5,0)
162 rain
= np
.multiply(rand(hours
) < p_rain
, rand(hours
)*max_rain
)
165 fm
[0] = 0.1 # initial FMC
167 for t
in range(hours
-1):
168 fm
[t
+1] = max(0.,model_moisture(fm
[t
],Ed
[t
-1],Ew
[t
-1],rain
[t
-1]) + random
.gauss(0,process_noise
))
169 fm
= fm
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
170 dat
= {'E':E
,'Ew':Ew
,'Ed':Ed
,'fm':fm
,'hours':hours
,'h2':h2
,'DeltaE':DeltaE
,'rain':rain
,'title':'Synthetic data'}
174 def plot_one(hmin
,hmax
,dat
,name
,linestyle
,c
,label
, alpha
=1,type='plot'):
175 # helper for plot_data
182 hour
= np
.array(range(hmin
,hmax
))
184 plt
.plot(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
185 elif type=='scatter':
186 plt
.scatter(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
188 # Lookup table for plotting features
190 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
191 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
192 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
194 def plot_feature(x
, y
, feature_name
):
195 style
= plot_styles
.get(feature_name
, {})
196 plt
.plot(x
, y
, **style
)
198 def plot_features(hmin
,hmax
,dat
,linestyle
,c
,label
,alpha
=1):
199 hour
= np
.array(range(hmin
,hmax
))
200 for feat
in dat
.features_list
:
201 i
= dat
.all_features_list
.index(feat
) # index of main data
202 if feat
in plot_styles
.keys():
203 plot_feature(x
=hour
, y
=dat
['X'][:,i
][hmin
:hmax
], feature_name
=feat
)
205 def plot_data(dat
, plot_period
='all', create_figure
=False,title
=None,title2
=None,hmin
=0,hmax
=None,xlabel
=None,ylabel
=None):
206 # Plot fmda dictionary of data and model if present
208 # dat: FMDA dictionary
209 # inverse_scale: logical, whether to inverse scale data
212 # dat = copy.deepcopy(dat0)
218 hmax
= min(hmax
, dat
['hours'])
219 if plot_period
== "all":
221 elif plot_period
== "predict":
222 assert "test_ind" in dat
.keys()
223 hmin
= dat
['test_ind']
226 raise ValueError(f
"unrecognized time period for plotting plot_period: {plot_period}")
230 plt
.figure(figsize
=(16,4))
232 plot_one(hmin
,hmax
,dat
,'y',linestyle
='-',c
='#468a29',label
='FM Observed')
233 plot_one(hmin
,hmax
,dat
,'m',linestyle
='-',c
='k',label
='FM Model')
234 plot_features(hmin
,hmax
,dat
,linestyle
='-',c
='k',label
='FM Model')
237 if 'test_ind' in dat
.keys():
238 test_ind
= dat
["test_ind"]
241 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
242 # Note: the code within the tildes here makes a more complex, annotated plot
243 if (test_ind
is not None) and ('m' in dat
.keys()):
244 plt
.axvline(test_ind
, linestyle
=':', c
='k', alpha
=.8)
245 yy
= plt
.ylim() # used to format annotations
246 plot_y0
= np
.max([hmin
, test_ind
]) # Used to format annotations
247 plot_y1
= np
.min([hmin
, test_ind
])
248 plt
.annotate('', xy
=(hmin
, yy
[0]),xytext
=(plot_y0
,yy
[0]),
249 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
250 annotation_clip
=False)
251 plt
.annotate('(Training)',xy
=((hmin
+plot_y0
)/2,yy
[1]),xytext
=((hmin
+plot_y0
)/2,yy
[1]+1), ha
= 'right',
252 annotation_clip
=False, alpha
=.8)
253 plt
.annotate('', xy
=(plot_y0
, yy
[0]),xytext
=(hmax
,yy
[0]),
254 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
255 annotation_clip
=False)
256 plt
.annotate('(Forecast)',xy
=(hmax
-(hmax
-test_ind
)/2,yy
[1]),
257 xytext
=(hmax
-(hmax
-test_ind
)/2,yy
[1]+1),
258 annotation_clip
=False, alpha
=.8)
259 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
262 if title
is not None:
266 # print('title',type(t),t)
269 if title2
is not None:
271 t
= t
+ ' (' + rmse_data_str(dat
)+')'
272 if plot_period
== "predict":
273 t
= t
+ " - Forecast Period"
277 plt
.xlabel('Time (hours)')
281 plt
.ylabel('FM (%) / Rain (mm/h)')
283 plt
.ylabel('Fuel moisture content (%)')
286 plt
.legend(loc
="upper left")
289 return np
.sqrt(mean_squared_error(a
.flatten(), b
.flatten()))
291 def rmse_skip_nan(x
, y
):
292 mask
= ~np
.isnan(x
) & ~np
.isnan(y
)
293 if np
.count_nonzero(mask
):
294 return np
.sqrt(np
.mean((x
[mask
] - y
[mask
]) ** 2))
299 rmse
= rmse_skip_nan(a
,b
)
300 return "RMSE " + "{:.3f}".format(rmse
)
302 def rmse_data_str(dat
, predict
=True, hours
= None, test_ind
= None):
303 # Return RMSE for model object in formatted string. Used within plotting
305 # dat: (dict) fmda dictionary
306 # predict: (bool) Whether to return prediction period RMSE. Default True
307 # hours: (int) total number of modeled time periods
308 # test_ind: (int) start of test period
309 # Return: (str) RMSE value
315 if 'test_ind' in dat
:
316 test_ind
= dat
['test_ind']
318 if 'm' in dat
and 'y' in dat
:
319 if predict
and hours
is not None and test_ind
is not None:
320 return rmse_str(dat
['m'][test_ind
:hours
],dat
['y'].flatten()[test_ind
:hours
])
322 return rmse_str(dat
['m'],dat
['y'].flatten())
327 # Calculate mean absolute error
329 return ((a
- b
).__abs
__()).mean()
331 def rmse_data(dat
, hours
= None, h2
= None, simulation
='m', measurements
='fm'):
338 fm
= dat
[measurements
]
341 train
=rmse(m
[:h2
], fm
[:h2
])
342 predict
= rmse(m
[h2
:hours
], fm
[h2
:hours
])
343 all
= rmse(m
[:hours
], fm
[:hours
])
344 print(case
,'Training 1 to',h2
,'hours RMSE: ' + str(np
.round(train
, 4)))
345 print(case
,'Prediction',h2
+1,'to',hours
,'hours RMSE: ' + str(np
.round(predict
, 4)))
346 print(f
"All predictions hash: {hash2(m)}")
348 return {'train':train
, 'predict':predict
, 'all':all
}
352 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
356 def get_file(filename
, data_dir
='data'):
357 # Check for file locally, retrieve with wget if not
358 if osp
.exists(osp
.join(data_dir
, filename
)):
359 print(f
"File {osp.join(data_dir, filename)} exists locally")
360 elif not osp
.exists(filename
):
362 base_url
= "https://demo.openwfm.org/web/data/fmda/dicts/"
363 print(f
"Retrieving data {osp.join(base_url, filename)}")
364 subprocess
.call(f
"wget -P {data_dir} {osp.join(base_url, filename)}", shell
=True)