From 7735aad335b9a7f3c5bdef182b81e5cb6c7b8fb2 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Wed, 16 Oct 2024 16:43:22 -0600 Subject: [PATCH] Update rnn_workshop.ipynb --- fmda/rnn_workshop.ipynb | 176 ++++++++++++++++++++++++++++++------------------ 1 file changed, 111 insertions(+), 65 deletions(-) diff --git a/fmda/rnn_workshop.ipynb b/fmda/rnn_workshop.ipynb index 3d69b29..09dc0a5 100644 --- a/fmda/rnn_workshop.ipynb +++ b/fmda/rnn_workshop.ipynb @@ -449,7 +449,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e86a5e2-f8df-4bee-b0db-365e46e9a6f9", + "id": "cb4d2490-9a74-4cd4-b496-2c0397e4913b", "metadata": {}, "outputs": [], "source": [ @@ -475,13 +475,9 @@ " self.model_train.compile(loss='mean_squared_error', optimizer=optimizer)\n", " self.model_predict.compile(loss='mean_squared_error', optimizer=optimizer)\n", "\n", - " def _build_model_train(self):\n", + " def _build_hidden_layers(self, x, stateful):\n", " params = self.params\n", " \n", - " # Define the input layer with the specified batch size, timesteps, and features\n", - " inputs = tf.keras.Input(batch_shape=(params['batch_size'], params['timesteps'], params['n_features']))\n", - " x = inputs\n", - " \n", " # Loop over each layer specified in 'hidden_layers'\n", " for i, layer_type in enumerate(params['hidden_layers']):\n", " units = params['hidden_units'][i]\n", @@ -495,16 +491,26 @@ " \n", " elif layer_type == 'rnn':\n", " x = layers.SimpleRNN(units=units, activation=activation, dropout=params['dropout'], recurrent_dropout=params['recurrent_dropout'],\n", - " return_sequences=True, stateful=True)(x)\n", + " return_sequences=True, stateful=stateful)(x)\n", " \n", " elif layer_type == 'lstm':\n", " x = layers.LSTM(units=units, activation=activation, dropout=params['dropout'], recurrent_dropout=params['recurrent_dropout'],\n", - " return_sequences=True, stateful=True)(x) \n", + " return_sequences=True, stateful=stateful)(x) \n", " \n", " elif layer_type == 'attention':\n", " # Self-attention mechanism\n", " x = layers.Attention()([x, x])\n", - " \n", + " return x\n", + " \n", + " def _build_model_train(self):\n", + " params = self.params\n", + " \n", + " # Define the input layer with the specified batch size, timesteps, and features\n", + " inputs = tf.keras.Input(batch_shape=(params['batch_size'], params['timesteps'], params['n_features']))\n", + " x = inputs\n", + " # Build hidden layers\n", + " x = self._build_hidden_layers(x, stateful = params['stateful']) \n", + "\n", " # Add the output layer\n", " if params['output_layer'] == 'dense':\n", " outputs = layers.Dense(units=params['output_dimension'], activation=params['output_activation'])(x)\n", @@ -521,24 +527,8 @@ " # Define the input layer with flexible batch size and sequence length\n", " inputs = tf.keras.Input(shape=(None, params['n_features']))\n", " x = inputs\n", - " \n", - " # Loop over each layer specified in 'hidden_layers'\n", - " for i, layer_type in enumerate(params['hidden_layers']):\n", - " units = params['hidden_units'][i]\n", - " activation = params['hidden_activation'][i]\n", - " \n", - " if layer_type == 'dense':\n", - " x = layers.Dense(units=units, activation=activation)(x)\n", - " \n", - " elif layer_type == 'rnn':\n", - " x = layers.SimpleRNN(units=units, activation=activation, return_sequences=True, stateful=False)(x)\n", - " \n", - " elif layer_type == 'lstm':\n", - " x = layers.LSTM(units=units, activation=activation, return_sequences=True, stateful=False)(x)\n", - " \n", - " elif layer_type == 'attention':\n", - " # Self-attention mechanism\n", - " x = layers.Attention()([x, x])\n", + " # Build hidden layers\n", + " x = self._build_hidden_layers(x, stateful=False) \n", " \n", " # Add the output layer\n", " if params['output_layer'] == 'dense':\n", @@ -675,12 +665,63 @@ " print(\"Predicting test data\")\n", " preds = self.model_predict.predict(X_test)\n", " \n", - " return preds\n" + " return preds\n", + "\n", + " def run_model(self, data, reproducibility_run=False, plot_period='all', return_epochs=False):\n", + "\n", + " # Set up print statements with verbose args\n", + " verbose_fit = self.params['verbose_fit']\n", + " verbose_weights = self.params['verbose_weights']\n", + " if verbose_weights:\n", + " data.print_hashes() \n", + "\n", + " # Set up name for run, used for plotting\n", + " case_id = \"Spatial Training Set\" if data.spatial else data.id\n", + " print(f\"Running {case_id}\")\n", + "\n", + " # Extract Datasets\n", + " X_train, y_train, X_test, y_test = data.X_train, data.y_train, data.X_test, data.y_test\n", + " X_val, y_val = data.get('X_val', None), data.get('y_val', None)\n", + " validation_data = (X_val, y_val) if X_val is not None and y_val is not None else None\n", + " \n", + " # Fit model, assign epochs to object, will just asign None if return_epochs is false\n", + " # NOTE: when using early stopping, number of epochs much be extracted here at the fit call\n", + " eps = self.fit(X_train, y_train, validation_data=validation_data, plot_title=case_id, return_epochs=return_epochs, verbose_fit=verbose_fit)\n", + "\n", + " # Generate Predictions\n", + " m = self.predict(X_test)\n", + " errs = eval_errs(m, y_test)\n", + "\n", + " return m, errs\n", + "\n", + "\n", + "def eval_errs(preds, y_test):\n", + " \"\"\"\n", + " Calculate RMSE for ndarrays structured as (batch_size, timesteps, features). \n", + " The first dimension, batch_size, could denote distinct locations. The second, timesteps, is length of sequence\n", + " \"\"\"\n", + " squared_diff = np.square(preds - y_test)\n", + " \n", + " # Mean squared error along the timesteps and dimensions (axis 1 and 2)\n", + " mse = np.mean(squared_diff, axis=(1, 2))\n", + " \n", + " # Root mean squared error (RMSE) for each timeseries\n", + " rmses = np.mean(np.sqrt(mse))\n", + " \n", + " return rmses" ] }, { "cell_type": "code", "execution_count": null, + "id": "e31ca700-7871-4fb5-b5d2-9b63b85f6688", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, "id": "42459583-a634-4dd9-a94b-4535302f481d", "metadata": {}, "outputs": [], @@ -689,9 +730,9 @@ " 'n_features': 8,\n", " 'timesteps': 12,\n", " 'batch_size': 32,\n", - " 'hidden_layers': ['dense', 'lstm', 'attention', 'dense'],\n", - " 'hidden_units': [64, 32, None, 32],\n", - " 'hidden_activation': ['relu', 'tanh', None, 'relu'],\n", + " 'hidden_layers': ['dense', 'lstm', 'attention', 'dense', 'dense'],\n", + " 'hidden_units': [64, 32, None, 32, 16],\n", + " 'hidden_activation': ['relu', 'tanh', None, 'relu', 'relu'],\n", " 'dropout': 0.2,\n", " 'recurrent_dropout': 0.2,\n", " 'output_layer': 'dense',\n", @@ -699,12 +740,20 @@ " 'output_dimension': 1,\n", " 'learning_rate': 0.001,\n", " 'early_stopping_patience': 5,\n", - " 'epochs': 10,\n", + " 'epochs': 3,\n", " 'reset_states': True,\n", " 'bmin': 10,\n", " 'bmax': 200,\n", " 'batch_schedule_type': 'step',\n", - " 'estep': 5\n", + " 'estep': 5,\n", + " 'features_list': ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat', 'solar', 'wind'],\n", + " 'scaler': 'standard',\n", + " 'time_fracs': [.8, .1, .1],\n", + " 'space_fracs': [.8, .1, .1],\n", + " 'stateful': True,\n", + " 'verbose_fit': True,\n", + " 'verbose_weights': False,\n", + " # 'return_sequences': True # whether or not the LAST recurrent layer should return sequences. If multiple, all previous need to be True\n", "}" ] }, @@ -721,10 +770,21 @@ { "cell_type": "code", "execution_count": null, + "id": "aa9117f1-39be-4ce0-81bc-c3bbbdf0aa95", + "metadata": {}, + "outputs": [], + "source": [ + "params_test['scaler']" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "7c1627f9-f011-4159-98a2-1b5973929e71", "metadata": {}, "outputs": [], "source": [ + "reproducibility.set_seed()\n", "mod = RNN2(params_test)" ] }, @@ -745,76 +805,64 @@ "metadata": {}, "outputs": [], "source": [ - "rnn_dat = rnn_data_wrap(combine_nested(train3), params)" + "rnn_dat = rnn_data_wrap(combine_nested(train3), params_test)" ] }, { "cell_type": "code", "execution_count": null, - "id": "49f0bd69-aa76-4d7e-a4eb-3c7a800b6d2f", + "id": "e213ffd7-d26c-41ce-8e2b-b17368fdd7a8", "metadata": {}, "outputs": [], "source": [ - "reproducibility.set_seed()\n", - "mod.fit(rnn_dat.X_train, rnn_dat.y_train, \n", - " validation_data = (rnn_dat.X_val, rnn_dat.y_val), verbose_fit = True)" + "params.update({\n", + " 'loc_batch_reset': rnn_dat.n_seqs # Used to reset hidden state when location changes for a given batch\n", + "})" ] }, { "cell_type": "code", "execution_count": null, - "id": "e0398824-16aa-4f30-a2c2-6e53d54778c1", + "id": "eb6f246f-c02e-434e-89d1-f6d9ec1c6b26", "metadata": {}, "outputs": [], "source": [ - "hash_weights(mod.model_train)" + "reproducibility.set_seed()\n", + "mod3 = RNN3(params_test)\n", + "mod3.model_train.summary()" ] }, { "cell_type": "code", "execution_count": null, - "id": "b6ed3cbd-b78e-4d8c-be61-c818055fb539", + "id": "74e599b6-7f4d-4175-a5f1-de892e72ebd4", "metadata": {}, "outputs": [], - "source": [ - "hash_weights(mod.model_predict)" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "49ed89b7-1685-4f9f-8e8f-d7cac0c7035f", + "id": "f894d203-d277-48f3-bb57-a610f162361f", "metadata": {}, "outputs": [], - "source": [ - "preds = mod.predict(rnn_dat.X_test)" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "7a2a2be2-6a16-49a5-860a-9e5f19f95318", + "id": "6a3f38aa-ed3f-4511-a7ec-f3403aa4c717", "metadata": {}, "outputs": [], - "source": [ - "preds.shape" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "b470aaec-4eed-4327-b5ee-3e6a4b24d7cb", + "id": "66ddc2cd-0308-4622-bbb1-d26d08292159", "metadata": {}, "outputs": [], - "source": [ - "squared_diff = np.square(preds - rnn_dat.y_test)\n", - "\n", - "# Mean squared error along the timesteps and dimensions (axis 1 and 2)\n", - "mse = np.mean(squared_diff, axis=(1, 2))\n", - "\n", - "# Root mean squared error (RMSE) for each timeseries\n", - "np.mean(np.sqrt(mse))" - ] + "source": [] }, { "cell_type": "markdown", @@ -971,9 +1019,7 @@ { "cell_type": "markdown", "id": "5ef092ff-8af1-491a-b0bf-cc3e674330e0", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, + "metadata": {}, "source": [ "## Phys Initialized" ] -- 2.11.4.GIT