Delete create_rnn_data2
[notebooks.git] / fmda / moisture_rnn.py
blob9b20a8aa54bd0ae1fca4d67db443f6885f9b7652
1 # v2 training and prediction class infrastructure
3 # Environment
4 import numpy as np
5 import pandas as pd
6 import tensorflow as tf
7 import matplotlib.pyplot as plt
8 import sys
9 from tensorflow.keras.callbacks import Callback, EarlyStopping, TerminateOnNaN
10 # from sklearn.metrics import mean_squared_error
11 import logging
12 from tensorflow.keras.layers import LSTM, SimpleRNN, Input, Dropout, Dense
13 # Local modules
14 import reproducibility
15 # from utils import print_dict_summary
16 from abc import ABC, abstractmethod
17 from utils import hash2, all_items_exist, hash_ndarray, hash_weights
18 from data_funcs import rmse, plot_data, compare_dicts
19 import copy
20 # import yaml
21 from sklearn.preprocessing import MinMaxScaler, StandardScaler
22 import warnings
24 #*************************************************************************************
25 # Data Formatting Functions
27 def staircase(x,y,timesteps,datapoints,return_sequences=False, verbose = False):
28 # x [datapoints,features] all inputs
29 # y [datapoints,outputs]
30 # timesteps: split x and y into samples length timesteps, shifted by 1
31 # datapoints: number of timesteps to use for training, no more than y.shape[0]
32 if verbose:
33 print('staircase: shape x = ',x.shape)
34 print('staircase: shape y = ',y.shape)
35 print('staircase: timesteps=',timesteps)
36 print('staircase: datapoints=',datapoints)
37 print('staircase: return_sequences=',return_sequences)
38 outputs = y.shape[1]
39 features = x.shape[1]
40 samples = datapoints-timesteps+1
41 if verbose:
42 print('staircase: samples=',samples,'timesteps=',timesteps,'features=',features)
43 x_train = np.empty([samples, timesteps, features])
44 if return_sequences:
45 if verbose:
46 print('returning all timesteps in a sample')
47 y_train = np.empty([samples, timesteps, outputs]) # all
48 for i in range(samples):
49 for k in range(timesteps):
50 x_train[i,k,:] = x[i+k,:]
51 y_train[i,k,:] = y[i+k,:]
52 else:
53 if verbose:
54 print('returning only the last timestep in a sample')
55 y_train = np.empty([samples, outputs])
56 for i in range(samples):
57 for k in range(timesteps):
58 x_train[i,k,:] = x[i+k,:]
59 y_train[i,:] = y[i+timesteps-1,:]
61 return x_train, y_train
63 def staircase_2(x,y,timesteps,batch_size=None,trainsteps=np.inf,return_sequences=False, verbose = False):
64 # create RNN training data in multiple batches
65 # input:
66 # x (,features)
67 # y (,outputs)
68 # timesteps: split x and y into sequences length timesteps
69 # a.k.a. lookback or sequence_length
71 # print params if verbose
73 if batch_size is None:
74 raise ValueError('staircase_2 requires batch_size')
75 if verbose:
76 print('staircase_2: shape x = ',x.shape)
77 print('staircase_2: shape y = ',y.shape)
78 print('staircase_2: timesteps=',timesteps)
79 print('staircase_2: batch_size=',batch_size)
80 print('staircase_2: return_sequences=',return_sequences)
82 nx,features= x.shape
83 ny,outputs = y.shape
84 datapoints = min(nx,ny,trainsteps)
85 if verbose:
86 print('staircase_2: datapoints=',datapoints)
88 # sequence j in a given batch is assumed to be the continuation of sequence j in the previous batch
89 # https://www.tensorflow.org/guide/keras/working_with_rnns Cross-batch statefulness
91 # example with timesteps=3 batch_size=3 datapoints=15
92 # batch 0: [0 1 2] [1 2 3] [2 3 4]
93 # batch 1: [3 4 5] [4 5 6] [5 6 7]
94 # batch 2: [6 7 8] [7 8 9] [8 9 10]
95 # batch 3: [9 10 11] [10 11 12] [11 12 13]
96 # batch 4: [12 13 14] [13 14 15] when runs out this is the last batch, can be shorter
98 # TODO: implement for multiple locations, same starting time for each batch
99 # Loc 1 Loc 2 Loc 3
100 # batch 0: [0 1 2] [0 1 2] [0 1 2]
101 # batch 1: [3 4 5] [3 4 5] [3 4 5]
102 # batch 2: [6 7 8] [6 7 8] [6 7 8]
103 # TODO: second epoch shift starting time at batch 0 in time
105 # TODO: implement for multiple locations, different starting times for each batch
106 # Loc 1 Loc 2 Loc 3
107 # batch 0: [0 1 2] [1 2 3] [2 3 4]
108 # batch 1: [3 4 5] [4 5 6] [5 6 57
109 # batch 2: [6 7 8] [7 8 9] [8 9 10]
112 # the first sample in batch j starts from timesteps*j and ends with timesteps*(j+1)-1
113 # e.g. the final hidden state of the rnn after the sequence of steps [0 1 2] in batch 0
114 # becomes the starting hidden state of the rnn in the sequence of steps [3 4 5] in batch 1, etc.
116 # sample [0 1 2] means the rnn is used twice to map state 0 -> 1 -> 2
117 # the state at time 0 is fixed but the state is considered a variable at times 1 and 2
118 # the loss is computed from the output at time 2 and the gradient of the loss function by chain rule which ends at time 0 because the state there is a constant -> derivative is zero
119 # sample [3 4 5] means the rnn is used twice to map state 3 -> 4 -> 5 # the state at time 3 is fixed to the output of the first sequence [0 1 2]
120 # the loss is computed from the output at time 5 and the gradient of the loss function by chain rule which ends at time 3 because the state there is considered constant -> derivative is zero
121 # how is the gradient computed? I suppose keras adds gradient wrt the weights at 2 5 8 ... 3 6 9... 4 7 ... and uses that to update the weights
122 # there is only one set of weights h(2) = f(h(1),w) h(1) = f(h(0),w) but w is always the same
123 # each column is a one successive evaluation of h(n+1) = f(h(n),w) for n = n_startn n_start+1,...
124 # the cannot be evaluated efficiently on gpu because gpu is a parallel processor
125 # this of it as each column served by one thread, and the threads are independent because they execute in parallel, there needs to be large number of threads (32 is a good number)\
126 # each batch consists of independent calculations
127 # but it can depend on the result of the previous batch (that's the recurrent parr)
131 max_batches = datapoints // timesteps
132 max_sequences = max_batches * batch_size
134 if verbose:
135 print('staircase_2: max_batches=',max_batches)
136 print('staircase_2: max_sequences=',max_sequences)
138 x_train = np.zeros((max_sequences, timesteps, features))
139 if return_sequences:
140 y_train = np.empty((max_sequences, timesteps, outputs))
141 else:
142 y_train = np.empty((max_sequences, outputs ))
144 # build the sequences
146 for i in range(max_batches):
147 for j in range(batch_size):
148 begin = i*timesteps + j
149 next = begin + timesteps
150 if next > datapoints:
151 break
152 if verbose:
153 print('sequence',k,'batch',i,'sample',j,'data',begin,'to',next-1)
154 x_train[k,:,:] = x[begin:next,:]
155 if return_sequences:
156 y_train[k,:,:] = y[begin:next,:]
157 else:
158 y_train[k,:] = y[next-1,:]
159 k += 1
160 if verbose:
161 print('staircase_2: shape x_train = ',x_train.shape)
162 print('staircase_2: shape y_train = ',y_train.shape)
163 print('staircase_2: sequences generated',k)
164 print('staircase_2: batch_size=',batch_size)
165 k = (k // batch_size) * batch_size
166 if verbose:
167 print('staircase_2: removing partial and empty batches at the end, keeping',k)
168 x_train = x_train[:k,:,:]
169 if return_sequences:
170 y_train = y_train[:k,:,:]
171 else:
172 y_train = y_train[:k,:]
174 if verbose:
175 print('staircase_2: shape x_train = ',x_train.shape)
176 print('staircase_2: shape y_train = ',y_train.shape)
178 return x_train, y_train
181 # Dictionary of scalers, used to avoid multiple object creation and to avoid multiple if statements
182 scalers = {
183 'minmax': MinMaxScaler(),
184 'standard': StandardScaler()
188 def batch_setup(ids, batch_size):
190 Sets up stateful batched training data scheme for RNN training.
192 This function takes a list or array of identifiers (`ids`) and divides them into batches of a specified size (`batch_size`). If the last batch does not have enough elements to meet the `batch_size`, the function will loop back to the start of the identifiers and continue filling the batch until it reaches the required size.
194 Parameters:
195 -----------
196 ids : list or numpy array
197 A list or numpy array containing the ids to be batched.
199 batch_size : int
200 The desired size of each batch.
202 Returns:
203 --------
204 batches : list of lists
205 A list where each element is a batch (itself a list) of identifiers. Each batch will contain exactly `batch_size` elements.
207 Example:
208 --------
209 >>> ids = [1, 2, 3, 4, 5]
210 >>> batch_size = 3
211 >>> batch_setup(ids, batch_size)
212 [[1, 2, 3], [4, 5, 1]]
214 Notes:
215 ------
216 - If `ids` is shorter than `batch_size`, the returned list will contain a single batch where identifiers are repeated from the start of `ids` until the batch is filled.
217 """
218 # Ensure ids is a numpy array
219 x = np.array(ids)
221 # Initialize the list to hold the batches
222 batches = []
224 # Use a loop to slice the list/array into batches
225 for i in range(0, len(x), batch_size):
226 batch = list(x[i:i + batch_size])
228 # If the batch is not full, continue from the start
229 while len(batch) < batch_size:
230 # Calculate the remaining number of items needed
231 remaining = batch_size - len(batch)
232 # Append the needed number of items from the start of the array
233 batch.extend(x[:remaining])
235 batches.append(batch)
237 return batches
239 def staircase_spatial(X, y, batch_size, timesteps, hours=None, start_times = None, verbose = True):
241 Prepares spatially formatted time series data for RNN training by creating batches of sequences across different locations, stacked to be compatible with stateful models.
243 This function processes multi-location time series data by slicing it into batches and formatting it to fit into a recurrent neural network (RNN) model. It utilizes a staircase-like approach to prepare sequences for each location and then interlaces them to align with stateful RNN structures.
245 Parameters:
246 -----------
247 X : list of numpy arrays
248 A list where each element is a numpy array containing features for a specific location. The shape of each array is `(total_time_steps, features)`.
250 y : list of numpy arrays
251 A list where each element is a numpy array containing the target values for a specific location. The shape of each array is `(total_time_steps,)`.
253 batch_size : int
254 The number of sequences to include in each batch.
256 timesteps : int
257 The number of time steps to include in each sequence for the RNN.
259 hours : int, optional
260 The length of each time series to consider for each location. If `None`, it defaults to the minimum length of `y` across all locations.
262 start_times : numpy array, optional
263 The initial time step for each location. If `None`, it defaults to an array starting from 0 and incrementing by 1 for each location.
265 verbose : bool, optional
266 If `True`, prints additional information during processing. Default is `True`.
268 Returns:
269 --------
270 XX : numpy array
271 A 3D numpy array with shape `(total_sequences, timesteps, features)` containing the prepared feature sequences for all locations.
273 yy : numpy array
274 A 2D numpy array with shape `(total_sequences, 1)` containing the corresponding target values for all locations.
276 n_seqs : int
277 Number of sequences per location. Used to reset states when location changes. Hidden state of RNN will be reset after n_seqs number of batches
279 Notes:
280 ------
281 - The function handles spatially distributed time series data by batching and formatting it for stateful RNNs.
282 - `hours` determines how much of the time series is used for each location. If not provided, it defaults to the shortest series in `y`.
283 - If `start_times` is not provided, it assumes each location starts its series at progressively later time steps.
284 - The `batch_setup` function is used internally to manage the creation of location and time step batches.
285 - The returned feature sequences `XX` and target sequences `yy` are interlaced to align with the expected input format of stateful RNNs.
288 # Generate ids based on number of distinct timeseries provided
289 n_loc = len(y) # assuming each list entry for y is a separate location
290 loc_ids = np.arange(n_loc)
292 # Generate hours and start_times if None
293 if hours is None:
294 print("Setting total hours to minimum length of y in provided dictionary")
295 hours = min(len(yi) for yi in y)
296 if start_times is None:
297 print("Setting Start times to offset by 1 hour by location")
298 start_times = np.arange(n_loc)
299 # Set up batches
300 loc_batch, t_batch = batch_setup(loc_ids, batch_size), batch_setup(start_times, batch_size)
301 if verbose:
302 print(f"Location ID Batches: {loc_batch}")
303 print(f"Start Times for Batches: {t_batch}")
305 # Loop over batches and construct with staircase_2
306 Xs = []
307 ys = []
308 for i in range(0, len(loc_batch)):
309 locs_i = loc_batch[i]
310 ts = t_batch[i]
311 for j in range(0, len(locs_i)):
312 t0 = ts[j]
313 tend = t0 + hours
314 # Create RNNData Dict
315 # Subset data to given location and time from t0 to t0+hours
316 X_temp = X[j][t0:tend,:]
317 y_temp = y[j][t0:tend].reshape(-1,1)
319 # Format sequences
320 Xi, yi = staircase_2(
321 X_temp,
322 y_temp,
323 timesteps = timesteps,
324 batch_size = 1, # note: using 1 here to format sequences for a single location, not same as target batch size for training data
325 verbose=False)
327 Xs.append(Xi)
328 ys.append(yi)
330 # Drop incomplete batches
331 lens = [yi.shape[0] for yi in ys]
332 n_seqs = min(lens)
333 if verbose:
334 print(f"Minimum number of sequences by location: {n_seqs}")
335 print(f"Applying minimum length to other arrays.")
336 Xs = [Xi[:n_seqs] for Xi in Xs]
337 ys = [yi[:n_seqs] for yi in ys]
339 # Interlace arrays to match stateful structure
340 n_features = Xi.shape[2]
341 XXs = []
342 yys = []
343 for i in range(0, len(loc_batch)):
344 locs_i = loc_batch[i]
345 XXi = np.empty((Xs[0].shape[0]*batch_size, 5, n_features))
346 yyi = np.empty((Xs[0].shape[0]*batch_size, 1))
347 for j in range(0, len(locs_i)):
348 XXi[j::(batch_size)] = Xs[locs_i[j]]
349 yyi[j::(batch_size)] = ys[locs_i[j]]
350 XXs.append(XXi)
351 yys.append(yyi)
352 yy = np.concatenate(yys, axis=0)
353 XX = np.concatenate(XXs, axis=0)
355 if verbose:
356 print(f"Spatially Formatted X Shape: {XX.shape}")
357 print(f"Spatially Formatted X Shape: {yy.shape}")
360 return XX, yy, n_seqs
362 #***********************************************************************************************
363 ### RNN Class Functionality
365 class RNNParams(dict):
367 A custom dictionary class for handling RNN parameters. Automatically calculates certain params based on others. Overwrites the update method to protect from incompatible parameter choices. Inherits from dict
368 """
369 def __init__(self, input_dict):
371 Initializes the RNNParams instance and runs checks and shape calculations.
373 Parameters:
374 -----------
375 input_dict : dict,
376 A dictionary containing RNN parameters.
378 super().__init__(input_dict)
379 # Automatically run checks on initialization
380 self.run_checks()
381 # Automatically calculate shapes on initialization
382 self.calc_param_shapes()
383 def run_checks(self, verbose=True):
385 Validates that required keys exist and are of the correct type.
387 Parameters:
388 -----------
389 verbose : bool, optional
390 If True, prints status messages. Default is True.
392 print("Checking params...")
393 # Keys must exist and be integers
394 int_keys = [
395 'batch_size', 'timesteps', 'rnn_layers',
396 'rnn_units', 'dense_layers', 'dense_units', 'epochs'
399 for key in int_keys:
400 assert key in self, f"Missing required key: {key}"
401 assert isinstance(self[key], int), f"Key '{key}' must be an integer"
403 # Keys must exist and be lists
404 list_keys = ['activation', 'features_list', 'dropout']
405 for key in list_keys:
406 assert key in self, f"Missing required key: {key}"
407 assert isinstance(self[key], list), f"Key '{key}' must be a list"
409 # Keys must exist and be floats
410 float_keys = ['learning_rate', 'train_frac', 'val_frac']
411 for key in float_keys:
412 assert key in self, f"Missing required key: {key}"
413 assert isinstance(self[key], float), f"Key '{key}' must be a float"
415 print("Input dictionary passed all checks.")
416 def calc_param_shapes(self, verbose=True):
418 Calculates and updates the shapes of certain parameters based on input data.
420 Parameters:
421 -----------
422 verbose : bool, optional
423 If True, prints status messages. Default is True.
425 if verbose:
426 print("Calculating shape params based on features list, timesteps, and batch size")
427 print(f"Input Feature List: {self['features_list']}")
428 print(f"Input Timesteps: {self['timesteps']}")
429 print(f"Input Batch Size: {self['batch_size']}")
431 n_features = len(self['features_list'])
432 batch_shape = (self["batch_size"], self["timesteps"], n_features)
433 if verbose:
434 print("Calculated params:")
435 print(f"Number of features: {n_features}")
436 print(f"Batch Shape: {batch_shape}")
438 # Update the dictionary
439 super().update({
440 'n_features': n_features,
441 'batch_shape': batch_shape
443 if verbose:
444 print(self)
446 def update(self, *args, verbose=True, **kwargs):
448 Updates the dictionary, with restrictions on certain keys, and recalculates shapes if necessary.
450 Parameters:
451 -----------
452 verbose : bool, optional
453 If True, prints status messages. Default is True.
455 # Prevent updating n_features and batch_shape
456 restricted_keys = {'n_features', 'batch_shape'}
457 keys_to_check = {'features_list', 'timesteps', 'batch_size'}
459 # Check for restricted keys in args
460 if args:
461 if isinstance(args[0], dict):
462 if restricted_keys & args[0].keys():
463 raise KeyError(f"Cannot directly update keys: {restricted_keys & args[0].keys()}, \n Instead update one of: {keys_to_check}")
464 elif isinstance(args[0], (tuple, list)) and all(isinstance(i, tuple) and len(i) == 2 for i in args[0]):
465 if restricted_keys & {k for k, v in args[0]}:
466 raise KeyError(f"Cannot directly update keys: {restricted_keys & {k for k, v in args[0]}}, \n Instead update one of: {keys_to_check}")
468 # Check for restricted keys in kwargs
469 if restricted_keys & kwargs.keys():
470 raise KeyError(f"Cannot update restricted keys: {restricted_keys & kwargs.keys()}")
473 # Track if specific keys are updated
474 keys_updated = set()
476 # Update using the standard dict update method
477 if args:
478 if isinstance(args[0], dict):
479 keys_updated.update(args[0].keys() & keys_to_check)
480 elif isinstance(args[0], (tuple, list)) and all(isinstance(i, tuple) and len(i) == 2 for i in args[0]):
481 keys_updated.update(k for k, v in args[0] if k in keys_to_check)
483 if kwargs:
484 keys_updated.update(kwargs.keys() & keys_to_check)
486 # Call the parent update method
487 super().update(*args, **kwargs)
489 # Recalculate shapes if necessary
490 if keys_updated:
491 self.calc_param_shapes(verbose=verbose)
494 ## Class for handling input data
495 class RNNData(dict):
497 A custom dictionary class for managing RNN data, with validation, scaling, and train-test splitting functionality.
498 """
499 required_keys = {"loc", "time", "X", "y", "features_list"}
500 def __init__(self, input_dict, scaler=None, features_list=None):
502 Initializes the RNNData instance, performs checks, and prepares data.
504 Parameters:
505 -----------
506 input_dict : dict
507 A dictionary containing the initial data.
508 scaler : str, optional
509 The name of the scaler to be used (e.g., 'minmax', 'standard'). Default is None.
510 features_list : list, optional
511 A subset of features to be used. Default is None which means all features.
514 # Copy to avoid changing external input
515 input_data = input_dict.copy()
516 # Initialize inherited dict class
517 super().__init__(input_data)
519 # Check if input data is one timeseries dataset or multiple
520 if type(self.loc['STID']) == str:
521 self.spatial = False
522 print("Input data is single timeseries.")
523 elif type(self.loc['STID']) == list:
524 self.spatial = True
525 print("Input data from multiple timeseries.")
526 else:
527 raise KeyError(f"Input locations not list or single string")
529 # Set up Data Scaling
530 self.scaler = None
531 if scaler is not None:
532 self.set_scaler(scaler)
534 # Rename and define other stuff.
535 if self.spatial:
536 self['hours'] = min(arr.shape[0] for arr in self.y)
537 else:
538 self['hours'] = len(self['y'])
540 self['all_features_list'] = self.pop('features_list')
541 if features_list is None:
542 print("Using all input features.")
543 self.features_list = self.all_features_list
544 else:
545 self.features_list = features_list
546 # self.run_checks()
547 self.__dict__.update(self)
549 # TODO: Fix checks for multilocation
550 def run_checks(self, verbose=True):
552 Validates that required keys are present and checks the integrity of data shapes.
554 Parameters:
555 -----------
556 verbose : bool, optional
557 If True, prints status messages. Default is True.
558 """
559 missing_keys = self.required_keys - self.keys()
560 if missing_keys:
561 raise KeyError(f"Missing required keys: {missing_keys}")
562 # # Check y 1-d
563 # y_shape = np.shape(self.y)
564 # if not (len(y_shape) == 1 or (len(y_shape) == 2 and y_shape[1] == 1)):
565 # raise ValueError(f"'y' must be one-dimensional, with shape (N,) or (N, 1). Current shape is {y_shape}.")
567 # # Check if 'hours' is provided and matches len(y)
568 # if 'hours' in self:
569 # if self.hours != len(self.y):
570 # raise ValueError(f"Provided 'hours' value {self.hours} does not match the length of 'y', which is {len(self.y)}.")
571 # Check desired subset of features is in all input features
572 if not all_items_exist(self.features_list, self.all_features_list):
573 raise ValueError(f"Provided 'features_list' {self.features_list} has elements not in input features.")
574 def set_scaler(self, scaler):
576 Sets the scaler to be used for data normalization.
578 Parameters:
579 -----------
580 scaler : str
581 The name of the scaler (e.g., 'minmax', 'standard').
582 """
583 recognized_scalers = ['minmax', 'standard']
584 if scaler in recognized_scalers:
585 print(f"Setting data scaler: {scaler}")
586 self.scaler = scalers[scaler]
587 else:
588 raise ValueError(f"Unrecognized scaler '{scaler}'. Recognized scalers are: {recognized_scalers}.")
589 def train_test_split(self, train_frac, val_frac=0.0, subset_features=True, features_list=None, split_time=True, split_space=False, verbose=True):
591 Splits the data into training, validation, and test sets.
593 Parameters:
594 -----------
595 train_frac : float
596 The fraction of data to be used for training.
597 val_frac : float, optional
598 The fraction of data to be used for validation. Default is 0.0.
599 subset_features : bool, optional
600 If True, subsets the data to the specified features list. Default is True.
601 features_list : list, optional
602 A list of features to use for subsetting. Default is None.
603 split_time : bool, optional
604 Whether to split the data based on time. Default is True.
605 split_space : bool, optional
606 Whether to split the data based on space. Default is False.
607 verbose : bool, optional
608 If True, prints status messages. Default is True.
610 # Indicate whether multi timeseries or not
611 spatial = self.spatial
613 # Extract data to desired features
614 X = self.X.copy()
615 y = self.y.copy()
616 if subset_features:
617 if verbose and self.features_list != self.all_features_list:
618 print(f"Subsetting input data to features_list: {self.features_list}")
619 # Indices to subset all features with based on params features
620 indices = []
621 for item in self.features_list:
622 if item in self.all_features_list:
623 indices.append(self.all_features_list.index(item))
624 else:
625 print(f"Warning: feature name '{item}' not found in list of all features from input data")
626 if spatial:
627 X = [Xi[:, indices] for Xi in X]
628 else:
629 X = X[:, indices]
631 # Setup train/test in time
632 train_ind = int(np.floor(self.hours * train_frac)); self.train_ind = train_ind
633 test_ind= int(train_ind + round(self.hours * val_frac)); self.test_ind = test_ind
635 # Check for any potential issues with indices
636 if test_ind > self.hours:
637 print(f"Setting test index to {self.hours}")
638 test_ind = self.hours
639 if train_ind >= test_ind:
640 raise ValueError("Train index must be less than test index.")
642 # Training data from 0 to train_ind
643 # Validation data from train_ind to test_ind
644 # Test data from test_ind to end
645 if spatial:
646 self.X_train = [Xi[:train_ind] for Xi in X]
647 self.y_train = [yi[:train_ind].reshape(-1,1) for yi in y]
648 if val_frac >0:
649 self.X_val = [Xi[train_ind:test_ind] for Xi in X]
650 self.y_val = [yi[train_ind:test_ind].reshape(-1,1) for yi in y]
651 self.X_test = [Xi[:train_ind] for Xi in X]
652 self.y_test = [yi[:train_ind].reshape(-1,1) for yi in y]
653 else:
654 self.X_train = X[:train_ind]
655 self.y_train = y[:train_ind].reshape(-1,1) # assumes y 1-d, change this if vector output
656 if val_frac >0:
657 self.X_val = X[train_ind:test_ind]
658 self.y_val = y[train_ind:test_ind].reshape(-1,1) # assumes y 1-d, change this if vector output
659 self.X_test = X[test_ind:]
660 self.y_test = y[test_ind:].reshape(-1,1) # assumes y 1-d, change this if vector output
664 # Print statements if verbose
665 if verbose:
666 print(f"Train index: 0 to {train_ind}")
667 print(f"Validation index: {train_ind} to {test_ind}")
668 print(f"Test index: {test_ind} to {self.hours}")
670 if spatial:
671 print(f"X_train[0] shape: {self.X_train[0].shape}, y_train[0] shape: {self.y_train[0].shape}")
672 print(f"X_val[0] shape: {self.X_val[0].shape}, y_val[0] shape: {self.y_val[0].shape}")
673 print(f"X_test[0] shape: {self.X_test[0].shape}, y_test[0] shape: {self.y_test[0].shape}")
674 else:
675 print(f"X_train shape: {self.X_train.shape}, y_train shape: {self.y_train.shape}")
676 print(f"X_val shape: {self.X_val.shape}, y_val shape: {self.y_val.shape}")
677 print(f"X_test shape: {self.X_test.shape}, y_test shape: {self.y_test.shape}")
679 def scale_data(self, verbose=True):
681 Scales the training data using the set scaler.
683 Parameters:
684 -----------
685 verbose : bool, optional
686 If True, prints status messages. Default is True.
687 """
688 # Indicate whether multi timeseries or not
689 spatial = self.spatial
690 if self.scaler is None:
691 raise ValueError("Scaler is not set. Use 'set_scaler' method to set a scaler before scaling data.")
692 if not hasattr(self, "X_train"):
693 raise AttributeError("No X_train within object. Run train_test_split first. This is to avoid fitting the scaler with prediction data.")
694 if verbose:
695 print(f"Scaling training data with scaler {self.scaler}, fitting on X_train")
697 if spatial:
698 # Fit scaler on row-joined training data
699 self.scaler.fit(np.vstack(self.X_train))
700 # Transform data using fitted scaler
701 self.X_train = [self.scaler.transform(Xi) for Xi in self.X_train]
702 if hasattr(self, 'X_val'):
703 self.X_val = [self.scaler.transform(Xi) for Xi in self.X_val]
704 self.X_test = [self.scaler.transform(Xi) for Xi in self.X_test]
705 else:
706 # Fit the scaler on the training data
707 self.scaler.fit(self.X_train)
708 # Transform the data using the fitted scaler
709 self.X_train = self.scaler.transform(self.X_train)
710 if hasattr(self, 'X_val'):
711 self.X_val = self.scaler.transform(self.X_val)
712 self.X_test = self.scaler.transform(self.X_test)
714 # NOTE: only works for non spatial
715 def scale_all_X(self, verbose=True):
717 Scales the all data using the set scaler.
719 Parameters:
720 -----------
721 verbose : bool, optional
722 If True, prints status messages. Default is True.
723 Returns:
724 -------
725 ndarray
726 Scaled X matrix, subsetted to features_list.
727 """
728 if self.spatial:
729 raise ValueError("Not implemented for spatial data")
731 if self.scaler is None:
732 raise ValueError("Scaler is not set. Use 'set_scaler' method to set a scaler before scaling data.")
733 if verbose:
734 print(f"Scaling all X data with scaler {self.scaler}, fitted on X_train")
735 # Subset features
736 indices = []
737 for item in self.features_list:
738 if item in self.all_features_list:
739 indices.append(self.all_features_list.index(item))
740 else:
741 print(f"Warning: feature name '{item}' not found in list of all features from input data")
742 X = self.X[:, indices]
743 X = self.scaler.transform(X)
745 return X
747 def inverse_scale(self, return_X = 'all_hours', save_changes=False, verbose=True):
749 Inversely scales the data to its original form.
751 Parameters:
752 -----------
753 return_X : str, optional
754 Specifies what data to return after inverse scaling. Default is 'all_hours'.
755 save_changes : bool, optional
756 If True, updates the internal data with the inversely scaled values. Default is False.
757 verbose : bool, optional
758 If True, prints status messages. Default is True.
759 """
760 if verbose:
761 print("Inverse scaling data...")
762 X_train = self.scaler.inverse_transform(self.X_train)
763 X_val = self.scaler.inverse_transform(self.X_val)
764 X_test = self.scaler.inverse_transform(self.X_test)
766 if save_changes:
767 print("Inverse transformed data saved")
768 self.X_train = X_train
769 self.X_val = X_val
770 self.X_test = X_test
771 else:
772 if verbose:
773 print("Inverse scaled, but internal data not changed.")
774 if verbose:
775 print(f"Attempting to return {return_X}")
776 if return_X == "all_hours":
777 return np.concatenate((X_train, X_val, X_test), axis=0)
778 else:
779 print(f"Unrecognized or unimplemented return value {return_X}")
780 def batch_reshape(self, timesteps, batch_size, hours=None, verbose=False):
782 Restructures input data to RNN using batches and sequences.
784 Parameters:
785 ----------
786 batch_size : int
787 The size of each training batch to reshape the data.
788 timesteps : int
789 The number of timesteps or sequence length. Consistitutes a single sample
790 timesteps : int
791 Number of timesteps or sequence length used for a single sequence in RNN training. Constitutes a single sample to the model
793 batch_size : int
794 Number of sequences used within a batch of training
796 Returns:
797 -------
798 None
799 This method reshapes the data in place.
800 Raises:
801 ------
802 AttributeError
803 If either 'X_train' or 'y_train' attributes do not exist within the instance.
805 Notes:
806 ------
807 The reshaping method depends on self param "spatial".
808 - spatial == False: Reshapes data assuming no spatial dimensions.
809 - spatial == True: Reshapes data considering spatial dimensions.
813 if not hasattr(self, 'X_train') or not hasattr(self, 'y_train'):
814 raise AttributeError("Both 'X_train' and 'y_train' must be set before reshaping batches.")
816 # Indicator of spatial training scheme or not
817 spatial = self.spatial
819 if spatial:
820 print(f"Reshaping spatial training data using batch size: {batch_size} and timesteps: {timesteps}")
821 self.X_train, self.y_train, self.n_seqs = staircase_spatial(self.X_train, self.y_train, timesteps = timesteps, batch_size=batch_size, hours=hours, verbose=verbose)
822 if hasattr(self, "X_val"):
823 print(f"Reshaping validation data using batch size: {batch_size} and timesteps: {timesteps}")
824 self.X_val, self.y_val, _ = staircase_spatial(self.X_val, self.y_val, timesteps = timesteps, batch_size=batch_size, hours=None, verbose=verbose)
825 else:
826 print(f"Reshaping training data using batch size: {batch_size} and timesteps: {timesteps}")
827 self.X_train, self.y_train = staircase_2(self.X_train, self.y_train, timesteps = timesteps, batch_size=batch_size, verbose=verbose)
828 if hasattr(self, "X_val"):
829 print(f"Reshaping validation data using batch size: {batch_size} and timesteps: {timesteps}")
830 self.X_val, self.y_val = staircase_2(self.X_val, self.y_val, timesteps = timesteps, batch_size=batch_size, verbose=verbose)
832 def print_hashes(self, attrs_to_check = ['X', 'y', 'X_train', 'y_train', 'X_val', 'y_val', 'X_test', 'y_test']):
834 Prints the hash of specified data attributes.
836 Parameters:
837 -----------
838 attrs_to_check : list, optional
839 A list of attribute names to hash and print. Default includes 'X', 'y', and split data.
841 for attr in attrs_to_check:
842 if hasattr(self, attr):
843 value = getattr(self, attr)
844 if self.spatial:
845 pass
846 else:
847 print(f"Hash of {attr}: {hash_ndarray(value)}")
848 def __getattr__(self, key):
850 Allows attribute-style access to dictionary keys, a.k.a. enables the "." operator for get elements
851 """
852 try:
853 return self[key]
854 except KeyError:
855 raise AttributeError(f"'rnn_data' object has no attribute '{key}'")
857 def __setitem__(self, key, value):
859 Ensures dictionary and attribute updates stay in sync for required keys.
860 """
861 super().__setitem__(key, value) # Update the dictionary
862 if key in self.required_keys:
863 super().__setattr__(key, value) # Ensure the attribute is updated as well
865 def __setattr__(self, key, value):
867 Ensures dictionary keys are updated when setting attributes.
869 self[key] = value
872 # Function to check reproduciblity hashes, environment info, and model parameters
873 def check_reproducibility(dict0, params, m_hash, w_hash):
875 Performs reproducibility checks on a model by comparing current settings and outputs with stored reproducibility information.
877 Parameters:
878 -----------
879 dict0 : dict
880 The data dictionary that should contain reproducibility information under the 'repro_info' attribute.
881 params : dict
882 The current model parameters to be checked against the reproducibility information.
883 m_hash : str
884 The hash of the current model predictions.
885 w_hash : str
886 The hash of the current fitted model weights.
888 Returns:
889 --------
890 None
891 The function returns None. It issues warnings if any reproducibility checks fail.
893 Notes:
894 ------
895 - Checks are only performed if the `dict0` contains the 'repro_info' attribute.
896 - Issues warnings for mismatches in model weights, predictions, Python version, TensorFlow version, and model parameters.
897 - Skips checks if physics-based initialization is used (not implemented).
898 """
899 if not hasattr(dict0, "repro_info"):
900 warnings.warn("The provided data dictionary does not have the required 'repro_info' attribute. Not running reproduciblity checks.")
901 return
903 repro_info = dict0.repro_info
904 # Check Hashes
905 if params['phys_initialize']:
906 hashes = repro_info['phys_initialize']
907 warnings.warn("Physics Initialization not implemented yet. Not running reproduciblity checks.")
908 else:
909 hashes = repro_info['rand_initialize']
910 print(f"Fitted weights hash: {w_hash} \n Reproducibility weights hash: {hashes['fitted_weights_hash']}")
911 print(f"Model predictions hash: {m_hash} \n Reproducibility preds hash: {hashes['preds_hash']}")
912 if (w_hash != hashes['fitted_weights_hash']) or (m_hash != hashes['preds_hash']):
913 if w_hash != hashes['fitted_weights_hash']:
914 warnings.warn("The fitted weights hash does not match the reproducibility weights hash.")
915 if m_hash != hashes['preds_hash']:
916 warnings.warn("The predictions hash does not match the reproducibility predictions hash.")
917 else:
918 print("***Reproducibility Checks passed - model weights and model predictions match expected.***")
920 # Check Environment
921 current_py_version = sys.version[0:6]
922 current_tf_version = tf.__version__
923 if current_py_version != repro_info['env_info']['py_version']:
924 warnings.warn(f"Python version mismatch: Current Python version is {current_py_version}, "
925 f"expected {repro_info['env_info']['py_version']}.")
927 if current_tf_version != repro_info['env_info']['tf_version']:
928 warnings.warn(f"TensorFlow version mismatch: Current TensorFlow version is {current_tf_version}, "
929 f"expected {repro_info['env_info']['tf_version']}.")
931 # Check Params
932 repro_params = repro_info.get('params', {})
934 for key, repro_value in repro_params.items():
935 if key in params:
936 if params[key] != repro_value:
937 warnings.warn(f"Parameter mismatch for '{key}': Current value is {params[key]}, "
938 f"repro value is {repro_value}.")
939 else:
940 warnings.warn(f"Parameter '{key}' is missing in the current params.")
942 return
944 class RNNModel(ABC):
946 Abstract base class for RNN models, providing structure for training, predicting, and running reproducibility checks.
948 def __init__(self, params: dict):
950 Initializes the RNNModel with the given parameters.
952 Parameters:
953 -----------
954 params : dict
955 A dictionary containing model parameters.
957 self.params = copy.deepcopy(params)
958 self.params['n_features'] = len(self.params['features_list'])
959 if type(self) is RNNModel:
960 raise TypeError("MLModel is an abstract class and cannot be instantiated directly")
961 super().__init__()
963 @abstractmethod
964 def _build_model_train(self):
965 """Abstract method to build the training model."""
966 pass
968 @abstractmethod
969 def _build_model_predict(self, return_sequences=True):
970 """Abstract method to build the prediction model. This model copies weights from the train model but with input structure that allows for easier prediction of arbitrary length timeseries. This model is not to be used for training, or don't use with .fit calls"""
971 pass
973 def is_stateful(self):
975 Checks whether any of the layers in the internal model (self.model_train) are stateful.
977 Returns:
978 bool: True if at least one layer in the model is stateful, False otherwise.
980 This method iterates over all the layers in the model and checks if any of them
981 have the 'stateful' attribute set to True. This is useful for determining if
982 the model is designed to maintain state across batches during training.
984 Example:
985 --------
986 model.is_stateful()
987 """
988 for layer in self.model_train.layers:
989 if hasattr(layer, 'stateful') and layer.stateful:
990 return True
991 return False
993 def fit(self, X_train, y_train, plot=True, plot_title = '',
994 weights=None, callbacks=[], validation_data=None, *args, **kwargs):
996 Trains the model on the provided training data.
998 Parameters:
999 -----------
1000 X_train : np.ndarray
1001 The input matrix data for training.
1002 y_train : np.ndarray
1003 The target vector data for training.
1004 plot : bool, optional
1005 If True, plots the training history. Default is True.
1006 plot_title : str, optional
1007 The title for the training plot. Default is an empty string.
1008 weights : optional
1009 Initial weights for the model. Default is None.
1010 callbacks : list, optional
1011 A list of callback functions to use during training. Default is an empty list.
1012 validation_data : tuple, optional
1013 Validation data to use during training, expected format (X_val, y_val). Default is None.
1014 """
1015 # verbose_fit argument is for printing out update after each epoch, which gets very long
1016 # These print statements at the top could be turned off with a verbose argument, but then
1017 # there would be a bunch of different verbose params
1018 verbose_fit = self.params['verbose_fit']
1019 verbose_weights = self.params['verbose_weights']
1020 if verbose_weights:
1021 print(f"Training simple RNN with params: {self.params}")
1022 # X_train, y_train = self.format_train_data(X_train, y_train)
1023 if validation_data is not None:
1024 X_val, y_val =validation_data[0], validation_data[1]
1025 if verbose_weights:
1026 print(f"Formatted X_train hash: {hash_ndarray(X_train)}")
1027 print(f"Formatted y_train hash: {hash_ndarray(y_train)}")
1028 if validation_data is not None:
1029 print(f"Formatted X_val hash: {hash_ndarray(X_val)}")
1030 print(f"Formatted y_val hash: {hash_ndarray(y_val)}")
1031 print(f"Initial weights before training hash: {hash_weights(self.model_train)}")
1032 # Setup callbacks
1033 if self.params["reset_states"]:
1034 callbacks=callbacks+[ResetStatesCallback(self.params), TerminateOnNaN()]
1035 if validation_data is not None:
1036 print("Using early stopping callback.")
1037 callbacks=callbacks+[EarlyStoppingCallback(patience = self.params['early_stopping_patience'])]
1039 # Note: we overload the params here so that verbose_fit can be easily turned on/off at the .fit call
1041 # Evaluate Model once to set nonzero initial state
1042 if self.params["batch_size"]>= X_train.shape[0]:
1043 self.model_train(X_train)
1044 if validation_data is not None:
1045 history = self.model_train.fit(
1046 X_train, y_train+self.params['centering'][1],
1047 epochs=self.params['epochs'],
1048 batch_size=self.params['batch_size'],
1049 callbacks = callbacks,
1050 verbose=verbose_fit,
1051 validation_data = (X_val, y_val),
1052 *args, **kwargs
1054 else:
1055 history = self.model_train.fit(
1056 X_train, y_train+self.params['centering'][1],
1057 epochs=self.params['epochs'],
1058 batch_size=self.params['batch_size'],
1059 callbacks = callbacks,
1060 verbose=verbose_fit,
1061 *args, **kwargs
1064 if plot:
1065 self.plot_history(history,plot_title)
1066 if self.params["verbose_weights"]:
1067 print(f"Fitted Weights Hash: {hash_weights(self.model_train)}")
1069 # Update Weights for Prediction Model
1070 w_fitted = self.model_train.get_weights()
1071 self.model_predict.set_weights(w_fitted)
1073 def predict(self, X_test):
1075 Generates predictions on the provided test data using the internal prediction model.
1077 Parameters:
1078 -----------
1079 X_test : np.ndarray
1080 The input data for generating predictions.
1082 Returns:
1083 --------
1084 np.ndarray
1085 The predicted values.
1086 """
1087 print("Predicting")
1088 X_test = self.format_pred_data(X_test)
1089 preds = self.model_predict.predict(X_test).flatten()
1090 return preds
1092 # DEPRECATED, USED WITHIN RNNData object now
1093 # def format_train_data(self, X, y, verbose=False):
1094 # """
1095 # Formats the training data for RNN input.
1097 # Parameters:
1098 # -----------
1099 # X : np.ndarray
1100 # The input data.
1101 # y : np.ndarray
1102 # The target data.
1103 # verbose : bool, optional
1104 # If True, prints status messages. Default is False.
1106 # Returns:
1107 # --------
1108 # tuple
1109 # Formatted input and target data.
1110 # """
1111 # X, y = staircase_2(X, y, timesteps = self.params["timesteps"], batch_size=self.params["batch_size"], verbose=verbose)
1112 # return X, y
1113 def format_pred_data(self, X):
1115 Formats the prediction data for RNN input.
1117 Parameters:
1118 -----------
1119 X : np.ndarray
1120 The input data.
1122 Returns:
1123 --------
1124 np.ndarray
1125 The formatted input data.
1126 """
1127 return np.reshape(X,(1, X.shape[0], self.params['n_features']))
1129 def plot_history(self, history, plot_title):
1131 Plots the training history.
1133 Parameters:
1134 -----------
1135 history : History object
1136 The training history object from model fitting. Output of keras' .fit command
1137 plot_title : str
1138 The title for the plot.
1140 plt.figure()
1141 plt.semilogy(history.history['loss'], label='Training loss')
1142 if 'val_loss' in history.history:
1143 plt.semilogy(history.history['val_loss'], label='Validation loss')
1144 plt.title(f'{plot_title} Model loss')
1145 plt.ylabel('Loss')
1146 plt.xlabel('Epoch')
1147 plt.legend(loc='upper left')
1148 plt.show()
1150 def run_model(self, dict0, reproducibility_run=False, plot_period='all'):
1152 Runs the RNN model, including training, prediction, and reproducibility checks.
1154 Parameters:
1155 -----------
1156 dict0 : dict
1157 The dictionary containing the input data and configuration.
1158 reproducibility_run : bool, optional
1159 If True, performs reproducibility checks after running the model. Default is False.
1161 Returns:
1162 --------
1163 tuple
1164 Model predictions and a dictionary of RMSE errors.
1165 """
1166 verbose_fit = self.params['verbose_fit']
1167 verbose_weights = self.params['verbose_weights']
1168 # Make copy to prevent changing in place
1169 dict1 = copy.deepcopy(dict0)
1170 if verbose_weights:
1171 print("Input data hashes, NOT formatted for rnn sequence/batches yet")
1172 dict1.print_hashes()
1173 # Extract Fields
1174 X_train, y_train, X_test, y_test = dict1['X_train'].copy(), dict1['y_train'].copy(), dict1["X_test"].copy(), dict1['y_test'].copy()
1175 if 'X_val' in dict1:
1176 X_val, y_val = dict1['X_val'].copy(), dict1['y_val'].copy()
1177 else:
1178 X_val = None
1179 case_id = dict1['case']
1181 # Fit model
1182 if X_val is None:
1183 self.fit(X_train, y_train, plot_title=case_id)
1184 else:
1185 self.fit(X_train, y_train, validation_data = (X_val, y_val), plot_title=case_id)
1186 # Generate Predictions,
1187 # run through training to get hidden state set proporly for forecast period
1188 X = dict1.scale_all_X()
1189 y = dict1.y.flatten()
1190 # Predict
1191 if verbose_weights:
1192 print(f"Predicting Training through Test")
1193 print(f"All X hash: {hash_ndarray(X)}")
1195 m = self.predict(X).flatten()
1196 if verbose_weights:
1197 print(f"Predictions Hash: {hash_ndarray(m)}")
1198 dict1['m']=m
1199 dict0['m']=m # add to outside env dictionary, should be only place this happens
1201 if reproducibility_run:
1202 print("Checking Reproducibility")
1203 check_reproducibility(dict0, self.params, hash_ndarray(m), hash_weights(self.model_predict))
1205 # print(dict1.keys())
1206 # Plot final fit and data
1207 dict1['y'] = y
1208 plot_data(dict1, title="RNN", title2=dict1['case'], plot_period=plot_period)
1210 # Calculate Errors
1211 err = rmse(m, y)
1212 train_ind = dict1["train_ind"] # index of final training set value
1213 test_ind = dict1["test_ind"] # index of first test set value
1215 err_train = rmse(m[:train_ind], y[:train_ind].flatten())
1216 err_pred = rmse(m[test_ind:], y[test_ind:].flatten())
1217 rmse_dict = {
1218 'all': err,
1219 'training': err_train,
1220 'prediction': err_pred
1222 return m, rmse_dict
1226 ## Callbacks
1228 # Helper functions for batch reset schedules
1229 def calc_exp_intervals(bmin, bmax, n_epochs, force_bmax = True):
1230 # Calculate the exponential intervals for each epoch
1231 epochs = np.arange(n_epochs)
1232 factors = epochs / n_epochs
1233 intervals = bmin * (bmax / bmin) ** factors
1234 if force_bmax:
1235 intervals[-1] = bmax # Ensure the last value is exactly bmax
1236 return intervals.astype(int)
1238 def calc_log_intervals(bmin, bmax, n_epochs, force_bmax = True):
1239 # Calculate the logarithmic intervals for each epoch
1240 epochs = np.arange(n_epochs)
1241 factors = np.log(1 + epochs) / np.log(1 + n_epochs)
1242 intervals = bmin + (bmax - bmin) * factors
1243 if force_bmax:
1244 intervals[-1] = bmax # Ensure the last value is exactly bmax
1245 return intervals.astype(int)
1247 class ResetStatesCallback(Callback):
1249 Custom callback to reset the states of RNN layers at the end of each epoch and optionally after a specified number of batches.
1251 Parameters:
1252 -----------
1253 batch_reset : int, optional
1254 If provided, resets the states of RNN layers after every `batch_reset` batches. Default is None.
1255 """
1256 # def __init__(self, bmin=None, bmax=None, epochs=None, loc_batch_reset = None, batch_schedule_type='linear', verbose=True):
1257 def __init__(self, params=None, verbose=True):
1259 Initializes the ResetStatesCallback with an optional batch reset interval.
1261 Parameters:
1262 -----------
1263 params: dict, optional
1264 Dictionary of parameters. If None provided, only on_epoch_end will trigger reset of hidden states.
1265 - bmin : int
1266 Minimum for batch reset schedule
1267 - bmax : int
1268 Maximum for batch reset schedule
1269 - epochs : int
1270 Number of training epochs.
1271 - loc_batch_reset : int
1272 Interval of batches after which to reset the states of RNN layers for location changes. Triggers reset for training AND validation phases
1273 - batch_schedule_type : str
1274 Type of batch scheduling to be used. Recognized methods are following:
1275 - 'constant' : Used fixed batch reset interval throughout training
1276 - 'linear' : Increases the batch reset interval linearly over epochs from bmin to bmax.
1277 - 'exp' : Increases the batch reset interval exponentially over epochs from bmin to bmax.
1278 - 'log' : Increases the batch reset interval logarithmically over epochs from bmin to bmax.
1281 Returns:
1282 -----------
1283 Only in-place reset of hidden states of RNN that calls uses this callback.
1285 """
1286 super(ResetStatesCallback, self).__init__()
1288 # Check for optional arguments, set None if missing in input params
1289 arg_list = ['bmin', 'bmax', 'epochs', 'loc_batch_reset', 'batch_schedule_type']
1290 for arg in arg_list:
1291 setattr(self, arg, params.get(arg, None))
1293 self.verbose = verbose
1294 if self.verbose:
1295 print(f"Using ResetStatesCallback with Batch Reset Schedule: {self.batch_schedule_type}")
1296 # Calculate the reset intervals for each epoch during initialization
1297 if self.batch_schedule_type is not None:
1298 if self.epochs is None:
1299 raise ValueError(f"Arugment `epochs` cannot be none with self.batch_schedule_type: {self.batch_schedule_type}")
1300 self.batch_reset_intervals = self._calc_reset_intervals(self.batch_schedule_type)
1301 if self.verbose:
1302 print(f"batch_reset_intervals: {self.batch_reset_intervals}")
1303 else:
1304 self.batch_reset_intervals = None
1305 def on_epoch_end(self, epoch, logs=None):
1307 Resets the states of RNN layers at the end of each epoch.
1309 Parameters:
1310 -----------
1311 epoch : int
1312 The index of the current epoch.
1313 logs : dict, optional
1314 A dictionary containing metrics from the epoch. Default is None.
1315 """
1316 # print(f" Resetting hidden state after epoch: {epoch+1}", flush=True)
1317 # Iterate over each layer in the model
1318 for layer in self.model.layers:
1319 # Check if the layer has a reset_states method
1320 if hasattr(layer, 'reset_states'):
1321 layer.reset_states()
1322 def _calc_reset_intervals(self,batch_schedule_type):
1323 methods = ['constant', 'linear', 'exp', 'log']
1324 if batch_schedule_type not in methods:
1325 raise ValueError(f"Batch schedule method {batch_schedule_type} not recognized. \n Available methods: {methods}")
1326 if batch_schedule_type == "constant":
1328 return np.repeat(self.bmin, self.epochs).astype(int)
1329 elif batch_schedule_type == "linear":
1330 return np.linspace(self.bmin, self.bmax, self.epochs).astype(int)
1331 elif batch_schedule_type == "exp":
1332 return calc_exp_intervals(self.bmin, self.bmax, self.epochs)
1333 elif batch_schedule_type == "log":
1334 return calc_log_intervals(self.bmin, self.bmax, self.epochs)
1335 def on_epoch_begin(self, epoch, logs=None):
1336 # Set the reset interval for the current epoch
1337 if self.batch_reset_intervals is not None:
1338 self.current_batch_reset = self.batch_reset_intervals[epoch]
1339 else:
1340 self.current_batch_reset = None
1341 def on_train_batch_end(self, batch, logs=None):
1343 Resets the states of RNN layers during training after a specified number of batches, if `batch_reset` or `loc_batch_reset` are provided. The `batch_reset` is used for stability and to avoid exploding gradients at the beginning of training when a hidden state is being passed with weights that haven't learned yet. The `loc_batch_reset` is used to reset the states when a particular batch is from a new location and thus the hidden state should be passed.
1345 Parameters:
1346 -----------
1347 batch : int
1348 The index of the current batch.
1349 logs : dict, optional
1350 A dictionary containing metrics from the batch. Default is None.
1351 """
1352 batch_reset = self.current_batch_reset
1353 if (batch_reset is not None and batch % batch_reset == 0):
1354 # print(f" Resetting states after batch {batch + 1}")
1355 # Iterate over each layer in the model
1356 for layer in self.model.layers:
1357 # Check if the layer has a reset_states method
1358 if hasattr(layer, 'reset_states'):
1359 layer.reset_states()
1360 def on_test_batch_end(self, batch, logs=None):
1362 Resets the states of RNN layers during validation if `loc_batch_reset` is provided to demarcate a new location and thus avoid passing a hidden state to a wrong location.
1364 Parameters:
1365 -----------
1366 batch : int
1367 The index of the current batch.
1368 logs : dict, optional
1369 A dictionary containing metrics from the batch. Default is None.
1370 """
1371 loc_batch_reset = self.loc_batch_reset
1372 if (loc_batch_reset is not None and batch % loc_batch_reset == 0):
1373 # print(f"Resetting states in Validation mode after batch {batch + 1}")
1374 # Iterate over each layer in the model
1375 for layer in self.model.layers:
1376 # Check if the layer has a reset_states method
1377 if hasattr(layer, 'reset_states'):
1378 layer.reset_states()
1380 # class ResetStatesCallback(Callback):
1381 # """
1382 # Custom callback to reset the states of RNN layers at the end of each epoch and optionally after a specified number of batches.
1384 # Parameters:
1385 # -----------
1386 # batch_reset : int, optional
1387 # If provided, resets the states of RNN layers after every `batch_reset` batches. Default is None.
1388 # """
1389 # def __init__(self, batch_reset=None, loc_batch_reset=None):
1390 # """
1391 # Initializes the ResetStatesCallback with an optional batch reset interval.
1393 # Parameters:
1394 # -----------
1395 # batch_reset : int, optional
1396 # The interval of batches after which to reset the states of RNN layers. Default is None.
1397 # loc_batch_reset : int, optional
1398 # The interval of batches after which the location changes for a given batch number, then reset the states of RNN layers. Default is None.
1399 # """
1400 # super(ResetStatesCallback, self).__init__()
1401 # self.batch_reset = batch_reset
1402 # self.loc_batch_reset = loc_batch_reset
1403 # def on_epoch_end(self, epoch, logs=None):
1404 # """
1405 # Resets the states of RNN layers at the end of each epoch.
1407 # Parameters:
1408 # -----------
1409 # epoch : int
1410 # The index of the current epoch.
1411 # logs : dict, optional
1412 # A dictionary containing metrics from the epoch. Default is None.
1413 # """
1414 # # Iterate over each layer in the model
1415 # for layer in self.model.layers:
1416 # # Check if the layer has a reset_states method
1417 # if hasattr(layer, 'reset_states'):
1418 # layer.reset_states()
1419 # def on_train_batch_end(self, batch, logs=None):
1420 # """
1421 # Resets the states of RNN layers after a specified number of batches, if `batch_reset` or `loc_batch_reset` are provided. The batch_reset parameter resets the state for stability, and the loc_batch_reset parameter resets when the underlying timeseries for the batch changes (typically when location changes). If None provided for either used parameter, set to inf so mod batch is never zero.
1423 # Parameters:
1424 # -----------
1425 # batch : int
1426 # The index of the current batch.
1427 # logs : dict, optional
1428 # A dictionary containing metrics from the batch. Default is None.
1429 # """
1430 # batch_reset = self.batch_reset
1431 # if batch_reset is None:
1432 # batch_reset = np.inf
1433 # loc_batch_reset = self.loc_batch_reset
1434 # if loc_batch_reset is None:
1435 # loc_batch_reset = np.inf
1436 # if (batch % batch_reset == 0) or (batch % loc_batch_reset == 0):
1437 # # print(f"Resetting states after batch {batch}")
1438 # # Iterate over each layer in the model
1439 # for layer in self.model.layers:
1440 # # Check if the layer has a reset_states method
1441 # if hasattr(layer, 'reset_states'):
1442 # layer.reset_states()
1445 ## Learning Schedules
1446 ## NOT TESTED YET
1447 lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
1448 initial_learning_rate=0.001,
1449 decay_steps=1000,
1450 alpha=0.0,
1451 name='CosineDecay',
1452 # warmup_target=None,
1453 # warmup_steps=100
1456 def EarlyStoppingCallback(patience=5):
1458 Creates an EarlyStopping callback with the specified patience.
1460 Args:
1461 patience (int): Number of epochs with no improvement after which training will be stopped.
1463 Returns:
1464 EarlyStopping: Configured EarlyStopping callback.
1466 return EarlyStopping(
1467 monitor='val_loss',
1468 patience=patience,
1469 verbose=1,
1470 mode='min',
1471 restore_best_weights=True
1473 # early_stopping = EarlyStopping(
1474 # monitor='val_loss', # Metric to monitor, e.g., 'val_loss', 'val_accuracy'
1475 # patience=5, # Number of epochs with no improvement after which training will be stopped
1476 # verbose=1, # Print information about early stopping
1477 # mode='min', # 'min' means training will stop when the quantity monitored has stopped decreasing
1478 # restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored quantity
1481 # with open("params.yaml") as file:
1482 # phys_params = yaml.safe_load(file)["physics_initializer"]
1484 phys_params = {
1485 'DeltaE': [0,-1], # bias correction
1486 'T1': 0.1, # 1/fuel class (10)
1487 'fm_raise_vs_rain': 0.2 # fm increase per mm rain
1492 def get_initial_weights(model_fit,params,scale_fm=1):
1493 # Given a RNN architecture and hyperparameter dictionary, return array of physics-initiated weights
1494 # Inputs:
1495 # model_fit: output of create_RNN_2 with no training
1496 # params: (dict) dictionary of hyperparameters
1497 # rnn_dat: (dict) data dictionary, output of create_rnn_dat
1498 # Returns: numpy ndarray of weights that should be a rough solution to the moisture ODE
1499 DeltaE = phys_params['DeltaE']
1500 T1 = phys_params['T1']
1501 fmr = phys_params['fm_raise_vs_rain']
1502 centering = params['centering'] # shift activation down
1504 w0_initial={'Ed':(1.-np.exp(-T1))/2,
1505 'Ew':(1.-np.exp(-T1))/2,
1506 'rain':fmr * scale_fm} # wx - input feature
1507 # wh wb wd bd = bias -1
1509 w_initial=np.array([np.nan, np.exp(-0.1), DeltaE[0]/scale_fm, # layer 0
1510 1.0, -centering[0] + DeltaE[1]/scale_fm]) # layer 1
1511 if params['verbose_weights']:
1512 print('Equilibrium moisture correction bias',DeltaE[0],
1513 'in the hidden layer and',DeltaE[1],' in the output layer')
1515 w_name = ['wx','wh','bh','wd','bd']
1517 w=model_fit.get_weights()
1518 for j in range(w[0].shape[0]):
1519 feature = params['features_list'][j]
1520 for k in range(w[0].shape[1]):
1521 w[0][j][k]=w0_initial[feature]
1522 for i in range(1,len(w)): # number of the weight
1523 for j in range(w[i].shape[0]): # number of the inputs
1524 if w[i].ndim==2:
1525 # initialize all entries of the weight matrix to the same number
1526 for k in range(w[i].shape[1]):
1527 w[i][j][k]=w_initial[i]/w[i].shape[0]
1528 elif w[i].ndim==1:
1529 w[i][j]=w_initial[i]
1530 else:
1531 print('weight',i,'shape',w[i].shape)
1532 raise ValueError("Only 1 or 2 dimensions supported")
1533 if params['verbose_weights']:
1534 print('weight',i,w_name[i],'shape',w[i].shape,'ndim',w[i].ndim,
1535 'initial: sum',np.sum(w[i],axis=0),'\nentries',w[i])
1537 return w, w_name
1539 class RNN(RNNModel):
1541 A concrete implementation of the RNNModel abstract base class, using simple recurrent cells for hidden recurrent layers.
1543 Parameters:
1544 -----------
1545 params : dict
1546 A dictionary of model parameters.
1547 loss : str, optional
1548 The loss function to use during model training. Default is 'mean_squared_error'.
1550 def __init__(self, params, loss='mean_squared_error'):
1552 Initializes the RNN model by building the training and prediction models.
1554 Parameters:
1555 -----------
1556 params : dict or RNNParams
1557 A dictionary containing the model's parameters.
1558 loss : str, optional
1559 The loss function to use during model training. Default is 'mean_squared_error'.
1560 """
1561 super().__init__(params)
1562 self.model_train = self._build_model_train()
1563 self.model_predict = self._build_model_predict()
1565 def _build_model_train(self):
1567 Builds and compiles the training model, with batch & sequence shape specifications for input.
1569 Returns:
1570 --------
1571 model : tf.keras.Model
1572 The compiled Keras model for training.
1573 """
1574 inputs = tf.keras.Input(batch_shape=self.params['batch_shape'])
1575 x = inputs
1576 for i in range(self.params['rnn_layers']):
1577 return_sequences = True if i < self.params['rnn_layers'] - 1 else False
1578 x = SimpleRNN(
1579 units=self.params['rnn_units'],
1580 activation=self.params['activation'][0],
1581 dropout=self.params["dropout"][0],
1582 recurrent_dropout = self.params["recurrent_dropout"],
1583 stateful=self.params['stateful'],
1584 return_sequences=return_sequences)(x)
1585 if self.params["dropout"][1] > 0:
1586 x = Dropout(self.params["dropout"][1])(x)
1587 for i in range(self.params['dense_layers']):
1588 x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)
1589 # Add final output layer, must be 1 dense cell with linear activation if continuous scalar output
1590 x = Dense(units=1, activation='linear')(x)
1591 model = tf.keras.Model(inputs=inputs, outputs=x)
1592 optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])
1593 # optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'], clipvalue=self.params['clipvalue'])
1594 # optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1595 model.compile(loss='mean_squared_error', optimizer=optimizer)
1597 if self.params["verbose_weights"]:
1598 print(f"Initial Weights Hash: {hash_weights(model)}")
1599 # print(model.get_weights())
1601 if self.params['phys_initialize']:
1602 assert self.params['scaler'] == 'reproducibility', f"Not implemented yet to do physics initialize with given data scaling {self.params['scaler']}"
1603 assert self.params['features_list'] == ['Ed', 'Ew', 'rain'], f"Physics initiation can only be done with features ['Ed', 'Ew', 'rain'], but given features {self.params['features_list']}"
1604 print("Initializing Model with Physics based weights")
1605 w, w_name=get_initial_weights(model, self.params)
1606 model.set_weights(w)
1607 print('initial weights hash =',hash_weights(model))
1608 return model
1609 def _build_model_predict(self, return_sequences=True):
1611 Builds and compiles the prediction model, doesn't use batch shape nor sequence length.
1613 Parameters:
1614 -----------
1615 return_sequences : bool, optional
1616 Whether to return the full sequence of outputs. Default is True.
1618 Returns:
1619 --------
1620 model : tf.keras.Model
1621 The compiled Keras model for prediction.
1622 """
1623 inputs = tf.keras.Input(shape=(None,self.params['n_features']))
1624 x = inputs
1625 for i in range(self.params['rnn_layers']):
1626 x = SimpleRNN(self.params['rnn_units'],activation=self.params['activation'][0],
1627 stateful=False,return_sequences=return_sequences)(x)
1628 for i in range(self.params['dense_layers']):
1629 x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)
1630 # Add final output layer, must be 1 dense cell with linear activation if continuous scalar output
1631 x = Dense(units=1, activation='linear')(x)
1632 model = tf.keras.Model(inputs=inputs, outputs=x)
1633 optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])
1634 model.compile(loss='mean_squared_error', optimizer=optimizer)
1636 # Set Weights to model_train
1637 w_fitted = self.model_train.get_weights()
1638 model.set_weights(w_fitted)
1640 return model
1643 class RNN_LSTM(RNNModel):
1645 A concrete implementation of the RNNModel abstract base class, use LSTM cells for hidden recurrent layers.
1647 Parameters:
1648 -----------
1649 params : dict
1650 A dictionary of model parameters.
1651 loss : str, optional
1652 The loss function to use during model training. Default is 'mean_squared_error'.
1654 def __init__(self, params, loss='mean_squared_error'):
1656 Initializes the RNN model by building the training and prediction models.
1658 Parameters:
1659 -----------
1660 params : dict or RNNParams
1661 A dictionary containing the model's parameters.
1662 loss : str, optional
1663 The loss function to use during model training. Default is 'mean_squared_error'.
1664 """
1665 super().__init__(params)
1666 self.model_train = self._build_model_train()
1667 self.model_predict = self._build_model_predict()
1669 def _build_model_train(self):
1671 Builds and compiles the training model, with batch & sequence shape specifications for input.
1673 Returns:
1674 --------
1675 model : tf.keras.Model
1676 The compiled Keras model for training.
1677 """
1678 inputs = tf.keras.Input(batch_shape=self.params['batch_shape'])
1679 x = inputs
1680 for i in range(self.params['rnn_layers']):
1681 return_sequences = True if i < self.params['rnn_layers'] - 1 else False
1682 x = LSTM(
1683 units=self.params['rnn_units'],
1684 activation=self.params['activation'][0],
1685 dropout=self.params["dropout"][0],
1686 recurrent_dropout = self.params["recurrent_dropout"],
1687 recurrent_activation=self.params["recurrent_activation"],
1688 stateful=self.params['stateful'],
1689 return_sequences=return_sequences)(x)
1690 if self.params["dropout"][1] > 0:
1691 x = Dropout(self.params["dropout"][1])(x)
1692 for i in range(self.params['dense_layers']):
1693 x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)
1694 model = tf.keras.Model(inputs=inputs, outputs=x)
1695 # optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'], clipvalue=self.params['clipvalue'])
1696 optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])
1697 model.compile(loss='mean_squared_error', optimizer=optimizer)
1699 if self.params["verbose_weights"]:
1700 print(f"Initial Weights Hash: {hash_weights(model)}")
1701 return model
1702 def _build_model_predict(self, return_sequences=True):
1704 Builds and compiles the prediction model, doesn't use batch shape nor sequence length.
1706 Parameters:
1707 -----------
1708 return_sequences : bool, optional
1709 Whether to return the full sequence of outputs. Default is True.
1711 Returns:
1712 --------
1713 model : tf.keras.Model
1714 The compiled Keras model for prediction.
1715 """
1716 inputs = tf.keras.Input(shape=(None,self.params['n_features']))
1717 x = inputs
1718 for i in range(self.params['rnn_layers']):
1719 x = LSTM(
1720 units=self.params['rnn_units'],
1721 activation=self.params['activation'][0],
1722 stateful=False,return_sequences=return_sequences)(x)
1723 for i in range(self.params['dense_layers']):
1724 x = Dense(self.params['dense_units'], activation=self.params['activation'][1])(x)
1725 model = tf.keras.Model(inputs=inputs, outputs=x)
1726 optimizer=tf.keras.optimizers.Adam(learning_rate=self.params['learning_rate'])
1727 model.compile(loss='mean_squared_error', optimizer=optimizer)
1729 # Set Weights to model_train
1730 w_fitted = self.model_train.get_weights()
1731 model.set_weights(w_fitted)
1733 return model