From 7159238c1d062add64625f75a039364e050a8066 Mon Sep 17 00:00:00 2001 From: jh-206 Date: Mon, 16 Sep 2024 20:55:19 -0600 Subject: [PATCH] Update moisture_rnn.py Add start_times argument to batch_reshape class func --- fmda/moisture_rnn.py | 93 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 90 deletions(-) diff --git a/fmda/moisture_rnn.py b/fmda/moisture_rnn.py index b5df472..4cb7d26 100644 --- a/fmda/moisture_rnn.py +++ b/fmda/moisture_rnn.py @@ -588,93 +588,6 @@ class RNNData(dict): self.scaler = scalers[scaler] else: raise ValueError(f"Unrecognized scaler '{scaler}'. Recognized scalers are: {recognized_scalers}.") - # def train_test_split(self, train_frac, val_frac=0.0, subset_features=True, features_list=None, split_space=False, verbose=True): - # """ - # Splits the data into training, validation, and test sets. - - # Parameters: - # ----------- - # train_frac : float - # The fraction of data to be used for training. - # val_frac : float, optional - # The fraction of data to be used for validation. Default is 0.0. - # subset_features : bool, optional - # If True, subsets the data to the specified features list. Default is True. - # features_list : list, optional - # A list of features to use for subsetting. Default is None. - # split_space : bool, optional - # Whether to split the data based on space. Default is False. - # verbose : bool, optional - # If True, prints status messages. Default is True. - # """ - # # Indicate whether multi timeseries or not - # spatial = self.spatial - - # # Extract data to desired features, copy to avoid changing input objects - # X = self.X.copy() - # y = self.y.copy() - # if subset_features: - # if verbose and self.features_list != self.all_features_list: - # print(f"Subsetting input data to features_list: {self.features_list}") - # # Indices to subset all features with based on params features - # indices = [] - # for item in self.features_list: - # if item in self.all_features_list: - # indices.append(self.all_features_list.index(item)) - # else: - # print(f"Warning: feature name '{item}' not found in list of all features from input data") - # if spatial: - # X = [Xi[:, indices] for Xi in X] - # else: - # X = X[:, indices] - - # # Setup train/test in time - # train_ind = int(np.floor(self.hours * train_frac)); self.train_ind = train_ind - # test_ind= int(train_ind + round(self.hours * val_frac)); self.test_ind = test_ind - - # # Check for any potential issues with indices - # if test_ind > self.hours: - # print(f"Setting test index to {self.hours}") - # test_ind = self.hours - # if train_ind >= test_ind: - # raise ValueError("Train index must be less than test index.") - - # # Training data from 0 to train_ind - # # Validation data from train_ind to test_ind - # # Test data from test_ind to end - # if spatial: - # self.X_train = [Xi[:train_ind] for Xi in X] - # self.y_train = [yi[:train_ind].reshape(-1,1) for yi in y] - # if val_frac >0: - # self.X_val = [Xi[train_ind:test_ind] for Xi in X] - # self.y_val = [yi[train_ind:test_ind].reshape(-1,1) for yi in y] - # self.X_test = [Xi[test_ind:] for Xi in X] - # self.y_test = [yi[test_ind:].reshape(-1,1) for yi in y] - # else: - # self.X_train = X[:train_ind] - # self.y_train = y[:train_ind].reshape(-1,1) # assumes y 1-d, change this if vector output - # if val_frac >0: - # self.X_val = X[train_ind:test_ind] - # self.y_val = y[train_ind:test_ind].reshape(-1,1) # assumes y 1-d, change this if vector output - # self.X_test = X[test_ind:] - # self.y_test = y[test_ind:].reshape(-1,1) # assumes y 1-d, change this if vector output - - - - # # Print statements if verbose - # if verbose: - # print(f"Train index: 0 to {train_ind}") - # print(f"Validation index: {train_ind} to {test_ind}") - # print(f"Test index: {test_ind} to {self.hours}") - - # if spatial: - # print(f"X_train[0] shape: {self.X_train[0].shape}, y_train[0] shape: {self.y_train[0].shape}") - # print(f"X_val[0] shape: {self.X_val[0].shape}, y_val[0] shape: {self.y_val[0].shape}") - # print(f"X_test[0] shape: {self.X_test[0].shape}, y_test[0] shape: {self.y_test[0].shape}") - # else: - # print(f"X_train shape: {self.X_train.shape}, y_train shape: {self.y_train.shape}") - # print(f"X_val shape: {self.X_val.shape}, y_val shape: {self.y_val.shape}") - # print(f"X_test shape: {self.X_test.shape}, y_test shape: {self.y_test.shape}") def train_test_split(self, time_fracs=[1.,0.,0.], space_fracs=[1.,0.,0.], subset_features=True, features_list=None, verbose=True): """ Splits the data into training, validation, and test sets. @@ -904,7 +817,7 @@ class RNNData(dict): return np.concatenate((X_train, X_val, X_test), axis=0) else: print(f"Unrecognized or unimplemented return value {return_X}") - def batch_reshape(self, timesteps, batch_size, hours=None, verbose=False): + def batch_reshape(self, timesteps, batch_size, hours=None, verbose=False, start_times=None): """ Restructures input data to RNN using batches and sequences. @@ -945,10 +858,10 @@ class RNNData(dict): if spatial: print(f"Reshaping spatial training data using batch size: {batch_size} and timesteps: {timesteps}") - self.X_train, self.y_train, self.n_seqs = staircase_spatial(self.X_train, self.y_train, timesteps = timesteps, batch_size=batch_size, hours=hours, verbose=verbose) + self.X_train, self.y_train, self.n_seqs = staircase_spatial(self.X_train, self.y_train, timesteps = timesteps, batch_size=batch_size, hours=hours, verbose=verbose, start_times=start_times) if hasattr(self, "X_val"): print(f"Reshaping validation data using batch size: {batch_size} and timesteps: {timesteps}") - self.X_val, self.y_val, _ = staircase_spatial(self.X_val, self.y_val, timesteps = timesteps, batch_size=batch_size, hours=None, verbose=verbose) + self.X_val, self.y_val, _ = staircase_spatial(self.X_val, self.y_val, timesteps = timesteps, batch_size=batch_size, hours=None, verbose=verbose, start_times=start_times) else: print(f"Reshaping training data using batch size: {batch_size} and timesteps: {timesteps}") self.X_train, self.y_train = staircase_2(self.X_train, self.y_train, timesteps = timesteps, batch_size=batch_size, verbose=verbose) -- 2.11.4.GIT