Update moisture_rnn.py
[notebooks.git] / fmda / rnn_workshop.ipynb
blob7a526a1c646407857386ddbada168b62fe48b7b6
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, create_rnn_data2\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "b8fe1011-a0cc-46a4-98b7-d82c2b22f5b0",
57    "metadata": {},
58    "source": [
59     "## Test Batch Reset"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "0df1c817-d422-4cfa-a4c5-b02549cdaffa",
66    "metadata": {
67     "scrolled": true
68    },
69    "outputs": [],
70    "source": [
71     "train = read_pkl('train.pkl')\n",
72     "train.keys()"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "4f2623ea-3504-446e-8243-f93ccce6b62e",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "import importlib\n",
83     "import moisture_rnn\n",
84     "importlib.reload(moisture_rnn)\n",
85     "from moisture_rnn import RNN, RNNData"
86    ]
87   },
88   {
89    "cell_type": "code",
90    "execution_count": null,
91    "id": "948138a6-1854-428c-b5ec-75e87c9c50e7",
92    "metadata": {},
93    "outputs": [],
94    "source": [
95     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
96     "params = RNNParams(params)\n",
97     "rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])\n",
98     "rnn_dat.train_test_split(\n",
99     "    train_frac = .9,\n",
100     "    val_frac = .05\n",
101     ")\n",
102     "rnn_dat.scale_data()\n",
103     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
104    ]
105   },
106   {
107    "cell_type": "code",
108    "execution_count": null,
109    "id": "78255617-5511-4e54-a022-8dba9946bbe2",
110    "metadata": {},
111    "outputs": [],
112    "source": [
113     "reproducibility.set_seed()\n",
114     "params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours})\n",
115     "rnn = RNN(params)\n",
116     "m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")"
117    ]
118   },
119   {
120    "cell_type": "code",
121    "execution_count": null,
122    "id": "d3a32de3-9556-491a-9bf1-3d252762f2b7",
123    "metadata": {},
124    "outputs": [],
125    "source": []
126   },
127   {
128    "cell_type": "code",
129    "execution_count": null,
130    "id": "2a7d607c-8f29-4a18-948b-4d939ebd5a34",
131    "metadata": {},
132    "outputs": [],
133    "source": []
134   },
135   {
136    "cell_type": "code",
137    "execution_count": null,
138    "id": "552c6e02-4a2d-4f50-9d6a-7e11bdbcfffc",
139    "metadata": {},
140    "outputs": [],
141    "source": []
142   },
143   {
144    "cell_type": "code",
145    "execution_count": null,
146    "id": "5cfd0dbe-8e7d-4d9e-a21c-9001a498084c",
147    "metadata": {},
148    "outputs": [],
149    "source": [
150     "params.update({'epochs': 2, 'verbose_fit': True, 'batch_size': 32, \n",
151     "        'rnn_layers': 2, 'activation':['relu', 'relu']})\n",
152     "rnn_dat = RNNData(train['PLFI1_202401'], scaler=params['scaler'], features_list = params['features_list'])\n",
153     "rnn_dat.train_test_split(\n",
154     "    train_frac = .9,\n",
155     "    val_frac = .05\n",
156     ")\n",
157     "rnn_dat.scale_data()\n",
158     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
159     "reproducibility.set_seed()\n",
160     "rnn = RNN(params)\n",
161     "m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")"
162    ]
163   },
164   {
165    "cell_type": "code",
166    "execution_count": null,
167    "id": "db2abad4-16d4-4afc-a0d8-b2dec6b872c2",
168    "metadata": {},
169    "outputs": [],
170    "source": []
171   },
172   {
173    "cell_type": "code",
174    "execution_count": null,
175    "id": "74158e6e-c84f-4a90-9f0a-c35cb711d9ed",
176    "metadata": {},
177    "outputs": [],
178    "source": []
179   },
180   {
181    "cell_type": "code",
182    "execution_count": null,
183    "id": "c022cce2-8863-43f4-96e8-c604ba2fe8bc",
184    "metadata": {},
185    "outputs": [],
186    "source": []
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "81037c9e-088f-4dd2-bbfa-5b9b2d044d80",
192    "metadata": {},
193    "outputs": [],
194    "source": []
195   },
196   {
197    "cell_type": "code",
198    "execution_count": null,
199    "id": "9aaaa2ff-6757-48a6-b03b-568d2b0d01b0",
200    "metadata": {},
201    "outputs": [],
202    "source": []
203   },
204   {
205    "cell_type": "code",
206    "execution_count": null,
207    "id": "ca8637d2-d111-4054-b174-ef913f0d9206",
208    "metadata": {},
209    "outputs": [],
210    "source": []
211   },
212   {
213    "cell_type": "code",
214    "execution_count": null,
215    "id": "12217194-a9be-49bc-99ef-6b2a107ab4f3",
216    "metadata": {},
217    "outputs": [],
218    "source": []
219   },
220   {
221    "cell_type": "code",
222    "execution_count": null,
223    "id": "7c3c428e-c628-4712-bdc7-15dce26837dd",
224    "metadata": {},
225    "outputs": [],
226    "source": []
227   },
228   {
229    "cell_type": "code",
230    "execution_count": null,
231    "id": "84fe9438-2f2e-483c-a5a4-9dd3a61f50aa",
232    "metadata": {},
233    "outputs": [],
234    "source": []
235   },
236   {
237    "cell_type": "code",
238    "execution_count": null,
239    "id": "77d90197-2b50-4621-9eee-def323ed836e",
240    "metadata": {},
241    "outputs": [],
242    "source": []
243   },
244   {
245    "cell_type": "code",
246    "execution_count": null,
247    "id": "e492ae99-ea1f-4185-9a33-41573263f2f1",
248    "metadata": {},
249    "outputs": [],
250    "source": []
251   },
252   {
253    "cell_type": "code",
254    "execution_count": null,
255    "id": "56b18e34-b50b-48a6-947a-5714b65e85cf",
256    "metadata": {},
257    "outputs": [],
258    "source": [
259     "ep = 15\n",
260     "bmin = 10\n",
261     "bmax = 500\n",
262     "xgrid = np.arange(0, ep)\n",
263     "plt.plot(xgrid, calc_exp_intervals(bmin, bmax, ep))\n",
264     "plt.plot(xgrid, calc_log_intervals(bmin, bmax, ep))"
265    ]
266   },
267   {
268    "cell_type": "code",
269    "execution_count": null,
270    "id": "8ae6b9d7-f108-4071-a9fc-1b7a32b26d75",
271    "metadata": {},
272    "outputs": [],
273    "source": []
274   },
275   {
276    "cell_type": "code",
277    "execution_count": null,
278    "id": "7be5f53c-6130-4d94-986b-e9b5dce7fdae",
279    "metadata": {},
280    "outputs": [],
281    "source": []
282   },
283   {
284    "cell_type": "code",
285    "execution_count": null,
286    "id": "b0b0a959-f07c-4b62-bee2-faa993320dda",
287    "metadata": {},
288    "outputs": [],
289    "source": []
290   },
291   {
292    "cell_type": "code",
293    "execution_count": null,
294    "id": "23d94e44-cff1-4a2a-9a0e-039bf401d4e3",
295    "metadata": {},
296    "outputs": [],
297    "source": []
298   },
299   {
300    "cell_type": "code",
301    "execution_count": null,
302    "id": "98382367-820a-4aad-97da-3ea2bc895b0f",
303    "metadata": {},
304    "outputs": [],
305    "source": []
306   },
307   {
308    "cell_type": "code",
309    "execution_count": null,
310    "id": "185c6f90-fe7f-4b05-b5ef-20635051f18b",
311    "metadata": {},
312    "outputs": [],
313    "source": []
314   },
315   {
316    "cell_type": "code",
317    "execution_count": null,
318    "id": "5520cd04-72b3-4550-bc54-ad78cbf77ec0",
319    "metadata": {},
320    "outputs": [],
321    "source": []
322   },
323   {
324    "cell_type": "code",
325    "execution_count": null,
326    "id": "fd0a4fa9-603f-4662-ad4e-14df33337441",
327    "metadata": {},
328    "outputs": [],
329    "source": []
330   },
331   {
332    "cell_type": "code",
333    "execution_count": null,
334    "id": "253c667a-748d-48cc-8479-50616e043609",
335    "metadata": {},
336    "outputs": [],
337    "source": []
338   },
339   {
340    "cell_type": "code",
341    "execution_count": null,
342    "id": "90cafb80-72e9-4610-8413-ca407f03dbd0",
343    "metadata": {},
344    "outputs": [],
345    "source": []
346   },
347   {
348    "cell_type": "code",
349    "execution_count": null,
350    "id": "9ebf0ccd-554f-4a0a-87e9-36de0a62a34c",
351    "metadata": {},
352    "outputs": [],
353    "source": []
354   },
355   {
356    "cell_type": "code",
357    "execution_count": null,
358    "id": "23d8d51b-6206-471a-a792-a85f4ad89637",
359    "metadata": {},
360    "outputs": [],
361    "source": []
362   },
363   {
364    "cell_type": "code",
365    "execution_count": null,
366    "id": "d25faf9f-d00f-44a6-a43f-e4a3277aad78",
367    "metadata": {},
368    "outputs": [],
369    "source": []
370   },
371   {
372    "cell_type": "code",
373    "execution_count": null,
374    "id": "e8c47be0-92f6-454f-9e04-385c3cb41831",
375    "metadata": {},
376    "outputs": [],
377    "source": []
378   },
379   {
380    "cell_type": "code",
381    "execution_count": null,
382    "id": "c8d197e6-1959-4a6f-8d2d-c67a97a123f9",
383    "metadata": {},
384    "outputs": [],
385    "source": []
386   },
387   {
388    "cell_type": "code",
389    "execution_count": null,
390    "id": "3aa5aece-79c5-42b7-98a2-3f8a3cfb8e29",
391    "metadata": {},
392    "outputs": [],
393    "source": []
394   },
395   {
396    "cell_type": "code",
397    "execution_count": null,
398    "id": "a9f8a650-330f-493e-a158-a683e2fd872d",
399    "metadata": {},
400    "outputs": [],
401    "source": []
402   },
403   {
404    "cell_type": "markdown",
405    "id": "b62f4360-e9d1-4510-bb5d-1d79a3a5ac75",
406    "metadata": {},
407    "source": [
408     "## Test Spatial Data"
409    ]
410   },
411   {
412    "cell_type": "code",
413    "execution_count": null,
414    "id": "3a04c2d3-3bf1-451d-88bc-7b1e8701cb52",
415    "metadata": {},
416    "outputs": [],
417    "source": [
418     "train = read_pkl('train.pkl')"
419    ]
420   },
421   {
422    "cell_type": "code",
423    "execution_count": null,
424    "id": "3d416f92-995a-427f-b76d-a6125061ee98",
425    "metadata": {},
426    "outputs": [],
427    "source": [
428     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
429     "params = RNNParams(params)"
430    ]
431   },
432   {
433    "cell_type": "code",
434    "execution_count": null,
435    "id": "1d76cd6e-2e0e-40ae-9a58-3ed62217a33d",
436    "metadata": {},
437    "outputs": [],
438    "source": [
439     "len(train.keys())"
440    ]
441   },
442   {
443    "cell_type": "code",
444    "execution_count": null,
445    "id": "4735052a-f046-4d52-8666-ce14e4a0e276",
446    "metadata": {},
447    "outputs": [],
448    "source": [
449     "from itertools import islice\n",
450     "dat = {k: train[k] for k in islice(train, 100)}"
451    ]
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "194f815f-c889-43ed-b0e8-853b1c4a8a81",
457    "metadata": {},
458    "outputs": [],
459    "source": [
460     "dat.keys()"
461    ]
462   },
463   {
464    "cell_type": "code",
465    "execution_count": null,
466    "id": "99819622-555b-4027-a644-5f75c76f7fbc",
467    "metadata": {},
468    "outputs": [],
469    "source": [
470     "from data_funcs import combine_nested\n",
471     "dd = combine_nested(dat)"
472    ]
473   },
474   {
475    "cell_type": "code",
476    "execution_count": null,
477    "id": "4a647e8a-61f3-4e5b-abb3-44d43d7e0844",
478    "metadata": {},
479    "outputs": [],
480    "source": [
481     "import importlib\n",
482     "import utils\n",
483     "importlib.reload(utils)\n",
484     "from utils import Dict"
485    ]
486   },
487   {
488    "cell_type": "code",
489    "execution_count": null,
490    "id": "6335deef-ea25-40bd-8a68-842af80cebe8",
491    "metadata": {},
492    "outputs": [],
493    "source": [
494     "dd = Dict(dd)"
495    ]
496   },
497   {
498    "cell_type": "code",
499    "execution_count": null,
500    "id": "e86c9e4d-4ccd-4d9d-92e1-2e4299549fa4",
501    "metadata": {},
502    "outputs": [],
503    "source": [
504     "import importlib\n",
505     "import moisture_rnn\n",
506     "importlib.reload(moisture_rnn)\n",
507     "from moisture_rnn import RNNData"
508    ]
509   },
510   {
511    "cell_type": "code",
512    "execution_count": null,
513    "id": "966c3559-740d-44d3-b98d-cc2efe63afcd",
514    "metadata": {},
515    "outputs": [],
516    "source": [
517     "rnn_dat = RNNData(dd, scaler=\"standard\", features_list = ['Ed', 'Ew', 'rain'])\n",
518     "rnn_dat.train_test_split(   \n",
519     "    train_frac = .9,\n",
520     "    val_frac = .05\n",
521     ")"
522    ]
523   },
524   {
525    "cell_type": "code",
526    "execution_count": null,
527    "id": "72289573-56a1-45ca-8551-b24c4c073bfd",
528    "metadata": {},
529    "outputs": [],
530    "source": [
531     "rnn_dat.scale_data()"
532    ]
533   },
534   {
535    "cell_type": "code",
536    "execution_count": null,
537    "id": "b35a8e1a-a161-42af-a595-2e1bae0fd0ba",
538    "metadata": {},
539    "outputs": [],
540    "source": [
541     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
542    ]
543   },
544   {
545    "cell_type": "code",
546    "execution_count": null,
547    "id": "c1520e93-0df1-41d1-98b3-f50edbe13b66",
548    "metadata": {},
549    "outputs": [],
550    "source": []
551   },
552   {
553    "cell_type": "code",
554    "execution_count": null,
555    "id": "fa14f0ab-07ff-4c67-bdcb-ff225610ffa2",
556    "metadata": {},
557    "outputs": [],
558    "source": [
559     "import importlib\n",
560     "import moisture_rnn\n",
561     "importlib.reload(moisture_rnn)\n",
562     "from moisture_rnn import RNN"
563    ]
564   },
565   {
566    "cell_type": "code",
567    "execution_count": null,
568    "id": "94f2030f-dbc9-4a6c-8e98-b932fe7691c7",
569    "metadata": {},
570    "outputs": [],
571    "source": [
572     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
573     "params.update({'epochs': 20, 'learning_rate': 0.0001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
574     "              'activation': ['relu', 'relu'], 'features_list': ['Ed', 'Ew', 'rain']})\n",
575     "reproducibility.set_seed(123)\n",
576     "rnn = RNN(params)\n",
577     "\n",
578     "history = rnn.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
579     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
580     "                    callbacks = [ResetStatesCallback(params),\n",
581     "                                EarlyStoppingCallback(patience = params['early_stopping_patience'])],\n",
582     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
583     "              "
584    ]
585   },
586   {
587    "cell_type": "code",
588    "execution_count": null,
589    "id": "32c39f16-6d80-44ef-b58a-be12869cd638",
590    "metadata": {},
591    "outputs": [],
592    "source": [
593     "plt.figure()\n",
594     "plt.semilogy(history.history['loss'], label='Training loss')\n",
595     "if 'val_loss' in history.history:\n",
596     "    plt.semilogy(history.history['val_loss'], label='Validation loss')\n",
597     "plt.ylabel('Loss')\n",
598     "plt.xlabel('Epoch')\n",
599     "plt.legend(loc='upper left')\n",
600     "plt.show()"
601    ]
602   },
603   {
604    "cell_type": "code",
605    "execution_count": null,
606    "id": "b5a88957-c7c7-4036-85bd-94cc4aa5c08c",
607    "metadata": {},
608    "outputs": [],
609    "source": [
610     "vpreds = rnn.model_train.predict(rnn_dat.X_val)"
611    ]
612   },
613   {
614    "cell_type": "code",
615    "execution_count": null,
616    "id": "f4f1ad13-6d23-4c3a-80f8-32d4cd7a9902",
617    "metadata": {},
618    "outputs": [],
619    "source": [
620     "vpreds.shape"
621    ]
622   },
623   {
624    "cell_type": "code",
625    "execution_count": null,
626    "id": "4e3e62bd-36f9-4ce5-befd-cecf01f13bd1",
627    "metadata": {},
628    "outputs": [],
629    "source": [
630     "rnn_dat.y_val.shape"
631    ]
632   },
633   {
634    "cell_type": "code",
635    "execution_count": null,
636    "id": "308ec7c9-a73b-4405-912a-454811a413ac",
637    "metadata": {},
638    "outputs": [],
639    "source": [
640     "from sklearn.metrics import mean_squared_error"
641    ]
642   },
643   {
644    "cell_type": "code",
645    "execution_count": null,
646    "id": "2d0cd42b-ffbd-413a-a73b-a622865c1b61",
647    "metadata": {},
648    "outputs": [],
649    "source": [
650     "mean_squared_error(vpreds, rnn_dat.y_val)"
651    ]
652   },
653   {
654    "cell_type": "code",
655    "execution_count": null,
656    "id": "1983a28d-f6b8-4a35-94f0-022c5ef898d2",
657    "metadata": {},
658    "outputs": [],
659    "source": [
660     "loss = tf.keras.losses.mse(rnn_dat.y_val, vpreds)\n",
661     "loss = tf.reduce_mean(loss).numpy()\n",
662     "loss"
663    ]
664   },
665   {
666    "cell_type": "code",
667    "execution_count": null,
668    "id": "28183109-605e-4392-b5ab-df79776f023a",
669    "metadata": {},
670    "outputs": [],
671    "source": [
672     "plt.scatter(vpreds, rnn_dat.y_val)"
673    ]
674   },
675   {
676    "cell_type": "code",
677    "execution_count": null,
678    "id": "c5335e7b-1d3f-4d76-8452-85b039b386ef",
679    "metadata": {},
680    "outputs": [],
681    "source": [
682     "hash_weights(rnn.model_train)"
683    ]
684   },
685   {
686    "cell_type": "code",
687    "execution_count": null,
688    "id": "1b4464b2-5ea2-4c1c-b092-9a478d5fffe4",
689    "metadata": {},
690    "outputs": [],
691    "source": [
692     "rnn.model_predict.set_weights(rnn.model_train.get_weights())"
693    ]
694   },
695   {
696    "cell_type": "code",
697    "execution_count": null,
698    "id": "c2a681d1-3402-4053-aec5-ecc1a94237b8",
699    "metadata": {},
700    "outputs": [],
701    "source": [
702     "hash_weights(rnn.model_predict)"
703    ]
704   },
705   {
706    "cell_type": "code",
707    "execution_count": null,
708    "id": "e94d82d8-5738-4d66-897c-3ac68036ec95",
709    "metadata": {},
710    "outputs": [],
711    "source": []
712   },
713   {
714    "cell_type": "code",
715    "execution_count": null,
716    "id": "cc74db07-84ed-4e33-949c-5e0548e98007",
717    "metadata": {},
718    "outputs": [],
719    "source": []
720   },
721   {
722    "cell_type": "code",
723    "execution_count": null,
724    "id": "df1309f1-6eac-4e07-b1bb-3c1a4cc9c5bf",
725    "metadata": {},
726    "outputs": [],
727    "source": []
728   },
729   {
730    "cell_type": "code",
731    "execution_count": null,
732    "id": "0166e780-3379-414b-b403-86bca3c36661",
733    "metadata": {},
734    "outputs": [],
735    "source": [
736     "preds = rnn.predict(rnn_dat.X_test[0])"
737    ]
738   },
739   {
740    "cell_type": "code",
741    "execution_count": null,
742    "id": "f878b950-b750-43db-b5e5-32723f9d0f07",
743    "metadata": {},
744    "outputs": [],
745    "source": [
746     "plt.plot(rnn_dat.y_test[2])\n",
747     "plt.plot(preds)"
748    ]
749   },
750   {
751    "cell_type": "code",
752    "execution_count": null,
753    "id": "3aa8dd59-12f0-46ec-b422-dcd23d8076bf",
754    "metadata": {},
755    "outputs": [],
756    "source": []
757   },
758   {
759    "cell_type": "code",
760    "execution_count": null,
761    "id": "17e09a7a-19ac-4e82-af54-dcced5791669",
762    "metadata": {},
763    "outputs": [],
764    "source": []
765   },
766   {
767    "cell_type": "code",
768    "execution_count": null,
769    "id": "33072593-30a4-49a0-8372-67012e1213eb",
770    "metadata": {},
771    "outputs": [],
772    "source": []
773   },
774   {
775    "cell_type": "code",
776    "execution_count": null,
777    "id": "090b36f5-fd34-4cc7-a84d-f8dd0d582000",
778    "metadata": {},
779    "outputs": [],
780    "source": []
781   },
782   {
783    "cell_type": "code",
784    "execution_count": null,
785    "id": "7deda359-1e7f-447a-97b7-576b98712a74",
786    "metadata": {},
787    "outputs": [],
788    "source": [
789     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
790     "params.update({'epochs': 20, 'learning_rate': 0.0001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
791     "              'activation': ['relu', 'relu'], 'features_list': ['Ed', 'Ew', 'rain']})\n",
792     "reproducibility.set_seed(123)\n",
793     "rnn = RNN(params)"
794    ]
795   },
796   {
797    "cell_type": "code",
798    "execution_count": null,
799    "id": "37d7f239-4d47-46d0-b891-b2a8d9da8c4e",
800    "metadata": {},
801    "outputs": [],
802    "source": [
803     "m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")"
804    ]
805   },
806   {
807    "cell_type": "code",
808    "execution_count": null,
809    "id": "150754f6-9927-4188-969a-d253bb0a5b22",
810    "metadata": {},
811    "outputs": [],
812    "source": [
813     "len(rnn_dat.X)"
814    ]
815   },
816   {
817    "cell_type": "code",
818    "execution_count": null,
819    "id": "8592cce5-77fe-4804-8df2-de92f058d11f",
820    "metadata": {},
821    "outputs": [],
822    "source": [
823     "len(rnn_dat.X_test)"
824    ]
825   },
826   {
827    "cell_type": "code",
828    "execution_count": null,
829    "id": "3047a7a3-32c2-4af0-aff9-bebdb1a877c1",
830    "metadata": {},
831    "outputs": [],
832    "source": [
833     "preds0 = rnn.predict(rnn_dat.X_test[0])"
834    ]
835   },
836   {
837    "cell_type": "code",
838    "execution_count": null,
839    "id": "2880b410-35a2-4d6c-ac28-2e2f366ec3a2",
840    "metadata": {},
841    "outputs": [],
842    "source": [
843     "rmse(preds0, rnn_dat.y_test[0])"
844    ]
845   },
846   {
847    "cell_type": "code",
848    "execution_count": null,
849    "id": "87623222-5fba-4833-8873-01933e9aba88",
850    "metadata": {},
851    "outputs": [],
852    "source": [
853     "plt.plot(rnn_dat.y_test[0])\n",
854     "plt.plot(preds0)"
855    ]
856   },
857   {
858    "cell_type": "code",
859    "execution_count": null,
860    "id": "7f922046-e74f-424e-aa1c-d6d4b2eb3a46",
861    "metadata": {},
862    "outputs": [],
863    "source": []
864   },
865   {
866    "cell_type": "code",
867    "execution_count": null,
868    "id": "47a098c2-28c3-483d-b062-da1d534f7766",
869    "metadata": {},
870    "outputs": [],
871    "source": []
872   },
873   {
874    "cell_type": "code",
875    "execution_count": null,
876    "id": "4a581630-2dc0-4cdb-8647-c81d41e149bc",
877    "metadata": {},
878    "outputs": [],
879    "source": []
880   },
881   {
882    "cell_type": "code",
883    "execution_count": null,
884    "id": "36b931d4-15dc-41a8-8748-610a2406ccad",
885    "metadata": {},
886    "outputs": [],
887    "source": []
888   },
889   {
890    "cell_type": "code",
891    "execution_count": null,
892    "id": "055d98f5-4028-4822-b409-b03d437490da",
893    "metadata": {},
894    "outputs": [],
895    "source": []
896   },
897   {
898    "cell_type": "code",
899    "execution_count": null,
900    "id": "beb357ab-16dc-4c91-a121-6dfc509f4ff6",
901    "metadata": {},
902    "outputs": [],
903    "source": []
904   },
905   {
906    "cell_type": "code",
907    "execution_count": null,
908    "id": "a319b314-b156-47af-8541-f97145352e5c",
909    "metadata": {},
910    "outputs": [],
911    "source": []
912   },
913   {
914    "cell_type": "code",
915    "execution_count": null,
916    "id": "b6922358-b824-4c77-abe4-c9b605a78738",
917    "metadata": {},
918    "outputs": [],
919    "source": []
920   },
921   {
922    "cell_type": "markdown",
923    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
924    "metadata": {},
925    "source": [
926     "## LSTM\n",
927     "\n",
928     "TODO: FIX BELOW"
929    ]
930   },
931   {
932    "cell_type": "code",
933    "execution_count": null,
934    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
935    "metadata": {},
936    "outputs": [],
937    "source": [
938     "import importlib \n",
939     "import moisture_rnn\n",
940     "importlib.reload(moisture_rnn)\n",
941     "from moisture_rnn import RNN_LSTM"
942    ]
943   },
944   {
945    "cell_type": "code",
946    "execution_count": null,
947    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
948    "metadata": {},
949    "outputs": [],
950    "source": [
951     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
952     "params = RNNParams(params)"
953    ]
954   },
955   {
956    "cell_type": "code",
957    "execution_count": null,
958    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
959    "metadata": {},
960    "outputs": [],
961    "source": [
962     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
963     "params.update({'epochs': 20, 'learning_rate': 0.0001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
964     "              'activation': ['relu', 'relu'], 'features_list': ['Ed', 'Ew', 'rain']})\n",
965     "reproducibility.set_seed(123)\n",
966     "lstm = RNN_LSTM(params)\n",
967     "\n",
968     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
969     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
970     "                    callbacks = [ResetStatesCallback(params),\n",
971     "                                EarlyStoppingCallback(patience = params['early_stopping_patience'])],\n",
972     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
973     "              "
974    ]
975   },
976   {
977    "cell_type": "code",
978    "execution_count": null,
979    "id": "de0c00e7-838f-41b6-9cc5-70594656d155",
980    "metadata": {},
981    "outputs": [],
982    "source": []
983   },
984   {
985    "cell_type": "code",
986    "execution_count": null,
987    "id": "430a2224-6798-48fa-b198-a32800f88f66",
988    "metadata": {},
989    "outputs": [],
990    "source": []
991   },
992   {
993    "cell_type": "code",
994    "execution_count": null,
995    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
996    "metadata": {},
997    "outputs": [],
998    "source": []
999   },
1000   {
1001    "cell_type": "code",
1002    "execution_count": null,
1003    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
1004    "metadata": {},
1005    "outputs": [],
1006    "source": []
1007   },
1008   {
1009    "cell_type": "code",
1010    "execution_count": null,
1011    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
1012    "metadata": {},
1013    "outputs": [],
1014    "source": []
1015   },
1016   {
1017    "cell_type": "code",
1018    "execution_count": null,
1019    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
1020    "metadata": {},
1021    "outputs": [],
1022    "source": []
1023   }
1024  ],
1025  "metadata": {
1026   "kernelspec": {
1027    "display_name": "Python 3 (ipykernel)",
1028    "language": "python",
1029    "name": "python3"
1030   },
1031   "language_info": {
1032    "codemirror_mode": {
1033     "name": "ipython",
1034     "version": 3
1035    },
1036    "file_extension": ".py",
1037    "mimetype": "text/x-python",
1038    "name": "python",
1039    "nbconvert_exporter": "python",
1040    "pygments_lexer": "ipython3",
1041    "version": "3.12.5"
1042   }
1043  },
1044  "nbformat": 4,
1045  "nbformat_minor": 5