diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index e4caee3..92af074 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -98,7 +98,8 @@ def __init__(self, runner_parameters: parameters.RunnerParameters): self.input_data_loader = cast( data_loader.DataLoader, self.input_data_loader ) - self.test_data_loader = self.input_data_loader + if not self.runner_parameters.upload_only: + self.test_data_loader = self.input_data_loader else: self.input_data_loader = csv_data_loader.CsvDataLoader( self.runner_parameters @@ -107,13 +108,14 @@ def __init__(self, runner_parameters: parameters.RunnerParameters): self.input_data_loader = cast( csv_data_loader.CsvDataLoader, self.input_data_loader ) - self.test_data_loader = csv_data_loader.CsvDataLoader( - self.runner_parameters - ) - # Type hint to prevent linter errors. - self.test_data_loader = cast( - csv_data_loader.CsvDataLoader, self.test_data_loader - ) + if not self.runner_parameters.upload_only: + self.test_data_loader = csv_data_loader.CsvDataLoader( + self.runner_parameters + ) + # Type hint to prevent linter errors. + self.test_data_loader = cast( + csv_data_loader.CsvDataLoader, self.test_data_loader + ) # TODO(b/247116870): Evaluate performance implications of using a global # testing array - the machine may not have enough memory to store the test @@ -555,6 +557,8 @@ def preprocess_train_test_split( features = features[random_indices] labels = labels[random_indices] + test_x, test_y = None, None + if self.runner_parameters.train_setting == parameters.TrainSetting.PNU: ground_truth_label_indices = np.where( labels != self.runner_parameters.unlabeled_data_value @@ -570,8 +574,9 @@ def preprocess_train_test_split( train_x = np.delete(features, test_index_subset, axis=0) train_y = np.delete(labels, test_index_subset, axis=0) - test_x = features[test_index_subset] - test_y = labels[test_index_subset] + if not self.runner_parameters.upload_only: + test_x = features[test_index_subset] + test_y = labels[test_index_subset] # TODO(b/247116870): Investigate the performance implications and user # interest for including some of the negative data in the training set @@ -606,12 +611,14 @@ def preprocess_train_test_split( axis=0, ) - test_x = np.concatenate( - [features[negative_indices], features[test_positive_indices]], axis=0 - ) - test_y = np.concatenate( - [labels[negative_indices], labels[test_positive_indices]], axis=0 - ) + if not self.runner_parameters.upload_only: + test_x = np.concatenate( + [features[negative_indices], features[test_positive_indices]], + axis=0, + ) + test_y = np.concatenate( + [labels[negative_indices], labels[test_positive_indices]], axis=0 + ) else: raise ValueError( @@ -619,38 +626,42 @@ def preprocess_train_test_split( f'datasets: {self.runner_parameters.train_setting}' ) - # TODO(b/247116870): Implement a dedicated function in the runner class - # to load BQ test sets before evaluating the supervised model. - if ( - self.runner_parameters.test_bigquery_table_path - or self.runner_parameters.data_test_gcs_uri - ): - test_tf_dataset = self._get_test_data() - test_x, test_y = test_tf_dataset.as_numpy_iterator().next() - self.test_x = np.array(test_x) - self.test_y = np.array(test_y) - - if not ( - np.any(test_y == self.runner_parameters.positive_data_value) - and np.any(test_y == self.runner_parameters.negative_data_value) + if not self.runner_parameters.upload_only: + # TODO(b/247116870): Implement a dedicated function in the runner class + # to load BQ test sets before evaluating the supervised model. + if ( + self.runner_parameters.test_bigquery_table_path + or self.runner_parameters.data_test_gcs_uri ): - raise ValueError( - 'Positive and negative labels must be in the testing set. Please ' - 'check the test table provided in the test_bigquery_table_path ' - 'parameter.' - ) - else: - if self.test_x is not None: - self.test_x = np.concatenate([self.test_x, test_x], axis=0) - self.test_y = np.concatenate([self.test_y, test_y], axis=0) + test_tf_dataset = self._get_test_data() + test_x, test_y = test_tf_dataset.as_numpy_iterator().next() + self.test_x = np.array(test_x) + self.test_y = np.array(test_y) + + if not ( + np.any(test_y == self.runner_parameters.positive_data_value) + and np.any(test_y == self.runner_parameters.negative_data_value) + ): + raise ValueError( + 'Positive and negative labels must be in the testing set. Please ' + 'check the test table provided in the test_bigquery_table_path ' + 'parameter.' + ) else: - self.test_x = test_x - self.test_y = test_y - - # Adjust the testing labels to values of 1 and 0 to align with the class - # the supervised model is trained on. - self.test_y[self.test_y == self.runner_parameters.positive_data_value] = 1 - self.test_y[self.test_y == self.runner_parameters.negative_data_value] = 0 + if self.test_x is not None: + self.test_x = np.concatenate([self.test_x, test_x], axis=0) + self.test_y = np.concatenate([self.test_y, test_y], axis=0) + else: + self.test_x = test_x + self.test_y = test_y + + # Adjust the testing labels to values of 1 and 0 to align with the class + # the supervised model is trained on. + if self.test_y is not None: + self.test_y[self.test_y == + self.runner_parameters.positive_data_value] = 1 + self.test_y[self.test_y == + self.runner_parameters.negative_data_value] = 0 return (train_x, train_y) diff --git a/spade_anomaly_detection/runner_test.py b/spade_anomaly_detection/runner_test.py index 23a2abd..7760395 100644 --- a/spade_anomaly_detection/runner_test.py +++ b/spade_anomaly_detection/runner_test.py @@ -776,7 +776,7 @@ def test_evaluate_set_throw_error_not_initialized(self): ): runner_object.evaluate_model() - def test_supervised_model_not_instantiated_throw_error(self): + def test_upload_only_true_throw_error(self): self.runner_parameters.upload_only = True self.runner_parameters.output_bigquery_table_path = ( 'project.dataset.output_table' @@ -785,7 +785,7 @@ def test_supervised_model_not_instantiated_throw_error(self): runner_object.run() with self.assertRaisesRegex( - ValueError, r'Evaluate called without a trained supervised model' + ValueError, r'There is no test set to evaluate on' ): runner_object.evaluate_model()