From 24b90d43ac3834fcccedc4f673c58e26506eba05 Mon Sep 17 00:00:00 2001 From: Ben Chess Date: Fri, 26 Apr 2024 16:44:01 -0700 Subject: [PATCH] support hash checks of old 2.x files --- tensorizer/_NumpyTensor.py | 15 --------------- tensorizer/serialization.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/tensorizer/_NumpyTensor.py b/tensorizer/_NumpyTensor.py index 8719892..37cb62c 100644 --- a/tensorizer/_NumpyTensor.py +++ b/tensorizer/_NumpyTensor.py @@ -82,8 +82,6 @@ k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES } -OPAQUE_DTYPE_SEP = "\0" - class _NumpyTensor(NamedTuple): data: numpy.ndarray @@ -226,15 +224,6 @@ def is_opaque(self): """ return self._is_opaque(self.numpy_dtype) - @property - def dtype_name(self): - if not self.is_opaque: - return self.numpy_dtype - - # The datatype name needs to contain both the numpy dtype that the - # data is serialized as and the original torch dtype. - return self.numpy_dtype + OPAQUE_DTYPE_SEP + self.torch_dtype - @staticmethod def _intermediate_type(size: int) -> torch.dtype: """ @@ -349,7 +338,3 @@ def _decode_torch_dtype(self) -> torch.dtype: raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") from e return dtype - - @property - def tensor_memory(self): - return self.data.data diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 3e2dc6f..9a1b2e6 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -70,7 +70,7 @@ ) from tensorizer._internal_utils import Chunked as _Chunked from tensorizer._internal_utils import _variable_read -from tensorizer._NumpyTensor import OPAQUE_DTYPE_SEP, _NumpyTensor +from tensorizer._NumpyTensor import _NumpyTensor from tensorizer._tensor_path import ( _TensorPath, _TensorPathComponent, @@ -168,6 +168,10 @@ class TensorType(IntEnum): HEADERS_AT_TOP_TENSORIZER_VERSION = 5 +# The hashable_segment_views used to include the fields that include the hash results themselves. +# These fields were zero when computing hash +HEADER_HASHES_OMIT_HASH_FIELDS = 5 + # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. META_TENSOR_TENSORIZER_VERSION = 4 @@ -185,6 +189,8 @@ class TensorType(IntEnum): TENSORIZER_MAGIC = b"|TZR|" +OPAQUE_DTYPE_SEP = "\0" + _TIMEOUT: typing.Final[int] = 3600 @@ -621,6 +627,7 @@ class _TensorHeaderDeserializer: @classmethod def from_io( cls, + file_version: int, reader: io.BufferedIOBase, zero_hashes: bool = True, check_crypt_info: bool = False, @@ -638,15 +645,20 @@ def from_io( with memoryview(buffer) as mv: reader.readinto(mv[offset:]) return cls( - buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info + file_version, + buffer, + zero_hashes=zero_hashes, + check_crypt_info=check_crypt_info, ) def __init__( self, + file_version: int, buffer: bytearray, zero_hashes: bool = True, check_crypt_info: bool = False, ): + self.file_version = file_version self.buffer = buffer offset = self.header_len_segment.size self.module_idx, tensor_type = self.tensor_info_segment.unpack_from( @@ -682,6 +694,7 @@ def __init__( self._zero_hashes(hashes_slice) if check_crypt_info: + crypt_info_start = offset crypt_info_slice, offset = self.read_crypt_info_block( buffer, offset ) @@ -689,19 +702,27 @@ def __init__( self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice ) + if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + self._hashable_segments = ( + slice(None, crypt_info_start), + slice(offset, None), + ) else: self.crypt_info = None - self._hashable_segments = (slice(None, None),) # Finally, get the tensor data length. data_length_start = offset = len(buffer) - self.data_length_segment.size self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] - self._hashable_segments = ( - slice(None, hash_start), - slice(data_length_start, None), - ) + if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if not check_crypt_info: + self._hashable_segments = (slice(None, None),) + else: + self._hashable_segments = ( + slice(None, hash_start), + slice(data_length_start, None), + ) def _hashable_segment_views(self): for segment_slice in self._hashable_segments: @@ -1703,6 +1724,7 @@ def __init__( raise ValueError("Header offsets overlap or are wrong") self._file.seek(entry.offset) header = _TensorHeaderDeserializer.from_io( + version_number, self._file, zero_hashes=True, check_crypt_info=self._has_crypt_info, @@ -2899,6 +2921,7 @@ def _copy_thread( if unsafe_self._headers is None: header = _TensorHeaderDeserializer.from_io( + unsafe_self._file_header.version_number, file_, zero_hashes=True, check_crypt_info=unsafe_self._has_crypt_info,