diff --git a/tensorizer/_futuregroup.py b/tensorizer/_futuregroup.py new file mode 100644 index 0000000..ac0210c --- /dev/null +++ b/tensorizer/_futuregroup.py @@ -0,0 +1,56 @@ +import concurrent.futures +from typing import Any, List, Optional, Sequence, Union + + +class _FutureGroup: + def __init__(self, futures): # type: (Sequence[_Future]) -> None + self.futures = futures + + def cancel(self) -> bool: + result = True + for f in self.futures: + result = result and f.cancel() + return result + + def cancelled(self) -> bool: + return all(f.cancelled() for f in self.futures) + + def running(self) -> bool: + return any(f.running() for f in self.futures) + + def done(self) -> bool: + return all(f.done() for f in self.futures) + + def result(self, timeout=None) -> None: + _future_wait_and_raise(self.futures, timeout=timeout) + + def exception(self, timeout=None) -> Optional[BaseException]: + for f in self.futures: + exc = f.exception(timeout=timeout) + if exc: + return exc + return None + + +_Future = Union[_FutureGroup, concurrent.futures.Future] + + +def _future_wait_and_raise(futures: Sequence[_Future], timeout=None) -> None: + # Wait on a list of futures with a timeout. Raise any exceptions, including TimeoutErrors. + + flattened_futures: List[concurrent.futures.Future] = [] + futures = list(futures) + while futures: + f = futures.pop() + if isinstance(f, _FutureGroup): + futures.extend(f.futures) + else: + flattened_futures.append(f) + + fs = concurrent.futures.wait(flattened_futures, timeout=timeout) + for f in fs.done: + # if the future has an exception, this will raise it + f.result() + for f in fs.not_done: + # force raise of TimeoutError + f.result(0) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 7b3b182..b88f483 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -4,9 +4,11 @@ ############################################################################## import abc import bisect +import collections import collections.abc import concurrent.futures import contextlib +import ctypes import dataclasses import enum import functools @@ -61,6 +63,11 @@ from tensorizer._crypt._cgroup_cpu_count import ( effective_cpu_count as _effective_cpu_count, ) +from tensorizer._futuregroup import ( + _Future, + _future_wait_and_raise, + _FutureGroup, +) from tensorizer._internal_utils import Chunked as _Chunked from tensorizer._internal_utils import _variable_read from tensorizer._NumpyTensor import _NumpyTensor @@ -157,7 +164,13 @@ class TensorType(IntEnum): # Current version -TENSORIZER_VERSION = 4 +TENSORIZER_VERSION = 5 + +HEADERS_AT_TOP_TENSORIZER_VERSION = 5 + +# The hashable_segment_views used to include the fields that include the hash results themselves. +# These fields were zero when computing hash +HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION = 5 # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. @@ -221,7 +234,9 @@ def deserialized_length(self): if self.data_length > 0: return self.data_length element_size: int = numpy.dtype(self.dtype).itemsize - num_elements: int = numpy.product(self.shape) + num_elements: int = int( + numpy.product(self.shape) + ) # numpy.product([]) == 1.0 return element_size * num_elements @@ -397,17 +412,27 @@ def __init__( dtype: bytes, shape: Sequence[int], data_length: int, - file_offset: int, + file_offset: int, # location of header in file include_crc32: bool = True, include_sha256: bool = True, crypt_info: Optional[_crypt_info.CryptInfo] = None, ): + self.module_index = module_index + self.tensor_type = tensor_type + self.name = name + self.shape = shape + self.dtype = dtype + self.data_length = data_length + self.file_offset = file_offset + self.include_crc32 = include_crc32 + self.include_sha256 = include_sha256 + # Calculate the variable length segment - name_len = len(name) - dtype_len = len(dtype) + self.name_len = len(name) + self.dtype_len = len(dtype) # NB: shape_len is the number of dimensions, # not the encoded byte length - shape_len = len(shape) + self.shape_len = len(shape) self.crypt_info = crypt_info if crypt_info is None: crypt_info_len = 0 @@ -415,9 +440,9 @@ def __init__( crypt_info_len = crypt_info.sized_size self.variable_length_segment = struct.Struct( self.variable_length_segment_template.format( - name_len=name_len, - dtype_len=dtype_len, - shape_len=shape_len, + name_len=self.name_len, + dtype_len=self.dtype_len, + shape_len=self.shape_len, ) ) crc32_len = sha256_len = self.hash_count = 0 @@ -438,7 +463,7 @@ def __init__( self.sha256_hash_offset, self.crypt_info_offset, self.data_length_offset, - self.data_offset, + self.size, ) = itertools.accumulate( ( self.start_segment.size, @@ -450,25 +475,28 @@ def __init__( self.data_length_segment.size, ) ) - self.size = self.data_offset + + def build(self, tensor_data_offset: int): + # tensor_data_offset: location of tensor data in file + self.data_offset = tensor_data_offset self.buffer = bytearray(self.size) self.start_segment.pack_into( self.buffer, 0, # Offset self.size, # Tensor header size - module_index, # Module index. - tensor_type.value, # Whether this is a parameter or a buffer - name_len, # Parameter/buffer name length + self.module_index, # Module index. + self.tensor_type.value, # Whether this is a parameter or a buffer + self.name_len, # Parameter/buffer name length ) self.variable_length_segment.pack_into( self.buffer, self.variable_length_offset, - name, # Parameter/buffer name UTF-8 bytes - dtype_len, # Tensor dtype length - dtype, # Tensor dtype UTF-8 bytes - shape_len, # Tensor shape length - *shape, # Tensor shape I array + self.name, # Parameter/buffer name UTF-8 bytes + self.dtype_len, # Tensor dtype length + self.dtype, # Tensor dtype UTF-8 bytes + self.shape_len, # Tensor shape length + *self.shape, # Tensor shape I array ) after_hashes = self.crypt_info_offset @@ -481,46 +509,46 @@ def __init__( ) # Placeholders - if include_crc32: + if self.include_crc32: self.add_crc32(0) - if include_sha256: + if self.include_sha256: self.add_sha256(b"") - if crypt_info is not None: - crypt_info.sized_pack_into(self.buffer, self.crypt_info_offset) + if self.crypt_info is not None: + self.crypt_info.sized_pack_into(self.buffer, self.crypt_info_offset) self.data_length_segment.pack_into( - self.buffer, self.data_length_offset, data_length + self.buffer, self.data_length_offset, self.data_length ) - metadata_entry_segment: struct.Struct = struct.Struct( - self.metadata_entry_segment_template.format( - name_len=name_len, - dtype_len=dtype_len, - shape_len=shape_len, - ) - ) + metadata_entry_segment = self.get_metadata_entry_segment() self.metadata_entry = metadata_entry_segment.pack( - name_len, # Name length - name, # Name - tensor_type.value, # Whether this is a parameter or a buffer - dtype_len, # Dtype length - dtype, # Dtype - shape_len, # Shape length - *shape, # Shape - file_offset, # Header start (relative to the file) + self.name_len, # Name length + self.name, # Name + self.tensor_type.value, # Whether this is a parameter or a buffer + self.dtype_len, # Dtype length + self.dtype, # Dtype + self.shape_len, # Shape length + *self.shape, # Shape + self.file_offset, # Header start (relative to the file) # Tensor data start (relative to the file): - file_offset + self.data_offset, - data_length, # Tensor length + self.data_offset, + self.data_length, # Tensor length + ) + + def get_metadata_entry_segment(self) -> struct.Struct: + return struct.Struct( + self.metadata_entry_segment_template.format( + name_len=self.name_len, + dtype_len=self.dtype_len, + shape_len=self.shape_len, + ) ) def _hashable_segment_views(self): - if self.crypt_info is None: - yield memoryview(self.buffer) - else: - yield memoryview(self.buffer)[: self.crypt_info_offset] - # Skip crypt_info - yield memoryview(self.buffer)[self.data_length_offset :] + # Skip areas where we store hashes and crypt_info + yield memoryview(self.buffer)[: self.hash_header_offset] + yield memoryview(self.buffer)[self.data_length_offset :] def compute_crc32(self) -> int: crc32 = 0 @@ -599,6 +627,7 @@ class _TensorHeaderDeserializer: @classmethod def from_io( cls, + file_version: int, reader: io.BufferedIOBase, zero_hashes: bool = True, check_crypt_info: bool = False, @@ -616,15 +645,20 @@ def from_io( with memoryview(buffer) as mv: reader.readinto(mv[offset:]) return cls( - buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info + file_version, + buffer, + zero_hashes=zero_hashes, + check_crypt_info=check_crypt_info, ) def __init__( self, + file_version: int, buffer: bytearray, zero_hashes: bool = True, check_crypt_info: bool = False, ): + self.file_version = file_version self.buffer = buffer offset = self.header_len_segment.size self.module_idx, tensor_type = self.tensor_info_segment.unpack_from( @@ -652,6 +686,7 @@ def __init__( self.shape, offset = self.read_shape(buffer, offset) # Read our hashes in. + hash_start = offset hashes_slice, offset = self.read_hash_block(buffer, offset) with hashes_slice: self.hashes = self._decode_hashes(hashes_slice) @@ -663,23 +698,37 @@ def __init__( crypt_info_slice, offset = self.read_crypt_info_block( buffer, offset ) - self._hashable_segments = ( - slice(None, crypt_info_start), - slice(offset, None), - ) with crypt_info_slice: self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice ) + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): + self._hashable_segments = ( + slice(None, crypt_info_start), + slice(offset, None), + ) else: self.crypt_info = None - self._hashable_segments = (slice(None, None),) # Finally, get the tensor data length. - offset = len(buffer) - self.data_length_segment.size + data_length_start = offset = len(buffer) - self.data_length_segment.size self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): + if not check_crypt_info: + self._hashable_segments = (slice(None, None),) + else: + self._hashable_segments = ( + slice(None, hash_start), + slice(data_length_start, None), + ) def _hashable_segment_views(self): for segment_slice in self._hashable_segments: @@ -1656,13 +1705,6 @@ def __init__( "Tensor is encrypted, but decryption was not requested" ) - # The total size of the file. - # WARNING: this is not accurate. This field isn't used in the - # deserializer, but has been available as a public attribute, - # so it is kept how it was for compatibility until the next - # major version. - self.total_file_bytes = self._file_header.tensor_size - # Read the metadata index of tensors. # This is a list of offsets into the file where the per-tensor data # is stored. @@ -1674,6 +1716,31 @@ def __init__( ) if not self._metadata: raise ValueError("Tensor index in the file is empty") + + self._headers: Optional[ + Dict[_TensorPath, _TensorHeaderDeserializer] + ] + if version_number >= HEADERS_AT_TOP_TENSORIZER_VERSION: + metadata_ordered = sorted( + self._metadata.values(), key=operator.attrgetter("offset") + ) + self._headers = {} + for entry in metadata_ordered: + if self._file.tell() > entry.offset: + raise ValueError("Header offsets overlap or are wrong") + self._file.seek(entry.offset) + header = _TensorHeaderDeserializer.from_io( + version_number, + self._file, + zero_hashes=True, + check_crypt_info=self._has_crypt_info, + ) + if header is None: + raise ValueError("Unexpected empty header") + self._headers[entry.name] = header + else: + self._headers = None + # filter_func is a test that determines the tensor names to read. # If filter_func is None, all tensors are read. if filter_func is not None: @@ -2243,7 +2310,7 @@ def keys(self): def _verify_hashes( self, - name: str, + name: _TensorPath, hashes: Iterable[TensorHash], header_hashes: Dict[HashType, Any], mv: Union[memoryview, bytes], @@ -2858,19 +2925,24 @@ def _copy_thread( if halt: break - header = _TensorHeaderDeserializer.from_io( - file_, - zero_hashes=True, - check_crypt_info=unsafe_self._has_crypt_info, - ) - - if header is None: - raise ValueError("Unexpected empty header") + if unsafe_self._headers is None: + header = _TensorHeaderDeserializer.from_io( + unsafe_self._file_header.version_number, + file_, + zero_hashes=True, + check_crypt_info=unsafe_self._has_crypt_info, + ) + if header is None: + raise KeyError("Unexpected empty header") - # Skip it if this tensor is not one we're supposed to load - if header.name not in tensor_sizes_by_name: - file_.seek(header.data_length, io.SEEK_CUR) - continue + # Skip it if this tensor is not one we're supposed to load + if header.name not in tensor_sizes_by_name: + file_.seek(header.data_length, io.SEEK_CUR) + continue + else: + header = unsafe_self._headers[ + tensor_items[tensors_read].name + ] numpy_dtype, *torch_dtype = header.dtype.split(OPAQUE_DTYPE_SEP) if not torch_dtype: @@ -2934,6 +3006,7 @@ def _copy_thread( if not is_meta: start = time.perf_counter_ns() if _perf_stats else 0 + file_.seek(unsafe_self._metadata[header.name].data_offset) if unsafe_self._encrypted and mv.nbytes > 0: TensorDeserializer._stream_decrypt( @@ -3266,6 +3339,7 @@ def __init__( *, encryption: Optional[EncryptionParams] = None, limit_cpu_concurrency: Optional[int] = None, + max_tensors: Optional[int] = None, ) -> None: if isinstance(file_obj, (str, bytes, os.PathLike, int)): self._file = stream_io.open_stream(file_obj, "wb+") @@ -3397,7 +3471,7 @@ def __init__( # to each pool in the same relative order. # Tracks work submitted to all pools to wait for pending work to finish. - self._jobs: List[concurrent.futures.Future] = [] + self._jobs: List[_Future] = [] # Tracks work submitted to the decryption pool to prevent conflicting, # overlapping in-place operations on tensors using shared storage. self._decryption_jobs: typing.MutableMapping[ @@ -3415,16 +3489,7 @@ def __init__( self._write(TENSORIZER_MAGIC) # Write file header metadata - if not self._encrypted: - # Can't tell if OPAQUE_TENSORIZER_VERSION - # or META_TENSOR_TENSORIZER_VERSION are needed - # until a tensor is written later with an opaque dtype - # or from the meta device, - # so assume it is compatible with version 1 until then. - version_number = NON_OPAQUE_TENSORIZER_VERSION - else: - # File encryption requires a newer tensorizer version - version_number = ENCRYPTION_TENSORIZER_VERSION + version_number = HEADERS_AT_TOP_TENSORIZER_VERSION feature_flags = _FileFeatureFlags(0) if self._encrypted: feature_flags |= _FileFeatureFlags.encrypted @@ -3437,14 +3502,26 @@ def __init__( ) self._write(self._file_header.to_bytes()) - # Reserve 256 KiB for metadata. - metadata_size = 256 * 1024 - self._write(struct.pack(" None: """ - Serializes a tensor, laying things out so that it can be read in three - calls from the input -- once for the size, once for the header, and - once for the tensor itself. + Serializes a tensor. Header data is appended to the metadata block at the top of the file, and + tensor data is appended to the bottom of the file. Args: idx: The index of the tensor in the module. @@ -3684,6 +3755,7 @@ def write_tensor( Serialization format: + in header block near top of file: { uint64 header_sz, uint16 module_idx, uint8 type, @@ -3700,529 +3772,158 @@ def write_tensor( uint8 hash_sz, []char hash_str, } hashes, - uint64 tensor_sz, - []byte tensor } - """ - self._write_tensor( - idx=idx, name=name, tensor_type=tensor_type, tensor=tensor - ) - - def _write_tensor( - self, - idx, - name: Union[_TensorPath, str], - tensor_type: TensorType, - tensor: Union[torch.Tensor, numpy.ndarray], - *, - _synchronize: bool = True, - _start_pos: Optional[int] = None, - _temporary_buffer: bool = False, - ) -> int: - """ - Underlying implementation for `write_tensor()`, - providing additional controls for asynchronous writes - - Args: - idx: The index of the tensor in the module. - name: The name of the tensor. - tensor_type: The type of the tensor. This is used to determine - how to interpret the tensor. - tensor: The tensor to serialize. - _synchronize: Whether to synchronize metadata after writing - and ensure that all data is written to the file before - the call returns. If false, data may continue to be written - asynchronously even after this call returns. - _start_pos: - Where in the file to write the tensor entry. If not specified, - writes starting at the current file offset. + uint64 tensor_sz} + ..... + affer all headers, bottom of file: + {[]byte tensor } """ - self._path_registry.register_path(name) - if isinstance(tensor, torch.Tensor): - shape: Sequence[int] = tensor.size() - has_data: bool = not tensor.is_meta - if has_data: - if not tensor.is_contiguous(): - _temporary_buffer = True - numpy_tensor = _NumpyTensor.from_tensor(tensor.contiguous()) - else: - self._file_header.version_number = max( - META_TENSOR_TENSORIZER_VERSION, - self._file_header.version_number, - ) - _temporary_buffer = True - hollow_tensor = torch.empty( - (0,) * tensor.ndim, device="cpu", dtype=tensor.dtype - ) - numpy_tensor = _NumpyTensor.from_tensor(hollow_tensor) - else: - shape: Sequence[int] = tensor.shape - has_data: bool = True - if ( - isinstance(tensor, numpy.ndarray) - and not tensor.flags.c_contiguous - and hasattr(numpy, "ascontiguousarray") - ): - numpy_tensor = _NumpyTensor.from_array( - numpy.ascontiguousarray(tensor) - ) - _temporary_buffer = True - else: - numpy_tensor = _NumpyTensor.from_array(tensor) + if isinstance(tensor, numpy.ndarray): + tensor = torch.from_numpy(tensor) - dtype_name = numpy_tensor.numpy_dtype - if numpy_tensor.is_opaque: - # The datatype name needs to contain both the numpy dtype that the - # data is serialized as and the original torch dtype. - dtype_name += OPAQUE_DTYPE_SEP + numpy_tensor.torch_dtype - self._file_header.version_number = max( - OPAQUE_TENSORIZER_VERSION, - self._file_header.version_number, - ) - - tensor: numpy.ndarray = numpy_tensor.data - tensor_memory: memoryview = numpy_tensor.data.data - tensor_size: int = tensor.nbytes - if tensor_memory.nbytes != tensor_size: - raise ValueError( - f"Cannot serialize tensor {name!r}:" - f" buffer size of underlying memory ({tensor_memory.nbytes})" - f" doesn't match reported size ({tensor_size})" - ) - if isinstance(name, str): - name_bytes: bytes = name.encode("utf-8") - else: - name_bytes: bytes = name.serialized_() - dtype_bytes = dtype_name.encode("utf-8") - if len(dtype_bytes) >= 256: - raise ValueError("dtype name length should be less than 256") - header_pos = self._file.tell() if _start_pos is None else _start_pos - - encrypted: bool = self._encrypted and has_data - if encrypted: - chunks = _Chunked( - total_size=tensor_memory.nbytes, - chunk_size=self._crypt_chunk_size, - ) - nonces = self._new_nonces(chunks.count) - encryptor = _crypt.ChunkedEncryption( - key=self._encryption.key, - buffer=tensor_memory, - chunk_size=self._crypt_chunk_size, - nonces=nonces, - executor=self._computation_pool, + write_spec = self._WriteSpec( + module_index=idx, name=name, tensor_type=tensor_type, tensor=tensor + ) + self._bulk_write([write_spec], incremental=True) + + class _WriteSpec: + def __init__( + self, + module_index: int, + name: Union[str, _TensorPath], + tensor_type: TensorType, + tensor: torch.Tensor, + ): + self.tensor = tensor + self.min_file_version = 0 + self.user_owns_tensor_data = True + + # Every parameter to _TensorHeaderSerializer() exists as an attribute except self.file_offset + # defaulting to the simplest possible case: + # CPU-based + # contiguous + # not hashing + # not encrypted + # not meta + # not opaque + self.module_index = module_index + self.tensor_type: TensorType = tensor_type + self.name = _TensorPath.wrap_(name) + self.dtype: Optional[str] = None # _prepare_for_write_dtype + self.shape = tensor.size() + self.data_length = tensor.element_size() * tensor.nelement() + # self.file_offset # intentionally omitted, handled by _write_headers() + self.include_crc32 = True + self.include_sha256 = True + self.crypt_info: Optional[_crypt_info.CryptInfo] = ( + None # _prepare_for_write_encryption ) - key_derivation_chunk = self._encryption._crypt_info_chunk() - encryption_algorithm_chunk = _crypt_info.XSalsa20ParallelChunk( - chunk_size=self._crypt_chunk_size, - nonce=nonces[0], - macs=encryptor.macs, + # Additional payloads that get set and used during the prepare_for_write procedures + self.header: Optional[_TensorHeaderSerializer] = ( + None # Set in _prepare_for_write_headers ) - if key_derivation_chunk is not None: - chunks = (key_derivation_chunk, encryption_algorithm_chunk) - else: - chunks = (encryption_algorithm_chunk,) - crypt_info = _crypt_info.CryptInfo(chunks) - else: - encryptor = None - if _FileFeatureFlags.encrypted in self._file_header.feature_flags: - # If the `encrypted` flag is present, all headers are expected - # to have crypt_info segments, so add an empty one - crypt_info = _crypt_info.CryptInfo() - else: - crypt_info = None - - include_crc32: bool = not encrypted - - header = _TensorHeaderSerializer( - idx, - tensor_type, - name_bytes, - dtype_bytes, - shape, - tensor_size, - header_pos, - include_crc32=include_crc32, - include_sha256=True, - crypt_info=crypt_info, - ) - - tensor_pos = header_pos + header.data_offset - - # Add our tensor metadata to the index. - metadata = header.metadata_entry - # Check for overflow - if self._metadata_cur + len(metadata) > self._metadata_end: - raise RuntimeError("Metadata overflow") - - metadata_pos = self._metadata_cur - metadata_len = len(metadata) - self._metadata_cur += metadata_len - - # This task is I/O-bound and has no prerequisites, - # so it goes into the regular writer pool. - def write_metadata(): - self._pwrite(metadata, metadata_pos, verify=metadata_len) - - self._jobs.append(self._writer_pool.submit(write_metadata)) - - # Calculate the hashes. - - # These two tasks are CPU-bound and don't block the GIL, - # so they go into the computation thread pool. - def compute_crc32(prerequisite: Optional[concurrent.futures.Future]): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - crc32 = header.compute_crc32() - return zlib.crc32(tensor_memory, crc32) - - def compute_sha256(prerequisite: Optional[concurrent.futures.Future]): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - sha256 = header.compute_sha256() - sha256.update(tensor_memory) - return sha256.digest() - - # This task is I/O-bound and dependent on the previous two tasks, - # so it goes into the header writer pool. - def commit_header( - crc32_future: Optional[concurrent.futures.Future], - sha256_future: Optional[concurrent.futures.Future], - encrypt_future: Optional[concurrent.futures.Future], - ): - crc32 = sha256 = None - if crc32_future is not None: - crc32 = crc32_future.result(_TIMEOUT) - if sha256_future is not None: - sha256 = sha256_future.result(_TIMEOUT) - if encrypt_future is not None: - encrypt_future.result(_TIMEOUT) - # These must be written only after all other futures complete - # to prevent a race condition from other threads hashing - # a partially-filled-in hash section - if crc32_future is not None: - header.add_crc32(crc32) - if sha256_future is not None: - header.add_sha256(sha256) - if encrypt_future is not None: - header.update_crypt_info() - self._pwrite(header.buffer, header_pos, verify=header.data_offset) - - hash_tasks = [] - if encrypted and not _temporary_buffer: - # If multiple tensors share memory, and were encrypted in-place, - # then this must not start hashing until any previous decryption - # tasks have restored this memory to its original state - mem_pointer = tensor.__array_interface__["data"][0] - pending_decryption = self._decryption_jobs.get(mem_pointer, None) - else: - mem_pointer = None - pending_decryption = None - if include_crc32: - crc32_task = self._computation_pool.submit( - compute_crc32, pending_decryption + self.metadata_pos = -1 # Set in _prepare_for_write_headers + self.encryptor: Optional[_crypt.ChunkedEncryption] = ( + None # Set in _do_encryption if encrypted ) - hash_tasks.append(crc32_task) - else: - crc32_task = None - sha256_task = self._computation_pool.submit( - compute_sha256, pending_decryption - ) - hash_tasks.append(sha256_task) - self._jobs.extend(hash_tasks) - - def encrypt(prerequisites: Iterable[concurrent.futures.Future]): - fs = concurrent.futures.wait(prerequisites, timeout=_TIMEOUT) - for f in fs.done: - # Raise exceptions - f.result() - for f in fs.not_done: - # Raise timeouts - f.result(0) - try: - encryptor.encrypt_all( - wait=True, - timeout=_TIMEOUT, - ) - except _crypt.CryptographyError as e: - raise CryptographyError("Tensor encryption failed") from e - # This task is I/O-bound, so it goes into the regular writer pool. - def write_tensor_data( - prerequisite: Optional[concurrent.futures.Future], size: int - ): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - if has_data: - bytes_written = self._pwrite( - tensor_memory, tensor_pos, verify=size + # self.tensor_data_task is a future for processing some contents of self.tensor + # e.g. cuda transfer, make_contiguous, hashing, encryption, writing, or decryption. + # They are often chained from one step of the process to the next + self.tensor_data_task: Optional[_Future] = None + + @property + def tensor_memoryview(self) -> memoryview: + if not self.tensor.is_contiguous(): + # This is actually possible, but probably not intended here, + # so throw an error instead of handling this case + raise BufferError( + "Cannot create a memoryview of a discontiguous tensor" ) - else: - bytes_written = 0 - with self._tensor_count_update_lock: - self._file_header.tensor_count += 1 - self._file_header.tensor_size += bytes_written - - def decrypt(prerequisite: concurrent.futures.Future): - try: - prerequisite.result(_TIMEOUT) - finally: - # Try to decrypt again even if writing to disk failed - # to avoid exiting with the tensor memory in a modified state - fs = encryptor.decrypt_all(wait=False) - try: - _crypt.ChunkedEncryption.wait_or_raise( - fs, - timeout=_TIMEOUT, - return_when=concurrent.futures.ALL_COMPLETED, - ) - except _crypt.CryptographyError as e: - try: - original_exc = prerequisite.exception(timeout=0) - except ( - concurrent.futures.TimeoutError, - concurrent.futures.CancelledError, - ): - original_exc = None - raise CryptographyError( - "Restoring encrypted tensor data in memory failed" - ) from (original_exc if original_exc is not None else e) - - # Encrypt the tensor memory in-place before writing - if encrypted: - encrypt_task = self._encryption_pool.submit(encrypt, hash_tasks) - self._jobs.append(encrypt_task) - else: - encrypt_task = None - - commit_header_task = self._header_writer_pool.submit( - commit_header, crc32_task, sha256_task, encrypt_task - ) - self._jobs.append(commit_header_task) - - # Write the potentially-encrypted tensor memory to the file - write_task = self._writer_pool.submit( - write_tensor_data, encrypt_task, tensor_size - ) - self._jobs.append(write_task) - # Decrypt the memory after writing is finished, if it was encrypted - if encrypted and not _temporary_buffer: - decrypt_task = self._decryption_pool.submit(decrypt, write_task) - self._jobs.append(decrypt_task) - assert mem_pointer is not None - self._decryption_jobs[mem_pointer] = decrypt_task - - tensor_endpos = tensor_pos + tensor_size - - # Update our prologue. - if _synchronize: - self._synchronize_pools() - # Move to the end of our serialized tensor to prepare - # for the next one in the synchronized case. - self._file.seek(tensor_endpos) - self._sync_prologue_state() + nbytes = self.tensor.element_size() * self.tensor.nelement() + return memoryview( + (ctypes.c_char * nbytes).from_address(self.tensor.data_ptr()) + ) - ds_size = tensor_endpos - header_pos - ds_bytes = f"{ds_size:,} bytes" + def set_min_file_version_number(self, version_number): + self.min_file_version = max(self.min_file_version, version_number) - typ = { - TensorType.PARAM: "p", - TensorType.BUFFER: "b", - TensorType.STATE_DICT: "sd", - }[tensor_type] + def _maybe_fallocate( + self, tensors: Sequence[_WriteSpec] + ) -> Optional[concurrent.futures.Future]: + if not _syscalls.has_fallocate() or not self._fd: + return None - # if self.compress_tensors: - # comp_report = ( - # f" - tensor:[raw: {tensor_raw_sz}," - # + f" compressed: {tensor_compressed_sz}," - # + f" ratio: {compression_ratio:.2f}]" - # ) - # else: - comp_report = "" - logger.debug( - f"{idx}:{typ}:{name} - {dtype_bytes.decode('utf-8')} - " - f"{tensor.shape} -> {ds_bytes}{comp_report}" + next_pos = self._file.tell() + size = sum(len(t.name.serialized_()) for t in tensors) + size += sum( + t.tensor.element_size() + * t.tensor.nelement() + * (not t.tensor.is_meta) + for t in tensors ) - return tensor_endpos - - @staticmethod - def _async_bulk_device_to_host_transfer( - tensors, max_read_ahead: Optional[int] = 32 - ) -> Tuple[Iterator[torch.Tensor], Callable]: - """ - Transfers CUDA tensors to host memory asynchronously in bulk. - - Args: - tensors: The list of tensors to transfer. - max_read_ahead: The maximum number of tensors to queue. - - Returns: - A tuple containing an iterator over CPU tensors, - and a callback to cancel the transfer early. - """ - if len(tensors) < max_read_ahead: - transferred = queue.SimpleQueue() - else: - transferred = queue.Queue(maxsize=max_read_ahead) - - tensor_sizes = [t.element_size() * t.nelement() for t in tensors] - staging_tensor = torch.empty( - (max(tensor_sizes),), - dtype=torch.uint8, - device="cpu", - pin_memory=True, + # Rough underestimate of header size + header_min_size = 24 + size += header_min_size * len(tensors) + + return self._header_writer_pool.submit( + _syscalls.try_fallocate, + self._fd, + next_pos, + size, + suppress_all_errors=True, ) - transfer_finished = False + def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): + write_specs = list(write_specs) - def _transfer(): - nonlocal transfer_finished - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # This is in a separate CUDA stream because it shouldn't - # affect any other GPU operations, even though each - # of these transfers are synchronous - try: - for t, nbytes in zip(tensors, tensor_sizes): - if transfer_finished: - break - staging_tensor_view = ( - staging_tensor.narrow(0, 0, nbytes) - .view(t.dtype) - .view(t.shape) - ) - staging_tensor_view.copy_(t) - new_cpu_tensor = staging_tensor_view.clone() - transferred.put( - new_cpu_tensor.detach(), - timeout=_TIMEOUT, - ) - finally: - # Sentinel - transferred.put(None) - transfer_finished = True + write_dependency: Optional[concurrent.futures.Future] = None + if not incremental: + write_dependency = self._maybe_fallocate(write_specs) - transfer_thread = threading.Thread( - target=_transfer, name="TensorizerTransfer", daemon=True - ) - transfer_thread.start() + for w in write_specs: + self._path_registry.register_path(w.name) - def _interrupt_transfer(): - nonlocal transfer_finished - if not transfer_finished: - # Signal the worker thread to end on its next loop - transfer_finished = True - try: - # Unstick the worker thread so that - # it isn't waiting for an open spot - # that will never arrive - transferred.get_nowait() - except queue.Empty: - pass + cuda_executor = self._prepare_for_write_cuda(write_specs) + try: + self._prepare_for_write_contiguous(write_specs) + self._prepare_for_write_meta(write_specs) + self._prepare_for_write_dtype(write_specs) + if self._encrypted: + self._prepare_for_write_encryption(write_specs) + self._prepare_for_write_headers(write_specs) + self._prepare_for_write_hashes(write_specs) - return ( - iter(lambda: transferred.get(timeout=_TIMEOUT), None), - _interrupt_transfer, - ) + if self._encrypted: + self._do_encryption(write_specs) + if write_dependency: + write_dependency.result(_TIMEOUT) - class _WriteSpec(typing.NamedTuple): - idx: int - path: Union[_TensorPath, str] - tensor_type: TensorType - tensor: torch.Tensor - callback: Optional[Callable] + self._do_commit_headers(write_specs) + self._do_commit_tensor_data(write_specs) + if self._encrypted: + self._maybe_decrypt_data(write_specs) - def _bulk_write(self, tensors: Iterable[_WriteSpec]): - tensors = collections.deque(tensors) - next_pos = self._file.tell() - if _syscalls.has_fallocate() and self._fd: - size = sum( - len(_TensorPath.wrap_(t.path).serialized_()) for t in tensors - ) - size += sum( - t.tensor.element_size() - * t.tensor.nelement() - * (not t.tensor.is_meta) - for t in tensors - ) - # Rough underestimate of header size - header_min_size = 24 - size += header_min_size * len(tensors) - _syscalls.try_fallocate( - self._fd, next_pos, size, suppress_all_errors=True + self._file_header.version_number = max( + self._file_header.version_number, + max(w.min_file_version for w in write_specs), ) - cuda_tensors = [ - t.tensor for t in tensors if t.tensor.device.type == "cuda" - ] - if cuda_tensors: - ( - transferred, - interrupt_transfer, - ) = self._async_bulk_device_to_host_transfer(cuda_tensors) - else: - transferred = interrupt_transfer = None - del cuda_tensors - - if self._encrypted: - shared = [] - seen_addresses = set() - for t in reversed(tensors): - if t.tensor.device.type in ("cuda", "meta"): - shared.append(False) - else: - address = t.tensor.data_ptr() - shared.append(address in seen_addresses) - seen_addresses.add(address) - del seen_addresses - else: - shared = [False] * len(tensors) + self._synchronize_pools() + self._sync_prologue_state() + except Exception as e: + for j in self._jobs: + j.cancel() + if cuda_executor is not None: + cuda_executor.shutdown(wait=False) + raise e - try: - while tensors: - idx, name, tensor_type, tensor, callback = tensors.popleft() - is_shared = shared.pop() - self._idx = idx - if tensor.device.type == "cuda": - tensor = next(transferred) - temp_tensor = True - elif is_shared and self._encrypted: - # Un-shares tensor memory in preparation for in-place - # operations on the buffer that would otherwise conflict - # with one another. Full support for shared-memory tensors - # (e.g. if they were only written once) could make - # this unnecessary, once implemented. - # Another option would be to reuse the same encrypted - # weights and decrypt them at the end. This would require - # confirming that the tensor data regions are actually - # identical, and don't just overlap. - tensor = tensor.clone().detach() - temp_tensor = True - else: - temp_tensor = False - next_pos = self._write_tensor( - idx, - name, - tensor_type, - tensor, - _synchronize=False, - _start_pos=next_pos, - _temporary_buffer=temp_tensor, - ) - if callback is not None: - callback() - except Exception: - if interrupt_transfer is not None: - interrupt_transfer() - raise - self._synchronize_pools() - self._file.seek(next_pos) - self._sync_prologue_state() + if cuda_executor is not None: + cuda_executor.shutdown(wait=True) def write_module( self, m: torch.nn.Module, - remove_tensors: bool = False, *, include_non_persistent_buffers: bool = True, ) -> None: @@ -4238,9 +3939,6 @@ def write_module( Args: m: The module to serialize. - remove_tensors: Whether to delete each tensor from `m` - after serializing it. - Deleted tensors are replaced with ``None``. include_non_persistent_buffers: Whether to serialize buffers registered with ``persistent=False``. Set to ``False`` to match the behaviour of @@ -4251,10 +3949,9 @@ def write_module( modules = tuple(m.named_modules()) - def extract_tensors(): + def extract_tensors() -> Iterator[TensorSerializer._WriteSpec]: chain = itertools.chain repeat = itertools.repeat - callback = None for idx, (module_name, module) in enumerate(modules): module: torch.nn.Module parameters = module.named_parameters(recurse=False) @@ -4264,14 +3961,11 @@ def extract_tensors(): zip(buffers, repeat(TensorType.BUFFER)), ): label = f"{module_name}.{name}" - if remove_tensors: - callback = partial(setattr, module, name, None) yield TensorSerializer._WriteSpec( - idx=idx, - path=label, + module_index=idx, + name=label, tensor_type=tensor_type, tensor=tensor, - callback=callback, ) def persistent_buffers() -> Set[str]: @@ -4311,7 +4005,7 @@ def persistent_buffers() -> Set[str]: spec for spec in all_tensors if spec.tensor_type != TensorType.BUFFER - or spec.path in persistent + or str(spec.name) in persistent ) self._bulk_write(all_tensors) @@ -4446,17 +4140,497 @@ def write_state_dict(self, state_dict: Union[Dict, List, Tuple]): idx = 0 self._bulk_write( TensorSerializer._WriteSpec( - idx=idx, - path=name, + module_index=idx, + name=name, tensor_type=TensorType.STATE_DICT, tensor=param, - callback=None, ) for name, param in _tensor_path.flatten_structure( torch.Tensor, state_dict ) ) + def _prepare_for_write_cuda( + self, write_specs: Sequence[_WriteSpec] + ) -> Optional[concurrent.futures.ThreadPoolExecutor]: + cuda_specs = [w for w in write_specs if w.tensor.device.type == "cuda"] + if not cuda_specs: + return None + + class CudaTransfer: + def __init__(self, max_size): + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="TransferThread", + initializer=self._allocate_staging_tensor, + initargs=(max_size,), + ) + + def submit(self, write_spec) -> concurrent.futures.Future: + return self.executor.submit(self._transfer, write_spec) + + def _allocate_staging_tensor(self, max_size: int): + self._stream = torch.cuda.Stream() + self._staging_tensor = torch.empty( + (max_size,), + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + + def _transfer(self, write_spec): + nbytes = ( + write_spec.tensor.element_size() + * write_spec.tensor.nelement() + ) + staging_tensor_view = ( + self._staging_tensor.narrow(0, 0, nbytes) + .view(write_spec.tensor.dtype) + .view(write_spec.shape) + ) + with torch.cuda.stream(self._stream): + staging_tensor_view.copy_(write_spec.tensor) + write_spec.user_owns_tensor_data = False + write_spec.tensor = staging_tensor_view.clone().detach() + + max_tensor_size = max( + [w.tensor.element_size() * w.tensor.nelement() for w in cuda_specs] + ) + cuda_transfer = CudaTransfer(max_tensor_size) + + for w in cuda_specs: + assert w.tensor_data_task is None + w.tensor_data_task = cuda_transfer.submit(w) + self._jobs.append(w.tensor_data_task) + + return cuda_transfer.executor + + def _prepare_for_write_contiguous(self, write_specs: Sequence[_WriteSpec]): + def make_contiguous(write_spec, dependency): + if dependency is not None: + dependency.result(_TIMEOUT) + write_spec.tensor = write_spec.tensor.contiguous() + write_spec.data_length = ( + write_spec.tensor.element_size() * write_spec.tensor.nelement() + ) + write_spec.user_owns_tensor_data = False + + for w in write_specs: + # if there is a tensor_data_task it is a cuda tensor + if w.tensor_data_task is not None or w.tensor.is_contiguous(): + continue + w.tensor_data_task = self._computation_pool.submit( + make_contiguous, w, w.tensor_data_task + ) + + def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]): + torch_dtype_to_numpy_dtype_cache: Dict[str, str] = {} + + for w in write_specs: + tensor_dtype_str = str(w.tensor.dtype) + if _NumpyTensor._is_asymmetric(w.tensor.dtype): + # is opaque + w.dtype = ( + f" None: + if not self._encrypted or self._encryption is None: + raise RuntimeError( + "Tried to encrypt tensors without encryption parameters" + " having been provided" + ) + + # If any tensors are shared, so we need to clone all but one of them before encrypting + write_specs_by_addr: Dict[int, List[TensorSerializer._WriteSpec]] = ( + collections.defaultdict(list) + ) + for w in write_specs: + if w.tensor.device.type != "cpu": + continue + address = w.tensor.untyped_storage().data_ptr() + write_specs_by_addr[address].append(w) + + for shared_write_specs in write_specs_by_addr.values(): + if len(shared_write_specs) == 1: + continue + + clone_dependencies = _FutureGroup( + [ + w.tensor_data_task + for w in shared_write_specs + if w.tensor_data_task is not None + ] + ) + + clone_tasks: List[_Future] = [] + for w in shared_write_specs[:-1]: + clone_tasks.append( + self._computation_pool.submit( + self._do_clone, w, clone_dependencies + ) + ) + w.user_owns_tensor_data = False + + shared_write_specs[-1].tensor_data_task = _FutureGroup( + clone_tasks + clone_dependencies.futures + ) + for w in shared_write_specs[:-1]: + w.tensor_data_task = _FutureGroup(clone_tasks) + + for w in write_specs: + w.include_crc32 = False + + if w.data_length == 0: + # All headers are expected to have crypt_info segments, so add + # an empty one + w.crypt_info = _crypt_info.CryptInfo() + continue + + if w.tensor_data_task is not None: + w.tensor_data_task.result(_TIMEOUT) + w.tensor_data_task = None + + tensor_memory: memoryview = w.tensor_memoryview + chunked = _Chunked( + total_size=tensor_memory.nbytes, + chunk_size=self._crypt_chunk_size, + ) + nonces = self._new_nonces(chunked.count) + w.encryptor = _crypt.ChunkedEncryption( + key=self._encryption.key, + buffer=tensor_memory, + chunk_size=self._crypt_chunk_size, + nonces=nonces, + executor=self._computation_pool, + ) + + key_derivation_chunk = self._encryption._crypt_info_chunk() + encryption_algorithm_chunk = _crypt_info.XSalsa20ParallelChunk( + chunk_size=self._crypt_chunk_size, + nonce=nonces[0], + macs=w.encryptor.macs, + ) + chunks: Sequence[Any] + if key_derivation_chunk is not None: + chunks = (key_derivation_chunk, encryption_algorithm_chunk) + else: + chunks = (encryption_algorithm_chunk,) + w.crypt_info = _crypt_info.CryptInfo(chunks) + + def _prepare_for_write_headers( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + # We first need to construct the headers so that we know the size of each + for w in write_specs: + dtype_bytes = w.dtype.encode("utf-8") # type: ignore + if len(dtype_bytes) >= 256: + raise ValueError("dtype name length should be less than 256") + + w.header = _TensorHeaderSerializer( + w.module_index, + w.tensor_type, + w.name.serialized_(), # name as bytes + dtype_bytes, + w.shape, + w.data_length, + 0, # placeholder file_offset + include_crc32=w.include_crc32, + include_sha256=w.include_sha256, + crypt_info=w.crypt_info, + ) + + # Specify the offsets for each metadata entry + file_offset = ( + self._metadata_cur + ) # position of next metadata entry to write + + ## metadata + for w in write_specs: + w.metadata_pos = file_offset + file_offset += w.header.get_metadata_entry_segment().size + + self._metadata_cur = file_offset + if self._metadata_end is None: + self._metadata_end = self._metadata_cur + elif file_offset > self._metadata_end: + raise RuntimeError("Metadata block is full. Increase max_tensors") + + ## headers + if self._header_cur is not None: + if self._header_cur < file_offset: + raise RuntimeError("Somehow wrote past metadata block") + file_offset = self._header_cur + + for w in write_specs: + w.header.file_offset = file_offset + file_offset += w.header.size + + self._header_cur = file_offset + if self._header_end is None: + self._header_end = self._header_cur + elif self._header_cur > self._header_end: + raise RuntimeError("Header block is full. Increase max_tensors") + + ## tensors + if self._tensor_cur is None: + # The block of tensor data starts on a page-aligned boundary + self._tensor_cur = (file_offset + 4095) & ~4095 + else: + if self._tensor_cur < file_offset: + raise RuntimeError("Somehow wrote past header block") + # Each tensor itself begins on an 8-byte aligned boundary + file_offset = (self._tensor_cur + 7) & ~7 + + # file_offset is now where we should start writing tensor data + for w in write_specs: + w.header.build(file_offset) # type: ignore + file_offset += w.data_length + + self._tensor_cur = file_offset + + def _prepare_for_write_meta( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + for w in write_specs: + if not w.tensor.is_meta: + continue + w.tensor = torch.empty( + (0,) * w.tensor.ndim, device="cpu", dtype=w.tensor.dtype + ) + w.data_length = 0 + w.user_owns_tensor_data = False + + def _prepare_for_write_hashes( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + def compute_crc32( + write_spec: TensorSerializer._WriteSpec, + dependency: Optional[_Future], + ): + if dependency is not None: + dependency.result(_TIMEOUT) + header_crc32 = write_spec.header.compute_crc32() + crc32 = zlib.crc32(write_spec.tensor_memoryview, header_crc32) + write_spec.header.add_crc32(crc32) + + def compute_sha256( + write_spec: TensorSerializer._WriteSpec, + dependency: Optional[_Future], + ): + if dependency is not None: + dependency.result(_TIMEOUT) + sha256 = write_spec.header.compute_sha256() + sha256.update(write_spec.tensor_memoryview) + write_spec.header.add_sha256(sha256.digest()) + + for w in write_specs: + old_tensor_data_task = w.tensor_data_task + + hash_tasks: List[_Future] = [] + if w.include_crc32: + crc32_task = self._computation_pool.submit( + compute_crc32, w, old_tensor_data_task + ) + hash_tasks.append(crc32_task) + if w.include_sha256: + sha256_task = self._computation_pool.submit( + compute_sha256, w, old_tensor_data_task + ) + hash_tasks.append(sha256_task) + + if hash_tasks: + w.tensor_data_task = _FutureGroup(hash_tasks) + self._jobs.extend(hash_tasks) + + def _do_encryption(self, write_specs: Sequence[_WriteSpec]) -> None: + def encrypt(write_spec, dependency: Optional[_Future]): + if dependency is not None: + dependency.result(_TIMEOUT) + try: + write_spec.encryptor.encrypt_all(wait=True, timeout=_TIMEOUT) + except _crypt.CryptographyError as e: + raise CryptographyError("Tensor encryption failed") from e + write_spec.header.update_crypt_info() + + for w in write_specs: + if not w.data_length: + continue + w.tensor_data_task = self._encryption_pool.submit( + encrypt, w, w.tensor_data_task + ) + self._jobs.append(w.tensor_data_task) + + def _do_commit_headers(self, write_specs_: Sequence[_WriteSpec]) -> None: + def do_commit( + write_specs: Sequence[TensorSerializer._WriteSpec], + dependencies: Sequence[_Future], + ): + # Fast version: makes one buffer containing the size, metadata, and headers, and writes it one go + header_block_size = self._header_cur - self._metadata_start + header_buffer = bytearray(header_block_size) + + metadata_start = self._metadata_start + metadata_size = ( + self._metadata_cur - metadata_start - 8 + ) # 8 bytes for metadata length field + struct.pack_into("= (3, 12): + kwargs["delete_on_close"] = False + return tempfile.NamedTemporaryFile(*args, **kwargs) + + f = tempfile.NamedTemporaryFile(*args, **kwargs) + f.close = f.file.close + return f