print arguments of create_RNN_2
[notebooks.git] / fmda / fmda_rnn_rain.ipynb
blob8ac9486e4c1c3f93babe284291110cb46311974d
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "d70c330d",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "import reproducibility"
11    ]
12   },
13   {
14    "cell_type": "code",
15    "execution_count": null,
16    "id": "c7291842-a72d-4c4e-9312-6c0c31df18e0",
17    "metadata": {},
18    "outputs": [],
19    "source": [
20     "# both can change\n",
21     "# Environment\n",
22     "import numpy as np\n",
23     "import pandas as pd\n",
24     "import tensorflow as tf\n",
25     "from keras.models import Sequential\n",
26     "from keras.layers import Dense, SimpleRNN\n",
27     "from keras.utils.vis_utils import plot_model\n",
28     "from sklearn.preprocessing import MinMaxScaler\n",
29     "from sklearn.metrics import mean_squared_error\n",
30     "import math\n",
31     "import json\n",
32     "import matplotlib.pyplot as plt\n",
33     "import tensorflow as tf\n",
34     "import keras.backend as K\n",
35     "from keras.utils.vis_utils import plot_model\n",
36     "from scipy.interpolate import LinearNDInterpolator, interpn\n",
37     "from scipy.optimize import root\n",
38     "from utils import hash2\n",
39     "\n",
40     "# Local modules for handling data and running moisture models\n",
41     "import data_funcs as datf\n",
42     "from data_funcs import format_raws, retrieve_raws, format_precip, fixnan\n",
43     "from data_funcs import raws_data, synthetic_data, plot_data, check_data, mse_data, to_json, from_json\n",
44     "from moisture_rnn import create_RNN_2, staircase, create_rnn_data, train_rnn, rnn_predict\n",
45     "import moisture_models as mod\n",
46     "\n",
47     "meso_token=\"b40cb52cbdef43ef81329b84e8fd874f\"\n"
48    ]
49   },
50   {
51    "cell_type": "markdown",
52    "id": "8320fb68-771e-4441-8849-e5bdec432bd1",
53    "metadata": {},
54    "source": [
55     "## Retrieve RAWS Data"
56    ]
57   },
58   {
59    "cell_type": "code",
60    "execution_count": null,
61    "id": "ce5e033c-5d16-44ec-9b58-4a55ff76d04d",
62    "metadata": {
63     "scrolled": true
64    },
65    "outputs": [],
66    "source": [
67     "# raws_dat=raws_data(start='201806010800', hours=1200, h2=300, stid=\"BKCU1\",meso_token=meso_token)\n",
68     "raws_dat=from_json('kf_orig.json')\n",
69     "# hash2(raws_dat)"
70    ]
71   },
72   {
73    "cell_type": "code",
74    "execution_count": null,
75    "id": "8c1766b6",
76    "metadata": {},
77    "outputs": [],
78    "source": [
79     "%matplotlib inline\n",
80     "plot_data(raws_dat)"
81    ]
82   },
83   {
84    "cell_type": "code",
85    "execution_count": null,
86    "id": "c655b1b9",
87    "metadata": {},
88    "outputs": [],
89    "source": [
90     "plot_data(raws_dat,hmax=600)"
91    ]
92   },
93   {
94    "cell_type": "code",
95    "execution_count": null,
96    "id": "87ae5df3-b9a6-4102-83fa-99c21151ecc8",
97    "metadata": {},
98    "outputs": [],
99    "source": [
100     "raws_dat.keys()"
101    ]
102   },
103   {
104    "cell_type": "code",
105    "execution_count": null,
106    "id": "382c9d3d-8e11-461a-a4aa-276ae46a2cde",
107    "metadata": {},
108    "outputs": [],
109    "source": [
110     "print('fm mean:', np.nanmean(raws_dat['fm']))\n",
111     "print('rain mean:', np.nanmean(raws_dat['rain']))\n",
112     "print('Ed mean:', np.nanmean(raws_dat['Ed']))\n",
113     "print('Ew mean:', np.nanmean(raws_dat['Ew']))\n",
114     "print('m mean:', np.nanmean(raws_dat['m']))"
115    ]
116   },
117   {
118    "cell_type": "markdown",
119    "id": "b45fcc1f-8394-418f-89ab-0cfbaa04d65f",
120    "metadata": {},
121    "source": [
122     "## Retrieve RTMA Function\n",
123     "\n",
124     "<mark>Not needed?</mark>"
125    ]
126   },
127   {
128    "cell_type": "markdown",
129    "id": "c7b64034",
130    "metadata": {},
131    "source": [
132     "## Interface\n",
133     "Jonathon changes above  create each case as a dictionary, then dictionary of dictionaries, figure out how to store and load dictionaries as a file. json is possible but: cannot contain datetime objects\n",
134     "look into pickle also compresses while json is plain text clone wrfxpy look how for idioms, pickle added jan/angel lager\n",
135     "Jan will edit from here below. \n",
136     "cases will be extracted from dictionary as global variables for now at least"
137    ]
138   },
139   {
140    "cell_type": "code",
141    "execution_count": null,
142    "id": "b1614583",
143    "metadata": {},
144    "outputs": [],
145    "source": [
146     "np.random.seed(seed=123)"
147    ]
148   },
149   {
150    "cell_type": "code",
151    "execution_count": null,
152    "id": "521b453c",
153    "metadata": {},
154    "outputs": [],
155    "source": [
156     "# dictionary raws_dat has all that is needed for the run \n",
157     "# keeping the name raws_dat for now even if it may not be raws data\n",
158     "\n"
159    ]
160   },
161   {
162    "cell_type": "code",
163    "execution_count": null,
164    "id": "df5a0e60",
165    "metadata": {},
166    "outputs": [],
167    "source": [
168     "synt_dat=synthetic_data()  # just testinh\n",
169     "%matplotlib inline\n",
170     "plot_data(synt_dat)"
171    ]
172   },
173   {
174    "cell_type": "markdown",
175    "id": "6c42a886-ecff-4379-8a12-db9a77d64045",
176    "metadata": {},
177    "source": [
178     "## Fit Augmented KF"
179    ]
180   },
181   {
182    "cell_type": "code",
183    "execution_count": null,
184    "id": "fda8aa6b-a241-47e3-881f-6e75373f1a2c",
185    "metadata": {},
186    "outputs": [],
187    "source": [
188     "m,Ec = mod.run_augmented_kf(raws_dat)  # extract from state\n",
189     "raws_dat['m']=m\n",
190     "raws_dat['Ec']=Ec\n",
191     "plot_data(raws_dat,title2='augmented KF')"
192    ]
193   },
194   {
195    "cell_type": "code",
196    "execution_count": null,
197    "id": "e98780d3",
198    "metadata": {},
199    "outputs": [],
200    "source": [
201     "plot_data(raws_dat,hmin=0,hmax=600)"
202    ]
203   },
204   {
205    "cell_type": "code",
206    "execution_count": null,
207    "id": "4d5388f2-1c21-4b4e-860f-a7ea7c7e2bbc",
208    "metadata": {},
209    "outputs": [],
210    "source": [
211     "plot_data(raws_dat,hmin=900,hmax=1200,title2='augmented KF prediction detail')"
212    ]
213   },
214   {
215    "cell_type": "code",
216    "execution_count": null,
217    "id": "ff9f0f20-b38f-4643-97cd-969914fca2dc",
218    "metadata": {},
219    "outputs": [],
220    "source": [
221     "mse_data(raws_dat)\n"
222    ]
223   },
224   {
225    "cell_type": "markdown",
226    "id": "f41a26c2-4a85-4c7e-b818-a1a2906dfb25",
227    "metadata": {},
228    "source": [
229     "## Fit RNN Model"
230    ]
231   },
232   {
233    "cell_type": "code",
234    "execution_count": null,
235    "id": "01143521-8222-4e69-9cde-7dc7c7c780e0",
236    "metadata": {},
237    "outputs": [],
238    "source": [
239     "# Set seed for reproducibility\n",
240     "tf.random.set_seed(123)"
241    ]
242   },
243   {
244    "cell_type": "code",
245    "execution_count": null,
246    "id": "cff5a394-c6f0-4af1-9890-37bf65ba0e68",
247    "metadata": {},
248    "outputs": [],
249    "source": [
250     "case_data = from_json('rnn_orig.json')"
251    ]
252   },
253   {
254    "cell_type": "code",
255    "execution_count": null,
256    "id": "1fcfc8e5",
257    "metadata": {},
258    "outputs": [],
259    "source": [
260     "plot_data(case_data,title2=' from rnn_orig.json',hmin=0,hmax=600)"
261    ]
262   },
263   {
264    "cell_type": "code",
265    "execution_count": null,
266    "id": "249b93d6",
267    "metadata": {},
268    "outputs": [],
269    "source": [
270     "plot_data(case_data,title2='RNN prediction',hmin=300,hmax=600)"
271    ]
272   },
273   {
274    "cell_type": "code",
275    "execution_count": null,
276    "id": "46b34ba2",
277    "metadata": {},
278    "outputs": [],
279    "source": [
280     "if 'm' in case_data:\n",
281     "    mse_data(case_data)  # just check sdolution if there\n",
282     "    del case_data['m']   # cleanup - remove old solution if any"
283    ]
284   },
285   {
286    "cell_type": "code",
287    "execution_count": null,
288    "id": "58615e60",
289    "metadata": {},
290    "outputs": [],
291    "source": [
292     "verbose = False\n",
293     "# Set seed for reproducibility\n",
294     "tf.random.set_seed(123)\n",
295     "rnn_dat = create_rnn_data(case_data,scale=False, hours=None, h2=None, verbose=verbose)"
296    ]
297   },
298   {
299    "cell_type": "code",
300    "execution_count": null,
301    "id": "eb655e2b-7288-4c69-ac4d-28079835270b",
302    "metadata": {},
303    "outputs": [],
304    "source": [
305     "## Check 1: equilibrium input data the same\n",
306     "\n",
307     "print(hash2(rnn_dat['Et']))\n",
308     "print(hash2(rnn_dat['x_train']))\n",
309     "print(hash2(rnn_dat['y_train']))"
310    ]
311   },
312   {
313    "cell_type": "code",
314    "execution_count": null,
315    "id": "871821a9-bcd9-47db-9bd6-1933094ac137",
316    "metadata": {},
317    "outputs": [],
318    "source": [
319     "model_predict = train_rnn(\n",
320     "    rnn_dat,\n",
321     "    rnn_dat['hours'],\n",
322     "    activation=['linear','linear'],\n",
323     "    hidden_units=6,\n",
324     "    dense_units=1,\n",
325     "    dense_layers=1,\n",
326     "    verbose = verbose\n",
327     ")"
328    ]
329   },
330   {
331    "cell_type": "code",
332    "execution_count": null,
333    "id": "5dc3ec60-292d-4526-a793-7d466f4ce9c7",
334    "metadata": {},
335    "outputs": [],
336    "source": [
337     "m = rnn_predict(model_predict, rnn_dat, rnn_dat['hours'], verbose = verbose)\n",
338     "case_data['m'] = m\n",
339     "note = 'm replaced by a solution from fmda_rnn_rain'\n",
340     "if 'note' in case_data:\n",
341     "    case_data['note'] = case_data['note'] + '\\n' + note\n",
342     "else:\n",
343     "    case_data['note'] = note\n",
344     "check_data(case_data)"
345    ]
346   },
347   {
348    "cell_type": "code",
349    "execution_count": null,
350    "id": "6ccc15fc-5d08-4df1-a4d5-4f0107171c15",
351    "metadata": {},
352    "outputs": [],
353    "source": [
354     "plot_data(case_data,title2='with trained RNN',hmin=0,hmax=600)\n"
355    ]
356   },
357   {
358    "cell_type": "code",
359    "execution_count": null,
360    "id": "1d06f6ee-2c05-4473-956c-ba490cf773d2",
361    "metadata": {},
362    "outputs": [],
363    "source": [
364     "mse_data(case_data)"
365    ]
366   },
367   {
368    "cell_type": "code",
369    "execution_count": null,
370    "id": "4e1668f8",
371    "metadata": {},
372    "outputs": [],
373    "source": [
374     "plot_data(case_data,title2='RNN prediction',hmin=300,hmax=600)"
375    ]
376   },
377   {
378    "cell_type": "markdown",
379    "id": "4a22cf6a-dfa8-45d5-9a4a-bba9f2fb386d",
380    "metadata": {},
381    "source": [
382     "---\n",
383     "---"
384    ]
385   },
386   {
387    "cell_type": "markdown",
388    "id": "f62ad24d-7974-4f22-a797-3419e51e03f9",
389    "metadata": {},
390    "source": [
391     "<mark>Start Here after Check 1<\\mark>"
392    ]
393   },
394   {
395    "cell_type": "code",
396    "execution_count": null,
397    "id": "e3b68921-3745-4067-857c-9d4329d9c979",
398    "metadata": {},
399    "outputs": [],
400    "source": [
401     "from utils import hash2"
402    ]
403   },
404   {
405    "cell_type": "code",
406    "execution_count": null,
407    "id": "bfbc82c0-2117-4b2b-b063-dbde46c856fd",
408    "metadata": {},
409    "outputs": [],
410    "source": [
411     "#tf.keras.utils.set_random_seed(123)\n",
412     "#tf.random.set_seed(123)\n",
413     "reproducibility.set_seed()"
414    ]
415   },
416   {
417    "cell_type": "code",
418    "execution_count": null,
419    "id": "9b09eb58-cd1f-4146-86d4-e60260d5f035",
420    "metadata": {},
421    "outputs": [],
422    "source": [
423     "from utils import vprint\n",
424     "\n",
425     "hours = rnn_dat['hours']\n",
426     "    \n",
427     "samples = rnn_dat['samples']\n",
428     "features = rnn_dat['features']\n",
429     "timesteps = rnn_dat['timesteps']\n",
430     "    \n",
431     "model_fit=create_RNN_2(hidden_units=6, \n",
432     "                        dense_units=1, \n",
433     "                        batch_shape=(samples,timesteps,features),\n",
434     "                        stateful=True,\n",
435     "                        return_sequences=False,\n",
436     "                        # initial_state=h0,\n",
437     "                        activation=['linear','linear'],\n",
438     "                        dense_layers=1)\n",
439     "\n",
440     "from keras.utils.vis_utils import plot_model\n",
441     "plot_model(model_fit, to_file='model_plot.png', \n",
442     "           show_shapes=True, show_layer_names=True)"
443    ]
444   },
445   {
446    "cell_type": "code",
447    "execution_count": null,
448    "id": "030953b3-be17-40d7-b0b6-7e73e8b92869",
449    "metadata": {},
450    "outputs": [],
451    "source": [
452     "## Check 2: Untrained RNN initialized with same weights\n",
453     "\n",
454     "hash2(model_fit.get_weights())"
455    ]
456   },
457   {
458    "cell_type": "code",
459    "execution_count": null,
460    "id": "d18e09e0-7f47-4bc5-8cee-23752ac5bd5f",
461    "metadata": {},
462    "outputs": [],
463    "source": [
464     "Et = rnn_dat['Et']\n",
465     "model_predict=create_RNN_2(hidden_units=6, dense_units=1,  \n",
466     "                            input_shape=(hours,features),stateful = False,\n",
467     "                            return_sequences=True,\n",
468     "                            activation=['linear','linear'],dense_layers=1)"
469    ]
470   },
471   {
472    "cell_type": "code",
473    "execution_count": null,
474    "id": "625a24ef-92ab-4657-87c6-aede6b7b89b2",
475    "metadata": {},
476    "outputs": [],
477    "source": [
478     "## Check 3: Second model initialization same weights\n",
479     "\n",
480     "hash2(model_predict.get_weights())"
481    ]
482   },
483   {
484    "cell_type": "code",
485    "execution_count": null,
486    "id": "797d5c31-9bff-4df9-83f5-cf688f3291d0",
487    "metadata": {},
488    "outputs": [],
489    "source": [
490     "print(rnn_dat)\n",
491     "x_train = rnn_dat['x_train']\n",
492     "y_train = rnn_dat['y_train']\n",
493     "type(x_train)\n",
494     "\n",
495     "# fitting\n",
496     "DeltaE = 0\n",
497     "w_exact=  [np.array([[1.-np.exp(-0.1)]]), np.array([[np.exp(-0.1)]]), np.array([0.]),np.array([[1.0]]),np.array([-1.*DeltaE])]\n",
498     "    \n",
499     "w_initial=[np.array([[1.-np.exp(-0.1)]]), np.array([[np.exp(-0.1)]]), np.array([0.]),np.array([[1.0]]),np.array([-1.0])]\n",
500     "w=model_fit.get_weights()\n",
501     "for i in range(len(w)):\n",
502     "    vprint('weight',i,'shape',w[i].shape,'ndim',w[i].ndim,'given',w_initial[i].shape)\n",
503     "    for j in range(w[i].shape[0]):\n",
504     "        if w[i].ndim==2:\n",
505     "            for k in range(w[i].shape[1]):\n",
506     "                w[i][j][k]=w_initial[i][0][0]/w[i].shape[0]\n",
507     "        else:\n",
508     "            w[i][j]=w_initial[i][0]\n",
509     "model_fit.set_weights(w)"
510    ]
511   },
512   {
513    "cell_type": "code",
514    "execution_count": null,
515    "id": "580caac4-ba9a-48b1-86a3-8e97357bf666",
516    "metadata": {},
517    "outputs": [],
518    "source": [
519     "## Check 4: weights and inputs the same after this step \n",
520     "\n",
521     "print(hash2(model_fit.get_weights()))\n",
522     "print(hash2(x_train))\n",
523     "print(hash2(y_train))"
524    ]
525   },
526   {
527    "cell_type": "code",
528    "execution_count": null,
529    "id": "33d6b35c",
530    "metadata": {},
531    "outputs": [],
532    "source": [
533     "print('model_fit input shape',x_train.shape,'output shape',y_train.shape)\n",
534     "print('x_train',x_train)\n",
535     "print('y_train',y_train)"
536    ]
537   },
538   {
539    "cell_type": "code",
540    "execution_count": null,
541    "id": "7a2be797",
542    "metadata": {},
543    "outputs": [],
544    "source": [
545     "reproducibility.set_seed()"
546    ]
547   },
548   {
549    "cell_type": "code",
550    "execution_count": null,
551    "id": "c7857fb4",
552    "metadata": {},
553    "outputs": [],
554    "source": [
555     "model_fit.get_weights()"
556    ]
557   },
558   {
559    "cell_type": "code",
560    "execution_count": null,
561    "id": "9de62d29-2c91-48b5-a92f-da3deae470ca",
562    "metadata": {},
563    "outputs": [],
564    "source": [
565     "model_fit.fit(x_train, y_train, epochs=5000, verbose=2, batch_size=samples)\n",
566     "w_fitted=model_fit.get_weights()\n",
567     "for i in range(len(w)):\n",
568     "    vprint('weight',i,' exact:',w_exact[i],':  initial:',w_initial[i],' fitted:',w_fitted[i])\n",
569     "    \n",
570     "model_predict.set_weights(w_fitted)"
571    ]
572   },
573   {
574    "cell_type": "code",
575    "execution_count": null,
576    "id": "e108b89d-0747-41bd-91fc-5716375d9d2b",
577    "metadata": {},
578    "outputs": [],
579    "source": [
580     "## Check 5: Weights NOT the same after fitting\n",
581     "\n",
582     "hash2(model_fit.get_weights())"
583    ]
584   },
585   {
586    "cell_type": "code",
587    "execution_count": null,
588    "id": "0fabcef2",
589    "metadata": {},
590    "outputs": [],
591    "source": [
592     "model_fit.get_weights()"
593    ]
594   },
595   {
596    "cell_type": "code",
597    "execution_count": null,
598    "id": "2cece158-e7db-4d33-be09-03022e805b06",
599    "metadata": {},
600    "outputs": [],
601    "source": [
602     "model_fit.get_config()"
603    ]
604   },
605   {
606    "cell_type": "code",
607    "execution_count": null,
608    "id": "a7c0a46d-4fe3-4d8c-8411-f95ce2b465ce",
609    "metadata": {},
610    "outputs": [],
611    "source": [
612     "## RNN weights repeated in odd way that looks like untrained\n",
613     "\n",
614     "model_fit.get_weights()"
615    ]
616   },
617   {
618    "cell_type": "code",
619    "execution_count": null,
620    "id": "438b3b9c-7e7a-48a0-ac22-0e3ab4460e57",
621    "metadata": {},
622    "outputs": [],
623    "source": []
624   },
625   {
626    "cell_type": "code",
627    "execution_count": null,
628    "id": "94e9a5c6",
629    "metadata": {},
630    "outputs": [],
631    "source": []
632   }
633  ],
634  "metadata": {
635   "kernelspec": {
636    "display_name": "Python 3 (ipykernel)",
637    "language": "python",
638    "name": "python3"
639   },
640   "language_info": {
641    "codemirror_mode": {
642     "name": "ipython",
643     "version": 3
644    },
645    "file_extension": ".py",
646    "mimetype": "text/x-python",
647    "name": "python",
648    "nbconvert_exporter": "python",
649    "pygments_lexer": "ipython3",
650    "version": "3.10.9"
651   }
652  },
653  "nbformat": 4,
654  "nbformat_minor": 5