Update rnn_workshop.ipynb
[notebooks.git] / fmda / fmda_rnn_spatial.ipynb
blobd7c33e7a98159ae283b1a3e2451deed051b63c2d
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN with Spatial Training\n",
9     "\n",
10     "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
30     "import pickle\n",
31     "import logging\n",
32     "import os.path as osp\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from moisture_rnn import RNNParams, RNNData, RNN \n",
35     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
36     "from moisture_rnn import RNN\n",
37     "import reproducibility\n",
38     "from data_funcs import rmse, to_json, combine_nested\n",
39     "from moisture_models import run_augmented_kf\n",
40     "import copy\n",
41     "import pandas as pd\n",
42     "import matplotlib.pyplot as plt\n",
43     "import yaml\n",
44     "import time"
45    ]
46   },
47   {
48    "cell_type": "code",
49    "execution_count": null,
50    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
51    "metadata": {},
52    "outputs": [],
53    "source": [
54     "logging_setup()"
55    ]
56   },
57   {
58    "cell_type": "code",
59    "execution_count": null,
60    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
61    "metadata": {},
62    "outputs": [],
63    "source": [
64     "retrieve_url(\n",
65     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
66     "    dest_path = \"fmda_nw_202401-05_f05.pkl\")"
67    ]
68   },
69   {
70    "cell_type": "code",
71    "execution_count": null,
72    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
73    "metadata": {},
74    "outputs": [],
75    "source": [
76     "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
77     "file_names=['fmda_nw_202401-05_f05.pkl']\n",
78     "file_dir='data'\n",
79     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
80    ]
81   },
82   {
83    "cell_type": "code",
84    "execution_count": null,
85    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
86    "metadata": {},
87    "outputs": [],
88    "source": [
89     "# read/write control\n",
90     "train_file='train.pkl'\n",
91     "train_create=False   # if false, read\n",
92     "train_write=False\n",
93     "train_read=True"
94    ]
95   },
96   {
97    "cell_type": "code",
98    "execution_count": null,
99    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
100    "metadata": {
101     "scrolled": true
102    },
103    "outputs": [],
104    "source": [
105     "repro = read_pkl(repro_file)\n",
106     "\n",
107     "if train_create:\n",
108     "    logging.info('creating the training cases from files %s',file_paths)\n",
109     "    # osp.join works on windows too, joins paths using \\ or /\n",
110     "    train = pkl2train(file_paths)\n",
111     "if train_write:\n",
112     "    with open(train_file, 'wb') as file:\n",
113     "        logging.info('Writing the rain cases into file %s',train_file)\n",
114     "        pickle.dump(train, file)\n",
115     "if train_read:\n",
116     "    logging.info('Reading the train cases from file %s',train_file)\n",
117     "    train = read_pkl(train_file)"
118    ]
119   },
120   {
121    "cell_type": "code",
122    "execution_count": null,
123    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
124    "metadata": {},
125    "outputs": [],
126    "source": [
127     "params = read_yml(\"params.yaml\", subkey='rnn')\n",
128     "params"
129    ]
130   },
131   {
132    "cell_type": "code",
133    "execution_count": null,
134    "id": "78cf4dbc-4e7d-4c6d-ac2e-0bac513f92dd",
135    "metadata": {},
136    "outputs": [],
137    "source": [
138     "# from itertools import islice\n",
139     "# train = {k: train[k] for k in islice(train, 100)}\n",
140     "dat = Dict(combine_nested(train))"
141    ]
142   },
143   {
144    "cell_type": "code",
145    "execution_count": null,
146    "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
147    "metadata": {},
148    "outputs": [],
149    "source": [
150     "# Set up output dictionaries\n",
151     "outputs_kf = {}\n",
152     "outputs_rnn_serial = {}\n",
153     "outputs_rnn_spatial = {}"
154    ]
155   },
156   {
157    "cell_type": "markdown",
158    "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
159    "metadata": {},
160    "source": [
161     "## Spatial Data Traing"
162    ]
163   },
164   {
165    "cell_type": "code",
166    "execution_count": null,
167    "id": "c58f9f89-46d8-407c-be8b-8e5f16dbcc51",
168    "metadata": {},
169    "outputs": [],
170    "source": [
171     "params = RNNParams(params)"
172    ]
173   },
174   {
175    "cell_type": "code",
176    "execution_count": null,
177    "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52",
178    "metadata": {},
179    "outputs": [],
180    "source": [
181     "# Start timer\n",
182     "start_time = time.time()"
183    ]
184   },
185   {
186    "cell_type": "code",
187    "execution_count": null,
188    "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
189    "metadata": {},
190    "outputs": [],
191    "source": [
192     "rnn_dat = RNNData(dat, scaler=\"standard\", \n",
193     "                  features_list = ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat',\n",
194     "                                  'solar', 'wind'])\n",
195     "\n",
196     "rnn_dat.train_test_split(   \n",
197     "    time_fracs = [.9, .05, .05],\n",
198     "    space_fracs = [.6, .2, .2]\n",
199     ")\n",
200     "rnn_dat.scale_data()\n",
201     "\n",
202     "rnn_dat.batch_reshape(\n",
203     "    timesteps = params['timesteps'], \n",
204     "    batch_size = params['batch_size']\n",
205     ")"
206    ]
207   },
208   {
209    "cell_type": "code",
210    "execution_count": null,
211    "id": "59ddf393-2024-4093-927f-69f135a165b8",
212    "metadata": {},
213    "outputs": [],
214    "source": [
215     "params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours,\n",
216     "               'loc_batch_reset': rnn_dat.n_seqs, \n",
217     "               'epochs': 100, 'learning_rate': 0.0001,\n",
218     "               'recurrent_layers': 2, 'recurrent_units': 40, 'dense_layers': 2, 'dense_units': 20,\n",
219     "              'features_list': rnn_dat.features_list})"
220    ]
221   },
222   {
223    "cell_type": "code",
224    "execution_count": null,
225    "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
226    "metadata": {},
227    "outputs": [],
228    "source": [
229     "reproducibility.set_seed(123)\n",
230     "rnn = RNN(params)\n",
231     "m, errs = rnn.run_model(rnn_dat)"
232    ]
233   },
234   {
235    "cell_type": "code",
236    "execution_count": null,
237    "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
238    "metadata": {},
239    "outputs": [],
240    "source": [
241     "errs.mean()"
242    ]
243   },
244   {
245    "cell_type": "code",
246    "execution_count": null,
247    "id": "d53571e3-b6cf-49aa-9848-e3c77053283d",
248    "metadata": {},
249    "outputs": [],
250    "source": [
251     "# End Timer\n",
252     "end_time = time.time()\n",
253     "\n",
254     "# Calculate Code Runtime\n",
255     "elapsed_time = end_time - start_time\n",
256     "print(f\"Spatial Training Elapsed time: {elapsed_time:.4f} seconds\")"
257    ]
258   },
259   {
260    "cell_type": "markdown",
261    "id": "7d8292a2-418c-48ed-aff7-ccbe98b046d3",
262    "metadata": {},
263    "source": [
264     "## Run ODE + KF and Compare"
265    ]
266   },
267   {
268    "cell_type": "code",
269    "execution_count": null,
270    "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43",
271    "metadata": {
272     "scrolled": true
273    },
274    "outputs": [],
275    "source": [
276     "outputs_kf = {}\n",
277     "for case in rnn_dat.loc['test_locs']:\n",
278     "    print(\"~\"*50)\n",
279     "    print(case)\n",
280     "    # Run Augmented KF\n",
281     "    print('Running Augmented KF')\n",
282     "    train[case]['h2'] = train[case]['hours'] // 2\n",
283     "    train[case]['scale_fm'] = 1\n",
284     "    m, Ec = run_augmented_kf(train[case])\n",
285     "    y = train[case]['y']        \n",
286     "    train[case]['m'] = m\n",
287     "    print(f\"KF RMSE: {rmse(m,y)}\")\n",
288     "    outputs_kf[case] = {'case':case, 'errs': rmse(m,y)}"
289    ]
290   },
291   {
292    "cell_type": "code",
293    "execution_count": null,
294    "id": "57b19ec5-23f6-44ec-9f71-16d4d69aec68",
295    "metadata": {},
296    "outputs": [],
297    "source": [
298     "df2 = pd.DataFrame.from_dict(outputs_kf).transpose()\n",
299     "df2.head()"
300    ]
301   },
302   {
303    "cell_type": "markdown",
304    "id": "86795281-f8ea-4141-81ea-c53fae830e80",
305    "metadata": {},
306    "source": [
307     "## Compare"
308    ]
309   },
310   {
311    "cell_type": "code",
312    "execution_count": null,
313    "id": "508a6392-49bc-4471-ad8e-814f60119283",
314    "metadata": {},
315    "outputs": [],
316    "source": [
317     "df2.errs.mean()"
318    ]
319   },
320   {
321    "cell_type": "code",
322    "execution_count": null,
323    "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0",
324    "metadata": {},
325    "outputs": [],
326    "source": [
327     "df2.shape"
328    ]
329   },
330   {
331    "cell_type": "code",
332    "execution_count": null,
333    "id": "104ea555-1a88-4293-b2a6-dd870fb4b1ed",
334    "metadata": {},
335    "outputs": [],
336    "source": [
337     "errs.shape"
338    ]
339   },
340   {
341    "cell_type": "code",
342    "execution_count": null,
343    "id": "dc1d5cd6-2321-43b2-ab88-7f44806dc73f",
344    "metadata": {},
345    "outputs": [],
346    "source": [
347     "errs.mean()"
348    ]
349   },
350   {
351    "cell_type": "code",
352    "execution_count": null,
353    "id": "a73d22ee-707b-44a3-80ab-ad6e671731cf",
354    "metadata": {},
355    "outputs": [],
356    "source": []
357   },
358   {
359    "cell_type": "code",
360    "execution_count": null,
361    "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2",
362    "metadata": {},
363    "outputs": [],
364    "source": []
365   }
366  ],
367  "metadata": {
368   "kernelspec": {
369    "display_name": "Python 3 (ipykernel)",
370    "language": "python",
371    "name": "python3"
372   },
373   "language_info": {
374    "codemirror_mode": {
375     "name": "ipython",
376     "version": 3
377    },
378    "file_extension": ".py",
379    "mimetype": "text/x-python",
380    "name": "python",
381    "nbconvert_exporter": "python",
382    "pygments_lexer": "ipython3",
383    "version": "3.12.5"
384   }
385  },
386  "nbformat": 4,
387  "nbformat_minor": 5