From e71d5397f8df10e595686b451fed288fbcd1a630 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Mon, 9 Sep 2024 16:40:21 -0600 Subject: [PATCH] Update batch_reset_param_tutorial.ipynb --- .../presentations/batch_reset_param_tutorial.ipynb | 140 +++++++-------------- 1 file changed, 45 insertions(+), 95 deletions(-) diff --git a/fmda/presentations/batch_reset_param_tutorial.ipynb b/fmda/presentations/batch_reset_param_tutorial.ipynb index 0800307..695f056 100644 --- a/fmda/presentations/batch_reset_param_tutorial.ipynb +++ b/fmda/presentations/batch_reset_param_tutorial.ipynb @@ -214,64 +214,8 @@ "id": "4958e7e0-1216-4392-a239-399f47917d98", "metadata": {}, "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86019514-5972-4a7e-b154-a5ef6d100b51", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c2e8c55-a749-4bf9-a11d-467f6b2759bd", - "metadata": {}, - "outputs": [], - "source": [ - "# import importlib\n", - "# import utils\n", - "# importlib.reload(utils)\n", - "# from utils import linear_increase, exp_increase, log_incraese" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "760288e8-6397-4283-bcc3-5cbfc49678fd", - "metadata": {}, - "outputs": [], "source": [ - "def linear_increase(max_epochs, bmin, bmax):\n", - " def func(epoch):\n", - " if epoch < max_epochs:\n", - " return bmin + (bmax - bmin) * (epoch / max_epochs)\n", - " else:\n", - " return bmax\n", - " return func\n", - "\n", - "\n", - "def exp_increase(max_epochs, bmin, bmax):\n", - " def func(epoch):\n", - " if epoch < max_epochs:\n", - " factor = epoch / max_epochs\n", - " return bmin * (bmax / bmin) ** factor\n", - " else:\n", - " return bmax\n", - " return func\n", - "\n", - "import numpy as np\n", - "\n", - "def log_increase(max_epochs, bmin, bmax):\n", - " def func(epoch):\n", - " if epoch < max_epochs:\n", - " factor = np.log(1 + epoch) / np.log(1 + max_epochs)\n", - " return bmin + (bmax - bmin) * factor\n", - " else:\n", - " return bmax\n", - " return func\n" + "from moisture_rnn import calc_exp_intervals, calc_log_intervals" ] }, { @@ -281,26 +225,11 @@ "metadata": {}, "outputs": [], "source": [ - "max_epochs = 10\n", + "epochs = 50\n", "bmin = 10\n", "bmax = 200\n", "\n", - "linear_func = linear_increase(max_epochs, bmin, bmax)\n", - "exp_func = exp_increase(max_epochs, bmin, bmax)\n", - "log_func = log_increase(max_epochs, bmin, bmax)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "58d2661a-ea65-4e64-92ba-252a2e82eeab", - "metadata": {}, - "outputs": [], - "source": [ - "epochs = range(max_epochs)\n", - "linear_values = [linear_func(epoch) for epoch in epochs]\n", - "exp_values = [exp_func(epoch) for epoch in epochs]\n", - "log_values = [log_func(epoch) for epoch in epochs]" + "egrid = np.arange(epochs)" ] }, { @@ -310,9 +239,9 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(epochs, linear_values, label='Linear Increase')\n", - "plt.plot(epochs, exp_values, label='Exponential Increase')\n", - "plt.plot(epochs, log_values, label='Logarithmic Increase')\n", + "plt.plot(egrid, np.linspace(bmin, bmax, epochs), label='Linear Increase')\n", + "plt.plot(egrid, calc_exp_intervals(bmin, bmax, epochs), label='Exponential Increase')\n", + "plt.plot(egrid, calc_log_intervals(bmin, bmax, epochs), label='Logarithmic Increase')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Batch Reset Value')\n", "plt.legend()\n", @@ -321,12 +250,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "8d86d3fb-fd30-4435-b18b-2ed3e48ade12", + "cell_type": "markdown", + "id": "92e06001-4b27-47e8-aa5e-03bffdb9ba03", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "### Linear Schedule" + ] }, { "cell_type": "code", @@ -334,15 +263,22 @@ "id": "cf4819b7-7e8e-4da9-8dee-217eda8f274f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "params.update({'verbose_fit': False, 'stateful': True, \n", + " 'batch_schedule_type':'linear', 'bmin': 20, 'bmax': 200})\n", + "params.update({'epochs': 40})\n", + "reproducibility.set_seed(123)\n", + "rnn = RNN(params)\n", + "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")" + ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d1f642d3-c0a1-49d2-bc7c-69c5152626ed", + "cell_type": "markdown", + "id": "362dad1f-7584-4a04-a146-c13c2da2dc84", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "### Exponential Increase" + ] }, { "cell_type": "code", @@ -350,15 +286,22 @@ "id": "d8e590af-4bdc-4e3f-bced-af2f4479e015", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "params.update({'verbose_fit': False, 'stateful': True, \n", + " 'batch_schedule_type':'exp', 'bmin': 20, 'bmax': 200})\n", + "params.update({'epochs': 40})\n", + "reproducibility.set_seed(123)\n", + "rnn = RNN(params)\n", + "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")" + ] }, { - "cell_type": "code", - "execution_count": null, - "id": "071cc121-8607-4dee-9ab2-9495de92dc0f", + "cell_type": "markdown", + "id": "2424801b-3482-4542-9236-dbebae2c1143", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "### Log Increase" + ] }, { "cell_type": "code", @@ -366,7 +309,14 @@ "id": "00d32580-6c20-4916-b241-1af281fedbd9", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "params.update({'verbose_fit': False, 'stateful': True, \n", + " 'batch_schedule_type':'log', 'bmin': 20, 'bmax': 200})\n", + "params.update({'epochs': 40})\n", + "reproducibility.set_seed(123)\n", + "rnn = RNN(params)\n", + "m, errs = rnn.run_model(rnn_dat, plot_period = \"predict\")" + ] }, { "cell_type": "code", -- 2.11.4.GIT