Skip to content

Commit

Permalink
Allow labels to take on integer values represented by the string type.
Browse files Browse the repository at this point in the history
This change allows the input label type to be a string containg an integer, as well as the empty string which is mapped to the unlabeled data value.

This feature is supported and tested only for the CSV data loader.

Summary:

1. dtype of label column will be str or int.
2. empty string is a legal value.
3. empty string will be internally mapped to the unlabeled value.
4. internally the map is: {'1': 1, '0': 0, '': -1, '-1': -1}.

PiperOrigin-RevId: 702784656
  • Loading branch information
raj-sinha authored and The spade_anomaly_detection Authors committed Dec 4, 2024
1 parent bd1a77d commit e3ec5e3
Show file tree
Hide file tree
Showing 12 changed files with 976 additions and 214 deletions.
236 changes: 196 additions & 40 deletions spade_anomaly_detection/csv_data_loader.py

Large diffs are not rendered by default.

285 changes: 220 additions & 65 deletions spade_anomaly_detection/csv_data_loader_test.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions spade_anomaly_detection/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def setUp(self):
label_col_name='label',
positive_data_value=5,
negative_data_value=3,
labels_are_strings=False,
unlabeled_data_value=-100,
positive_threshold=5,
negative_threshold=95,
Expand Down
34 changes: 19 additions & 15 deletions spade_anomaly_detection/occ_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,16 @@ def _score_unlabeled_data(
'negative_indices': negative_indices
}

def pseudo_label(self,
features: np.ndarray,
labels: np.ndarray,
positive_data_value: int,
negative_data_value: Optional[int],
unlabeled_data_value: int,
alpha: float = 1.0,
verbose: Optional[bool] = False) -> Sequence[np.ndarray]:
def pseudo_label(
self,
features: np.ndarray,
labels: np.ndarray,
positive_data_value: str | int,
negative_data_value: str | int | None,
unlabeled_data_value: str | int,
alpha: float = 1.0,
verbose: Optional[bool] = False,
) -> Sequence[np.ndarray]:
"""Labels unlabeled data using the trained ensemble of OCCs.
Args:
Expand Down Expand Up @@ -270,13 +272,15 @@ def pseudo_label(self,
negative_features,
],
axis=0)
new_labels = np.concatenate([
np.ones(len(new_positive_indices)),
np.zeros(len(new_negative_indices)),
np.ones(len(original_positive_idx)),
np.zeros(len(original_negative_idx))
],
axis=0)
new_labels = np.concatenate(
[
np.full(len(new_positive_indices), positive_data_value),
np.full(len(new_negative_indices), negative_data_value),
np.full(len(original_positive_idx), positive_data_value),
np.full(len(original_negative_idx), negative_data_value),
],
axis=0,
)
weights = np.concatenate([
np.repeat(alpha, len(new_positive_indices)),
np.repeat(alpha, len(new_negative_indices)),
Expand Down
56 changes: 41 additions & 15 deletions spade_anomaly_detection/occ_ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,39 @@ def test_ensemble_training_no_error(
msg='Model count in ensemble not equal to specified ensemble size.',
)

def test_score_unlabeled_data_no_error(self):
@parameterized.named_parameters(
('labels_are_integers', False),
('labels_are_strings', True),
)
def test_score_unlabeled_data_no_error(self, labels_are_strings: bool):
batches_per_occ = 1
positive_threshold = 2
negative_threshold = 90
positive_data_value = 1
negative_data_value = 0
unlabeled_data_value = -1
alpha = 0.8

if labels_are_strings:
positive_data_value = b'1'
negative_data_value = b'0'
unlabeled_data_value = b'-1'
else:
positive_data_value = 1
negative_data_value = 0
unlabeled_data_value = -1

occ_train_dataset = data_loader.load_tf_dataset_from_csv(
dataset_name='drug_train_pu_labeled',
batch_size=None,
filter_label_value=unlabeled_data_value,
# Coerce `unlabeled_data_value` to int since the test dataset contains
# only integer labels.
filter_label_value=int(unlabeled_data_value),
)
features_len = occ_train_dataset.cardinality().numpy()
if labels_are_strings:
# Treat the labels as strings for testing. Note that the test dataset
# contains only integer labels.
occ_train_dataset = occ_train_dataset.map(
lambda x, y: (x, tf.as_string(y))
)

ensemble_obj = occ_ensemble.GmmEnsemble(
n_components=1,
Expand All @@ -114,18 +132,23 @@ def test_score_unlabeled_data_no_error(self):
)
ensemble_obj.fit(occ_train_dataset, batches_per_occ)

features, labels = (
data_loader.load_tf_dataset_from_csv(
dataset_name='drug_train_pu_labeled',
batch_size=500,
filter_label_value=None,
)
.as_numpy_iterator()
.next()
occ_train_dataset = data_loader.load_tf_dataset_from_csv(
dataset_name='drug_train_pu_labeled',
batch_size=500,
filter_label_value=None,
)
if labels_are_strings:
# Treat the labels as strings for testing. Note that the test dataset
# contains only integer labels.
occ_train_dataset = occ_train_dataset.map(
lambda x, y: (x, tf.as_string(y))
)
features, labels = occ_train_dataset.as_numpy_iterator().next()

label_count_before_labeling = len(
np.where((labels == 0) | (labels == 1))[0]
np.where(
(labels == negative_data_value) | (labels == positive_data_value)
)[0]
)

updated_features, updated_labels, weights, pseudolabel_flags = (
Expand All @@ -140,7 +163,10 @@ def test_score_unlabeled_data_no_error(self):
)

label_count_after_labeling = len(
np.where((updated_labels == 0) | (updated_labels == 1))[0]
np.where(
(updated_labels == negative_data_value)
| (updated_labels == positive_data_value)
)[0]
)

new_label_count = label_count_after_labeling - label_count_before_labeling
Expand Down
46 changes: 41 additions & 5 deletions spade_anomaly_detection/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Holds dataclasses and enums leveraged by the SPADE algorithm.
"""
"""Holds dataclasses and enums leveraged by the SPADE algorithm."""


import dataclasses
import enum
Expand All @@ -46,6 +46,19 @@ class TrainSetting(str, enum.Enum):
PNU = 'PNU'


# An immutable mapping of label values to their corresponding integer values.
# This supports label values that the empty string. This value is mapped to -1,
# which is the value used to denote unlabeled data.
# None is not included in this mapping, because the Data Loader will default
# to the empty string if the label column is empty.
labels_mapping: Final[dict[str | None, int]] = {
'': -1,
'1': 1,
'0': 0,
'-1': -1,
}


@dataclasses.dataclass
class RunnerParameters:
"""Stores runner related parameters for helper functions in the module.
Expand All @@ -71,6 +84,8 @@ class RunnerParameters:
will be added to the end of the folder so that multiple runs of this won't
overwrite previous runs.
label_col_name: The name of the label column in the input BigQuery table.
labels_are_strings: Whether the labels in the input dataset are strings or
integers.
positive_data_value: The value used in the label column to denote positive
data - data points that are anomalous.
negative_data_value: The value used in the label column to denote negative
Expand Down Expand Up @@ -163,9 +178,10 @@ class RunnerParameters:
data_input_gcs_uri: str
output_gcs_uri: str
label_col_name: str
positive_data_value: int
negative_data_value: int
unlabeled_data_value: int
positive_data_value: int | str
negative_data_value: int | str
unlabeled_data_value: int | str
labels_are_strings: bool = True
positive_threshold: Optional[float] = None
negative_threshold: Optional[float] = None
ignore_columns: Optional[Sequence[str]] = None
Expand All @@ -188,6 +204,8 @@ class RunnerParameters:
verbose: bool = False

def __post_init__(self):
"""Validates the parameters and sets default values."""
# Parameter checks.
if not (self.input_bigquery_table_path or self.data_input_gcs_uri):
raise ValueError(
'`input_bigquery_table_path` or `data_input_gcs_uri` must be set.'
Expand All @@ -212,3 +230,21 @@ def __post_init__(self):
'`positive_data_value`, `negative_data_value` and'
' `unlabeled_data_value` must all be different from each other.'
)
if self.labels_are_strings and not self._check_labels_are_strings():
raise TypeError(
'`labels_are_strings` must be True if `positive_data_value`, '
'`negative_data_value` and `unlabeled_data_value` are strings.'
)
# Adjust the labels if needed.
if not self.labels_are_strings:
self.positive_data_value = int(self.positive_data_value)
self.negative_data_value = int(self.negative_data_value)
self.unlabeled_data_value = int(self.unlabeled_data_value)

def _check_labels_are_strings(self) -> bool:
"""Returns True if the labels are strings."""
return (
isinstance(self.positive_data_value, str)
and isinstance(self.negative_data_value, str)
and isinstance(self.unlabeled_data_value, str)
)
28 changes: 21 additions & 7 deletions spade_anomaly_detection/parameters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def test_none_required_parameter_raises(self):
data_input_gcs_uri='gs://some_bucket/some_data_input_path',
output_gcs_uri='gs://some_bucket/some_path',
label_col_name='y',
positive_data_value=1,
negative_data_value=0,
unlabeled_data_value=-1,
positive_data_value='1',
negative_data_value='0',
unlabeled_data_value='-1',
)
with self.subTest(name='no_input_sources_specified'):
with self.assertRaises(ValueError):
Expand All @@ -58,9 +58,9 @@ def test_none_required_parameter_raises(self):
data_input_gcs_uri=None,
output_gcs_uri='gs://some_bucket/some_path',
label_col_name='y',
positive_data_value=1,
negative_data_value=0,
unlabeled_data_value=-1,
positive_data_value='1',
negative_data_value='0',
unlabeled_data_value='-1',
)

def test_equal_data_value_parameter_raises(self):
Expand All @@ -71,9 +71,23 @@ def test_equal_data_value_parameter_raises(self):
data_input_gcs_uri=None,
output_gcs_uri='gs://some_bucket/some_path',
label_col_name='y',
positive_data_value='1',
negative_data_value='0',
unlabeled_data_value='0',
)

def test_labels_are_strings_discrepancy_raises(self):
with self.assertRaises(TypeError):
_ = parameters.RunnerParameters(
train_setting=parameters.TrainSetting.PNU,
input_bigquery_table_path='some_project.some_dataset.some_table',
data_input_gcs_uri=None,
output_gcs_uri='gs://some_bucket/some_path',
label_col_name='y',
labels_are_strings=True,
positive_data_value=1,
negative_data_value=0,
unlabeled_data_value=0,
unlabeled_data_value=-1,
)


Expand Down
38 changes: 34 additions & 4 deletions spade_anomaly_detection/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
individual modules and functions.
"""


from unittest import mock

from absl.testing import parameterized
from spade_anomaly_detection import csv_data_loader
from spade_anomaly_detection import data_loader
from spade_anomaly_detection import parameters
Expand Down Expand Up @@ -66,6 +66,7 @@ def setUp(self):
positive_data_value=1,
negative_data_value=0,
unlabeled_data_value=-1,
labels_are_strings=False,
positive_threshold=10,
negative_threshold=90,
test_label_col_name='y',
Expand Down Expand Up @@ -240,7 +241,7 @@ def test_spade_auc_performance_pu_single_batch(self):
self.assertAlmostEqual(auc, 0.9178, delta=0.02)


class PerformanceTestOnCSVData(tf.test.TestCase):
class PerformanceTestOnCSVData(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
Expand All @@ -262,6 +263,7 @@ def setUp(self):
positive_data_value=1,
negative_data_value=0,
unlabeled_data_value=-1,
labels_are_strings=False,
positive_threshold=10,
negative_threshold=90,
test_label_col_name='y',
Expand Down Expand Up @@ -399,8 +401,22 @@ def setUp(self):
)
)

def test_spade_auc_performance_pnu_single_batch(self):
@parameterized.named_parameters([
('labels_are_ints', False, 1, 0, -1),
('labels_are_strings', True, '1', '0', '-1'),
])
def test_spade_auc_performance_pnu_single_batch(
self,
labels_are_strings: bool,
positive_data_value: str | int,
negative_data_value: str | int,
unlabeled_data_value: str | int,
):
self.runner_parameters.train_setting = parameters.TrainSetting.PNU
self.runner_parameters.labels_are_strings = labels_are_strings
self.runner_parameters.positive_data_value = positive_data_value
self.runner_parameters.negative_data_value = negative_data_value
self.runner_parameters.unlabeled_data_value = unlabeled_data_value
self.runner_parameters.positive_threshold = 0.1
self.runner_parameters.negative_threshold = 95
self.runner_parameters.alpha = 0.1
Expand Down Expand Up @@ -433,8 +449,22 @@ def test_spade_auc_performance_pnu_single_batch(self):
# performance seen on the ~580k row Coertype dataset in the PNU setting.
self.assertAlmostEqual(auc, 0.9755, delta=0.02)

def test_spade_auc_performance_pu_single_batch(self):
@parameterized.named_parameters([
('labels_are_ints', False, 1, 0, -1),
('labels_are_strings', True, '1', '0', '-1'),
])
def test_spade_auc_performance_pu_single_batch(
self,
labels_are_strings: bool,
positive_data_value: str | int,
negative_data_value: str | int,
unlabeled_data_value: str | int,
):
self.runner_parameters.train_setting = parameters.TrainSetting.PU
self.runner_parameters.labels_are_strings = labels_are_strings
self.runner_parameters.positive_data_value = positive_data_value
self.runner_parameters.negative_data_value = negative_data_value
self.runner_parameters.unlabeled_data_value = unlabeled_data_value
self.runner_parameters.positive_threshold = 10
self.runner_parameters.negative_threshold = 50
self.runner_parameters.labeling_and_model_training_batch_size = (
Expand Down
Loading

0 comments on commit e3ec5e3

Please sign in to comment.