Merge pull request #17 from openwfm/restructure
[notebooks.git] / fmda / OLD / results_summary.ipynb
blob68e0701d2f35cab76a000e6732537b7f6871b07e
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "a2442d5d-ad18-422a-8972-ac877a5d7772",
6    "metadata": {},
7    "source": [
8     "# Results Analysis"
9    ]
10   },
11   {
12    "cell_type": "markdown",
13    "id": "b121638e-72c5-4199-b076-5040b30e521e",
14    "metadata": {},
15    "source": [
16     "The purpose of this notebook is to analyze the results from the RNN training experiments.\n",
17     "\n",
18     "Unless otherwise stated, we will refer to the trained RNN as \"the RNN\". All validation numbers are RMSE."
19    ]
20   },
21   {
22    "cell_type": "markdown",
23    "id": "b9fcba93-c2ae-4838-a867-c803b9674d43",
24    "metadata": {},
25    "source": [
26     "## Environment Setup"
27    ]
28   },
29   {
30    "cell_type": "code",
31    "execution_count": null,
32    "id": "a9a5579c-007b-49fd-a022-0463f430a854",
33    "metadata": {},
34    "outputs": [],
35    "source": [
36     "import numpy as np\n",
37     "import matplotlib.pyplot as plt\n",
38     "import pandas as pd\n",
39     "import seaborn as sns\n",
40     "\n",
41     "from data_funcs import from_json"
42    ]
43   },
44   {
45    "cell_type": "code",
46    "execution_count": null,
47    "id": "1693cc97-a154-4a1f-8dbb-724b4669abb7",
48    "metadata": {},
49    "outputs": [],
50    "source": [
51     "results = from_json(\"data/output.json\")"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "adef561c-e182-4aa8-ab5a-a55a3a9a91c6",
57    "metadata": {},
58    "source": [
59     "## Control Case\n",
60     "\n",
61     "Case 11 with param set 0 was the reproducibility case."
62    ]
63   },
64   {
65    "cell_type": "code",
66    "execution_count": null,
67    "id": "4031235e-b679-46f2-984a-9529167c3137",
68    "metadata": {},
69    "outputs": [],
70    "source": [
71     "pd.DataFrame(results['0']['cases']['case11'])"
72    ]
73   },
74   {
75    "cell_type": "markdown",
76    "id": "628d1153-b085-42cd-be06-e1a73b3b539a",
77    "metadata": {},
78    "source": [
79     "The RNN outperforms the KF in this case. Note the prediction RMSE is lower than the training RMSE for the RNN, indicating there are not signs of overfitting. The KF, by contrast, has very low training error but a prediction error over 3x larger."
80    ]
81   },
82   {
83    "cell_type": "markdown",
84    "id": "220ea3e3-f266-4ec5-9738-9f9fb212524f",
85    "metadata": {},
86    "source": [
87     "## Summarise Results\n",
88     "\n",
89     "### Param Set Descriptions"
90    ]
91   },
92   {
93    "cell_type": "code",
94    "execution_count": null,
95    "id": "ea73ac10-7b75-4819-a4e5-78195cfc1655",
96    "metadata": {},
97    "outputs": [],
98    "source": [
99     "for i in range(1, len(results)):\n",
100     "    print('~'*50)\n",
101     "    print(results[str(i)]['params'])"
102    ]
103   },
104   {
105    "cell_type": "markdown",
106    "id": "e27d6191-98f6-4284-aedf-6137ddf1f183",
107    "metadata": {},
108    "source": [
109     "The main differences in these param sets are:\n",
110     "\n",
111     "* Activation functions: linear for the first case, then tanh, and then sigmoid.\n",
112     "* Epochs: 1,000 for set 1 versus 10,000 for sets 2 and 3\n",
113     "* Scaling: 1, .8, .8\n",
114     "* Centering: 0, 0, .5\n",
115     "\n",
116     "Each param set was run on 7 cases:"
117    ]
118   },
119   {
120    "cell_type": "markdown",
121    "id": "53b80168-c2a7-4202-bb17-36703523766d",
122    "metadata": {},
123    "source": [
124     "### Extract Results\n",
125     "\n",
126     "Excluding param set 0, as that was only run on case 11."
127    ]
128   },
129   {
130    "cell_type": "code",
131    "execution_count": null,
132    "id": "6d0d3878-957f-44cc-afca-287b6fa327f2",
133    "metadata": {},
134    "outputs": [],
135    "source": [
136     "for i in range(1, len(results)):\n",
137     "    print('~'*50)\n",
138     "    print(results[str(i)]['cases'].keys())"
139    ]
140   },
141   {
142    "cell_type": "markdown",
143    "id": "5fc6c35c-490f-4c33-9f9a-4d3a6d0136b9",
144    "metadata": {},
145    "source": [
146     "We summarise the RMSE for the param sets:"
147    ]
148   },
149   {
150    "cell_type": "markdown",
151    "id": "67acb399-b14e-4f88-816b-b26ef540c9a8",
152    "metadata": {},
153    "source": [
154     "Each case has 9 RMSE values:"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "8edad271-93a0-4fb4-991e-b0d71096a449",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "pd.DataFrame(results[str(1)]['cases']['case10'])"
165    ]
166   },
167   {
168    "cell_type": "markdown",
169    "id": "cb62add6-6f16-4e19-97b3-b5e285cfc928",
170    "metadata": {},
171    "source": [
172     "We next build a long-format dataframe with all of the results from the results dictionary. There are 3 param sets, 3 models, 3 time periods, and 7 cases. So we expect a dataframe of $3\\cdot3\\cdot3\\cdot7=189$ rows"
173    ]
174   },
175   {
176    "cell_type": "code",
177    "execution_count": null,
178    "id": "50d8743d-790d-4d73-9a9e-b641f1fb105a",
179    "metadata": {},
180    "outputs": [],
181    "source": [
182     "df = pd.DataFrame(columns=['Period', 'Case', 'RMSE', 'Model'])\n",
183     "for i in range(1, len(results)):\n",
184     "    for case in results[str(i)]['cases']:\n",
185     "        df_temp = pd.DataFrame(results[str(i)]['cases'][case])\n",
186     "        df_temp=df_temp.rename_axis(\"Period\").reset_index()\n",
187     "        df_temp['Case']=np.repeat(case, 3)\n",
188     "        df_temp['param_set']=np.repeat(int(i), 3)\n",
189     "        df_temp=pd.melt(df_temp, id_vars=['Period', 'Case', 'param_set'], value_vars=['Augmented KF', 'RNN initial', 'RNN trained'],\n",
190     "                     var_name='Model', value_name='RMSE')\n",
191     "        df = pd.concat((df, df_temp))\n",
192     "\n",
193     "df"
194    ]
195   },
196   {
197    "cell_type": "markdown",
198    "id": "a738a933-e880-412c-ba0f-75f0f7548d9d",
199    "metadata": {},
200    "source": [
201     "### Results by Param Set\n",
202     "\n",
203     "Excluding RNN initial."
204    ]
205   },
206   {
207    "cell_type": "code",
208    "execution_count": null,
209    "id": "6be03ccb-02e8-412d-803d-113a46cee9d6",
210    "metadata": {
211     "scrolled": true
212    },
213    "outputs": [],
214    "source": [
215     "df2 = df[df.Model != 'RNN initial']\n",
216     "sns.boxplot(\n",
217     "    x=df2['param_set'],\n",
218     "    y=df2['RMSE'],\n",
219     "    hue=df2['Period']\n",
220     ").set_title('Results by Param Set')"
221    ]
222   },
223   {
224    "cell_type": "markdown",
225    "id": "398bdc79-e3f0-48ce-9edf-d84b0bc445c4",
226    "metadata": {},
227    "source": [
228     "We print the group means..."
229    ]
230   },
231   {
232    "cell_type": "code",
233    "execution_count": null,
234    "id": "7a576d0b-d287-44f7-ad7c-94d11697aa01",
235    "metadata": {},
236    "outputs": [],
237    "source": [
238     "x=df2.groupby(['param_set', 'Period']).agg({'RMSE': 'mean'})\n",
239     "pd.DataFrame({\n",
240     "    'Period': ['all', 'predict', 'train'],\n",
241     "    'Set 1': list(x.RMSE[0:3]),\n",
242     "    'Set 2': list(x.RMSE[3:6]),\n",
243     "    'Set 3': list(x.RMSE[6:9])\n",
244     "})"
245    ]
246   },
247   {
248    "cell_type": "markdown",
249    "id": "bb1a2f6b-914c-4d36-a9e5-5e4af776a895",
250    "metadata": {},
251    "source": [
252     "Param sets 2 and 3 have similar rates of prediction error, though the boxplots show there is substantial overlap."
253    ]
254   },
255   {
256    "cell_type": "markdown",
257    "id": "f0b95428-ee27-49b2-ab16-045405bac854",
258    "metadata": {},
259    "source": [
260     "### Results by Model\n",
261     "\n",
262     "Here we just look at results from Param set 2 so we are not double (triple) counting results.\n",
263     "\n",
264     "Again we exclude the untrained RNN from the plot as there are extreme values that distort the plot margins."
265    ]
266   },
267   {
268    "cell_type": "code",
269    "execution_count": null,
270    "id": "6a3ab944-4dc7-459b-aa17-2c85f6e7c681",
271    "metadata": {},
272    "outputs": [],
273    "source": [
274     "df2 = df[(df.Model != 'RNN initial') & (df.param_set == 2)]\n",
275     "sns.boxplot(\n",
276     "    x=df2['Model'],\n",
277     "    y=df2['RMSE'],\n",
278     "    hue=df2['Period']\n",
279     ").set_title('Results by Model')"
280    ]
281   },
282   {
283    "cell_type": "code",
284    "execution_count": null,
285    "id": "67446b8a-a6c5-4963-9d56-8de91d1d9301",
286    "metadata": {},
287    "outputs": [],
288    "source": [
289     "x=df2.groupby(['Model', 'Period']).agg({'RMSE': 'mean'})\n",
290     "pd.DataFrame({\n",
291     "    'Period': ['all', 'predict', 'train'],\n",
292     "    'KF': list(x.RMSE[0:3]),\n",
293     "    'RNN Trained': list(x.RMSE[3:6])\n",
294     "})"
295    ]
296   },
297   {
298    "cell_type": "markdown",
299    "id": "0034c5af-3a14-480f-8e6c-b1b26790de0e",
300    "metadata": {},
301    "source": [
302     "The trained RNN has a lower prediction error on average than the KF. \n",
303     "\n",
304     "The augmented Kalman Filter gets very low training error, but a much higher prediction error, over 5x. This is clear signs of overfitting."
305    ]
306   },
307   {
308    "cell_type": "code",
309    "execution_count": null,
310    "id": "6c6d0595-c8ac-4a0b-82e4-6d069089578a",
311    "metadata": {},
312    "outputs": [],
313    "source": [
314     "df1=df[(df.Model == \"Augmented KF\") & (df.param_set==2)]\n",
315     "df2=df[(df.Model == \"RNN trained\") & (df.param_set==2)]\n",
316     "\n",
317     "# Check equality of other cols\n",
318     "print(df1['Period'].equals(df1['Period']))\n",
319     "print(df1['Case'].equals(df1['Case']))"
320    ]
321   },
322   {
323    "cell_type": "code",
324    "execution_count": null,
325    "id": "fbb1f874-4a3e-4109-82f1-fbc633afc52a",
326    "metadata": {},
327    "outputs": [],
328    "source": [
329     "# Rename RMSE's then Add RMSE from df2 to df1\n",
330     "df1=df1.rename(columns={\"RMSE\": \"RMSE KF\"})\n",
331     "df2=df2.rename(columns={\"RMSE\": \"RMSE RNN\"})\n",
332     "# df1.join(df2['RMSE RNN'])\n",
333     "df1['RMSE RNN'] = df2['RMSE RNN'].to_numpy()"
334    ]
335   },
336   {
337    "cell_type": "code",
338    "execution_count": null,
339    "id": "2e6e1fdc-a7d4-4572-ad12-e7c8862e723f",
340    "metadata": {},
341    "outputs": [],
342    "source": [
343     "sns.scatterplot(\n",
344     "    data=df1, \n",
345     "    x='RMSE KF', \n",
346     "    y='RMSE RNN', \n",
347     "    hue='Period')\n",
348     "plt.legend(loc=\"upper left\")\n",
349     "plt.ylim(0,8)\n",
350     "plt.xlim(0,8)\n",
351     "plt.title(\"RMSE - KF vs RNN (Param Set 2)\")\n",
352     "plt.axline((0, 0), slope=1, c='k', linestyle=':', alpha=.6)\n",
353     "plt.text(6,6.2,\"equal RMSE\",rotation=37, alpha=.6)\n",
354     "plt.text(3,7,\"KF Better\", alpha=.6)\n",
355     "plt.text(6,1,\"RNN Better\", alpha=.6)"
356    ]
357   },
358   {
359    "cell_type": "markdown",
360    "id": "486568fe-18b6-4cbb-b011-8fab6859e985",
361    "metadata": {},
362    "source": [
363     "## Where the RNN goes wrong\n",
364     "\n"
365    ]
366   },
367   {
368    "cell_type": "markdown",
369    "id": "f6dd5f74-92e9-4546-9239-a88392a35be4",
370    "metadata": {},
371    "source": [
372     "The initial RNN, with physics-initiated weights, has some extreme values for the initial accuracy."
373    ]
374   },
375   {
376    "cell_type": "code",
377    "execution_count": null,
378    "id": "b44c5349-14ad-4dbd-b50d-8566b09487b7",
379    "metadata": {},
380    "outputs": [],
381    "source": [
382     "df1 = df[df['Model']!= \"Augmented KF\"]"
383    ]
384   },
385   {
386    "cell_type": "code",
387    "execution_count": null,
388    "id": "59377737-eb73-4b81-88fe-a0d066ab56ce",
389    "metadata": {},
390    "outputs": [],
391    "source": [
392     "sns.histplot(df1[df1['Model']==\"RNN initial\"]['RMSE'])"
393    ]
394   },
395   {
396    "cell_type": "markdown",
397    "id": "e63a8b6a-7c2b-44f3-95ee-e51d0a706176",
398    "metadata": {},
399    "source": [
400     "The pattern is far from clean and linear, but generally the largest RMSE after training corresponds to the largest errors from the initial, untrained RNN models. We should investigate why these large initial RNN errors exist and whether it is indivative of a data issue or modeling issue."
401    ]
402   },
403   {
404    "cell_type": "code",
405    "execution_count": null,
406    "id": "39927f1f-4114-49bd-8371-13efccde6d42",
407    "metadata": {},
408    "outputs": [],
409    "source": [
410     "plt.scatter(\n",
411     "    df1[df1['Model']==\"RNN initial\"]['RMSE'],\n",
412     "    df1[df1['Model']==\"RNN trained\"]['RMSE']\n",
413     ")\n",
414     "plt.xlabel(\"Initial RMSE\")\n",
415     "plt.ylabel(\"Trained RMSE\")\n",
416     "plt.title(\"RNN RMSE - Initial vs Trained\")"
417    ]
418   },
419   {
420    "cell_type": "code",
421    "execution_count": null,
422    "id": "4d5a968c-0b89-4f6f-955c-2fdc949617c3",
423    "metadata": {},
424    "outputs": [],
425    "source": []
426   }
427  ],
428  "metadata": {
429   "kernelspec": {
430    "display_name": "Python 3 (ipykernel)",
431    "language": "python",
432    "name": "python3"
433   },
434   "language_info": {
435    "codemirror_mode": {
436     "name": "ipython",
437     "version": 3
438    },
439    "file_extension": ".py",
440    "mimetype": "text/x-python",
441    "name": "python",
442    "nbconvert_exporter": "python",
443    "pygments_lexer": "ipython3",
444    "version": "3.9.12"
445   }
446  },
447  "nbformat": 4,
448  "nbformat_minor": 5