Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to allow labels class values to be unique arbitrary strings. Also allows the unlabeled value to be the empty string. #39

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 92 additions & 23 deletions spade_anomaly_detection/csv_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@
import tensorflow as tf


# Types are from //cloud/ml/research/data_utils/feature_metadata.py
_FEATURES_TYPE: Final[str] = 'FLOAT64'
_SOURCE_LABEL_TYPE: Final[str] = 'STRING'
_SOURCE_LABEL_DEFAULT_VALUE: Final[str] = '-1'
_LABEL_TYPE: Final[str] = 'INT64'
_STRING_TO_INTEGER_LABEL_MAP: dict[str | int, int] = {
1: 1,
0: 0,
-1: -1,
'': -1,
'-1': -1,
'0': 0,
'1': 1,
'positive': 1,
'negative': 0,
'unlabeled': -1,
}

# Setting the shuffle buffer size to 1M seems to be necessary to get the CSV
# reader to provide a diversity of data to the model.
Expand Down Expand Up @@ -167,12 +178,12 @@ def from_inputs_file(
raise ValueError(
f'Label column {label_column_name} not found in the header: {header}'
)
num_features = len(all_columns) - 1
features_types = [_FEATURES_TYPE] * len(all_columns)
column_names_dict = collections.OrderedDict(
zip(all_columns, features_types)
)
column_names_dict[label_column_name] = _SOURCE_LABEL_DEFAULT_VALUE
num_features = len(all_columns) - 1
return ColumnNamesInfo(
column_names_dict=column_names_dict,
header=header,
Expand Down Expand Up @@ -216,6 +227,13 @@ def __init__(self, runner_parameters: parameters.RunnerParameters):
self.runner_parameters.negative_data_value,
self.runner_parameters.unlabeled_data_value,
]
# Add any labels that are not already in the map.
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.positive_data_value] = 1
_STRING_TO_INTEGER_LABEL_MAP[self.runner_parameters.negative_data_value] = 0
_STRING_TO_INTEGER_LABEL_MAP[
self.runner_parameters.unlabeled_data_value
] = -1

# Construct a label remap from string labels to integers. The table is not
# necessary for the case when the labels are all integers. But instead of
# checking if the labels are all integers, we construct the table and use
Expand Down Expand Up @@ -286,7 +304,8 @@ def get_inputs_metadata(
)
# Get information about the columns.
column_names_info = ColumnNamesInfo.from_inputs_file(
csv_filenames[0], label_column_name
csv_filenames[0],
label_column_name,
)
logging.info(
'Obtained metadata for data with CSV prefix %s (number of features=%d)',
Expand Down Expand Up @@ -360,22 +379,19 @@ def filter_func(features: tf.Tensor, label: tf.Tensor) -> bool: # pylint: disab
@classmethod
def convert_str_to_int(cls, value: str) -> int:
"""Converts a string integer label to an integer label."""
if isinstance(value, str) and value.lstrip('-').isdigit():
return int(value)
elif isinstance(value, int):
return value
if value in _STRING_TO_INTEGER_LABEL_MAP:
return _STRING_TO_INTEGER_LABEL_MAP[value]
else:
raise ValueError(
f'Label {value} of type {type(value)} is not a string integer.'
f'Label {value} of type {type(value)} is not a string integer or '
'mappable to an integer.'
)

@classmethod
def _get_label_remap_table(
cls, labels_mapping: dict[str, int]
) -> tf.lookup.StaticHashTable:
"""Returns a label remap table that converts string labels to integers."""
# The possible keys are '', '-1, '0', '1'. None is not included because the
# Data Loader will default to '' if the label is None.
keys_tensor = tf.constant(
list(labels_mapping.keys()),
dtype=tf.dtypes.as_dtype(_SOURCE_LABEL_TYPE.lower()),
Expand All @@ -390,6 +406,14 @@ def _get_label_remap_table(
)
return label_remap_table

def remap_label(self, label: str | tf.Tensor) -> int | tf.Tensor:
"""Remaps the label to an integer."""
if isinstance(label, str) or (
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
):
return self._label_remap_table.lookup(label)
return label

def load_tf_dataset_from_csv(
self,
input_path: str,
Expand Down Expand Up @@ -441,6 +465,7 @@ def load_tf_dataset_from_csv(
self._last_read_metadata.column_names_info.column_names_dict.values()
)
]
logging.info('column_defaults: %s', column_defaults)

# Construct a single dataset out of multiple CSV files.
# TODO(sinharaj): Remove the determinism after testing.
Expand All @@ -456,7 +481,7 @@ def load_tf_dataset_from_csv(
na_value='',
header=True,
num_epochs=1,
shuffle=True,
shuffle=False,
shuffle_buffer_size=_SHUFFLE_BUFFER_SIZE,
shuffle_seed=self.runner_parameters.random_seed,
prefetch_buffer_size=tf.data.AUTOTUNE,
Expand All @@ -473,17 +498,9 @@ def load_tf_dataset_from_csv(
'created.'
)

def remap_label(label: str | tf.Tensor) -> int | tf.Tensor:
"""Remaps the label to an integer."""
if isinstance(label, str) or (
isinstance(label, tf.Tensor) and label.dtype == tf.dtypes.string
):
return self._label_remap_table.lookup(label)
return label

# The Dataset can have labels of type int or str. Cast them to int.
dataset = dataset.map(
lambda features, label: (features, remap_label(label)),
lambda features, label: (features, self.remap_label(label)),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=True,
)
Expand Down Expand Up @@ -535,7 +552,6 @@ def combine_features_dict_into_tensor(
self._label_counts = {
k: v.numpy() for k, v in self.counts_by_label(dataset).items()
}
logging.info('Label counts: %s', self._label_counts)

return dataset

Expand All @@ -554,11 +570,11 @@ def counts_by_label(self, dataset: tf.data.Dataset) -> Dict[int, tf.Tensor]:

@tf.function
def count_class(
counts: Dict[int, int], # Keys are always strings.
counts: Dict[int, int],
batch: Tuple[tf.Tensor, tf.Tensor],
) -> Dict[int, int]:
_, labels = batch
# Keys are always strings.
labels = self.remap_label(labels)
new_counts: Dict[int, int] = counts.copy()
for i in self.all_labels:
# This function is called after the Dataset is constructed and the
Expand All @@ -582,6 +598,59 @@ def count_class(
)
return counts

def counts_by_original_label(
self, dataset: tf.data.Dataset
) -> tuple[dict[str, tf.Tensor], dict[int, tf.Tensor]]:
"""Counts the number of samples in each label class in the dataset."""

all_int_labels = [l for l in self.all_labels if isinstance(l, int)]
logging.info('all_int_labels: %s', all_int_labels)
all_str_labels = [l for l in self.all_labels if isinstance(l, str)]
logging.info('all_str_labels: %s', all_str_labels)

@tf.function
def count_original_class(
counts: Dict[int | str, int],
batch: Tuple[tf.Tensor, tf.Tensor],
) -> Dict[int | str, int]:
keys_are_int = all(isinstance(k, int) for k in counts.keys())
if keys_are_int:
all_labels = all_int_labels
else:
all_labels = all_str_labels
_, labels = batch
new_counts: Dict[int | str, int] = counts.copy()
for label in all_labels:
cc: tf.Tensor = tf.cast(labels == label, tf.int32)
if label in list(new_counts.keys()):
new_counts[label] += tf.reduce_sum(cc)
else:
new_counts[label] = tf.reduce_sum(cc)
return new_counts

int_keys_map = {
k: v
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
if isinstance(k, int)
}
initial_int_state = dict((int(label), 0) for label in int_keys_map.keys())
if initial_int_state:
int_counts = dataset.reduce(
initial_state=initial_int_state, reduce_func=count_original_class
)
else:
int_counts = {}
str_keys_map = {
k: v
for k, v in _STRING_TO_INTEGER_LABEL_MAP.items()
if isinstance(k, str)
}
initial_str_state = dict((str(label), 0) for label in str_keys_map.keys())
str_counts = dataset.reduce(
initial_state=initial_str_state, reduce_func=count_original_class
)
return int_counts, str_counts

def get_label_thresholds(self) -> Mapping[str, float]:
"""Computes positive and negative thresholds based on label ratios.

Expand Down
Loading
Loading