Update rnn_workshop.ipynb
[notebooks.git] / fmda / rnn_workshop.ipynb
blobcbd32403ef33023129ce5fbfa92f3bdbaced28e1
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "e6cc7920-e380-4b81-bac0-cd6840450e9a",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "# Environment\n",
11     "import numpy as np\n",
12     "import pandas as pd\n",
13     "import tensorflow as tf\n",
14     "import matplotlib.pyplot as plt\n",
15     "import sys\n",
16     "# Local modules\n",
17     "sys.path.append('..')\n",
18     "import reproducibility\n",
19     "import pandas as pd\n",
20     "from utils import print_dict_summary\n",
21     "from data_funcs import load_and_fix_data, rmse\n",
22     "from moisture_rnn import RNN, RNN_LSTM, create_rnn_data2\n",
23     "from moisture_rnn_pkl import pkl2train\n",
24     "from tensorflow.keras.callbacks import Callback\n",
25     "from sklearn.metrics import mean_squared_error\n",
26     "from utils import hash2\n",
27     "import copy\n",
28     "import logging\n",
29     "from utils import logging_setup"
30    ]
31   },
32   {
33    "cell_type": "code",
34    "execution_count": null,
35    "id": "f58e8839-bf0e-4995-b966-c09e4df001ce",
36    "metadata": {},
37    "outputs": [],
38    "source": [
39     "logging_setup()"
40    ]
41   },
42   {
43    "cell_type": "markdown",
44    "id": "2298a1a1-b72c-4c7e-bcb6-2cdefe96fe3e",
45    "metadata": {},
46    "source": [
47     "## Test Data Creation"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "56444dda-1e57-4b47-ad35-72ae7ed706e6",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "import yaml\n",
58     "\n",
59     "with open(\"params.yaml\") as file:\n",
60     "    params = yaml.safe_load(file)[\"rnn\"]\n",
61     "# params.update({'scale': 1})"
62    ]
63   },
64   {
65    "cell_type": "code",
66    "execution_count": null,
67    "id": "4666d11f-aaa2-426e-a406-70603f2799f9",
68    "metadata": {
69     "scrolled": true
70    },
71    "outputs": [],
72    "source": [
73     "params.update({'features_list': ['Ed', 'rain', 'solar', 'wind']})\n",
74     "train = pkl2train(['data/reproducibility_dict2.pickle', \"data/test_CA_202401.pkl\"],\n",
75     "                 features_list = params['features_list'])\n",
76     "# train = pd.read_pickle('train.pkl')"
77    ]
78   },
79   {
80    "cell_type": "code",
81    "execution_count": null,
82    "id": "58d54d93-067c-4372-b285-5a6e394fdd75",
83    "metadata": {},
84    "outputs": [],
85    "source": [
86     "list(train.keys())[0:5]"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "cb119a26-c08a-41a1-bea2-c6dfe3b5feb0",
93    "metadata": {},
94    "outputs": [],
95    "source": [
96     "print(params['features_list'])\n",
97     "print(train['FCHC1_202401']['X'].shape)"
98    ]
99   },
100   {
101    "cell_type": "code",
102    "execution_count": null,
103    "id": "f8a1c7ad-6529-4c34-8444-0ef11f44dc2c",
104    "metadata": {},
105    "outputs": [],
106    "source": [
107     "params2 = copy.deepcopy(params)"
108    ]
109   },
110   {
111    "cell_type": "code",
112    "execution_count": null,
113    "id": "fbf6a57e-8a8c-4494-aea4-08eb4bf9c438",
114    "metadata": {},
115    "outputs": [],
116    "source": [
117     "print(params2['scale'])\n",
118     "print(params2['scaler'])"
119    ]
120   },
121   {
122    "cell_type": "code",
123    "execution_count": null,
124    "id": "79e3c25a-cf08-4a6d-be2e-21afb861c976",
125    "metadata": {},
126    "outputs": [],
127    "source": [
128     "params2.update({'val_frac': .2, 'scale': True, 'scaler': 'standard'})\n",
129     "rnn_dat = create_rnn_data2(train['CRVC1_202401'], params2)"
130    ]
131   },
132   {
133    "cell_type": "code",
134    "execution_count": null,
135    "id": "681f9bb4-9f05-424c-ad24-c37d747283b8",
136    "metadata": {},
137    "outputs": [],
138    "source": [
139     "print(train['CRVC1_202401']['X'][0,:])\n",
140     "print(rnn_dat['X'][0,:])"
141    ]
142   },
143   {
144    "cell_type": "code",
145    "execution_count": null,
146    "id": "e0a1147c-e882-4b67-9e77-36a3b4b8bde4",
147    "metadata": {},
148    "outputs": [],
149    "source": [
150     "reproducibility.set_seed()\n",
151     "rnn = RNN(params2)\n",
152     "m, errs = rnn.run_model(rnn_dat)"
153    ]
154   },
155   {
156    "cell_type": "code",
157    "execution_count": null,
158    "id": "21ea67c5-26c3-4f2c-a2b8-03be57d7e013",
159    "metadata": {},
160    "outputs": [],
161    "source": [
162     "params2.update({'scale': False})\n",
163     "rnn_dat = create_rnn_data2(train['CRVC1_202401'], params2)\n",
164     "print(rnn_dat['X_train'][0,:])\n",
165     "reproducibility.set_seed()\n",
166     "rnn = RNN(params2)\n",
167     "m, errs = rnn.run_model(rnn_dat)"
168    ]
169   },
170   {
171    "cell_type": "code",
172    "execution_count": null,
173    "id": "888dd72a-4eef-414b-ac33-f6f4bfbefe60",
174    "metadata": {},
175    "outputs": [],
176    "source": [
177     "errs"
178    ]
179   },
180   {
181    "cell_type": "markdown",
182    "id": "d2360aef-e9c4-4a71-922d-336e53b82537",
183    "metadata": {},
184    "source": [
185     "## LSTM"
186    ]
187   },
188   {
189    "cell_type": "code",
190    "execution_count": null,
191    "id": "71d4e441-9bf1-4d57-bb37-091553e23212",
192    "metadata": {},
193    "outputs": [],
194    "source": [
195     "import importlib \n",
196     "import moisture_rnn\n",
197     "importlib.reload(moisture_rnn)\n",
198     "from moisture_rnn import RNN_LSTM"
199    ]
200   },
201   {
202    "cell_type": "code",
203    "execution_count": null,
204    "id": "88a73c40-0bab-4aa0-8265-00f427aa97ea",
205    "metadata": {
206     "jupyter": {
207      "source_hidden": true
208     }
209    },
210    "outputs": [],
211    "source": [
212     "from moisture_rnn import RNN_LSTM\n",
213     "\n",
214     "# from tensorflow.keras.layers import LSTM, Input, Dropout, Dense, SimpleRNN\n",
215     "# from moisture_rnn import staircase_2\n",
216     "# from abc import ABC, abstractmethod\n",
217     "# from data_funcs import compare_dicts\n",
218     "# class RNNModel(ABC):\n",
219     "#     def __init__(self, params: dict):\n",
220     "#         self.params = params\n",
221     "#         if type(self) is RNNModel:\n",
222     "#             raise TypeError(\"MLModel is an abstract class and cannot be instantiated directly\")\n",
223     "#         super().__init__()\n",
224     "\n",
225     "#     @abstractmethod\n",
226     "#     def fit(self, X_train, y_train, weights=None):\n",
227     "#         pass\n",
228     "\n",
229     "#     @abstractmethod\n",
230     "#     def predict(self, X):\n",
231     "#         pass\n",
232     "\n",
233     "#     def run_model(self, dict0):\n",
234     "#         # Make copy to prevent changing in place\n",
235     "#         dict1 = copy.deepcopy(dict0)\n",
236     "#         # Extract Fields\n",
237     "#         X_train, y_train, X_test, y_test = dict1['X_train'].copy(), dict1['y_train'].copy(), dict1[\"X_test\"].copy(), dict1['y_test'].copy()\n",
238     "#         if 'X_val' in dict1:\n",
239     "#             X_val, y_val = dict1['X_val'].copy(), dict1['y_val'].copy()\n",
240     "#         else:\n",
241     "#             X_val = None\n",
242     "#         case_id = dict1['case']\n",
243     "\n",
244     "#         # Fit model\n",
245     "#         if X_val is None:\n",
246     "#             self.fit(X_train, y_train)\n",
247     "#         else:\n",
248     "#             self.fit(X_train, y_train, validation_data = (X_val, y_val))\n",
249     "#         # Generate Predictions, \n",
250     "#         # run through training to get hidden state set proporly for forecast period\n",
251     "#         if X_val is None:\n",
252     "#             X = np.concatenate((X_train, X_test))\n",
253     "#             y = np.concatenate((y_train, y_test)).flatten()\n",
254     "#         else:\n",
255     "#             X = np.concatenate((X_train, X_val, X_test))\n",
256     "#             y = np.concatenate((y_train, y_val, y_test)).flatten()\n",
257     "#         # Predict\n",
258     "#         print(f\"Predicting Training through Test \\n features hash: {hash2(X)} \\n response hash: {hash2(y)} \")\n",
259     "#         m = self.predict(X).flatten()\n",
260     "#         dict1['m']=m\n",
261     "#         dict0['m']=m # add to outside env dictionary, should be only place this happens\n",
262     "#         if self.params['scale']:\n",
263     "#             print(f\"Rescaling data using {self.params['scaler']}\")\n",
264     "#             if self.params['scaler'] == \"reproducibility\":\n",
265     "#                 m  *= self.params['scale_fm']\n",
266     "#                 y  *= self.params['scale_fm']\n",
267     "#                 y_train *= self.params['scale_fm']\n",
268     "#                 y_test *= self.params['scale_fm']\n",
269     "#         # Check Reproducibility, TODO: old dict calls it hidden_units not rnn_units, so this doens't check that\n",
270     "#         if (case_id == \"reproducibility\") and compare_dicts(self.params, repro_hashes['params'], ['epochs', 'batch_size', 'scale', 'activation', 'learning_rate']):\n",
271     "#             print(\"Checking Reproducibility\")\n",
272     "#             checkm = m[350]\n",
273     "#             hv = hash2(self.model_predict.get_weights())\n",
274     "#             if self.params['phys_initialize']:\n",
275     "#                 hv5 = repro_hashes['phys_initialize']['fitted_weight_hash']\n",
276     "#                 mv = repro_hashes['phys_initialize']['predictions_hash']\n",
277     "#             else:\n",
278     "#                 hv5 = repro_hashes['rand_initialize']['fitted_weight_hash']\n",
279     "#                 mv = repro_hashes['rand_initialize']['predictions_hash']           \n",
280     "            \n",
281     "#             print(f\"Fitted weights hash (check 5): {hv}, Reproducibility weights hash: {hv5}, Error: {hv5-hv}\")\n",
282     "#             print(f\"Model predictions hash: {checkm}, Reproducibility preds hash: {mv}, Error: {mv-checkm}\")\n",
283     "\n",
284     "#         # print(dict1.keys())\n",
285     "#         # Plot final fit and data\n",
286     "#         # TODO: make plot_data specific to this context\n",
287     "#         dict1['y'] = y\n",
288     "#         plot_data(dict1, title=\"RNN\", title2=dict1['case'])\n",
289     "        \n",
290     "#         # Calculate Errors\n",
291     "#         err = rmse(m, y)\n",
292     "#         train_ind = dict1[\"train_ind\"] # index of final training set value\n",
293     "#         test_ind = dict1[\"test_ind\"] # index of first test set value\n",
294     "#         err_train = rmse(m[:train_ind], y_train.flatten())\n",
295     "#         err_pred = rmse(m[test_ind:], y_test.flatten())\n",
296     "#         rmse_dict = {\n",
297     "#             'all': err, \n",
298     "#             'training': err_train, \n",
299     "#             'prediction': err_pred\n",
300     "#         }\n",
301     "#         return rmse_dict\n",
302     "        \n",
303     "# class ResetStatesCallback(Callback):\n",
304     "#     def on_epoch_end(self, epoch, logs=None):\n",
305     "#         self.model.reset_states()\n",
306     "\n",
307     "\n",
308     "# class RNN_LSTM(RNNModel):\n",
309     "#     def __init__(self, params, loss='mean_squared_error'):\n",
310     "#         super().__init__(params)\n",
311     "#         self.model_train = self._build_model_train()\n",
312     "#         self.model_predict = self._build_model_predict()\n",
313     "\n",
314     "#     def _build_model_train(self, return_sequences=False):\n",
315     "#         inputs = tf.keras.Input(batch_shape=self.params['batch_shape'])\n",
316     "#         x = inputs\n",
317     "#         for i in range(self.params['rnn_layers']):\n",
318     "#             x = LSTM(\n",
319     "#                 units=self.params['rnn_units'],\n",
320     "#                 activation=self.params['activation'][0],\n",
321     "#                 dropout=self.params[\"dropout\"][0],\n",
322     "#                 stateful=self.params['stateful'],\n",
323     "#                 return_sequences=return_sequences)(x)\n",
324     "#         if self.params[\"dropout\"][1] > 0:\n",
325     "#             x = Dropout(self.params[\"dropout\"][1])(x)            \n",
326     "#         for i in range(self.params['dense_layers']):\n",
327     "#             x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)\n",
328     "#         model = tf.keras.Model(inputs=inputs, outputs=x)\n",
329     "#         optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])\n",
330     "#         model.compile(loss='mean_squared_error', optimizer=optimizer)\n",
331     "        \n",
332     "#         if self.params[\"verbose_weights\"]:\n",
333     "#             print(f\"Initial Weights Hash: {hash2(model.get_weights())}\")\n",
334     "#         return model\n",
335     "#     def _build_model_predict(self, return_sequences=True):\n",
336     "        \n",
337     "#         inputs = tf.keras.Input(shape=self.params['pred_input_shape'])\n",
338     "#         x = inputs\n",
339     "#         for i in range(self.params['rnn_layers']):\n",
340     "#             x = LSTM(\n",
341     "#                 units=self.params['rnn_units'],\n",
342     "#                 activation=self.params['activation'][0],\n",
343     "#                 stateful=False,return_sequences=return_sequences)(x)\n",
344     "#         for i in range(self.params['dense_layers']):\n",
345     "#             x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)\n",
346     "#         model = tf.keras.Model(inputs=inputs, outputs=x)\n",
347     "#         optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])\n",
348     "#         model.compile(loss='mean_squared_error', optimizer=optimizer)  \n",
349     "\n",
350     "#         # Set Weights to model_train\n",
351     "#         w_fitted = self.model_train.get_weights()\n",
352     "#         model.set_weights(w_fitted)\n",
353     "        \n",
354     "#         return model\n",
355     "#     def format_train_data(self, X, y, verbose=False):\n",
356     "#         X, y = staircase_2(X, y, timesteps = self.params[\"timesteps\"], batch_size=self.params[\"batch_size\"], verbose=verbose)\n",
357     "#         return X, y\n",
358     "#     def format_pred_data(self, X):\n",
359     "#         return np.reshape(X,(1, X.shape[0], self.params['features']))\n",
360     "#     def fit(self, X_train, y_train, plot=True, plot_title = '', \n",
361     "#             weights=None, callbacks=[], verbose_fit=None, validation_data=None, *args, **kwargs):\n",
362     "#         # verbose_fit argument is for printing out update after each epoch, which gets very long\n",
363     "#         # These print statements at the top could be turned off with a verbose argument, but then\n",
364     "#         # there would be a bunch of different verbose params\n",
365     "#         print(f\"Training simple RNN with params: {self.params}\")\n",
366     "#         X_train, y_train = self.format_train_data(X_train, y_train)\n",
367     "#         print(f\"X_train hash: {hash2(X_train)}\")\n",
368     "#         print(f\"y_train hash: {hash2(y_train)}\")\n",
369     "#         if validation_data is not None:\n",
370     "#             X_val, y_val = self.format_train_data(validation_data[0], validation_data[1])\n",
371     "#             print(f\"X_val hash: {hash2(X_val)}\")\n",
372     "#             print(f\"y_val hash: {hash2(y_val)}\")\n",
373     "#         print(f\"Initial weights before training hash: {hash2(self.model_train.get_weights())}\")\n",
374     "#         # Setup callbacks\n",
375     "#         if self.params[\"reset_states\"]:\n",
376     "#             callbacks=callbacks+[ResetStatesCallback()]\n",
377     "        \n",
378     "#         # Note: we overload the params here so that verbose_fit can be easily turned on/off at the .fit call \n",
379     "#         if verbose_fit is None:\n",
380     "#             verbose_fit = self.params['verbose_fit']\n",
381     "#         # Evaluate Model once to set nonzero initial state\n",
382     "#         if self.params[\"batch_size\"]>= X_train.shape[0]:\n",
383     "#             self.model_train(X_train)\n",
384     "#         if validation_data is not None:\n",
385     "#             history = self.model_train.fit(\n",
386     "#                 X_train, y_train+self.params['centering'][1], \n",
387     "#                 epochs=self.params['epochs'], \n",
388     "#                 batch_size=self.params['batch_size'],\n",
389     "#                 callbacks = callbacks,\n",
390     "#                 verbose=verbose_fit,\n",
391     "#                 validation_data = (X_val, y_val),\n",
392     "#                 *args, **kwargs\n",
393     "#             )\n",
394     "#         else:\n",
395     "#             history = self.model_train.fit(\n",
396     "#                 X_train, y_train+self.params['centering'][1], \n",
397     "#                 epochs=self.params['epochs'], \n",
398     "#                 batch_size=self.params['batch_size'],\n",
399     "#                 callbacks = callbacks,\n",
400     "#                 verbose=verbose_fit,\n",
401     "#                 *args, **kwargs\n",
402     "#             )\n",
403     "#         if plot:\n",
404     "#             self.plot_history(history,plot_title)\n",
405     "#         if self.params[\"verbose_weights\"]:\n",
406     "#             print(f\"Fitted Weights Hash: {hash2(self.model_train.get_weights())}\")\n",
407     "\n",
408     "#         # Update Weights for Prediction Model\n",
409     "#         w_fitted = self.model_train.get_weights()\n",
410     "#         self.model_predict.set_weights(w_fitted)\n",
411     "#     def predict(self, X_test):\n",
412     "#         print(\"Predicting with simple RNN\")\n",
413     "#         X_test = self.format_pred_data(X_test)\n",
414     "#         preds = self.model_predict.predict(X_test).flatten()\n",
415     "#         return preds\n",
416     "\n",
417     "\n",
418     "#     def plot_history(self, history, plot_title):\n",
419     "#         plt.semilogy(history.history['loss'], label='Training loss')\n",
420     "#         if 'val_loss' in history.history:\n",
421     "#             plt.semilogy(history.history['val_loss'], label='Validation loss')\n",
422     "#         plt.title(f'{plot_title} Model loss')\n",
423     "#         plt.ylabel('Loss')\n",
424     "#         plt.xlabel('Epoch')\n",
425     "#         plt.legend(loc='upper left')\n",
426     "#         plt.show()"
427    ]
428   },
429   {
430    "cell_type": "code",
431    "execution_count": null,
432    "id": "59480f19-3567-4b24-b6ff-d9292dc8c2ec",
433    "metadata": {},
434    "outputs": [],
435    "source": [
436     "with open(\"params.yaml\") as file:\n",
437     "    params = yaml.safe_load(file)[\"lstm\"]\n",
438     "    \n",
439     "rnn_dat = create_rnn_data2(train['reproducibility'],params)"
440    ]
441   },
442   {
443    "cell_type": "code",
444    "execution_count": null,
445    "id": "2adff592-7aa4-4e59-a229-cad4a133297e",
446    "metadata": {},
447    "outputs": [],
448    "source": [
449     "params"
450    ]
451   },
452   {
453    "cell_type": "code",
454    "execution_count": null,
455    "id": "6a9d612e-8cd2-40ca-a789-91c99c3d6ccd",
456    "metadata": {},
457    "outputs": [],
458    "source": [
459     "params.update({'epochs': 75})\n",
460     "reproducibility.set_seed()\n",
461     "lstm = RNN_LSTM(params)\n",
462     "m, errs = lstm.run_model(rnn_dat)"
463    ]
464   },
465   {
466    "cell_type": "code",
467    "execution_count": null,
468    "id": "f60a24c6-9a67-45aa-bc5c-8818aa0ca049",
469    "metadata": {},
470    "outputs": [],
471    "source": []
472   }
473  ],
474  "metadata": {
475   "kernelspec": {
476    "display_name": "Python 3 (ipykernel)",
477    "language": "python",
478    "name": "python3"
479   },
480   "language_info": {
481    "codemirror_mode": {
482     "name": "ipython",
483     "version": 3
484    },
485    "file_extension": ".py",
486    "mimetype": "text/x-python",
487    "name": "python",
488    "nbconvert_exporter": "python",
489    "pygments_lexer": "ipython3",
490    "version": "3.9.12"
491   }
492  },
493  "nbformat": 4,
494  "nbformat_minor": 5