Process code
[notebooks.git] / fmda / rnn_workshop.ipynb
blob878daaf179e924096043f0b15cd536c10a02f0d5
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.2 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, build_train_dict, combine_nested, subset_by_features\n",
32     "# from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, rnn_data_wrap\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights, str2time\n",
40     "import yaml\n",
41     "import copy\n",
42     "import time"
43    ]
44   },
45   {
46    "cell_type": "code",
47    "execution_count": null,
48    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
49    "metadata": {},
50    "outputs": [],
51    "source": [
52     "logging_setup()"
53    ]
54   },
55   {
56    "cell_type": "markdown",
57    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
58    "metadata": {},
59    "source": [
60     "## Test Data"
61    ]
62   },
63   {
64    "cell_type": "code",
65    "execution_count": null,
66    "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
67    "metadata": {},
68    "outputs": [],
69    "source": [
70     "file_paths = ['data/fmda_rocky_202403-05_f05.pkl']"
71    ]
72   },
73   {
74    "cell_type": "code",
75    "execution_count": null,
76    "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
77    "metadata": {},
78    "outputs": [],
79    "source": [
80     "# Params used for data filtering\n",
81     "params_data = read_yml(\"params_data.yaml\") \n",
82     "params_data"
83    ]
84   },
85   {
86    "cell_type": "code",
87    "execution_count": null,
88    "id": "c45cb8ef-41fc-4bf7-b506-dad5fd24abb3",
89    "metadata": {},
90    "outputs": [],
91    "source": [
92     "dat = read_pkl(file_paths[0])"
93    ]
94   },
95   {
96    "cell_type": "code",
97    "execution_count": null,
98    "id": "3c960d69-4f8a-4abb-a5d9-ed6cf98f899b",
99    "metadata": {},
100    "outputs": [],
101    "source": [
102     "import importlib\n",
103     "import data_funcs\n",
104     "importlib.reload(data_funcs)\n",
105     "from data_funcs import build_train_dict"
106    ]
107   },
108   {
109    "cell_type": "code",
110    "execution_count": null,
111    "id": "369cd913-85cb-4855-a80c-817d84637852",
112    "metadata": {},
113    "outputs": [],
114    "source": [
115     "params_data.update({'hours': None})"
116    ]
117   },
118   {
119    "cell_type": "code",
120    "execution_count": null,
121    "id": "8cdc2ce8-45b4-4caa-81d9-646271ff2e97",
122    "metadata": {
123     "scrolled": true
124    },
125    "outputs": [],
126    "source": [
127     "train3 = build_train_dict(file_paths, params_data, spatial=False, forecast_step=3, drop_na=True)\n"
128    ]
129   },
130   {
131    "cell_type": "code",
132    "execution_count": null,
133    "id": "3c4548ae-caa4-4bc4-9122-9f24e7e59ef7",
134    "metadata": {},
135    "outputs": [],
136    "source": []
137   },
138   {
139    "cell_type": "code",
140    "execution_count": null,
141    "id": "3dbb6f24-4435-47b3-90c6-6176582b0d4c",
142    "metadata": {},
143    "outputs": [],
144    "source": []
145   },
146   {
147    "cell_type": "markdown",
148    "id": "6322f0bc-107d-40a5-96dc-804495085a99",
149    "metadata": {
150     "jp-MarkdownHeadingCollapsed": true
151    },
152    "source": [
153     "## Test Other ML"
154    ]
155   },
156   {
157    "cell_type": "code",
158    "execution_count": null,
159    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
160    "metadata": {},
161    "outputs": [],
162    "source": [
163     "params = read_yml(\"params.yaml\", subkey='xgb')\n",
164     "params"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "dat = read_pkl(\"data/train.pkl\")"
175    ]
176   },
177   {
178    "cell_type": "code",
179    "execution_count": null,
180    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
181    "metadata": {},
182    "outputs": [],
183    "source": [
184     "cases = [*dat.keys()]"
185    ]
186   },
187   {
188    "cell_type": "code",
189    "execution_count": null,
190    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
191    "metadata": {},
192    "outputs": [],
193    "source": [
194     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
195     "rnn_dat.train_test_split(\n",
196     "    time_fracs = [.8, .1, .1]\n",
197     ")\n",
198     "rnn_dat.scale_data()"
199    ]
200   },
201   {
202    "cell_type": "code",
203    "execution_count": null,
204    "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
205    "metadata": {},
206    "outputs": [],
207    "source": [
208     "from moisture_models import XGB, RF, LM"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "mod = XGB(params)"
219    ]
220   },
221   {
222    "cell_type": "code",
223    "execution_count": null,
224    "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
225    "metadata": {},
226    "outputs": [],
227    "source": [
228     "mod.params"
229    ]
230   },
231   {
232    "cell_type": "code",
233    "execution_count": null,
234    "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
235    "metadata": {},
236    "outputs": [],
237    "source": [
238     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
239    ]
240   },
241   {
242    "cell_type": "code",
243    "execution_count": null,
244    "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
245    "metadata": {},
246    "outputs": [],
247    "source": [
248     "preds = mod.predict(rnn_dat.X_test)"
249    ]
250   },
251   {
252    "cell_type": "code",
253    "execution_count": null,
254    "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
255    "metadata": {},
256    "outputs": [],
257    "source": [
258     "rmse(preds, rnn_dat.y_test)"
259    ]
260   },
261   {
262    "cell_type": "code",
263    "execution_count": null,
264    "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
265    "metadata": {},
266    "outputs": [],
267    "source": [
268     "plt.plot(rnn_dat.y_test)\n",
269     "plt.plot(preds)"
270    ]
271   },
272   {
273    "cell_type": "code",
274    "execution_count": null,
275    "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
276    "metadata": {},
277    "outputs": [],
278    "source": [
279     "params = read_yml(\"params.yaml\", subkey='rf')\n",
280     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
281     "rnn_dat.train_test_split(\n",
282     "    time_fracs = [.8, .1, .1]\n",
283     ")"
284    ]
285   },
286   {
287    "cell_type": "code",
288    "execution_count": null,
289    "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
290    "metadata": {},
291    "outputs": [],
292    "source": [
293     "import importlib\n",
294     "import moisture_models\n",
295     "importlib.reload(moisture_models)"
296    ]
297   },
298   {
299    "cell_type": "code",
300    "execution_count": null,
301    "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
302    "metadata": {},
303    "outputs": [],
304    "source": [
305     "params"
306    ]
307   },
308   {
309    "cell_type": "code",
310    "execution_count": null,
311    "id": "fafe76e5-0212-4bd1-a058-535935a08780",
312    "metadata": {},
313    "outputs": [],
314    "source": [
315     "mod2 = RF(params)\n",
316     "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
317     "preds2 = mod2.predict(rnn_dat.X_test)\n",
318     "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
319     "plt.plot(rnn_dat.y_test)\n",
320     "plt.plot(preds2)"
321    ]
322   },
323   {
324    "cell_type": "code",
325    "execution_count": null,
326    "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
327    "metadata": {},
328    "outputs": [],
329    "source": [
330     "from moisture_models import RF\n",
331     "mod2 = RF(params)"
332    ]
333   },
334   {
335    "cell_type": "code",
336    "execution_count": null,
337    "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
338    "metadata": {},
339    "outputs": [],
340    "source": []
341   },
342   {
343    "cell_type": "code",
344    "execution_count": null,
345    "id": "c5598bfe-2d87-4d23-869e-aff127782462",
346    "metadata": {},
347    "outputs": [],
348    "source": [
349     "params = read_yml(\"params.yaml\", subkey='lm')\n",
350     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
351     "rnn_dat.train_test_split(\n",
352     "    time_fracs = [.8, .1, .1]\n",
353     ")\n",
354     "mod = LM(params)"
355    ]
356   },
357   {
358    "cell_type": "code",
359    "execution_count": null,
360    "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
361    "metadata": {},
362    "outputs": [],
363    "source": [
364     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
365     "preds = mod.predict(rnn_dat.X_test)\n",
366     "print(rmse(preds2, rnn_dat.y_test.flatten()))"
367    ]
368   },
369   {
370    "cell_type": "code",
371    "execution_count": null,
372    "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
373    "metadata": {},
374    "outputs": [],
375    "source": []
376   },
377   {
378    "cell_type": "code",
379    "execution_count": null,
380    "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
381    "metadata": {},
382    "outputs": [],
383    "source": []
384   },
385   {
386    "cell_type": "markdown",
387    "id": "d6e089d9-e466-45bb-80f2-15c563ae21ad",
388    "metadata": {},
389    "source": [
390     "## Class RNN "
391    ]
392   },
393   {
394    "cell_type": "code",
395    "execution_count": null,
396    "id": "3d5792a1-53e3-4099-8630-1bd5e3f52dcc",
397    "metadata": {},
398    "outputs": [],
399    "source": [
400     "from tensorflow.keras import layers,models"
401    ]
402   },
403   {
404    "cell_type": "code",
405    "execution_count": null,
406    "id": "0962428e-1124-4e1f-8500-d02b26640204",
407    "metadata": {},
408    "outputs": [],
409    "source": [
410     "import importlib\n",
411     "import moisture_rnn\n",
412     "importlib.reload(moisture_rnn)\n",
413     "from moisture_rnn import RNN, RNNParams"
414    ]
415   },
416   {
417    "cell_type": "code",
418    "execution_count": null,
419    "id": "a14f9c76-93eb-4b13-a11d-6ccb38285335",
420    "metadata": {},
421    "outputs": [],
422    "source": [
423     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))"
424    ]
425   },
426   {
427    "cell_type": "code",
428    "execution_count": null,
429    "id": "ed3dd798-6a40-4e90-b40b-accabe49fb35",
430    "metadata": {},
431    "outputs": [],
432    "source": [
433     "params.update({\n",
434     "    'hidden_layers': ['lstm', 'conv1d', 'dense'],\n",
435     "    'hidden_units': [32, 32, 16],\n",
436     "    'hidden_activation': ['tanh', 'relu', 'relu'],\n",
437     "    'return_sequences': True\n",
438     "})"
439    ]
440   },
441   {
442    "cell_type": "code",
443    "execution_count": null,
444    "id": "e559d0d7-5847-4fd0-81e4-7d3ca92147dd",
445    "metadata": {},
446    "outputs": [],
447    "source": [
448     "import importlib\n",
449     "import moisture_rnn\n",
450     "importlib.reload(moisture_rnn)\n",
451     "from moisture_rnn import RNN, rnn_data_wrap"
452    ]
453   },
454   {
455    "cell_type": "code",
456    "execution_count": null,
457    "id": "7c1627f9-f011-4159-98a2-1b5973929e71",
458    "metadata": {},
459    "outputs": [],
460    "source": [
461     "reproducibility.set_seed()\n",
462     "mod = RNN(params)"
463    ]
464   },
465   {
466    "cell_type": "code",
467    "execution_count": null,
468    "id": "5dbc66c0-ccb5-46c2-a073-1fa7a5be750a",
469    "metadata": {},
470    "outputs": [],
471    "source": [
472     "mod.model_train.summary()"
473    ]
474   },
475   {
476    "cell_type": "code",
477    "execution_count": null,
478    "id": "882c5872-a017-4d9c-90be-88e692dd33e8",
479    "metadata": {},
480    "outputs": [],
481    "source": [
482     "mod.model_predict.summary()"
483    ]
484   },
485   {
486    "cell_type": "code",
487    "execution_count": null,
488    "id": "30498201-3798-484d-922f-974909b195af",
489    "metadata": {},
490    "outputs": [],
491    "source": [
492     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
493    ]
494   },
495   {
496    "cell_type": "code",
497    "execution_count": null,
498    "id": "e213ffd7-d26c-41ce-8e2b-b17368fdd7a8",
499    "metadata": {},
500    "outputs": [],
501    "source": [
502     "params.update({\n",
503     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
504     "})"
505    ]
506   },
507   {
508    "cell_type": "code",
509    "execution_count": null,
510    "id": "74e599b6-7f4d-4175-a5f1-de892e72ebd4",
511    "metadata": {},
512    "outputs": [],
513    "source": [
514     "m, errs = mod.run_model(rnn_dat)"
515    ]
516   },
517   {
518    "cell_type": "code",
519    "execution_count": null,
520    "id": "f894d203-d277-48f3-bb57-a610f162361f",
521    "metadata": {},
522    "outputs": [],
523    "source": [
524     "errs.mean()"
525    ]
526   },
527   {
528    "cell_type": "code",
529    "execution_count": null,
530    "id": "b875ea70-41f9-4550-982b-88380ad1b5a0",
531    "metadata": {},
532    "outputs": [],
533    "source": [
534     "params"
535    ]
536   },
537   {
538    "cell_type": "markdown",
539    "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
540    "metadata": {},
541    "source": [
542     "## RNN"
543    ]
544   },
545   {
546    "cell_type": "code",
547    "execution_count": null,
548    "id": "8c1894e3-5283-4e5e-83ae-9c386836a990",
549    "metadata": {},
550    "outputs": [],
551    "source": [
552     "import importlib \n",
553     "import moisture_rnn\n",
554     "importlib.reload(moisture_rnn)\n",
555     "from moisture_rnn import RNN"
556    ]
557   },
558   {
559    "cell_type": "code",
560    "execution_count": null,
561    "id": "aa1b690f-edaa-4c97-893c-ec9a3a615ce1",
562    "metadata": {},
563    "outputs": [],
564    "source": [
565     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
566     "params = RNNParams(params)\n",
567     "params.update({\n",
568     "    'dense_layers': 2,\n",
569     "    'dense_units': 32\n",
570     "})"
571    ]
572   },
573   {
574    "cell_type": "code",
575    "execution_count": null,
576    "id": "054ab015-4e41-4255-8b1a-843b61e3d21d",
577    "metadata": {},
578    "outputs": [],
579    "source": [
580     "params.update({'batch_schedule_type': 'step'})"
581    ]
582   },
583   {
584    "cell_type": "code",
585    "execution_count": null,
586    "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
587    "metadata": {},
588    "outputs": [],
589    "source": [
590     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
591     "reproducibility.set_seed(123)\n",
592     "rnn = RNN(params)"
593    ]
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "27d11b75-89e9-43a9-8801-7be7fb845b09",
599    "metadata": {},
600    "outputs": [],
601    "source": [
602     "rnn.model_train.summary()"
603    ]
604   },
605   {
606    "cell_type": "code",
607    "execution_count": null,
608    "id": "b9a0b3fb-aaab-4948-b6e6-824e9dcb92a7",
609    "metadata": {},
610    "outputs": [],
611    "source": [
612     "rnn.model_predict.summary()"
613    ]
614   },
615   {
616    "cell_type": "code",
617    "execution_count": null,
618    "id": "ade176b9-2844-43b6-b85e-5bb30414aa35",
619    "metadata": {},
620    "outputs": [],
621    "source": [
622     "rnn.params"
623    ]
624   },
625   {
626    "cell_type": "code",
627    "execution_count": null,
628    "id": "5945e6c1-6b3a-4b7d-ade2-b5788860ef18",
629    "metadata": {},
630    "outputs": [],
631    "source": [
632     "rnn.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, validation_data=(rnn_dat.X_val, rnn_dat.y_val), \n",
633     "                    verbose=True, epochs=20)"
634    ]
635   },
636   {
637    "cell_type": "code",
638    "execution_count": null,
639    "id": "2d123b2b-047e-4a04-b49e-6629cc22edc6",
640    "metadata": {},
641    "outputs": [],
642    "source": [
643     "rnn.model_predict.set_weights(rnn.model_train.get_weights())"
644    ]
645   },
646   {
647    "cell_type": "code",
648    "execution_count": null,
649    "id": "db57df64-d2ac-4b91-bbfc-71a5834ddf41",
650    "metadata": {},
651    "outputs": [],
652    "source": [
653     "rnn.model_predict.summary()"
654    ]
655   },
656   {
657    "cell_type": "code",
658    "execution_count": null,
659    "id": "0466887f-9833-4a6a-a0c7-a4d56f207d33",
660    "metadata": {},
661    "outputs": [],
662    "source": [
663     "rnn_dat.X_test.shape"
664    ]
665   },
666   {
667    "cell_type": "code",
668    "execution_count": null,
669    "id": "1d3e630c-db69-4603-962e-95c576b45ac9",
670    "metadata": {},
671    "outputs": [],
672    "source": [
673     "preds = rnn.model_predict.predict(rnn_dat.X_test)"
674    ]
675   },
676   {
677    "cell_type": "code",
678    "execution_count": null,
679    "id": "8b8228a9-5b6d-4de1-8968-d40277edacd2",
680    "metadata": {},
681    "outputs": [],
682    "source": [
683     "preds.shape"
684    ]
685   },
686   {
687    "cell_type": "code",
688    "execution_count": null,
689    "id": "8b001dd8-ffd7-4fd1-bf11-413515ddc488",
690    "metadata": {},
691    "outputs": [],
692    "source": [
693     "rnn_dat.X_test.shape"
694    ]
695   },
696   {
697    "cell_type": "code",
698    "execution_count": null,
699    "id": "f96c6dbf-6ca8-451e-abc4-b68b8116871b",
700    "metadata": {},
701    "outputs": [],
702    "source": [
703     "squared_diff = np.square(preds - rnn_dat.y_test)\n",
704     "mse = np.mean(squared_diff, axis=(1, 2))\n",
705     "errs = np.sqrt(mse)\n",
706     "errs.mean()"
707    ]
708   },
709   {
710    "cell_type": "markdown",
711    "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
712    "metadata": {},
713    "source": [
714     "## Phys Initialized"
715    ]
716   },
717   {
718    "cell_type": "code",
719    "execution_count": null,
720    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
721    "metadata": {},
722    "outputs": [],
723    "source": [
724     "params.update({\n",
725     "    'epochs':100,\n",
726     "    'dense_layers': 0,\n",
727     "    'activation': ['relu', 'relu'],\n",
728     "    'phys_initialize': False,\n",
729     "    'dropout': [0,0],\n",
730     "    'space_fracs': [.8, .1, .1],\n",
731     "    'scaler': None\n",
732     "})"
733    ]
734   },
735   {
736    "cell_type": "code",
737    "execution_count": null,
738    "id": "ab7db7d6-949e-457d-90b9-22d9c5aa4739",
739    "metadata": {},
740    "outputs": [],
741    "source": [
742     "import importlib\n",
743     "import moisture_rnn\n",
744     "importlib.reload(moisture_rnn)\n",
745     "from moisture_rnn import rnn_data_wrap"
746    ]
747   },
748   {
749    "cell_type": "code",
750    "execution_count": null,
751    "id": "d26cf1b2-2fad-409d-888f-4921b0ae4ba8",
752    "metadata": {},
753    "outputs": [],
754    "source": [
755     "params['scaler'] is None"
756    ]
757   },
758   {
759    "cell_type": "code",
760    "execution_count": null,
761    "id": "1c4627bc-0f90-44e6-9103-2efe5c5f439d",
762    "metadata": {},
763    "outputs": [],
764    "source": [
765     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
766    ]
767   },
768   {
769    "cell_type": "code",
770    "execution_count": null,
771    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
772    "metadata": {},
773    "outputs": [],
774    "source": [
775     "reproducibility.set_seed()\n",
776     "rnn = RNN(params)\n",
777     "m, errs = rnn.run_model(rnn_dat)"
778    ]
779   },
780   {
781    "cell_type": "code",
782    "execution_count": null,
783    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
784    "metadata": {},
785    "outputs": [],
786    "source": [
787     "rnn.model_train.summary()"
788    ]
789   },
790   {
791    "cell_type": "code",
792    "execution_count": null,
793    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
794    "metadata": {},
795    "outputs": [],
796    "source": [
797     "errs.mean()"
798    ]
799   },
800   {
801    "cell_type": "code",
802    "execution_count": null,
803    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
804    "metadata": {
805     "scrolled": true
806    },
807    "outputs": [],
808    "source": [
809     "rnn_dat.X_train[:,:,0].mean()"
810    ]
811   },
812   {
813    "cell_type": "code",
814    "execution_count": null,
815    "id": "7ca41db1-72aa-44b6-b9dd-058735336ab3",
816    "metadata": {},
817    "outputs": [],
818    "source": []
819   },
820   {
821    "cell_type": "code",
822    "execution_count": null,
823    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
824    "metadata": {},
825    "outputs": [],
826    "source": [
827     "rnn_dat['features_list']"
828    ]
829   },
830   {
831    "cell_type": "code",
832    "execution_count": null,
833    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
834    "metadata": {},
835    "outputs": [],
836    "source": []
837   },
838   {
839    "cell_type": "markdown",
840    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
841    "metadata": {},
842    "source": [
843     "## LSTM"
844    ]
845   },
846   {
847    "cell_type": "code",
848    "execution_count": null,
849    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
850    "metadata": {},
851    "outputs": [],
852    "source": [
853     "import importlib \n",
854     "import moisture_rnn\n",
855     "importlib.reload(moisture_rnn)\n",
856     "from moisture_rnn import RNN_LSTM"
857    ]
858   },
859   {
860    "cell_type": "code",
861    "execution_count": null,
862    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
863    "metadata": {},
864    "outputs": [],
865    "source": [
866     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
867     "params = RNNParams(params)"
868    ]
869   },
870   {
871    "cell_type": "code",
872    "execution_count": null,
873    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
874    "metadata": {},
875    "outputs": [],
876    "source": [
877     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
878    ]
879   },
880   {
881    "cell_type": "code",
882    "execution_count": null,
883    "id": "57bb5708-7be9-4474-abb4-3b7ff4bf79df",
884    "metadata": {},
885    "outputs": [],
886    "source": [
887     "params.update({\n",
888     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
889     "})"
890    ]
891   },
892   {
893    "cell_type": "code",
894    "execution_count": null,
895    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
896    "metadata": {},
897    "outputs": [],
898    "source": [
899     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
900     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
901     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
902     "              'batch_schedule_type':'step', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
903     "reproducibility.set_seed(123)\n",
904     "lstm = RNN_LSTM(params)\n",
905     "\n",
906     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
907     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
908     "                    callbacks = [ResetStatesCallback(params),\n",
909     "                                EarlyStoppingCallback(patience = 15)],\n",
910     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
911     "              "
912    ]
913   },
914   {
915    "cell_type": "code",
916    "execution_count": null,
917    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
918    "metadata": {},
919    "outputs": [],
920    "source": []
921   },
922   {
923    "cell_type": "code",
924    "execution_count": null,
925    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
926    "metadata": {},
927    "outputs": [],
928    "source": [
929     "params = RNNParams(read_yml(\"params.yaml\", subkey=\"lstm\"))\n",
930     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
931     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
932     "              'batch_schedule_type':'step', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
933     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)\n",
934     "params.update({\n",
935     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
936     "})\n",
937     "reproducibility.set_seed(123)\n",
938     "lstm = RNN_LSTM(params)\n",
939     "m, errs = lstm.run_model(rnn_dat)"
940    ]
941   },
942   {
943    "cell_type": "code",
944    "execution_count": null,
945    "id": "be46a2dc-bf5c-4893-a1ee-a1682566f7a2",
946    "metadata": {},
947    "outputs": [],
948    "source": [
949     "errs.mean()"
950    ]
951   },
952   {
953    "cell_type": "code",
954    "execution_count": null,
955    "id": "0f319f37-7d13-41fd-95fa-66dbdfeab588",
956    "metadata": {},
957    "outputs": [],
958    "source": []
959   },
960   {
961    "cell_type": "code",
962    "execution_count": null,
963    "id": "b1252b08-62b9-4d24-add2-0f87d15b0ff2",
964    "metadata": {},
965    "outputs": [],
966    "source": [
967     "params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
968     "rnn_dat = rnn_data_wrap(combine_nested(train3), params)"
969    ]
970   },
971   {
972    "cell_type": "code",
973    "execution_count": null,
974    "id": "9281540b-eb26-4923-883b-1b31d8347634",
975    "metadata": {},
976    "outputs": [],
977    "source": [
978     "reproducibility.set_seed(123)\n",
979     "rnn = RNN(params)\n",
980     "m, errs = rnn.run_model(rnn_dat)"
981    ]
982   },
983   {
984    "cell_type": "code",
985    "execution_count": null,
986    "id": "8a0269b4-d6b7-4f20-8386-69814d7acaa3",
987    "metadata": {},
988    "outputs": [],
989    "source": [
990     "errs.mean()"
991    ]
992   },
993   {
994    "cell_type": "code",
995    "execution_count": null,
996    "id": "10b44de3-a0e9-49e4-9e03-873d69580c07",
997    "metadata": {},
998    "outputs": [],
999    "source": []
1000   },
1001   {
1002    "cell_type": "code",
1003    "execution_count": null,
1004    "id": "27f4fee4-7fce-49c5-a455-97a90b754c13",
1005    "metadata": {},
1006    "outputs": [],
1007    "source": []
1008   },
1009   {
1010    "cell_type": "code",
1011    "execution_count": null,
1012    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
1013    "metadata": {},
1014    "outputs": [],
1015    "source": []
1016   }
1017  ],
1018  "metadata": {
1019   "kernelspec": {
1020    "display_name": "Python 3 (ipykernel)",
1021    "language": "python",
1022    "name": "python3"
1023   },
1024   "language_info": {
1025    "codemirror_mode": {
1026     "name": "ipython",
1027     "version": 3
1028    },
1029    "file_extension": ".py",
1030    "mimetype": "text/x-python",
1031    "name": "python",
1032    "nbconvert_exporter": "python",
1033    "pygments_lexer": "ipython3",
1034    "version": "3.12.5"
1035   }
1036  },
1037  "nbformat": 4,
1038  "nbformat_minor": 5