Update rnn_workshop.ipynb
[notebooks.git] / fmda / rnn_workshop.ipynb
blob05b97f8a7f41d185225a9c83ab5ca8802ba24097
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, process_train_dict\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
57    "metadata": {},
58    "source": [
59     "## Test Data"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "3efed1fa-9cda-4934-8a6c-edcf179c8755",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "file_paths = ['data/fmda_nw_202401-05_f05.pkl']"
70    ]
71   },
72   {
73    "cell_type": "code",
74    "execution_count": null,
75    "id": "28fd3746-1861-4afa-ab7e-ac449fbed322",
76    "metadata": {},
77    "outputs": [],
78    "source": [
79     "# Params used for data filtering\n",
80     "params_data = read_yml(\"params_data.yaml\") \n",
81     "params_data"
82    ]
83   },
84   {
85    "cell_type": "code",
86    "execution_count": null,
87    "id": "32a46674-8377-47f8-9c3a-e6b07f9505cf",
88    "metadata": {},
89    "outputs": [],
90    "source": [
91     "params = read_yml(\"params.yaml\", subkey='rnn') \n",
92     "params = RNNParams(params)\n",
93     "params.update({'epochs': 200, \n",
94     "               'learning_rate': 0.001,\n",
95     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
96     "               'recurrent_layers': 2, 'recurrent_units': 30, \n",
97     "               'dense_layers': 2, 'dense_units': 30,\n",
98     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
99     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
100     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
101     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
102     "               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
103     "               'timesteps': 12\n",
104     "              })"
105    ]
106   },
107   {
108    "cell_type": "code",
109    "execution_count": null,
110    "id": "91466d1b-3106-4b49-8ee8-47cbf58f938c",
111    "metadata": {
112     "scrolled": true
113    },
114    "outputs": [],
115    "source": [
116     "from data_funcs import build_train_dict2\n",
117     "train2 = build_train_dict2(file_paths, params_data, spatial=False)"
118    ]
119   },
120   {
121    "cell_type": "code",
122    "execution_count": null,
123    "id": "117dcd97-fa2a-4324-9a0e-6f7723c9c0ce",
124    "metadata": {
125     "scrolled": true
126    },
127    "outputs": [],
128    "source": [
129     "train = process_train_dict(file_paths, atm_dict=\"HRRR\", params_data = params_data, verbose=True, spatial=False)"
130    ]
131   },
132   {
133    "cell_type": "code",
134    "execution_count": null,
135    "id": "ce8e3803-44c7-4602-882c-8c69530546e7",
136    "metadata": {},
137    "outputs": [],
138    "source": [
139     "key = \"PLFI1_202401\"\n",
140     "train[key]['features_list']"
141    ]
142   },
143   {
144    "cell_type": "code",
145    "execution_count": null,
146    "id": "ca760d99-5348-4437-b30d-c5fcb252a6a1",
147    "metadata": {},
148    "outputs": [],
149    "source": [
150     "train2[key]['features_list']"
151    ]
152   },
153   {
154    "cell_type": "code",
155    "execution_count": null,
156    "id": "c3557d4b-162e-45bb-b297-4192e6b977d1",
157    "metadata": {
158     "scrolled": true
159    },
160    "outputs": [],
161    "source": [
162     "for k in train:\n",
163     "    print(\"~\"*50)\n",
164     "    print(k)\n",
165     "    print(np.all(train[k]['X'][:,-1] == train2[k]['X'][:,-1]))"
166    ]
167   },
168   {
169    "cell_type": "code",
170    "execution_count": null,
171    "id": "d378b9fd-c92d-4c76-b039-364fb10512d9",
172    "metadata": {},
173    "outputs": [],
174    "source": [
175     "reproducibility.set_seed(123)"
176    ]
177   },
178   {
179    "cell_type": "code",
180    "execution_count": null,
181    "id": "3ec92f91-90cc-4096-a091-9ec54870ea77",
182    "metadata": {},
183    "outputs": [],
184    "source": [
185     "from itertools import islice\n",
186     "train = {k: train[k] for k in islice(train, 250)}"
187    ]
188   },
189   {
190    "cell_type": "code",
191    "execution_count": null,
192    "id": "0fb060ed-1239-44da-8c9b-923ff2004e38",
193    "metadata": {},
194    "outputs": [],
195    "source": [
196     "from data_funcs import combine_nested"
197    ]
198   },
199   {
200    "cell_type": "code",
201    "execution_count": null,
202    "id": "377ff40a-c8f3-469b-b5f5-bd46b6dc1ae1",
203    "metadata": {},
204    "outputs": [],
205    "source": [
206     "d1 = RNNData(\n",
207     "    combine_nested(train), # input dictionary\n",
208     "    scaler=\"standard\",  # data scaling type\n",
209     "    features_list = params['features_list'] # features for predicting outcome\n",
210     ")\n",
211     "\n",
212     "\n",
213     "d1.train_test_split(   \n",
214     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
215     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
216     ")\n",
217     "d1.scale_data()\n",
218     "\n",
219     "d1.batch_reshape(\n",
220     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
221     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
222     ")"
223    ]
224   },
225   {
226    "cell_type": "code",
227    "execution_count": null,
228    "id": "cc7ea493-6beb-43c6-a933-032d30b8415f",
229    "metadata": {},
230    "outputs": [],
231    "source": [
232     "# Update Params specific to spatial training\n",
233     "params.update({\n",
234     "    'loc_batch_reset': d1.n_seqs # Used to reset hidden state when location changes for a given batch\n",
235     "})"
236    ]
237   },
238   {
239    "cell_type": "code",
240    "execution_count": null,
241    "id": "11c2ec92-e51d-4015-88ed-5a5c3ea4a58f",
242    "metadata": {},
243    "outputs": [],
244    "source": [
245     "reproducibility.set_seed(123)\n",
246     "rnn_sp = RNN(params)\n",
247     "m, errs = rnn_sp.run_model(d1)"
248    ]
249   },
250   {
251    "cell_type": "code",
252    "execution_count": null,
253    "id": "7b17fa96-eb8f-455f-80c3-24b384ae65e7",
254    "metadata": {},
255    "outputs": [],
256    "source": [
257     "errs.mean()"
258    ]
259   },
260   {
261    "cell_type": "code",
262    "execution_count": null,
263    "id": "5403b69a-b0c2-45d1-be52-d232fcfbe7d9",
264    "metadata": {},
265    "outputs": [],
266    "source": []
267   },
268   {
269    "cell_type": "code",
270    "execution_count": null,
271    "id": "baefb163-198e-4b8c-920a-3c520eba8579",
272    "metadata": {},
273    "outputs": [],
274    "source": [
275     "from itertools import islice\n",
276     "train2 = {k: train2[k] for k in islice(train2, 250)}"
277    ]
278   },
279   {
280    "cell_type": "code",
281    "execution_count": null,
282    "id": "54b3ea85-a4ed-4b9b-8fb5-4c2669100177",
283    "metadata": {},
284    "outputs": [],
285    "source": [
286     "d2 = RNNData(\n",
287     "    combine_nested(train), # input dictionary\n",
288     "    scaler=\"standard\",  # data scaling type\n",
289     "    features_list = params['features_list'] # features for predicting outcome\n",
290     ")\n",
291     "\n",
292     "\n",
293     "d2.train_test_split(   \n",
294     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
295     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
296     ")\n",
297     "d2.scale_data()\n",
298     "\n",
299     "d2.batch_reshape(\n",
300     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
301     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
302     ")\n",
303     "# Update Params specific to spatial training\n",
304     "params.update({\n",
305     "    'loc_batch_reset': d2.n_seqs # Used to reset hidden state when location changes for a given batch\n",
306     "})"
307    ]
308   },
309   {
310    "cell_type": "code",
311    "execution_count": null,
312    "id": "aa0b0073-7b0e-4871-9c87-8d19e1c49758",
313    "metadata": {},
314    "outputs": [],
315    "source": [
316     "reproducibility.set_seed(123)\n",
317     "rnn2 = RNN(params)\n",
318     "m2, errs2 = rnn2.run_model(d2)"
319    ]
320   },
321   {
322    "cell_type": "code",
323    "execution_count": null,
324    "id": "89cdbdb5-2eee-4412-92d2-d47ef3e3549e",
325    "metadata": {},
326    "outputs": [],
327    "source": [
328     "errs2.mean()"
329    ]
330   },
331   {
332    "cell_type": "code",
333    "execution_count": null,
334    "id": "6b045231-710c-452a-bfc4-214d5e148cd8",
335    "metadata": {},
336    "outputs": [],
337    "source": []
338   },
339   {
340    "cell_type": "code",
341    "execution_count": null,
342    "id": "05184bee-a561-4541-a3e6-7dd63cb491f7",
343    "metadata": {},
344    "outputs": [],
345    "source": []
346   },
347   {
348    "cell_type": "code",
349    "execution_count": null,
350    "id": "b722325b-af18-402d-acda-daf4586e6bbc",
351    "metadata": {},
352    "outputs": [],
353    "source": []
354   },
355   {
356    "cell_type": "code",
357    "execution_count": null,
358    "id": "2685870d-6b05-4228-97f9-e017b2a4d1ee",
359    "metadata": {},
360    "outputs": [],
361    "source": []
362   },
363   {
364    "cell_type": "code",
365    "execution_count": null,
366    "id": "641aa6cd-80cf-4d62-a13c-6a7de4644778",
367    "metadata": {},
368    "outputs": [],
369    "source": []
370   },
371   {
372    "cell_type": "code",
373    "execution_count": null,
374    "id": "c110fa06-1eb4-4f24-aca4-fd852e5297c5",
375    "metadata": {},
376    "outputs": [],
377    "source": [
378     "from data_funcs import combine_nested"
379    ]
380   },
381   {
382    "cell_type": "code",
383    "execution_count": null,
384    "id": "2143ecb6-6edb-4948-8a30-698cfaceefa2",
385    "metadata": {},
386    "outputs": [],
387    "source": [
388     "nest = combine_nested(train)"
389    ]
390   },
391   {
392    "cell_type": "code",
393    "execution_count": null,
394    "id": "f8ce412b-cc9f-4273-b423-21cb53002258",
395    "metadata": {},
396    "outputs": [],
397    "source": [
398     "nest2 = combine_nested(train2)"
399    ]
400   },
401   {
402    "cell_type": "code",
403    "execution_count": null,
404    "id": "64ec6aa5-0d44-4e84-a5e7-89b2c4b4ec1a",
405    "metadata": {},
406    "outputs": [],
407    "source": [
408     "nest.keys()"
409    ]
410   },
411   {
412    "cell_type": "code",
413    "execution_count": null,
414    "id": "657e77d5-b0db-4a92-8171-0cc4e9796bb9",
415    "metadata": {},
416    "outputs": [],
417    "source": [
418     "nest2.keys()"
419    ]
420   },
421   {
422    "cell_type": "code",
423    "execution_count": null,
424    "id": "f4b2cd20-cac7-4870-8b51-2cf0726ff286",
425    "metadata": {},
426    "outputs": [],
427    "source": []
428   },
429   {
430    "cell_type": "code",
431    "execution_count": null,
432    "id": "04809b38-61af-47eb-bca5-aed80167e0ec",
433    "metadata": {},
434    "outputs": [],
435    "source": []
436   },
437   {
438    "cell_type": "code",
439    "execution_count": null,
440    "id": "8404317e-f13e-4758-82cf-07549ee9efc1",
441    "metadata": {},
442    "outputs": [],
443    "source": []
444   },
445   {
446    "cell_type": "code",
447    "execution_count": null,
448    "id": "e73de0d3-b57b-41e4-8ea1-fe7d4ac69c9b",
449    "metadata": {},
450    "outputs": [],
451    "source": []
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "3582f92a-bf5b-45b7-b8ae-ea50f7ae46cd",
457    "metadata": {},
458    "outputs": [],
459    "source": []
460   },
461   {
462    "cell_type": "code",
463    "execution_count": null,
464    "id": "a54246b4-f093-4c4f-be6b-dbe9d7a8a3fd",
465    "metadata": {},
466    "outputs": [],
467    "source": []
468   },
469   {
470    "cell_type": "markdown",
471    "id": "6322f0bc-107d-40a5-96dc-804495085a99",
472    "metadata": {
473     "jp-MarkdownHeadingCollapsed": true
474    },
475    "source": [
476     "## Test Other ML"
477    ]
478   },
479   {
480    "cell_type": "code",
481    "execution_count": null,
482    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
483    "metadata": {},
484    "outputs": [],
485    "source": [
486     "params = read_yml(\"params.yaml\", subkey='xgb')\n",
487     "params"
488    ]
489   },
490   {
491    "cell_type": "code",
492    "execution_count": null,
493    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
494    "metadata": {},
495    "outputs": [],
496    "source": [
497     "dat = read_pkl(\"data/train.pkl\")"
498    ]
499   },
500   {
501    "cell_type": "code",
502    "execution_count": null,
503    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
504    "metadata": {},
505    "outputs": [],
506    "source": [
507     "cases = [*dat.keys()]"
508    ]
509   },
510   {
511    "cell_type": "code",
512    "execution_count": null,
513    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
514    "metadata": {},
515    "outputs": [],
516    "source": [
517     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
518     "rnn_dat.train_test_split(\n",
519     "    time_fracs = [.8, .1, .1]\n",
520     ")\n",
521     "rnn_dat.scale_data()"
522    ]
523   },
524   {
525    "cell_type": "code",
526    "execution_count": null,
527    "id": "e79f8dc8-5cf8-4190-b4ff-e640f61bd78b",
528    "metadata": {},
529    "outputs": [],
530    "source": [
531     "from moisture_models import XGB, RF, LM"
532    ]
533   },
534   {
535    "cell_type": "code",
536    "execution_count": null,
537    "id": "b3aeb47f-261e-4e29-9eeb-67215e5628f6",
538    "metadata": {},
539    "outputs": [],
540    "source": [
541     "mod = XGB(params)"
542    ]
543   },
544   {
545    "cell_type": "code",
546    "execution_count": null,
547    "id": "cae9a20d-1caf-45aa-a9c4-aef21b65d9c8",
548    "metadata": {},
549    "outputs": [],
550    "source": [
551     "mod.params"
552    ]
553   },
554   {
555    "cell_type": "code",
556    "execution_count": null,
557    "id": "68a07b25-c586-4fc4-a3d5-c857354e7a2c",
558    "metadata": {},
559    "outputs": [],
560    "source": [
561     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)"
562    ]
563   },
564   {
565    "cell_type": "code",
566    "execution_count": null,
567    "id": "c8f88819-0a7a-4420-abb9-56a47015a4de",
568    "metadata": {},
569    "outputs": [],
570    "source": [
571     "preds = mod.predict(rnn_dat.X_test)"
572    ]
573   },
574   {
575    "cell_type": "code",
576    "execution_count": null,
577    "id": "cb7cdf14-74d6-45e4-bc1b-7d4d47dd41ac",
578    "metadata": {},
579    "outputs": [],
580    "source": [
581     "rmse(preds, rnn_dat.y_test)"
582    ]
583   },
584   {
585    "cell_type": "code",
586    "execution_count": null,
587    "id": "74d478c7-8c01-448e-9a00-dd0e1ee8e325",
588    "metadata": {},
589    "outputs": [],
590    "source": [
591     "plt.plot(rnn_dat.y_test)\n",
592     "plt.plot(preds)"
593    ]
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "c5441014-c39a-4414-a779-95b81e1ed6a8",
599    "metadata": {},
600    "outputs": [],
601    "source": [
602     "params = read_yml(\"params.yaml\", subkey='rf')\n",
603     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
604     "rnn_dat.train_test_split(\n",
605     "    time_fracs = [.8, .1, .1]\n",
606     ")"
607    ]
608   },
609   {
610    "cell_type": "code",
611    "execution_count": null,
612    "id": "cafe711a-20cb-4bd3-a4bc-4995a843a021",
613    "metadata": {},
614    "outputs": [],
615    "source": [
616     "import importlib\n",
617     "import moisture_models\n",
618     "importlib.reload(moisture_models)"
619    ]
620   },
621   {
622    "cell_type": "code",
623    "execution_count": null,
624    "id": "ee45f7d6-f57f-4ff6-995a-527565565f94",
625    "metadata": {},
626    "outputs": [],
627    "source": [
628     "params"
629    ]
630   },
631   {
632    "cell_type": "code",
633    "execution_count": null,
634    "id": "fafe76e5-0212-4bd1-a058-535935a08780",
635    "metadata": {},
636    "outputs": [],
637    "source": [
638     "mod2 = RF(params)\n",
639     "mod2.fit(rnn_dat.X_train, rnn_dat.y_train.flatten())\n",
640     "preds2 = mod2.predict(rnn_dat.X_test)\n",
641     "print(rmse(preds2, rnn_dat.y_test.flatten()))\n",
642     "plt.plot(rnn_dat.y_test)\n",
643     "plt.plot(preds2)"
644    ]
645   },
646   {
647    "cell_type": "code",
648    "execution_count": null,
649    "id": "c0ab4244-996c-49af-bf4a-8b0c47b0b6db",
650    "metadata": {},
651    "outputs": [],
652    "source": [
653     "from moisture_models import RF\n",
654     "mod2 = RF(params)"
655    ]
656   },
657   {
658    "cell_type": "code",
659    "execution_count": null,
660    "id": "aa6c33fd-db35-4c77-9eee-fdb39a934959",
661    "metadata": {},
662    "outputs": [],
663    "source": []
664   },
665   {
666    "cell_type": "code",
667    "execution_count": null,
668    "id": "c5598bfe-2d87-4d23-869e-aff127782462",
669    "metadata": {},
670    "outputs": [],
671    "source": [
672     "params = read_yml(\"params.yaml\", subkey='lm')\n",
673     "rnn_dat = RNNData(dat[cases[10]], features_list = ['Ed', 'Ew', 'solar', 'wind', 'rain'])\n",
674     "rnn_dat.train_test_split(\n",
675     "    time_fracs = [.8, .1, .1]\n",
676     ")\n",
677     "mod = LM(params)"
678    ]
679   },
680   {
681    "cell_type": "code",
682    "execution_count": null,
683    "id": "d828c15c-4078-4967-abff-c1fd15d4696d",
684    "metadata": {},
685    "outputs": [],
686    "source": [
687     "mod.fit(rnn_dat.X_train, rnn_dat.y_train)\n",
688     "preds = mod.predict(rnn_dat.X_test)\n",
689     "print(rmse(preds2, rnn_dat.y_test.flatten()))"
690    ]
691   },
692   {
693    "cell_type": "code",
694    "execution_count": null,
695    "id": "8496a32a-8269-4d6b-953e-7f33fe626789",
696    "metadata": {},
697    "outputs": [],
698    "source": []
699   },
700   {
701    "cell_type": "code",
702    "execution_count": null,
703    "id": "75ce8bf3-6efb-4dc7-b895-def92f6ce6b4",
704    "metadata": {},
705    "outputs": [],
706    "source": []
707   },
708   {
709    "cell_type": "markdown",
710    "id": "282cb651-b21f-401d-94c5-9e07530a9ba8",
711    "metadata": {},
712    "source": [
713     "## RNN"
714    ]
715   },
716   {
717    "cell_type": "code",
718    "execution_count": null,
719    "id": "fa38f35a-d367-4df8-b2d3-7691ff4b0cf4",
720    "metadata": {},
721    "outputs": [],
722    "source": []
723   },
724   {
725    "cell_type": "markdown",
726    "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0",
727    "metadata": {},
728    "source": [
729     "## Phys Initialized"
730    ]
731   },
732   {
733    "cell_type": "code",
734    "execution_count": null,
735    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
736    "metadata": {},
737    "outputs": [],
738    "source": [
739     "params.update({\n",
740     "    'epochs':100,\n",
741     "    'dense_layers': 0,\n",
742     "    'activation': ['relu', 'relu'],\n",
743     "    'phys_initialize': False,\n",
744     "    'dropout': [0,0]\n",
745     "})"
746    ]
747   },
748   {
749    "cell_type": "code",
750    "execution_count": null,
751    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
752    "metadata": {},
753    "outputs": [],
754    "source": [
755     "reproducibility.set_seed()\n",
756     "rnn = RNN(params)\n",
757     "m, errs = rnn.run_model(rnn_dat)"
758    ]
759   },
760   {
761    "cell_type": "code",
762    "execution_count": null,
763    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
764    "metadata": {},
765    "outputs": [],
766    "source": [
767     "rnn.model_train.summary()"
768    ]
769   },
770   {
771    "cell_type": "code",
772    "execution_count": null,
773    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
774    "metadata": {},
775    "outputs": [],
776    "source": []
777   },
778   {
779    "cell_type": "code",
780    "execution_count": null,
781    "id": "0aab34c7-8a09-480a-9d3e-619f7cf82b34",
782    "metadata": {},
783    "outputs": [],
784    "source": [
785     "params.update({\n",
786     "    'phys_initialize': True,\n",
787     "    'scaler': None, # TODO\n",
788     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
789     "    'activation': ['linear', 'linear'], # TODO tanh, relu the same\n",
790     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
791     "})"
792    ]
793   },
794   {
795    "cell_type": "code",
796    "execution_count": null,
797    "id": "ab549075-f71f-42ad-b36f-3d1e90247e33",
798    "metadata": {},
799    "outputs": [],
800    "source": [
801     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
802     "rnn_dat2.train_test_split(\n",
803     "    time_fracs = [.8, .1, .1]\n",
804     ")\n",
805     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
806    ]
807   },
808   {
809    "cell_type": "code",
810    "execution_count": null,
811    "id": "195f337a-ac8a-4471-8226-94863b9385e2",
812    "metadata": {},
813    "outputs": [],
814    "source": [
815     "import importlib\n",
816     "import moisture_rnn\n",
817     "importlib.reload(moisture_rnn)\n",
818     "from moisture_rnn import RNN, RNNData"
819    ]
820   },
821   {
822    "cell_type": "code",
823    "execution_count": null,
824    "id": "9395d147-17a5-44ba-aaa2-a213ffde062b",
825    "metadata": {
826     "scrolled": true
827    },
828    "outputs": [],
829    "source": [
830     "reproducibility.set_seed()\n",
831     "\n",
832     "rnn = RNN(params)"
833    ]
834   },
835   {
836    "cell_type": "code",
837    "execution_count": null,
838    "id": "d3eebe8a-ff12-454b-81b6-6a138924f127",
839    "metadata": {},
840    "outputs": [],
841    "source": [
842     "m, errs = rnn.run_model(rnn_dat2)"
843    ]
844   },
845   {
846    "cell_type": "code",
847    "execution_count": null,
848    "id": "bcbb0159-74c5-4f56-9d69-d85a58ddbd1a",
849    "metadata": {},
850    "outputs": [],
851    "source": [
852     "rnn.model_predict.get_weights()"
853    ]
854   },
855   {
856    "cell_type": "code",
857    "execution_count": null,
858    "id": "c25f741a-6280-4cf2-8017-e56672236fdb",
859    "metadata": {},
860    "outputs": [],
861    "source": []
862   },
863   {
864    "cell_type": "code",
865    "execution_count": null,
866    "id": "e8ed2b03-6123-4bdf-9e26-ef2ce4951663",
867    "metadata": {},
868    "outputs": [],
869    "source": [
870     "params['rnn_units']"
871    ]
872   },
873   {
874    "cell_type": "code",
875    "execution_count": null,
876    "id": "e44302bf-af49-4140-ae31-54f7c88a6735",
877    "metadata": {},
878    "outputs": [],
879    "source": [
880     "params.update({\n",
881     "    'phys_initialize': True,\n",
882     "    'scaler': None, # TODO\n",
883     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
884     "    'activation': ['relu', 'relu'], # TODO tanh, relu the same\n",
885     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
886     "})"
887    ]
888   },
889   {
890    "cell_type": "code",
891    "execution_count": null,
892    "id": "9a8ac32d-551c-43e8-988e-a3b13e6d9cd9",
893    "metadata": {},
894    "outputs": [],
895    "source": [
896     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
897     "rnn_dat2.train_test_split(\n",
898     "    time_fracs = [.8, .1, .1]\n",
899     ")\n",
900     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
901    ]
902   },
903   {
904    "cell_type": "code",
905    "execution_count": null,
906    "id": "ff727da8-38fb-4fda-999b-f712b98de0df",
907    "metadata": {
908     "scrolled": true
909    },
910    "outputs": [],
911    "source": [
912     "reproducibility.set_seed()\n",
913     "\n",
914     "rnn = RNN(params)\n",
915     "m, errs = rnn.run_model(rnn_dat2)"
916    ]
917   },
918   {
919    "cell_type": "code",
920    "execution_count": null,
921    "id": "b165074c-ea88-4b4d-8e41-6b6f22b4d221",
922    "metadata": {},
923    "outputs": [],
924    "source": []
925   },
926   {
927    "cell_type": "code",
928    "execution_count": null,
929    "id": "aa5cd4e6-4441-4c77-a086-e9edefbeb83b",
930    "metadata": {},
931    "outputs": [],
932    "source": []
933   },
934   {
935    "cell_type": "code",
936    "execution_count": null,
937    "id": "7bd1e05b-5cd8-48b4-8469-4842313d6097",
938    "metadata": {},
939    "outputs": [],
940    "source": []
941   },
942   {
943    "cell_type": "code",
944    "execution_count": null,
945    "id": "b399346d-20b8-4c97-898a-606a4be98065",
946    "metadata": {},
947    "outputs": [],
948    "source": []
949   },
950   {
951    "cell_type": "code",
952    "execution_count": null,
953    "id": "521285e6-6b6a-4d23-b688-9eb84b8eab68",
954    "metadata": {},
955    "outputs": [],
956    "source": []
957   },
958   {
959    "cell_type": "code",
960    "execution_count": null,
961    "id": "12c66af1-54fd-4398-8ee2-36eeb937c40d",
962    "metadata": {},
963    "outputs": [],
964    "source": []
965   },
966   {
967    "cell_type": "code",
968    "execution_count": null,
969    "id": "eb21fb8e-05c6-4a39-bdf1-4a57067c786d",
970    "metadata": {},
971    "outputs": [],
972    "source": []
973   },
974   {
975    "cell_type": "code",
976    "execution_count": null,
977    "id": "628a9105-ca06-44c4-ad00-13808e2f4773",
978    "metadata": {},
979    "outputs": [],
980    "source": []
981   },
982   {
983    "cell_type": "code",
984    "execution_count": null,
985    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
986    "metadata": {},
987    "outputs": [],
988    "source": []
989   },
990   {
991    "cell_type": "code",
992    "execution_count": null,
993    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
994    "metadata": {},
995    "outputs": [],
996    "source": []
997   },
998   {
999    "cell_type": "code",
1000    "execution_count": null,
1001    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
1002    "metadata": {},
1003    "outputs": [],
1004    "source": []
1005   },
1006   {
1007    "cell_type": "markdown",
1008    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
1009    "metadata": {},
1010    "source": [
1011     "## LSTM\n",
1012     "\n",
1013     "TODO: FIX BELOW"
1014    ]
1015   },
1016   {
1017    "cell_type": "code",
1018    "execution_count": null,
1019    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
1020    "metadata": {},
1021    "outputs": [],
1022    "source": [
1023     "import importlib \n",
1024     "import moisture_rnn\n",
1025     "importlib.reload(moisture_rnn)\n",
1026     "from moisture_rnn import RNN_LSTM"
1027    ]
1028   },
1029   {
1030    "cell_type": "code",
1031    "execution_count": null,
1032    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
1033    "metadata": {},
1034    "outputs": [],
1035    "source": [
1036     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
1037     "params = RNNParams(params)"
1038    ]
1039   },
1040   {
1041    "cell_type": "code",
1042    "execution_count": null,
1043    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
1044    "metadata": {},
1045    "outputs": [],
1046    "source": [
1047     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
1048     "rnn_dat.train_test_split(\n",
1049     "    time_fracs = [.8, .1, .1]\n",
1050     ")\n",
1051     "rnn_dat.scale_data()\n",
1052     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
1053    ]
1054   },
1055   {
1056    "cell_type": "code",
1057    "execution_count": null,
1058    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
1059    "metadata": {},
1060    "outputs": [],
1061    "source": [
1062     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
1063     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
1064     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
1065     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
1066     "reproducibility.set_seed(123)\n",
1067     "lstm = RNN_LSTM(params)\n",
1068     "\n",
1069     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
1070     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
1071     "                    callbacks = [ResetStatesCallback(params),\n",
1072     "                                EarlyStoppingCallback(patience = 15)],\n",
1073     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
1074     "              "
1075    ]
1076   },
1077   {
1078    "cell_type": "code",
1079    "execution_count": null,
1080    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
1081    "metadata": {},
1082    "outputs": [],
1083    "source": []
1084   },
1085   {
1086    "cell_type": "code",
1087    "execution_count": null,
1088    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
1089    "metadata": {},
1090    "outputs": [],
1091    "source": []
1092   },
1093   {
1094    "cell_type": "code",
1095    "execution_count": null,
1096    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
1097    "metadata": {},
1098    "outputs": [],
1099    "source": [
1100     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
1101     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
1102     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
1103     "              'early_stopping_patience': 25})\n",
1104     "reproducibility.set_seed(123)\n",
1105     "lstm = RNN_LSTM(params)\n",
1106     "m, errs = lstm.run_model(rnn_dat)"
1107    ]
1108   },
1109   {
1110    "cell_type": "code",
1111    "execution_count": null,
1112    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
1113    "metadata": {},
1114    "outputs": [],
1115    "source": []
1116   },
1117   {
1118    "cell_type": "code",
1119    "execution_count": null,
1120    "id": "00910bd2-f050-438c-ab3b-c793b83cb5f5",
1121    "metadata": {},
1122    "outputs": [],
1123    "source": [
1124     "rnn_dat.spatial"
1125    ]
1126   },
1127   {
1128    "cell_type": "code",
1129    "execution_count": null,
1130    "id": "236b33e3-e864-4453-be16-cf07338c4105",
1131    "metadata": {},
1132    "outputs": [],
1133    "source": [
1134     "params = RNNParams(read_yml(\"params.yaml\", subkey='lstm'))\n",
1135     "params"
1136    ]
1137   },
1138   {
1139    "cell_type": "code",
1140    "execution_count": null,
1141    "id": "fe2a484c-dc99-45a9-89fc-2f451bd719b5",
1142    "metadata": {},
1143    "outputs": [],
1144    "source": [
1145     "train = read_pkl(\"data/train.pkl\")"
1146    ]
1147   },
1148   {
1149    "cell_type": "code",
1150    "execution_count": null,
1151    "id": "07bfac87-a6d4-4dcc-8d11-adf83eafab76",
1152    "metadata": {},
1153    "outputs": [],
1154    "source": [
1155     "from itertools import islice\n",
1156     "train = {k: train[k] for k in islice(train, 100)}"
1157    ]
1158   },
1159   {
1160    "cell_type": "code",
1161    "execution_count": null,
1162    "id": "4e26099b-f760-4047-afec-9e751d24b7a6",
1163    "metadata": {},
1164    "outputs": [],
1165    "source": [
1166     "from data_funcs import combine_nested\n",
1167     "rnn_dat_sp = RNNData(\n",
1168     "    combine_nested(train), # input dictionary\n",
1169     "    scaler=\"standard\",  # data scaling type\n",
1170     "    features_list = params['features_list'] # features for predicting outcome\n",
1171     ")\n",
1172     "\n",
1173     "\n",
1174     "rnn_dat_sp.train_test_split(   \n",
1175     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
1176     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
1177     ")\n",
1178     "rnn_dat_sp.scale_data()\n",
1179     "\n",
1180     "rnn_dat_sp.batch_reshape(\n",
1181     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
1182     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
1183     ")"
1184    ]
1185   },
1186   {
1187    "cell_type": "code",
1188    "execution_count": null,
1189    "id": "10738795-c83b-4da3-88ba-09278caa35f8",
1190    "metadata": {},
1191    "outputs": [],
1192    "source": [
1193     "params.update({\n",
1194     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
1195     "})"
1196    ]
1197   },
1198   {
1199    "cell_type": "code",
1200    "execution_count": null,
1201    "id": "9c5d45cc-bcf0-4b6c-9c51-c4c790a2d9a5",
1202    "metadata": {},
1203    "outputs": [],
1204    "source": [
1205     "rnn_sp = RNN_LSTM(params)\n",
1206     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
1207    ]
1208   },
1209   {
1210    "cell_type": "code",
1211    "execution_count": null,
1212    "id": "ee332ccf-4e4a-4f66-b4d6-c079dbdb1411",
1213    "metadata": {},
1214    "outputs": [],
1215    "source": [
1216     "errs.mean()"
1217    ]
1218   },
1219   {
1220    "cell_type": "code",
1221    "execution_count": null,
1222    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
1223    "metadata": {},
1224    "outputs": [],
1225    "source": []
1226   }
1227  ],
1228  "metadata": {
1229   "kernelspec": {
1230    "display_name": "Python 3 (ipykernel)",
1231    "language": "python",
1232    "name": "python3"
1233   },
1234   "language_info": {
1235    "codemirror_mode": {
1236     "name": "ipython",
1237     "version": 3
1238    },
1239    "file_extension": ".py",
1240    "mimetype": "text/x-python",
1241    "name": "python",
1242    "nbconvert_exporter": "python",
1243    "pygments_lexer": "ipython3",
1244    "version": "3.12.5"
1245   }
1246  },
1247  "nbformat": 4,
1248  "nbformat_minor": 5