Update rnn_workshop.ipynb
[notebooks.git] / fmda / presentations / rnn_loss_tutorial.ipynb
blobc361757adad059755c66be5397ffd42d3da15b5e
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "2106ddc4-8330-44ca-8a69-742f53796740",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "import sys\n",
11     "sys.path.append(\"..\")\n",
12     "import reproducibility\n",
13     "import pandas as pd\n",
14     "from data_funcs import load_and_fix_data\n",
15     "from moisture_rnn import create_RNN_2, create_rnn_data_1, create_rnn_data_2, train_rnn"
16    ]
17   },
18   {
19    "cell_type": "code",
20    "execution_count": null,
21    "id": "f5b25c8a-e879-4d3f-b144-ae7876aa70c3",
22    "metadata": {},
23    "outputs": [],
24    "source": [
25     "reproducibility_file='../data/reproducibility_dict.pickle'\n",
26     "repro=load_and_fix_data(reproducibility_file)"
27    ]
28   },
29   {
30    "cell_type": "code",
31    "execution_count": null,
32    "id": "af32b30f-ad9c-44a2-9586-3012ba0450f5",
33    "metadata": {},
34    "outputs": [],
35    "source": [
36     "from module_param_sets import param_sets"
37    ]
38   },
39   {
40    "cell_type": "code",
41    "execution_count": null,
42    "id": "4a5a4137-2c9b-4fd9-8ad0-00b773ee99cd",
43    "metadata": {},
44    "outputs": [],
45    "source": [
46     "params = param_sets['0']"
47    ]
48   },
49   {
50    "cell_type": "code",
51    "execution_count": null,
52    "id": "b93faa27-c0f8-46cb-9084-7405f93470ea",
53    "metadata": {},
54    "outputs": [],
55    "source": [
56     "# Simplify params\n",
57     "params['batch_size']=3\n",
58     "params[\"timesteps\"]=2\n",
59     "params[\"epochs\"]=1\n",
60     "params[\"initialize\"]=False\n",
61     "params[\"hidden_units\"]=1\n",
62     "params[\"rain_do\"]=False\n",
63     "params"
64    ]
65   },
66   {
67    "cell_type": "code",
68    "execution_count": null,
69    "id": "0be61c7d-59c4-4b3c-afec-6393a9289439",
70    "metadata": {},
71    "outputs": [],
72    "source": [
73     "# Format Data\n",
74     "case_data = repro[\"case11\"]\n",
75     "h2=20\n",
76     "reproducibility.set_seed() # Set seed for reproducibility\n",
77     "rnn_dat = create_rnn_data_1(case_data,params)\n",
78     "create_rnn_data_2(rnn_dat,params)"
79    ]
80   },
81   {
82    "cell_type": "code",
83    "execution_count": null,
84    "id": "d1ced401-0b81-4404-b8dd-3bbb756be244",
85    "metadata": {},
86    "outputs": [],
87    "source": [
88     "features = rnn_dat[\"X\"].shape[1]\n",
89     "features"
90    ]
91   },
92   {
93    "cell_type": "code",
94    "execution_count": null,
95    "id": "26d61ce5-f492-4d20-88d8-f5bbd41f31a4",
96    "metadata": {},
97    "outputs": [],
98    "source": [
99     "# Setup Model\n",
100     "reproducibility.set_seed()\n",
101     "model = create_RNN_2(\n",
102     "    hidden_units=params[\"hidden_units\"], \n",
103     "    dense_units=1, \n",
104     "    activation=params[\"activation\"],\n",
105     "    batch_shape=(params[\"batch_size\"],params[\"timesteps\"],features),\n",
106     "    stateful=True\n",
107     ")\n",
108     "# Print initial weights\n",
109     "model.get_weights()"
110    ]
111   },
112   {
113    "cell_type": "code",
114    "execution_count": null,
115    "id": "bca3b758-dde0-491b-9566-4a868f45286c",
116    "metadata": {},
117    "outputs": [],
118    "source": [
119     "# Run a sample through\n",
120     "X = rnn_dat[\"x_train\"][0,:,:].reshape(-1,params[\"timesteps\"],features)\n",
121     "X"
122    ]
123   },
124   {
125    "cell_type": "code",
126    "execution_count": null,
127    "id": "0e8b055c-426e-430c-ace8-87ea482876c7",
128    "metadata": {},
129    "outputs": [],
130    "source": [
131     "preds = model.predict(X)\n",
132     "preds[0]"
133    ]
134   },
135   {
136    "cell_type": "code",
137    "execution_count": null,
138    "id": "fa267c9b-e76f-46c7-a10a-d3434c913aaa",
139    "metadata": {},
140    "outputs": [],
141    "source": [
142     "y = rnn_dat[\"y_train\"][0].reshape(-1, 1)"
143    ]
144   },
145   {
146    "cell_type": "code",
147    "execution_count": null,
148    "id": "9d11f8c4-d8c9-4665-8c93-6fd887b864d7",
149    "metadata": {},
150    "outputs": [],
151    "source": [
152     "# Calculate MSE\n",
153     "(y - preds[0])**2"
154    ]
155   },
156   {
157    "cell_type": "code",
158    "execution_count": null,
159    "id": "d6b9b56e-52b7-4636-a524-ab310c059fe7",
160    "metadata": {},
161    "outputs": [],
162    "source": [
163     "# Use loss calculation from before to manually update weights\n",
164     "model.output"
165    ]
166   },
167   {
168    "cell_type": "code",
169    "execution_count": null,
170    "id": "c3c47f67-52a2-4c13-ac5b-d1fc0de077bd",
171    "metadata": {},
172    "outputs": [],
173    "source": [
174     "model.optimizer.learning_rate.value"
175    ]
176   },
177   {
178    "cell_type": "code",
179    "execution_count": null,
180    "id": "d373fa86-2dc8-4a80-8882-cacc0db281c6",
181    "metadata": {},
182    "outputs": [],
183    "source": [
184     "reproducibility.set_seed()\n",
185     "history = model.fit(X, \n",
186     "            y, \n",
187     "            epochs=params[\"epochs\"], \n",
188     "            batch_size=params[\"batch_size\"])"
189    ]
190   },
191   {
192    "cell_type": "code",
193    "execution_count": null,
194    "id": "34bfa036-dbed-4fe9-a4cb-9e90b2de9185",
195    "metadata": {},
196    "outputs": [],
197    "source": [
198     "history.history"
199    ]
200   },
201   {
202    "cell_type": "code",
203    "execution_count": null,
204    "id": "ceac8930-bcc5-4a06-b929-7c5c943d6581",
205    "metadata": {},
206    "outputs": [],
207    "source": [
208     "# Print Trained Weights\n",
209     "model.get_weights()"
210    ]
211   },
212   {
213    "cell_type": "code",
214    "execution_count": null,
215    "id": "3624a803-a45b-4314-b3d9-6628b1d2161b",
216    "metadata": {},
217    "outputs": [],
218    "source": []
219   }
220  ],
221  "metadata": {
222   "kernelspec": {
223    "display_name": "Python 3 (ipykernel)",
224    "language": "python",
225    "name": "python3"
226   },
227   "language_info": {
228    "codemirror_mode": {
229     "name": "ipython",
230     "version": 3
231    },
232    "file_extension": ".py",
233    "mimetype": "text/x-python",
234    "name": "python",
235    "nbconvert_exporter": "python",
236    "pygments_lexer": "ipython3",
237    "version": "3.9.12"
238   }
239  },
240  "nbformat": 4,
241  "nbformat_minor": 5