diff --git a/pyproject.toml b/pyproject.toml index 1a478ce..25085ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ dependencies = [ "psutil>=5.9.4", "boto3>=1.26.0", "redis>=5.0.0", - "hiredis>=2.2.0" + "hiredis>=2.2.0", + "pynacl>=1.5.0", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index b2edffb..237cd6c 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -12,6 +12,7 @@ import io import itertools import logging +import math import mmap import os import queue @@ -39,6 +40,8 @@ Union, ) +import nacl.secret +import nacl.utils import numpy import redis import torch @@ -46,7 +49,7 @@ import tensorizer.stream_io as stream_io import tensorizer.utils as utils from tensorizer._NumpyTensor import _NumpyTensor -from tensorizer.stream_io import CURLStreamFile +from tensorizer.stream_io import CURLStreamFile, DecryptedStream if torch.cuda.is_available(): cudart = torch.cuda.cudart() @@ -82,6 +85,7 @@ class TensorType(Enum): class HashType(Enum): CRC32 = 0 SHA256 = 1 + XSALSA20 = 2 @dataclasses.dataclass(order=True) @@ -193,7 +197,7 @@ class _TensorHeaderSerializer: "B" # Hash count (fixed for a particular tensorizer version) ) hash_header_offset: int - hash_count: ClassVar[int] = 2 + hash_count: int crc32_hash_segment: ClassVar[struct.Struct] = struct.Struct( "<" @@ -211,6 +215,16 @@ class _TensorHeaderSerializer: ) sha256_hash_offset: int + xsalsa20_hash_segment: ClassVar[struct.Struct] = struct.Struct( + "<" + "B" # XSalsa20 hash type + "B" # XSalsa20 length + "32s" # 32-byte salt + "24s" # 24-byte nonce + "B" # Crypto block size in # of shifts + ) + xsalsa20_hash_offset: int + data_length_segment: ClassVar[struct.Struct] = struct.Struct( " Tuple[bytes, bytes, int]: + salt = None + nonce = None + block_sz = None + for hash_entry in hashes: + if hash_entry.type == HashType.XSALSA20: + salt = hash_entry.hash[:32] + nonce = hash_entry.hash[32:56] + shifts = struct.unpack(" None: if isinstance(file_obj, (str, bytes, os.PathLike, int)): self._file = stream_io.open_stream(file_obj, "wb+") @@ -2041,6 +2149,24 @@ def __init__( self._mode_check(file_obj) self._file = file_obj + self._cleartext_chunk_size = crypt_chunk_size + self._crypt_chunk_size = ( + crypt_chunk_size + nacl.secret.SecretBox.MACBYTES + ) + if passphrase is not None: + if salt is None: + salt = os.urandom(32) + elif isinstance(salt, str): + salt = salt.encode("utf-8") + self._salt = salt + # Convert our passphrase to bytes + if isinstance(passphrase, str): + passphrase = passphrase.encode("utf-8") + self._crypto_key = hashlib.sha256(passphrase + salt).digest() + self._lockbox = nacl.secret.SecretBox(self._crypto_key) + else: + self._lockbox = None + # Get information about the file object's capabilities _fd_getter = getattr(self._file, "fileno", None) self._fd = _fd_getter() if callable(_fd_getter) else None @@ -2390,6 +2516,7 @@ def _write_tensor( shape, tensor_size, header_pos, + self._lockbox is not None, ) tensor_pos = header_pos + header.data_offset @@ -2423,35 +2550,105 @@ def compute_sha256(): sha256.update(tensor_memory) return sha256.digest() + def encrypt_tensor() -> Tuple[Optional[bytes], Optional[bytes]]: + start = time.monotonic() + if self._lockbox is None: + return None, None + nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) + nonce_int = int.from_bytes( + nonce, + "big", + signed=False, + ) + + num_chunks = math.ceil( + tensor_memory.nbytes / self._cleartext_chunk_size + ) + cryptotext_size = num_chunks * self._crypt_chunk_size + cryptotext = bytearray(cryptotext_size) + cryptotext_end = 0 + tensor_bytes = memoryview(tensor_memory.tobytes()) + setup_end = time.monotonic() + + chunk_repr = "" + + for i in range(num_chunks): + # We XOR the nonce with the chunk index to avoid nonce reuse. + step_nonce = nonce_int ^ i + step_nonce_bytes = step_nonce.to_bytes( + nacl.secret.SecretBox.NONCE_SIZE, "big", signed=False + ) + plaintext_begin = i * self._cleartext_chunk_size + plaintext_end = plaintext_begin + self._cleartext_chunk_size + if plaintext_end > tensor_memory.nbytes: + plaintext_end = tensor_memory.nbytes + to_encrypt = tensor_bytes[plaintext_begin:plaintext_end] + chunk = self._lockbox.encrypt( + to_encrypt, + step_nonce_bytes, + ).ciphertext + cryptotext_begin = i * self._crypt_chunk_size + cryptotext_end = cryptotext_begin + len(chunk) + cryptotext[cryptotext_begin:cryptotext_end] = chunk + encryption_header_size = len(chunk) - len(to_encrypt) + end = time.monotonic() + duration_setup_ms = (setup_end - start) * 1000 + duration_ms = (end - start) * 1000 + print( + f"Pos: {tensor_pos} - Size: {tensor_size} - Encryption time:" + f" {duration_ms:.2f}ms, setup: {duration_setup_ms:.2f}ms," + f" {cryptotext_end} bytes, {encryption_header_size} header" + f" size, {chunk_repr}" + ) + + return nonce, cryptotext[:cryptotext_end] + # This task is I/O-bound and dependent on the previous two tasks, # so it goes into the header writer pool. def commit_header( crc32_future: concurrent.futures.Future, sha256_future: concurrent.futures.Future, + encrypt_future: concurrent.futures.Future, ): crc32 = crc32_future.result(3600) sha256 = sha256_future.result(3600) + nonce, encrypted = encrypt_future.result(3600) header.add_crc32(crc32) header.add_sha256(sha256) + if encrypted is not None: + header.add_xsalsa20(self._salt, nonce, self._crypt_chunk_size) + header.update_data_length(len(encrypted)) self._pwrite(header.buffer, header_pos) crc32_task = self._computation_pool.submit(compute_crc32) sha256_task = self._computation_pool.submit(compute_sha256) + encrypt_task = self._computation_pool.submit(encrypt_tensor) + commit_header_task = self._header_writer_pool.submit( - commit_header, crc32_task, sha256_task + commit_header, crc32_task, sha256_task, encrypt_task + ) + self._jobs.extend( + (encrypt_task, crc32_task, sha256_task, commit_header_task) ) - self._jobs.extend((crc32_task, sha256_task, commit_header_task)) # This task is I/O-bound and has no prerequisites, # so it goes into the regular writer pool. def write_tensor_data(): - bytes_written = self._pwrite(tensor_memory, tensor_pos) + _, encrypted = encrypt_task.result(3600) + if encrypted is not None: + self._pwrite(encrypted, tensor_pos) + else: + self._pwrite(tensor_memory, tensor_pos) with self._tensor_count_update_lock: self._file_header.tensor_count += 1 - self._file_header.tensor_size += bytes_written + self._file_header.tensor_size += tensor_memory.nbytes self._jobs.append(self._writer_pool.submit(write_tensor_data)) - tensor_endpos = tensor_pos + tensor_size + tensor_encrypted_payload = encrypt_task.result(3600)[1] + if tensor_encrypted_payload is not None: + tensor_endpos = tensor_pos + len(tensor_encrypted_payload) + else: + tensor_endpos = tensor_pos + tensor_size # Update our prologue. if _synchronize: diff --git a/tensorizer/stream_io.py b/tensorizer/stream_io.py index d47873e..80240df 100644 --- a/tensorizer/stream_io.py +++ b/tensorizer/stream_io.py @@ -18,6 +18,7 @@ import boto3 import botocore +import nacl.secret import redis import tensorizer._version as _version @@ -186,7 +187,105 @@ def __hash__(self): return hash(self._curl_flags) -class CURLStreamFile: +class DecryptedStream(io.RawIOBase): + """ + This class is a file-like object that wraps a mixed stream of encrypted and + decrypted data. It is intended to be called when it is known that the next + read is going to be a decryption operation. + """ + + def __init__( + self, + stream: io.RawIOBase, + key: bytes, + nonce: bytes, + chunk_size: int = 1024 << 8, + ): + self._stream = stream + self._lockbox = nacl.secret.SecretBox(key) + self._nonce_int = int.from_bytes(nonce, "big", signed=False) + self._chunk_size = chunk_size + self._ciphertext_chunk_sz = chunk_size + self._lockbox.MACBYTES + self._ciphertext_buffer = bytearray(self._ciphertext_chunk_sz) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + self.close() + + def tell(self) -> int: + return self._stream.tell() + + def readinto(self, ba: bytearray) -> int: + goal = len(ba) + if goal == 0: + return 0 + # Read in chunks of self._chunk_size and decrypt them into ba + # until we have enough bytes. + ciphertext_offset = 0 + plaintext_offset = 0 + num_chunks = goal // self._chunk_size + if goal % self._chunk_size: + num_chunks += 1 + ciphertext_goal = goal + (num_chunks * self._lockbox.MACBYTES) + + step = 0 + while ciphertext_offset < ciphertext_goal: + step_nonce = self._nonce_int ^ step + step_nonce_bytes = step_nonce.to_bytes( + nacl.secret.SecretBox.NONCE_SIZE, "big", signed=False + ) + if ciphertext_offset + self._ciphertext_chunk_sz > ciphertext_goal: + ciphertext_read_sz = ciphertext_goal - ciphertext_offset + ciphertext = memoryview(self._ciphertext_buffer)[ + :ciphertext_read_sz + ] + else: + ciphertext_read_sz = self._ciphertext_chunk_sz + ciphertext = self._ciphertext_buffer + if ciphertext_read_sz == 0: + break + plaintext_sz = self._chunk_size + if goal - plaintext_offset < plaintext_sz: + plaintext_sz = goal - plaintext_offset + self._stream.readinto(ciphertext) + ba[plaintext_offset : plaintext_offset + plaintext_sz] = ( + self._lockbox.decrypt(ciphertext, step_nonce_bytes) + ) + step += 1 + ciphertext_offset += ciphertext_read_sz + plaintext_offset += plaintext_sz + + def read(self, size=-1) -> bytes: + buf = bytearray(size) + bytes_read = self.readinto(buf) + return bytes(buf[:bytes_read]) + + def writable(self) -> bool: + return False + + def fileno(self) -> int: + return self._stream.fileno() + + def close(self): + # We are a passive wrapper, so we don't close the underlying stream. + pass + + def closed(self): + return self._stream.closed + + def readline(self, size=-1) -> bytes: + return self._stream.readline(size) + + def seek(self, position, whence=SEEK_SET): + self._stream.seek(position, whence) + + +class CURLStreamFile(io.RawIOBase): """ CURLStreamFile implements a file-like object around an HTTP download, the intention being to not buffer more than we have to. It is intended for @@ -307,7 +406,7 @@ def __init__( self._curr = 0 if begin is None else begin self._end = end - self.closed = False + self._closed = False def _init_vars(self): self.popen_latencies: List[float] = getattr(self, "popen_latencies", []) @@ -444,7 +543,7 @@ def _read_until( ret_buff = ba self.bytes_read += ret_buff_sz if ret_buff_sz != rq_sz: - self.closed = True + self._closed = True self._curl.terminate() raise IOError(f"Requested {rq_sz} != {ret_buff_sz}") self._curr += ret_buff_sz @@ -464,23 +563,21 @@ def readinto(self, ba: bytearray) -> int: return self._read_until(goal_position, ba) def read(self, size=None) -> bytes: - if self.closed: + if self._closed: raise IOError("CURLStreamFile closed.") if size is None: return self._curl.stdout.read() goal_position = self._curr + size return self._read_until(goal_position) - @staticmethod - def writable() -> bool: + def writable(self) -> bool: return False - @staticmethod - def fileno() -> int: + def fileno(self) -> int: return -1 def close(self): - self.closed = True + self._closed = True if self._curl is not None: if self._curl.poll() is None: self._curl.stdout.close() @@ -492,7 +589,10 @@ def close(self): self._curl.stdout.close() self._curl = None - def readline(self): + def closed(self): + return self._closed + + def readline(self, size=-1) -> bytes: raise NotImplementedError("Unimplemented") """ diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6c9e9da..31e4cf5 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -43,6 +43,8 @@ class SerializeMethod(enum.Enum): Module = 1 StateDict = 2 + EncryptedModule = 3 + EncryptedStateDict = 4 def serialize_model( @@ -55,10 +57,16 @@ def serialize_model( out_file = tempfile.NamedTemporaryFile("wb+", delete=False) try: start_time = time.monotonic() - serializer = TensorSerializer(out_file) - if method is SerializeMethod.Module: + if method is SerializeMethod.EncryptedModule: + serializer = TensorSerializer(out_file, passphrase="test") + else: + serializer = TensorSerializer(out_file) + if method in (SerializeMethod.Module, SerializeMethod.EncryptedModule): serializer.write_module(model) - elif method is SerializeMethod.StateDict: + elif method in ( + SerializeMethod.StateDict, + SerializeMethod.EncryptedStateDict, + ): serializer.write_state_dict(sd) else: raise ValueError("Invalid serialization method") @@ -109,7 +117,10 @@ def check_deserialized( allow_subset: bool = False, include_non_persistent_buffers: bool = True, ): - orig_sd = model_digest(model_name, include_non_persistent_buffers) + orig_sd = model_digest( + model_name, + include_non_persistent_buffers, + ) if not allow_subset: test_case.assertEqual( @@ -228,6 +239,30 @@ def test_serialization(self): finally: os.unlink(serialized_model) + def test_encryption(self): + unencrypted_model, orig_sd = serialize_model( + model_name, "cpu", method=SerializeMethod.Module + ) + encrypted_model, orig_sd = serialize_model( + model_name, "cpu", method=SerializeMethod.EncryptedModule + ) + try: + with open(encrypted_model, "rb") as in_file: + deserialized = TensorDeserializer( + in_file, device="cpu", passphrase="test" + ) + check_deserialized( + self, + deserialized, + model_name, + include_non_persistent_buffers=(True), + ) + deserialized.close() + del deserialized + finally: + os.unlink(unencrypted_model) + os.unlink(encrypted_model) + def test_bfloat16(self): shape = (50, 50) tensor = torch.normal(0, 0.5, shape, dtype=torch.bfloat16)