Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / rnn_workshop.ipynb
blob6ceb5b0a307e5bab823cf196476ab7e2defd3bb8
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, create_rnn_data2\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "b62f4360-e9d1-4510-bb5d-1d79a3a5ac75",
57    "metadata": {},
58    "source": [
59     "## Tests"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "91155de8-1927-4485-80c9-095fb4d8613d",
66    "metadata": {
67     "scrolled": true
68    },
69    "outputs": [],
70    "source": [
71     "# data_file = 'data/test_CA_202401.pkl'\n",
72     "# train = pkl2train([data_file])\n",
73     "# # with open(\"train.pkl\", 'wb') as file:\n",
74     "# #     logging.info('Writing the rain cases into file %s',\"train.pkl\")\n",
75     "# #     pickle.dump(train, file)\n",
76     "# train.keys()\n",
77     "\n",
78     "train = read_pkl(\"train.pkl\")"
79    ]
80   },
81   {
82    "cell_type": "code",
83    "execution_count": null,
84    "id": "a5979f46-7e1a-4552-a92f-fc17d2ca75b4",
85    "metadata": {},
86    "outputs": [],
87    "source": [
88     "train.keys()"
89    ]
90   },
91   {
92    "cell_type": "code",
93    "execution_count": null,
94    "id": "1bc6a3c7-e068-4d83-9ab0-131eaeec289e",
95    "metadata": {},
96    "outputs": [],
97    "source": [
98     "case = \"FBFW1_202401\"\n",
99     "dat = train[case]"
100    ]
101   },
102   {
103    "cell_type": "code",
104    "execution_count": null,
105    "id": "84a4e8b6-9762-4406-b71e-bb6bf990fd9b",
106    "metadata": {},
107    "outputs": [],
108    "source": [
109     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
110     "params = RNNParams(params)\n",
111     "params"
112    ]
113   },
114   {
115    "cell_type": "code",
116    "execution_count": null,
117    "id": "92b7a5ed-5215-43a7-933c-5dd146241944",
118    "metadata": {},
119    "outputs": [],
120    "source": [
121     "rnn_dat = RNNData(dat, scaler = params['scaler'], features_list = params['features_list'])\n",
122     "rnn_dat.train_test_split(\n",
123     "    train_frac = .9,\n",
124     "    val_frac = .05\n",
125     ")\n",
126     "rnn_dat.scale_data()"
127    ]
128   },
129   {
130    "cell_type": "code",
131    "execution_count": null,
132    "id": "9141a034-5912-4087-aaaa-a4c221e77495",
133    "metadata": {},
134    "outputs": [],
135    "source": [
136     "import importlib\n",
137     "import moisture_rnn\n",
138     "importlib.reload(moisture_rnn)\n",
139     "from moisture_rnn import RNN"
140    ]
141   },
142   {
143    "cell_type": "code",
144    "execution_count": null,
145    "id": "2c570a67-0b67-4fe4-a536-68cf5a1f8256",
146    "metadata": {},
147    "outputs": [],
148    "source": [
149     "params.update({'epochs': 20})\n",
150     "reproducibility.set_seed(123)\n",
151     "rnn = RNN(params)\n"
152    ]
153   },
154   {
155    "cell_type": "code",
156    "execution_count": null,
157    "id": "034300ff-a71e-4807-85d7-368197962ab5",
158    "metadata": {},
159    "outputs": [],
160    "source": [
161     "m, errs = rnn.run_model(rnn_dat)"
162    ]
163   },
164   {
165    "cell_type": "code",
166    "execution_count": null,
167    "id": "1f5183e0-5e58-4ff3-9a1f-0f707241a698",
168    "metadata": {},
169    "outputs": [],
170    "source": [
171     "rnn.model_train"
172    ]
173   },
174   {
175    "cell_type": "code",
176    "execution_count": null,
177    "id": "0dc33ddc-cb10-4b0c-b7f2-af3302e28a23",
178    "metadata": {},
179    "outputs": [],
180    "source": []
181   },
182   {
183    "cell_type": "code",
184    "execution_count": null,
185    "id": "ebb4d0d3-0a9e-428e-bd56-6899bde36df7",
186    "metadata": {},
187    "outputs": [],
188    "source": []
189   },
190   {
191    "cell_type": "markdown",
192    "id": "a42cc05f-1438-459f-9a15-64276aa2f651",
193    "metadata": {},
194    "source": [
195     "## Test RNN"
196    ]
197   },
198   {
199    "cell_type": "code",
200    "execution_count": null,
201    "id": "d27c7277-c9e1-4fd1-b050-d4b6bc737822",
202    "metadata": {},
203    "outputs": [],
204    "source": [
205     "import importlib\n",
206     "import utils\n",
207     "importlib.reload(utils)\n",
208     "from utils import read_pkl"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "7a9a5414-d33f-4f0d-b1e5-b320b22d60f5",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "train = read_pkl(\"train.pkl\")"
219    ]
220   },
221   {
222    "cell_type": "code",
223    "execution_count": null,
224    "id": "e0147836-f6ba-4141-9c9d-7c2e5d676bc2",
225    "metadata": {},
226    "outputs": [],
227    "source": [
228     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
229     "params = RNNParams(params)"
230    ]
231   },
232   {
233    "cell_type": "code",
234    "execution_count": null,
235    "id": "4219f8e0-cf44-43a6-830b-fc859c3d954b",
236    "metadata": {},
237    "outputs": [],
238    "source": [
239     "params.update({'activation': ['linear', 'linear'], 'epochs':300, 'val_frac': .2, 'scaler': 'minmax', 'rnn_layers': 1, 'dense_layers': 1})"
240    ]
241   },
242   {
243    "cell_type": "code",
244    "execution_count": null,
245    "id": "722bd100-beaa-49c7-a1ab-b72765c89ebe",
246    "metadata": {},
247    "outputs": [],
248    "source": [
249     "rnn_dat = RNNData(train['NV020_202401'], scaler = params['scaler'], features_list = params['features_list'])"
250    ]
251   },
252   {
253    "cell_type": "code",
254    "execution_count": null,
255    "id": "2eb4adf9-c4eb-493c-9f62-59ba17f6da2f",
256    "metadata": {},
257    "outputs": [],
258    "source": [
259     "rnn_dat.train_test_split(\n",
260     "    train_frac = params['train_frac'],\n",
261     "    val_frac = params['val_frac']\n",
262     ")"
263    ]
264   },
265   {
266    "cell_type": "code",
267    "execution_count": null,
268    "id": "2b4fa4e1-c1b9-483a-83ac-cf0ee46662fa",
269    "metadata": {},
270    "outputs": [],
271    "source": [
272     "rnn_dat.scale_data()"
273    ]
274   },
275   {
276    "cell_type": "code",
277    "execution_count": null,
278    "id": "fc362aa4-fe28-4848-9575-bb17f72ac9fd",
279    "metadata": {},
280    "outputs": [],
281    "source": [
282     "import importlib\n",
283     "import moisture_rnn\n",
284     "importlib.reload(moisture_rnn)\n",
285     "from moisture_rnn import RNN"
286    ]
287   },
288   {
289    "cell_type": "code",
290    "execution_count": null,
291    "id": "79381a0b-2338-4a09-876e-91e50b968d3f",
292    "metadata": {},
293    "outputs": [],
294    "source": [
295     "rnn = RNN(params)"
296    ]
297   },
298   {
299    "cell_type": "code",
300    "execution_count": null,
301    "id": "155cebc4-e7f6-47ae-943e-556c3939ab95",
302    "metadata": {},
303    "outputs": [],
304    "source": [
305     "rnn.predict(rnn_dat.X_test)"
306    ]
307   },
308   {
309    "cell_type": "code",
310    "execution_count": null,
311    "id": "7d28dd8c-3d90-43ba-842e-4d4ed5deb823",
312    "metadata": {},
313    "outputs": [],
314    "source": [
315     "rnn.model_predict.summary()"
316    ]
317   },
318   {
319    "cell_type": "code",
320    "execution_count": null,
321    "id": "1f928e06-867e-4cc5-ab94-83b30b923374",
322    "metadata": {},
323    "outputs": [],
324    "source": [
325     "reproducibility.set_seed(123)\n",
326     "rnn = RNN(params)\n",
327     "m, errs = rnn.run_model(rnn_dat)"
328    ]
329   },
330   {
331    "cell_type": "code",
332    "execution_count": null,
333    "id": "4b4ca16e-7c40-4bb9-b971-dd0efa4e8a83",
334    "metadata": {},
335    "outputs": [],
336    "source": []
337   },
338   {
339    "cell_type": "code",
340    "execution_count": null,
341    "id": "54c917e9-20a6-4b8d-b6ab-04fdb0333467",
342    "metadata": {},
343    "outputs": [],
344    "source": []
345   },
346   {
347    "cell_type": "code",
348    "execution_count": null,
349    "id": "7c659050-e74e-4f07-b95a-3f7b57653061",
350    "metadata": {},
351    "outputs": [],
352    "source": []
353   },
354   {
355    "cell_type": "code",
356    "execution_count": null,
357    "id": "e86c9e4d-4ccd-4d9d-92e1-2e4299549fa4",
358    "metadata": {},
359    "outputs": [],
360    "source": []
361   },
362   {
363    "cell_type": "code",
364    "execution_count": null,
365    "id": "966c3559-740d-44d3-b98d-cc2efe63afcd",
366    "metadata": {},
367    "outputs": [],
368    "source": []
369   },
370   {
371    "cell_type": "code",
372    "execution_count": null,
373    "id": "52c1df9f-87ab-4882-aca1-e90cca7bd470",
374    "metadata": {},
375    "outputs": [],
376    "source": []
377   },
378   {
379    "cell_type": "code",
380    "execution_count": null,
381    "id": "1addf015-b2c1-42df-a769-39df81dd5d14",
382    "metadata": {},
383    "outputs": [],
384    "source": []
385   },
386   {
387    "cell_type": "code",
388    "execution_count": null,
389    "id": "3262a73f-0bd7-4e78-ade0-ca53a7da2b84",
390    "metadata": {},
391    "outputs": [],
392    "source": []
393   },
394   {
395    "cell_type": "code",
396    "execution_count": null,
397    "id": "eecfd38b-0a1a-4de3-b568-2f53ffbcc78c",
398    "metadata": {},
399    "outputs": [],
400    "source": []
401   },
402   {
403    "cell_type": "markdown",
404    "id": "2298a1a1-b72c-4c7e-bcb6-2cdefe96fe3e",
405    "metadata": {},
406    "source": [
407     "## Test Data Creation"
408    ]
409   },
410   {
411    "cell_type": "code",
412    "execution_count": null,
413    "id": "c4645246-edce-4544-9809-5ffb0760ae25",
414    "metadata": {},
415    "outputs": [],
416    "source": [
417     "import importlib\n",
418     "import moisture_rnn_pkl\n",
419     "importlib.reload(moisture_rnn_pkl)\n",
420     "from moisture_rnn_pkl import pkl2train"
421    ]
422   },
423   {
424    "cell_type": "code",
425    "execution_count": null,
426    "id": "5b662edb-7a79-4532-b0d7-2492b1ad917d",
427    "metadata": {},
428    "outputs": [],
429    "source": [
430     "file_names=['test_CA_202401.pkl', 'test_NW_202401.pkl']\n",
431     "file_dir='data'\n",
432     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
433    ]
434   },
435   {
436    "cell_type": "code",
437    "execution_count": null,
438    "id": "1185c995-e9fa-4586-96c2-44b159ccf477",
439    "metadata": {
440     "scrolled": true
441    },
442    "outputs": [],
443    "source": [
444     "train = pkl2train(file_paths)"
445    ]
446   },
447   {
448    "cell_type": "code",
449    "execution_count": null,
450    "id": "665291be-0f40-46b5-9a63-27a58965f8ca",
451    "metadata": {},
452    "outputs": [],
453    "source": [
454     "train.keys()"
455    ]
456   },
457   {
458    "cell_type": "code",
459    "execution_count": null,
460    "id": "6b61a406-eed8-4595-9c3f-4c11e1aed7c8",
461    "metadata": {},
462    "outputs": [],
463    "source": []
464   },
465   {
466    "cell_type": "code",
467    "execution_count": null,
468    "id": "e234b0f6-3cc9-46d1-926a-d825c58e3991",
469    "metadata": {},
470    "outputs": [],
471    "source": []
472   },
473   {
474    "cell_type": "code",
475    "execution_count": null,
476    "id": "7fdf595c-68e1-4e93-a5ec-d6e20e2f1bdf",
477    "metadata": {},
478    "outputs": [],
479    "source": []
480   },
481   {
482    "cell_type": "code",
483    "execution_count": null,
484    "id": "fc3e8264-da29-4261-a560-ef457f42ed70",
485    "metadata": {},
486    "outputs": [],
487    "source": []
488   },
489   {
490    "cell_type": "code",
491    "execution_count": null,
492    "id": "7deda359-1e7f-447a-97b7-576b98712a74",
493    "metadata": {},
494    "outputs": [],
495    "source": []
496   },
497   {
498    "cell_type": "code",
499    "execution_count": null,
500    "id": "7fc05c26-9a54-4863-8956-d76913128701",
501    "metadata": {},
502    "outputs": [],
503    "source": []
504   },
505   {
506    "cell_type": "markdown",
507    "id": "2afc2cf7-eab1-4a85-8632-4d306aead358",
508    "metadata": {},
509    "source": [
510     "## Test RNN"
511    ]
512   },
513   {
514    "cell_type": "code",
515    "execution_count": null,
516    "id": "bfd419f0-9092-470d-81b7-d3b45e4bdc0b",
517    "metadata": {},
518    "outputs": [],
519    "source": []
520   },
521   {
522    "cell_type": "code",
523    "execution_count": null,
524    "id": "545ece65-9f4a-4b45-b87f-ea3a23032cac",
525    "metadata": {},
526    "outputs": [],
527    "source": []
528   },
529   {
530    "cell_type": "code",
531    "execution_count": null,
532    "id": "1e9ec6f9-8598-4560-b71e-222f5b4c4968",
533    "metadata": {},
534    "outputs": [],
535    "source": []
536   },
537   {
538    "cell_type": "code",
539    "execution_count": null,
540    "id": "e2a7840d-f7e4-424d-b343-06f913f9d3f6",
541    "metadata": {},
542    "outputs": [],
543    "source": []
544   },
545   {
546    "cell_type": "code",
547    "execution_count": null,
548    "id": "52e2942b-3bed-4c3d-8082-c7069d791036",
549    "metadata": {},
550    "outputs": [],
551    "source": []
552   },
553   {
554    "cell_type": "code",
555    "execution_count": null,
556    "id": "def73f2c-5d2f-42c6-8c2d-328ac5e8db20",
557    "metadata": {},
558    "outputs": [],
559    "source": []
560   },
561   {
562    "cell_type": "code",
563    "execution_count": null,
564    "id": "888dd72a-4eef-414b-ac33-f6f4bfbefe60",
565    "metadata": {},
566    "outputs": [],
567    "source": [
568     "errs"
569    ]
570   },
571   {
572    "cell_type": "code",
573    "execution_count": null,
574    "id": "7f40cdfd-b33a-43c1-8bc4-44a0ea6817ff",
575    "metadata": {},
576    "outputs": [],
577    "source": [
578     "import importlib \n",
579     "import moisture_rnn\n",
580     "importlib.reload(moisture_rnn)\n",
581     "from moisture_rnn import RNN"
582    ]
583   },
584   {
585    "cell_type": "code",
586    "execution_count": null,
587    "id": "bdf0ba2e-f944-4c86-a20e-a59e023897cb",
588    "metadata": {},
589    "outputs": [],
590    "source": [
591     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
592     "params = RNNParams(params)"
593    ]
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "9dbd51b0-9342-4b90-a250-0ac2c75d3066",
599    "metadata": {},
600    "outputs": [],
601    "source": [
602     "reproducibility.set_seed()\n",
603     "rnn = RNN(params)\n",
604     "m, errs = rnn.run_model(rnn_dat)"
605    ]
606   },
607   {
608    "cell_type": "code",
609    "execution_count": null,
610    "id": "c6d7d34c-dfae-4370-a398-a287790eff53",
611    "metadata": {},
612    "outputs": [],
613    "source": []
614   },
615   {
616    "cell_type": "markdown",
617    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
618    "metadata": {},
619    "source": [
620     "## LSTM\n",
621     "\n",
622     "TODO: FIX BELOW"
623    ]
624   },
625   {
626    "cell_type": "code",
627    "execution_count": null,
628    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
629    "metadata": {},
630    "outputs": [],
631    "source": [
632     "import importlib \n",
633     "import moisture_rnn\n",
634     "importlib.reload(moisture_rnn)\n",
635     "from moisture_rnn import RNN_LSTM"
636    ]
637   },
638   {
639    "cell_type": "code",
640    "execution_count": null,
641    "id": "59480f19-3567-4b24-b6ff-d9292dc8c2ec",
642    "metadata": {},
643    "outputs": [],
644    "source": [
645     "with open(\"params.yaml\") as file:\n",
646     "    params = yaml.safe_load(file)[\"lstm\"]\n",
647     "    \n",
648     "rnn_dat2 = create_rnn_data2(train[case],params)"
649    ]
650   },
651   {
652    "cell_type": "code",
653    "execution_count": null,
654    "id": "2adff592-7aa4-4e59-a229-cad4a133297e",
655    "metadata": {},
656    "outputs": [],
657    "source": [
658     "params.update({'epochs': 10})"
659    ]
660   },
661   {
662    "cell_type": "code",
663    "execution_count": null,
664    "id": "b20539f0-eed2-44de-9269-ae8696c8e7c8",
665    "metadata": {},
666    "outputs": [],
667    "source": []
668   },
669   {
670    "cell_type": "code",
671    "execution_count": null,
672    "id": "6bfbcbb5-b631-4594-9ae5-618c4fe68e7b",
673    "metadata": {},
674    "outputs": [],
675    "source": [
676     "reproducibility.set_seed()\n",
677     "rnn = RNN(params)\n",
678     "m, errs = rnn.run_model(rnn_dat2)"
679    ]
680   },
681   {
682    "cell_type": "code",
683    "execution_count": null,
684    "id": "dd8a9700-f479-4c11-8655-ca7b45222402",
685    "metadata": {},
686    "outputs": [],
687    "source": []
688   },
689   {
690    "cell_type": "code",
691    "execution_count": null,
692    "id": "de46c481-74a7-46cc-8334-678ad8230cce",
693    "metadata": {},
694    "outputs": [],
695    "source": [
696     "import importlib\n",
697     "importlib.reload(moisture_rnn)\n",
698     "from moisture_rnn import RNN_LSTM"
699    ]
700   },
701   {
702    "cell_type": "code",
703    "execution_count": null,
704    "id": "2b6a699a-68e8-49ef-95f2-409137502fb6",
705    "metadata": {},
706    "outputs": [],
707    "source": [
708     "with open(\"params.yaml\") as file:\n",
709     "    params = yaml.safe_load(file)[\"lstm\"]\n",
710     "\n",
711     "rnn_dat2 = create_rnn_data2(train[case],params)\n",
712     "params"
713    ]
714   },
715   {
716    "cell_type": "code",
717    "execution_count": null,
718    "id": "188c0d5d-f3f6-4a61-83b0-b21dfc5d01b7",
719    "metadata": {},
720    "outputs": [],
721    "source": [
722     "params.update({\n",
723     "    'learning_rate': 0.000001,\n",
724     "    'epochs': 10,\n",
725     "    'clipvalue':1.0\n",
726     "})"
727    ]
728   },
729   {
730    "cell_type": "code",
731    "execution_count": null,
732    "id": "6a9d612e-8cd2-40ca-a789-91c99c3d6ccd",
733    "metadata": {},
734    "outputs": [],
735    "source": [
736     "reproducibility.set_seed()\n",
737     "lstm = RNN_LSTM(params)\n",
738     "m, errs = lstm.run_model(rnn_dat2)"
739    ]
740   },
741   {
742    "cell_type": "code",
743    "execution_count": null,
744    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
745    "metadata": {},
746    "outputs": [],
747    "source": []
748   },
749   {
750    "cell_type": "code",
751    "execution_count": null,
752    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
753    "metadata": {},
754    "outputs": [],
755    "source": []
756   },
757   {
758    "cell_type": "code",
759    "execution_count": null,
760    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
761    "metadata": {},
762    "outputs": [],
763    "source": []
764   },
765   {
766    "cell_type": "code",
767    "execution_count": null,
768    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
769    "metadata": {},
770    "outputs": [],
771    "source": []
772   }
773  ],
774  "metadata": {
775   "kernelspec": {
776    "display_name": "Python 3 (ipykernel)",
777    "language": "python",
778    "name": "python3"
779   },
780   "language_info": {
781    "codemirror_mode": {
782     "name": "ipython",
783     "version": 3
784    },
785    "file_extension": ".py",
786    "mimetype": "text/x-python",
787    "name": "python",
788    "nbconvert_exporter": "python",
789    "pygments_lexer": "ipython3",
790    "version": "3.12.5"
791   }
792  },
793  "nbformat": 4,
794  "nbformat_minor": 5