4 "cell_type": "markdown",
5 "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
8 "# v2.1 exploration trying to make it work better"
13 "execution_count": null,
14 "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
20 "import os.path as osp\n",
21 "import numpy as np\n",
22 "import pandas as pd\n",
23 "import tensorflow as tf\n",
24 "import matplotlib.pyplot as plt\n",
27 "sys.path.append('..')\n",
28 "import reproducibility\n",
29 "import pandas as pd\n",
30 "from utils import print_dict_summary\n",
31 "from data_funcs import rmse, process_train_dict\n",
32 "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM\n",
33 "from moisture_rnn_pkl import pkl2train\n",
34 "from tensorflow.keras.callbacks import Callback\n",
35 "from utils import hash2\n",
39 "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
46 "execution_count": null,
47 "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
55 "cell_type": "markdown",
56 "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
64 "execution_count": null,
65 "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
69 "file_paths = ['data/fmda_nw_202401-05_f05.pkl']"
74 "execution_count": null,
75 "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
79 "# Params used for data filtering\n",
80 "params_data = read_yml(\"params_data.yaml\") \n",
86 "execution_count": null,
87 "id": "32a46674-8377-47f8-9c3a-e6b07f9505cf",
91 "params = read_yml(\"params.yaml\", subkey='rnn') \n",
92 "params = RNNParams(params)\n",
93 "params.update({'epochs': 200, \n",
94 " 'learning_rate': 0.001,\n",
95 " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
96 " 'recurrent_layers': 2, 'recurrent_units': 30, \n",
97 " 'dense_layers': 2, 'dense_units': 30,\n",
98 " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
99 " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
100 " 'bmin': 20, # Lower bound of hidden state batch reset, \n",
101 " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
102 " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
103 " 'timesteps': 12\n",
109 "execution_count": null,
110 "id": "91466d1b-3106-4b49-8ee8-47cbf58f938c",
116 "from data_funcs import build_train_dict2\n",
117 "train2 = build_train_dict2(file_paths, params_data, spatial=False)"
122 "execution_count": null,
123 "id": "117dcd97-fa2a-4324-9a0e-6f7723c9c0ce",
129 "train = process_train_dict(file_paths, atm_dict=\"HRRR\", params_data = params_data, verbose=True, spatial=False)"
134 "execution_count": null,
135 "id": "ce8e3803-44c7-4602-882c-8c69530546e7",
139 "key = \"PLFI1_202401\"\n",
140 "train[key]['features_list']"
145 "execution_count": null,
146 "id": "ca760d99-5348-4437-b30d-c5fcb252a6a1",
150 "train2[key]['features_list']"
155 "execution_count": null,
156 "id": "c3557d4b-162e-45bb-b297-4192e6b977d1",
163 " print(\"~\"*50)\n",
165 " print(np.all(train[k]['X'][:,-1] == train2[k]['X'][:,-1]))"
170 "execution_count": null,
171 "id": "d378b9fd-c92d-4c76-b039-364fb10512d9",
175 "reproducibility.set_seed(123)"
180 "execution_count": null,
181 "id": "3ec92f91-90cc-4096-a091-9ec54870ea77",
185 "from itertools import islice\n",
186 "train = {k: train[k] for k in islice(train, 250)}"
191 "execution_count": null,
192 "id": "0fb060ed-1239-44da-8c9b-923ff2004e38",
196 "from data_funcs import combine_nested"
201 "execution_count": null,
202 "id": "377ff40a-c8f3-469b-b5f5-bd46b6dc1ae1",
207 " combine_nested(train), # input dictionary\n",
208 " scaler=\"standard\", # data scaling type\n",
209 " features_list = params['features_list'] # features for predicting outcome\n",
213 "d1.train_test_split( \n",
214 " time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
215 " space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
219 "d1.batch_reshape(\n",
220 " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
221 " batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
227 "execution_count": null,
228 "id": "cc7ea493-6beb-43c6-a933-032d30b8415f",
232 "# Update Params specific to spatial training\n",
234 " 'loc_batch_reset': d1.n_seqs # Used to reset hidden state when location changes for a given batch\n",
240 "execution_count": null,
241 "id": "11c2ec92-e51d-4015-88ed-5a5c3ea4a58f",
245 "reproducibility.set_seed(123)\n",
246 "rnn_sp = RNN(params)\n",
247 "m, errs = rnn_sp.run_model(d1)"
252 "execution_count": null,
253 "id": "7b17fa96-eb8f-455f-80c3-24b384ae65e7",
262 "execution_count": null,
263 "id": "5403b69a-b0c2-45d1-be52-d232fcfbe7d9",
270 "execution_count": null,
271 "id": "baefb163-198e-4b8c-920a-3c520eba8579",
275 "from itertools import islice\n",
276 "train2 = {k: train2[k] for k in islice(train2, 250)}"
281 "execution_count": null,
282 "id": "54b3ea85-a4ed-4b9b-8fb5-4c2669100177",
287 " combine_nested(train), # input dictionary\n",
288 " scaler=\"standard\", # data scaling type\n",
289 " features_list = params['features_list'] # features for predicting outcome\n",
293 "d2.train_test_split( \n",
294 " time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
295 " space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
299 "d2.batch_reshape(\n",
300 " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
301 " batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
303 "# Update Params specific to spatial training\n",
305 " 'loc_batch_reset': d2.n_seqs # Used to reset hidden state when location changes for a given batch\n",
311 "execution_count": null,
312 "id": "aa0b0073-7b0e-4871-9c87-8d19e1c49758",
316 "reproducibility.set_seed(123)\n",
317 "rnn2 = RNN(params)\n",
318 "m2, errs2 = rnn2.run_model(d2)"
323 "execution_count": null,
324 "id": "89cdbdb5-2eee-4412-92d2-d47ef3e3549e",
333 "execution_count": null,
334 "id": "6b045231-710c-452a-bfc4-214d5e148cd8",
341 "execution_count": null,
342 "id": "05184bee-a561-4541-a3e6-7dd63cb491f7",
349 "execution_count": null,
350 "id": "b722325b-af18-402d-acda-daf4586e6bbc",
357 "execution_count": null,
358 "id": "2685870d-6b05-4228-97f9-e017b2a4d1ee",
365 "execution_count": null,
366 "id": "641aa6cd-80cf-4d62-a13c-6a7de4644778",
373 "execution_count": null,
374 "id": "c110fa06-1eb4-4f24-aca4-fd852e5297c5",
378 "from data_funcs import combine_nested"
383 "execution_count": null,
384 "id": "2143ecb6-6edb-4948-8a30-698cfaceefa2",
388 "nest = combine_nested(train)"
393 "execution_count": null,
394 "id": "f8ce412b-cc9f-4273-b423-21cb53002258",
398 "nest2 = combine_nested(train2)"
403 "execution_count": null,
404 "id": "64ec6aa5-0d44-4e84-a5e7-89b2c4b4ec1a",
413 "execution_count": null,
414 "id": "657e77d5-b0db-4a92-8171-0cc4e9796bb9",
423 "execution_count": null,
424 "id": "f4b2cd20-cac7-4870-8b51-2cf0726ff286",
431 "execution_count": null,
432 "id": "04809b38-61af-47eb-bca5-aed80167e0ec",
439 "execution_count": null,
440 "id": "8404317e-f13e-4758-82cf-07549ee9efc1",
447 "execution_count": null,
448 "id": "e73de0d3-b57b-41e4-8ea1-fe7d4ac69c9b",
455 "execution_count": null,
456 "id": "3582f92a-bf5b-45b7-b8ae-ea50f7ae46cd",
463 "execution_count": null,
464 "id": "a54246b4-f093-4c4f-be6b-dbe9d7a8a3fd",
470 "cell_type": "markdown",
471 "id": "6322f0bc-107d-40a5-96dc-804495085a99",
473 "jp-MarkdownHeadingCollapsed": true
481 "execution_count": null,
482 "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
486 "params = read_yml(\"params.yaml\", subkey='xgb')\n",
492 "execution_count": null,
493 "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
497 "dat = read_pkl(\"data/train.pkl\")"
502 "execution_count": null,
503 "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
507 "cases = [*dat.keys()]"
512 "execution_count": null,
513 "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
517 "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
518 "rnn_dat.train_test_split(\n",
519 " time_fracs = [.8, .1, .1]\n",
521 "rnn_dat.scale_data()"
526 "execution_count": null,
527 "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
531 "from moisture_models import XGB, RF, LM"
536 "execution_count": null,
537 "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
546 "execution_count": null,
547 "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
556 "execution_count": null,
557 "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
561 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
566 "execution_count": null,
567 "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
571 "preds = mod.predict(rnn_dat.X_test)"
576 "execution_count": null,
577 "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
581 "rmse(preds, rnn_dat.y_test)"
586 "execution_count": null,
587 "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
591 "plt.plot(rnn_dat.y_test)\n",
597 "execution_count": null,
598 "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
602 "params = read_yml(\"params.yaml\", subkey='rf')\n",
603 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
604 "rnn_dat.train_test_split(\n",
605 " time_fracs = [.8, .1, .1]\n",
611 "execution_count": null,
612 "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
616 "import importlib\n",
617 "import moisture_models\n",
618 "importlib.reload(moisture_models)"
623 "execution_count": null,
624 "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
633 "execution_count": null,
634 "id": "fafe76e5-0212-4bd1-a058-535935a08780",
638 "mod2 = RF(params)\n",
639 "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
640 "preds2 = mod2.predict(rnn_dat.X_test)\n",
641 "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
642 "plt.plot(rnn_dat.y_test)\n",
648 "execution_count": null,
649 "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
653 "from moisture_models import RF\n",
659 "execution_count": null,
660 "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
667 "execution_count": null,
668 "id": "c5598bfe-2d87-4d23-869e-aff127782462",
672 "params = read_yml(\"params.yaml\", subkey='lm')\n",
673 "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
674 "rnn_dat.train_test_split(\n",
675 " time_fracs = [.8, .1, .1]\n",
682 "execution_count": null,
683 "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
687 "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
688 "preds = mod.predict(rnn_dat.X_test)\n",
689 "print(rmse(preds2, rnn_dat.y_test.flatten()))"
694 "execution_count": null,
695 "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
702 "execution_count": null,
703 "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
709 "cell_type": "markdown",
710 "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
718 "execution_count": null,
719 "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
725 "cell_type": "markdown",
726 "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
729 "## Phys Initialized"
734 "execution_count": null,
735 "id": "5488628e-4552-4909-83e9-413fd6878bdd",
741 " 'dense_layers': 0,\n",
742 " 'activation': ['relu', 'relu'],\n",
743 " 'phys_initialize': False,\n",
744 " 'dropout': [0,0]\n",
750 "execution_count": null,
751 "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
755 "reproducibility.set_seed()\n",
756 "rnn = RNN(params)\n",
757 "m, errs = rnn.run_model(rnn_dat)"
762 "execution_count": null,
763 "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
767 "rnn.model_train.summary()"
772 "execution_count": null,
773 "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
780 "execution_count": null,
781 "id": "0aab34c7-8a09-480a-9d3e-619f7cf82b34",
786 " 'phys_initialize': True,\n",
787 " 'scaler': None, # TODO\n",
788 " 'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
789 " 'activation': ['linear', 'linear'], # TODO tanh, relu the same\n",
790 " 'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
796 "execution_count": null,
797 "id": "ab549075-f71f-42ad-b36f-3d1e90247e33",
801 "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
802 "rnn_dat2.train_test_split(\n",
803 " time_fracs = [.8, .1, .1]\n",
805 "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
810 "execution_count": null,
811 "id": "195f337a-ac8a-4471-8226-94863b9385e2",
815 "import importlib\n",
816 "import moisture_rnn\n",
817 "importlib.reload(moisture_rnn)\n",
818 "from moisture_rnn import RNN, RNNData"
823 "execution_count": null,
824 "id": "9395d147-17a5-44ba-aaa2-a213ffde062b",
830 "reproducibility.set_seed()\n",
837 "execution_count": null,
838 "id": "d3eebe8a-ff12-454b-81b6-6a138924f127",
842 "m, errs = rnn.run_model(rnn_dat2)"
847 "execution_count": null,
848 "id": "bcbb0159-74c5-4f56-9d69-d85a58ddbd1a",
852 "rnn.model_predict.get_weights()"
857 "execution_count": null,
858 "id": "c25f741a-6280-4cf2-8017-e56672236fdb",
865 "execution_count": null,
866 "id": "e8ed2b03-6123-4bdf-9e26-ef2ce4951663",
870 "params['rnn_units']"
875 "execution_count": null,
876 "id": "e44302bf-af49-4140-ae31-54f7c88a6735",
881 " 'phys_initialize': True,\n",
882 " 'scaler': None, # TODO\n",
883 " 'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
884 " 'activation': ['relu', 'relu'], # TODO tanh, relu the same\n",
885 " 'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
891 "execution_count": null,
892 "id": "9a8ac32d-551c-43e8-988e-a3b13e6d9cd9",
896 "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
897 "rnn_dat2.train_test_split(\n",
898 " time_fracs = [.8, .1, .1]\n",
900 "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
905 "execution_count": null,
906 "id": "ff727da8-38fb-4fda-999b-f712b98de0df",
912 "reproducibility.set_seed()\n",
914 "rnn = RNN(params)\n",
915 "m, errs = rnn.run_model(rnn_dat2)"
920 "execution_count": null,
921 "id": "b165074c-ea88-4b4d-8e41-6b6f22b4d221",
928 "execution_count": null,
929 "id": "aa5cd4e6-4441-4c77-a086-e9edefbeb83b",
936 "execution_count": null,
937 "id": "7bd1e05b-5cd8-48b4-8469-4842313d6097",
944 "execution_count": null,
945 "id": "b399346d-20b8-4c97-898a-606a4be98065",
952 "execution_count": null,
953 "id": "521285e6-6b6a-4d23-b688-9eb84b8eab68",
960 "execution_count": null,
961 "id": "12c66af1-54fd-4398-8ee2-36eeb937c40d",
968 "execution_count": null,
969 "id": "eb21fb8e-05c6-4a39-bdf1-4a57067c786d",
976 "execution_count": null,
977 "id": "628a9105-ca06-44c4-ad00-13808e2f4773",
984 "execution_count": null,
985 "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
992 "execution_count": null,
993 "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
1000 "execution_count": null,
1001 "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
1007 "cell_type": "markdown",
1008 "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
1017 "cell_type": "code",
1018 "execution_count": null,
1019 "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
1023 "import importlib \n",
1024 "import moisture_rnn\n",
1025 "importlib.reload(moisture_rnn)\n",
1026 "from moisture_rnn import RNN_LSTM"
1030 "cell_type": "code",
1031 "execution_count": null,
1032 "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
1036 "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
1037 "params = RNNParams(params)"
1041 "cell_type": "code",
1042 "execution_count": null,
1043 "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
1047 "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
1048 "rnn_dat.train_test_split(\n",
1049 " time_fracs = [.8, .1, .1]\n",
1051 "rnn_dat.scale_data()\n",
1052 "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
1056 "cell_type": "code",
1057 "execution_count": null,
1058 "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
1062 "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
1063 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
1064 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
1065 " 'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
1066 "reproducibility.set_seed(123)\n",
1067 "lstm = RNN_LSTM(params)\n",
1069 "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
1070 " batch_size = params['batch_size'], epochs=params['epochs'], \n",
1071 " callbacks = [ResetStatesCallback(params),\n",
1072 " EarlyStoppingCallback(patience = 15)],\n",
1073 " validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
1078 "cell_type": "code",
1079 "execution_count": null,
1080 "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
1086 "cell_type": "code",
1087 "execution_count": null,
1088 "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
1094 "cell_type": "code",
1095 "execution_count": null,
1096 "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
1100 "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
1101 " 'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
1102 " 'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
1103 " 'early_stopping_patience': 25})\n",
1104 "reproducibility.set_seed(123)\n",
1105 "lstm = RNN_LSTM(params)\n",
1106 "m, errs = lstm.run_model(rnn_dat)"
1110 "cell_type": "code",
1111 "execution_count": null,
1112 "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
1118 "cell_type": "code",
1119 "execution_count": null,
1120 "id": "00910bd2-f050-438c-ab3b-c793b83cb5f5",
1128 "cell_type": "code",
1129 "execution_count": null,
1130 "id": "236b33e3-e864-4453-be16-cf07338c4105",
1134 "params = RNNParams(read_yml(\"params.yaml\", subkey='lstm'))\n",
1139 "cell_type": "code",
1140 "execution_count": null,
1141 "id": "fe2a484c-dc99-45a9-89fc-2f451bd719b5",
1145 "train = read_pkl(\"data/train.pkl\")"
1149 "cell_type": "code",
1150 "execution_count": null,
1151 "id": "07bfac87-a6d4-4dcc-8d11-adf83eafab76",
1155 "from itertools import islice\n",
1156 "train = {k: train[k] for k in islice(train, 100)}"
1160 "cell_type": "code",
1161 "execution_count": null,
1162 "id": "4e26099b-f760-4047-afec-9e751d24b7a6",
1166 "from data_funcs import combine_nested\n",
1167 "rnn_dat_sp = RNNData(\n",
1168 " combine_nested(train), # input dictionary\n",
1169 " scaler=\"standard\", # data scaling type\n",
1170 " features_list = params['features_list'] # features for predicting outcome\n",
1174 "rnn_dat_sp.train_test_split( \n",
1175 " time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
1176 " space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
1178 "rnn_dat_sp.scale_data()\n",
1180 "rnn_dat_sp.batch_reshape(\n",
1181 " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
1182 " batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
1187 "cell_type": "code",
1188 "execution_count": null,
1189 "id": "10738795-c83b-4da3-88ba-09278caa35f8",
1193 "params.update({\n",
1194 " 'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
1199 "cell_type": "code",
1200 "execution_count": null,
1201 "id": "9c5d45cc-bcf0-4b6c-9c51-c4c790a2d9a5",
1205 "rnn_sp = RNN_LSTM(params)\n",
1206 "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
1210 "cell_type": "code",
1211 "execution_count": null,
1212 "id": "ee332ccf-4e4a-4f66-b4d6-c079dbdb1411",
1220 "cell_type": "code",
1221 "execution_count": null,
1222 "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
1230 "display_name": "Python 3 (ipykernel)",
1231 "language": "python",
1235 "codemirror_mode": {
1239 "file_extension": ".py",
1240 "mimetype": "text/x-python",
1242 "nbconvert_exporter": "python",
1243 "pygments_lexer": "ipython3",