From 0d424c8911a0e00c370d7dccf3ab37944c36a3e9 Mon Sep 17 00:00:00 2001 From: Raj Sinha Date: Wed, 4 Sep 2024 13:57:48 -0700 Subject: [PATCH] Internal update. PiperOrigin-RevId: 671090263 --- spade_anomaly_detection/runner.py | 6 +++--- spade_anomaly_detection/runner_test.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index 80656f7..e4caee3 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -293,7 +293,7 @@ def instantiate_and_fit_ensemble( logging.info('Batch size for OCC ensemble: %s', batch_size) if self.data_format == DataFormat.BIGQUERY: - logging.info('Loading training data from BigQuery.') + logging.info('Loading GMM training data from BigQuery.') self.input_data_loader = cast( data_loader.DataLoader, self.input_data_loader ) @@ -310,7 +310,7 @@ def instantiate_and_fit_ensemble( ], ) else: - logging.info('Loading training data from CSV.') + logging.info('Loading GMM training data from CSV.') self.input_data_loader = cast( csv_data_loader.CsvDataLoader, self.input_data_loader ) @@ -348,7 +348,7 @@ def write_verbose_logs( weights: Weights corresponding to pseudo labels - this is the alpha parameter. """ - updated_label_counts = pd.DataFrame(labels).value_counts() + updated_label_counts = pd.DataFrame(labels).value_counts().reset_index() logging.info('Updated label counts %s', updated_label_counts) if self.test_x is not None and self.test_y is not None: diff --git a/spade_anomaly_detection/runner_test.py b/spade_anomaly_detection/runner_test.py index cab146c..23a2abd 100644 --- a/spade_anomaly_detection/runner_test.py +++ b/spade_anomaly_detection/runner_test.py @@ -387,7 +387,8 @@ def test_verbose_logging_no_error(self): ) with self.subTest(name='LabelCountLogs'): self._assert_regex_in( - training_logs.output, r'Updated label counts 0 90\n1 10\n' + training_logs.output, + r'Updated label counts 0 count\n0 0 90\n1 1 10', ) with self.subTest(name='TrainFeatureShapeLogs'): self._assert_regex_in(