Skip to content

Commit

Permalink
updates from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
bchess committed Jun 25, 2024
1 parent 24b90d4 commit 54ad70f
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class TensorType(IntEnum):

# The hashable_segment_views used to include the fields that include the hash results themselves.
# These fields were zero when computing hash
HEADER_HASHES_OMIT_HASH_FIELDS = 5
HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION = 5

# To serialize meta tensors into metadata-only tensors
# that deserialize back into zeroed-out buffers, data version 4 is required.
Expand Down Expand Up @@ -702,7 +702,10 @@ def __init__(
self.crypt_info = _crypt_info.CryptInfo.unpack_from(
crypt_info_slice
)
if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS:
if (
self.file_version
< HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION
):
self._hashable_segments = (
slice(None, crypt_info_start),
slice(offset, None),
Expand All @@ -715,7 +718,10 @@ def __init__(
self.data_length = self.data_length_segment.unpack_from(buffer, offset)[
0
]
if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS:
if (
self.file_version
< HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION
):
if not check_crypt_info:
self._hashable_segments = (slice(None, None),)
else:
Expand Down Expand Up @@ -1730,7 +1736,7 @@ def __init__(
check_crypt_info=self._has_crypt_info,
)
if header is None:
raise KeyError("Unexpected empty header")
raise ValueError("Unexpected empty header")
self._headers[entry.name] = header
else:
self._headers = None
Expand Down Expand Up @@ -3804,7 +3810,7 @@ def __init__(
self.name = _TensorPath.wrap_(name)
self.dtype: Optional[str] = None # _prepare_for_write_dtype
self.shape = tensor.size()
self.data_length = tensor.nbytes
self.data_length = tensor.element_size() * tensor.nelement()
# self.file_offset # intentionally omitted, handled by _write_headers()
self.include_crc32 = True
self.include_sha256 = True
Expand All @@ -3814,11 +3820,11 @@ def __init__(

# Additional payloads that get set and used during the prepare_for_write procedures
self.header: Optional[_TensorHeaderSerializer] = (
None # $et in _prepare_for_write_headers
None # Set in _prepare_for_write_headers
)
self.metadata_pos = -1 # Set in _prepare_for_write_headers
self.encryptor: Optional[_crypt.ChunkedEncryption] = (
None # $et in _do_encryption if encrypted
None # Set in _do_encryption if encrypted
)

# self.tensor_data_task is a future for processing some contents of self.tensor
Expand All @@ -3828,6 +3834,12 @@ def __init__(

@property
def tensor_memoryview(self) -> memoryview:
if not self.tensor.is_contiguous():
# This is actually possible, but probably not intended here,
# so throw an error instead of handling this case
raise BufferError(
"Cannot create a memoryview of a discontiguous tensor"
)
nbytes = self.tensor.element_size() * self.tensor.nelement()
return memoryview(
(ctypes.c_char * nbytes).from_address(self.tensor.data_ptr())
Expand Down Expand Up @@ -4198,7 +4210,9 @@ def make_contiguous(write_spec, dependency):
if dependency is not None:
dependency.result(_TIMEOUT)
write_spec.tensor = write_spec.tensor.contiguous()
write_spec.data_length = write_spec.tensor.nbytes
write_spec.data_length = (
write_spec.tensor.element_size() * write_spec.tensor.nelement()
)
write_spec.user_owns_tensor_data = False

for w in write_specs:
Expand All @@ -4218,10 +4232,9 @@ def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]):
# is opaque
w.dtype = (
f"<V{w.tensor.element_size():d}"
+ OPAQUE_DTYPE_SEP
+ tensor_dtype_str
f"{OPAQUE_DTYPE_SEP}"
f"{tensor_dtype_str}"
)
w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION)
else:
w.dtype = torch_dtype_to_numpy_dtype_cache.get(
tensor_dtype_str, None
Expand All @@ -4245,7 +4258,11 @@ def _do_clone(write_spec, dependency: Optional[_Future]):
def _prepare_for_write_encryption(
self, write_specs: Sequence[_WriteSpec]
) -> None:
assert self._encrypted and self._encryption is not None
if not self._encrypted or self._encryption is None:
raise RuntimeError(
"Tried to encrypt tensors without encryption parameters"
" having been provided"
)

# If any tensors are shared, so we need to clone all but one of them before encrypting
write_specs_by_addr: Dict[int, List[TensorSerializer._WriteSpec]] = (
Expand All @@ -4270,18 +4287,18 @@ def _prepare_for_write_encryption(
)

clone_tasks: List[_Future] = []
for w in shared_write_specs[1:]:
for w in shared_write_specs[:-1]:
clone_tasks.append(
self._computation_pool.submit(
self._do_clone, w, clone_dependencies
)
)
w.user_owns_tensor_data = False

shared_write_specs[0].tensor_data_task = _FutureGroup(
shared_write_specs[-1].tensor_data_task = _FutureGroup(
clone_tasks + clone_dependencies.futures
)
for w in shared_write_specs[1:]:
for w in shared_write_specs[:-1]:
w.tensor_data_task = _FutureGroup(clone_tasks)

for w in write_specs:
Expand Down Expand Up @@ -4340,7 +4357,7 @@ def _prepare_for_write_headers(
dtype_bytes,
w.shape,
w.data_length,
0, # bogus file_offset. This gets filled in in build()
0, # placeholder file_offset
include_crc32=w.include_crc32,
include_sha256=w.include_sha256,
crypt_info=w.crypt_info,
Expand Down

0 comments on commit 54ad70f

Please sign in to comment.