From c0daaaa548fa380d12457a4cd21ba4502c640eb9 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Mon, 17 Jun 2024 09:58:26 -0600 Subject: [PATCH] Update rnn_workshop.ipynb --- fmda/rnn_workshop.ipynb | 103 +++++++++--------------------------------------- 1 file changed, 19 insertions(+), 84 deletions(-) diff --git a/fmda/rnn_workshop.ipynb b/fmda/rnn_workshop.ipynb index 15e231c..5ab52ac 100644 --- a/fmda/rnn_workshop.ipynb +++ b/fmda/rnn_workshop.ipynb @@ -18,7 +18,7 @@ "import reproducibility\n", "from utils import print_dict_summary\n", "from data_funcs import load_and_fix_data, rmse\n", - "from moisture_rnn import RNN, create_rnn_data2\n", + "from moisture_rnn import RNN, RNN_LSTM, create_rnn_data2\n", "from moisture_rnn_pkl import pkl2train\n", "from tensorflow.keras.callbacks import Callback\n", "from sklearn.metrics import mean_squared_error\n", @@ -164,10 +164,27 @@ { "cell_type": "code", "execution_count": null, - "id": "88a73c40-0bab-4aa0-8265-00f427aa97ea", + "id": "71d4e441-9bf1-4d57-bb37-091553e23212", "metadata": {}, "outputs": [], "source": [ + "import importlib \n", + "import moisture_rnn\n", + "importlib.reload(moisture_rnn)\n", + "from moisture_rnn import RNN_LSTM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88a73c40-0bab-4aa0-8265-00f427aa97ea", + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ "from moisture_rnn import RNN_LSTM\n", "\n", "# from tensorflow.keras.layers import LSTM, Input, Dropout, Dense, SimpleRNN\n", @@ -388,76 +405,6 @@ { "cell_type": "code", "execution_count": null, - "id": "89f3ee62-4bef-4eb9-a599-405beaa0632d", - "metadata": {}, - "outputs": [], - "source": [ - "rnn_dat = create_rnn_data2(train['reproducibility'],params2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ac7b05a3-8788-4fcb-9a8b-3c8c28426803", - "metadata": {}, - "outputs": [], - "source": [ - "# from tensorflow.keras.layers import LSTM, Input, Dropout, Dense\n", - "reproducibility.set_seed()\n", - "params2.update({'epochs': 50})\n", - "lstm = RNN_LSTM(params2)\n", - "lstm.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f51fbc8c-9b0f-42c9-bfbb-5f4c0cacaa27", - "metadata": {}, - "outputs": [], - "source": [ - "lstm = RNN_LSTM(params2)\n", - "lstm.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"], \n", - " validation_data = (rnn_dat['X_val'], rnn_dat['y_val']))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1df67d38-fbe0-4e3b-9dd9-ec25d8378e4b", - "metadata": {}, - "outputs": [], - "source": [ - "# from moisture_rnn import repro_hashes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "51ecf9fd-d026-4edf-91ba-6fd050dbd1d0", - "metadata": {}, - "outputs": [], - "source": [ - "lstm = RNN_LSTM(params2)\n", - "lstm.run_model(rnn_dat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "71d4e441-9bf1-4d57-bb37-091553e23212", - "metadata": {}, - "outputs": [], - "source": [ - "import importlib \n", - "import moisture_rnn\n", - "importlib.reload(moisture_rnn)\n", - "from moisture_rnn import RNN_LSTM" - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "59480f19-3567-4b24-b6ff-d9292dc8c2ec", "metadata": {}, "outputs": [], @@ -481,18 +428,6 @@ { "cell_type": "code", "execution_count": null, - "id": "95d7ae31-e3fb-4a44-95bd-093ec34e0ce5", - "metadata": {}, - "outputs": [], - "source": [ - "reproducibility.set_seed()\n", - "lstm = RNN_LSTM(params)\n", - "lstm.fit(rnn_dat[\"X_train\"], rnn_dat[\"y_train\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "6a9d612e-8cd2-40ca-a789-91c99c3d6ccd", "metadata": {}, "outputs": [], -- 2.11.4.GIT