Update other ML models, handle params better
[notebooks.git] / fmda / rnn_workshop.ipynb
blobee582dffa085871261651d81d1be74e3e921793e
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, process_train_dict\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
57    "metadata": {},
58    "source": [
59     "## Tests"
60    ]
61   },
62   {
63    "cell_type": "markdown",
64    "id": "6322f0bc-107d-40a5-96dc-804495085a99",
65    "metadata": {},
66    "source": [
67     "## Test Other ML"
68    ]
69   },
70   {
71    "cell_type": "code",
72    "execution_count": null,
73    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
74    "metadata": {},
75    "outputs": [],
76    "source": [
77     "params = read_yml(\"params.yaml\", subkey='xgb')\n",
78     "params"
79    ]
80   },
81   {
82    "cell_type": "code",
83    "execution_count": null,
84    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
85    "metadata": {},
86    "outputs": [],
87    "source": [
88     "dat = read_pkl(\"data/train.pkl\")"
89    ]
90   },
91   {
92    "cell_type": "code",
93    "execution_count": null,
94    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
95    "metadata": {},
96    "outputs": [],
97    "source": [
98     "cases = [*dat.keys()]"
99    ]
100   },
101   {
102    "cell_type": "code",
103    "execution_count": null,
104    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
105    "metadata": {},
106    "outputs": [],
107    "source": [
108     "rnn_dat = RNNData(dat[cases[0]], params['scaler'], params['features_list'])\n",
109     "rnn_dat.train_test_split(\n",
110     "    time_fracs = [.8, .1, .1]\n",
111     ")\n",
112     "rnn_dat.scale_data()"
113    ]
114   },
115   {
116    "cell_type": "code",
117    "execution_count": null,
118    "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
119    "metadata": {},
120    "outputs": [],
121    "source": [
122     "from moisture_models import XGB, RF, LM"
123    ]
124   },
125   {
126    "cell_type": "code",
127    "execution_count": null,
128    "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
129    "metadata": {},
130    "outputs": [],
131    "source": [
132     "mod = XGB(params)"
133    ]
134   },
135   {
136    "cell_type": "code",
137    "execution_count": null,
138    "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
139    "metadata": {},
140    "outputs": [],
141    "source": [
142     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
143    ]
144   },
145   {
146    "cell_type": "code",
147    "execution_count": null,
148    "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
149    "metadata": {},
150    "outputs": [],
151    "source": [
152     "preds = mod.predict(rnn_dat.X_test)"
153    ]
154   },
155   {
156    "cell_type": "code",
157    "execution_count": null,
158    "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
159    "metadata": {},
160    "outputs": [],
161    "source": [
162     "rmse(preds, rnn_dat.y_test)"
163    ]
164   },
165   {
166    "cell_type": "code",
167    "execution_count": null,
168    "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
169    "metadata": {},
170    "outputs": [],
171    "source": [
172     "plt.plot(rnn_dat.y_test)\n",
173     "plt.plot(preds)"
174    ]
175   },
176   {
177    "cell_type": "code",
178    "execution_count": null,
179    "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
180    "metadata": {},
181    "outputs": [],
182    "source": [
183     "params = read_yml(\"params.yaml\", subkey='rf')\n",
184     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
185     "rnn_dat.train_test_split(\n",
186     "    time_fracs = [.8, .1, .1]\n",
187     ")"
188    ]
189   },
190   {
191    "cell_type": "code",
192    "execution_count": null,
193    "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
194    "metadata": {},
195    "outputs": [],
196    "source": [
197     "import importlib\n",
198     "import moisture_models\n",
199     "importlib.reload(moisture_models)"
200    ]
201   },
202   {
203    "cell_type": "code",
204    "execution_count": null,
205    "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
206    "metadata": {},
207    "outputs": [],
208    "source": [
209     "params"
210    ]
211   },
212   {
213    "cell_type": "code",
214    "execution_count": null,
215    "id": "fafe76e5-0212-4bd1-a058-535935a08780",
216    "metadata": {},
217    "outputs": [],
218    "source": [
219     "mod2 = RF(params)\n",
220     "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
221     "preds2 = mod2.predict(rnn_dat.X_test)\n",
222     "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
223     "plt.plot(rnn_dat.y_test)\n",
224     "plt.plot(preds2)"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
231    "metadata": {},
232    "outputs": [],
233    "source": [
234     "from moisture_models import RF\n",
235     "mod2 = RF(params)"
236    ]
237   },
238   {
239    "cell_type": "code",
240    "execution_count": null,
241    "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
242    "metadata": {},
243    "outputs": [],
244    "source": []
245   },
246   {
247    "cell_type": "code",
248    "execution_count": null,
249    "id": "c5598bfe-2d87-4d23-869e-aff127782462",
250    "metadata": {},
251    "outputs": [],
252    "source": [
253     "params = read_yml(\"params.yaml\", subkey='lm')\n",
254     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
255     "rnn_dat.train_test_split(\n",
256     "    time_fracs = [.8, .1, .1]\n",
257     ")\n",
258     "mod = LM(params)"
259    ]
260   },
261   {
262    "cell_type": "code",
263    "execution_count": null,
264    "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
265    "metadata": {},
266    "outputs": [],
267    "source": [
268     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
269     "preds = mod.predict(rnn_dat.X_test)\n",
270     "print(rmse(preds2, rnn_dat.y_test.flatten()))"
271    ]
272   },
273   {
274    "cell_type": "code",
275    "execution_count": null,
276    "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
277    "metadata": {},
278    "outputs": [],
279    "source": []
280   },
281   {
282    "cell_type": "code",
283    "execution_count": null,
284    "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
285    "metadata": {},
286    "outputs": [],
287    "source": []
288   },
289   {
290    "cell_type": "markdown",
291    "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
292    "metadata": {},
293    "source": [
294     "## RNN"
295    ]
296   },
297   {
298    "cell_type": "code",
299    "execution_count": null,
300    "id": "96fe971b-c6d3-45ee-94ee-e4f426735d56",
301    "metadata": {},
302    "outputs": [],
303    "source": [
304     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))\n",
305     "params.update({\n",
306     "    'features_list': ['Ed', 'Ew', 'solar', 'wind', 'rain']\n",
307     "})"
308    ]
309   },
310   {
311    "cell_type": "code",
312    "execution_count": null,
313    "id": "5a55e8e7-1869-43fc-9bc6-09bd4f5a8d76",
314    "metadata": {},
315    "outputs": [],
316    "source": [
317     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
318     "rnn_dat2.train_test_split(\n",
319     "    time_fracs = [.8, .1, .1]\n",
320     ")\n",
321     "rnn_dat2.scale_data()\n",
322     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
323    ]
324   },
325   {
326    "cell_type": "code",
327    "execution_count": null,
328    "id": "aaec14ac-c6a6-4fcd-ad8e-d28143b92623",
329    "metadata": {},
330    "outputs": [],
331    "source": [
332     "reproducibility.set_seed()\n",
333     "rnn = RNN(params)\n",
334     "m, errs = rnn.run_model(rnn_dat2, plot_period=\"predict\")"
335    ]
336   },
337   {
338    "cell_type": "code",
339    "execution_count": null,
340    "id": "c79ed028-ba60-4db5-9864-d3b2c01e09c3",
341    "metadata": {},
342    "outputs": [],
343    "source": []
344   },
345   {
346    "cell_type": "code",
347    "execution_count": null,
348    "id": "3e609a7c-52ea-486e-8a9f-0192b3e41e13",
349    "metadata": {},
350    "outputs": [],
351    "source": []
352   },
353   {
354    "cell_type": "code",
355    "execution_count": null,
356    "id": "d23be7cd-0883-46e3-a573-1e19167f0fd6",
357    "metadata": {},
358    "outputs": [],
359    "source": []
360   },
361   {
362    "cell_type": "code",
363    "execution_count": null,
364    "id": "975ed5a5-1f5a-4def-996d-bf374096e6c7",
365    "metadata": {},
366    "outputs": [],
367    "source": []
368   },
369   {
370    "cell_type": "markdown",
371    "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
372    "metadata": {},
373    "source": [
374     "## Phys Initialized"
375    ]
376   },
377   {
378    "cell_type": "code",
379    "execution_count": null,
380    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
381    "metadata": {},
382    "outputs": [],
383    "source": [
384     "params.update({\n",
385     "    'epochs':100,\n",
386     "    'dense_layers': 0,\n",
387     "    'activation': ['relu', 'relu'],\n",
388     "    'phys_initialize': False,\n",
389     "    'dropout': [0,0]\n",
390     "})"
391    ]
392   },
393   {
394    "cell_type": "code",
395    "execution_count": null,
396    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
397    "metadata": {},
398    "outputs": [],
399    "source": [
400     "reproducibility.set_seed()\n",
401     "rnn = RNN(params)\n",
402     "m, errs = rnn.run_model(rnn_dat)"
403    ]
404   },
405   {
406    "cell_type": "code",
407    "execution_count": null,
408    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
409    "metadata": {},
410    "outputs": [],
411    "source": [
412     "rnn.model_train.summary()"
413    ]
414   },
415   {
416    "cell_type": "code",
417    "execution_count": null,
418    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
419    "metadata": {},
420    "outputs": [],
421    "source": []
422   },
423   {
424    "cell_type": "code",
425    "execution_count": null,
426    "id": "0aab34c7-8a09-480a-9d3e-619f7cf82b34",
427    "metadata": {},
428    "outputs": [],
429    "source": [
430     "params.update({\n",
431     "    'phys_initialize': True,\n",
432     "    'scaler': None, # TODO\n",
433     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
434     "    'activation': ['linear', 'linear'], # TODO tanh, relu the same\n",
435     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
436     "})"
437    ]
438   },
439   {
440    "cell_type": "code",
441    "execution_count": null,
442    "id": "ab549075-f71f-42ad-b36f-3d1e90247e33",
443    "metadata": {},
444    "outputs": [],
445    "source": [
446     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
447     "rnn_dat2.train_test_split(\n",
448     "    time_fracs = [.8, .1, .1]\n",
449     ")\n",
450     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
451    ]
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "195f337a-ac8a-4471-8226-94863b9385e2",
457    "metadata": {},
458    "outputs": [],
459    "source": [
460     "import importlib\n",
461     "import moisture_rnn\n",
462     "importlib.reload(moisture_rnn)\n",
463     "from moisture_rnn import RNN, RNNData"
464    ]
465   },
466   {
467    "cell_type": "code",
468    "execution_count": null,
469    "id": "9395d147-17a5-44ba-aaa2-a213ffde062b",
470    "metadata": {
471     "scrolled": true
472    },
473    "outputs": [],
474    "source": [
475     "reproducibility.set_seed()\n",
476     "\n",
477     "rnn = RNN(params)"
478    ]
479   },
480   {
481    "cell_type": "code",
482    "execution_count": null,
483    "id": "d3eebe8a-ff12-454b-81b6-6a138924f127",
484    "metadata": {},
485    "outputs": [],
486    "source": [
487     "m, errs = rnn.run_model(rnn_dat2)"
488    ]
489   },
490   {
491    "cell_type": "code",
492    "execution_count": null,
493    "id": "bcbb0159-74c5-4f56-9d69-d85a58ddbd1a",
494    "metadata": {},
495    "outputs": [],
496    "source": [
497     "rnn.model_predict.get_weights()"
498    ]
499   },
500   {
501    "cell_type": "code",
502    "execution_count": null,
503    "id": "c25f741a-6280-4cf2-8017-e56672236fdb",
504    "metadata": {},
505    "outputs": [],
506    "source": []
507   },
508   {
509    "cell_type": "code",
510    "execution_count": null,
511    "id": "e8ed2b03-6123-4bdf-9e26-ef2ce4951663",
512    "metadata": {},
513    "outputs": [],
514    "source": [
515     "params['rnn_units']"
516    ]
517   },
518   {
519    "cell_type": "code",
520    "execution_count": null,
521    "id": "e44302bf-af49-4140-ae31-54f7c88a6735",
522    "metadata": {},
523    "outputs": [],
524    "source": [
525     "params.update({\n",
526     "    'phys_initialize': True,\n",
527     "    'scaler': None, # TODO\n",
528     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
529     "    'activation': ['relu', 'relu'], # TODO tanh, relu the same\n",
530     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
531     "})"
532    ]
533   },
534   {
535    "cell_type": "code",
536    "execution_count": null,
537    "id": "9a8ac32d-551c-43e8-988e-a3b13e6d9cd9",
538    "metadata": {},
539    "outputs": [],
540    "source": [
541     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
542     "rnn_dat2.train_test_split(\n",
543     "    time_fracs = [.8, .1, .1]\n",
544     ")\n",
545     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
546    ]
547   },
548   {
549    "cell_type": "code",
550    "execution_count": null,
551    "id": "ff727da8-38fb-4fda-999b-f712b98de0df",
552    "metadata": {
553     "scrolled": true
554    },
555    "outputs": [],
556    "source": [
557     "reproducibility.set_seed()\n",
558     "\n",
559     "rnn = RNN(params)\n",
560     "m, errs = rnn.run_model(rnn_dat2)"
561    ]
562   },
563   {
564    "cell_type": "code",
565    "execution_count": null,
566    "id": "b165074c-ea88-4b4d-8e41-6b6f22b4d221",
567    "metadata": {},
568    "outputs": [],
569    "source": []
570   },
571   {
572    "cell_type": "code",
573    "execution_count": null,
574    "id": "aa5cd4e6-4441-4c77-a086-e9edefbeb83b",
575    "metadata": {},
576    "outputs": [],
577    "source": []
578   },
579   {
580    "cell_type": "code",
581    "execution_count": null,
582    "id": "7bd1e05b-5cd8-48b4-8469-4842313d6097",
583    "metadata": {},
584    "outputs": [],
585    "source": []
586   },
587   {
588    "cell_type": "code",
589    "execution_count": null,
590    "id": "b399346d-20b8-4c97-898a-606a4be98065",
591    "metadata": {},
592    "outputs": [],
593    "source": []
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "521285e6-6b6a-4d23-b688-9eb84b8eab68",
599    "metadata": {},
600    "outputs": [],
601    "source": []
602   },
603   {
604    "cell_type": "code",
605    "execution_count": null,
606    "id": "12c66af1-54fd-4398-8ee2-36eeb937c40d",
607    "metadata": {},
608    "outputs": [],
609    "source": []
610   },
611   {
612    "cell_type": "code",
613    "execution_count": null,
614    "id": "eb21fb8e-05c6-4a39-bdf1-4a57067c786d",
615    "metadata": {},
616    "outputs": [],
617    "source": []
618   },
619   {
620    "cell_type": "code",
621    "execution_count": null,
622    "id": "628a9105-ca06-44c4-ad00-13808e2f4773",
623    "metadata": {},
624    "outputs": [],
625    "source": []
626   },
627   {
628    "cell_type": "code",
629    "execution_count": null,
630    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
631    "metadata": {},
632    "outputs": [],
633    "source": []
634   },
635   {
636    "cell_type": "code",
637    "execution_count": null,
638    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
639    "metadata": {},
640    "outputs": [],
641    "source": []
642   },
643   {
644    "cell_type": "code",
645    "execution_count": null,
646    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
647    "metadata": {},
648    "outputs": [],
649    "source": []
650   },
651   {
652    "cell_type": "markdown",
653    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
654    "metadata": {},
655    "source": [
656     "## LSTM\n",
657     "\n",
658     "TODO: FIX BELOW"
659    ]
660   },
661   {
662    "cell_type": "code",
663    "execution_count": null,
664    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
665    "metadata": {},
666    "outputs": [],
667    "source": [
668     "import importlib \n",
669     "import moisture_rnn\n",
670     "importlib.reload(moisture_rnn)\n",
671     "from moisture_rnn import RNN_LSTM"
672    ]
673   },
674   {
675    "cell_type": "code",
676    "execution_count": null,
677    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
678    "metadata": {},
679    "outputs": [],
680    "source": [
681     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
682     "params = RNNParams(params)"
683    ]
684   },
685   {
686    "cell_type": "code",
687    "execution_count": null,
688    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
689    "metadata": {},
690    "outputs": [],
691    "source": [
692     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
693     "rnn_dat.train_test_split(\n",
694     "    time_fracs = [.8, .1, .1]\n",
695     ")\n",
696     "rnn_dat.scale_data()\n",
697     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
698    ]
699   },
700   {
701    "cell_type": "code",
702    "execution_count": null,
703    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
704    "metadata": {},
705    "outputs": [],
706    "source": [
707     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
708     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
709     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
710     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
711     "reproducibility.set_seed(123)\n",
712     "lstm = RNN_LSTM(params)\n",
713     "\n",
714     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
715     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
716     "                    callbacks = [ResetStatesCallback(params),\n",
717     "                                EarlyStoppingCallback(patience = 15)],\n",
718     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
719     "              "
720    ]
721   },
722   {
723    "cell_type": "code",
724    "execution_count": null,
725    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
726    "metadata": {},
727    "outputs": [],
728    "source": []
729   },
730   {
731    "cell_type": "code",
732    "execution_count": null,
733    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
734    "metadata": {},
735    "outputs": [],
736    "source": []
737   },
738   {
739    "cell_type": "code",
740    "execution_count": null,
741    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
742    "metadata": {},
743    "outputs": [],
744    "source": [
745     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
746     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
747     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
748     "              'early_stopping_patience': 25})\n",
749     "reproducibility.set_seed(123)\n",
750     "lstm = RNN_LSTM(params)\n",
751     "m, errs = lstm.run_model(rnn_dat)"
752    ]
753   },
754   {
755    "cell_type": "code",
756    "execution_count": null,
757    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
758    "metadata": {},
759    "outputs": [],
760    "source": []
761   },
762   {
763    "cell_type": "code",
764    "execution_count": null,
765    "id": "00910bd2-f050-438c-ab3b-c793b83cb5f5",
766    "metadata": {},
767    "outputs": [],
768    "source": [
769     "rnn_dat.spatial"
770    ]
771   },
772   {
773    "cell_type": "code",
774    "execution_count": null,
775    "id": "236b33e3-e864-4453-be16-cf07338c4105",
776    "metadata": {},
777    "outputs": [],
778    "source": [
779     "params = RNNParams(read_yml(\"params.yaml\", subkey='lstm'))\n",
780     "params"
781    ]
782   },
783   {
784    "cell_type": "code",
785    "execution_count": null,
786    "id": "fe2a484c-dc99-45a9-89fc-2f451bd719b5",
787    "metadata": {},
788    "outputs": [],
789    "source": [
790     "train = read_pkl(\"data/train.pkl\")"
791    ]
792   },
793   {
794    "cell_type": "code",
795    "execution_count": null,
796    "id": "07bfac87-a6d4-4dcc-8d11-adf83eafab76",
797    "metadata": {},
798    "outputs": [],
799    "source": [
800     "from itertools import islice\n",
801     "train = {k: train[k] for k in islice(train, 100)}"
802    ]
803   },
804   {
805    "cell_type": "code",
806    "execution_count": null,
807    "id": "4e26099b-f760-4047-afec-9e751d24b7a6",
808    "metadata": {},
809    "outputs": [],
810    "source": [
811     "from data_funcs import combine_nested\n",
812     "rnn_dat_sp = RNNData(\n",
813     "    combine_nested(train), # input dictionary\n",
814     "    scaler=\"standard\",  # data scaling type\n",
815     "    features_list = params['features_list'] # features for predicting outcome\n",
816     ")\n",
817     "\n",
818     "\n",
819     "rnn_dat_sp.train_test_split(   \n",
820     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
821     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
822     ")\n",
823     "rnn_dat_sp.scale_data()\n",
824     "\n",
825     "rnn_dat_sp.batch_reshape(\n",
826     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
827     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
828     ")"
829    ]
830   },
831   {
832    "cell_type": "code",
833    "execution_count": null,
834    "id": "10738795-c83b-4da3-88ba-09278caa35f8",
835    "metadata": {},
836    "outputs": [],
837    "source": [
838     "params.update({\n",
839     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
840     "})"
841    ]
842   },
843   {
844    "cell_type": "code",
845    "execution_count": null,
846    "id": "9c5d45cc-bcf0-4b6c-9c51-c4c790a2d9a5",
847    "metadata": {},
848    "outputs": [],
849    "source": [
850     "rnn_sp = RNN_LSTM(params)\n",
851     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
852    ]
853   },
854   {
855    "cell_type": "code",
856    "execution_count": null,
857    "id": "ee332ccf-4e4a-4f66-b4d6-c079dbdb1411",
858    "metadata": {},
859    "outputs": [],
860    "source": [
861     "errs.mean()"
862    ]
863   },
864   {
865    "cell_type": "code",
866    "execution_count": null,
867    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
868    "metadata": {},
869    "outputs": [],
870    "source": []
871   }
872  ],
873  "metadata": {
874   "kernelspec": {
875    "display_name": "Python 3 (ipykernel)",
876    "language": "python",
877    "name": "python3"
878   },
879   "language_info": {
880    "codemirror_mode": {
881     "name": "ipython",
882     "version": 3
883    },
884    "file_extension": ".py",
885    "mimetype": "text/x-python",
886    "name": "python",
887    "nbconvert_exporter": "python",
888    "pygments_lexer": "ipython3",
889    "version": "3.12.6"
890   }
891  },
892  "nbformat": 4,
893  "nbformat_minor": 5