Update params.yaml
[notebooks.git] / fmda / fmda_rnn_train_and_save.ipynb
blobba882fe6ab65d2f5f82cc13b2f6ff46c76c4c5a5
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN with Spatial Training\n",
9     "\n",
10     "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
30     "import pickle\n",
31     "import logging\n",
32     "import os.path as osp\n",
33     "import tensorflow as tf\n",
34     "from moisture_rnn_pkl import pkl2train\n",
35     "from moisture_rnn import RNNParams, RNNData, RNN \n",
36     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
37     "from moisture_rnn import RNN\n",
38     "import reproducibility\n",
39     "from data_funcs import rmse, to_json, combine_nested, subset_by_features, process_train_dict\n",
40     "from moisture_models import run_augmented_kf\n",
41     "import copy\n",
42     "import pandas as pd\n",
43     "import matplotlib.pyplot as plt\n",
44     "import yaml\n",
45     "import time"
46    ]
47   },
48   {
49    "cell_type": "code",
50    "execution_count": null,
51    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
52    "metadata": {},
53    "outputs": [],
54    "source": [
55     "logging_setup()"
56    ]
57   },
58   {
59    "cell_type": "code",
60    "execution_count": null,
61    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
62    "metadata": {},
63    "outputs": [],
64    "source": [
65     "retrieve_url(\n",
66     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/raws_rocky_202305-202405.pkl.pkl\", \n",
67     "    dest_path = \"data/raws_rocky_202305-202405.pkl\")"
68    ]
69   },
70   {
71    "cell_type": "code",
72    "execution_count": null,
73    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
74    "metadata": {},
75    "outputs": [],
76    "source": [
77     "file_paths = ['data/raws_rocky_202305-202405.pkl']"
78    ]
79   },
80   {
81    "cell_type": "code",
82    "execution_count": null,
83    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
84    "metadata": {},
85    "outputs": [],
86    "source": [
87     "# read/write control\n",
88     "train_file='data/train.pkl'\n",
89     "train_create=True   # if false, read\n",
90     "train_write=False\n",
91     "train_read=False"
92    ]
93   },
94   {
95    "cell_type": "code",
96    "execution_count": null,
97    "id": "604388de-11ab-45c3-9f0d-80bdff0cca60",
98    "metadata": {},
99    "outputs": [],
100    "source": [
101     "# Params used for data filtering\n",
102     "params_data = read_yml(\"params_data.yaml\") \n",
103     "params_data"
104    ]
105   },
106   {
107    "cell_type": "code",
108    "execution_count": null,
109    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
110    "metadata": {},
111    "outputs": [],
112    "source": [
113     "# Params used for setting up RNN\n",
114     "params = read_yml(\"params.yaml\", subkey='rnn') \n",
115     "params"
116    ]
117   },
118   {
119    "cell_type": "code",
120    "execution_count": null,
121    "id": "38e6bc61-e123-4cc9-bdee-54b051bbb352",
122    "metadata": {},
123    "outputs": [],
124    "source": [
125     "feats = ['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain']"
126    ]
127   },
128   {
129    "cell_type": "code",
130    "execution_count": null,
131    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
132    "metadata": {
133     "scrolled": true
134    },
135    "outputs": [],
136    "source": [
137     "if train_create:\n",
138     "    params_data.update({'hours': 1440})\n",
139     "    logging.info('creating the training cases from files %s',file_paths)\n",
140     "    # osp.join works on windows too, joins paths using \\ or /\n",
141     "    train = process_train_dict(file_paths, atm_dict = \"RAWS\", params_data = params_data, verbose=True)\n",
142     "    train = subset_by_features(train, feats)\n",
143     "    train = combine_nested(train)\n",
144     "if train_write:\n",
145     "    with open(train_file, 'wb') as file:\n",
146     "        logging.info('Writing the rain cases into file %s',train_file)\n",
147     "        pickle.dump(train, file)\n",
148     "if train_read:\n",
149     "    logging.info('Reading the train cases from file %s',train_file)\n",
150     "    train = read_pkl(train_file)"
151    ]
152   },
153   {
154    "cell_type": "code",
155    "execution_count": null,
156    "id": "23cd60c0-9865-4314-9a96-948c3d400c08",
157    "metadata": {},
158    "outputs": [],
159    "source": [
160     "from itertools import islice\n",
161     "train = {k: train[k] for k in islice(train, 500)}"
162    ]
163   },
164   {
165    "cell_type": "markdown",
166    "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
167    "metadata": {},
168    "source": [
169     "## Spatial Data Training\n",
170     "\n",
171     "This method combines the training timeseries data into a single 3-d array, with timeseries at the same location arranged appropriately in the right order for a given `batch_size` hyperparameter. The hidden state of the recurrent layers are set up reset when the location changes. "
172    ]
173   },
174   {
175    "cell_type": "code",
176    "execution_count": null,
177    "id": "36823193-b93c-421e-b699-8c1ae5719309",
178    "metadata": {},
179    "outputs": [],
180    "source": [
181     "reproducibility.set_seed(123)"
182    ]
183   },
184   {
185    "cell_type": "code",
186    "execution_count": null,
187    "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d",
188    "metadata": {},
189    "outputs": [],
190    "source": [
191     "params = RNNParams(params)\n",
192     "params.update({'epochs': 200, \n",
193     "               'learning_rate': 0.001,\n",
194     "               'activation': ['relu', 'relu'], # Activation for RNN Layers, Dense layers respectively.\n",
195     "               'recurrent_layers': 1, 'recurrent_units': 30, \n",
196     "               'dense_layers': 1, 'dense_units': 30,\n",
197     "               'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
198     "               'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
199     "               'bmin': 20, # Lower bound of hidden state batch reset, \n",
200     "               'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
201     "               'features_list': feats\n",
202     "              })"
203    ]
204   },
205   {
206    "cell_type": "code",
207    "execution_count": null,
208    "id": "faf93470-b55f-4770-9fa9-3288a2f13fcc",
209    "metadata": {},
210    "outputs": [],
211    "source": [
212     "# Combine Nested Dictionary into Spatial Data\n",
213     "train_sp = Dict(train)"
214    ]
215   },
216   {
217    "cell_type": "code",
218    "execution_count": null,
219    "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
220    "metadata": {},
221    "outputs": [],
222    "source": [
223     "rnn_dat_sp = RNNData(\n",
224     "    train_sp, # input dictionary\n",
225     "    scaler=\"standard\",  # data scaling type\n",
226     "    features_list = params['features_list'] # features for predicting outcome\n",
227     ")\n",
228     "\n",
229     "\n",
230     "rnn_dat_sp.train_test_split(   \n",
231     "    time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n",
232     "    space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n",
233     ")\n",
234     "rnn_dat_sp.scale_data()\n",
235     "\n",
236     "rnn_dat_sp.batch_reshape(\n",
237     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
238     "    batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n",
239     ")"
240    ]
241   },
242   {
243    "cell_type": "code",
244    "execution_count": null,
245    "id": "7431bc95-d384-40fd-a622-bbc0ee68e5cd",
246    "metadata": {},
247    "outputs": [],
248    "source": [
249     "# Update Params specific to spatial training\n",
250     "params.update({\n",
251     "    'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n",
252     "})"
253    ]
254   },
255   {
256    "cell_type": "code",
257    "execution_count": null,
258    "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
259    "metadata": {},
260    "outputs": [],
261    "source": [
262     "rnn_sp = RNN(params)\n",
263     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
264    ]
265   },
266   {
267    "cell_type": "code",
268    "execution_count": null,
269    "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
270    "metadata": {},
271    "outputs": [],
272    "source": [
273     "errs.mean()"
274    ]
275   },
276   {
277    "cell_type": "markdown",
278    "id": "62c1b049-304e-4c90-b1d2-b9b96b9a202f",
279    "metadata": {},
280    "source": [
281     "## Save Model"
282    ]
283   },
284   {
285    "cell_type": "code",
286    "execution_count": null,
287    "id": "f333521f-c724-40bf-8c1c-32735aea52cc",
288    "metadata": {},
289    "outputs": [],
290    "source": [
291     "outpath = \"data/outputs/models\"\n",
292     "filename = osp.join(outpath, f\"model_predict_raws_rocky.keras\")\n",
293     "rnn_sp.model_predict.save(filename)"
294    ]
295   },
296   {
297    "cell_type": "markdown",
298    "id": "bc1c601f-23a9-41b0-b921-47f1340f2a47",
299    "metadata": {},
300    "source": [
301     "## Load and Check"
302    ]
303   },
304   {
305    "cell_type": "code",
306    "execution_count": null,
307    "id": "3c27b3c1-6f60-450e-82ea-18eaf012fece",
308    "metadata": {},
309    "outputs": [],
310    "source": [
311     "mod = tf.keras.models.load_model(filename)"
312    ]
313   },
314   {
315    "cell_type": "code",
316    "execution_count": null,
317    "id": "25bf5420-d681-40ec-9eb8-aed784ca4e5a",
318    "metadata": {},
319    "outputs": [],
320    "source": [
321     "from utils import hash_weights\n",
322     "\n",
323     "hash_weights(mod)"
324    ]
325   },
326   {
327    "cell_type": "code",
328    "execution_count": null,
329    "id": "d773b2ab-18de-4b13-a243-b6353c57f192",
330    "metadata": {},
331    "outputs": [],
332    "source": [
333     "type(rnn_dat_sp.X_test)"
334    ]
335   },
336   {
337    "cell_type": "code",
338    "execution_count": null,
339    "id": "253ba437-c3a2-452b-b8e6-078aa17c8408",
340    "metadata": {},
341    "outputs": [],
342    "source": [
343     "X_test = np.stack(rnn_dat_sp.X_test, axis=0)\n",
344     "y_array = np.stack(rnn_dat_sp.y_test, axis=0)"
345    ]
346   },
347   {
348    "cell_type": "code",
349    "execution_count": null,
350    "id": "f4332dd8-57cd-4f5b-a864-dc72f96d72b2",
351    "metadata": {},
352    "outputs": [],
353    "source": [
354     "preds = mod.predict(X_test)\n",
355     "preds.shape"
356    ]
357   },
358   {
359    "cell_type": "code",
360    "execution_count": null,
361    "id": "4e4cd809-6701-4bd7-b4fe-37c5e35d8999",
362    "metadata": {},
363    "outputs": [],
364    "source": [
365     "np.mean(np.sqrt(np.mean(np.square(preds - y_array), axis=(1,2))))"
366    ]
367   },
368   {
369    "cell_type": "code",
370    "execution_count": null,
371    "id": "4f4d80cb-edef-4720-b335-4af5a04992c3",
372    "metadata": {},
373    "outputs": [],
374    "source": []
375   },
376   {
377    "cell_type": "code",
378    "execution_count": null,
379    "id": "e9d7f913-b391-4e14-9b64-46a0a9786f4a",
380    "metadata": {},
381    "outputs": [],
382    "source": []
383   }
384  ],
385  "metadata": {
386   "kernelspec": {
387    "display_name": "Python 3 (ipykernel)",
388    "language": "python",
389    "name": "python3"
390   },
391   "language_info": {
392    "codemirror_mode": {
393     "name": "ipython",
394     "version": 3
395    },
396    "file_extension": ".py",
397    "mimetype": "text/x-python",
398    "name": "python",
399    "nbconvert_exporter": "python",
400    "pygments_lexer": "ipython3",
401    "version": "3.12.5"
402   }
403  },
404  "nbformat": 4,
405  "nbformat_minor": 5