From cc18fcca934c2efa47fc7fb29b99a5d3cb758c3e Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Thu, 18 Apr 2024 13:59:16 -0700 Subject: [PATCH 01/11] Major refactor of serialization in preparation for 3.0. _bulk_write() has been broken up into a series of smaller functions that each perform some operation on the tensors. _WriteSpec carries the state of each tensor as it gets processed. Many ops are threaded and tracked as Futures. Dependent steps do future chaining. The file schema has changed. Immediately following the metadata block is the full block of all header structures. We may want to combine the metadata and the header entries, but they are left separate for now for easier backwards compatibility. Previously, hash computations included the hash fields themselves, presumed to be zeroed. This has been changed so the hash computation segments omit the hash fields themselves. This breaks compatibility with 2.x files for now. This will probably get fixed before 3.0 release --- tensorizer/_NumpyTensor.py | 15 + tensorizer/_futuregroup.py | 61 ++ tensorizer/serialization.py | 1312 ++++++++++++++++++----------------- tests/test_serialization.py | 107 +++ 4 files changed, 875 insertions(+), 620 deletions(-) create mode 100644 tensorizer/_futuregroup.py diff --git a/tensorizer/_NumpyTensor.py b/tensorizer/_NumpyTensor.py index 37cb62c..8719892 100644 --- a/tensorizer/_NumpyTensor.py +++ b/tensorizer/_NumpyTensor.py @@ -82,6 +82,8 @@ k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES } +OPAQUE_DTYPE_SEP = "\0" + class _NumpyTensor(NamedTuple): data: numpy.ndarray @@ -224,6 +226,15 @@ def is_opaque(self): """ return self._is_opaque(self.numpy_dtype) + @property + def dtype_name(self): + if not self.is_opaque: + return self.numpy_dtype + + # The datatype name needs to contain both the numpy dtype that the + # data is serialized as and the original torch dtype. + return self.numpy_dtype + OPAQUE_DTYPE_SEP + self.torch_dtype + @staticmethod def _intermediate_type(size: int) -> torch.dtype: """ @@ -338,3 +349,7 @@ def _decode_torch_dtype(self) -> torch.dtype: raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") from e return dtype + + @property + def tensor_memory(self): + return self.data.data diff --git a/tensorizer/_futuregroup.py b/tensorizer/_futuregroup.py new file mode 100644 index 0000000..8c24592 --- /dev/null +++ b/tensorizer/_futuregroup.py @@ -0,0 +1,61 @@ +import concurrent.futures +from collections.abc import Callable +from typing import Any, Optional, Sequence + + +class _FutureGroup(concurrent.futures.Future): + def __init__(self, futures: Sequence[concurrent.futures.Future]): + self.futures = futures + + def cancel(self) -> bool: + result = True + for f in self.futures: + result = result and f.cancel() + return result + + def cancelled(self) -> bool: + return all(f.cancelled() for f in self.futures) + + def running(self) -> bool: + return any(f.running() for f in self.futures) + + def done(self) -> bool: + return all(f.done() for f in self.futures) + + def result(self, timeout=None) -> Sequence[Any]: + return _future_wait_and_raise(self.futures, timeout=timeout) + + def exception(self, timeout=None) -> Optional[BaseException]: + for f in self.futures: + exc = f.exception(timeout=timeout) + if exc: + return exc + return None + + def add_done_callback(self, _): + raise NotImplementedError() + + def set_running_or_notify_cancel(self) -> bool: + raise NotImplementedError() + + def set_result(self, _) -> None: + raise NotImplementedError() + + def set_exception(self, _) -> None: + raise NotImplementedError() + + +def _future_wait_and_raise( + futures: Sequence[concurrent.futures.Future], timeout=None +) -> Sequence[Any]: + # Wait on a list of futures with a timeout. Raise any exceptions, including TimeoutErrors. + # otherwise return the list of results in the same order as the input futures. + results = [] + fs = concurrent.futures.wait(futures, timeout=timeout) + for f in fs.done: + # if the future has an exception, this will raise it + results.append(f.result()) + for f in fs.not_done: + # force raise of TimeoutError + results.append(f.result(0)) + return results diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 7b3b182..a2519b2 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -4,6 +4,7 @@ ############################################################################## import abc import bisect +import collections import collections.abc import concurrent.futures import contextlib @@ -61,9 +62,10 @@ from tensorizer._crypt._cgroup_cpu_count import ( effective_cpu_count as _effective_cpu_count, ) +from tensorizer._futuregroup import _future_wait_and_raise, _FutureGroup from tensorizer._internal_utils import Chunked as _Chunked from tensorizer._internal_utils import _variable_read -from tensorizer._NumpyTensor import _NumpyTensor +from tensorizer._NumpyTensor import OPAQUE_DTYPE_SEP, _NumpyTensor from tensorizer._tensor_path import ( _TensorPath, _TensorPathComponent, @@ -157,7 +159,9 @@ class TensorType(IntEnum): # Current version -TENSORIZER_VERSION = 4 +TENSORIZER_VERSION = 5 + +HEADERS_AT_TOP_TENSORIZER_VERSION = 5 # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. @@ -176,8 +180,6 @@ class TensorType(IntEnum): TENSORIZER_MAGIC = b"|TZR|" -OPAQUE_DTYPE_SEP = "\0" - _TIMEOUT: typing.Final[int] = 3600 @@ -221,7 +223,9 @@ def deserialized_length(self): if self.data_length > 0: return self.data_length element_size: int = numpy.dtype(self.dtype).itemsize - num_elements: int = numpy.product(self.shape) + num_elements: int = int( + numpy.product(self.shape) + ) # numpy.product([]) == 1.0 return element_size * num_elements @@ -397,17 +401,27 @@ def __init__( dtype: bytes, shape: Sequence[int], data_length: int, - file_offset: int, + file_offset: int, # location of header in file include_crc32: bool = True, include_sha256: bool = True, crypt_info: Optional[_crypt_info.CryptInfo] = None, ): + self.module_index = module_index + self.tensor_type = tensor_type + self.name = name + self.shape = shape + self.dtype = dtype + self.data_length = data_length + self.file_offset = file_offset + self.include_crc32 = include_crc32 + self.include_sha256 = include_sha256 + # Calculate the variable length segment - name_len = len(name) - dtype_len = len(dtype) + self.name_len = len(name) + self.dtype_len = len(dtype) # NB: shape_len is the number of dimensions, # not the encoded byte length - shape_len = len(shape) + self.shape_len = len(shape) self.crypt_info = crypt_info if crypt_info is None: crypt_info_len = 0 @@ -415,9 +429,9 @@ def __init__( crypt_info_len = crypt_info.sized_size self.variable_length_segment = struct.Struct( self.variable_length_segment_template.format( - name_len=name_len, - dtype_len=dtype_len, - shape_len=shape_len, + name_len=self.name_len, + dtype_len=self.dtype_len, + shape_len=self.shape_len, ) ) crc32_len = sha256_len = self.hash_count = 0 @@ -438,7 +452,7 @@ def __init__( self.sha256_hash_offset, self.crypt_info_offset, self.data_length_offset, - self.data_offset, + self.size, ) = itertools.accumulate( ( self.start_segment.size, @@ -450,25 +464,28 @@ def __init__( self.data_length_segment.size, ) ) - self.size = self.data_offset + + def build(self, tensor_data_offset: int): + # tensor_data_offset: location of tensor data in file + self.data_offset = tensor_data_offset self.buffer = bytearray(self.size) self.start_segment.pack_into( self.buffer, 0, # Offset self.size, # Tensor header size - module_index, # Module index. - tensor_type.value, # Whether this is a parameter or a buffer - name_len, # Parameter/buffer name length + self.module_index, # Module index. + self.tensor_type.value, # Whether this is a parameter or a buffer + self.name_len, # Parameter/buffer name length ) self.variable_length_segment.pack_into( self.buffer, self.variable_length_offset, - name, # Parameter/buffer name UTF-8 bytes - dtype_len, # Tensor dtype length - dtype, # Tensor dtype UTF-8 bytes - shape_len, # Tensor shape length - *shape, # Tensor shape I array + self.name, # Parameter/buffer name UTF-8 bytes + self.dtype_len, # Tensor dtype length + self.dtype, # Tensor dtype UTF-8 bytes + self.shape_len, # Tensor shape length + *self.shape, # Tensor shape I array ) after_hashes = self.crypt_info_offset @@ -481,46 +498,46 @@ def __init__( ) # Placeholders - if include_crc32: + if self.include_crc32: self.add_crc32(0) - if include_sha256: + if self.include_sha256: self.add_sha256(b"") - if crypt_info is not None: - crypt_info.sized_pack_into(self.buffer, self.crypt_info_offset) + if self.crypt_info is not None: + self.crypt_info.sized_pack_into(self.buffer, self.crypt_info_offset) self.data_length_segment.pack_into( - self.buffer, self.data_length_offset, data_length + self.buffer, self.data_length_offset, self.data_length ) - metadata_entry_segment: struct.Struct = struct.Struct( - self.metadata_entry_segment_template.format( - name_len=name_len, - dtype_len=dtype_len, - shape_len=shape_len, - ) - ) + metadata_entry_segment = self.get_metadata_entry_segment() self.metadata_entry = metadata_entry_segment.pack( - name_len, # Name length - name, # Name - tensor_type.value, # Whether this is a parameter or a buffer - dtype_len, # Dtype length - dtype, # Dtype - shape_len, # Shape length - *shape, # Shape - file_offset, # Header start (relative to the file) + self.name_len, # Name length + self.name, # Name + self.tensor_type.value, # Whether this is a parameter or a buffer + self.dtype_len, # Dtype length + self.dtype, # Dtype + self.shape_len, # Shape length + *self.shape, # Shape + self.file_offset, # Header start (relative to the file) # Tensor data start (relative to the file): - file_offset + self.data_offset, - data_length, # Tensor length + self.data_offset, + self.data_length, # Tensor length + ) + + def get_metadata_entry_segment(self) -> struct.Struct: + return struct.Struct( + self.metadata_entry_segment_template.format( + name_len=self.name_len, + dtype_len=self.dtype_len, + shape_len=self.shape_len, + ) ) def _hashable_segment_views(self): - if self.crypt_info is None: - yield memoryview(self.buffer) - else: - yield memoryview(self.buffer)[: self.crypt_info_offset] - # Skip crypt_info - yield memoryview(self.buffer)[self.data_length_offset :] + # Skip areas where we store hashes and crypt_info + yield memoryview(self.buffer)[: self.hash_header_offset] + yield memoryview(self.buffer)[self.data_length_offset :] def compute_crc32(self) -> int: crc32 = 0 @@ -652,6 +669,7 @@ def __init__( self.shape, offset = self.read_shape(buffer, offset) # Read our hashes in. + hash_start = offset hashes_slice, offset = self.read_hash_block(buffer, offset) with hashes_slice: self.hashes = self._decode_hashes(hashes_slice) @@ -659,14 +677,9 @@ def __init__( self._zero_hashes(hashes_slice) if check_crypt_info: - crypt_info_start = offset crypt_info_slice, offset = self.read_crypt_info_block( buffer, offset ) - self._hashable_segments = ( - slice(None, crypt_info_start), - slice(offset, None), - ) with crypt_info_slice: self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice @@ -676,10 +689,14 @@ def __init__( self._hashable_segments = (slice(None, None),) # Finally, get the tensor data length. - offset = len(buffer) - self.data_length_segment.size + data_length_start = offset = len(buffer) - self.data_length_segment.size self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] + self._hashable_segments = ( + slice(None, hash_start), + slice(data_length_start, None), + ) def _hashable_segment_views(self): for segment_slice in self._hashable_segments: @@ -1656,13 +1673,6 @@ def __init__( "Tensor is encrypted, but decryption was not requested" ) - # The total size of the file. - # WARNING: this is not accurate. This field isn't used in the - # deserializer, but has been available as a public attribute, - # so it is kept how it was for compatibility until the next - # major version. - self.total_file_bytes = self._file_header.tensor_size - # Read the metadata index of tensors. # This is a list of offsets into the file where the per-tensor data # is stored. @@ -1674,6 +1684,30 @@ def __init__( ) if not self._metadata: raise ValueError("Tensor index in the file is empty") + + self._headers: Optional[ + Dict[_TensorPath, _TensorHeaderDeserializer] + ] + if version_number >= HEADERS_AT_TOP_TENSORIZER_VERSION: + metadata_ordered = sorted( + self._metadata.values(), key=operator.attrgetter("offset") + ) + self._headers = {} + for entry in metadata_ordered: + if self._file.tell() > entry.offset: + raise ValueError("Header offsets overlap or are wrong") + self._file.seek(entry.offset) + header = _TensorHeaderDeserializer.from_io( + self._file, + zero_hashes=True, + check_crypt_info=self._has_crypt_info, + ) + if header is None: + raise KeyError("Unexpected empty header") + self._headers[entry.name] = header + else: + self._headers = None + # filter_func is a test that determines the tensor names to read. # If filter_func is None, all tensors are read. if filter_func is not None: @@ -2243,7 +2277,7 @@ def keys(self): def _verify_hashes( self, - name: str, + name: _TensorPath, hashes: Iterable[TensorHash], header_hashes: Dict[HashType, Any], mv: Union[memoryview, bytes], @@ -2858,19 +2892,23 @@ def _copy_thread( if halt: break - header = _TensorHeaderDeserializer.from_io( - file_, - zero_hashes=True, - check_crypt_info=unsafe_self._has_crypt_info, - ) - - if header is None: - raise ValueError("Unexpected empty header") + if unsafe_self._headers is None: + header = _TensorHeaderDeserializer.from_io( + file_, + zero_hashes=True, + check_crypt_info=unsafe_self._has_crypt_info, + ) + if header is None: + raise KeyError("Unexpected empty header") - # Skip it if this tensor is not one we're supposed to load - if header.name not in tensor_sizes_by_name: - file_.seek(header.data_length, io.SEEK_CUR) - continue + # Skip it if this tensor is not one we're supposed to load + if header.name not in tensor_sizes_by_name: + file_.seek(header.data_length, io.SEEK_CUR) + continue + else: + header = unsafe_self._headers[ + tensor_items[tensors_read].name + ] numpy_dtype, *torch_dtype = header.dtype.split(OPAQUE_DTYPE_SEP) if not torch_dtype: @@ -2934,6 +2972,7 @@ def _copy_thread( if not is_meta: start = time.perf_counter_ns() if _perf_stats else 0 + file_.seek(unsafe_self._metadata[header.name].data_offset) if unsafe_self._encrypted and mv.nbytes > 0: TensorDeserializer._stream_decrypt( @@ -3266,6 +3305,7 @@ def __init__( *, encryption: Optional[EncryptionParams] = None, limit_cpu_concurrency: Optional[int] = None, + max_tensors: Optional[int] = None, ) -> None: if isinstance(file_obj, (str, bytes, os.PathLike, int)): self._file = stream_io.open_stream(file_obj, "wb+") @@ -3415,16 +3455,7 @@ def __init__( self._write(TENSORIZER_MAGIC) # Write file header metadata - if not self._encrypted: - # Can't tell if OPAQUE_TENSORIZER_VERSION - # or META_TENSOR_TENSORIZER_VERSION are needed - # until a tensor is written later with an opaque dtype - # or from the meta device, - # so assume it is compatible with version 1 until then. - version_number = NON_OPAQUE_TENSORIZER_VERSION - else: - # File encryption requires a newer tensorizer version - version_number = ENCRYPTION_TENSORIZER_VERSION + version_number = HEADERS_AT_TOP_TENSORIZER_VERSION feature_flags = _FileFeatureFlags(0) if self._encrypted: feature_flags |= _FileFeatureFlags.encrypted @@ -3437,14 +3468,26 @@ def __init__( ) self._write(self._file_header.to_bytes()) - # Reserve 256 KiB for metadata. - metadata_size = 256 * 1024 - self._write(struct.pack(" None: """ - Serializes a tensor, laying things out so that it can be read in three - calls from the input -- once for the size, once for the header, and - once for the tensor itself. + Serializes a tensor. Header data is appended to the metadata block at the top of the file, and + tensor data is appended to the bottom of the file. Args: idx: The index of the tensor in the module. @@ -3684,6 +3721,7 @@ def write_tensor( Serialization format: + in header block near top of file: { uint64 header_sz, uint16 module_idx, uint8 type, @@ -3700,529 +3738,137 @@ def write_tensor( uint8 hash_sz, []char hash_str, } hashes, - uint64 tensor_sz, - []byte tensor } - """ - self._write_tensor( - idx=idx, name=name, tensor_type=tensor_type, tensor=tensor - ) - - def _write_tensor( - self, - idx, - name: Union[_TensorPath, str], - tensor_type: TensorType, - tensor: Union[torch.Tensor, numpy.ndarray], - *, - _synchronize: bool = True, - _start_pos: Optional[int] = None, - _temporary_buffer: bool = False, - ) -> int: + uint64 tensor_sz} + ..... + affer all headers, bottom of file: + {[]byte tensor } """ - Underlying implementation for `write_tensor()`, - providing additional controls for asynchronous writes - - Args: - idx: The index of the tensor in the module. - name: The name of the tensor. - tensor_type: The type of the tensor. This is used to determine - how to interpret the tensor. - tensor: The tensor to serialize. - _synchronize: Whether to synchronize metadata after writing - and ensure that all data is written to the file before - the call returns. If false, data may continue to be written - asynchronously even after this call returns. - _start_pos: - Where in the file to write the tensor entry. If not specified, - writes starting at the current file offset. - """ - self._path_registry.register_path(name) - if isinstance(tensor, torch.Tensor): - shape: Sequence[int] = tensor.size() - has_data: bool = not tensor.is_meta - if has_data: - if not tensor.is_contiguous(): - _temporary_buffer = True - numpy_tensor = _NumpyTensor.from_tensor(tensor.contiguous()) - else: - self._file_header.version_number = max( - META_TENSOR_TENSORIZER_VERSION, - self._file_header.version_number, - ) - _temporary_buffer = True - hollow_tensor = torch.empty( - (0,) * tensor.ndim, device="cpu", dtype=tensor.dtype - ) - numpy_tensor = _NumpyTensor.from_tensor(hollow_tensor) - else: - shape: Sequence[int] = tensor.shape - has_data: bool = True - if ( - isinstance(tensor, numpy.ndarray) - and not tensor.flags.c_contiguous - and hasattr(numpy, "ascontiguousarray") - ): - numpy_tensor = _NumpyTensor.from_array( - numpy.ascontiguousarray(tensor) - ) - _temporary_buffer = True - else: - numpy_tensor = _NumpyTensor.from_array(tensor) + if isinstance(tensor, numpy.ndarray): + tensor = torch.from_numpy(tensor) - dtype_name = numpy_tensor.numpy_dtype - if numpy_tensor.is_opaque: - # The datatype name needs to contain both the numpy dtype that the - # data is serialized as and the original torch dtype. - dtype_name += OPAQUE_DTYPE_SEP + numpy_tensor.torch_dtype - self._file_header.version_number = max( - OPAQUE_TENSORIZER_VERSION, - self._file_header.version_number, + write_spec = self._WriteSpec( + module_index=idx, name=name, tensor_type=tensor_type, tensor=tensor + ) + self._bulk_write([write_spec], incremental=True) + + class _WriteSpec: + def __init__( + self, + module_index: int, + name: str | _TensorPath, + tensor_type: TensorType, + tensor: torch.Tensor, + ): + self.tensor = tensor + self.min_file_version = 0 + self.user_owns_tensor_data = True + + # Every parameter to _TensorHeaderSerializer() exists as an attribute except self.file_offset + # defaulting to the simplest possible case: + # CPU-based + # contiguous + # not hashing + # not encrypted + # not meta + # not opaque + self.module_index = module_index + self.tensor_type: TensorType = tensor_type + self.name = _TensorPath.wrap_(name) + self.dtype: Optional[str] = None # _prepare_for_write_numpy_tensor + self.shape = tensor.size() + self.data_length = tensor.nbytes + # self.file_offset # intentionally omitted, handled by _write_headers() + self.include_crc32 = True + self.include_sha256 = True + self.crypt_info: Optional[_crypt_info.CryptInfo] = ( + None # _prepare_for_write_encryption ) - tensor: numpy.ndarray = numpy_tensor.data - tensor_memory: memoryview = numpy_tensor.data.data - tensor_size: int = tensor.nbytes - if tensor_memory.nbytes != tensor_size: - raise ValueError( - f"Cannot serialize tensor {name!r}:" - f" buffer size of underlying memory ({tensor_memory.nbytes})" - f" doesn't match reported size ({tensor_size})" + # Additional payloads that get set and used during the prepare_for_write procedures + self.numpy_tensor: Optional[_NumpyTensor] = ( + None # $et in _prepare_for_write_numpy_tensor ) - if isinstance(name, str): - name_bytes: bytes = name.encode("utf-8") - else: - name_bytes: bytes = name.serialized_() - dtype_bytes = dtype_name.encode("utf-8") - if len(dtype_bytes) >= 256: - raise ValueError("dtype name length should be less than 256") - header_pos = self._file.tell() if _start_pos is None else _start_pos - - encrypted: bool = self._encrypted and has_data - if encrypted: - chunks = _Chunked( - total_size=tensor_memory.nbytes, - chunk_size=self._crypt_chunk_size, + self.header: Optional[_TensorHeaderSerializer] = ( + None # $et in _prepare_for_write_headers ) - nonces = self._new_nonces(chunks.count) - encryptor = _crypt.ChunkedEncryption( - key=self._encryption.key, - buffer=tensor_memory, - chunk_size=self._crypt_chunk_size, - nonces=nonces, - executor=self._computation_pool, + self.metadata_pos = -1 # Set in _prepare_for_write_headers + self.encryptor: Optional[_crypt.ChunkedEncryption] = ( + None # $et in _do_encryption if encrypted ) - key_derivation_chunk = self._encryption._crypt_info_chunk() - encryption_algorithm_chunk = _crypt_info.XSalsa20ParallelChunk( - chunk_size=self._crypt_chunk_size, - nonce=nonces[0], - macs=encryptor.macs, - ) - if key_derivation_chunk is not None: - chunks = (key_derivation_chunk, encryption_algorithm_chunk) - else: - chunks = (encryption_algorithm_chunk,) - crypt_info = _crypt_info.CryptInfo(chunks) - else: - encryptor = None - if _FileFeatureFlags.encrypted in self._file_header.feature_flags: - # If the `encrypted` flag is present, all headers are expected - # to have crypt_info segments, so add an empty one - crypt_info = _crypt_info.CryptInfo() - else: - crypt_info = None - - include_crc32: bool = not encrypted - - header = _TensorHeaderSerializer( - idx, - tensor_type, - name_bytes, - dtype_bytes, - shape, - tensor_size, - header_pos, - include_crc32=include_crc32, - include_sha256=True, - crypt_info=crypt_info, - ) + # self.tensor_data_task is a future for processing some contents of self.tensor + # e.g. cuda transfer, make_contiguous, hashing, encryption, writing, or decryption. + # They are often chained from one step of the process to the next + self.tensor_data_task: Optional[concurrent.futures.Future] = None - tensor_pos = header_pos + header.data_offset - - # Add our tensor metadata to the index. - metadata = header.metadata_entry - # Check for overflow - if self._metadata_cur + len(metadata) > self._metadata_end: - raise RuntimeError("Metadata overflow") - - metadata_pos = self._metadata_cur - metadata_len = len(metadata) - self._metadata_cur += metadata_len - - # This task is I/O-bound and has no prerequisites, - # so it goes into the regular writer pool. - def write_metadata(): - self._pwrite(metadata, metadata_pos, verify=metadata_len) - - self._jobs.append(self._writer_pool.submit(write_metadata)) - - # Calculate the hashes. - - # These two tasks are CPU-bound and don't block the GIL, - # so they go into the computation thread pool. - def compute_crc32(prerequisite: Optional[concurrent.futures.Future]): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - crc32 = header.compute_crc32() - return zlib.crc32(tensor_memory, crc32) - - def compute_sha256(prerequisite: Optional[concurrent.futures.Future]): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - sha256 = header.compute_sha256() - sha256.update(tensor_memory) - return sha256.digest() - - # This task is I/O-bound and dependent on the previous two tasks, - # so it goes into the header writer pool. - def commit_header( - crc32_future: Optional[concurrent.futures.Future], - sha256_future: Optional[concurrent.futures.Future], - encrypt_future: Optional[concurrent.futures.Future], - ): - crc32 = sha256 = None - if crc32_future is not None: - crc32 = crc32_future.result(_TIMEOUT) - if sha256_future is not None: - sha256 = sha256_future.result(_TIMEOUT) - if encrypt_future is not None: - encrypt_future.result(_TIMEOUT) - # These must be written only after all other futures complete - # to prevent a race condition from other threads hashing - # a partially-filled-in hash section - if crc32_future is not None: - header.add_crc32(crc32) - if sha256_future is not None: - header.add_sha256(sha256) - if encrypt_future is not None: - header.update_crypt_info() - self._pwrite(header.buffer, header_pos, verify=header.data_offset) - - hash_tasks = [] - if encrypted and not _temporary_buffer: - # If multiple tensors share memory, and were encrypted in-place, - # then this must not start hashing until any previous decryption - # tasks have restored this memory to its original state - mem_pointer = tensor.__array_interface__["data"][0] - pending_decryption = self._decryption_jobs.get(mem_pointer, None) - else: - mem_pointer = None - pending_decryption = None - if include_crc32: - crc32_task = self._computation_pool.submit( - compute_crc32, pending_decryption - ) - hash_tasks.append(crc32_task) - else: - crc32_task = None - sha256_task = self._computation_pool.submit( - compute_sha256, pending_decryption - ) - hash_tasks.append(sha256_task) - self._jobs.extend(hash_tasks) - - def encrypt(prerequisites: Iterable[concurrent.futures.Future]): - fs = concurrent.futures.wait(prerequisites, timeout=_TIMEOUT) - for f in fs.done: - # Raise exceptions - f.result() - for f in fs.not_done: - # Raise timeouts - f.result(0) - try: - encryptor.encrypt_all( - wait=True, - timeout=_TIMEOUT, - ) - except _crypt.CryptographyError as e: - raise CryptographyError("Tensor encryption failed") from e + def set_min_file_version_number(self, version_number): + self.min_file_version = max(self.min_file_version, version_number) - # This task is I/O-bound, so it goes into the regular writer pool. - def write_tensor_data( - prerequisite: Optional[concurrent.futures.Future], size: int - ): - if prerequisite is not None: - prerequisite.result(_TIMEOUT) - if has_data: - bytes_written = self._pwrite( - tensor_memory, tensor_pos, verify=size - ) - else: - bytes_written = 0 - with self._tensor_count_update_lock: - self._file_header.tensor_count += 1 - self._file_header.tensor_size += bytes_written - - def decrypt(prerequisite: concurrent.futures.Future): - try: - prerequisite.result(_TIMEOUT) - finally: - # Try to decrypt again even if writing to disk failed - # to avoid exiting with the tensor memory in a modified state - fs = encryptor.decrypt_all(wait=False) - try: - _crypt.ChunkedEncryption.wait_or_raise( - fs, - timeout=_TIMEOUT, - return_when=concurrent.futures.ALL_COMPLETED, - ) - except _crypt.CryptographyError as e: - try: - original_exc = prerequisite.exception(timeout=0) - except ( - concurrent.futures.TimeoutError, - concurrent.futures.CancelledError, - ): - original_exc = None - raise CryptographyError( - "Restoring encrypted tensor data in memory failed" - ) from (original_exc if original_exc is not None else e) - - # Encrypt the tensor memory in-place before writing - if encrypted: - encrypt_task = self._encryption_pool.submit(encrypt, hash_tasks) - self._jobs.append(encrypt_task) - else: - encrypt_task = None - - commit_header_task = self._header_writer_pool.submit( - commit_header, crc32_task, sha256_task, encrypt_task - ) - self._jobs.append(commit_header_task) - - # Write the potentially-encrypted tensor memory to the file - write_task = self._writer_pool.submit( - write_tensor_data, encrypt_task, tensor_size - ) - self._jobs.append(write_task) - # Decrypt the memory after writing is finished, if it was encrypted - if encrypted and not _temporary_buffer: - decrypt_task = self._decryption_pool.submit(decrypt, write_task) - self._jobs.append(decrypt_task) - assert mem_pointer is not None - self._decryption_jobs[mem_pointer] = decrypt_task - - tensor_endpos = tensor_pos + tensor_size - - # Update our prologue. - if _synchronize: - self._synchronize_pools() - # Move to the end of our serialized tensor to prepare - # for the next one in the synchronized case. - self._file.seek(tensor_endpos) - self._sync_prologue_state() - - ds_size = tensor_endpos - header_pos - ds_bytes = f"{ds_size:,} bytes" - - typ = { - TensorType.PARAM: "p", - TensorType.BUFFER: "b", - TensorType.STATE_DICT: "sd", - }[tensor_type] + def _maybe_fallocate(self, tensors: Sequence[_WriteSpec]): + if not _syscalls.has_fallocate() or not self._fd: + return - # if self.compress_tensors: - # comp_report = ( - # f" - tensor:[raw: {tensor_raw_sz}," - # + f" compressed: {tensor_compressed_sz}," - # + f" ratio: {compression_ratio:.2f}]" - # ) - # else: - comp_report = "" - logger.debug( - f"{idx}:{typ}:{name} - {dtype_bytes.decode('utf-8')} - " - f"{tensor.shape} -> {ds_bytes}{comp_report}" + next_pos = self._file.tell() + size = sum(len(t.name.serialized_()) for t in tensors) + size += sum( + t.tensor.element_size() + * t.tensor.nelement() + * (not t.tensor.is_meta) + for t in tensors ) - return tensor_endpos - - @staticmethod - def _async_bulk_device_to_host_transfer( - tensors, max_read_ahead: Optional[int] = 32 - ) -> Tuple[Iterator[torch.Tensor], Callable]: - """ - Transfers CUDA tensors to host memory asynchronously in bulk. - - Args: - tensors: The list of tensors to transfer. - max_read_ahead: The maximum number of tensors to queue. - - Returns: - A tuple containing an iterator over CPU tensors, - and a callback to cancel the transfer early. - """ - if len(tensors) < max_read_ahead: - transferred = queue.SimpleQueue() - else: - transferred = queue.Queue(maxsize=max_read_ahead) - - tensor_sizes = [t.element_size() * t.nelement() for t in tensors] - staging_tensor = torch.empty( - (max(tensor_sizes),), - dtype=torch.uint8, - device="cpu", - pin_memory=True, + # Rough underestimate of header size + header_min_size = 24 + size += header_min_size * len(tensors) + _syscalls.try_fallocate( + self._fd, next_pos, size, suppress_all_errors=True ) - transfer_finished = False - - def _transfer(): - nonlocal transfer_finished - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # This is in a separate CUDA stream because it shouldn't - # affect any other GPU operations, even though each - # of these transfers are synchronous - try: - for t, nbytes in zip(tensors, tensor_sizes): - if transfer_finished: - break - staging_tensor_view = ( - staging_tensor.narrow(0, 0, nbytes) - .view(t.dtype) - .view(t.shape) - ) - staging_tensor_view.copy_(t) - new_cpu_tensor = staging_tensor_view.clone() - transferred.put( - new_cpu_tensor.detach(), - timeout=_TIMEOUT, - ) - finally: - # Sentinel - transferred.put(None) - transfer_finished = True + def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): + write_specs = list(write_specs) - transfer_thread = threading.Thread( - target=_transfer, name="TensorizerTransfer", daemon=True - ) - transfer_thread.start() + if not incremental: + # TODO: make into a future + self._maybe_fallocate(write_specs) - def _interrupt_transfer(): - nonlocal transfer_finished - if not transfer_finished: - # Signal the worker thread to end on its next loop - transfer_finished = True - try: - # Unstick the worker thread so that - # it isn't waiting for an open spot - # that will never arrive - transferred.get_nowait() - except queue.Empty: - pass + for w in write_specs: + self._path_registry.register_path(w.name) - return ( - iter(lambda: transferred.get(timeout=_TIMEOUT), None), - _interrupt_transfer, - ) + cuda_executor = self._prepare_for_write_cuda(write_specs) + try: + self._prepare_for_write_contiguous(write_specs) + self._prepare_for_write_meta(write_specs) + self._prepare_for_write_numpy_tensor(write_specs) + self._prepare_for_write_opaque(write_specs) + if self._encrypted: + self._prepare_for_write_encryption(write_specs) + self._prepare_for_write_headers(write_specs) + self._prepare_for_write_hashes(write_specs) - class _WriteSpec(typing.NamedTuple): - idx: int - path: Union[_TensorPath, str] - tensor_type: TensorType - tensor: torch.Tensor - callback: Optional[Callable] + if self._encrypted: + self._do_encryption(write_specs) + self._do_commit_headers(write_specs) + self._do_commit_tensor_data(write_specs) + if self._encrypted: + self._maybe_decrypt_data(write_specs) - def _bulk_write(self, tensors: Iterable[_WriteSpec]): - tensors = collections.deque(tensors) - next_pos = self._file.tell() - if _syscalls.has_fallocate() and self._fd: - size = sum( - len(_TensorPath.wrap_(t.path).serialized_()) for t in tensors - ) - size += sum( - t.tensor.element_size() - * t.tensor.nelement() - * (not t.tensor.is_meta) - for t in tensors - ) - # Rough underestimate of header size - header_min_size = 24 - size += header_min_size * len(tensors) - _syscalls.try_fallocate( - self._fd, next_pos, size, suppress_all_errors=True + self._file_header.version_number = max( + self._file_header.version_number, + max(w.min_file_version for w in write_specs), ) - cuda_tensors = [ - t.tensor for t in tensors if t.tensor.device.type == "cuda" - ] - if cuda_tensors: - ( - transferred, - interrupt_transfer, - ) = self._async_bulk_device_to_host_transfer(cuda_tensors) - else: - transferred = interrupt_transfer = None - del cuda_tensors - - if self._encrypted: - shared = [] - seen_addresses = set() - for t in reversed(tensors): - if t.tensor.device.type in ("cuda", "meta"): - shared.append(False) - else: - address = t.tensor.data_ptr() - shared.append(address in seen_addresses) - seen_addresses.add(address) - del seen_addresses - else: - shared = [False] * len(tensors) + self._synchronize_pools() + self._sync_prologue_state() + except Exception as e: + if cuda_executor is not None: + cuda_executor.shutdown(wait=False, cancel_futures=True) + raise e - try: - while tensors: - idx, name, tensor_type, tensor, callback = tensors.popleft() - is_shared = shared.pop() - self._idx = idx - if tensor.device.type == "cuda": - tensor = next(transferred) - temp_tensor = True - elif is_shared and self._encrypted: - # Un-shares tensor memory in preparation for in-place - # operations on the buffer that would otherwise conflict - # with one another. Full support for shared-memory tensors - # (e.g. if they were only written once) could make - # this unnecessary, once implemented. - # Another option would be to reuse the same encrypted - # weights and decrypt them at the end. This would require - # confirming that the tensor data regions are actually - # identical, and don't just overlap. - tensor = tensor.clone().detach() - temp_tensor = True - else: - temp_tensor = False - next_pos = self._write_tensor( - idx, - name, - tensor_type, - tensor, - _synchronize=False, - _start_pos=next_pos, - _temporary_buffer=temp_tensor, - ) - if callback is not None: - callback() - except Exception: - if interrupt_transfer is not None: - interrupt_transfer() - raise - self._synchronize_pools() - self._file.seek(next_pos) - self._sync_prologue_state() + if cuda_executor is not None: + cuda_executor.shutdown(wait=True, cancel_futures=False) def write_module( self, m: torch.nn.Module, - remove_tensors: bool = False, *, include_non_persistent_buffers: bool = True, ) -> None: @@ -4238,9 +3884,6 @@ def write_module( Args: m: The module to serialize. - remove_tensors: Whether to delete each tensor from `m` - after serializing it. - Deleted tensors are replaced with ``None``. include_non_persistent_buffers: Whether to serialize buffers registered with ``persistent=False``. Set to ``False`` to match the behaviour of @@ -4251,10 +3894,9 @@ def write_module( modules = tuple(m.named_modules()) - def extract_tensors(): + def extract_tensors() -> Iterator[TensorSerializer._WriteSpec]: chain = itertools.chain repeat = itertools.repeat - callback = None for idx, (module_name, module) in enumerate(modules): module: torch.nn.Module parameters = module.named_parameters(recurse=False) @@ -4264,14 +3906,11 @@ def extract_tensors(): zip(buffers, repeat(TensorType.BUFFER)), ): label = f"{module_name}.{name}" - if remove_tensors: - callback = partial(setattr, module, name, None) yield TensorSerializer._WriteSpec( - idx=idx, - path=label, + module_index=idx, + name=label, tensor_type=tensor_type, tensor=tensor, - callback=callback, ) def persistent_buffers() -> Set[str]: @@ -4311,7 +3950,7 @@ def persistent_buffers() -> Set[str]: spec for spec in all_tensors if spec.tensor_type != TensorType.BUFFER - or spec.path in persistent + or str(spec.name) in persistent ) self._bulk_write(all_tensors) @@ -4446,17 +4085,450 @@ def write_state_dict(self, state_dict: Union[Dict, List, Tuple]): idx = 0 self._bulk_write( TensorSerializer._WriteSpec( - idx=idx, - path=name, + module_index=idx, + name=name, tensor_type=TensorType.STATE_DICT, tensor=param, - callback=None, ) for name, param in _tensor_path.flatten_structure( torch.Tensor, state_dict ) ) + def _prepare_for_write_cuda( + self, write_specs: Sequence[_WriteSpec] + ) -> Optional[concurrent.futures.ThreadPoolExecutor]: + cuda_specs = [w for w in write_specs if w.tensor.device.type == "cuda"] + if not cuda_specs: + return None + + class CudaTransfer: + def __init__(self, max_size): + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="TransferThread", + initializer=self._allocate_staging_tensor, + initargs=(max_size,), + ) + + def submit(self, write_spec) -> concurrent.futures.Future: + return self.executor.submit(self._transfer, write_spec) + + def _allocate_staging_tensor(self, max_size: int): + self._stream = torch.cuda.Stream() + self._staging_tensor = torch.empty( + (max_size,), + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + + def _transfer(self, write_spec): + nbytes = ( + write_spec.tensor.element_size() + * write_spec.tensor.nelement() + ) + staging_tensor_view = ( + self._staging_tensor.narrow(0, 0, nbytes) + .view(write_spec.tensor.dtype) + .view(write_spec.shape) + ) + with torch.cuda.stream(self._stream): + staging_tensor_view.copy_(write_spec.tensor) + write_spec.user_owns_tensor_data = False + write_spec.tensor = staging_tensor_view.clone().detach() + + max_tensor_size = max( + [w.tensor.element_size() * w.tensor.nelement() for w in cuda_specs] + ) + cuda_transfer = CudaTransfer(max_tensor_size) + + for w in cuda_specs: + assert w.tensor_data_task is None + w.tensor_data_task = cuda_transfer.submit(w) + + return cuda_transfer.executor + + def _prepare_for_write_contiguous(self, write_specs: Sequence[_WriteSpec]): + def make_contiguous(write_spec, dependency): + if dependency is not None: + dependency.result(_TIMEOUT) + write_spec.tensor = write_spec.tensor.contiguous() + write_spec.data_length = write_spec.tensor.nbytes + write_spec.user_owns_tensor_data = False + + for w in write_specs: + # if there is a tensor_data_task it is a cuda tensor + if w.tensor_data_task is not None or w.tensor.is_contiguous(): + continue + w.tensor_data_task = self._computation_pool.submit( + make_contiguous, w, w.tensor_data_task + ) + + def _prepare_for_write_numpy_tensor( + self, write_specs: Sequence[_WriteSpec] + ): + for w in write_specs: + # all futures are resolved here. This step is not multi-threaded. + if w.tensor_data_task is not None: + w.tensor_data_task.result(_TIMEOUT) + w.tensor_data_task = None + w.numpy_tensor = _NumpyTensor.from_tensor(w.tensor) + w.dtype = w.numpy_tensor.numpy_dtype + if w.numpy_tensor.data.data.nbytes != w.tensor.nbytes: + raise ValueError( + f"Cannot serialize tensor {w.name!r}:" + f" buffer size of underlying memory ({w.numpy_tensor.data.data.nbytes})" + f" doesn't match reported size ({w.tensor.nbytes})" + ) + + def _prepare_for_write_opaque( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + for w in write_specs: + if not w.numpy_tensor.is_opaque: # type: ignore + continue + # The datatype name needs to contain both the numpy dtype that the + # data is serialized as and the original torch dtype. + w.dtype += OPAQUE_DTYPE_SEP + w.numpy_tensor.torch_dtype # type: ignore + w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION) + + @staticmethod + def _do_clone(write_spec, dependency: Optional[concurrent.futures.Future]): + if dependency is not None: + dependency.result(_TIMEOUT) + write_spec.tensor = write_spec.tensor.clone().detach() + + def _prepare_for_write_encryption( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + assert self._encrypted and self._encryption is not None + + # If any tensors are shared, so we need to clone all but one of them before encrypting + write_specs_by_addr: Dict[int, List[TensorSerializer._WriteSpec]] = ( + collections.defaultdict(list) + ) + for w in write_specs: + if w.tensor.device.type != "cpu": + continue + address = w.tensor.untyped_storage().data_ptr() + write_specs_by_addr[address].append(w) + + for shared_write_specs in write_specs_by_addr.values(): + if len(shared_write_specs) == 1: + continue + + clone_dependencies = _FutureGroup( + [ + w.tensor_data_task + for w in shared_write_specs + if w.tensor_data_task is not None + ] + ) + + clone_tasks = [] + for w in shared_write_specs[1:]: + clone_tasks.append( + self._computation_pool.submit( + self._do_clone, w, clone_dependencies + ) + ) + w.user_owns_tensor_data = False + + shared_write_specs[0].tensor_data_task = _FutureGroup( + clone_tasks + clone_dependencies.futures + ) + for w in shared_write_specs[1:]: + w.tensor_data_task = _FutureGroup(clone_tasks) + + for w in write_specs: + assert w.numpy_tensor is not None + w.include_crc32 = False + + if w.data_length == 0: + # All headers are expected to have crypt_info segments, so add + # an empty one + w.crypt_info = _crypt_info.CryptInfo() + continue + + if w.tensor_data_task is not None: + w.tensor_data_task.result(_TIMEOUT) + w.tensor_data_task = None + + tensor_memory: memoryview = w.numpy_tensor.tensor_memory + chunked = _Chunked( + total_size=tensor_memory.nbytes, + chunk_size=self._crypt_chunk_size, + ) + nonces = self._new_nonces(chunked.count) + w.encryptor = _crypt.ChunkedEncryption( + key=self._encryption.key, + buffer=tensor_memory, + chunk_size=self._crypt_chunk_size, + nonces=nonces, + executor=self._computation_pool, + ) + + key_derivation_chunk = self._encryption._crypt_info_chunk() + encryption_algorithm_chunk = _crypt_info.XSalsa20ParallelChunk( + chunk_size=self._crypt_chunk_size, + nonce=nonces[0], + macs=w.encryptor.macs, + ) + chunks: Sequence[Any] + if key_derivation_chunk is not None: + chunks = (key_derivation_chunk, encryption_algorithm_chunk) + else: + chunks = (encryption_algorithm_chunk,) + w.crypt_info = _crypt_info.CryptInfo(chunks) + + def _prepare_for_write_headers( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + # We first need to construct the headers so that we know the size of each + for w in write_specs: + dtype_bytes = w.dtype.encode("utf-8") # type: ignore + if len(dtype_bytes) >= 256: + raise ValueError("dtype name length should be less than 256") + + w.header = _TensorHeaderSerializer( + w.module_index, + w.tensor_type, + w.name.serialized_(), # name as bytes + dtype_bytes, + w.shape, + w.data_length, + 0, # bogus file_offset. This gets filled in in build() + include_crc32=w.include_crc32, + include_sha256=w.include_sha256, + crypt_info=w.crypt_info, + ) + + # Specify the offsets for each metadata entry + file_offset = ( + self._metadata_cur + ) # position of next metadata entry to write + + ## metadata + for w in write_specs: + w.metadata_pos = file_offset + file_offset += w.header.get_metadata_entry_segment().size + + self._metadata_cur = file_offset + if self._metadata_end is None: + self._metadata_end = self._metadata_cur + elif file_offset > self._metadata_end: + raise RuntimeError("Metadata block is full. Increase max_tensors") + + ## headers + if self._header_cur is not None: + if self._header_cur < file_offset: + raise RuntimeError("Somehow wrote past metadata block") + file_offset = self._header_cur + + for w in write_specs: + w.header.file_offset = file_offset + file_offset += w.header.size + + self._header_cur = file_offset + if self._header_end is None: + self._header_end = self._header_cur + elif self._header_cur > self._header_end: + raise RuntimeError("Header block is full. Increase max_tensors") + + ## tensors + if self._tensor_cur is None: + # The block of tensor data starts on a page-aligned boundary + self._tensor_cur = (file_offset + 4095) & ~4095 + else: + if self._tensor_cur < file_offset: + raise RuntimeError("Somehow wrote past header block") + # Each tensor itself begins on an 8-byte aligned boundary + file_offset = (self._tensor_cur + 7) & ~7 + + # file_offset is now where we should start writing tensor data + for w in write_specs: + w.header.build(file_offset) # type: ignore + file_offset += w.data_length + + self._tensor_cur = file_offset + + def _prepare_for_write_meta( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + for w in write_specs: + if not w.tensor.is_meta: + continue + w.tensor = torch.empty( + (0,) * w.tensor.ndim, device="cpu", dtype=w.tensor.dtype + ) + w.data_length = 0 + w.user_owns_tensor_data = False + + def _prepare_for_write_hashes( + self, write_specs: Sequence[_WriteSpec] + ) -> None: + def compute_crc32( + write_spec: TensorSerializer._WriteSpec, + dependency: Optional[concurrent.futures.Future], + ): + if dependency is not None: + dependency.result(_TIMEOUT) + header_crc32 = write_spec.header.compute_crc32() + crc32 = zlib.crc32( + write_spec.numpy_tensor.tensor_memory, header_crc32 + ) + write_spec.header.add_crc32(crc32) + + def compute_sha256( + write_spec: TensorSerializer._WriteSpec, + dependency: Optional[concurrent.futures.Future], + ): + if dependency is not None: + dependency.result(_TIMEOUT) + sha256 = write_spec.header.compute_sha256() + sha256.update(write_spec.numpy_tensor.tensor_memory) + write_spec.header.add_sha256(sha256.digest()) + + for w in write_specs: + old_tensor_data_task = w.tensor_data_task + + hash_tasks = [] + if w.include_crc32: + crc32_task = self._computation_pool.submit( + compute_crc32, w, old_tensor_data_task + ) + hash_tasks.append(crc32_task) + if w.include_sha256: + sha256_task = self._computation_pool.submit( + compute_sha256, w, old_tensor_data_task + ) + hash_tasks.append(sha256_task) + + if hash_tasks: + w.tensor_data_task = _FutureGroup(hash_tasks) + self._jobs.extend(hash_tasks) + + def _do_encryption(self, write_specs: Sequence[_WriteSpec]) -> None: + def encrypt(write_spec, dependency): + if dependency is not None: + dependency.result(_TIMEOUT) + try: + write_spec.encryptor.encrypt_all(wait=True, timeout=_TIMEOUT) + except _crypt.CryptographyError as e: + raise CryptographyError("Tensor encryption failed") from e + write_spec.header.update_crypt_info() + + for w in write_specs: + if not w.data_length: + continue + w.tensor_data_task = self._encryption_pool.submit( + encrypt, w, w.tensor_data_task + ) + self._jobs.append(w.tensor_data_task) + + def _do_commit_headers(self, write_specs: Sequence[_WriteSpec]) -> None: + # TODO: this is lots of tiny writes. Buffer them for performance + def commit_header(write_spec, dependency): + if dependency is not None: + dependency.result(_TIMEOUT) + self._pwrite( + write_spec.header.metadata_entry, + write_spec.metadata_pos, + verify=len(write_spec.header.metadata_entry), + ) + self._pwrite( + write_spec.header.buffer, + write_spec.header.file_offset, + verify=write_spec.header.size, + ) + + metadata_size = ( + self._metadata_cur - self._metadata_start - 8 + ) # 8 bytes for metadata length field + metadata_size_task = self._header_writer_pool.submit( + self._pwrite, + struct.pack(" Date: Thu, 18 Apr 2024 15:07:50 -0700 Subject: [PATCH 02/11] python 3.8 fixes --- tensorizer/serialization.py | 9 ++++++--- tests/test_serialization.py | 27 +++++++++++++++++++-------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index a2519b2..2e193cc 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -3755,7 +3755,7 @@ class _WriteSpec: def __init__( self, module_index: int, - name: str | _TensorPath, + name: Union[str, _TensorPath], tensor_type: TensorType, tensor: torch.Tensor, ): @@ -3859,12 +3859,14 @@ def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): self._synchronize_pools() self._sync_prologue_state() except Exception as e: + for j in self._jobs: + j.cancel() if cuda_executor is not None: - cuda_executor.shutdown(wait=False, cancel_futures=True) + cuda_executor.shutdown(wait=False) raise e if cuda_executor is not None: - cuda_executor.shutdown(wait=True, cancel_futures=False) + cuda_executor.shutdown(wait=True) def write_module( self, @@ -4146,6 +4148,7 @@ def _transfer(self, write_spec): for w in cuda_specs: assert w.tensor_data_task is None w.tensor_data_task = cuda_transfer.submit(w) + self._jobs.append(w.tensor_data_task) return cuda_transfer.executor diff --git a/tests/test_serialization.py b/tests/test_serialization.py index aee1975..6c7904c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -816,8 +816,8 @@ class TestIncrementalSerialization(unittest.TestCase): def test_too_many_no_max_tensors(self): # Any attempt to call write_tensor() more than once will fail if you haven't specified max_tensors - with tempfile.NamedTemporaryFile( - "wb+", delete=True, delete_on_close=False + with NamedTemporaryFileCloseNoDelete( + "wb+", delete=True ) as temporary_file: serializer = TensorSerializer(temporary_file) serializer.write_tensor( @@ -836,8 +836,8 @@ def test_too_many_no_max_tensors(self): def test_too_many(self): # If you set max_tensors too low you'll eventually run out of header space - with tempfile.NamedTemporaryFile( - "wb+", delete=True, delete_on_close=False + with NamedTemporaryFileCloseNoDelete( + "wb+", delete=True ) as temporary_file: serializer = TensorSerializer(temporary_file, max_tensors=2) last_success = 0 @@ -854,8 +854,8 @@ def test_too_many(self): def test_long_tensor_name(self): # Tensors with very long names could still cause space problems even if max_tensors is correct - with tempfile.NamedTemporaryFile( - "wb+", delete=True, delete_on_close=False + with NamedTemporaryFileCloseNoDelete( + "wb+", delete=True ) as temporary_file: serializer = TensorSerializer(temporary_file, max_tensors=2) serializer.write_tensor( @@ -877,8 +877,8 @@ def _test_model(self, ser_kwargs, deser_kwargs): model = AutoModelForCausalLM.from_pretrained(model_name) tensors = model.state_dict() - with tempfile.NamedTemporaryFile( - "wb+", delete=True, delete_on_close=False + with NamedTemporaryFileCloseNoDelete( + "wb+", delete=True ) as temporary_file: serializer = TensorSerializer( temporary_file, max_tensors=len(tensors), **ser_kwargs @@ -1478,3 +1478,14 @@ def test_module_verification_fail(self): self.assertTrue( status, f"Unexpected mismatch on {tensor_name}" ) + + +def NamedTemporaryFileCloseNoDelete(*args, **kwargs): + # NamedTemporaryFile(delete_on_close=False) is not available until Python 3.12 + if sys.version_info >= (3, 12): + kwargs["delete_on_close"] = False + return tempfile.NamedTemporaryFile(*args, **kwargs) + + f = tempfile.NamedTemporaryFile(*args, **kwargs) + f.close = f.file.close + return f From cab61c5ae8655dbdcc84811a2edffa5317b49db1 Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Thu, 18 Apr 2024 17:10:34 -0700 Subject: [PATCH 03/11] update numpy_tensor upon clone --- tensorizer/serialization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 2e193cc..e08d4ba 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -4201,6 +4201,7 @@ def _do_clone(write_spec, dependency: Optional[concurrent.futures.Future]): if dependency is not None: dependency.result(_TIMEOUT) write_spec.tensor = write_spec.tensor.clone().detach() + write_spec.numpy_tensor = _NumpyTensor.from_tensor(write_spec.tensor) def _prepare_for_write_encryption( self, write_specs: Sequence[_WriteSpec] From 94d60ea4b2b6023dccd3f4225662c4635edda4be Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Thu, 25 Apr 2024 17:15:41 -0700 Subject: [PATCH 04/11] flatten future groups for wait --- tensorizer/_futuregroup.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorizer/_futuregroup.py b/tensorizer/_futuregroup.py index 8c24592..085de3d 100644 --- a/tensorizer/_futuregroup.py +++ b/tensorizer/_futuregroup.py @@ -51,7 +51,13 @@ def _future_wait_and_raise( # Wait on a list of futures with a timeout. Raise any exceptions, including TimeoutErrors. # otherwise return the list of results in the same order as the input futures. results = [] - fs = concurrent.futures.wait(futures, timeout=timeout) + flattened_futures = [] + for f in futures: + if isinstance(f, _FutureGroup): + flattened_futures.extend(f.futures) + else: + flattened_futures.append(f) + fs = concurrent.futures.wait(flattened_futures, timeout=timeout) for f in fs.done: # if the future has an exception, this will raise it results.append(f.result()) From a55e119897b87f098f29dd487de94344601c05e5 Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Fri, 26 Apr 2024 09:19:28 -0700 Subject: [PATCH 05/11] Stop trying to subclass Future --- tensorizer/_futuregroup.py | 41 ++++++++++++++----------------------- tensorizer/serialization.py | 22 ++++++++++++-------- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/tensorizer/_futuregroup.py b/tensorizer/_futuregroup.py index 085de3d..68fd5bb 100644 --- a/tensorizer/_futuregroup.py +++ b/tensorizer/_futuregroup.py @@ -1,9 +1,8 @@ import concurrent.futures -from collections.abc import Callable -from typing import Any, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union -class _FutureGroup(concurrent.futures.Future): +class _FutureGroup: def __init__(self, futures: Sequence[concurrent.futures.Future]): self.futures = futures @@ -22,8 +21,8 @@ def running(self) -> bool: def done(self) -> bool: return all(f.done() for f in self.futures) - def result(self, timeout=None) -> Sequence[Any]: - return _future_wait_and_raise(self.futures, timeout=timeout) + def result(self, timeout=None) -> None: + _future_wait_and_raise(self.futures, timeout=timeout) def exception(self, timeout=None) -> Optional[BaseException]: for f in self.futures: @@ -32,36 +31,26 @@ def exception(self, timeout=None) -> Optional[BaseException]: return exc return None - def add_done_callback(self, _): - raise NotImplementedError() - def set_running_or_notify_cancel(self) -> bool: - raise NotImplementedError() +_Future = Union[_FutureGroup, concurrent.futures.Future] - def set_result(self, _) -> None: - raise NotImplementedError() - def set_exception(self, _) -> None: - raise NotImplementedError() - - -def _future_wait_and_raise( - futures: Sequence[concurrent.futures.Future], timeout=None -) -> Sequence[Any]: +def _future_wait_and_raise(futures: Sequence[_Future], timeout=None) -> None: # Wait on a list of futures with a timeout. Raise any exceptions, including TimeoutErrors. - # otherwise return the list of results in the same order as the input futures. - results = [] - flattened_futures = [] - for f in futures: + + flattened_futures: List[concurrent.futures.Future] = [] + futures = list(futures) + while futures: + f = futures.pop() if isinstance(f, _FutureGroup): - flattened_futures.extend(f.futures) + futures.extend(f.futures) else: flattened_futures.append(f) + fs = concurrent.futures.wait(flattened_futures, timeout=timeout) for f in fs.done: # if the future has an exception, this will raise it - results.append(f.result()) + f.result() for f in fs.not_done: # force raise of TimeoutError - results.append(f.result(0)) - return results + f.result(0) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index e08d4ba..ad684c6 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -62,7 +62,11 @@ from tensorizer._crypt._cgroup_cpu_count import ( effective_cpu_count as _effective_cpu_count, ) -from tensorizer._futuregroup import _future_wait_and_raise, _FutureGroup +from tensorizer._futuregroup import ( + _Future, + _future_wait_and_raise, + _FutureGroup, +) from tensorizer._internal_utils import Chunked as _Chunked from tensorizer._internal_utils import _variable_read from tensorizer._NumpyTensor import OPAQUE_DTYPE_SEP, _NumpyTensor @@ -3437,7 +3441,7 @@ def __init__( # to each pool in the same relative order. # Tracks work submitted to all pools to wait for pending work to finish. - self._jobs: List[concurrent.futures.Future] = [] + self._jobs: List[_Future] = [] # Tracks work submitted to the decryption pool to prevent conflicting, # overlapping in-place operations on tensors using shared storage. self._decryption_jobs: typing.MutableMapping[ @@ -3799,7 +3803,7 @@ def __init__( # self.tensor_data_task is a future for processing some contents of self.tensor # e.g. cuda transfer, make_contiguous, hashing, encryption, writing, or decryption. # They are often chained from one step of the process to the next - self.tensor_data_task: Optional[concurrent.futures.Future] = None + self.tensor_data_task: Optional[_Future] = None def set_min_file_version_number(self, version_number): self.min_file_version = max(self.min_file_version, version_number) @@ -4374,7 +4378,7 @@ def _prepare_for_write_hashes( ) -> None: def compute_crc32( write_spec: TensorSerializer._WriteSpec, - dependency: Optional[concurrent.futures.Future], + dependency: Optional[_Future], ): if dependency is not None: dependency.result(_TIMEOUT) @@ -4386,7 +4390,7 @@ def compute_crc32( def compute_sha256( write_spec: TensorSerializer._WriteSpec, - dependency: Optional[concurrent.futures.Future], + dependency: Optional[_Future], ): if dependency is not None: dependency.result(_TIMEOUT) @@ -4414,7 +4418,7 @@ def compute_sha256( self._jobs.extend(hash_tasks) def _do_encryption(self, write_specs: Sequence[_WriteSpec]) -> None: - def encrypt(write_spec, dependency): + def encrypt(write_spec, dependency: _Future): if dependency is not None: dependency.result(_TIMEOUT) try: @@ -4433,7 +4437,7 @@ def encrypt(write_spec, dependency): def _do_commit_headers(self, write_specs: Sequence[_WriteSpec]) -> None: # TODO: this is lots of tiny writes. Buffer them for performance - def commit_header(write_spec, dependency): + def commit_header(write_spec, dependency: _Future): if dependency is not None: dependency.result(_TIMEOUT) self._pwrite( @@ -4468,7 +4472,7 @@ def commit_header(write_spec, dependency): def _do_commit_tensor_data(self, write_specs: Sequence[_WriteSpec]): def commit_tensor_data( write_spec: TensorSerializer._WriteSpec, - dependency: Optional[concurrent.futures.Future], + dependency: Optional[_Future], ): if dependency is not None: dependency.result(_TIMEOUT) @@ -4494,7 +4498,7 @@ def commit_tensor_data( def _maybe_decrypt_data(self, write_specs: Sequence[_WriteSpec]): def decrypt( write_spec: TensorSerializer._WriteSpec, - dependency: Optional[concurrent.futures.Future], + dependency: Optional[_Future], ): try: if dependency is not None: From f2009871eedbd03dd146ac65cd69f20d4691d75f Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Fri, 26 Apr 2024 11:11:04 -0700 Subject: [PATCH 06/11] more _FutureGroup typing --- tensorizer/_futuregroup.py | 2 +- tensorizer/serialization.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorizer/_futuregroup.py b/tensorizer/_futuregroup.py index 68fd5bb..ac0210c 100644 --- a/tensorizer/_futuregroup.py +++ b/tensorizer/_futuregroup.py @@ -3,7 +3,7 @@ class _FutureGroup: - def __init__(self, futures: Sequence[concurrent.futures.Future]): + def __init__(self, futures): # type: (Sequence[_Future]) -> None self.futures = futures def cancel(self) -> bool: diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index ad684c6..f5e57ff 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -4201,7 +4201,7 @@ def _prepare_for_write_opaque( w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION) @staticmethod - def _do_clone(write_spec, dependency: Optional[concurrent.futures.Future]): + def _do_clone(write_spec, dependency: Optional[_Future]): if dependency is not None: dependency.result(_TIMEOUT) write_spec.tensor = write_spec.tensor.clone().detach() @@ -4234,7 +4234,7 @@ def _prepare_for_write_encryption( ] ) - clone_tasks = [] + clone_tasks: List[_Future] = [] for w in shared_write_specs[1:]: clone_tasks.append( self._computation_pool.submit( @@ -4401,7 +4401,7 @@ def compute_sha256( for w in write_specs: old_tensor_data_task = w.tensor_data_task - hash_tasks = [] + hash_tasks: List[_Future] = [] if w.include_crc32: crc32_task = self._computation_pool.submit( compute_crc32, w, old_tensor_data_task @@ -4418,7 +4418,7 @@ def compute_sha256( self._jobs.extend(hash_tasks) def _do_encryption(self, write_specs: Sequence[_WriteSpec]) -> None: - def encrypt(write_spec, dependency: _Future): + def encrypt(write_spec, dependency: Optional[_Future]): if dependency is not None: dependency.result(_TIMEOUT) try: From 03af4b889ed96b69e306ba0ddee3e5b08654e637 Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Fri, 26 Apr 2024 14:18:12 -0700 Subject: [PATCH 07/11] axe numpy_tensor --- tensorizer/serialization.py | 75 ++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index f5e57ff..5840a52 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -8,6 +8,7 @@ import collections.abc import concurrent.futures import contextlib +import ctypes import dataclasses import enum import functools @@ -3778,7 +3779,7 @@ def __init__( self.module_index = module_index self.tensor_type: TensorType = tensor_type self.name = _TensorPath.wrap_(name) - self.dtype: Optional[str] = None # _prepare_for_write_numpy_tensor + self.dtype: Optional[str] = None # _prepare_for_write_dtype self.shape = tensor.size() self.data_length = tensor.nbytes # self.file_offset # intentionally omitted, handled by _write_headers() @@ -3789,9 +3790,6 @@ def __init__( ) # Additional payloads that get set and used during the prepare_for_write procedures - self.numpy_tensor: Optional[_NumpyTensor] = ( - None # $et in _prepare_for_write_numpy_tensor - ) self.header: Optional[_TensorHeaderSerializer] = ( None # $et in _prepare_for_write_headers ) @@ -3805,6 +3803,13 @@ def __init__( # They are often chained from one step of the process to the next self.tensor_data_task: Optional[_Future] = None + @property + def tensor_memoryview(self) -> memoryview: + nbytes = self.tensor.element_size() * self.tensor.nelement() + return memoryview( + (ctypes.c_char * nbytes).from_address(self.tensor.data_ptr()) + ) + def set_min_file_version_number(self, version_number): self.min_file_version = max(self.min_file_version, version_number) @@ -3841,8 +3846,7 @@ def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): try: self._prepare_for_write_contiguous(write_specs) self._prepare_for_write_meta(write_specs) - self._prepare_for_write_numpy_tensor(write_specs) - self._prepare_for_write_opaque(write_specs) + self._prepare_for_write_dtype(write_specs) if self._encrypted: self._prepare_for_write_encryption(write_specs) self._prepare_for_write_headers(write_specs) @@ -4172,33 +4176,31 @@ def make_contiguous(write_spec, dependency): make_contiguous, w, w.tensor_data_task ) - def _prepare_for_write_numpy_tensor( - self, write_specs: Sequence[_WriteSpec] - ): - for w in write_specs: - # all futures are resolved here. This step is not multi-threaded. - if w.tensor_data_task is not None: - w.tensor_data_task.result(_TIMEOUT) - w.tensor_data_task = None - w.numpy_tensor = _NumpyTensor.from_tensor(w.tensor) - w.dtype = w.numpy_tensor.numpy_dtype - if w.numpy_tensor.data.data.nbytes != w.tensor.nbytes: - raise ValueError( - f"Cannot serialize tensor {w.name!r}:" - f" buffer size of underlying memory ({w.numpy_tensor.data.data.nbytes})" - f" doesn't match reported size ({w.tensor.nbytes})" - ) + def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]): + torch_dtype_to_numpy_dtype_cache: Dict[str, str] = {} - def _prepare_for_write_opaque( - self, write_specs: Sequence[_WriteSpec] - ) -> None: for w in write_specs: - if not w.numpy_tensor.is_opaque: # type: ignore - continue - # The datatype name needs to contain both the numpy dtype that the - # data is serialized as and the original torch dtype. - w.dtype += OPAQUE_DTYPE_SEP + w.numpy_tensor.torch_dtype # type: ignore - w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION) + tensor_dtype_str = str(w.tensor.dtype) + if _NumpyTensor._is_asymmetric(w.tensor.dtype): + # is opaque + w.dtype = ( + f" Date: Thu, 25 Apr 2024 16:17:23 -0700 Subject: [PATCH 08/11] buffer header writes. fallocate is a future --- tensorizer/serialization.py | 90 ++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 35 deletions(-) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 5840a52..90f66e1 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -3813,9 +3813,11 @@ def tensor_memoryview(self) -> memoryview: def set_min_file_version_number(self, version_number): self.min_file_version = max(self.min_file_version, version_number) - def _maybe_fallocate(self, tensors: Sequence[_WriteSpec]): + def _maybe_fallocate( + self, tensors: Sequence[_WriteSpec] + ) -> Optional[concurrent.futures.Future]: if not _syscalls.has_fallocate() or not self._fd: - return + return None next_pos = self._file.tell() size = sum(len(t.name.serialized_()) for t in tensors) @@ -3828,16 +3830,21 @@ def _maybe_fallocate(self, tensors: Sequence[_WriteSpec]): # Rough underestimate of header size header_min_size = 24 size += header_min_size * len(tensors) - _syscalls.try_fallocate( - self._fd, next_pos, size, suppress_all_errors=True + + return self._header_writer_pool.submit( + _syscalls.try_fallocate, + self._fd, + next_pos, + size, + suppress_all_errors=True, ) def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): write_specs = list(write_specs) + write_dependency: Optional[concurrent.futures.Future] = None if not incremental: - # TODO: make into a future - self._maybe_fallocate(write_specs) + write_dependency = self._maybe_fallocate(write_specs) for w in write_specs: self._path_registry.register_path(w.name) @@ -3854,6 +3861,9 @@ def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False): if self._encrypted: self._do_encryption(write_specs) + if write_dependency: + write_dependency.result(_TIMEOUT) + self._do_commit_headers(write_specs) self._do_commit_tensor_data(write_specs) if self._encrypted: @@ -4434,39 +4444,49 @@ def encrypt(write_spec, dependency: Optional[_Future]): ) self._jobs.append(w.tensor_data_task) - def _do_commit_headers(self, write_specs: Sequence[_WriteSpec]) -> None: - # TODO: this is lots of tiny writes. Buffer them for performance - def commit_header(write_spec, dependency: _Future): - if dependency is not None: - dependency.result(_TIMEOUT) - self._pwrite( - write_spec.header.metadata_entry, - write_spec.metadata_pos, - verify=len(write_spec.header.metadata_entry), - ) + def _do_commit_headers(self, write_specs_: Sequence[_WriteSpec]) -> None: + def do_commit( + write_specs: Sequence[TensorSerializer._WriteSpec], + dependencies: Sequence[_Future], + ): + + header_block_size = self._header_cur - self._metadata_start + header_buffer = bytearray(header_block_size) + + metadata_start = self._metadata_start + metadata_size = ( + self._metadata_cur - metadata_start - 8 + ) # 8 bytes for metadata length field + struct.pack_into(" Date: Fri, 26 Apr 2024 16:07:02 -0700 Subject: [PATCH 09/11] buffer header writes. fallocate is a future. Buffering headers won't work in incremental mode, so preserve the old path --- tensorizer/serialization.py | 55 ++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 90f66e1..3e2dc6f 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -4449,7 +4449,7 @@ def do_commit( write_specs: Sequence[TensorSerializer._WriteSpec], dependencies: Sequence[_Future], ): - + # Fast version: makes one buffer containing the size, metadata, and headers, and writes it one go header_block_size = self._header_cur - self._metadata_start header_buffer = bytearray(header_block_size) @@ -4478,15 +4478,50 @@ def do_commit( header_buffer, metadata_start, verify=header_block_size ) - deps = [ - w.tensor_data_task - for w in write_specs_ - if w.tensor_data_task is not None - ] - commit_header_task = self._header_writer_pool.submit( - do_commit, list(write_specs_), deps - ) - self._jobs.append(commit_header_task) + def do_commit_incremental(write_spec, dependency: _Future): + # Slow version: issues one write for each metadata and one write for each header + if dependency is not None: + dependency.result(_TIMEOUT) + self._pwrite( + write_spec.header.metadata_entry, + write_spec.metadata_pos, + verify=len(write_spec.header.metadata_entry), + ) + self._pwrite( + write_spec.header.buffer, + write_spec.header.file_offset, + verify=write_spec.header.size, + ) + + if write_specs_[0].metadata_pos == self._metadata_start + 8: + deps = [ + w.tensor_data_task + for w in write_specs_ + if w.tensor_data_task is not None + ] + commit_header_task = self._header_writer_pool.submit( + do_commit, list(write_specs_), deps + ) + + self._jobs.append(commit_header_task) + else: + # We've already written headers (we're in incremental mode) + # So we can't just batch up and write all of them at once + for w in write_specs_: + # Note this does _not_ set w.tensor_data_task, as committing headers is safe + self._header_writer_pool.submit( + do_commit_incremental, w, w.tensor_data_task + ) + metadata_size = ( + self._metadata_cur - self._metadata_start - 8 + ) # 8 bytes for metadata length field + metadata_size_task = self._header_writer_pool.submit( + self._pwrite, + struct.pack(" Date: Fri, 26 Apr 2024 16:44:01 -0700 Subject: [PATCH 10/11] support hash checks of old 2.x files --- tensorizer/_NumpyTensor.py | 15 --------------- tensorizer/serialization.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/tensorizer/_NumpyTensor.py b/tensorizer/_NumpyTensor.py index 8719892..37cb62c 100644 --- a/tensorizer/_NumpyTensor.py +++ b/tensorizer/_NumpyTensor.py @@ -82,8 +82,6 @@ k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES } -OPAQUE_DTYPE_SEP = "\0" - class _NumpyTensor(NamedTuple): data: numpy.ndarray @@ -226,15 +224,6 @@ def is_opaque(self): """ return self._is_opaque(self.numpy_dtype) - @property - def dtype_name(self): - if not self.is_opaque: - return self.numpy_dtype - - # The datatype name needs to contain both the numpy dtype that the - # data is serialized as and the original torch dtype. - return self.numpy_dtype + OPAQUE_DTYPE_SEP + self.torch_dtype - @staticmethod def _intermediate_type(size: int) -> torch.dtype: """ @@ -349,7 +338,3 @@ def _decode_torch_dtype(self) -> torch.dtype: raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") from e return dtype - - @property - def tensor_memory(self): - return self.data.data diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 3e2dc6f..9a1b2e6 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -70,7 +70,7 @@ ) from tensorizer._internal_utils import Chunked as _Chunked from tensorizer._internal_utils import _variable_read -from tensorizer._NumpyTensor import OPAQUE_DTYPE_SEP, _NumpyTensor +from tensorizer._NumpyTensor import _NumpyTensor from tensorizer._tensor_path import ( _TensorPath, _TensorPathComponent, @@ -168,6 +168,10 @@ class TensorType(IntEnum): HEADERS_AT_TOP_TENSORIZER_VERSION = 5 +# The hashable_segment_views used to include the fields that include the hash results themselves. +# These fields were zero when computing hash +HEADER_HASHES_OMIT_HASH_FIELDS = 5 + # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. META_TENSOR_TENSORIZER_VERSION = 4 @@ -185,6 +189,8 @@ class TensorType(IntEnum): TENSORIZER_MAGIC = b"|TZR|" +OPAQUE_DTYPE_SEP = "\0" + _TIMEOUT: typing.Final[int] = 3600 @@ -621,6 +627,7 @@ class _TensorHeaderDeserializer: @classmethod def from_io( cls, + file_version: int, reader: io.BufferedIOBase, zero_hashes: bool = True, check_crypt_info: bool = False, @@ -638,15 +645,20 @@ def from_io( with memoryview(buffer) as mv: reader.readinto(mv[offset:]) return cls( - buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info + file_version, + buffer, + zero_hashes=zero_hashes, + check_crypt_info=check_crypt_info, ) def __init__( self, + file_version: int, buffer: bytearray, zero_hashes: bool = True, check_crypt_info: bool = False, ): + self.file_version = file_version self.buffer = buffer offset = self.header_len_segment.size self.module_idx, tensor_type = self.tensor_info_segment.unpack_from( @@ -682,6 +694,7 @@ def __init__( self._zero_hashes(hashes_slice) if check_crypt_info: + crypt_info_start = offset crypt_info_slice, offset = self.read_crypt_info_block( buffer, offset ) @@ -689,19 +702,27 @@ def __init__( self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice ) + if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + self._hashable_segments = ( + slice(None, crypt_info_start), + slice(offset, None), + ) else: self.crypt_info = None - self._hashable_segments = (slice(None, None),) # Finally, get the tensor data length. data_length_start = offset = len(buffer) - self.data_length_segment.size self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] - self._hashable_segments = ( - slice(None, hash_start), - slice(data_length_start, None), - ) + if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if not check_crypt_info: + self._hashable_segments = (slice(None, None),) + else: + self._hashable_segments = ( + slice(None, hash_start), + slice(data_length_start, None), + ) def _hashable_segment_views(self): for segment_slice in self._hashable_segments: @@ -1703,6 +1724,7 @@ def __init__( raise ValueError("Header offsets overlap or are wrong") self._file.seek(entry.offset) header = _TensorHeaderDeserializer.from_io( + version_number, self._file, zero_hashes=True, check_crypt_info=self._has_crypt_info, @@ -2899,6 +2921,7 @@ def _copy_thread( if unsafe_self._headers is None: header = _TensorHeaderDeserializer.from_io( + unsafe_self._file_header.version_number, file_, zero_hashes=True, check_crypt_info=unsafe_self._has_crypt_info, From 54ad70f6d4f5b18bc68478550d24e7ef0b332931 Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Tue, 25 Jun 2024 15:43:24 -0700 Subject: [PATCH 11/11] updates from PR --- tensorizer/serialization.py | 49 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 9a1b2e6..b88f483 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -170,7 +170,7 @@ class TensorType(IntEnum): # The hashable_segment_views used to include the fields that include the hash results themselves. # These fields were zero when computing hash -HEADER_HASHES_OMIT_HASH_FIELDS = 5 +HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION = 5 # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. @@ -702,7 +702,10 @@ def __init__( self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice ) - if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): self._hashable_segments = ( slice(None, crypt_info_start), slice(offset, None), @@ -715,7 +718,10 @@ def __init__( self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] - if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): if not check_crypt_info: self._hashable_segments = (slice(None, None),) else: @@ -1730,7 +1736,7 @@ def __init__( check_crypt_info=self._has_crypt_info, ) if header is None: - raise KeyError("Unexpected empty header") + raise ValueError("Unexpected empty header") self._headers[entry.name] = header else: self._headers = None @@ -3804,7 +3810,7 @@ def __init__( self.name = _TensorPath.wrap_(name) self.dtype: Optional[str] = None # _prepare_for_write_dtype self.shape = tensor.size() - self.data_length = tensor.nbytes + self.data_length = tensor.element_size() * tensor.nelement() # self.file_offset # intentionally omitted, handled by _write_headers() self.include_crc32 = True self.include_sha256 = True @@ -3814,11 +3820,11 @@ def __init__( # Additional payloads that get set and used during the prepare_for_write procedures self.header: Optional[_TensorHeaderSerializer] = ( - None # $et in _prepare_for_write_headers + None # Set in _prepare_for_write_headers ) self.metadata_pos = -1 # Set in _prepare_for_write_headers self.encryptor: Optional[_crypt.ChunkedEncryption] = ( - None # $et in _do_encryption if encrypted + None # Set in _do_encryption if encrypted ) # self.tensor_data_task is a future for processing some contents of self.tensor @@ -3828,6 +3834,12 @@ def __init__( @property def tensor_memoryview(self) -> memoryview: + if not self.tensor.is_contiguous(): + # This is actually possible, but probably not intended here, + # so throw an error instead of handling this case + raise BufferError( + "Cannot create a memoryview of a discontiguous tensor" + ) nbytes = self.tensor.element_size() * self.tensor.nelement() return memoryview( (ctypes.c_char * nbytes).from_address(self.tensor.data_ptr()) @@ -4198,7 +4210,9 @@ def make_contiguous(write_spec, dependency): if dependency is not None: dependency.result(_TIMEOUT) write_spec.tensor = write_spec.tensor.contiguous() - write_spec.data_length = write_spec.tensor.nbytes + write_spec.data_length = ( + write_spec.tensor.element_size() * write_spec.tensor.nelement() + ) write_spec.user_owns_tensor_data = False for w in write_specs: @@ -4218,10 +4232,9 @@ def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]): # is opaque w.dtype = ( f" None: - assert self._encrypted and self._encryption is not None + if not self._encrypted or self._encryption is None: + raise RuntimeError( + "Tried to encrypt tensors without encryption parameters" + " having been provided" + ) # If any tensors are shared, so we need to clone all but one of them before encrypting write_specs_by_addr: Dict[int, List[TensorSerializer._WriteSpec]] = ( @@ -4270,7 +4287,7 @@ def _prepare_for_write_encryption( ) clone_tasks: List[_Future] = [] - for w in shared_write_specs[1:]: + for w in shared_write_specs[:-1]: clone_tasks.append( self._computation_pool.submit( self._do_clone, w, clone_dependencies @@ -4278,10 +4295,10 @@ def _prepare_for_write_encryption( ) w.user_owns_tensor_data = False - shared_write_specs[0].tensor_data_task = _FutureGroup( + shared_write_specs[-1].tensor_data_task = _FutureGroup( clone_tasks + clone_dependencies.futures ) - for w in shared_write_specs[1:]: + for w in shared_write_specs[:-1]: w.tensor_data_task = _FutureGroup(clone_tasks) for w in write_specs: @@ -4340,7 +4357,7 @@ def _prepare_for_write_headers( dtype_bytes, w.shape, w.data_length, - 0, # bogus file_offset. This gets filled in in build() + 0, # placeholder file_offset include_crc32=w.include_crc32, include_sha256=w.include_sha256, crypt_info=w.crypt_info,