Update moisture_rnn_pkl.py
[notebooks.git] / fmda / moisture_rnn_pkl.py
blobc24125b0c8b13d055ec6c60fbbf536fad0dca1a0
1 import sys
2 import logging
3 from utils import print_dict_summary, print_first, str2time, check_increment, time_intp
4 import pickle
5 import os.path as osp
6 import pandas as pd
7 import numpy as np
8 import reproducibility
9 from moisture_rnn import create_rnn_data_2, train_rnn, rnn_predict
10 from data_funcs import plot_data,rmse_data
11 import matplotlib.pyplot as plt
13 # run this from test-pkl2train.ipynb
15 def pkl2train(input_file_paths,output_file_path='train.pkl',forecast_step=1):
16 # in:
17 # file_path list of strings - files as in read_test_pkl
18 # forecast_step int - which forecast step to take atmospheric data from (maybe 03, must be >0).
19 # return:
20 # train dictionary with structure
21 # {key : {'key' : key, # copied subdict key
22 # 'loc' : {...}, # copied from in dict = {key : {'loc': ... }...}
23 # 'time' : time, # datetime vector, spacing tres
24 # 'X' : fm # target fuel moisture from the RAWS, interpolated to time
25 # 'Y' : feat # features from atmosphere and location
29 if forecast_step > 0 and forecast_step < 100 and forecast_step == int(forecast_step):
30 fstep='f'+str(forecast_step).zfill(2)
31 fprev='f'+str(forecast_step-1).zfill(2)
32 logging.info('Using data from step %s',fstep)
33 logging.info('Using rain as the difference of accumulated precipitation between %s and %s',fstep,fprev)
34 else:
35 logging.critical('forecast_step must be integer between 1 and 99')
36 raise ValueError('bad forecast_step')
38 train = {}
39 for file_path in input_file_paths:
40 with open(file_path, 'rb') as file:
41 logging.info("loading file %s", file_path)
42 d = pickle.load(file)
43 for key in d:
44 logging.info('Processing subdictionary %s',key)
45 if key in train:
46 logging.warning('skipping duplicate key %s',key)
47 else:
48 subdict=d[key] # subdictionary for this case
49 loc=subdict['loc']
50 train[key] = {
51 'id': key, # store the key inside the dictionary, subdictionary will be used separatedly
52 'case':key,
53 'filename': file_path,
54 'loc': loc
56 desc='descr'
57 if desc in subdict:
58 train[desc]=subdict[desc]
59 time_hrrr=str2time(subdict['HRRR']['time'])
60 # timekeeping
61 timesteps=len(train[key]['HRRR']['time'])
62 hours=timesteps
63 train[key]['hours']=hours
64 train[key]['h2'] =hours # not doing prediction yet
65 hrrr_increment = check_increment(time_hrrr,id=key+' HRRR.time')
66 logging.info('HRRR increment is %s h',hrrr_increment)
67 if hrrr_increment < 1:
68 logging.critical('HRRR increment is %s h must be at least 1 h',hrrr_increment)
69 raise(ValueError)
71 # build matrix of features - assuming all the same length, if not column_stack will fail
72 train[key]['time']=time_hrrr
74 columns=[]
75 # location as features constant in time come first
76 columns.append(np.full(timesteps,loc['elev']))
77 columns.append(np.full(timesteps,loc['lon']))
78 columns.append(np.full(timesteps,loc['lat']))
79 for i in ["rh","wind","solar","soilm","groundflux","Ed","Ew"]:
80 columns.append(subdict['HRRR'][fstep][i]) # add variables from HRRR forecast steps
81 # compute rain as difference of accumulated precipitation
82 rain = subdict['HRRR'][fstep]['precip_accum']- subdict['HRRR'][fprev]['precip_accum']
83 logging.info('%s rain as difference %s minus %s: min %s max %s',
84 key,fstep,fprev,np.min(rain),np.max(rain))
85 columns.append( rain ) # add rain feature
86 train[key]['X'] = np.column_stack(columns)
88 logging.info(f"Created feature matrix train[{key}]['X'] shape {train[key]['X'].shape}")
89 time_raws=str2time(subdict['RAWS']['time_raws']) # may not be the same as HRRR
90 logging.info('%s RAWS.time_raws length is %s',key,len(time_raws))
91 check_increment(time_raws,id=key+' RAWS.time_raws')
92 # print_first(time_raws,num=5,id='RAWS.time_raws')
93 fm=subdict['RAWS']['fm']
94 logging.info('%s RAWS.fm length is %s',key,len(fm))
95 # interpolate RAWS sensors to HRRR time and over NaNs
96 train[key]['Y'] = time_intp(time_raws,fm,time_hrrr)
97 # TODO: check endpoint interpolation when RAWS data sparse, and bail out if not enough data
99 if train[key]['Y'] is None:
100 logging.error('Cannot create target matrix for %s, using None',key)
101 else:
102 logging.info(f"Created target matrix train[{key}]['Y'] shape {train[key]['Y'].shape}")
104 logging.info('Created a "train" dictionary with %s items',len(train))
106 # clean up
108 keys_to_delete = []
109 for key in train:
110 if train[key]['X'] is None or train[key]['Y'] is None:
111 logging.warning('Deleting training item %s because features X or target Y are None', key)
112 keys_to_delete.append(key)
114 # Delete the items from the dictionary
115 if len(keys_to_delete)>0:
116 for key in keys_to_delete:
117 del train[key]
118 logging.warning('Deleted %s items with None for data. %s items remain in the training dictionary.',
119 len(keys_to_delete),len(train))
121 # output
123 if output_file_path is not None:
124 with open(output_file_path, 'wb') as file:
125 logging.info('Writing pickle dump of the dictionary train into file %s',output_file_path)
126 pickle.dump(train, file)
128 logging.info('pkl2train done')
130 return train
132 def run_rnn_pkl(case_data,params, title2=None):
133 # analogous to run_rnn after the create_rnn_data_1 stage
134 # instead, after pkl2train
135 # Inputs:
136 # case_data: (dict) one case train[case] after pkl2train()
137 # also plays the role of rnn_dat after create_rnn_data_1
138 # title2: (str) string to add to plot titles
139 # called from: top level
141 logging.info('run_rnn start')
142 verbose = params['verbose']
144 if title2 is None:
145 title2=case_data['id']
147 reproducibility.set_seed() # Set seed for reproducibility
149 print('case_data at entry to run_rnn_pkl')
150 print_dict_summary(case_data)
152 # add batched x_train, y_train
153 create_rnn_data_2(case_data,params)
155 # train the rnn over period create prediction model with optimized weights
156 model_predict = train_rnn(
157 case_data,
158 params,
159 case_data['hours']
162 m = rnn_predict(model_predict, params, case_data)
163 case_data['m'] = m
165 plot_data(case_data,title2=title2)
166 plt.show()
167 logging.info('run_rnn_pkl end')
168 # return m, rmse_data(case_data) # do not have a "measurements" field
169 return m