Update moisture_rnn.py
[notebooks.git] / fmda / version_control / rnn_train_versions.ipynb
blob5aca18a89a52426d85b096fdd04564038c3c446c
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "e20166f4-1a8b-4471-a9a9-e944cc4b1087",
6    "metadata": {},
7    "source": [
8     "# Use to Check Reproducibility - v2.1 Code\n",
9     "\n",
10     "Version 2.1 relies on conda environment built from yaml file `fmda/install/fmda_ml.yml`. This environment uses python version `>=3.12` tensorflow version `>=2.16`. This led to substantial changes from the old reproducibility for code v2.0, so those old reproducibility hashes are no longer supported.\n",
11     "\n",
12     "To see old code and reproduciblity results, see commit \"911c6d7a\" or PR#11."
13    ]
14   },
15   {
16    "cell_type": "markdown",
17    "id": "ccbfa419-70b9-484f-ada7-82fcc70b5b38",
18    "metadata": {},
19    "source": [
20     "## Setup"
21    ]
22   },
23   {
24    "cell_type": "code",
25    "execution_count": null,
26    "id": "8530bc7e-61ae-4463-a14f-d5eb42f0b83e",
27    "metadata": {},
28    "outputs": [],
29    "source": [
30     "# Environment\n",
31     "import numpy as np\n",
32     "import pandas as pd\n",
33     "import tensorflow as tf\n",
34     "import matplotlib.pyplot as plt\n",
35     "import sys\n",
36     "# Local modules\n",
37     "sys.path.append('..')\n",
38     "from moisture_rnn import RNN, RNNParams, RNNData\n",
39     "import reproducibility\n",
40     "from utils import print_dict_summary, read_yml, read_pkl\n",
41     "from moisture_rnn_pkl import pkl2train\n",
42     "from moisture_rnn import RNN, RNNData, RNNParams\n",
43     "import logging\n",
44     "from utils import logging_setup\n",
45     "logging_setup()"
46    ]
47   },
48   {
49    "cell_type": "markdown",
50    "id": "63d275da-b13a-405e-9e1a-aa3f972119b5",
51    "metadata": {},
52    "source": [
53     "### Reproducibility Dataset"
54    ]
55   },
56   {
57    "cell_type": "code",
58    "execution_count": null,
59    "id": "2dfa08fa-01c9-4fd7-927b-541b0a532e4f",
60    "metadata": {},
61    "outputs": [],
62    "source": [
63     "# Original File\n",
64     "repro_file='../data/reproducibility_dict_v2_TEST.pkl'\n",
65     "repro = read_pkl(repro_file)"
66    ]
67   },
68   {
69    "cell_type": "markdown",
70    "id": "87f59db1-fa7b-44f4-bbe4-7e49221226e9",
71    "metadata": {},
72    "source": [
73     "## RNN with Stateful Batch Training\n"
74    ]
75   },
76   {
77    "cell_type": "code",
78    "execution_count": null,
79    "id": "134d13b9-f329-49fb-8b53-16b4b109ae18",
80    "metadata": {},
81    "outputs": [],
82    "source": [
83     "# Set up params\n",
84     "params = repro['repro_info']['params']\n",
85     "print(type(params))\n",
86     "print(params)"
87    ]
88   },
89   {
90    "cell_type": "code",
91    "execution_count": null,
92    "id": "e9867f1b-f0fe-4032-a302-ec093782b227",
93    "metadata": {},
94    "outputs": [],
95    "source": [
96     "# Set up input data\n",
97     "rnn_dat = RNNData(repro, scaler = params['scaler'], features_list = params['features_list'])\n",
98     "rnn_dat.train_test_split(\n",
99     "    time_fracs = params['time_fracs']\n",
100     ")\n",
101     "rnn_dat.scale_data()\n",
102     "rnn_dat.batch_reshape(timesteps = params['timesteps'], batch_size = params['batch_size'])"
103    ]
104   },
105   {
106    "cell_type": "code",
107    "execution_count": null,
108    "id": "458ea3bb-2b19-42b3-8ac8-fdec96c6d315",
109    "metadata": {},
110    "outputs": [],
111    "source": [
112     "reproducibility.set_seed()\n",
113     "rnn = RNN(params)\n",
114     "m, errs = rnn.run_model(rnn_dat, reproducibility_run=True)"
115    ]
116   },
117   {
118    "cell_type": "code",
119    "execution_count": null,
120    "id": "ac65f5e4-b9a6-4592-998c-3e5161125449",
121    "metadata": {},
122    "outputs": [],
123    "source": []
124   },
125   {
126    "cell_type": "code",
127    "execution_count": null,
128    "id": "0e66d653-92d9-40cb-8690-dc9aa9c7b504",
129    "metadata": {},
130    "outputs": [],
131    "source": []
132   },
133   {
134    "cell_type": "code",
135    "execution_count": null,
136    "id": "591cd4d8-8b15-4d94-b258-76e92176f298",
137    "metadata": {},
138    "outputs": [],
139    "source": []
140   },
141   {
142    "cell_type": "code",
143    "execution_count": null,
144    "id": "67b6d000-b41b-4d5f-bead-814167e79fce",
145    "metadata": {},
146    "outputs": [],
147    "source": []
148   },
149   {
150    "cell_type": "markdown",
151    "id": "d4103c38-e067-4e72-a694-27080fa5265e",
152    "metadata": {},
153    "source": [
154     "### Physics Initialized"
155    ]
156   },
157   {
158    "cell_type": "code",
159    "execution_count": null,
160    "id": "383e2870-a581-4ac2-8669-f16ac41a64a4",
161    "metadata": {},
162    "outputs": [],
163    "source": [
164     "print(\"NOT YET IMPLEMENTED\")\n",
165     "# params.update({'phys_initialize': True})\n",
166     "# reproducibility.set_seed()\n",
167     "# rnn = RNN(params)\n",
168     "# m, errs = rnn.run_model(rnn_dat)"
169    ]
170   },
171   {
172    "cell_type": "code",
173    "execution_count": null,
174    "id": "33584d39-cd5d-4613-b330-48ab960cb42e",
175    "metadata": {},
176    "outputs": [],
177    "source": []
178   },
179   {
180    "cell_type": "code",
181    "execution_count": null,
182    "id": "9bfbd951-c89a-4b69-be94-7ac4783f5eb7",
183    "metadata": {},
184    "outputs": [],
185    "source": []
186   }
187  ],
188  "metadata": {
189   "kernelspec": {
190    "display_name": "Python 3 (ipykernel)",
191    "language": "python",
192    "name": "python3"
193   },
194   "language_info": {
195    "codemirror_mode": {
196     "name": "ipython",
197     "version": 3
198    },
199    "file_extension": ".py",
200    "mimetype": "text/x-python",
201    "name": "python",
202    "nbconvert_exporter": "python",
203    "pygments_lexer": "ipython3",
204    "version": "3.12.5"
205   }
206  },
207  "nbformat": 4,
208  "nbformat_minor": 5