Skip to content

Commit

Permalink
Do not handle test data when upload_only is True.
Browse files Browse the repository at this point in the history
When the upload_only flag is set, the test data generation is not required. This
commit adds logic to check for this and create the test data accordingly.

PiperOrigin-RevId: 677013478
  • Loading branch information
Vineet Joshi authored and The spade_anomaly_detection Authors committed Sep 21, 2024
1 parent 7c5508a commit 2fe7521
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 48 deletions.
103 changes: 57 additions & 46 deletions spade_anomaly_detection/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
self.input_data_loader = cast(
data_loader.DataLoader, self.input_data_loader
)
self.test_data_loader = self.input_data_loader
if not self.runner_parameters.upload_only:
self.test_data_loader = self.input_data_loader
else:
self.input_data_loader = csv_data_loader.CsvDataLoader(
self.runner_parameters
Expand All @@ -107,13 +108,14 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
self.input_data_loader = cast(
csv_data_loader.CsvDataLoader, self.input_data_loader
)
self.test_data_loader = csv_data_loader.CsvDataLoader(
self.runner_parameters
)
# Type hint to prevent linter errors.
self.test_data_loader = cast(
csv_data_loader.CsvDataLoader, self.test_data_loader
)
if not self.runner_parameters.upload_only:
self.test_data_loader = csv_data_loader.CsvDataLoader(
self.runner_parameters
)
# Type hint to prevent linter errors.
self.test_data_loader = cast(
csv_data_loader.CsvDataLoader, self.test_data_loader
)

# TODO(b/247116870): Evaluate performance implications of using a global
# testing array - the machine may not have enough memory to store the test
Expand Down Expand Up @@ -555,6 +557,8 @@ def preprocess_train_test_split(
features = features[random_indices]
labels = labels[random_indices]

test_x, test_y = None, None

if self.runner_parameters.train_setting == parameters.TrainSetting.PNU:
ground_truth_label_indices = np.where(
labels != self.runner_parameters.unlabeled_data_value
Expand All @@ -570,8 +574,9 @@ def preprocess_train_test_split(
train_x = np.delete(features, test_index_subset, axis=0)
train_y = np.delete(labels, test_index_subset, axis=0)

test_x = features[test_index_subset]
test_y = labels[test_index_subset]
if not self.runner_parameters.upload_only:
test_x = features[test_index_subset]
test_y = labels[test_index_subset]

# TODO(b/247116870): Investigate the performance implications and user
# interest for including some of the negative data in the training set
Expand Down Expand Up @@ -606,51 +611,57 @@ def preprocess_train_test_split(
axis=0,
)

test_x = np.concatenate(
[features[negative_indices], features[test_positive_indices]], axis=0
)
test_y = np.concatenate(
[labels[negative_indices], labels[test_positive_indices]], axis=0
)
if not self.runner_parameters.upload_only:
test_x = np.concatenate(
[features[negative_indices], features[test_positive_indices]],
axis=0,
)
test_y = np.concatenate(
[labels[negative_indices], labels[test_positive_indices]], axis=0
)

else:
raise ValueError(
'Unknown train setting for preparing train/test '
f'datasets: {self.runner_parameters.train_setting}'
)

# TODO(b/247116870): Implement a dedicated function in the runner class
# to load BQ test sets before evaluating the supervised model.
if (
self.runner_parameters.test_bigquery_table_path
or self.runner_parameters.data_test_gcs_uri
):
test_tf_dataset = self._get_test_data()
test_x, test_y = test_tf_dataset.as_numpy_iterator().next()
self.test_x = np.array(test_x)
self.test_y = np.array(test_y)

if not (
np.any(test_y == self.runner_parameters.positive_data_value)
and np.any(test_y == self.runner_parameters.negative_data_value)
if not self.runner_parameters.upload_only:
# TODO(b/247116870): Implement a dedicated function in the runner class
# to load BQ test sets before evaluating the supervised model.
if (
self.runner_parameters.test_bigquery_table_path
or self.runner_parameters.data_test_gcs_uri
):
raise ValueError(
'Positive and negative labels must be in the testing set. Please '
'check the test table provided in the test_bigquery_table_path '
'parameter.'
)
else:
if self.test_x is not None:
self.test_x = np.concatenate([self.test_x, test_x], axis=0)
self.test_y = np.concatenate([self.test_y, test_y], axis=0)
test_tf_dataset = self._get_test_data()
test_x, test_y = test_tf_dataset.as_numpy_iterator().next()
self.test_x = np.array(test_x)
self.test_y = np.array(test_y)

if not (
np.any(test_y == self.runner_parameters.positive_data_value)
and np.any(test_y == self.runner_parameters.negative_data_value)
):
raise ValueError(
'Positive and negative labels must be in the testing set. Please '
'check the test table provided in the test_bigquery_table_path '
'parameter.'
)
else:
self.test_x = test_x
self.test_y = test_y

# Adjust the testing labels to values of 1 and 0 to align with the class
# the supervised model is trained on.
self.test_y[self.test_y == self.runner_parameters.positive_data_value] = 1
self.test_y[self.test_y == self.runner_parameters.negative_data_value] = 0
if self.test_x is not None:
self.test_x = np.concatenate([self.test_x, test_x], axis=0)
self.test_y = np.concatenate([self.test_y, test_y], axis=0)
else:
self.test_x = test_x
self.test_y = test_y

# Adjust the testing labels to values of 1 and 0 to align with the class
# the supervised model is trained on.
if self.test_y is not None:
self.test_y[self.test_y ==
self.runner_parameters.positive_data_value] = 1
self.test_y[self.test_y ==
self.runner_parameters.negative_data_value] = 0

return (train_x, train_y)

Expand Down
4 changes: 2 additions & 2 deletions spade_anomaly_detection/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_evaluate_set_throw_error_not_initialized(self):
):
runner_object.evaluate_model()

def test_supervised_model_not_instantiated_throw_error(self):
def test_upload_only_true_throw_error(self):
self.runner_parameters.upload_only = True
self.runner_parameters.output_bigquery_table_path = (
'project.dataset.output_table'
Expand All @@ -785,7 +785,7 @@ def test_supervised_model_not_instantiated_throw_error(self):
runner_object.run()

with self.assertRaisesRegex(
ValueError, r'Evaluate called without a trained supervised model'
ValueError, r'There is no test set to evaluate on'
):
runner_object.evaluate_model()

Expand Down

0 comments on commit 2fe7521

Please sign in to comment.