4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# Compare Batch Resetting Schedules\n"
13 "execution_count": null,
14 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
18 "import numpy as np\n",
21 "import os.path as osp\n",
23 "sys.path.append('..')\n",
24 "from moisture_rnn_pkl import pkl2train\n",
25 "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
26 "from utils import hash2, read_yml, read_pkl, retrieve_url, print_dict_summary, print_first, str2time, logging_setup\n",
27 "from moisture_rnn import RNN\n",
28 "import reproducibility\n",
29 "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
30 "from moisture_models import run_augmented_kf\n",
32 "import pandas as pd\n",
33 "import matplotlib.pyplot as plt\n",
36 "import reproducibility\n",
37 "import tensorflow as tf"
42 "execution_count": null,
43 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
52 "execution_count": null,
53 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
57 "filename = \"fmda_rocky_202311-202402_f05.pkl\"\n",
59 " url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
60 " dest_path = f\"../data/{filename}\")"
65 "execution_count": null,
66 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
70 "file_names=[filename]\n",
71 "file_dir='../data'\n",
72 "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
77 "execution_count": null,
78 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
82 "params = RNNParams(read_yml(\"../params.yaml\", subkey='rnn'))\n",
83 "params_data = read_yml(\"../params_data.yaml\")"
88 "execution_count": null,
89 "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
95 "params_data.update({\n",
97 " 'max_intp_time': 12,\n",
98 " 'zero_lag_threshold': 12\n",
100 "# train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)\n",
101 "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=True, verbose=True)"
106 "execution_count": null,
107 "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
111 "# from itertools import islice\n",
112 "# train = {k: train[k] for k in islice(train, 250)}"
117 "execution_count": null,
118 "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
122 "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
123 "# params.update({'epochs': 200, \n",
124 "# 'learning_rate': 0.001,\n",
125 "# 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
126 "# 'recurrent_layers': 1, 'recurrent_units': 30, \n",
127 "# 'dense_layers': 1, 'dense_units': 30,\n",
128 "# 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
129 "# 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
130 "# 'bmin': 20, # Lower bound of hidden state batch reset, \n",
131 "# 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
132 "# 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
133 "# 'timesteps': 12,\n",
134 "# 'batch_size': 50\n",
139 "cell_type": "markdown",
140 "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
148 "execution_count": null,
149 "id": "adbba43e-603b-4801-8a35-35b8ccc053af",
153 "rnn_dat = rnn_data_wrap(train, params)\n",
156 " 'loc_batch_reset': rnn_dat.n_seqs, # Used to reset hidden state when location changes for a given batch\n",
157 " 'bmax': params_data['hours'],\n",
158 " 'early_stopping_patience': 25\n",
163 "cell_type": "markdown",
164 "id": "703dca05-5371-409e-b0a6-c430594bb76f",
172 "execution_count": null,
173 "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
178 " 'stateful': False,\n",
179 " 'batch_schedule_type': None\n",
185 "execution_count": null,
186 "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
190 "reproducibility.set_seed(123)\n",
191 "rnn = RNN(params)\n",
192 "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
197 "execution_count": null,
198 "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
207 "execution_count": null,
208 "id": "921bf523-d39d-40f7-8778-08f73f5c002d",
216 "cell_type": "markdown",
217 "id": "13f04499-3048-430e-88b0-6010de1a00d5",
220 "## Stateful No Batch Schedule"
225 "execution_count": null,
226 "id": "f95c9a51-1203-4ad6-a75c-dd414821db40",
231 " 'stateful': True, \n",
232 " 'batch_schedule_type':None\n",
238 "execution_count": null,
239 "id": "7f768878-7a3b-4cd2-a174-fce6d3039f54",
243 "reproducibility.set_seed(123)\n",
244 "rnn = RNN(params)\n",
245 "m1, errs1, epochs1 = rnn.run_model(rnn_dat, return_epochs=True)"
249 "cell_type": "markdown",
250 "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
253 "## Constant Batch Schedule (Stateful)"
258 "execution_count": null,
259 "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
264 " 'stateful': True, \n",
265 " 'batch_schedule_type':'constant', \n",
271 "execution_count": null,
272 "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
276 "reproducibility.set_seed(123)\n",
277 "rnn = RNN(params)\n",
278 "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
283 "execution_count": null,
284 "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
292 "cell_type": "markdown",
293 "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
296 "## Exp Batch Schedule (Stateful)"
301 "execution_count": null,
302 "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
307 " 'stateful': True, \n",
308 " 'batch_schedule_type':'exp', \n",
310 " 'bmax': rnn_dat.hours\n",
316 "execution_count": null,
317 "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
321 "reproducibility.set_seed(123)\n",
322 "rnn = RNN(params)\n",
323 "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
328 "execution_count": null,
329 "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
337 "cell_type": "markdown",
338 "id": "6655ec7d-45ec-45d4-9486-de8d15e9e380",
346 "execution_count": null,
347 "id": "ec25686e-e41d-4478-8a51-f657547fb3b3",
353 " \"Method\": [\"Non-Stateful\", \"Stateful - No Reset\", \"Stateful - Const Reset\", \"Stateful - Exp Reset\"],\n",
354 " \"RMSE\": [errs0, errs1, errs2, errs3], \n",
355 " \"N_Epochs\": [epochs0, epochs1, epochs2, epochs3] \n",
362 "execution_count": null,
363 "id": "7dd6d484-b104-42a4-90cf-a3503cb29ca2",
370 "execution_count": null,
371 "id": "2efafcb5-7ddf-4558-8d4e-972d080efd34",
379 "display_name": "Python 3 (ipykernel)",
380 "language": "python",
388 "file_extension": ".py",
389 "mimetype": "text/x-python",
391 "nbconvert_exporter": "python",
392 "pygments_lexer": "ipython3",