Update moisture_rnn.py
[notebooks.git] / fmda / fmda_rnn_spatial.ipynb
blobcc740ff75b4d4b245bcc62de13ccdddc9fcc3abc
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN with Spatial Training\n",
9     "\n",
10     "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
30     "import pickle\n",
31     "import logging\n",
32     "import os.path as osp\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
35     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
36     "from moisture_rnn import RNN\n",
37     "import reproducibility\n",
38     "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
39     "from moisture_models import run_augmented_kf\n",
40     "import copy\n",
41     "import pandas as pd\n",
42     "import matplotlib.pyplot as plt\n",
43     "import yaml\n",
44     "import time"
45    ]
46   },
47   {
48    "cell_type": "code",
49    "execution_count": null,
50    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
51    "metadata": {},
52    "outputs": [],
53    "source": [
54     "logging_setup()"
55    ]
56   },
57   {
58    "cell_type": "code",
59    "execution_count": null,
60    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
61    "metadata": {},
62    "outputs": [],
63    "source": [
64     "filename = \"fmda_rocky_202403-05_f05.pkl\"\n",
65     "retrieve_url(\n",
66     "    url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
67     "    dest_path = f\"data/{filename}\")"
68    ]
69   },
70   {
71    "cell_type": "code",
72    "execution_count": null,
73    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
74    "metadata": {},
75    "outputs": [],
76    "source": [
77     "file_paths = [f\"data/{filename}\"]"
78    ]
79   },
80   {
81    "cell_type": "code",
82    "execution_count": null,
83    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
84    "metadata": {},
85    "outputs": [],
86    "source": [
87     "# read/write control\n",
88     "train_file='data/train.pkl'\n",
89     "train_create=True   # if false, read\n",
90     "train_write=False\n",
91     "train_read=False"
92    ]
93   },
94   {
95    "cell_type": "code",
96    "execution_count": null,
97    "id": "604388de-11ab-45c3-9f0d-80bdff0cca60",
98    "metadata": {},
99    "outputs": [],
100    "source": [
101     "# Params used for data filtering\n",
102     "params_data = read_yml(\"params_data.yaml\") \n",
103     "params_data"
104    ]
105   },
106   {
107    "cell_type": "code",
108    "execution_count": null,
109    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
110    "metadata": {},
111    "outputs": [],
112    "source": [
113     "# Params used for setting up RNN\n",
114     "params = read_yml(\"params.yaml\", subkey='rnn') \n",
115     "params"
116    ]
117   },
118   {
119    "cell_type": "code",
120    "execution_count": null,
121    "id": "be81d76c-3123-4467-982b-d2da5b1c08bd",
122    "metadata": {
123     "scrolled": true
124    },
125    "outputs": [],
126    "source": [
127     "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=False, verbose=True)"
128    ]
129   },
130   {
131    "cell_type": "code",
132    "execution_count": null,
133    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
134    "metadata": {
135     "scrolled": true
136    },
137    "outputs": [],
138    "source": [
139     "# if train_create:\n",
140     "#     logging.info('creating the training cases from files %s',file_paths)\n",
141     "#     # osp.join works on windows too, joins paths using \\ or /\n",
142     "#     train = process_train_dict(file_paths, atm_dict=\"HRRR\", params_data = params_data, verbose=True)\n",
143     "# if train_write:\n",
144     "#     with open(train_file, 'wb') as file:\n",
145     "#         logging.info('Writing the rain cases into file %s',train_file)\n",
146     "#         pickle.dump(train, file)\n",
147     "# if train_read:\n",
148     "#     logging.info('Reading the train cases from file %s',train_file)\n",
149     "#     train = read_pkl(train_file)"
150    ]
151   },
152   {
153    "cell_type": "code",
154    "execution_count": null,
155    "id": "23cd60c0-9865-4314-9a96-948c3d400c08",
156    "metadata": {},
157    "outputs": [],
158    "source": [
159     "from itertools import islice\n",
160     "train = {k: train[k] for k in islice(train, 250)}"
161    ]
162   },
163   {
164    "cell_type": "markdown",
165    "id": "efc10cdc-f18b-4781-84da-b8e2eef39981",
166    "metadata": {},
167    "source": [
168     "## Setup Validation Runs"
169    ]
170   },
171   {
172    "cell_type": "markdown",
173    "id": "2d9cd5c5-87ed-41f9-b36c-e0c18d58c841",
174    "metadata": {},
175    "source": [
176     "The following parameters will be used for both serial and spatial models."
177    ]
178   },
179   {
180    "cell_type": "code",
181    "execution_count": null,
182    "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d",
183    "metadata": {},
184    "outputs": [],
185    "source": [
186     "params = RNNParams(params)\n",
187     "params.update({'epochs': 200, \n",
188     "               'learning_rate': 0.001,\n",
189     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
190     "               'rnn_layers': 2, 'recurrent_units': 30, \n",
191     "               'dense_layers': 2, 'dense_units': 30,\n",
192     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
193     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
194     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
195     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
196     "               'features_list': ['hod', 'Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
197     "               'timesteps': 12,\n",
198     "               'space_fracs': [.8, .1, .1]\n",
199     "              })"
200    ]
201   },
202   {
203    "cell_type": "code",
204    "execution_count": null,
205    "id": "36823193-b93c-421e-b699-8c1ae5719309",
206    "metadata": {},
207    "outputs": [],
208    "source": [
209     "reproducibility.set_seed(123)"
210    ]
211   },
212   {
213    "cell_type": "markdown",
214    "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
215    "metadata": {},
216    "source": [
217     "## Spatial Data Training\n",
218     "\n",
219     "This method combines the training timeseries data into a single 3-d array, with timeseries at the same location arranged appropriately in the right order for a given `batch_size` hyperparameter. The hidden state of the recurrent layers are set up reset when the location changes. "
220    ]
221   },
222   {
223    "cell_type": "code",
224    "execution_count": null,
225    "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52",
226    "metadata": {},
227    "outputs": [],
228    "source": [
229     "# Start timer for code \n",
230     "start_time = time.time()"
231    ]
232   },
233   {
234    "cell_type": "code",
235    "execution_count": null,
236    "id": "faf93470-b55f-4770-9fa9-3288a2f13fcc",
237    "metadata": {},
238    "outputs": [],
239    "source": [
240     "# Combine Nested Dictionary into Spatial Data\n",
241     "train_sp = Dict(combine_nested(train))"
242    ]
243   },
244   {
245    "cell_type": "code",
246    "execution_count": null,
247    "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
248    "metadata": {},
249    "outputs": [],
250    "source": [
251     "# rnn_dat_sp = RNNData(\n",
252     "#     train_sp, # input dictionary\n",
253     "#     scaler=\"standard\",  # data scaling type\n",
254     "#     features_list = params['features_list'] # features for predicting outcome\n",
255     "# )\n",
256     "\n",
257     "\n",
258     "# rnn_dat_sp.train_test_split(   \n",
259     "#     time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
260     "#     space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
261     "# )\n",
262     "# rnn_dat_sp.scale_data()\n",
263     "\n",
264     "# rnn_dat_sp.batch_reshape(\n",
265     "#     timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
266     "#     batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
267     "# )"
268    ]
269   },
270   {
271    "cell_type": "code",
272    "execution_count": null,
273    "id": "af82c50e-bcc4-406d-b759-399119d1af81",
274    "metadata": {},
275    "outputs": [],
276    "source": [
277     "rnn_dat_sp = rnn_data_wrap(train, params)\n",
278     "params.update({\n",
279     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
280     "})"
281    ]
282   },
283   {
284    "cell_type": "code",
285    "execution_count": null,
286    "id": "7431bc95-d384-40fd-a622-bbc0ee68e5cd",
287    "metadata": {},
288    "outputs": [],
289    "source": [
290     "# # Update Params specific to spatial training\n",
291     "# params.update({\n",
292     "#     'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
293     "# })"
294    ]
295   },
296   {
297    "cell_type": "code",
298    "execution_count": null,
299    "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
300    "metadata": {},
301    "outputs": [],
302    "source": [
303     "rnn_sp = RNN(params)\n",
304     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
305    ]
306   },
307   {
308    "cell_type": "code",
309    "execution_count": null,
310    "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
311    "metadata": {},
312    "outputs": [],
313    "source": [
314     "errs.mean()"
315    ]
316   },
317   {
318    "cell_type": "code",
319    "execution_count": null,
320    "id": "d53571e3-b6cf-49aa-9848-e3c77053283d",
321    "metadata": {},
322    "outputs": [],
323    "source": [
324     "# End Timer\n",
325     "end_time = time.time()\n",
326     "\n",
327     "# Calculate Code Runtime\n",
328     "elapsed_time_sp = end_time - start_time\n",
329     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
330    ]
331   },
332   {
333    "cell_type": "markdown",
334    "id": "7d8292a2-418c-48ed-aff7-ccbe98b046d3",
335    "metadata": {},
336    "source": [
337     "## Run ODE + KF "
338    ]
339   },
340   {
341    "cell_type": "code",
342    "execution_count": null,
343    "id": "8eaa136b-b496-4543-8970-dac46cb88df8",
344    "metadata": {},
345    "outputs": [],
346    "source": [
347     "import importlib\n",
348     "import moisture_models\n",
349     "importlib.reload(moisture_models)\n",
350     "from moisture_models import run_augmented_kf"
351    ]
352   },
353   {
354    "cell_type": "code",
355    "execution_count": null,
356    "id": "cca12d8c-c0e1-4df4-b2ca-20440485f2f3",
357    "metadata": {},
358    "outputs": [],
359    "source": [
360     "# Get timeseries IDs from previous RNNData object\n",
361     "test_cases = rnn_dat_sp.loc['test_locs']\n",
362     "print(len(test_cases))"
363    ]
364   },
365   {
366    "cell_type": "code",
367    "execution_count": null,
368    "id": "997f2534-7e77-45b3-93bf-d988837dfc0b",
369    "metadata": {},
370    "outputs": [],
371    "source": [
372     "test_ind = rnn_dat_sp.test_ind # Time index for test period start\n",
373     "print(test_ind)"
374    ]
375   },
376   {
377    "cell_type": "code",
378    "execution_count": null,
379    "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43",
380    "metadata": {
381     "scrolled": true
382    },
383    "outputs": [],
384    "source": [
385     "outputs_kf = {}\n",
386     "for case in test_cases:\n",
387     "    print(\"~\"*50)\n",
388     "    print(case)\n",
389     "    # Run Augmented KF\n",
390     "    print('Running Augmented KF')\n",
391     "    train[case]['h2'] = test_ind\n",
392     "    train[case]['scale_fm'] = 1\n",
393     "    m, Ec = run_augmented_kf(train[case])\n",
394     "    y = train[case]['y']        \n",
395     "    train[case]['m_kf'] = m\n",
396     "    print(f\"KF RMSE: {rmse(m[test_ind:],y[test_ind:])}\")\n",
397     "    outputs_kf[case] = {'case':case, 'errs': rmse(m[test_ind:],y[test_ind:])}"
398    ]
399   },
400   {
401    "cell_type": "code",
402    "execution_count": null,
403    "id": "57b19ec5-23f6-44ec-9f71-16d4d69aec68",
404    "metadata": {},
405    "outputs": [],
406    "source": [
407     "df_kf = pd.DataFrame.from_dict(outputs_kf).transpose()\n",
408     "df_kf.head()"
409    ]
410   },
411   {
412    "cell_type": "code",
413    "execution_count": null,
414    "id": "25a9d2fe-83f7-4ef3-a04b-14c970b6e2ba",
415    "metadata": {},
416    "outputs": [],
417    "source": [
418     "df_kf.errs.mean()"
419    ]
420   },
421   {
422    "cell_type": "markdown",
423    "id": "86795281-f8ea-4141-81ea-c53fae830e80",
424    "metadata": {},
425    "source": [
426     "## Compare"
427    ]
428   },
429   {
430    "cell_type": "code",
431    "execution_count": null,
432    "id": "508a6392-49bc-4471-ad8e-814f60119283",
433    "metadata": {},
434    "outputs": [],
435    "source": [
436     "print(f\"Total Test Cases: {len(test_cases)}\")\n",
437     "print(f\"Total Test Hours: {rnn_dat_sp.y_test.shape[0]}\")"
438    ]
439   },
440   {
441    "cell_type": "code",
442    "execution_count": null,
443    "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0",
444    "metadata": {},
445    "outputs": [],
446    "source": [
447     "print(f\"Spatial Training RMSE: {errs.mean()}\")\n",
448     "print(f\"Augmented KF RMSE: {df_kf.errs.mean()}\")"
449    ]
450   },
451   {
452    "cell_type": "code",
453    "execution_count": null,
454    "id": "a73d22ee-707b-44a3-80ab-ad6e671731cf",
455    "metadata": {},
456    "outputs": [],
457    "source": []
458   },
459   {
460    "cell_type": "code",
461    "execution_count": null,
462    "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2",
463    "metadata": {},
464    "outputs": [],
465    "source": [
466     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
467    ]
468   },
469   {
470    "cell_type": "code",
471    "execution_count": null,
472    "id": "38ab08fb-ac97-45be-8907-6f9cd124243b",
473    "metadata": {},
474    "outputs": [],
475    "source": []
476   }
477  ],
478  "metadata": {
479   "kernelspec": {
480    "display_name": "Python 3 (ipykernel)",
481    "language": "python",
482    "name": "python3"
483   },
484   "language_info": {
485    "codemirror_mode": {
486     "name": "ipython",
487     "version": 3
488    },
489    "file_extension": ".py",
490    "mimetype": "text/x-python",
491    "name": "python",
492    "nbconvert_exporter": "python",
493    "pygments_lexer": "ipython3",
494    "version": "3.12.5"
495   }
496  },
497  "nbformat": 4,
498  "nbformat_minor": 5