Update data wrapper
[notebooks.git] / fmda / rnn_workshop.ipynb
blob421e4b02ddae824925acb58fec0bd8f9de3938c4
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, build_train_dict, combine_nested, subset_by_features\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, rnn_data_wrap\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights, str2time\n",
40     "import yaml\n",
41     "import copy\n",
42     "import time"
43    ]
44   },
45   {
46    "cell_type": "code",
47    "execution_count": null,
48    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
49    "metadata": {},
50    "outputs": [],
51    "source": [
52     "logging_setup()"
53    ]
54   },
55   {
56    "cell_type": "markdown",
57    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
58    "metadata": {},
59    "source": [
60     "## Test Data"
61    ]
62   },
63   {
64    "cell_type": "code",
65    "execution_count": null,
66    "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
67    "metadata": {},
68    "outputs": [],
69    "source": [
70     "file_paths = ['data/fmda_rocky_202403-05_f05.pkl']"
71    ]
72   },
73   {
74    "cell_type": "code",
75    "execution_count": null,
76    "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
77    "metadata": {},
78    "outputs": [],
79    "source": [
80     "# Params used for data filtering\n",
81     "params_data = read_yml(\"params_data.yaml\") \n",
82     "params_data"
83    ]
84   },
85   {
86    "cell_type": "code",
87    "execution_count": null,
88    "id": "32a46674-8377-47f8-9c3a-e6b07f9505cf",
89    "metadata": {},
90    "outputs": [],
91    "source": [
92     "params = read_yml(\"params.yaml\", subkey='rnn') \n",
93     "params = RNNParams(params)\n",
94     "params.update({'epochs': 200, \n",
95     "               'learning_rate': 0.001,\n",
96     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
97     "               'recurrent_layers': 2, 'recurrent_units': 30, \n",
98     "               'dense_layers': 2, 'dense_units': 30,\n",
99     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
100     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
101     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
102     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
103     "               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
104     "               'timesteps': 12\n",
105     "              })"
106    ]
107   },
108   {
109    "cell_type": "code",
110    "execution_count": null,
111    "id": "c45cb8ef-41fc-4bf7-b506-dad5fd24abb3",
112    "metadata": {},
113    "outputs": [],
114    "source": [
115     "dat = read_pkl(file_paths[0])"
116    ]
117   },
118   {
119    "cell_type": "code",
120    "execution_count": null,
121    "id": "3c960d69-4f8a-4abb-a5d9-ed6cf98f899b",
122    "metadata": {},
123    "outputs": [],
124    "source": [
125     "import importlib\n",
126     "import data_funcs\n",
127     "importlib.reload(data_funcs)\n",
128     "from data_funcs import build_train_dict"
129    ]
130   },
131   {
132    "cell_type": "code",
133    "execution_count": null,
134    "id": "369cd913-85cb-4855-a80c-817d84637852",
135    "metadata": {},
136    "outputs": [],
137    "source": [
138     "params_data.update({'hours': None})"
139    ]
140   },
141   {
142    "cell_type": "code",
143    "execution_count": null,
144    "id": "371a66ac-027d-4377-b2bb-22106707a614",
145    "metadata": {},
146    "outputs": [],
147    "source": [
148     "start_time = time.time()"
149    ]
150   },
151   {
152    "cell_type": "code",
153    "execution_count": null,
154    "id": "8cdc2ce8-45b4-4caa-81d9-646271ff2e97",
155    "metadata": {
156     "scrolled": true
157    },
158    "outputs": [],
159    "source": [
160     "train3 = build_train_dict(file_paths, params_data, spatial=False, forecast_step=3, drop_na=True)\n"
161    ]
162   },
163   {
164    "cell_type": "code",
165    "execution_count": null,
166    "id": "29e3289b-f47a-450d-8b17-2a7813fad089",
167    "metadata": {},
168    "outputs": [],
169    "source": [
170     "# End Timer\n",
171     "end_time = time.time()\n",
172     "\n",
173     "# Calculate Code Runtime\n",
174     "elapsed_time_sp = end_time - start_time\n",
175     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
176    ]
177   },
178   {
179    "cell_type": "code",
180    "execution_count": null,
181    "id": "aae373a5-f0b2-4ab9-b5df-ffa52c1bc305",
182    "metadata": {},
183    "outputs": [],
184    "source": [
185     "from data_funcs import build_features_single"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "7e01fdde-ae5e-43d0-9a0d-7d7a024f4481",
192    "metadata": {
193     "scrolled": true
194    },
195    "outputs": [],
196    "source": [
197     "dat['PLFI1_202401'].keys()"
198    ]
199   },
200   {
201    "cell_type": "code",
202    "execution_count": null,
203    "id": "765867c7-72ac-4761-9d86-2558d1f75c3a",
204    "metadata": {},
205    "outputs": [],
206    "source": [
207     "start_time = time.time()\n",
208     "\n",
209     "for key in dat:\n",
210     "    build_features_single(dat[key], atm=\"HRRR\", fstep=\"f03\", fprev=\"f02\")\n",
211     "\n",
212     "# End Timer\n",
213     "end_time = time.time()\n",
214     "\n",
215     "# Calculate Code Runtime\n",
216     "elapsed_time_sp = end_time - start_time\n",
217     "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")"
218    ]
219   },
220   {
221    "cell_type": "code",
222    "execution_count": null,
223    "id": "46212800-55bd-46c6-84f9-abf395f863df",
224    "metadata": {},
225    "outputs": [],
226    "source": [
227     "from multiprocessing import Process, Queue"
228    ]
229   },
230   {
231    "cell_type": "code",
232    "execution_count": null,
233    "id": "cee55474-804c-4446-ab7b-ca95aa522089",
234    "metadata": {},
235    "outputs": [],
236    "source": [
237     "keys = list(dat.keys())"
238    ]
239   },
240   {
241    "cell_type": "code",
242    "execution_count": null,
243    "id": "bfd9aa0c-4d58-4af2-b0b2-30592fe7b4f6",
244    "metadata": {},
245    "outputs": [],
246    "source": [
247     "def process_key(key):\n",
248     "    build_features_single(dat[key], atm=\"HRRR\", fstep=\"f03\", fprev=\"f02\")"
249    ]
250   },
251   {
252    "cell_type": "code",
253    "execution_count": null,
254    "id": "50836818-1961-4012-b820-9930939b3a8a",
255    "metadata": {},
256    "outputs": [],
257    "source": [
258     "from multiprocessing import Pool"
259    ]
260   },
261   {
262    "cell_type": "code",
263    "execution_count": null,
264    "id": "890b5fce-3dcc-47b9-8ab7-2651582ffdb5",
265    "metadata": {},
266    "outputs": [],
267    "source": [
268     "if __name__ == '__main__':\n",
269     "    with Pool() as pool:\n",
270     "        pool.map(process_key, keys)"
271    ]
272   },
273   {
274    "cell_type": "code",
275    "execution_count": null,
276    "id": "3c4548ae-caa4-4bc4-9122-9f24e7e59ef7",
277    "metadata": {},
278    "outputs": [],
279    "source": []
280   },
281   {
282    "cell_type": "code",
283    "execution_count": null,
284    "id": "3dbb6f24-4435-47b3-90c6-6176582b0d4c",
285    "metadata": {},
286    "outputs": [],
287    "source": []
288   },
289   {
290    "cell_type": "markdown",
291    "id": "6322f0bc-107d-40a5-96dc-804495085a99",
292    "metadata": {
293     "jp-MarkdownHeadingCollapsed": true
294    },
295    "source": [
296     "## Test Other ML"
297    ]
298   },
299   {
300    "cell_type": "code",
301    "execution_count": null,
302    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
303    "metadata": {},
304    "outputs": [],
305    "source": [
306     "params = read_yml(\"params.yaml\", subkey='xgb')\n",
307     "params"
308    ]
309   },
310   {
311    "cell_type": "code",
312    "execution_count": null,
313    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
314    "metadata": {},
315    "outputs": [],
316    "source": [
317     "dat = read_pkl(\"data/train.pkl\")"
318    ]
319   },
320   {
321    "cell_type": "code",
322    "execution_count": null,
323    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
324    "metadata": {},
325    "outputs": [],
326    "source": [
327     "cases = [*dat.keys()]"
328    ]
329   },
330   {
331    "cell_type": "code",
332    "execution_count": null,
333    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
334    "metadata": {},
335    "outputs": [],
336    "source": [
337     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
338     "rnn_dat.train_test_split(\n",
339     "    time_fracs = [.8, .1, .1]\n",
340     ")\n",
341     "rnn_dat.scale_data()"
342    ]
343   },
344   {
345    "cell_type": "code",
346    "execution_count": null,
347    "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
348    "metadata": {},
349    "outputs": [],
350    "source": [
351     "from moisture_models import XGB, RF, LM"
352    ]
353   },
354   {
355    "cell_type": "code",
356    "execution_count": null,
357    "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
358    "metadata": {},
359    "outputs": [],
360    "source": [
361     "mod = XGB(params)"
362    ]
363   },
364   {
365    "cell_type": "code",
366    "execution_count": null,
367    "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
368    "metadata": {},
369    "outputs": [],
370    "source": [
371     "mod.params"
372    ]
373   },
374   {
375    "cell_type": "code",
376    "execution_count": null,
377    "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
378    "metadata": {},
379    "outputs": [],
380    "source": [
381     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
382    ]
383   },
384   {
385    "cell_type": "code",
386    "execution_count": null,
387    "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
388    "metadata": {},
389    "outputs": [],
390    "source": [
391     "preds = mod.predict(rnn_dat.X_test)"
392    ]
393   },
394   {
395    "cell_type": "code",
396    "execution_count": null,
397    "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
398    "metadata": {},
399    "outputs": [],
400    "source": [
401     "rmse(preds, rnn_dat.y_test)"
402    ]
403   },
404   {
405    "cell_type": "code",
406    "execution_count": null,
407    "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
408    "metadata": {},
409    "outputs": [],
410    "source": [
411     "plt.plot(rnn_dat.y_test)\n",
412     "plt.plot(preds)"
413    ]
414   },
415   {
416    "cell_type": "code",
417    "execution_count": null,
418    "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
419    "metadata": {},
420    "outputs": [],
421    "source": [
422     "params = read_yml(\"params.yaml\", subkey='rf')\n",
423     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
424     "rnn_dat.train_test_split(\n",
425     "    time_fracs = [.8, .1, .1]\n",
426     ")"
427    ]
428   },
429   {
430    "cell_type": "code",
431    "execution_count": null,
432    "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
433    "metadata": {},
434    "outputs": [],
435    "source": [
436     "import importlib\n",
437     "import moisture_models\n",
438     "importlib.reload(moisture_models)"
439    ]
440   },
441   {
442    "cell_type": "code",
443    "execution_count": null,
444    "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
445    "metadata": {},
446    "outputs": [],
447    "source": [
448     "params"
449    ]
450   },
451   {
452    "cell_type": "code",
453    "execution_count": null,
454    "id": "fafe76e5-0212-4bd1-a058-535935a08780",
455    "metadata": {},
456    "outputs": [],
457    "source": [
458     "mod2 = RF(params)\n",
459     "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
460     "preds2 = mod2.predict(rnn_dat.X_test)\n",
461     "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
462     "plt.plot(rnn_dat.y_test)\n",
463     "plt.plot(preds2)"
464    ]
465   },
466   {
467    "cell_type": "code",
468    "execution_count": null,
469    "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
470    "metadata": {},
471    "outputs": [],
472    "source": [
473     "from moisture_models import RF\n",
474     "mod2 = RF(params)"
475    ]
476   },
477   {
478    "cell_type": "code",
479    "execution_count": null,
480    "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
481    "metadata": {},
482    "outputs": [],
483    "source": []
484   },
485   {
486    "cell_type": "code",
487    "execution_count": null,
488    "id": "c5598bfe-2d87-4d23-869e-aff127782462",
489    "metadata": {},
490    "outputs": [],
491    "source": [
492     "params = read_yml(\"params.yaml\", subkey='lm')\n",
493     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
494     "rnn_dat.train_test_split(\n",
495     "    time_fracs = [.8, .1, .1]\n",
496     ")\n",
497     "mod = LM(params)"
498    ]
499   },
500   {
501    "cell_type": "code",
502    "execution_count": null,
503    "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
504    "metadata": {},
505    "outputs": [],
506    "source": [
507     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
508     "preds = mod.predict(rnn_dat.X_test)\n",
509     "print(rmse(preds2, rnn_dat.y_test.flatten()))"
510    ]
511   },
512   {
513    "cell_type": "code",
514    "execution_count": null,
515    "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
516    "metadata": {},
517    "outputs": [],
518    "source": []
519   },
520   {
521    "cell_type": "code",
522    "execution_count": null,
523    "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
524    "metadata": {},
525    "outputs": [],
526    "source": []
527   },
528   {
529    "cell_type": "markdown",
530    "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
531    "metadata": {},
532    "source": [
533     "## RNN"
534    ]
535   },
536   {
537    "cell_type": "code",
538    "execution_count": null,
539    "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
540    "metadata": {},
541    "outputs": [],
542    "source": []
543   },
544   {
545    "cell_type": "markdown",
546    "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
547    "metadata": {},
548    "source": [
549     "## Phys Initialized"
550    ]
551   },
552   {
553    "cell_type": "code",
554    "execution_count": null,
555    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
556    "metadata": {},
557    "outputs": [],
558    "source": [
559     "params.update({\n",
560     "    'epochs':100,\n",
561     "    'dense_layers': 0,\n",
562     "    'activation': ['relu', 'relu'],\n",
563     "    'phys_initialize': False,\n",
564     "    'dropout': [0,0],\n",
565     "    'space_fracs': [.8, .1, .1],\n",
566     "    'scaler': None\n",
567     "})"
568    ]
569   },
570   {
571    "cell_type": "code",
572    "execution_count": null,
573    "id": "ab7db7d6-949e-457d-90b9-22d9c5aa4739",
574    "metadata": {},
575    "outputs": [],
576    "source": [
577     "import importlib\n",
578     "import moisture_rnn\n",
579     "importlib.reload(moisture_rnn)\n",
580     "from moisture_rnn import rnn_data_wrap"
581    ]
582   },
583   {
584    "cell_type": "code",
585    "execution_count": null,
586    "id": "d26cf1b2-2fad-409d-888f-4921b0ae4ba8",
587    "metadata": {},
588    "outputs": [],
589    "source": [
590     "params['scaler'] is None"
591    ]
592   },
593   {
594    "cell_type": "code",
595    "execution_count": null,
596    "id": "1c4627bc-0f90-44e6-9103-2efe5c5f439d",
597    "metadata": {},
598    "outputs": [],
599    "source": [
600     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
601    ]
602   },
603   {
604    "cell_type": "code",
605    "execution_count": null,
606    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
607    "metadata": {},
608    "outputs": [],
609    "source": [
610     "reproducibility.set_seed()\n",
611     "rnn = RNN(params)\n",
612     "m, errs = rnn.run_model(rnn_dat)"
613    ]
614   },
615   {
616    "cell_type": "code",
617    "execution_count": null,
618    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
619    "metadata": {},
620    "outputs": [],
621    "source": [
622     "rnn.model_train.summary()"
623    ]
624   },
625   {
626    "cell_type": "code",
627    "execution_count": null,
628    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
629    "metadata": {},
630    "outputs": [],
631    "source": [
632     "errs.mean()"
633    ]
634   },
635   {
636    "cell_type": "code",
637    "execution_count": null,
638    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
639    "metadata": {
640     "scrolled": true
641    },
642    "outputs": [],
643    "source": [
644     "rnn_dat.X_train[:,:,0].mean()"
645    ]
646   },
647   {
648    "cell_type": "code",
649    "execution_count": null,
650    "id": "7ca41db1-72aa-44b6-b9dd-058735336ab3",
651    "metadata": {},
652    "outputs": [],
653    "source": []
654   },
655   {
656    "cell_type": "code",
657    "execution_count": null,
658    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
659    "metadata": {},
660    "outputs": [],
661    "source": [
662     "rnn_dat['features_list']"
663    ]
664   },
665   {
666    "cell_type": "code",
667    "execution_count": null,
668    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
669    "metadata": {},
670    "outputs": [],
671    "source": []
672   },
673   {
674    "cell_type": "markdown",
675    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
676    "metadata": {},
677    "source": [
678     "## LSTM"
679    ]
680   },
681   {
682    "cell_type": "code",
683    "execution_count": null,
684    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
685    "metadata": {},
686    "outputs": [],
687    "source": [
688     "import importlib \n",
689     "import moisture_rnn\n",
690     "importlib.reload(moisture_rnn)\n",
691     "from moisture_rnn import RNN_LSTM"
692    ]
693   },
694   {
695    "cell_type": "code",
696    "execution_count": null,
697    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
698    "metadata": {},
699    "outputs": [],
700    "source": [
701     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
702     "params = RNNParams(params)"
703    ]
704   },
705   {
706    "cell_type": "code",
707    "execution_count": null,
708    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
709    "metadata": {},
710    "outputs": [],
711    "source": [
712     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
713    ]
714   },
715   {
716    "cell_type": "code",
717    "execution_count": null,
718    "id": "57bb5708-7be9-4474-abb4-3b7ff4bf79df",
719    "metadata": {},
720    "outputs": [],
721    "source": [
722     "params.update({\n",
723     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
724     "})"
725    ]
726   },
727   {
728    "cell_type": "code",
729    "execution_count": null,
730    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
731    "metadata": {},
732    "outputs": [],
733    "source": [
734     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
735     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
736     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
737     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
738     "reproducibility.set_seed(123)\n",
739     "lstm = RNN_LSTM(params)\n",
740     "\n",
741     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
742     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
743     "                    callbacks = [ResetStatesCallback(params),\n",
744     "                                EarlyStoppingCallback(patience = 15)],\n",
745     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
746     "              "
747    ]
748   },
749   {
750    "cell_type": "code",
751    "execution_count": null,
752    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
753    "metadata": {},
754    "outputs": [],
755    "source": [
756     "errs.mean()"
757    ]
758   },
759   {
760    "cell_type": "code",
761    "execution_count": null,
762    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
763    "metadata": {},
764    "outputs": [],
765    "source": []
766   },
767   {
768    "cell_type": "code",
769    "execution_count": null,
770    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
771    "metadata": {},
772    "outputs": [],
773    "source": [
774     "params = RNNParams(read_yml(\"params.yaml\", subkey=\"lstm\"))\n",
775     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
776     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
777     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
778     "              'early_stopping_patience': 25})\n",
779     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
780     "params.update({\n",
781     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
782     "})\n",
783     "reproducibility.set_seed(123)\n",
784     "lstm = RNN_LSTM(params)\n",
785     "m, errs = lstm.run_model(rnn_dat)"
786    ]
787   },
788   {
789    "cell_type": "code",
790    "execution_count": null,
791    "id": "be46a2dc-bf5c-4893-a1ee-a1682566f7a2",
792    "metadata": {},
793    "outputs": [],
794    "source": [
795     "errs.mean()"
796    ]
797   },
798   {
799    "cell_type": "code",
800    "execution_count": null,
801    "id": "0f319f37-7d13-41fd-95fa-66dbdfeab588",
802    "metadata": {},
803    "outputs": [],
804    "source": []
805   },
806   {
807    "cell_type": "code",
808    "execution_count": null,
809    "id": "b1252b08-62b9-4d24-add2-0f87d15b0ff2",
810    "metadata": {},
811    "outputs": [],
812    "source": [
813     "params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
814     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
815    ]
816   },
817   {
818    "cell_type": "code",
819    "execution_count": null,
820    "id": "9281540b-eb26-4923-883b-1b31d8347634",
821    "metadata": {},
822    "outputs": [],
823    "source": [
824     "reproducibility.set_seed(123)\n",
825     "rnn = RNN(params)\n",
826     "m, errs = rnn.run_model(rnn_dat)"
827    ]
828   },
829   {
830    "cell_type": "code",
831    "execution_count": null,
832    "id": "8a0269b4-d6b7-4f20-8386-69814d7acaa3",
833    "metadata": {},
834    "outputs": [],
835    "source": [
836     "errs.mean()"
837    ]
838   },
839   {
840    "cell_type": "code",
841    "execution_count": null,
842    "id": "10b44de3-a0e9-49e4-9e03-873d69580c07",
843    "metadata": {},
844    "outputs": [],
845    "source": []
846   },
847   {
848    "cell_type": "code",
849    "execution_count": null,
850    "id": "27f4fee4-7fce-49c5-a455-97a90b754c13",
851    "metadata": {},
852    "outputs": [],
853    "source": []
854   },
855   {
856    "cell_type": "code",
857    "execution_count": null,
858    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
859    "metadata": {},
860    "outputs": [],
861    "source": []
862   }
863  ],
864  "metadata": {
865   "kernelspec": {
866    "display_name": "Python 3 (ipykernel)",
867    "language": "python",
868    "name": "python3"
869   },
870   "language_info": {
871    "codemirror_mode": {
872     "name": "ipython",
873     "version": 3
874    },
875    "file_extension": ".py",
876    "mimetype": "text/x-python",
877    "name": "python",
878    "nbconvert_exporter": "python",
879    "pygments_lexer": "ipython3",
880    "version": "3.12.5"
881   }
882  },
883  "nbformat": 4,
884  "nbformat_minor": 5