From 4042a8f17c3db43cb8c53ccd20931c762d60b9bf Mon Sep 17 00:00:00 2001 From: jh-206 Date: Wed, 21 Aug 2024 09:58:14 -0600 Subject: [PATCH] Update rnn_workshop.ipynb --- fmda/rnn_workshop.ipynb | 84 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 5 deletions(-) diff --git a/fmda/rnn_workshop.ipynb b/fmda/rnn_workshop.ipynb index 348a4c3..176774c 100644 --- a/fmda/rnn_workshop.ipynb +++ b/fmda/rnn_workshop.ipynb @@ -140,6 +140,7 @@ "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n", + "from utils import ndarray_hash, weights_hash\n", "\n", "scalers = {\n", " 'minmax': MinMaxScaler(),\n", @@ -219,7 +220,7 @@ "\n", " # Check for any potential issues with indices\n", " if test_ind > self.hours:\n", - " print(\"Setting test index to {self.hours}\")\n", + " print(f\"Setting test index to {self.hours}\")\n", " test_ind = self.hours\n", " if train_ind >= test_ind:\n", " raise ValueError(\"Train index must be less than test index.\") \n", @@ -257,6 +258,11 @@ " if hasattr(self, 'X_val'):\n", " self.X_val = self.scaler.transform(self.X_val)\n", " self.X_test = self.scaler.transform(self.X_test)\n", + " def print_hashes(self, attrs_to_check = ['X', 'y', 'X_train', 'y_train', 'X_val', 'y_val', 'X_test', 'y_test']):\n", + " for attr in attrs_to_check:\n", + " if hasattr(self, attr):\n", + " value = getattr(self, attr)\n", + " print(f\"Hash of {attr}: {ndarray_hash(value)}\") \n", " def __getattr__(self, key):\n", " try:\n", " return self[key]\n", @@ -319,11 +325,11 @@ { "cell_type": "code", "execution_count": null, - "id": "2659bb7f-b961-4721-8512-c609014daa02", + "id": "4a6f9a1a-1859-4fca-aee9-d330286f0e2f", "metadata": {}, "outputs": [], "source": [ - "d.X" + "d.print_hashes()" ] }, { @@ -339,6 +345,16 @@ { "cell_type": "code", "execution_count": null, + "id": "2cee12b2-931b-4f3f-ae05-f4bc23b9df37", + "metadata": {}, + "outputs": [], + "source": [ + "d.print_hashes()" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "becd6c36-cf9b-43f0-9f08-62dbe1a6ac5e", "metadata": {}, "outputs": [], @@ -347,6 +363,16 @@ ] }, { + "cell_type": "code", + "execution_count": null, + "id": "5c5b0907-ace7-448a-a19d-55647dbcbaf8", + "metadata": {}, + "outputs": [], + "source": [ + "d.print_hashes()" + ] + }, + { "cell_type": "markdown", "id": "2afc2cf7-eab1-4a85-8632-4d306aead358", "metadata": {}, @@ -383,9 +409,9 @@ "metadata": {}, "outputs": [], "source": [ - "params.update({'val_frac': .2, 'scale': True, 'scaler': 'standard', 'epochs': 200})\n", + "params.update({'val_frac': .2, 'scale': True, 'scaler': 'minmax', 'epochs': 500})\n", "# params.update({'features_list': ['wind', 'Ed', 'Ew', 'solar', 'rain']})\n", - "params.update({'rnn_layers': 3})\n", + "params.update({'rnn_layers': 1, 'dense_layers': 1})\n", "rnn_dat = create_rnn_data2(train[case], params)" ] }, @@ -404,6 +430,54 @@ { "cell_type": "code", "execution_count": null, + "id": "bfd419f0-9092-470d-81b7-d3b45e4bdc0b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "545ece65-9f4a-4b45-b87f-ea3a23032cac", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e9ec6f9-8598-4560-b71e-222f5b4c4968", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2a7840d-f7e4-424d-b343-06f913f9d3f6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52e2942b-3bed-4c3d-8082-c7069d791036", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "def73f2c-5d2f-42c6-8c2d-328ac5e8db20", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, "id": "888dd72a-4eef-414b-ac33-f6f4bfbefe60", "metadata": {}, "outputs": [], -- 2.11.4.GIT