From 105f2b072daf77e7c558acaa05c9682d0d70d265 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Sun, 13 Oct 2024 17:34:39 -0600 Subject: [PATCH] use data wrapper --- .../fmda_rnn_test_batch_reset_schedule.ipynb | 54 ++++++++++++++-------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/fmda/test_notebooks/fmda_rnn_test_batch_reset_schedule.ipynb b/fmda/test_notebooks/fmda_rnn_test_batch_reset_schedule.ipynb index 6a1606f..1c6aff9 100644 --- a/fmda/test_notebooks/fmda_rnn_test_batch_reset_schedule.ipynb +++ b/fmda/test_notebooks/fmda_rnn_test_batch_reset_schedule.ipynb @@ -22,7 +22,7 @@ "import sys\n", "sys.path.append('..')\n", "from moisture_rnn_pkl import pkl2train\n", - "from moisture_rnn import RNNParams, RNNData, RNN \n", + "from moisture_rnn import RNNParams, RNNData, RNN, rnn_data_wrap\n", "from utils import hash2, read_yml, read_pkl, retrieve_url, print_dict_summary, print_first, str2time, logging_setup\n", "from moisture_rnn import RNN\n", "import reproducibility\n", @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "filename = \"fmda_nw_202401-05_f05.pkl\"\n", + "filename = \"fmda_rocky_202403-05_f05.pkl\"\n", "retrieve_url(\n", " url = f\"https://demo.openwfm.org/web/data/fmda/dicts/{filename}\", \n", " dest_path = f\"../data/{filename}\")" @@ -93,7 +93,7 @@ "outputs": [], "source": [ "params_data.update({\n", - " 'hours': 3000,\n", + " 'hours': 2205,\n", " 'max_intp_time': 12,\n", " 'zero_lag_threshold': 12\n", "})\n", @@ -131,7 +131,8 @@ " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n", " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n", " 'timesteps': 12,\n", - " 'batch_size': 50\n", + " 'batch_size': 50,\n", + " 'space_fracs': [.8, .1, .1]\n", " })" ] }, @@ -150,26 +151,39 @@ "metadata": {}, "outputs": [], "source": [ - "# train_sp = combine_nested(train)\n", - "rnn_dat = RNNData(\n", - " train, # input dictionary\n", - " scaler=\"standard\", # data scaling type\n", - " features_list = params['features_list'] # features for predicting outcome\n", - ")\n", + "# # train_sp = combine_nested(train)\n", + "# rnn_dat = RNNData(\n", + "# train, # input dictionary\n", + "# scaler=\"standard\", # data scaling type\n", + "# features_list = params['features_list'] # features for predicting outcome\n", + "# )\n", "\n", "\n", - "rnn_dat.train_test_split( \n", - " time_fracs = [.9, .05, .05], # Percent of total time steps used for train/val/test\n", - " space_fracs = [.40, .30, .30] # Percent of total timeseries used for train/val/test\n", - ")\n", - "rnn_dat.scale_data()\n", + "# rnn_dat.train_test_split( \n", + "# time_fracs = [.9, .05, .05], # Percent of total time steps used for train/val/test\n", + "# space_fracs = [.40, .30, .30] # Percent of total timeseries used for train/val/test\n", + "# )\n", + "# rnn_dat.scale_data()\n", "\n", - "rnn_dat.batch_reshape(\n", - " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n", - " batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent\n", - " start_times = np.zeros(len(rnn_dat.loc['train_locs']))\n", - ")\n", + "# rnn_dat.batch_reshape(\n", + "# timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n", + "# batch_size = params['batch_size'], # Number of samples of length timesteps for a single round of grad. descent\n", + "# start_times = np.zeros(len(rnn_dat.loc['train_locs']))\n", + "# )\n", "\n", + "# params.update({\n", + "# 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n", + "# })" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adbba43e-603b-4801-8a35-35b8ccc053af", + "metadata": {}, + "outputs": [], + "source": [ + "rnn_dat = rnn_data_wrap(train, params)\n", "params.update({\n", " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n", "})" -- 2.11.4.GIT