Update params.yaml
[notebooks.git] / fmda / fmda_rnn_spatial.ipynb
blob896bdd59ed3a361b375c5e82893c62ee63038f87
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN with Spatial Training\n",
9     "\n",
10     "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
30     "import pickle\n",
31     "import logging\n",
32     "import os.path as osp\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from moisture_rnn import RNNParams, RNNData, RNN \n",
35     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
36     "from moisture_rnn import RNN\n",
37     "import reproducibility\n",
38     "from data_funcs import rmse, to_json, combine_nested, process_train_dict\n",
39     "from moisture_models import run_augmented_kf\n",
40     "import copy\n",
41     "import pandas as pd\n",
42     "import matplotlib.pyplot as plt\n",
43     "import yaml\n",
44     "import time"
45    ]
46   },
47   {
48    "cell_type": "code",
49    "execution_count": null,
50    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
51    "metadata": {},
52    "outputs": [],
53    "source": [
54     "logging_setup()"
55    ]
56   },
57   {
58    "cell_type": "code",
59    "execution_count": null,
60    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
61    "metadata": {},
62    "outputs": [],
63    "source": [
64     "retrieve_url(\n",
65     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/fmda_nw_202401-05_f05.pkl\", \n",
66     "    dest_path = \"data/fmda_nw_202401-05_f05.pkl\")"
67    ]
68   },
69   {
70    "cell_type": "code",
71    "execution_count": null,
72    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
73    "metadata": {},
74    "outputs": [],
75    "source": [
76     "file_paths = ['data/fmda_nw_202401-05_f05.pkl']"
77    ]
78   },
79   {
80    "cell_type": "code",
81    "execution_count": null,
82    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
83    "metadata": {},
84    "outputs": [],
85    "source": [
86     "# read/write control\n",
87     "train_file='data/train.pkl'\n",
88     "train_create=True   # if false, read\n",
89     "train_write=False\n",
90     "train_read=False"
91    ]
92   },
93   {
94    "cell_type": "code",
95    "execution_count": null,
96    "id": "604388de-11ab-45c3-9f0d-80bdff0cca60",
97    "metadata": {},
98    "outputs": [],
99    "source": [
100     "# Params used for data filtering\n",
101     "params_data = read_yml(\"params_data.yaml\") \n",
102     "params_data"
103    ]
104   },
105   {
106    "cell_type": "code",
107    "execution_count": null,
108    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
109    "metadata": {},
110    "outputs": [],
111    "source": [
112     "# Params used for setting up RNN\n",
113     "params = read_yml(\"params.yaml\", subkey='rnn') \n",
114     "params"
115    ]
116   },
117   {
118    "cell_type": "code",
119    "execution_count": null,
120    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
121    "metadata": {
122     "scrolled": true
123    },
124    "outputs": [],
125    "source": [
126     "if train_create:\n",
127     "    logging.info('creating the training cases from files %s',file_paths)\n",
128     "    # osp.join works on windows too, joins paths using \\ or /\n",
129     "    train = process_train_dict(file_paths, atm_dict=\"HRRR\", params_data = params_data, verbose=True)\n",
130     "if train_write:\n",
131     "    with open(train_file, 'wb') as file:\n",
132     "        logging.info('Writing the rain cases into file %s',train_file)\n",
133     "        pickle.dump(train, file)\n",
134     "if train_read:\n",
135     "    logging.info('Reading the train cases from file %s',train_file)\n",
136     "    train = read_pkl(train_file)"
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "23cd60c0-9865-4314-9a96-948c3d400c08",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "from itertools import islice\n",
147     "train = {k: train[k] for k in islice(train, 250)}"
148    ]
149   },
150   {
151    "cell_type": "markdown",
152    "id": "efc10cdc-f18b-4781-84da-b8e2eef39981",
153    "metadata": {},
154    "source": [
155     "## Setup Validation Runs"
156    ]
157   },
158   {
159    "cell_type": "markdown",
160    "id": "2d9cd5c5-87ed-41f9-b36c-e0c18d58c841",
161    "metadata": {},
162    "source": [
163     "The following parameters will be used for both serial and spatial models."
164    ]
165   },
166   {
167    "cell_type": "code",
168    "execution_count": null,
169    "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d",
170    "metadata": {},
171    "outputs": [],
172    "source": [
173     "params = RNNParams(params)\n",
174     "params.update({'epochs': 200, \n",
175     "               'learning_rate': 0.001,\n",
176     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
177     "               'recurrent_layers': 2, 'recurrent_units': 30, \n",
178     "               'dense_layers': 2, 'dense_units': 30,\n",
179     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
180     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
181     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
182     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
183     "               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
184     "               'timesteps': 12\n",
185     "              })"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "36823193-b93c-421e-b699-8c1ae5719309",
192    "metadata": {},
193    "outputs": [],
194    "source": [
195     "reproducibility.set_seed(123)"
196    ]
197   },
198   {
199    "cell_type": "markdown",
200    "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
201    "metadata": {},
202    "source": [
203     "## Spatial Data Training\n",
204     "\n",
205     "This method combines the training timeseries data into a single 3-d array, with timeseries at the same location arranged appropriately in the right order for a given `batch_size` hyperparameter. The hidden state of the recurrent layers are set up reset when the location changes. "
206    ]
207   },
208   {
209    "cell_type": "code",
210    "execution_count": null,
211    "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52",
212    "metadata": {},
213    "outputs": [],
214    "source": [
215     "# Start timer for code \n",
216     "start_time = time.time()"
217    ]
218   },
219   {
220    "cell_type": "code",
221    "execution_count": null,
222    "id": "faf93470-b55f-4770-9fa9-3288a2f13fcc",
223    "metadata": {},
224    "outputs": [],
225    "source": [
226     "# Combine Nested Dictionary into Spatial Data\n",
227     "train_sp = Dict(combine_nested(train))"
228    ]
229   },
230   {
231    "cell_type": "code",
232    "execution_count": null,
233    "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
234    "metadata": {},
235    "outputs": [],
236    "source": [
237     "rnn_dat_sp = RNNData(\n",
238     "    train_sp, # input dictionary\n",
239     "    scaler=\"standard\",  # data scaling type\n",
240     "    features_list = params['features_list'] # features for predicting outcome\n",
241     ")\n",
242     "\n",
243     "\n",
244     "rnn_dat_sp.train_test_split(   \n",
245     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
246     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
247     ")\n",
248     "rnn_dat_sp.scale_data()\n",
249     "\n",
250     "rnn_dat_sp.batch_reshape(\n",
251     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
252     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
253     ")"
254    ]
255   },
256   {
257    "cell_type": "code",
258    "execution_count": null,
259    "id": "7431bc95-d384-40fd-a622-bbc0ee68e5cd",
260    "metadata": {},
261    "outputs": [],
262    "source": [
263     "# Update Params specific to spatial training\n",
264     "params.update({\n",
265     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
266     "})"
267    ]
268   },
269   {
270    "cell_type": "code",
271    "execution_count": null,
272    "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
273    "metadata": {},
274    "outputs": [],
275    "source": [
276     "rnn_sp = RNN(params)\n",
277     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
278    ]
279   },
280   {
281    "cell_type": "code",
282    "execution_count": null,
283    "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
284    "metadata": {},
285    "outputs": [],
286    "source": [
287     "errs.mean()"
288    ]
289   },
290   {
291    "cell_type": "code",
292    "execution_count": null,
293    "id": "d53571e3-b6cf-49aa-9848-e3c77053283d",
294    "metadata": {},
295    "outputs": [],
296    "source": [
297     "# End Timer\n",
298     "end_time = time.time()\n",
299     "\n",
300     "# Calculate Code Runtime\n",
301     "elapsed_time_sp = end_time - start_time\n",
302     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
303    ]
304   },
305   {
306    "cell_type": "markdown",
307    "id": "7d8292a2-418c-48ed-aff7-ccbe98b046d3",
308    "metadata": {},
309    "source": [
310     "## Run ODE + KF "
311    ]
312   },
313   {
314    "cell_type": "code",
315    "execution_count": null,
316    "id": "cca12d8c-c0e1-4df4-b2ca-20440485f2f3",
317    "metadata": {},
318    "outputs": [],
319    "source": [
320     "# Get timeseries IDs from previous RNNData object\n",
321     "test_cases = rnn_dat_sp.loc['test_locs']\n",
322     "print(len(test_cases))"
323    ]
324   },
325   {
326    "cell_type": "code",
327    "execution_count": null,
328    "id": "997f2534-7e77-45b3-93bf-d988837dfc0b",
329    "metadata": {},
330    "outputs": [],
331    "source": [
332     "test_ind = rnn_dat_sp.test_ind # Time index for test period start\n",
333     "print(test_ind)"
334    ]
335   },
336   {
337    "cell_type": "code",
338    "execution_count": null,
339    "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43",
340    "metadata": {
341     "scrolled": true
342    },
343    "outputs": [],
344    "source": [
345     "outputs_kf = {}\n",
346     "for case in test_cases:\n",
347     "    print(\"~\"*50)\n",
348     "    print(case)\n",
349     "    # Run Augmented KF\n",
350     "    print('Running Augmented KF')\n",
351     "    train[case]['h2'] = test_ind\n",
352     "    train[case]['scale_fm'] = 1\n",
353     "    m, Ec = run_augmented_kf(train[case])\n",
354     "    y = train[case]['y']        \n",
355     "    train[case]['m_kf'] = m\n",
356     "    print(f\"KF RMSE: {rmse(m[test_ind:],y[test_ind:])}\")\n",
357     "    outputs_kf[case] = {'case':case, 'errs': rmse(m[test_ind:],y[test_ind:])}"
358    ]
359   },
360   {
361    "cell_type": "code",
362    "execution_count": null,
363    "id": "57b19ec5-23f6-44ec-9f71-16d4d69aec68",
364    "metadata": {},
365    "outputs": [],
366    "source": [
367     "df_kf = pd.DataFrame.from_dict(outputs_kf).transpose()\n",
368     "df_kf.head()"
369    ]
370   },
371   {
372    "cell_type": "code",
373    "execution_count": null,
374    "id": "25a9d2fe-83f7-4ef3-a04b-14c970b6e2ba",
375    "metadata": {},
376    "outputs": [],
377    "source": [
378     "df_kf.errs.mean()"
379    ]
380   },
381   {
382    "cell_type": "markdown",
383    "id": "f616bbf8-d89e-4c5b-9e47-59f02246b6f2",
384    "metadata": {},
385    "source": [
386     "## Serial Training\n",
387     "\n",
388     "This method initializes a RNN and uses successive `.fit` calls to train the model one location at a time. This is the naive approach to training a RNN on multiple timeseries, and is used as a baseline to see whether the spatial training scheme improves things."
389    ]
390   },
391   {
392    "cell_type": "code",
393    "execution_count": null,
394    "id": "6fa20e9f-604a-4938-ab68-b71fbb7326df",
395    "metadata": {},
396    "outputs": [],
397    "source": [
398     "# Start timer for code \n",
399     "start_time = time.time()"
400    ]
401   },
402   {
403    "cell_type": "code",
404    "execution_count": null,
405    "id": "f033e78c-a506-4508-a23c-8e6574014872",
406    "metadata": {},
407    "outputs": [],
408    "source": [
409     "# Update Params specific to Serial training\n",
410     "params.update({\n",
411     "    'loc_batch_reset': None, # Used to reset hidden state when location changes for a given batch\n",
412     "    'epochs': 1 # less epochs since fit will be run multiple times over locations\n",
413     "})"
414    ]
415   },
416   {
417    "cell_type": "code",
418    "execution_count": null,
419    "id": "ff1788ec-081b-403f-bcfa-b625f0e3dbe1",
420    "metadata": {},
421    "outputs": [],
422    "source": [
423     "train_cases = rnn_dat_sp.loc['train_locs']\n",
424     "test_cases = rnn_dat_sp.loc['test_locs']"
425    ]
426   },
427   {
428    "cell_type": "code",
429    "execution_count": null,
430    "id": "8a2af45e-e81b-421f-b940-e8779177dd5d",
431    "metadata": {},
432    "outputs": [],
433    "source": [
434     "# Initialize Model with first train case\n",
435     "rnn_dat = RNNData(train[train_cases[0]], params['scaler'], params['features_list'])\n",
436     "rnn_dat.train_test_split(\n",
437     "    time_fracs = [.8, .1, .1]\n",
438     ")\n",
439     "rnn_dat.scale_data()\n",
440     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
441    ]
442   },
443   {
444    "cell_type": "code",
445    "execution_count": null,
446    "id": "ac6fecc2-f614-4506-b5f9-05a6eca3b62e",
447    "metadata": {},
448    "outputs": [],
449    "source": [
450     "reproducibility.set_seed()\n",
451     "rnn = RNN(params)"
452    ]
453   },
454   {
455    "cell_type": "code",
456    "execution_count": null,
457    "id": "79b5af30-7d52-410c-9595-e89e9756fd38",
458    "metadata": {
459     "scrolled": true
460    },
461    "outputs": [],
462    "source": [
463     "# Train\n",
464     "for case in train_cases:\n",
465     "    print(\"~\"*50)\n",
466     "    print(f\"Training with Case {case}\")\n",
467     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
468     "    rnn_dat_temp.train_test_split(\n",
469     "        time_fracs = [.8, .1, .1]\n",
470     "    )\n",
471     "    rnn_dat_temp.scale_data()\n",
472     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
473     "    rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n",
474     "           validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val']))    "
475    ]
476   },
477   {
478    "cell_type": "code",
479    "execution_count": null,
480    "id": "03d716b4-0ff5-4b80-a241-440543ba9b46",
481    "metadata": {
482     "scrolled": true
483    },
484    "outputs": [],
485    "source": [
486     "# Predict\n",
487     "outputs_rnn_serial = {}\n",
488     "test_ind = rnn_dat.test_ind\n",
489     "for i, case in enumerate(test_cases):\n",
490     "    print(\"~\"*50)\n",
491     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
492     "    rnn_dat_temp.train_test_split(\n",
493     "        time_fracs = [.8, .1, .1]\n",
494     "    )\n",
495     "    rnn_dat_temp.scale_data()\n",
496     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])    \n",
497     "    X_temp = rnn_dat_temp.scale_all_X()\n",
498     "    X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
499     "    m = rnn.predict(X_temp).flatten()\n",
500     "    outputs_rnn_serial[case] = {'case':case, 'errs': rmse(m[test_ind:], rnn_dat.y_test)}"
501    ]
502   },
503   {
504    "cell_type": "code",
505    "execution_count": null,
506    "id": "e5a80bae-fe1a-4ec9-b9ac-31d540eaba40",
507    "metadata": {},
508    "outputs": [],
509    "source": [
510     "df_rnn_serial = pd.DataFrame.from_dict(outputs_rnn_serial).transpose()\n",
511     "df_rnn_serial.head()"
512    ]
513   },
514   {
515    "cell_type": "code",
516    "execution_count": null,
517    "id": "0c5b866e-c2bf-4bc1-8f6f-3ba8a9448d07",
518    "metadata": {},
519    "outputs": [],
520    "source": [
521     "df_rnn_serial.errs.mean()"
522    ]
523   },
524   {
525    "cell_type": "code",
526    "execution_count": null,
527    "id": "f5a364cb-01bf-49ad-a704-5aa3c9564967",
528    "metadata": {},
529    "outputs": [],
530    "source": [
531     "# End Timer\n",
532     "end_time = time.time()\n",
533     "\n",
534     "# Calculate Code Runtime\n",
535     "elapsed_time_ser = end_time - start_time\n",
536     "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")"
537    ]
538   },
539   {
540    "cell_type": "markdown",
541    "id": "86795281-f8ea-4141-81ea-c53fae830e80",
542    "metadata": {},
543    "source": [
544     "## Compare"
545    ]
546   },
547   {
548    "cell_type": "code",
549    "execution_count": null,
550    "id": "508a6392-49bc-4471-ad8e-814f60119283",
551    "metadata": {},
552    "outputs": [],
553    "source": [
554     "print(f\"Total Test Cases: {len(test_cases)}\")\n",
555     "print(f\"Total Test Hours: {rnn_dat_temp.y_test.shape[0]}\")"
556    ]
557   },
558   {
559    "cell_type": "code",
560    "execution_count": null,
561    "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0",
562    "metadata": {},
563    "outputs": [],
564    "source": [
565     "print(f\"Spatial Training RMSE: {errs.mean()}\")\n",
566     "print(f\"Serial Training RMSE: {df_rnn_serial.errs.mean()}\")\n",
567     "print(f\"Augmented KF RMSE: {df_kf.errs.mean()}\")"
568    ]
569   },
570   {
571    "cell_type": "code",
572    "execution_count": null,
573    "id": "a73d22ee-707b-44a3-80ab-ad6e671731cf",
574    "metadata": {},
575    "outputs": [],
576    "source": []
577   },
578   {
579    "cell_type": "code",
580    "execution_count": null,
581    "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2",
582    "metadata": {},
583    "outputs": [],
584    "source": [
585     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")\n",
586     "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")"
587    ]
588   },
589   {
590    "cell_type": "code",
591    "execution_count": null,
592    "id": "38ab08fb-ac97-45be-8907-6f9cd124243b",
593    "metadata": {},
594    "outputs": [],
595    "source": []
596   }
597  ],
598  "metadata": {
599   "kernelspec": {
600    "display_name": "Python 3 (ipykernel)",
601    "language": "python",
602    "name": "python3"
603   },
604   "language_info": {
605    "codemirror_mode": {
606     "name": "ipython",
607     "version": 3
608    },
609    "file_extension": ".py",
610    "mimetype": "text/x-python",
611    "name": "python",
612    "nbconvert_exporter": "python",
613    "pygments_lexer": "ipython3",
614    "version": "3.12.5"
615   }
616  },
617  "nbformat": 4,
618  "nbformat_minor": 5