From 1a6ebb4c83306ebcfbfd0b7c15aacd4073048d2e Mon Sep 17 00:00:00 2001 From: jh-206 Date: Tue, 22 Oct 2024 16:38:33 -0600 Subject: [PATCH] Update fmda_rnn_train_and_save.ipynb --- fmda/test_notebooks/fmda_rnn_train_and_save.ipynb | 25 ++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/fmda/test_notebooks/fmda_rnn_train_and_save.ipynb b/fmda/test_notebooks/fmda_rnn_train_and_save.ipynb index 354da31..8b396b0 100644 --- a/fmda/test_notebooks/fmda_rnn_train_and_save.ipynb +++ b/fmda/test_notebooks/fmda_rnn_train_and_save.ipynb @@ -5,9 +5,9 @@ "id": "83b774b3-ef55-480a-b999-506676e49145", "metadata": {}, "source": [ - "# v2.2 run RNN with Spatial Training\n", + "# v2.3 run RNN and Save\n", "\n", - "This notebook is intended to set up a test where the RNN is run serial by location and compared to the spatial training scheme. Additionally, the ODE model with the augmented KF will be run as a comparison, but note that the RNN models will be predicting entirely without knowledge of the heldout locations, while the augmented KF will be run directly on the test locations.\n" + "This notebook is intended to test traing and then saving a model object for later use.\n" ] }, { @@ -114,7 +114,11 @@ "source": [ "# Params used for setting up RNN\n", "params = read_yml(\"../params.yaml\", subkey='rnn') \n", - "params" + "params.update({\n", + " 'hidden_layers': ['dense', 'lstm', 'attention', 'dense'],\n", + " 'hidden_units': [64, 32, None, 32],\n", + " 'hidden_activation': ['relu', 'tanh', None, 'relu']\n", + "})" ] }, { @@ -262,7 +266,7 @@ "id": "62c1b049-304e-4c90-b1d2-b9b96b9a202f", "metadata": {}, "source": [ - "## Save Model" + "## Save " ] }, { @@ -274,7 +278,18 @@ "source": [ "outpath = \"../outputs/models\"\n", "filename = osp.join(outpath, f\"model_predict_raws_rocky.keras\")\n", - "rnn_sp.model_predict.save(filename)" + "rnn_sp.model_predict.save(filename) # save prediction model only" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf6231a9-0c7b-45ba-ac75-7fb5b6124c72", + "metadata": {}, + "outputs": [], + "source": [ + "with open(f\"{outpath}/rnn_data_rocky.pkl\", 'wb') as file:\n", + " pickle.dump(rnn_dat_sp, file)" ] }, { -- 2.11.4.GIT