Update fmda_rnn_serial.ipynb
[notebooks.git] / fmda / fmda_rnn_serial_batch_reset_schedule.ipynb
blobc9344d7edc26a82d943bcbb2639de21f225eb55a
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "83b774b3-ef55-480a-b999-506676e49145",
6    "metadata": {},
7    "source": [
8     "# v2.1 run RNN strategy serial by Location\n",
9     "\n",
10     "This version of the RNN runs the model on each location separately, one at a time. Two main runs:\n",
11     "1. Run separate model at each location - training and prediction at least location independently - training mode periods 0:train_ind (was 0:h2), then prediction in test_ind:end. Validation data, if any, are from train_ind:test_ind\n",
12     "2. Run same model with multiple fitting calls 0:train_ind at different locations, compare prediction accuracy in test_ind:end  at for all location. \n"
13    ]
14   },
15   {
16    "cell_type": "code",
17    "execution_count": null,
18    "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
19    "metadata": {},
20    "outputs": [],
21    "source": [
22     "import numpy as np\n",
23     "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
24     "import pickle\n",
25     "import logging\n",
26     "import os.path as osp\n",
27     "from moisture_rnn_pkl import pkl2train\n",
28     "from moisture_rnn import RNNParams, RNNData, RNN \n",
29     "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
30     "from moisture_rnn import RNN\n",
31     "import reproducibility\n",
32     "from data_funcs import rmse, to_json\n",
33     "from moisture_models import run_augmented_kf\n",
34     "import copy\n",
35     "import pandas as pd\n",
36     "import matplotlib.pyplot as plt\n",
37     "import yaml\n",
38     "import time"
39    ]
40   },
41   {
42    "cell_type": "code",
43    "execution_count": null,
44    "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
45    "metadata": {},
46    "outputs": [],
47    "source": [
48     "logging_setup()"
49    ]
50   },
51   {
52    "cell_type": "code",
53    "execution_count": null,
54    "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
55    "metadata": {},
56    "outputs": [],
57    "source": [
58     "retrieve_url(\n",
59     "    url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
60     "    dest_path = \"fmda_nw_202401-05_f05.pkl\")"
61    ]
62   },
63   {
64    "cell_type": "code",
65    "execution_count": null,
66    "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
67    "metadata": {},
68    "outputs": [],
69    "source": [
70     "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
71     "file_names=['fmda_nw_202401-05_f05.pkl']\n",
72     "file_dir='data'\n",
73     "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
74    ]
75   },
76   {
77    "cell_type": "code",
78    "execution_count": null,
79    "id": "dcca6185-e799-4dd1-8acb-87ad33c411d7",
80    "metadata": {},
81    "outputs": [],
82    "source": [
83     "# read/write control\n",
84     "train_file='train.pkl'\n",
85     "train_create=False   # if false, read\n",
86     "train_write=False\n",
87     "train_read=True"
88    ]
89   },
90   {
91    "cell_type": "code",
92    "execution_count": null,
93    "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75",
94    "metadata": {
95     "scrolled": true
96    },
97    "outputs": [],
98    "source": [
99     "repro = read_pkl(repro_file)\n",
100     "\n",
101     "if train_create:\n",
102     "    logging.info('creating the training cases from files %s',file_paths)\n",
103     "    # osp.join works on windows too, joins paths using \\ or /\n",
104     "    train = pkl2train(file_paths)\n",
105     "if train_write:\n",
106     "    with open(train_file, 'wb') as file:\n",
107     "        logging.info('Writing the rain cases into file %s',train_file)\n",
108     "        pickle.dump(train, file)\n",
109     "if train_read:\n",
110     "    logging.info('Reading the train cases from file %s',train_file)\n",
111     "    train = read_pkl(train_file)"
112    ]
113   },
114   {
115    "cell_type": "code",
116    "execution_count": null,
117    "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
118    "metadata": {},
119    "outputs": [],
120    "source": [
121     "params_all = read_yml(\"params.yaml\")\n",
122     "print(params_all.keys())"
123    ]
124   },
125   {
126    "cell_type": "code",
127    "execution_count": null,
128    "id": "698df86b-8550-4135-81df-45dbf503dd4e",
129    "metadata": {},
130    "outputs": [],
131    "source": [
132     "# from module_param_sets import param_sets"
133    ]
134   },
135   {
136    "cell_type": "code",
137    "execution_count": null,
138    "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
139    "metadata": {},
140    "outputs": [],
141    "source": [
142     "param_sets_keys=['rnn']\n",
143     "cases=list(train.keys())\n",
144     "# cases=list(train.keys())\n",
145     "# cases.remove('reproducibility')\n",
146     "train_cases = cases[0:100]"
147    ]
148   },
149   {
150    "cell_type": "code",
151    "execution_count": null,
152    "id": "22c87832-5cd5-4ed4-b755-25b4a192ddcb",
153    "metadata": {},
154    "outputs": [],
155    "source": [
156     "print(f\"Number of Training Locations: {len(train_cases)}\")"
157    ]
158   },
159   {
160    "cell_type": "markdown",
161    "id": "49e31fdd-4c14-4a81-9e2b-4c6ba94d1f83",
162    "metadata": {},
163    "source": [
164     "## Separate Models by Location"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "e11e7c83-183f-48ba-abd8-a6aedff66090",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "# Set up output dictionaries\n",
175     "outputs_const = {}\n",
176     "outputs_exp = {}"
177    ]
178   },
179   {
180    "cell_type": "code",
181    "execution_count": null,
182    "id": "dc5b47bd-4fbc-44b8-b2dd-d118e068b450",
183    "metadata": {
184     "scrolled": true
185    },
186    "outputs": [],
187    "source": [
188     "params = RNNParams(params_all['rnn'])\n",
189     "print(\"~\"*80)\n",
190     "print(\"Running with params:\")\n",
191     "print(params)\n",
192     "params.update({\n",
193     "    'activation': ['relu', 'relu'],\n",
194     "    'epochs': 20,\n",
195     "    'rnn_layers' : 2,\n",
196     "    'rnn_units' : 30,\n",
197     "    'dense_units': 20,\n",
198     "    'rnn_layers': 2       \n",
199     "})\n",
200     "for case in train_cases:\n",
201     "    print(\"~\"*50)\n",
202     "    logging.info('Processing case %s',case)\n",
203     "    print_dict_summary(train[case])\n",
204     "    # Format data & Run Model\n",
205     "    # rnn_dat = create_rnn_data2(train[case], params)\n",
206     "    rnn_dat = RNNData(train[case], scaler = params['scaler'], features_list = params['features_list'])\n",
207     "    rnn_dat.train_test_split(\n",
208     "        time_fracs = [.9, .05, .05]\n",
209     "    )\n",
210     "    rnn_dat.scale_data()\n",
211     "    rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n",
212     "    \n",
213     "    params.update({'batch_schedule_type':'constant', 'bmin':20})\n",
214     "    reproducibility.set_seed()\n",
215     "    rnn = RNN(params)\n",
216     "    m, errs, best_ep = rnn.run_model(rnn_dat, return_epochs=True)\n",
217     "    # Get RMSE Prediction Error\n",
218     "    print(f\"RMSE: {errs}\")\n",
219     "    outputs_const[case] = {'case':case, 'errs': errs.copy(), 'epochs': best_ep}\n",
220     "    \n",
221     "    ###########\n",
222     "    params.update({'batch_schedule_type':'exp', \n",
223     "                   'bmin':20, 'bmax': rnn_dat.hours})\n",
224     "    reproducibility.set_seed()\n",
225     "    rnn = RNN(params)\n",
226     "    m, errs, best_ep = rnn.run_model(rnn_dat, return_epochs=True)\n",
227     "    # Get RMSE Prediction Error\n",
228     "    print(f\"RMSE: {errs}\")\n",
229     "    outputs_exp[case] = {'case':case, 'errs': errs.copy(), 'epochs': best_ep}"
230    ]
231   },
232   {
233    "cell_type": "markdown",
234    "id": "31cba687-a66b-4c0f-b3db-c8f1ccb354ae",
235    "metadata": {},
236    "source": [
237     "## Compare"
238    ]
239   },
240   {
241    "cell_type": "code",
242    "execution_count": null,
243    "id": "2fdb63b3-68b8-4877-a7a2-f63257cb29d5",
244    "metadata": {},
245    "outputs": [],
246    "source": [
247     "# Prepare lists to store the extracted values\n",
248     "cases = []\n",
249     "predictions = []\n",
250     "epochs = []\n",
251     "\n",
252     "# Iterate through the dictionary to extract the needed values\n",
253     "for key, value in outputs_const.items():\n",
254     "    cases.append(value['case'])\n",
255     "    predictions.append(value['errs']['prediction'])\n",
256     "    epochs.append(value['epochs'])\n",
257     "\n",
258     "# Create the DataFrame\n",
259     "df1 = pd.DataFrame({\n",
260     "    'case': cases,\n",
261     "    'prediction': predictions,\n",
262     "    'epochs' : epochs\n",
263     "})\n"
264    ]
265   },
266   {
267    "cell_type": "code",
268    "execution_count": null,
269    "id": "5c7563c5-a880-45c7-8381-8ce4e1a44216",
270    "metadata": {},
271    "outputs": [],
272    "source": [
273     "# Prepare lists to store the extracted values\n",
274     "cases = []\n",
275     "predictions = []\n",
276     "epochs = []\n",
277     "\n",
278     "# Iterate through the dictionary to extract the needed values\n",
279     "for key, value in outputs_exp.items():\n",
280     "    cases.append(value['case'])\n",
281     "    predictions.append(value['errs']['prediction'])\n",
282     "    epochs.append(value['epochs'])\n",
283     "\n",
284     "# Create the DataFrame\n",
285     "df2 = pd.DataFrame({\n",
286     "    'case': cases,\n",
287     "    'prediction': predictions,\n",
288     "    'epochs' : epochs\n",
289     "})"
290    ]
291   },
292   {
293    "cell_type": "code",
294    "execution_count": null,
295    "id": "df2c2dfd-3896-4dff-bae6-6d9c3d25c2ab",
296    "metadata": {},
297    "outputs": [],
298    "source": [
299     "df1.head()"
300    ]
301   },
302   {
303    "cell_type": "code",
304    "execution_count": null,
305    "id": "ea0ddeb1-726d-4b15-9565-e098733521e4",
306    "metadata": {},
307    "outputs": [],
308    "source": [
309     "df2.head()"
310    ]
311   },
312   {
313    "cell_type": "code",
314    "execution_count": null,
315    "id": "6bce90d6-803a-4e61-b031-351bb5aa3071",
316    "metadata": {},
317    "outputs": [],
318    "source": [
319     "df1.prediction.mean()"
320    ]
321   },
322   {
323    "cell_type": "code",
324    "execution_count": null,
325    "id": "c458d837-25e3-4b73-a7a3-1dc8dbe533cf",
326    "metadata": {},
327    "outputs": [],
328    "source": [
329     "df2.prediction.mean()"
330    ]
331   },
332   {
333    "cell_type": "code",
334    "execution_count": null,
335    "id": "86a6f0c0-80b7-4200-8d46-c7baea2ec72e",
336    "metadata": {},
337    "outputs": [],
338    "source": [
339     "df1.epochs.mean()"
340    ]
341   },
342   {
343    "cell_type": "code",
344    "execution_count": null,
345    "id": "9994864d-6a8a-4cca-875f-de68ac9ca091",
346    "metadata": {},
347    "outputs": [],
348    "source": [
349     "df2.epochs.mean()"
350    ]
351   },
352   {
353    "cell_type": "code",
354    "execution_count": null,
355    "id": "22c36713-96e5-413d-8e60-ce87fe2896e6",
356    "metadata": {},
357    "outputs": [],
358    "source": []
359   }
360  ],
361  "metadata": {
362   "kernelspec": {
363    "display_name": "Python 3 (ipykernel)",
364    "language": "python",
365    "name": "python3"
366   },
367   "language_info": {
368    "codemirror_mode": {
369     "name": "ipython",
370     "version": 3
371    },
372    "file_extension": ".py",
373    "mimetype": "text/x-python",
374    "name": "python",
375    "nbconvert_exporter": "python",
376    "pygments_lexer": "ipython3",
377    "version": "3.12.5"
378   }
379  },
380  "nbformat": 4,
381  "nbformat_minor": 5