Update .gitignore
[notebooks.git] / fmda / rnn_workshop.ipynb
blob5686dcf1f4747b00ef2a47d13aa9d7d3eb20708b
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, process_train_dict\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
57    "metadata": {},
58    "source": [
59     "## Tests"
60    ]
61   },
62   {
63    "cell_type": "markdown",
64    "id": "6322f0bc-107d-40a5-96dc-804495085a99",
65    "metadata": {
66     "jp-MarkdownHeadingCollapsed": true
67    },
68    "source": [
69     "## Test Other ML"
70    ]
71   },
72   {
73    "cell_type": "code",
74    "execution_count": null,
75    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
76    "metadata": {},
77    "outputs": [],
78    "source": [
79     "params = read_yml(\"params.yaml\", subkey='xgb')\n",
80     "params"
81    ]
82   },
83   {
84    "cell_type": "code",
85    "execution_count": null,
86    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
87    "metadata": {},
88    "outputs": [],
89    "source": [
90     "dat = read_pkl(\"data/train.pkl\")"
91    ]
92   },
93   {
94    "cell_type": "code",
95    "execution_count": null,
96    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
97    "metadata": {},
98    "outputs": [],
99    "source": [
100     "cases = [*dat.keys()]"
101    ]
102   },
103   {
104    "cell_type": "code",
105    "execution_count": null,
106    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
107    "metadata": {},
108    "outputs": [],
109    "source": [
110     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
111     "rnn_dat.train_test_split(\n",
112     "    time_fracs = [.8, .1, .1]\n",
113     ")\n",
114     "rnn_dat.scale_data()"
115    ]
116   },
117   {
118    "cell_type": "code",
119    "execution_count": null,
120    "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
121    "metadata": {},
122    "outputs": [],
123    "source": [
124     "from moisture_models import XGB, RF, LM"
125    ]
126   },
127   {
128    "cell_type": "code",
129    "execution_count": null,
130    "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
131    "metadata": {},
132    "outputs": [],
133    "source": [
134     "mod = XGB(params)"
135    ]
136   },
137   {
138    "cell_type": "code",
139    "execution_count": null,
140    "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
141    "metadata": {},
142    "outputs": [],
143    "source": [
144     "mod.params"
145    ]
146   },
147   {
148    "cell_type": "code",
149    "execution_count": null,
150    "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
151    "metadata": {},
152    "outputs": [],
153    "source": [
154     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "preds = mod.predict(rnn_dat.X_test)"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "rmse(preds, rnn_dat.y_test)"
175    ]
176   },
177   {
178    "cell_type": "code",
179    "execution_count": null,
180    "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
181    "metadata": {},
182    "outputs": [],
183    "source": [
184     "plt.plot(rnn_dat.y_test)\n",
185     "plt.plot(preds)"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
192    "metadata": {},
193    "outputs": [],
194    "source": [
195     "params = read_yml(\"params.yaml\", subkey='rf')\n",
196     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
197     "rnn_dat.train_test_split(\n",
198     "    time_fracs = [.8, .1, .1]\n",
199     ")"
200    ]
201   },
202   {
203    "cell_type": "code",
204    "execution_count": null,
205    "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
206    "metadata": {},
207    "outputs": [],
208    "source": [
209     "import importlib\n",
210     "import moisture_models\n",
211     "importlib.reload(moisture_models)"
212    ]
213   },
214   {
215    "cell_type": "code",
216    "execution_count": null,
217    "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
218    "metadata": {},
219    "outputs": [],
220    "source": [
221     "params"
222    ]
223   },
224   {
225    "cell_type": "code",
226    "execution_count": null,
227    "id": "fafe76e5-0212-4bd1-a058-535935a08780",
228    "metadata": {},
229    "outputs": [],
230    "source": [
231     "mod2 = RF(params)\n",
232     "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
233     "preds2 = mod2.predict(rnn_dat.X_test)\n",
234     "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
235     "plt.plot(rnn_dat.y_test)\n",
236     "plt.plot(preds2)"
237    ]
238   },
239   {
240    "cell_type": "code",
241    "execution_count": null,
242    "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
243    "metadata": {},
244    "outputs": [],
245    "source": [
246     "from moisture_models import RF\n",
247     "mod2 = RF(params)"
248    ]
249   },
250   {
251    "cell_type": "code",
252    "execution_count": null,
253    "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
254    "metadata": {},
255    "outputs": [],
256    "source": []
257   },
258   {
259    "cell_type": "code",
260    "execution_count": null,
261    "id": "c5598bfe-2d87-4d23-869e-aff127782462",
262    "metadata": {},
263    "outputs": [],
264    "source": [
265     "params = read_yml(\"params.yaml\", subkey='lm')\n",
266     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
267     "rnn_dat.train_test_split(\n",
268     "    time_fracs = [.8, .1, .1]\n",
269     ")\n",
270     "mod = LM(params)"
271    ]
272   },
273   {
274    "cell_type": "code",
275    "execution_count": null,
276    "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
277    "metadata": {},
278    "outputs": [],
279    "source": [
280     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
281     "preds = mod.predict(rnn_dat.X_test)\n",
282     "print(rmse(preds2, rnn_dat.y_test.flatten()))"
283    ]
284   },
285   {
286    "cell_type": "code",
287    "execution_count": null,
288    "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
289    "metadata": {},
290    "outputs": [],
291    "source": []
292   },
293   {
294    "cell_type": "code",
295    "execution_count": null,
296    "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
297    "metadata": {},
298    "outputs": [],
299    "source": []
300   },
301   {
302    "cell_type": "markdown",
303    "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
304    "metadata": {},
305    "source": [
306     "## RNN"
307    ]
308   },
309   {
310    "cell_type": "code",
311    "execution_count": null,
312    "id": "8162841a-131b-4da6-a0c5-9d131a7cadf9",
313    "metadata": {},
314    "outputs": [],
315    "source": [
316     "dat = read_pkl(\"data/train.pkl\")"
317    ]
318   },
319   {
320    "cell_type": "code",
321    "execution_count": null,
322    "id": "96fe971b-c6d3-45ee-94ee-e4f426735d56",
323    "metadata": {},
324    "outputs": [],
325    "source": [
326     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))\n",
327     "params.update({\n",
328     "    'features_list': ['Ed', 'Ew', 'solar', 'wind', 'rain']\n",
329     "})"
330    ]
331   },
332   {
333    "cell_type": "code",
334    "execution_count": null,
335    "id": "5a55e8e7-1869-43fc-9bc6-09bd4f5a8d76",
336    "metadata": {},
337    "outputs": [],
338    "source": [
339     "rnn_dat2 = RNNData(dat[\"GSHPN_202401_set_2\"], params['scaler'], params['features_list'])\n",
340     "rnn_dat2.train_test_split(\n",
341     "    time_fracs = [.8, .1, .1]\n",
342     ")\n",
343     "rnn_dat2.scale_data()\n",
344     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
345    ]
346   },
347   {
348    "cell_type": "code",
349    "execution_count": null,
350    "id": "aaec14ac-c6a6-4fcd-ad8e-d28143b92623",
351    "metadata": {},
352    "outputs": [],
353    "source": [
354     "reproducibility.set_seed()\n",
355     "rnn = RNN(params)\n",
356     "m, errs = rnn.run_model(rnn_dat2, plot_period=\"predict\")"
357    ]
358   },
359   {
360    "cell_type": "code",
361    "execution_count": null,
362    "id": "e5a1b91d-c692-484a-9e7c-c39cc5a4b1f7",
363    "metadata": {},
364    "outputs": [],
365    "source": []
366   },
367   {
368    "cell_type": "code",
369    "execution_count": null,
370    "id": "2a16aa0f-bd62-458f-94ba-d66fe1ec33cb",
371    "metadata": {},
372    "outputs": [],
373    "source": [
374     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))"
375    ]
376   },
377   {
378    "cell_type": "code",
379    "execution_count": null,
380    "id": "7ce63a46-47bd-47d6-87a0-573a97ca2880",
381    "metadata": {},
382    "outputs": [],
383    "source": [
384     "params = RNNParams(params)\n",
385     "params.update({'epochs': 200, \n",
386     "               'learning_rate': 0.001,\n",
387     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
388     "               'recurrent_layers': 2, 'recurrent_units': 30, \n",
389     "               'dense_layers': 2, 'dense_units': 30,\n",
390     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
391     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
392     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
393     "               'bmax': rnn_dat2.hours, # Upper bound of hidden state batch reset, using max hours\n",
394     "               'features_list': ['Ed', 'Ew', 'rain', 'solar', 'wind'],\n",
395     "               'timesteps': 12\n",
396     "              })"
397    ]
398   },
399   {
400    "cell_type": "code",
401    "execution_count": null,
402    "id": "c2906b3f-6266-4bbc-9494-88ab36af58f9",
403    "metadata": {},
404    "outputs": [],
405    "source": [
406     "rnn_dat = RNNData(dat[\"FPRO3_202401_set_2\"], params['scaler'],  params['features_list'])\n",
407     "rnn_dat.train_test_split(\n",
408     "    time_fracs = [.8, .1, .1]\n",
409     ")\n",
410     "rnn_dat.scale_data()\n",
411     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
412    ]
413   },
414   {
415    "cell_type": "code",
416    "execution_count": null,
417    "id": "c79ed028-ba60-4db5-9864-d3b2c01e09c3",
418    "metadata": {},
419    "outputs": [],
420    "source": [
421     "reproducibility.set_seed()\n",
422     "rnn = RNN(params)\n",
423     "m, errs = rnn.run_model(rnn_dat)"
424    ]
425   },
426   {
427    "cell_type": "code",
428    "execution_count": null,
429    "id": "947c1581-8021-48de-81fb-2212c1c17253",
430    "metadata": {},
431    "outputs": [],
432    "source": [
433     "rnn_dat = RNNData(dat[\"GSHPN_202401_set_2\"], params['scaler'],  params['features_list'])\n",
434     "rnn_dat.train_test_split(\n",
435     "    time_fracs = [.8, .1, .1]\n",
436     ")\n",
437     "rnn_dat.scale_data()\n",
438     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
439    ]
440   },
441   {
442    "cell_type": "code",
443    "execution_count": null,
444    "id": "c134690e-599f-420e-b634-71a9201f0e83",
445    "metadata": {},
446    "outputs": [],
447    "source": [
448     "reproducibility.set_seed()\n",
449     "rnn = RNN(params)\n",
450     "m, errs = rnn.run_model(rnn_dat)"
451    ]
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "cd6c8267-a98a-4fa8-9e42-54aebf15b0ee",
457    "metadata": {},
458    "outputs": [],
459    "source": []
460   },
461   {
462    "cell_type": "code",
463    "execution_count": null,
464    "id": "f76779c2-88db-4483-9191-b7902d2fb6a5",
465    "metadata": {},
466    "outputs": [],
467    "source": []
468   },
469   {
470    "cell_type": "markdown",
471    "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
472    "metadata": {},
473    "source": [
474     "## Phys Initialized"
475    ]
476   },
477   {
478    "cell_type": "code",
479    "execution_count": null,
480    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
481    "metadata": {},
482    "outputs": [],
483    "source": [
484     "params.update({\n",
485     "    'epochs':100,\n",
486     "    'dense_layers': 0,\n",
487     "    'activation': ['relu', 'relu'],\n",
488     "    'phys_initialize': False,\n",
489     "    'dropout': [0,0]\n",
490     "})"
491    ]
492   },
493   {
494    "cell_type": "code",
495    "execution_count": null,
496    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
497    "metadata": {},
498    "outputs": [],
499    "source": [
500     "reproducibility.set_seed()\n",
501     "rnn = RNN(params)\n",
502     "m, errs = rnn.run_model(rnn_dat)"
503    ]
504   },
505   {
506    "cell_type": "code",
507    "execution_count": null,
508    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
509    "metadata": {},
510    "outputs": [],
511    "source": [
512     "rnn.model_train.summary()"
513    ]
514   },
515   {
516    "cell_type": "code",
517    "execution_count": null,
518    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
519    "metadata": {},
520    "outputs": [],
521    "source": []
522   },
523   {
524    "cell_type": "code",
525    "execution_count": null,
526    "id": "0aab34c7-8a09-480a-9d3e-619f7cf82b34",
527    "metadata": {},
528    "outputs": [],
529    "source": [
530     "params.update({\n",
531     "    'phys_initialize': True,\n",
532     "    'scaler': None, # TODO\n",
533     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
534     "    'activation': ['linear', 'linear'], # TODO tanh, relu the same\n",
535     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
536     "})"
537    ]
538   },
539   {
540    "cell_type": "code",
541    "execution_count": null,
542    "id": "ab549075-f71f-42ad-b36f-3d1e90247e33",
543    "metadata": {},
544    "outputs": [],
545    "source": [
546     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
547     "rnn_dat2.train_test_split(\n",
548     "    time_fracs = [.8, .1, .1]\n",
549     ")\n",
550     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
551    ]
552   },
553   {
554    "cell_type": "code",
555    "execution_count": null,
556    "id": "195f337a-ac8a-4471-8226-94863b9385e2",
557    "metadata": {},
558    "outputs": [],
559    "source": [
560     "import importlib\n",
561     "import moisture_rnn\n",
562     "importlib.reload(moisture_rnn)\n",
563     "from moisture_rnn import RNN, RNNData"
564    ]
565   },
566   {
567    "cell_type": "code",
568    "execution_count": null,
569    "id": "9395d147-17a5-44ba-aaa2-a213ffde062b",
570    "metadata": {
571     "scrolled": true
572    },
573    "outputs": [],
574    "source": [
575     "reproducibility.set_seed()\n",
576     "\n",
577     "rnn = RNN(params)"
578    ]
579   },
580   {
581    "cell_type": "code",
582    "execution_count": null,
583    "id": "d3eebe8a-ff12-454b-81b6-6a138924f127",
584    "metadata": {},
585    "outputs": [],
586    "source": [
587     "m, errs = rnn.run_model(rnn_dat2)"
588    ]
589   },
590   {
591    "cell_type": "code",
592    "execution_count": null,
593    "id": "bcbb0159-74c5-4f56-9d69-d85a58ddbd1a",
594    "metadata": {},
595    "outputs": [],
596    "source": [
597     "rnn.model_predict.get_weights()"
598    ]
599   },
600   {
601    "cell_type": "code",
602    "execution_count": null,
603    "id": "c25f741a-6280-4cf2-8017-e56672236fdb",
604    "metadata": {},
605    "outputs": [],
606    "source": []
607   },
608   {
609    "cell_type": "code",
610    "execution_count": null,
611    "id": "e8ed2b03-6123-4bdf-9e26-ef2ce4951663",
612    "metadata": {},
613    "outputs": [],
614    "source": [
615     "params['rnn_units']"
616    ]
617   },
618   {
619    "cell_type": "code",
620    "execution_count": null,
621    "id": "e44302bf-af49-4140-ae31-54f7c88a6735",
622    "metadata": {},
623    "outputs": [],
624    "source": [
625     "params.update({\n",
626     "    'phys_initialize': True,\n",
627     "    'scaler': None, # TODO\n",
628     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
629     "    'activation': ['relu', 'relu'], # TODO tanh, relu the same\n",
630     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
631     "})"
632    ]
633   },
634   {
635    "cell_type": "code",
636    "execution_count": null,
637    "id": "9a8ac32d-551c-43e8-988e-a3b13e6d9cd9",
638    "metadata": {},
639    "outputs": [],
640    "source": [
641     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
642     "rnn_dat2.train_test_split(\n",
643     "    time_fracs = [.8, .1, .1]\n",
644     ")\n",
645     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
646    ]
647   },
648   {
649    "cell_type": "code",
650    "execution_count": null,
651    "id": "ff727da8-38fb-4fda-999b-f712b98de0df",
652    "metadata": {
653     "scrolled": true
654    },
655    "outputs": [],
656    "source": [
657     "reproducibility.set_seed()\n",
658     "\n",
659     "rnn = RNN(params)\n",
660     "m, errs = rnn.run_model(rnn_dat2)"
661    ]
662   },
663   {
664    "cell_type": "code",
665    "execution_count": null,
666    "id": "b165074c-ea88-4b4d-8e41-6b6f22b4d221",
667    "metadata": {},
668    "outputs": [],
669    "source": []
670   },
671   {
672    "cell_type": "code",
673    "execution_count": null,
674    "id": "aa5cd4e6-4441-4c77-a086-e9edefbeb83b",
675    "metadata": {},
676    "outputs": [],
677    "source": []
678   },
679   {
680    "cell_type": "code",
681    "execution_count": null,
682    "id": "7bd1e05b-5cd8-48b4-8469-4842313d6097",
683    "metadata": {},
684    "outputs": [],
685    "source": []
686   },
687   {
688    "cell_type": "code",
689    "execution_count": null,
690    "id": "b399346d-20b8-4c97-898a-606a4be98065",
691    "metadata": {},
692    "outputs": [],
693    "source": []
694   },
695   {
696    "cell_type": "code",
697    "execution_count": null,
698    "id": "521285e6-6b6a-4d23-b688-9eb84b8eab68",
699    "metadata": {},
700    "outputs": [],
701    "source": []
702   },
703   {
704    "cell_type": "code",
705    "execution_count": null,
706    "id": "12c66af1-54fd-4398-8ee2-36eeb937c40d",
707    "metadata": {},
708    "outputs": [],
709    "source": []
710   },
711   {
712    "cell_type": "code",
713    "execution_count": null,
714    "id": "eb21fb8e-05c6-4a39-bdf1-4a57067c786d",
715    "metadata": {},
716    "outputs": [],
717    "source": []
718   },
719   {
720    "cell_type": "code",
721    "execution_count": null,
722    "id": "628a9105-ca06-44c4-ad00-13808e2f4773",
723    "metadata": {},
724    "outputs": [],
725    "source": []
726   },
727   {
728    "cell_type": "code",
729    "execution_count": null,
730    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
731    "metadata": {},
732    "outputs": [],
733    "source": []
734   },
735   {
736    "cell_type": "code",
737    "execution_count": null,
738    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
739    "metadata": {},
740    "outputs": [],
741    "source": []
742   },
743   {
744    "cell_type": "code",
745    "execution_count": null,
746    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
747    "metadata": {},
748    "outputs": [],
749    "source": []
750   },
751   {
752    "cell_type": "markdown",
753    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
754    "metadata": {},
755    "source": [
756     "## LSTM\n",
757     "\n",
758     "TODO: FIX BELOW"
759    ]
760   },
761   {
762    "cell_type": "code",
763    "execution_count": null,
764    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
765    "metadata": {},
766    "outputs": [],
767    "source": [
768     "import importlib \n",
769     "import moisture_rnn\n",
770     "importlib.reload(moisture_rnn)\n",
771     "from moisture_rnn import RNN_LSTM"
772    ]
773   },
774   {
775    "cell_type": "code",
776    "execution_count": null,
777    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
778    "metadata": {},
779    "outputs": [],
780    "source": [
781     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
782     "params = RNNParams(params)"
783    ]
784   },
785   {
786    "cell_type": "code",
787    "execution_count": null,
788    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
789    "metadata": {},
790    "outputs": [],
791    "source": [
792     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
793     "rnn_dat.train_test_split(\n",
794     "    time_fracs = [.8, .1, .1]\n",
795     ")\n",
796     "rnn_dat.scale_data()\n",
797     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
798    ]
799   },
800   {
801    "cell_type": "code",
802    "execution_count": null,
803    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
804    "metadata": {},
805    "outputs": [],
806    "source": [
807     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
808     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
809     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
810     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
811     "reproducibility.set_seed(123)\n",
812     "lstm = RNN_LSTM(params)\n",
813     "\n",
814     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
815     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
816     "                    callbacks = [ResetStatesCallback(params),\n",
817     "                                EarlyStoppingCallback(patience = 15)],\n",
818     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
819     "              "
820    ]
821   },
822   {
823    "cell_type": "code",
824    "execution_count": null,
825    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
826    "metadata": {},
827    "outputs": [],
828    "source": []
829   },
830   {
831    "cell_type": "code",
832    "execution_count": null,
833    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
834    "metadata": {},
835    "outputs": [],
836    "source": []
837   },
838   {
839    "cell_type": "code",
840    "execution_count": null,
841    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
842    "metadata": {},
843    "outputs": [],
844    "source": [
845     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
846     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
847     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
848     "              'early_stopping_patience': 25})\n",
849     "reproducibility.set_seed(123)\n",
850     "lstm = RNN_LSTM(params)\n",
851     "m, errs = lstm.run_model(rnn_dat)"
852    ]
853   },
854   {
855    "cell_type": "code",
856    "execution_count": null,
857    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
858    "metadata": {},
859    "outputs": [],
860    "source": []
861   },
862   {
863    "cell_type": "code",
864    "execution_count": null,
865    "id": "00910bd2-f050-438c-ab3b-c793b83cb5f5",
866    "metadata": {},
867    "outputs": [],
868    "source": [
869     "rnn_dat.spatial"
870    ]
871   },
872   {
873    "cell_type": "code",
874    "execution_count": null,
875    "id": "236b33e3-e864-4453-be16-cf07338c4105",
876    "metadata": {},
877    "outputs": [],
878    "source": [
879     "params = RNNParams(read_yml(\"params.yaml\", subkey='lstm'))\n",
880     "params"
881    ]
882   },
883   {
884    "cell_type": "code",
885    "execution_count": null,
886    "id": "fe2a484c-dc99-45a9-89fc-2f451bd719b5",
887    "metadata": {},
888    "outputs": [],
889    "source": [
890     "train = read_pkl(\"data/train.pkl\")"
891    ]
892   },
893   {
894    "cell_type": "code",
895    "execution_count": null,
896    "id": "07bfac87-a6d4-4dcc-8d11-adf83eafab76",
897    "metadata": {},
898    "outputs": [],
899    "source": [
900     "from itertools import islice\n",
901     "train = {k: train[k] for k in islice(train, 100)}"
902    ]
903   },
904   {
905    "cell_type": "code",
906    "execution_count": null,
907    "id": "4e26099b-f760-4047-afec-9e751d24b7a6",
908    "metadata": {},
909    "outputs": [],
910    "source": [
911     "from data_funcs import combine_nested\n",
912     "rnn_dat_sp = RNNData(\n",
913     "    combine_nested(train), # input dictionary\n",
914     "    scaler=\"standard\",  # data scaling type\n",
915     "    features_list = params['features_list'] # features for predicting outcome\n",
916     ")\n",
917     "\n",
918     "\n",
919     "rnn_dat_sp.train_test_split(   \n",
920     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
921     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
922     ")\n",
923     "rnn_dat_sp.scale_data()\n",
924     "\n",
925     "rnn_dat_sp.batch_reshape(\n",
926     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
927     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
928     ")"
929    ]
930   },
931   {
932    "cell_type": "code",
933    "execution_count": null,
934    "id": "10738795-c83b-4da3-88ba-09278caa35f8",
935    "metadata": {},
936    "outputs": [],
937    "source": [
938     "params.update({\n",
939     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
940     "})"
941    ]
942   },
943   {
944    "cell_type": "code",
945    "execution_count": null,
946    "id": "9c5d45cc-bcf0-4b6c-9c51-c4c790a2d9a5",
947    "metadata": {},
948    "outputs": [],
949    "source": [
950     "rnn_sp = RNN_LSTM(params)\n",
951     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
952    ]
953   },
954   {
955    "cell_type": "code",
956    "execution_count": null,
957    "id": "ee332ccf-4e4a-4f66-b4d6-c079dbdb1411",
958    "metadata": {},
959    "outputs": [],
960    "source": [
961     "errs.mean()"
962    ]
963   },
964   {
965    "cell_type": "code",
966    "execution_count": null,
967    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
968    "metadata": {},
969    "outputs": [],
970    "source": []
971   }
972  ],
973  "metadata": {
974   "kernelspec": {
975    "display_name": "Python 3 (ipykernel)",
976    "language": "python",
977    "name": "python3"
978   },
979   "language_info": {
980    "codemirror_mode": {
981     "name": "ipython",
982     "version": 3
983    },
984    "file_extension": ".py",
985    "mimetype": "text/x-python",
986    "name": "python",
987    "nbconvert_exporter": "python",
988    "pygments_lexer": "ipython3",
989    "version": "3.12.6"
990   }
991  },
992  "nbformat": 4,
993  "nbformat_minor": 5