From 537c28e2d0467fa54c63f2f9df3c6801847cc05b Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Thu, 25 Apr 2024 17:38:35 -0400 Subject: [PATCH] make S3FileSystem.list_paths() only list direct children This also makes it consistent with other FileSystems --- .../rastervision/aws_s3/s3_file_system.py | 96 +++++++++++-------- .../pipeline/file_system/utils.py | 9 +- 2 files changed, 60 insertions(+), 45 deletions(-) diff --git a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py index f6d4ff854..cfaa2ef98 100644 --- a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py +++ b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Any, Iterator, Tuple import io import os import subprocess @@ -16,41 +16,38 @@ # Code from https://alexwlchan.net/2017/07/listing-s3-keys/ -def get_matching_s3_objects(bucket, prefix='', suffix='', - request_payer='None'): - """ - Generate objects in an S3 bucket. - - :param bucket: Name of the S3 bucket. - :param prefix: Only fetch objects whose key starts with - this prefix (optional). - :param suffix: Only fetch objects whose keys end with - this suffix (optional). +def get_matching_s3_objects( + bucket: str, + prefix: str = '', + suffix: str = '', + delimiter: str = '/', + request_payer: str = 'None') -> Iterator[tuple[str, Any]]: + """Generate objects in an S3 bucket. + + Args: + bucket: Name of the S3 bucket. + prefix: Only fetch objects whose key starts with this prefix. + suffix: Only fetch objects whose keys end with this suffix. """ s3 = S3FileSystem.get_client() - kwargs = {'Bucket': bucket, 'RequestPayer': request_payer} - - # If the prefix is a single string (not a tuple of strings), we can - # do the filtering directly in the S3 API. - if isinstance(prefix, str): - kwargs['Prefix'] = prefix - + kwargs = dict( + Bucket=bucket, + RequestPayer=request_payer, + Delimiter=delimiter, + Prefix=prefix, + ) while True: - - # The S3 API response is a large blob of metadata. - # 'Contents' contains information about the listed objects. - resp = s3.list_objects_v2(**kwargs) - - try: - contents = resp['Contents'] - except KeyError: - return - - for obj in contents: + resp: dict = s3.list_objects_v2(**kwargs) + dirs: list[dict] = resp.get('CommonPrefixes', {}) + files: list[dict] = resp.get('Contents', {}) + for obj in dirs: + key = obj['Prefix'] + if key.startswith(prefix) and key.endswith(suffix): + yield key, obj + for obj in files: key = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): - yield obj - + yield key, obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). @@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='', break -def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'): - """ - Generate the keys in an S3 bucket. +def get_matching_s3_keys(bucket: str, + prefix: str = '', + suffix: str = '', + delimiter: str = '/', + request_payer: str = 'None') -> Iterator[str]: + """Generate the keys in an S3 bucket. - :param bucket: Name of the S3 bucket. - :param prefix: Only fetch keys that start with this prefix (optional). - :param suffix: Only fetch keys that end with this suffix (optional). + Args: + bucket: Name of the S3 bucket. + prefix: Only fetch keys that start with this prefix. + suffix: Only fetch keys that end with this suffix. """ - for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer): - yield obj['Key'] + obj_iterator = get_matching_s3_objects( + bucket, + prefix=prefix, + suffix=suffix, + delimiter=delimiter, + request_payer=request_payer) + out = (key for key, _ in obj_iterator) + return out def progressbar(total_size: int, desc: str): @@ -284,11 +291,16 @@ def last_modified(uri: str) -> datetime: return head_data['LastModified'] @staticmethod - def list_paths(uri, ext=''): + def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]: request_payer = S3FileSystem.get_request_payer() parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( - bucket, prefix, suffix=ext, request_payer=request_payer) - return [os.path.join('s3://', bucket, key) for key in keys] + bucket, + prefix, + suffix=ext, + delimiter=delimiter, + request_payer=request_payer) + paths = [os.path.join('s3://', bucket, key) for key in keys] + return paths diff --git a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py index 8e43a3dd3..5048d3053 100644 --- a/rastervision_pipeline/rastervision/pipeline/file_system/utils.py +++ b/rastervision_pipeline/rastervision/pipeline/file_system/utils.py @@ -219,8 +219,10 @@ def file_exists(uri, fs=None, include_dir=True) -> bool: return fs.file_exists(uri, include_dir) -def list_paths(uri: str, ext: str = '', - fs: Optional[FileSystem] = None) -> List[str]: +def list_paths(uri: str, + ext: str = '', + fs: Optional[FileSystem] = None, + **kwargs) -> List[str]: """List paths rooted at URI. Optionally only includes paths with a certain file extension. @@ -230,6 +232,7 @@ def list_paths(uri: str, ext: str = '', ext: the optional file extension to filter by fs: if supplied, use fs instead of automatically chosen FileSystem for uri + **kwargs: extra kwargs to pass to fs.list_paths(). """ if uri is None: return None @@ -237,7 +240,7 @@ def list_paths(uri: str, ext: str = '', if not fs: fs = FileSystem.get_file_system(uri, 'r') - return fs.list_paths(uri, ext=ext) + return fs.list_paths(uri, ext=ext, **kwargs) def upload_or_copy(src_path: str,