Update fmda_rnn_test_batch_reset_schedule.ipynb
[notebooks.git] / fmda / test_notebooks / fmda_rnn_test_batch_reset_schedule.ipynb
blob582f1bd800debf37e1322b4534e92e3b8c927d3b
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# Compare Batch Resetting Schedules\n"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "import numpy as np\n",
19     "import pickle\n",
20     "import logging\n",
21     "import os.path as osp\n",
22     "import sys\n",
23     "sys.path.append('..')\n",
24     "from moisture_rnn_pkl import pkl2train\n",
25     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
26     "from utils import hash2, read_yml, read_pkl, retrieve_url, print_dict_summary, print_first, str2time, logging_setup\n",
27     "from moisture_rnn import RNN\n",
28     "import reproducibility\n",
29     "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
30     "from moisture_models import run_augmented_kf\n",
31     "import copy\n",
32     "import pandas as pd\n",
33     "import matplotlib.pyplot as plt\n",
34     "import yaml\n",
35     "import time\n",
36     "import reproducibility\n",
37     "import tensorflow as tf"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "logging_setup()"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "filename = \"fmda_rocky_202311-202402_f05.pkl\"\n",
58     "retrieve_url(\n",
59     "    url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
60     "    dest_path = f\"../data/{filename}\")"
61    ]
62   },
63   {
64    "cell_type": "code",
65    "execution_count": null,
66    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
67    "metadata": {},
68    "outputs": [],
69    "source": [
70     "file_names=[filename]\n",
71     "file_dir='../data'\n",
72     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "params = RNNParams(read_yml(\"../params.yaml\", subkey='rnn'))\n",
83     "params_data = read_yml(\"../params_data.yaml\")"
84    ]
85   },
86   {
87    "cell_type": "code",
88    "execution_count": null,
89    "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
90    "metadata": {
91     "scrolled": true
92    },
93    "outputs": [],
94    "source": [
95     "params_data.update({\n",
96     "    'hours': 720,\n",
97     "    'max_intp_time': 12,\n",
98     "    'zero_lag_threshold': 12\n",
99     "})\n",
100     "# train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)\n",
101     "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=True, verbose=True)"
102    ]
103   },
104   {
105    "cell_type": "code",
106    "execution_count": null,
107    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
108    "metadata": {},
109    "outputs": [],
110    "source": [
111     "# from itertools import islice\n",
112     "# train = {k: train[k] for k in islice(train, 250)}"
113    ]
114   },
115   {
116    "cell_type": "code",
117    "execution_count": null,
118    "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
119    "metadata": {},
120    "outputs": [],
121    "source": [
122     "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
123     "# params.update({'epochs': 200, \n",
124     "#                'learning_rate': 0.001,\n",
125     "#                'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
126     "#                'recurrent_layers': 1, 'recurrent_units': 30, \n",
127     "#                'dense_layers': 1, 'dense_units': 30,\n",
128     "#                'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
129     "#                'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
130     "#                'bmin': 20, # Lower bound of hidden state batch reset, \n",
131     "#                'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
132     "#                'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
133     "#                'timesteps': 12,\n",
134     "#                'batch_size': 50\n",
135     "#               })"
136    ]
137   },
138   {
139    "cell_type": "markdown",
140    "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
141    "metadata": {},
142    "source": [
143     "## Handle Data"
144    ]
145   },
146   {
147    "cell_type": "code",
148    "execution_count": null,
149    "id": "adbba43e-603b-4801-8a35-35b8ccc053af",
150    "metadata": {},
151    "outputs": [],
152    "source": [
153     "rnn_dat = rnn_data_wrap(train, params)\n",
154     "\n",
155     "params.update({\n",
156     "    'loc_batch_reset': rnn_dat.n_seqs, # Used to reset hidden state when location changes for a given batch\n",
157     "    'bmax': params_data['hours'],\n",
158     "    'early_stopping_patience': 25\n",
159     "})"
160    ]
161   },
162   {
163    "cell_type": "markdown",
164    "id": "703dca05-5371-409e-b0a6-c430594bb76f",
165    "metadata": {},
166    "source": [
167     "## Non-Stateful"
168    ]
169   },
170   {
171    "cell_type": "code",
172    "execution_count": null,
173    "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
174    "metadata": {},
175    "outputs": [],
176    "source": [
177     "params.update({\n",
178     "    'stateful': False,\n",
179     "    'batch_schedule_type': None\n",
180     "})"
181    ]
182   },
183   {
184    "cell_type": "code",
185    "execution_count": null,
186    "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
187    "metadata": {},
188    "outputs": [],
189    "source": [
190     "reproducibility.set_seed(123)\n",
191     "rnn = RNN(params)\n",
192     "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
193    ]
194   },
195   {
196    "cell_type": "code",
197    "execution_count": null,
198    "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
199    "metadata": {},
200    "outputs": [],
201    "source": [
202     "errs0.mean()"
203    ]
204   },
205   {
206    "cell_type": "code",
207    "execution_count": null,
208    "id": "921bf523-d39d-40f7-8778-08f73f5c002d",
209    "metadata": {},
210    "outputs": [],
211    "source": [
212     "rnn.is_stateful()"
213    ]
214   },
215   {
216    "cell_type": "markdown",
217    "id": "13f04499-3048-430e-88b0-6010de1a00d5",
218    "metadata": {},
219    "source": [
220     "## Stateful No Batch Schedule"
221    ]
222   },
223   {
224    "cell_type": "code",
225    "execution_count": null,
226    "id": "f95c9a51-1203-4ad6-a75c-dd414821db40",
227    "metadata": {},
228    "outputs": [],
229    "source": [
230     "params.update({\n",
231     "    'stateful': True, \n",
232     "    'batch_schedule_type':None\n",
233     "})"
234    ]
235   },
236   {
237    "cell_type": "code",
238    "execution_count": null,
239    "id": "7f768878-7a3b-4cd2-a174-fce6d3039f54",
240    "metadata": {},
241    "outputs": [],
242    "source": [
243     "reproducibility.set_seed(123)\n",
244     "rnn = RNN(params)\n",
245     "m1, errs1, epochs1 = rnn.run_model(rnn_dat, return_epochs=True)"
246    ]
247   },
248   {
249    "cell_type": "markdown",
250    "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
251    "metadata": {},
252    "source": [
253     "## Constant Batch Schedule (Stateful)"
254    ]
255   },
256   {
257    "cell_type": "code",
258    "execution_count": null,
259    "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
260    "metadata": {},
261    "outputs": [],
262    "source": [
263     "params.update({\n",
264     "    'stateful': True, \n",
265     "    'batch_schedule_type':'constant', \n",
266     "    'bmin': 50})"
267    ]
268   },
269   {
270    "cell_type": "code",
271    "execution_count": null,
272    "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
273    "metadata": {},
274    "outputs": [],
275    "source": [
276     "reproducibility.set_seed(123)\n",
277     "rnn = RNN(params)\n",
278     "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
279    ]
280   },
281   {
282    "cell_type": "code",
283    "execution_count": null,
284    "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
285    "metadata": {},
286    "outputs": [],
287    "source": [
288     "errs2.mean()"
289    ]
290   },
291   {
292    "cell_type": "markdown",
293    "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
294    "metadata": {},
295    "source": [
296     "## Exp Batch Schedule (Stateful)"
297    ]
298   },
299   {
300    "cell_type": "code",
301    "execution_count": null,
302    "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
303    "metadata": {},
304    "outputs": [],
305    "source": [
306     "params.update({\n",
307     "    'stateful': True, \n",
308     "    'batch_schedule_type':'exp', \n",
309     "    'bmin': 20,\n",
310     "    'bmax': rnn_dat.hours\n",
311     "})"
312    ]
313   },
314   {
315    "cell_type": "code",
316    "execution_count": null,
317    "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
318    "metadata": {},
319    "outputs": [],
320    "source": [
321     "reproducibility.set_seed(123)\n",
322     "rnn = RNN(params)\n",
323     "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
324    ]
325   },
326   {
327    "cell_type": "code",
328    "execution_count": null,
329    "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
330    "metadata": {},
331    "outputs": [],
332    "source": [
333     "errs3.mean()"
334    ]
335   },
336   {
337    "cell_type": "markdown",
338    "id": "6655ec7d-45ec-45d4-9486-de8d15e9e380",
339    "metadata": {},
340    "source": [
341     "## Results"
342    ]
343   },
344   {
345    "cell_type": "code",
346    "execution_count": null,
347    "id": "ec25686e-e41d-4478-8a51-f657547fb3b3",
348    "metadata": {},
349    "outputs": [],
350    "source": [
351     "pd.DataFrame(\n",
352     "    {\n",
353     "        \"Method\": [\"Non-Stateful\", \"Stateful - No Reset\", \"Stateful - Const Reset\", \"Stateful - Exp Reset\"],\n",
354     "        \"RMSE\": [errs0, errs1, errs2, errs3], \n",
355     "        \"N_Epochs\": [epochs0, epochs1, epochs2, epochs3] \n",
356     "    }    \n",
357     ")"
358    ]
359   },
360   {
361    "cell_type": "code",
362    "execution_count": null,
363    "id": "7dd6d484-b104-42a4-90cf-a3503cb29ca2",
364    "metadata": {},
365    "outputs": [],
366    "source": []
367   },
368   {
369    "cell_type": "code",
370    "execution_count": null,
371    "id": "2efafcb5-7ddf-4558-8d4e-972d080efd34",
372    "metadata": {},
373    "outputs": [],
374    "source": []
375   }
376  ],
377  "metadata": {
378   "kernelspec": {
379    "display_name": "Python 3 (ipykernel)",
380    "language": "python",
381    "name": "python3"
382   },
383   "language_info": {
384    "codemirror_mode": {
385     "name": "ipython",
386     "version": 3
387    },
388    "file_extension": ".py",
389    "mimetype": "text/x-python",
390    "name": "python",
391    "nbconvert_exporter": "python",
392    "pygments_lexer": "ipython3",
393    "version": "3.12.5"
394   }
395  },
396  "nbformat": 4,
397  "nbformat_minor": 5