BUG FIX: add needed imports to utils
[notebooks.git] / fmda / rnn_workshop.ipynb
blobc23b6db6a41a02279d47f55612ab0a69efdbbbc3
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM, create_rnn_data2\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "a42cc05f-1438-459f-9a15-64276aa2f651",
57    "metadata": {},
58    "source": [
59     "## Test New Repro File"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "d27c7277-c9e1-4fd1-b050-d4b6bc737822",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "import importlib\n",
70     "import utils\n",
71     "importlib.reload(utils)\n",
72     "from utils import read_pkl"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "7a9a5414-d33f-4f0d-b1e5-b320b22d60f5",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "train = read_pkl(\"data/train.pkl\")"
83    ]
84   },
85   {
86    "cell_type": "code",
87    "execution_count": null,
88    "id": "e0147836-f6ba-4141-9c9d-7c2e5d676bc2",
89    "metadata": {},
90    "outputs": [],
91    "source": [
92     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
93     "params = RNNParams(params)"
94    ]
95   },
96   {
97    "cell_type": "code",
98    "execution_count": null,
99    "id": "4219f8e0-cf44-43a6-830b-fc859c3d954b",
100    "metadata": {},
101    "outputs": [],
102    "source": [
103     "params.update({'activation': ['linear', 'linear'], 'epochs':500, 'val_frac': .2, 'scaler': 'minmax'})"
104    ]
105   },
106   {
107    "cell_type": "code",
108    "execution_count": null,
109    "id": "722bd100-beaa-49c7-a1ab-b72765c89ebe",
110    "metadata": {},
111    "outputs": [],
112    "source": [
113     "rnn_dat = RNNData(train['CRVC1_202401'], scaler = params['scaler'], features_list = params['features_list'])"
114    ]
115   },
116   {
117    "cell_type": "code",
118    "execution_count": null,
119    "id": "2eb4adf9-c4eb-493c-9f62-59ba17f6da2f",
120    "metadata": {},
121    "outputs": [],
122    "source": [
123     "rnn_dat.train_test_split(\n",
124     "    train_frac = params['train_frac'],\n",
125     "    val_frac = params['val_frac']\n",
126     ")"
127    ]
128   },
129   {
130    "cell_type": "code",
131    "execution_count": null,
132    "id": "2b4fa4e1-c1b9-483a-83ac-cf0ee46662fa",
133    "metadata": {},
134    "outputs": [],
135    "source": [
136     "rnn_dat.scale_data()"
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "1f928e06-867e-4cc5-ab94-83b30b923374",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "reproducibility.set_seed(123)\n",
147     "rnn = RNN(params)\n",
148     "m, errs = rnn.run_model(rnn_dat)"
149    ]
150   },
151   {
152    "cell_type": "code",
153    "execution_count": null,
154    "id": "b00a6796-e8c4-41f9-a5ca-f92d0c977b6c",
155    "metadata": {},
156    "outputs": [],
157    "source": []
158   },
159   {
160    "cell_type": "code",
161    "execution_count": null,
162    "id": "d5079a15-23ac-4099-9a3c-816531ca6fc7",
163    "metadata": {},
164    "outputs": [],
165    "source": []
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "a3e8f1a9-43cd-4d8e-8f02-efb1865b3035",
171    "metadata": {},
172    "outputs": [],
173    "source": []
174   },
175   {
176    "cell_type": "markdown",
177    "id": "2298a1a1-b72c-4c7e-bcb6-2cdefe96fe3e",
178    "metadata": {},
179    "source": [
180     "## Test Data Creation"
181    ]
182   },
183   {
184    "cell_type": "code",
185    "execution_count": null,
186    "id": "c4645246-edce-4544-9809-5ffb0760ae25",
187    "metadata": {},
188    "outputs": [],
189    "source": [
190     "import importlib\n",
191     "import moisture_rnn_pkl\n",
192     "importlib.reload(moisture_rnn_pkl)\n",
193     "from moisture_rnn_pkl import pkl2train"
194    ]
195   },
196   {
197    "cell_type": "code",
198    "execution_count": null,
199    "id": "5b662edb-7a79-4532-b0d7-2492b1ad917d",
200    "metadata": {},
201    "outputs": [],
202    "source": [
203     "file_names=['test_CA_202401.pkl', 'test_NW_202401.pkl']\n",
204     "file_dir='data'\n",
205     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
206    ]
207   },
208   {
209    "cell_type": "code",
210    "execution_count": null,
211    "id": "1185c995-e9fa-4586-96c2-44b159ccf477",
212    "metadata": {
213     "scrolled": true
214    },
215    "outputs": [],
216    "source": [
217     "train = pkl2train(file_paths)"
218    ]
219   },
220   {
221    "cell_type": "code",
222    "execution_count": null,
223    "id": "665291be-0f40-46b5-9a63-27a58965f8ca",
224    "metadata": {},
225    "outputs": [],
226    "source": [
227     "train.keys()"
228    ]
229   },
230   {
231    "cell_type": "code",
232    "execution_count": null,
233    "id": "6b61a406-eed8-4595-9c3f-4c11e1aed7c8",
234    "metadata": {},
235    "outputs": [],
236    "source": []
237   },
238   {
239    "cell_type": "code",
240    "execution_count": null,
241    "id": "e234b0f6-3cc9-46d1-926a-d825c58e3991",
242    "metadata": {},
243    "outputs": [],
244    "source": []
245   },
246   {
247    "cell_type": "code",
248    "execution_count": null,
249    "id": "7fdf595c-68e1-4e93-a5ec-d6e20e2f1bdf",
250    "metadata": {},
251    "outputs": [],
252    "source": []
253   },
254   {
255    "cell_type": "code",
256    "execution_count": null,
257    "id": "fc3e8264-da29-4261-a560-ef457f42ed70",
258    "metadata": {},
259    "outputs": [],
260    "source": []
261   },
262   {
263    "cell_type": "code",
264    "execution_count": null,
265    "id": "7deda359-1e7f-447a-97b7-576b98712a74",
266    "metadata": {},
267    "outputs": [],
268    "source": []
269   },
270   {
271    "cell_type": "code",
272    "execution_count": null,
273    "id": "7fc05c26-9a54-4863-8956-d76913128701",
274    "metadata": {},
275    "outputs": [],
276    "source": []
277   },
278   {
279    "cell_type": "markdown",
280    "id": "2afc2cf7-eab1-4a85-8632-4d306aead358",
281    "metadata": {},
282    "source": [
283     "## Test RNN"
284    ]
285   },
286   {
287    "cell_type": "code",
288    "execution_count": null,
289    "id": "bfd419f0-9092-470d-81b7-d3b45e4bdc0b",
290    "metadata": {},
291    "outputs": [],
292    "source": []
293   },
294   {
295    "cell_type": "code",
296    "execution_count": null,
297    "id": "545ece65-9f4a-4b45-b87f-ea3a23032cac",
298    "metadata": {},
299    "outputs": [],
300    "source": []
301   },
302   {
303    "cell_type": "code",
304    "execution_count": null,
305    "id": "1e9ec6f9-8598-4560-b71e-222f5b4c4968",
306    "metadata": {},
307    "outputs": [],
308    "source": []
309   },
310   {
311    "cell_type": "code",
312    "execution_count": null,
313    "id": "e2a7840d-f7e4-424d-b343-06f913f9d3f6",
314    "metadata": {},
315    "outputs": [],
316    "source": []
317   },
318   {
319    "cell_type": "code",
320    "execution_count": null,
321    "id": "52e2942b-3bed-4c3d-8082-c7069d791036",
322    "metadata": {},
323    "outputs": [],
324    "source": []
325   },
326   {
327    "cell_type": "code",
328    "execution_count": null,
329    "id": "def73f2c-5d2f-42c6-8c2d-328ac5e8db20",
330    "metadata": {},
331    "outputs": [],
332    "source": []
333   },
334   {
335    "cell_type": "code",
336    "execution_count": null,
337    "id": "888dd72a-4eef-414b-ac33-f6f4bfbefe60",
338    "metadata": {},
339    "outputs": [],
340    "source": [
341     "errs"
342    ]
343   },
344   {
345    "cell_type": "code",
346    "execution_count": null,
347    "id": "7f40cdfd-b33a-43c1-8bc4-44a0ea6817ff",
348    "metadata": {},
349    "outputs": [],
350    "source": [
351     "import importlib \n",
352     "import moisture_rnn\n",
353     "importlib.reload(moisture_rnn)\n",
354     "from moisture_rnn import RNN"
355    ]
356   },
357   {
358    "cell_type": "code",
359    "execution_count": null,
360    "id": "bdf0ba2e-f944-4c86-a20e-a59e023897cb",
361    "metadata": {},
362    "outputs": [],
363    "source": [
364     "params = read_yml(\"params.yaml\", subkey=\"rnn\")\n",
365     "params = RNNParams(params)"
366    ]
367   },
368   {
369    "cell_type": "code",
370    "execution_count": null,
371    "id": "9dbd51b0-9342-4b90-a250-0ac2c75d3066",
372    "metadata": {},
373    "outputs": [],
374    "source": [
375     "reproducibility.set_seed()\n",
376     "rnn = RNN(params)\n",
377     "m, errs = rnn.run_model(rnn_dat)"
378    ]
379   },
380   {
381    "cell_type": "code",
382    "execution_count": null,
383    "id": "c6d7d34c-dfae-4370-a398-a287790eff53",
384    "metadata": {},
385    "outputs": [],
386    "source": []
387   },
388   {
389    "cell_type": "markdown",
390    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
391    "metadata": {},
392    "source": [
393     "## LSTM\n",
394     "\n",
395     "TODO: FIX BELOW"
396    ]
397   },
398   {
399    "cell_type": "code",
400    "execution_count": null,
401    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
402    "metadata": {},
403    "outputs": [],
404    "source": [
405     "import importlib \n",
406     "import moisture_rnn\n",
407     "importlib.reload(moisture_rnn)\n",
408     "from moisture_rnn import RNN_LSTM"
409    ]
410   },
411   {
412    "cell_type": "code",
413    "execution_count": null,
414    "id": "59480f19-3567-4b24-b6ff-d9292dc8c2ec",
415    "metadata": {},
416    "outputs": [],
417    "source": [
418     "with open(\"params.yaml\") as file:\n",
419     "    params = yaml.safe_load(file)[\"lstm\"]\n",
420     "    \n",
421     "rnn_dat2 = create_rnn_data2(train[case],params)"
422    ]
423   },
424   {
425    "cell_type": "code",
426    "execution_count": null,
427    "id": "2adff592-7aa4-4e59-a229-cad4a133297e",
428    "metadata": {},
429    "outputs": [],
430    "source": [
431     "params.update({'epochs': 10})"
432    ]
433   },
434   {
435    "cell_type": "code",
436    "execution_count": null,
437    "id": "b20539f0-eed2-44de-9269-ae8696c8e7c8",
438    "metadata": {},
439    "outputs": [],
440    "source": []
441   },
442   {
443    "cell_type": "code",
444    "execution_count": null,
445    "id": "6bfbcbb5-b631-4594-9ae5-618c4fe68e7b",
446    "metadata": {},
447    "outputs": [],
448    "source": [
449     "reproducibility.set_seed()\n",
450     "rnn = RNN(params)\n",
451     "m, errs = rnn.run_model(rnn_dat2)"
452    ]
453   },
454   {
455    "cell_type": "code",
456    "execution_count": null,
457    "id": "dd8a9700-f479-4c11-8655-ca7b45222402",
458    "metadata": {},
459    "outputs": [],
460    "source": []
461   },
462   {
463    "cell_type": "code",
464    "execution_count": null,
465    "id": "de46c481-74a7-46cc-8334-678ad8230cce",
466    "metadata": {},
467    "outputs": [],
468    "source": [
469     "import importlib\n",
470     "importlib.reload(moisture_rnn)\n",
471     "from moisture_rnn import RNN_LSTM"
472    ]
473   },
474   {
475    "cell_type": "code",
476    "execution_count": null,
477    "id": "2b6a699a-68e8-49ef-95f2-409137502fb6",
478    "metadata": {},
479    "outputs": [],
480    "source": [
481     "with open(\"params.yaml\") as file:\n",
482     "    params = yaml.safe_load(file)[\"lstm\"]\n",
483     "\n",
484     "rnn_dat2 = create_rnn_data2(train[case],params)\n",
485     "params"
486    ]
487   },
488   {
489    "cell_type": "code",
490    "execution_count": null,
491    "id": "188c0d5d-f3f6-4a61-83b0-b21dfc5d01b7",
492    "metadata": {},
493    "outputs": [],
494    "source": [
495     "params.update({\n",
496     "    'learning_rate': 0.000001,\n",
497     "    'epochs': 10,\n",
498     "    'clipvalue':1.0\n",
499     "})"
500    ]
501   },
502   {
503    "cell_type": "code",
504    "execution_count": null,
505    "id": "6a9d612e-8cd2-40ca-a789-91c99c3d6ccd",
506    "metadata": {},
507    "outputs": [],
508    "source": [
509     "reproducibility.set_seed()\n",
510     "lstm = RNN_LSTM(params)\n",
511     "m, errs = lstm.run_model(rnn_dat2)"
512    ]
513   },
514   {
515    "cell_type": "code",
516    "execution_count": null,
517    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
518    "metadata": {},
519    "outputs": [],
520    "source": []
521   },
522   {
523    "cell_type": "code",
524    "execution_count": null,
525    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
526    "metadata": {},
527    "outputs": [],
528    "source": []
529   },
530   {
531    "cell_type": "code",
532    "execution_count": null,
533    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
534    "metadata": {},
535    "outputs": [],
536    "source": []
537   },
538   {
539    "cell_type": "code",
540    "execution_count": null,
541    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
542    "metadata": {},
543    "outputs": [],
544    "source": []
545   }
546  ],
547  "metadata": {
548   "kernelspec": {
549    "display_name": "Python 3 (ipykernel)",
550    "language": "python",
551    "name": "python3"
552   },
553   "language_info": {
554    "codemirror_mode": {
555     "name": "ipython",
556     "version": 3
557    },
558    "file_extension": ".py",
559    "mimetype": "text/x-python",
560    "name": "python",
561    "nbconvert_exporter": "python",
562    "pygments_lexer": "ipython3",
563    "version": "3.12.5"
564   }
565  },
566  "nbformat": 4,
567  "nbformat_minor": 5