From 42df74738c3586f799d243b6c0a398a3bf0fcd88 Mon Sep 17 00:00:00 2001 From: Jan Date: Sat, 2 Jul 2022 13:26:29 -0600 Subject: [PATCH] consolidate creating RNN --- fmda_kf_rnn.ipynb | 87 ++++++++++++++----------------------------------------- 1 file changed, 21 insertions(+), 66 deletions(-) diff --git a/fmda_kf_rnn.ipynb b/fmda_kf_rnn.ipynb index 9077786..1db8668 100644 --- a/fmda_kf_rnn.ipynb +++ b/fmda_kf_rnn.ipynb @@ -1395,7 +1395,7 @@ " y = tf.keras.layers.Dense(hidden_units, activation=activation[1])(x)\n", " outputs = tf.keras.layers.Dense(dense_units, activation=activation[1])(y)\n", " model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", - " model.compile(loss='mean_squared_error', optimizer='sgd')\n", + " model.compile(loss='mean_squared_error', optimizer='adam')\n", " return model" ], "metadata": { @@ -1412,84 +1412,39 @@ }, "outputs": [], "source": [ - "fmda_model=create_RNN_2(hidden_units=7, dense_units=1, \n", - " batch_shape=(samples, timesteps, features),\n", - " input_shape=(timesteps, features),\n", - " stateful = True,\n", - " activation=['tanh', 'tanh'])\n", - "print(fmda_model.summary())" + "def create_fit_predict_RNN(hidden_units, dense_units, \n", + " samples, timesteps, features, activation):\n", + " # statefull model version with with fixed number of batches\n", + " model_fit=create_RNN_2(hidden_units=hidden_units, dense_units=dense_units, \n", + " batch_shape=(samples, timesteps, features),stateful = True,\n", + " activation=activation)\n", + " print(model_fit.summary())\n", + " # same model for prediction on the entire dataset\n", + " model_predict=create_RNN_2(hidden_units=hidden_units, dense_units=dense_units, \n", + " input_shape=(None,features),stateful = False,\n", + " activation=activation)\n", + " print(model_predict.summary())\n", + " return model_fit, model_predict\n", + "\n", + "fmda_model, fmda_model_eval = create_fit_predict_RNN(hidden_units=7, dense_units=1, \n", + " samples=samples, timesteps=timesteps, features=1, \n", + " activation=['tanh', 'tanh'])" ] }, { "cell_type": "code", "source": [ - "fmda_model_eval=create_RNN_2(hidden_units=7, dense_units=1, \n", - " batch_shape=(samples, timesteps, features),\n", - " input_shape=(None,features),\n", - " stateful = False,\n", - " activation=['tanh', 'tanh'])\n", - "print(fmda_model_eval.summary())" - ], - "metadata": { - "id": "Bq-e6mfZLfnE" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "fmda_model.fit(x_train, y_train, epochs=40, verbose=2,batch_size=samples)" - ], - "metadata": { - "id": "VZoa3tlQWbBG" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "" - ], - "metadata": { - "id": "Rl5lbJ7kM3tr" - } - }, - { - "cell_type": "code", - "source": [ + "fmda_model.fit(x_train, y_train, epochs=40, verbose=2,batch_size=samples)\n", "# Same model as stateless for prediction:\n", "w=fmda_model.get_weights()\n", "fmda_model_eval.set_weights(w)\n", "# prediction on the entire dataset from zero state\n", "mt = fmda_model_eval.predict(Et)\n", - "m = scalery.inverse_transform(mt)" - ], - "metadata": { - "id": "R2jkoZlAIaSb" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ + "m = scalery.inverse_transform(mt)\n", "plot_m(m,title='RNN prediction')" ], "metadata": { - "id": "eTeOP9I7sBJ4" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "" - ], - "metadata": { - "id": "jqlUYh8Zt8Cc" + "id": "R2jkoZlAIaSb" }, "execution_count": null, "outputs": [] -- 2.11.4.GIT