Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / rnn_class.ipynb
blobe54c8739c531b9b144c50e36b51072ba5c6b95e2
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "ae5031bf-2e51-4caa-b83c-6cdb68926331",
6    "metadata": {},
7    "source": [
8     "# v2 Demonstrations/testing of the class structure"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "8b71194e-c5f8-488c-94d4-64b480805d44",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import numpy as np\n",
20     "import pandas as pd\n",
21     "import tensorflow as tf\n",
22     "import matplotlib.pyplot as plt\n",
23     "import sys\n",
24     "# Local modules\n",
25     "sys.path.append('..')\n",
26     "import reproducibility\n",
27     "from utils import print_dict_summary\n",
28     "from data_funcs import load_and_fix_data, rmse\n",
29     "from abc import ABC, abstractmethod\n",
30     "from tensorflow.keras.callbacks import Callback\n",
31     "from sklearn.metrics import mean_squared_error\n",
32     "from utils import hash2"
33    ]
34   },
35   {
36    "cell_type": "code",
37    "execution_count": null,
38    "id": "040a3833-210e-4625-b101-6f33c484b127",
39    "metadata": {},
40    "outputs": [],
41    "source": [
42     "reproducibility_file='version_control/reproducibility_dict0.pickle'\n",
43     "\n",
44     "repro={}\n",
45     "repro.update(load_and_fix_data(reproducibility_file))\n",
46     "print_dict_summary(repro)\n",
47     "\n",
48     "case = 'case11'\n",
49     "case_data=repro[case]\n",
50     "case_data[\"h2\"]=300"
51    ]
52   },
53   {
54    "cell_type": "markdown",
55    "id": "cf248365-2b11-4f86-8fa4-5d18cbf27ee8",
56    "metadata": {},
57    "source": [
58     "## Stateful Batch Training"
59    ]
60   },
61   {
62    "cell_type": "code",
63    "execution_count": null,
64    "id": "0b3675ad-f7d9-46e1-ad60-7ad5432699ec",
65    "metadata": {},
66    "outputs": [],
67    "source": [
68     "from moisture_rnn import RNN, create_rnn_data2\n",
69     "import logging\n",
70     "from utils import logging_setup\n",
71     "from moisture_rnn_pkl import pkl2train\n",
72     "logging_setup()"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "8f8df5cb-64f5-4a61-9aba-d4fc44751610",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "reproducibility_file='data/reproducibility_dict2.pickle'\n",
83     "train = pkl2train([reproducibility_file])"
84    ]
85   },
86   {
87    "cell_type": "code",
88    "execution_count": null,
89    "id": "024ea0f7-3574-4cc2-81d2-ad2af74b9ec9",
90    "metadata": {},
91    "outputs": [],
92    "source": [
93     "import yaml\n",
94     "\n",
95     "with open(\"params.yaml\") as file:\n",
96     "    params = yaml.safe_load(file)[\"rnn_repro\"]\n",
97     "params"
98    ]
99   },
100   {
101    "cell_type": "code",
102    "execution_count": null,
103    "id": "21ed1f73-885c-43c3-8869-0877b06f8ad5",
104    "metadata": {},
105    "outputs": [],
106    "source": [
107     "\n",
108     "rnn_dat = create_rnn_data2(train[\"reproducibility\"], params)"
109    ]
110   },
111   {
112    "cell_type": "code",
113    "execution_count": null,
114    "id": "41a62222-9771-4b7a-b508-b82f4b38b46c",
115    "metadata": {},
116    "outputs": [],
117    "source": [
118     "# Update Params for Reproducibility\n",
119     "\n",
120     "params.update({\n",
121     "    'epochs':200,\n",
122     "    'dropout': [0, 0], # NOTE: length must match total number of layers, default is 1 hidden recurrent layer and 1 dense output layer\n",
123     "    'recurrent_dropout': 0, # Length must match number of recurrent layers    \n",
124     "    'rnn_units': 20\n",
125     "})"
126    ]
127   },
128   {
129    "cell_type": "code",
130    "execution_count": null,
131    "id": "fc961d9e-2e66-472f-8fe4-ec0e888d1b04",
132    "metadata": {},
133    "outputs": [],
134    "source": [
135     "reproducibility.set_seed()\n",
136     "rnn = RNN(params)"
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "080b491f-48ab-4592-8d73-2fbd5ccb82b4",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "rnn.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"])\n",
147     "# rnn.fit(X_train, y_train)"
148    ]
149   },
150   {
151    "cell_type": "code",
152    "execution_count": null,
153    "id": "882a1902-cdee-46bb-a8de-762465df03cd",
154    "metadata": {},
155    "outputs": [],
156    "source": [
157     "# preds = rnn.predict(np.reshape(X,(1, X.shape[0], features)))\n",
158     "print(hash2(rnn_dat['X']))\n",
159     "print(hash2(rnn_dat['y']))\n",
160     "preds = rnn.predict(rnn_dat['X'])\n",
161     "rmse(preds, rnn_dat['y'])"
162    ]
163   },
164   {
165    "cell_type": "code",
166    "execution_count": null,
167    "id": "b9cae53e-2130-4cba-9fc7-901747b6349e",
168    "metadata": {},
169    "outputs": [],
170    "source": [
171     "reproducibility.set_seed()\n",
172     "rnn = RNN(params)\n",
173     "rnn.run_model(rnn_dat)"
174    ]
175   },
176   {
177    "cell_type": "markdown",
178    "id": "9d61b952-dbf8-4935-b795-974d9d5d6bbf",
179    "metadata": {},
180    "source": [
181     "## Physics Initialized"
182    ]
183   },
184   {
185    "cell_type": "code",
186    "execution_count": null,
187    "id": "80466af6-a786-483d-8c61-7214a8dd0b67",
188    "metadata": {},
189    "outputs": [],
190    "source": []
191   },
192   {
193    "cell_type": "markdown",
194    "id": "fd8f1b91-80db-4086-a44d-4bcc695466e4",
195    "metadata": {},
196    "source": [
197     "---\n",
198     "\n",
199     "## New Developments"
200    ]
201   },
202   {
203    "cell_type": "markdown",
204    "id": "fb3c3a94-e3ff-4efe-98f6-f29896cc1066",
205    "metadata": {},
206    "source": [
207     "### Other Hyperparams"
208    ]
209   },
210   {
211    "cell_type": "code",
212    "execution_count": null,
213    "id": "0db3fe44-e056-49c5-88ac-00b1c43f3beb",
214    "metadata": {},
215    "outputs": [],
216    "source": [
217     "params.update({\n",
218     "    'activation': ['sigmoid', 'relu'], # Length must match total number of layers\n",
219     "    'dropout': [0.2, 0.2], # NOTE: length must match total number of layers, default is 1 hidden recurrent layer and 1 dense output layer\n",
220     "    'recurrent_dropout': 0.2, # Length must match number of recurrent layers\n",
221     "    'learning_rate': 0.003,\n",
222     "    'rnn_units': 9,\n",
223     "    'epochs': 100\n",
224     "})"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "78ccebaf-7ea4-465a-9af9-5bc3aad7cb25",
231    "metadata": {},
232    "outputs": [],
233    "source": [
234     "reproducibility.set_seed()\n",
235     "rnn = RNN(params)"
236    ]
237   },
238   {
239    "cell_type": "code",
240    "execution_count": null,
241    "id": "70787b72-bbef-4774-919a-2fc725fbb6d7",
242    "metadata": {},
243    "outputs": [],
244    "source": [
245     "rnn.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"], verbose_fit=False)"
246    ]
247   },
248   {
249    "cell_type": "code",
250    "execution_count": null,
251    "id": "03afe3c9-480f-4e36-bd1f-f30a35927441",
252    "metadata": {},
253    "outputs": [],
254    "source": [
255     "preds = rnn.predict(rnn_dat[\"X\"])\n",
256     "# np.sqrt(mean_squared_error(preds.flatten(), rnn_dat[\"y\"].flatten())\n",
257     "rmse(preds, rnn_dat[\"y\"])"
258    ]
259   },
260   {
261    "cell_type": "markdown",
262    "id": "6bc0e50e-5a34-42e7-9ae0-0ac688408e3a",
263    "metadata": {},
264    "source": [
265     "### Validation Error"
266    ]
267   },
268   {
269    "cell_type": "code",
270    "execution_count": null,
271    "id": "6bb62317-707c-47a2-8a54-6c7276ae7d91",
272    "metadata": {},
273    "outputs": [],
274    "source": [
275     "params.update({\n",
276     "    'train_frac': 0.5,\n",
277     "    'val_frac': 0.1\n",
278     "})\n",
279     "rnn_dat = create_rnn_data2(train[\"reproducibility\"], params)"
280    ]
281   },
282   {
283    "cell_type": "code",
284    "execution_count": null,
285    "id": "e36bc1bd-31af-4369-9e8f-801d61a9aa25",
286    "metadata": {},
287    "outputs": [],
288    "source": [
289     "reproducibility.set_seed()\n",
290     "rnn = RNN(params)\n",
291     "rnn.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"], validation_data = (rnn_dat[\"X_val\"], rnn_dat[\"y_val\"]))"
292    ]
293   },
294   {
295    "cell_type": "code",
296    "execution_count": null,
297    "id": "e2fa46fd-4c08-4f92-96cf-e7ca3ff8b320",
298    "metadata": {},
299    "outputs": [],
300    "source": [
301     "import importlib\n",
302     "import moisture_rnn\n",
303     "importlib.reload(moisture_rnn)\n",
304     "from moisture_rnn import RNN"
305    ]
306   },
307   {
308    "cell_type": "code",
309    "execution_count": null,
310    "id": "e534f3a6-faf6-4ce6-9f5f-b88af5928ec6",
311    "metadata": {},
312    "outputs": [],
313    "source": [
314     "reproducibility.set_seed()\n",
315     "rnn = RNN(params)\n",
316     "rnn.run_model(rnn_dat)"
317    ]
318   },
319   {
320    "cell_type": "code",
321    "execution_count": null,
322    "id": "2ee11823-a3e1-406a-85b7-3085aa685318",
323    "metadata": {},
324    "outputs": [],
325    "source": []
326   }
327  ],
328  "metadata": {
329   "kernelspec": {
330    "display_name": "Python 3 (ipykernel)",
331    "language": "python",
332    "name": "python3"
333   },
334   "language_info": {
335    "codemirror_mode": {
336     "name": "ipython",
337     "version": 3
338    },
339    "file_extension": ".py",
340    "mimetype": "text/x-python",
341    "name": "python",
342    "nbconvert_exporter": "python",
343    "pygments_lexer": "ipython3",
344    "version": "3.10.9"
345   }
346  },
347  "nbformat": 4,
348  "nbformat_minor": 5