From 8004bbf940c1c152e6ee6fbc6f7be76909d79c61 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Wed, 18 Sep 2024 13:27:01 -0600 Subject: [PATCH] Update fmda_rnn_spatial.ipynb --- fmda/fmda_rnn_spatial.ipynb | 356 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 286 insertions(+), 70 deletions(-) diff --git a/fmda/fmda_rnn_spatial.ipynb b/fmda/fmda_rnn_spatial.ipynb index ff0c689..4ccc244 100644 --- a/fmda/fmda_rnn_spatial.ipynb +++ b/fmda/fmda_rnn_spatial.ipynb @@ -35,7 +35,7 @@ "from utils import hash2, read_yml, read_pkl, retrieve_url, Dict\n", "from moisture_rnn import RNN\n", "import reproducibility\n", - "from data_funcs import rmse, to_json, combine_nested\n", + "from data_funcs import rmse, to_json, combine_nested, process_train_dict\n", "from moisture_models import run_augmented_kf\n", "import copy\n", "import pandas as pd\n", @@ -62,8 +62,8 @@ "outputs": [], "source": [ "retrieve_url(\n", - " url = \"https://demo.openwfm.org/web/data/fmda/dicts/test_CA_202401.pkl\", \n", - " dest_path = \"fmda_nw_202401-05_f05.pkl\")" + " url = \"https://demo.openwfm.org/web/data/fmda/dicts/fmda_nw_202401-05_f05.pkl\", \n", + " dest_path = \"data/fmda_nw_202401-05_f05.pkl\")" ] }, { @@ -73,10 +73,7 @@ "metadata": {}, "outputs": [], "source": [ - "repro_file = \"data/reproducibility_dict_v2_TEST.pkl\"\n", - "file_names=['fmda_nw_202401-05_f05.pkl']\n", - "file_dir='data'\n", - "file_paths = [osp.join(file_dir,file_name) for file_name in file_names]" + "file_paths = ['data/fmda_nw_202401-05_f05.pkl']" ] }, { @@ -87,27 +84,49 @@ "outputs": [], "source": [ "# read/write control\n", - "train_file='train.pkl'\n", - "train_create=False # if false, read\n", - "train_write=False\n", + "train_file='data/train.pkl'\n", + "train_create=True # if false, read\n", + "train_write=True\n", "train_read=True" ] }, { "cell_type": "code", "execution_count": null, + "id": "604388de-11ab-45c3-9f0d-80bdff0cca60", + "metadata": {}, + "outputs": [], + "source": [ + "# Params used for data filtering\n", + "params_data = read_yml(\"params_data.yaml\") \n", + "params_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26", + "metadata": {}, + "outputs": [], + "source": [ + "# Params used for setting up RNN\n", + "params = read_yml(\"params.yaml\", subkey='rnn') \n", + "params" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "bc0a775b-b587-42ef-8576-e36dc0be3a75", "metadata": { "scrolled": true }, "outputs": [], "source": [ - "repro = read_pkl(repro_file)\n", - "\n", "if train_create:\n", " logging.info('creating the training cases from files %s',file_paths)\n", " # osp.join works on windows too, joins paths using \\ or /\n", - " train = pkl2train(file_paths)\n", + " train = process_train_dict(file_paths, params_data = params_data, verbose=True)\n", "if train_write:\n", " with open(train_file, 'wb') as file:\n", " logging.info('Writing the rain cases into file %s',train_file)\n", @@ -120,37 +139,51 @@ { "cell_type": "code", "execution_count": null, - "id": "211a1c2f-ba8d-40b8-b29c-daa38af97a26", + "id": "23cd60c0-9865-4314-9a96-948c3d400c08", "metadata": {}, "outputs": [], "source": [ - "params = read_yml(\"params.yaml\", subkey='rnn')\n", - "params" + "from itertools import islice\n", + "train = {k: train[k] for k in islice(train, 150)}" + ] + }, + { + "cell_type": "markdown", + "id": "efc10cdc-f18b-4781-84da-b8e2eef39981", + "metadata": {}, + "source": [ + "## Setup Validation Runs" ] }, { "cell_type": "code", "execution_count": null, - "id": "78cf4dbc-4e7d-4c6d-ac2e-0bac513f92dd", + "id": "66f40c9f-c1c2-4b12-bf14-2ada8c26113d", "metadata": {}, "outputs": [], "source": [ - "# from itertools import islice\n", - "# train = {k: train[k] for k in islice(train, 100)}\n", - "dat = Dict(combine_nested(train))" + "params = RNNParams(params)\n", + "params.update({'epochs': 200, \n", + " 'learning_rate': 0.001,\n", + " 'activation': ['tanh', 'tanh'], # Activation for RNN Layers, Dense layers respectively.\n", + " 'recurrent_layers': 2, 'recurrent_units': 30, \n", + " 'dense_layers': 2, 'dense_units': 30,\n", + " 'early_stopping_patience': 30, # how many epochs of no validation accuracy gain to wait before stopping\n", + " 'batch_schedule_type': 'exp', # Hidden state batch reset schedule\n", + " 'bmin': 20, # Lower bound of hidden state batch reset, \n", + " 'bmax': params_data['hours'], # Upper bound of hidden state batch reset, using max hours\n", + " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind']\n", + " })" ] }, { "cell_type": "code", "execution_count": null, - "id": "e11e7c83-183f-48ba-abd8-a6aedff66090", + "id": "36823193-b93c-421e-b699-8c1ae5719309", "metadata": {}, "outputs": [], "source": [ - "# Set up output dictionaries\n", - "outputs_kf = {}\n", - "outputs_rnn_serial = {}\n", - "outputs_rnn_spatial = {}" + "reproducibility.set_seed(123)" ] }, { @@ -158,28 +191,29 @@ "id": "a24d76fc-6c25-43e7-99df-3cd5dbf84fc3", "metadata": {}, "source": [ - "## Spatial Data Traing" + "## Spatial Data Training" ] }, { "cell_type": "code", "execution_count": null, - "id": "c58f9f89-46d8-407c-be8b-8e5f16dbcc51", + "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52", "metadata": {}, "outputs": [], "source": [ - "params = RNNParams(params)" + "# Start timer for code \n", + "start_time = time.time()" ] }, { "cell_type": "code", "execution_count": null, - "id": "3b5371a9-c1e8-4df5-b360-210746f7cd52", + "id": "faf93470-b55f-4770-9fa9-3288a2f13fcc", "metadata": {}, "outputs": [], "source": [ - "# Start timer\n", - "start_time = time.time()" + "# Combine Nested Dictionary into Spatial Data\n", + "train_sp = Dict(combine_nested(train))" ] }, { @@ -189,35 +223,36 @@ "metadata": {}, "outputs": [], "source": [ - "rnn_dat = RNNData(dat, scaler=\"standard\", \n", - " features_list = ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat',\n", - " 'solar', 'wind'])\n", + "rnn_dat_sp = RNNData(\n", + " train_sp, # input dictionary\n", + " scaler=\"standard\", # data scaling type\n", + " features_list = params['features_list'] # features for predicting outcome\n", + ")\n", "\n", - "rnn_dat.train_test_split( \n", - " time_fracs = [.9, .05, .05],\n", - " space_fracs = [.6, .2, .2]\n", + "\n", + "rnn_dat_sp.train_test_split( \n", + " time_fracs = [.8, .1, .1], # Percent of total time steps used for train/val/test\n", + " space_fracs = [.8, .1, .1] # Percent of total timeseries used for train/val/test\n", ")\n", - "rnn_dat.scale_data()\n", + "rnn_dat_sp.scale_data()\n", "\n", - "rnn_dat.batch_reshape(\n", - " timesteps = params['timesteps'], \n", - " batch_size = params['batch_size']\n", + "rnn_dat_sp.batch_reshape(\n", + " timesteps = params['timesteps'], # Timesteps aka sequence length for RNN input data. \n", + " batch_size = params['batch_size'] # Number of samples of length timesteps for a single round of grad. descent\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "59ddf393-2024-4093-927f-69f135a165b8", + "id": "7431bc95-d384-40fd-a622-bbc0ee68e5cd", "metadata": {}, "outputs": [], "source": [ - "params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours,\n", - " 'loc_batch_reset': rnn_dat.n_seqs, \n", - " 'epochs': 100, 'learning_rate': 0.0001,\n", - " 'activation': ['tanh', 'tanh'],\n", - " 'recurrent_layers': 2, 'recurrent_units': 20, 'dense_layers': 1, 'dense_units': 20,\n", - " 'features_list': rnn_dat.features_list})" + "# Update Params specific to spatial training\n", + "params.update({\n", + " 'loc_batch_reset': rnn_dat_sp.n_seqs # Used to reset hidden state when location changes for a given batch\n", + "})" ] }, { @@ -227,9 +262,8 @@ "metadata": {}, "outputs": [], "source": [ - "reproducibility.set_seed(123)\n", - "rnn = RNN(params)\n", - "m, errs = rnn.run_model(rnn_dat)" + "rnn_sp = RNN(params)\n", + "m, errs = rnn_sp.run_model(rnn_dat_sp)" ] }, { @@ -253,8 +287,8 @@ "end_time = time.time()\n", "\n", "# Calculate Code Runtime\n", - "elapsed_time = end_time - start_time\n", - "print(f\"Spatial Training Elapsed time: {elapsed_time:.4f} seconds\")" + "elapsed_time_sp = end_time - start_time\n", + "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")" ] }, { @@ -268,6 +302,29 @@ { "cell_type": "code", "execution_count": null, + "id": "cca12d8c-c0e1-4df4-b2ca-20440485f2f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Get timeseries IDs from previous RNNData object\n", + "test_cases = rnn_dat_sp.loc['test_locs']\n", + "print(len(test_cases))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "997f2534-7e77-45b3-93bf-d988837dfc0b", + "metadata": {}, + "outputs": [], + "source": [ + "test_ind = rnn_dat_sp.test_ind # Time index for test period start\n", + "print(test_ind)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "1e4ffc68-c775-41c6-ac42-f49c76824b43", "metadata": { "scrolled": true @@ -275,18 +332,18 @@ "outputs": [], "source": [ "outputs_kf = {}\n", - "for case in rnn_dat.loc['test_locs']:\n", + "for case in test_cases:\n", " print(\"~\"*50)\n", " print(case)\n", " # Run Augmented KF\n", " print('Running Augmented KF')\n", - " train[case]['h2'] = train[case]['hours'] // 2\n", + " train[case]['h2'] = test_ind\n", " train[case]['scale_fm'] = 1\n", " m, Ec = run_augmented_kf(train[case])\n", " y = train[case]['y'] \n", - " train[case]['m'] = m\n", - " print(f\"KF RMSE: {rmse(m,y)}\")\n", - " outputs_kf[case] = {'case':case, 'errs': rmse(m,y)}" + " train[case]['m_kf'] = m\n", + " print(f\"KF RMSE: {rmse(m[test_ind:],y[test_ind:])}\")\n", + " outputs_kf[case] = {'case':case, 'errs': rmse(m[test_ind:],y[test_ind:])}" ] }, { @@ -296,56 +353,204 @@ "metadata": {}, "outputs": [], "source": [ - "df2 = pd.DataFrame.from_dict(outputs_kf).transpose()\n", - "df2.head()" + "df_kf = pd.DataFrame.from_dict(outputs_kf).transpose()\n", + "df_kf.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25a9d2fe-83f7-4ef3-a04b-14c970b6e2ba", + "metadata": {}, + "outputs": [], + "source": [ + "df_kf.errs.mean()" ] }, { "cell_type": "markdown", - "id": "86795281-f8ea-4141-81ea-c53fae830e80", + "id": "f616bbf8-d89e-4c5b-9e47-59f02246b6f2", "metadata": {}, "source": [ - "## Compare" + "## Serial Training" ] }, { "cell_type": "code", "execution_count": null, - "id": "508a6392-49bc-4471-ad8e-814f60119283", + "id": "6fa20e9f-604a-4938-ab68-b71fbb7326df", "metadata": {}, "outputs": [], "source": [ - "df2.errs.mean()" + "# Start timer for code \n", + "start_time = time.time()" ] }, { "cell_type": "code", "execution_count": null, - "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0", + "id": "f033e78c-a506-4508-a23c-8e6574014872", "metadata": {}, "outputs": [], "source": [ - "df2.shape" + "# Update Params specific to Serial training\n", + "params.update({\n", + " 'loc_batch_reset': None, # Used to reset hidden state when location changes for a given batch\n", + " 'epochs': 2 # less epochs since fit will be run multiple times over locations\n", + "})" ] }, { "cell_type": "code", "execution_count": null, - "id": "104ea555-1a88-4293-b2a6-dd870fb4b1ed", + "id": "ff1788ec-081b-403f-bcfa-b625f0e3dbe1", "metadata": {}, "outputs": [], "source": [ - "errs.shape" + "train_cases = rnn_dat_sp.loc['train_locs']\n", + "test_cases = rnn_dat_sp.loc['test_locs']" ] }, { "cell_type": "code", "execution_count": null, - "id": "dc1d5cd6-2321-43b2-ab88-7f44806dc73f", + "id": "8a2af45e-e81b-421f-b940-e8779177dd5d", "metadata": {}, "outputs": [], "source": [ - "errs.mean()" + "# Initialize Model with first train case\n", + "rnn_dat = RNNData(train[train_cases[0]], params['scaler'], params['features_list'])\n", + "rnn_dat.train_test_split(\n", + " time_fracs = [.8, .1, .1]\n", + ")\n", + "rnn_dat.scale_data()\n", + "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac6fecc2-f614-4506-b5f9-05a6eca3b62e", + "metadata": {}, + "outputs": [], + "source": [ + "reproducibility.set_seed()\n", + "rnn = RNN(params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79b5af30-7d52-410c-9595-e89e9756fd38", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Train\n", + "for case in train_cases:\n", + " print(\"~\"*50)\n", + " print(f\"Training with Case {case}\")\n", + " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n", + " rnn_dat_temp.train_test_split(\n", + " time_fracs = [.8, .1, .1]\n", + " )\n", + " rnn_dat_temp.scale_data()\n", + " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])\n", + " rnn.fit(rnn_dat_temp['X_train'], rnn_dat_temp['y_train'],\n", + " validation_data=(rnn_dat_temp['X_val'], rnn_dat_temp['y_val'])) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03d716b4-0ff5-4b80-a241-440543ba9b46", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Predict\n", + "outputs_rnn_serial = {}\n", + "test_ind = rnn_dat.test_ind\n", + "for i, case in enumerate(test_cases):\n", + " print(\"~\"*50)\n", + " rnn_dat_temp = RNNData(train[case], params['scaler'], params['features_list'])\n", + " rnn_dat_temp.train_test_split(\n", + " time_fracs = [.8, .1, .1]\n", + " )\n", + " rnn_dat_temp.scale_data()\n", + " rnn_dat_temp.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size']) \n", + " X_temp = rnn_dat_temp.scale_all_X()\n", + " m = rnn.predict(X_temp)\n", + " outputs_rnn_serial[case] = {'case':case, 'errs': rmse(m[test_ind:], rnn_dat.y_test)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5a80bae-fe1a-4ec9-b9ac-31d540eaba40", + "metadata": {}, + "outputs": [], + "source": [ + "df_rnn_serial = pd.DataFrame.from_dict(outputs_rnn_serial).transpose()\n", + "df_rnn_serial.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c5b866e-c2bf-4bc1-8f6f-3ba8a9448d07", + "metadata": {}, + "outputs": [], + "source": [ + "df_rnn_serial.errs.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5a364cb-01bf-49ad-a704-5aa3c9564967", + "metadata": {}, + "outputs": [], + "source": [ + "# End Timer\n", + "end_time = time.time()\n", + "\n", + "# Calculate Code Runtime\n", + "elapsed_time_ser = end_time - start_time\n", + "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "id": "86795281-f8ea-4141-81ea-c53fae830e80", + "metadata": {}, + "source": [ + "## Compare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "508a6392-49bc-4471-ad8e-814f60119283", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Total Test Cases: {len(test_cases)}\")\n", + "print(f\"Total Test Hours: {rnn_dat_temp.y_test.shape[0]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73e8ca05-d17b-4e72-8def-fa77664e7bb0", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Spatial Training RMSE: {errs.mean()}\")\n", + "print(f\"Serial Training RMSE: {df_rnn_serial.errs.mean()}\")\n", + "print(f\"Augmented KF RMSE: {df_kf.errs.mean()}\")" ] }, { @@ -362,6 +567,17 @@ "id": "272bfb32-e8e2-49dd-8f90-4b5b09c3a2a2", "metadata": {}, "outputs": [], + "source": [ + "print(f\"Spatial Training Elapsed time: {elapsed_time_sp:.4f} seconds\")\n", + "print(f\"Serial Training Elapsed time: {elapsed_time_ser:.4f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38ab08fb-ac97-45be-8907-6f9cd124243b", + "metadata": {}, + "outputs": [], "source": [] } ], -- 2.11.4.GIT