Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / forecast_model.ipynb
blob531fdc9594768cbe88653600590b8873035ae62f
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "a1ade692-423b-4699-82e6-e8261772d6e5",
6    "metadata": {},
7    "source": [
8     "# OLD fit model at one location, run prediction on HRRRR. grid"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "d6757008-a9ea-4e3a-8724-f5d241df0488",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "# Environment\n",
19     "import numpy as np\n",
20     "import pandas as pd\n",
21     "import tensorflow as tf\n",
22     "import matplotlib.pyplot as plt\n",
23     "import sys\n",
24     "# Local modules\n",
25     "sys.path.append('..')\n",
26     "import reproducibility\n",
27     "from utils import print_dict_summary\n",
28     "from data_funcs import load_and_fix_data, rmse, plot_data\n",
29     "from moisture_rnn import create_rnn_data_1, create_rnn_data_2, train_rnn, rnn_predict"
30    ]
31   },
32   {
33    "cell_type": "markdown",
34    "id": "f15ee8a4-44ff-4186-9962-c19b48acffd8",
35    "metadata": {},
36    "source": [
37     "## Train Model"
38    ]
39   },
40   {
41    "cell_type": "code",
42    "execution_count": null,
43    "id": "8e196ef6-1762-4705-818b-e5485ef64575",
44    "metadata": {},
45    "outputs": [],
46    "source": [
47     "reproducibility_file='data/reproducibility_dict.pickle'\n",
48     "\n",
49     "repro={}\n",
50     "repro.update(load_and_fix_data(reproducibility_file))\n",
51     "print_dict_summary(repro)"
52    ]
53   },
54   {
55    "cell_type": "code",
56    "execution_count": null,
57    "id": "e763346c-3223-4be0-8300-1e998a58d693",
58    "metadata": {},
59    "outputs": [],
60    "source": [
61     "param_sets_ORIG = {'id':0,\n",
62     "        'purpose':'reproducibility',\n",
63     "        'batch_size':np.inf,\n",
64     "        'training':None,\n",
65     "        'cases':['case11'],\n",
66     "        'scale':0,\n",
67     "        'rain_do':False,\n",
68     "#        'verbose':False,\n",
69     "        'verbose':1,\n",
70     "        'timesteps':5,\n",
71     "        'activation':['linear','linear'],\n",
72     "        'centering':[0.0,0.0],\n",
73     "        'hidden_units':6,\n",
74     "        'dense_units':1,\n",
75     "        'dense_layers':1,\n",
76     "        'DeltaE':[0,-1],    # -1.0 is to correct E bias but put at the end\n",
77     "        'synthetic':False,  # run also synthetic cases\n",
78     "        'T1': 0.1,          # 1/fuel class (10)\n",
79     "        'fm_raise_vs_rain': 2.0,         # fm increase per mm rain                              \n",
80     "        'epochs':5000,\n",
81     "        'verbose_fit':0,\n",
82     "        'verbose_weights':False,\n",
83     "        'note':'check 5 should give zero error'\n",
84     "        }"
85    ]
86   },
87   {
88    "cell_type": "code",
89    "execution_count": null,
90    "id": "402e4c74-ec55-40e7-9c36-5afeacb65ac4",
91    "metadata": {},
92    "outputs": [],
93    "source": [
94     "case_data = repro[\"case11\"]\n",
95     "h2=case_data['h2']\n",
96     "params = param_sets_ORIG\n",
97     "reproducibility.set_seed() # Set seed for reproducibility\n",
98     "rnn_dat = create_rnn_data_1(case_data,params)\n",
99     "create_rnn_data_2(rnn_dat,params)\n",
100     "\n",
101     "print(rnn_dat[\"x_train\"].shape)\n",
102     "print(rnn_dat[\"y_train\"].shape)"
103    ]
104   },
105   {
106    "cell_type": "code",
107    "execution_count": null,
108    "id": "f921b88f-21ef-438d-b657-5ef7dfb5bfb2",
109    "metadata": {},
110    "outputs": [],
111    "source": [
112     "model_predict = train_rnn(\n",
113     "    rnn_dat,\n",
114     "    params,\n",
115     "    rnn_dat['hours'],\n",
116     "    fit=True\n",
117     ")"
118    ]
119   },
120   {
121    "cell_type": "code",
122    "execution_count": null,
123    "id": "62a63c8f-f832-4a8d-9c3e-60202d029957",
124    "metadata": {},
125    "outputs": [],
126    "source": [
127     "m = rnn_predict(model_predict, params, rnn_dat)\n",
128     "case_data['m'] = m\n",
129     "plot_data(case_data)"
130    ]
131   },
132   {
133    "cell_type": "markdown",
134    "id": "b507a0e6-8830-451b-a557-139b3fd4c310",
135    "metadata": {},
136    "source": [
137     "## Format Data to Predict"
138    ]
139   },
140   {
141    "cell_type": "markdown",
142    "id": "65386a9b-f0d0-483c-a881-aee2f5432f49",
143    "metadata": {},
144    "source": [
145     "### Test Plot of One HRRR Grid"
146    ]
147   },
148   {
149    "cell_type": "code",
150    "execution_count": null,
151    "id": "f7e3f84b-804b-431c-993e-60ec591280fb",
152    "metadata": {},
153    "outputs": [],
154    "source": [
155     "# Destination URL for data files\n",
156     "url = \"https://demo.openwfm.org/web/data/fmda/tif/20240101/\"\n",
157     "\n",
158     "# Get List of files for model with just E's\n",
159     "# Need temp and RH band nums from https://www.nco.ncep.noaa.gov/pmb/products/hrrr/hrrr.t00z.wrfprsf00.grib2.shtml\n",
160     "bands = [616, 620] # temp, RH\n",
161     "\n",
162     "# List of hours to predict\n",
163     "pred_hours = [0, 1, 2]\n",
164     "\n",
165     "# Format tif files\n",
166     "files = {}\n",
167     "for h in pred_hours:\n",
168     "    hr = str(h).zfill(2)\n",
169     "    files[f\"hour_{hr}\"] = []\n",
170     "    for b in bands:\n",
171     "        f = f\"{url}hrrr.t{hr}z.wrfprsf00.{b}.tif\"\n",
172     "        files[f\"hour_{hr}\"].append(f)\n",
173     "        print(f\"Filename: {f}\")"
174    ]
175   },
176   {
177    "cell_type": "code",
178    "execution_count": null,
179    "id": "8ca89988-30bd-45dd-bcd8-6b7fde5c3565",
180    "metadata": {},
181    "outputs": [],
182    "source": [
183     "# Ed = 0.924*rh**0.679 + 0.000499*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - t2)*(1 - np.exp(-0.115*rh))\n",
184     "# Ew = 0.618*rh**0.753 + 0.000454*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - t2)*(1 - np.exp(-0.115*rh))"
185    ]
186   },
187   {
188    "cell_type": "code",
189    "execution_count": null,
190    "id": "eed839a1-35dd-4e1d-83ad-4989a8bdcf12",
191    "metadata": {},
192    "outputs": [],
193    "source": [
194     "import rioxarray"
195    ]
196   },
197   {
198    "cell_type": "code",
199    "execution_count": null,
200    "id": "bb67b7d1-a957-46a9-9e52-8e617c08018c",
201    "metadata": {},
202    "outputs": [],
203    "source": [
204     "files[\"hour_00\"][0]"
205    ]
206   },
207   {
208    "cell_type": "code",
209    "execution_count": null,
210    "id": "fbc16c08-347c-4e53-aa92-68b17266fab0",
211    "metadata": {},
212    "outputs": [],
213    "source": [
214     "temp = rioxarray.open_rasterio(files[\"hour_00\"][0])"
215    ]
216   },
217   {
218    "cell_type": "code",
219    "execution_count": null,
220    "id": "db4779c2-578c-49ee-a54b-e32804697b70",
221    "metadata": {},
222    "outputs": [],
223    "source": [
224     "temp"
225    ]
226   },
227   {
228    "cell_type": "code",
229    "execution_count": null,
230    "id": "590fd20e-dbdf-4f4f-bb05-4e49d55d61ea",
231    "metadata": {},
232    "outputs": [],
233    "source": [
234     "if np.any(temp < 150):\n",
235     "    temp += 273.15\n",
236     "plt.imshow(temp.sel(band=1))"
237    ]
238   },
239   {
240    "cell_type": "code",
241    "execution_count": null,
242    "id": "d75b4471-07fa-4d75-8f77-12730d382e06",
243    "metadata": {},
244    "outputs": [],
245    "source": [
246     "rh = rioxarray.open_rasterio(files[\"hour_00\"][1])\n",
247     "plt.imshow(rh.sel(band=1))"
248    ]
249   },
250   {
251    "cell_type": "code",
252    "execution_count": null,
253    "id": "4ca412f4-f4b2-4269-b19f-e951eb776be1",
254    "metadata": {},
255    "outputs": [],
256    "source": [
257     "Ed = 0.924*rh**0.679 + 0.000499*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - temp)*(1 - np.exp(-0.115*rh))\n",
258     "Ew = 0.618*rh**0.753 + 0.000454*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - temp)*(1 - np.exp(-0.115*rh))"
259    ]
260   },
261   {
262    "cell_type": "code",
263    "execution_count": null,
264    "id": "09b02c01-6b6d-41b5-8948-0a0806b24f0b",
265    "metadata": {},
266    "outputs": [],
267    "source": [
268     "plt.imshow(Ed.sel(band=1))\n",
269     "plt.title(\"Drying Equilibrium\")"
270    ]
271   },
272   {
273    "cell_type": "code",
274    "execution_count": null,
275    "id": "3a4df045-f2aa-4115-884a-5b8c2889b771",
276    "metadata": {},
277    "outputs": [],
278    "source": [
279     "plt.imshow(Ew.sel(band=1))\n",
280     "plt.title(\"Wetting Equilibrium\")"
281    ]
282   },
283   {
284    "cell_type": "code",
285    "execution_count": null,
286    "id": "dbe9cf25-86ab-4de0-b956-0b0a87318ecc",
287    "metadata": {},
288    "outputs": [],
289    "source": [
290     "def get_eq_from_url(files):\n",
291     "\n",
292     "    # Get right bands\n",
293     "    tfile = [file for file in files if \".616.tif\" in file]\n",
294     "    rhfile = [file for file in files if \".620.tif\" in file]\n",
295     "\n",
296     "    # Data checks\n",
297     "    assert len(tfile) == 1, \"More than 1 file found with band 620 (rh), this func only processes 1hr\"\n",
298     "    assert len(rhfile) == 1, \"More than 1 file found with band 616 (temp), this func only processes 1hr\"\n",
299     "\n",
300     "    # Read Data\n",
301     "    temp = rioxarray.open_rasterio(tfile[0])\n",
302     "    rh = rioxarray.open_rasterio(rhfile[0])\n",
303     "    assert temp.data.shape == rh.data.shape, \"Temp and RH data different shapes\"\n",
304     "\n",
305     "    # Convert C to K if C detected, check is whether any value is less than 150 deg. TODO: do this w metadata\n",
306     "    if np.any(temp < 150):\n",
307     "        temp += 273.15\n",
308     "    \n",
309     "    Ed = 0.924*rh**0.679 + 0.000499*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - temp)*(1 - np.exp(-0.115*rh))\n",
310     "    Ew = 0.618*rh**0.753 + 0.000454*np.exp(0.1*rh) + 0.18*(21.1 + 273.15 - temp)*(1 - np.exp(-0.115*rh))\n",
311     "\n",
312     "    return Ed, Ew"
313    ]
314   },
315   {
316    "cell_type": "code",
317    "execution_count": null,
318    "id": "31a84e0e-ecbc-4b57-9432-bdc6042249ab",
319    "metadata": {},
320    "outputs": [],
321    "source": [
322     "Ed, Ew = get_eq_from_url(files[\"hour_00\"])"
323    ]
324   },
325   {
326    "cell_type": "code",
327    "execution_count": null,
328    "id": "8ff42b9f-a8c8-4767-a1d4-2e7b464c5019",
329    "metadata": {},
330    "outputs": [],
331    "source": [
332     "from pyproj import Transformer\n",
333     "\n",
334     "# Subset data with bbox\n",
335     "# BBox from GACC\n",
336     "bbox = [42,-124.6,49,-116.4] # PNW bbox\n",
337     "# Convert to coord system of datasets\n",
338     "transform = Ed.rio.transform()\n",
339     "crs = Ed.rio.crs\n",
340     "transformer = Transformer.from_crs(\"EPSG:4326\", crs, always_xy=True)\n",
341     "inv_transform = ~transform\n",
342     "# x_low, y_low = inv_transform * transformer.transform(-124.6, 42)\n",
343     "# x_up, y_up = inv_transform * transformer.transform(-116.4, 49)\n",
344     "x_low, y_low = transformer.transform(-124.6, 42)\n",
345     "x_up, y_up = transformer.transform(-116.4, 49)"
346    ]
347   },
348   {
349    "cell_type": "code",
350    "execution_count": null,
351    "id": "c5e93c40",
352    "metadata": {},
353    "outputs": [],
354    "source": [
355     "from pyproj import Transformer\n",
356     "\n",
357     "# Subset data with bbox\n",
358     "# BBox from GACC\n",
359     "bbox = [42,-124.6,49,-116.4] # PNW bbox\n",
360     "# Convert to coord system of datasets\n",
361     "transform = Ed.rio.transform()\n",
362     "crs = Ed.rio.crs\n",
363     "transformer = Transformer.from_crs(\"EPSG:4326\", crs, always_xy=True)\n",
364     "inv_transform = ~transform\n",
365     "# x_low, y_low = inv_transform * transformer.transform(-124.6, 42)\n",
366     "# x_up, y_up = inv_transform * transformer.transform(-116.4, 49)\n",
367     "x_low, y_low = transformer.transform(-124.6, 42)\n",
368     "x_up, y_up = transformer.transform(-116.4, 49)"
369    ]
370   },
371   {
372    "cell_type": "code",
373    "execution_count": null,
374    "id": "14ba70f9-39b0-4a0f-b95a-db3324027a23",
375    "metadata": {},
376    "outputs": [],
377    "source": [
378     "zz = Ed.sel(x=slice(x_low, x_up), y=slice(y_up, y_low))"
379    ]
380   },
381   {
382    "cell_type": "code",
383    "execution_count": null,
384    "id": "23bc285e-a52c-420c-a894-324aa4a8dc88",
385    "metadata": {},
386    "outputs": [],
387    "source": [
388     "plt.imshow(Ed.sel(band=1))\n",
389     "xx, yy=inv_transform * (x_low, y_low)\n",
390     "xx2, yy2=inv_transform * (x_up, y_up)\n",
391     "plt.plot(xx, yy, marker='o', color='red', markersize=6)\n",
392     "plt.plot(xx2, yy2, marker='o', color='red', markersize=6)\n",
393     "# plt.imshow(zz.sel(band=1))\n",
394     "plt.show()"
395    ]
396   },
397   {
398    "cell_type": "code",
399    "execution_count": null,
400    "id": "39169958-a2b4-4bb8-af4a-8e6f258b8037",
401    "metadata": {},
402    "outputs": [],
403    "source": [
404     "plt.imshow(zz.sel(band=1))\n",
405     "plt.plot(xx, yy, marker='o', color='red', markersize=6)\n",
406     "plt.plot(xx2, yy2, marker='o', color='red', markersize=6)"
407    ]
408   },
409   {
410    "cell_type": "markdown",
411    "id": "d2994284-f3ce-4892-86c5-086604e5a2d7",
412    "metadata": {},
413    "source": [
414     "### Get Timeseries for Grid"
415    ]
416   },
417   {
418    "cell_type": "code",
419    "execution_count": null,
420    "id": "510ff586-a327-43ed-8ded-8109fbd6cdb9",
421    "metadata": {},
422    "outputs": [],
423    "source": [
424     "# Extract ndarray of Grid Eqs\n",
425     "## Subsetting for memory\n",
426     "\n",
427     "# Get first hour\n",
428     "Ed, Ew = get_eq_from_url(files[\"hour_00\"])\n",
429     "Ed = Ed.data\n",
430     "Ew = Ew.data\n",
431     "\n",
432     "for hr in files:\n",
433     "    if hr == \"hour_00\":\n",
434     "        continue\n",
435     "    Ed_temp, Ew_temp = get_eq_from_url(files[hr])\n",
436     "    Ed = np.concatenate((Ed, Ed_temp), axis=0)\n",
437     "    Ew = np.concatenate((Ew, Ew_temp), axis=0)\n",
438     "    del(Ed_temp)\n",
439     "    del(Ew_temp)"
440    ]
441   },
442   {
443    "cell_type": "code",
444    "execution_count": null,
445    "id": "e9902f4e-ddd1-4f26-bdca-29f92667ee43",
446    "metadata": {},
447    "outputs": [],
448    "source": [
449     "print(Ed.shape)\n",
450     "print(Ew.shape)"
451    ]
452   },
453   {
454    "cell_type": "markdown",
455    "id": "acff6d85-2533-4dec-88c1-44f9b421764c",
456    "metadata": {},
457    "source": [
458     "## Apply Model to Grid"
459    ]
460   },
461   {
462    "cell_type": "code",
463    "execution_count": null,
464    "id": "9a07125e-9f1c-4007-9e0a-46949be9e174",
465    "metadata": {},
466    "outputs": [],
467    "source": [
468     "# Get one point, convert to 1x2 array and apply model \n",
469     "hours = Ed.shape[0]\n",
470     "features = 2\n",
471     "X_new = np.array([Ed[:, 0, 0], Ew[:, 0, 0]]).reshape(1,hours,features)\n",
472     "print(X_new)\n",
473     "print(X_new.shape)"
474    ]
475   },
476   {
477    "cell_type": "code",
478    "execution_count": null,
479    "id": "4a74deb2-cd3c-4e8e-845c-705a643f6282",
480    "metadata": {},
481    "outputs": [],
482    "source": [
483     "model_predict.predict(X_new)"
484    ]
485   },
486   {
487    "cell_type": "markdown",
488    "id": "15518f20-736a-4d77-95e4-a2424bba8d4d",
489    "metadata": {},
490    "source": [
491     "### Test Simulated New Data"
492    ]
493   },
494   {
495    "cell_type": "code",
496    "execution_count": null,
497    "id": "7a2a4401-227e-472e-aeab-f35e9f0f2240",
498    "metadata": {},
499    "outputs": [],
500    "source": [
501     "hours = 100\n",
502     "features = 2\n",
503     "XX = np.array([np.repeat(20.0, hours), np.repeat(20.0, hours)]).reshape(1,hours,features)\n",
504     "print(XX.shape)"
505    ]
506   },
507   {
508    "cell_type": "code",
509    "execution_count": null,
510    "id": "8e94ad4c-5e44-4139-9ab3-89511fbaeee6",
511    "metadata": {},
512    "outputs": [],
513    "source": [
514     "preds = model_predict.predict(XX)"
515    ]
516   },
517   {
518    "cell_type": "code",
519    "execution_count": null,
520    "id": "43a24025-44d0-405d-bbb7-f641037c1d9a",
521    "metadata": {},
522    "outputs": [],
523    "source": [
524     "plt.plot(XX[0,:,0], label = \"New Data\")\n",
525     "plt.plot(preds.squeeze(), label = \"Preds\")\n",
526     "plt.legend()"
527    ]
528   },
529   {
530    "cell_type": "code",
531    "execution_count": null,
532    "id": "9ade0ed6-bb5f-4603-a1e7-64c2e45ca457",
533    "metadata": {},
534    "outputs": [],
535    "source": []
536   }
537  ],
538  "metadata": {
539   "kernelspec": {
540    "display_name": "Python 3 (ipykernel)",
541    "language": "python",
542    "name": "python3"
543   },
544   "language_info": {
545    "codemirror_mode": {
546     "name": "ipython",
547     "version": 3
548    },
549    "file_extension": ".py",
550    "mimetype": "text/x-python",
551    "name": "python",
552    "nbconvert_exporter": "python",
553    "pygments_lexer": "ipython3",
554    "version": "3.10.9"
555   }
556  },
557  "nbformat": 4,
558  "nbformat_minor": 5