Update moisture_rnn.py
[notebooks.git] / fmda / presentations / batch_reset_param_tutorial.ipynb
blob695f056f92adb7093adc5889e409a12b82690c4b
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "6b50356b-169f-471d-92e0-8d82a2911185",
6    "metadata": {},
7    "source": [
8     "# Batch Reset Hyperparameter Tutorial"
9    ]
10   },
11   {
12    "cell_type": "markdown",
13    "id": "f6b09663-b3d5-46a5-a214-22b02467afb4",
14    "metadata": {},
15    "source": [
16     "When training data are very long, a stateful model is prone to instability since at the early iterations of training, an unreasonable hidden state is generated and propogated through many batches of training.\n",
17     "\n",
18     "We introduce the hyperparameter `batch_reset`, which resets the hidden state after a fixed number of batches. Future work will make this a schedule where the number of batches before reset is increased as the network learns and will be less subject to exploding/vanishing gradients."
19    ]
20   },
21   {
22    "cell_type": "markdown",
23    "id": "9d9f20b3-6cad-43f2-8eab-7b065b02891b",
24    "metadata": {},
25    "source": [
26     "## Environment and Data Setup"
27    ]
28   },
29   {
30    "cell_type": "code",
31    "execution_count": null,
32    "id": "23115780-950f-46ea-b1d8-72bd5f3ec3bd",
33    "metadata": {},
34    "outputs": [],
35    "source": [
36     "# Environment\n",
37     "import os\n",
38     "import os.path as osp\n",
39     "import matplotlib.pyplot as plt\n",
40     "import sys\n",
41     "import numpy as np\n",
42     "import pandas as pd\n",
43     "# Local modules\n",
44     "sys.path.append('..')\n",
45     "import reproducibility\n",
46     "from utils import print_dict_summary\n",
47     "from data_funcs import rmse\n",
48     "from moisture_rnn import RNNParams, RNNData, RNN\n",
49     "from moisture_rnn_pkl import pkl2train\n",
50     "from utils import read_yml, read_pkl\n",
51     "import yaml\n",
52     "import pickle"
53    ]
54   },
55   {
56    "cell_type": "code",
57    "execution_count": null,
58    "id": "40e68cdd-bb04-499c-8370-3cbdb3aebc46",
59    "metadata": {},
60    "outputs": [],
61    "source": [
62     "dat = read_pkl(\"batch_reset_tutorial_case.pkl\")"
63    ]
64   },
65   {
66    "cell_type": "code",
67    "execution_count": null,
68    "id": "25e64ffb-3b4c-44c7-9737-caf05303ad0d",
69    "metadata": {},
70    "outputs": [],
71    "source": [
72     "params = read_yml(\"../params.yaml\", subkey=\"rnn\")\n",
73     "params = RNNParams(params)\n",
74     "params.update({'epochs': 10})"
75    ]
76   },
77   {
78    "cell_type": "code",
79    "execution_count": null,
80    "id": "67b36002-7d78-4723-80d5-8f3a6ac5886d",
81    "metadata": {},
82    "outputs": [],
83    "source": [
84     "rnn_dat = RNNData(dat, scaler = params['scaler'], features_list = params['features_list'])\n",
85     "rnn_dat.train_test_split(\n",
86     "    train_frac = .9,\n",
87     "    val_frac = .05\n",
88     ")\n",
89     "rnn_dat.scale_data()\n",
90     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
91    ]
92   },
93   {
94    "cell_type": "markdown",
95    "id": "10eb476a-5791-459d-8ecf-901a11fee1f2",
96    "metadata": {},
97    "source": [
98     "## Train without Stateful"
99    ]
100   },
101   {
102    "cell_type": "code",
103    "execution_count": null,
104    "id": "3e32c020-12e8-42f3-8278-3e15431b042c",
105    "metadata": {},
106    "outputs": [],
107    "source": [
108     "params.update({'verbose_fit': True, 'stateful': False, 'batch_reset':9999})\n",
109     "reproducibility.set_seed(123)\n",
110     "rnn = RNN(params)\n",
111     "try:\n",
112     "    m, errs = rnn.run_model(rnn_dat)\n",
113     "except Exception as e:\n",
114     "    print(\"*\"*50)\n",
115     "    print(f\"Caught Error {e}\")\n",
116     "    print(\"*\"*50)"
117    ]
118   },
119   {
120    "cell_type": "code",
121    "execution_count": null,
122    "id": "05ab6d32-6626-4a18-ab15-c1dddf8496ed",
123    "metadata": {},
124    "outputs": [],
125    "source": [
126     "rnn.predict(rnn_dat.scale_all_X())[0:150]"
127    ]
128   },
129   {
130    "cell_type": "markdown",
131    "id": "c1172389-8c3a-4838-9c99-8bb575ba2014",
132    "metadata": {},
133    "source": [
134     "## Train with Stateful, without Batch Reset\n",
135     "\n",
136     "We turn off the parameter by setting it to a huge value."
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "81232bf4-1995-490c-a609-0f8b88cef65d",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "params.update({'verbose_fit': True, 'stateful': True, 'batch_schedule_type':None})\n",
147     "params.update({'epochs': 30})\n",
148     "reproducibility.set_seed(123)\n",
149     "rnn = RNN(params)"
150    ]
151   },
152   {
153    "cell_type": "code",
154    "execution_count": null,
155    "id": "65e4f4af-9b02-4329-b25e-b89e3c893ba7",
156    "metadata": {},
157    "outputs": [],
158    "source": [
159     "try:\n",
160     "    m, errs = rnn.run_model(rnn_dat)\n",
161     "except Exception as e:\n",
162     "    print(\"*\"*50)\n",
163     "    print(f\"Caught Error {e}\")\n",
164     "    print(\"*\"*50)"
165    ]
166   },
167   {
168    "cell_type": "markdown",
169    "id": "92058207-1186-4389-8fbc-ad39672d5cdb",
170    "metadata": {},
171    "source": [
172     "## Train with Stateful, with Periodic Batch Reset"
173    ]
174   },
175   {
176    "cell_type": "code",
177    "execution_count": null,
178    "id": "06fb5624-6d47-452d-b41b-3c7ae2f987e3",
179    "metadata": {},
180    "outputs": [],
181    "source": [
182     "params.update({'verbose_fit': True, 'stateful': True, 'batch_schedule_type':'constant', 'bmin': 20})\n",
183     "params.update({'epochs': 30})\n",
184     "reproducibility.set_seed(123)\n",
185     "rnn = RNN(params)"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "33938583-4823-4e84-8afd-9652d85d164a",
192    "metadata": {},
193    "outputs": [],
194    "source": [
195     "try:\n",
196     "    m, errs = rnn.run_model(rnn_dat)\n",
197     "except Exception as e:\n",
198     "    print(\"*\"*50)\n",
199     "    print(f\"Caught Error {e}\")\n",
200     "    print(\"*\"*50)"
201    ]
202   },
203   {
204    "cell_type": "markdown",
205    "id": "6f071895-ae8c-4c5a-8c5c-331fbfec1e6a",
206    "metadata": {},
207    "source": [
208     "## Batch Reset Schedules"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "4958e7e0-1216-4392-a239-399f47917d98",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "from moisture_rnn import calc_exp_intervals, calc_log_intervals"
219    ]
220   },
221   {
222    "cell_type": "code",
223    "execution_count": null,
224    "id": "e37db25b-a24e-4b37-abdc-ca6830afafc2",
225    "metadata": {},
226    "outputs": [],
227    "source": [
228     "epochs = 50\n",
229     "bmin = 10\n",
230     "bmax = 200\n",
231     "\n",
232     "egrid = np.arange(epochs)"
233    ]
234   },
235   {
236    "cell_type": "code",
237    "execution_count": null,
238    "id": "ea826c89-5af7-438e-8db1-f91702ae032d",
239    "metadata": {},
240    "outputs": [],
241    "source": [
242     "plt.plot(egrid, np.linspace(bmin, bmax, epochs), label='Linear Increase')\n",
243     "plt.plot(egrid, calc_exp_intervals(bmin, bmax, epochs), label='Exponential Increase')\n",
244     "plt.plot(egrid, calc_log_intervals(bmin, bmax, epochs), label='Logarithmic Increase')\n",
245     "plt.xlabel('Epoch')\n",
246     "plt.ylabel('Batch Reset Value')\n",
247     "plt.legend()\n",
248     "plt.title('Batch Reset Value vs Epoch')\n",
249     "plt.show()"
250    ]
251   },
252   {
253    "cell_type": "markdown",
254    "id": "92e06001-4b27-47e8-aa5e-03bffdb9ba03",
255    "metadata": {},
256    "source": [
257     "### Linear Schedule"
258    ]
259   },
260   {
261    "cell_type": "code",
262    "execution_count": null,
263    "id": "cf4819b7-7e8e-4da9-8dee-217eda8f274f",
264    "metadata": {},
265    "outputs": [],
266    "source": [
267     "params.update({'verbose_fit': False, 'stateful': True, \n",
268     "               'batch_schedule_type':'linear', 'bmin': 20, 'bmax': 200})\n",
269     "params.update({'epochs': 40})\n",
270     "reproducibility.set_seed(123)\n",
271     "rnn = RNN(params)\n",
272     "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")"
273    ]
274   },
275   {
276    "cell_type": "markdown",
277    "id": "362dad1f-7584-4a04-a146-c13c2da2dc84",
278    "metadata": {},
279    "source": [
280     "### Exponential Increase"
281    ]
282   },
283   {
284    "cell_type": "code",
285    "execution_count": null,
286    "id": "d8e590af-4bdc-4e3f-bced-af2f4479e015",
287    "metadata": {},
288    "outputs": [],
289    "source": [
290     "params.update({'verbose_fit': False, 'stateful': True, \n",
291     "               'batch_schedule_type':'exp', 'bmin': 20, 'bmax': 200})\n",
292     "params.update({'epochs': 40})\n",
293     "reproducibility.set_seed(123)\n",
294     "rnn = RNN(params)\n",
295     "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")"
296    ]
297   },
298   {
299    "cell_type": "markdown",
300    "id": "2424801b-3482-4542-9236-dbebae2c1143",
301    "metadata": {},
302    "source": [
303     "### Log Increase"
304    ]
305   },
306   {
307    "cell_type": "code",
308    "execution_count": null,
309    "id": "00d32580-6c20-4916-b241-1af281fedbd9",
310    "metadata": {},
311    "outputs": [],
312    "source": [
313     "params.update({'verbose_fit': False, 'stateful': True, \n",
314     "               'batch_schedule_type':'log', 'bmin': 20, 'bmax': 200})\n",
315     "params.update({'epochs': 40})\n",
316     "reproducibility.set_seed(123)\n",
317     "rnn = RNN(params)\n",
318     "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")"
319    ]
320   },
321   {
322    "cell_type": "code",
323    "execution_count": null,
324    "id": "e6588375-c1bc-4b73-ad34-f8e6c2606979",
325    "metadata": {},
326    "outputs": [],
327    "source": []
328   }
329  ],
330  "metadata": {
331   "kernelspec": {
332    "display_name": "Python 3 (ipykernel)",
333    "language": "python",
334    "name": "python3"
335   },
336   "language_info": {
337    "codemirror_mode": {
338     "name": "ipython",
339     "version": 3
340    },
341    "file_extension": ".py",
342    "mimetype": "text/x-python",
343    "name": "python",
344    "nbconvert_exporter": "python",
345    "pygments_lexer": "ipython3",
346    "version": "3.12.5"
347   }
348  },
349  "nbformat": 4,
350  "nbformat_minor": 5