Major Restructure: redo nested dictionary code with better functions
[notebooks.git] / fmda / fmda_rnn_serial.ipynb
blob659d686012e3423f9f59d240539eae5d30a5d530
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN strategy serial by Location\n",
9     "\n",
10     "This version of the RNN runs the model on each location separately, one at a time. Two main runs:\n",
11     "1. Run separate model at each location - training and prediction at least location independently - training mode periods 0:train_ind (was 0:h2), then prediction in test_ind:end. Validation data, if any, are from train_ind:test_ind\n",
12     "2. Run same model with multiple fitting calls 0:train_ind at different locations, compare prediction accuracy in test_ind:end  at for all location. \n"
13    ]
14   },
15   {
16    "cell_type": "code",
17    "execution_count": null,
18    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
19    "metadata": {},
20    "outputs": [],
21    "source": [
22     "import numpy as np\n",
23     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
24     "import pickle\n",
25     "import logging\n",
26     "import os.path as osp\n",
27     "from moisture_rnn_pkl import pkl2train\n",
28     "from moisture_rnn import RNNParams, RNNData, RNN \n",
29     "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
30     "from moisture_rnn import RNN\n",
31     "import reproducibility\n",
32     "from data_funcs import rmse, to_json, build_train_dict\n",
33     "from moisture_models import run_augmented_kf\n",
34     "import copy\n",
35     "import pandas as pd\n",
36     "import matplotlib.pyplot as plt\n",
37     "import yaml"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "logging_setup()"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "retrieve_url(\n",
58     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
59     "    dest_path = \"data/fmda_nw_202401-05_f05.pkl\")"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
70     "file_names=['fmda_nw_202401-05_f05.pkl']\n",
71     "file_dir='data'\n",
72     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "# read/write control\n",
83     "train_file='data/train.pkl'\n",
84     "train_create=True   # if false, read\n",
85     "train_write=False\n",
86     "train_read=False"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "7b0e7a81-48e0-4b6a-b92e-e911db597607",
93    "metadata": {},
94    "outputs": [],
95    "source": [
96     "params_data = read_yml(\"params_data.yaml\") "
97    ]
98   },
99   {
100    "cell_type": "code",
101    "execution_count": null,
102    "id": "da8638f9-28e3-4c76-98e3-0dded261551c",
103    "metadata": {
104     "scrolled": true
105    },
106    "outputs": [],
107    "source": [
108     "train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=False, verbose=True)"
109    ]
110   },
111   {
112    "cell_type": "code",
113    "execution_count": null,
114    "id": "65666e4c-e64e-4a85-adb9-a7729e309602",
115    "metadata": {},
116    "outputs": [],
117    "source": [
118     "repro = read_pkl(repro_file)"
119    ]
120   },
121   {
122    "cell_type": "code",
123    "execution_count": null,
124    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
125    "metadata": {
126     "scrolled": true
127    },
128    "outputs": [],
129    "source": [
130     "# repro = read_pkl(repro_file)\n",
131     "\n",
132     "# if train_create:\n",
133     "#     logging.info('creating the training cases from files %s',file_paths)\n",
134     "#     # osp.join works on windows too, joins paths using \\ or /\n",
135     "#     # train = pkl2train(file_paths)\n",
136     "#     train = build_train_dict(file_paths, atm_source=\"HRRR\", params_data = params_data, spatial=False, verbose=True)\n",
137     "# if train_write:\n",
138     "#     with open(train_file, 'wb') as file:\n",
139     "#         logging.info('Writing the rain cases into file %s',train_file)\n",
140     "#         pickle.dump(train, file)\n",
141     "# if train_read:\n",
142     "#     logging.info('Reading the train cases from file %s',train_file)\n",
143     "#     train = read_pkl(train_file)"
144    ]
145   },
146   {
147    "cell_type": "code",
148    "execution_count": null,
149    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
150    "metadata": {},
151    "outputs": [],
152    "source": [
153     "params_all = read_yml(\"params.yaml\")\n",
154     "print(params_all.keys())"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "698df86b-8550-4135-81df-45dbf503dd4e",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "# from module_param_sets import param_sets"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "param_sets_keys=['rnn']\n",
175     "cases=list(train.keys())[0:50]\n",
176     "# cases=list(train.keys())\n",
177     "# cases.remove('reproducibility')\n",
178     "cases"
179    ]
180   },
181   {
182    "cell_type": "code",
183    "execution_count": null,
184    "id": "dd22baf2-59d2-460e-8c47-b20116dd5982",
185    "metadata": {},
186    "outputs": [],
187    "source": [
188     "logging.info('Running over parameter sets %s',param_sets_keys)\n",
189     "logging.info('Running over cases %s',cases)"
190    ]
191   },
192   {
193    "cell_type": "markdown",
194    "id": "802f3eef-1702-4478-b6e3-2288a6edae24",
195    "metadata": {},
196    "source": [
197     "## Run Reproducibility Case"
198    ]
199   },
200   {
201    "cell_type": "code",
202    "execution_count": null,
203    "id": "69a3adb9-39fd-4c0c-9c9b-aaa2a9a3af40",
204    "metadata": {},
205    "outputs": [],
206    "source": [
207     "params = repro['repro_info']['params']\n",
208     "print(type(params))\n",
209     "print(params)\n",
210     "\n",
211     "# Set up input data\n",
212     "rnn_dat = RNNData(repro, scaler = params['scaler'], features_list = params['features_list'])\n",
213     "rnn_dat.train_test_split(\n",
214     "    time_fracs = params['time_fracs']\n",
215     ")\n",
216     "rnn_dat.scale_data()\n",
217     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
218    ]
219   },
220   {
221    "cell_type": "code",
222    "execution_count": null,
223    "id": "855703c4-d7a9-4579-bca7-7c737a81d0de",
224    "metadata": {},
225    "outputs": [],
226    "source": [
227     "reproducibility.set_seed(123)\n",
228     "rnn = RNN(params)\n",
229     "m, errs = rnn.run_model(rnn_dat, reproducibility_run=True)"
230    ]
231   },
232   {
233    "cell_type": "markdown",
234    "id": "49e31fdd-4c14-4a81-9e2b-4c6ba94d1f83",
235    "metadata": {},
236    "source": [
237     "## Separate Models by Location"
238    ]
239   },
240   {
241    "cell_type": "code",
242    "execution_count": null,
243    "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
244    "metadata": {},
245    "outputs": [],
246    "source": [
247     "# Set up output dictionaries\n",
248     "outputs_kf = {}\n",
249     "outputs_rnn = {}"
250    ]
251   },
252   {
253    "cell_type": "code",
254    "execution_count": null,
255    "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
256    "metadata": {
257     "scrolled": true
258    },
259    "outputs": [],
260    "source": [
261     "\n",
262     "for k in param_sets_keys:\n",
263     "    params = RNNParams(params_all[k])\n",
264     "    print(\"~\"*80)\n",
265     "    print(\"Running with params:\")\n",
266     "    print(params)\n",
267     "    # Increase Val Frac so no errors, TODO fix validation\n",
268     "    params.update({\n",
269     "        'train_frac': .9,\n",
270     "        'val_frac': .05,\n",
271     "        'activation': ['relu', 'relu'],\n",
272     "        'epochs': 10,\n",
273     "        'dense_units': 10,\n",
274     "        'rnn_layers': 2       \n",
275     "    })\n",
276     "    for case in cases:\n",
277     "        print(\"~\"*50)\n",
278     "        logging.info('Processing case %s',case)\n",
279     "        print_dict_summary(train[case])\n",
280     "        # Format data & Run Model\n",
281     "        # rnn_dat = create_rnn_data2(train[case], params)\n",
282     "        rnn_dat = RNNData(train[case], scaler = params['scaler'], features_list = params['features_list'])\n",
283     "        rnn_dat.train_test_split(\n",
284     "            time_fracs = [.9, .05, .05]\n",
285     "        )\n",
286     "        rnn_dat.scale_data()\n",
287     "        rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
288     "        params.update({'bmax': rnn_dat.hours})\n",
289     "        reproducibility.set_seed()\n",
290     "        rnn = RNN(params)\n",
291     "        m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")\n",
292     "        # Add model output to case\n",
293     "        train[case]['m_rnn']=m\n",
294     "        # Get RMSE Prediction Error\n",
295     "        print(f\"RMSE: {errs}\")\n",
296     "        outputs_rnn[case] = {'case':case, 'errs': errs.copy()}\n",
297     "        \n",
298     "        # Run Augmented KF\n",
299     "        print('Running Augmented KF')\n",
300     "        train[case]['h2'] = rnn_dat.test_ind\n",
301     "        train[case]['scale_fm'] = 1\n",
302     "        m, Ec = run_augmented_kf(train[case])\n",
303     "        y = rnn_dat['y']        \n",
304     "        train[case]['m_kf'] = m\n",
305     "        print(f\"KF RMSE: {rmse(m[rnn_dat.test_ind:],y[rnn_dat.test_ind:])}\")\n",
306     "        outputs_kf[case] = {'case':case, 'errs': rmse(m[rnn_dat.test_ind:],y[rnn_dat.test_ind:])}\n",
307     "\n",
308     "        # Save Outputs \n",
309     "        to_json(outputs_rnn, \"rnn_errs.json\")\n",
310     "        to_json(outputs_kf, \"kf_errs.json\")"
311    ]
312   },
313   {
314    "cell_type": "code",
315    "execution_count": null,
316    "id": "15384e4d-b8ec-4700-bdc2-83b0433d11c9",
317    "metadata": {},
318    "outputs": [],
319    "source": [
320     "logging.info('fmda_rnn_serial.ipynb done')"
321    ]
322   },
323   {
324    "cell_type": "code",
325    "execution_count": null,
326    "id": "d0e78fb3-b501-49d6-81a9-1a13da0134a0",
327    "metadata": {},
328    "outputs": [],
329    "source": [
330     "import importlib\n",
331     "import moisture_rnn\n",
332     "importlib.reload(moisture_rnn)\n",
333     "from moisture_rnn import RNN"
334    ]
335   },
336   {
337    "cell_type": "code",
338    "execution_count": null,
339    "id": "37053436-8dfe-4c40-8614-811817e83782",
340    "metadata": {},
341    "outputs": [],
342    "source": [
343     "for k in outputs_rnn:\n",
344     "    print(\"~\"*50)\n",
345     "    print(outputs_rnn[k]['case'])\n",
346     "    print(outputs_rnn[k]['errs']['prediction'])"
347    ]
348   },
349   {
350    "cell_type": "code",
351    "execution_count": null,
352    "id": "9154d5f7-015f-4ef7-af45-020410a1ea65",
353    "metadata": {},
354    "outputs": [],
355    "source": [
356     "for k in outputs_kf:\n",
357     "    print(\"~\"*50)\n",
358     "    print(outputs_kf[k]['case'])\n",
359     "    print(outputs_kf[k]['errs'])"
360    ]
361   },
362   {
363    "cell_type": "code",
364    "execution_count": null,
365    "id": "dfd90d87-fe08-48b5-8cc4-31bd19c5c20a",
366    "metadata": {},
367    "outputs": [],
368    "source": []
369   },
370   {
371    "cell_type": "markdown",
372    "id": "f3c1c299-1655-4c64-a458-c7723db6ea6d",
373    "metadata": {},
374    "source": [
375     "### TODO: FIX SCALING in Scheme below\n",
376     "\n",
377     "Scaling is done separately in each now."
378    ]
379   },
380   {
381    "cell_type": "markdown",
382    "id": "0c0c3470-30f5-4915-98a7-dcdf5760d482",
383    "metadata": {},
384    "source": [
385     "## Training at Multiple Locations\n",
386     "\n",
387     "Still sequential"
388    ]
389   },
390   {
391    "cell_type": "code",
392    "execution_count": null,
393    "id": "dd1aca73-7279-473e-b2a3-95aa1db7b1a8",
394    "metadata": {},
395    "outputs": [],
396    "source": [
397     "params = RNNParams(params_all['rnn'])\n",
398     "params.update({\n",
399     "    'epochs': 1, # less epochs since it is per location\n",
400     "    'activation': ['relu', 'relu'],\n",
401     "    'train_frac': .9,\n",
402     "    'val_frac': .05,    \n",
403     "    'dense_units': 10,\n",
404     "    'rnn_layers': 2\n",
405     "})\n",
406     "\n",
407     "# rnn_dat = create_rnn_data2(train[cases[0]], params)\n",
408     "rnn_dat = RNNData(train[cases[0]], params['scaler'], params['features_list'])\n",
409     "rnn_dat.train_test_split(\n",
410     "    time_fracs = [.9, .05, .05]\n",
411     ")\n",
412     "rnn_dat.scale_data()\n",
413     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
414    ]
415   },
416   {
417    "cell_type": "code",
418    "execution_count": null,
419    "id": "65b2f9a3-a8f2-4ac1-8e4d-ba38a86eaf40",
420    "metadata": {},
421    "outputs": [],
422    "source": [
423     "reproducibility.set_seed()\n",
424     "rnn = RNN(params)"
425    ]
426   },
427   {
428    "cell_type": "code",
429    "execution_count": null,
430    "id": "47a85ef2-8145-4de8-9f2e-86622306ffd8",
431    "metadata": {
432     "scrolled": true
433    },
434    "outputs": [],
435    "source": [
436     "print(\"~\"*80)\n",
437     "print(\"Running with params:\")\n",
438     "print(params)\n",
439     "\n",
440     "for case in cases[0:10]:\n",
441     "    print(\"~\"*50)\n",
442     "    logging.info('Processing case %s',case)\n",
443     "    print_dict_summary(train[case])\n",
444     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
445     "    rnn_dat_temp.train_test_split(\n",
446     "        time_fracs = [.9, .05, .05]\n",
447     "    )\n",
448     "    rnn_dat_temp.scale_data()\n",
449     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
450     "    rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n",
451     "           validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val']))\n",
452     "    # run_rnn_pkl(train[case],param_sets[i])"
453    ]
454   },
455   {
456    "cell_type": "markdown",
457    "id": "a0421b8d-49aa-4409-8cbf-7732f1137838",
458    "metadata": {},
459    "source": [
460     "### Predict "
461    ]
462   },
463   {
464    "cell_type": "code",
465    "execution_count": null,
466    "id": "63d7854a-94f7-425c-9561-4fe518e044bb",
467    "metadata": {
468     "scrolled": true
469    },
470    "outputs": [],
471    "source": [
472     "# Predict Cases Used in Training\n",
473     "rmses = []\n",
474     "inds = np.arange(0,10)\n",
475     "train_keys = list(train.keys())\n",
476     "for i in inds:\n",
477     "    print(\"~\"*50)\n",
478     "    case = train_keys[i]\n",
479     "    print(f\"Predicting case {case}\")\n",
480     "    # rnn_dat = create_rnn_data2(train[case], params)\n",
481     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
482     "    rnn_dat_temp.train_test_split(\n",
483     "        time_fracs = [.9, .05, .05]\n",
484     "    )\n",
485     "    rnn_dat_temp.scale_data()\n",
486     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
487     "    X_temp = rnn_dat_temp.scale_all_X()\n",
488     "    X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
489     "    m = rnn.predict(X_temp).flatten()\n",
490     "    test_ind = rnn_dat['test_ind']\n",
491     "    rmses.append(rmse(m[test_ind:], rnn_dat['y_test'].flatten()))"
492    ]
493   },
494   {
495    "cell_type": "code",
496    "execution_count": null,
497    "id": "2a5423e0-778b-4f69-9ed0-f0082a1fefe5",
498    "metadata": {},
499    "outputs": [],
500    "source": [
501     "rmses"
502    ]
503   },
504   {
505    "cell_type": "code",
506    "execution_count": null,
507    "id": "45c9caae-7ced-4f21-aa05-c9b125e8fdcb",
508    "metadata": {},
509    "outputs": [],
510    "source": [
511     "pd.DataFrame({'Case': list(train.keys())[0:10], 'RMSE': rmses}).style.hide(axis=\"index\")"
512    ]
513   },
514   {
515    "cell_type": "code",
516    "execution_count": null,
517    "id": "f710f482-b600-4ea5-9a8a-823a13b4ec7a",
518    "metadata": {
519     "scrolled": true
520    },
521    "outputs": [],
522    "source": [
523     "# Predict New Locations\n",
524     "rmses = []\n",
525     "for i, case in enumerate(list(train.keys())[10:100]):\n",
526     "    print(\"~\"*50)\n",
527     "    print(f\"Predicting case {case}\")\n",
528     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
529     "    rnn_dat_temp.train_test_split(\n",
530     "        time_fracs = [.9, .05, .05]\n",
531     "    )\n",
532     "    rnn_dat_temp.scale_data()\n",
533     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
534     "    X_temp = rnn_dat_temp.scale_all_X()\n",
535     "    X_temp = X_temp.reshape(1, X_temp.shape[0], X_temp.shape[1])\n",
536     "    m = rnn.predict(X_temp).flatten()\n",
537     "    train[case]['m'] = m\n",
538     "    test_ind = rnn_dat['test_ind']\n",
539     "    rmses.append(rmse(m[test_ind:], rnn_dat.y_test.flatten()))\n",
540     "\n",
541     "df = pd.DataFrame({'Case': list(train.keys())[10:100], 'RMSE': rmses})"
542    ]
543   },
544   {
545    "cell_type": "code",
546    "execution_count": null,
547    "id": "d793ac87-d94b-4b16-a271-46cdc259b4fe",
548    "metadata": {},
549    "outputs": [],
550    "source": [
551     "df[0:5].style.hide(axis=\"index\")"
552    ]
553   },
554   {
555    "cell_type": "code",
556    "execution_count": null,
557    "id": "b99606d1-bd46-4041-8303-1bcbb196f6f4",
558    "metadata": {},
559    "outputs": [],
560    "source": [
561     "df"
562    ]
563   },
564   {
565    "cell_type": "code",
566    "execution_count": null,
567    "id": "52ec264d-d4b7-444c-b623-002d6383da30",
568    "metadata": {},
569    "outputs": [],
570    "source": [
571     "df.RMSE.mean()"
572    ]
573   },
574   {
575    "cell_type": "code",
576    "execution_count": null,
577    "id": "998922cd-46bb-4063-8284-0497e19c39b0",
578    "metadata": {},
579    "outputs": [],
580    "source": [
581     "plt.hist(df.RMSE)"
582    ]
583   },
584   {
585    "cell_type": "code",
586    "execution_count": null,
587    "id": "889f3bbb-9fb2-4621-9e93-1d0bc0f83e01",
588    "metadata": {},
589    "outputs": [],
590    "source": []
591   },
592   {
593    "cell_type": "code",
594    "execution_count": null,
595    "id": "fe407f61-15f2-4086-a386-7d7a5bb90d26",
596    "metadata": {},
597    "outputs": [],
598    "source": []
599   },
600   {
601    "cell_type": "code",
602    "execution_count": null,
603    "id": "2fdb63b3-68b8-4877-a7a2-f63257cb29d5",
604    "metadata": {},
605    "outputs": [],
606    "source": []
607   },
608   {
609    "cell_type": "code",
610    "execution_count": null,
611    "id": "5c7563c5-a880-45c7-8381-8ce4e1a44216",
612    "metadata": {},
613    "outputs": [],
614    "source": []
615   },
616   {
617    "cell_type": "code",
618    "execution_count": null,
619    "id": "ad5dae6c-1269-4674-a49e-2efe8b956911",
620    "metadata": {},
621    "outputs": [],
622    "source": []
623   }
624  ],
625  "metadata": {
626   "kernelspec": {
627    "display_name": "Python 3 (ipykernel)",
628    "language": "python",
629    "name": "python3"
630   },
631   "language_info": {
632    "codemirror_mode": {
633     "name": "ipython",
634     "version": 3
635    },
636    "file_extension": ".py",
637    "mimetype": "text/x-python",
638    "name": "python",
639    "nbconvert_exporter": "python",
640    "pygments_lexer": "ipython3",
641    "version": "3.12.5"
642   }
643  },
644  "nbformat": 4,
645  "nbformat_minor": 5