diff --git a/pyproject.toml b/pyproject.toml index d6cc581c..133fe082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dependencies = [ tests = [ "pytest>=3.2.1,!=3.3.0", "hypothesis>=3.27.0", + "mypy>=1.4.1", + "flake8>=5.0.4", ] docs = [ "sphinx<7", diff --git a/src/nacl/bindings/crypto_secretbox.py b/src/nacl/bindings/crypto_secretbox.py index d1ad1133..01a36268 100644 --- a/src/nacl/bindings/crypto_secretbox.py +++ b/src/nacl/bindings/crypto_secretbox.py @@ -44,6 +44,9 @@ def crypto_secretbox(message: bytes, nonce: bytes, key: bytes) -> bytes: if len(nonce) != crypto_secretbox_NONCEBYTES: raise exc.ValueError("Invalid nonce") + nonce = ffi.from_buffer(nonce) + key = ffi.from_buffer(key) + padded = b"\x00" * crypto_secretbox_ZEROBYTES + message ciphertext = ffi.new("unsigned char[]", len(padded)) @@ -72,6 +75,9 @@ def crypto_secretbox_open( if len(nonce) != crypto_secretbox_NONCEBYTES: raise exc.ValueError("Invalid nonce") + nonce = ffi.from_buffer(nonce) + key = ffi.from_buffer(key) + padded = b"\x00" * crypto_secretbox_BOXZEROBYTES + ciphertext plaintext = ffi.new("unsigned char[]", len(padded)) diff --git a/tests/test_bindings.py b/tests/test_bindings.py index a89c361c..bbcabd5a 100644 --- a/tests/test_bindings.py +++ b/tests/test_bindings.py @@ -14,8 +14,9 @@ import hashlib +import itertools from binascii import hexlify, unhexlify -from typing import List, Tuple +from typing import Callable, List, Tuple from hypothesis import given, settings from hypothesis.strategies import binary, integers @@ -94,6 +95,26 @@ def test_secretbox_easy(): ) +@pytest.mark.parametrize( + ("encoder", "decoder"), + itertools.product( + [bytes, bytearray, memoryview], + [bytes, bytearray, memoryview], + ), +) +def test_secretbox_byteslike( + encoder: Callable[[bytes], bytes], decoder: Callable[[bytes], bytes] +): + key = b"\x00" * c.crypto_secretbox_KEYBYTES + msg = b"message" + nonce = b"\x01" * c.crypto_secretbox_NONCEBYTES + ct = c.crypto_secretbox(encoder(msg), encoder(nonce), encoder(key)) + assert len(ct) == len(msg) + c.crypto_secretbox_BOXZEROBYTES + assert tohex(ct) == "3ae84dfb89728737bd6e2c8cacbaf8af3d34cc1666533a" + msg2 = c.crypto_secretbox_open(decoder(ct), decoder(nonce), decoder(key)) + assert msg2 == msg + + def test_secretbox_wrong_length(): with pytest.raises(ValueError): c.crypto_secretbox(b"", b"", b"")