From f36c1c229dad893a7403c1521ecf57c9209f6791 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 1 May 2024 17:33:03 -0400 Subject: [PATCH] pass RequestPayer when calling head_object() on s3 client --- .../rastervision/aws_s3/s3_file_system.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py index cfaa2ef98..4c33deb78 100644 --- a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py +++ b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py @@ -187,8 +187,9 @@ def read_bytes(uri: str) -> bytes: bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: - file_size = s3.head_object( - Bucket=bucket, Key=key)['ContentLength'] + obj = s3.head_object( + Bucket=bucket, Key=key, RequestPayer=request_payer) + file_size = obj['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, @@ -263,7 +264,9 @@ def copy_from(src_uri: str, dst_path: str) -> None: request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: - file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength'] + obj = s3.head_object( + Bucket=bucket, Key=key, RequestPayer=request_payer) + file_size = obj['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket,