Merge pull request #17 from openwfm/restructure
[notebooks.git] / fmda / test_notebooks / fmda_rnn_train_and_save.ipynb
blob8b396b07826659773e2a9de8eda51fe188a0ebb8
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.3 run RNN and Save\n",
9     "\n",
10     "This notebook is intended to test traing and then saving a model object for later use.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "import sys\n",
30     "sys.path.append('..')\n",
31     "import pickle\n",
32     "import logging\n",
33     "import os.path as osp\n",
34     "import tensorflow as tf\n",
35     "from moisture_rnn_pkl import pkl2train\n",
36     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
37     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict, print_dict_summary, print_first, str2time, logging_setup\n",
38     "from moisture_rnn import RNN\n",
39     "import reproducibility\n",
40     "from data_funcs import rmse, to_json, combine_nested, subset_by_features, build_train_dict\n",
41     "from moisture_models import run_augmented_kf\n",
42     "import copy\n",
43     "import pandas as pd\n",
44     "import matplotlib.pyplot as plt\n",
45     "import yaml\n",
46     "import time"
47    ]
48   },
49   {
50    "cell_type": "code",
51    "execution_count": null,
52    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
53    "metadata": {},
54    "outputs": [],
55    "source": [
56     "logging_setup()"
57    ]
58   },
59   {
60    "cell_type": "code",
61    "execution_count": null,
62    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
63    "metadata": {},
64    "outputs": [],
65    "source": [
66     "filename = \"fmda_rocky_202403-05_f05.pkl\"\n",
67     "retrieve_url(\n",
68     "    url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
69     "    dest_path = f\"../data/{filename}\")"
70    ]
71   },
72   {
73    "cell_type": "code",
74    "execution_count": null,
75    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
76    "metadata": {},
77    "outputs": [],
78    "source": [
79     "file_paths = [f'../data/{filename}']"
80    ]
81   },
82   {
83    "cell_type": "code",
84    "execution_count": null,
85    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
86    "metadata": {},
87    "outputs": [],
88    "source": [
89     "# # read/write control\n",
90     "# train_file='../data/train.pkl'\n",
91     "# train_create=True   # if false, read\n",
92     "# train_write=False\n",
93     "# train_read=False"
94    ]
95   },
96   {
97    "cell_type": "code",
98    "execution_count": null,
99    "id": "604388de-11ab-45c3-9f0d-80bdff0cca60",
100    "metadata": {},
101    "outputs": [],
102    "source": [
103     "# Params used for data filtering\n",
104     "params_data = read_yml(\"../params_data.yaml\") \n",
105     "params_data"
106    ]
107   },
108   {
109    "cell_type": "code",
110    "execution_count": null,
111    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
112    "metadata": {},
113    "outputs": [],
114    "source": [
115     "# Params used for setting up RNN\n",
116     "params = read_yml(\"../params.yaml\", subkey='rnn') \n",
117     "params.update({\n",
118     "    'hidden_layers': ['dense', 'lstm', 'attention', 'dense'],\n",
119     "    'hidden_units': [64, 32, None, 32],\n",
120     "    'hidden_activation': ['relu', 'tanh', None, 'relu']\n",
121     "})"
122    ]
123   },
124   {
125    "cell_type": "code",
126    "execution_count": null,
127    "id": "38e6bc61-e123-4cc9-bdee-54b051bbb352",
128    "metadata": {},
129    "outputs": [],
130    "source": [
131     "feats = ['Ed', 'Ew', 'solar', 'wind', 'elev', 'lon', 'lat', 'rain']\n",
132     "params.update({'features_list': feats})"
133    ]
134   },
135   {
136    "cell_type": "code",
137    "execution_count": null,
138    "id": "ef84104f-9898-4cd9-be54-7c480536ee0e",
139    "metadata": {
140     "scrolled": true
141    },
142    "outputs": [],
143    "source": [
144     "train = build_train_dict(file_paths, atm_source=\"RAWS\", params_data = params_data,\n",
145     "                         features_subset = feats, spatial=False, verbose=True)\n",
146     "train = subset_by_features(train, params['features_list'])\n",
147     "train = combine_nested(train)"
148    ]
149   },
150   {
151    "cell_type": "code",
152    "execution_count": null,
153    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
154    "metadata": {
155     "scrolled": true
156    },
157    "outputs": [],
158    "source": [
159     "# if train_create:\n",
160     "#     params_data.update({'hours': 1440})\n",
161     "#     logging.info('creating the training cases from files %s',file_paths)\n",
162     "#     # osp.join works on windows too, joins paths using \\ or /\n",
163     "#     train = process_train_dict(file_paths, atm_dict = \"RAWS\", params_data = params_data, verbose=True)\n",
164     "#     train = subset_by_features(train, feats)\n",
165     "#     train = combine_nested(train)\n",
166     "# if train_write:\n",
167     "#     with open(train_file, 'wb') as file:\n",
168     "#         logging.info('Writing the rain cases into file %s',train_file)\n",
169     "#         pickle.dump(train, file)\n",
170     "# if train_read:\n",
171     "#     logging.info('Reading the train cases from file %s',train_file)\n",
172     "#     train = read_pkl(train_file)"
173    ]
174   },
175   {
176    "cell_type": "markdown",
177    "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
178    "metadata": {},
179    "source": [
180     "## Spatial Data Training\n",
181     "\n",
182     "This method combines the training timeseries data into a single 3-d array, with timeseries at the same location arranged appropriately in the right order for a given `batch_size` hyperparameter. The hidden state of the recurrent layers are set up reset when the location changes. "
183    ]
184   },
185   {
186    "cell_type": "code",
187    "execution_count": null,
188    "id": "36823193-b93c-421e-b699-8c1ae5719309",
189    "metadata": {},
190    "outputs": [],
191    "source": [
192     "reproducibility.set_seed(123)"
193    ]
194   },
195   {
196    "cell_type": "code",
197    "execution_count": null,
198    "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d",
199    "metadata": {},
200    "outputs": [],
201    "source": [
202     "params = RNNParams(params)\n",
203     "# params.update({'epochs': 200, \n",
204     "#                'learning_rate': 0.001,\n",
205     "#                'activation': ['relu', 'relu'], # Activation for RNN Layers, Dense layers respectively.\n",
206     "#                'recurrent_layers': 1, 'recurrent_units': 30, \n",
207     "#                'dense_layers': 1, 'dense_units': 30,\n",
208     "#                'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
209     "#                'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
210     "#                'bmin': 20, # Lower bound of hidden state batch reset, \n",
211     "#                'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
212     "#                'batch_size': 60\n",
213     "#               })"
214    ]
215   },
216   {
217    "cell_type": "code",
218    "execution_count": null,
219    "id": "82bc407d-9d26-41e3-8b58-ab3f7238e105",
220    "metadata": {},
221    "outputs": [],
222    "source": [
223     "import importlib\n",
224     "import moisture_rnn\n",
225     "importlib.reload(moisture_rnn)\n",
226     "from moisture_rnn import RNNData"
227    ]
228   },
229   {
230    "cell_type": "code",
231    "execution_count": null,
232    "id": "924549ba-ea73-4fc9-91b3-8f1f0e32e831",
233    "metadata": {},
234    "outputs": [],
235    "source": [
236     "rnn_dat_sp = rnn_data_wrap(train, params)\n",
237     "params.update({\n",
238     "    'loc_batch_reset': rnn_dat_sp.n_seqs, # Used to reset hidden state when location changes for a given batch\n",
239     "    'bmax': params_data['hours']\n",
240     "})"
241    ]
242   },
243   {
244    "cell_type": "code",
245    "execution_count": null,
246    "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
247    "metadata": {},
248    "outputs": [],
249    "source": [
250     "rnn_sp = RNN(params)\n",
251     "m_sp, errs = rnn_sp.run_model(rnn_dat_sp)"
252    ]
253   },
254   {
255    "cell_type": "code",
256    "execution_count": null,
257    "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
258    "metadata": {},
259    "outputs": [],
260    "source": [
261     "errs.mean()"
262    ]
263   },
264   {
265    "cell_type": "markdown",
266    "id": "62c1b049-304e-4c90-b1d2-b9b96b9a202f",
267    "metadata": {},
268    "source": [
269     "## Save "
270    ]
271   },
272   {
273    "cell_type": "code",
274    "execution_count": null,
275    "id": "f333521f-c724-40bf-8c1c-32735aea52cc",
276    "metadata": {},
277    "outputs": [],
278    "source": [
279     "outpath = \"../outputs/models\"\n",
280     "filename = osp.join(outpath, f\"model_predict_raws_rocky.keras\")\n",
281     "rnn_sp.model_predict.save(filename) # save prediction model only"
282    ]
283   },
284   {
285    "cell_type": "code",
286    "execution_count": null,
287    "id": "cf6231a9-0c7b-45ba-ac75-7fb5b6124c72",
288    "metadata": {},
289    "outputs": [],
290    "source": [
291     "with open(f\"{outpath}/rnn_data_rocky.pkl\", 'wb') as file:\n",
292     "    pickle.dump(rnn_dat_sp, file)"
293    ]
294   },
295   {
296    "cell_type": "markdown",
297    "id": "bc1c601f-23a9-41b0-b921-47f1340f2a47",
298    "metadata": {},
299    "source": [
300     "## Load and Check"
301    ]
302   },
303   {
304    "cell_type": "code",
305    "execution_count": null,
306    "id": "3c27b3c1-6f60-450e-82ea-18eaf012fece",
307    "metadata": {},
308    "outputs": [],
309    "source": [
310     "mod = tf.keras.models.load_model(filename)"
311    ]
312   },
313   {
314    "cell_type": "code",
315    "execution_count": null,
316    "id": "25bf5420-d681-40ec-9eb8-aed784ca4e5a",
317    "metadata": {},
318    "outputs": [],
319    "source": [
320     "from utils import hash_weights\n",
321     "\n",
322     "hash_weights(mod)"
323    ]
324   },
325   {
326    "cell_type": "code",
327    "execution_count": null,
328    "id": "d773b2ab-18de-4b13-a243-b6353c57f192",
329    "metadata": {},
330    "outputs": [],
331    "source": [
332     "type(rnn_dat_sp.X_test)"
333    ]
334   },
335   {
336    "cell_type": "code",
337    "execution_count": null,
338    "id": "253ba437-c3a2-452b-b8e6-078aa17c8408",
339    "metadata": {},
340    "outputs": [],
341    "source": [
342     "X_test = np.stack(rnn_dat_sp.X_test, axis=0)\n",
343     "y_array = np.stack(rnn_dat_sp.y_test, axis=0)"
344    ]
345   },
346   {
347    "cell_type": "code",
348    "execution_count": null,
349    "id": "f4332dd8-57cd-4f5b-a864-dc72f96d72b2",
350    "metadata": {},
351    "outputs": [],
352    "source": [
353     "preds = mod.predict(X_test)\n",
354     "preds.shape"
355    ]
356   },
357   {
358    "cell_type": "code",
359    "execution_count": null,
360    "id": "4e4cd809-6701-4bd7-b4fe-37c5e35d8999",
361    "metadata": {},
362    "outputs": [],
363    "source": [
364     "np.mean(np.sqrt(np.mean(np.square(preds - y_array), axis=(1,2))))"
365    ]
366   },
367   {
368    "cell_type": "code",
369    "execution_count": null,
370    "id": "4f4d80cb-edef-4720-b335-4af5a04992c3",
371    "metadata": {},
372    "outputs": [],
373    "source": []
374   },
375   {
376    "cell_type": "code",
377    "execution_count": null,
378    "id": "e9d7f913-b391-4e14-9b64-46a0a9786f4a",
379    "metadata": {},
380    "outputs": [],
381    "source": []
382   }
383  ],
384  "metadata": {
385   "kernelspec": {
386    "display_name": "Python 3 (ipykernel)",
387    "language": "python",
388    "name": "python3"
389   },
390   "language_info": {
391    "codemirror_mode": {
392     "name": "ipython",
393     "version": 3
394    },
395    "file_extension": ".py",
396    "mimetype": "text/x-python",
397    "name": "python",
398    "nbconvert_exporter": "python",
399    "pygments_lexer": "ipython3",
400    "version": "3.12.5"
401   }
402  },
403  "nbformat": 4,
404  "nbformat_minor": 5