4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# v2.3 Forecast on a grid\n",
10 "This notebook is intended test reading in a saved, trained model and deploying it on a grid.\n"
14 "cell_type": "markdown",
15 "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
18 "## Environment Setup"
23 "execution_count": null,
24 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
28 "import numpy as np\n",
30 "sys.path.append('..')\n",
33 "import os.path as osp\n",
34 "import tensorflow as tf\n",
35 "from moisture_rnn_pkl import pkl2train\n",
36 "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
37 "from utils import read_yml, read_pkl, retrieve_url, Dict, print_dict_summary, rmse_3d\n",
38 "from moisture_rnn import RNN\n",
39 "import reproducibility\n",
40 "from data_funcs import rmse, to_json, combine_nested, subset_by_features, build_train_dict\n",
41 "from moisture_models import run_augmented_kf\n",
43 "import pandas as pd\n",
44 "import matplotlib.pyplot as plt\n",
50 "cell_type": "markdown",
51 "id": "bc1c601f-23a9-41b0-b921-47f1340f2a47",
54 "## Load Model and Examine"
59 "execution_count": null,
60 "id": "3c27b3c1-6f60-450e-82ea-18eaf012fece",
64 "filename = \"../outputs/models/model_predict_raws_rocky.keras\"\n",
65 "mod = tf.keras.models.load_model(filename) # prediction model"
70 "execution_count": null,
71 "id": "44061a00-68c9-462f-b9c4-93e2c629a26b",
75 "# Print Model Summary\n",
81 "execution_count": null,
82 "id": "885b51fa-42f6-4542-9d7a-333ce02ad5d8",
86 "dat = read_pkl(f\"../outputs/models/rnn_data_rocky.pkl\")"
91 "execution_count": null,
92 "id": "3d329c21-b23b-4dd8-844b-dbab98ca8b02",
101 "execution_count": null,
102 "id": "4f4d80cb-edef-4720-b335-4af5a04992c3",
111 "execution_count": null,
112 "id": "bbde8943-9be2-464a-bd26-140265f5943d",
120 "cell_type": "markdown",
121 "id": "6d34d26e-18dc-49dd-8f22-52ec5cbcbda8",
129 "execution_count": null,
130 "id": "63373fa5-74d1-45dc-822f-554527202e63",
134 "preds = mod.predict(dat.X_test)\n",
136 "print(f\"{preds.shape=}\")"
141 "execution_count": null,
142 "id": "aed536af-367b-476a-9f7e-ad99b17bbd38",
146 "errs = rmse_3d(preds, dat.y_test)\n",
147 "print(f\"Test Error: {errs}\")"
151 "cell_type": "markdown",
152 "id": "7b21d063-4076-4526-a579-1536b7bd85a9",
159 "cell_type": "markdown",
160 "id": "f15e1a2f-87c3-4203-a27d-14bbaa0df291",
168 "execution_count": null,
169 "id": "da69389c-6dd8-4f41-9adb-e364a482c075",
178 "execution_count": null,
179 "id": "60ea9ac7-6534-444b-860e-ccbe2d4a86df",
183 "from utils import Dict\n",
184 "keys_to_copy = ['features_list', 'all_features_list']\n",
185 "d = Dict({key: dat[key] for key in keys_to_copy if key in dat})\n",
189 "d['id'] = dat.id[loc_ind]\n",
190 "d['X'] = dat.X[loc_ind][dat.test_ind:, :]\n",
191 "d['y'] = dat.y_test[loc_ind, :, :]\n",
192 "d['hours'] = len(d['y'])\n",
193 "d['m'] = preds[loc_ind, :, :]"
198 "execution_count": null,
199 "id": "015bff30-52a8-4b45-a22d-00cee2ea3b5f",
203 "import importlib\n",
204 "import data_funcs\n",
205 "importlib.reload(data_funcs)\n",
206 "from data_funcs import plot_data"
211 "execution_count": null,
212 "id": "4d3bd61e-586b-48f7-8218-4cf6f9f35e72",
216 "plot_data(d, title=\"RNN Prediction Error\", title2=d['id'], plot_period=\"all\")"
220 "cell_type": "markdown",
221 "id": "f7667277-cd34-4ac6-b299-6a2feb17fd11",
229 "execution_count": null,
230 "id": "5daa22df-bbed-465c-9a77-14e0c13e050f",
239 "execution_count": null,
240 "id": "dc2a6ea0-796b-4193-872f-fda5d82784b6",
244 "test_indices = [i for i, val in enumerate(dat[\"id\"]) if val in dat.loc['test_locs']]\n",
245 "x_coord_test = [dat.loc[\"pixel_x\"][i] for i in test_indices]\n",
246 "y_coord_test = [dat.loc[\"pixel_y\"][i] for i in test_indices]"
251 "execution_count": null,
252 "id": "8f8a9568-7e98-4649-b1aa-dc3555dafcfb",
261 "execution_count": null,
262 "id": "d9bb9446-9cb7-45d1-b6b6-1461f87ed15c",
271 "execution_count": null,
272 "id": "bae49a58-7ee3-47de-8348-14fb2a270fc7",
278 "sc = plt.scatter(\n",
279 " x_coord_test, y_coord_test,\n",
280 " c = preds[:, tstep, 0], # Assuming the first dimension is to be used for color\n",
281 " cmap='viridis' # Choose a colormap\n",
283 "plt.colorbar(sc) # Add colorbar for scale\n",
284 "plt.xlabel(\"X Grid Coordinate\")\n",
285 "plt.ylabel(\"Y Grid Coordinate\")\n",
291 "execution_count": null,
292 "id": "6954b184-7910-499d-a023-8adce25c9225",
301 "execution_count": null,
302 "id": "f8e5da73-3b4f-4319-aeab-3f9f6fb97f33",
306 "from matplotlib.animation import FuncAnimation, PillowWriter\n",
309 "fig, ax = plt.subplots()\n",
310 "sc = ax.scatter(x_coord_test, y_coord_test, c=preds[:, 0, 0], cmap='viridis')\n",
311 "plt.colorbar(sc)\n",
312 "ax.set_xlabel(\"X Grid Coordinate\")\n",
313 "ax.set_ylabel(\"Y Grid Coordinate\")\n",
315 "# Function to update the scatter plot at each frame\n",
316 "def update(tstep):\n",
317 " sc.set_array(preds[:, tstep, 0]) # Update the colors based on tstep\n",
318 " ax.set_title(f'Time Step: {tstep}')\n",
321 "# Number of frames (time steps)\n",
322 "num_timesteps = preds.shape[1]\n",
324 "# Create the animation\n",
325 "ani = FuncAnimation(\n",
326 " fig, update, frames=np.arange(num_timesteps), interval=200, repeat=True\n",
329 "# Save the animation as a GIF\n",
330 "ani.save(\"../outputs/animation.gif\", writer=PillowWriter(fps=5))\n"
335 "execution_count": null,
336 "id": "e5fe7d5b-2ad6-46c1-8608-53b45368583a",
340 "# from IPython.display import Image\n",
341 "# Image(filename=\"../outputs/animation.gif\")"
346 "execution_count": null,
347 "id": "17e1b6e0-024e-4330-a403-78f0415d9ce2",
354 "execution_count": null,
355 "id": "6c4c149a-b70b-4a07-b354-a1119b196363",
362 "execution_count": null,
363 "id": "d9543769-b81d-4e7e-8b15-a2d45b35941f",
370 "execution_count": null,
371 "id": "190b3a8d-b5f2-446d-b63f-bfcdb65f75ae",
379 "display_name": "Python 3 (ipykernel)",
380 "language": "python",
388 "file_extension": ".py",
389 "mimetype": "text/x-python",
391 "nbconvert_exporter": "python",
392 "pygments_lexer": "ipython3",