From 163e91d225a62021c4fd426298a5cd068039bd06 Mon Sep 17 00:00:00 2001 From: Raj Sinha Date: Wed, 4 Dec 2024 16:26:55 -0800 Subject: [PATCH] Internal update. PiperOrigin-RevId: 702899732 --- spade_anomaly_detection/runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spade_anomaly_detection/runner.py b/spade_anomaly_detection/runner.py index 674f6ad..e70a4f9 100644 --- a/spade_anomaly_detection/runner.py +++ b/spade_anomaly_detection/runner.py @@ -36,7 +36,6 @@ """ import enum -# TODO(b/247116870): Change to collections when Vertex supports python 3.9 from typing import Mapping, Optional, Tuple, cast from absl import logging @@ -49,6 +48,8 @@ from spade_anomaly_detection import supervised_model import tensorflow as tf +# TODO(b/247116870): Change to collections when Vertex supports python 3.9 + @enum.unique class DataFormat(enum.Enum): @@ -135,6 +136,7 @@ def __init__(self, runner_parameters: parameters.RunnerParameters): else: self.supervised_model_object = None + # If the thresholds are not set, use the thresholds from the input table. if ( self.runner_parameters.positive_threshold is None or self.runner_parameters.negative_threshold is None @@ -760,7 +762,7 @@ def run(self) -> None: batch_size=1, ) train_label_counts = self.input_data_loader.label_counts - # TODO(sinharaj): This is not ideal, we should not need to read the files + # This is not ideal, we should not need to read the files # again. Find a way to get the label counts without reading the files. # Assumes that data loader has already been used to read the input table. total_record_count = sum(train_label_counts.values()) @@ -885,6 +887,7 @@ def run(self) -> None: labels=updated_labels, weights=weights, ) + # End of pseudolabeling and supervised model training loop. if not self.runner_parameters.upload_only: self.evaluate_model()