From f618e1848a8e8d69414778bf08e45a619229214e Mon Sep 17 00:00:00 2001 From: jh-206 Date: Wed, 11 Sep 2024 09:42:35 -0600 Subject: [PATCH] Update moisture_rnn.py Use dot operator for RNNData object --- fmda/moisture_rnn.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/fmda/moisture_rnn.py b/fmda/moisture_rnn.py index ab67206..3df0f0e 100644 --- a/fmda/moisture_rnn.py +++ b/fmda/moisture_rnn.py @@ -1150,12 +1150,12 @@ class RNNModel(ABC): print("Input data hashes, NOT formatted for rnn sequence/batches yet") dict0.print_hashes() # Extract Datasets - X_train, y_train, X_test, y_test = dict0['X_train'], dict0['y_train'], dict0["X_test"], dict0['y_test'] + X_train, y_train, X_test, y_test = dict0.X_train, dict0.y_train, dict0.X_test, dict0.y_test if 'X_val' in dict0: - X_val, y_val = dict0['X_val'], dict0['y_val'] + X_val, y_val = dict0.X_val, dict0.y_val else: X_val = None - case_id = dict0['case'] + case_id = dict0.case # Fit model if X_val is None: @@ -1166,13 +1166,13 @@ class RNNModel(ABC): # Generate Predictions and Evaluate Test Error if dict0.spatial: m, errs = self._eval_multi(dict0) + if save_outputs: + dict0['m']=m else: m, errs = self._eval_single(dict0, verbose_weights, reproducibility_run) - plot_data(dict0, title="RNN", title2=dict0['case'], plot_period=plot_period) - - - if save_outputs: - dict0['m']=m + if save_outputs: + dict0['m']=m + plot_data(dict0, title="RNN", title2=dict0.case, plot_period=plot_period) return m, errs @@ -1201,8 +1201,8 @@ class RNNModel(ABC): # Calculate Errors err = rmse(m, y) - train_ind = dict0["train_ind"] # index of final training set value - test_ind = dict0["test_ind"] # index of first test set value + train_ind = dict0.train_ind # index of final training set value + test_ind = dict0.test_ind # index of first test set value err_train = rmse(m[:train_ind], y[:train_ind].flatten()) err_pred = rmse(m[test_ind:], y[test_ind:].flatten()) -- 2.11.4.GIT