Merge pull request #17 from openwfm/restructure
[notebooks.git] / fmda / presentations / output_layer_physics_constrained.ipynb
blobcca26d784663e31964c68d87fe427bd0c3a50bb7
2  "cells": [
3   {
4    "cell_type": "code",
5    "execution_count": null,
6    "id": "95909853-fb1d-43fd-82fe-d215a215b24d",
7    "metadata": {},
8    "outputs": [],
9    "source": [
10     "import sys\n",
11     "sys.path.append(\"..\")"
12    ]
13   },
14   {
15    "cell_type": "code",
16    "execution_count": null,
17    "id": "a19b50d3-6524-4365-9855-ad487e5d87f2",
18    "metadata": {},
19    "outputs": [],
20    "source": [
21     "# Setup\n",
22     "import reproducibility\n",
23     "\n",
24     "# both can change\n",
25     "# Environment\n",
26     "import numpy as np\n",
27     "import pandas as pd\n",
28     "import tensorflow as tf\n",
29     "\n",
30     "import matplotlib.pyplot as plt\n",
31     "import tensorflow as tf\n",
32     "import pickle, os\n",
33     "\n",
34     "from data_funcs import load_and_fix_data, plot_data\n",
35     "from moisture_rnn import create_rnn_data, train_rnn, rnn_predict"
36    ]
37   },
38   {
39    "cell_type": "markdown",
40    "id": "268f2737-feec-4d6e-accf-dd21e7fbb2e9",
41    "metadata": {},
42    "source": [
43     "# Physics-Initiated Neural Networks - Output Layers\n",
44     "\n",
45     "The purpose of this notebook is to discuss the final output layer of physics-initiated Neural Networks. Physics might constrain the range of model outputs, so it might make sense to hard-code this into the output layer. Furthermore, a physical system may have different behavior when it approaches the minimum and maximum allowable values."
46    ]
47   },
48   {
49    "cell_type": "markdown",
50    "id": "9b94cb41-9d27-415b-b033-dcce2f139d82",
51    "metadata": {},
52    "source": [
53     "## Fuel Moisture Models\n",
54     "\n",
55     "In the fuel moisture content (FM) modeling framework, there are constraints on the model outputs to make it physically reasonable. First, FM cannot go below 0%, when there is no water content in the fuels. The maximum possible water content depends on the fuel type. Live fuels range from 200-250% maximum allowable moisture. For dead fuels, Mandel et.al. 2014 use the model below with a \"saturation\" moisture level $S=250\\%$:\n",
56     "\n",
57     "\n",
58     "$$\n",
59     "\\frac{\\mathrm{d}m}{\\mathrm{d}t}=\\frac{S-m}{t_{\\mathrm{r}}}\\left(1-\\exp\\left(-\\frac{r-r_0}{r_{\\mathrm{s}}}\n",
60     "\\right)  \\right),\\ \\text{if}\\ r>r_0, \n",
61     "$$\n",
62     "\n",
63     "A simple approach would be to constrain the outputs with a piece-wise linear function, or a \"clipped relu\" function as depicted below:\n",
64     "\n",
65     "![activation functions](https://www.researchgate.net/profile/Md-Hossain-331/publication/343096012/figure/fig2/AS:915583516278784@1595303529166/Activation-Functions-f-x-and-their-corresponding-1-st-derivatives-Dx-The-proposed.png)\n",
66     "\n",
67     "For the purposes of this notebook, we will constrain dead fuel moisture values to be less than or equal to 250%. Additionally, a physically reasonable process would have moisture content approach the maximum logarithmically, but the minimum moisture content of 0% could be reached more easily. Thus, the \"log-tailed relu\" depicted above could be preferrable, though this function will approach infinity instead of being capped at a maximum value. We will explore augmenting the log-tailed relu idea to have the range of outputs exponentially approach a maximum value."
68    ]
69   },
70   {
71    "cell_type": "markdown",
72    "id": "c5c5b843-2e7b-4079-aab9-b19d177bb851",
73    "metadata": {},
74    "source": [
75     "Sources:\n",
76     "* [Live Fuel Moistures](https://www.nwcg.gov/publications/pms437/fuel-moisture/live-fuel-moisture-content)\n",
77     "* [Fuel Moisture Processes](https://www.nwcg.gov/publications/pms425-1/weather-and-fuel-moisture), includes discussion of fiber-saturation level"
78    ]
79   },
80   {
81    "cell_type": "markdown",
82    "id": "a027cc49-2a4b-4d44-a449-5fd1117a6f32",
83    "metadata": {},
84    "source": [
85     "## Alternative ReLU Functions \n",
86     "\n",
87     "Below we define the activation functions and plot the range. (The primary source for this section is Hossain 2020.)\n",
88     "\n",
89     "The Clipped-ReLU function is identical to the ReLU up to a threshold value, after which it is constant with zero slope. The mathematical form of the clipped-ReLU with threshold value $A$ is:\n",
90     "\n",
91     "$$\n",
92     "f(x)=\\begin{cases}\n",
93     "    \\max(0, x) &  0<x\\leq A\\\\\n",
94     "    A &  x> A\n",
95     "\\end{cases}\n",
96     "$$\n",
97     "\n",
98     "This can be easily programmed as a piecewise linear function by taking the maximum of 0 and the input $x$, and then the minimum of that output with the threshold value:"
99    ]
100   },
101   {
102    "cell_type": "code",
103    "execution_count": null,
104    "id": "ab050a17-87c6-4be7-a4ff-f4a77d45f45f",
105    "metadata": {},
106    "outputs": [],
107    "source": [
108     "# Define standard ReLU function\n",
109     "def relu(x):\n",
110     "    return tf.keras.backend.maximum(0., x)\n",
111     "\n",
112     "# Define clipped ReLU function\n",
113     "def clipped_relu(x, threshold=250):\n",
114     "    return tf.keras.backend.minimum(tf.keras.backend.maximum(0., x), threshold)"
115    ]
116   },
117   {
118    "cell_type": "code",
119    "execution_count": null,
120    "id": "2170e86b-63e2-4702-b495-8c653a23e7e8",
121    "metadata": {},
122    "outputs": [],
123    "source": [
124     "xgrid = np.linspace(-100, 400, 50)"
125    ]
126   },
127   {
128    "cell_type": "code",
129    "execution_count": null,
130    "id": "3f090c69-7423-4998-a9c7-9222fd699793",
131    "metadata": {},
132    "outputs": [],
133    "source": [
134     "plt.ylim(-50, 400)\n",
135     "plt.axline((-1, 0), (0, 0), color=\"k\", linestyle=\":\") # x axis line\n",
136     "plt.axline((0, 0), (0, 1), color=\"k\", linestyle=\":\") # y axis line\n",
137     "plt.plot(xgrid, relu(xgrid), label = \"Standard Relu\", linestyle=\"dashed\")\n",
138     "plt.plot(xgrid, clipped_relu(xgrid), label = \"Clipped Relu\")\n",
139     "plt.legend()\n",
140     "plt.grid()"
141    ]
142   },
143   {
144    "cell_type": "markdown",
145    "id": "da021082-8e29-4e65-b9f3-860e9dafa9fd",
146    "metadata": {},
147    "source": [
148     "The log-tailed ReLU function is similarly identical to the standard ReLU up to a threshold value, and then proceeds logarithmically from there. The mathematical specification, for a threshold value of $A$, is:\n",
149     "\n",
150     "$$\n",
151     "f(x)=\\begin{cases}\n",
152     "    0 &  x\\leq 0\\\\\n",
153     "    x &  0<x\\leq A\\\\\n",
154     "    A+\\log(x-A) &  x> A\n",
155     "\\end{cases}\n",
156     "$$"
157    ]
158   },
159   {
160    "cell_type": "code",
161    "execution_count": null,
162    "id": "e6333440-2e6b-45bd-b654-3c55830cfabd",
163    "metadata": {},
164    "outputs": [],
165    "source": [
166     "# Define Log-Tailed Relu\n",
167     "def logtailed_relu(x, threshold=240):\n",
168     "    fx = np.maximum(0., x)\n",
169     "    x2 = x[x>threshold]\n",
170     "    fx[np.where(fx>threshold)]=threshold+np.log(x2-threshold)\n",
171     "    return fx"
172    ]
173   },
174   {
175    "cell_type": "code",
176    "execution_count": null,
177    "id": "b5b3c053-16f4-470e-af87-335c06b6b852",
178    "metadata": {
179     "scrolled": true
180    },
181    "outputs": [],
182    "source": [
183     "plt.ylim(-50, 400)\n",
184     "plt.axline((-1, 0), (0, 0), color=\"k\", linestyle=\":\") # x axis line\n",
185     "plt.axline((0, 0), (0, 1), color=\"k\", linestyle=\":\") # y axis line\n",
186     "plt.plot(xgrid, relu(xgrid), label = \"Standard Relu\", linestyle=\"dashed\")\n",
187     "plt.plot(xgrid, clipped_relu(xgrid), label = \"Clipped Relu\")\n",
188     "plt.plot(xgrid, logtailed_relu(xgrid), label = \"Log-Tailed Relu\")\n",
189     "plt.legend()\n",
190     "plt.grid()"
191    ]
192   },
193   {
194    "cell_type": "markdown",
195    "id": "51accc9a-838a-453e-82c9-3fb68b47ce3f",
196    "metadata": {},
197    "source": [
198     "The log-tailed ReLU as presented above is virtually identical to the clipped ReLU until very near the saturation level.\n",
199     "\n",
200     "It might be worth exploring other functions that approach the saturation level slower."
201    ]
202   },
203   {
204    "cell_type": "markdown",
205    "id": "f534117b-2634-41c6-9274-c4fc5aba69eb",
206    "metadata": {},
207    "source": [
208     "## Testing Models with Various Output Layer Activation"
209    ]
210   },
211   {
212    "cell_type": "code",
213    "execution_count": null,
214    "id": "e7d92a91-6d66-4ddf-9fff-18e6c896c94c",
215    "metadata": {},
216    "outputs": [],
217    "source": [
218     "# Linear Activation Params\n",
219     "param_sets={}\n",
220     "param_sets.update({1:{'id':1,\n",
221     "        'purpose':'test 1',\n",
222     "        'cases':'all',\n",
223     "        'scale':1,        # every feature in [0, scale]\n",
224     "        'rain_do':True,\n",
225     "        'verbose':False,\n",
226     "        'timesteps':5,\n",
227     "        'activation':['linear','linear'],\n",
228     "        'hidden_units':20,  \n",
229     "        'dense_units':1,    # do not change\n",
230     "        'dense_layers':1,   # do not change\n",
231     "        'centering':[0.0,0.0],  # should be activation at 0\n",
232     "        'DeltaE':[0,-1],    # bias correction\n",
233     "        'synthetic':False,  # run also synthetic cases\n",
234     "        'T1': 0.1,          # 1/fuel class (10)\n",
235     "        'fm_raise_vs_rain': 2.0,         # fm increase per mm rain \n",
236     "        'train_frac':0.5,  # time fraction to spend on training\n",
237     "        'epochs':1000,\n",
238     "        'verbose_fit':False,\n",
239     "        'verbose_weights':False,\n",
240     "}})"
241    ]
242   },
243   {
244    "cell_type": "code",
245    "execution_count": null,
246    "id": "f57d3c90-eb69-40e3-90fa-1b06b308ad01",
247    "metadata": {},
248    "outputs": [],
249    "source": [
250     "param_sets.update({2:{'id':2,\n",
251     "        'purpose':'test 1',\n",
252     "        'cases':'all',\n",
253     "        'scale':1,        # every feature in [0, scale]\n",
254     "        'rain_do':True,\n",
255     "        'verbose':False,\n",
256     "        'timesteps':5,\n",
257     "        'activation':['linear',clipped_relu],\n",
258     "        'hidden_units':20,  \n",
259     "        'dense_units':1,    # do not change\n",
260     "        'dense_layers':1,   # do not change\n",
261     "        'centering':[0.0,0.0],  # should be activation at 0\n",
262     "        'DeltaE':[0,-1],    # bias correction\n",
263     "        'synthetic':False,  # run also synthetic cases\n",
264     "        'T1': 0.1,          # 1/fuel class (10)\n",
265     "        'fm_raise_vs_rain': 2.0,         # fm increase per mm rain \n",
266     "        'train_frac':0.5,  # time fraction to spend on training\n",
267     "        'epochs':1000,\n",
268     "        'verbose_fit':False,\n",
269     "        'verbose_weights':False,\n",
270     "}})"
271    ]
272   },
273   {
274    "cell_type": "code",
275    "execution_count": null,
276    "id": "45f9484b-881d-4e57-b9e0-bbd5e5ad62a0",
277    "metadata": {},
278    "outputs": [],
279    "source": [
280     "# Data\n",
281     "# Change directory for data read/write\n",
282     "\n",
283     "dict_file='../data/raws_CO_202306.pickle' # input path of FMDA dictionaries\n",
284     "reproducibility_file='../data/reproducibility_dict.pickle'\n",
285     "\n",
286     "# read test datasets\n",
287     "test_dict={}\n",
288     "test_dict.update(load_and_fix_data(dict_file))\n",
289     "print(test_dict.keys())\n",
290     "\n",
291     "repro_dict={}\n",
292     "repro_dict.update(load_and_fix_data(reproducibility_file))\n",
293     "print(repro_dict.keys())\n",
294     "# Build Case Data\n",
295     "id = \"CPTC2_202306010000\"\n",
296     "case_data=test_dict[id]\n",
297     "case_data[\"hours\"]=len(case_data['fm'])\n",
298     "case_data[\"h2\"]=int(24*20)"
299    ]
300   },
301   {
302    "cell_type": "markdown",
303    "id": "b5477a1a-42b0-4a71-b47c-aea423f50cee",
304    "metadata": {},
305    "source": [
306     "### Check Initial Fit"
307    ]
308   },
309   {
310    "cell_type": "code",
311    "execution_count": null,
312    "id": "0bf897cb-acfb-4433-960c-687fd579ccca",
313    "metadata": {},
314    "outputs": [],
315    "source": [
316     "rnn_dat = create_rnn_data(case_data,param_sets[1])\n",
317     "model1 = train_rnn(\n",
318     "    rnn_dat,\n",
319     "    param_sets[1],\n",
320     "    rnn_dat['hours'],\n",
321     "    fit=False\n",
322     ")\n",
323     "fit1 = rnn_predict(model1, param_sets[1], rnn_dat)\n",
324     "rnn_dat = create_rnn_data(case_data,param_sets[2])\n",
325     "model1 = train_rnn(\n",
326     "    rnn_dat,\n",
327     "    param_sets[2],\n",
328     "    rnn_dat['hours'],\n",
329     "    fit=False\n",
330     ")\n",
331     "fit2 = rnn_predict(model1, param_sets[2], rnn_dat)\n",
332     "plt.plot(fit1, label=\"Linear Output\")\n",
333     "plt.plot(fit2, label=\"Clipped Relu Output\")\n",
334     "plt.legend()"
335    ]
336   },
337   {
338    "cell_type": "markdown",
339    "id": "5cbfb9ea-370b-46b5-a8d1-7d83b1e56415",
340    "metadata": {},
341    "source": [
342     "### Check Trained Fit"
343    ]
344   },
345   {
346    "cell_type": "code",
347    "execution_count": null,
348    "id": "9174fc29-a60b-4811-a71e-1b1e9cbfd2b5",
349    "metadata": {},
350    "outputs": [],
351    "source": [
352     "reproducibility.set_seed() # Set seed for reproducibility\n",
353     "params = param_sets[1]\n",
354     "\n",
355     "rnn_dat = create_rnn_data(case_data,params)\n",
356     "model1 = train_rnn(\n",
357     "    rnn_dat,\n",
358     "    params,\n",
359     "    rnn_dat['hours'],\n",
360     "    fit=True\n",
361     ")\n",
362     "case_data['m'] = rnn_predict(model1, params, rnn_dat)\n",
363     "fit_linear = case_data['m']"
364    ]
365   },
366   {
367    "cell_type": "code",
368    "execution_count": null,
369    "id": "0262a9db-59fd-4f49-8217-a171d818d33d",
370    "metadata": {},
371    "outputs": [],
372    "source": [
373     "plot_data(case_data,title2='Initial RNN Linear')"
374    ]
375   },
376   {
377    "cell_type": "code",
378    "execution_count": null,
379    "id": "16b0fcaa-e027-4acb-a21e-618d0eba05e4",
380    "metadata": {},
381    "outputs": [],
382    "source": [
383     "reproducibility.set_seed() # Set seed for reproducibility\n",
384     "params = param_sets[2]\n",
385     "\n",
386     "rnn_dat = create_rnn_data(case_data,params)\n",
387     "model1 = train_rnn(\n",
388     "    rnn_dat,\n",
389     "    params,\n",
390     "    rnn_dat['hours'],\n",
391     "    fit=True\n",
392     ")\n",
393     "case_data['m'] = rnn_predict(model1, params, rnn_dat)\n",
394     "fit_clipped = case_data['m']"
395    ]
396   },
397   {
398    "cell_type": "code",
399    "execution_count": null,
400    "id": "3bf4bec6-9fce-4e82-b1f0-6cae9dd8bbce",
401    "metadata": {},
402    "outputs": [],
403    "source": [
404     "plot_data(case_data,title2='Initial RNN Clipped')"
405    ]
406   },
407   {
408    "cell_type": "code",
409    "execution_count": null,
410    "id": "32875d7d-914a-438e-b5c7-29fdeb016a37",
411    "metadata": {},
412    "outputs": [],
413    "source": [
414     "print(np.max(fit_linear - fit_clipped))"
415    ]
416   },
417   {
418    "cell_type": "markdown",
419    "id": "da7d66c7-d16e-4492-92f9-4b7bfc71d4d9",
420    "metadata": {},
421    "source": [
422     "The maximum difference in the fitted values is about half of a tenth of a percent, so there was no dramatic effect after training."
423    ]
424   },
425   {
426    "cell_type": "markdown",
427    "id": "c94d6765-b17c-41ff-a08d-32dc3709b73d",
428    "metadata": {},
429    "source": [
430     "## Sources\n",
431     "\n",
432     "Hossain, Md & Teng, Shyh & Sohel, Ferdous & Lu, Guojun. (2020). Robust Image Classification Using A Low-Pass Activation Function and DCT Augmentation. "
433    ]
434   },
435   {
436    "cell_type": "code",
437    "execution_count": null,
438    "id": "138af0b6-031b-4b80-acbe-8dd474a70b5d",
439    "metadata": {},
440    "outputs": [],
441    "source": []
442   }
443  ],
444  "metadata": {
445   "kernelspec": {
446    "display_name": "Python 3 (ipykernel)",
447    "language": "python",
448    "name": "python3"
449   },
450   "language_info": {
451    "codemirror_mode": {
452     "name": "ipython",
453     "version": 3
454    },
455    "file_extension": ".py",
456    "mimetype": "text/x-python",
457    "name": "python",
458    "nbconvert_exporter": "python",
459    "pygments_lexer": "ipython3",
460    "version": "3.9.12"
461   }
462  },
463  "nbformat": 4,
464  "nbformat_minor": 5