4 "cell_type": "markdown",
5 "id": "83b774b3-ef55-480a-b999-506676e49145",
8 "# Compare Batch Resetting Schedules\n"
13 "execution_count": null,
14 "id": "83cc1dc4-3dcb-4325-9263-58101a3dc378",
18 "import numpy as np\n",
19 "from utils import print_dict_summary, print_first, str2time, logging_setup\n",
22 "import os.path as osp\n",
23 "from moisture_rnn_pkl import pkl2train\n",
24 "from moisture_rnn import RNNParams, RNNData, RNN \n",
25 "from utils import hash2, read_yml, read_pkl, retrieve_url\n",
26 "from moisture_rnn import RNN\n",
27 "import reproducibility\n",
28 "from data_funcs import rmse, to_json, combine_nested, process_train_dict\n",
29 "from moisture_models import run_augmented_kf\n",
31 "import pandas as pd\n",
32 "import matplotlib.pyplot as plt\n",
35 "import reproducibility\n",
36 "import tensorflow as tf"
41 "execution_count": null,
42 "id": "17db9b90-a931-4674-a447-5b8ffbcdc86a",
51 "execution_count": null,
52 "id": "35319c1c-7849-4b8c-8262-f5aa6656e0c7",
57 " url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n",
58 " dest_path = \"fmda_nw_202401-05_f05.pkl\")"
63 "execution_count": null,
64 "id": "eabdbd9c-07d9-4bae-9851-cca79f321895",
68 "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n",
69 "file_names=['fmda_nw_202401-05_f05.pkl']\n",
71 "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]"
76 "execution_count": null,
77 "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26",
81 "params = RNNParams(read_yml(\"params.yaml\", subkey='rnn'))\n",
82 "params_data = read_yml(\"params_data.yaml\")"
87 "execution_count": null,
88 "id": "5f3d09b4-c2fd-4556-90b7-e547431ca523",
94 "data_params = read_yml(\"params_data.yaml\")\n",
95 "data_params.update({\n",
97 " 'max_intp_time': 24,\n",
98 " 'zero_lag_threshold': 24\n",
100 "train = process_train_dict([\"data/fmda_nw_202401-05_f05.pkl\"], params_data=params_data, verbose=True)"
105 "execution_count": null,
106 "id": "4b0c9a9b-dd02-4251-aa4a-2acc1101e153",
110 "from itertools import islice\n",
111 "train = {k: train[k] for k in islice(train, 250)}"
116 "execution_count": null,
117 "id": "35ae0bdb-a209-429f-8116-c5e1dccafb89",
121 "## params = RNNParams(read_yml(\"params.yaml\", subkey=\"rnn\"))\n",
122 "params.update({'epochs': 200, \n",
123 " 'learning_rate': 0.001,\n",
124 " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n",
125 " 'recurrent_layers': 1, 'recurrent_units': 30, \n",
126 " 'dense_layers': 1, 'dense_units': 30,\n",
127 " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n",
128 " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n",
129 " 'bmin': 20, # Lower bound of hidden state batch reset, \n",
130 " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n",
131 " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n",
132 " 'timesteps': 12\n",
137 "cell_type": "markdown",
138 "id": "d6751dcc-ba4c-47d5-90d2-60f4a61e96fa",
146 "execution_count": null,
147 "id": "cc23e05f-282d-424b-9a50-6e32d9ae4095",
151 "train_sp = combine_nested(train)\n",
152 "rnn_dat = RNNData(\n",
153 " train_sp, # input dictionary\n",
154 " scaler=\"standard\", # data scaling type\n",
155 " features_list = params['features_list'] # features for predicting outcome\n",
159 "rnn_dat.train_test_split( \n",
160 " time_fracs = [.9, .05, .05], # Percent of total time steps used for train/val/test\n",
161 " space_fracs = [.40, .30, .30] # Percent of total timeseries used for train/val/test\n",
163 "rnn_dat.scale_data()\n",
165 "rnn_dat.batch_reshape(\n",
166 " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n",
167 " batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent\n",
168 " start_times = np.zeros(len(rnn_dat.loc['train_locs']))\n",
172 " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n",
177 "cell_type": "markdown",
178 "id": "703dca05-5371-409e-b0a6-c430594bb76f",
186 "execution_count": null,
187 "id": "4eaaf547-1967-4325-be61-b5e8ed33141f",
192 " 'stateful': False,\n",
193 " 'batch_schedule_type': None\n",
199 "execution_count": null,
200 "id": "bfebe87d-1bbb-48c5-9b32-836d19d16787",
204 "reproducibility.set_seed(123)\n",
205 "rnn = RNN(params)\n",
206 "m0, errs0, epochs0 = rnn.run_model(rnn_dat, return_epochs=True)"
211 "execution_count": null,
212 "id": "25581ed5-7fff-402f-b902-ed32bbcf1c0c",
220 "cell_type": "markdown",
221 "id": "80026da4-e8ec-4803-b791-66110c4b10d9",
224 "## Constant Batch Schedule (Stateful)"
229 "execution_count": null,
230 "id": "e7335f3d-3d8a-4733-9105-faed311a7df3",
235 " 'stateful': True, \n",
236 " 'batch_schedule_type':'constant', \n",
242 "execution_count": null,
243 "id": "e1febc02-35af-4325-ad02-4f6a2ce065fd",
247 "reproducibility.set_seed(123)\n",
248 "rnn = RNN(params)\n",
249 "m2, errs2, epochs2 = rnn.run_model(rnn_dat, return_epochs=True)"
254 "execution_count": null,
255 "id": "4ad5e3cc-a4a4-4a54-9bc3-9b8b88b454ee",
263 "cell_type": "markdown",
264 "id": "7082eeb3-5173-4226-85db-4bb5a26f67a4",
267 "## Exp Batch Schedule (Stateful)"
272 "execution_count": null,
273 "id": "2a428391-400d-476b-990e-c5e1d9cba7f8",
278 " 'stateful': True, \n",
279 " 'batch_schedule_type':'exp', \n",
281 " 'bmax': rnn_dat.hours\n",
287 "execution_count": null,
288 "id": "9f77ee03-01f8-415f-9748-b986a77f2982",
292 "reproducibility.set_seed(123)\n",
293 "rnn = RNN(params)\n",
294 "m3, errs3, epochs3 = rnn.run_model(rnn_dat, return_epochs=True)"
299 "execution_count": null,
300 "id": "d2b1eaf5-b710-431a-a10e-0098f713c325",
310 "display_name": "Python 3 (ipykernel)",
311 "language": "python",
319 "file_extension": ".py",
320 "mimetype": "text/x-python",
322 "nbconvert_exporter": "python",
323 "pygments_lexer": "ipython3",