4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# v2.1 run RNN with Spatial Training\n",
10 "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
14 "cell_type": "markdown",
15 "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
18 "## Environment Setup"
23 "execution_count": null,
24 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
28 "import numpy as np\n",
29 "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
32 "import os.path as osp\n",
33 "from moisture_rnn_pkl import pkl2train\n",
34 "from moisture_rnn import RNNParams, RNNData, RNN \n",
35 "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
36 "from moisture_rnn import RNN\n",
37 "import reproducibility\n",
38 "from data_funcs import rmse, to_json, combine_nested, process_train_dict\n",
39 "from moisture_models import run_augmented_kf\n",
41 "import pandas as pd\n",
42 "import matplotlib.pyplot as plt\n",
49 "execution_count": null,
50 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
59 "execution_count": null,
60 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
65 " url = \"https://demo.openwfm.org/web/data/fmda/dicts/fmda_nw_202401-05_f05.pkl\", \n",
66 " dest_path = \"data/fmda_nw_202401-05_f05.pkl\")"
71 "execution_count": null,
72 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
76 "file_paths = ['data/fmda_nw_202401-05_f05.pkl']"
81 "execution_count": null,
82 "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
86 "# read/write control\n",
87 "train_file='data/train.pkl'\n",
88 "train_create=True # if false, read\n",
89 "train_write=False\n",
95 "execution_count": null,
96 "id": "604388de-11ab-45c3-9f0d-80bdff0cca60",
100 "# Params used for data filtering\n",
101 "params_data = read_yml(\"params_data.yaml\") \n",
107 "execution_count": null,
108 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
112 "# Params used for setting up RNN\n",
113 "params = read_yml(\"params.yaml\", subkey='rnn') \n",
119 "execution_count": null,
120 "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
126 "if train_create:\n",
127 " logging.info('creating the training cases from files %s',file_paths)\n",
128 " # osp.join works on windows too, joins paths using \\ or /\n",
129 " train = process_train_dict(file_paths, atm_dict=\"HRRR\", params_data = params_data, verbose=True)\n",
131 " with open(train_file, 'wb') as file:\n",
132 " logging.info('Writing the rain cases into file %s',train_file)\n",
133 " pickle.dump(train, file)\n",
135 " logging.info('Reading the train cases from file %s',train_file)\n",
136 " train = read_pkl(train_file)"
141 "execution_count": null,
142 "id": "23cd60c0-9865-4314-9a96-948c3d400c08",
146 "from itertools import islice\n",
147 "train = {k: train[k] for k in islice(train, 250)}"
151 "cell_type": "markdown",
152 "id": "efc10cdc-f18b-4781-84da-b8e2eef39981",
155 "## Setup Validation Runs"
159 "cell_type": "markdown",
160 "id": "2d9cd5c5-87ed-41f9-b36c-e0c18d58c841",
163 "The following parameters will be used for both serial and spatial models."
168 "execution_count": null,
169 "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d",
173 "params = RNNParams(params)\n",
174 "params.update({'epochs': 200, \n",
175 " 'learning_rate': 0.001,\n",
176 " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
177 " 'recurrent_layers': 2, 'recurrent_units': 30, \n",
178 " 'dense_layers': 2, 'dense_units': 30,\n",
179 " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
180 " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
181 " 'bmin': 20, # Lower bound of hidden state batch reset, \n",
182 " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
183 " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
184 " 'timesteps': 12\n",
190 "execution_count": null,
191 "id": "36823193-b93c-421e-b699-8c1ae5719309",
195 "reproducibility.set_seed(123)"
199 "cell_type": "markdown",
200 "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
203 "## Spatial Data Training\n",
205 "This method combines the training timeseries data into a single 3-d array, with timeseries at the same location arranged appropriately in the right order for a given `batch_size` hyperparameter. The hidden state of the recurrent layers are set up reset when the location changes. "
210 "execution_count": null,
211 "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52",
215 "# Start timer for code \n",
216 "start_time = time.time()"
221 "execution_count": null,
222 "id": "faf93470-b55f-4770-9fa9-3288a2f13fcc",
226 "# Combine Nested Dictionary into Spatial Data\n",
227 "train_sp = Dict(combine_nested(train))"
232 "execution_count": null,
233 "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
237 "rnn_dat_sp = RNNData(\n",
238 " train_sp, # input dictionary\n",
239 " scaler=\"standard\", # data scaling type\n",
240 " features_list = params['features_list'] # features for predicting outcome\n",
244 "rnn_dat_sp.train_test_split( \n",
245 " time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
246 " space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
248 "rnn_dat_sp.scale_data()\n",
250 "rnn_dat_sp.batch_reshape(\n",
251 " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
252 " batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
258 "execution_count": null,
259 "id": "7431bc95-d384-40fd-a622-bbc0ee68e5cd",
263 "# Update Params specific to spatial training\n",
265 " 'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
271 "execution_count": null,
272 "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
276 "rnn_sp = RNN(params)\n",
277 "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
282 "execution_count": null,
283 "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
292 "execution_count": null,
293 "id": "d53571e3-b6cf-49aa-9848-e3c77053283d",
298 "end_time = time.time()\n",
300 "# Calculate Code Runtime\n",
301 "elapsed_time_sp = end_time - start_time\n",
302 "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
306 "cell_type": "markdown",
307 "id": "7d8292a2-418c-48ed-aff7-ccbe98b046d3",
315 "execution_count": null,
316 "id": "cca12d8c-c0e1-4df4-b2ca-20440485f2f3",
320 "# Get timeseries IDs from previous RNNData object\n",
321 "test_cases = rnn_dat_sp.loc['test_locs']\n",
322 "print(len(test_cases))"
327 "execution_count": null,
328 "id": "997f2534-7e77-45b3-93bf-d988837dfc0b",
332 "test_ind = rnn_dat_sp.test_ind # Time index for test period start\n",
338 "execution_count": null,
339 "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43",
346 "for case in test_cases:\n",
347 " print(\"~\"*50)\n",
349 " # Run Augmented KF\n",
350 " print('Running Augmented KF')\n",
351 " train[case]['h2'] = test_ind\n",
352 " train[case]['scale_fm'] = 1\n",
353 " m, Ec = run_augmented_kf(train[case])\n",
354 " y = train[case]['y'] \n",
355 " train[case]['m_kf'] = m\n",
356 " print(f\"KF RMSE: {rmse(m[test_ind:],y[test_ind:])}\")\n",
357 " outputs_kf[case] = {'case':case, 'errs': rmse(m[test_ind:],y[test_ind:])}"
362 "execution_count": null,
363 "id": "57b19ec5-23f6-44ec-9f71-16d4d69aec68",
367 "df_kf = pd.DataFrame.from_dict(outputs_kf).transpose()\n",
373 "execution_count": null,
374 "id": "25a9d2fe-83f7-4ef3-a04b-14c970b6e2ba",
382 "cell_type": "markdown",
383 "id": "f616bbf8-d89e-4c5b-9e47-59f02246b6f2",
386 "## Serial Training\n",
388 "This method initializes a RNN and uses successive `.fit` calls to train the model one location at a time. This is the naive approach to training a RNN on multiple timeseries, and is used as a baseline to see whether the spatial training scheme improves things."
393 "execution_count": null,
394 "id": "6fa20e9f-604a-4938-ab68-b71fbb7326df",
398 "# Start timer for code \n",
399 "start_time = time.time()"
404 "execution_count": null,
405 "id": "f033e78c-a506-4508-a23c-8e6574014872",
409 "# Update Params specific to Serial training\n",
411 " 'loc_batch_reset': None, # Used to reset hidden state when location changes for a given batch\n",
412 " 'epochs': 1 # less epochs since fit will be run multiple times over locations\n",
418 "execution_count": null,
419 "id": "ff1788ec-081b-403f-bcfa-b625f0e3dbe1",
423 "train_cases = rnn_dat_sp.loc['train_locs']\n",
424 "test_cases = rnn_dat_sp.loc['test_locs']"
429 "execution_count": null,
430 "id": "8a2af45e-e81b-421f-b940-e8779177dd5d",
434 "# Initialize Model with first train case\n",
435 "rnn_dat = RNNData(train[train_cases[0]], params['scaler'], params['features_list'])\n",
436 "rnn_dat.train_test_split(\n",
437 " time_fracs = [.8, .1, .1]\n",
439 "rnn_dat.scale_data()\n",
440 "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
445 "execution_count": null,
446 "id": "ac6fecc2-f614-4506-b5f9-05a6eca3b62e",
450 "reproducibility.set_seed()\n",
456 "execution_count": null,
457 "id": "79b5af30-7d52-410c-9595-e89e9756fd38",
464 "for case in train_cases:\n",
465 " print(\"~\"*50)\n",
466 " print(f\"Training with Case {case}\")\n",
467 " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
468 " rnn_dat_temp.train_test_split(\n",
469 " time_fracs = [.8, .1, .1]\n",
471 " rnn_dat_temp.scale_data()\n",
472 " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
473 " rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n",
474 " validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val'])) "
479 "execution_count": null,
480 "id": "03d716b4-0ff5-4b80-a241-440543ba9b46",
487 "outputs_rnn_serial = {}\n",
488 "test_ind = rnn_dat.test_ind\n",
489 "for i, case in enumerate(test_cases):\n",
490 " print(\"~\"*50)\n",
491 " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
492 " rnn_dat_temp.train_test_split(\n",
493 " time_fracs = [.8, .1, .1]\n",
495 " rnn_dat_temp.scale_data()\n",
496 " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size']) \n",
497 " X_temp = rnn_dat_temp.scale_all_X()\n",
498 " X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
499 " m = rnn.predict(X_temp).flatten()\n",
500 " outputs_rnn_serial[case] = {'case':case, 'errs': rmse(m[test_ind:], rnn_dat.y_test)}"
505 "execution_count": null,
506 "id": "e5a80bae-fe1a-4ec9-b9ac-31d540eaba40",
510 "df_rnn_serial = pd.DataFrame.from_dict(outputs_rnn_serial).transpose()\n",
511 "df_rnn_serial.head()"
516 "execution_count": null,
517 "id": "0c5b866e-c2bf-4bc1-8f6f-3ba8a9448d07",
521 "df_rnn_serial.errs.mean()"
526 "execution_count": null,
527 "id": "f5a364cb-01bf-49ad-a704-5aa3c9564967",
532 "end_time = time.time()\n",
534 "# Calculate Code Runtime\n",
535 "elapsed_time_ser = end_time - start_time\n",
536 "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")"
540 "cell_type": "markdown",
541 "id": "86795281-f8ea-4141-81ea-c53fae830e80",
549 "execution_count": null,
550 "id": "508a6392-49bc-4471-ad8e-814f60119283",
554 "print(f\"Total Test Cases: {len(test_cases)}\")\n",
555 "print(f\"Total Test Hours: {rnn_dat_temp.y_test.shape[0]}\")"
560 "execution_count": null,
561 "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0",
565 "print(f\"Spatial Training RMSE: {errs.mean()}\")\n",
566 "print(f\"Serial Training RMSE: {df_rnn_serial.errs.mean()}\")\n",
567 "print(f\"Augmented KF RMSE: {df_kf.errs.mean()}\")"
572 "execution_count": null,
573 "id": "a73d22ee-707b-44a3-80ab-ad6e671731cf",
580 "execution_count": null,
581 "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2",
585 "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")\n",
586 "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")"
591 "execution_count": null,
592 "id": "38ab08fb-ac97-45be-8907-6f9cd124243b",
600 "display_name": "Python 3 (ipykernel)",
601 "language": "python",
609 "file_extension": ".py",
610 "mimetype": "text/x-python",
612 "nbconvert_exporter": "python",
613 "pygments_lexer": "ipython3",