Update rnn_train_versions.ipynb
[notebooks.git] / fmda / moisture_rnn_pkl.py
blob767c69fbb879066aead993f8d1478cc11f640a43
1 import sys
2 import logging
3 from utils import print_dict_summary, print_first, str2time, check_increment, time_intp, hash2
4 import pickle
5 import os.path as osp
6 import pandas as pd
7 import numpy as np
8 import reproducibility
9 # from moisture_rnn import create_rnn_data_2, train_rnn, rnn_predict
10 from data_funcs import plot_data,rmse_data
11 import matplotlib.pyplot as plt
12 import sys
13 import yaml
14 import os
16 feature_types = {
17 # Static features are based on physical location, e.g. location of RAWS site
18 'static': ['elev', 'lon', 'lat'],
19 # Atmospheric weather features come from either RAWS subdict or HRRR
20 'atm': ['temp', 'rh', 'wind', 'solar', 'soilm', 'canopyw', 'groundflux', 'Ed', 'Ew']
24 def pkl2train(input_file_paths,
25 forecast_step=1, atm="HRRR",features_all=['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain']):
26 # in:
27 # file_path list of strings - files as in read_test_pkl
28 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
29 # atm str - name of subdict where atmospheric vars are located
30 # features_list list of strings - names of keys in subdicts to collect into features matrix. Default is everything collected
31 # return:
32 # train dictionary with structure
33 # {key : {'key' : key, # copied subdict key
34 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
35 # 'time' : time, # datetime vector, spacing tres
36 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
37 # 'Y' : feat # features from atmosphere and location
42 # TODO: fix this
43 if 'rain' in features_all and (not features_all[-1]=='rain'):
44 raise ValueError(f"Make rain in features list last element since (working on fix as of 24-6-24), given features list: {features_list}")
46 if forecast_step > 0 and forecast_step < 100 and forecast_step == int(forecast_step):
47 fstep='f'+str(forecast_step).zfill(2)
48 fprev='f'+str(forecast_step-1).zfill(2)
49 logging.info('Using data from step %s',fstep)
50 logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
51 else:
52 logging.critical('forecast_step must be integer between 1 and 99')
53 raise ValueError('bad forecast_step')
55 train = {}
56 for file_path in input_file_paths:
57 with open(file_path, 'rb') as file:
58 logging.info("loading file %s", file_path)
59 d = pickle.load(file)
60 for key in d:
61 atm_dict = atm
62 features_list = features_all
63 logging.info('Processing subdictionary %s',key)
64 if key in train:
65 logging.warning('skipping duplicate key %s',key)
66 else:
67 subdict=d[key] # subdictionary for this case
68 loc=subdict['loc']
69 train[key] = {
70 'id': key, # store the key inside the dictionary, subdictionary will be used separatedly
71 'case':key,
72 'filename': file_path,
73 'loc': loc
75 desc='descr'
76 if desc in subdict:
77 train[desc]=subdict[desc]
78 time_hrrr=str2time(subdict[atm_dict]['time'])
79 # timekeeping
80 timesteps=len(d[key][atm_dict]['time'])
81 hours=timesteps
82 train[key]['hours']=hours
83 train[key]['h2'] =hours # not doing prediction yet
84 hrrr_increment = check_increment(time_hrrr,id=key+f' {atm_dict}.time')
85 logging.info(f'{atm_dict} increment is %s h',hrrr_increment)
86 if hrrr_increment < 1:
87 logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
88 raise(ValueError)
90 # build matrix of features - assuming all the same length, if not column_stack will fail
91 train[key]['time']=time_hrrr
93 # TODO: REMOVE THIS
94 scale_fm = 1
95 train[key]["scale_fm"] = scale_fm
96 # Set up static vars, but not for repro case
97 columns=[]
99 for feat in features_list:
100 # For atmospheric features,
101 if feat in feature_types['atm']:
102 if atm_dict == "HRRR":
103 vec = subdict[atm_dict][fstep][feat]
104 elif atm_dict == "RAWS":
105 vec = subdict[atm_dict][feat]
106 if feat in ['Ed', 'Ew']:
107 vec = vec / scale_fm
108 columns.append(vec)
110 # For static features, repeat to fit number of time observations
111 elif feat in feature_types['static']:
112 columns.append(np.full(timesteps,loc[feat]))
114 # compute rain as difference of accumulated precipitation
115 if 'rain' in features_list:
116 if atm_dict == "HRRR":
117 rain = subdict[atm_dict][fstep]['precip_accum']- subdict[atm_dict][fprev]['precip_accum']
118 logging.info('%s rain as difference %s minus %s: min %s max %s',
119 key,fstep,fprev,np.min(rain),np.max(rain))
120 elif atm_dict == "RAWS":
121 if 'rain' in subdict[atm_dict]:
122 rain = subdict[atm_dict]['rain']
123 else:
124 logging.info('No rain data found in RAWS subdictionary %s', key)
125 columns.append( rain ) # add rain feature
126 train[key]['X'] = np.column_stack(columns)
127 train[key]['features_list'] = features_list
129 logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
130 time_raws=str2time(subdict['RAWS']['time_raws']) # may not be the same as HRRR
131 logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
132 check_increment(time_raws,id=key+' RAWS.time_raws')
133 # print_first(time_raws,num=5,id='RAWS.time_raws')
134 fm=subdict['RAWS']['fm']
135 logging.info('%s RAWS.fm length is %s',key,len(fm))
136 # interpolate RAWS sensors to HRRR time and over NaNs
137 train[key]['y'] = time_intp(time_raws,fm,time_hrrr) / scale_fm
138 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
140 if train[key]['y'] is None:
141 logging.error('Cannot create target matrix for %s, using None',key)
142 else:
143 logging.info(f"Created target matrix train[{key}]['y'] shape {train[key]['y'].shape}")
145 logging.info('Created a "train" dictionary with %s items',len(train))
147 # clean up
149 keys_to_delete = []
150 for key in train:
151 if train[key]['X'] is None or train[key]['y'] is None:
152 logging.warning('Deleting training item %s because features X or target Y are None', key)
153 keys_to_delete.append(key)
155 # Delete the items from the dictionary
156 if len(keys_to_delete)>0:
157 for key in keys_to_delete:
158 del train[key]
159 logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
160 len(keys_to_delete),len(train))
162 # output
164 # if output_file_path is not None:
165 # with open(output_file_path, 'wb') as file:
166 # logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
167 # pickle.dump(train, file)
169 logging.info('pkl2train done')
171 return train
173 def run_rnn_pkl(case_data,params, title2=None):
174 # analogous to run_rnn after the create_rnn_data_1 stage
175 # instead, after pkl2train
176 # Inputs:
177 # case_data: (dict) one case train[case] after pkl2train()
178 # also plays the role of rnn_dat after create_rnn_data_1
179 # title2: (str) string to add to plot titles
180 # called from: top level
182 logging.info('run_rnn start')
183 verbose = params['verbose']
185 if title2 is None:
186 title2=case_data['id']
188 reproducibility.set_seed() # Set seed for reproducibility
190 print('case_data at entry to run_rnn_pkl')
191 print_dict_summary(case_data)
193 # add batched x_train, y_train
194 create_rnn_data_2(case_data,params)
196 # train the rnn over period create prediction model with optimized weights
197 model_predict = train_rnn(
198 case_data,
199 params,
200 case_data['hours']
203 m = rnn_predict(model_predict, params, case_data)
204 case_data['m'] = m
205 print(f"Model outputs hash: {hash2(m)}")
207 # Plot data needs certain names
208 # TODO: make plot_data specific to this context
209 case_data.update({"fm": case_data["Y"]*case_data['scale_fm']})
210 plot_data(case_data,title2=title2)
211 plt.show()
213 logging.info('run_rnn_pkl end')
214 # Print and return Errors
215 # return m, rmse_data(case_data) # do not have a "measurements" field
216 return m, rmse_data(case_data)