Skip to content

Commit

Permalink
add logic for handling large file uploads to s3
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Jan 26, 2024
1 parent 2894d20 commit ed9b6ba
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
65 changes: 54 additions & 11 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
# from the MIT License (MIT).
#
"""Implements file-like objects for reading and writing from/to AWS S3."""
from __future__ import annotations

import io
import functools
import logging
import time
import warnings
from typing import TYPE_CHECKING

try:
import boto3
Expand All @@ -27,13 +29,22 @@

from smart_open import constants

if TYPE_CHECKING:
from mypy_boto3_s3.client import S3Client
from typing_extensions import Buffer

logger = logging.getLogger(__name__)

DEFAULT_MIN_PART_SIZE = 50 * 1024**2
"""Default minimum part size for S3 multipart uploads"""
MIN_MIN_PART_SIZE = 5 * 1024 ** 2
"""The absolute minimum permitted by Amazon."""

DEFAULT_MAX_PART_SIZE = 5 * 1024**3
"""Default maximum part size for S3 multipart uploads"""
MAX_MAX_PART_SIZE = 5 * 1024 ** 3
"""The absolute maximum permitted by Amazon."""

SCHEMES = ("s3", "s3n", 's3u', "s3a")
DEFAULT_PORT = 443
DEFAULT_HOST = 's3.amazonaws.com'
Expand Down Expand Up @@ -247,6 +258,7 @@ def open(
client=None,
client_kwargs=None,
writebuffer=None,
max_part_size=DEFAULT_MAX_PART_SIZE
):
"""Open an S3 object for reading or writing.
Expand Down Expand Up @@ -317,6 +329,7 @@ def open(
client=client,
client_kwargs=client_kwargs,
writebuffer=writebuffer,
max_part_size=max_part_size,
)
else:
fileobj = SinglepartWriter(
Expand Down Expand Up @@ -779,14 +792,28 @@ def __init__(
min_part_size=DEFAULT_MIN_PART_SIZE,
client=None,
client_kwargs=None,
writebuffer=None,
writebuffer: io.BytesIO|None=None,
max_part_size=DEFAULT_MAX_PART_SIZE
):
if min_part_size < MIN_MIN_PART_SIZE:
logger.warning("S3 requires minimum part size >= 5MB; \
multipart upload may fail")
logger.warning(f"min_part_size set to {min_part_size}; "
"S3 requires minimum part size >= 5MiB; "
"multipart upload may fail")
if max_part_size > MAX_MAX_PART_SIZE:
logger.warning(f"max_part_size set to {max_part_size}; "
"S3 requires maximum part size <= 5GiB; "
"multipart upload may fail")
if max_part_size < min_part_size:
logger.warning(f"max_part_size {max_part_size} smaller than min_part_size {min_part_size}. Setting min_part_size to max_part_size")
min_part_size = max_part_size
# Raise error instead?
self._min_part_size = min_part_size
self._max_part_size = max_part_size

_initialize_boto3(self, client, client_kwargs, bucket, key)
self._client: S3Client
self._bucket: str
self._key: str

try:
partial = functools.partial(
Expand All @@ -809,12 +836,12 @@ def __init__(

self._total_bytes = 0
self._total_parts = 0
self._parts = []
self._parts: list[dict[str, object]] = []

#
# This member is part of the io.BufferedIOBase interface.
#
self.raw = None
self.raw = None # type: ignore[assignment]

def flush(self):
pass
Expand Down Expand Up @@ -890,22 +917,38 @@ def tell(self):
def detach(self):
raise io.UnsupportedOperation("detach() not supported")

def write(self, b):
def write(self, b: Buffer) -> int:
"""Write the given buffer (bytes, bytearray, memoryview or any buffer
interface implementation) to the S3 file.
For more information about buffers, see https://docs.python.org/3/c-api/buffer.html
There's buffering happening under the covers, so this may not actually
do any HTTP transfer right away."""
# Part size: 5 MiB to 5 GiB. There is no minimum size limit on the last part of your multipart upload.

length = self._buf.write(b)
self._total_bytes += length
# botocore does not accept memoryview, otherwise we could've gotten away with
# not needing to write a copy to the buffer aside from cases where b is smaller
# than min_part_size

i = 0
mv = memoryview(b)
self._total_bytes += len(mv)

while i < len(mv):
start = i
end = i + self._max_part_size - self._buf.tell()

self._buf.write(mv[start:end])

if self._buf.tell() < self._min_part_size:
assert end >= len(mv)
return len(mv)

if self._buf.tell() >= self._min_part_size:
self._upload_next_part()
i += end-start
return len(mv)

return length

def terminate(self):
"""Cancel the underlying multipart upload."""
Expand All @@ -928,7 +971,7 @@ def to_boto3(self, resource):
#
# Internal methods.
#
def _upload_next_part(self):
def _upload_next_part(self) -> None:
part_num = self._total_parts + 1
logger.info(
"%s: uploading part_num: %i, %i bytes (total %.3fGB)",
Expand Down
36 changes: 36 additions & 0 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,42 @@ def test_writebuffer(self):

assert actual == contents

def test_max_part_size_1(self) -> None:
"""write successive chunks of size 5MiB-1 with a min_part_size of 5MiB and max_part_size=7MiB
There are no minimum size limits of the last part of a multipart upload, which is why test_write03 can get away with small test data. But since we need to get multiple parts we cannot avoid that."""
contents = bytes(5 * 2**20-1)

with smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'wb', min_part_size=5*2**20, max_part_size=7*2**20) as fout:
fout.write(contents)
assert fout._total_parts == 0
assert fout._buf.tell() == 5*2**20-1

fout.write(contents)
assert fout._total_parts == 1
assert fout._buf.tell() == 3*2**20-2

fout.write(contents)
assert fout._total_parts == 2
assert fout._buf.tell() == 1*2**20-3
contents = b''

output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, "rb"))
assert len(output[0]) == 3*(5*2**20-1)

def test_max_part_size_2(self) -> None:
"""Do a single big write of 15MiB with a max_part_size of 5MiB"""
contents = bytes(15 * 2**20)

with smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'wb', min_part_size=5*2**20, max_part_size=5*2**20) as fout:
fout.write(contents)
assert fout._total_parts == 3
assert fout._buf.tell() == 0
contents = b''

output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, "rb"))
assert len(output[0]) == 15*2**20


@moto.mock_s3
class SinglepartWriterTest(unittest.TestCase):
Expand Down

0 comments on commit ed9b6ba

Please sign in to comment.