Remove imports to deleted function
[notebooks.git] / fmda / fmda_rnn_serial.ipynb
blob33b62c9903deadc399fb0922602a183b29eab5e4
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN strategy serial by Location\n",
9     "\n",
10     "This version of the RNN runs the model on each location separately, one at a time. Two main runs:\n",
11     "1. Run separate model at each location - training and prediction at least location independently - training mode periods 0:train_ind (was 0:h2), then prediction in test_ind:end. Validation data, if any, are from train_ind:test_ind\n",
12     "2. Run same model with multiple fitting calls 0:train_ind at different locations, compare prediction accuracy in test_ind:end  at for all location. \n"
13    ]
14   },
15   {
16    "cell_type": "code",
17    "execution_count": null,
18    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
19    "metadata": {},
20    "outputs": [],
21    "source": [
22     "import numpy as np\n",
23     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
24     "import pickle\n",
25     "import logging\n",
26     "import os.path as osp\n",
27     "from moisture_rnn_pkl import pkl2train\n",
28     "from moisture_rnn import RNNParams, RNNData, RNN \n",
29     "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
30     "from moisture_rnn import RNN\n",
31     "import reproducibility\n",
32     "from data_funcs import rmse, to_json\n",
33     "from moisture_models import run_augmented_kf\n",
34     "import copy\n",
35     "import pandas as pd\n",
36     "import matplotlib.pyplot as plt\n",
37     "import yaml"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "logging_setup()"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "retrieve_url(\n",
58     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
59     "    dest_path = \"data/test_CA_202401.pkl\")"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
70     "file_names=['fmda_nw_202401-05_f05.pkl']\n",
71     "file_dir='data'\n",
72     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
73    ]
74   },
75   {
76    "cell_type": "code",
77    "execution_count": null,
78    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
79    "metadata": {},
80    "outputs": [],
81    "source": [
82     "# read/write control\n",
83     "train_file='train.pkl'\n",
84     "train_create=False   # if false, read\n",
85     "train_write=False\n",
86     "train_read=True"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
93    "metadata": {
94     "scrolled": true
95    },
96    "outputs": [],
97    "source": [
98     "repro = read_pkl(repro_file)\n",
99     "\n",
100     "if train_create:\n",
101     "    logging.info('creating the training cases from files %s',file_paths)\n",
102     "    # osp.join works on windows too, joins paths using \\ or /\n",
103     "    train = pkl2train(file_paths)\n",
104     "if train_write:\n",
105     "    with open(train_file, 'wb') as file:\n",
106     "        logging.info('Writing the rain cases into file %s',train_file)\n",
107     "        pickle.dump(train, file)\n",
108     "if train_read:\n",
109     "    logging.info('Reading the train cases from file %s',train_file)\n",
110     "    train = read_pkl(train_file)"
111    ]
112   },
113   {
114    "cell_type": "code",
115    "execution_count": null,
116    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
117    "metadata": {},
118    "outputs": [],
119    "source": [
120     "params_all = read_yml(\"params.yaml\")\n",
121     "print(params_all.keys())"
122    ]
123   },
124   {
125    "cell_type": "code",
126    "execution_count": null,
127    "id": "698df86b-8550-4135-81df-45dbf503dd4e",
128    "metadata": {},
129    "outputs": [],
130    "source": [
131     "# from module_param_sets import param_sets"
132    ]
133   },
134   {
135    "cell_type": "code",
136    "execution_count": null,
137    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
138    "metadata": {},
139    "outputs": [],
140    "source": [
141     "param_sets_keys=['rnn']\n",
142     "cases=list(train.keys())[0:10]\n",
143     "# cases=list(train.keys())\n",
144     "# cases.remove('reproducibility')\n",
145     "cases"
146    ]
147   },
148   {
149    "cell_type": "code",
150    "execution_count": null,
151    "id": "dd22baf2-59d2-460e-8c47-b20116dd5982",
152    "metadata": {},
153    "outputs": [],
154    "source": [
155     "logging.info('Running over parameter sets %s',param_sets_keys)\n",
156     "logging.info('Running over cases %s',cases)"
157    ]
158   },
159   {
160    "cell_type": "markdown",
161    "id": "802f3eef-1702-4478-b6e3-2288a6edae24",
162    "metadata": {},
163    "source": [
164     "## Run Reproducibility Case"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "69a3adb9-39fd-4c0c-9c9b-aaa2a9a3af40",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "params = repro['repro_info']['params']\n",
175     "print(type(params))\n",
176     "print(params)\n",
177     "\n",
178     "# Set up input data\n",
179     "rnn_dat = RNNData(repro, scaler = params['scaler'], features_list = params['features_list'])\n",
180     "rnn_dat.train_test_split(\n",
181     "    train_frac = params['train_frac'],\n",
182     "    val_frac = params['val_frac']\n",
183     ")\n",
184     "rnn_dat.scale_data()\n",
185     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "855703c4-d7a9-4579-bca7-7c737a81d0de",
192    "metadata": {},
193    "outputs": [],
194    "source": [
195     "reproducibility.set_seed(123)\n",
196     "rnn = RNN(params)\n",
197     "m, errs = rnn.run_model(rnn_dat, reproducibility_run=True)"
198    ]
199   },
200   {
201    "cell_type": "markdown",
202    "id": "49e31fdd-4c14-4a81-9e2b-4c6ba94d1f83",
203    "metadata": {},
204    "source": [
205     "## Separate Models by Location"
206    ]
207   },
208   {
209    "cell_type": "code",
210    "execution_count": null,
211    "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
212    "metadata": {},
213    "outputs": [],
214    "source": [
215     "# Set up output dictionaries\n",
216     "outputs_kf = {}\n",
217     "outputs_rnn = {}"
218    ]
219   },
220   {
221    "cell_type": "code",
222    "execution_count": null,
223    "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
224    "metadata": {
225     "scrolled": true
226    },
227    "outputs": [],
228    "source": [
229     "\n",
230     "for k in param_sets_keys:\n",
231     "    params = RNNParams(params_all[k])\n",
232     "    print(\"~\"*80)\n",
233     "    print(\"Running with params:\")\n",
234     "    print(params)\n",
235     "    # Increase Val Frac so no errors, TODO fix validation\n",
236     "    params.update({\n",
237     "        'train_frac': .9,\n",
238     "        'val_frac': .05,\n",
239     "        'activation': ['relu', 'relu'],\n",
240     "        'epochs': 10,\n",
241     "        'dense_units': 10,\n",
242     "        'rnn_layers': 2       \n",
243     "    })\n",
244     "    for case in cases:\n",
245     "        print(\"~\"*50)\n",
246     "        logging.info('Processing case %s',case)\n",
247     "        print_dict_summary(train[case])\n",
248     "        # Format data & Run Model\n",
249     "        # rnn_dat = create_rnn_data2(train[case], params)\n",
250     "        rnn_dat = RNNData(train[case], scaler = params['scaler'], features_list = params['features_list'])\n",
251     "        rnn_dat.train_test_split(\n",
252     "            train_frac = params['train_frac'],\n",
253     "            val_frac = params['val_frac']\n",
254     "        )\n",
255     "        rnn_dat.scale_data()\n",
256     "        rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
257     "        reproducibility.set_seed()\n",
258     "        rnn = RNN(params)\n",
259     "        m, errs = rnn.run_model(rnn_dat)\n",
260     "        # Add model output to case\n",
261     "        train[case]['m']=m\n",
262     "        # Get RMSE Prediction Error\n",
263     "        print(f\"RMSE: {errs}\")\n",
264     "        outputs_rnn[case] = {'case':case, 'errs': errs.copy()}\n",
265     "        \n",
266     "        # Run Augmented KF\n",
267     "        print('Running Augmented KF')\n",
268     "        train[case]['h2'] = train[case]['hours'] // 2\n",
269     "        train[case]['scale_fm'] = 1\n",
270     "        m, Ec = run_augmented_kf(train[case])\n",
271     "        m = m*rnn_dat['scale_fm']\n",
272     "        y = rnn_dat['y']*rnn_dat['scale_fm']          \n",
273     "        train[case]['m'] = m\n",
274     "        print(f\"KF RMSE: {rmse(m,y)}\")\n",
275     "        outputs_kf[case] = {'case':case, 'errs': rmse(m,y)}\n",
276     "\n",
277     "        # Save Outputs \n",
278     "        to_json(outputs_rnn, \"rnn_errs.json\")\n",
279     "        to_json(outputs_kf, \"kf_errs.json\")"
280    ]
281   },
282   {
283    "cell_type": "code",
284    "execution_count": null,
285    "id": "15384e4d-b8ec-4700-bdc2-83b0433d11c9",
286    "metadata": {},
287    "outputs": [],
288    "source": [
289     "logging.info('fmda_rnn_serial.ipynb done')"
290    ]
291   },
292   {
293    "cell_type": "code",
294    "execution_count": null,
295    "id": "d0e78fb3-b501-49d6-81a9-1a13da0134a0",
296    "metadata": {},
297    "outputs": [],
298    "source": [
299     "import importlib\n",
300     "import moisture_rnn\n",
301     "importlib.reload(moisture_rnn)\n",
302     "from moisture_rnn import RNN"
303    ]
304   },
305   {
306    "cell_type": "code",
307    "execution_count": null,
308    "id": "37053436-8dfe-4c40-8614-811817e83782",
309    "metadata": {},
310    "outputs": [],
311    "source": [
312     "for k in outputs_rnn:\n",
313     "    print(\"~\"*50)\n",
314     "    print(outputs_rnn[k]['case'])\n",
315     "    print(outputs_rnn[k]['errs']['prediction'])"
316    ]
317   },
318   {
319    "cell_type": "code",
320    "execution_count": null,
321    "id": "9154d5f7-015f-4ef7-af45-020410a1ea65",
322    "metadata": {},
323    "outputs": [],
324    "source": [
325     "for k in outputs_kf:\n",
326     "    print(\"~\"*50)\n",
327     "    print(outputs_kf[k]['case'])\n",
328     "    print(outputs_kf[k]['errs'])"
329    ]
330   },
331   {
332    "cell_type": "markdown",
333    "id": "f3c1c299-1655-4c64-a458-c7723db6ea6d",
334    "metadata": {},
335    "source": [
336     "### TODO: FIX SCALING in Scheme below\n",
337     "\n",
338     "Scaling is done separately in each now."
339    ]
340   },
341   {
342    "cell_type": "markdown",
343    "id": "0c0c3470-30f5-4915-98a7-dcdf5760d482",
344    "metadata": {},
345    "source": [
346     "## Training at Multiple Locations\n",
347     "\n",
348     "Still sequential"
349    ]
350   },
351   {
352    "cell_type": "code",
353    "execution_count": null,
354    "id": "dd1aca73-7279-473e-b2a3-95aa1db7b1a8",
355    "metadata": {},
356    "outputs": [],
357    "source": [
358     "params = RNNParams(params_all['rnn'])\n",
359     "params.update({\n",
360     "    'epochs': 1, # less epochs since it is per location\n",
361     "    'activation': ['relu', 'relu'],\n",
362     "    'train_frac': .9,\n",
363     "    'val_frac': .05,    \n",
364     "    'dense_units': 10,\n",
365     "    'rnn_layers': 2\n",
366     "})\n",
367     "\n",
368     "# rnn_dat = create_rnn_data2(train[cases[0]], params)\n",
369     "rnn_dat = RNNData(train[cases[0]], params['scaler'], params['features_list'])\n",
370     "rnn_dat.train_test_split(\n",
371     "    train_frac = params['train_frac'],\n",
372     "    val_frac = params['val_frac']\n",
373     ")\n",
374     "rnn_dat.scale_data()\n",
375     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
376    ]
377   },
378   {
379    "cell_type": "code",
380    "execution_count": null,
381    "id": "65b2f9a3-a8f2-4ac1-8e4d-ba38a86eaf40",
382    "metadata": {},
383    "outputs": [],
384    "source": [
385     "reproducibility.set_seed()\n",
386     "rnn = RNN(params)"
387    ]
388   },
389   {
390    "cell_type": "code",
391    "execution_count": null,
392    "id": "47a85ef2-8145-4de8-9f2e-86622306ffd8",
393    "metadata": {
394     "scrolled": true
395    },
396    "outputs": [],
397    "source": [
398     "print(\"~\"*80)\n",
399     "print(\"Running with params:\")\n",
400     "print(params)\n",
401     "\n",
402     "for case in cases[0:10]:\n",
403     "    print(\"~\"*50)\n",
404     "    logging.info('Processing case %s',case)\n",
405     "    print_dict_summary(train[case])\n",
406     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
407     "    rnn_dat_temp.train_test_split(\n",
408     "        train_frac = params['train_frac'],\n",
409     "        val_frac = params['val_frac']\n",
410     "    )\n",
411     "    rnn_dat_temp.scale_data()\n",
412     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
413     "    rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n",
414     "           validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val']))\n",
415     "    # run_rnn_pkl(train[case],param_sets[i])"
416    ]
417   },
418   {
419    "cell_type": "markdown",
420    "id": "a0421b8d-49aa-4409-8cbf-7732f1137838",
421    "metadata": {},
422    "source": [
423     "### Predict "
424    ]
425   },
426   {
427    "cell_type": "code",
428    "execution_count": null,
429    "id": "63d7854a-94f7-425c-9561-4fe518e044bb",
430    "metadata": {
431     "scrolled": true
432    },
433    "outputs": [],
434    "source": [
435     "# Predict Cases Used in Training\n",
436     "rmses = []\n",
437     "inds = np.arange(0,10)\n",
438     "train_keys = list(train.keys())\n",
439     "for i in inds:\n",
440     "    print(\"~\"*50)\n",
441     "    case = train_keys[i]\n",
442     "    print(f\"Predicting case {case}\")\n",
443     "    # rnn_dat = create_rnn_data2(train[case], params)\n",
444     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
445     "    rnn_dat_temp.train_test_split(\n",
446     "        train_frac = params['train_frac'],\n",
447     "        val_frac = params['val_frac']\n",
448     "    )\n",
449     "    rnn_dat_temp.scale_data()\n",
450     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
451     "    X_temp = rnn_dat_temp.scale_all_X()\n",
452     "    m = rnn.predict(X_temp)\n",
453     "    test_ind = rnn_dat['test_ind']\n",
454     "    rmses.append(rmse(m[test_ind:], rnn_dat['y_test'].flatten()))"
455    ]
456   },
457   {
458    "cell_type": "code",
459    "execution_count": null,
460    "id": "2a5423e0-778b-4f69-9ed0-f0082a1fefe5",
461    "metadata": {},
462    "outputs": [],
463    "source": [
464     "rmses"
465    ]
466   },
467   {
468    "cell_type": "code",
469    "execution_count": null,
470    "id": "45c9caae-7ced-4f21-aa05-c9b125e8fdcb",
471    "metadata": {},
472    "outputs": [],
473    "source": [
474     "pd.DataFrame({'Case': list(train.keys())[0:10], 'RMSE': rmses}).style.hide(axis=\"index\")"
475    ]
476   },
477   {
478    "cell_type": "code",
479    "execution_count": null,
480    "id": "f710f482-b600-4ea5-9a8a-823a13b4ec7a",
481    "metadata": {
482     "scrolled": true
483    },
484    "outputs": [],
485    "source": [
486     "# Predict New Locations\n",
487     "rmses = []\n",
488     "for i, case in enumerate(list(train.keys())[10:100]):\n",
489     "    print(\"~\"*50)\n",
490     "    print(f\"Predicting case {case}\")\n",
491     "    rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n",
492     "    rnn_dat_temp.train_test_split(\n",
493     "        train_frac = params['train_frac'],\n",
494     "        val_frac = params['val_frac']\n",
495     "    )\n",
496     "    rnn_dat_temp.scale_data()\n",
497     "    rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
498     "    X = rnn_dat_temp.scale_all_X()\n",
499     "    m = rnn.predict(X)\n",
500     "    train[case]['m'] = m\n",
501     "    test_ind = rnn_dat['test_ind']\n",
502     "    rmses.append(rmse(m[test_ind:], rnn_dat.y_test.flatten()))\n",
503     "\n",
504     "df = pd.DataFrame({'Case': list(train.keys())[10:100], 'RMSE': rmses})"
505    ]
506   },
507   {
508    "cell_type": "code",
509    "execution_count": null,
510    "id": "d793ac87-d94b-4b16-a271-46cdc259b4fe",
511    "metadata": {},
512    "outputs": [],
513    "source": [
514     "df[0:5].style.hide(axis=\"index\")"
515    ]
516   },
517   {
518    "cell_type": "code",
519    "execution_count": null,
520    "id": "b99606d1-bd46-4041-8303-1bcbb196f6f4",
521    "metadata": {},
522    "outputs": [],
523    "source": [
524     "df"
525    ]
526   },
527   {
528    "cell_type": "code",
529    "execution_count": null,
530    "id": "52ec264d-d4b7-444c-b623-002d6383da30",
531    "metadata": {},
532    "outputs": [],
533    "source": [
534     "df.RMSE.mean()"
535    ]
536   },
537   {
538    "cell_type": "code",
539    "execution_count": null,
540    "id": "998922cd-46bb-4063-8284-0497e19c39b0",
541    "metadata": {},
542    "outputs": [],
543    "source": [
544     "plt.hist(df.RMSE)"
545    ]
546   },
547   {
548    "cell_type": "code",
549    "execution_count": null,
550    "id": "889f3bbb-9fb2-4621-9e93-1d0bc0f83e01",
551    "metadata": {},
552    "outputs": [],
553    "source": []
554   },
555   {
556    "cell_type": "code",
557    "execution_count": null,
558    "id": "fe407f61-15f2-4086-a386-7d7a5bb90d26",
559    "metadata": {},
560    "outputs": [],
561    "source": []
562   },
563   {
564    "cell_type": "code",
565    "execution_count": null,
566    "id": "2fdb63b3-68b8-4877-a7a2-f63257cb29d5",
567    "metadata": {},
568    "outputs": [],
569    "source": []
570   },
571   {
572    "cell_type": "code",
573    "execution_count": null,
574    "id": "5c7563c5-a880-45c7-8381-8ce4e1a44216",
575    "metadata": {},
576    "outputs": [],
577    "source": []
578   },
579   {
580    "cell_type": "code",
581    "execution_count": null,
582    "id": "ad5dae6c-1269-4674-a49e-2efe8b956911",
583    "metadata": {},
584    "outputs": [],
585    "source": []
586   }
587  ],
588  "metadata": {
589   "kernelspec": {
590    "display_name": "Python 3 (ipykernel)",
591    "language": "python",
592    "name": "python3"
593   },
594   "language_info": {
595    "codemirror_mode": {
596     "name": "ipython",
597     "version": 3
598    },
599    "file_extension": ".py",
600    "mimetype": "text/x-python",
601    "name": "python",
602    "nbconvert_exporter": "python",
603    "pygments_lexer": "ipython3",
604    "version": "3.12.5"
605   }
606  },
607  "nbformat": 4,
608  "nbformat_minor": 5