Skip to content

Commit

Permalink
Update to allow labels class values to be unique arbitrary strings. A…
Browse files Browse the repository at this point in the history
…lso allows the unlabeled value to be the empty string.

For example:
positive label value = "positive"
negative label value = "negative"
unlabeled label value = ""

PiperOrigin-RevId: 714146904
  • Loading branch information
raj-sinha authored and The spade_anomaly_detection Authors committed Jan 11, 2025
1 parent 462e53e commit 41f0758
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 78 deletions.
115 changes: 92 additions & 23 deletions spade_anomaly_detection/csv_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@
import tensorflow as tf


# Types are from //cloud/ml/research/data_utils/feature_metadata.py
_FEATURES_TYPE: Final[str] = 'FLOAT64'
_SOURCE_LABEL_TYPE: Final[str] = 'STRING'
_SOURCE_LABEL_DEFAULT_VALUE: Final[str] = '-1'
_LABEL_TYPE: Final[str] = 'INT64'
_STRING_TO_INTEGER_LABEL_MAP: dict[str | int, int] = {
1: 1,
0: 0,
-1: -1,
'': -1,
'-1': -1,
'0': 0,
'1': 1,
'positive': 1,
'negative': 0,
'unlabeled': -1,
}

# Setting the shuffle buffer size to 1M seems to be necessary to get the CSV
# reader to provide a diversity of data to the model.
Expand Down Expand Up @@ -167,12 +178,12 @@ def from_inputs_file(
raise ValueError(
f'Label column {label_column_name} not found in the header: {header}'
)
num_features = len(all_columns) - 1
features_types = [_FEATURES_TYPE] * len(all_columns)
column_names_dict = collections.OrderedDict(
zip(all_columns, features_types)
)
column_names_dict[label_column_name] = _SOURCE_LABEL_DEFAULT_VALUE
num_features = len(all_columns) - 1
return ColumnNamesInfo(
column_names_dict=column_names_dict,
header=header,
Expand Down Expand Up @@ -216,6 +227,13 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
self.runner_parameters.negative_data_value,
self.runner_parameters.unlabeled_data_value,
]
# Add any labels that are not already in the map.
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.positive_data_value] = 1
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.negative_data_value] = 0
_STRING_TO_INTEGER_LABEL_MAP[
self.runner_parameters.unlabeled_data_value
] = -1

# Construct a label remap from string labels to integers. The table is not
# necessary for the case when the labels are all integers. But instead of
# checking if the labels are all integers, we construct the table and use
Expand Down Expand Up @@ -286,7 +304,8 @@ def get_inputs_metadata(
)
# Get information about the columns.
column_names_info = ColumnNamesInfo.from_inputs_file(
csv_filenames[0], label_column_name
csv_filenames[0],
label_column_name,
)
logging.info(
'Obtained metadata for data with CSV prefix %s (number of features=%d)',
Expand Down Expand Up @@ -360,22 +379,19 @@ def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disab
@classmethod
def convert_str_to_int(cls, value: str) -> int:
"""Converts a string integer label to an integer label."""
if isinstance(value, str) and value.lstrip('-').isdigit():
return int(value)
elif isinstance(value, int):
return value
if value in _STRING_TO_INTEGER_LABEL_MAP:
return _STRING_TO_INTEGER_LABEL_MAP[value]
else:
raise ValueError(
f'Label {value} of type {type(value)} is not a string integer.'
f'Label {value} of type {type(value)} is not a string integer or '
'mappable to an integer.'
)

@classmethod
def _get_label_remap_table(
cls, labels_mapping: dict[str, int]
) -> tf.lookup.StaticHashTable:
"""Returns a label remap table that converts string labels to integers."""
# The possible keys are '', '-1, '0', '1'. None is not included because the
# Data Loader will default to '' if the label is None.
keys_tensor = tf.constant(
list(labels_mapping.keys()),
dtype=tf.dtypes.as_dtype(_SOURCE_LABEL_TYPE.lower()),
Expand All @@ -390,6 +406,14 @@ def _get_label_remap_table(
)
return label_remap_table

def remap_label(self, label: str | tf.Tensor) -> int | tf.Tensor:
"""Remaps the label to an integer."""
if isinstance(label, str) or (
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
):
return self._label_remap_table.lookup(label)
return label

def load_tf_dataset_from_csv(
self,
input_path: str,
Expand Down Expand Up @@ -441,6 +465,7 @@ def load_tf_dataset_from_csv(
self._last_read_metadata.column_names_info.column_names_dict.values()
)
]
logging.info('column_defaults: %s', column_defaults)

# Construct a single dataset out of multiple CSV files.
# TODO(sinharaj): Remove the determinism after testing.
Expand All @@ -456,7 +481,7 @@ def load_tf_dataset_from_csv(
na_value='',
header=True,
num_epochs=1,
shuffle=True,
shuffle=False,
shuffle_buffer_size=_SHUFFLE_BUFFER_SIZE,
shuffle_seed=self.runner_parameters.random_seed,
prefetch_buffer_size=tf.data.AUTOTUNE,
Expand All @@ -473,17 +498,9 @@ def load_tf_dataset_from_csv(
'created.'
)

def remap_label(label: str | tf.Tensor) -> int | tf.Tensor:
"""Remaps the label to an integer."""
if isinstance(label, str) or (
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
):
return self._label_remap_table.lookup(label)
return label

# The Dataset can have labels of type int or str. Cast them to int.
dataset = dataset.map(
lambda features, label: (features, remap_label(label)),
lambda features, label: (features, self.remap_label(label)),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=True,
)
Expand Down Expand Up @@ -535,7 +552,6 @@ def combine_features_dict_into_tensor(
self._label_counts = {
k: v.numpy() for k, v in self.counts_by_label(dataset).items()
}
logging.info('Label counts: %s', self._label_counts)

return dataset

Expand All @@ -554,11 +570,11 @@ def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:

@tf.function
def count_class(
counts: Dict[int, int], # Keys are always strings.
counts: Dict[int, int],
batch: Tuple[tf.Tensor, tf.Tensor],
) -> Dict[int, int]:
_, labels = batch
# Keys are always strings.
labels = self.remap_label(labels)
new_counts: Dict[int, int] = counts.copy()
for i in self.all_labels:
# This function is called after the Dataset is constructed and the
Expand All @@ -582,6 +598,59 @@ def count_class(
)
return counts

def counts_by_original_label(
self, dataset: tf.data.Dataset
) -> tuple[dict[str, tf.Tensor], dict[int, tf.Tensor]]:
"""Counts the number of samples in each label class in the dataset."""

all_int_labels = [l for l in self.all_labels if isinstance(l, int)]
logging.info('all_int_labels: %s', all_int_labels)
all_str_labels = [l for l in self.all_labels if isinstance(l, str)]
logging.info('all_str_labels: %s', all_str_labels)

@tf.function
def count_original_class(
counts: Dict[int | str, int],
batch: Tuple[tf.Tensor, tf.Tensor],
) -> Dict[int | str, int]:
keys_are_int = all(isinstance(k, int) for k in counts.keys())
if keys_are_int:
all_labels = all_int_labels
else:
all_labels = all_str_labels
_, labels = batch
new_counts: Dict[int | str, int] = counts.copy()
for label in all_labels:
cc: tf.Tensor = tf.cast(labels == label, tf.int32)
if label in list(new_counts.keys()):
new_counts[label] += tf.reduce_sum(cc)
else:
new_counts[label] = tf.reduce_sum(cc)
return new_counts

int_keys_map = {
k: v
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
if isinstance(k, int)
}
initial_int_state = dict((int(label), 0) for label in int_keys_map.keys())
if initial_int_state:
int_counts = dataset.reduce(
initial_state=initial_int_state, reduce_func=count_original_class
)
else:
int_counts = {}
str_keys_map = {
k: v
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
if isinstance(k, str)
}
initial_str_state = dict((str(label), 0) for label in str_keys_map.keys())
str_counts = dataset.reduce(
initial_state=initial_str_state, reduce_func=count_original_class
)
return int_counts, str_counts

def get_label_thresholds(self) -> Mapping[str, float]:
"""Computes positive and negative thresholds based on label ratios.
Expand Down
Loading

0 comments on commit 41f0758

Please sign in to comment.