splitting moisture_rnn.py off moisture_models.py
[notebooks.git] / fmda / fmda_rnn_rain.ipynb
blob917dfc9cf7522ad7b2a27a6593017081c702db52
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "c7291842-a72d-4c4e-9312-6c0c31df18e0",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "# both can change\n",
11     "# Environment\n",
12     "import numpy as np\n",
13     "import pandas as pd\n",
14     "from MesoPy import Meso\n",
15     "import matplotlib.pyplot as plt\n",
16     "from datetime import datetime, timedelta\n",
17     "\n",
18     "import tensorflow as tf\n",
19     "from keras.models import Sequential\n",
20     "from keras.layers import Dense, SimpleRNN\n",
21     "from keras.utils.vis_utils import plot_model\n",
22     "from sklearn.preprocessing import MinMaxScaler\n",
23     "from sklearn.metrics import mean_squared_error\n",
24     "import math\n",
25     "import matplotlib.pyplot as plt\n",
26     "import tensorflow as tf\n",
27     "import keras.backend as K\n",
28     "from keras.utils.vis_utils import plot_model\n",
29     "from scipy.interpolate import LinearNDInterpolator, interpn\n",
30     "from scipy.optimize import root\n",
31     "\n",
32     "# Local modules for handling data and running moisture models\n",
33     "import data_funcs as datf\n",
34     "from data_funcs import format_raws, retrieve_raws, format_precip, fixnan\n",
35     "import moisture_models as mod\n",
36     "from moisture_rnn import create_RNN, create_RNN_2, staircase, seq2batches, create_rnn_data, train_rnn, rnn_predict\n",
37     "\n",
38     "meso_token=\"4192c18707b848299783d59a9317c6e1\"\n",
39     "m=Meso(meso_token)"
40    ]
41   },
42   {
43    "cell_type": "code",
44    "execution_count": null,
45    "id": "7d299dea-9c39-4410-a4a5-a7f72d23ba99",
46    "metadata": {},
47    "outputs": [],
48    "source": [
49     "# Calculate mean squared error\n",
50     "def mse(a, b):\n",
51     "    return ((a - b)**2).mean()\n",
52     "# Calculate mean absolute error\n",
53     "def mape(a, b):\n",
54     "    return ((a - b).__abs__()).mean()"
55    ]
56   },
57   {
58    "cell_type": "code",
59    "execution_count": null,
60    "id": "04d3b14a-20e8-4e4e-802f-88404d151991",
61    "metadata": {},
62    "outputs": [],
63    "source": [
64     "def vprint(*args):\n",
65     "    if verbose: \n",
66     "        for s in args[:(len(args)-1)]:\n",
67     "            print(s, end=' ')\n",
68     "        print(args[-1])"
69    ]
70   },
71   {
72    "cell_type": "markdown",
73    "id": "947f972c-9a02-4550-8d82-1b9cf69d7e9c",
74    "metadata": {},
75    "source": [
76     "## Validation Setup"
77    ]
78   },
79   {
80    "cell_type": "code",
81    "execution_count": null,
82    "id": "d6236755-11e7-4b31-bf50-ffedb6077795",
83    "metadata": {},
84    "outputs": [],
85    "source": [
86     "time_start = \"201806010800\"\n",
87     "hours = 1200 # total simulation time\n",
88     "time_end = datetime.strptime(time_start, \"%Y%m%d%H%M\")+timedelta(hours = hours+1) # end time, plus a buffer to control for time shift\n",
89     "time_end = str(int(time_end.strftime(\"%Y%m%d%H%M\")))\n",
90     "h2 = 300 # training period\n",
91     "train_hrs = np.arange(0, h2) # training time\n",
92     "test_hrs = np.arange(h2, hours) # forecast time\n",
93     "\n",
94     "print('Time Parameters:')\n",
95     "print('-'*50)\n",
96     "print('Time Start:', datetime.strptime(time_start, \"%Y%m%d%H%M\").strftime(\"%Y/%M/%d %H:%M\"))\n",
97     "print('Time End:', datetime.strptime(time_end, \"%Y%m%d%H%M\").strftime(\"%Y/%M/%d %H:%M\"))\n",
98     "print('Total Runtime:', hours, 'hours')\n",
99     "print('Training Time:', h2, 'hours')\n",
100     "print('-'*50)"
101    ]
102   },
103   {
104    "cell_type": "markdown",
105    "id": "8320fb68-771e-4441-8849-e5bdec432bd1",
106    "metadata": {},
107    "source": [
108     "## Retrieve RAWS Data"
109    ]
110   },
111   {
112    "cell_type": "code",
113    "execution_count": null,
114    "id": "6dae1750-0656-4369-a114-2ccaa885ff55",
115    "metadata": {},
116    "outputs": [],
117    "source": [
118     "raws_vars='air_temp,relative_humidity,precip_accum,fuel_moisture'"
119    ]
120   },
121   {
122    "cell_type": "code",
123    "execution_count": null,
124    "id": "a5d8616e-3f4d-4f56-8e52-f065e8d3d5ae",
125    "metadata": {},
126    "outputs": [],
127    "source": [
128     "station, raws_dat = retrieve_raws(m, \"BKCU1\", raws_vars, time_start, time_end)"
129    ]
130   },
131   {
132    "cell_type": "code",
133    "execution_count": null,
134    "id": "0d751525-58c9-4668-9755-f1aaeb34aa40",
135    "metadata": {},
136    "outputs": [],
137    "source": [
138     "def plot_dat(stn, dat, val):\n",
139     "    plt.figure(figsize=(16,4))\n",
140     "    plt.plot(dat[val],linestyle='-',c='k')\n",
141     "    plt.title(stn['STID']+' '+ val)\n",
142     "    plt.xlabel('Time (hours)') \n",
143     "    plt.ylabel('val')"
144    ]
145   },
146   {
147    "cell_type": "code",
148    "execution_count": null,
149    "id": "013a6794-9c56-4a32-8699-15763465544f",
150    "metadata": {},
151    "outputs": [],
152    "source": [
153     "%matplotlib inline\n",
154     "plot_dat(station, raws_dat, 'fm')"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "ce5e033c-5d16-44ec-9b58-4a55ff76d04d",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "print('Data Read:')\n",
165     "print('-'*50)\n",
166     "print('Station ID:', station['STID'])\n",
167     "print('Lat / Lon:', station['LATITUDE'],', ',station['LONGITUDE'])\n",
168     "if(station['QC_FLAGGED']): print('WARNING: station flagged for QC')\n",
169     "print('-'*50)"
170    ]
171   },
172   {
173    "cell_type": "markdown",
174    "id": "b45fcc1f-8394-418f-89ab-0cfbaa04d65f",
175    "metadata": {},
176    "source": [
177     "## Retrieve RTMA Function\n",
178     "\n",
179     "<mark>Not needed?</mark>"
180    ]
181   },
182   {
183    "cell_type": "markdown",
184    "id": "c7b64034",
185    "metadata": {},
186    "source": [
187     "% Jonathon changes above  create each case as a dictionary, then dictionary of dictionaries, figure out how to store and load dictionaries as a file. json is possible but: cannot contain datetime objects\n",
188     "% look into pickle also compresses while json is plain text clone wrfxpy look how for idioms, pickle added jan/angel lager\n",
189     "% Jan will edit from here below. \n",
190     "% cases will be extracted from dictionary as global variables for now at least"
191    ]
192   },
193   {
194    "cell_type": "markdown",
195    "id": "6c42a886-ecff-4379-8a12-db9a77d64045",
196    "metadata": {},
197    "source": [
198     "## Fit Augmented KF"
199    ]
200   },
201   {
202    "cell_type": "code",
203    "execution_count": null,
204    "id": "fda8aa6b-a241-47e3-881f-6e75373f1a2c",
205    "metadata": {},
206    "outputs": [],
207    "source": [
208     "m,Ec = mod.run_augmented_kf(raws_dat['fm'],raws_dat['Ed'],raws_dat['Ew'],raws_dat['rain'],h2,hours)  # extract from state"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "c531fad3-1d6f-4738-a019-f587a7ab7139",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "def plot_moisture(hmin,hmax):\n",
219     "    print('training from 0 to',h2,'plot from',hmin,'to',hmax)\n",
220     "    plt.figure(figsize=(16,4))\n",
221     "    plt.plot(range(hmin,hmax),raws_dat['Ed'][hmin:hmax],linestyle='--',c='r',label='Drying Equilibrium (%)')\n",
222     "    plt.plot(range(hmin,hmax),raws_dat['Ew'][hmin:hmax],linestyle='--',c='b',label='Wetting Equilibrium (%)')\n",
223     "    plt.plot(range(hmin,hmax),Ec[hmin:hmax],linestyle='--',c='g',label='Equilibrium Correction (%)')\n",
224     "    plt.plot(range(hmin,hmax),m[hmin:hmax],linestyle='-',c='k',label='filtered')\n",
225     "    plt.plot(range(hmin,hmax),raws_dat['fm'][hmin:hmax],linestyle='-',c='b',label='RAWS data (%)')\n",
226     "    plt.plot(range(hmin,hmax),raws_dat['rain'][hmin:hmax],linestyle='-',c='b',label='RTMA rain (mm/h)')\n",
227     "    if hmin>=h2:\n",
228     "        plt.plot(m[hmin:h2],linestyle='-',c='k',label='Filtered')\n",
229     "    h1 = np.maximum(hmin,h2)\n",
230     "    plt.plot(range(h1,hmax),m[h1:hmax],linestyle='-',c='r',label='Forecast (%)')\n",
231     "    plt.title(station['STID'] +' Kalman filtering and forecast with augmented state, real data. Training 0:%i hmax' % h2)\n",
232     "    plt.xlabel('Time (hours)') \n",
233     "    plt.ylabel('Fuel moisture content (%)')\n",
234     "    plt.legend()"
235    ]
236   },
237   {
238    "cell_type": "code",
239    "execution_count": null,
240    "id": "4d5388f2-1c21-4b4e-860f-a7ea7c7e2bbc",
241    "metadata": {},
242    "outputs": [],
243    "source": [
244     "plot_moisture(0, hours)"
245    ]
246   },
247   {
248    "cell_type": "code",
249    "execution_count": null,
250    "id": "ff9f0f20-b38f-4643-97cd-969914fca2dc",
251    "metadata": {},
252    "outputs": [],
253    "source": [
254     "# Forecast Error\n",
255     "print('Forecast MSE: ' + str(np.round(mse(m[h2:hours], raws_dat['fm'][h2:hours]), 4)))"
256    ]
257   },
258   {
259    "cell_type": "markdown",
260    "id": "f41a26c2-4a85-4c7e-b818-a1a2906dfb25",
261    "metadata": {},
262    "source": [
263     "## Fit RNN Model"
264    ]
265   },
266   {
267    "cell_type": "code",
268    "execution_count": null,
269    "id": "01143521-8222-4e69-9cde-7dc7c7c780e0",
270    "metadata": {},
271    "outputs": [],
272    "source": [
273     "# Set seed for reproducibility\n",
274     "tf.random.set_seed(123)"
275    ]
276   },
277   {
278    "cell_type": "code",
279    "execution_count": null,
280    "id": "cff5a394-c6f0-4af1-9890-37bf65ba0e68",
281    "metadata": {},
282    "outputs": [],
283    "source": [
284     "verbose = False\n",
285     "scale = False\n",
286     "rnn_dat = create_rnn_data(raws_dat, hours, h2, scale)"
287    ]
288   },
289   {
290    "cell_type": "code",
291    "execution_count": null,
292    "id": "871821a9-bcd9-47db-9bd6-1933094ac137",
293    "metadata": {},
294    "outputs": [],
295    "source": [
296     "model_predict = train_rnn(\n",
297     "    rnn_dat,\n",
298     "    hours,\n",
299     "    activation=['linear','linear'],\n",
300     "    hidden_units=3,\n",
301     "    dense_units=1,\n",
302     "    dense_layers=1,\n",
303     "    verbose = verbose\n",
304     ")"
305    ]
306   },
307   {
308    "cell_type": "code",
309    "execution_count": null,
310    "id": "5dc3ec60-292d-4526-a793-7d466f4ce9c7",
311    "metadata": {},
312    "outputs": [],
313    "source": [
314     "verbose = 0\n",
315     "m = rnn_predict(model_predict, rnn_dat, hours)"
316    ]
317   },
318   {
319    "cell_type": "code",
320    "execution_count": null,
321    "id": "6ccc15fc-5d08-4df1-a4d5-4f0107171c15",
322    "metadata": {},
323    "outputs": [],
324    "source": [
325     "hour=np.array(range(hours))\n",
326     "title=\"RNN forecast\"\n",
327     "plt.figure(figsize=(16,4))\n",
328     "plt.plot(hour,rnn_dat['Et'][:,0],linestyle='--',c='r',label='E=Equilibrium data')\n",
329     "# print(len(hour),len(m_f))\n",
330     "plt.scatter(hour,raws_dat['fm'],c='b',label='data=10-h fuel data')\n",
331     "if m is not None:\n",
332     "    plt.plot(hour[:h2],m[:h2],linestyle='-',c='k',label='m=filtered')\n",
333     "    plt.plot(hour[h2:hours],m[h2:hours],linestyle='-',c='r',label='m=forecast')\n",
334     "plt.title(title) \n",
335     "plt.legend()"
336    ]
337   },
338   {
339    "cell_type": "code",
340    "execution_count": null,
341    "id": "1d06f6ee-2c05-4473-956c-ba490cf773d2",
342    "metadata": {},
343    "outputs": [],
344    "source": [
345     "# Overall Error\n",
346     "print(mse(m, raws_dat['fm'][0:hours]))\n",
347     "\n",
348     "# Forecast Eror\n",
349     "print(mse(m[h2:hours], raws_dat['fm'][h2:hours]))"
350    ]
351   },
352   {
353    "cell_type": "code",
354    "execution_count": null,
355    "id": "1e7253c4-0e4c-4076-8f2a-f661e057efd8",
356    "metadata": {},
357    "outputs": [],
358    "source": []
359   }
360  ],
361  "metadata": {
362   "kernelspec": {
363    "display_name": "Python 3 (ipykernel)",
364    "language": "python",
365    "name": "python3"
366   },
367   "language_info": {
368    "codemirror_mode": {
369     "name": "ipython",
370     "version": 3
371    },
372    "file_extension": ".py",
373    "mimetype": "text/x-python",
374    "name": "python",
375    "nbconvert_exporter": "python",
376    "pygments_lexer": "ipython3",
377    "version": "3.10.8"
378   }
379  },
380  "nbformat": 4,
381  "nbformat_minor": 5