Update rnn_workshop.ipynb
[notebooks.git] / fmda / rnn_workshop.ipynb
blob8690fbeac66351f2dab0fc80d4e13b945310f36b
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "244c2fb0-4339-476c-a2db-a641e124e25a",
6    "metadata": {},
7    "source": [
8     "# v2.1 exploration trying to make it work better"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import os\n",
20     "import os.path as osp\n",
21     "import numpy as np\n",
22     "import pandas as pd\n",
23     "import tensorflow as tf\n",
24     "import matplotlib.pyplot as plt\n",
25     "import sys\n",
26     "# Local modules\n",
27     "sys.path.append('..')\n",
28     "import reproducibility\n",
29     "import pandas as pd\n",
30     "from utils import print_dict_summary\n",
31     "from data_funcs import rmse, process_train_dict\n",
32     "from moisture_rnn import RNNParams, RNNData, RNN, RNN_LSTM\n",
33     "from moisture_rnn_pkl import pkl2train\n",
34     "from tensorflow.keras.callbacks import Callback\n",
35     "from utils import hash2\n",
36     "import copy\n",
37     "import logging\n",
38     "import pickle\n",
39     "from utils import logging_setup, read_yml, read_pkl, hash_ndarray, hash_weights\n",
40     "import yaml\n",
41     "import copy"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "logging_setup()"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "fae67b50-f916-45a7-bcc7-61995ba39449",
57    "metadata": {},
58    "source": [
59     "## Tests"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "12992b9a-407f-4131-ac61-e1dc338386bf",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))\n",
70     "params"
71    ]
72   },
73   {
74    "cell_type": "code",
75    "execution_count": null,
76    "id": "f214fdf8-bb76-4912-8f8c-5d0c8c1230c2",
77    "metadata": {},
78    "outputs": [],
79    "source": [
80     "dat = read_pkl(\"data/train.pkl\")"
81    ]
82   },
83   {
84    "cell_type": "code",
85    "execution_count": null,
86    "id": "888b7805-15f6-4c09-a05b-7aed7d253f6e",
87    "metadata": {},
88    "outputs": [],
89    "source": [
90     "cases = [*dat.keys()]"
91    ]
92   },
93   {
94    "cell_type": "code",
95    "execution_count": null,
96    "id": "375055d8-c070-4639-9561-e47d3f21f1f8",
97    "metadata": {},
98    "outputs": [],
99    "source": [
100     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
101     "rnn_dat.train_test_split(\n",
102     "    time_fracs = [.8, .1, .1]\n",
103     ")\n",
104     "rnn_dat.scale_data()\n",
105     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
106    ]
107   },
108   {
109    "cell_type": "code",
110    "execution_count": null,
111    "id": "5488628e-4552-4909-83e9-413fd6878bdd",
112    "metadata": {},
113    "outputs": [],
114    "source": [
115     "params.update({\n",
116     "    'epochs':100,\n",
117     "    'dense_layers': 0,\n",
118     "    'activation': ['relu', 'relu'],\n",
119     "    'phys_initialize': False,\n",
120     "    'dropout': [0,0]\n",
121     "})"
122    ]
123   },
124   {
125    "cell_type": "code",
126    "execution_count": null,
127    "id": "56bdf26c-07e7-4e4a-a567-af7dd0f564d9",
128    "metadata": {},
129    "outputs": [],
130    "source": [
131     "reproducibility.set_seed()\n",
132     "rnn = RNN(params)\n",
133     "m, errs = rnn.run_model(rnn_dat)"
134    ]
135   },
136   {
137    "cell_type": "code",
138    "execution_count": null,
139    "id": "01227b79-98f3-4931-bdfc-ff08afa8be5f",
140    "metadata": {},
141    "outputs": [],
142    "source": [
143     "rnn.model_train.summary()"
144    ]
145   },
146   {
147    "cell_type": "code",
148    "execution_count": null,
149    "id": "918a8bf0-638b-4b4b-82fe-c6a1965a72dd",
150    "metadata": {},
151    "outputs": [],
152    "source": []
153   },
154   {
155    "cell_type": "code",
156    "execution_count": null,
157    "id": "0aab34c7-8a09-480a-9d3e-619f7cf82b34",
158    "metadata": {},
159    "outputs": [],
160    "source": [
161     "params.update({\n",
162     "    'phys_initialize': True,\n",
163     "    'scaler': None, # TODO\n",
164     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
165     "    'activation': ['linear', 'linear'], # TODO tanh, relu the same\n",
166     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
167     "})"
168    ]
169   },
170   {
171    "cell_type": "code",
172    "execution_count": null,
173    "id": "ab549075-f71f-42ad-b36f-3d1e90247e33",
174    "metadata": {},
175    "outputs": [],
176    "source": [
177     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
178     "rnn_dat2.train_test_split(\n",
179     "    time_fracs = [.8, .1, .1]\n",
180     ")\n",
181     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
182    ]
183   },
184   {
185    "cell_type": "code",
186    "execution_count": null,
187    "id": "195f337a-ac8a-4471-8226-94863b9385e2",
188    "metadata": {},
189    "outputs": [],
190    "source": [
191     "import importlib\n",
192     "import moisture_rnn\n",
193     "importlib.reload(moisture_rnn)\n",
194     "from moisture_rnn import RNN, RNNData"
195    ]
196   },
197   {
198    "cell_type": "code",
199    "execution_count": null,
200    "id": "9395d147-17a5-44ba-aaa2-a213ffde062b",
201    "metadata": {
202     "scrolled": true
203    },
204    "outputs": [],
205    "source": [
206     "reproducibility.set_seed()\n",
207     "\n",
208     "rnn = RNN(params)"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "d3eebe8a-ff12-454b-81b6-6a138924f127",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "m, errs = rnn.run_model(rnn_dat2)"
219    ]
220   },
221   {
222    "cell_type": "code",
223    "execution_count": null,
224    "id": "bcbb0159-74c5-4f56-9d69-d85a58ddbd1a",
225    "metadata": {},
226    "outputs": [],
227    "source": [
228     "rnn.model_predict.get_weights()"
229    ]
230   },
231   {
232    "cell_type": "code",
233    "execution_count": null,
234    "id": "c25f741a-6280-4cf2-8017-e56672236fdb",
235    "metadata": {},
236    "outputs": [],
237    "source": []
238   },
239   {
240    "cell_type": "code",
241    "execution_count": null,
242    "id": "e8ed2b03-6123-4bdf-9e26-ef2ce4951663",
243    "metadata": {},
244    "outputs": [],
245    "source": [
246     "params['rnn_units']"
247    ]
248   },
249   {
250    "cell_type": "code",
251    "execution_count": null,
252    "id": "e44302bf-af49-4140-ae31-54f7c88a6735",
253    "metadata": {},
254    "outputs": [],
255    "source": [
256     "params.update({\n",
257     "    'phys_initialize': True,\n",
258     "    'scaler': None, # TODO\n",
259     "    'dense_layers': 0, # NOT including single Dense output layer which is hard-coded\n",
260     "    'activation': ['relu', 'relu'], # TODO tanh, relu the same\n",
261     "    'batch_schedule_type': None # Hopefully this isn't a necessity like before, but maybe it will help\n",
262     "})"
263    ]
264   },
265   {
266    "cell_type": "code",
267    "execution_count": null,
268    "id": "9a8ac32d-551c-43e8-988e-a3b13e6d9cd9",
269    "metadata": {},
270    "outputs": [],
271    "source": [
272     "rnn_dat2 = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
273     "rnn_dat2.train_test_split(\n",
274     "    time_fracs = [.8, .1, .1]\n",
275     ")\n",
276     "rnn_dat2.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
277    ]
278   },
279   {
280    "cell_type": "code",
281    "execution_count": null,
282    "id": "ff727da8-38fb-4fda-999b-f712b98de0df",
283    "metadata": {
284     "scrolled": true
285    },
286    "outputs": [],
287    "source": [
288     "reproducibility.set_seed()\n",
289     "\n",
290     "rnn = RNN(params)\n",
291     "m, errs = rnn.run_model(rnn_dat2)"
292    ]
293   },
294   {
295    "cell_type": "code",
296    "execution_count": null,
297    "id": "b165074c-ea88-4b4d-8e41-6b6f22b4d221",
298    "metadata": {},
299    "outputs": [],
300    "source": []
301   },
302   {
303    "cell_type": "code",
304    "execution_count": null,
305    "id": "aa5cd4e6-4441-4c77-a086-e9edefbeb83b",
306    "metadata": {},
307    "outputs": [],
308    "source": []
309   },
310   {
311    "cell_type": "code",
312    "execution_count": null,
313    "id": "7bd1e05b-5cd8-48b4-8469-4842313d6097",
314    "metadata": {},
315    "outputs": [],
316    "source": []
317   },
318   {
319    "cell_type": "code",
320    "execution_count": null,
321    "id": "b399346d-20b8-4c97-898a-606a4be98065",
322    "metadata": {},
323    "outputs": [],
324    "source": []
325   },
326   {
327    "cell_type": "code",
328    "execution_count": null,
329    "id": "521285e6-6b6a-4d23-b688-9eb84b8eab68",
330    "metadata": {},
331    "outputs": [],
332    "source": []
333   },
334   {
335    "cell_type": "code",
336    "execution_count": null,
337    "id": "12c66af1-54fd-4398-8ee2-36eeb937c40d",
338    "metadata": {},
339    "outputs": [],
340    "source": []
341   },
342   {
343    "cell_type": "code",
344    "execution_count": null,
345    "id": "eb21fb8e-05c6-4a39-bdf1-4a57067c786d",
346    "metadata": {},
347    "outputs": [],
348    "source": []
349   },
350   {
351    "cell_type": "code",
352    "execution_count": null,
353    "id": "628a9105-ca06-44c4-ad00-13808e2f4773",
354    "metadata": {},
355    "outputs": [],
356    "source": []
357   },
358   {
359    "cell_type": "code",
360    "execution_count": null,
361    "id": "37fdbb3a-3e83-4541-93b2-982b6d4cbe93",
362    "metadata": {},
363    "outputs": [],
364    "source": []
365   },
366   {
367    "cell_type": "code",
368    "execution_count": null,
369    "id": "a592a4c9-cb3b-4174-8eaa-02afd00a1897",
370    "metadata": {},
371    "outputs": [],
372    "source": []
373   },
374   {
375    "cell_type": "code",
376    "execution_count": null,
377    "id": "3832fb05-417c-4648-8e2e-7748c06b3768",
378    "metadata": {},
379    "outputs": [],
380    "source": []
381   },
382   {
383    "cell_type": "markdown",
384    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
385    "metadata": {},
386    "source": [
387     "## LSTM\n",
388     "\n",
389     "TODO: FIX BELOW"
390    ]
391   },
392   {
393    "cell_type": "code",
394    "execution_count": null,
395    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
396    "metadata": {},
397    "outputs": [],
398    "source": [
399     "import importlib \n",
400     "import moisture_rnn\n",
401     "importlib.reload(moisture_rnn)\n",
402     "from moisture_rnn import RNN_LSTM"
403    ]
404   },
405   {
406    "cell_type": "code",
407    "execution_count": null,
408    "id": "0f6ba896-e3be-4a9f-8a42-3df64aff7d63",
409    "metadata": {},
410    "outputs": [],
411    "source": [
412     "params = read_yml(\"params.yaml\", subkey=\"lstm\")\n",
413     "params = RNNParams(params)"
414    ]
415   },
416   {
417    "cell_type": "code",
418    "execution_count": null,
419    "id": "a4cf567e-d623-4e14-b578-eed88b80d04e",
420    "metadata": {},
421    "outputs": [],
422    "source": [
423     "rnn_dat = RNNData(dat[cases[10]], params['scaler'], params['features_list'])\n",
424     "rnn_dat.train_test_split(\n",
425     "    time_fracs = [.8, .1, .1]\n",
426     ")\n",
427     "rnn_dat.scale_data()\n",
428     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
429    ]
430   },
431   {
432    "cell_type": "code",
433    "execution_count": null,
434    "id": "0157a6bc-3a99-4b87-a42c-ab770d19ae37",
435    "metadata": {},
436    "outputs": [],
437    "source": [
438     "from moisture_rnn import ResetStatesCallback, EarlyStoppingCallback\n",
439     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
440     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
441     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours})\n",
442     "reproducibility.set_seed(123)\n",
443     "lstm = RNN_LSTM(params)\n",
444     "\n",
445     "history = lstm.model_train.fit(rnn_dat.X_train, rnn_dat.y_train, \n",
446     "                    batch_size = params['batch_size'], epochs=params['epochs'], \n",
447     "                    callbacks = [ResetStatesCallback(params),\n",
448     "                                EarlyStoppingCallback(patience = 15)],\n",
449     "                   validation_data = (rnn_dat.X_val, rnn_dat.y_val))\n",
450     "              "
451    ]
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "ec95e7d4-6d57-441b-b673-f10625ee5dec",
457    "metadata": {},
458    "outputs": [],
459    "source": []
460   },
461   {
462    "cell_type": "code",
463    "execution_count": null,
464    "id": "9b3c8d8d-ea50-44ea-8c0c-414e07cd01ac",
465    "metadata": {},
466    "outputs": [],
467    "source": []
468   },
469   {
470    "cell_type": "code",
471    "execution_count": null,
472    "id": "03063e3c-e8f4-451d-b0cf-25bd965cd9d6",
473    "metadata": {},
474    "outputs": [],
475    "source": [
476     "params.update({'epochs': 50, 'learning_rate': 0.001, 'verbose_fit': True, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n",
477     "              'activation': ['tanh', 'tanh'], 'features_list': rnn_dat.features_list,\n",
478     "              'batch_schedule_type':'exp', 'bmin': 10, 'bmax':rnn_dat.hours,\n",
479     "              'early_stopping_patience': 25})\n",
480     "reproducibility.set_seed(123)\n",
481     "lstm = RNN_LSTM(params)\n",
482     "m, errs = lstm.run_model(rnn_dat)"
483    ]
484   },
485   {
486    "cell_type": "code",
487    "execution_count": null,
488    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
489    "metadata": {},
490    "outputs": [],
491    "source": []
492   },
493   {
494    "cell_type": "code",
495    "execution_count": null,
496    "id": "00910bd2-f050-438c-ab3b-c793b83cb5f5",
497    "metadata": {},
498    "outputs": [],
499    "source": [
500     "rnn_dat.spatial"
501    ]
502   },
503   {
504    "cell_type": "code",
505    "execution_count": null,
506    "id": "236b33e3-e864-4453-be16-cf07338c4105",
507    "metadata": {},
508    "outputs": [],
509    "source": [
510     "params = RNNParams(read_yml(\"params.yaml\", subkey='lstm'))\n",
511     "params"
512    ]
513   },
514   {
515    "cell_type": "code",
516    "execution_count": null,
517    "id": "fe2a484c-dc99-45a9-89fc-2f451bd719b5",
518    "metadata": {},
519    "outputs": [],
520    "source": [
521     "train = read_pkl(\"data/train.pkl\")"
522    ]
523   },
524   {
525    "cell_type": "code",
526    "execution_count": null,
527    "id": "07bfac87-a6d4-4dcc-8d11-adf83eafab76",
528    "metadata": {},
529    "outputs": [],
530    "source": [
531     "from itertools import islice\n",
532     "train = {k: train[k] for k in islice(train, 100)}"
533    ]
534   },
535   {
536    "cell_type": "code",
537    "execution_count": null,
538    "id": "4e26099b-f760-4047-afec-9e751d24b7a6",
539    "metadata": {},
540    "outputs": [],
541    "source": [
542     "from data_funcs import combine_nested\n",
543     "rnn_dat_sp = RNNData(\n",
544     "    combine_nested(train), # input dictionary\n",
545     "    scaler=\"standard\",  # data scaling type\n",
546     "    features_list = params['features_list'] # features for predicting outcome\n",
547     ")\n",
548     "\n",
549     "\n",
550     "rnn_dat_sp.train_test_split(   \n",
551     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
552     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
553     ")\n",
554     "rnn_dat_sp.scale_data()\n",
555     "\n",
556     "rnn_dat_sp.batch_reshape(\n",
557     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
558     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
559     ")"
560    ]
561   },
562   {
563    "cell_type": "code",
564    "execution_count": null,
565    "id": "10738795-c83b-4da3-88ba-09278caa35f8",
566    "metadata": {},
567    "outputs": [],
568    "source": [
569     "params.update({\n",
570     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
571     "})"
572    ]
573   },
574   {
575    "cell_type": "code",
576    "execution_count": null,
577    "id": "9c5d45cc-bcf0-4b6c-9c51-c4c790a2d9a5",
578    "metadata": {},
579    "outputs": [],
580    "source": [
581     "rnn_sp = RNN_LSTM(params)\n",
582     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
583    ]
584   },
585   {
586    "cell_type": "code",
587    "execution_count": null,
588    "id": "ee332ccf-4e4a-4f66-b4d6-c079dbdb1411",
589    "metadata": {},
590    "outputs": [],
591    "source": [
592     "errs.mean()"
593    ]
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "739d4b26-641e-47b2-a90a-67cd32215d05",
599    "metadata": {},
600    "outputs": [],
601    "source": []
602   }
603  ],
604  "metadata": {
605   "kernelspec": {
606    "display_name": "Python 3 (ipykernel)",
607    "language": "python",
608    "name": "python3"
609   },
610   "language_info": {
611    "codemirror_mode": {
612     "name": "ipython",
613     "version": 3
614    },
615    "file_extension": ".py",
616    "mimetype": "text/x-python",
617    "name": "python",
618    "nbconvert_exporter": "python",
619    "pygments_lexer": "ipython3",
620    "version": "3.12.5"
621   }
622  },
623  "nbformat": 4,
624  "nbformat_minor": 5