Update fmda_rnn_spatial.ipynb
[notebooks.git] / fmda / presentations / rnn_data_structure.ipynb
blobdda26e99e5149be49ab194c343017bd494e32fde
2  "cells": [
3   {
4    "cell_type": "markdown",
5    "id": "55855dde-0a45-4430-b576-29c7eb078794",
6    "metadata": {},
7    "source": [
8     "# Input Data Structure for RNNs"
9    ]
10   },
11   {
12    "cell_type": "code",
13    "execution_count": null,
14    "id": "2fa3f116-148d-4251-a8a5-e908710bc3e4",
15    "metadata": {},
16    "outputs": [],
17    "source": [
18     "import numpy as np\n",
19     "import sys\n",
20     "sys.path.append(\"..\")\n",
21     "from moisture_rnn import staircase, staircase_2\n",
22     "import pandas as pd"
23    ]
24   },
25   {
26    "cell_type": "markdown",
27    "id": "e8155370-6650-4d18-b3e4-c39260d217b4",
28    "metadata": {},
29    "source": [
30     "\n",
31     "## Background\n",
32     "\n",
33     "RNNs are a type of timeseries model that relates the outcome at time $t$ to the outcome at previous times. Like other machine learning models, training is typically done by calculating the gradient of the output with respect to the weights, or parameters, of the model. With recursive or other type of autoregressive models, the gradient calculation at time $t$ ends up depending on the gradient at $t-1, t-2, ...,$ and to $t=0$. This ends up being computationally expensive, but more importantly can lead to \"vanishing\" or \"exploding\" gradient problems, where many gradients are multiplied together and either blow up or shrink. See LINK_TO_RECURSIVE_GRADIENT_LATEX for more info...\n",
34     "\n",
35     "RRNs and other timeseries neural network architectures* get around this issue by approximating the gradient in more stable ways. In addition to many model architecture and hyperparameter options, timeseries neural networks use two main ways of restructuring the input data.\n",
36     "\n",
37     "* **Sequence Length:** The input data is divided into smaller collections of ordered events, known as sequences. When calculating the gradient with respect to the model weights, the gradient only looks back this number of timesteps. Also known as `timesteps`, `sequence_length`, or just \"sample\" in `tensorflow` functions and related literature. For a sequence length of 3, the unique sequences of length 3 are: `[1,2,3], [2,3,4], ..., [T-2, T-1, T]`, for a total number of sequences `T-timesteps+1`\n",
38     "\n",
39     "* **Batch size:** Sequences are grouped together into batches. The batch size determines how many sequences the network processes in a single step of calculating loss and updating weights. Used as `batch_size` in `tensorflow`.\n",
40     "\n",
41     "The total number of batches is therefore determined from the total number of observations $T$ and the batch size. In a single batch, the loss is typically calculated for each sequence and then averaged to produce a single value. Then, the gradient of the loss with respect to the parameters (weights and biases) is computed for each sequence in the batch. So each batch will have a single gradient calculation that is the sum of the gradients of each sequence in the batch.\n",
42     "\n",
43     "**Note:* these same data principles apply to more complex versions of timeseries neural network layers, such as LSTM and GRU."
44    ]
45   },
46   {
47    "cell_type": "markdown",
48    "id": "39fa2a9d-c74b-445d-a875-e7591c7a6666",
49    "metadata": {},
50    "source": [
51     "## Stateless vs Stateful Networks\n",
52     "\n",
53     "RNNs have a hidden state that represents the recurrent layer output at a previous time. There is a weight and bias at each RNN cell that determines the relative contribution of the previous output to the current output. When updating weights in RNNs, there are two main types of training scheme:\n",
54     "\n",
55     "**Stateless:** the hidden state is reset to the initial state (often zero) at the start of each new sequence in a batch. So, the network treats each sequence independently, and no information is carried over in time between sequences. These models are simpler, but work better when time dependence is relatively short.\n",
56     "* **Input Data Shape:** (`n_sequence`, `timesteps`, `features`), where `n_sequence` is total number of sequences (a function of total observed times `T` and the user choice of timesteps). The input data does NOT need to be structured with batch size in stateless RNNs.\n",
57     "* **Tensorflow RNN Args:** for a stateless RNN, use the `input_shape` parameter, with `input_shape`=(`timesteps`, `features`). Then, `batch_size` can be declared in the fitting stage with `model.fit(X_train, batch_size = __)`. \n",
58     "\n",
59     "**Stateful:** the hidden states are carried over from one sequence to the next within a batch. Longer time dependencies can be learned in this way.\n",
60     "* **Input Data Shape:** (`batch_size`, `timesteps`, `features`). In order for the hidden state to be passed between sequences, the input data must be formatted using the `batch_size` hyperparameter.\n",
61     "* **Tensorflow RNN Args:** for a stateful RNN, use the `batch_input_shape` parameter, with `batch_input_shape`=(`batch_size`, `timesteps`, `features`)"
62    ]
63   },
64   {
65    "cell_type": "markdown",
66    "id": "50e9011c-ee92-4bc2-bf6b-461b5bf9c662",
67    "metadata": {},
68    "source": [
69     "## Examples\n",
70     "\n",
71     "### Data Description\n",
72     "\n",
73     "Consider $T=100$ observations of a variable in time $y_t$, so $t=1, ..., 100$. A feature matrix with $1$ features has dimensions $100\\times 1$, and must be restructured for use in RNNs."
74    ]
75   },
76   {
77    "cell_type": "code",
78    "execution_count": null,
79    "id": "03048248-0c81-491d-b0ce-a099bd984dae",
80    "metadata": {},
81    "outputs": [],
82    "source": [
83     "T=100 # total number of times obseved\n",
84     "\n",
85     "data = np.arange(0, 100).reshape(-1, 1)\n",
86     "# data = np.array([[f\"x{i}\"] for i in range(100)])\n",
87     "\n",
88     "\n",
89     "# Generate random response vector, needed by staircase func\n",
90     "y = np.arange(1, 101).reshape(-1, 1)\n",
91     "\n",
92     "print(f\"Response Data Shape: {y.shape}\")\n",
93     "print(\"First 10 rows\")\n",
94     "print(y[0:10])\n",
95     "\n",
96     "# Print head of data\n",
97     "print(f\"Feature Data Shape: {data.shape}\")\n",
98     "print(\"First 10 rows\")\n",
99     "data[0:10,:]"
100    ]
101   },
102   {
103    "cell_type": "markdown",
104    "id": "eb559e3f-b7ca-40a6-8d64-98df3bcb0f51",
105    "metadata": {},
106    "source": [
107     "The rows of the input data array represent all features at a single timepoint. The first digit represents the feature number, and the second digit represents time point. So value $13$ represents feature 1 at time 3. "
108    ]
109   },
110   {
111    "cell_type": "markdown",
112    "id": "571efd5b-7a38-4a79-9f34-8c5c7a64933f",
113    "metadata": {},
114    "source": [
115     "### Single Batch Example\n",
116     "\n",
117     "With a stateless RNN, the input data is structured to be of shape (`n_sequence`, `timesteps`, `features`). The `batch_size` is not needed to structure the data for a stateless RNN.\n",
118     "\n",
119     "When using functions that expect `batch_size` to structure the data, an option is to set `batch_size` to be some large number greater than the total number of observed times $T$, so that all the data is guarenteed to be in one batch. *NOTE:* here we trick the function by using a large batch size, but `batch_size` could still be declared at the fitting stage of the model.\n",
120     "\n",
121     "Suppose in we use `timesteps=5`, so we would get sequences of times `[1,2,3,4,5]` for the first sequence, then `[2,3,4,5,6]` for the next, and so on until `[96,97,98,99,100]`. \n",
122     "\n",
123     "Thus, there are `100-5+1=96` possible ordered sequences of length `5`. "
124    ]
125   },
126   {
127    "cell_type": "markdown",
128    "id": "7e17e4dc-62be-4001-ae07-1e9a985f5b9e",
129    "metadata": {},
130    "source": [
131     "We need to structure the input data to the RNN to be of shape (96, 5, 3). *Note:* since the model is stateless, and the sequences are treated independently, the actual order of the sequences doesn't matter."
132    ]
133   },
134   {
135    "cell_type": "markdown",
136    "id": "af506b7e-7ffd-484f-a1d5-a748da8f2f6d",
137    "metadata": {},
138    "source": [
139     "For a stateless RNN, the batches could consist of any collection of the sequences, since the sequences are indepenent. \n",
140     "\n",
141     "We want all of the data sequences to be in a single batch, but number of batches is not a direct user input for most built-in functions. To get around this, we make the batch size some number larger than the total number of observed times."
142    ]
143   },
144   {
145    "cell_type": "markdown",
146    "id": "865805e0-caf8-4931-b9d2-6bd64cda0189",
147    "metadata": {},
148    "source": [
149     "We now recreate this using the custom `staircase` function, which produces the same results for the input data:"
150    ]
151   },
152   {
153    "cell_type": "code",
154    "execution_count": null,
155    "id": "e350c887-af78-4036-bd66-f0e9fe574d1d",
156    "metadata": {
157     "scrolled": true
158    },
159    "outputs": [],
160    "source": [
161     "X_train, y_train = staircase(data, y, timesteps = 5, datapoints = len(y), verbose=True)"
162    ]
163   },
164   {
165    "cell_type": "code",
166    "execution_count": null,
167    "id": "7b51ca33-40ea-4ef4-a26d-b1a875e61313",
168    "metadata": {},
169    "outputs": [],
170    "source": [
171     "print(X_train.shape)\n",
172     "print(y_train.shape)"
173    ]
174   },
175   {
176    "cell_type": "markdown",
177    "id": "0bcca61f-c765-4553-861f-b2da60cb5fcc",
178    "metadata": {},
179    "source": [
180     "The first input sequence will be 5 observations of the features starting at time 0:"
181    ]
182   },
183   {
184    "cell_type": "code",
185    "execution_count": null,
186    "id": "32140b50-12e8-4352-80fd-507db473c5c5",
187    "metadata": {},
188    "outputs": [],
189    "source": [
190     "print(\"Sequence 1:\")\n",
191     "print(X_train[0,:,:])"
192    ]
193   },
194   {
195    "cell_type": "markdown",
196    "id": "b0a006c9-122c-4968-86cc-312472f26e02",
197    "metadata": {},
198    "source": [
199     "The second input sequence will be 5 observations of the features starting at time 1:"
200    ]
201   },
202   {
203    "cell_type": "code",
204    "execution_count": null,
205    "id": "418da438-16f4-4179-a997-6f0e1f3a974e",
206    "metadata": {},
207    "outputs": [],
208    "source": [
209     "print(\"Sequence 2:\")\n",
210     "X_train[1,:,:]"
211    ]
212   },
213   {
214    "cell_type": "markdown",
215    "id": "ec684b9c-a6bb-48a0-92f8-9350286a0dfd",
216    "metadata": {},
217    "source": [
218     "In this implementation, we structure the input data as all possible sequences of length 5, but there is no requirement to do it this way. With a stateless RNN, you can use any combination of sequences. The input data does not need to be highly structured, you can put any combination of sequences you want in an array."
219    ]
220   },
221   {
222    "cell_type": "markdown",
223    "id": "b14965b3-75cf-43e9-b3fe-85900e473a86",
224    "metadata": {},
225    "source": [
226     "### Stateful Multi-Batch Example"
227    ]
228   },
229   {
230    "cell_type": "markdown",
231    "id": "ce8685b5-7e79-462b-b782-6a493dc63184",
232    "metadata": {},
233    "source": [
234     "We now need the data in the format (`batch_size`, `timesteps`, `features`) for each batch. A stateful RNN maintains a hidden state between batches. So sequence $i$ in batch $j$ needs to be a continuation of sequence $i$ in batch $j-1$."
235    ]
236   },
237   {
238    "cell_type": "code",
239    "execution_count": null,
240    "id": "341283b8-f66a-4c6b-a314-a3df4d1f08bb",
241    "metadata": {},
242    "outputs": [],
243    "source": [
244     "X_train, y_train = staircase_2(data, y, timesteps = 5, batch_size = 32, verbose=False)"
245    ]
246   },
247   {
248    "cell_type": "markdown",
249    "id": "a28a11c9-c3f2-4e78-a455-7611a0edb920",
250    "metadata": {},
251    "source": [
252     "The first input sequence will again be 5 observations of the features starting at time 0. Since the batch size is 32, the 33rd input sequence will be a continuation of the 1st input sequence. So the 33rd input sequence starts at time 5:"
253    ]
254   },
255   {
256    "cell_type": "code",
257    "execution_count": null,
258    "id": "a7164cbf-8009-43bb-aa0b-1ed7c22d1212",
259    "metadata": {},
260    "outputs": [],
261    "source": [
262     "print(\"Sequence 1, Batch 1:\")\n",
263     "print(X_train[0,:,:])\n",
264     "print(\"Sequence 1, Batch 2:\")\n",
265     "print(X_train[0+32,:,:])\n",
266     "print(\"Sequence 1, Batch 3:\")\n",
267     "print(X_train[32+32,:,:])"
268    ]
269   },
270   {
271    "cell_type": "code",
272    "execution_count": null,
273    "id": "c74a00f8-8fa8-4f99-b0db-4100d9fdcd27",
274    "metadata": {},
275    "outputs": [],
276    "source": [
277     "print(\"Sequence 2, Batch 1:\")\n",
278     "print(X_train[1,:,:])\n",
279     "print(\"Sequence 2, Batch 2:\")\n",
280     "print(X_train[1+32,:,:])\n",
281     "print(\"Sequence 2, Batch 3:\")\n",
282     "print(X_train[1+32+32,:,:])"
283    ]
284   },
285   {
286    "cell_type": "markdown",
287    "id": "a261a94f-8fbb-4cb9-98ae-86df0ea8f6d8",
288    "metadata": {},
289    "source": [
290     "By setting a RNN to be stateful with batch size 32, in the first batch the model will run each sequence through the model. The hidden state at the end of sequence $i$ is used as the initial hidden state for seqeunce $i$ in the next batch.\n",
291     "\n",
292     "In this example we again structured the data to use all possible sequences of length 5. So within batch 1, the 1st sequence starts at time zero, the next sequence starts at time 1. But within a batch, there does not need to be any temporal relationship between the sequences. The only requirement is that the $i^{th}$ sequence within each batch lines up."
293    ]
294   },
295   {
296    "cell_type": "markdown",
297    "id": "6105be3d-c5c0-4faa-ac56-056121579ff0",
298    "metadata": {},
299    "source": [
300     "## Multiple Timeseries Data Structure"
301    ]
302   },
303   {
304    "cell_type": "markdown",
305    "id": "3f23cf68-224a-4b39-87db-827b6614fb3f",
306    "metadata": {},
307    "source": [
308     "Before we showed how to structure the input data if we have observations from 1 timeseries. In spatial contexts, there will be observation at multiple locations. There are many different ways to implement training for RNNs using multiple timeseries."
309    ]
310   },
311   {
312    "cell_type": "markdown",
313    "id": "94891cfc-b853-4fa4-a46b-fed2c8b1f4e6",
314    "metadata": {},
315    "source": [
316     "### Example: Stateful & Locations Equals Batch Size, Same Start Time\n",
317     "\n",
318     "As a simplifying first example, suppose we set the batch size equal to the number of unique locations. So suppose we have 3 timeseries of observations at 3 unique locations. In this example, we will suppose they occur at the same time, but the data structure used here does not depend on that. The only temporal dependence is between sequences that will share a hidden state across batches in a stateful RNN. We include a second column in the feature matrix for location index below, values 1, 2, or 3. "
319    ]
320   },
321   {
322    "cell_type": "code",
323    "execution_count": null,
324    "id": "df2514f5-8e04-4ec3-b934-145f92122803",
325    "metadata": {},
326    "outputs": [],
327    "source": [
328     "data2 = np.column_stack(\n",
329     "    (np.concatenate((np.arange(0, 100), np.arange(0, 100), np.arange(0, 100))),\n",
330     "    np.concatenate((np.repeat(1, 100), np.repeat(2, 100), np.repeat(3, 100))))\n",
331     ")"
332    ]
333   },
334   {
335    "cell_type": "code",
336    "execution_count": null,
337    "id": "613dd7e9-d518-4cb4-ac88-fc15baeb966a",
338    "metadata": {},
339    "outputs": [],
340    "source": [
341     "print('First 10 observations at location 1:')\n",
342     "data2[0:10,:]"
343    ]
344   },
345   {
346    "cell_type": "code",
347    "execution_count": null,
348    "id": "e8af0cc6-9985-4c65-a18a-0770a32636fd",
349    "metadata": {},
350    "outputs": [],
351    "source": [
352     "print('First 10 observations at location 2:')\n",
353     "data2[100:110,:]"
354    ]
355   },
356   {
357    "cell_type": "code",
358    "execution_count": null,
359    "id": "fcc9c682-1f68-43ca-a033-c189f34320ae",
360    "metadata": {},
361    "outputs": [],
362    "source": [
363     "print('First 10 observations at location 3:')\n",
364     "data2[200:210,:]"
365    ]
366   },
367   {
368    "cell_type": "markdown",
369    "id": "d2cbd92c-c6f7-467d-869a-e63588799e5a",
370    "metadata": {},
371    "source": [
372     "In this example, we construct a dataset with `batch_size` = 3."
373    ]
374   },
375   {
376    "cell_type": "code",
377    "execution_count": null,
378    "id": "3ff0bd5c-7934-47a9-9ad1-cad842c87c39",
379    "metadata": {},
380    "outputs": [],
381    "source": [
382     "X1, y1 = staircase_2(data2[data2[:,1] == 1], y, timesteps = 5, batch_size = 1, verbose=False)\n",
383     "X2, y2 = staircase_2(data2[data2[:,1] == 2], y, timesteps = 5, batch_size = 1, verbose=False)\n",
384     "X3, y3 = staircase_2(data2[data2[:,1] == 3], y, timesteps = 5, batch_size = 1, verbose=False)\n",
385     "\n",
386     "Xs = [X1, X2, X3]"
387    ]
388   },
389   {
390    "cell_type": "code",
391    "execution_count": null,
392    "id": "2cd24d17-9936-4ec9-9781-0312d307d34b",
393    "metadata": {},
394    "outputs": [],
395    "source": [
396     "[Xi.shape[0] for Xi in Xs]"
397    ]
398   },
399   {
400    "cell_type": "code",
401    "execution_count": null,
402    "id": "7a17fa42-70df-4da0-aae4-62b4306a8d85",
403    "metadata": {},
404    "outputs": [],
405    "source": [
406     "locs = len(Xs)\n",
407     "XX = np.empty((Xs[0].shape[0]*locs, 5, 2))"
408    ]
409   },
410   {
411    "cell_type": "code",
412    "execution_count": null,
413    "id": "7233bde1-8169-4731-b78d-6cc9efa4cd3f",
414    "metadata": {},
415    "outputs": [],
416    "source": [
417     "for i in range(0,locs):\n",
418     "    XX[i::locs] =  Xs[i]"
419    ]
420   },
421   {
422    "cell_type": "code",
423    "execution_count": null,
424    "id": "3c637c93-882a-4e09-9871-5c406b3efb88",
425    "metadata": {},
426    "outputs": [],
427    "source": [
428     "print(\"Sequence 1, Batch 1\")\n",
429     "print(XX[0,:,:])\n",
430     "print(\"Sequence 1, Batch 2\")\n",
431     "print(XX[3,:,:])"
432    ]
433   },
434   {
435    "cell_type": "code",
436    "execution_count": null,
437    "id": "7351737c-d0fa-42c9-b2cc-8530e3a88a76",
438    "metadata": {},
439    "outputs": [],
440    "source": [
441     "print(\"Sequence 2, Batch 1\")\n",
442     "print(XX[1,:,:])\n",
443     "print(\"Sequence 2, Batch 2\")\n",
444     "print(XX[4,:,:])"
445    ]
446   },
447   {
448    "cell_type": "code",
449    "execution_count": null,
450    "id": "74ed5a15-198a-4a4b-b334-f2720fee2185",
451    "metadata": {},
452    "outputs": [],
453    "source": [
454     "print(\"Sequence 3, Batch 1\")\n",
455     "print(XX[2,:,:])\n",
456     "print(\"Sequence 3, Batch 2\")\n",
457     "print(XX[5,:,:])"
458    ]
459   },
460   {
461    "cell_type": "markdown",
462    "id": "3ff9f26c-5c58-4450-b885-81e5e1f203ea",
463    "metadata": {},
464    "source": [
465     "### Example: Stateful & Locations Equals Batch Size, Staggered Start Time\n",
466     "\n",
467     "In the previous example, within a batch the sequences all start at the same time. This can lead to over-reliance on the particular ordering of the data. In this next example, we will use the same data from 3 locations as before, but we will stagger the start time of the sequences. This will result in losing a sequence at the end of the timeseries that are offset, so we filter out the data to match in dimensions."
468    ]
469   },
470   {
471    "cell_type": "code",
472    "execution_count": null,
473    "id": "d4940983-9568-4092-9705-98b5b5383103",
474    "metadata": {},
475    "outputs": [],
476    "source": [
477     "X1, y1 = staircase_2(data2[(data2[:,1] == 1) & (data2[:,0]>= 0)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
478     "X2, y2 = staircase_2(data2[(data2[:,1] == 2) & (data2[:,0]>= 1)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
479     "X3, y3 = staircase_2(data2[(data2[:,1] == 3) & (data2[:,0]>= 2)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
480     "\n",
481     "Xs = [X1, X2, X3]"
482    ]
483   },
484   {
485    "cell_type": "code",
486    "execution_count": null,
487    "id": "577d1a36-67f5-42a2-ab1f-9d6ade3404dd",
488    "metadata": {},
489    "outputs": [],
490    "source": [
491     "lens = [Xi.shape[0] for Xi in Xs]\n",
492     "print(lens)\n",
493     "print(min(lens))"
494    ]
495   },
496   {
497    "cell_type": "code",
498    "execution_count": null,
499    "id": "cd43b870-99b4-43ff-857d-8a36a7e26cfd",
500    "metadata": {},
501    "outputs": [],
502    "source": [
503     "# Filter each array to be same length\n",
504     "min_shape = min(lens)\n",
505     "Xs = [Xi[:min_shape] for Xi in Xs]"
506    ]
507   },
508   {
509    "cell_type": "code",
510    "execution_count": null,
511    "id": "7648b9bb-df00-41b7-b942-76e6887e2ac7",
512    "metadata": {},
513    "outputs": [],
514    "source": [
515     "[Xi.shape[0] for Xi in Xs]"
516    ]
517   },
518   {
519    "cell_type": "code",
520    "execution_count": null,
521    "id": "06976280-3f5a-4d7a-8a86-211f1109152c",
522    "metadata": {},
523    "outputs": [],
524    "source": [
525     "locs = len(Xs)\n",
526     "XX = np.empty((Xs[0].shape[0]*locs, 5, 2))\n",
527     "\n",
528     "for i in range(0,locs):\n",
529     "    XX[i::locs] =  Xs[i]"
530    ]
531   },
532   {
533    "cell_type": "code",
534    "execution_count": null,
535    "id": "58241ac4-adcc-4cd5-b1af-9485788fe6fd",
536    "metadata": {},
537    "outputs": [],
538    "source": [
539     "print(\"Sequence 1, Batch 1\")\n",
540     "print(XX[0,:,:])\n",
541     "print(\"Sequence 1, Batch 2\")\n",
542     "print(XX[3,:,:])"
543    ]
544   },
545   {
546    "cell_type": "code",
547    "execution_count": null,
548    "id": "520d62d2-d735-4f0d-8481-8dfe86c3c563",
549    "metadata": {},
550    "outputs": [],
551    "source": [
552     "print(\"Sequence 2, Batch 1\")\n",
553     "print(XX[1,:,:])\n",
554     "print(\"Sequence 2, Batch 2\")\n",
555     "print(XX[4,:,:])"
556    ]
557   },
558   {
559    "cell_type": "code",
560    "execution_count": null,
561    "id": "3ffa7073-841d-444b-b6b1-71dd6b3f8040",
562    "metadata": {},
563    "outputs": [],
564    "source": [
565     "print(\"Sequence 3, Batch 1\")\n",
566     "print(XX[2,:,:])\n",
567     "print(\"Sequence 3, Batch 2\")\n",
568     "print(XX[5,:,:])"
569    ]
570   },
571   {
572    "cell_type": "markdown",
573    "id": "eed650eb-d76e-4b95-9cf7-f92d4746379a",
574    "metadata": {},
575    "source": [
576     "### Example: More Locations than Batch Size"
577    ]
578   },
579   {
580    "cell_type": "code",
581    "execution_count": null,
582    "id": "3e7697c1-bf9c-436d-b41e-ad040799329a",
583    "metadata": {},
584    "outputs": [],
585    "source": [
586     "def batch_setup(x, batch_size):\n",
587     "    # Ensure x is a numpy array\n",
588     "    x = np.array(x)\n",
589     "    \n",
590     "    # Initialize the list to hold the batches\n",
591     "    batches = []\n",
592     "    \n",
593     "    # Use a loop to slice the list/array into batches\n",
594     "    for i in range(0, len(x), batch_size):\n",
595     "        batch = list(x[i:i + batch_size])\n",
596     "        \n",
597     "        # If the batch is not full, continue from the start\n",
598     "        while len(batch) < batch_size:\n",
599     "            # Calculate the remaining number of items needed\n",
600     "            remaining = batch_size - len(batch)\n",
601     "            # Append the needed number of items from the start of the array\n",
602     "            batch.extend(x[:remaining])\n",
603     "        \n",
604     "        batches.append(batch)\n",
605     "    \n",
606     "    return batches"
607    ]
608   },
609   {
610    "cell_type": "code",
611    "execution_count": null,
612    "id": "d5460c73-d735-4dc0-a795-4f1418c108c2",
613    "metadata": {},
614    "outputs": [],
615    "source": [
616     "data2 = np.column_stack(\n",
617     "    (np.concatenate((np.arange(0, 100), np.arange(0, 100), np.arange(0, 100), np.arange(0, 100))),\n",
618     "    np.concatenate((np.repeat(1, 100), np.repeat(2, 100), np.repeat(3, 100), np.repeat(4, 100))))\n",
619     ")"
620    ]
621   },
622   {
623    "cell_type": "code",
624    "execution_count": null,
625    "id": "e8ca3d48-3b17-488a-b03c-2704455c91c1",
626    "metadata": {},
627    "outputs": [],
628    "source": [
629     "data_config = {\n",
630     "    'nloc': 4, # Unique locations\n",
631     "    'start_times': [0,2,4,6], # relative to first observation, must match number of locs\n",
632     "    'hours': 100, # total number of hours to use from data\n",
633     "    'batch_size': 2,\n",
634     "    'seq_length': 5\n",
635     "}"
636    ]
637   },
638   {
639    "cell_type": "code",
640    "execution_count": null,
641    "id": "f26d3350-497f-46d6-bd9d-82d2613cb75c",
642    "metadata": {},
643    "outputs": [],
644    "source": [
645     "print(f\"Unique Locations: {data_config['nloc']}\")"
646    ]
647   },
648   {
649    "cell_type": "code",
650    "execution_count": null,
651    "id": "8e4d76ac-c72a-43a8-bea0-5f22ad1b267d",
652    "metadata": {},
653    "outputs": [],
654    "source": [
655     "# Create array of location IDs\n",
656     "loc_ids = np.arange(data_config['nloc'])\n",
657     "loc_names = np.unique(data2[:,1])\n",
658     "\n",
659     "loc_batches, t_batch =  batch_setup(loc_ids, 2), batch_setup(data_config['start_times'], 2)\n",
660     "print(loc_batches)\n",
661     "print(t_batch)"
662    ]
663   },
664   {
665    "cell_type": "code",
666    "execution_count": null,
667    "id": "1dded017-6565-46ff-a6a6-a416d7b05f63",
668    "metadata": {},
669    "outputs": [],
670    "source": [
671     "# j = data_config['locs'][0] # starting location index\n",
672     "Xs = []\n",
673     "hours = data_config[\"hours\"]\n",
674     "for i in range(0, data_config[\"batch_size\"]):\n",
675     "    locs = loc_batches[i]\n",
676     "    ts = t_batch[i]\n",
677     "    for j in range(0, len(locs)):\n",
678     "        loc = loc_names[locs[j]]\n",
679     "        t0 = ts[j]\n",
680     "        # Subset data to given location and time from t0 to t0+hours\n",
681     "        dat_temp = data2[(data2[:,1] == loc) & (data2[:,0]>= t0) & (data2[:,0]< t0+hours)]\n",
682     "        # Format sequences\n",
683     "        Xi, yi = staircase_2(\n",
684     "            dat_temp, \n",
685     "            y, \n",
686     "            timesteps = data_config['seq_length'], \n",
687     "            batch_size = 1,  # note: using 1 here to format sequences for a single location, not same as target batch size for training data\n",
688     "            verbose=False)\n",
689     "    \n",
690     "        Xs.append(Xi)"
691    ]
692   },
693   {
694    "cell_type": "code",
695    "execution_count": null,
696    "id": "93e14f0a-cbfc-4ccc-9f53-33c4d733f994",
697    "metadata": {},
698    "outputs": [],
699    "source": [
700     "batch_size = data_config['batch_size']\n",
701     "lens = [Xi.shape[0] for Xi in Xs]\n",
702     "min_shape = min(lens)\n",
703     "Xs = [Xi[:min_shape] for Xi in Xs]"
704    ]
705   },
706   {
707    "cell_type": "code",
708    "execution_count": null,
709    "id": "d7fa971b-0018-477e-a400-766d46aa48c8",
710    "metadata": {},
711    "outputs": [],
712    "source": [
713     "XXs = []\n",
714     "for i in range(0, len(loc_batches)):\n",
715     "    locs = loc_batches[i]\n",
716     "    XXi = np.empty((Xs[0].shape[0]*batch_size, 5, 2))\n",
717     "    for j in range(0, len(locs)):\n",
718     "        XXi[j::(batch_size)] =  Xs[locs[j]]\n",
719     "    XXs.append(XXi)"
720    ]
721   },
722   {
723    "cell_type": "code",
724    "execution_count": null,
725    "id": "0d5e1c57-b045-4719-bd96-cd6bec8ed8e8",
726    "metadata": {},
727    "outputs": [],
728    "source": [
729     "XX = np.concatenate(XXs, axis=0)\n",
730     "print(XX.shape)"
731    ]
732   },
733   {
734    "cell_type": "markdown",
735    "id": "95bbbc3a-4146-4458-a09a-9646d317bf66",
736    "metadata": {},
737    "source": [
738     "The batches at the start of the data structure include only a subset of locations, since there are more locations than the batch size. So the first few sequences are structured the way they were before."
739    ]
740   },
741   {
742    "cell_type": "code",
743    "execution_count": null,
744    "id": "fb13356f-d71f-48be-ab30-47540a0c0f83",
745    "metadata": {},
746    "outputs": [],
747    "source": [
748     "print(\"Sequence 1, Batch 1\")\n",
749     "print(XX[0,:,:])\n",
750     "print(\"Sequence 1, Batch 2\")\n",
751     "print(XX[0+batch_size,:,:])\n",
752     "print(\"⋮\")\n",
753     "print(f\"Sequence 1, Batch {min_shape}\")\n",
754     "print(XX[0+(min_shape-1)*batch_size,:,:])"
755    ]
756   },
757   {
758    "cell_type": "code",
759    "execution_count": null,
760    "id": "3ec0be3f-60fd-4fe4-b1f5-7662e46a8f5a",
761    "metadata": {},
762    "outputs": [],
763    "source": [
764     "print(\"Sequence 2, Batch 1\")\n",
765     "print(XX[1,:,:])\n",
766     "print(\"Sequence 2, Batch 2\")\n",
767     "print(XX[1+batch_size,:,:])\n",
768     "print(\"⋮\")\n",
769     "print(f\"Sequence 1, Batch {min_shape}\")\n",
770     "print(XX[1+(min_shape-1)*batch_size,:,:])"
771    ]
772   },
773   {
774    "cell_type": "markdown",
775    "id": "309586a4-b34e-4f87-9668-96ad79e19046",
776    "metadata": {},
777    "source": [
778     "After all of the sequences from the first subset of locations are used, the locations change and then the remaining batches use those locations. \n",
779     "\n",
780     "Since this data structure is for a stateful RNN, at this point within an epoch of training the hidden states of the RNN must be reset using the `reset_states` function callback. This is to avoid the situation where the hidden state from a certain location and time is passed to a different location at a different time."
781    ]
782   },
783   {
784    "cell_type": "code",
785    "execution_count": null,
786    "id": "4f8206d9-e36d-4d96-adad-8dffff19f2bd",
787    "metadata": {},
788    "outputs": [],
789    "source": [
790     "print(f\"Sequence 1, Batch {min_shape+1}\")\n",
791     "print(XX[0+(min_shape)*batch_size,:,:])\n",
792     "print(f\"Sequence 1, Batch {min_shape+2}\")\n",
793     "print(XX[0+(min_shape)*batch_size+batch_size,:,:])\n",
794     "print(f\"Sequence 1, Batch {min_shape+3}\")\n",
795     "print(XX[0+(min_shape)*batch_size+2*batch_size,:,:])\n",
796     "print(\"⋮\")"
797    ]
798   },
799   {
800    "cell_type": "code",
801    "execution_count": null,
802    "id": "87af2279-5163-4d66-99f7-958be70afa4f",
803    "metadata": {},
804    "outputs": [],
805    "source": [
806     "print(f\"Sequence 2, Batch {min_shape+1}\")\n",
807     "print(XX[1+(min_shape)*batch_size,:,:])\n",
808     "print(f\"Sequence 2, Batch {min_shape+2}\")\n",
809     "print(XX[1+(min_shape)*batch_size+batch_size,:,:])\n",
810     "print(f\"Sequence 2, Batch {min_shape+3}\")\n",
811     "print(XX[1+(min_shape)*batch_size+2*batch_size,:,:])\n",
812     "print(\"⋮\")"
813    ]
814   },
815   {
816    "cell_type": "code",
817    "execution_count": null,
818    "id": "85a34346-5eca-4820-861a-3c574a601ad9",
819    "metadata": {},
820    "outputs": [],
821    "source": []
822   },
823   {
824    "cell_type": "code",
825    "execution_count": null,
826    "id": "09eb0cea-531e-47e9-9586-561b1ca62c49",
827    "metadata": {},
828    "outputs": [],
829    "source": [
830     "def format_spatial_stateful_data(data, loc_ids, start_times, hours, sequence_length, batch_size, verbose=True):\n",
831     "    # loc_ids: list or array of unique location names\n",
832     "\n",
833     "    # Create array of location indices\n",
834     "    inds = np.arange(len(loc_inds))\n",
835     "\n",
836     "    # Set up structure of times and batches\n",
837     "    loc_batches, t_batch =  batch_setup(inds, batch_size), batch_setup(start_times, batch_size)\n",
838     "    if verbose:\n",
839     "        print(loc_batches)\n",
840     "        print(t_batch)\n",
841     "\n",
842     "    # Loop over batches and construct sequences w staircase_2 \n",
843     "    Xs = []\n",
844     "    for i in range(0, batch_size):\n",
845     "        locs = loc_batches[i]\n",
846     "        ts = t_batch[i]\n",
847     "        for j in range(0, len(locs)):\n",
848     "            loc = loc_ids[locs[j]]\n",
849     "            t0 = ts[j]\n",
850     "            # Subset data to given location and time from t0 to t0+hours\n",
851     "            dat_temp = data2[(data2[:,1] == loc) & (data2[:,0]>= t0) & (data2[:,0]< t0+hours)]\n",
852     "            # Format sequences\n",
853     "            Xi, yi = staircase_2(\n",
854     "                dat_temp, \n",
855     "                y, \n",
856     "                timesteps = sequence_length, \n",
857     "                batch_size = 1,  # note: using 1 here to format sequences for a single location, not same as target batch size for training data\n",
858     "                verbose=False)\n",
859     "        \n",
860     "            Xs.append(Xi)\n",
861     "    \n",
862     "    return"
863    ]
864   },
865   {
866    "cell_type": "code",
867    "execution_count": null,
868    "id": "bc84c2c6-bfad-4262-bb4b-b9e9f74cfd9f",
869    "metadata": {},
870    "outputs": [],
871    "source": []
872   },
873   {
874    "cell_type": "code",
875    "execution_count": null,
876    "id": "1e3b033c-be48-41fb-b63d-3b9ed661328f",
877    "metadata": {},
878    "outputs": [],
879    "source": []
880   },
881   {
882    "cell_type": "markdown",
883    "id": "4a3db8d5-3df8-4d23-b2ba-47c9bd90e97c",
884    "metadata": {},
885    "source": [
886     "### Example: Fewer Locations than Batch Size"
887    ]
888   },
889   {
890    "cell_type": "code",
891    "execution_count": null,
892    "id": "08954a73-5b7b-4c05-bbdc-6324b07b389c",
893    "metadata": {},
894    "outputs": [],
895    "source": [
896     "# X1, y1 = staircase_2(data2[(data2[:,1] == 1) & (data2[:,0]>= 0)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
897     "# X2, y2 = staircase_2(data2[(data2[:,1] == 2) & (data2[:,0]>= 1)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
898     "# X3, y3 = staircase_2(data2[(data2[:,1] == 3) & (data2[:,0]>= 2)], y, timesteps = 5, batch_size = 1, verbose=False)\n",
899     "\n",
900     "# Xs = [X1, X2, X3]"
901    ]
902   },
903   {
904    "cell_type": "markdown",
905    "id": "f0905891-6b67-4a47-93f7-222dedcf74cb",
906    "metadata": {},
907    "source": [
908     "## References\n",
909     "\n",
910     "https://d2l.ai/chapter_recurrent-neural-networks/bptt.html\n",
911     "\n",
912     "https://www.tensorflow.org/guide/keras/working_with_rnns#cross-batch_statefulness\n",
913     "\n",
914     "Tensorflow `timeseries_dataset_from_array` tutorial: https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/timeseries_dataset_from_array\n",
915     "\n",
916     "Wiki BPTT: https://en.wikipedia.org/wiki/Backpropagation_through_time#:~:text=Backpropagation%20through%20time%20(BPTT)%20is,independently%20derived%20by%20numerous%20researchers.\n",
917     "\n",
918     "https://machinelearningmastery.com/understanding-stateful-lstm-recurrent-neural-networks-python-keras/"
919    ]
920   },
921   {
922    "cell_type": "code",
923    "execution_count": null,
924    "id": "b3898af7-b504-4308-87b2-54c217709a54",
925    "metadata": {},
926    "outputs": [],
927    "source": []
928   }
929  ],
930  "metadata": {
931   "kernelspec": {
932    "display_name": "Python 3 (ipykernel)",
933    "language": "python",
934    "name": "python3"
935   },
936   "language_info": {
937    "codemirror_mode": {
938     "name": "ipython",
939     "version": 3
940    },
941    "file_extension": ".py",
942    "mimetype": "text/x-python",
943    "name": "python",
944    "nbconvert_exporter": "python",
945    "pygments_lexer": "ipython3",
946    "version": "3.12.5"
947   }
948  },
949  "nbformat": 4,
950  "nbformat_minor": 5