Read pickle files with util
[notebooks.git] / fmda / fmda_rnn_serial.ipynb
blobd897439e391c6e8143fd92e57257c8cf246eb9ff
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2 run RNN strategy serial by Location\n",
9     "\n",
10     "This version of the RNN runs the model on each location separately, one at a time. Two main runs:\n",
11     "1. Run separate model at each location - training and prediction at least location independently - training mode periods 0:train_ind (was 0:h2), then prediction in test_ind:end. Validation data, if any, are from train_ind:test_ind\n",
12     "2. Run same model with multiple fitting calls 0:train_ind at different locations, compare prediction accuracy in test_ind:end  at for all location. \n"
13    ]
14   },
15   {
16    "cell_type": "code",
17    "execution_count": null,
18    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
19    "metadata": {},
20    "outputs": [],
21    "source": [
22     "import numpy as np\n",
23     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
24     "import pickle\n",
25     "import logging\n",
26     "import os.path as osp\n",
27     "from moisture_rnn_pkl import pkl2train\n",
28     "from moisture_rnn import RNNParams, RNN, create_rnn_data2 \n",
29     "from utils import hash2, read_yml\n",
30     "from moisture_rnn import RNN\n",
31     "import reproducibility\n",
32     "from data_funcs import rmse\n",
33     "from moisture_models import run_augmented_kf\n",
34     "import copy\n",
35     "import pandas as pd\n",
36     "import matplotlib.pyplot as plt\n",
37     "import yaml"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "logging_setup()"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "file_names=[\"reproducibility_dict2.pickle\",'test_CA_202401.pkl']\n",
58     "file_dir='data'\n",
59     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": null,
65    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
66    "metadata": {},
67    "outputs": [],
68    "source": [
69     "# read/write control\n",
70     "train_file='train.pkl'\n",
71     "train_create=True   # if false, read\n",
72     "train_write=True\n",
73     "train_read=True"
74    ]
75   },
76   {
77    "cell_type": "code",
78    "execution_count": null,
79    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
80    "metadata": {
81     "scrolled": true
82    },
83    "outputs": [],
84    "source": [
85     "if train_create:\n",
86     "    logging.info('creating the training cases from files %s',file_paths)\n",
87     "    # osp.join works on windows too, joins paths using \\ or /\n",
88     "    train = pkl2train(file_paths)\n",
89     "if train_write:\n",
90     "    with open(train_file, 'wb') as file:\n",
91     "        logging.info('Writing the rain cases into file %s',train_file)\n",
92     "        pickle.dump(train, file)\n",
93     "if train_read:\n",
94     "    logging.info('Reading the train cases from file %s',train_file)\n",
95     "    with open(train_file,'rb') as file:\n",
96     "        train=pickle.load(file)"
97    ]
98   },
99   {
100    "cell_type": "code",
101    "execution_count": null,
102    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
103    "metadata": {},
104    "outputs": [],
105    "source": [
106     "params_all = read_yml(\"params.yaml\")\n",
107     "print(params_all.keys())"
108    ]
109   },
110   {
111    "cell_type": "code",
112    "execution_count": null,
113    "id": "698df86b-8550-4135-81df-45dbf503dd4e",
114    "metadata": {},
115    "outputs": [],
116    "source": [
117     "# from module_param_sets import param_sets"
118    ]
119   },
120   {
121    "cell_type": "code",
122    "execution_count": null,
123    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
124    "metadata": {},
125    "outputs": [],
126    "source": [
127     "param_sets_keys=['rnn_repro', 'rnn']\n",
128     "# cases=[list(train.keys())[0]]\n",
129     "cases=list(train.keys())[70:90]\n",
130     "# cases.remove('reproducibility')\n",
131     "cases"
132    ]
133   },
134   {
135    "cell_type": "code",
136    "execution_count": null,
137    "id": "dd22baf2-59d2-460e-8c47-b20116dd5982",
138    "metadata": {},
139    "outputs": [],
140    "source": [
141     "logging.info('Running over parameter sets %s',param_sets_keys)\n",
142     "logging.info('Running over cases %s',cases)"
143    ]
144   },
145   {
146    "cell_type": "markdown",
147    "id": "49e31fdd-4c14-4a81-9e2b-4c6ba94d1f83",
148    "metadata": {},
149    "source": [
150     "## Separate Models by Location"
151    ]
152   },
153   {
154    "cell_type": "code",
155    "execution_count": null,
156    "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
157    "metadata": {},
158    "outputs": [],
159    "source": [
160     "# Set up output dictionaries\n",
161     "outputs_kf = {}\n",
162     "outputs_rnn = {}"
163    ]
164   },
165   {
166    "cell_type": "code",
167    "execution_count": null,
168    "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
169    "metadata": {},
170    "outputs": [],
171    "source": [
172     "\n",
173     "for k in param_sets_keys:\n",
174     "    params = RNNParams(params_all[k])\n",
175     "    print(\"~\"*80)\n",
176     "    print(\"Running with params:\")\n",
177     "    print(params)\n",
178     "    if k == \"rnn_repro\":\n",
179     "        case = 'reproducibility'\n",
180     "        logging.info('Running reproducibility case')\n",
181     "        rnn_dat = create_rnn_data2(train[case], params)\n",
182     "        reproducibility.set_seed()\n",
183     "        rnn = RNN(params)\n",
184     "        m, errs = rnn.run_model(rnn_dat)\n",
185     "        print(f\"RNN RMSE: {errs}\")\n",
186     "        outputs_rnn[case] = {'case':case, 'm': m.copy(), 'errs': errs.copy()}\n",
187     "\n",
188     "        \n",
189     "        # Run Augmented KF\n",
190     "        print('Running Augmented KF')\n",
191     "        train[case]['h2'] = train[case]['hours'] // 2\n",
192     "        train[case]['scale_fm'] = 1\n",
193     "        m, Ec = run_augmented_kf(train[case])\n",
194     "        m = m*rnn_dat['scale_fm']\n",
195     "        y = rnn_dat['y']*rnn_dat['scale_fm']\n",
196     "        train[case]['m'] = m\n",
197     "        print(f\"KF RMSE: {rmse(m,y)}\")\n",
198     "        outputs_kf[case] = {'case':case, 'm': m.copy(), 'errs': rmse(m,y)}\n",
199     "    else:\n",
200     "        for case in cases:\n",
201     "            # Increase Val Frac so no errors, TODO fix validation\n",
202     "            params.update({\n",
203     "                'train_frac': .5,\n",
204     "                'val_frac': .2\n",
205     "            })\n",
206     "            print(\"~\"*50)\n",
207     "            logging.info('Processing case %s',case)\n",
208     "            print_dict_summary(train[case])\n",
209     "            # Format data & Run Model\n",
210     "            rnn_dat = create_rnn_data2(train[case], params)\n",
211     "            reproducibility.set_seed()\n",
212     "            rnn = RNN(params)\n",
213     "            m, errs = rnn.run_model(rnn_dat)\n",
214     "            # Add model output to case\n",
215     "            train[case]['m']=m\n",
216     "            # Get RMSE Prediction Error\n",
217     "            print(f\"RMSE: {errs}\")\n",
218     "            outputs_rnn[case] = {'case':case, 'm': m.copy(), 'errs': errs.copy()}\n",
219     "            \n",
220     "            # Run Augmented KF\n",
221     "            print('Running Augmented KF')\n",
222     "            train[case]['h2'] = train[case]['hours'] // 2\n",
223     "            train[case]['scale_fm'] = 1\n",
224     "            m, Ec = run_augmented_kf(train[case])\n",
225     "            m = m*rnn_dat['scale_fm']\n",
226     "            y = rnn_dat['y']*rnn_dat['scale_fm']          \n",
227     "            train[case]['m'] = m\n",
228     "            print(f\"KF RMSE: {rmse(m,y)}\")\n",
229     "            outputs_kf[case] = {'case':case, 'm': m.copy(), 'errs': rmse(m,y)}"
230    ]
231   },
232   {
233    "cell_type": "code",
234    "execution_count": null,
235    "id": "15384e4d-b8ec-4700-bdc2-83b0433d11c9",
236    "metadata": {},
237    "outputs": [],
238    "source": [
239     "logging.info('fmda_rnn_serial.ipynb done')"
240    ]
241   },
242   {
243    "cell_type": "code",
244    "execution_count": null,
245    "id": "d0e78fb3-b501-49d6-81a9-1a13da0134a0",
246    "metadata": {},
247    "outputs": [],
248    "source": [
249     "import importlib\n",
250     "import moisture_rnn\n",
251     "importlib.reload(moisture_rnn)\n",
252     "from moisture_rnn import RNN"
253    ]
254   },
255   {
256    "cell_type": "code",
257    "execution_count": null,
258    "id": "37053436-8dfe-4c40-8614-811817e83782",
259    "metadata": {},
260    "outputs": [],
261    "source": [
262     "for k in outputs_rnn:\n",
263     "    print(\"~\"*50)\n",
264     "    print(outputs_rnn[k]['case'])\n",
265     "    print(outputs_rnn[k]['errs']['prediction'])"
266    ]
267   },
268   {
269    "cell_type": "code",
270    "execution_count": null,
271    "id": "9154d5f7-015f-4ef7-af45-020410a1ea65",
272    "metadata": {},
273    "outputs": [],
274    "source": [
275     "for k in outputs_kf:\n",
276     "    print(\"~\"*50)\n",
277     "    print(outputs_kf[k]['case'])\n",
278     "    print(outputs_kf[k]['errs'])"
279    ]
280   },
281   {
282    "cell_type": "markdown",
283    "id": "0c0c3470-30f5-4915-98a7-dcdf5760d482",
284    "metadata": {},
285    "source": [
286     "## Training at Multiple Locations\n",
287     "\n",
288     "Still sequential"
289    ]
290   },
291   {
292    "cell_type": "code",
293    "execution_count": null,
294    "id": "dd1aca73-7279-473e-b2a3-95aa1db7b1a8",
295    "metadata": {},
296    "outputs": [],
297    "source": [
298     "params = params_all['rnn']\n",
299     "params.update({\n",
300     "    'epochs': 100, # less epochs since it is per location\n",
301     "    'activation': ['sigmoid', 'linear'],\n",
302     "    'rnn_units': 10,\n",
303     "    'train_frac': .5,\n",
304     "    'val_frac': .2,\n",
305     "    'scale': True,\n",
306     "    'features_list': ['Ed', 'Ew', 'solar', 'wind', 'rain']\n",
307     "})\n",
308     "\n",
309     "rnn_dat = create_rnn_data2(train[cases[0]], params)"
310    ]
311   },
312   {
313    "cell_type": "code",
314    "execution_count": null,
315    "id": "65b2f9a3-a8f2-4ac1-8e4d-ba38a86eaf40",
316    "metadata": {},
317    "outputs": [],
318    "source": [
319     "reproducibility.set_seed()\n",
320     "rnn = RNN(params)"
321    ]
322   },
323   {
324    "cell_type": "code",
325    "execution_count": null,
326    "id": "47a85ef2-8145-4de8-9f2e-86622306ffd8",
327    "metadata": {},
328    "outputs": [],
329    "source": [
330     "print(\"~\"*80)\n",
331     "print(\"Running with params:\")\n",
332     "print(params)\n",
333     "\n",
334     "for case in cases:\n",
335     "    print(\"~\"*50)\n",
336     "    logging.info('Processing case %s',case)\n",
337     "    print_dict_summary(train[case])\n",
338     "    rnn_dat = create_rnn_data2(train[case], params)\n",
339     "    rnn.fit(rnn_dat['X_train'], rnn_dat['y_train'],\n",
340     "           validation_data=(rnn_dat['X_val'], rnn_dat['y_val']))\n",
341     "    # run_rnn_pkl(train[case],param_sets[i])"
342    ]
343   },
344   {
345    "cell_type": "markdown",
346    "id": "a0421b8d-49aa-4409-8cbf-7732f1137838",
347    "metadata": {},
348    "source": [
349     "### Predict "
350    ]
351   },
352   {
353    "cell_type": "code",
354    "execution_count": null,
355    "id": "63d7854a-94f7-425c-9561-4fe518e044bb",
356    "metadata": {
357     "scrolled": true
358    },
359    "outputs": [],
360    "source": [
361     "# Predict Cases Used in Training\n",
362     "rmses = []\n",
363     "for i, case in enumerate(list(train.keys())[200:210]):\n",
364     "    print(\"~\"*50)\n",
365     "    print(f\"Predicting case {case}\")\n",
366     "    rnn_dat = create_rnn_data2(train[case], params)\n",
367     "    m = rnn.predict(rnn_dat[\"X\"])\n",
368     "    test_ind = rnn_dat['test_ind']\n",
369     "    rmses.append(rmse(m[test_ind:], rnn_dat['y_test'].flatten()))"
370    ]
371   },
372   {
373    "cell_type": "code",
374    "execution_count": null,
375    "id": "2a5423e0-778b-4f69-9ed0-f0082a1fefe5",
376    "metadata": {},
377    "outputs": [],
378    "source": [
379     "rmses"
380    ]
381   },
382   {
383    "cell_type": "code",
384    "execution_count": null,
385    "id": "45c9caae-7ced-4f21-aa05-c9b125e8fdcb",
386    "metadata": {},
387    "outputs": [],
388    "source": [
389     "pd.DataFrame({'Case': list(train.keys())[200:210], 'RMSE': rmses}).style.hide(axis=\"index\")"
390    ]
391   },
392   {
393    "cell_type": "code",
394    "execution_count": null,
395    "id": "f710f482-b600-4ea5-9a8a-823a13b4ec7a",
396    "metadata": {
397     "scrolled": true
398    },
399    "outputs": [],
400    "source": [
401     "# Predict New Locations\n",
402     "rmses = []\n",
403     "for i, case in enumerate(list(train.keys())[150:180]):\n",
404     "    print(\"~\"*50)\n",
405     "    print(f\"Predicting case {case}\")\n",
406     "    rnn_dat = create_rnn_data2(train[case], params)\n",
407     "    m = rnn.predict(rnn_dat[\"X\"])\n",
408     "    train[case]['m'] = m\n",
409     "    test_ind = rnn_dat['test_ind']\n",
410     "    rmses.append(rmse(m[test_ind:], rnn_dat['y_test'].flatten()))\n",
411     "\n",
412     "df = pd.DataFrame({'Case': list(train.keys())[150:180], 'RMSE': rmses})"
413    ]
414   },
415   {
416    "cell_type": "code",
417    "execution_count": null,
418    "id": "d793ac87-d94b-4b16-a271-46cdc259b4fe",
419    "metadata": {},
420    "outputs": [],
421    "source": [
422     "df[0:5].style.hide(axis=\"index\")"
423    ]
424   },
425   {
426    "cell_type": "code",
427    "execution_count": null,
428    "id": "52ec264d-d4b7-444c-b623-002d6383da30",
429    "metadata": {},
430    "outputs": [],
431    "source": [
432     "df.RMSE.mean()"
433    ]
434   },
435   {
436    "cell_type": "code",
437    "execution_count": null,
438    "id": "998922cd-46bb-4063-8284-0497e19c39b0",
439    "metadata": {},
440    "outputs": [],
441    "source": [
442     "plt.hist(df.RMSE)"
443    ]
444   },
445   {
446    "cell_type": "code",
447    "execution_count": null,
448    "id": "889f3bbb-9fb2-4621-9e93-1d0bc0f83e01",
449    "metadata": {},
450    "outputs": [],
451    "source": []
452   },
453   {
454    "cell_type": "code",
455    "execution_count": null,
456    "id": "fe407f61-15f2-4086-a386-7d7a5bb90d26",
457    "metadata": {},
458    "outputs": [],
459    "source": []
460   },
461   {
462    "cell_type": "code",
463    "execution_count": null,
464    "id": "2fdb63b3-68b8-4877-a7a2-f63257cb29d5",
465    "metadata": {},
466    "outputs": [],
467    "source": []
468   },
469   {
470    "cell_type": "code",
471    "execution_count": null,
472    "id": "5c7563c5-a880-45c7-8381-8ce4e1a44216",
473    "metadata": {},
474    "outputs": [],
475    "source": []
476   },
477   {
478    "cell_type": "code",
479    "execution_count": null,
480    "id": "ad5dae6c-1269-4674-a49e-2efe8b956911",
481    "metadata": {},
482    "outputs": [],
483    "source": []
484   }
485  ],
486  "metadata": {
487   "kernelspec": {
488    "display_name": "Python 3 (ipykernel)",
489    "language": "python",
490    "name": "python3"
491   },
492   "language_info": {
493    "codemirror_mode": {
494     "name": "ipython",
495     "version": 3
496    },
497    "file_extension": ".py",
498    "mimetype": "text/x-python",
499    "name": "python",
500    "nbconvert_exporter": "python",
501    "pygments_lexer": "ipython3",
502    "version": "3.9.12"
503   }
504  },
505  "nbformat": 4,
506  "nbformat_minor": 5