4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# v2.1 run RNN strategy serial by Location\n",
10 "This version of the RNN runs the model on each location separately, one at a time. Two main runs:\n",
11 "1. Run separate model at each location - training and prediction at least location independently - training mode periods 0:train_ind (was 0:h2), then prediction in test_ind:end. Validation data, if any, are from train_ind:test_ind\n",
12 "2. Run same model with multiple fitting calls 0:train_ind at different locations, compare prediction accuracy in test_ind:end at for all location. \n"
17 "execution_count": null,
18 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
22 "import numpy as np\n",
23 "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
26 "import os.path as osp\n",
27 "from moisture_rnn_pkl import pkl2train\n",
28 "from moisture_rnn import RNNParams, RNNData, RNN \n",
29 "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
30 "from moisture_rnn import RNN\n",
31 "import reproducibility\n",
32 "from data_funcs import rmse, to_json, build_train_dict\n",
33 "from moisture_models import run_augmented_kf\n",
35 "import pandas as pd\n",
36 "import matplotlib.pyplot as plt\n",
42 "execution_count": null,
43 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
52 "execution_count": null,
53 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
58 " url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
59 " dest_path = \"data/fmda_nw_202401-05_f05.pkl\")"
64 "execution_count": null,
65 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
69 "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
70 "file_names=['fmda_nw_202401-05_f05.pkl']\n",
72 "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
77 "execution_count": null,
78 "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
82 "# read/write control\n",
83 "train_file='data/train.pkl'\n",
84 "train_create=True # if false, read\n",
85 "train_write=False\n",
91 "execution_count": null,
92 "id": "7b0e7a81-48e0-4b6a-b92e-e911db597607",
96 "params_data = read_yml(\"params_data.yaml\") "
101 "execution_count": null,
102 "id": "da8638f9-28e3-4c76-98e3-0dded261551c",
108 "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=False, verbose=True)"
113 "execution_count": null,
114 "id": "65666e4c-e64e-4a85-adb9-a7729e309602",
118 "repro = read_pkl(repro_file)"
123 "execution_count": null,
124 "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
130 "# repro = read_pkl(repro_file)\n",
132 "# if train_create:\n",
133 "# logging.info('creating the training cases from files %s',file_paths)\n",
134 "# # osp.join works on windows too, joins paths using \\ or /\n",
135 "# # train = pkl2train(file_paths)\n",
136 "# train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=False, verbose=True)\n",
137 "# if train_write:\n",
138 "# with open(train_file, 'wb') as file:\n",
139 "# logging.info('Writing the rain cases into file %s',train_file)\n",
140 "# pickle.dump(train, file)\n",
141 "# if train_read:\n",
142 "# logging.info('Reading the train cases from file %s',train_file)\n",
143 "# train = read_pkl(train_file)"
148 "execution_count": null,
149 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
153 "params_all = read_yml(\"params.yaml\")\n",
154 "print(params_all.keys())"
159 "execution_count": null,
160 "id": "698df86b-8550-4135-81df-45dbf503dd4e",
164 "# from module_param_sets import param_sets"
169 "execution_count": null,
170 "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
174 "param_sets_keys=['rnn']\n",
175 "cases=list(train.keys())[0:50]\n",
176 "# cases=list(train.keys())\n",
177 "# cases.remove('reproducibility')\n",
183 "execution_count": null,
184 "id": "dd22baf2-59d2-460e-8c47-b20116dd5982",
188 "logging.info('Running over parameter sets %s',param_sets_keys)\n",
189 "logging.info('Running over cases %s',cases)"
193 "cell_type": "markdown",
194 "id": "802f3eef-1702-4478-b6e3-2288a6edae24",
197 "## Run Reproducibility Case"
202 "execution_count": null,
203 "id": "69a3adb9-39fd-4c0c-9c9b-aaa2a9a3af40",
207 "params = repro['repro_info']['params']\n",
208 "print(type(params))\n",
211 "# Set up input data\n",
212 "rnn_dat = RNNData(repro, scaler = params['scaler'], features_list = params['features_list'])\n",
213 "rnn_dat.train_test_split(\n",
214 " time_fracs = params['time_fracs']\n",
216 "rnn_dat.scale_data()\n",
217 "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
222 "execution_count": null,
223 "id": "855703c4-d7a9-4579-bca7-7c737a81d0de",
227 "reproducibility.set_seed(123)\n",
228 "rnn = RNN(params)\n",
229 "m, errs = rnn.run_model(rnn_dat, reproducibility_run=True)"
233 "cell_type": "markdown",
234 "id": "49e31fdd-4c14-4a81-9e2b-4c6ba94d1f83",
237 "## Separate Models by Location"
242 "execution_count": null,
243 "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
247 "# Set up output dictionaries\n",
254 "execution_count": null,
255 "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
262 "for k in param_sets_keys:\n",
263 " params = RNNParams(params_all[k])\n",
264 " print(\"~\"*80)\n",
265 " print(\"Running with params:\")\n",
267 " # Increase Val Frac so no errors, TODO fix validation\n",
268 " params.update({\n",
269 " 'train_frac': .9,\n",
270 " 'val_frac': .05,\n",
271 " 'activation': ['relu', 'relu'],\n",
273 " 'dense_units': 10,\n",
274 " 'rnn_layers': 2 \n",
276 " for case in cases:\n",
277 " print(\"~\"*50)\n",
278 " logging.info('Processing case %s',case)\n",
279 " print_dict_summary(train[case])\n",
280 " # Format data & Run Model\n",
281 " # rnn_dat = create_rnn_data2(train[case], params)\n",
282 " rnn_dat = RNNData(train[case], scaler = params['scaler'], features_list = params['features_list'])\n",
283 " rnn_dat.train_test_split(\n",
284 " time_fracs = [.9, .05, .05]\n",
286 " rnn_dat.scale_data()\n",
287 " rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
288 " params.update({'bmax': rnn_dat.hours})\n",
289 " reproducibility.set_seed()\n",
290 " rnn = RNN(params)\n",
291 " m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")\n",
292 " # Add model output to case\n",
293 " train[case]['m_rnn']=m\n",
294 " # Get RMSE Prediction Error\n",
295 " print(f\"RMSE: {errs}\")\n",
296 " outputs_rnn[case] = {'case':case, 'errs': errs.copy()}\n",
298 " # Run Augmented KF\n",
299 " print('Running Augmented KF')\n",
300 " train[case]['h2'] = rnn_dat.test_ind\n",
301 " train[case]['scale_fm'] = 1\n",
302 " m, Ec = run_augmented_kf(train[case])\n",
303 " y = rnn_dat['y'] \n",
304 " train[case]['m_kf'] = m\n",
305 " print(f\"KF RMSE: {rmse(m[rnn_dat.test_ind:],y[rnn_dat.test_ind:])}\")\n",
306 " outputs_kf[case] = {'case':case, 'errs': rmse(m[rnn_dat.test_ind:],y[rnn_dat.test_ind:])}\n",
308 " # Save Outputs \n",
309 " to_json(outputs_rnn, \"rnn_errs.json\")\n",
310 " to_json(outputs_kf, \"kf_errs.json\")"
315 "execution_count": null,
316 "id": "15384e4d-b8ec-4700-bdc2-83b0433d11c9",
320 "logging.info('fmda_rnn_serial.ipynb done')"
325 "execution_count": null,
326 "id": "d0e78fb3-b501-49d6-81a9-1a13da0134a0",
330 "import importlib\n",
331 "import moisture_rnn\n",
332 "importlib.reload(moisture_rnn)\n",
333 "from moisture_rnn import RNN"
338 "execution_count": null,
339 "id": "37053436-8dfe-4c40-8614-811817e83782",
343 "for k in outputs_rnn:\n",
344 " print(\"~\"*50)\n",
345 " print(outputs_rnn[k]['case'])\n",
346 " print(outputs_rnn[k]['errs']['prediction'])"
351 "execution_count": null,
352 "id": "9154d5f7-015f-4ef7-af45-020410a1ea65",
356 "for k in outputs_kf:\n",
357 " print(\"~\"*50)\n",
358 " print(outputs_kf[k]['case'])\n",
359 " print(outputs_kf[k]['errs'])"
364 "execution_count": null,
365 "id": "dfd90d87-fe08-48b5-8cc4-31bd19c5c20a",
371 "cell_type": "markdown",
372 "id": "f3c1c299-1655-4c64-a458-c7723db6ea6d",
375 "### TODO: FIX SCALING in Scheme below\n",
377 "Scaling is done separately in each now."
381 "cell_type": "markdown",
382 "id": "0c0c3470-30f5-4915-98a7-dcdf5760d482",
385 "## Training at Multiple Locations\n",
392 "execution_count": null,
393 "id": "dd1aca73-7279-473e-b2a3-95aa1db7b1a8",
397 "params = RNNParams(params_all['rnn'])\n",
399 " 'epochs': 1, # less epochs since it is per location\n",
400 " 'activation': ['relu', 'relu'],\n",
401 " 'train_frac': .9,\n",
402 " 'val_frac': .05, \n",
403 " 'dense_units': 10,\n",
404 " 'rnn_layers': 2\n",
407 "# rnn_dat = create_rnn_data2(train[cases[0]], params)\n",
408 "rnn_dat = RNNData(train[cases[0]], params['scaler'], params['features_list'])\n",
409 "rnn_dat.train_test_split(\n",
410 " time_fracs = [.9, .05, .05]\n",
412 "rnn_dat.scale_data()\n",
413 "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
418 "execution_count": null,
419 "id": "65b2f9a3-a8f2-4ac1-8e4d-ba38a86eaf40",
423 "reproducibility.set_seed()\n",
429 "execution_count": null,
430 "id": "47a85ef2-8145-4de8-9f2e-86622306ffd8",
437 "print(\"Running with params:\")\n",
440 "for case in cases[0:10]:\n",
441 " print(\"~\"*50)\n",
442 " logging.info('Processing case %s',case)\n",
443 " print_dict_summary(train[case])\n",
444 " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
445 " rnn_dat_temp.train_test_split(\n",
446 " time_fracs = [.9, .05, .05]\n",
448 " rnn_dat_temp.scale_data()\n",
449 " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
450 " rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n",
451 " validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val']))\n",
452 " # run_rnn_pkl(train[case],param_sets[i])"
456 "cell_type": "markdown",
457 "id": "a0421b8d-49aa-4409-8cbf-7732f1137838",
465 "execution_count": null,
466 "id": "63d7854a-94f7-425c-9561-4fe518e044bb",
472 "# Predict Cases Used in Training\n",
474 "inds = np.arange(0,10)\n",
475 "train_keys = list(train.keys())\n",
477 " print(\"~\"*50)\n",
478 " case = train_keys[i]\n",
479 " print(f\"Predicting case {case}\")\n",
480 " # rnn_dat = create_rnn_data2(train[case], params)\n",
481 " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
482 " rnn_dat_temp.train_test_split(\n",
483 " time_fracs = [.9, .05, .05]\n",
485 " rnn_dat_temp.scale_data()\n",
486 " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
487 " X_temp = rnn_dat_temp.scale_all_X()\n",
488 " X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
489 " m = rnn.predict(X_temp).flatten()\n",
490 " test_ind = rnn_dat['test_ind']\n",
491 " rmses.append(rmse(m[test_ind:], rnn_dat['y_test'].flatten()))"
496 "execution_count": null,
497 "id": "2a5423e0-778b-4f69-9ed0-f0082a1fefe5",
506 "execution_count": null,
507 "id": "45c9caae-7ced-4f21-aa05-c9b125e8fdcb",
511 "pd.DataFrame({'Case': list(train.keys())[0:10], 'RMSE': rmses}).style.hide(axis=\"index\")"
516 "execution_count": null,
517 "id": "f710f482-b600-4ea5-9a8a-823a13b4ec7a",
523 "# Predict New Locations\n",
525 "for i, case in enumerate(list(train.keys())[10:100]):\n",
526 " print(\"~\"*50)\n",
527 " print(f\"Predicting case {case}\")\n",
528 " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
529 " rnn_dat_temp.train_test_split(\n",
530 " time_fracs = [.9, .05, .05]\n",
532 " rnn_dat_temp.scale_data()\n",
533 " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
534 " X_temp = rnn_dat_temp.scale_all_X()\n",
535 " X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
536 " m = rnn.predict(X_temp).flatten()\n",
537 " train[case]['m'] = m\n",
538 " test_ind = rnn_dat['test_ind']\n",
539 " rmses.append(rmse(m[test_ind:], rnn_dat.y_test.flatten()))\n",
541 "df = pd.DataFrame({'Case': list(train.keys())[10:100], 'RMSE': rmses})"
546 "execution_count": null,
547 "id": "d793ac87-d94b-4b16-a271-46cdc259b4fe",
551 "df[0:5].style.hide(axis=\"index\")"
556 "execution_count": null,
557 "id": "b99606d1-bd46-4041-8303-1bcbb196f6f4",
566 "execution_count": null,
567 "id": "52ec264d-d4b7-444c-b623-002d6383da30",
576 "execution_count": null,
577 "id": "998922cd-46bb-4063-8284-0497e19c39b0",
586 "execution_count": null,
587 "id": "889f3bbb-9fb2-4621-9e93-1d0bc0f83e01",
594 "execution_count": null,
595 "id": "fe407f61-15f2-4086-a386-7d7a5bb90d26",
602 "execution_count": null,
603 "id": "2fdb63b3-68b8-4877-a7a2-f63257cb29d5",
610 "execution_count": null,
611 "id": "5c7563c5-a880-45c7-8381-8ce4e1a44216",
618 "execution_count": null,
619 "id": "ad5dae6c-1269-4674-a49e-2efe8b956911",
627 "display_name": "Python 3 (ipykernel)",
628 "language": "python",
636 "file_extension": ".py",
637 "mimetype": "text/x-python",
639 "nbconvert_exporter": "python",
640 "pygments_lexer": "ipython3",