Update rnn_train_versions.ipynb
[notebooks.git] / fmda / rnn_code_tutorial.ipynb
blobebf5dde85e19294df09001d1b60067a23a9d4a31
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "7f39c039-5ee5-4b46-bf8f-1ae289db8d17",
6    "metadata": {},
7    "source": [
8     "# v2.3 run RNN Class with Spatial Training\n",
9     "\n",
10     "This notebook serves as a guide for using the RNN code in this project. It walks through the core functionality for the data pre-processing, setting up model hyperparameters, structuring data to feed into RNN, and evaluating prediction error with spatiotemporal cross-validation. "
11    ]
12   },
13   {
14    "cell_type": "markdown",
15    "id": "1e98fcc9-3079-45d1-aece-d656d70a4244",
16    "metadata": {},
17    "source": [
18     "## Setup\n",
19     "\n",
20     "We will import certain functions at code cells in relevant sections for clarity, but everything used will be included in this setup cell."
21    ]
22   },
23   {
24    "cell_type": "code",
25    "execution_count": null,
26    "id": "31369263-1526-4117-b25d-c3ed71d298b0",
27    "metadata": {},
28    "outputs": [],
29    "source": [
30     "import numpy as np\n",
31     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
32     "import pickle\n",
33     "import logging\n",
34     "import os.path as osp\n",
35     "from moisture_rnn_pkl import pkl2train\n",
36     "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n",
37     "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict, print_dict_summary\n",
38     "from moisture_rnn import RNN\n",
39     "import reproducibility\n",
40     "from data_funcs import rmse, to_json, combine_nested, build_train_dict\n",
41     "from moisture_models import run_augmented_kf\n",
42     "import copy\n",
43     "import pandas as pd\n",
44     "import matplotlib.pyplot as plt\n",
45     "import yaml\n",
46     "import time"
47    ]
48   },
49   {
50    "cell_type": "code",
51    "execution_count": null,
52    "id": "2c14b5e7-5f22-45d2-8bcd-73a94f9b25e8",
53    "metadata": {},
54    "outputs": [],
55    "source": [
56     "from IPython.display import Markdown, display\n",
57     "\n",
58     "# Helper function to make documentation a little prettier\n",
59     "def print_markdown_docstring(func):\n",
60     "    display(Markdown(f\"```python\\n{func.__doc__}\\n```\"))"
61    ]
62   },
63   {
64    "cell_type": "markdown",
65    "id": "cede04ca-f1ae-411e-b014-e35493c8b9c9",
66    "metadata": {},
67    "source": [
68     "## Acquiring Data\n",
69     "\n",
70     "The expected format of the input data for this project is in the form of nested dictionaries with a particular structure. These dictionaries are produced by the process `build_fmda_dicts` within the `wrfxpy` branch `develop-72-jh`. These files are staged remotely as `pickle` files on the OpenWFM Demo site. The data consist of ground-based observations from RAWS stations and atmospheric data from the HRRR weather model interpolated to the location of the RAWS site. These data were collected by specifying a time period and a spatial bounding box, and all RAWS with FMC sensors were collected within those bounds and time frame.\n",
71     "\n",
72     "<mark>NOTE: as of 2024-10-22 the wrfxpy code is still needs to be merged with the latest changed from Angel. The code that makes fmda dictionaries shouldn't depend much on other changes within wrfxpy</mark>\n",
73     "\n",
74     "The first step is just to retrieve the files. The method is called `retrieve_url`, and lives in a python module `utils`. The `utils` functions are meant to apply to a general context, not anything specific to this project. It uses a method that calls `wget` as a subprocesses and saves to a target directory if the file doesn't already exist. You can force it to download with a function argument. The function documentation is printed below, then it is called using f-strings to make the code more concise."
75    ]
76   },
77   {
78    "cell_type": "code",
79    "execution_count": null,
80    "id": "e32267c9-e5ef-475d-a5ec-00e7212996e4",
81    "metadata": {},
82    "outputs": [],
83    "source": [
84     "print_markdown_docstring(retrieve_url)"
85    ]
86   },
87   {
88    "cell_type": "code",
89    "execution_count": null,
90    "id": "aa00ee6d-1d13-46fc-a942-64578dfe5b7d",
91    "metadata": {},
92    "outputs": [],
93    "source": [
94     "filename = \"fmda_rocky_202403-05_f05.pkl\"\n",
95     "retrieve_url(\n",
96     "    url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n",
97     "    dest_path = f\"data/{filename}\")"
98    ]
99   },
100   {
101    "cell_type": "markdown",
102    "id": "8ced2117-f36b-427b-86ae-0dca1e9cdccf",
103    "metadata": {},
104    "source": [
105     "### Exploring the Nested Dictionary Structure \n",
106     "\n",
107     "The data dictionaries have the following structure:\n",
108     "\n",
109     "* Top level keys are RAWS station IDs and some additional string related to the time period.\n",
110     "* For each of the RAWS sites, there are 3 subdictionaries consisting of different types of data that pertain to that location.\n",
111     "    - A `loc` subdirectory that consists of static information about the physical location of the RAWS site. This includes station ID name, longitude, latitude, elevation, and two grid coordinates named \"pixel_x\" and \"pixel_y\" <mark>This will be renamed to \"grid_coordinate\" in the future</mark>. These correspond to the transformation of the lon/lat coordinates from the RAWS site onto the regular HRRR grid.\n",
112     "    - A `RAWS` subdirectory that includes at least FMC observations and the associated times returned by Synoptic. These times may not line up perfectly with the requested regular hours. In addition to the FMC data, any available ground-based sensor data for variables relevant to FMC were collected. These data are intended to be used as validation for the accuracy of the interpolated HRRR data.\n",
113     "    - A `HRRR` subdirectory that includes atmospheric variables relevant to FMC. The formatted table below shows the variables used by this project, where band numbers come from [NOAA documentation](https://www.nco.ncep.noaa.gov/pmb/products/hrrr/hrrr.t00z.wrfprsf00.grib2.shtml). <mark>More variables will be collected in the future</mark>. The HRRR subdirectory is organized into forecast hours. Each forecast hour subdirectory should have all the same information, just at different times from the HRRR forecast. "
114    ]
115   },
116   {
117    "cell_type": "code",
118    "execution_count": null,
119    "id": "8bebd077-2690-4b91-8779-a1223a5c91dc",
120    "metadata": {},
121    "outputs": [],
122    "source": [
123     "dat = read_pkl(f\"data/{filename}\")\n",
124     "\n",
125     "# Print top level keys, each corresponds to a RAWS site\n",
126     "dat.keys()"
127    ]
128   },
129   {
130    "cell_type": "code",
131    "execution_count": null,
132    "id": "e0dda99c-077c-4fdd-a5ae-0c3095c2057f",
133    "metadata": {},
134    "outputs": [],
135    "source": [
136     "# Check structure within \n",
137     "dat['CPTC2_202403'].keys()"
138    ]
139   },
140   {
141    "cell_type": "code",
142    "execution_count": null,
143    "id": "e5d9e2e8-90ff-4597-8dd9-4a8f3cec2303",
144    "metadata": {},
145    "outputs": [],
146    "source": [
147     "print_dict_summary(dat['CPTC2_202403'])"
148    ]
149   },
150   {
151    "cell_type": "code",
152    "execution_count": null,
153    "id": "f9b9dafc-1020-4973-84e4-321a903441b1",
154    "metadata": {},
155    "outputs": [],
156    "source": [
157     "# Print dataframe used to organize HRRR band retrievals\n",
158     "band_df_hrrr = pd.DataFrame({\n",
159     "    'Band': [616, 620, 624, 628, 629, 661, 561, 612, 643],\n",
160     "    'hrrr_name': ['TMP', 'RH', \"WIND\", 'PRATE', 'APCP',\n",
161     "                  'DSWRF', 'SOILW', 'CNWAT', 'GFLUX'],\n",
162     "    'dict_name': [\"temp\", \"rh\", \"wind\", \"rain\", \"precip_accum\",\n",
163     "                 \"solar\", \"soilm\", \"canopyw\", \"groundflux\"],\n",
164     "    'descr': ['2m Temperature [K]', \n",
165     "              '2m Relative Humidity [%]', \n",
166     "              '10m Wind Speed [m/s]'\n",
167     "              'surface Precip. Rate [kg/m^2/s]',\n",
168     "              'surface Total Precipitation [kg/m^2]',\n",
169     "              'surface Downward Short-Wave Radiation Flux [W/m^2]',\n",
170     "              'surface Total Precipitation [kg/m^2]',\n",
171     "              '0.0m below ground Volumetric Soil Moisture Content [Fraction]',\n",
172     "              'Plant Canopy Surface Water [kg/m^2]',\n",
173     "              'surface Ground Heat Flux [W/m^2]']\n",
174     "})\n",
175     "\n",
176     "band_df_hrrr"
177    ]
178   },
179   {
180    "cell_type": "markdown",
181    "id": "191d9cb9-bbbd-4a7d-8413-8b508e7be052",
182    "metadata": {},
183    "source": [
184     "## Data Processing - Reading and Cleaning Data\n",
185     "\n",
186     "The `build_train_dict` function reads the previously described dictionary and processes it in a few ways. The function lives in the `data_funcs` python module, which is intended to include code that is specific to the particular formatting decisions of this project. The `build_train_dict` function can receive some important parameters that control how it processes the data:\n",
187     "\n",
188     "* `params_data`: this is a configuration file. An example is saved internally in this project as `params_data.yaml`. This file includes hyperparameters related to data filtering. These hyperparameters control how suspect data is flagged and filtered.\n",
189     "* `atm_source`: this specifies the subdictionary source for the atmospheric data. Currently this is one of \"HRRR\" or \"RAWS\".\n",
190     "* `forecast_hour`: this specifies which HRRR forecast hour should be used. At the 0th hour, the HRRR weather model is very smooth and there is no accumulated precipitation yet. Within `wrfxpy`, the 3rd forecast hour is used.\n",
191     "* `spatial`: controls whether or not the separate locations are combined into a single dictionary or not. The reason not to do it is if you want to analyze timeseries at single locations more easily, perhaps to run the ODE+KF physical model of FMC.\n",
192     "\n",
193     "The `build_train_dict` function performs the following operations:\n",
194     "\n",
195     "* Reads a list of file names\n",
196     "* Extracts FMC and all possible modeling variables. This includes\n",
197     "    * Extracting static variables, like elevation, and extending them by the number of timeseries hours to fit a tabular data format for machine learning.\n",
198     "    * Calculates derived features like hour of day and day of year.\n",
199     "    * Calculates hourly precipitation (mm/hr) from accumulated precipitation.\n",
200     "* Temporally interpolate RAWS data, including FMC, to make it line up in time with the HRRR data. The HRRR data is always on a regular hourly interval, but the RAWS data can have missing data or return values not exactly on the hour requested.\n",
201     "* Shift the atmospheric data by the given `forecast_hour`. So if you want to build a timeseries at 3pm using the 3hr HRRR forecast data, you would start your data with the 3hr forecast from noon.\n",
202     "* Perform a series of data filtering steps:\n",
203     "    * If specified, the total timeseries within the input dictioanry is broken up into chunks of a specified number of `hours`. This makes the data filtering much easier, since we want continuous timeseries for training the RNN models, and if chunks of data are missing in time from the RAWS data it is easier to break the whole timeseries into smaller pieces and filter out the bad ones.\n",
204     "    * Physically reasonable min and max values for various variables are applied as filters\n",
205     "    * Two main parameters control what is fully excluded from the training data:\n",
206     "        * `max_intp_time`: this is the maximum number of hours that is allowed for temporal interpolation. Any RAWS site with a longer stretch of missing data will be flagged and removed.\n",
207     "        *  `zero_lag_threshold`: this is the maximum number of hours where there can be zero change in a variable before it is flagged as a broken sensor and values are set to NaN for that period.\n",
208     "        *  NOTE: since this is training data for a model where ample input data is available, we will air on the side of aggressively filtering out suspect data since more can always be collected if volume is an issue. It is possible that sensors break nonrandomly, maybe more missing data in a particular season of the year. This merits further study. "
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "8a9f18e3-be33-4b29-bcda-6f7c7bb72e6a",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "params_data = read_yml(\"params_data.yaml\") \n",
219     "params_data"
220    ]
221   },
222   {
223    "cell_type": "code",
224    "execution_count": null,
225    "id": "dc3ea077-2475-414d-a440-fcc3678f1348",
226    "metadata": {},
227    "outputs": [],
228    "source": [
229     "from data_funcs import build_train_dict\n",
230     "\n",
231     "file_paths = f\"data/{filename}\""
232    ]
233   },
234   {
235    "cell_type": "code",
236    "execution_count": null,
237    "id": "61ed37a2-7fe1-464e-8b4f-13dfad131311",
238    "metadata": {
239     "scrolled": true
240    },
241    "outputs": [],
242    "source": [
243     "train = build_train_dict(\n",
244     "    input_file_paths = [f\"data/{filename}\"], \n",
245     "    atm_source=\"HRRR\", \n",
246     "    params_data = params_data, \n",
247     "    forecast_step = 3,\n",
248     "    spatial=True, \n",
249     "    verbose=True\n",
250     ")"
251    ]
252   },
253   {
254    "cell_type": "code",
255    "execution_count": null,
256    "id": "d8fb4171-d6d5-484c-890b-cc60ec70de69",
257    "metadata": {},
258    "outputs": [],
259    "source": [
260     "# Print Data dictionary keys at the end of the process\n",
261     "train.keys()"
262    ]
263   },
264   {
265    "cell_type": "markdown",
266    "id": "109c976a-121c-44b0-b00f-e7170d92c37c",
267    "metadata": {},
268    "source": [
269     "## RNN Parameters Custom Classes\n",
270     "\n",
271     "This project utilizes a few custom classes. The `RNNParams` custom class is used to make modeling easier and provide checks to avoid errors from incompatible models and data. It takes a dictionary as an input. Dictionaries are used since it easily works with the structure of a json file or a yaml file, two commonly used file formats for storing parameter lists. The parameters includes a number of hyperparameters related to model architecture and data formatting decisions. The `RNNParams` object is needed to pre-process data for the RNN model since it specifies things like percentages of data to be used for train/validation/test. To use this custom class, you read in the yaml file, perhaps using the `read_yml` utility function in this project, and create an RNNParams object out of it.\n",
272     "\n",
273     "These are some of the required elements of an input dictionary to make a RNNParams object and the internal checks associated with them:\n",
274     "\n",
275     "* `features_list`: list of features by name intended to be used for modeling. See `features_list` from the previously processed object `train` for a list of all possible feature names.\n",
276     "    * Internally, a value `n_features` is calculated as the length of this list. This can only be done internally, and changing the features list automatically changes `n_features` to avoid the situation where there is any mismatch.\n",
277     "* `batch_size`, `timesteps`: these parameters control the input data shape. They must be integers.\n",
278     "    * Along with `features_list`, these parameters control the input layer to the model. The input data to the model will thus be `(batch_size, timesteps, n_features)`\n",
279     "* `hidden_layers`, `hidden_units`, `hidden_activation`: each are lists that control hidden layer specifications. Internal checks are run that they must be the same length. Layer type is one of \"rnn\" (simple RNN layer), \"lstm\", \"attention\", \"dropout\", or \"dense\". The units specifies the number of cells, and should be None for attention and dropout layers. The activation type is one of tensorflows recognized activations, including 'linear', 'tanh', and 'sigmoid'. Similarly, the activation type should be None for attention and dropout layers \n",
280     "* `output_layer`, `output_activation`, `output_dimension`: Currently it is a dense layer with 1 cell and linear activation. This is typical for a regression problem where the outcome is a single continuous scalar. Adding to output_dimenision would require changing the target data structure, but this could be done if you desire outputting multiple time steps or values from multiple locations.\n",
281     "* `return_sequences`: whether or not the final recurrent layer should return an entire sequence or only the last time step in the input sequence. This is a tunable hyperparameter. Typically, False leads to better predictions for sequence-to-scalar networks, but True is likely required for sequence-to-sequence networks (not tested yet).\n",
282     "* `time_fracs`, `space_fracs`: these are lists that control the percentage of data used for cross-validation train/validation/test splits. Each must be a list of 3 floats that sum up to 1. `time_fracs` partitions the data based on time, so train must proceed validaton in time, and validation proceeds test in time. `space_fracs` randomly samples physical locations. A physical location should only be included in one of train/validation/test sets."
283    ]
284   },
285   {
286    "cell_type": "code",
287    "execution_count": null,
288    "id": "b5972811-7da8-4a32-85a2-72b86589fc44",
289    "metadata": {},
290    "outputs": [],
291    "source": [
292     "from moisture_rnn import RNNParams"
293    ]
294   },
295   {
296    "cell_type": "code",
297    "execution_count": null,
298    "id": "69b31fc6-f8c6-4d9d-94f8-ebc5e0c8cbb1",
299    "metadata": {},
300    "outputs": [],
301    "source": [
302     "file = read_yml(\"params.yaml\", subkey = \"rnn\")\n",
303     "params = RNNParams(file)\n",
304     "params.update({\n",
305     "    'learning_rate': 0.0001\n",
306     "}) # update some params here for illustrative purposes"
307    ]
308   },
309   {
310    "cell_type": "markdown",
311    "id": "6889df7a-0537-4b21-9d6f-ac20ff21f0ce",
312    "metadata": {},
313    "source": [
314     "## RNN Data Custom Class\n",
315     "\n",
316     "Using the input dictionary and the parameters discussed previously, we create a custom class `RNNData` which controls data scaling and restructuring. The important methods for this class are:\n",
317     "\n",
318     "* `train_test_split`: this splits into train/validation/test sets based on both space and time. This should be run before scaling data, so that only training data is used to scale test data to avoid data leakage. NOTE: the main data `X` and `y` are still organized as lists of ndarrays at this point. This is to make handling spatial locations easier, but it might be desirable to switch these to higher dimensional arrays or pandas dataframes.\n",
319     "* `scale_data`: this applies the given scaler, either MinMax or Standard (Gaussian). The scaler is fit on the training data and applied to the validation and test data.\n",
320     "* `batch_reshape`: this method combines the list of input and target data into 3-d arrays, based on the format `(batch_size, timesteps, n_features)`. This method utilizes a data structuring technique that allows for stateful RNN models to be trained with data from multiple timeseries. For more inforamtion see FMDA with Recurrent Neural Netowrks document, chapter XX <mark> add link </mark>\n",
321     "* `print_hashes`: this runs a utility `hash_ndarray` on all internal data in the object. This data produces a unique string for the data object. "
322    ]
323   },
324   {
325    "cell_type": "code",
326    "execution_count": null,
327    "id": "dbe7edb9-7a67-4369-99c2-17cebf60a126",
328    "metadata": {},
329    "outputs": [],
330    "source": [
331     "from moisture_rnn import RNNData"
332    ]
333   },
334   {
335    "cell_type": "code",
336    "execution_count": null,
337    "id": "7084e589-95ef-41cd-b417-daa15bc7a3ac",
338    "metadata": {},
339    "outputs": [],
340    "source": [
341     "# Set random seeds, affects random sample of locations\n",
342     "reproducibility.set_seed(123)\n",
343     "\n",
344     "rnn_dat = RNNData(\n",
345     "    train, # input dictionary\n",
346     "    scaler=params['scaler'],  # data scaling type\n",
347     "    features_list = params['features_list'] # features for predicting outcome\n",
348     ")"
349    ]
350   },
351   {
352    "cell_type": "code",
353    "execution_count": null,
354    "id": "51fc9042-05f0-4b94-af69-9fa25d4581b0",
355    "metadata": {},
356    "outputs": [],
357    "source": [
358     "rnn_dat.train_test_split(   \n",
359     "    time_fracs = params['time_fracs'], # Percent of total time steps used for train/val/test\n",
360     "    space_fracs = params['space_fracs'] # Percent of total timeseries used for train/val/test\n",
361     ")"
362    ]
363   },
364   {
365    "cell_type": "code",
366    "execution_count": null,
367    "id": "4660ddcf-940e-4c27-8d4d-d8b5fe6da90e",
368    "metadata": {},
369    "outputs": [],
370    "source": [
371     "rnn_dat.scale_data()"
372    ]
373   },
374   {
375    "cell_type": "code",
376    "execution_count": null,
377    "id": "343af419-7640-45e1-a637-5b802e21b56c",
378    "metadata": {},
379    "outputs": [],
380    "source": [
381     "rnn_dat.batch_reshape(\n",
382     "    timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
383     "    batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent\n",
384     "    start_times = np.zeros(len(rnn_dat.loc['train_locs']))\n",
385     ")    "
386    ]
387   },
388   {
389    "cell_type": "code",
390    "execution_count": null,
391    "id": "085102a2-9dd1-4f50-b5c6-63af8777d748",
392    "metadata": {},
393    "outputs": [],
394    "source": [
395     "rnn_dat.print_hashes()"
396    ]
397   },
398   {
399    "cell_type": "markdown",
400    "id": "85215951-d9be-4742-b215-ee43655ecd9f",
401    "metadata": {},
402    "source": [
403     "## RNN Model Class"
404    ]
405   },
406   {
407    "cell_type": "markdown",
408    "id": "c0164048-a823-405c-8fa7-586b1afbfbf1",
409    "metadata": {},
410    "source": [
411     "### Building a Model\n",
412     "\n",
413     "The `RNN` custom class is used to streamline building a model with different layers and handling training and predicting easier. It requires a `RNNParams` object as an input to initialize. Several processes call a utility `hash_weights` which produces a unique hash value for model weights, which is a list a ndarrays. \n",
414     "\n",
415     "On initialization, the `RNNParams` object builds and compiles two neural networks based on the input hyperparameters. One network is used when calling `.fit`, which we will call the \"training network\". The training network has restrictions on the input data shape to be `(batch_size, timesteps, n_features)`. After fitting, the weights are copied over into another neural network, called the \"prediction network\", which is identical except for the input shape is related to be `(None, None, n_features)`. The two networks are used since certain training schemes, particularly stateful, require consistent batch size across samples. But when it comes time for prediction, we want to be able to predict at an arbitrary number of locations and an arbitrary number of timesteps. That is the purpose of the prediction network. But the prediction network is not intended to be used for training, it always just receives it's weights copied over from the training. For more infomation on train versus prediction networks, see Geron 2019 chapter 16 <mark> add cite </mark>. To illustrate this method we will redefine some parameters and examine the resulting networks.\n",
416     "\n",
417     "To run `.fit`, you must set the random seed using the `reproducibility.py` module, which collects all the various types of random seeds that need setting in this project. In this project, tensorflow is configured to run deterministically to ensure reproducibility, so the random seed must be set or tensorflow throws errors."
418    ]
419   },
420   {
421    "cell_type": "code",
422    "execution_count": null,
423    "id": "051eab2b-5764-478b-98fe-696b69e015af",
424    "metadata": {},
425    "outputs": [],
426    "source": [
427     "from moisture_rnn import RNN, rnn_data_wrap\n",
428     "import reproducibility"
429    ]
430   },
431   {
432    "cell_type": "code",
433    "execution_count": null,
434    "id": "615a251f-ffbc-4470-8902-5b99536e3d5a",
435    "metadata": {},
436    "outputs": [],
437    "source": [
438     "params.update({\n",
439     "    'hidden_layers': ['dense', 'lstm', 'dense', 'dropout'],\n",
440     "    'hidden_units': [64, 32, 16, None],\n",
441     "    'hidden_activation': ['relu', 'tanh', 'relu', None],\n",
442     "    'return_sequences': False\n",
443     "})"
444    ]
445   },
446   {
447    "cell_type": "code",
448    "execution_count": null,
449    "id": "f6562cbe-a092-481e-a177-3e869adc9ce7",
450    "metadata": {},
451    "outputs": [],
452    "source": [
453     "reproducibility.set_seed(123)\n",
454     "model = RNN(params)"
455    ]
456   },
457   {
458    "cell_type": "code",
459    "execution_count": null,
460    "id": "ee330213-e88c-4494-93d4-487d4b0bdfcd",
461    "metadata": {},
462    "outputs": [],
463    "source": [
464     "model.model_train.summary()"
465    ]
466   },
467   {
468    "cell_type": "code",
469    "execution_count": null,
470    "id": "ed972463-21e8-463e-85f2-c93cbdf86f96",
471    "metadata": {},
472    "outputs": [],
473    "source": [
474     "model.model_predict.summary()"
475    ]
476   },
477   {
478    "cell_type": "markdown",
479    "id": "05b963d8-f300-4986-ad52-4e8bd7c0dc5c",
480    "metadata": {},
481    "source": [
482     "Notice how in the training model, since we set `return_sequences` to False, the output shape loses a dimension. The final dense layer outputs a single value for each sample in the batch, so output shape of `(batch_size, 1)`. For the prediction model, each layer accepts None for the first two dimensions. In practice, we use this to predict at a certain number of locations for an arbitrary number of timesteps. But in both cases, the number of trainable parameters are the same. This shows is the utility of using two separate models: we can leverage sophisticated training mechanisms that restrict the input data type, but then copy these weights over to a more flexible network that is easier to use for forecasting.\n",
483     "\n",
484     "<mark> Question for Jan: </mark> help me understand the linear algebra of why this works and why it's the same number of parameters."
485    ]
486   },
487   {
488    "cell_type": "markdown",
489    "id": "6c426a9c-47e7-4b1f-a505-5e473333bb9f",
490    "metadata": {},
491    "source": [
492     "### Running the Model\n",
493     "\n",
494     "Internally, the `RNN` class has a `.fit` and a `.predict` method that access the relevant internal models. The fit method also sets up certain callbacks used to control things about the training, such as early stopping based on validation data error. Additionally, the fit method automatically sets the weights of the prediction model at the end.\n",
495     "\n",
496     "We call `.fit` below. Note that this method will access internal hyperparameters that were used to initialize the object, such as the number of epochs and the batch size."
497    ]
498   },
499   {
500    "cell_type": "code",
501    "execution_count": null,
502    "id": "a612440f-482b-40cf-9c57-e0af9a1ba30f",
503    "metadata": {},
504    "outputs": [],
505    "source": [
506     "test_epochs = model.fit(\n",
507     "    rnn_dat.X_train, \n",
508     "    rnn_dat.y_train,\n",
509     "    validation_data = (rnn_dat.X_val, rnn_dat.y_val),\n",
510     "    plot_history = True, # plots train/validation loss by epoch\n",
511     "    verbose_fit = True, # prints epochs updates\n",
512     "    return_epochs = True # returns the epoch used for early stopping. Used for hyperparam tuning\n",
513     ")\n",
514     "\n",
515     "print(f\"{test_epochs=}\")"
516    ]
517   },
518   {
519    "cell_type": "markdown",
520    "id": "e5e90844-83b2-4b05-a137-8bb83465561c",
521    "metadata": {},
522    "source": [
523     "Next, we demonstrate here how the fitted training model weights are identical to the prediction model weights. Then, we predict new values using the prediction model. The shape of the test data will be `(n_locations, n_times, n_features)`. This mimics the formatting before, but for the training model the `batch_size` and `timesteps` were tunable hyperparameters. Here `n_locations` and `n_times` could be any integer values and are determined by the user based on their forecasting needs."
524    ]
525   },
526   {
527    "cell_type": "code",
528    "execution_count": null,
529    "id": "21466e5c-a605-4783-94a4-df15acf6e2b1",
530    "metadata": {},
531    "outputs": [],
532    "source": [
533     "from utils import hash_weights\n",
534     "\n",
535     "print(f\"Fitted Training Model Weights Hash: {hash_weights(model.model_train)}\")\n",
536     "print(f\"Prediction Model Weights Hash:      {hash_weights(model.model_predict)}\")"
537    ]
538   },
539   {
540    "cell_type": "code",
541    "execution_count": null,
542    "id": "45d73dab-8441-4fdc-9fae-c470657cb21c",
543    "metadata": {},
544    "outputs": [],
545    "source": [
546     "# Show test data format, (n_loc, n_times, n_features)\n",
547     "print(f\"Number of Locations in Test Set:   {len(rnn_dat.loc['test_locs'])}\")\n",
548     "print(f\"Number of Features used in Model:  {model.params['n_features']}\")\n",
549     "\n",
550     "print(f\"X_test shape:                      {rnn_dat.X_test.shape}\")\n",
551     "print(f\"y_test shape:                      {rnn_dat.y_test.shape}\")\n"
552    ]
553   },
554   {
555    "cell_type": "code",
556    "execution_count": null,
557    "id": "635c2cc5-edcc-4772-b0fe-8ed5b1eb2be8",
558    "metadata": {},
559    "outputs": [],
560    "source": [
561     "preds = model.predict(\n",
562     "    rnn_dat.X_test\n",
563     ")\n",
564     "\n",
565     "print(f\"{preds.shape = }\")"
566    ]
567   },
568   {
569    "cell_type": "markdown",
570    "id": "f28bdd0d-bbf3-47f9-9f96-bbee2ffd360f",
571    "metadata": {},
572    "source": [
573     "Finally, we calculate the RMSE for each location. If desired, you could calculate the overall RMSE, but we are choosing to group by location and then average the results at the end. This methodology prioritizes accuracy across space, and avoids the situation where large errors at one location get masked by small errors at the other locations. We use a utility `rmse_3d` for this purpose which calculates means and squares across a 3d array in the proper way."
574    ]
575   },
576   {
577    "cell_type": "code",
578    "execution_count": null,
579    "id": "5fc26cdc-0d58-4220-b679-eca95905217f",
580    "metadata": {},
581    "outputs": [],
582    "source": [
583     "from utils import rmse_3d\n",
584     "\n",
585     "print(f\"{rmse_3d(preds, rnn_dat.y_test) = }\")"
586    ]
587   },
588   {
589    "cell_type": "markdown",
590    "id": "565121a5-7c20-4d58-80a9-72418fedbb1b",
591    "metadata": {},
592    "source": [
593     "The `RNN` class has a method `run_model` which combines these steps based on just an input `RNNData` object. It prints out a lot of other information related to parameter configurations. We will reinitialize the model to show reproducibility. The method returns a list of model predictions for each test location and an RMSE associated with that location. Compare the printed weight hashes to before to ensure they match."
594    ]
595   },
596   {
597    "cell_type": "code",
598    "execution_count": null,
599    "id": "05aa8e6a-3e73-45b4-a2c5-5670ad4cd8cc",
600    "metadata": {},
601    "outputs": [],
602    "source": [
603     "reproducibility.set_seed(123)\n",
604     "model = RNN(params)\n",
605     "m, errs = model.run_model(rnn_dat)"
606    ]
607   },
608   {
609    "cell_type": "code",
610    "execution_count": null,
611    "id": "16ce4a1e-305f-4deb-92fb-b3e8c374d61e",
612    "metadata": {},
613    "outputs": [],
614    "source": [
615     "print(f\"{errs.mean() = }\")"
616    ]
617   },
618   {
619    "cell_type": "code",
620    "execution_count": null,
621    "id": "635767e7-f1ad-4c23-a57d-0990f937180e",
622    "metadata": {},
623    "outputs": [],
624    "source": []
625   },
626   {
627    "cell_type": "code",
628    "execution_count": null,
629    "id": "557ea9b5-ae97-495f-91fd-5a68e0102826",
630    "metadata": {},
631    "outputs": [],
632    "source": []
633   }
634  ],
635  "metadata": {
636   "kernelspec": {
637    "display_name": "Python 3 (ipykernel)",
638    "language": "python",
639    "name": "python3"
640   },
641   "language_info": {
642    "codemirror_mode": {
643     "name": "ipython",
644     "version": 3
645    },
646    "file_extension": ".py",
647    "mimetype": "text/x-python",
648    "name": "python",
649    "nbconvert_exporter": "python",
650    "pygments_lexer": "ipython3",
651    "version": "3.12.5"
652   }
653  },
654  "nbformat": 4,
655  "nbformat_minor": 5