Skip to content

Commit

Permalink
support hash checks of old 2.x files
Browse files Browse the repository at this point in the history
  • Loading branch information
bchess committed Apr 26, 2024
1 parent 613e555 commit 24b90d4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
15 changes: 0 additions & 15 deletions tensorizer/_NumpyTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@
k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES
}

OPAQUE_DTYPE_SEP = "\0"


class _NumpyTensor(NamedTuple):
data: numpy.ndarray
Expand Down Expand Up @@ -226,15 +224,6 @@ def is_opaque(self):
"""
return self._is_opaque(self.numpy_dtype)

@property
def dtype_name(self):
if not self.is_opaque:
return self.numpy_dtype

# The datatype name needs to contain both the numpy dtype that the
# data is serialized as and the original torch dtype.
return self.numpy_dtype + OPAQUE_DTYPE_SEP + self.torch_dtype

@staticmethod
def _intermediate_type(size: int) -> torch.dtype:
"""
Expand Down Expand Up @@ -349,7 +338,3 @@ def _decode_torch_dtype(self) -> torch.dtype:
raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") from e

return dtype

@property
def tensor_memory(self):
return self.data.data
37 changes: 30 additions & 7 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
)
from tensorizer._internal_utils import Chunked as _Chunked
from tensorizer._internal_utils import _variable_read
from tensorizer._NumpyTensor import OPAQUE_DTYPE_SEP, _NumpyTensor
from tensorizer._NumpyTensor import _NumpyTensor
from tensorizer._tensor_path import (
_TensorPath,
_TensorPathComponent,
Expand Down Expand Up @@ -168,6 +168,10 @@ class TensorType(IntEnum):

HEADERS_AT_TOP_TENSORIZER_VERSION = 5

# The hashable_segment_views used to include the fields that include the hash results themselves.
# These fields were zero when computing hash
HEADER_HASHES_OMIT_HASH_FIELDS = 5

# To serialize meta tensors into metadata-only tensors
# that deserialize back into zeroed-out buffers, data version 4 is required.
META_TENSOR_TENSORIZER_VERSION = 4
Expand All @@ -185,6 +189,8 @@ class TensorType(IntEnum):

TENSORIZER_MAGIC = b"|TZR|"

OPAQUE_DTYPE_SEP = "\0"

_TIMEOUT: typing.Final[int] = 3600


Expand Down Expand Up @@ -621,6 +627,7 @@ class _TensorHeaderDeserializer:
@classmethod
def from_io(
cls,
file_version: int,
reader: io.BufferedIOBase,
zero_hashes: bool = True,
check_crypt_info: bool = False,
Expand All @@ -638,15 +645,20 @@ def from_io(
with memoryview(buffer) as mv:
reader.readinto(mv[offset:])
return cls(
buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info
file_version,
buffer,
zero_hashes=zero_hashes,
check_crypt_info=check_crypt_info,
)

def __init__(
self,
file_version: int,
buffer: bytearray,
zero_hashes: bool = True,
check_crypt_info: bool = False,
):
self.file_version = file_version
self.buffer = buffer
offset = self.header_len_segment.size
self.module_idx, tensor_type = self.tensor_info_segment.unpack_from(
Expand Down Expand Up @@ -682,26 +694,35 @@ def __init__(
self._zero_hashes(hashes_slice)

if check_crypt_info:
crypt_info_start = offset
crypt_info_slice, offset = self.read_crypt_info_block(
buffer, offset
)
with crypt_info_slice:
self.crypt_info = _crypt_info.CryptInfo.unpack_from(
crypt_info_slice
)
if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS:
self._hashable_segments = (
slice(None, crypt_info_start),
slice(offset, None),
)
else:
self.crypt_info = None
self._hashable_segments = (slice(None, None),)

# Finally, get the tensor data length.
data_length_start = offset = len(buffer) - self.data_length_segment.size
self.data_length = self.data_length_segment.unpack_from(buffer, offset)[
0
]
self._hashable_segments = (
slice(None, hash_start),
slice(data_length_start, None),
)
if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS:
if not check_crypt_info:
self._hashable_segments = (slice(None, None),)
else:
self._hashable_segments = (
slice(None, hash_start),
slice(data_length_start, None),
)

def _hashable_segment_views(self):
for segment_slice in self._hashable_segments:
Expand Down Expand Up @@ -1703,6 +1724,7 @@ def __init__(
raise ValueError("Header offsets overlap or are wrong")
self._file.seek(entry.offset)
header = _TensorHeaderDeserializer.from_io(
version_number,
self._file,
zero_hashes=True,
check_crypt_info=self._has_crypt_info,
Expand Down Expand Up @@ -2899,6 +2921,7 @@ def _copy_thread(

if unsafe_self._headers is None:
header = _TensorHeaderDeserializer.from_io(
unsafe_self._file_header.version_number,
file_,
zero_hashes=True,
check_crypt_info=unsafe_self._has_crypt_info,
Expand Down

0 comments on commit 24b90d4

Please sign in to comment.