Skip to content

Commit

Permalink
make S3FileSystem.list_paths() only list direct children
Browse files Browse the repository at this point in the history
This also makes it consistent with other FileSystems
  • Loading branch information
AdeelH committed May 2, 2024
1 parent 87d15e3 commit 537c28e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 45 deletions.
96 changes: 54 additions & 42 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Any, Iterator, Tuple
import io
import os
import subprocess
Expand All @@ -16,41 +16,38 @@


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
def get_matching_s3_objects(bucket, prefix='', suffix='',
request_payer='None'):
"""
Generate objects in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch objects whose key starts with
this prefix (optional).
:param suffix: Only fetch objects whose keys end with
this suffix (optional).
def get_matching_s3_objects(
bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[tuple[str, Any]]:
"""Generate objects in an S3 bucket.
Args:
bucket: Name of the S3 bucket.
prefix: Only fetch objects whose key starts with this prefix.
suffix: Only fetch objects whose keys end with this suffix.
"""
s3 = S3FileSystem.get_client()
kwargs = {'Bucket': bucket, 'RequestPayer': request_payer}

# If the prefix is a single string (not a tuple of strings), we can
# do the filtering directly in the S3 API.
if isinstance(prefix, str):
kwargs['Prefix'] = prefix

kwargs = dict(
Bucket=bucket,
RequestPayer=request_payer,
Delimiter=delimiter,
Prefix=prefix,
)
while True:

# The S3 API response is a large blob of metadata.
# 'Contents' contains information about the listed objects.
resp = s3.list_objects_v2(**kwargs)

try:
contents = resp['Contents']
except KeyError:
return

for obj in contents:
resp: dict = s3.list_objects_v2(**kwargs)
dirs: list[dict] = resp.get('CommonPrefixes', {})
files: list[dict] = resp.get('Contents', {})
for obj in dirs:
key = obj['Prefix']
if key.startswith(prefix) and key.endswith(suffix):
yield key, obj
for obj in files:
key = obj['Key']
if key.startswith(prefix) and key.endswith(suffix):
yield obj

yield key, obj
# The S3 API is paginated, returning up to 1000 keys at a time.
# Pass the continuation token into the next response, until we
# reach the final page (when this field is missing).
Expand All @@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='',
break


def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'):
"""
Generate the keys in an S3 bucket.
def get_matching_s3_keys(bucket: str,
prefix: str = '',
suffix: str = '',
delimiter: str = '/',
request_payer: str = 'None') -> Iterator[str]:
"""Generate the keys in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch keys that start with this prefix (optional).
:param suffix: Only fetch keys that end with this suffix (optional).
Args:
bucket: Name of the S3 bucket.
prefix: Only fetch keys that start with this prefix.
suffix: Only fetch keys that end with this suffix.
"""
for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer):
yield obj['Key']
obj_iterator = get_matching_s3_objects(
bucket,
prefix=prefix,
suffix=suffix,
delimiter=delimiter,
request_payer=request_payer)
out = (key for key, _ in obj_iterator)
return out


def progressbar(total_size: int, desc: str):
Expand Down Expand Up @@ -284,11 +291,16 @@ def last_modified(uri: str) -> datetime:
return head_data['LastModified']

@staticmethod
def list_paths(uri, ext=''):
def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]:
request_payer = S3FileSystem.get_request_payer()
parsed_uri = urlparse(uri)
bucket = parsed_uri.netloc
prefix = os.path.join(parsed_uri.path[1:])
keys = get_matching_s3_keys(
bucket, prefix, suffix=ext, request_payer=request_payer)
return [os.path.join('s3://', bucket, key) for key in keys]
bucket,
prefix,
suffix=ext,
delimiter=delimiter,
request_payer=request_payer)
paths = [os.path.join('s3://', bucket, key) for key in keys]
return paths
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def file_exists(uri, fs=None, include_dir=True) -> bool:
return fs.file_exists(uri, include_dir)


def list_paths(uri: str, ext: str = '',
fs: Optional[FileSystem] = None) -> List[str]:
def list_paths(uri: str,
ext: str = '',
fs: Optional[FileSystem] = None,
**kwargs) -> List[str]:
"""List paths rooted at URI.
Optionally only includes paths with a certain file extension.
Expand All @@ -230,14 +232,15 @@ def list_paths(uri: str, ext: str = '',
ext: the optional file extension to filter by
fs: if supplied, use fs instead of automatically chosen FileSystem for
uri
**kwargs: extra kwargs to pass to fs.list_paths().
"""
if uri is None:
return None

if not fs:
fs = FileSystem.get_file_system(uri, 'r')

return fs.list_paths(uri, ext=ext)
return fs.list_paths(uri, ext=ext, **kwargs)


def upload_or_copy(src_path: str,
Expand Down

0 comments on commit 537c28e

Please sign in to comment.