Update params.yaml
[notebooks.git] / fmda / fmda_rnn_test_batch_reset_schedule.ipynb
blobf46413b2808d144993fbd4dea25541d3c2559f21
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# Compare Batch Resetting Schedules\n"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "import numpy as np\n",
19     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
20     "import pickle\n",
21     "import logging\n",
22     "import os.path as osp\n",
23     "from moisture_rnn_pkl import pkl2train\n",
24     "from moisture_rnn import RNNParams, RNNData, RNN \n",
25     "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
26     "from moisture_rnn import RNN\n",
27     "import reproducibility\n",
28     "from data_funcs import rmse, to_json, combine_nested, process_train_dict\n",
29     "from moisture_models import run_augmented_kf\n",
30     "import copy\n",
31     "import pandas as pd\n",
32     "import matplotlib.pyplot as plt\n",
33     "import yaml\n",
34     "import time\n",
35     "import reproducibility\n",
36     "import tensorflow as tf"
37    ]
38   },
39   {
40    "cell_type": "code",
41    "execution_count": null,
42    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
43    "metadata": {},
44    "outputs": [],
45    "source": [
46     "logging_setup()"
47    ]
48   },
49   {
50    "cell_type": "code",
51    "execution_count": null,
52    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
53    "metadata": {},
54    "outputs": [],
55    "source": [
56     "retrieve_url(\n",
57     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
58     "    dest_path = \"fmda_nw_202401-05_f05.pkl\")"
59    ]
60   },
61   {
62    "cell_type": "code",
63    "execution_count": null,
64    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
65    "metadata": {},
66    "outputs": [],
67    "source": [
68     "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
69     "file_names=['fmda_nw_202401-05_f05.pkl']\n",
70     "file_dir='data'\n",
71     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
72    ]
73   },
74   {
75    "cell_type": "code",
76    "execution_count": null,
77    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
78    "metadata": {},
79    "outputs": [],
80    "source": [
81     "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))\n",
82     "params_data = read_yml(\"params_data.yaml\")"
83    ]
84   },
85   {
86    "cell_type": "code",
87    "execution_count": null,
88    "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
89    "metadata": {
90     "scrolled": true
91    },
92    "outputs": [],
93    "source": [
94     "data_params = read_yml(\"params_data.yaml\")\n",
95     "data_params.update({\n",
96     "    'hours': 3000,\n",
97     "    'max_intp_time': 24,\n",
98     "    'zero_lag_threshold': 24\n",
99     "})\n",
100     "train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)"
101    ]
102   },
103   {
104    "cell_type": "code",
105    "execution_count": null,
106    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
107    "metadata": {},
108    "outputs": [],
109    "source": [
110     "from itertools import islice\n",
111     "train = {k: train[k] for k in islice(train, 250)}"
112    ]
113   },
114   {
115    "cell_type": "code",
116    "execution_count": null,
117    "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
118    "metadata": {},
119    "outputs": [],
120    "source": [
121     "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
122     "params.update({'epochs': 200, \n",
123     "               'learning_rate': 0.001,\n",
124     "               'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
125     "               'recurrent_layers': 1, 'recurrent_units': 30, \n",
126     "               'dense_layers': 1, 'dense_units': 30,\n",
127     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
128     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
129     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
130     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
131     "               'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
132     "               'timesteps': 12\n",
133     "              })"
134    ]
135   },
136   {
137    "cell_type": "markdown",
138    "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
139    "metadata": {},
140    "source": [
141     "## Handle Data"
142    ]
143   },
144   {
145    "cell_type": "code",
146    "execution_count": null,
147    "id": "cc23e05f-282d-424b-9a50-6e32d9ae4095",
148    "metadata": {},
149    "outputs": [],
150    "source": [
151     "train_sp = combine_nested(train)\n",
152     "rnn_dat = RNNData(\n",
153     "    train_sp, # input dictionary\n",
154     "    scaler=\"standard\",  # data scaling type\n",
155     "    features_list = params['features_list'] # features for predicting outcome\n",
156     ")\n",
157     "\n",
158     "\n",
159     "rnn_dat.train_test_split(   \n",
160     "    time_fracs = [.9, .05, .05], # Percent of total time steps used for train/val/test\n",
161     "    space_fracs = [.40, .30, .30] # Percent of total timeseries used for train/val/test\n",
162     ")\n",
163     "rnn_dat.scale_data()\n",
164     "\n",
165     "rnn_dat.batch_reshape(\n",
166     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
167     "    batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent\n",
168     "    start_times = np.zeros(len(rnn_dat.loc['train_locs']))\n",
169     ")\n",
170     "\n",
171     "params.update({\n",
172     "    'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
173     "})"
174    ]
175   },
176   {
177    "cell_type": "markdown",
178    "id": "703dca05-5371-409e-b0a6-c430594bb76f",
179    "metadata": {},
180    "source": [
181     "## Non-Stateful"
182    ]
183   },
184   {
185    "cell_type": "code",
186    "execution_count": null,
187    "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
188    "metadata": {},
189    "outputs": [],
190    "source": [
191     "params.update({\n",
192     "    'stateful': False,\n",
193     "    'batch_schedule_type': None\n",
194     "})"
195    ]
196   },
197   {
198    "cell_type": "code",
199    "execution_count": null,
200    "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
201    "metadata": {},
202    "outputs": [],
203    "source": [
204     "reproducibility.set_seed(123)\n",
205     "rnn = RNN(params)\n",
206     "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
207    ]
208   },
209   {
210    "cell_type": "code",
211    "execution_count": null,
212    "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
213    "metadata": {},
214    "outputs": [],
215    "source": [
216     "errs0.mean()"
217    ]
218   },
219   {
220    "cell_type": "markdown",
221    "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
222    "metadata": {},
223    "source": [
224     "## Constant Batch Schedule (Stateful)"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
231    "metadata": {},
232    "outputs": [],
233    "source": [
234     "params.update({\n",
235     "    'stateful': True, \n",
236     "    'batch_schedule_type':'constant', \n",
237     "    'bmin': 20})"
238    ]
239   },
240   {
241    "cell_type": "code",
242    "execution_count": null,
243    "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
244    "metadata": {},
245    "outputs": [],
246    "source": [
247     "reproducibility.set_seed(123)\n",
248     "rnn = RNN(params)\n",
249     "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
250    ]
251   },
252   {
253    "cell_type": "code",
254    "execution_count": null,
255    "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
256    "metadata": {},
257    "outputs": [],
258    "source": [
259     "errs2.mean()"
260    ]
261   },
262   {
263    "cell_type": "markdown",
264    "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
265    "metadata": {},
266    "source": [
267     "## Exp Batch Schedule (Stateful)"
268    ]
269   },
270   {
271    "cell_type": "code",
272    "execution_count": null,
273    "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
274    "metadata": {},
275    "outputs": [],
276    "source": [
277     "params.update({\n",
278     "    'stateful': True, \n",
279     "    'batch_schedule_type':'exp', \n",
280     "    'bmin': 20,\n",
281     "    'bmax': rnn_dat.hours\n",
282     "})"
283    ]
284   },
285   {
286    "cell_type": "code",
287    "execution_count": null,
288    "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
289    "metadata": {},
290    "outputs": [],
291    "source": [
292     "reproducibility.set_seed(123)\n",
293     "rnn = RNN(params)\n",
294     "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
295    ]
296   },
297   {
298    "cell_type": "code",
299    "execution_count": null,
300    "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
301    "metadata": {},
302    "outputs": [],
303    "source": [
304     "errs3.mean()"
305    ]
306   }
307  ],
308  "metadata": {
309   "kernelspec": {
310    "display_name": "Python 3 (ipykernel)",
311    "language": "python",
312    "name": "python3"
313   },
314   "language_info": {
315    "codemirror_mode": {
316     "name": "ipython",
317     "version": 3
318    },
319    "file_extension": ".py",
320    "mimetype": "text/x-python",
321    "name": "python",
322    "nbconvert_exporter": "python",
323    "pygments_lexer": "ipython3",
324    "version": "3.12.5"
325   }
326  },
327  "nbformat": 4,
328  "nbformat_minor": 5