4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# v2.1 run RNN with Spatial Training\n",
10 "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n"
14 "cell_type": "markdown",
15 "id": "bbd84d61-a9cd-47b4-b538-4986fb10b98d",
18 "## Environment Setup"
23 "execution_count": null,
24 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
28 "import numpy as np\n",
29 "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
32 "import os.path as osp\n",
33 "from moisture_rnn_pkl import pkl2train\n",
34 "from moisture_rnn import RNNParams, RNNData, RNN \n",
35 "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n",
36 "from moisture_rnn import RNN\n",
37 "import reproducibility\n",
38 "from data_funcs import rmse, to_json, combine_nested\n",
39 "from moisture_models import run_augmented_kf\n",
41 "import pandas as pd\n",
42 "import matplotlib.pyplot as plt\n",
49 "execution_count": null,
50 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
59 "execution_count": null,
60 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
65 " url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
66 " dest_path = \"fmda_nw_202401-05_f05.pkl\")"
71 "execution_count": null,
72 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
76 "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
77 "file_names=['fmda_nw_202401-05_f05.pkl']\n",
79 "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
84 "execution_count": null,
85 "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
89 "# read/write control\n",
90 "train_file='train.pkl'\n",
91 "train_create=False # if false, read\n",
92 "train_write=False\n",
98 "execution_count": null,
99 "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
105 "repro = read_pkl(repro_file)\n",
107 "if train_create:\n",
108 " logging.info('creating the training cases from files %s',file_paths)\n",
109 " # osp.join works on windows too, joins paths using \\ or /\n",
110 " train = pkl2train(file_paths)\n",
112 " with open(train_file, 'wb') as file:\n",
113 " logging.info('Writing the rain cases into file %s',train_file)\n",
114 " pickle.dump(train, file)\n",
116 " logging.info('Reading the train cases from file %s',train_file)\n",
117 " train = read_pkl(train_file)"
122 "execution_count": null,
123 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
127 "params = read_yml(\"params.yaml\", subkey='rnn')\n",
133 "execution_count": null,
134 "id": "78cf4dbc-4e7d-4c6d-ac2e-0bac513f92dd",
138 "# from itertools import islice\n",
139 "# train = {k: train[k] for k in islice(train, 100)}\n",
140 "dat = Dict(combine_nested(train))"
145 "execution_count": null,
146 "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
150 "# Set up output dictionaries\n",
152 "outputs_rnn_serial = {}\n",
153 "outputs_rnn_spatial = {}"
157 "cell_type": "markdown",
158 "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3",
161 "## Spatial Data Traing"
166 "execution_count": null,
167 "id": "c58f9f89-46d8-407c-be8b-8e5f16dbcc51",
171 "params = RNNParams(params)"
176 "execution_count": null,
177 "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52",
182 "start_time = time.time()"
187 "execution_count": null,
188 "id": "c0c7f5fb-4c33-45f8-9a2e-38c9ab1cd4e3",
192 "rnn_dat = RNNData(dat, scaler=\"standard\", \n",
193 " features_list = ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat',\n",
194 " 'solar', 'wind'])\n",
196 "rnn_dat.train_test_split( \n",
197 " time_fracs = [.9, .05, .05],\n",
198 " space_fracs = [.6, .2, .2]\n",
200 "rnn_dat.scale_data()\n",
202 "rnn_dat.batch_reshape(\n",
203 " timesteps = params['timesteps'], \n",
204 " batch_size = params['batch_size']\n",
210 "execution_count": null,
211 "id": "59ddf393-2024-4093-927f-69f135a165b8",
215 "params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours,\n",
216 " 'loc_batch_reset': rnn_dat.n_seqs, \n",
217 " 'epochs': 100, 'learning_rate': 0.0001,\n",
218 " 'recurrent_layers': 2, 'recurrent_units': 40, 'dense_layers': 2, 'dense_units': 20,\n",
219 " 'features_list': rnn_dat.features_list})"
224 "execution_count": null,
225 "id": "4bc11474-fed8-47f2-b9cf-dfdda0d3d3b2",
229 "reproducibility.set_seed(123)\n",
230 "rnn = RNN(params)\n",
231 "m, errs = rnn.run_model(rnn_dat)"
236 "execution_count": null,
237 "id": "704ad662-d81a-488d-be3d-e90bf775a5b8",
246 "execution_count": null,
247 "id": "d53571e3-b6cf-49aa-9848-e3c77053283d",
252 "end_time = time.time()\n",
254 "# Calculate Code Runtime\n",
255 "elapsed_time = end_time - start_time\n",
256 "print(f\"Spatial Training Elapsed time: {elapsed_time:.4f} seconds\")"
260 "cell_type": "markdown",
261 "id": "7d8292a2-418c-48ed-aff7-ccbe98b046d3",
264 "## Run ODE + KF and Compare"
269 "execution_count": null,
270 "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43",
277 "for case in rnn_dat.loc['test_locs']:\n",
278 " print(\"~\"*50)\n",
280 " # Run Augmented KF\n",
281 " print('Running Augmented KF')\n",
282 " train[case]['h2'] = train[case]['hours'] // 2\n",
283 " train[case]['scale_fm'] = 1\n",
284 " m, Ec = run_augmented_kf(train[case])\n",
285 " y = train[case]['y'] \n",
286 " train[case]['m'] = m\n",
287 " print(f\"KF RMSE: {rmse(m,y)}\")\n",
288 " outputs_kf[case] = {'case':case, 'errs': rmse(m,y)}"
293 "execution_count": null,
294 "id": "57b19ec5-23f6-44ec-9f71-16d4d69aec68",
298 "df2 = pd.DataFrame.from_dict(outputs_kf).transpose()\n",
303 "cell_type": "markdown",
304 "id": "86795281-f8ea-4141-81ea-c53fae830e80",
312 "execution_count": null,
313 "id": "508a6392-49bc-4471-ad8e-814f60119283",
322 "execution_count": null,
323 "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0",
332 "execution_count": null,
333 "id": "104ea555-1a88-4293-b2a6-dd870fb4b1ed",
342 "execution_count": null,
343 "id": "dc1d5cd6-2321-43b2-ab88-7f44806dc73f",
352 "execution_count": null,
353 "id": "a73d22ee-707b-44a3-80ab-ad6e671731cf",
360 "execution_count": null,
361 "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2",
369 "display_name": "Python 3 (ipykernel)",
370 "language": "python",
378 "file_extension": ".py",
379 "mimetype": "text/x-python",
381 "nbconvert_exporter": "python",
382 "pygments_lexer": "ipython3",