Merge pull request #17 from openwfm/restructure
[notebooks.git] / fmda / test_notebooks / fmda_forecast_model.ipynb
blob5f69c2cf43c96e4a2695ee9dcdeba49bf6eab328
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.3 Forecast on a grid\n",
9     "\n",
10     "This notebook is intended test reading in a saved, trained model and deploying it on a grid.\n"
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
16    "metadata": {},
17    "source": [
18     "## Environment Setup"
19    ]
20   },
21   {
22    "cell_type": "code",
23    "execution_count": null,
24    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
25    "metadata": {},
26    "outputs": [],
27    "source": [
28     "import numpy as np\n",
29     "import sys\n",
30     "sys.path.append('..')\n",
31     "import pickle\n",
32     "import logging\n",
33     "import os.path as osp\n",
34     "import tensorflow as tf\n",
35     "from moisture_rnn_pkl import pkl2train\n",
36     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
37     "from utils import read_yml, read_pkl, retrieve_url, Dict, print_dict_summary, rmse_3d\n",
38     "from moisture_rnn import RNN\n",
39     "import reproducibility\n",
40     "from data_funcs import rmse, to_json, combine_nested, subset_by_features, build_train_dict\n",
41     "from moisture_models import run_augmented_kf\n",
42     "import copy\n",
43     "import pandas as pd\n",
44     "import matplotlib.pyplot as plt\n",
45     "import yaml\n",
46     "import time"
47    ]
48   },
49   {
50    "cell_type": "markdown",
51    "id": "bc1c601f-23a9-41b0-b921-47f1340f2a47",
52    "metadata": {},
53    "source": [
54     "## Load Model and Examine"
55    ]
56   },
57   {
58    "cell_type": "code",
59    "execution_count": null,
60    "id": "3c27b3c1-6f60-450e-82ea-18eaf012fece",
61    "metadata": {},
62    "outputs": [],
63    "source": [
64     "filename = \"../outputs/models/model_predict_raws_rocky.keras\"\n",
65     "mod = tf.keras.models.load_model(filename) # prediction model"
66    ]
67   },
68   {
69    "cell_type": "code",
70    "execution_count": null,
71    "id": "44061a00-68c9-462f-b9c4-93e2c629a26b",
72    "metadata": {},
73    "outputs": [],
74    "source": [
75     "# Print Model Summary\n",
76     "mod.summary()"
77    ]
78   },
79   {
80    "cell_type": "code",
81    "execution_count": null,
82    "id": "885b51fa-42f6-4542-9d7a-333ce02ad5d8",
83    "metadata": {},
84    "outputs": [],
85    "source": [
86     "dat = read_pkl(f\"../outputs/models/rnn_data_rocky.pkl\")"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "3d329c21-b23b-4dd8-844b-dbab98ca8b02",
93    "metadata": {},
94    "outputs": [],
95    "source": [
96     "type(dat)"
97    ]
98   },
99   {
100    "cell_type": "code",
101    "execution_count": null,
102    "id": "4f4d80cb-edef-4720-b335-4af5a04992c3",
103    "metadata": {},
104    "outputs": [],
105    "source": [
106     "dat.keys()"
107    ]
108   },
109   {
110    "cell_type": "code",
111    "execution_count": null,
112    "id": "bbde8943-9be2-464a-bd26-140265f5943d",
113    "metadata": {},
114    "outputs": [],
115    "source": [
116     "dat.X_test.shape"
117    ]
118   },
119   {
120    "cell_type": "markdown",
121    "id": "6d34d26e-18dc-49dd-8f22-52ec5cbcbda8",
122    "metadata": {},
123    "source": [
124     "## Predict"
125    ]
126   },
127   {
128    "cell_type": "code",
129    "execution_count": null,
130    "id": "63373fa5-74d1-45dc-822f-554527202e63",
131    "metadata": {},
132    "outputs": [],
133    "source": [
134     "preds = mod.predict(dat.X_test)\n",
135     "\n",
136     "print(f\"{preds.shape=}\")"
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "aed536af-367b-476a-9f7e-ad99b17bbd38",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "errs = rmse_3d(preds, dat.y_test)\n",
147     "print(f\"Test Error: {errs}\")"
148    ]
149   },
150   {
151    "cell_type": "markdown",
152    "id": "7b21d063-4076-4526-a579-1536b7bd85a9",
153    "metadata": {},
154    "source": [
155     "## Plot"
156    ]
157   },
158   {
159    "cell_type": "markdown",
160    "id": "f15e1a2f-87c3-4203-a27d-14bbaa0df291",
161    "metadata": {},
162    "source": [
163     "### Single Plot"
164    ]
165   },
166   {
167    "cell_type": "code",
168    "execution_count": null,
169    "id": "da69389c-6dd8-4f41-9adb-e364a482c075",
170    "metadata": {},
171    "outputs": [],
172    "source": [
173     "dat.keys()"
174    ]
175   },
176   {
177    "cell_type": "code",
178    "execution_count": null,
179    "id": "60ea9ac7-6534-444b-860e-ccbe2d4a86df",
180    "metadata": {},
181    "outputs": [],
182    "source": [
183     "from utils import Dict\n",
184     "keys_to_copy = ['features_list', 'all_features_list']\n",
185     "d = Dict({key: dat[key] for key in keys_to_copy if key in dat})\n",
186     "\n",
187     "loc_ind = 23\n",
188     "\n",
189     "d['id'] = dat.id[loc_ind]\n",
190     "d['X'] = dat.X[loc_ind][dat.test_ind:, :]\n",
191     "d['y'] = dat.y_test[loc_ind, :, :]\n",
192     "d['hours'] = len(d['y'])\n",
193     "d['m'] = preds[loc_ind, :, :]"
194    ]
195   },
196   {
197    "cell_type": "code",
198    "execution_count": null,
199    "id": "015bff30-52a8-4b45-a22d-00cee2ea3b5f",
200    "metadata": {},
201    "outputs": [],
202    "source": [
203     "import importlib\n",
204     "import data_funcs\n",
205     "importlib.reload(data_funcs)\n",
206     "from data_funcs import plot_data"
207    ]
208   },
209   {
210    "cell_type": "code",
211    "execution_count": null,
212    "id": "4d3bd61e-586b-48f7-8218-4cf6f9f35e72",
213    "metadata": {},
214    "outputs": [],
215    "source": [
216     "plot_data(d, title=\"RNN Prediction Error\", title2=d['id'], plot_period=\"all\")"
217    ]
218   },
219   {
220    "cell_type": "markdown",
221    "id": "f7667277-cd34-4ac6-b299-6a2feb17fd11",
222    "metadata": {},
223    "source": [
224     "## Plot Grid"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "5daa22df-bbed-465c-9a77-14e0c13e050f",
231    "metadata": {},
232    "outputs": [],
233    "source": [
234     "dat.loc.keys()"
235    ]
236   },
237   {
238    "cell_type": "code",
239    "execution_count": null,
240    "id": "dc2a6ea0-796b-4193-872f-fda5d82784b6",
241    "metadata": {},
242    "outputs": [],
243    "source": [
244     "test_indices = [i for i, val in enumerate(dat[\"id\"]) if val in dat.loc['test_locs']]\n",
245     "x_coord_test = [dat.loc[\"pixel_x\"][i] for i in test_indices]\n",
246     "y_coord_test = [dat.loc[\"pixel_y\"][i] for i in test_indices]"
247    ]
248   },
249   {
250    "cell_type": "code",
251    "execution_count": null,
252    "id": "8f8a9568-7e98-4649-b1aa-dc3555dafcfb",
253    "metadata": {},
254    "outputs": [],
255    "source": [
256     "len(x_coord_test)"
257    ]
258   },
259   {
260    "cell_type": "code",
261    "execution_count": null,
262    "id": "d9bb9446-9cb7-45d1-b6b6-1461f87ed15c",
263    "metadata": {},
264    "outputs": [],
265    "source": [
266     "preds.shape"
267    ]
268   },
269   {
270    "cell_type": "code",
271    "execution_count": null,
272    "id": "bae49a58-7ee3-47de-8348-14fb2a270fc7",
273    "metadata": {},
274    "outputs": [],
275    "source": [
276     "tstep = 0\n",
277     "\n",
278     "sc = plt.scatter(\n",
279     "    x_coord_test, y_coord_test,\n",
280     "    c = preds[:, tstep, 0],  # Assuming the first dimension is to be used for color\n",
281     "    cmap='viridis'  # Choose a colormap\n",
282     ")\n",
283     "plt.colorbar(sc)  # Add colorbar for scale\n",
284     "plt.xlabel(\"X Grid Coordinate\")\n",
285     "plt.ylabel(\"Y Grid Coordinate\")\n",
286     "plt.show()"
287    ]
288   },
289   {
290    "cell_type": "code",
291    "execution_count": null,
292    "id": "6954b184-7910-499d-a023-8adce25c9225",
293    "metadata": {},
294    "outputs": [],
295    "source": [
296     "%matplotlib inline"
297    ]
298   },
299   {
300    "cell_type": "code",
301    "execution_count": null,
302    "id": "f8e5da73-3b4f-4319-aeab-3f9f6fb97f33",
303    "metadata": {},
304    "outputs": [],
305    "source": [
306     "from matplotlib.animation import FuncAnimation, PillowWriter\n",
307     "\n",
308     "\n",
309     "fig, ax = plt.subplots()\n",
310     "sc = ax.scatter(x_coord_test, y_coord_test, c=preds[:, 0, 0], cmap='viridis')\n",
311     "plt.colorbar(sc)\n",
312     "ax.set_xlabel(\"X Grid Coordinate\")\n",
313     "ax.set_ylabel(\"Y Grid Coordinate\")\n",
314     "\n",
315     "# Function to update the scatter plot at each frame\n",
316     "def update(tstep):\n",
317     "    sc.set_array(preds[:, tstep, 0])  # Update the colors based on tstep\n",
318     "    ax.set_title(f'Time Step: {tstep}')\n",
319     "    return sc,\n",
320     "\n",
321     "# Number of frames (time steps)\n",
322     "num_timesteps = preds.shape[1]\n",
323     "\n",
324     "# Create the animation\n",
325     "ani = FuncAnimation(\n",
326     "    fig, update, frames=np.arange(num_timesteps), interval=200, repeat=True\n",
327     ")\n",
328     "\n",
329     "# Save the animation as a GIF\n",
330     "ani.save(\"../outputs/animation.gif\", writer=PillowWriter(fps=5))\n"
331    ]
332   },
333   {
334    "cell_type": "code",
335    "execution_count": null,
336    "id": "e5fe7d5b-2ad6-46c1-8608-53b45368583a",
337    "metadata": {},
338    "outputs": [],
339    "source": [
340     "# from IPython.display import Image\n",
341     "# Image(filename=\"../outputs/animation.gif\")"
342    ]
343   },
344   {
345    "cell_type": "code",
346    "execution_count": null,
347    "id": "17e1b6e0-024e-4330-a403-78f0415d9ce2",
348    "metadata": {},
349    "outputs": [],
350    "source": []
351   },
352   {
353    "cell_type": "code",
354    "execution_count": null,
355    "id": "6c4c149a-b70b-4a07-b354-a1119b196363",
356    "metadata": {},
357    "outputs": [],
358    "source": []
359   },
360   {
361    "cell_type": "code",
362    "execution_count": null,
363    "id": "d9543769-b81d-4e7e-8b15-a2d45b35941f",
364    "metadata": {},
365    "outputs": [],
366    "source": []
367   },
368   {
369    "cell_type": "code",
370    "execution_count": null,
371    "id": "190b3a8d-b5f2-446d-b63f-bfcdb65f75ae",
372    "metadata": {},
373    "outputs": [],
374    "source": []
375   }
376  ],
377  "metadata": {
378   "kernelspec": {
379    "display_name": "Python 3 (ipykernel)",
380    "language": "python",
381    "name": "python3"
382   },
383   "language_info": {
384    "codemirror_mode": {
385     "name": "ipython",
386     "version": 3
387    },
388    "file_extension": ".py",
389    "mimetype": "text/x-python",
390    "name": "python",
391    "nbconvert_exporter": "python",
392    "pygments_lexer": "ipython3",
393    "version": "3.12.5"
394   }
395  },
396  "nbformat": 4,
397  "nbformat_minor": 5