diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 9a1b2e6..b88f483 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -170,7 +170,7 @@ class TensorType(IntEnum): # The hashable_segment_views used to include the fields that include the hash results themselves. # These fields were zero when computing hash -HEADER_HASHES_OMIT_HASH_FIELDS = 5 +HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION = 5 # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. @@ -702,7 +702,10 @@ def __init__( self.crypt_info = _crypt_info.CryptInfo.unpack_from( crypt_info_slice ) - if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): self._hashable_segments = ( slice(None, crypt_info_start), slice(offset, None), @@ -715,7 +718,10 @@ def __init__( self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] - if self.file_version < HEADER_HASHES_OMIT_HASH_FIELDS: + if ( + self.file_version + < HEADER_HASHES_OMIT_HASH_FIELDS_TENSORIZER_VERSION + ): if not check_crypt_info: self._hashable_segments = (slice(None, None),) else: @@ -1730,7 +1736,7 @@ def __init__( check_crypt_info=self._has_crypt_info, ) if header is None: - raise KeyError("Unexpected empty header") + raise ValueError("Unexpected empty header") self._headers[entry.name] = header else: self._headers = None @@ -3804,7 +3810,7 @@ def __init__( self.name = _TensorPath.wrap_(name) self.dtype: Optional[str] = None # _prepare_for_write_dtype self.shape = tensor.size() - self.data_length = tensor.nbytes + self.data_length = tensor.element_size() * tensor.nelement() # self.file_offset # intentionally omitted, handled by _write_headers() self.include_crc32 = True self.include_sha256 = True @@ -3814,11 +3820,11 @@ def __init__( # Additional payloads that get set and used during the prepare_for_write procedures self.header: Optional[_TensorHeaderSerializer] = ( - None # $et in _prepare_for_write_headers + None # Set in _prepare_for_write_headers ) self.metadata_pos = -1 # Set in _prepare_for_write_headers self.encryptor: Optional[_crypt.ChunkedEncryption] = ( - None # $et in _do_encryption if encrypted + None # Set in _do_encryption if encrypted ) # self.tensor_data_task is a future for processing some contents of self.tensor @@ -3828,6 +3834,12 @@ def __init__( @property def tensor_memoryview(self) -> memoryview: + if not self.tensor.is_contiguous(): + # This is actually possible, but probably not intended here, + # so throw an error instead of handling this case + raise BufferError( + "Cannot create a memoryview of a discontiguous tensor" + ) nbytes = self.tensor.element_size() * self.tensor.nelement() return memoryview( (ctypes.c_char * nbytes).from_address(self.tensor.data_ptr()) @@ -4198,7 +4210,9 @@ def make_contiguous(write_spec, dependency): if dependency is not None: dependency.result(_TIMEOUT) write_spec.tensor = write_spec.tensor.contiguous() - write_spec.data_length = write_spec.tensor.nbytes + write_spec.data_length = ( + write_spec.tensor.element_size() * write_spec.tensor.nelement() + ) write_spec.user_owns_tensor_data = False for w in write_specs: @@ -4218,10 +4232,9 @@ def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]): # is opaque w.dtype = ( f" None: - assert self._encrypted and self._encryption is not None + if not self._encrypted or self._encryption is None: + raise RuntimeError( + "Tried to encrypt tensors without encryption parameters" + " having been provided" + ) # If any tensors are shared, so we need to clone all but one of them before encrypting write_specs_by_addr: Dict[int, List[TensorSerializer._WriteSpec]] = ( @@ -4270,7 +4287,7 @@ def _prepare_for_write_encryption( ) clone_tasks: List[_Future] = [] - for w in shared_write_specs[1:]: + for w in shared_write_specs[:-1]: clone_tasks.append( self._computation_pool.submit( self._do_clone, w, clone_dependencies @@ -4278,10 +4295,10 @@ def _prepare_for_write_encryption( ) w.user_owns_tensor_data = False - shared_write_specs[0].tensor_data_task = _FutureGroup( + shared_write_specs[-1].tensor_data_task = _FutureGroup( clone_tasks + clone_dependencies.futures ) - for w in shared_write_specs[1:]: + for w in shared_write_specs[:-1]: w.tensor_data_task = _FutureGroup(clone_tasks) for w in write_specs: @@ -4340,7 +4357,7 @@ def _prepare_for_write_headers( dtype_bytes, w.shape, w.data_length, - 0, # bogus file_offset. This gets filled in in build() + 0, # placeholder file_offset include_crc32=w.include_crc32, include_sha256=w.include_sha256, crypt_info=w.crypt_info,