4 "cell_type": "markdown",
5 "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
8 "# v2.2 exploration trying to make it work better"
13 "execution_count": null,
14 "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
20 "import os.path as osp\n",
21 "import numpy as np\n",
22 "import pandas as pd\n",
23 "import tensorflow as tf\n",
24 "import matplotlib.pyplot as plt\n",
27 "sys.path.append('..')\n",
28 "import reproducibility\n",
29 "import pandas as pd\n",
30 "from utils import print_dict_summary\n",
31 "from data_funcs import rmse, build_train_dict, combine_nested, subset_by_features\n",
32 "# from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, rnn_data_wrap\n",
33 "from moisture_rnn_pkl import pkl2train\n",
34 "from tensorflow.keras.callbacks import Callback\n",
35 "from utils import hash2\n",
39 "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights, str2time\n",
47 "execution_count": null,
48 "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
56 "cell_type": "markdown",
57 "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
65 "execution_count": null,
66 "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
70 "file_paths = ['data/fmda_rocky_202403-05_f05.pkl']"
75 "execution_count": null,
76 "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
80 "# Params used for data filtering\n",
81 "params_data = read_yml(\"params_data.yaml\") \n",
87 "execution_count": null,
88 "id": "c45cb8ef-41fc-4bf7-b506-dad5fd24abb3",
92 "dat = read_pkl(file_paths[0])"
97 "execution_count": null,
98 "id": "3c960d69-4f8a-4abb-a5d9-ed6cf98f899b",
102 "import importlib\n",
103 "import data_funcs\n",
104 "importlib.reload(data_funcs)\n",
105 "from data_funcs import build_train_dict"
110 "execution_count": null,
111 "id": "369cd913-85cb-4855-a80c-817d84637852",
115 "params_data.update({'hours': None})"
120 "execution_count": null,
121 "id": "8cdc2ce8-45b4-4caa-81d9-646271ff2e97",
127 "train3 = build_train_dict(file_paths, params_data, spatial=False, forecast_step=3, drop_na=True)\n"
132 "execution_count": null,
133 "id": "3c4548ae-caa4-4bc4-9122-9f24e7e59ef7",
140 "execution_count": null,
141 "id": "3dbb6f24-4435-47b3-90c6-6176582b0d4c",
147 "cell_type": "markdown",
148 "id": "6322f0bc-107d-40a5-96dc-804495085a99",
150 "jp-MarkdownHeadingCollapsed": true
158 "execution_count": null,
159 "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
163 "params = read_yml(\"params.yaml\", subkey='xgb')\n",
169 "execution_count": null,
170 "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
174 "dat = read_pkl(\"data/train.pkl\")"
179 "execution_count": null,
180 "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
184 "cases = [*dat.keys()]"
189 "execution_count": null,
190 "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
194 "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
195 "rnn_dat.train_test_split(\n",
196 " time_fracs = [.8, .1, .1]\n",
198 "rnn_dat.scale_data()"
203 "execution_count": null,
204 "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
208 "from moisture_models import XGB, RF, LM"
213 "execution_count": null,
214 "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
223 "execution_count": null,
224 "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
233 "execution_count": null,
234 "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
238 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
243 "execution_count": null,
244 "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
248 "preds = mod.predict(rnn_dat.X_test)"
253 "execution_count": null,
254 "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
258 "rmse(preds, rnn_dat.y_test)"
263 "execution_count": null,
264 "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
268 "plt.plot(rnn_dat.y_test)\n",
274 "execution_count": null,
275 "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
279 "params = read_yml(\"params.yaml\", subkey='rf')\n",
280 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
281 "rnn_dat.train_test_split(\n",
282 " time_fracs = [.8, .1, .1]\n",
288 "execution_count": null,
289 "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
293 "import importlib\n",
294 "import moisture_models\n",
295 "importlib.reload(moisture_models)"
300 "execution_count": null,
301 "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
310 "execution_count": null,
311 "id": "fafe76e5-0212-4bd1-a058-535935a08780",
315 "mod2 = RF(params)\n",
316 "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
317 "preds2 = mod2.predict(rnn_dat.X_test)\n",
318 "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
319 "plt.plot(rnn_dat.y_test)\n",
325 "execution_count": null,
326 "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
330 "from moisture_models import RF\n",
336 "execution_count": null,
337 "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
344 "execution_count": null,
345 "id": "c5598bfe-2d87-4d23-869e-aff127782462",
349 "params = read_yml(\"params.yaml\", subkey='lm')\n",
350 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
351 "rnn_dat.train_test_split(\n",
352 " time_fracs = [.8, .1, .1]\n",
359 "execution_count": null,
360 "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
364 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
365 "preds = mod.predict(rnn_dat.X_test)\n",
366 "print(rmse(preds2, rnn_dat.y_test.flatten()))"
371 "execution_count": null,
372 "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
379 "execution_count": null,
380 "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
386 "cell_type": "markdown",
387 "id": "d6e089d9-e466-45bb-80f2-15c563ae21ad",
395 "execution_count": null,
396 "id": "3d5792a1-53e3-4099-8630-1bd5e3f52dcc",
400 "from tensorflow.keras import layers,models"
405 "execution_count": null,
406 "id": "0962428e-1124-4e1f-8500-d02b26640204",
410 "import importlib\n",
411 "import moisture_rnn\n",
412 "importlib.reload(moisture_rnn)\n",
413 "from moisture_rnn import RNN, RNNParams"
418 "execution_count": null,
419 "id": "a14f9c76-93eb-4b13-a11d-6ccb38285335",
423 "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))"
428 "execution_count": null,
429 "id": "ed3dd798-6a40-4e90-b40b-accabe49fb35",
434 " 'hidden_layers': ['lstm', 'conv1d', 'dense'],\n",
435 " 'hidden_units': [32, 32, 16],\n",
436 " 'hidden_activation': ['tanh', 'relu', 'relu'],\n",
437 " 'return_sequences': True\n",
443 "execution_count": null,
444 "id": "e559d0d7-5847-4fd0-81e4-7d3ca92147dd",
448 "import importlib\n",
449 "import moisture_rnn\n",
450 "importlib.reload(moisture_rnn)\n",
451 "from moisture_rnn import RNN, rnn_data_wrap"
456 "execution_count": null,
457 "id": "7c1627f9-f011-4159-98a2-1b5973929e71",
461 "reproducibility.set_seed()\n",
467 "execution_count": null,
468 "id": "5dbc66c0-ccb5-46c2-a073-1fa7a5be750a",
472 "mod.model_train.summary()"
477 "execution_count": null,
478 "id": "882c5872-a017-4d9c-90be-88e692dd33e8",
482 "mod.model_predict.summary()"
487 "execution_count": null,
488 "id": "30498201-3798-484d-922f-974909b195af",
492 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
497 "execution_count": null,
498 "id": "e213ffd7-d26c-41ce-8e2b-b17368fdd7a8",
503 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
509 "execution_count": null,
510 "id": "74e599b6-7f4d-4175-a5f1-de892e72ebd4",
514 "m, errs = mod.run_model(rnn_dat)"
519 "execution_count": null,
520 "id": "f894d203-d277-48f3-bb57-a610f162361f",
529 "execution_count": null,
530 "id": "b875ea70-41f9-4550-982b-88380ad1b5a0",
538 "cell_type": "markdown",
539 "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
547 "execution_count": null,
548 "id": "8c1894e3-5283-4e5e-83ae-9c386836a990",
552 "import importlib \n",
553 "import moisture_rnn\n",
554 "importlib.reload(moisture_rnn)\n",
555 "from moisture_rnn import RNN"
560 "execution_count": null,
561 "id": "aa1b690f-edaa-4c97-893c-ec9a3a615ce1",
565 "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
566 "params = RNNParams(params)\n",
568 " 'dense_layers': 2,\n",
569 " 'dense_units': 32\n",
575 "execution_count": null,
576 "id": "054ab015-4e41-4255-8b1a-843b61e3d21d",
580 "params.update({'batch_schedule_type': 'step'})"
585 "execution_count": null,
586 "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
590 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
591 "reproducibility.set_seed(123)\n",
597 "execution_count": null,
598 "id": "27d11b75-89e9-43a9-8801-7be7fb845b09",
602 "rnn.model_train.summary()"
607 "execution_count": null,
608 "id": "b9a0b3fb-aaab-4948-b6e6-824e9dcb92a7",
612 "rnn.model_predict.summary()"
617 "execution_count": null,
618 "id": "ade176b9-2844-43b6-b85e-5bb30414aa35",
627 "execution_count": null,
628 "id": "5945e6c1-6b3a-4b7d-ade2-b5788860ef18",
632 "rnn.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, validation_data=(rnn_dat.X_val, rnn_dat.y_val), \n",
633 " verbose=True, epochs=20)"
638 "execution_count": null,
639 "id": "2d123b2b-047e-4a04-b49e-6629cc22edc6",
643 "rnn.model_predict.set_weights(rnn.model_train.get_weights())"
648 "execution_count": null,
649 "id": "db57df64-d2ac-4b91-bbfc-71a5834ddf41",
653 "rnn.model_predict.summary()"
658 "execution_count": null,
659 "id": "0466887f-9833-4a6a-a0c7-a4d56f207d33",
663 "rnn_dat.X_test.shape"
668 "execution_count": null,
669 "id": "1d3e630c-db69-4603-962e-95c576b45ac9",
673 "preds = rnn.model_predict.predict(rnn_dat.X_test)"
678 "execution_count": null,
679 "id": "8b8228a9-5b6d-4de1-8968-d40277edacd2",
688 "execution_count": null,
689 "id": "8b001dd8-ffd7-4fd1-bf11-413515ddc488",
693 "rnn_dat.X_test.shape"
698 "execution_count": null,
699 "id": "f96c6dbf-6ca8-451e-abc4-b68b8116871b",
703 "squared_diff = np.square(preds - rnn_dat.y_test)\n",
704 "mse = np.mean(squared_diff, axis=(1, 2))\n",
705 "errs = np.sqrt(mse)\n",
710 "cell_type": "markdown",
711 "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
714 "## Phys Initialized"
719 "execution_count": null,
720 "id": "5488628e-4552-4909-83e9-413fd6878bdd",
726 " 'dense_layers': 0,\n",
727 " 'activation': ['relu', 'relu'],\n",
728 " 'phys_initialize': False,\n",
729 " 'dropout': [0,0],\n",
730 " 'space_fracs': [.8, .1, .1],\n",
737 "execution_count": null,
738 "id": "ab7db7d6-949e-457d-90b9-22d9c5aa4739",
742 "import importlib\n",
743 "import moisture_rnn\n",
744 "importlib.reload(moisture_rnn)\n",
745 "from moisture_rnn import rnn_data_wrap"
750 "execution_count": null,
751 "id": "d26cf1b2-2fad-409d-888f-4921b0ae4ba8",
755 "params['scaler'] is None"
760 "execution_count": null,
761 "id": "1c4627bc-0f90-44e6-9103-2efe5c5f439d",
765 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
770 "execution_count": null,
771 "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
775 "reproducibility.set_seed()\n",
776 "rnn = RNN(params)\n",
777 "m, errs = rnn.run_model(rnn_dat)"
782 "execution_count": null,
783 "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
787 "rnn.model_train.summary()"
792 "execution_count": null,
793 "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
802 "execution_count": null,
803 "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
809 "rnn_dat.X_train[:,:,0].mean()"
814 "execution_count": null,
815 "id": "7ca41db1-72aa-44b6-b9dd-058735336ab3",
822 "execution_count": null,
823 "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
827 "rnn_dat['features_list']"
832 "execution_count": null,
833 "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
839 "cell_type": "markdown",
840 "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
848 "execution_count": null,
849 "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
853 "import importlib \n",
854 "import moisture_rnn\n",
855 "importlib.reload(moisture_rnn)\n",
856 "from moisture_rnn import RNN_LSTM"
861 "execution_count": null,
862 "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
866 "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
867 "params = RNNParams(params)"
872 "execution_count": null,
873 "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
877 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
882 "execution_count": null,
883 "id": "57bb5708-7be9-4474-abb4-3b7ff4bf79df",
888 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
894 "execution_count": null,
895 "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
899 "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
900 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
901 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
902 " 'batch_schedule_type':'step', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
903 "reproducibility.set_seed(123)\n",
904 "lstm = RNN_LSTM(params)\n",
906 "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
907 " batch_size = params['batch_size'], epochs=params['epochs'], \n",
908 " callbacks = [ResetStatesCallback(params),\n",
909 " EarlyStoppingCallback(patience = 15)],\n",
910 " validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
916 "execution_count": null,
917 "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
924 "execution_count": null,
925 "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
929 "params = RNNParams(read_yml(\"params.yaml\", subkey=\"lstm\"))\n",
930 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
931 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
932 " 'batch_schedule_type':'step', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
933 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
935 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
937 "reproducibility.set_seed(123)\n",
938 "lstm = RNN_LSTM(params)\n",
939 "m, errs = lstm.run_model(rnn_dat)"
944 "execution_count": null,
945 "id": "be46a2dc-bf5c-4893-a1ee-a1682566f7a2",
954 "execution_count": null,
955 "id": "0f319f37-7d13-41fd-95fa-66dbdfeab588",
962 "execution_count": null,
963 "id": "b1252b08-62b9-4d24-add2-0f87d15b0ff2",
967 "params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
968 "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
973 "execution_count": null,
974 "id": "9281540b-eb26-4923-883b-1b31d8347634",
978 "reproducibility.set_seed(123)\n",
979 "rnn = RNN(params)\n",
980 "m, errs = rnn.run_model(rnn_dat)"
985 "execution_count": null,
986 "id": "8a0269b4-d6b7-4f20-8386-69814d7acaa3",
995 "execution_count": null,
996 "id": "10b44de3-a0e9-49e4-9e03-873d69580c07",
1002 "cell_type": "code",
1003 "execution_count": null,
1004 "id": "27f4fee4-7fce-49c5-a455-97a90b754c13",
1010 "cell_type": "code",
1011 "execution_count": null,
1012 "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
1020 "display_name": "Python 3 (ipykernel)",
1021 "language": "python",
1025 "codemirror_mode": {
1029 "file_extension": ".py",
1030 "mimetype": "text/x-python",
1032 "nbconvert_exporter": "python",
1033 "pygments_lexer": "ipython3",