Create Batch Reset Hyperparameter tutorial notebook
[notebooks.git] / fmda / version_control / rnn_train_versions.ipynb
blobf433b7f21b4c12726cc4c46edee1492b1bfc4064
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "e20166f4-1a8b-4471-a9a9-e944cc4b1087",
6    "metadata": {},
7    "source": [
8     "# Use to Check Reproducibility - v2.1 Code\n",
9     "\n",
10     "Version 2.1 relies on conda environment built from yaml file `fmda/install/fmda_ml.yml`. This environment uses python version `>=3.12` tensorflow version `>=2.16`. This led to substantial changes from the old reproducibility for code v2.0, so those old reproducibility hashes are no longer supported.\n",
11     "\n",
12     "To see old code and reproduciblity results, see commit \"911c6d7a\" or PR#11."
13    ]
14   },
15   {
16    "cell_type": "markdown",
17    "id": "ccbfa419-70b9-484f-ada7-82fcc70b5b38",
18    "metadata": {},
19    "source": [
20     "## Setup"
21    ]
22   },
23   {
24    "cell_type": "code",
25    "execution_count": null,
26    "id": "8530bc7e-61ae-4463-a14f-d5eb42f0b83e",
27    "metadata": {},
28    "outputs": [],
29    "source": [
30     "# Environment\n",
31     "import numpy as np\n",
32     "import pandas as pd\n",
33     "import tensorflow as tf\n",
34     "import matplotlib.pyplot as plt\n",
35     "import sys\n",
36     "# Local modules\n",
37     "sys.path.append('..')\n",
38     "from moisture_rnn import RNN, RNNParams, RNNData\n",
39     "import reproducibility\n",
40     "from utils import print_dict_summary, read_yml, read_pkl\n",
41     "from moisture_rnn_pkl import pkl2train\n",
42     "from moisture_rnn import RNN, RNNData, RNNParams\n",
43     "import logging\n",
44     "from utils import logging_setup\n",
45     "logging_setup()"
46    ]
47   },
48   {
49    "cell_type": "markdown",
50    "id": "63d275da-b13a-405e-9e1a-aa3f972119b5",
51    "metadata": {},
52    "source": [
53     "### Reproducibility Dataset"
54    ]
55   },
56   {
57    "cell_type": "code",
58    "execution_count": null,
59    "id": "2dfa08fa-01c9-4fd7-927b-541b0a532e4f",
60    "metadata": {},
61    "outputs": [],
62    "source": [
63     "# Original File\n",
64     "repro_file='../data/reproducibility_dict_v2_TEST.pkl'\n",
65     "repro = read_pkl(repro_file)"
66    ]
67   },
68   {
69    "cell_type": "markdown",
70    "id": "87f59db1-fa7b-44f4-bbe4-7e49221226e9",
71    "metadata": {},
72    "source": [
73     "## RNN with Stateful Batch Training\n"
74    ]
75   },
76   {
77    "cell_type": "code",
78    "execution_count": null,
79    "id": "134d13b9-f329-49fb-8b53-16b4b109ae18",
80    "metadata": {},
81    "outputs": [],
82    "source": [
83     "# Set up params\n",
84     "params = repro['repro_info']['params']\n",
85     "print(type(params))\n",
86     "print(params)"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "e9867f1b-f0fe-4032-a302-ec093782b227",
93    "metadata": {},
94    "outputs": [],
95    "source": [
96     "# Set up input data\n",
97     "rnn_dat = RNNData(repro, scaler = params['scaler'], features_list = params['features_list'])\n",
98     "rnn_dat.train_test_split(\n",
99     "    train_frac = params['train_frac'],\n",
100     "    val_frac = params['val_frac']\n",
101     ")\n",
102     "rnn_dat.scale_data()"
103    ]
104   },
105   {
106    "cell_type": "code",
107    "execution_count": null,
108    "id": "458ea3bb-2b19-42b3-8ac8-fdec96c6d315",
109    "metadata": {},
110    "outputs": [],
111    "source": [
112     "reproducibility.set_seed()\n",
113     "rnn = RNN(params)\n",
114     "m, errs = rnn.run_model(rnn_dat, reproducibility_run=True)"
115    ]
116   },
117   {
118    "cell_type": "code",
119    "execution_count": null,
120    "id": "ac65f5e4-b9a6-4592-998c-3e5161125449",
121    "metadata": {},
122    "outputs": [],
123    "source": []
124   },
125   {
126    "cell_type": "code",
127    "execution_count": null,
128    "id": "0e66d653-92d9-40cb-8690-dc9aa9c7b504",
129    "metadata": {},
130    "outputs": [],
131    "source": []
132   },
133   {
134    "cell_type": "code",
135    "execution_count": null,
136    "id": "591cd4d8-8b15-4d94-b258-76e92176f298",
137    "metadata": {},
138    "outputs": [],
139    "source": []
140   },
141   {
142    "cell_type": "code",
143    "execution_count": null,
144    "id": "67b6d000-b41b-4d5f-bead-814167e79fce",
145    "metadata": {},
146    "outputs": [],
147    "source": []
148   },
149   {
150    "cell_type": "markdown",
151    "id": "d4103c38-e067-4e72-a694-27080fa5265e",
152    "metadata": {},
153    "source": [
154     "### Physics Initialized"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "383e2870-a581-4ac2-8669-f16ac41a64a4",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "print(\"NOT YET IMPLEMENTED\")\n",
165     "# params.update({'phys_initialize': True})\n",
166     "# reproducibility.set_seed()\n",
167     "# rnn = RNN(params)\n",
168     "# m, errs = rnn.run_model(rnn_dat)"
169    ]
170   },
171   {
172    "cell_type": "code",
173    "execution_count": null,
174    "id": "33584d39-cd5d-4613-b330-48ab960cb42e",
175    "metadata": {},
176    "outputs": [],
177    "source": []
178   },
179   {
180    "cell_type": "code",
181    "execution_count": null,
182    "id": "9bfbd951-c89a-4b69-be94-7ac4783f5eb7",
183    "metadata": {},
184    "outputs": [],
185    "source": []
186   }
187  ],
188  "metadata": {
189   "kernelspec": {
190    "display_name": "Python 3 (ipykernel)",
191    "language": "python",
192    "name": "python3"
193   },
194   "language_info": {
195    "codemirror_mode": {
196     "name": "ipython",
197     "version": 3
198    },
199    "file_extension": ".py",
200    "mimetype": "text/x-python",
201    "name": "python",
202    "nbconvert_exporter": "python",
203    "pygments_lexer": "ipython3",
204    "version": "3.12.5"
205   }
206  },
207  "nbformat": 4,
208  "nbformat_minor": 5