4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# Compare Batch Resetting Schedules\n"
13 "execution_count": null,
14 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
18 "import numpy as np\n",
21 "import os.path as osp\n",
23 "sys.path.append('..')\n",
24 "from moisture_rnn_pkl import pkl2train\n",
25 "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
26 "from utils import hash2, read_yml, read_pkl, retrieve_url, print_dict_summary, print_first, str2time, logging_setup\n",
27 "from moisture_rnn import RNN\n",
28 "import reproducibility\n",
29 "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
30 "from moisture_models import run_augmented_kf\n",
32 "import pandas as pd\n",
33 "import matplotlib.pyplot as plt\n",
36 "import reproducibility\n",
37 "import tensorflow as tf"
42 "execution_count": null,
43 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
52 "execution_count": null,
53 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
57 "filename = \"fmda_rocky_202403-05_f05.pkl\"\n",
59 " url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
60 " dest_path = f\"../data/{filename}\")"
65 "execution_count": null,
66 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
70 "file_names=[filename]\n",
71 "file_dir='../data'\n",
72 "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
77 "execution_count": null,
78 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
82 "params = RNNParams(read_yml(\"../params.yaml\", subkey='rnn'))\n",
83 "params_data = read_yml(\"../params_data.yaml\")"
88 "execution_count": null,
89 "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
95 "params_data.update({\n",
97 " 'max_intp_time': 12,\n",
98 " 'zero_lag_threshold': 12\n",
100 "# train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)\n",
101 "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=True, verbose=True)"
106 "execution_count": null,
107 "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
111 "# from itertools import islice\n",
112 "# train = {k: train[k] for k in islice(train, 250)}"
117 "execution_count": null,
118 "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
122 "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
123 "params.update({'epochs': 200, \n",
124 " 'learning_rate': 0.001,\n",
125 " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
126 " 'recurrent_layers': 1, 'recurrent_units': 30, \n",
127 " 'dense_layers': 1, 'dense_units': 30,\n",
128 " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
129 " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
130 " 'bmin': 20, # Lower bound of hidden state batch reset, \n",
131 " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
132 " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
133 " 'timesteps': 12,\n",
134 " 'batch_size': 50\n",
139 "cell_type": "markdown",
140 "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
148 "execution_count": null,
149 "id": "adbba43e-603b-4801-8a35-35b8ccc053af",
153 "rnn_dat = rnn_data_wrap(train, params)\n",
155 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
160 "cell_type": "markdown",
161 "id": "703dca05-5371-409e-b0a6-c430594bb76f",
169 "execution_count": null,
170 "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
175 " 'stateful': False,\n",
176 " 'batch_schedule_type': None\n",
182 "execution_count": null,
183 "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
187 "reproducibility.set_seed(123)\n",
188 "rnn = RNN(params)\n",
189 "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
194 "execution_count": null,
195 "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
203 "cell_type": "markdown",
204 "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
207 "## Constant Batch Schedule (Stateful)"
212 "execution_count": null,
213 "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
218 " 'stateful': True, \n",
219 " 'batch_schedule_type':'constant', \n",
225 "execution_count": null,
226 "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
230 "reproducibility.set_seed(123)\n",
231 "rnn = RNN(params)\n",
232 "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
237 "execution_count": null,
238 "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
246 "cell_type": "markdown",
247 "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
250 "## Exp Batch Schedule (Stateful)"
255 "execution_count": null,
256 "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
261 " 'stateful': True, \n",
262 " 'batch_schedule_type':'exp', \n",
264 " 'bmax': rnn_dat.hours\n",
270 "execution_count": null,
271 "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
275 "reproducibility.set_seed(123)\n",
276 "rnn = RNN(params)\n",
277 "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
282 "execution_count": null,
283 "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
292 "execution_count": null,
293 "id": "ec25686e-e41d-4478-8a51-f657547fb3b3",
300 "execution_count": null,
301 "id": "69194c5d-eb38-41d9-b48c-3496049bf44f",
309 "display_name": "Python 3 (ipykernel)",
310 "language": "python",
318 "file_extension": ".py",
319 "mimetype": "text/x-python",
321 "nbconvert_exporter": "python",
322 "pygments_lexer": "ipython3",