From e0b05b947e9e133e85fe78d69837f2c372faa41e Mon Sep 17 00:00:00 2001 From: jh-206 Date: Mon, 16 Sep 2024 21:06:28 -0600 Subject: [PATCH] Update rnn_workshop.ipynb --- fmda/rnn_workshop.ipynb | 216 +++++++++++++++++------------------------------- 1 file changed, 77 insertions(+), 139 deletions(-) diff --git a/fmda/rnn_workshop.ipynb b/fmda/rnn_workshop.ipynb index ea84754..2442622 100644 --- a/fmda/rnn_workshop.ipynb +++ b/fmda/rnn_workshop.ipynb @@ -137,130 +137,12 @@ { "cell_type": "code", "execution_count": null, - "id": "821fdad6-4e2a-4e7a-81c9-8baa55de741b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9e168412-83f0-4ad1-8f2b-4a5b3ac48239", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "be7737c5-5d91-4af0-962e-07a5e30928eb", - "metadata": {}, - "outputs": [], - "source": [ - "reproducibility.set_seed()\n", - "params.update({'batch_schedule_type': None})\n", - "rnn = RNN(params)\n", - "\n", - "lr_schedule = tf.keras.optimizers.schedules.CosineDecay(\n", - " initial_learning_rate=0.00001,\n", - " decay_steps=1000,\n", - " alpha=0.0,\n", - " name='CosineDecay'\n", - " # warmup_target=None,\n", - " # warmup_steps=100\n", - ")\n", - "\n", - "optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "72731c9a-5f56-4901-83b9-a7db332fd3bc", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b9cfb12-a19e-4f5a-b181-7962527d963a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "558111b9-5d6f-4839-ac78-1b8f8e358760", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "162aba9b-3aba-4874-9bea-5420943642a7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a7d607c-8f29-4a18-948b-4d939ebd5a34", - "metadata": {}, - "outputs": [], - "source": [ - "rnn_dat.spatial" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "552c6e02-4a2d-4f50-9d6a-7e11bdbcfffc", - "metadata": {}, - "outputs": [], - "source": [ - "reproducibility.set_seed()\n", - "params.update({'batch_schedule_type': 'log', 'bmin': 20, 'bmax': rnn_dat.hours})\n", - "rnn = RNN(params)\n", - "m, errs = rnn.run_model(rnn_dat, plot_period=\"predict\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, "id": "db2abad4-16d4-4afc-a0d8-b2dec6b872c2", "metadata": {}, "outputs": [], "source": [] }, { - "cell_type": "code", - "execution_count": null, - "id": "74158e6e-c84f-4a90-9f0a-c35cb711d9ed", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c022cce2-8863-43f4-96e8-c604ba2fe8bc", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9f8a650-330f-493e-a158-a683e2fd872d", - "metadata": {}, - "outputs": [], - "source": [] - }, - { "cell_type": "markdown", "id": "b62f4360-e9d1-4510-bb5d-1d79a3a5ac75", "metadata": {}, @@ -470,80 +352,126 @@ { "cell_type": "code", "execution_count": null, - "id": "2fdb0213-1898-4d99-93dd-e38b72f53ceb", + "id": "303b5071-4254-4a37-95d2-989b7e87be5e", "metadata": {}, "outputs": [], "source": [ - "errs.shape" + "errs.mean()" ] }, { "cell_type": "code", "execution_count": null, - "id": "77ec9ff9-34ae-4221-ab5f-978c85a35c60", + "id": "055d98f5-4028-4822-b409-b03d437490da", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3abc849-6eb7-4b8c-a222-b136998f9db1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b228d5ff-c60e-47fa-b0b1-25446791d232", "metadata": {}, "outputs": [], "source": [ - "errs.mean()" + "import importlib\n", + "import data_funcs\n", + "importlib.reload(data_funcs)\n", + "from data_funcs import process_train_dict" ] }, { "cell_type": "code", "execution_count": null, - "id": "4cc444eb-623f-47b4-9341-a432a14d3953", - "metadata": {}, + "id": "95b4a55f-2dbd-4320-a656-a6a0a3ca8f5a", + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ - "np.median(errs)" + "from data_funcs import process_train_dict\n", + "data_params = read_yml(\"params_data.yaml\")\n", + "data_params.update({\n", + " 'hours': 3648\n", + "})\n", + "train2 = process_train_dict(\"data/fmda_nw_202401-05_f05.pkl\", data_params=data_params, verbose=True)" ] }, { "cell_type": "code", "execution_count": null, - "id": "7f922046-e74f-424e-aa1c-d6d4b2eb3a46", + "id": "b78b45b4-1406-4657-9c31-08c9f24d93a1", "metadata": {}, "outputs": [], "source": [ - "new_data = np.stack(rnn_dat.X_test, axis=0)\n", - "y_array = np.stack(rnn_dat.y_test, axis=0)" + "import importlib\n", + "import moisture_rnn\n", + "importlib.reload(moisture_rnn)\n", + "from moisture_rnn import RNNData" ] }, { "cell_type": "code", "execution_count": null, - "id": "47a098c2-28c3-483d-b062-da1d534f7766", + "id": "6caf2bc4-e9c5-460b-bae5-b15b6221aaa1", "metadata": {}, "outputs": [], "source": [ - "y_array.shape" + "dat = {k: train2[k] for k in islice(train2, 100)}\n", + "dd = combine_nested(dat)\n", + "dd = Dict(dd)\n", + "rnn_dat = RNNData(dd, scaler=\"standard\", \n", + " features_list = ['Ed', 'Ew', 'rain', 'elev', 'lon', 'lat'])\n", + "rnn_dat.train_test_split( \n", + " time_fracs = [.7, .15, .15],\n", + " space_fracs = [.8, .1, .1]\n", + ")\n", + "params.update({'batch_size': 32})\n", + "rnn_dat.batch_reshape(\n", + " timesteps = params['timesteps'], \n", + " batch_size = params['batch_size'],\n", + " start_times = np.zeros(len(rnn_dat.case)).astype(int),\n", + " verbose=False\n", + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "4a581630-2dc0-4cdb-8647-c81d41e149bc", + "id": "72fee44b-4e1f-44c7-bd00-76ae05d5684c", "metadata": {}, "outputs": [], "source": [ - "preds = rnn.model_predict.predict(new_data)\n", - "preds.shape" + "rnn_dat.X_train.shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "303b5071-4254-4a37-95d2-989b7e87be5e", + "id": "63994abb-de21-4e37-9e39-83fdc7043d83", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "rnn_dat.X_val.shape" + ] }, { "cell_type": "code", "execution_count": null, - "id": "055d98f5-4028-4822-b409-b03d437490da", + "id": "c34b46d6-3ac7-4880-92d9-d75c8897c003", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "rnn_dat.X_test[0].shape" + ] }, { "cell_type": "code", @@ -551,7 +479,15 @@ "id": "beb357ab-16dc-4c91-a121-6dfc509f4ff6", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "params.update({'epochs': 25, 'learning_rate': 0.0001, 'verbose_fit': False, 'rnn_layers': 2, 'rnn_units': 20, 'dense_layers': 1, 'dense_units': 10,\n", + " 'activation': ['relu', 'relu'],\n", + " 'features_list': rnn_dat.features_list})\n", + "params.update({'batch_schedule_type': 'exp', 'bmin': 20, 'bmax': rnn_dat.hours})\n", + "reproducibility.set_seed(123)\n", + "rnn = RNN(params)\n", + "m, errs = rnn.run_model(rnn_dat)" + ] }, { "cell_type": "code", @@ -559,7 +495,9 @@ "id": "a319b314-b156-47af-8541-f97145352e5c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "errs.mean()" + ] }, { "cell_type": "code", -- 2.11.4.GIT