Modify ResetStates callback to reset after every N batches given by a param
[notebooks.git] / fmda / version_control / reproducibility_file.ipynb
blob91a03a8747a6d970639624079f7822bf9db1d900
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "c887bd3d-4f89-4d8f-9cc8-b8a20c50b2a2",
6    "metadata": {},
7    "source": [
8     "# Utility to create stable file used for reproduciblity checks\n",
9     "\n",
10     "## v2.1 Code"
11    ]
12   },
13   {
14    "cell_type": "code",
15    "execution_count": 1,
16    "id": "58a6ff55-f5c6-41ee-8c71-fc97813f18d6",
17    "metadata": {},
18    "outputs": [],
19    "source": [
20     "import pickle\n",
21     "import numpy as np\n",
22     "import os.path as osp\n",
23     "import os\n",
24     "import pandas as pd\n",
25     "import tensorflow as tf\n",
26     "import sys\n",
27     "sys.path.append('..')\n",
28     "from moisture_rnn_pkl import pkl2train\n",
29     "from moisture_rnn import RNNParams\n",
30     "from utils import read_yml, read_pkl, print_dict_summary, load_and_fix_data"
31    ]
32   },
33   {
34    "cell_type": "code",
35    "execution_count": 2,
36    "id": "18a399f9-e66a-448f-82bc-8816a19444f0",
37    "metadata": {},
38    "outputs": [],
39    "source": [
40     "pkl_file = \"../data/test_CA_202401.pkl\"\n",
41     "case_name = \"NV020_202401\"\n",
42     "# Destination File\n",
43     "outfile = \"../data/reproducibility_dict_v2_TEST.pkl\""
44    ]
45   },
46   {
47    "cell_type": "markdown",
48    "id": "a4c7c4b3-9652-405f-b0b9-fbc26241afdc",
49    "metadata": {},
50    "source": [
51     "## Read Data and Extract Case"
52    ]
53   },
54   {
55    "cell_type": "markdown",
56    "id": "f9378d70-5e4a-4d26-8445-4ab4c959b3bd",
57    "metadata": {},
58    "source": [
59     "### Read subdict directly"
60    ]
61   },
62   {
63    "cell_type": "code",
64    "execution_count": 3,
65    "id": "513c5f9c-5ce1-440e-85aa-bfb986e74e20",
66    "metadata": {},
67    "outputs": [
68     {
69      "name": "stdout",
70      "output_type": "stream",
71      "text": [
72       "loading file ../data/test_CA_202401.pkl\n",
73       "loc\n",
74       "      STID : NV020\n",
75       "      lat : 38.7482\n",
76       "      lon : -119.53656\n",
77       "      elev : 5677\n",
78       "      pixel_x : 268.6997896202013\n",
79       "      pixel_y : 444.2995027841032\n",
80       "RAWS\n",
81       "     temp: NumPy array of shape (1009,), min: 263.15, max: 280.506\n",
82       "     fm: NumPy array of shape (1009,), min: 8.62, max: 16.2\n",
83       "     rh: NumPy array of shape (1009,), min: 14.52, max: 93.7\n",
84       "     wind: NumPy array of shape (1009,), min: nan, max: nan\n",
85       "     time_raws: NumPy array of shape (1009,), type object\n",
86       "      hours : 1009\n",
87       "     time: NumPy array of shape (168,), type object\n",
88       "     Ed: NumPy array of shape (1009,), min: 8.049790374634147, max: 30.42933179653953\n",
89       "     Ew: NumPy array of shape (1009,), min: 6.999615289304687, max: 28.609014502317358\n",
90       "HRRR\n",
91       "     time: NumPy array of shape (168,), type object\n",
92       "     f00\n",
93       "          temp: NumPy array of shape (168,), min: 262.1491059538274, max: 278.6557143358586\n",
94       "          rh: NumPy array of shape (168,), min: 19.766470197815824, max: 88.04762200131668\n",
95       "          wind: NumPy array of shape (168,), min: 0.4422897586312999, max: 11.740318800249659\n",
96       "          rain: NumPy array of shape (168,), min: 0.0, max: 0.0\n",
97       "          precip_accum: NumPy array of shape (168,), min: 0.0, max: 0.0\n",
98       "          solar: NumPy array of shape (168,), min: 0.0, max: 531.9806798402358\n",
99       "          soilm: NumPy array of shape (168,), min: 0.04917251652964121, max: 0.12315401646785898\n",
100       "          canopyw: NumPy array of shape (168,), min: 0.0, max: 0.0\n",
101       "          groundflux: NumPy array of shape (168,), min: -131.7843148684105, max: 323.2364124714159\n",
102       "          Ed: NumPy array of shape (168,), min: 10.01036196657367, max: 27.084001303686698\n",
103       "          Ew: NumPy array of shape (168,), min: 8.847334592947469, max: 25.46199682778882\n",
104       "           descr : Source: HRRR data from 3d pressure model, linear grid interpolated to RAWS location\n",
105       "     f01\n",
106       "          temp: NumPy array of shape (168,), min: 262.62389410746255, max: 277.81355191096424\n",
107       "          rh: NumPy array of shape (168,), min: 19.11316672559181, max: 91.87131998621496\n",
108       "          wind: NumPy array of shape (168,), min: 1.2391147980713344, max: 13.282109257495303\n",
109       "          rain: NumPy array of shape (168,), min: 0.0, max: 0.00039006309462551956\n",
110       "          precip_accum: NumPy array of shape (168,), min: 0.0, max: 0.945306784924014\n",
111       "          solar: NumPy array of shape (168,), min: 0.0, max: 545.6303389769764\n",
112       "          soilm: NumPy array of shape (168,), min: 0.04868231485555667, max: 0.12315401646785898\n",
113       "          canopyw: NumPy array of shape (168,), min: 0.0, max: 0.0\n",
114       "          groundflux: NumPy array of shape (168,), min: -128.93263270355502, max: 84.48231533782075\n",
115       "          Ed: NumPy array of shape (168,), min: 10.55464038561572, max: 28.68516905374082\n",
116       "          Ew: NumPy array of shape (168,), min: 9.403674257216629, max: 26.943124103890092\n",
117       "           descr : Source: HRRR data from 3d pressure model, linear grid interpolated to RAWS location\n",
118       " case : NV020_202401\n",
119       " filename : ../data/test_CA_202401.pkl\n",
120       " title : NV020_202401\n",
121       " descr : NV020_202401 FMDA dictionary\n"
122      ]
123     }
124    ],
125    "source": [
126     "dat = load_and_fix_data(pkl_file)\n",
127     "print_dict_summary(dat[case_name])"
128    ]
129   },
130   {
131    "cell_type": "markdown",
132    "id": "40eff829-d24e-4ef9-bcb2-773112063359",
133    "metadata": {},
134    "source": [
135     "### Extract processed case"
136    ]
137   },
138   {
139    "cell_type": "code",
140    "execution_count": 4,
141    "id": "1e3403e6-de9c-47a2-af31-190e212210a1",
142    "metadata": {},
143    "outputs": [],
144    "source": [
145     "train = pkl2train([pkl_file])"
146    ]
147   },
148   {
149    "cell_type": "code",
150    "execution_count": 5,
151    "id": "e6a76831-a898-41c0-9caa-0f060f978417",
152    "metadata": {},
153    "outputs": [
154     {
155      "name": "stdout",
156      "output_type": "stream",
157      "text": [
158       " id : NV020_202401\n",
159       " case : NV020_202401\n",
160       " filename : ../data/test_CA_202401.pkl\n",
161       "loc\n",
162       "      STID : NV020\n",
163       "      lat : 38.7482\n",
164       "      lon : -119.53656\n",
165       "      elev : 5677\n",
166       "      pixel_x : 268.6997896202013\n",
167       "      pixel_y : 444.2995027841032\n",
168       " hours : 168\n",
169       " h2 : 168\n",
170       "time: NumPy array of shape (168,), type object\n",
171       " scale_fm : 1\n",
172       "X: NumPy array of shape (168, 8), min: -119.53656, max: 5677.0\n",
173       "features_list: Array of 8 items\n",
174       "y: NumPy array of shape (168,), min: 8.62, max: 15.98\n"
175      ]
176     }
177    ],
178    "source": [
179     "print_dict_summary(train[case_name])"
180    ]
181   },
182   {
183    "cell_type": "markdown",
184    "id": "4597d3f5-2833-41df-b070-f99d4f8b4ff9",
185    "metadata": {},
186    "source": [
187     "## Add Reproducibility Info"
188    ]
189   },
190   {
191    "cell_type": "code",
192    "execution_count": 6,
193    "id": "b7525bd2-5c3e-430a-87e9-3c3ec048553e",
194    "metadata": {},
195    "outputs": [
196     {
197      "data": {
198       "text/plain": [
199        "{'batch_size': 32,\n",
200        " 'timesteps': 5,\n",
201        " 'optimizer': 'adam',\n",
202        " 'rnn_layers': 1,\n",
203        " 'rnn_units': 20,\n",
204        " 'dense_layers': 1,\n",
205        " 'dense_units': 5,\n",
206        " 'activation': ['linear', 'linear'],\n",
207        " 'centering': [0.0, 0.0],\n",
208        " 'dropout': [0.2, 0.2],\n",
209        " 'recurrent_dropout': 0.2,\n",
210        " 'reset_states': True,\n",
211        " 'batch_reset': None,\n",
212        " 'epochs': 300,\n",
213        " 'learning_rate': 0.001,\n",
214        " 'clipvalue': 10.0,\n",
215        " 'phys_initialize': False,\n",
216        " 'stateful': True,\n",
217        " 'verbose_weights': True,\n",
218        " 'verbose_fit': False,\n",
219        " 'features_list': ['Ed', 'Ew', 'solar', 'wind', 'rain'],\n",
220        " 'scale': True,\n",
221        " 'scaler': 'minmax',\n",
222        " 'train_frac': 0.5,\n",
223        " 'val_frac': 0.2}"
224       ]
225      },
226      "execution_count": 6,
227      "metadata": {},
228      "output_type": "execute_result"
229     }
230    ],
231    "source": [
232     "params = read_yml('../params.yaml', subkey=\"rnn_repro\")\n",
233     "params"
234    ]
235   },
236   {
237    "cell_type": "code",
238    "execution_count": 7,
239    "id": "e820012b-6f48-4c58-bc6e-6c9811935ec9",
240    "metadata": {},
241    "outputs": [
242     {
243      "name": "stdout",
244      "output_type": "stream",
245      "text": [
246       "Checking params...\n",
247       "Input dictionary passed all checks.\n",
248       "Calculating shape params based on features list, timesteps, and batch size\n",
249       "Input Feature List: ['Ed', 'Ew', 'solar', 'wind', 'rain']\n",
250       "Input Timesteps: 5\n",
251       "Input Batch Size: 32\n",
252       "Calculated params:\n",
253       "Number of features: 5\n",
254       "Batch Shape: (32, 5, 5)\n",
255       "{'batch_size': 32, 'timesteps': 5, 'optimizer': 'adam', 'rnn_layers': 1, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 5, 'activation': ['linear', 'linear'], 'centering': [0.0, 0.0], 'dropout': [0.2, 0.2], 'recurrent_dropout': 0.2, 'reset_states': True, 'batch_reset': None, 'epochs': 300, 'learning_rate': 0.001, 'clipvalue': 10.0, 'phys_initialize': False, 'stateful': True, 'verbose_weights': True, 'verbose_fit': False, 'features_list': ['Ed', 'Ew', 'solar', 'wind', 'rain'], 'scale': True, 'scaler': 'minmax', 'train_frac': 0.5, 'val_frac': 0.2, 'n_features': 5, 'batch_shape': (32, 5, 5)}\n"
256      ]
257     }
258    ],
259    "source": [
260     "repro_info = {\n",
261     "    'phys_initialize': \"NOT YET IMPLEMENTED WITH v2.1\",\n",
262     "    'rand_initialize':{\n",
263     "        'fitted_weights_hash': '01513ac086d842dc67d40eb94ee1110c',\n",
264     "        'preds_hash': '4999d10893207f2b40086e3f84c214a3'\n",
265     "    },\n",
266     "    'env_info':{\n",
267     "        'py_version': sys.version[0:6],\n",
268     "        'tf_version': tf.__version__,\n",
269     "        'seed': 123\n",
270     "    },\n",
271     "    'params': RNNParams(params)\n",
272     "}\n",
273     "\n",
274     "train[case_name]['repro_info'] = repro_info"
275    ]
276   },
277   {
278    "cell_type": "markdown",
279    "id": "d15f56fe-a5b1-4d66-91c2-0c9a7f7929bf",
280    "metadata": {},
281    "source": [
282     "## Write Output"
283    ]
284   },
285   {
286    "cell_type": "code",
287    "execution_count": 8,
288    "id": "4cc57cd9-4dd7-488d-abd5-b253aa2cee7d",
289    "metadata": {},
290    "outputs": [
291     {
292      "name": "stdout",
293      "output_type": "stream",
294      "text": [
295       "Writing file: ../data/reproducibility_dict_v2_TEST.pkl\n"
296      ]
297     }
298    ],
299    "source": [
300     "with open(outfile, 'wb') as file:\n",
301     "    print(f\"Writing file: {outfile}\")\n",
302     "    pickle.dump(train[case_name], file)"
303    ]
304   },
305   {
306    "cell_type": "code",
307    "execution_count": null,
308    "id": "d1846862-af98-4781-bcad-e8f32a87412c",
309    "metadata": {},
310    "outputs": [],
311    "source": []
312   }
313  ],
314  "metadata": {
315   "kernelspec": {
316    "display_name": "Python 3 (ipykernel)",
317    "language": "python",
318    "name": "python3"
319   },
320   "language_info": {
321    "codemirror_mode": {
322     "name": "ipython",
323     "version": 3
324    },
325    "file_extension": ".py",
326    "mimetype": "text/x-python",
327    "name": "python",
328    "nbconvert_exporter": "python",
329    "pygments_lexer": "ipython3",
330    "version": "3.12.5"
331   }
332  },
333  "nbformat": 4,
334  "nbformat_minor": 5