Skip to content

Commit

Permalink
Add support for wildcards in GCS URIs in CSV data loader (allows data…
Browse files Browse the repository at this point in the history
…_input_gcs_uri to have form 'gs://bucket/dir/prefix*.csv').

PiperOrigin-RevId: 656185270
  • Loading branch information
The spade_anomaly_detection Authors committed Jul 26, 2024
1 parent 53a1db0 commit 4b440cd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
25 changes: 18 additions & 7 deletions spade_anomaly_detection/csv_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,33 @@ def _list_files(
return filenames


def _parse_gcs_uri(gcs_uri: str) -> tuple[str, str]:
def _parse_gcs_uri(gcs_uri: str) -> tuple[str, str, str]:
"""Parses a GCS URI into bucket name, prefix and suffix.
Args:
gcs_uri: GCS URI to parse.
Returns:
Bucket name and prefix.
Bucket name, prefix and suffix.
Raises:
ValueError: If the GCS URI is not valid.
"""
gcs_uri_prefix = 'gs://'
if not gcs_uri.startswith(gcs_uri_prefix):
raise ValueError(f'GCS URI {gcs_uri} does not start with "gs://".')
# Paths must be to folders, not files.
gcs_uri = f'{gcs_uri}/' if not gcs_uri.endswith('/') else gcs_uri
gcs_uri = gcs_uri.removeprefix(gcs_uri_prefix)
bucket_name = gcs_uri.split('/')[0]
rest = gcs_uri.removeprefix(f'{bucket_name}/')
return bucket_name, rest
split = rest.split('*')
if len(split) == 1:
# Paths must be to folders, not files.
rest = f'{rest}/' if not rest.endswith('/') else rest
return bucket_name, rest, ''
elif len(split) == 2:
return bucket_name, split[0], split[1]
else:
raise ValueError(f"GCS URI {gcs_uri} has more than one wildcard ('*').")


@dataclasses.dataclass
Expand Down Expand Up @@ -218,6 +224,7 @@ def get_inputs_metadata(
self,
bucket_name: str,
location_prefix: str,
location_suffix: str,
label_column_name: str,
) -> 'InputFilesMetadata':
"""Gets information about the CSVs containing the input data.
Expand All @@ -226,14 +233,17 @@ def get_inputs_metadata(
bucket_name: Name of the GCS bucket where the CSV files are located.
location_prefix: The prefix of location of the CSV files, excluding any
trailing unique identifiers.
location_suffix: The suffix of location of the CSV files (e.g. '.csv').
label_column_name: The name of the label column.
Returns:
Return a InputFilesMetadata instance.
"""
# Get the names of the CSV files containing the input data.
csv_filenames = _list_files(
bucket_name=bucket_name, input_blob_prefix=location_prefix
bucket_name=bucket_name,
input_blob_prefix=location_prefix,
input_blob_suffix=location_suffix,
)
logging.info(
'Collecting metadata for %d files at %s',
Expand Down Expand Up @@ -316,12 +326,13 @@ def load_tf_dataset_from_csv(
Returns:
A tf.data.Dataset.
"""
bucket, prefix = _parse_gcs_uri(input_path)
bucket, prefix, suffix = _parse_gcs_uri(input_path)
# Since we are reading a new set of CSV files, we need to get the metadata
# again.
self._last_read_metadata = self.get_inputs_metadata(
bucket_name=bucket,
location_prefix=prefix,
location_suffix=suffix,
label_column_name=label_col_name,
)
logging.info('Last read metadata: %s', self._last_read_metadata)
Expand Down
40 changes: 29 additions & 11 deletions spade_anomaly_detection/csv_data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from spade_anomaly_detection import csv_data_loader
from spade_anomaly_detection import parameters
import tensorflow as tf

import tensorflow_datasets as tfds

import pytest
Expand Down Expand Up @@ -110,19 +109,38 @@ def setUp(self):

# Params to test: gcs_uri.
@parameterized.named_parameters(
("single_file", "gs://bucket/dir/file.csv", "bucket", "dir/file.csv/"),
("folder_no_slash", "gs://bucket/dir", "bucket", "dir/"),
("folder_with_slash", "gs://bucket/dir/", "bucket", "dir/"),
(
"single_file",
"gs://bucket/dir/file.csv",
"bucket",
"dir/file.csv/",
"",
),
("folder_no_slash", "gs://bucket/dir", "bucket", "dir/", ""),
("folder_with_slash", "gs://bucket/dir/", "bucket", "dir/", ""),
(
"folder_with_wildcard",
"gs://bucket/dir/file*.csv",
"bucket",
"dir/file",
".csv",
),
)
def test_parse_gcs_uri_returns_bucket_name_and_prefix(
self, gcs_uri, expected_bucket, expected_prefix
def test_parse_gcs_uri_returns_bucket_name_prefix_and_suffix(
self, gcs_uri, expected_bucket, expected_prefix, expected_suffix
):
bucket_name, prefix = csv_data_loader._parse_gcs_uri(gcs_uri=gcs_uri)
bucket_name, prefix, suffix = csv_data_loader._parse_gcs_uri(
gcs_uri=gcs_uri
)
self.assertEqual(bucket_name, expected_bucket)
self.assertEqual(prefix, expected_prefix)
self.assertEqual(suffix, expected_suffix)

def test_parse_gcs_uri_incorrect_uri_raises(self):
gcs_uri = "bucket/dir/"
@parameterized.named_parameters(
("incorrect_folder", "bucket/dir/"),
("too_many_wildcards", "gs://bucket/*/file*.csv")
)
def test_parse_gcs_uri_incorrect_uri_raises(self, gcs_uri):
with self.assertRaises(ValueError):
_, _ = csv_data_loader._parse_gcs_uri(gcs_uri=gcs_uri)

Expand Down Expand Up @@ -259,7 +277,7 @@ def test_load_tf_dataset_from_csv_returns_expected_dataset(
tmp_dir = self.create_tempdir("tmp")
input_path = os.path.join(tmp_dir.full_path, self.dir)
tf.io.gfile.makedirs(input_path)
mock_parse_gcs_uri.return_value = ("doesnt_matter", input_path)
mock_parse_gcs_uri.return_value = ("doesnt_matter", input_path, "")
mock_file_reader.return_value = [
os.path.join(tmp_dir.full_path, self.csv_file1),
os.path.join(tmp_dir.full_path, self.csv_file2),
Expand Down Expand Up @@ -410,7 +428,7 @@ def test_upload_dataframe_to_gcs(self):
np.repeat([0], len(features2))
.reshape(len(features2), 1)
.astype(np.int64)
) # Upload batch 1.
) # Upload batch 1.
data_loader.upload_dataframe_to_gcs(
batch=1,
features=features1,
Expand Down

0 comments on commit 4b440cd

Please sign in to comment.