Cleanups
[notebooks.git] / fmda / test_notebooks / fmda_rnn_test_batch_reset_schedule.ipynb
blob3bcf7b5e4eb46a330b1ca31139d052189cf2dc8e
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# Compare Batch Resetting Schedules\n"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "import numpy as np\n",
19     "import pickle\n",
20     "import logging\n",
21     "import os.path as osp\n",
22     "import sys\n",
23     "sys.path.append('..')\n",
24     "from moisture_rnn_pkl import pkl2train\n",
25     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
26     "from utils import hash2, read_yml, read_pkl, retrieve_url, print_dict_summary, print_first, str2time, logging_setup\n",
27     "from moisture_rnn import RNN\n",
28     "import reproducibility\n",
29     "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
30     "from moisture_models import run_augmented_kf\n",
31     "import copy\n",
32     "import pandas as pd\n",
33     "import matplotlib.pyplot as plt\n",
34     "import yaml\n",
35     "import time\n",
36     "import reproducibility\n",
37     "import tensorflow as tf"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "logging_setup()"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "filename = \"fmda_rocky_202403-05_f05.pkl\"\n",
58     "retrieve_url(\n",
59     "    url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
60     "    dest_path = f\"../data/{filename}\")"
61    ]
62   },
63   {
64    "cell_type": "code",
65    "execution_count": null,
66    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
67    "metadata": {},
68    "outputs": [],
69    "source": [
70     "file_names=[filename]\n",
71     "file_dir='../data'\n",
72     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "params = RNNParams(read_yml(\"../params.yaml\", subkey='rnn'))\n",
83     "params_data = read_yml(\"../params_data.yaml\")"
84    ]
85   },
86   {
87    "cell_type": "code",
88    "execution_count": null,
89    "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
90    "metadata": {
91     "scrolled": true
92    },
93    "outputs": [],
94    "source": [
95     "params_data.update({\n",
96     "    'hours': 2205,\n",
97     "    'max_intp_time': 12,\n",
98     "    'zero_lag_threshold': 12\n",
99     "})\n",
100     "# train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)\n",
101     "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=True, verbose=True)"
102    ]
103   },
104   {
105    "cell_type": "code",
106    "execution_count": null,
107    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
108    "metadata": {},
109    "outputs": [],
110    "source": [
111     "# from itertools import islice\n",
112     "# train = {k: train[k] for k in islice(train, 250)}"
113    ]
114   },
115   {
116    "cell_type": "code",
117    "execution_count": null,
118    "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
119    "metadata": {},
120    "outputs": [],
121    "source": [
122     "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
123     "params.update({'epochs': 200, \n",
124     "               'learning_rate': 0.001,\n",
125     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
126     "               'recurrent_layers': 1, 'recurrent_units': 30, \n",
127     "               'dense_layers': 1, 'dense_units': 30,\n",
128     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
129     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
130     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
131     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
132     "               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
133     "               'timesteps': 12,\n",
134     "               'batch_size': 50\n",
135     "              })"
136    ]
137   },
138   {
139    "cell_type": "markdown",
140    "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
141    "metadata": {},
142    "source": [
143     "## Handle Data"
144    ]
145   },
146   {
147    "cell_type": "code",
148    "execution_count": null,
149    "id": "adbba43e-603b-4801-8a35-35b8ccc053af",
150    "metadata": {},
151    "outputs": [],
152    "source": [
153     "rnn_dat = rnn_data_wrap(train, params)\n",
154     "params.update({\n",
155     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
156     "})"
157    ]
158   },
159   {
160    "cell_type": "markdown",
161    "id": "703dca05-5371-409e-b0a6-c430594bb76f",
162    "metadata": {},
163    "source": [
164     "## Non-Stateful"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "params.update({\n",
175     "    'stateful': False,\n",
176     "    'batch_schedule_type': None\n",
177     "})"
178    ]
179   },
180   {
181    "cell_type": "code",
182    "execution_count": null,
183    "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
184    "metadata": {},
185    "outputs": [],
186    "source": [
187     "reproducibility.set_seed(123)\n",
188     "rnn = RNN(params)\n",
189     "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
190    ]
191   },
192   {
193    "cell_type": "code",
194    "execution_count": null,
195    "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
196    "metadata": {},
197    "outputs": [],
198    "source": [
199     "errs0.mean()"
200    ]
201   },
202   {
203    "cell_type": "markdown",
204    "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
205    "metadata": {},
206    "source": [
207     "## Constant Batch Schedule (Stateful)"
208    ]
209   },
210   {
211    "cell_type": "code",
212    "execution_count": null,
213    "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
214    "metadata": {},
215    "outputs": [],
216    "source": [
217     "params.update({\n",
218     "    'stateful': True, \n",
219     "    'batch_schedule_type':'constant', \n",
220     "    'bmin': 20})"
221    ]
222   },
223   {
224    "cell_type": "code",
225    "execution_count": null,
226    "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
227    "metadata": {},
228    "outputs": [],
229    "source": [
230     "reproducibility.set_seed(123)\n",
231     "rnn = RNN(params)\n",
232     "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
233    ]
234   },
235   {
236    "cell_type": "code",
237    "execution_count": null,
238    "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
239    "metadata": {},
240    "outputs": [],
241    "source": [
242     "errs2.mean()"
243    ]
244   },
245   {
246    "cell_type": "markdown",
247    "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
248    "metadata": {},
249    "source": [
250     "## Exp Batch Schedule (Stateful)"
251    ]
252   },
253   {
254    "cell_type": "code",
255    "execution_count": null,
256    "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
257    "metadata": {},
258    "outputs": [],
259    "source": [
260     "params.update({\n",
261     "    'stateful': True, \n",
262     "    'batch_schedule_type':'exp', \n",
263     "    'bmin': 20,\n",
264     "    'bmax': rnn_dat.hours\n",
265     "})"
266    ]
267   },
268   {
269    "cell_type": "code",
270    "execution_count": null,
271    "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
272    "metadata": {},
273    "outputs": [],
274    "source": [
275     "reproducibility.set_seed(123)\n",
276     "rnn = RNN(params)\n",
277     "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
278    ]
279   },
280   {
281    "cell_type": "code",
282    "execution_count": null,
283    "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
284    "metadata": {},
285    "outputs": [],
286    "source": [
287     "errs3.mean()"
288    ]
289   },
290   {
291    "cell_type": "code",
292    "execution_count": null,
293    "id": "ec25686e-e41d-4478-8a51-f657547fb3b3",
294    "metadata": {},
295    "outputs": [],
296    "source": []
297   },
298   {
299    "cell_type": "code",
300    "execution_count": null,
301    "id": "69194c5d-eb38-41d9-b48c-3496049bf44f",
302    "metadata": {},
303    "outputs": [],
304    "source": []
305   }
306  ],
307  "metadata": {
308   "kernelspec": {
309    "display_name": "Python 3 (ipykernel)",
310    "language": "python",
311    "name": "python3"
312   },
313   "language_info": {
314    "codemirror_mode": {
315     "name": "ipython",
316     "version": 3
317    },
318    "file_extension": ".py",
319    "mimetype": "text/x-python",
320    "name": "python",
321    "nbconvert_exporter": "python",
322    "pygments_lexer": "ipython3",
323    "version": "3.12.5"
324   }
325  },
326  "nbformat": 4,
327  "nbformat_minor": 5