Add validation data compatibility to run_model top level command
[notebooks.git] / fmda / version_control / rnn_train_versions.ipynb
blob4851ed1273e3f616a4e27282e00fa45593b0be04
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "8530bc7e-61ae-4463-a14f-d5eb42f0b83e",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "# Environment\n",
11     "import numpy as np\n",
12     "import pandas as pd\n",
13     "import tensorflow as tf\n",
14     "import matplotlib.pyplot as plt\n",
15     "import sys\n",
16     "# Local modules\n",
17     "sys.path.append('..')\n",
18     "from moisture_rnn import RNN\n",
19     "import reproducibility\n",
20     "from utils import print_dict_summary\n",
21     "from data_funcs import load_and_fix_data, rmse\n",
22     "from moisture_rnn0 import run_case\n",
23     "from moisture_rnn_pkl import pkl2train\n",
24     "from utils import hash2"
25    ]
26   },
27   {
28    "cell_type": "markdown",
29    "id": "63d275da-b13a-405e-9e1a-aa3f972119b5",
30    "metadata": {},
31    "source": [
32     "### Reproducibility Datasets"
33    ]
34   },
35   {
36    "cell_type": "code",
37    "execution_count": null,
38    "id": "2dfa08fa-01c9-4fd7-927b-541b0a532e4f",
39    "metadata": {},
40    "outputs": [],
41    "source": [
42     "# Original File\n",
43     "reproducibility_file='reproducibility_dict0.pickle'\n",
44     "\n",
45     "repro={}\n",
46     "repro.update(load_and_fix_data(reproducibility_file))\n",
47     "print_dict_summary(repro)"
48    ]
49   },
50   {
51    "cell_type": "code",
52    "execution_count": null,
53    "id": "f5f4c46a-6e9f-46a0-86e7-3503b83d1816",
54    "metadata": {},
55    "outputs": [],
56    "source": [
57     "# Restructured original file\n",
58     "reproducibility_file='../data/reproducibility_dict2.pickle'\n",
59     "repro2 = pkl2train([reproducibility_file])\n",
60     "print_dict_summary(repro2)"
61    ]
62   },
63   {
64    "cell_type": "markdown",
65    "id": "87f59db1-fa7b-44f4-bbe4-7e49221226e9",
66    "metadata": {},
67    "source": [
68     "## RNN with Stateful Batch Training\n"
69    ]
70   },
71   {
72    "cell_type": "markdown",
73    "id": "d135d6e5-505c-474b-8ed6-b70a37691b2c",
74    "metadata": {},
75    "source": [
76     "### Custom Class"
77    ]
78   },
79   {
80    "cell_type": "code",
81    "execution_count": null,
82    "id": "4ede9568-cbfa-41b1-ab93-69155344383a",
83    "metadata": {},
84    "outputs": [],
85    "source": [
86     "from moisture_rnn import create_rnn_data2, RNN\n",
87     "import logging\n",
88     "from utils import logging_setup\n",
89     "logging_setup()"
90    ]
91   },
92   {
93    "cell_type": "code",
94    "execution_count": null,
95    "id": "0040bb22-ce88-43b7-aead-36a26bc59ad0",
96    "metadata": {},
97    "outputs": [],
98    "source": []
99   },
100   {
101    "cell_type": "code",
102    "execution_count": null,
103    "id": "e84f8219-3f71-43d2-b159-4f200a7cf1c0",
104    "metadata": {},
105    "outputs": [],
106    "source": [
107     "import yaml\n",
108     "\n",
109     "with open(\"../params.yaml\") as file:\n",
110     "    params = yaml.safe_load(file)[\"rnn_repro\"]\n",
111     "params.update({'scale': 1})\n",
112     "params"
113    ]
114   },
115   {
116    "cell_type": "code",
117    "execution_count": null,
118    "id": "749a4a1c-e3b7-4a6c-9305-ed2349ca7647",
119    "metadata": {},
120    "outputs": [],
121    "source": [
122     "rnn_dat = create_rnn_data2(repro2[\"reproducibility\"], params)"
123    ]
124   },
125   {
126    "cell_type": "code",
127    "execution_count": null,
128    "id": "458ea3bb-2b19-42b3-8ac8-fdec96c6d315",
129    "metadata": {},
130    "outputs": [],
131    "source": [
132     "reproducibility.set_seed()\n",
133     "rnn = RNN(params)\n",
134     "rnn.run_model(rnn_dat)"
135    ]
136   },
137   {
138    "cell_type": "markdown",
139    "id": "a5ee76d9-2a0f-4bca-9fef-5d0f9361a02d",
140    "metadata": {},
141    "source": [
142     "### Using Old Code with `run_case`"
143    ]
144   },
145   {
146    "cell_type": "code",
147    "execution_count": null,
148    "id": "6875a379-bf28-438a-b7fc-d5f81879b26f",
149    "metadata": {},
150    "outputs": [],
151    "source": [
152     "from module_param_sets0 import param_sets\n",
153     "params = param_sets['0']"
154    ]
155   },
156   {
157    "cell_type": "code",
158    "execution_count": null,
159    "id": "29596ede-689a-4629-8f07-412289e51122",
160    "metadata": {},
161    "outputs": [],
162    "source": [
163     "reproducibility.set_seed()\n",
164     "params['initialize']=False \n",
165     "case = 'case11'\n",
166     "case_data=repro[case]\n",
167     "case_data['h2']=427\n",
168     "run_case(case_data,params)"
169    ]
170   },
171   {
172    "cell_type": "code",
173    "execution_count": null,
174    "id": "4303e754-581c-4ffd-a81b-4b9064d281a1",
175    "metadata": {},
176    "outputs": [],
177    "source": []
178   },
179   {
180    "cell_type": "markdown",
181    "id": "1cad6dcd-7a0a-4afb-aec7-6cd5aa7ee568",
182    "metadata": {},
183    "source": [
184     "## Original Case - Single Batch"
185    ]
186   },
187   {
188    "cell_type": "code",
189    "execution_count": null,
190    "id": "db563737-4aa8-409a-a3f2-86495678ba2b",
191    "metadata": {},
192    "outputs": [],
193    "source": [
194     "# NOTE: original param sets live in model_param_sets0 but commented out, manually reproducing here\n",
195     "param_sets_ORIG = {'id':0,\n",
196     "        'purpose':'reproducibility',\n",
197     "        'batch_size':np.inf,\n",
198     "        'training':None,\n",
199     "        'cases':['case11'],\n",
200     "        'scale':0,\n",
201     "        'rain_do':False,\n",
202     "#        'verbose':False,\n",
203     "        'verbose':1,\n",
204     "        'timesteps':5,\n",
205     "        'activation':['linear','linear'],\n",
206     "        'centering':[0.0,0.0],\n",
207     "        'hidden_units':6,\n",
208     "        'dense_units':1,\n",
209     "        'dense_layers':1,\n",
210     "        'DeltaE':[0,-1],    # -1.0 is to correct E bias but put at the end\n",
211     "        'synthetic':False,  # run also synthetic cases\n",
212     "        'T1': 0.1,          # 1/fuel class (10)\n",
213     "        'fm_raise_vs_rain': 2.0,         # fm increase per mm rain                              \n",
214     "        'epochs':5000,\n",
215     "        'verbose_fit':0,\n",
216     "        'verbose_weights':True,\n",
217     "        'note':'check 5 should give zero error'\n",
218     "        }"
219    ]
220   },
221   {
222    "cell_type": "markdown",
223    "id": "785d0e34-35c7-4a8d-802d-4a176e2c19bd",
224    "metadata": {},
225    "source": [
226     "### Using Old RNN Code\n",
227     "\n",
228     "Code is deployed through the `run_case` function."
229    ]
230   },
231   {
232    "cell_type": "code",
233    "execution_count": null,
234    "id": "49cb8274-7c1a-40c0-ad81-d1fbcb498884",
235    "metadata": {},
236    "outputs": [],
237    "source": [
238     "reproducibility.set_seed()\n",
239     "print('Running reproducibility')\n",
240     "assert param_sets_ORIG['purpose'] == 'reproducibility'\n",
241     "param_sets_ORIG['initialize']=False \n",
242     "case = 'case11'\n",
243     "case_data=repro[case]\n",
244     "case_data[\"h2\"]=300\n",
245     "run_case(case_data,param_sets_ORIG)"
246    ]
247   },
248   {
249    "cell_type": "markdown",
250    "id": "6ddef6dc-c56b-4257-b603-783001a94262",
251    "metadata": {},
252    "source": [
253     "### Reproduce with Class Code\n",
254     "\n",
255     "Code deployed through custom class, and parameters come from yaml file."
256    ]
257   },
258   {
259    "cell_type": "code",
260    "execution_count": null,
261    "id": "05cc8cf3-48db-44df-9649-4a453a271957",
262    "metadata": {},
263    "outputs": [],
264    "source": [
265     "from tensorflow.keras.callbacks import Callback\n",
266     "from abc import ABC, abstractmethod\n",
267     "class ResetStatesCallback(Callback):\n",
268     "    def on_epoch_end(self, epoch, logs=None):\n",
269     "        self.model.reset_states()\n",
270     "        \n",
271     "from sklearn.metrics import mean_squared_error\n",
272     "def rmse(a, b):\n",
273     "    return np.sqrt(mean_squared_error(a.flatten(), b.flatten()))\n",
274     "\n",
275     "\n",
276     "class RNNModel(ABC):\n",
277     "    def __init__(self, params: dict):\n",
278     "        self.params = params\n",
279     "        if type(self) is RNNModel:\n",
280     "            raise TypeError(\"MLModel is an abstract class and cannot be instantiated directly\")\n",
281     "        super().__init__()\n",
282     "\n",
283     "    @abstractmethod\n",
284     "    def fit(self, X_train, y_train, weights=None):\n",
285     "        pass\n",
286     "\n",
287     "    @abstractmethod\n",
288     "    def predict(self, X):\n",
289     "        pass\n",
290     "\n",
291     "class RNN(RNNModel):\n",
292     "    def __init__(self, params, loss='mean_squared_error'):\n",
293     "        super().__init__(params)\n",
294     "        self.model_train = self._build_model_train()\n",
295     "        self.model_predict = self._build_model_predict()\n",
296     "        # self.compile_model()\n",
297     "\n",
298     "    def _build_model_train(self, return_sequences=False):\n",
299     "        inputs = tf.keras.Input(batch_shape=self.params['batch_shape'])\n",
300     "        x = inputs\n",
301     "        for i in range(self.params['rnn_layers']):\n",
302     "            x = tf.keras.layers.SimpleRNN(self.params['rnn_units'],activation=self.params['activation'][0],\n",
303     "                  stateful=self.params['stateful'],return_sequences=return_sequences)(x)\n",
304     "        for i in range(self.params['dense_layers']):\n",
305     "            x = tf.keras.layers.Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)\n",
306     "        model = tf.keras.Model(inputs=inputs, outputs=x)\n",
307     "        model.compile(loss='mean_squared_error', optimizer='adam')\n",
308     "        if self.params[\"verbose_weights\"]:\n",
309     "            print(f\"Initial Weights Hash: {hash2(model.get_weights())}\")\n",
310     "        \n",
311     "        return model\n",
312     "    def _build_model_predict(self, return_sequences=True):\n",
313     "        \n",
314     "        inputs = tf.keras.Input(shape=self.params['pred_input_shape'])\n",
315     "        x = inputs\n",
316     "        for i in range(self.params['rnn_layers']):\n",
317     "            x = tf.keras.layers.SimpleRNN(self.params['rnn_units'],activation=self.params['activation'][0],\n",
318     "                  stateful=False,return_sequences=return_sequences)(x)\n",
319     "        for i in range(self.params['dense_layers']):\n",
320     "            x = tf.keras.layers.Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)\n",
321     "        model = tf.keras.Model(inputs=inputs, outputs=x)\n",
322     "        model.compile(loss='mean_squared_error', optimizer='adam')  \n",
323     "\n",
324     "        # Set Weights to model_train\n",
325     "        w_fitted = self.model_train.get_weights()\n",
326     "        model.set_weights(w_fitted)\n",
327     "        \n",
328     "        return model\n",
329     "\n",
330     "    def fit(self, X_train, y_train, plot=True, weights=None, callbacks=[], verbose_fit=None):\n",
331     "        # verbose_fit argument is for printing out update after each epoch, which gets very long\n",
332     "        # These print statements at the top could be turned off with a verbose argument, but then\n",
333     "        # there would be a bunch of different verbose params\n",
334     "        print(f\"Training simple RNN with params: {self.params}\")\n",
335     "        print(f\"X_train hash: {hash2(X_train)}\")\n",
336     "        print(f\"y_train hash: {hash2(y_train)}\")\n",
337     "        print(f\"Initial weights before training hash: {hash2(self.model_train.get_weights())}\")\n",
338     "        # Note: we overload the params here so that verbose_fit can be easily turned on/off at the .fit call \n",
339     "        if verbose_fit is None:\n",
340     "            verbose_fit = self.params['verbose_fit']\n",
341     "        # Evaluate Model once to set nonzero initial state\n",
342     "        if self.params[\"batch_size\"]>= X_train.shape[0]:\n",
343     "            self.model_train(X_train)\n",
344     "        # Fit Model\n",
345     "        history = self.model_train.fit(\n",
346     "            X_train, y_train+self.params['centering'][1], \n",
347     "            epochs=self.params['epochs'], \n",
348     "            batch_size=self.params['batch_size'],\n",
349     "            callbacks = callbacks,\n",
350     "            verbose=verbose_fit)\n",
351     "        if plot:\n",
352     "            self.plot_history(history)\n",
353     "        if self.params[\"verbose_weights\"]:\n",
354     "            print(f\"Fitted Weights Hash: {hash2(self.model_train.get_weights())}\")\n",
355     "\n",
356     "        # Update Weights for Prediction Model\n",
357     "        w_fitted = self.model_train.get_weights()\n",
358     "        self.model_predict.set_weights(w_fitted)\n",
359     "    def predict(self, X_test):\n",
360     "        print(\"Predicting with simple RNN\")\n",
361     "        preds = self.model_predict.predict(X_test)\n",
362     "        return preds\n",
363     "    def plot_history(self, history):\n",
364     "        plt.semilogy(history.history['loss'], label='Training loss')\n",
365     "        if 'val_loss' in history.history:\n",
366     "            plt.semilogy(history.history['val_loss'], label='Validation loss')\n",
367     "        plt.title(case + ' Model loss')\n",
368     "        plt.ylabel('Loss')\n",
369     "        plt.xlabel('Epoch')\n",
370     "        plt.legend(loc='upper left')\n",
371     "        plt.show()\n"
372    ]
373   },
374   {
375    "cell_type": "code",
376    "execution_count": null,
377    "id": "73243595-fa7c-426e-a4ca-ee3395e3e740",
378    "metadata": {},
379    "outputs": [],
380    "source": [
381     "import yaml\n",
382     "\n",
383     "with open(\"../params.yaml\") as file:\n",
384     "    params = yaml.safe_load(file)[\"rnn\"]\n",
385     "\n",
386     "params.update({\n",
387     "    'dropout': [0, 0], # NOTE: length must match total number of layers, default is 1 hidden recurrent layer and 1 dense output layer\n",
388     "    'recurrent_dropout': 0, # Length must match number of recurrent layers\n",
389     "})"
390    ]
391   },
392   {
393    "cell_type": "code",
394    "execution_count": null,
395    "id": "9507496c-d9d1-4b18-a05f-af3143d2f9e6",
396    "metadata": {},
397    "outputs": [],
398    "source": [
399     "N = len(repro[case][\"fm\"]) # total observations\n",
400     "train_ind = repro[case]['h2']\n",
401     "\n",
402     "X = np.vstack((repro[case][\"Ed\"], repro[case][\"Ew\"])).T\n",
403     "y = repro[case][\"fm\"]\n",
404     "\n",
405     "X_train = X[:train_ind]\n",
406     "X_test = X[train_ind:]\n",
407     "y_train = y[:train_ind].reshape(-1,1)\n",
408     "y_test = y[train_ind:].reshape(-1,1)\n",
409     "\n",
410     "print(f\"Total Observations: {N}\")\n",
411     "print(f\"Num Training: {X_train.shape[0]}\")\n",
412     "print(f\"Num Test: {X_test.shape[0]}\")\n",
413     "\n",
414     "from moisture_rnn import staircase\n",
415     "X_train, y_train = staircase(X_train, y_train, timesteps = params[\"timesteps\"], datapoints = len(y_train), verbose=True)\n",
416     "print(\"~\"*50)\n",
417     "phours, features = X_test.shape\n",
418     "X_test = np.reshape(X_test,(1, phours, features))\n",
419     "print(f\"X_test shape: {X_test.shape}\")"
420    ]
421   },
422   {
423    "cell_type": "code",
424    "execution_count": null,
425    "id": "77b675d8-ac1c-473f-acf5-6682bfb2e2dd",
426    "metadata": {},
427    "outputs": [],
428    "source": [
429     "samples, timesteps, features = X_train.shape\n",
430     "batch_size = samples # Single batch for testing\n",
431     "\n",
432     "params.update({\n",
433     "    'batch_shape': (batch_size,timesteps,features),\n",
434     "    'batch_size': batch_size, # Single Batch for testing\n",
435     "    'pred_input_shape': (X.shape[0], X.shape[1]),\n",
436     "    'epochs': 5000,\n",
437     "    'stateful': True,\n",
438     "    'features': features\n",
439     "})"
440    ]
441   },
442   {
443    "cell_type": "code",
444    "execution_count": null,
445    "id": "0523b883-08e1-4158-8e07-27466f2588d5",
446    "metadata": {},
447    "outputs": [],
448    "source": [
449     "reproducibility.set_seed()\n",
450     "rnn = RNN(params)\n",
451     "m = rnn.predict(np.reshape(X,(1, X.shape[0], features)))\n",
452     "print(hash2(m))\n",
453     "rnn.fit(X_train, y_train)"
454    ]
455   },
456   {
457    "cell_type": "code",
458    "execution_count": null,
459    "id": "773c5c20-1278-47e7-9b69-08065fb7bc8b",
460    "metadata": {},
461    "outputs": [],
462    "source": [
463     "preds = rnn.predict(np.reshape(X,(1, X.shape[0], features)))\n",
464     "rmse(preds, y)"
465    ]
466   },
467   {
468    "cell_type": "code",
469    "execution_count": null,
470    "id": "9bfbd951-c89a-4b69-be94-7ac4783f5eb7",
471    "metadata": {},
472    "outputs": [],
473    "source": []
474   }
475  ],
476  "metadata": {
477   "kernelspec": {
478    "display_name": "Python 3 (ipykernel)",
479    "language": "python",
480    "name": "python3"
481   },
482   "language_info": {
483    "codemirror_mode": {
484     "name": "ipython",
485     "version": 3
486    },
487    "file_extension": ".py",
488    "mimetype": "text/x-python",
489    "name": "python",
490    "nbconvert_exporter": "python",
491    "pygments_lexer": "ipython3",
492    "version": "3.9.12"
493   }
494  },
495  "nbformat": 4,
496  "nbformat_minor": 5