1 ## Set of Functions to process and format fuel moisture model inputs
2 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 import numpy
as np
, random
5 from numpy
.random
import rand
6 import tensorflow
as tf
8 from sklearn
.metrics
import mean_squared_error
9 import matplotlib
.pyplot
as plt
10 from moisture_models
import model_decay
, model_moisture
11 from datetime
import datetime
, timedelta
12 from utils
import is_numeric_ndarray
, hash2
18 # Utility to combine nested fmda dictionaries
19 def combine_nested(nested_input_dict
, verbose
=True):
21 Combines input data dictionaries.
25 verbose : bool, optional
26 If True, prints status messages. Default is True.
28 # Setup return dictionary
30 # Use the helper function to populate the keys
31 d
['id'] = _combine_key(nested_input_dict
, 'id')
32 d
['case'] = _combine_key(nested_input_dict
, 'case')
33 d
['filename'] = _combine_key(nested_input_dict
, 'filename')
34 d
['time'] = _combine_key(nested_input_dict
, 'time')
35 d
['X'] = _combine_key(nested_input_dict
, 'X')
36 d
['y'] = _combine_key(nested_input_dict
, 'y')
38 # Build the loc subdictionary using _combine_key for each loc key
40 'STID': _combine_key(nested_input_dict
, 'loc', 'STID'),
41 'lat': _combine_key(nested_input_dict
, 'loc', 'lat'),
42 'lon': _combine_key(nested_input_dict
, 'loc', 'lon'),
43 'elev': _combine_key(nested_input_dict
, 'loc', 'elev'),
44 'pixel_x': _combine_key(nested_input_dict
, 'loc', 'pixel_x'),
45 'pixel_y': _combine_key(nested_input_dict
, 'loc', 'pixel_y')
48 # Handle features_list separately with validation
49 features_list
= _combine_key(nested_input_dict
, 'features_list')
51 first_features_list
= features_list
[0]
52 for fl
in features_list
:
53 if fl
!= first_features_list
:
54 warnings
.warn("Different features_list found in the nested input dictionaries.")
55 d
['features_list'] = first_features_list
59 def _combine_key(nested_input_dict
, key
, subkey
=None):
61 for input_dict
in nested_input_dict
.values():
62 if isinstance(input_dict
, dict):
65 combined_list
.append(input_dict
[key
][subkey
])
67 combined_list
.append(input_dict
[key
])
69 raise ValueError(f
"Missing expected key: '{key}'{f' or subkey: {subkey}' if subkey else ''} in one of the input dictionaries")
71 raise ValueError(f
"Expected a dictionary, but got {type(input_dict)}")
75 def compare_dicts(dict1
, dict2
, keys
):
77 if dict1
.get(key
) != dict2
.get(key
):
81 items
= '_items_' # dictionary key to keep list of items in
82 def check_data_array(dat
,hours
,a
,s
):
87 print("array %s %s length %i min %s max %s hash %s %s" %
88 (a
,s
,len(ar
),min(ar
),max(ar
),hash2(ar
),type(ar
)))
91 print('len(%a) = %i does not equal to hours = %i' % (a
,len(ar
),hours
))
94 print(a
+ ' not present')
96 def check_data_scalar(dat
,a
):
100 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
102 print(a
+ ' not present' )
104 def check_data(dat
,case
=True,name
=None):
105 dat
[items
] = list(dat
.keys()) # add list of items to the dictionary
109 check_data_scalar(dat
,'filename')
110 check_data_scalar(dat
,'title')
111 check_data_scalar(dat
,'note')
112 check_data_scalar(dat
,'hours')
113 check_data_scalar(dat
,'h2')
114 check_data_scalar(dat
,'case')
119 check_data_array(dat
,hours
,'E','drying equilibrium (%)')
120 check_data_array(dat
,hours
,'Ed','drying equilibrium (%)')
121 check_data_array(dat
,hours
,'Ew','wetting equilibrium (%)')
122 check_data_array(dat
,hours
,'Ec','equilibrium equilibrium (%)')
123 check_data_array(dat
,hours
,'rain','rain intensity (mm/h)')
124 check_data_array(dat
,hours
,'fm','RAWS fuel moisture data (%)')
125 check_data_array(dat
,hours
,'m','fuel moisture estimate (%)')
127 print('items:',dat
[items
])
128 for a
in dat
[items
].copy():
130 if dat
[a
] is None or np
.isscalar(dat
[a
]):
131 check_data_scalar(dat
,a
)
132 elif is_numeric_ndarray(ar
):
134 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
135 "max",np
.max(ar
),"hash",hash2(ar
),"type",type(ar
))
136 elif isinstance(ar
, tf
.Tensor
):
137 print("array", a
, "shape",ar
.shape
,"min",np
.min(ar
),
138 "max",np
.max(ar
),"type",type(ar
))
140 print('%s = %s' % (a
,dat
[a
]),' ',type(dat
[a
]))
141 del dat
[items
] # clean up
143 # Note: the project structure has moved towards pickle files, so these json funcs might not be needed
144 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145 def to_json(dic
,filename
):
146 # Write given dictionary as json file.
147 # This utility is used because the typical method fails on numpy.ndarray
150 # filename: (str) output json filename, expect a ".json" file extension
153 print('writing ',filename
)
157 if type(dic
[i
]) is np
.ndarray
:
158 new
[i
]=dic
[i
].tolist() # because numpy.ndarray is not serializable
161 # print('i',type(new[i]))
162 new
['filename']=filename
163 print('Hash: ', hash2(new
))
164 json
.dump(new
,open(filename
,'w'),indent
=4)
166 def from_json(filename
):
167 # Read json file given a filename
168 # Inputs: filename (str) expect a ".json" string
170 print('reading ',filename
)
171 dic
=json
.load(open(filename
,'r'))
174 if type(dic
[i
]) is list:
175 new
[i
]=np
.array(dic
[i
]) # because ndarray is not serializable
179 print('Hash: ', hash2(new
))
182 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184 # Function to simulate moisture data and equilibrium for model testing
185 def create_synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,DeltaE
=0.0):
188 hour
= np
.array(range(hours
))
189 day
= np
.array(range(hours
))/24.
191 # artificial equilibrium data
192 E
= 100.0*np
.power(np
.sin(np
.pi
*day
),4) # diurnal curve
195 m_f
= np
.zeros(hours
)
196 m_f
[0] = 0.1 # initial FMC
198 for t
in range(hours
-1):
199 m_f
[t
+1] = max(0.,model_decay(m_f
[t
],E
[t
]) + random
.gauss(0,process_noise
) )
200 data
= m_f
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
202 return E
,m_f
,data
,hour
,h2
,DeltaE
204 # the following input or output dictionary with all model data and variables
206 def synthetic_data(days
=20,power
=4,data_noise
=0.02,process_noise
=0.0,
207 DeltaE
=0.0,Emin
=5,Emax
=30,p_rain
=0.01,max_rain
=10.0):
210 hour
= np
.array(range(hours
))
211 day
= np
.array(range(hours
))/24.
212 # artificial equilibrium data
213 E
= np
.power(np
.sin(np
.pi
*day
),power
) # diurnal curve betwen 0 and 1
214 E
= Emin
+(Emax
- Emin
)*E
217 Ew
=np
.maximum(E
-0.5,0)
218 rain
= np
.multiply(rand(hours
) < p_rain
, rand(hours
)*max_rain
)
221 fm
[0] = 0.1 # initial FMC
223 for t
in range(hours
-1):
224 fm
[t
+1] = max(0.,model_moisture(fm
[t
],Ed
[t
-1],Ew
[t
-1],rain
[t
-1]) + random
.gauss(0,process_noise
))
225 fm
= fm
+ np
.random
.normal(loc
=0,scale
=data_noise
,size
=hours
)
226 dat
= {'E':E
,'Ew':Ew
,'Ed':Ed
,'fm':fm
,'hours':hours
,'h2':h2
,'DeltaE':DeltaE
,'rain':rain
,'title':'Synthetic data'}
230 def plot_one(hmin
,hmax
,dat
,name
,linestyle
,c
,label
, alpha
=1,type='plot'):
231 # helper for plot_data
238 hour
= np
.array(range(hmin
,hmax
))
240 plt
.plot(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
241 elif type=='scatter':
242 plt
.scatter(hour
,dat
[name
][hmin
:hmax
],linestyle
=linestyle
,c
=c
,label
=label
, alpha
=alpha
)
244 # Lookup table for plotting features
246 'Ed': {'color': '#EF847C', 'linestyle': '--', 'alpha':.8, 'label': 'drying EQ'},
247 'Ew': {'color': '#7CCCEF', 'linestyle': '--', 'alpha':.8, 'label': 'wetting EQ'},
248 'rain': {'color': 'b', 'linestyle': '-', 'alpha':.9, 'label': 'Rain'}
250 def plot_feature(x
, y
, feature_name
):
251 style
= plot_styles
.get(feature_name
, {})
252 plt
.plot(x
, y
, **style
)
254 def plot_features(hmin
,hmax
,dat
,linestyle
,c
,label
,alpha
=1):
255 hour
= np
.array(range(hmin
,hmax
))
256 for feat
in dat
.features_list
:
257 i
= dat
.all_features_list
.index(feat
) # index of main data
258 if feat
in plot_styles
.keys():
259 plot_feature(x
=hour
, y
=dat
['X'][:,i
][hmin
:hmax
], feature_name
=feat
)
261 def plot_data(dat
, plot_period
='all', create_figure
=False,title
=None,title2
=None,hmin
=0,hmax
=None,xlabel
=None,ylabel
=None):
262 # Plot fmda dictionary of data and model if present
264 # dat: FMDA dictionary
265 # inverse_scale: logical, whether to inverse scale data
268 # dat = copy.deepcopy(dat0)
274 hmax
= min(hmax
, dat
['hours'])
275 if plot_period
== "all":
277 elif plot_period
== "predict":
278 assert "test_ind" in dat
.keys()
279 hmin
= dat
['test_ind']
282 raise ValueError(f
"unrecognized time period for plotting plot_period: {plot_period}")
286 plt
.figure(figsize
=(16,4))
288 plot_one(hmin
,hmax
,dat
,'y',linestyle
='-',c
='#468a29',label
='FM Observed')
289 plot_one(hmin
,hmax
,dat
,'m',linestyle
='-',c
='k',label
='FM Model')
290 plot_features(hmin
,hmax
,dat
,linestyle
='-',c
='k',label
='FM Model')
293 if 'test_ind' in dat
.keys():
294 test_ind
= dat
["test_ind"]
297 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
298 # Note: the code within the tildes here makes a more complex, annotated plot
299 if (test_ind
is not None) and ('m' in dat
.keys()):
300 plt
.axvline(test_ind
, linestyle
=':', c
='k', alpha
=.8)
301 yy
= plt
.ylim() # used to format annotations
302 plot_y0
= np
.max([hmin
, test_ind
]) # Used to format annotations
303 plot_y1
= np
.min([hmin
, test_ind
])
304 plt
.annotate('', xy
=(hmin
, yy
[0]),xytext
=(plot_y0
,yy
[0]),
305 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
306 annotation_clip
=False)
307 plt
.annotate('(Training)',xy
=((hmin
+plot_y0
)/2,yy
[1]),xytext
=((hmin
+plot_y0
)/2,yy
[1]+1), ha
= 'right',
308 annotation_clip
=False, alpha
=.8)
309 plt
.annotate('', xy
=(plot_y0
, yy
[0]),xytext
=(hmax
,yy
[0]),
310 arrowprops
=dict(arrowstyle
='<-', linewidth
=2),
311 annotation_clip
=False)
312 plt
.annotate('(Forecast)',xy
=(hmax
-(hmax
-test_ind
)/2,yy
[1]),
313 xytext
=(hmax
-(hmax
-test_ind
)/2,yy
[1]+1),
314 annotation_clip
=False, alpha
=.8)
315 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318 if title
is not None:
322 # print('title',type(t),t)
325 if title2
is not None:
327 t
= t
+ ' (' + rmse_data_str(dat
)+')'
328 if plot_period
== "predict":
329 t
= t
+ " - Forecast Period"
333 plt
.xlabel('Time (hours)')
337 plt
.ylabel('FM (%) / Rain (mm/h)')
339 plt
.ylabel('Fuel moisture content (%)')
342 plt
.legend(loc
="upper left")
345 return np
.sqrt(mean_squared_error(a
.flatten(), b
.flatten()))
347 def rmse_skip_nan(x
, y
):
348 mask
= ~np
.isnan(x
) & ~np
.isnan(y
)
349 if np
.count_nonzero(mask
):
350 return np
.sqrt(np
.mean((x
[mask
] - y
[mask
]) ** 2))
355 rmse
= rmse_skip_nan(a
,b
)
356 return "RMSE " + "{:.3f}".format(rmse
)
358 def rmse_data_str(dat
, predict
=True, hours
= None, test_ind
= None):
359 # Return RMSE for model object in formatted string. Used within plotting
361 # dat: (dict) fmda dictionary
362 # predict: (bool) Whether to return prediction period RMSE. Default True
363 # hours: (int) total number of modeled time periods
364 # test_ind: (int) start of test period
365 # Return: (str) RMSE value
371 if 'test_ind' in dat
:
372 test_ind
= dat
['test_ind']
374 if 'm' in dat
and 'y' in dat
:
375 if predict
and hours
is not None and test_ind
is not None:
376 return rmse_str(dat
['m'][test_ind
:hours
],dat
['y'].flatten()[test_ind
:hours
])
378 return rmse_str(dat
['m'],dat
['y'].flatten())
383 # Calculate mean absolute error
385 return ((a
- b
).__abs
__()).mean()
387 def rmse_data(dat
, hours
= None, h2
= None, simulation
='m', measurements
='fm'):
394 fm
= dat
[measurements
]
397 train
=rmse(m
[:h2
], fm
[:h2
])
398 predict
= rmse(m
[h2
:hours
], fm
[h2
:hours
])
399 all
= rmse(m
[:hours
], fm
[:hours
])
400 print(case
,'Training 1 to',h2
,'hours RMSE: ' + str(np
.round(train
, 4)))
401 print(case
,'Prediction',h2
+1,'to',hours
,'hours RMSE: ' + str(np
.round(predict
, 4)))
402 print(f
"All predictions hash: {hash2(m)}")
404 return {'train':train
, 'predict':predict
, 'all':all
}
408 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
412 def get_file(filename
, data_dir
='data'):
413 # Check for file locally, retrieve with wget if not
414 if osp
.exists(osp
.join(data_dir
, filename
)):
415 print(f
"File {osp.join(data_dir, filename)} exists locally")
416 elif not osp
.exists(filename
):
418 base_url
= "https://demo.openwfm.org/web/data/fmda/dicts/"
419 print(f
"Retrieving data {osp.join(base_url, filename)}")
420 subprocess
.call(f
"wget -P {data_dir} {osp.join(base_url, filename)}", shell
=True)