3 from utils
import print_dict_summary
, print_first
, str2time
, check_increment
, time_intp
9 from moisture_rnn
import create_rnn_data_2
, train_rnn
, rnn_predict
10 from data_funcs
import plot_data
,rmse_data
11 import matplotlib
.pyplot
as plt
13 # run this from test-pkl2train.ipynb
15 def pkl2train(input_file_paths
,output_file_path
='train.pkl',forecast_step
=1):
17 # file_path list of strings - files as in read_test_pkl
18 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
20 # train dictionary with structure
21 # {key : {'key' : key, # copied subdict key
22 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
23 # 'time' : time, # datetime vector, spacing tres
24 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
25 # 'Y' : feat # features from atmosphere and location
29 if forecast_step
> 0 and forecast_step
< 100 and forecast_step
== int(forecast_step
):
30 fstep
='f'+str(forecast_step
).zfill(2)
31 fprev
='f'+str(forecast_step
-1).zfill(2)
32 logging
.info('Using data from step %s',fstep
)
33 logging
.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep
,fprev
)
35 logging
.critical('forecast_step must be integer between 1 and 99')
36 raise ValueError('bad forecast_step')
39 for file_path
in input_file_paths
:
40 with
open(file_path
, 'rb') as file:
41 logging
.info("loading file %s", file_path
)
44 logging
.info('Processing subdictionary %s',key
)
46 logging
.warning('skipping duplicate key %s',key
)
48 subdict
=d
[key
] # subdictionary for this case
51 'id': key
, # store the key inside the dictionary, subdictionary will be used separatedly
53 'filename': file_path
,
58 train
[desc
]=subdict
[desc
]
59 time_hrrr
=str2time(subdict
['HRRR']['time'])
61 timesteps
=len(train
[key
]['HRRR']['time'])
63 train
[key
]['hours']=hours
64 train
[key
]['h2'] =hours
# not doing prediction yet
65 hrrr_increment
= check_increment(time_hrrr
,id=key
+' HRRR.time')
66 logging
.info('HRRR increment is %s h',hrrr_increment
)
67 if hrrr_increment
< 1:
68 logging
.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment
)
71 # build matrix of features - assuming all the same length, if not column_stack will fail
72 train
[key
]['time']=time_hrrr
75 # location as features constant in time come first
76 columns
.append(np
.full(timesteps
,loc
['elev']))
77 columns
.append(np
.full(timesteps
,loc
['lon']))
78 columns
.append(np
.full(timesteps
,loc
['lat']))
79 for i
in ["rh","wind","solar","soilm","groundflux","Ed","Ew"]:
80 columns
.append(subdict
['HRRR'][fstep
][i
]) # add variables from HRRR forecast steps
81 # compute rain as difference of accumulated precipitation
82 rain
= subdict
['HRRR'][fstep
]['precip_accum']- subdict
['HRRR'][fprev
]['precip_accum']
83 logging
.info('%s rain as difference %s minus %s: min %s max %s',
84 key
,fstep
,fprev
,np
.min(rain
),np
.max(rain
))
85 columns
.append( rain
) # add rain feature
86 train
[key
]['X'] = np
.column_stack(columns
)
88 logging
.info(f
"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
89 time_raws
=str2time(subdict
['RAWS']['time_raws']) # may not be the same as HRRR
90 logging
.info('%s RAWS.time_raws length is %s',key
,len(time_raws
))
91 check_increment(time_raws
,id=key
+' RAWS.time_raws')
92 # print_first(time_raws,num=5,id='RAWS.time_raws')
93 fm
=subdict
['RAWS']['fm']
94 logging
.info('%s RAWS.fm length is %s',key
,len(fm
))
95 # interpolate RAWS sensors to HRRR time and over NaNs
96 train
[key
]['Y'] = time_intp(time_raws
,fm
,time_hrrr
)
97 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
99 if train
[key
]['Y'] is None:
100 logging
.error('Cannot create target matrix for %s, using None',key
)
102 logging
.info(f
"Created target matrix train[{key}]['Y'] shape {train[key]['Y'].shape}")
104 logging
.info('Created a "train" dictionary with %s items',len(train
))
110 if train
[key
]['X'] is None or train
[key
]['Y'] is None:
111 logging
.warning('Deleting training item %s because features X or target Y are None', key
)
112 keys_to_delete
.append(key
)
114 # Delete the items from the dictionary
115 if len(keys_to_delete
)>0:
116 for key
in keys_to_delete
:
118 logging
.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
119 len(keys_to_delete
),len(train
))
123 if output_file_path
is not None:
124 with
open(output_file_path
, 'wb') as file:
125 logging
.info('Writing pickle dump of the dictionary train into file %s',output_file_path
)
126 pickle
.dump(train
, file)
128 logging
.info('pkl2train done')
132 def run_rnn_pkl(case_data
,params
, title2
=None):
133 # analogous to run_rnn after the create_rnn_data_1 stage
134 # instead, after pkl2train
136 # case_data: (dict) one case train[case] after pkl2train()
137 # also plays the role of rnn_dat after create_rnn_data_1
138 # title2: (str) string to add to plot titles
139 # called from: top level
141 logging
.info('run_rnn start')
142 verbose
= params
['verbose']
145 title2
=case_data
['id']
147 reproducibility
.set_seed() # Set seed for reproducibility
149 print('case_data at entry to run_rnn_pkl')
150 print_dict_summary(case_data
)
152 # add batched x_train, y_train
153 create_rnn_data_2(case_data
,params
)
155 # train the rnn over period create prediction model with optimized weights
156 model_predict
= train_rnn(
162 m
= rnn_predict(model_predict
, params
, case_data
)
165 plot_data(case_data
,title2
=title2
)
167 logging
.info('run_rnn_pkl end')
168 # return m, rmse_data(case_data) # do not have a "measurements" field