4 "cell_type": "markdown",
5 "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
8 "# v2.1 exploration trying to make it work better"
13 "execution_count": null,
14 "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
20 "import os.path as osp\n",
21 "import numpy as np\n",
22 "import pandas as pd\n",
23 "import tensorflow as tf\n",
24 "import matplotlib.pyplot as plt\n",
27 "sys.path.append('..')\n",
28 "import reproducibility\n",
29 "import pandas as pd\n",
30 "from utils import print_dict_summary\n",
31 "from data_funcs import rmse, build_train_dict, combine_nested, subset_by_features\n",
32 "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, rnn_data_wrap\n",
33 "from moisture_rnn_pkl import pkl2train\n",
34 "from tensorflow.keras.callbacks import Callback\n",
35 "from utils import hash2\n",
39 "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights, str2time\n",
47 "execution_count": null,
48 "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
56 "cell_type": "markdown",
57 "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
65 "execution_count": null,
66 "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
70 "file_paths = ['data/fmda_rocky_202403-05_f05.pkl']"
75 "execution_count": null,
76 "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
80 "# Params used for data filtering\n",
81 "params_data = read_yml(\"params_data.yaml\") \n",
87 "execution_count": null,
88 "id": "32a46674-8377-47f8-9c3a-e6b07f9505cf",
92 "params = read_yml(\"params.yaml\", subkey='rnn') \n",
93 "params = RNNParams(params)\n",
94 "params.update({'epochs': 200, \n",
95 " 'learning_rate': 0.001,\n",
96 " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
97 " 'recurrent_layers': 2, 'recurrent_units': 30, \n",
98 " 'dense_layers': 2, 'dense_units': 30,\n",
99 " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
100 " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
101 " 'bmin': 20, # Lower bound of hidden state batch reset, \n",
102 " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
103 " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
104 " 'timesteps': 12\n",
110 "execution_count": null,
111 "id": "c45cb8ef-41fc-4bf7-b506-dad5fd24abb3",
115 "dat = read_pkl(file_paths[0])"
120 "execution_count": null,
121 "id": "3c960d69-4f8a-4abb-a5d9-ed6cf98f899b",
125 "import importlib\n",
126 "import data_funcs\n",
127 "importlib.reload(data_funcs)\n",
128 "from data_funcs import build_train_dict"
133 "execution_count": null,
134 "id": "369cd913-85cb-4855-a80c-817d84637852",
138 "params_data.update({'hours': None})"
143 "execution_count": null,
144 "id": "371a66ac-027d-4377-b2bb-22106707a614",
148 "start_time = time.time()"
153 "execution_count": null,
154 "id": "8cdc2ce8-45b4-4caa-81d9-646271ff2e97",
160 "train3 = build_train_dict(file_paths, params_data, spatial=False, forecast_step=3, drop_na=True)\n"
165 "execution_count": null,
166 "id": "29e3289b-f47a-450d-8b17-2a7813fad089",
171 "end_time = time.time()\n",
173 "# Calculate Code Runtime\n",
174 "elapsed_time_sp = end_time - start_time\n",
175 "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
180 "execution_count": null,
181 "id": "aae373a5-f0b2-4ab9-b5df-ffa52c1bc305",
185 "from data_funcs import build_features_single"
190 "execution_count": null,
191 "id": "7e01fdde-ae5e-43d0-9a0d-7d7a024f4481",
197 "dat['PLFI1_202401'].keys()"
202 "execution_count": null,
203 "id": "765867c7-72ac-4761-9d86-2558d1f75c3a",
207 "start_time = time.time()\n",
210 " build_features_single(dat[key], atm=\"HRRR\", fstep=\"f03\", fprev=\"f02\")\n",
213 "end_time = time.time()\n",
215 "# Calculate Code Runtime\n",
216 "elapsed_time_sp = end_time - start_time\n",
217 "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
222 "execution_count": null,
223 "id": "46212800-55bd-46c6-84f9-abf395f863df",
227 "from multiprocessing import Process, Queue"
232 "execution_count": null,
233 "id": "cee55474-804c-4446-ab7b-ca95aa522089",
237 "keys = list(dat.keys())"
242 "execution_count": null,
243 "id": "bfd9aa0c-4d58-4af2-b0b2-30592fe7b4f6",
247 "def process_key(key):\n",
248 " build_features_single(dat[key], atm=\"HRRR\", fstep=\"f03\", fprev=\"f02\")"
253 "execution_count": null,
254 "id": "50836818-1961-4012-b820-9930939b3a8a",
258 "from multiprocessing import Pool"
263 "execution_count": null,
264 "id": "890b5fce-3dcc-47b9-8ab7-2651582ffdb5",
268 "if __name__ == '__main__':\n",
269 " with Pool() as pool:\n",
270 " pool.map(process_key, keys)"
275 "execution_count": null,
276 "id": "3c4548ae-caa4-4bc4-9122-9f24e7e59ef7",
283 "execution_count": null,
284 "id": "3dbb6f24-4435-47b3-90c6-6176582b0d4c",
290 "cell_type": "markdown",
291 "id": "6322f0bc-107d-40a5-96dc-804495085a99",
293 "jp-MarkdownHeadingCollapsed": true
301 "execution_count": null,
302 "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
306 "params = read_yml(\"params.yaml\", subkey='xgb')\n",
312 "execution_count": null,
313 "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
317 "dat = read_pkl(\"data/train.pkl\")"
322 "execution_count": null,
323 "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
327 "cases = [*dat.keys()]"
332 "execution_count": null,
333 "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
337 "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
338 "rnn_dat.train_test_split(\n",
339 " time_fracs = [.8, .1, .1]\n",
341 "rnn_dat.scale_data()"
346 "execution_count": null,
347 "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
351 "from moisture_models import XGB, RF, LM"
356 "execution_count": null,
357 "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
366 "execution_count": null,
367 "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
376 "execution_count": null,
377 "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
381 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
386 "execution_count": null,
387 "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
391 "preds = mod.predict(rnn_dat.X_test)"
396 "execution_count": null,
397 "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
401 "rmse(preds, rnn_dat.y_test)"
406 "execution_count": null,
407 "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
411 "plt.plot(rnn_dat.y_test)\n",
417 "execution_count": null,
418 "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
422 "params = read_yml(\"params.yaml\", subkey='rf')\n",
423 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
424 "rnn_dat.train_test_split(\n",
425 " time_fracs = [.8, .1, .1]\n",
431 "execution_count": null,
432 "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
436 "import importlib\n",
437 "import moisture_models\n",
438 "importlib.reload(moisture_models)"
443 "execution_count": null,
444 "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
453 "execution_count": null,
454 "id": "fafe76e5-0212-4bd1-a058-535935a08780",
458 "mod2 = RF(params)\n",
459 "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
460 "preds2 = mod2.predict(rnn_dat.X_test)\n",
461 "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
462 "plt.plot(rnn_dat.y_test)\n",
468 "execution_count": null,
469 "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
473 "from moisture_models import RF\n",
479 "execution_count": null,
480 "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
487 "execution_count": null,
488 "id": "c5598bfe-2d87-4d23-869e-aff127782462",
492 "params = read_yml(\"params.yaml\", subkey='lm')\n",
493 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
494 "rnn_dat.train_test_split(\n",
495 " time_fracs = [.8, .1, .1]\n",
502 "execution_count": null,
503 "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
507 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
508 "preds = mod.predict(rnn_dat.X_test)\n",
509 "print(rmse(preds2, rnn_dat.y_test.flatten()))"
514 "execution_count": null,
515 "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
522 "execution_count": null,
523 "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
529 "cell_type": "markdown",
530 "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
538 "execution_count": null,
539 "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
545 "cell_type": "markdown",
546 "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
549 "## Phys Initialized"
554 "execution_count": null,
555 "id": "5488628e-4552-4909-83e9-413fd6878bdd",
561 " 'dense_layers': 0,\n",
562 " 'activation': ['relu', 'relu'],\n",
563 " 'phys_initialize': False,\n",
564 " 'dropout': [0,0],\n",
565 " 'space_fracs': [.8, .1, .1],\n",
572 "execution_count": null,
573 "id": "ab7db7d6-949e-457d-90b9-22d9c5aa4739",
577 "import importlib\n",
578 "import moisture_rnn\n",
579 "importlib.reload(moisture_rnn)\n",
580 "from moisture_rnn import rnn_data_wrap"
585 "execution_count": null,
586 "id": "d26cf1b2-2fad-409d-888f-4921b0ae4ba8",
590 "params['scaler'] is None"
595 "execution_count": null,
596 "id": "1c4627bc-0f90-44e6-9103-2efe5c5f439d",
600 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
605 "execution_count": null,
606 "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
610 "reproducibility.set_seed()\n",
611 "rnn = RNN(params)\n",
612 "m, errs = rnn.run_model(rnn_dat)"
617 "execution_count": null,
618 "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
622 "rnn.model_train.summary()"
627 "execution_count": null,
628 "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
637 "execution_count": null,
638 "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
644 "rnn_dat.X_train[:,:,0].mean()"
649 "execution_count": null,
650 "id": "7ca41db1-72aa-44b6-b9dd-058735336ab3",
657 "execution_count": null,
658 "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
662 "rnn_dat['features_list']"
667 "execution_count": null,
668 "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
674 "cell_type": "markdown",
675 "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
683 "execution_count": null,
684 "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
688 "import importlib \n",
689 "import moisture_rnn\n",
690 "importlib.reload(moisture_rnn)\n",
691 "from moisture_rnn import RNN_LSTM"
696 "execution_count": null,
697 "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
701 "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
702 "params = RNNParams(params)"
707 "execution_count": null,
708 "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
712 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
717 "execution_count": null,
718 "id": "57bb5708-7be9-4474-abb4-3b7ff4bf79df",
723 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
729 "execution_count": null,
730 "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
734 "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
735 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
736 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
737 " 'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
738 "reproducibility.set_seed(123)\n",
739 "lstm = RNN_LSTM(params)\n",
741 "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
742 " batch_size = params['batch_size'], epochs=params['epochs'], \n",
743 " callbacks = [ResetStatesCallback(params),\n",
744 " EarlyStoppingCallback(patience = 15)],\n",
745 " validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
751 "execution_count": null,
752 "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
761 "execution_count": null,
762 "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
769 "execution_count": null,
770 "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
774 "params = RNNParams(read_yml(\"params.yaml\", subkey=\"lstm\"))\n",
775 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
776 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
777 " 'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
778 " 'early_stopping_patience': 25})\n",
779 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
781 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
783 "reproducibility.set_seed(123)\n",
784 "lstm = RNN_LSTM(params)\n",
785 "m, errs = lstm.run_model(rnn_dat)"
790 "execution_count": null,
791 "id": "be46a2dc-bf5c-4893-a1ee-a1682566f7a2",
800 "execution_count": null,
801 "id": "0f319f37-7d13-41fd-95fa-66dbdfeab588",
808 "execution_count": null,
809 "id": "b1252b08-62b9-4d24-add2-0f87d15b0ff2",
813 "params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
814 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
819 "execution_count": null,
820 "id": "9281540b-eb26-4923-883b-1b31d8347634",
824 "reproducibility.set_seed(123)\n",
825 "rnn = RNN(params)\n",
826 "m, errs = rnn.run_model(rnn_dat)"
831 "execution_count": null,
832 "id": "8a0269b4-d6b7-4f20-8386-69814d7acaa3",
841 "execution_count": null,
842 "id": "10b44de3-a0e9-49e4-9e03-873d69580c07",
849 "execution_count": null,
850 "id": "27f4fee4-7fce-49c5-a455-97a90b754c13",
857 "execution_count": null,
858 "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
866 "display_name": "Python 3 (ipykernel)",
867 "language": "python",
875 "file_extension": ".py",
876 "mimetype": "text/x-python",
878 "nbconvert_exporter": "python",
879 "pygments_lexer": "ipython3",